Luigi commited on
Commit
fc65dac
·
1 Parent(s): 4bf72b9

update train script

Browse files
Files changed (1) hide show
  1. train_with_unsloth.py +5 -5
train_with_unsloth.py CHANGED
@@ -147,12 +147,12 @@ new_dataset = dataset.train_test_split(test_size=0.1)
147
  # Configure training arguments
148
  training_args = SFTConfig(
149
  fp16_full_eval=False,
150
- per_device_train_batch_size=1,
151
  gradient_accumulation_steps=1,
152
  per_device_eval_batch_size=1,
153
  eval_accumulation_steps=1,
154
  evaluation_strategy="steps",
155
- eval_steps=1000,
156
  save_strategy="steps",
157
  save_steps=1000,
158
  load_best_model_at_end=True,
@@ -161,7 +161,7 @@ training_args = SFTConfig(
161
  dataset_text_field="text",
162
  output_dir=PROJECT_NAME,
163
  max_seq_length=MAX_LEN,
164
- num_train_epochs=5,
165
  learning_rate=2e-4,
166
  weight_decay=0.01,
167
  warmup_steps=500,
@@ -169,7 +169,7 @@ training_args = SFTConfig(
169
  logging_dir=f"{PROJECT_NAME}/logs",
170
  report_to=["wandb"],
171
  run_name=f"{PROJECT_NAME}_CLOUD",
172
- optim="adamw_torch",
173
  push_to_hub=False,
174
  gradient_checkpointing=False,
175
  seed=3407,
@@ -213,4 +213,4 @@ outputs = model.generate(
213
  temperature=0.8,
214
  pad_token_id=tokenizer.eos_token_id
215
  )
216
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
147
  # Configure training arguments
148
  training_args = SFTConfig(
149
  fp16_full_eval=False,
150
+ per_device_train_batch_size=40,
151
  gradient_accumulation_steps=1,
152
  per_device_eval_batch_size=1,
153
  eval_accumulation_steps=1,
154
  evaluation_strategy="steps",
155
+ eval_steps=10,
156
  save_strategy="steps",
157
  save_steps=1000,
158
  load_best_model_at_end=True,
 
161
  dataset_text_field="text",
162
  output_dir=PROJECT_NAME,
163
  max_seq_length=MAX_LEN,
164
+ num_train_epochs=3,
165
  learning_rate=2e-4,
166
  weight_decay=0.01,
167
  warmup_steps=500,
 
169
  logging_dir=f"{PROJECT_NAME}/logs",
170
  report_to=["wandb"],
171
  run_name=f"{PROJECT_NAME}_CLOUD",
172
+ optim="adamw_8bit",
173
  push_to_hub=False,
174
  gradient_checkpointing=False,
175
  seed=3407,
 
213
  temperature=0.8,
214
  pad_token_id=tokenizer.eos_token_id
215
  )
216
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))