LLM Course documentation

Exercițiu Practic: Ajustează fin un model cu GRPO

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Ask a Question Open In Colab

Exercițiu Practic: Ajustează fin un model cu GRPO

Acum că ai văzut teoria, să o punem în practică! În acest exercițiu, vei ajusta fin un model cu GRPO.

Acest exercițiu a fost scris de expertul în ajustarea fină LLM @mlabonne.

Instalează dependențele

În primul rând, să instalăm dependențele pentru acest exercițiu.

!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
!pip install -qqq flash-attn --no-build-isolation --progress-bar off

Acum vom importa bibliotecile necesare.

import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

Importă și loghează în Weights & Biases

Weights & Biases este un instrument pentru înregistrarea și monitorizarea experimentelor tale. Îl vom folosi pentru a înregistra procesul nostru de ajustare fină.

import wandb

wandb.login()

Poți face acest exercițiu fără să te loghezi în Weights & Biases, dar este recomandat să faci asta pentru a-ți urmări experimentele și pentru a interpreta rezultatele.

Încarcă setul de date

Acum, să încărcăm setul de date. În acest caz, vom folosi setul de date mlabonne/smoltldr, care conține o listă de povești scurte.

dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

Încarcă modelul

Acum, să încărcăm modelul.

Pentru acest exercițiu, vom folosi modelul SmolLM2-135M.

Acesta este un model mic de 135M parametri care rulează pe hardware limitat. Aceasta face modelul ideal pentru învățare, dar nu este cel mai puternic model de acolo. Dacă ai acces la hardware mai puternic, poți încerca să ajustezi fin un model mai mare precum SmolLM2-1.7B.

model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Încarcă LoRA

Acum, să încărcăm configurația LoRA. Vom profita de LoRA pentru a reduce numărul de parametri antrenabili, și prin urmare amprenta de memorie de care avem nevoie pentru a ajusta fin modelul.

Dacă nu ești familiar cu LoRA, poți citi mai multe despre aceasta în Capitolul 11.

# Încarcă LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())
Parametri antrenabili totali: 135M

Definește funcția de recompensă

După cum s-a menționat în secțiunea anterioară, GRPO poate folosi orice funcție de recompensă pentru a îmbunătăți modelul. În acest caz, vom folosi o funcție de recompensă simplă care încurajează modelul să genereze text lung de 50 de token-uri.

# Funcția de recompensă
ideal_length = 50


def reward_len(completions, **kwargs):
    return [-abs(ideal_length - len(completion)) for completion in completions]

Definește argumentele de antrenament

Acum, să definim argumentele de antrenament. Vom folosi clasa GRPOConfig pentru a defini argumentele de antrenament într-un stil tipic transformers.

Dacă aceasta este prima dată când definești argumente de antrenament, poți verifica clasa TrainingArguments pentru mai multe informații, sau Capitolul 2 pentru o introducere detaliată.

# Argumentele de antrenament
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    max_prompt_length=512,
    max_completion_length=96,
    num_generations=8,
    optim="adamw_8bit",
    num_train_epochs=1,
    bf16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=1,
)

Acum, putem inițializa antrenorul cu modelul, setul de date și argumentele de antrenament și să începem antrenamentul.

# Antrenorul
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_len],
    args=training_args,
    train_dataset=dataset["train"],
)

# Antrenează modelul
wandb.init(project="GRPO")
trainer.train()

Antrenamentul durează aproximativ 1 oră pe un singur GPU A10G care este disponibil pe Google Colab sau prin Hugging Face Spaces.

Împinge modelul pe Hub în timpul antrenamentului

Dacă setăm argumentul push_to_hub la True și argumentul model_id la un nume de model valid, modelul va fi împins pe Hugging Face Hub în timp ce antrenăm. Aceasta este utilă dacă vrei să începi să testezi modelul imediat!

Interpretează rezultatele antrenamentului

GRPOTrainer înregistrează recompensa din funcția ta de recompensă, pierderea, și o gamă de alte metrici.

Ne vom concentra pe recompensa din funcția de recompensă și pierderea.

După cum poți vedea, recompensa din funcția de recompensă se apropie de 0 pe măsură ce modelul învață. Acesta este un semn bun că modelul învață să genereze text de lungimea corectă.

Rezultatul oferit din funcția de recompensă

Ai putea observa că pierderea începe de la zero și apoi crește în timpul antrenamentului, ceea ce poate părea contraontuitiv. Acest comportament este așteptat în GRPO și este direct legat de formularea matematică a algoritmului. Pierderea în GRPO este proporțională cu divergența KL (capsula relativă la politica originală). Pe măsură ce antrenamentul progresează, modelul învață să genereze text care se potrivește mai bine cu funcția de recompensă, făcându-l să devieze mai mult de la politica sa inițială. Această divergență crescândă se reflectă în valoarea pierderii în creștere, care de fapt indică că modelul se adaptează cu succes pentru a optimiza funcția de recompensă.

Pierderea

Salvează și publică modelul

Să partajăm modelul cu comunitatea!

merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub(
    "SmolGRPO-135M", private=False, tags=["GRPO", "Reasoning-Course"]
)

Generează text

🎉 Ai ajustat fin cu succes un model cu GRPO! Acum, să generăm puțin text cu modelul.

În primul rând, vom defini un document foarte lung!

prompt = """
# Un document lung despre pisică

Pisica (Felis catus), numită și pisica domestică sau pisica de casă, este un mamifer carnivor 
mic domesticit. Este singura specie domesticită din familia Felidae. Progresele în arheologie 
și genetică au arătat că domesticirea pisicii a avut loc în Orientul Apropiat în jurul anului 
7500 î.Hr. Este ținută în mod obișnuit ca animal de companie și pisică de fermă, dar se află 
și în libertate ca pisică sălbatică evitând contactul uman. Este prețuită de oameni pentru 
companie și capacitatea sa de a ucide dăunători. Ghearele sale retractabile sunt adaptate 
pentru uciderea speciilor de pradă mici precum șoarecii și șobolanii. Are un corp puternic 
și flexibil, reflexe rapide și dinți ascuțiți, iar vederea nocturnă și simțul mirosului sunt 
bine dezvoltate. Este o specie socială, dar un vânător solitar și un prădător crepuscular. 
Comunicarea pisicii include vocalizări—incluzând mieunat, tors, triluri, șuierat, mormăit 
și grunțit—precum și limbajul corpului. Poate auzi sunete prea slabe sau prea înalte în 
frecvență pentru urechile umane, cum ar fi cele făcute de mamiferele mici. Secretă și 
percep feromoni.
"""

messages = [ {“role”: “user”, “content”: prompt}, ]

Acum, putem genera text cu modelul.

# Generează text
from transformers import pipeline

generator = pipeline("text-generation", model="SmolGRPO-135M")

## Sau folosește modelul și tokenizer-ul pe care le-am definit mai devreme
# generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.5,
    "min_p": 0.1,
}

generated_text = generator(messages, generate_kwargs=generate_kwargs)

print(generated_text)

Concluzie

În acest capitol, am văzut cum să ajustăm fin un model cu GRPO. Am văzut de asemenea cum să interpretăm rezultatele antrenamentului și să generăm text cu modelul.

Update on GitHub