update train script
Browse files- train_with_unsloth.py +14 -1
train_with_unsloth.py
CHANGED
|
@@ -28,6 +28,7 @@ import os
|
|
| 28 |
import torch
|
| 29 |
import random
|
| 30 |
import logging
|
|
|
|
| 31 |
|
| 32 |
logging.basicConfig(level=logging.WARNING)
|
| 33 |
logger = logging.getLogger(__name__)
|
|
@@ -129,6 +130,18 @@ dataset = load_fitting_samples(
|
|
| 129 |
seed=3407,
|
| 130 |
)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
new_dataset = dataset.train_test_split(test_size=0.01)
|
| 133 |
|
| 134 |
# Configure training arguments
|
|
@@ -139,7 +152,7 @@ training_args = SFTConfig(
|
|
| 139 |
per_device_eval_batch_size=1,
|
| 140 |
eval_accumulation_steps=4,
|
| 141 |
evaluation_strategy="steps",
|
| 142 |
-
eval_steps=
|
| 143 |
save_strategy="steps",
|
| 144 |
save_steps=500,
|
| 145 |
load_best_model_at_end=True,
|
|
|
|
| 28 |
import torch
|
| 29 |
import random
|
| 30 |
import logging
|
| 31 |
+
import re
|
| 32 |
|
| 33 |
logging.basicConfig(level=logging.WARNING)
|
| 34 |
logger = logging.getLogger(__name__)
|
|
|
|
| 130 |
seed=3407,
|
| 131 |
)
|
| 132 |
|
| 133 |
+
def clean_assistant_marker(example):
|
| 134 |
+
# collapse any "<|im_start|>assistant\n\n…\n\n" into "<|im_start|>assistant\n"
|
| 135 |
+
example["text"] = re.sub(
|
| 136 |
+
r"(<\|im_start\|>assistant)\n+",
|
| 137 |
+
r"\1\n",
|
| 138 |
+
example["text"]
|
| 139 |
+
)
|
| 140 |
+
return example
|
| 141 |
+
|
| 142 |
+
# clean: <|im_start|>assistant\n\n -> <|im_start|>assistant\n
|
| 143 |
+
dataset = dataset.map(clean_assistant_marker, batched=False)
|
| 144 |
+
|
| 145 |
new_dataset = dataset.train_test_split(test_size=0.01)
|
| 146 |
|
| 147 |
# Configure training arguments
|
|
|
|
| 152 |
per_device_eval_batch_size=1,
|
| 153 |
eval_accumulation_steps=4,
|
| 154 |
evaluation_strategy="steps",
|
| 155 |
+
eval_steps=100,
|
| 156 |
save_strategy="steps",
|
| 157 |
save_steps=500,
|
| 158 |
load_best_model_at_end=True,
|