mjschock commited on
Commit
aa6b654
·
unverified ·
1 Parent(s): 9a87cb8

Refactor trainer configuration in train.py for improved clarity. Clean up comments and ensure consistent formatting in evaluation strategy and model selection parameters.

Browse files
Files changed (1) hide show
  1. train.py +5 -7
train.py CHANGED
@@ -192,14 +192,14 @@ def create_trainer(
192
  model=model,
193
  tokenizer=tokenizer,
194
  train_dataset=dataset["train"],
195
- eval_dataset=dataset["validation"], # Add validation dataset
196
  dataset_text_field="text",
197
  max_seq_length=max_seq_length,
198
  dataset_num_proc=2,
199
  packing=False,
200
  args=TrainingArguments(
201
  per_device_train_batch_size=2,
202
- per_device_eval_batch_size=2, # Add evaluation batch size
203
  gradient_accumulation_steps=16,
204
  warmup_steps=100,
205
  max_steps=120,
@@ -207,9 +207,7 @@ def create_trainer(
207
  fp16=not is_bfloat16_supported(),
208
  bf16=is_bfloat16_supported(),
209
  logging_steps=1,
210
- evaluation_strategy="steps", # Add evaluation strategy
211
  eval_steps=10, # Evaluate every 10 steps
212
- save_strategy="steps",
213
  save_steps=30,
214
  save_total_limit=2,
215
  optim="adamw_8bit",
@@ -218,9 +216,9 @@ def create_trainer(
218
  seed=3407,
219
  output_dir="outputs",
220
  gradient_checkpointing=True,
221
- load_best_model_at_end=True, # Load best model at the end
222
- metric_for_best_model="eval_loss", # Use validation loss for model selection
223
- greater_is_better=False, # Lower loss is better
224
  ),
225
  )
226
  logger.info("Trainer created successfully")
 
192
  model=model,
193
  tokenizer=tokenizer,
194
  train_dataset=dataset["train"],
195
+ eval_dataset=dataset["validation"],
196
  dataset_text_field="text",
197
  max_seq_length=max_seq_length,
198
  dataset_num_proc=2,
199
  packing=False,
200
  args=TrainingArguments(
201
  per_device_train_batch_size=2,
202
+ per_device_eval_batch_size=2,
203
  gradient_accumulation_steps=16,
204
  warmup_steps=100,
205
  max_steps=120,
 
207
  fp16=not is_bfloat16_supported(),
208
  bf16=is_bfloat16_supported(),
209
  logging_steps=1,
 
210
  eval_steps=10, # Evaluate every 10 steps
 
211
  save_steps=30,
212
  save_total_limit=2,
213
  optim="adamw_8bit",
 
216
  seed=3407,
217
  output_dir="outputs",
218
  gradient_checkpointing=True,
219
+ load_best_model_at_end=True,
220
+ metric_for_best_model="eval_loss",
221
+ greater_is_better=False,
222
  ),
223
  )
224
  logger.info("Trainer created successfully")