mjschock commited on
Commit
611c848
·
unverified ·
1 Parent(s): 5bfd071

Refactor train.py to utilize a comprehensive configuration structure from config.yaml, enhancing model loading, dataset handling, and trainer setup. This update centralizes parameters for model, PEFT, dataset, and training settings, improving maintainability and flexibility.

Browse files
Files changed (2) hide show
  1. conf/config.yaml +68 -2
  2. train.py +40 -67
conf/config.yaml CHANGED
@@ -1,6 +1,72 @@
1
  defaults:
2
  - _self_
3
 
4
- model_name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  train: false
6
- output_dir: "final_model"
 
1
  defaults:
2
  - _self_
3
 
4
+ # Model configuration
5
+ model:
6
+ name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
7
+ max_seq_length: 2048 # Auto supports RoPE Scaling internally
8
+ dtype: null # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
9
+ load_in_4bit: true # Use 4bit quantization to reduce memory usage
10
+
11
+ # PEFT configuration
12
+ peft:
13
+ r: 64
14
+ lora_alpha: 128
15
+ lora_dropout: 0.05
16
+ bias: "none"
17
+ use_gradient_checkpointing: "unsloth"
18
+ random_state: 3407
19
+ use_rslora: true
20
+ loftq_config: null
21
+ target_modules:
22
+ - "q_proj"
23
+ - "k_proj"
24
+ - "v_proj"
25
+ - "o_proj"
26
+ - "gate_proj"
27
+ - "up_proj"
28
+ - "down_proj"
29
+
30
+ # Dataset configuration
31
+ dataset:
32
+ validation_split: 0.1 # 10% of data for validation
33
+ seed: 3407 # Random seed for dataset splitting
34
+
35
+ # Training configuration
36
+ training:
37
+ args:
38
+ per_device_train_batch_size: 2
39
+ per_device_eval_batch_size: 2
40
+ gradient_accumulation_steps: 16
41
+ warmup_steps: 100
42
+ max_steps: 120
43
+ learning_rate: 5e-5
44
+ logging_steps: 1
45
+ save_strategy: "steps"
46
+ save_steps: 30
47
+ eval_strategy: "steps"
48
+ eval_steps: 30
49
+ save_total_limit: 2
50
+ optim: "adamw_8bit"
51
+ weight_decay: 0.01
52
+ lr_scheduler_type: "cosine_with_restarts"
53
+ seed: 3407
54
+ output_dir: "outputs"
55
+ gradient_checkpointing: true
56
+ load_best_model_at_end: true
57
+ metric_for_best_model: "eval_loss"
58
+ greater_is_better: false
59
+
60
+ sft:
61
+ dataset_num_proc: 2
62
+ packing: false
63
+ data_collator:
64
+ mlm: false
65
+ pad_to_multiple_of: 8
66
+
67
+ # Output configuration
68
+ output:
69
+ dir: "final_model"
70
+
71
+ # Training control
72
  train: false
 
train.py CHANGED
@@ -43,13 +43,6 @@ from transformers import (
43
  )
44
  from trl import SFTTrainer
45
 
46
- # Configuration
47
- dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
48
- load_in_4bit = True # Use 4bit quantization to reduce memory usage
49
- max_seq_length = 2048 # Auto supports RoPE Scaling internally
50
- validation_split = 0.1 # 10% of data for validation
51
-
52
-
53
  # Setup logging
54
  def setup_logging():
55
  """Configure logging for the training process."""
@@ -90,38 +83,30 @@ def install_dependencies():
90
  raise
91
 
92
 
93
- def load_model(model_name: str) -> tuple[FastLanguageModel, AutoTokenizer]:
94
  """Load and configure the model."""
95
  logger.info("Loading model and tokenizer...")
96
  try:
97
  model, tokenizer = FastLanguageModel.from_pretrained(
98
- model_name=model_name,
99
- max_seq_length=max_seq_length,
100
- dtype=dtype,
101
- load_in_4bit=load_in_4bit,
102
  )
103
  logger.info("Base model loaded successfully")
104
 
105
  # Configure LoRA
106
  model = FastLanguageModel.get_peft_model(
107
  model,
108
- r=64,
109
- target_modules=[
110
- "q_proj",
111
- "k_proj",
112
- "v_proj",
113
- "o_proj",
114
- "gate_proj",
115
- "up_proj",
116
- "down_proj",
117
- ],
118
- lora_alpha=128,
119
- lora_dropout=0.05,
120
- bias="none",
121
- use_gradient_checkpointing="unsloth",
122
- random_state=3407,
123
- use_rslora=True,
124
- loftq_config=None,
125
  )
126
  logger.info("LoRA configuration applied successfully")
127
 
@@ -133,6 +118,7 @@ def load_model(model_name: str) -> tuple[FastLanguageModel, AutoTokenizer]:
133
 
134
  def load_and_format_dataset(
135
  tokenizer: AutoTokenizer,
 
136
  ) -> tuple[
137
  Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
138
  ]:
