re-implement dataset filtering in more efficient way
Browse files- train_with_unsloth.py +36 -17
train_with_unsloth.py
CHANGED
|
@@ -26,6 +26,8 @@ from transformers.integrations import WandbCallback
|
|
| 26 |
from datasets import load_dataset
|
| 27 |
import os
|
| 28 |
import torch
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Project and dataset settings
|
| 31 |
PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat'
|
|
@@ -66,23 +68,40 @@ model = FastLanguageModel.get_peft_model(
|
|
| 66 |
)
|
| 67 |
|
| 68 |
# Prepare dataset with 5% validation split
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
new_dataset = dataset.train_test_split(test_size=0.2)
|
| 87 |
|
| 88 |
# Configure training arguments
|
|
|
|
| 26 |
from datasets import load_dataset
|
| 27 |
import os
|
| 28 |
import torch
|
| 29 |
+
from datasets import load_dataset, Dataset
|
| 30 |
+
import random
|
| 31 |
|
| 32 |
# Project and dataset settings
|
| 33 |
PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat'
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
# Prepare dataset with 5% validation split
|
| 71 |
+
def load_fitting_samples(dataset_id, tokenizer, max_len, n_samples, seed=3407):
|
| 72 |
+
# 1) Open the HF dataset in streaming mode
|
| 73 |
+
stream = load_dataset(dataset_id, split="train", streaming=True)
|
| 74 |
+
|
| 75 |
+
selected = []
|
| 76 |
+
for example in stream:
|
| 77 |
+
# 2) Render the chat‐template text
|
| 78 |
+
text = tokenizer.apply_chat_template(
|
| 79 |
+
example["messages"],
|
| 80 |
+
tokenize=False,
|
| 81 |
+
add_generation_prompt=True,
|
| 82 |
+
)
|
| 83 |
+
# 3) Quick length check on token IDs
|
| 84 |
+
tokens = tokenizer(text, add_special_tokens=False)["input_ids"]
|
| 85 |
+
if len(tokens) <= max_len:
|
| 86 |
+
selected.append({"text": text})
|
| 87 |
+
|
| 88 |
+
# 4) Stop as soon as we have enough
|
| 89 |
+
if len(selected) >= n_samples:
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
# 5) Shuffle and build a regular Dataset
|
| 93 |
+
random.Random(seed).shuffle(selected)
|
| 94 |
+
return Dataset.from_list(selected)
|
| 95 |
+
|
| 96 |
+
# --- usage in your script ---
|
| 97 |
+
dataset = load_fitting_samples(
|
| 98 |
+
DATASET_ID,
|
| 99 |
+
tokenizer=tokenizer,
|
| 100 |
+
max_len=MAX_LEN,
|
| 101 |
+
n_samples=N_SAMPLES,
|
| 102 |
+
seed=3407,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
new_dataset = dataset.train_test_split(test_size=0.2)
|
| 106 |
|
| 107 |
# Configure training arguments
|