mjschock commited on
Commit
9a87cb8
·
unverified ·
1 Parent(s): 70eb9de

Refactor train.py to improve code readability and organization. Adjust logging setup for clarity, streamline dependency installation commands, and enhance dataset splitting and formatting processes. Ensure consistent formatting in log messages and code structure.

Browse files
Files changed (1) hide show
  1. train.py +32 -24
train.py CHANGED
@@ -13,8 +13,8 @@ To run this script:
13
  2. Run: python train.py
14
  """
15
 
16
- import os
17
  import logging
 
18
  from datetime import datetime
19
  from pathlib import Path
20
  from typing import Union
@@ -39,39 +39,41 @@ dtype = (
39
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
40
  validation_split = 0.1 # 10% of data for validation
41
 
 
42
  # Setup logging
43
  def setup_logging():
44
  """Configure logging for the training process."""
45
  # Create logs directory if it doesn't exist
46
  log_dir = Path("logs")
47
  log_dir.mkdir(exist_ok=True)
48
-
49
  # Create a unique log file name with timestamp
50
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
51
  log_file = log_dir / f"training_{timestamp}.log"
52
-
53
  # Configure logging
54
  logging.basicConfig(
55
  level=logging.INFO,
56
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
57
- handlers=[
58
- logging.FileHandler(log_file),
59
- logging.StreamHandler()
60
- ]
61
  )
62
-
63
  logger = logging.getLogger(__name__)
64
  logger.info(f"Logging initialized. Log file: {log_file}")
65
  return logger
66
 
 
67
  logger = setup_logging()
68
 
 
69
  def install_dependencies():
70
  """Install required dependencies."""
71
  logger.info("Installing dependencies...")
72
  try:
73
- os.system('pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"')
74
- os.system('pip install --no-deps xformers trl peft accelerate bitsandbytes')
 
 
75
  logger.info("Dependencies installed successfully")
76
  except Exception as e:
77
  logger.error(f"Error installing dependencies: {e}")
@@ -133,7 +135,9 @@ def load_and_format_dataset(
133
 
134
  # Split into train and validation sets
135
  dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
136
- logger.info(f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets")
 
 
137
 
138
  # Configure chat template
139
  tokenizer = get_chat_template(
@@ -160,10 +164,14 @@ def load_and_format_dataset(
160
  return {"text": texts}
161
 
162
  # Apply formatting to both train and validation sets
163
- dataset = DatasetDict({
164
- "train": dataset["train"].map(formatting_prompts_func, batched=True),
165
- "validation": dataset["test"].map(formatting_prompts_func, batched=True)
166
- })
 
 
 
 
167
  logger.info("Dataset formatting completed successfully")
168
 
169
  return dataset, tokenizer
@@ -226,33 +234,33 @@ def main():
226
  """Main training function."""
227
  try:
228
  logger.info("Starting training process...")
229
-
230
  # Install dependencies
231
  install_dependencies()
232
-
233
  # Load model and tokenizer
234
  model, tokenizer = load_model()
235
-
236
  # Load and prepare dataset
237
  dataset, tokenizer = load_and_format_dataset(tokenizer)
238
-
239
  # Create trainer
240
  trainer: Trainer = create_trainer(model, tokenizer, dataset)
241
-
242
  # Train
243
  logger.info("Starting training...")
244
  trainer.train()
245
-
246
  # Save model
247
  logger.info("Saving final model...")
248
  trainer.save_model("final_model")
249
-
250
  # Print final metrics
251
  final_metrics = trainer.state.log_history[-1]
252
  logger.info("\nTraining completed!")
253
  logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
254
  logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
255
-
256
  except Exception as e:
257
  logger.error(f"Error in main training process: {e}")
258
  raise
 
13
  2. Run: python train.py
14
  """
15
 
 
16
  import logging
17
+ import os
18
  from datetime import datetime
19
  from pathlib import Path
20
  from typing import Union
 
39
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
40
  validation_split = 0.1 # 10% of data for validation
41
 
42
+
43
  # Setup logging
44
  def setup_logging():
45
  """Configure logging for the training process."""
46
  # Create logs directory if it doesn't exist
47
  log_dir = Path("logs")
48
  log_dir.mkdir(exist_ok=True)
49
+
50
  # Create a unique log file name with timestamp
51
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
  log_file = log_dir / f"training_{timestamp}.log"
53
+
54
  # Configure logging
55
  logging.basicConfig(
56
  level=logging.INFO,
57
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
58
+ handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
 
 
 
59
  )
60
+
61
  logger = logging.getLogger(__name__)
62
  logger.info(f"Logging initialized. Log file: {log_file}")
63
  return logger
64
 
65
+
66
  logger = setup_logging()
67
 
68
+
69
  def install_dependencies():
70
  """Install required dependencies."""
71
  logger.info("Installing dependencies...")
72
  try:
73
+ os.system(
74
+ 'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
75
+ )
76
+ os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes")
77
  logger.info("Dependencies installed successfully")
78
  except Exception as e:
79
  logger.error(f"Error installing dependencies: {e}")
 
135
 
136
  # Split into train and validation sets
137
  dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
138
+ logger.info(
139
+ f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
140
+ )
141
 
142
  # Configure chat template
143
  tokenizer = get_chat_template(
 
164
  return {"text": texts}
165
 
166
  # Apply formatting to both train and validation sets
167
+ dataset = DatasetDict(
168
+ {
169
+ "train": dataset["train"].map(formatting_prompts_func, batched=True),
170
+ "validation": dataset["test"].map(
171
+ formatting_prompts_func, batched=True
172
+ ),
173
+ }
174
+ )
175
  logger.info("Dataset formatting completed successfully")
176
 
177
  return dataset, tokenizer
 
234
  """Main training function."""
235
  try:
236
  logger.info("Starting training process...")
237
+
238
  # Install dependencies
239
  install_dependencies()
240
+
241
  # Load model and tokenizer
242
  model, tokenizer = load_model()
243
+
244
  # Load and prepare dataset
245
  dataset, tokenizer = load_and_format_dataset(tokenizer)
246
+
247
  # Create trainer
248
  trainer: Trainer = create_trainer(model, tokenizer, dataset)
249
+
250
  # Train
251
  logger.info("Starting training...")
252
  trainer.train()
253
+
254
  # Save model
255
  logger.info("Saving final model...")
256
  trainer.save_model("final_model")
257
+
258
  # Print final metrics
259
  final_metrics = trainer.state.log_history[-1]
260
  logger.info("\nTraining completed!")
261
  logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
262
  logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
263
+
264
  except Exception as e:
265
  logger.error(f"Error in main training process: {e}")
266
  raise