from transformers import Trainer, TrainingArguments | |
def fine_ttm(new_data): | |
training_args = TrainingArguments( | |
output_dir="./results", | |
per_device_train_batch_size=4, | |
num_train_epochs=3 | |
) | |
trainer = Trainer( | |
model=ttm_model, | |
args=training_args, | |
train_dataset=new_data | |
) | |
trainer.train() | |
trainer.save_model("updated_ttm") |