Spaces:
Build error
Build error
#!/usr/bin/env python | |
# coding=utf-8 | |
# Basic Python imports | |
import os | |
import sys | |
import json | |
import argparse | |
import logging | |
from datetime import datetime | |
import time | |
import warnings | |
from importlib.util import find_spec | |
# Check hardware capabilities first | |
import torch | |
CUDA_AVAILABLE = torch.cuda.is_available() | |
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 | |
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" | |
# Configure logging early | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
# Set other loggers to WARNING to reduce noise and ensure our logs are visible | |
logging.getLogger("transformers").setLevel(logging.WARNING) | |
logging.getLogger("datasets").setLevel(logging.WARNING) | |
logging.getLogger("accelerate").setLevel(logging.WARNING) | |
logging.getLogger("torch").setLevel(logging.WARNING) | |
logging.getLogger("bitsandbytes").setLevel(logging.WARNING) | |
# Import Unsloth first, before other ML imports | |
try: | |
from unsloth import FastLanguageModel | |
from unsloth.chat_templates import get_chat_template | |
unsloth_available = True | |
logger.info("Unsloth successfully imported") | |
except ImportError: | |
unsloth_available = False | |
logger.warning("Unsloth not available. Please install with: pip install unsloth") | |
# Now import other ML libraries | |
try: | |
import transformers | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TrainingArguments, | |
Trainer, | |
TrainerCallback, | |
set_seed, | |
BitsAndBytesConfig | |
) | |
logger.info(f"Transformers version: {transformers.__version__}") | |
except ImportError: | |
logger.error("Transformers not available. This is a critical dependency.") | |
# Check availability of libraries | |
peft_available = find_spec("peft") is not None | |
if peft_available: | |
import peft | |
logger.info(f"PEFT version: {peft.__version__}") | |
else: | |
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
# Import datasets library after the main ML libraries | |
try: | |
from datasets import load_dataset | |
logger.info("Datasets library successfully imported") | |
except ImportError: | |
logger.error("Datasets library not available. This is required for loading training data.") | |
# Define a clean logging function for HF Space compatibility | |
def log_info(message): | |
"""Log information in a format compatible with Hugging Face Spaces""" | |
# Just use the logger, but ensure consistent formatting | |
logger.info(message) | |
# Also ensure output is flushed immediately for streaming | |
sys.stdout.flush() | |
# Check for BitsAndBytes | |
try: | |
from transformers import BitsAndBytesConfig | |
bitsandbytes_available = True | |
except ImportError: | |
bitsandbytes_available = False | |
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
# Check for PEFT | |
try: | |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
peft_available = True | |
except ImportError: | |
peft_available = False | |
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
def load_env_variables(): | |
"""Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
# Check if we're running in a Hugging Face Space | |
if os.environ.get("SPACE_ID"): | |
logging.info("Running in Hugging Face Space") | |
# Log the presence of variables (without revealing values) | |
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
# If username is not set, try to extract from SPACE_ID | |
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
username = os.environ.get("SPACE_ID").split("/")[0] | |
os.environ["HF_USERNAME"] = username | |
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
else: | |
# Try to load from .env file if not in a Space | |
try: | |
from dotenv import load_dotenv | |
# First check the current directory | |
env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") | |
if os.path.exists(env_path): | |
load_dotenv(env_path) | |
logging.info(f"Loaded environment variables from {env_path}") | |
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
else: | |
# Try the shared directory as fallback | |
shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
if os.path.exists(shared_env_path): | |
load_dotenv(shared_env_path) | |
logging.info(f"Loaded environment variables from {shared_env_path}") | |
logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
else: | |
logging.warning(f"No .env file found in current or shared directory") | |
except ImportError: | |
logging.warning("python-dotenv not installed, not loading from .env file") | |
if not os.environ.get("HF_TOKEN"): | |
logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.") | |
if not os.environ.get("HF_USERNAME"): | |
logger.warning("HF_USERNAME is not set. Using default username.") | |
if not os.environ.get("HF_SPACE_NAME"): | |
logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
# Set HF_TOKEN for huggingface_hub | |
if os.environ.get("HF_TOKEN"): | |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
def load_configs(base_path): | |
"""Load configuration from transformers_config.json file.""" | |
# Using a single consolidated config file | |
config_file = base_path | |
try: | |
with open(config_file, "r") as f: | |
config = json.load(f) | |
logger.info(f"Loaded configuration from {config_file}") | |
return config | |
except Exception as e: | |
logger.error(f"Error loading {config_file}: {e}") | |
raise | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") | |
parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file") | |
return parser.parse_args() | |
def load_model_and_tokenizer(config): | |
"""Load model and tokenizer with proper error handling and optimizations.""" | |
try: | |
if not unsloth_available: | |
logger.error("Unsloth is required for training with pre-quantized model") | |
logger.error("Please ensure unsloth is in requirements.txt") | |
raise ImportError("Unsloth is required for this training setup") | |
# Get model name correctly from config | |
model_name = config.get("model_name") or config.get("model", {}).get("name") | |
logger.info(f"Loading model: {model_name}") | |
if not model_name: | |
raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") | |
logger.info("Using Unsloth optimizations with pre-quantized model") | |
# First detect if we have a GPU | |
if torch.cuda.is_available(): | |
gpu_count = torch.cuda.device_count() | |
logger.info(f"Found {gpu_count} CUDA devices") | |
else: | |
logger.warning("No CUDA devices detected. Training will be slow on CPU!") | |
gpu_count = 0 | |
# Set default dtype for better numerics | |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
# Use bfloat16 for Ampere or newer | |
dtype = torch.bfloat16 | |
logger.info("Using bfloat16 precision (Ampere+ GPU)") | |
elif torch.cuda.is_available(): | |
# Use float16 for older GPUs | |
dtype = torch.float16 | |
logger.info("Using float16 precision (pre-Ampere GPU)") | |
else: | |
# CPU, use default dtype | |
dtype = None | |
logger.info("Using default precision (CPU)") | |
# Check for flash attention as the last dependency check | |
use_flash_attention = config.get("use_flash_attention", True) | |
if use_flash_attention and not find_spec("flash_attn"): | |
logger.warning("flash-attn not found. Will continue without flash attention.") | |
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") | |
use_flash_attention = False | |
# Set device map based on config or default to "auto" | |
device_map = config.get("hardware", {}).get("hardware_setup", {}).get("device_map", "auto") | |
# Calculate max memory settings if multiple GPUs are available | |
max_memory = None | |
if gpu_count > 1: | |
memory_per_gpu = config.get("hardware", {}).get("specs", {}).get("vram_per_gpu", 24) | |
max_memory = {i: f"{int(memory_per_gpu * 0.85)}GiB" for i in range(gpu_count)} | |
max_memory["cpu"] = "64GiB" # Allow CPU offloading if needed | |
# Load model with proper error handling for out-of-memory | |
try: | |
# Improved memory settings for multi-GPU setup | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
dtype=dtype, | |
device_map=device_map, | |
max_memory=max_memory, | |
# Don't explicitly use flash attention config here, let Unsloth handle it | |
) | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.") | |
raise | |
else: | |
# Try again with CPU placement to see if it's a memory issue | |
logger.warning(f"Error loading model on default device: {str(e)}") | |
logger.warning("Attempting to load with device_map='cpu' and no specific dtype") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
dtype=None, | |
device_map={"": "cpu"}, | |
) | |
logger.warning("Model loaded on CPU. Training will be very slow.") | |
# Ensure model and optimizer init is on the same device | |
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}") | |
# Apply Unsloth's training optimizations with config parameters | |
unsloth_config = config.get("unsloth", {}) | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=unsloth_config.get("r", 32), | |
target_modules=unsloth_config.get("target_modules", | |
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), | |
lora_alpha=unsloth_config.get("alpha", 16), | |
lora_dropout=unsloth_config.get("dropout", 0.05), | |
bias="none", | |
use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), | |
random_state=config.get("seed", 42), | |
) | |
logger.info("Unsloth optimizations applied successfully") | |
# Set up tokenizer settings | |
chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") | |
if chat_template: | |
try: | |
template = get_chat_template("phi") | |
tokenizer.chat_template = template | |
logger.info("Set phi chat template") | |
except Exception as e: | |
logger.warning(f"Failed to set chat template: {str(e)}") | |
# Ensure proper token settings | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error in model/tokenizer loading: {str(e)}") | |
logger.error("If missing dependencies, check the requirements.txt file") | |
raise | |
def load_dataset_with_mapping(dataset_config): | |
"""Load dataset and apply appropriate column mappings.""" | |
try: | |
# Load dataset | |
dataset_name = dataset_config.get("dataset", {}).get("name", "") | |
dataset_split = dataset_config.get("dataset", {}).get("split", "train") | |
if not dataset_name: | |
raise ValueError("Dataset name not provided in configuration") | |
logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}") | |
dataset = load_dataset(dataset_name, split=dataset_split) | |
# Apply minimal processing since the dataset has already been properly structured | |
# Just perform validation to ensure required fields exist | |
# Check for required fields | |
required_fields = ["prompt_number", "article_id", "conversations"] | |
missing_fields = [field for field in required_fields if field not in dataset.column_names] | |
if missing_fields: | |
logger.warning(f"Dataset is missing required fields: {missing_fields}") | |
logger.warning("This may cause issues with sequence integrity and metadata management") | |
else: | |
logger.info(f"Dataset has all required fields: {required_fields}") | |
# Verify that column order matches our expectation | |
expected_order = ["prompt_number", "article_id", "conversations"] | |
actual_order = dataset.column_names | |
if actual_order == expected_order: | |
logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)") | |
else: | |
logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})") | |
logger.warning("This should not affect processing but is noted for debugging purposes") | |
# Log a few samples for verification | |
if len(dataset) > 0: | |
sample_indices = range(min(5, len(dataset))) | |
sample_records = [] | |
for i in sample_indices: | |
record = {} | |
record["prompt_number"] = dataset[i].get("prompt_number", "N/A") | |
record["article_id"] = dataset[i].get("article_id", "N/A") | |
if "conversations" in dataset[i]: | |
record["conversations_length"] = len(dataset[i]["conversations"]) | |
sample_records.append(record) | |
logger.info(f"Sample records: {sample_records}") | |
# Verify sequential integrity | |
if "prompt_number" in dataset.column_names and len(dataset) > 1: | |
first_prompt_numbers = [dataset[i]["prompt_number"] for i in range(min(10, len(dataset)))] | |
is_sequential = all(first_prompt_numbers[i] == i + 1 for i in range(len(first_prompt_numbers))) | |
if is_sequential: | |
logger.info("Dataset prompt numbers are sequential (1-indexed) - sequence integrity preserved") | |
else: | |
logger.warning("Dataset prompt numbers are not sequential - sequence integrity may be compromised") | |
logger.info(f"First few prompt numbers: {first_prompt_numbers}") | |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples") | |
logger.info(f"Dataset columns: {dataset.column_names}") | |
# Data loading configuration - ensure shuffle is disabled | |
data_loading_config = dataset_config.get("data_loading", {}) | |
if data_loading_config.get("shuffle", False): | |
logger.error("CRITICAL: shuffle is enabled in the dataset config!") | |
logger.error("This will RANDOMIZE your dataset and break sequential order.") | |
logger.error("Setting shuffle to False to preserve order") | |
data_loading_config["shuffle"] = False | |
return dataset | |
except Exception as e: | |
logger.error(f"Error loading dataset: {str(e)}") | |
raise | |
def format_phi_chat(messages, dataset_config): | |
"""Format messages according to phi-4's chat template and dataset config.""" | |
formatted_chat = "" | |
# Get role templates from config | |
roles = dataset_config.get("data_formatting", {}).get("roles", { | |
"system": "System: {content}\n\n", | |
"human": "Human: {content}\n\n", | |
"user": "Human: {content}\n\n", | |
"assistant": "Assistant: {content}\n\n" | |
}) | |
# Handle research introduction metadata first | |
metadata = next((msg for msg in messages if isinstance(msg, dict) and | |
"[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) | |
if metadata: | |
system_template = roles.get("system", "System: {content}\n\n") | |
formatted_chat = system_template.format(content=metadata['content']) | |
messages = [msg for msg in messages if msg != metadata] | |
# Process remaining messages | |
for message in messages: | |
if not isinstance(message, dict) or "content" not in message: | |
logger.warning(f"Skipping invalid message format: {message}") | |
continue | |
role = message.get("role", "").lower() | |
content = message.get("content", "") | |
# Format based on role | |
if role == "human" or role == "user": | |
template = roles.get("user", roles.get("human", "Human: {content}\n\n")) | |
formatted_chat += template.format(content=content) | |
elif role == "assistant" or role == "bot": | |
template = roles.get("assistant", "Assistant: {content}\n\n") | |
formatted_chat += template.format(content=content) | |
elif role == "system": | |
# For system messages, prepend them | |
template = roles.get("system", "System: {content}\n\n") | |
formatted_chat = template.format(content=content) + formatted_chat | |
else: | |
# Default to system for unknown roles | |
logger.warning(f"Unknown role '{role}' - treating as system message") | |
template = roles.get("system", "System: {content}\n\n") | |
formatted_chat += template.format(content=content) | |
return formatted_chat.strip() | |
class SimpleDataCollator: | |
def __init__(self, tokenizer, dataset_config): | |
self.tokenizer = tokenizer | |
self.dataset_config = dataset_config | |
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} | |
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) | |
logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}") | |
logger.info("Using exact dataset structure without reformatting") | |
# Check if we're on GPU | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"SimpleDataCollator using device: {self.device}") | |
def __call__(self, features): | |
"""Process examples preserving exact JSONL structure""" | |
batch = {"input_ids": [], "attention_mask": [], "labels": []} | |
for example in features: | |
try: | |
# Get ID | |
paper_id = example.get("id", "") | |
# Get conversations - these should already contain role and content | |
conversations = example.get("conversations", []) | |
if not conversations: | |
self.stats["skipped"] += 1 | |
continue | |
# Directly use the conversations array as input to the model's chat template | |
# This preserves the exact structure with roles and content as they are | |
try: | |
# Let tokenizer handle the content with the model's chat template | |
inputs = self.tokenizer.apply_chat_template( | |
conversations, | |
return_tensors=None, | |
add_generation_prompt=False | |
) | |
except Exception as chat_error: | |
# Fallback if apply_chat_template fails | |
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}") | |
# Create a basic representation of the conversation | |
conversation_text = "" | |
for msg in conversations: | |
if isinstance(msg, dict) and 'content' in msg: | |
conversation_text += msg.get('content', '') + "\n\n" | |
# Basic tokenization | |
inputs = self.tokenizer( | |
conversation_text, | |
add_special_tokens=True, | |
return_tensors=None | |
) | |
# Apply length cap if needed (shouldn't be necessary for pre-audited data) | |
if self.max_seq_length > 0 and len(inputs) > self.max_seq_length: | |
logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})") | |
inputs = inputs[:self.max_seq_length] | |
# Create attention mask (1 for all tokens) | |
attention_mask = [1] * len(inputs) | |
if len(inputs) > 0: | |
# For causal language modeling, labels are the same as inputs | |
labels = inputs.copy() | |
batch["input_ids"].append(inputs) | |
batch["attention_mask"].append(attention_mask) | |
batch["labels"].append(labels) | |
self.stats["processed"] += 1 | |
self.stats["total_tokens"] += len(inputs) | |
# Debug logging for first few examples | |
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) | |
if self.stats["processed"] <= log_samples: | |
logger.info(f"Example {self.stats['processed']}:") | |
logger.info(f"Paper ID: {paper_id}") | |
logger.info(f"Token count: {len(inputs)}") | |
logger.info(f"Conversation entries: {len(conversations)}") | |
else: | |
self.stats["skipped"] += 1 | |
except Exception as e: | |
logger.warning(f"Error processing example: {str(e)[:100]}...") | |
logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}") | |
self.stats["skipped"] += 1 | |
continue | |
if not batch["input_ids"]: | |
logger.warning("Empty batch, returning dummy tensors") | |
return { | |
"input_ids": torch.zeros((1, 1), dtype=torch.long), | |
"attention_mask": torch.zeros((1, 1), dtype=torch.long), | |
"labels": torch.zeros((1, 1), dtype=torch.long) | |
} | |
# Pad the batch | |
max_length = max(len(ids) for ids in batch["input_ids"]) | |
for i in range(len(batch["input_ids"])): | |
padding_length = max_length - len(batch["input_ids"][i]) | |
if padding_length > 0: | |
batch["input_ids"][i].extend([self.pad_token_id] * padding_length) | |
batch["attention_mask"][i].extend([0] * padding_length) | |
batch["labels"][i].extend([-100] * padding_length) | |
# Convert to tensors | |
batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} | |
# Log stats periodically | |
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) | |
if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: | |
logger.info(f"Data collator stats: processed={self.stats['processed']}, " | |
f"skipped={self.stats['skipped']}, " | |
f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}") | |
return batch | |
class LoggingCallback(TrainerCallback): | |
def __init__(self): | |
super().__init__() | |
self.training_started = time.time() | |
self.last_log_time = time.time() | |
self.last_step = 0 | |
self.verify_sequence = None | |
self.sequence_samples = None | |
self.sample_indices = None | |
def on_train_begin(self, args, state, control, **kwargs): | |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") | |
# Disable sequence verification | |
self.verify_sequence = False | |
log_info("=== Training is starting ===") | |
# Log important training parameters for visibility | |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS | |
total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs) | |
log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps") | |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") | |
log_info(f"Learning rate: {args.learning_rate}") | |
log_info(f"Epochs: {args.num_train_epochs}") | |
# Log memory information in compact format | |
if CUDA_AVAILABLE: | |
memory_info = [] | |
for i in range(NUM_GPUS): | |
allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") | |
log_info(f"Initial memory usage - {', '.join(memory_info)}") | |
def on_step_end(self, args, state, control, **kwargs): | |
# Log every 50 steps or every 5 minutes, whichever comes first | |
current_time = time.time() | |
# Sequence verification removed | |
# Log progress at regular intervals | |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300): | |
if state.log_history: | |
loss = state.log_history[-1].get('loss', 'N/A') | |
# Use simple formatting for better Space log compatibility | |
log_info(f"Step {state.global_step}: Loss {loss}") | |
else: | |
log_info(f"Step {state.global_step}: No loss data available") | |
self.last_log_time = current_time | |
def on_train_end(self, args, state, control, **kwargs): | |
training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started)) | |
log_info(f"=== Training completed in {training_time} ===") | |
# Log final memory usage | |
if CUDA_AVAILABLE: | |
for i in range(NUM_GPUS): | |
max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB | |
log_info(f"GPU {i} max memory: {max_mem:.2f} GB") | |
# Clear GPU memory | |
torch.cuda.empty_cache() | |
log_info("GPU memory cleared") | |
log_info(f"Total steps: {state.global_step}") | |
log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") | |
def check_dependencies(): | |
"""Check if all required dependencies are installed and in the correct order.""" | |
missing_packages = [] | |
order_issues = [] | |
# Check critical packages in the required order | |
# 1. First check for unsloth as it should be imported before transformers | |
if not unsloth_available: | |
missing_packages.append("unsloth>=2024.3") | |
# 2. Check transformers (imported at module level) | |
try: | |
import transformers | |
logger.info(f"Using transformers version {transformers.__version__}") | |
except ImportError: | |
missing_packages.append("transformers>=4.38.0") | |
# 3. Check for peft | |
if not peft_available: | |
missing_packages.append("peft>=0.9.0") | |
# 4. Check for accelerate | |
try: | |
import accelerate | |
logger.info(f"Using accelerate version {accelerate.__version__}") | |
except ImportError: | |
missing_packages.append("accelerate>=0.27.0") | |
# Check for order-specific issues | |
try: | |
import sys | |
modules = sys.modules.keys() | |
# Unsloth should be imported before transformers for optimal performance | |
if 'transformers' in modules and 'unsloth' in modules: | |
if modules.index('transformers') < modules.index('unsloth'): | |
order_issues.append("For optimal performance, unsloth should be imported before transformers") | |
except Exception: | |
# If we can't check order, just skip this check | |
pass | |
# If critical packages are missing, exit with instructions | |
if missing_packages: | |
logger.error("Critical dependencies missing:") | |
for pkg in missing_packages: | |
logger.error(f" - {pkg}") | |
logger.error("Please install the missing dependencies with:") | |
logger.error(f" pip install {' '.join(missing_packages)}") | |
return False | |
# Report order issues as warnings | |
for issue in order_issues: | |
logger.warning(issue) | |
# Optional packages - moved to the end | |
if find_spec("flash_attn"): | |
logger.info("flash-attn found. Flash attention will be used for faster training.") | |
else: | |
logger.warning("flash-attn not found. Training will work but may be slower.") | |
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") | |
# Additional optional packages that improve performance | |
if find_spec("bitsandbytes"): | |
logger.info("bitsandbytes found. Quantization will be available.") | |
else: | |
logger.warning("bitsandbytes not found. Quantization may not be available.") | |
logger.warning("To use quantization, install with: pip install bitsandbytes") | |
return True | |
def update_huggingface_space(): | |
"""Update the Hugging Face Space with the current code.""" | |
log_info("Updating Hugging Face Space...") | |
update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py") | |
if not os.path.exists(update_script): | |
logger.warning(f"Update space script not found at {update_script}") | |
return False | |
try: | |
import subprocess | |
result = subprocess.run([sys.executable, update_script, "--force"], | |
capture_output=True, text=True, check=False) | |
if result.returncode == 0: | |
log_info("Hugging Face Space updated successfully!") | |
log_info(f"Space URL: https://huggingface.co/spaces/{os.environ.get('HF_USERNAME', 'George-API')}/{os.environ.get('HF_SPACE_NAME', 'phi4training')}") | |
return True | |
else: | |
logger.error(f"Failed to update Hugging Face Space: {result.stderr}") | |
return False | |
except Exception as e: | |
logger.error(f"Error updating Hugging Face Space: {str(e)}") | |
return False | |
def validate_huggingface_credentials(): | |
"""Validate Hugging Face credentials to ensure they work correctly.""" | |
if not os.environ.get("HF_TOKEN"): | |
logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.") | |
return False | |
try: | |
# Import here to avoid requiring huggingface_hub if not needed | |
from huggingface_hub import HfApi, login | |
# Try to login with the token | |
login(token=os.environ.get("HF_TOKEN")) | |
# Check if we can access the API | |
api = HfApi() | |
username = os.environ.get("HF_USERNAME", "George-API") | |
space_name = os.environ.get("HF_SPACE_NAME", "phi4training") | |
# Try to get whoami info | |
user_info = api.whoami() | |
logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}") | |
# Check if the space exists | |
try: | |
space_id = f"{username}/{space_name}" | |
space_info = api.space_info(repo_id=space_id) | |
logger.info(f"Space {space_id} is accessible") | |
return True | |
except Exception as e: | |
logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}") | |
logger.warning("Space updating may not work correctly") | |
return False | |
except ImportError: | |
logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.") | |
return False | |
except Exception as e: | |
logger.warning(f"Error validating Hugging Face credentials: {str(e)}") | |
return False | |
def main(): | |
# Set up logging | |
logger.info("Starting training process") | |
# Check dependencies first, before any other operations | |
if not check_dependencies(): | |
logger.error("Aborting due to missing critical dependencies") | |
return 1 | |
# Parse arguments | |
args = parse_args() | |
# Load environment variables | |
load_env_variables() | |
# Validate Hugging Face credentials if we're going to use them | |
validate_huggingface_credentials() | |
# Load configuration | |
try: | |
transformers_config = load_configs(args.config) | |
hardware_config = transformers_config.get("hardware", {}) | |
dataset_config = transformers_config.get("dataset", {}) | |
logger.info("Configuration loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading configuration: {e}") | |
return 1 | |
# Check if we're in distributed mode | |
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 | |
if is_distributed: | |
local_rank = int(os.environ.get("LOCAL_RANK", "0")) | |
log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}") | |
else: | |
log_info("Running in non-distributed mode (single process)") | |
# Set random seed for reproducibility | |
seed = transformers_config.get("seed", 42) | |
set_seed(seed) | |
logger.info(f"Set random seed to {seed}") | |
# Load model and tokenizer using the consolidated config | |
model, tokenizer = load_model_and_tokenizer(transformers_config) | |
# Empty CUDA cache to ensure clean state | |
if CUDA_AVAILABLE: | |
torch.cuda.empty_cache() | |
log_info("Cleared CUDA cache") | |
# Setup environment variable for CUDA memory allocation | |
if CUDA_AVAILABLE: | |
system_settings = hardware_config.get("system_settings", {}) | |
cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85) | |
if cuda_memory_fraction < 1.0: | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" | |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") | |
try: | |
log_info("Loading dataset...") | |
dataset = load_dataset_with_mapping(dataset_config) | |
log_info(f"Dataset loaded with {len(dataset)} examples") | |
# Minimal validation before proceeding | |
if dataset is None or len(dataset) == 0: | |
logger.error("Dataset is empty or None! Cannot proceed with training.") | |
return 1 | |
# Create data collator | |
data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence | |
# First check hardware config, then transformers config | |
use_bf16 = False | |
use_fp16 = False | |
# Check hardware config first | |
hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "") | |
if hardware_precision.lower() == "bf16": | |
use_bf16 = True | |
log_info("Using BF16 precision from hardware config") | |
elif hardware_precision.lower() == "fp16": | |
use_fp16 = True | |
log_info("Using FP16 precision from hardware config") | |
else: | |
# Fall back to transformers config | |
use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16" | |
use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set | |
log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}") | |
# Get per device batch size - from transformers config, but possibly overridden by hardware config | |
per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16) | |
gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3) | |
# Get multi-GPU strategy from hardware config (default to data_parallel) | |
multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel") | |
logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}") | |
# For multi-GPU setup, adjust for better balance | |
if CUDA_AVAILABLE and NUM_GPUS > 1: | |
log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs") | |
# Set up FSDP for multi-GPU training if specified and in distributed mode | |
fsdp_config = None | |
if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1: | |
try: | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
MixedPrecision, | |
BackwardPrefetch, | |
ShardingStrategy, | |
CPUOffload, | |
) | |
from torch.distributed.fsdp.wrap import ( | |
transformer_auto_wrap_policy, | |
enable_wrap, | |
wrap, | |
) | |
log_info("Using FSDP for distributed training") | |
# Configure FSDP | |
fsdp_config = { | |
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], | |
"fsdp_offload_params": False, | |
"fsdp_backward_prefetch": "BACKWARD_PRE", | |
"fsdp_min_num_params": 1e6, | |
"fsdp_sharding_strategy": 1, # FULL_SHARD | |
} | |
if use_bf16 or use_fp16: | |
precision_type = "bf16" if use_bf16 else "fp16" | |
fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT" | |
log_info(f"FSDP using mixed precision: {precision_type}") | |
except ImportError: | |
log_info("FSDP imports failed, falling back to standard DDP") | |
fsdp_config = None | |
elif multi_gpu_strategy == "fsdp" and not is_distributed: | |
log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)") | |
log_info("Using DataParallel for multi-GPU training instead") | |
else: | |
log_info(f"Using {multi_gpu_strategy} for multi-GPU training") | |
# Get system settings from hardware config | |
dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2) | |
pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True) | |
# Set up training arguments | |
log_info("Setting up training arguments") | |
training_args = TrainingArguments( | |
output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"), | |
num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3), | |
per_device_train_batch_size=per_device_batch_size, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5), | |
weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01), | |
warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05), | |
lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"), | |
logging_steps=transformers_config.get("training", {}).get("logging_steps", 10), | |
save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"), | |
save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100), | |
save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3), | |
fp16=use_fp16, | |
bf16=use_bf16, | |
max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0), | |
push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False), | |
hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None), | |
hub_token=os.environ.get("HF_TOKEN", None), | |
report_to="tensorboard", | |
remove_unused_columns=False, # Keep all columns | |
gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True), | |
dataloader_pin_memory=pin_memory, | |
optim=transformers_config.get("training", {}).get("optim", "adamw_torch"), | |
ddp_find_unused_parameters=False, # Improve distributed training efficiency | |
dataloader_drop_last=False, # Process all examples | |
dataloader_num_workers=dataloader_workers, | |
no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available | |
# Only add FSDP if we're in distributed mode with FSDP strategy | |
fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None, | |
) | |
# Create sequential sampler to maintain original dataset order | |
sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
# Initialize trainer first | |
log_info("Initializing Trainer") | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, # We'll override this with our custom dataloader | |
data_collator=data_collator, | |
callbacks=[LoggingCallback()], | |
) | |
# Then override the get_train_dataloader method | |
def custom_get_train_dataloader(): | |
"""Custom dataloader that preserves original dataset order""" | |
log_info("Creating sequential dataloader to maintain original dataset order") | |
# Create a simple sequential sampler | |
sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
# Verification of sequence preservation flags - simplified | |
data_loading_config = dataset_config.get("data_loading", {}) | |
shuffle_enabled = data_loading_config.get("shuffle", False) | |
if shuffle_enabled: | |
log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.") | |
# We enforce sequential processing regardless of config | |
# Log our approach clearly | |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number") | |
# Verify column order | |
expected_order = ["prompt_number", "article_id", "conversations"] | |
if hasattr(dataset, 'column_names'): | |
actual_order = dataset.column_names | |
if actual_order == expected_order: | |
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}") | |
else: | |
log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})") | |
log_info("This is handled correctly by field-based access, but noting for clarity") | |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence") | |
# Calculate batch size based on device availability | |
if getattr(training_args, "no_cuda", False): | |
batch_size = training_args.per_device_train_batch_size | |
else: | |
batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1) | |
log_info(f"Using sequential sampler with batch size {batch_size}") | |
# Return DataLoader with sequential sampler | |
return torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
sampler=sequential_sampler, # Always use sequential sampler | |
collate_fn=data_collator, | |
drop_last=training_args.dataloader_drop_last, | |
num_workers=training_args.dataloader_num_workers, | |
pin_memory=training_args.dataloader_pin_memory, | |
) | |
# Override the get_train_dataloader method | |
trainer.get_train_dataloader = custom_get_train_dataloader | |
# Start training | |
log_info("=== Starting Training ===") | |
try: | |
# Empty cache again right before training | |
if CUDA_AVAILABLE: | |
torch.cuda.empty_cache() | |
log_info("Cleared CUDA cache before training") | |
# Display compact training info | |
total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs) | |
log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps") | |
trainer.train() | |
log_info("Training completed successfully!") | |
# Save the final model | |
log_info("Saving final model...") | |
trainer.save_model() | |
log_info(f"Model saved to {training_args.output_dir}") | |
# Push to hub if enabled | |
if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False): | |
hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model") | |
log_info(f"Pushing model to Hugging Face Hub as {hub_id}...") | |
trainer.push_to_hub() | |
log_info("Model successfully pushed to Hub") | |
# Update the Hugging Face Space with current code | |
if os.environ.get("HF_TOKEN") and os.environ.get("HF_USERNAME") and os.environ.get("HF_SPACE_NAME"): | |
update_huggingface_space() | |
return 0 | |
except Exception as e: | |
logger.error(f"Training failed with error: {str(e)}") | |
# Log CUDA memory info if available in compact format | |
if CUDA_AVAILABLE: | |
memory_info = [] | |
for i in range(NUM_GPUS): | |
allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
reserved = torch.cuda.memory_reserved(i) / 1024**2 | |
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)") | |
logger.error(f"GPU memory at failure: {', '.join(memory_info)}") | |
raise | |
except Exception as e: | |
logger.error(f"Error in main training loop: {str(e)}") | |
return 1 | |
if __name__ == "__main__": | |
sys.exit(main()) | |