Highlighting the matching words for text classfication

This should work for now. I’ll add an API for this to Thinc, but for now I think this solution gives decent usability and doesn’t require you to run a different version of anything:


'''Get attention weights out of the ParametricAttention layer.'''
import spacy
from contextlib import contextmanager
from thinc.api import layerize
from spacy.gold import GoldParse


def find_attn_layer(model):
    queue = [model]
    seen = set()
    for layer in queue:
        names = [child.name for child in layer._layers]
        if 'para-attn' in names:
            return layer, names.index('para-attn')
        if id(layer) not in seen:
            queue.extend(layer._layers)
        seen.add(id(layer))
    return None, -1

def create_attn_proxy(attn):
    '''Return a proxy to the attention layer which will fetch the attention
    weights on each call, appending them to the list 'output'. 
    '''
    output = []
    def get_weights(Xs_lengths, drop=0.):
        Xs, lengths = Xs_lengths
        output.append(attn._get_attention(attn.Q, Xs, lengths)[0])
        return attn.begin_update(Xs_lengths, drop=drop)
    return output, layerize(get_weights)

@contextmanager
def get_attention_weights(textcat):
    '''Wrap the attention layer of the textcat with a function to
    intercept the attention weights. We replace the attention component
    with our wrapper in the pipeline for the duration of the context manager.
    On exit, we put everything back.
    '''
    parent, i = find_attn_layer(textcat.model)
    if parent != None:
        output_vars, wrapped = create_attn_proxy(parent._layers[i])
        parent._layers[i] = wrapped
        yield output_vars
    else:
        yield None

def main():
    nlp = spacy.blank('en')
    textcat = nlp.create_pipe('textcat')
    textcat.add_label('SPAM')
    nlp.add_pipe(textcat)
    opt = nlp.begin_training()
    docs = [nlp.make_doc('buy viagra')]
    golds = [GoldParse(docs[0], cats={'SPAM':1})]
    # All calls to the attention model made during this block will append
    # the attention weights to the list attn_weights.
    # The weights for a batch of documents will be a single concatenated
    # array -- so if you pass in a batch of lengths 4, 5 and 7, you'll get
    # the weights in an array of shape (16,). The value at index 3 will be
    # the attention weight of the last word of the first document. The value
    # at index 4 will be the attention weight of the first word of the second
    # document.
    # The attention weights should be floats between 0 and 1, with 1 indicating
    # maximum relevance.
    # The attention layer is parametric (following Liang et al 2016's textcat
    # paper), which means the query vector is learned
    # jointly with the model. It's not too hard to substitute a different
    # attention layer instead, e.g. one which does attention by average of
    # the word vectors or something. See thinc/neural/_classes/attention.py
    with get_attention_weights(textcat) as attn_weights:
        loss = textcat.update(docs, golds, sgd=opt)
    print(attn_weights)


if __name__ == '__main__':
    main()
1 Like