@@ -144,7 +130,7 @@ def load_and_format_dataset(
144
  logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
145
 
146
  # Split into train and validation sets
147
- dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
148
  logger.info(
149
  f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
150
  )
@@ -194,47 +180,34 @@ def create_trainer(
194
  model: FastLanguageModel,
195
  tokenizer: AutoTokenizer,
196
  dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
 
197
  ) -> Trainer:
198
  """Create and configure the SFTTrainer."""
199
  logger.info("Creating trainer...")
200
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  trainer = SFTTrainer(
202
  model=model,
203
  tokenizer=tokenizer,
204
  train_dataset=dataset["train"],
205
  eval_dataset=dataset["validation"],
206
- dataset_num_proc=2,
207
- packing=False,
208
- args=TrainingArguments(
209
- per_device_train_batch_size=2,
210
- per_device_eval_batch_size=2,
211
- gradient_accumulation_steps=16,
212
- warmup_steps=100,
213
- max_steps=120,
214
- learning_rate=5e-5,
215
- fp16=not is_bfloat16_supported(),
216
- bf16=is_bfloat16_supported(),
217
- logging_steps=1,
218
- save_strategy="steps",
219
- save_steps=30,
220
- eval_strategy="steps",
221
- eval_steps=30,
222
- save_total_limit=2,
223
- optim="adamw_8bit",
224
- weight_decay=0.01,
225
- lr_scheduler_type="cosine_with_restarts",
226
- seed=3407,
227
- output_dir="outputs",
228
- gradient_checkpointing=True,
229
- load_best_model_at_end=True,
230
- metric_for_best_model="eval_loss",
231
- greater_is_better=False,
232
- ),
233
- data_collator=DataCollatorForLanguageModeling(
234
- tokenizer=tokenizer,
235
- mlm=False,
236
- pad_to_multiple_of=8,
237
- ),
238
  )
239
  logger.info("Trainer created successfully")
240
  return trainer
@@ -254,13 +227,13 @@ def main(cfg: DictConfig) -> None:
254
  install_dependencies()
255
 
256
  # Load model and tokenizer
257
- model, tokenizer = load_model(cfg.model_name)
258
 
259
  # Load and prepare dataset
260
- dataset, tokenizer = load_and_format_dataset(tokenizer)
261
 
262
  # Create trainer
263
- trainer: Trainer = create_trainer(model, tokenizer, dataset)
264
 
265
  # Train if requested
266
  if cfg.train:
@@ -268,8 +241,8 @@ def main(cfg: DictConfig) -> None:
268
  trainer.train()
269
 
270
  # Save model
271
- logger.info(f"Saving final model to {cfg.output_dir}...")
272
- trainer.save_model(cfg.output_dir)
273
 
274
  # Print final metrics
275
  final_metrics = trainer.state.log_history[-1]
 
43
  )
44
  from trl import SFTTrainer
45
 
 
 
 
 
 
 
 
46
  # Setup logging
47
  def setup_logging():
48
  """Configure logging for the training process."""
 
83
  raise
84
 
85
 
86
+ def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]:
87
  """Load and configure the model."""
88
  logger.info("Loading model and tokenizer...")
89
  try:
90
  model, tokenizer = FastLanguageModel.from_pretrained(
91
+ model_name=cfg.model.name,
92
+ max_seq_length=cfg.model.max_seq_length,
93
+ dtype=cfg.model.dtype,
94
+ load_in_4bit=cfg.model.load_in_4bit,
95
  )
96
  logger.info("Base model loaded successfully")
97
 
98
  # Configure LoRA
99
  model = FastLanguageModel.get_peft_model(
100
  model,
101
+ r=cfg.peft.r,
102
+ target_modules=cfg.peft.target_modules,
103
+ lora_alpha=cfg.peft.lora_alpha,
104
+ lora_dropout=cfg.peft.lora_dropout,
105
+ bias=cfg.peft.bias,
106
+ use_gradient_checkpointing=cfg.peft.use_gradient_checkpointing,
107
+ random_state=cfg.peft.random_state,
108
+ use_rslora=cfg.peft.use_rslora,
109
+ loftq_config=cfg.peft.loftq_config,
 
 
 
 
 
 
 
 
110
  )
111
  logger.info("LoRA configuration applied successfully")
112
 
 
118
 
119
  def load_and_format_dataset(
120
  tokenizer: AutoTokenizer,
121
+ cfg: DictConfig,
122
  ) -> tuple[
123
  Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
124
  ]:
 
130
  logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
131
 
132
  # Split into train and validation sets
133
+ dataset = dataset.train_test_split(test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed)
134
  logger.info(
135
  f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
136
  )
 
180
  model: FastLanguageModel,
181
  tokenizer: AutoTokenizer,
182
  dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
183
+ cfg: DictConfig,
184
  ) -> Trainer:
185
  """Create and configure the SFTTrainer."""
186
  logger.info("Creating trainer...")
187
  try:
188
+ # Create TrainingArguments from config
189
+ training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
190
+ # Add dynamic precision settings
191
+ training_args_dict.update({
192
+ "fp16": not is_bfloat16_supported(),
193
+ "bf16": is_bfloat16_supported(),
194
+ })
195
+ training_args = TrainingArguments(**training_args_dict)
196
+
197
+ # Create data collator from config
198
+ data_collator = DataCollatorForLanguageModeling(
199
+ tokenizer=tokenizer,
200
+ **cfg.training.sft.data_collator,
201
+ )
202
+
203
  trainer = SFTTrainer(
204
  model=model,
205
  tokenizer=tokenizer,
206
  train_dataset=dataset["train"],
207
  eval_dataset=dataset["validation"],
208
+ args=training_args,
209
+ data_collator=data_collator,
210
+ **cfg.training.sft,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
  logger.info("Trainer created successfully")
213
  return trainer
 
227
  install_dependencies()
228
 
229
  # Load model and tokenizer
230
+ model, tokenizer = load_model(cfg)
231
 
232
  # Load and prepare dataset
233
+ dataset, tokenizer = load_and_format_dataset(tokenizer, cfg)
234
 
235
  # Create trainer
236
+ trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg)
237
 
238
  # Train if requested
239
  if cfg.train:
 
241
  trainer.train()
242
 
243
  # Save model
244
+ logger.info(f"Saving final model to {cfg.output.dir}...")
245
+ trainer.save_model(cfg.output.dir)
246
 
247
  # Print final metrics
248
  final_metrics = trainer.state.log_history[-1]