import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import load_dataset import os def train(): # Load model and tokenizer model_name = "microsoft/phi-2" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True) # Add padding token if missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load dataset (update paths as needed) dataset = load_dataset( "csv", data_files={ "train": "eswardivi/medical_qa", "validation": "eswardivi/medical_qa" } ) # Tokenization function def tokenize_function(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, max_length=256, return_tensors="pt", ) # Preprocess dataset tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["text"] ) # Data collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # Training arguments training_args = TrainingArguments( output_dir="./phi2-cpu-results", overwrite_output_dir=True, per_device_train_batch_size=2, per_device_eval_batch_size=2, num_train_epochs=3, logging_dir="./logs", logging_steps=100, evaluation_strategy="epoch", save_strategy="epoch", fp16=False, report_to="none", ) # Initialize Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["validation"], data_collator=data_collator, ) # Start training print("Starting training...") trainer.train() # Save model trainer.save_model("./phi2-trained-model") tokenizer.save_pretrained("./phi2-trained-model") print("Training complete! Model saved.") if __name__ == "__main__": train()