update train script
Browse files- train_with_unsloth.py +5 -5
train_with_unsloth.py
CHANGED
|
@@ -147,12 +147,12 @@ new_dataset = dataset.train_test_split(test_size=0.1)
|
|
| 147 |
# Configure training arguments
|
| 148 |
training_args = SFTConfig(
|
| 149 |
fp16_full_eval=False,
|
| 150 |
-
per_device_train_batch_size=
|
| 151 |
gradient_accumulation_steps=1,
|
| 152 |
per_device_eval_batch_size=1,
|
| 153 |
eval_accumulation_steps=1,
|
| 154 |
evaluation_strategy="steps",
|
| 155 |
-
eval_steps=
|
| 156 |
save_strategy="steps",
|
| 157 |
save_steps=1000,
|
| 158 |
load_best_model_at_end=True,
|
|
@@ -161,7 +161,7 @@ training_args = SFTConfig(
|
|
| 161 |
dataset_text_field="text",
|
| 162 |
output_dir=PROJECT_NAME,
|
| 163 |
max_seq_length=MAX_LEN,
|
| 164 |
-
num_train_epochs=
|
| 165 |
learning_rate=2e-4,
|
| 166 |
weight_decay=0.01,
|
| 167 |
warmup_steps=500,
|
|
@@ -169,7 +169,7 @@ training_args = SFTConfig(
|
|
| 169 |
logging_dir=f"{PROJECT_NAME}/logs",
|
| 170 |
report_to=["wandb"],
|
| 171 |
run_name=f"{PROJECT_NAME}_CLOUD",
|
| 172 |
-
optim="
|
| 173 |
push_to_hub=False,
|
| 174 |
gradient_checkpointing=False,
|
| 175 |
seed=3407,
|
|
@@ -213,4 +213,4 @@ outputs = model.generate(
|
|
| 213 |
temperature=0.8,
|
| 214 |
pad_token_id=tokenizer.eos_token_id
|
| 215 |
)
|
| 216 |
-
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
|
|
| 147 |
# Configure training arguments
|
| 148 |
training_args = SFTConfig(
|
| 149 |
fp16_full_eval=False,
|
| 150 |
+
per_device_train_batch_size=40,
|
| 151 |
gradient_accumulation_steps=1,
|
| 152 |
per_device_eval_batch_size=1,
|
| 153 |
eval_accumulation_steps=1,
|
| 154 |
evaluation_strategy="steps",
|
| 155 |
+
eval_steps=10,
|
| 156 |
save_strategy="steps",
|
| 157 |
save_steps=1000,
|
| 158 |
load_best_model_at_end=True,
|
|
|
|
| 161 |
dataset_text_field="text",
|
| 162 |
output_dir=PROJECT_NAME,
|
| 163 |
max_seq_length=MAX_LEN,
|
| 164 |
+
num_train_epochs=3,
|
| 165 |
learning_rate=2e-4,
|
| 166 |
weight_decay=0.01,
|
| 167 |
warmup_steps=500,
|
|
|
|
| 169 |
logging_dir=f"{PROJECT_NAME}/logs",
|
| 170 |
report_to=["wandb"],
|
| 171 |
run_name=f"{PROJECT_NAME}_CLOUD",
|
| 172 |
+
optim="adamw_8bit",
|
| 173 |
push_to_hub=False,
|
| 174 |
gradient_checkpointing=False,
|
| 175 |
seed=3407,
|
|
|
|
| 213 |
temperature=0.8,
|
| 214 |
pad_token_id=tokenizer.eos_token_id
|
| 215 |
)
|
| 216 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|