SetFit documentation

Zero-shot Text Classification

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Zero-shot Text Classification

Your class names are likely already good descriptors of the text that you’re looking to classify. With πŸ€— SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.

This guide will show you how to perform zero-shot text classification.

Testing dataset

We’ll use the dair-ai/emotion dataset to test the performance of our zero-shot model.

from datasets import load_dataset

test_dataset = load_dataset("dair-ai/emotion", "split", split="test")

This dataset stores the class names within the dataset Features, so we’ll extract the classes like so:

classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Otherwise, we could manually set the list of classes.

Synthetic dataset

Then, we can use get_templated_dataset() to synthetically generate a dummy dataset given these class names.

from setfit import get_templated_dataset

train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
#     features: ['text', 'label'],
#     num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}

Training

We can use this dataset to train a SetFit model just like normal:

from setfit import SetFitModel, Trainer, TrainingArguments

model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

args = TrainingArguments(
    batch_size=32,
    num_epochs=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
  Num examples = 60
  Num epochs = 1
  Total optimization steps = 60
  Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}                                                                                 
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}                                                                                 
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}                                                     
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 60/60 [00:09<00:00,  6.35it/s]

Once trained, we can evaluate the model:

metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}

And run predictions:

preds = model.predict([
    "i am just feeling cranky and blue",
    "i feel incredibly lucky just to be able to talk to her",
    "you're pissing me off right now",
    "i definitely have thalassophobia, don't get me near water like that",
    "i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']

These predictions all look right!

Baseline

To show that the zero-shot performance of SetFit works well, we’ll compare it against a zero-shot classification model from transformers.

from transformers import pipeline
from datasets import load_dataset
import evaluate

# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names

# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]

# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}

With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by transformers.

Prediction latency

Beyond getting higher accuracies, SetFit is much faster too. Let’s compute the latency of SetFit with BAAI/bge-small-en-v1.5 versus the latency of transformers with facebook/bart-large-mnli. Both tests were performed on a GPU.

import time

start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time

start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence

So, SetFit with BAAI/bge-small-en-v1.5 is 67x faster than transformers with facebook/bart-large-mnli, alongside being more accurate:

zero_shot_transformers_vs_setfit

< > Update on GitHub