DAM / train.py
Kari's picture
more models
14e6deb
#Helsinki-NLP/opus-mt-zh-en
# 测试中英翻译模型
# from transformers import pipeline
# translator = pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh", max_time=7)
# prediction = translator("FRST", )[0]["translation_text"]
# print(prediction)
# 微调
from datasets import load_dataset, load_metric
import torch
import numpy as np
import os
raw_datasets = load_dataset("json", data_files="./more models/bank_en_zh_4.json")
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=20)
split_datasets["validation"] = split_datasets.pop("test")
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
model_checkpoint = "Helsinki-NLP/opus-mt-en-zh"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="tf",device=device)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors="tf")
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# model = torch.nn.DataParallel(model)
model.cuda()
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)
max_input_length = 23
max_target_length = 23
def preprocess_function(examples):
inputs = [ex["en"] for ex in examples["translation"]]
targets = [ex["zh"] for ex in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Set up the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets = split_datasets.map(
preprocess_function,
batched=True,
remove_columns=split_datasets["train"].column_names,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
batch = data_collator([tokenized_datasets["train"][i] for i in range(1, 3)])
def compute_metrics(eval_preds):
preds, labels = eval_preds
# In case the model returns more than the prediction logits
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100s in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [[label.strip()] for label in decoded_labels]
# print("Return:- ", metric.compute(predictions=decoded_preds, references=decoded_labels))
# print("decoded_preds:- ", decoded_preds)
# print("decoded_labels:- ", decoded_labels)
# print("Done")
return metric.compute(predictions=decoded_preds, references=decoded_labels)
from transformers import Seq2SeqTrainingArguments
args = Seq2SeqTrainingArguments(
f"marian-finetuned-kde4-en-to-zh",
evaluation_strategy="no",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=128,#32
per_device_eval_batch_size=64,#64
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=60,
predict_with_generate=True,
fp16=True,
push_to_hub=False,
)
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.save_model("./more models/test-ml-trained_4")