{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "OZuBIVRs18lP" }, "source": [ "Inspired by: https://huggingface.co/blog/fine-tune-vit" ] }, { "cell_type": "markdown", "metadata": { "id": "rwb1Sdi0QASD" }, "source": [ "# Fine-Tuning Vision Transformers for Image Classification\n", "\n", "Just as transformers-based models have revolutionized NLP, we're now seeing an explosion of papers applying them to all sorts of other domains. One of the most revolutionary of these was the Vision Transformer (ViT), which was introduced in [June 2021](https://arxiv.org/abs/2010.11929) by a team of researchers at Google Brain.\n", "\n", "This paper explored how you can tokenize images, just as you would tokenize sentences, so that they can be passed to transformer models for training. Its quite a simple concept, really...\n", "\n", "1. Split an image into a grid of sub-image patches\n", "1. Embed each patch with a linear projection\n", "1. Each embedded patch becomes a token, and the resulting sequence of embedded patches is the sequence you pass to the model.\n", "\n", "\n", "\n", "\n", "It turns out that once you've done the above, you can pre-train and finetune transformers just as you're used to with NLP tasks. Pretty sweet π." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "\n", "! pip install datasets transformers evaluate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset_name = \"jonathan-roberts1/Satellite-Images-of-Hurricane-Damage\"\n", "\n", "def get_ds():\n", " ds = load_dataset(dataset_name)\n", " ds = ds[\"train\"].train_test_split(test_size=0.5)\n", " ds[\"train\"][\"label\"].count(1), ds[\"test\"][\"label\"].count(0)\n", " ds_ = ds[\"test\"].train_test_split(test_size=0.5)\n", " ds[\"validation\"] = ds_[\"train\"]\n", " ds[\"test\"] = ds_[\"test\"]\n", " return ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = get_ds()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = ds['train'][400]['image']" ] }, { "cell_type": "markdown", "metadata": { "id": "ooXr_55XICXM" }, "source": [ "## Loading ViT Feature Extractor\n", "\n", "Now that we know what our images look like and have a better understanding of the problem we're trying to solve, let's see how we can prepare these images for our model.\n", "\n", "When ViT models are trained, specific transformations are applied to images being fed into them. Use the wrong transformations on your image and the model won't be able to understand what it's seeing! πΌ β‘οΈ π’\n", "\n", "To make sure we apply the correct transformations, we will use a [`ViTFeatureExtractor`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=classlabel#datasets.ClassLabel.int2str) initialized with a configuration that was saved along with the pretrained model we plan to use. In our case, we'll be using the [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) model, so lets load its feature extractor from the π€ Hub." ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 103, "referenced_widgets": [ "57012c0224244c3080261b8d0ab34ce8", "2e7f50f2e01049978d4094f2353cdd0c", "0c5cb87527db42e8a992da9bb0976fab", "196ae5aee81345238928fc5a06f2faa8", "9758e18c65b44506aa8208504bc2cca6", "a2a454d1aa654002a52967f00343aa7c", "c73ebc46be0f49089a5cfb00e883fef8", "62695a863fb74564b636c1b891a24b8f", "4e0989c80f1f4494a60d3bb8f9569621", "8f62ab0b2f7442d581afe3909344c9b0", "d774a2342d3b438b99a00132e92c30a1" ] }, "executionInfo": { "elapsed": 14251, "status": "ok", "timestamp": 1734165764141, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "Ct6zPRixIUoI", "outputId": "feb48e50-3c67-467e-81d9-77c3937a2af3" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "57012c0224244c3080261b8d0ab34ce8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "preprocessor_config.json: 0%| | 0.00/160 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/models/vit/feature_extraction_vit.py:28: FutureWarning: The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ViTImageProcessor instead.\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import ViTFeatureExtractor\n", "\n", "model_name_or_path = 'google/vit-base-patch16-224-in21k'\n", "feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "EbGIOc_FIbU7" }, "source": [ "If we print a feature extractor, we can see its configuration." ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1734165764142, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "ea22sEWLIg4e", "outputId": "fb0fbea4-d61e-4f09-cea6-55d43ae7785c" }, "outputs": [ { "data": { "text/plain": [ "ViTFeatureExtractor {\n", " \"do_normalize\": true,\n", " \"do_rescale\": true,\n", " \"do_resize\": true,\n", " \"image_mean\": [\n", " 0.5,\n", " 0.5,\n", " 0.5\n", " ],\n", " \"image_processor_type\": \"ViTFeatureExtractor\",\n", " \"image_std\": [\n", " 0.5,\n", " 0.5,\n", " 0.5\n", " ],\n", " \"resample\": 2,\n", " \"rescale_factor\": 0.00392156862745098,\n", " \"size\": {\n", " \"height\": 224,\n", " \"width\": 224\n", " }\n", "}" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_extractor" ] }, { "cell_type": "markdown", "metadata": { "id": "DmhC6aSKIics" }, "source": [ "To process an image, simply pass it to the feature extractor's call function. This will return a dict containing `pixel values`, which is the numeric representation of your image that we'll pass to the model.\n", "\n", "We get a numpy array by default, but if we add the `return_tensors='pt'` argument, we'll get back `torch` tensors instead.\n" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 251, "status": "ok", "timestamp": 1734165764389, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "3lne4VrXJRIe", "outputId": "4d1dc59e-8480-45cd-c011-7f126ef54d6d" }, "outputs": [ { "data": { "text/plain": [ "{'pixel_values': tensor([[[[-0.4118, -0.4039, -0.3882, ..., -0.1137, -0.1137, -0.1137],\n", " [-0.4353, -0.4275, -0.4118, ..., -0.1294, -0.1373, -0.1373],\n", " [-0.4745, -0.4667, -0.4431, ..., -0.1608, -0.1686, -0.1686],\n", " ...,\n", " [-0.6078, -0.6000, -0.5765, ..., -0.5294, -0.5373, -0.5451],\n", " [-0.6000, -0.6000, -0.5843, ..., -0.5216, -0.5294, -0.5373],\n", " [-0.6000, -0.6000, -0.5922, ..., -0.5216, -0.5294, -0.5373]],\n", "\n", " [[-0.3255, -0.3176, -0.3020, ..., -0.0824, -0.0824, -0.0824],\n", " [-0.3490, -0.3412, -0.3255, ..., -0.0980, -0.1059, -0.1059],\n", " [-0.3882, -0.3804, -0.3569, ..., -0.1294, -0.1373, -0.1373],\n", " ...,\n", " [-0.6000, -0.5922, -0.5686, ..., -0.4118, -0.4196, -0.4275],\n", " [-0.5922, -0.5922, -0.5765, ..., -0.4039, -0.4118, -0.4196],\n", " [-0.5922, -0.5922, -0.5843, ..., -0.4039, -0.4118, -0.4196]],\n", "\n", " [[-0.4510, -0.4431, -0.4275, ..., -0.1922, -0.1922, -0.1922],\n", " [-0.4745, -0.4667, -0.4510, ..., -0.2078, -0.2157, -0.2157],\n", " [-0.5137, -0.5059, -0.4824, ..., -0.2392, -0.2471, -0.2471],\n", " ...,\n", " [-0.7412, -0.7333, -0.7098, ..., -0.6392, -0.6471, -0.6549],\n", " [-0.7333, -0.7333, -0.7176, ..., -0.6314, -0.6392, -0.6471],\n", " [-0.7333, -0.7333, -0.7255, ..., -0.6314, -0.6392, -0.6471]]]])}" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_extractor(image, return_tensors='pt')" ] }, { "cell_type": "markdown", "metadata": { "id": "ujbbcaIPJiAW" }, "source": [ "## Processing the Dataset\n", "\n", "Now that we know how to read in images and transform them into inputs, let's write a function that will put those two things together to process a single example from the dataset." ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1734165764389, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "0U48pAEuMLQh" }, "outputs": [], "source": [ "def process_example(example):\n", " inputs = feature_extractor(example['image'], return_tensors='pt')\n", " inputs['label'] = example['label']\n", " return inputs" ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1734165764389, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "WmIilnQ-MbhG", "outputId": "0555b530-3ef0-4563-db49-2f9c4aad474b" }, "outputs": [ { "data": { "text/plain": [ "{'pixel_values': tensor([[[[-0.3569, -0.3490, -0.3412, ..., -0.5294, -0.5137, -0.5059],\n", " [-0.3412, -0.3333, -0.3255, ..., -0.5294, -0.5216, -0.5137],\n", " [-0.3176, -0.3098, -0.3020, ..., -0.5373, -0.5294, -0.5216],\n", " ...,\n", " [-0.5686, -0.5765, -0.5843, ..., -0.4353, -0.4431, -0.4510],\n", " [-0.5686, -0.5686, -0.5765, ..., -0.3882, -0.3882, -0.3961],\n", " [-0.5686, -0.5686, -0.5686, ..., -0.3569, -0.3569, -0.3569]],\n", "\n", " [[-0.3961, -0.3882, -0.3804, ..., -0.4196, -0.4039, -0.3961],\n", " [-0.3804, -0.3725, -0.3647, ..., -0.4196, -0.4118, -0.4039],\n", " [-0.3569, -0.3490, -0.3412, ..., -0.4275, -0.4196, -0.4118],\n", " ...,\n", " [-0.3882, -0.3961, -0.4039, ..., -0.3490, -0.3569, -0.3647],\n", " [-0.3882, -0.3882, -0.3961, ..., -0.3020, -0.3020, -0.3098],\n", " [-0.3882, -0.3882, -0.3882, ..., -0.2706, -0.2706, -0.2706]],\n", "\n", " [[-0.5686, -0.5608, -0.5529, ..., -0.6235, -0.6078, -0.6000],\n", " [-0.5529, -0.5451, -0.5373, ..., -0.6235, -0.6157, -0.6078],\n", " [-0.5294, -0.5216, -0.5137, ..., -0.6314, -0.6235, -0.6157],\n", " ...,\n", " [-0.6157, -0.6235, -0.6314, ..., -0.4824, -0.4902, -0.4902],\n", " [-0.6157, -0.6157, -0.6235, ..., -0.4275, -0.4353, -0.4353],\n", " [-0.6157, -0.6157, -0.6157, ..., -0.3961, -0.3961, -0.3961]]]]), 'label': 0}" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "process_example(ds['train'][0])" ] }, { "cell_type": "markdown", "metadata": { "id": "Fusnj3EHMk5g" }, "source": [ "While we could call `ds.map` and apply this to every example at once, this can be very slow, especially if you use a larger dataset. Instead, we'll apply a ***transform*** to the dataset. Transforms are only applied to examples as you index them.\n", "\n", "First, though, we'll need to update our last function to accept a batch of data, as that's what `ds.with_transform` expects." ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "executionInfo": { "elapsed": 8, "status": "ok", "timestamp": 1734165773522, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "Z_sF61AoM3X1" }, "outputs": [], "source": [ "def transform(example_batch):\n", " # Take a list of PIL images and turn them to pixel values\n", " inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')\n", "\n", " # Don't forget to include the labels!\n", " inputs['label'] = example_batch['label']\n", " return inputs\n", "\n", "prepared_ds = ds.with_transform(transform)" ] }, { "cell_type": "markdown", "metadata": { "id": "p31_fIQ3N5ej" }, "source": [ "We can directly apply this to our dataset using `ds.with_transform(transform)`." ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1734165773522, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "VlCsAUJxOlZy" }, "outputs": [], "source": [ "prepared_ds = ds.with_transform(transform)" ] }, { "cell_type": "markdown", "metadata": { "id": "_Xng7C3pOq9Q" }, "source": [ "Now, whenever we get an example from the dataset, our transform will be\n", "applied in real time (on both samples and slices, as shown below)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1734165773522, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "SZEwL06H9IQr", "outputId": "841f9c46-dabb-4d0b-de7a-e7b94ab21016" }, "outputs": [ { "data": { "text/plain": [ "{'pixel_values': tensor([[[[-0.0275, -0.0118, 0.0118, ..., 0.1216, 0.0980, 0.0824],\n", " [-0.0196, -0.0039, 0.0196, ..., 0.1137, 0.0980, 0.0902],\n", " [-0.0039, 0.0118, 0.0431, ..., 0.0980, 0.0980, 0.0980],\n", " ...,\n", " [-0.1765, -0.1059, 0.0196, ..., 0.0275, 0.0118, 0.0039],\n", " [-0.2000, -0.1686, -0.1059, ..., 0.0275, 0.0353, 0.0353],\n", " [-0.2157, -0.2078, -0.1843, ..., 0.0275, 0.0510, 0.0588]],\n", "\n", " [[-0.0431, -0.0275, -0.0039, ..., 0.1137, 0.0902, 0.0745],\n", " [-0.0353, -0.0196, 0.0039, ..., 0.1059, 0.0902, 0.0824],\n", " [-0.0196, -0.0039, 0.0275, ..., 0.0902, 0.0902, 0.0902],\n", " ...,\n", " [-0.1843, -0.1137, 0.0118, ..., 0.0588, 0.0431, 0.0353],\n", " [-0.2078, -0.1765, -0.1137, ..., 0.0588, 0.0667, 0.0667],\n", " [-0.2235, -0.2157, -0.1922, ..., 0.0588, 0.0824, 0.0902]],\n", "\n", " [[-0.2235, -0.2078, -0.1843, ..., -0.1294, -0.1529, -0.1686],\n", " [-0.2157, -0.2000, -0.1765, ..., -0.1373, -0.1529, -0.1608],\n", " [-0.2000, -0.1843, -0.1529, ..., -0.1451, -0.1529, -0.1529],\n", " ...,\n", " [-0.3255, -0.2549, -0.1294, ..., -0.0510, -0.0745, -0.0745],\n", " [-0.3490, -0.3176, -0.2549, ..., -0.0510, -0.0431, -0.0431],\n", " [-0.3647, -0.3569, -0.3333, ..., -0.0510, -0.0275, -0.0196]]],\n", "\n", "\n", " [[[-0.4431, -0.4275, -0.4039, ..., 0.1373, 0.1843, 0.2157],\n", " [-0.4353, -0.4275, -0.4039, ..., 0.1216, 0.1686, 0.2000],\n", " [-0.4275, -0.4196, -0.3961, ..., 0.0980, 0.1451, 0.1765],\n", " ...,\n", " [-0.5529, -0.5529, -0.5451, ..., -0.6078, -0.5843, -0.5765],\n", " [-0.5843, -0.5843, -0.5765, ..., -0.6000, -0.5686, -0.5529],\n", " [-0.6078, -0.6078, -0.6000, ..., -0.6000, -0.5608, -0.5373]],\n", "\n", " [[-0.3647, -0.3490, -0.3255, ..., 0.0980, 0.1373, 0.1686],\n", " [-0.3569, -0.3490, -0.3255, ..., 0.0824, 0.1216, 0.1529],\n", " [-0.3490, -0.3412, -0.3176, ..., 0.0588, 0.0980, 0.1294],\n", " ...,\n", " [-0.4667, -0.4667, -0.4588, ..., -0.5373, -0.5137, -0.5059],\n", " [-0.4980, -0.4980, -0.4902, ..., -0.5294, -0.4980, -0.4824],\n", " [-0.5216, -0.5216, -0.5137, ..., -0.5294, -0.4902, -0.4667]],\n", "\n", " [[-0.6314, -0.6157, -0.5922, ..., -0.0745, -0.0353, -0.0039],\n", " [-0.6235, -0.6157, -0.5922, ..., -0.0902, -0.0510, -0.0196],\n", " [-0.6157, -0.6078, -0.5843, ..., -0.1137, -0.0745, -0.0431],\n", " ...,\n", " [-0.7333, -0.7333, -0.7255, ..., -0.6863, -0.6627, -0.6549],\n", " [-0.7647, -0.7647, -0.7569, ..., -0.6784, -0.6471, -0.6314],\n", " [-0.7882, -0.7882, -0.7804, ..., -0.6784, -0.6392, -0.6157]]]]), 'label': [1, 1]}" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prepared_ds['train'][0:2]" ] }, { "cell_type": "markdown", "metadata": { "id": "4ngOi7XiTFpB" }, "source": [ "# Training and Evaluation\n", "\n", "The data is processed and we are ready to start setting up the training pipeline. We will make use of π€'s Trainer, but that'll require us to do a few things first:\n", "\n", "- Define a collate function.\n", "\n", "- Define an evaluation metric. During training, the model should be evaluated on its prediction accuracy. We should define a compute_metrics function accordingly.\n", "\n", "- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.\n", "\n", "- Define the training configuration.\n", "\n", "After having fine-tuned the model, we will correctly evaluate it on the evaluation data and verify that it has indeed learned to correctly classify our images." ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17, "referenced_widgets": [ "80c2e52a73ba4d229efa479a04004c96", "cf3b95476d37406280d277739dc9d478", "2c1c56f0e3154c61b0da3884923dc13d", "181a76ecbad14b25a66d4a4a593271c2", "fd1b883f9a4b4947896d43470feecee6", "4bc343824e404ac0b33d33da97777abf", "c47eaf6d663f41e98de253b70454f11f", "837609239b4143979482857cc43710d8", "7ed6568e7a7b4cf1b7b895ff485a75ac", "282a26f0ae3b4c2e9f5bb369933e9b4a", "0a8f2fb303744d8c8253050b99b5f885", "20ffdb1a2d8644e6aa817dcd29375a11", "b475a8691ee54bc1b59f0aa441dd1db5", "5ea9aa98624e4ccfa7f7582ff666fcf5", "5455b872cbad468c936cf712d7734f0d", "a731040efce144308aae0bfd7cd4ad24", "f5a35dba3ff340fdb0a12b2d5741d7ad", "e4b82a180f7b4387a75cdbdedea814b7", "2485c0cb825340de8e12cb408cdda7c4", "e5135f975329493594fa5a406636a632" ] }, "executionInfo": { "elapsed": 219, "status": "ok", "timestamp": 1734165773735, "user": { "displayName": "Till Wenke", "userId": "10971785981473027459" }, "user_tz": -60 }, "id": "omHT-thePyhn", "outputId": "b86d4f65-dbd7-4e94-afdb-648ec0409a0f" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "80c2e52a73ba4d229efa479a04004c96", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Model Preparation Time | \n", "Accuracy | \n", "
---|---|---|---|---|
100 | \n", "0.111800 | \n", "0.148582 | \n", "0.005100 | \n", "0.947600 | \n", "
200 | \n", "0.111200 | \n", "0.070119 | \n", "0.005100 | \n", "0.975200 | \n", "
300 | \n", "0.069400 | \n", "0.060849 | \n", "0.005100 | \n", "0.980800 | \n", "
400 | \n", "0.004800 | \n", "0.091668 | \n", "0.005100 | \n", "0.974400 | \n", "
500 | \n", "0.036000 | \n", "0.055198 | \n", "0.005100 | \n", "0.983600 | \n", "
600 | \n", "0.059400 | \n", "0.054691 | \n", "0.005100 | \n", "0.980800 | \n", "
700 | \n", "0.011500 | \n", "0.062730 | \n", "0.005100 | \n", "0.984400 | \n", "
800 | \n", "0.001600 | \n", "0.029573 | \n", "0.005100 | \n", "0.993600 | \n", "
900 | \n", "0.004000 | \n", "0.032514 | \n", "0.005100 | \n", "0.991600 | \n", "
1000 | \n", "0.000900 | \n", "0.022371 | \n", "0.005100 | \n", "0.994800 | \n", "
1100 | \n", "0.000800 | \n", "0.027039 | \n", "0.005100 | \n", "0.993600 | \n", "
1200 | \n", "0.000800 | \n", "0.025595 | \n", "0.005100 | \n", "0.994000 | \n", "
"
],
"text/plain": [
"