Transformers documentation

Knowledge Distillation for Computer Vision

You are viewing v4.40.2 version. A newer version v4.46.2 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Knowledge Distillation for Computer Vision

知識の蒸留は、より大規模で複雑なモデル (教師) からより小規模で単純なモデル (生徒) に知識を伝達するために使用される手法です。あるモデルから別のモデルに知識を抽出するには、特定のタスク (この場合は画像分類) でトレーニングされた事前トレーニング済み教師モデルを取得し、画像分類でトレーニングされる生徒モデルをランダムに初期化します。次に、学生モデルをトレーニングして、その出力と教師の出力の差を最小限に抑え、動作を模倣します。これは Distilling the Knowledge in a Neural Network by Hinton et al で最初に導入されました。このガイドでは、タスク固有の知識の蒸留を行います。これには Beans データセット を使用します。

このガイドでは、微調整された ViT モデル (教師モデル) を抽出して MobileNet (学生モデル) 🤗 Transformers の Trainer API を使用します。

蒸留とプロセスの評価に必要なライブラリをインストールしましょう。

pip install transformers datasets accelerate tensorboard evaluate --upgrade

この例では、教師モデルとしてmerve/beans-vit-224モデルを使用しています。これは、Bean データセットに基づいて微調整されたgoogle/vit-base-patch16-224-in21kに基づく画像分類モデルです。このモデルをランダムに初期化された MobileNetV2 に抽出します。

次に、データセットをロードします。

from datasets import load_dataset

dataset = load_dataset("beans")

この場合、同じ解像度で同じ出力が返されるため、どちらのモデルの画像プロセッサも使用できます。 datasetmap()メソッドを使用して、データセットのすべての分割に前処理を適用します。

from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

processed_datasets = dataset.map(process, batched=True)

基本的に、我々は生徒モデル(ランダムに初期化されたMobileNet)が教師モデル(微調整されたビジョン変換器)を模倣することを望む。これを実現するために、まず教師と生徒からロジット出力を得る。次に、それぞれのソフトターゲットの重要度を制御するパラメータtemperatureで分割する。lambdaと呼ばれるパラメータは蒸留ロスの重要度を量る。この例では、temperature=5lambda=0.5とする。生徒と教師の間の発散を計算するために、Kullback-Leibler発散損失を使用します。2つのデータPとQが与えられたとき、KLダイバージェンスはQを使ってPを表現するためにどれだけの余分な情報が必要かを説明します。もし2つが同じであれば、QからPを説明するために必要な他の情報はないので、それらのKLダイバージェンスはゼロになります。

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F


class ImageDistilTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

次に、Hugging Face Hub にログインして、trainerを通じてモデルを Hugging Face Hub にプッシュできるようにします。

from huggingface_hub import notebook_login

notebook_login()

教師モデルと生徒モデルであるTrainingArgumentsを設定しましょう。

from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

training_args = TrainingArguments(
    output_dir="my-awesome-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

compute_metrics 関数を使用して、テスト セットでモデルを評価できます。この関数は、トレーニング プロセス中にモデルのaccuracyf1を計算するために使用されます。

import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

定義したトレーニング引数を使用してTrainerを初期化しましょう。データ照合装置も初期化します。

from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    training_args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    tokenizer=teacher_extractor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)

これでモデルをトレーニングできるようになりました。

trainer.train()

テスト セットでモデルを評価できます。

trainer.evaluate(processed_datasets["test"])

テスト セットでは、モデルの精度は 72% に達します。蒸留効率の健全性チェックを行うために、同じハイパーパラメータを使用して Bean データセットで MobileNet を最初からトレーニングし、テスト セットで 63% の精度を観察しました。読者の皆様には、さまざまな事前トレーニング済み教師モデル、学生アーキテクチャ、蒸留パラメータを試していただき、その結果を報告していただくようお勧めします。抽出されたモデルのトレーニング ログとチェックポイントは このリポジトリ にあり、最初からトレーニングされた MobileNetV2 はこの リポジトリ

< > Update on GitHub