ZeeAI1 commited on
Commit
27d1a88
·
verified ·
1 Parent(s): 02927f4

Update train_flan_t5.py

Browse files
Files changed (1) hide show
  1. train_flan_t5.py +3 -47
train_flan_t5.py CHANGED
@@ -1,47 +1,3 @@
1
- from datasets import load_dataset
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments
3
-
4
- model_checkpoint = "google/flan-t5-large"
5
- output_dir = "./finetuned-flan-t5"
6
-
7
- # Load dataset
8
- dataset = load_dataset("json", data_files={"train": "train_data.jsonl"})
9
-
10
- # Tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
-
13
- def preprocess_function(examples):
14
- inputs = examples["input"]
15
- targets = examples["output"]
16
- model_inputs = tokenizer(inputs, max_length=512, truncation=True)
17
- labels = tokenizer(targets, max_length=128, truncation=True)
18
- model_inputs["labels"] = labels["input_ids"]
19
- return model_inputs
20
-
21
- tokenized_datasets = dataset.map(preprocess_function, batched=True)
22
-
23
- # Model
24
- model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
25
-
26
- # Training arguments
27
- training_args = Seq2SeqTrainingArguments(
28
- output_dir=output_dir,
29
- evaluation_strategy="no",
30
- learning_rate=5e-5,
31
- per_device_train_batch_size=2,
32
- num_train_epochs=3,
33
- weight_decay=0.01,
34
- save_total_limit=2,
35
- push_to_hub=False
36
- )
37
-
38
- trainer = Seq2SeqTrainer(
39
- model=model,
40
- args=training_args,
41
- train_dataset=tokenized_datasets["train"]
42
- )
43
-
44
- trainer.train()
45
-
46
- model.save_pretrained(output_dir)
47
- tokenizer.save_pretrained(output_dir)
 
1
+ tokenizer_config.json
2
+ special_tokens_map.json
3
+ spiece.model (for T5 models)