#!/usr/bin/env python # coding=utf-8 import os import sys import json import argparse import logging from datetime import datetime import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed, BitsAndBytesConfig ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Check for BitsAndBytes try: from transformers import BitsAndBytesConfig bitsandbytes_available = True except ImportError: bitsandbytes_available = False logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") # Check for PEFT try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training peft_available = True except ImportError: peft_available = False logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") def load_env_variables(): """Load environment variables from system, .env file, or Hugging Face Space variables.""" # Check if we're running in a Hugging Face Space if os.environ.get("SPACE_ID"): logging.info("Running in Hugging Face Space") # Log the presence of variables (without revealing values) logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") # If username is not set, try to extract from SPACE_ID if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): username = os.environ.get("SPACE_ID").split("/")[0] os.environ["HF_USERNAME"] = username logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") else: # Try to load from .env file if not in a Space try: from dotenv import load_dotenv # Updated path to .env file in the new directory structure env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") if os.path.exists(env_path): load_dotenv(env_path) logging.info(f"Loaded environment variables from {env_path}") logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") else: logging.warning(f"No .env file found at {env_path}") except ImportError: logging.warning("python-dotenv not installed, not loading from .env file") if not os.environ.get("HF_USERNAME"): logger.warning("HF_USERNAME is not set. Using default username.") if not os.environ.get("HF_SPACE_NAME"): logger.warning("HF_SPACE_NAME is not set. Using default space name.") # Set HF_TOKEN for huggingface_hub if os.environ.get("HF_TOKEN"): os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") def parse_args(): parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to the configuration file") return parser.parse_args() def main(): # Set up logging logger.info("Starting training process") # Parse arguments args = parse_args() # Load environment variables load_env_variables() # Load configuration try: with open(args.config, "r") as f: config = json.load(f) logger.info(f"Loaded configuration from {args.config}") except Exception as e: logger.error(f"Error loading configuration: {e}") return 1 # Set random seed for reproducibility seed = config.get("seed", 42) set_seed(seed) logger.info(f"Set random seed to {seed}") # Check if we're running in a Hugging Face Space if os.environ.get("SPACE_ID") and not os.environ.get("HF_USERNAME"): # Extract username from SPACE_ID username = os.environ.get("SPACE_ID").split("/")[0] logger.info(f"Extracted username from SPACE_ID: {username}") # Set hub_model_id if not already set and push_to_hub is enabled if config.get("push_to_hub", False) and not config.get("hub_model_id"): model_name = config.get("model_name", "").split("/")[-1] config["hub_model_id"] = f"{username}/finetuned-{model_name}" logger.info(f"Set hub_model_id to {config['hub_model_id']}") # Load model and tokenizer logger.info(f"Loading model: {config.get('model_name')}") # Prepare BitsAndBytes config if 4-bit quantization is enabled quantization_config = None if config.get("load_in_4bit", False) and bitsandbytes_available: logger.info("Using 4-bit quantization") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=config.get("bnb_4bit_quant_type", "nf4"), bnb_4bit_compute_dtype=getattr(torch, config.get("bnb_4bit_compute_dtype", "float16")), bnb_4bit_use_double_quant=config.get("bnb_4bit_use_double_quant", True) ) # Load model with quantization config try: model = AutoModelForCausalLM.from_pretrained( config.get("model_name"), quantization_config=quantization_config, device_map="auto", trust_remote_code=config.get("trust_remote_code", False), use_cache=False # For compatibility with gradient checkpointing ) logger.info("Model loaded successfully") # Enable gradient checkpointing if available if hasattr(model, "gradient_checkpointing_enable"): try: # Try with use_reentrant parameter (newer versions) model.gradient_checkpointing_enable(use_reentrant=False) logger.info("Gradient checkpointing enabled with use_reentrant=False") except TypeError: # Fall back to version without parameter (older versions) model.gradient_checkpointing_enable() logger.info("Gradient checkpointing enabled without parameters") except Exception as e: logger.error(f"Error loading model: {e}") return 1 # Load tokenizer try: tokenizer = AutoTokenizer.from_pretrained( config.get("model_name"), use_fast=config.get("use_fast_tokenizer", True), trust_remote_code=config.get("trust_remote_code", False) ) logger.info("Tokenizer loaded successfully") # Set chat template if specified if config.get("chat_template"): tokenizer.chat_template = config.get("chat_template") logger.info(f"Set chat template to {config.get('chat_template')}") # Ensure pad token is properly set if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") except Exception as e: logger.error(f"Error loading tokenizer: {e}") return 1 # Prepare model for k-bit training if using PEFT if config.get("use_peft", False) and peft_available: logger.info("Preparing model for parameter-efficient fine-tuning") try: model = prepare_model_for_kbit_training(model) # Get target modules target_modules = config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) # Create LoRA config lora_config = LoraConfig( r=config.get("lora_r", 16), lora_alpha=config.get("lora_alpha", 32), lora_dropout=config.get("lora_dropout", 0.05), bias="none", task_type="CAUSAL_LM", target_modules=target_modules ) # Apply LoRA to model model = get_peft_model(model, lora_config) logger.info(f"Applied LoRA with r={config.get('lora_r', 16)}, alpha={config.get('lora_alpha', 32)}") except Exception as e: logger.error(f"Error setting up PEFT: {e}") return 1 # Load dataset logger.info(f"Loading dataset: {config.get('dataset_name')}") try: dataset = load_dataset(config.get("dataset_name")) logger.info(f"Dataset loaded successfully with {len(dataset['train'])} training examples") # Sort dataset by ID to ensure chunks from the same paper are processed together logger.info("Sorting dataset by ID to maintain paper chunk order") def sort_by_id(example): # Extract ID as integer if possible, otherwise keep as string try: return int(example['id']) except (ValueError, TypeError): return example['id'] # Apply sorting to the dataset dataset['train'] = dataset['train'].sort('id') logger.info("Dataset sorted by ID") # Log the first few IDs to verify sorting sample_ids = [example['id'] for example in dataset['train'].select(range(min(5, len(dataset['train']))))] logger.info(f"First few IDs after sorting: {sample_ids}") except Exception as e: logger.error(f"Error loading or sorting dataset: {e}") return 1 # Simple data collator that processes each entry independently # This ensures entries are not combined based on token size, even when batch size > 1 class SimpleDataCollator: def __init__(self, tokenizer): self.tokenizer = tokenizer self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.prompt_counter = 0 # Global counter for all prompts self.paper_counters = {} # Track prompts per paper ID logger.info("SimpleDataCollator initialized - processing entries independently") def __call__(self, features): batch = {"input_ids": [], "attention_mask": [], "labels": []} # Process each entry independently (no combining based on token size) for example in features: try: # Get ID and conversation fields paper_id = example.get("id", "") if isinstance(example, dict) else getattr(example, "id", "") conversation = example.get("conversations", []) if isinstance(example, dict) else getattr(example, "conversations", []) # Skip empty entries if not conversation: self.stats["skipped"] += 1 continue # Increment global prompt counter self.prompt_counter += 1 # Track prompts per paper if paper_id not in self.paper_counters: self.paper_counters[paper_id] = 0 self.paper_counters[paper_id] += 1 # Create a formatted prompt with tracking information full_content = f"Prompt #{self.prompt_counter} | Paper ID: {paper_id} | Paper Chunk: {self.paper_counters[paper_id]}\n\n" for message in conversation: # Extract role and content if isinstance(message, dict): role = message.get("role", "") content = message.get("content", "") else: role = getattr(message, "role", "") content = getattr(message, "content", "") # Add role and content to the full content full_content += f"{role}: {content}\n\n" # Tokenize the full content input_ids = self.tokenizer.encode(full_content, add_special_tokens=True) attention_mask = [1] * len(input_ids) # Truncate if necessary max_length = config.get("max_seq_length", 2048) if len(input_ids) > max_length: input_ids = input_ids[:max_length] attention_mask = attention_mask[:max_length] # Only add to batch if we have data if len(input_ids) > 0: # For content understanding, use the same tokens as labels labels = input_ids.copy() batch["input_ids"].append(input_ids) batch["attention_mask"].append(attention_mask) batch["labels"].append(labels) self.stats["processed"] += 1 self.stats["total_tokens"] += len(input_ids) # Debug logging for the first few examples if self.stats["processed"] <= 3: logger.info(f"Example {self.stats['processed']} - Prompt #{self.prompt_counter} | Paper ID: {paper_id} | Paper Chunk: {self.paper_counters[paper_id]}") logger.info(f"Token count: {len(input_ids)}") if len(input_ids) < 50: # Catch potentially short sequences logger.info(f"WARNING: Short token sequence: {len(input_ids)} tokens") logger.info(f"Content preview: {full_content[:200]}...") else: self.stats["skipped"] += 1 except Exception as e: logger.warning(f"Error processing example: {str(e)[:100]}...") self.stats["skipped"] += 1 continue # Pad the batch if not batch["input_ids"]: logger.warning("Empty batch, returning dummy tensors") return { "input_ids": torch.zeros((1, 1), dtype=torch.long), "attention_mask": torch.zeros((1, 1), dtype=torch.long), "labels": torch.zeros((1, 1), dtype=torch.long) } max_length = max(len(ids) for ids in batch["input_ids"]) # Pad all sequences to max_length for i in range(len(batch["input_ids"])): padding_length = max_length - len(batch["input_ids"][i]) if padding_length > 0: batch["input_ids"][i].extend([self.pad_token_id] * padding_length) batch["attention_mask"][i].extend([0] * padding_length) batch["labels"][i].extend([-100] * padding_length) # Don't compute loss on padding # Convert to tensors batch = {k: torch.tensor(v) for k, v in batch.items()} # Log stats periodically (every 100 batches) if self.stats["processed"] % 100 == 0 and self.stats["processed"] > 0: logger.info(f"Data collator stats: processed={self.stats['processed']}, " f"skipped={self.stats['skipped']}, " f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, " f"unique_papers={len(self.paper_counters)}") return batch # Create data collator data_collator = SimpleDataCollator(tokenizer) # Simple logging callback class LoggingCallback(TrainerCallback): def __init__(self): self.last_log_time = datetime.now() self.training_start_time = datetime.now() def on_step_end(self, args, state, control, **kwargs): # Log every 50 steps or every 5 minutes, whichever comes first current_time = datetime.now() time_diff = (current_time - self.last_log_time).total_seconds() elapsed_time = (current_time - self.training_start_time).total_seconds() / 60 # in minutes if state.global_step % 50 == 0 or time_diff > 300: # 300 seconds = 5 minutes loss = state.log_history[-1]['loss'] if state.log_history else 'N/A' lr = state.log_history[-1]['learning_rate'] if state.log_history else 'N/A' if isinstance(loss, float): loss_str = f"{loss:.4f}" else: loss_str = str(loss) if isinstance(lr, float): lr_str = f"{lr:.8f}" else: lr_str = str(lr) logger.info(f"Step: {state.global_step} | Loss: {loss_str} | LR: {lr_str} | Elapsed: {elapsed_time:.2f} min") self.last_log_time = current_time # Set up training arguments logger.info("Setting up training arguments") training_args = TrainingArguments( output_dir=config.get("output_dir", "./results"), num_train_epochs=config.get("num_train_epochs", 3), per_device_train_batch_size=config.get("per_device_train_batch_size", 4), # Use config value, can be > 1 gradient_accumulation_steps=config.get("gradient_accumulation_steps", 8), learning_rate=config.get("learning_rate", 5e-5), weight_decay=config.get("weight_decay", 0.01), warmup_ratio=config.get("warmup_ratio", 0.1), lr_scheduler_type=config.get("lr_scheduler_type", "linear"), logging_steps=config.get("logging_steps", 10), save_strategy=config.get("save_strategy", "steps"), # Updated to use steps by default save_steps=config.get("save_steps", 100), # Save every 100 steps by default save_total_limit=config.get("save_total_limit", 3), # Keep last 3 checkpoints fp16=config.get("fp16", True), bf16=config.get("bf16", False), max_grad_norm=config.get("max_grad_norm", 1.0), push_to_hub=config.get("push_to_hub", False), hub_model_id=config.get("hub_model_id", None), hub_token=os.environ.get("HF_TOKEN", None), report_to="tensorboard", remove_unused_columns=False, # Keep the conversations column gradient_checkpointing=True, # Enable gradient checkpointing dataloader_pin_memory=False, # Reduce memory usage optim=config.get("optim", "adamw_torch"), ddp_find_unused_parameters=False, # Improve distributed training efficiency dataloader_drop_last=False, # Process all examples dataloader_num_workers=0, # Sequential data loading ) # Create a sequential sampler to ensure dataset is processed in order logger.info("Creating sequential sampler to maintain dataset order") # Create trainer with callback logger.info("Creating trainer") # Check if we should resume from checkpoint resume_from_checkpoint = False output_dir = config.get("output_dir", "./results") if os.path.exists(output_dir): checkpoints = [folder for folder in os.listdir(output_dir) if folder.startswith("checkpoint-")] if checkpoints: latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1])) resume_from_checkpoint = os.path.join(output_dir, latest_checkpoint) logger.info(f"Found checkpoint: {resume_from_checkpoint}. Training will resume from this point.") trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], data_collator=data_collator, callbacks=[LoggingCallback()] ) # Override the default data loader to disable shuffling # This is necessary because TrainingArguments doesn't have a direct shuffle parameter def get_train_dataloader_no_shuffle(): """Create a train DataLoader with shuffling disabled.""" logger.info("Creating train dataloader with sequential sampler (no shuffling)") # Create a sequential sampler to ensure dataset is processed in order train_sampler = torch.utils.data.SequentialSampler(dataset["train"]) return torch.utils.data.DataLoader( dataset["train"], batch_size=training_args.per_device_train_batch_size, sampler=train_sampler, # Use sequential sampler instead of shuffle parameter collate_fn=data_collator, drop_last=False, num_workers=0, pin_memory=False ) # Replace the default data loader with our non-shuffling version trainer.get_train_dataloader = get_train_dataloader_no_shuffle # Start training logger.info("Starting training") logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently") # Create a lock file to indicate training is in progress lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock") with open(lock_file, "w") as f: f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n") f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n") logger.info(f"Created lock file: {lock_file}") try: trainer.train(resume_from_checkpoint=resume_from_checkpoint) logger.info("Training completed successfully") # Save model if config.get("push_to_hub", False): logger.info(f"Pushing model to hub: {config.get('hub_model_id')}") trainer.push_to_hub() logger.info("Model pushed to hub successfully") else: logger.info(f"Saving model to {config.get('output_dir', './results')}") trainer.save_model() logger.info("Model saved successfully") except Exception as e: logger.error(f"Training failed with error: {str(e)}") raise finally: # Remove the lock file when training completes or fails if os.path.exists(lock_file): os.remove(lock_file) logger.info(f"Removed lock file: {lock_file}") return 0 if __name__ == "__main__": sys.exit(main())