This should work for now. I’ll add an API for this to Thinc, but for now I think this solution gives decent usability and doesn’t require you to run a different version of anything:
'''Get attention weights out of the ParametricAttention layer.'''
import spacy
from contextlib import contextmanager
from thinc.api import layerize
from spacy.gold import GoldParse
def find_attn_layer(model):
queue = [model]
seen = set()
for layer in queue:
names = [child.name for child in layer._layers]
if 'para-attn' in names:
return layer, names.index('para-attn')
if id(layer) not in seen:
queue.extend(layer._layers)
seen.add(id(layer))
return None, -1
def create_attn_proxy(attn):
'''Return a proxy to the attention layer which will fetch the attention
weights on each call, appending them to the list 'output'.
'''
output = []
def get_weights(Xs_lengths, drop=0.):
Xs, lengths = Xs_lengths
output.append(attn._get_attention(attn.Q, Xs, lengths)[0])
return attn.begin_update(Xs_lengths, drop=drop)
return output, layerize(get_weights)
@contextmanager
def get_attention_weights(textcat):
'''Wrap the attention layer of the textcat with a function to
intercept the attention weights. We replace the attention component
with our wrapper in the pipeline for the duration of the context manager.
On exit, we put everything back.
'''
parent, i = find_attn_layer(textcat.model)
if parent != None:
output_vars, wrapped = create_attn_proxy(parent._layers[i])
parent._layers[i] = wrapped
yield output_vars
else:
yield None
def main():
nlp = spacy.blank('en')
textcat = nlp.create_pipe('textcat')
textcat.add_label('SPAM')
nlp.add_pipe(textcat)
opt = nlp.begin_training()
docs = [nlp.make_doc('buy viagra')]
golds = [GoldParse(docs[0], cats={'SPAM':1})]
# All calls to the attention model made during this block will append
# the attention weights to the list attn_weights.
# The weights for a batch of documents will be a single concatenated
# array -- so if you pass in a batch of lengths 4, 5 and 7, you'll get
# the weights in an array of shape (16,). The value at index 3 will be
# the attention weight of the last word of the first document. The value
# at index 4 will be the attention weight of the first word of the second
# document.
# The attention weights should be floats between 0 and 1, with 1 indicating
# maximum relevance.
# The attention layer is parametric (following Liang et al 2016's textcat
# paper), which means the query vector is learned
# jointly with the model. It's not too hard to substitute a different
# attention layer instead, e.g. one which does attention by average of
# the word vectors or something. See thinc/neural/_classes/attention.py
with get_attention_weights(textcat) as attn_weights:
loss = textcat.update(docs, golds, sgd=opt)
print(attn_weights)
if __name__ == '__main__':
main()