Apply labels with zero-shot classification

Use zero shot learning for labeling, classification and topic modeling

Β·

4 min read

This article shows how zero-shot classification can be used to perform text classification, labeling and topic modeling. txtai provides a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers. This method works impressively well out of the box. Kudos to the Hugging Face team for the phenomenal work on zero-shot classification!

The examples in this article pick the best matching label using a list of labels for a snippet of text.

tldrstory has full-stack implementation of a zero-shot classification system using Streamlit, FastAPI and Hugging Face Transformers. There is also a Medium article describing tldrstory and zero-shot classification.

Install dependencies

Install txtai and all dependencies.

pip install txtai

Create a Labels instance

The Labels instance is the main entrypoint for zero-shot classification. This is a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers.

In addition to the default model, additional models can be found on the Hugging Face model hub.

from txtai.pipeline import Labels

# Create labels model
labels = Labels()

Applying labels to text

The example below shows how a zero-shot classifier can be applied to arbitary text. The default model for the zero-shot classification pipeline is bart-large-mnli.

Look at the results below. It's nothing short of amazing✨ how well it performs. These aren't all simple even for a human. For example, intercepted was purposely picked as that is more common in football than basketball. The amount of knowledge stored in larger Transformer models continues to impress me.

data = ["Dodgers lose again, give up 3 HRs in a loss to the Giants",
        "Giants 5 Cardinals 4 final in extra innings",
        "Dodgers drop Game 2 against the Giants, 5-4",
        "Flyers 4 Lightning 1 final. 45 saves for the Lightning.",
        "Slashing, penalty, 2 minute power play coming up",
        "What a stick save!",
        "Leads the NFL in sacks with 9.5",
        "UCF 38 Temple 13",
        "With the 30 yard completion, down to the 10 yard line",
        "Drains the 3pt shot!!, 0:15 remaining in the game",
        "Intercepted! Drives down the court and shoots for the win",
        "Massive dunk!!! they are now up by 15 with 2 minutes to go"]

# List of labels
tags = ["Baseball", "Football", "Hockey", "Basketball"]

print("%-75s %s" % ("Text", "Label"))
print("-" * 100)

for text in data:
    print("%-75s %s" % (text, tags[labels(text, tags)[0][0]]))
Text                                                                        Label
----------------------------------------------------------------------------------------------------
Dodgers lose again, give up 3 HRs in a loss to the Giants                   Baseball
Giants 5 Cardinals 4 final in extra innings                                 Baseball
Dodgers drop Game 2 against the Giants, 5-4                                 Baseball
Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     Hockey
Slashing, penalty, 2 minute power play coming up                            Hockey
What a stick save!                                                          Hockey
Leads the NFL in sacks with 9.5                                             Football
UCF 38 Temple 13                                                            Football
With the 30 yard completion, down to the 10 yard line                       Football
Drains the 3pt shot!!, 0:15 remaining in the game                           Basketball
Intercepted! Drives down the court and shoots for the win                   Basketball
Massive dunk!!! they are now up by 15 with 2 minutes to go                  Basketball

Let's try emoji πŸ˜€

Does the model have knowledge of emoji? Check out the run below, sure looks like it does! Notice the labels are applied based on the perspective from which the information is presented.

tags = ["πŸ˜€", "😑"]

print("%-75s %s" % ("Text", "Label"))
print("-" * 100)

for text in data:
    print("%-75s %s" % (text, tags[labels(text, tags)[0][0]]))
Text                                                                        Label
----------------------------------------------------------------------------------------------------
Dodgers lose again, give up 3 HRs in a loss to the Giants                   😑
Giants 5 Cardinals 4 final in extra innings                                 πŸ˜€
Dodgers drop Game 2 against the Giants, 5-4                                 😑
Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     πŸ˜€
Slashing, penalty, 2 minute power play coming up                            😑
What a stick save!                                                          πŸ˜€
Leads the NFL in sacks with 9.5                                             πŸ˜€
UCF 38 Temple 13                                                            πŸ˜€
With the 30 yard completion, down to the 10 yard line                       πŸ˜€
Drains the 3pt shot!!, 0:15 remaining in the game                           πŸ˜€
Intercepted! Drives down the court and shoots for the win                   πŸ˜€
Massive dunk!!! they are now up by 15 with 2 minutes to go                  πŸ˜€
Β