Spaces:
Build error
Build error
Enhance training script for SmolLM2-135M model by adding logging functionality, improving error handling, and implementing dataset validation split. Refactor model loading and dataset preparation processes for better clarity and robustness. Update trainer configuration to include evaluation strategy and logging of final metrics.
Browse files
train.py
CHANGED
@@ -14,6 +14,9 @@ To run this script:
|
|
14 |
"""
|
15 |
|
16 |
import os
|
|
|
|
|
|
|
17 |
from typing import Union
|
18 |
|
19 |
from datasets import (
|
@@ -34,45 +37,86 @@ dtype = (
|
|
34 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
35 |
)
|
36 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
|
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
-
|
41 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
|
45 |
"""Load and configure the model."""
|
46 |
-
model
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
def load_and_format_dataset(
|
@@ -81,36 +125,51 @@ def load_and_format_dataset(
|
|
81 |
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
|
82 |
]:
|
83 |
"""Load and format the training dataset."""
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
tokenizer,
|
90 |
-
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
|
91 |
-
mapping={
|
92 |
-
"role": "from",
|
93 |
-
"content": "value",
|
94 |
-
"user": "human",
|
95 |
-
"assistant": "gpt",
|
96 |
-
}, # ShareGPT style
|
97 |
-
map_eos_token=True, # Maps <|im_end|> to </s> instead
|
98 |
-
)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
tokenizer.apply_chat_template(
|
104 |
-
convo, tokenize=False, add_generation_prompt=False
|
105 |
-
)
|
106 |
-
for convo in convos
|
107 |
-
]
|
108 |
-
return {"text": texts}
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
def create_trainer(
|
@@ -119,55 +178,84 @@ def create_trainer(
|
|
119 |
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
|
120 |
) -> Trainer:
|
121 |
"""Create and configure the SFTTrainer."""
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
|
152 |
def main():
|
153 |
"""Main training function."""
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
|
173 |
if __name__ == "__main__":
|
|
|
14 |
"""
|
15 |
|
16 |
import os
|
17 |
+
import logging
|
18 |
+
from datetime import datetime
|
19 |
+
from pathlib import Path
|
20 |
from typing import Union
|
21 |
|
22 |
from datasets import (
|
|
|
37 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
38 |
)
|
39 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
40 |
+
validation_split = 0.1 # 10% of data for validation
|
41 |
|
42 |
+
# Setup logging
|
43 |
+
def setup_logging():
|
44 |
+
"""Configure logging for the training process."""
|
45 |
+
# Create logs directory if it doesn't exist
|
46 |
+
log_dir = Path("logs")
|
47 |
+
log_dir.mkdir(exist_ok=True)
|
48 |
+
|
49 |
+
# Create a unique log file name with timestamp
|
50 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
51 |
+
log_file = log_dir / f"training_{timestamp}.log"
|
52 |
+
|
53 |
+
# Configure logging
|
54 |
+
logging.basicConfig(
|
55 |
+
level=logging.INFO,
|
56 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
57 |
+
handlers=[
|
58 |
+
logging.FileHandler(log_file),
|
59 |
+
logging.StreamHandler()
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
logger.info(f"Logging initialized. Log file: {log_file}")
|
65 |
+
return logger
|
66 |
+
|
67 |
+
logger = setup_logging()
|
68 |
+
|
69 |
+
def install_dependencies():
|
70 |
+
"""Install required dependencies."""
|
71 |
+
logger.info("Installing dependencies...")
|
72 |
+
try:
|
73 |
+
os.system('pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"')
|
74 |
+
os.system('pip install --no-deps xformers trl peft accelerate bitsandbytes')
|
75 |
+
logger.info("Dependencies installed successfully")
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error installing dependencies: {e}")
|
78 |
+
raise
|
79 |
|
80 |
|
81 |
def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
|
82 |
"""Load and configure the model."""
|
83 |
+
logger.info("Loading model and tokenizer...")
|
84 |
+
try:
|
85 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
86 |
+
model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
|
87 |
+
max_seq_length=max_seq_length,
|
88 |
+
dtype=dtype,
|
89 |
+
load_in_4bit=load_in_4bit,
|
90 |
+
)
|
91 |
+
logger.info("Base model loaded successfully")
|
92 |
|
93 |
+
# Configure LoRA
|
94 |
+
model = FastLanguageModel.get_peft_model(
|
95 |
+
model,
|
96 |
+
r=64,
|
97 |
+
target_modules=[
|
98 |
+
"q_proj",
|
99 |
+
"k_proj",
|
100 |
+
"v_proj",
|
101 |
+
"o_proj",
|
102 |
+
"gate_proj",
|
103 |
+
"up_proj",
|
104 |
+
"down_proj",
|
105 |
+
],
|
106 |
+
lora_alpha=128,
|
107 |
+
lora_dropout=0.05,
|
108 |
+
bias="none",
|
109 |
+
use_gradient_checkpointing="unsloth",
|
110 |
+
random_state=3407,
|
111 |
+
use_rslora=True,
|
112 |
+
loftq_config=None,
|
113 |
+
)
|
114 |
+
logger.info("LoRA configuration applied successfully")
|
115 |
|
116 |
+
return model, tokenizer
|
117 |
+
except Exception as e:
|
118 |
+
logger.error(f"Error loading model: {e}")
|
119 |
+
raise
|
120 |
|
121 |
|
122 |
def load_and_format_dataset(
|
|
|
125 |
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
|
126 |
]:
|
127 |
"""Load and format the training dataset."""
|
128 |
+
logger.info("Loading and formatting dataset...")
|
129 |
+
try:
|
130 |
+
# Load the code-act dataset
|
131 |
+
dataset = load_dataset("xingyaoww/code-act", split="codeact")
|
132 |
+
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
# Split into train and validation sets
|
135 |
+
dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
|
136 |
+
logger.info(f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets")
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
# Configure chat template
|
139 |
+
tokenizer = get_chat_template(
|
140 |
+
tokenizer,
|
141 |
+
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
|
142 |
+
mapping={
|
143 |
+
"role": "from",
|
144 |
+
"content": "value",
|
145 |
+
"user": "human",
|
146 |
+
"assistant": "gpt",
|
147 |
+
}, # ShareGPT style
|
148 |
+
map_eos_token=True, # Maps <|im_end|> to </s> instead
|
149 |
+
)
|
150 |
+
logger.info("Chat template configured successfully")
|
151 |
|
152 |
+
def formatting_prompts_func(examples):
|
153 |
+
convos = examples["conversations"]
|
154 |
+
texts = [
|
155 |
+
tokenizer.apply_chat_template(
|
156 |
+
convo, tokenize=False, add_generation_prompt=False
|
157 |
+
)
|
158 |
+
for convo in convos
|
159 |
+
]
|
160 |
+
return {"text": texts}
|
161 |
+
|
162 |
+
# Apply formatting to both train and validation sets
|
163 |
+
dataset = DatasetDict({
|
164 |
+
"train": dataset["train"].map(formatting_prompts_func, batched=True),
|
165 |
+
"validation": dataset["test"].map(formatting_prompts_func, batched=True)
|
166 |
+
})
|
167 |
+
logger.info("Dataset formatting completed successfully")
|
168 |
+
|
169 |
+
return dataset, tokenizer
|
170 |
+
except Exception as e:
|
171 |
+
logger.error(f"Error loading/formatting dataset: {e}")
|
172 |
+
raise
|
173 |
|
174 |
|
175 |
def create_trainer(
|
|
|
178 |
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
|
179 |
) -> Trainer:
|
180 |
"""Create and configure the SFTTrainer."""
|
181 |
+
logger.info("Creating trainer...")
|
182 |
+
try:
|
183 |
+
trainer = SFTTrainer(
|
184 |
+
model=model,
|
185 |
+
tokenizer=tokenizer,
|
186 |
+
train_dataset=dataset["train"],
|
187 |
+
eval_dataset=dataset["validation"], # Add validation dataset
|
188 |
+
dataset_text_field="text",
|
189 |
+
max_seq_length=max_seq_length,
|
190 |
+
dataset_num_proc=2,
|
191 |
+
packing=False,
|
192 |
+
args=TrainingArguments(
|
193 |
+
per_device_train_batch_size=2,
|
194 |
+
per_device_eval_batch_size=2, # Add evaluation batch size
|
195 |
+
gradient_accumulation_steps=16,
|
196 |
+
warmup_steps=100,
|
197 |
+
max_steps=120,
|
198 |
+
learning_rate=5e-5,
|
199 |
+
fp16=not is_bfloat16_supported(),
|
200 |
+
bf16=is_bfloat16_supported(),
|
201 |
+
logging_steps=1,
|
202 |
+
evaluation_strategy="steps", # Add evaluation strategy
|
203 |
+
eval_steps=10, # Evaluate every 10 steps
|
204 |
+
save_strategy="steps",
|
205 |
+
save_steps=30,
|
206 |
+
save_total_limit=2,
|
207 |
+
optim="adamw_8bit",
|
208 |
+
weight_decay=0.01,
|
209 |
+
lr_scheduler_type="cosine_with_restarts",
|
210 |
+
seed=3407,
|
211 |
+
output_dir="outputs",
|
212 |
+
gradient_checkpointing=True,
|
213 |
+
load_best_model_at_end=True, # Load best model at the end
|
214 |
+
metric_for_best_model="eval_loss", # Use validation loss for model selection
|
215 |
+
greater_is_better=False, # Lower loss is better
|
216 |
+
),
|
217 |
+
)
|
218 |
+
logger.info("Trainer created successfully")
|
219 |
+
return trainer
|
220 |
+
except Exception as e:
|
221 |
+
logger.error(f"Error creating trainer: {e}")
|
222 |
+
raise
|
223 |
|
224 |
|
225 |
def main():
|
226 |
"""Main training function."""
|
227 |
+
try:
|
228 |
+
logger.info("Starting training process...")
|
229 |
+
|
230 |
+
# Install dependencies
|
231 |
+
install_dependencies()
|
232 |
+
|
233 |
+
# Load model and tokenizer
|
234 |
+
model, tokenizer = load_model()
|
235 |
+
|
236 |
+
# Load and prepare dataset
|
237 |
+
dataset, tokenizer = load_and_format_dataset(tokenizer)
|
238 |
+
|
239 |
+
# Create trainer
|
240 |
+
trainer: Trainer = create_trainer(model, tokenizer, dataset)
|
241 |
+
|
242 |
+
# Train
|
243 |
+
logger.info("Starting training...")
|
244 |
+
trainer.train()
|
245 |
+
|
246 |
+
# Save model
|
247 |
+
logger.info("Saving final model...")
|
248 |
+
trainer.save_model("final_model")
|
249 |
+
|
250 |
+
# Print final metrics
|
251 |
+
final_metrics = trainer.state.log_history[-1]
|
252 |
+
logger.info("\nTraining completed!")
|
253 |
+
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
254 |
+
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
|
255 |
+
|
256 |
+
except Exception as e:
|
257 |
+
logger.error(f"Error in main training process: {e}")
|
258 |
+
raise
|
259 |
|
260 |
|
261 |
if __name__ == "__main__":
|