{ "cells": [ { "cell_type": "markdown", "id": "0c4ecb49-ce58-4b65-849a-760980576e48", "metadata": {}, "source": [ "# Poro34B Lora fine-tuning with S-Group's data" ] }, { "cell_type": "code", "execution_count": null, "id": "5b686006-65a7-43af-8207-1c7309a5e423", "metadata": {}, "outputs": [], "source": [ "# This script finetunes the Poro34B model with 185 Questions and Answers pair" ] }, { "cell_type": "markdown", "id": "defcdb6f-3b69-4b03-b2dc-07c4b3027fd6", "metadata": {}, "source": [ "## Initialization" ] }, { "cell_type": "code", "execution_count": 2, "id": "67f730e6-3467-4a19-ab76-e8baace8e02e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: peft in /opt/conda/lib/python3.10/site-packages (0.9.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from peft) (1.26.3)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from peft) (23.2)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from peft) (5.9.8)\n", "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from peft) (6.0.1)\n", "Requirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.10/site-packages (from peft) (2.0.0.post101)\n", "Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (from peft) (4.31.0)\n", "Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from peft) (4.66.1)\n", "Requirement already satisfied: accelerate>=0.21.0 in /opt/conda/lib/python3.10/site-packages (from peft) (0.21.0)\n", "Requirement already satisfied: safetensors in /opt/conda/lib/python3.10/site-packages (from peft) (0.3.3)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /opt/conda/lib/python3.10/site-packages (from peft) (0.20.2)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft) (3.13.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft) (2023.6.0)\n", "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft) (2.31.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft) (4.5.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft) (1.12)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->peft) (3.1.3)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers->peft) (2023.12.25)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.10/site-packages (from transformers->peft) (0.13.3)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft) (2.1.4)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft) (1.26.18)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft) (2023.11.17)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n" ] } ], "source": [ "# pip install peft, all other Python libraries are already in AWS image\n", "!pip install peft" ] }, { "cell_type": "code", "execution_count": 3, "id": "80b24df2-140b-4792-aaf1-6f6aff92ece8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-02-29 16:17:22.775245: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "import torch\n", "import json\n", "from transformers import AutoModelForCausalLM, AutoTokenizer \n", "from transformers import TrainingArguments, Trainer\n", "from transformers import pipeline\n", "from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "d31adfc6-a460-419e-871b-d0437501b026", "metadata": {}, "outputs": [], "source": [ "# this checks wether we have GPU\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "2c5a9b07-c92b-4d1d-b5b5-96e8c234e14f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "print(device)" ] }, { "cell_type": "markdown", "id": "6ea88a10-f5f1-4342-939b-60d2b9c5bb91", "metadata": {}, "source": [ "## Foundation model import" ] }, { "cell_type": "code", "execution_count": 5, "id": "2c0f7b3a-9d56-46ce-9dc8-5fe40b2628a6", "metadata": {}, "outputs": [], "source": [ "# Foundation model\n", "model_name='LumiOpen/Poro-34B'" ] }, { "cell_type": "code", "execution_count": 6, "id": "4e4c9089-a195-4fd7-91b2-6240cafb4989", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 7, "id": "a42e0fb6-40d4-483b-a034-84ff351c021d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34a67d537808415ab77b583333186ae3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/14 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/sagemaker-user/wandb/run-20240229_162718-2nyysvnx" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run trim-dust-9 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/timo-au-laine/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/timo-au-laine/huggingface/runs/2nyysvnx" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "You're using a BloomTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [19/20 5:10:38 < 18:16, 0.00 it/s, Epoch 1.53/2]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
12.511700
22.543000
32.566400
42.472700
52.500000
62.500000
72.625000
82.277300
92.359400
102.175800
112.293000
122.132800
131.974600
142.076200
151.869100
161.640600
171.769500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "1ed2cf09-3683-4016-88d9-9ada1ddb4345", "metadata": {}, "source": [ "## Saving the finetuned model" ] }, { "cell_type": "code", "execution_count": null, "id": "c37902bf-47e5-4f89-9128-a6b7d91cb437", "metadata": {}, "outputs": [], "source": [ "model_id185 = \"Poro-34B-Lora-185\"" ] }, { "cell_type": "code", "execution_count": null, "id": "163b54c4-3027-4e0d-9d52-7e3d698020da", "metadata": {}, "outputs": [], "source": [ "peft_model.save_pretrained(model_id185)" ] }, { "cell_type": "code", "execution_count": null, "id": "ec432db5-4f0c-43c7-b4e4-ef087f057bd0", "metadata": {}, "outputs": [], "source": [ "!ls -lh {model_id185} # Lora parameters file size" ] }, { "cell_type": "markdown", "id": "11460d4e-3e11-4fdb-b134-61b45bb84018", "metadata": {}, "source": [ "## Testing" ] }, { "cell_type": "code", "execution_count": 8, "id": "eb6a1213-a7ab-4bb5-8ffc-0e2666286dc6", "metadata": {}, "outputs": [], "source": [ "def generate_output(model, inputs, max_new_tokens=100):\n", " outputs = model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", " max_new_tokens=max_new_tokens,\n", " temperature=0.1,\n", " )\n", " return outputs" ] }, { "cell_type": "markdown", "id": "0a844312-2a1e-4c76-9078-96506b252522", "metadata": {}, "source": [ "### Original model" ] }, { "cell_type": "code", "execution_count": 9, "id": "d38bbed0-e938-43ef-b816-b5e0f9d066fd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Given the question delimited by triple backticks ```{ Mikä on osuusmaksu ja miksi se pitää maksaa? }```, what is the answer? Answer: Osuusmaksu on osuuskaupan osuusmaksu, joka maksetaan liittymisen yhteydessä. Osuusmaksu on sijoitus osuuskauppaan ja se palautetaan, kun jäsen eroaa osuuskaupasta. Osuusmaksun suuruus vaihtelee osuuskaupoittain. Osuusmaksun suuruus on 100 euroa. Osuusmaksu on sijoitus osuuskauppaan ja se palautetaan, kun jäsen eroaa osuuskaupasta. Osuusmaksun suuruus vaihtelee osuuskaupoittain. Osuusmaksun suuruus on 100 euroa. Osuusmaksu on sijoitus osuuskauppaan ja se palautetaan, kun jäsen eroaa osuuskaupasta. Osuusmaksun suuruus vaihtelee osuuskaupoittain.']\n" ] } ], "source": [ "prompt = tokenizer('Given the question delimited by triple backticks ```{ Mikä on osuusmaksu ja miksi se pitää maksaa? }```, what is the answer? Answer:', return_tensors=\"pt\")\n", "result = generate_output(model,prompt)\n", "print(tokenizer.batch_decode(result, skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": 10, "id": "d6091480-e399-4890-bd32-7a51d1cbb50f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Given the question delimited by triple backticks ```{ Mistä näen S-Tilini tapahtumat ja saldon? }```, what is the answer? Answer: S-Pankin verkkopankissa. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the first sentence in the paragraph. The answer is the']\n" ] } ], "source": [ "prompt = tokenizer('Given the question delimited by triple backticks ```{ Mistä näen S-Tilini tapahtumat ja saldon? }```, what is the answer? Answer:', return_tensors=\"pt\")\n", "result = generate_output(model,prompt)\n", "print(tokenizer.batch_decode(result, skip_special_tokens=True))" ] }, { "cell_type": "markdown", "id": "ae3c3d6a-2b07-4e46-9ddc-dccadfd07196", "metadata": {}, "source": [ "### Finetuned model" ] }, { "cell_type": "code", "execution_count": 27, "id": "4cf53f39-ad3f-43e2-8daa-79853b054cd2", "metadata": {}, "outputs": [], "source": [ "from peft import PeftModel" ] }, { "cell_type": "code", "execution_count": 28, "id": "142bf57d-cffc-47b2-ae91-8a5420c46d32", "metadata": {}, "outputs": [], "source": [ "loaded_model = PeftModel.from_pretrained(model,model_id185,is_trainable=False)" ] }, { "cell_type": "code", "execution_count": 29, "id": "c3cacd26-edff-494c-9428-55b7659988de", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Given the question delimited by triple backticks ```{ Mikä on osuusmaksu ja miksi se pitää maksaa? }```, what is the answer? Answer: {Osuusmaksun suuruus on 100 euroa. Se maksetaan vain kerran, ja se palautetaan osuuskaupan asiakasomistajuudesta luovuttaessa. Osuusmaksun voi maksaa kerralla kokonaan tai kerryttää sitä Bonuksilla. Osuusmaksun voi maksaa myös S-Etukortilla, jolloin se veloitetaan S-Etukorttiin liitetyltä maksuvälineeltä.}\\n\\n{Osuusmaksun voi maksaa myös S-ryhmän lahjakortilla. Lahjakortin saldo ei kuitenkaan kerrytä Bonusta.}\\n\\n{Osuusmaksun voi maksaa myös S-ryhmän lahjak']\n" ] } ], "source": [ "prompt = tokenizer('Given the question delimited by triple backticks ```{ Mikä on osuusmaksu ja miksi se pitää maksaa? }```, what is the answer? Answer:', return_tensors=\"pt\")\n", "result = generate_output(loaded_model,prompt)\n", "print(tokenizer.batch_decode(result, skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": 30, "id": "57029580-6429-4c43-8482-1a052839bc05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Given the question delimited by triple backticks ```{ Mistä näen S-Tilini tapahtumat ja saldon? }```, what is the answer? Answer: {S-Tilin tapahtumat ja saldo ovat nähtävissä S-mobiilissa ja S-Pankin verkkopankissa.}\\n\\n{S-mobiilissa S-Tilin tapahtumat ja saldo ovat nähtävissä S-Etukortti-osiossa.}\\n\\n{S-Pankin verkkopankissa S-Tilin tapahtumat ja saldo ovat nähtävissä Tilit-osiossa.}\\n\\n{S-Tilin tapahtumat ja saldo ovat nähtävissä myös S-mobiilin ja S-Pankin verkkopankin S-Etukortti- ja Tilit-osiossa, kun olet']\n" ] } ], "source": [ "prompt = tokenizer('Given the question delimited by triple backticks ```{ Mistä näen S-Tilini tapahtumat ja saldon? }```, what is the answer? Answer:', return_tensors=\"pt\")\n", "result = generate_output(loaded_model,prompt)\n", "print(tokenizer.batch_decode(result, skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "166c476c-01a2-49cc-b03f-6cb1d9ae6136", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "897448ea-f680-4a5b-a148-f65df6704bac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }