Luigi commited on
Commit
afa5f94
·
1 Parent(s): 69d7616

re-implement dataset filtering in more efficient way

Browse files
Files changed (1) hide show
  1. train_with_unsloth.py +36 -17
train_with_unsloth.py CHANGED
@@ -26,6 +26,8 @@ from transformers.integrations import WandbCallback
26
  from datasets import load_dataset
27
  import os
28
  import torch
 
 
29
 
30
  # Project and dataset settings
31
  PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat'
@@ -66,23 +68,40 @@ model = FastLanguageModel.get_peft_model(
66
  )
67
 
68
  # Prepare dataset with 5% validation split
69
- dataset = load_dataset(DATASET_ID, split=f"train")
70
- def fmt(examples):
71
- return {"text": [
72
- tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
73
- for msgs in examples["messages"]
74
- ]}
75
- dataset = dataset.map(fmt, batched=True, remove_columns=["messages"])
76
- def is_within_max_len(example):
77
- toks = tokenizer(
78
- example["text"],
79
- add_special_tokens=False
80
- )["input_ids"]
81
- return len(toks) <= MAX_LEN
82
-
83
- # Filter out samples whose encoded length >= MAX_LEN
84
- filtered_ds = dataset.filter(is_within_max_len)
85
- dataset = filtered_ds.select(range(N_SAMPLES))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  new_dataset = dataset.train_test_split(test_size=0.2)
87
 
88
  # Configure training arguments
 
26
  from datasets import load_dataset
27
  import os
28
  import torch
29
+ from datasets import load_dataset, Dataset
30
+ import random
31
 
32
  # Project and dataset settings
33
  PROJECT_NAME = 'SmolLM2-360M-Instruct-TaiwanChat'
 
68
  )
69
 
70
  # Prepare dataset with 5% validation split
71
+ def load_fitting_samples(dataset_id, tokenizer, max_len, n_samples, seed=3407):
72
+ # 1) Open the HF dataset in streaming mode
73
+ stream = load_dataset(dataset_id, split="train", streaming=True)
74
+
75
+ selected = []
76
+ for example in stream:
77
+ # 2) Render the chat‐template text
78
+ text = tokenizer.apply_chat_template(
79
+ example["messages"],
80
+ tokenize=False,
81
+ add_generation_prompt=True,
82
+ )
83
+ # 3) Quick length check on token IDs
84
+ tokens = tokenizer(text, add_special_tokens=False)["input_ids"]
85
+ if len(tokens) <= max_len:
86
+ selected.append({"text": text})
87
+
88
+ # 4) Stop as soon as we have enough
89
+ if len(selected) >= n_samples:
90
+ break
91
+
92
+ # 5) Shuffle and build a regular Dataset
93
+ random.Random(seed).shuffle(selected)
94
+ return Dataset.from_list(selected)
95
+
96
+ # --- usage in your script ---
97
+ dataset = load_fitting_samples(
98
+ DATASET_ID,
99
+ tokenizer=tokenizer,
100
+ max_len=MAX_LEN,
101
+ n_samples=N_SAMPLES,
102
+ seed=3407,
103
+ )
104
+
105
  new_dataset = dataset.train_test_split(test_size=0.2)
106
 
107
  # Configure training arguments