train spacy for text classification
Asked Answered
R

2

7

After reading the docs and doing the tutorial I figured I'd make a small demo. Turns out my model does not want to train. Here's the code

import spacy
import random
import json

TRAINING_DATA = [
    ["My little kitty is so special", {"KAT": True}],
    ["Dude, Totally, Yeah, Video Games", {"KAT": False}],
    ["Should I pay $1,000 for the iPhone X?", {"KAT": False}],
    ["The iPhone 8 reviews are here", {"KAT": False}],
    ["Noa is a great cat name.", {"KAT": True}],
    ["We got a new kitten!", {"KAT": True}]
]

nlp = spacy.blank("en")
category = nlp.create_pipe("textcat")
nlp.add_pipe(category)
category.add_label("KAT")

# Start the training
nlp.begin_training()

# Loop for 10 iterations
for itn in range(100):
    # Shuffle the training data
    random.shuffle(TRAINING_DATA)
    losses = {}

    # Batch the examples and iterate over them
    for batch in spacy.util.minibatch(TRAINING_DATA, size=2):
        texts = [text for text, entities in batch]
        annotations = [{"textcat": [entities]} for text, entities in batch]
        nlp.update(texts, annotations, losses=losses)
    if itn % 20 == 0:
        print(losses)

When I run this the output suggests that very little is learned.

{'textcat': 0.0}
{'textcat': 0.0}
{'textcat': 0.0}
{'textcat': 0.0}
{'textcat': 0.0}

This feels wrong. There should be an error or a meaningful tag. The predictions confirm this.

for text, d in TRAINING_DATA:
    print(text, nlp(text).cats)

# Dude, Totally, Yeah, Video Games {'KAT': 0.45303162932395935}
# The iPhone 8 reviews are here {'KAT': 0.45303162932395935}
# Noa is a great cat name. {'KAT': 0.45303162932395935}
# Should I pay $1,000 for the iPhone X? {'KAT': 0.45303162932395935}
# We got a new kitten! {'KAT': 0.45303162932395935}
# My little kitty is so special {'KAT': 0.45303162932395935}

It feels like my code is missing something but I can't figure out what.

Rutharuthann answered 23/5, 2019 at 19:14 Comment(3)
Here they use 2000 examples. Are you sure that machine learning works with 6 examples? All three of your cat examples use different words for cats. I'd start with 10 different examples with only one word for a cat.Fraktur
sure, but the textcat category is reporting zero loss, this should not be so.Rutharuthann
Your training loop and data looks correct – and I think I found the problem: try changing {"textcat": [entities]} to {"cats": entities} (also see here for the expected keys if you're passing in a dict of annotations). When you're updating the text classifier, it'll look for a key "cats" – but that wasn't there, only "textcat". So you were basically updating the text classifier with nothing, and ended up with only the randomly initialized weights (resulting from nlp.begin_training).Jointless
C
8

If you update and use spaCy 3 - the code above will no longer work. The solution is to migrate with some changes. I've modified the example from cantdutchthis accordingly.

Summary of changes:

  • use the config to change the architecture. The old default was "bag of words", the new default is "text ensemble" which uses attention. Keep this in mind when tuning the models
  • labels now need to be one-hot encoded
  • the add_pipe interface has changed slightly
  • nlp.update now requires an Example object rather than a tuple of text, annotation
import spacy
# Add imports for example, as well as textcat config...
from spacy.training import Example
from spacy.pipeline.textcat import single_label_bow_config, single_label_default_config
from thinc.api import Config
import random

# labels should be one-hot encoded
TRAINING_DATA = [
    ["My little kitty is so special", {"KAT0": True}],
    ["Dude, Totally, Yeah, Video Games", {"KAT1": True}],
    ["Should I pay $1,000 for the iPhone X?", {"KAT1": True}],
    ["The iPhone 8 reviews are here", {"KAT1": True}],
    ["Noa is a great cat name.", {"KAT0": True}],
    ["We got a new kitten!", {"KAT0": True}]
]


# bow
# config = Config().from_str(single_label_bow_config)

# textensemble with attention
config = Config().from_str(single_label_default_config)

nlp = spacy.blank("en")
# now uses `add_pipe` instead
category = nlp.add_pipe("textcat", last=True, config=config)
category.add_label("KAT0")
category.add_label("KAT1")


# Start the training
nlp.begin_training()

# Loop for 10 iterations
for itn in range(100):
    # Shuffle the training data
    random.shuffle(TRAINING_DATA)
    losses = {}

    # Batch the examples and iterate over them
    for batch in spacy.util.minibatch(TRAINING_DATA, size=4):
        texts = [nlp.make_doc(text) for text, entities in batch]
        annotations = [{"cats": entities} for text, entities in batch]

        # uses an example object rather than text/annotation tuple
        examples = [Example.from_dict(doc, annotation) for doc, annotation in zip(
            texts, annotations
        )]
        nlp.update(examples, losses=losses)
    if itn % 20 == 0:
        print(losses)
Cover answered 11/6, 2021 at 21:7 Comment(1)
Looks like the config variable has been initialized but hasn't been used anywhere? How does the text cat model pick the config?Benefaction
R
7

Based on the comment from Ines, this is the answer.

import spacy
import random
import json

TRAINING_DATA = [
    ["My little kitty is so special", {"KAT": True}],
    ["Dude, Totally, Yeah, Video Games", {"KAT": False}],
    ["Should I pay $1,000 for the iPhone X?", {"KAT": False}],
    ["The iPhone 8 reviews are here", {"KAT": False}],
    ["Noa is a great cat name.", {"KAT": True}],
    ["We got a new kitten!", {"KAT": True}]
]

nlp = spacy.blank("en")
category = nlp.create_pipe("textcat")
category.add_label("KAT")
nlp.add_pipe(category)

# Start the training
nlp.begin_training()

# Loop for 10 iterations
for itn in range(100):
    # Shuffle the training data
    random.shuffle(TRAINING_DATA)
    losses = {}

    # Batch the examples and iterate over them
    for batch in spacy.util.minibatch(TRAINING_DATA, size=1):
        texts = [nlp(text) for text, entities in batch]
        annotations = [{"cats": entities} for text, entities in batch]
        nlp.update(texts, annotations, losses=losses)
    if itn % 20 == 0:
        print(losses)
Rutharuthann answered 24/5, 2019 at 14:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.