{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "initial_id", "metadata": { "collapsed": true }, "outputs": [], "source": [ "import os\n", "from typing import Dict, List\n", "\n", "IS_COLAB = True if \"GOOGLE_CLOUD_PROJECT\" in os.environ else False\n", "if IS_COLAB:\n", " # this needs to run before all other imports\n", " os.environ[\"HF_HOME\"] = \"/content/cache/\" # to avoid running out of disk space\n", "\n", "import mteb\n", "import numpy as np\n", "import torch\n", "from mteb.encoder_interface import PromptType\n", "from sentence_transformers import SentenceTransformer" ] }, { "cell_type": "markdown", "id": "5325acfb", "metadata": {}, "source": [ "### Notebook Configuration" ] }, { "cell_type": "code", "execution_count": null, "id": "f0189ff1e7814a5a", "metadata": {}, "outputs": [], "source": [ "MODELS = {\n", " \"ir-prod\": {\n", " \"name\": \"MongoDB/mdbr-leaf-ir\",\n", " \"revision\": \"2e46f5aac796e621d51f678c306a66ede4712ecb\",\n", " \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n", " },\n", " \"ir-paper\": {\n", " \"name\": \"MongoDB/mdbr-leaf-ir\",\n", " \"revision\": \"ea98995e96beac21b820aa8ad9afaa6fd29b243d\",\n", " \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n", " },\n", " \"mt-prod\": {\n", " \"name\": \"MongoDB/mdbr-leaf-mt\",\n", " \"revision\": \"66c47ba6d753efc208d54412b5af6c744a39a4df\",\n", " \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n", " },\n", " \"mt-paper\": {\n", " \"name\": \"MongoDB/mdbr-leaf-mt\",\n", " \"revision\": \"c342f945a6855346bd5f48d5ee8b7e39120b0ce9\",\n", " \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n", " },\n", "}" ] }, { "cell_type": "markdown", "id": "371c6122efdf476a", "metadata": {}, "source": [ "In the cell below:\n", "* set the output folder and\n", "* select one of the models defined above\n", "* desired benchmark" ] }, { "cell_type": "code", "execution_count": null, "id": "58d52a330febb9ac", "metadata": {}, "outputs": [], "source": [ "# output_folder = f\"../../data/results/publish/\"\n", "output_folder = f\"/content/data/results/publish/\"\n", "\n", "model_selection = MODELS[\"ir-prod\"]\n", "benchmark_name = \"BEIR\"\n", "\n", "# model_selection = MODELS['mt-prod']\n", "# benchmark_name = \"MTEB(eng, v2)\"" ] }, { "cell_type": "markdown", "id": "1b4367afc1278e", "metadata": {}, "source": [ "### Run Evals" ] }, { "cell_type": "code", "execution_count": null, "id": "c716c6344f9cd939", "metadata": {}, "outputs": [], "source": [ "benchmark = mteb.get_benchmark(benchmark_name)\n", "evaluation = mteb.MTEB(tasks=benchmark)" ] }, { "cell_type": "code", "execution_count": null, "id": "d6f13945a94f7a85", "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n", "\n", "# alternative:\n", "# meta = mteb.get_model_meta(\n", "# model_name=model_selection['name'],\n", "# revision=model_selection['revision']\n", "# )\n", "# model = meta.load_model()" ] }, { "cell_type": "code", "execution_count": null, "id": "9bd44e88fc360663", "metadata": {}, "outputs": [], "source": [ "%%time\n", "results = evaluation.run(\n", " model=model,\n", " verbosity=1,\n", " output_folder=output_folder,\n", " overwrite_results=True,\n", ")" ] }, { "cell_type": "markdown", "id": "733e52ca41cf92a7", "metadata": {}, "source": [ "Evaluate Quora" ] }, { "cell_type": "code", "execution_count": null, "id": "61aea9a04468202f", "metadata": {}, "outputs": [], "source": [ "if model_selection[\"name\"].endswith(\"ir\"):\n", " # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n", " # we thus embed them without the typical query prompt\n", " model.prompts = {}\n", " tasks = mteb.get_tasks(\n", " tasks=[\n", " \"QuoraRetrieval\",\n", " ]\n", " )\n", "\n", " evaluation = mteb.MTEB(tasks=tasks)\n", " results = evaluation.run(\n", " model=model,\n", " verbosity=1,\n", " output_folder=output_folder,\n", " overwrite_results=True,\n", " )" ] }, { "cell_type": "markdown", "id": "6a6c164e", "metadata": {}, "source": [ "### Asymmetric Mode\n", "\n", "Compute asymmetric mode scores: queries encoded by `leaf`, documents by the original teacher model." ] }, { "cell_type": "code", "execution_count": null, "id": "487ba349", "metadata": {}, "outputs": [], "source": [ "class AsymmetricModel:\n", " def __init__(\n", " self,\n", " doc_model: SentenceTransformer,\n", " query_model: SentenceTransformer,\n", " ) -> None:\n", " self.doc_model = doc_model\n", " self.query_model = query_model\n", "\n", " def encode(self, sentences: List[str], **kwargs) -> np.ndarray | torch.Tensor:\n", " if \"prompt_type\" not in kwargs:\n", " kwargs[\"prompt_type\"] = None\n", "\n", " match kwargs[\"prompt_type\"]:\n", " case PromptType.query:\n", " out = self.query_model.encode(sentences, prompt_name=\"query\", **kwargs)\n", "\n", " case PromptType.document:\n", " out = self.doc_model.encode(sentences, **kwargs)\n", "\n", " case None:\n", " print(\"No prompt type: using query (leaf) model for encoding\")\n", " out = self.query_model.encode(sentences, **kwargs)\n", " case _:\n", " raise ValueError(f\"Encoding unknown type: {kwargs['prompt_type']}\")\n", "\n", " if not isinstance(out, torch.Tensor):\n", " out = torch.from_numpy(out)\n", "\n", " out = out.to(\"cpu\")\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "id": "4162af7f", "metadata": {}, "outputs": [], "source": [ "leaf = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n", "teacher = SentenceTransformer(model_selection[\"teacher\"])\n", "\n", "asymm_model = AsymmetricModel(\n", " query_model=leaf,\n", " doc_model=teacher,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "848d8a5f", "metadata": {}, "outputs": [], "source": [ "%%time\n", "results = evaluation.run(\n", " model=asymm_model,\n", " verbosity=1,\n", " output_folder=output_folder,\n", " overwrite_results=True,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }