mjschock's picture
Refactor trainer configuration in train.py for improved clarity. Clean up comments and ensure consistent formatting in evaluation strategy and model selection parameters.
aa6b654 unverified
raw
history blame
8.51 kB
#!/usr/bin/env python3
"""
Fine-tuning script for SmolLM2-135M model using Unsloth.
This script demonstrates how to:
1. Install and configure Unsloth
2. Prepare and format training data
3. Configure and run the training process
4. Save and evaluate the model
To run this script:
1. Install dependencies: pip install -r requirements.txt
2. Run: python train.py
"""
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Union
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import AutoTokenizer, Trainer, TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template
# Configuration
max_seq_length = 2048 # Auto supports RoPE Scaling internally
dtype = (
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
)
load_in_4bit = True # Use 4bit quantization to reduce memory usage
validation_split = 0.1 # 10% of data for validation
# Setup logging
def setup_logging():
"""Configure logging for the training process."""
# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# Create a unique log file name with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"training_{timestamp}.log"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
logger.info(f"Logging initialized. Log file: {log_file}")
return logger
logger = setup_logging()
def install_dependencies():
"""Install required dependencies."""
logger.info("Installing dependencies...")
try:
os.system(
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
)
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes")
logger.info("Dependencies installed successfully")
except Exception as e:
logger.error(f"Error installing dependencies: {e}")
raise
def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
"""Load and configure the model."""
logger.info("Loading model and tokenizer...")
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
logger.info("Base model loaded successfully")
# Configure LoRA
model = FastLanguageModel.get_peft_model(
model,
r=64,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=128,
lora_dropout=0.05,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=True,
loftq_config=None,
)
logger.info("LoRA configuration applied successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def load_and_format_dataset(
tokenizer: AutoTokenizer,
) -> tuple[
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
]:
"""Load and format the training dataset."""
logger.info("Loading and formatting dataset...")
try:
# Load the code-act dataset
dataset = load_dataset("xingyaoww/code-act", split="codeact")
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
# Split into train and validation sets
dataset = dataset.train_test_split(test_size=validation_split, seed=3407)
logger.info(
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
)
# Configure chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
mapping={
"role": "from",
"content": "value",
"user": "human",
"assistant": "gpt",
}, # ShareGPT style
map_eos_token=True, # Maps <|im_end|> to </s> instead
)
logger.info("Chat template configured successfully")
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
)
for convo in convos
]
return {"text": texts}
# Apply formatting to both train and validation sets
dataset = DatasetDict(
{
"train": dataset["train"].map(formatting_prompts_func, batched=True),
"validation": dataset["test"].map(
formatting_prompts_func, batched=True
),
}
)
logger.info("Dataset formatting completed successfully")
return dataset, tokenizer
except Exception as e:
logger.error(f"Error loading/formatting dataset: {e}")
raise
def create_trainer(
model: FastLanguageModel,
tokenizer: AutoTokenizer,
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
) -> Trainer:
"""Create and configure the SFTTrainer."""
logger.info("Creating trainer...")
try:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=16,
warmup_steps=100,
max_steps=120,
learning_rate=5e-5,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
eval_steps=10, # Evaluate every 10 steps
save_steps=30,
save_total_limit=2,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="cosine_with_restarts",
seed=3407,
output_dir="outputs",
gradient_checkpointing=True,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
),
)
logger.info("Trainer created successfully")
return trainer
except Exception as e:
logger.error(f"Error creating trainer: {e}")
raise
def main():
"""Main training function."""
try:
logger.info("Starting training process...")
# Install dependencies
install_dependencies()
# Load model and tokenizer
model, tokenizer = load_model()
# Load and prepare dataset
dataset, tokenizer = load_and_format_dataset(tokenizer)
# Create trainer
trainer: Trainer = create_trainer(model, tokenizer, dataset)
# Train
logger.info("Starting training...")
trainer.train()
# Save model
logger.info("Saving final model...")
trainer.save_model("final_model")
# Print final metrics
final_metrics = trainer.state.log_history[-1]
logger.info("\nTraining completed!")
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
except Exception as e:
logger.error(f"Error in main training process: {e}")
raise
if __name__ == "__main__":
main()