{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "i72MPXVuA7Is" }, "source": [ "# Fine-tune ColPali for Multimodal RAG\n", "\n", "ColPali is a document-text query similarity model based on dual encoder architecture. The model is based on PaliGemma (image-encoder text decoder based image tower) and Gemma (decoder-only text tower) where image and text outputs of each model are projected to a joint space. The similarity is calculated between the projected embeddings and documents and texts are matched based on maximum similarity between them. This approach itself as of now is state-of-the-art for document retrieval. In this notebook we will fine-tune ColPali.\n", "\n", "This model can be used for any application where you would like to build document retrieval pipelines, including multimodal RAG. Normally for document retrieval, you would transcribe a complex document using brittle PDF parsers that include image captioner, table-to-markdown readers and OCR models. ColPali-like models remove the need for such brittle and slow pipelines.\n", "\n", "\n", "\n", "This notebook is a very minimal example to fine-tune ColPali on [UFO documents and queries](https://huggingface.co/datasets/davanstrien/ufo-Colpali) a dataset synthetically generated, if you want to generate a dataset of your own you can read [this blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html). Then we will show a very minimal example on how to retrieve infographics. There's a notebook that already exists on how to fine-tune ColPali by Tony Wu, but this notebook is showing a much similar example and is focused on using transformers implementation of ColPali.\n", "\n", "We need to install transformers from main, peft and bitsandbytes for QLoRA and colpali-engine for contrastive loss implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xnh1fcmsbbzu", "outputId": "34a89b75-8321-45d3-a42d-43c761abf084" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m69.1/69.1 MB\u001b[0m \u001b[31m32.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m68.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.3/527.3 kB\u001b[0m \u001b[31m36.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m104.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.6/177.6 kB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m287.4/287.4 kB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for gputil (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.6.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q git+https://github.com/huggingface/transformers.git colpali-engine[train] peft bitsandbytes accelerate" ] }, { "cell_type": "markdown", "metadata": { "id": "ORqOHFwzlTDv" }, "source": [ "We will login to Hugging Face as ColPali itself is based on PaliGemma which has Gemma license. Make sure to agree to Gemma license once in a model repository that is gated by Gemma license before running this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17, "referenced_widgets": [ "8e7b9370e14444ed8915ad413282d183", "8bd2e24a21d94b7aae1c70b649144b06", "c5fe090bb03e4e718f650a1096e9f007", "7345c3de2447467aa84028e323c569c0", "62642a4e87934fb490e62b0da64836bc", "e354bfcef92b45538b2a0a24d037c01e", "63cb2f1a8f844482bd10993f10ab08f2", "5fb50f40f4e24a3abbc7583fa86ed5ca", "32064c8580114d4a8d04eadc0a65dd95", "dea769f66fee4558892a0090d87478ae", "ca287043512648d391772a638c1858d1", "5a5c3cd017204a4b9b4e0a1fcc425ebf", "24747331cd9a4c3eaf7ab8a9ef0d9c28", "6d96e76d746d4a5e92d381e9153d237b", "11fab766798d4d6d8fcdd2480ac7eddf", "f800d99a2c3e4608ac8923800ab69264", "2d8df73745324a8bb95113c498928af7", "36d31ecc97be4e43962c79b2c66cbc84", "0b3f0bbbe80648738fa37b322bd4f545", "d0062151b3c0418894bc2c32ed055774" ] }, "id": "TzUccPzV9ZE2", "outputId": "5b83d3b7-a91e-4492-e2de-7b3b9103d36f" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e7b9370e14444ed8915ad413282d183", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "
---|---|
20 | \n", "0.142100 | \n", "
40 | \n", "0.114500 | \n", "
60 | \n", "0.103400 | \n", "
80 | \n", "0.096000 | \n", "
100 | \n", "0.106300 | \n", "
120 | \n", "0.120000 | \n", "
"
],
"text/plain": [
"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.