Hi!
I have a question regarding the new update of the textcat.batch.train
recipe in prodigy version 1.8.
I see that you added the parameter --exclusive
which can be very useful. However, I dont see it applied to the code of the recipe. It looks like the parameter exclusive
is passed but never used but maybe I am missing something. Just want to double check with you. Here is the code of the recipe:
@recipe(
"textcat.batch-train",
dataset=recipe_args["dataset"],
input_model=recipe_args["spacy_model"],
output_model=recipe_args["output"],
init_tok2vec=recipe_args["init_tok2vec"],
exclusive=recipe_args["exclusive"],
lang=recipe_args["lang"],
factor=recipe_args["factor"],
dropout=recipe_args["dropout"],
n_iter=recipe_args["n_iter"],
batch_size=recipe_args["batch_size"],
eval_id=recipe_args["eval_id"],
eval_split=recipe_args["eval_split"],
long_text=("Long text", "flag", "L", bool),
silent=recipe_args["silent"],
)
def batch_train(
dataset,
input_model=None,
output_model=None,
init_tok2vec=None,
lang="en",
factor=1,
dropout=0.2,
n_iter=10,
exclusive=False,
batch_size=10,
eval_id=None,
eval_split=None,
long_text=False,
silent=False,
):
"""
Batch train a new text classification model from annotations. Prodigy will
export the best result to the output directory, and include a JSONL file of
the training and evaluation examples. You can either supply a dataset ID
containing the evaluation data, or choose to split off a percentage of
examples for evaluation.
"""
log("RECIPE: Starting recipe textcat.batch-train", locals())
fix_random_seed(0)
DB = connect()
if dataset not in DB:
prints("Can't find dataset '{}'".format(dataset), exits=1, error=True)
print_ = get_print(silent)
random.seed(0)
if input_model is not None:
nlp = spacy.load(input_model)
print_("\nLoaded model {}".format(input_model))
else:
nlp = spacy.blank(lang, pipeline=[])
print_("\nLoaded blank model")
examples = DB.get_dataset(dataset)
# Make sure that examples in datasets created with a choice interface are
# converted to "regular" text classification tasks with a "label" key
examples = convert_options_to_cats(examples)
labels = set()
for eg in examples:
for label, value in eg["cats"].items():
labels.add(label)
labels = list(sorted(labels))
model = TextClassifier(
nlp,
labels,
long_text=long_text,
low_data=len(examples) < 1000,
init_tok2vec=init_tok2vec,
)
log(
"RECIPE: Initialised TextClassifier with model {}".format(input_model),
model.nlp.meta,
)
other_pipes = [p for p in nlp.pipe_names if p not in ("textcat", "sentencizer")]
if other_pipes:
disabled = nlp.disable_pipes(*other_pipes)
log("RECIPE: Temporarily disabled other pipes: {}".format(other_pipes))
else:
disabled = None
random.shuffle(examples)
if eval_id:
evals = DB.get_dataset(eval_id)
evals = convert_options_to_cats(evals)
print_("Loaded {} evaluation examples from '{}'".format(len(evals), eval_id))
else:
examples, evals, eval_split = split_evals(examples, eval_split)
print_(
"Using {}% of examples ({}) for evaluation".format(
round(eval_split * 100), len(evals)
)
)
random.shuffle(examples)
examples = examples[: int(len(examples) * factor)]
print_(printers.trainconf(dropout, n_iter, batch_size, factor, len(examples)))
if len(evals) > 0:
print_(printers.tc_update_header())
best_acc = {"accuracy": 0}
best_model = None
if long_text:
examples = list(split_sentences(nlp, examples, min_length=False))
for i in range(n_iter):
loss = 0.0
random.shuffle(examples)
for batch in minibatch(tqdm.tqdm(examples, leave=False), size=batch_size):
batch = list(batch)
loss += model.update(batch, revise=False, drop=dropout)
if len(evals) > 0:
with nlp.use_params(model.optimizer.averages):
acc = model.evaluate(tqdm.tqdm(evals, leave=False))
if acc["accuracy"] > best_acc["accuracy"]:
best_acc = dict(acc)
best_model = nlp.to_bytes()
print_(printers.tc_update(i, loss, acc))
if len(evals) > 0:
print_(printers.tc_result(best_acc))
if output_model is not None:
if best_model is not None:
nlp = nlp.from_bytes(best_model)
if disabled:
log("RECIPE: Restoring disabled pipes: {}".format(other_pipes))
disabled.restore()
msg = export_model_data(output_model, nlp, examples, evals)
print_(msg)
return best_acc["accuracy"]
Thanks,
Kasra