R1 / app.py
hackergeek98's picture
Create app.py
ef46523 verified
raw
history blame
715 Bytes
import os
from autotrain import logger
from autotrain.trainers.common import ALLOW_REMOTE_CODE
from autotrain.trainers.text_generation import LLMTrainingParams, LLMTrainer
def train():
# Define training parameters
params = LLMTrainingParams(
model_name="microsoft/phi-4", # Replace with your model
data_path="lavita/medical-qa-datasets",
project_name="phi4-training",
learning_rate=2e-5,
num_train_epochs=3,
batch_size=2,
fp16=True,
push_to_hub=True,
repo_id="hackergeek98/phi4-trained",
)
# Initialize and run trainer
trainer = LLMTrainer(params=params)
trainer.train()
if __name__ == "__main__":
train()