|
|
|
|
|
""" |
|
|
Fine-tune “SmolLM2-360M-Instruct” on the TaiwanChat dataset using Unsloth’s 4-bit quantization |
|
|
+ LoRA adapters, with evaluation on a 5% hold-out every 500 steps, early stopping, |
|
|
explicit LR and optimizer, and push the merged model to Hugging Face. |
|
|
|
|
|
Adjustments: |
|
|
- LoRA rank remains r=16 (sufficient capacity for instruction data) |
|
|
- No LoRA dropout (maximize capacity to avoid underfitting) |
|
|
- Weight decay of 0.01 for slight regularization |
|
|
- 5% validation split for robust hold-out |
|
|
- Explicit learning_rate=2e-4 and warmup_steps=500 |
|
|
- logging_steps=50 for clearer loss trends |
|
|
- optim="adamw_torch" for full-precision AdamW |
|
|
- gradient_accumulation_steps=2 for more frequent updates |
|
|
- num_train_epochs=5 to ensure sufficient training steps |
|
|
- gradient_checkpointing disabled for stable gradient computation |
|
|
- EarlyStoppingCallback to halt if no improvement over 4 evals |
|
|
""" |
|
|
|
|
|
from unsloth import FastLanguageModel |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
from transformers import DataCollatorForLanguageModeling, EarlyStoppingCallback |
|
|
from unsloth.chat_templates import train_on_responses_only |
|
|
from transformers.integrations import WandbCallback |
|
|
from datasets import load_dataset, Dataset |
|
|
import os |
|
|
import torch |
|
|
import random |
|
|
import logging |
|
|
import re |
|
|
|
|
|
logging.basicConfig(level=logging.WARNING) |
|
|
logger = logging.getLogger(__name__) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LoggingSFTTrainer(SFTTrainer): |
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
|
|
|
labels = inputs.get("labels", None) |
|
|
if labels is not None: |
|
|
num_valid = (labels != -100).sum().item() |
|
|
if not model.training and num_valid == 0: |
|
|
input_ids = inputs.get("input_ids", None) |
|
|
if input_ids is not None: |
|
|
texts = self.tokenizer.batch_decode( |
|
|
input_ids, skip_special_tokens=False |
|
|
) |
|
|
for idx, txt in enumerate(texts): |
|
|
logger.warning( |
|
|
f"→ [Step {self.state.global_step}] Example {idx} has no valid labels:\n{txt!r}" |
|
|
) |
|
|
else: |
|
|
logger.warning( |
|
|
f"→ [Step {self.state.global_step}] Zero‐label batch but no input_ids to decode!" |
|
|
) |
|
|
|
|
|
|
|
|
loss_and_outputs = super().compute_loss( |
|
|
model, inputs, return_outputs=True, **kwargs |
|
|
) |
|
|
|
|
|
if isinstance(loss_and_outputs, tuple): |
|
|
loss, outputs = loss_and_outputs |
|
|
else: |
|
|
loss, outputs = loss_and_outputs, None |
|
|
|
|
|
|
|
|
if not model.training: |
|
|
if torch.isnan(loss) or torch.isinf(loss): |
|
|
input_ids = inputs.get("input_ids", None) |
|
|
if input_ids is not None: |
|
|
texts = self.tokenizer.batch_decode( |
|
|
input_ids, skip_special_tokens=False |
|
|
) |
|
|
for idx, txt in enumerate(texts): |
|
|
logger.warning( |
|
|
f"→ [Step {self.state.global_step}] Example {idx} resulted in invalid loss ({loss.item()}):\n{txt!r}" |
|
|
) |
|
|
else: |
|
|
logger.warning( |
|
|
f"→ [Step {self.state.global_step}] Invalid loss ({loss.item()}) but no input_ids to decode!" |
|
|
) |
|
|
|
|
|
|
|
|
if return_outputs: |
|
|
return loss, outputs |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat' |
|
|
BASE_MODEL_ID = "unsloth/SmolLM2-360M-Instruct" |
|
|
DATASET_ID = "yentinglin/TaiwanChat" |
|
|
N_SAMPLES = 600000 |
|
|
MAX_LEN = 512 |
|
|
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128" |
|
|
os.environ["WANDB_PROJECT"] = f"{PROJECT_NAME}_CLOUD" |
|
|
os.environ["WANDB_LOG_MODEL"] = "end" |
|
|
|
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=BASE_MODEL_ID, |
|
|
max_seq_length=MAX_LEN, |
|
|
load_in_4bit=True, |
|
|
full_finetuning=False, |
|
|
) |
|
|
|
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
|
model, |
|
|
r=16, |
|
|
target_modules=[ |
|
|
"q_proj", "k_proj", "v_proj", "o_proj", |
|
|
"gate_proj", "up_proj", "down_proj", |
|
|
], |
|
|
lora_alpha=16, |
|
|
lora_dropout=0.0, |
|
|
bias="none", |
|
|
use_gradient_checkpointing="unsloth", |
|
|
random_state=3407, |
|
|
max_seq_length=MAX_LEN, |
|
|
use_rslora=False, |
|
|
loftq_config=None, |
|
|
) |
|
|
|
|
|
|
|
|
def load_fitting_samples(dataset_id, tokenizer, max_len, n_samples, seed=3407): |
|
|
|
|
|
stream = load_dataset(dataset_id, split="train", streaming=True) |
|
|
|
|
|
selected = [] |
|
|
for example in stream: |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
example["messages"], |
|
|
tokenize=False, |
|
|
add_generation_prompt=False, |
|
|
) |
|
|
|
|
|
tokens = tokenizer(text, add_special_tokens=False)["input_ids"] |
|
|
if len(tokens) <= max_len: |
|
|
selected.append({"text": text}) |
|
|
|
|
|
|
|
|
if len(selected) >= n_samples: |
|
|
break |
|
|
|
|
|
|
|
|
random.Random(seed).shuffle(selected) |
|
|
return Dataset.from_list(selected) |
|
|
|
|
|
|
|
|
dataset = load_fitting_samples( |
|
|
DATASET_ID, |
|
|
tokenizer=tokenizer, |
|
|
max_len=MAX_LEN, |
|
|
n_samples=N_SAMPLES, |
|
|
seed=3407, |
|
|
) |
|
|
|
|
|
def clean_assistant_marker(example): |
|
|
|
|
|
example["text"] = re.sub( |
|
|
r"(<\|im_start\|>assistant)\n+", |
|
|
r"\1\n", |
|
|
example["text"] |
|
|
) |
|
|
return example |
|
|
|
|
|
|
|
|
dataset = dataset.map(clean_assistant_marker, batched=False) |
|
|
|
|
|
new_dataset = dataset.train_test_split(test_size=0.1) |
|
|
|
|
|
|
|
|
training_args = SFTConfig( |
|
|
fp16_full_eval=False, |
|
|
per_device_train_batch_size=40, |
|
|
gradient_accumulation_steps=1, |
|
|
per_device_eval_batch_size=1, |
|
|
eval_accumulation_steps=1, |
|
|
evaluation_strategy="steps", |
|
|
eval_steps=100, |
|
|
save_strategy="steps", |
|
|
save_steps=1000, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
dataset_text_field="text", |
|
|
output_dir=PROJECT_NAME, |
|
|
max_seq_length=MAX_LEN, |
|
|
num_train_epochs=3, |
|
|
learning_rate=2e-4, |
|
|
weight_decay=0.01, |
|
|
warmup_steps=500, |
|
|
logging_steps=50, |
|
|
logging_dir=f"{PROJECT_NAME}/logs", |
|
|
report_to=["wandb"], |
|
|
run_name=f"{PROJECT_NAME}_CLOUD", |
|
|
optim="adamw_8bit", |
|
|
push_to_hub=False, |
|
|
gradient_checkpointing=False, |
|
|
seed=3407, |
|
|
) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
trainer = LoggingSFTTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
|
|
tokenizer=tokenizer, |
|
|
callbacks=[WandbCallback, EarlyStoppingCallback(early_stopping_patience=4)], |
|
|
train_dataset=new_dataset["train"], |
|
|
eval_dataset=new_dataset["test"], |
|
|
) |
|
|
|
|
|
|
|
|
trainer = train_on_responses_only( |
|
|
trainer, |
|
|
instruction_part="<|im_start|>user\n", |
|
|
response_part="<|im_start|>assistant\n", |
|
|
) |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model.push_to_hub_merged( |
|
|
f'Luigi/{PROJECT_NAME}', |
|
|
tokenizer, |
|
|
save_method="merged_16bit", |
|
|
safe_serialization=None |
|
|
) |
|
|
|
|
|
|
|
|
test_prompt = "請問台北今天的天氣如何?" |
|
|
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
|