As written it is not updating the db, even when saving from the interface or when several batches have been processed. I do not know why. I ended up updating the db in the custom teach method ( as well as made it applicable to all pattern files).
import prodigy
from prodigy.recipes.ner import teach
from prodigy.components.db import connect
import spacy
from spacy.matcher import Matcher
import json
@prodigy.recipe('must_match_pattern.ner.teach',
patterns=prodigy.recipe_args['patterns'],
dataset=prodigy.recipe_args['dataset'],
spacy_model=prodigy.recipe_args['spacy_model'],
database=("Database to connect to", "positional", None, str),
label=prodigy.recipe_args['label'])
def custom_ner_teach(dataset, spacy_model, database, patterns, label):
"""Custom wrapper for ner.teach recipe that replaces the stream.
Automatically rejects a suggested annotation if it does
not match from the patterns file
"""
components = teach(dataset=dataset, spacy_model=spacy_model,
source=database, label=label, patterns=patterns)
original_stream = components['stream']
original_update = components['update']
# add all the patterns to the matcher
nlp = spacy.load('en')
matcher = Matcher(nlp.vocab)
# read in patterns file and for each pattern add it to spacy matcher
with open(str(patterns), "r") as f:
for line in f:
print (line)
matcher.add(label, None, json.loads(line)['pattern'])
bad_spans = []
def get_modified_stream():
nonlocal bad_spans # want to update this outside of the function
j = 0
for eg in original_stream:
# import ipdb; ipdb.set_trace()
is_rejected = False
for span in eg['spans']:
doc = nlp(span['text'])
matches = matcher(doc)
# has to have appropriate label and not be a match in order to reject
if span['label'] == label and matches == []:
eg['answer'] = 'reject' # auto-reject
is_rejected = True
if j % 10 == 0:
print('rejected', str(j), ' spans that did not match the pattern so far')
j += 1
if is_rejected:
bad_spans.append(eg)
continue
else:
yield eg
def modified_update(batch):
nonlocal bad_spans
batch=batch + bad_spans
# update db with rejects
update_db(bad_spans)
# reset rejects
bad_spans=[]
return original_update(batch)
def update_db(bad_spans):
db=connect()
# data = db.get_dataset(dataset)
db.add_examples(bad_spans, datasets=[dataset])
print ('added ', len(bad_spans), ' to db')
components['stream']=get_modified_stream()
components['update']=modified_update
components['config']['label']=label # hack to fix incorrect labeling of label
return components
usage looks like:
prodigy must_match_pattern.ner.teach [db-name] [model path] [data source path] --label [label name] --patterns [path to patternsfile] -F special_filter.py
I am getting a 30x speed up in terms of annotation rate - you guys should add something like this for the next version. Maybe just add a must_match_pattern flag to teach.
Also, sorry I can't get the code to look pretty - maybe you can fix it?