mjschock commited on
Commit
5bfd071
·
unverified ·
1 Parent(s): aecd650

Add hydra integration and configuration support in train.py, allowing dynamic model loading and training control. Update requirements.txt to include hydra-core dependency and introduce config.yaml for model parameters and training settings.

Browse files
Files changed (3) hide show
  1. conf/config.yaml +6 -0
  2. requirements.txt +1 -0
  3. train.py +24 -17
conf/config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+
4
+ model_name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
5
+ train: false
6
+ output_dir: "final_model"
requirements.txt CHANGED
@@ -4,6 +4,7 @@ bitsandbytes>=0.45.5
4
  duckduckgo-search>=8.0.1
5
  gradio[oauth]>=5.26.0
6
  hf-xet>=1.0.5
 
7
  ipywidgets>=8.1.6
8
  isort>=6.0.1
9
  jupyter>=1.1.1
 
4
  duckduckgo-search>=8.0.1
5
  gradio[oauth]>=5.26.0
6
  hf-xet>=1.0.5
7
+ hydra-core>=1.3.2
8
  ipywidgets>=8.1.6
9
  isort>=6.0.1
10
  jupyter>=1.1.1
train.py CHANGED
@@ -19,6 +19,9 @@ from datetime import datetime
19
  from pathlib import Path
20
  from typing import Union
21
 
 
 
 
22
  # isort: off
23
  from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
24
  from unsloth.chat_templates import get_chat_template # noqa: E402
@@ -41,7 +44,6 @@ from transformers import (
41
  from trl import SFTTrainer
42
 
43
  # Configuration
44
- DEFAULT_MODEL_NAME = "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
45
  dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
46
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
47
  max_seq_length = 2048 # Auto supports RoPE Scaling internally
@@ -88,7 +90,7 @@ def install_dependencies():
88
  raise
89
 
90
 
91
- def load_model(model_name: str = DEFAULT_MODEL_NAME) -> tuple[FastLanguageModel, AutoTokenizer]:
92
  """Load and configure the model."""
93
  logger.info("Loading model and tokenizer...")
94
  try:
@@ -241,16 +243,18 @@ def create_trainer(
241
  raise
242
 
243
 
244
- def main():
 
245
  """Main training function."""
246
  try:
247
  logger.info("Starting training process...")
 
248
 
249
  # Install dependencies
250
  install_dependencies()
251
 
252
  # Load model and tokenizer
253
- model, tokenizer = load_model()
254
 
255
  # Load and prepare dataset
256
  dataset, tokenizer = load_and_format_dataset(tokenizer)
@@ -258,19 +262,22 @@ def main():
258
  # Create trainer
259
  trainer: Trainer = create_trainer(model, tokenizer, dataset)
260
 
261
- # Train
262
- logger.info("Starting training...")
263
- trainer.train()
264
-
265
- # Save model
266
- logger.info("Saving final model...")
267
- trainer.save_model("final_model")
268
-
269
- # Print final metrics
270
- final_metrics = trainer.state.log_history[-1]
271
- logger.info("\nTraining completed!")
272
- logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
273
- logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
 
 
 
274
 
275
  except Exception as e:
276
  logger.error(f"Error in main training process: {e}")
 
19
  from pathlib import Path
20
  from typing import Union
21
 
22
+ import hydra
23
+ from omegaconf import DictConfig, OmegaConf
24
+
25
  # isort: off
26
  from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
27
  from unsloth.chat_templates import get_chat_template # noqa: E402
 
44
  from trl import SFTTrainer
45
 
46
  # Configuration
 
47
  dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
48
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
49
  max_seq_length = 2048 # Auto supports RoPE Scaling internally
 
90
  raise
91
 
92
 
93
+ def load_model(model_name: str) -> tuple[FastLanguageModel, AutoTokenizer]:
94
  """Load and configure the model."""
95
  logger.info("Loading model and tokenizer...")
96
  try:
 
243
  raise
244
 
245
 
246
+ @hydra.main(version_base=None, config_path="conf", config_name="config")
247
+ def main(cfg: DictConfig) -> None:
248
  """Main training function."""
249
  try:
250
  logger.info("Starting training process...")
251
+ logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
252
 
253
  # Install dependencies
254
  install_dependencies()
255
 
256
  # Load model and tokenizer
257
+ model, tokenizer = load_model(cfg.model_name)
258
 
259
  # Load and prepare dataset
260
  dataset, tokenizer = load_and_format_dataset(tokenizer)
 
262
  # Create trainer
263
  trainer: Trainer = create_trainer(model, tokenizer, dataset)
264
 
265
+ # Train if requested
266
+ if cfg.train:
267
+ logger.info("Starting training...")
268
+ trainer.train()
269
+
270
+ # Save model
271
+ logger.info(f"Saving final model to {cfg.output_dir}...")
272
+ trainer.save_model(cfg.output_dir)
273
+
274
+ # Print final metrics
275
+ final_metrics = trainer.state.log_history[-1]
276
+ logger.info("\nTraining completed!")
277
+ logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
278
+ logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
279
+ else:
280
+ logger.info("Training skipped as train=False")
281
 
282
  except Exception as e:
283
  logger.error(f"Error in main training process: {e}")