import os from autotrain import logger from autotrain.trainers.common import ALLOW_REMOTE_CODE from autotrain.trainers.text_generation import LLMTrainingParams, LLMTrainer def train(): # Define training parameters params = LLMTrainingParams( model_name="microsoft/phi-4", # Replace with your model data_path="lavita/medical-qa-datasets", project_name="phi4-training", learning_rate=2e-5, num_train_epochs=3, batch_size=2, fp16=True, push_to_hub=True, repo_id="hackergeek98/phi4-trained", ) # Initialize and run trainer trainer = LLMTrainer(params=params) trainer.train() if __name__ == "__main__": train()