SFT model_v2 gemma_3_800M_sft_v1_translation-kazparc_latest

June 23

Base Model SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

SFT trained on Kazparc (kk_to_en, kk_to_ru, ru_to_kk, en_to_kk)

Inference params

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"

model_path = "SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]

tok = tokenizer
# Промпт в формате чата
prompt = (
    (f"<start_of_turn>system\n{s}<end_of_turn>\n"
    f"<start_of_turn>user\n{u}<end_of_turn>\n"
    f"<start_of_turn>assistant"))

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
    **model_inputs,
    max_new_tokens=64,
    do_sample=True,
    top_p=0.9,
    #temperature=0.7,
    #repetition_penalty=1.2,
    eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
    pad_token_id=tok.eos_token_id,
    #min_new_tokens=5,
)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Train

Main script for training

# train_gemma_sft.py  🔧
import os, math, argparse, torch
from pathlib import Path
from datasets import load_dataset, concatenate_datasets
from transformers import (AutoTokenizer, Gemma3ForCausalLM)
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

# ─── CLI ────────────────────────────────────────────────────────────────
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--tokenizer_path", required=True)
    p.add_argument("--model_path",
                   default="/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1/checkpoint-300")
    p.add_argument("--data_dir", required=True,                    # *.jsonl with system/user/assistant
                   help="Folder with SFT jsonl shards")
    p.add_argument("--output_dir",  default="runs/gemma_sft")
    p.add_argument("--max_seq_length",   type=int, default=2048)
    p.add_argument("--per_device_batch_size", type=int, default=8)
    p.add_argument("--gradient_accumulation_steps", type=int, default=4)
    p.add_argument("--learning_rate", type=float, default=2e-4)
    p.add_argument("--wandb_project",  default="gemma-sft")
    p.add_argument("--wandb_run_name", default=None)
    return p.parse_args()

args = parse_args()
os.environ["WANDB_PROJECT"] = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ─── tokenizer / model ─────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
for t in ["<start_of_turn>", "<end_of_turn>"]:
    if t not in tok.get_vocab():
        tok.add_special_tokens({"additional_special_tokens": [t]})

model = Gemma3ForCausalLM.from_pretrained(
    args.model_path,
    torch_dtype=torch.bfloat16,
    _attn_implementation="eager"
)
model.resize_token_embeddings(len(tok))  # in case we added tags

# ─── dataset loading  ──────────────────────────────────────────────────
data_dir = Path(args.data_dir)
jsonl_files = sorted(data_dir.glob("*.jsonl"))
if not jsonl_files:
    raise ValueError("no jsonl found")

print(f"→ Loading {len(jsonl_files)} shards")
dsets = [load_dataset("json", data_files=str(f), split="train")
         for f in jsonl_files]
raw_ds = concatenate_datasets(dsets)

# build chat template + rough length filter
MAX_LEN = args.max_seq_length
def build_and_filter_batch(ex):
    texts = []
    for s,u,a in zip(ex["system"], ex["user"], ex["assistant"]):
        if (len(s)+len(u)+len(a)) > MAX_LEN*4:   # ≈ char filter
            continue
        t = (f"<start_of_turn>system\n{s}<end_of_turn>\n"
             f"<start_of_turn>user\n{u}<end_of_turn>\n"
             f"<start_of_turn>assistant\n{a}<end_of_turn>{tok.eos_token}")
        texts.append(t)
    return {"text": texts}

cpu = os.cpu_count()
ds = raw_ds.map(build_and_filter_batch,
                batched=True, batch_size=1000, num_proc=cpu,
                remove_columns=raw_ds.column_names)
ds = ds.shuffle(seed=42)

# ─── collator: mask *только* ответ ассистента ──────────────────────────
collator = DataCollatorForCompletionOnlyLM(
    tokenizer=tok,
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>assistant\n",
    mlm=False,
)

# ─── training args ─────────────────────────────────────────────────────
train_cfg = SFTConfig(
    output_dir=args.output_dir,
    run_name=args.wandb_run_name,
    max_seq_length=args.max_seq_length,
    gradient_checkpointing=True,
    packing=False,
    per_device_train_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    bf16=True,
    warmup_ratio=0.03,
    weight_decay=0.01,
    do_train=True,
    group_by_length=True,
    lr_scheduler_type="cosine",
    logging_steps=1,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=15,
    deepspeed="../train_trl/ds_stage1.json",
    dataloader_num_workers=8,
    dataset_num_proc=cpu,
)

trainer = SFTTrainer(
    model=model,
    args=train_cfg,
    train_dataset=ds,
    data_collator=collator,
    processing_class=tok,         
)

if __name__ == "__main__":
    print(f"🚀 Start SFT: {len(ds):,} chat samples")
    trainer.train()
    trainer.save_model(f"{args.output_dir}/checkpoint-final")
    tok.save_pretrained(f"{args.output_dir}/checkpoint-final")

To run training please use similar bash

#bash

export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"

export WANDB_API_KEY=""

OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_sft_with_base_model_v1_2'
WANDB_RUN_NAME='sft_translation_on_test_2_sft_with_base_model_v1_2'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi

# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 test_sft_train.py \
  --tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
  --model_path "/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-900" \
  --data_dir "/scratch/vladimir_albrekht/projects/smollm/data/sft/kazparc/jsonl/train" \
  --max_seq_length 2048 \
  --per_device_batch_size 32 \
  --gradient_accumulation_steps 8 \
  --learning_rate 4e-5 \
  --output_dir ${OUTPUT_DIR} \
  --wandb_project "small_llm_SRP" \
  --wandb_run_name ${WANDB_RUN_NAME}
Downloads last month
171
Safetensors
Model size
859M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest

Dataset used to train SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest

Space using SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest 1

Collection including SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest