Spaces:
Build error
Build error
Refactor train.py to improve code readability and organization. Adjust logging setup for clarity, streamline dependency installation commands, and enhance dataset splitting and formatting processes. Ensure consistent formatting in log messages and code structure.
Browse files
train.py
CHANGED
@@ -13,8 +13,8 @@ To run this script:
|
|
13 |
2. Run: python train.py
|
14 |
"""
|
15 |
|
16 |
-
import os
|
17 |
import logging
|
|
|
18 |
from datetime import datetime
|
19 |
from pathlib import Path
|
20 |
from typing import Union
|
@@ -39,39 +39,41 @@ dtype = (
|
|
39 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
40 |
validation_split = 0.1 # 10% of data for validation
|
41 |
|
|
|
42 |
# Setup logging
|
43 |
def setup_logging():
|
44 |
"""Configure logging for the training process."""
|
45 |
# Create logs directory if it doesn't exist
|
46 |
log_dir = Path("logs")
|
47 |
log_dir.mkdir(exist_ok=True)
|
48 |
-
|
49 |
# Create a unique log file name with timestamp
|
50 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
51 |
log_file = log_dir / f"training_{timestamp}.log"
|
52 |
-
|
53 |
# Configure logging
|
54 |
logging.basicConfig(
|
55 |
level=logging.INFO,
|
56 |
-
format=
|
57 |
-
handlers=[
|
58 |
-
logging.FileHandler(log_file),
|
59 |
-
logging.StreamHandler()
|
60 |
-
]
|
61 |
)
|
62 |
-
|
63 |
logger = logging.getLogger(__name__)
|
64 |
logger.info(f"Logging initialized. Log file: {log_file}")
|
65 |
return logger
|
66 |
|
|
|
67 |
logger = setup_logging()
|
68 |
|
|
|
69 |
def install_dependencies():
|
70 |
"""Install required dependencies."""
|
71 |
logger.info("Installing dependencies...")
|
72 |
try:
|
73 |
-
os.system(
|
74 |
-
|
|
|
|
|
75 |
logger.info("Dependencies installed successfully")
|
76 |
except Exception as e:
|
77 |
logger.error(f"Error installing dependencies: {e}")
|
@@ -133,7 +135,9 @@ def load_and_format_dataset(
|
|
133 |
|
134 |
# Split into train and validation sets
|
135 |
dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
|
136 |
-
logger.info(
|
|
|
|
|
137 |
|
138 |
# Configure chat template
|
139 |
tokenizer = get_chat_template(
|
@@ -160,10 +164,14 @@ def load_and_format_dataset(
|
|
160 |
return {"text": texts}
|
161 |
|
162 |
# Apply formatting to both train and validation sets
|
163 |
-
dataset = DatasetDict(
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
167 |
logger.info("Dataset formatting completed successfully")
|
168 |
|
169 |
return dataset, tokenizer
|
@@ -226,33 +234,33 @@ def main():
|
|
226 |
"""Main training function."""
|
227 |
try:
|
228 |
logger.info("Starting training process...")
|
229 |
-
|
230 |
# Install dependencies
|
231 |
install_dependencies()
|
232 |
-
|
233 |
# Load model and tokenizer
|
234 |
model, tokenizer = load_model()
|
235 |
-
|
236 |
# Load and prepare dataset
|
237 |
dataset, tokenizer = load_and_format_dataset(tokenizer)
|
238 |
-
|
239 |
# Create trainer
|
240 |
trainer: Trainer = create_trainer(model, tokenizer, dataset)
|
241 |
-
|
242 |
# Train
|
243 |
logger.info("Starting training...")
|
244 |
trainer.train()
|
245 |
-
|
246 |
# Save model
|
247 |
logger.info("Saving final model...")
|
248 |
trainer.save_model("final_model")
|
249 |
-
|
250 |
# Print final metrics
|
251 |
final_metrics = trainer.state.log_history[-1]
|
252 |
logger.info("\nTraining completed!")
|
253 |
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
254 |
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
|
255 |
-
|
256 |
except Exception as e:
|
257 |
logger.error(f"Error in main training process: {e}")
|
258 |
raise
|
|
|
13 |
2. Run: python train.py
|
14 |
"""
|
15 |
|
|
|
16 |
import logging
|
17 |
+
import os
|
18 |
from datetime import datetime
|
19 |
from pathlib import Path
|
20 |
from typing import Union
|
|
|
39 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
40 |
validation_split = 0.1 # 10% of data for validation
|
41 |
|
42 |
+
|
43 |
# Setup logging
|
44 |
def setup_logging():
|
45 |
"""Configure logging for the training process."""
|
46 |
# Create logs directory if it doesn't exist
|
47 |
log_dir = Path("logs")
|
48 |
log_dir.mkdir(exist_ok=True)
|
49 |
+
|
50 |
# Create a unique log file name with timestamp
|
51 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
52 |
log_file = log_dir / f"training_{timestamp}.log"
|
53 |
+
|
54 |
# Configure logging
|
55 |
logging.basicConfig(
|
56 |
level=logging.INFO,
|
57 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
58 |
+
handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
|
|
|
|
|
|
|
59 |
)
|
60 |
+
|
61 |
logger = logging.getLogger(__name__)
|
62 |
logger.info(f"Logging initialized. Log file: {log_file}")
|
63 |
return logger
|
64 |
|
65 |
+
|
66 |
logger = setup_logging()
|
67 |
|
68 |
+
|
69 |
def install_dependencies():
|
70 |
"""Install required dependencies."""
|
71 |
logger.info("Installing dependencies...")
|
72 |
try:
|
73 |
+
os.system(
|
74 |
+
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
|
75 |
+
)
|
76 |
+
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes")
|
77 |
logger.info("Dependencies installed successfully")
|
78 |
except Exception as e:
|
79 |
logger.error(f"Error installing dependencies: {e}")
|
|
|
135 |
|
136 |
# Split into train and validation sets
|
137 |
dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
|
138 |
+
logger.info(
|
139 |
+
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
|
140 |
+
)
|
141 |
|
142 |
# Configure chat template
|
143 |
tokenizer = get_chat_template(
|
|
|
164 |
return {"text": texts}
|
165 |
|
166 |
# Apply formatting to both train and validation sets
|
167 |
+
dataset = DatasetDict(
|
168 |
+
{
|
169 |
+
"train": dataset["train"].map(formatting_prompts_func, batched=True),
|
170 |
+
"validation": dataset["test"].map(
|
171 |
+
formatting_prompts_func, batched=True
|
172 |
+
),
|
173 |
+
}
|
174 |
+
)
|
175 |
logger.info("Dataset formatting completed successfully")
|
176 |
|
177 |
return dataset, tokenizer
|
|
|
234 |
"""Main training function."""
|
235 |
try:
|
236 |
logger.info("Starting training process...")
|
237 |
+
|
238 |
# Install dependencies
|
239 |
install_dependencies()
|
240 |
+
|
241 |
# Load model and tokenizer
|
242 |
model, tokenizer = load_model()
|
243 |
+
|
244 |
# Load and prepare dataset
|
245 |
dataset, tokenizer = load_and_format_dataset(tokenizer)
|
246 |
+
|
247 |
# Create trainer
|
248 |
trainer: Trainer = create_trainer(model, tokenizer, dataset)
|
249 |
+
|
250 |
# Train
|
251 |
logger.info("Starting training...")
|
252 |
trainer.train()
|
253 |
+
|
254 |
# Save model
|
255 |
logger.info("Saving final model...")
|
256 |
trainer.save_model("final_model")
|
257 |
+
|
258 |
# Print final metrics
|
259 |
final_metrics = trainer.state.log_history[-1]
|
260 |
logger.info("\nTraining completed!")
|
261 |
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
262 |
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
|
263 |
+
|
264 |
except Exception as e:
|
265 |
logger.error(f"Error in main training process: {e}")
|
266 |
raise
|