Active learning for a multilabel text classifer

If you provide a trained multilabel classifier to textcat.teach and ask to add new annotations only for a single label, then the TextClassifier model instanciated for the active learning part is not taking into account the prediction score for the specified label but the prediction score of the ‘first’ label (in alphabatical order) among all the labels from the ‘multi-label-classifier’

What I mean is that model = TextClassifier(nlp, label.split(','), long_text=long_text) seems to restrict the predictive model to the ‘first’ label nlp(text).cats.keys() without considering the given label when nlp = spacy.load('multi-label-classifier')

Is it the expected behaviour? I am using prodigy v0.5.0 and spacy 2.0.3

If someone has the same issue, this is the little hack I used to solve the pb:

from prodigy.models.textcat import TextClassifier
import cytoolz

    class MyTextClassifier(TextClassifier):
        """ """
        def __init__(self, nlp, label, batch_size=128):
            """ """
            if nlp.has_pipe("textcat"):
                # get existing labels from the trained model
                labels = nlp.get_pipe("textcat").labels
                # init TextClassifier with all labels
                super().__init__(nlp, labels)
                # keep our unique label as attribute
                self.label = label
                self.batch_size = batch_size
            else:
                super().__init__(nlp, label)

        def __call__(self, data_stream):
            """ """
            if hasattr(self, "label"):
                for batch in cytoolz.partition_all(self.batch_size, data_stream):
                    data_list = list(batch)
                    texts = (x["text"] for x in data_list)
                    tuples = zip(texts, data_list)
                    for doc, context in self.nlp.pipe(tuples, as_tuples=True):
                        # Get the score of the single label we are interested in for active learning
                        score = doc.cats[self.label]
                        context['score'] = score
                        context['label'] = self.label
                        yield (score, context)
            else:
                super().__call__(data_stream)