Here's my complete, working solution (as of Prodigy 1.10.4):
from typing import List, Optional, Dict, Any, Union, Iterable
from pathlib import Path
import spacy
from prodigy.models.matcher import PatternMatcher
from prodigy.models.textcat import TextClassifier
from prodigy.components.loaders import get_stream
from prodigy.components.preprocess import add_label_options, add_labels_to_stream
from prodigy.components.sorters import prefer_uncertain
from prodigy.core import recipe
from prodigy.util import combine_models, log, msg, get_labels, split_string
# Restore deprecated recipes
from prodigy.deprecated.train import textcat_batch_train as batch_train # noqa: F401
from prodigy.deprecated.train import textcat_train_curve as train_curve # noqa: F401
@recipe(
"textcat_multi",
# fmt: off
dataset=("Dataset to save annotations to", "positional", None, str),
spacy_model=("Loadable spaCy model or blank:lang (e.g. blank:en)", "positional", None, str),
source=("Data to annotate (file path or '-' to read from standard input)", "positional", None, str),
loader=("Loader (guessed from file extension if not set)", "option", "lo", str),
label=("Comma-separated label(s) to annotate or text file with one label per line", "option", "l", get_labels),
patterns=("Path to match patterns file", "option", "pt", str),
long_text=("DEPRECATED: Use long-text mode", "flag", "L", bool),
init_tok2vec=("Path to pretrained weights for the token-to-vector parts of the model. See 'spacy pretrain'.", "option", "t2v", str),
exclude=("Comma-separated list of dataset IDs whose annotations to exclude", "option", "e", split_string),
# fmt: on
)
def textcat_multi(
dataset: str,
spacy_model: str,
source: Union[str, Iterable[dict]],
label: Optional[List[str]] = None,
patterns: Optional[str] = None,
init_tok2vec: Optional[Union[str, Path]] = None,
loader: Optional[str] = None,
long_text: bool = False,
exclude: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Collect the best possible training data for a text classification model
with the model in the loop. Based on your annotations, Prodigy will decide
which questions to ask next.
"""
log("RECIPE: Starting recipe textcat.teach", locals())
if label is None:
msg.fail("textcat.teach requires at least one --label", exits=1)
if spacy_model.startswith("blank:"):
nlp = spacy.blank(spacy_model.replace("blank:", ""))
else:
nlp = spacy.load(spacy_model)
log(f"RECIPE: Creating TextClassifier with model {spacy_model}")
model = TextClassifier(nlp, label, long_text=long_text, init_tok2vec=init_tok2vec)
stream = get_stream(
source, loader=loader, rehash=True, dedup=True, input_key="text"
)
if patterns is None:
predict = model
update = model.update
else:
matcher = PatternMatcher(
model.nlp,
prior_correct=5.0,
prior_incorrect=5.0,
label_span=False,
label_task=True,
filter_labels=label,
combine_matches=True,
task_hash_keys=("label",),
)
matcher = matcher.from_disk(patterns)
log("RECIPE: Created PatternMatcher and loaded in patterns", patterns)
# Combine the textcat model with the PatternMatcher to annotate both
# match results and predictions, and update both models.
predict, update = combine_models(model, matcher)
def stream_pre_annotated(stream, model):
nlp = model
options = [
{"id":"OPTION_1", "text": "Option 1"},
{"id":"OPTION_2", "text": "Option 2"},
{"id":"OPTION_3", "text": "Option 3"},
]
for task in stream:
options_accepted = []
if task['score'] >= 0.5:
yield {
"text": task['text'],
"options": options,
"accept": [task['label']]
}
else:
yield {
"text": task['text'],
"options": options,
}
stream = prefer_uncertain(predict(stream))
stream = stream_pre_annotated(stream, model)
return {
"view_id": "choice",
"dataset": dataset,
"stream": stream,
"exclude": exclude,
"update": update,
"config": {
"lang": nlp.lang,
"labels": model.labels,
},
}
And this can be run like so:
prodigy textcat_multi your_dataset ./models/your_trained_textcat_model ./data/unlabeled_data.jsonl --label OPTION_1,OPTION_2,OPTION_3 -F ./src/multi_textcat_teach.py
There's surely some redundant code in here, but I don't have enough experience yet with the latest version of Prodigy's textcat
recipe to trim it down. Hope it helps someone.