LLM Course documentation
Implementarea GRPO în TRL
Implementarea GRPO în TRL
În această pagină, vom învăța cum să implementăm Optimizarea Relativă a Politicii de Grup (GRPO) folosind biblioteca Transformer Reinforcement Learning (TRL). Ne vom concentra pe implementarea practică cu cod minimal.
Vom explora conceptele centrale ale GRPO așa cum sunt întruchipate în GRPOTrainer din TRL, folosind fragmente din documentația oficială TRL pentru a ne ghida.
Acest capitol este destinat începătorilor TRL. Dacă ești deja familiar cu TRL, ai putea de asemenea să consulți implementarea Open R1 a GRPO.
În primul rând, să ne reamintim unele dintre conceptele importante ale algoritmului GRPO:
- Formarea Grupului: Modelul generează multiple completări pentru fiecare prompt.
- Învățarea Preferințelor: Modelul învață dintr-o funcție de recompensă care compară grupuri de completări.
- Configurația Antrenamentului: Modelul folosește o configurație pentru a controla procesul de antrenare.
Ce trebuie să facem pentru a implementa GRPO?
- Să definim un set de date de prompt-uri.
- Să definim o funcție de recompensă care ia o listă de completări și returnează o listă de recompense.
- Să configurăm procesul de antrenare cu un GRPOConfig.
- Să antrenăm modelul folosind GRPOTrainer.
Iată un exemplu minimal pentru a începe antrenamentul GRPO:
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
# 1. Încarcă setul tău de date
dataset = load_dataset("setul_tau_de_date", split="train")
# 2. Definește o funcție de recompensă simplă
def reward_func(completions, **kwargs):
"""Exemplu: Recompensează completările mai lungi"""
return [float(len(completion)) for completion in completions]
# 3. Configurează antrenamentul
training_args = GRPOConfig(
output_dir="output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
logging_steps=10,
)
# 4. Inițializează și antrenează
trainer = GRPOTrainer(
model="modelul_tau", # de exemplu "Qwen/Qwen2-0.5B-Instruct"
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
)
trainer.train()
Componentele Cheie
1. Formatul Setului de Date
Setul tău de date ar trebui să conțină prompt-uri la care modelul va răspunde. Antrenorul GRPO va genera multiple completări pentru fiecare prompt și va folosi funcția de recompensă pentru a le compara.
2. Funcția de Recompensă
Funcția de recompensă este crucială - determină cum învață modelul. Iată două exemple practice:
# Exemplul 1: Recompensă bazată pe lungimea completării
def reward_length(completions, **kwargs):
return [float(len(completion)) for completion in completions]
# Exemplul 2: Recompensă bazată pe potrivirea unui model
import re
def reward_format(completions, **kwargs):
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
return [1.0 if re.match(pattern, c) else 0.0 for c in completions]
3. Configurația Antrenamentului
Parametrii cheie de considerat în GRPOConfig
:
training_args = GRPOConfig(
# Parametrii esențiali
output_dir="output",
num_train_epochs=3,
num_generation=4, # Numărul de completări de generat pentru fiecare prompt
per_device_train_batch_size=4, # Vrem să obținem toate generările într-un lot de dispozitiv
# Opțional dar util
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
# Specific GRPO (opțional)
use_vllm=True, # Accelerează generarea
)
Parametrul num_generation
este deosebit de important pentru GRPO deoarece definește dimensiunea grupului - câte completări diferite va genera modelul pentru fiecare prompt. Acesta este un diferențiator cheie de alte metode RL:
- Prea mic (de exemplu, 2-3): S-ar putea să nu ofere suficientă diversitate pentru comparații semnificative
- Recomandat (4-16): Oferă un echilibru bun între diversitate și eficiența computațională
- Valori mai mari: Pot îmbunătăți învățarea dar cresc semnificativ costul computațional
Dimensiunea grupului ar trebui aleasă în funcție de resursele tale computaționale și complexitatea sarcinii tale. Pentru sarcini simple, grupuri mai mici (4-8) pot fi suficiente, în timp ce sarcinile de raționament mai complexe ar putea beneficia de grupuri mai mari (8-16).
Sfaturi pentru Succes
- Gestionarea Memoriei: Ajustează
per_device_train_batch_size
șigradient_accumulation_steps
în funcție de memoria GPU-ului tău. - Viteza: Activează
use_vllm=True
pentru generare mai rapidă dacă modelul tău este suportat. - Monitorizarea: Urmărește metricile înregistrate în timpul antrenamentului:
reward
: Recompensa medie pe completărireward_std
: Deviația standard în cadrul grupurilor de recompensekl
: Divergența KL de la modelul de referință
Designul Funcției de Recompensă
Lucrarea DeepSeek R1 demonstrează mai multe abordări eficiente pentru designul funcției de recompensă pe care le poți adapta pentru propria ta implementare GRPO:
1. Recompense Bazate pe Lungime
Una dintre cele mai ușoare recompense de implementat este o recompensă bazată pe lungime. Poți recompensa completări mai lungi:
def reward_len(completions, **kwargs):
ideal_length = 20
return [-abs(ideal_length - len(completion)) for completion in completions]
Această funcție de recompensă penalizează completările care sunt prea scurte sau prea lungi, încurajând modelul să genereze completări care sunt aproape de lungimea ideală de 20 de token-uri.
2. Recompense Bazate pe Reguli pentru Sarcini Verificabile
Pentru sarcini cu răspunsuri obiectiv corecte (cum ar fi matematica sau codarea), poți implementa funcții de recompensă bazate pe reguli:
def problem_reward(completions, answers, **kwargs):
"""Funcție de recompensă pentru probleme de matematică cu răspunsuri verificabile
completions: lista de completări de evaluat
answers: lista de răspunsuri la problemele din setul de date
"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extrage răspunsul din completare
try:
# Acesta este un exemplu simplificat - ai avea nevoie de parsing corespunzător
answer = extract_final_answer(completion)
# Recompensă binară: 1 pentru corect, 0 pentru incorect
reward = 1.0 if answer == correct_answer else 0.0
rewards.append(reward)
except:
# Dacă nu putem parsa un răspuns, dăm o recompensă mică
rewards.append(0.0)
return rewards
3. Recompense Bazate pe Format
Poți de asemenea să recompensezi formatarea corespunzătoare, care a fost importantă în antrenamentul DeepSeek R1:
def format_reward(completions, **kwargs):
"""Recompensează completările care urmează formatul dorit"""
# Exemplu: Verifică dacă completarea urmează un format gândește-apoi-răspunde
pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"
rewards = []
for completion in completions:
match = re.search(pattern, completion, re.DOTALL)
if match:
# Verifică dacă există conținut substanțial în ambele secțiuni
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
if len(think_content) > 20 and len(answer_content) > 0:
rewards.append(1.0)
else:
rewards.append(
0.5
) # Recompensă parțială pentru format corect dar conținut limitat
else:
rewards.append(0.0) # Nicio recompensă pentru format incorect
return rewards
Aceste exemple demonstrează cum poți implementa funcții de recompensă inspirate din procesul de antrenare DeepSeek R1, concentrându-se pe corectitudine, formatare și semnale combinate.
Asta e tot!
În următoarea secțiune, vei urma un exercițiu pentru a implementa GRPO în TRL.
< > Update on GitHub