SmolLM2-360M-Instruct-TaiwanChat / train_with_unsloth.py
Luigi's picture
adjust hyper-parameters
976c215
#!/usr/bin/env python3
"""
Fine-tune “SmolLM2-360M-Instruct” on the TaiwanChat dataset using Unsloth’s 4-bit quantization
+ LoRA adapters, with evaluation on a 5% hold-out every 500 steps, early stopping,
explicit LR and optimizer, and push the merged model to Hugging Face.
Adjustments:
- LoRA rank remains r=16 (sufficient capacity for instruction data)
- No LoRA dropout (maximize capacity to avoid underfitting)
- Weight decay of 0.01 for slight regularization
- 5% validation split for robust hold-out
- Explicit learning_rate=2e-4 and warmup_steps=500
- logging_steps=50 for clearer loss trends
- optim="adamw_torch" for full-precision AdamW
- gradient_accumulation_steps=2 for more frequent updates
- num_train_epochs=5 to ensure sufficient training steps
- gradient_checkpointing disabled for stable gradient computation
- EarlyStoppingCallback to halt if no improvement over 4 evals
"""
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
from transformers import DataCollatorForLanguageModeling, EarlyStoppingCallback
from unsloth.chat_templates import train_on_responses_only
from transformers.integrations import WandbCallback
from datasets import load_dataset, Dataset
import os
import torch
import random
import logging
import re
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class LoggingSFTTrainer(SFTTrainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# 0) your existing “no valid labels” check
labels = inputs.get("labels", None)
if labels is not None:
num_valid = (labels != -100).sum().item()
if not model.training and num_valid == 0:
input_ids = inputs.get("input_ids", None)
if input_ids is not None:
texts = self.tokenizer.batch_decode(
input_ids, skip_special_tokens=False
)
for idx, txt in enumerate(texts):
logger.warning(
f"→ [Step {self.state.global_step}] Example {idx} has no valid labels:\n{txt!r}"
)
else:
logger.warning(
f"→ [Step {self.state.global_step}] Zero‐label batch but no input_ids to decode!"
)
# 1) always get both loss and outputs so we can inspect the loss
loss_and_outputs = super().compute_loss(
model, inputs, return_outputs=True, **kwargs
)
# unpack depending on whether there are outputs
if isinstance(loss_and_outputs, tuple):
loss, outputs = loss_and_outputs
else:
loss, outputs = loss_and_outputs, None
# 2) during evaluation, catch infinite or NaN losses
if not model.training:
if torch.isnan(loss) or torch.isinf(loss):
input_ids = inputs.get("input_ids", None)
if input_ids is not None:
texts = self.tokenizer.batch_decode(
input_ids, skip_special_tokens=False
)
for idx, txt in enumerate(texts):
logger.warning(
f"→ [Step {self.state.global_step}] Example {idx} resulted in invalid loss ({loss.item()}):\n{txt!r}"
)
else:
logger.warning(
f"→ [Step {self.state.global_step}] Invalid loss ({loss.item()}) but no input_ids to decode!"
)
# 3) return in the format the caller expects
if return_outputs:
return loss, outputs
return loss
# Project and dataset settings
PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat'
BASE_MODEL_ID = "unsloth/SmolLM2-360M-Instruct"
DATASET_ID = "yentinglin/TaiwanChat"
N_SAMPLES = 600000
MAX_LEN = 512
# CUDA and W&B setup
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
os.environ["WANDB_PROJECT"] = f"{PROJECT_NAME}_CLOUD"
os.environ["WANDB_LOG_MODEL"] = "end"
# 1) Load 4-bit quantized model without full fine-tuning
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_ID,
max_seq_length=MAX_LEN,
load_in_4bit=True,
full_finetuning=False,
)
# 2) Attach LoRA adapters
model = FastLanguageModel.get_peft_model(
model,
r=16, # sufficient capacity for instruction tasks
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=16,
lora_dropout=0.0, # no dropout to maximize capacity
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
max_seq_length=MAX_LEN,
use_rslora=False,
loftq_config=None,
)
# Prepare dataset with 5% validation split
def load_fitting_samples(dataset_id, tokenizer, max_len, n_samples, seed=3407):
# 1) Open the HF dataset in streaming mode
stream = load_dataset(dataset_id, split="train", streaming=True)
selected = []
for example in stream:
# 2) Render the chat‐template text
text = tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
# 3) Quick length check on token IDs
tokens = tokenizer(text, add_special_tokens=False)["input_ids"]
if len(tokens) <= max_len:
selected.append({"text": text})
# 4) Stop as soon as we have enough
if len(selected) >= n_samples:
break
# 5) Shuffle and build a regular Dataset
random.Random(seed).shuffle(selected)
return Dataset.from_list(selected)
# --- usage in your script ---
dataset = load_fitting_samples(
DATASET_ID,
tokenizer=tokenizer,
max_len=MAX_LEN,
n_samples=N_SAMPLES,
seed=3407,
)
def clean_assistant_marker(example):
# collapse any "<|im_start|>assistant\n\n…\n\n" into "<|im_start|>assistant\n"
example["text"] = re.sub(
r"(<\|im_start\|>assistant)\n+",
r"\1\n",
example["text"]
)
return example
# clean: <|im_start|>assistant\n\n -> <|im_start|>assistant\n
dataset = dataset.map(clean_assistant_marker, batched=False)
new_dataset = dataset.train_test_split(test_size=0.1)
# Configure training arguments
training_args = SFTConfig(
fp16_full_eval=False,
per_device_train_batch_size=40,
gradient_accumulation_steps=1,
per_device_eval_batch_size=1,
eval_accumulation_steps=1,
evaluation_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=1000,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
dataset_text_field="text",
output_dir=PROJECT_NAME,
max_seq_length=MAX_LEN,
num_train_epochs=3,
learning_rate=2e-4,
weight_decay=0.01,
warmup_steps=500,
logging_steps=50,
logging_dir=f"{PROJECT_NAME}/logs",
report_to=["wandb"],
run_name=f"{PROJECT_NAME}_CLOUD",
optim="adamw_8bit",
push_to_hub=False,
gradient_checkpointing=False,
seed=3407,
)
# Initialize Trainer with early stopping
torch.cuda.empty_cache()
trainer = LoggingSFTTrainer(
model=model,
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
tokenizer=tokenizer,
callbacks=[WandbCallback, EarlyStoppingCallback(early_stopping_patience=4)],
train_dataset=new_dataset["train"],
eval_dataset=new_dataset["test"],
)
# Mask user prompts and train
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
)
trainer.train()
# Merge LoRA weights and push merged model to Hugging Face
model.push_to_hub_merged(
f'Luigi/{PROJECT_NAME}',
tokenizer,
save_method="merged_16bit",
safe_serialization=None
)
# Example inference
test_prompt = "請問台北今天的天氣如何?"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))