{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XIyP_0r6zuVc"
},
"source": [
"# Training Large Language Models in 2bit with `aqlm`, `transformers` and `PEFT`\n",
"\n",
"\n",
" \n",
"\n",
"\n",
"Welcome to this notebook that goes through the recent `aqlm` integration that introduces minimal performance degradation 2bit quantization techniques.\n",
"\n",
"In this notebook, we will learn how to load a large model in 2bit (`Mixtral-8x7b`) and train it using Google Colab and PEFT library from Hugging Face 🤗.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A_VgSpl4Dsr3"
},
"source": [
"**Install the `aqlm` library**\n",
"- It's the only extra dependency to run AQLM models.\n",
"- Add `[gpu]` to install the required CUDA specific dependencies.\n",
"- Install the latest `accelerate` and `transformers` releases to properly support it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FuXIFTFapAMI"
},
"outputs": [],
"source": [
"%%capture\n",
"!pip install aqlm[gpu]>=1.1.0\n",
"!pip install git+https://github.com/huggingface/peft.git@main\n",
"!pip install accelerate>=0.27.0\n",
"!pip install git+https://github.com/huggingface/transformers.git@main\n",
"!pip install datasets\n",
"!pip install bitsandbytes\n",
"# for 8-bit optimizer only"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MJ-5idQwzvg-"
},
"source": [
"First let's load the model we are going to use - `Mixtral-8x7b`! Note that the model itself is around 50GB in half precision"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "E0Nl5mWL0k2T"
},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
"\n",
"model_id = \"ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"auto\", torch_dtype=\"bfloat16\", low_cpu_mem_usage=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mp2gMi1ZzGET"
},
"source": [
"**Add LoRA**\n",
"\n",
"To alter model's behavior, we have to make it trainable. We can do that by addind a small set of trainable parameters on top of the untrainable quantized ones."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ybeyl20n3dYH",
"outputId": "0efda156-4886-4718-9877-e93a17dc02d2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 41,943,040 || all params: 2,084,114,432 || trainable%: 2.0125\n"
]
}
],
"source": [
"from peft import LoraConfig, get_peft_model\n",
"\n",
"config = LoraConfig(\n",
" r=16,\n",
" lora_alpha=32,\n",
" target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','down_proj','up_proj', ],\n",
" lora_dropout=0.05,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\"\n",
")\n",
"\n",
"model = get_peft_model(model, config)\n",
"model.print_trainable_parameters()\n",
"model.enable_input_require_grads() # it's needed for gradient checkpointing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4xSPH1D_Wv9x"
},
"source": [
"Here we add a trainable adapter ontop of every `q_prok`, `k_proj` and `o_proj` linear layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FCc64bfnmd3j"
},
"source": [
"**Loading a dataset**\n",
"\n",
"Let's load a common dataset, english quotes, to fine tune our model on famous quotes."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "s6f4z8EYmcJ6"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9ef07f1bc62e4887817a81d4a3e15da1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/114 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded dataset with 100000 examples\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "560c7be6397c4e3aac2318d97f1f8f86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/26.0 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b667958a3b3d4529b77baf5e5bc9c259",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/665 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10359f3b8d974be49da2d3fd87f89576",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.json: 0%| | 0.00/1.04M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "97835946d4a44460bc1bd48276b8d3d0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"merges.txt: 0%| | 0.00/456k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed37faeff8914b369649cb514981991d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/1.36M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a56a1781f3347f8a056a18dc24ea7a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/100000 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed dataset has 100000 examples\n",
"Features: {'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}\n"
]
}
],
"source": [
"from datasets import load_dataset, Dataset\n",
"import itertools\n",
"from transformers import AutoTokenizer\n",
"\n",
"# Load the dataset in streaming mode\n",
"ds = load_dataset(\"open-web-math/open-web-math\", split=\"train\", streaming=True)\n",
"\n",
"# Define the number of examples you want to load\n",
"num_examples = 100000 # Adjust this number as needed\n",
"\n",
"# Create a subset by taking the first num_examples\n",
"subset = list(itertools.islice(ds, num_examples))\n",
"\n",
"# Convert the subset to a Dataset object\n",
"data = Dataset.from_list(subset)\n",
"print(f\"Loaded dataset with {len(data)} examples\")\n",
"\n",
"# Initialize tokenizer (replace 'gpt2' with your specific model if different)\n",
"tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
"\n",
"max_seq_length = 2048\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.model_max_length = max_seq_length\n",
"\n",
"def preprocess_function(examples):\n",
" # Join the list of strings into a single string\n",
" texts = [\" \".join(text) for text in examples[\"text\"]]\n",
" return tokenizer(texts, truncation=True, max_length=max_seq_length, padding=\"max_length\")\n",
"\n",
"# Process the dataset\n",
"processed_dataset = data.map(preprocess_function, batched=True, remove_columns=data.column_names)\n",
"\n",
"print(f\"Processed dataset has {len(processed_dataset)} examples\")\n",
"print(f\"Features: {processed_dataset.features}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments\n",
"import transformers\n",
"from peft import LoraConfig, get_peft_model\n",
"from datasets import load_dataset\n",
"from transformers.trainer_callback import TrainerCallback\n",
"import os\n",
"import random\n",
"import subprocess\n",
"from huggingface_hub import HfApi, hf_hub_download\n",
"\n",
"\n",
"# Custom callback to push to Hub\n",
"class PushToHubCallback(TrainerCallback):\n",
" def __init__(self, trainer, push_frequency):\n",
" self.trainer = trainer\n",
" self.push_frequency = push_frequency\n",
"\n",
" def on_step_end(self, args, state, control, **kwargs):\n",
" if state.global_step % self.push_frequency == 0:\n",
" self.trainer.save_model()\n",
" self.trainer.push_to_hub(\n",
" commit_message=f\"Training in progress - Step {state.global_step}\"\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74a0f8d448004c048d8b0608fa3a61fd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "
---|---|
1 | \n", "5.558500 | \n", "
"
],
"text/plain": [
" "
],
"text/plain": [
"\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Step \n",
" Training Loss \n",
" \n",
" \n",
" 1 \n",
" 2.042200 \n",
" \n",
" \n",
" 2 \n",
" 1.293400 \n",
" \n",
" \n",
" 3 \n",
" 1.447500 \n",
" \n",
" \n",
" 4 \n",
" 1.433600 \n",
" \n",
" \n",
" 5 \n",
" 1.725900 \n",
" \n",
" \n",
" 6 \n",
" 1.506400 \n",
" \n",
" \n",
" 7 \n",
" 1.549600 \n",
" \n",
" \n",
" 8 \n",
" 1.038300 \n",
" \n",
" \n",
" 9 \n",
" 1.603300 \n",
" \n",
" \n",
" \n",
"10 \n",
" 1.676400 \n",
"