Knowledge Distillation
If you have access to unlabeled data, then you can use knowledge distillation to improve the performance of your small SetFit model. The approach involves training a larger model and using unlabeled data to distil its performance into your smaller SetFit model. As a result, your SetFit model will become stronger.
Additionally, you can also use knowledge distillation to replace your trained SetFit model with a more efficient model at less of a performance decrease.
This guide will show you how to proceed with knowledge distillation.
Data preparation
Letβs consider a scenario with a little bit of labeled training data (e.g. 64 sentences). We will simulate this scenario using the ag_news dataset for this guide.
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
Baseline model
We can use standard SetFit training approach to prepare a model.
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
***** Running training *****
Num examples = 48
Num epochs = 5
Total optimization steps = 240
Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 240/240 [00:20<00:00, 11.97it/s]
***** Running evaluation *****
{'accuracy': 0.7818421052631579}
This model reaches 78.18% on our dataset. Certainly respectable given the tiny amount of training data, but we can use knowledge distillation to squeeze more performance out of our model.
Unlabeled Data Preparation
Alongside our labeled training data, we may als have a lot of unlabeled training data (e.g. 500 sentences). Letβs prepare it:
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
Teacher model
Then, we will prepare a larger trained SetFit model that will act as the teacher to our smaller student model. The strong sentence-transformers/paraphrase-mpnet-base-v2
Sentence Transformer model will be used to initialize the SetFit model.
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
We need to train this model on the labeled dataset first:
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
Num examples = 192
Num epochs = 2
Total optimization steps = 384
Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 384/384 [01:24<00:00, 4.55it/s]
***** Running evaluation *****
{'accuracy': 0.8378947368421052}
This large teacher model reaches 83.79%, which is quite strong for this little data, and noticeably, stronger than the 78.18% from our smaller (but more efficient) model.
Knowledge Distillation
The performance from the stronger teacher_model can be distilled into the smaller model using the DistillationTrainer. It accepts a teacher and a student model, as well as an unlabeled dataset.
Note that this trainer uses pairs between sentences as the training samples, so the number of training steps grows exponentially to the number of unlabeled examples. To avoid overfitting, consider setting max_steps
relatively low.
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
Num examples = 7829
Num epochs = 1
Total optimization steps = 7829
Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 500/500 [00:22<00:00, 22.45it/s]
***** Running evaluation *****
{'accuracy': 0.8084210526315789}
Using knowledge distillation, we were able to improve our model from 78.18% to 80.84% in a few minutes of training.
End-to-end
This snippet shows the entire knowledge distillation strategy in an end-to-end example:
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)