Matthew Hollings commited on
Commit
5497d17
·
1 Parent(s): ec6ed9c
Files changed (1) hide show
  1. example.py +659 -0
example.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ from dataclasses import dataclass, field
29
+ from itertools import chain
30
+ from typing import Optional
31
+
32
+ import datasets
33
+ from datasets import load_dataset
34
+
35
+ import evaluate
36
+ import transformers
37
+ from transformers import (
38
+ CONFIG_MAPPING,
39
+ MODEL_FOR_CAUSAL_LM_MAPPING,
40
+ AutoConfig,
41
+ AutoModelForCausalLM,
42
+ AutoTokenizer,
43
+ HfArgumentParser,
44
+ Trainer,
45
+ TrainingArguments,
46
+ default_data_collator,
47
+ is_torch_tpu_available,
48
+ set_seed,
49
+ )
50
+ from transformers.testing_utils import CaptureLogger
51
+ from transformers.trainer_utils import get_last_checkpoint
52
+ from transformers.utils import check_min_version, send_example_telemetry
53
+ from transformers.utils.versions import require_version
54
+
55
+
56
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57
+ check_min_version("4.23.0.dev0")
58
+
59
+ require_version(
60
+ "datasets>=1.8.0",
61
+ "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
62
+ )
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+
67
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
68
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
75
+ """
76
+
77
+ model_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": (
81
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
82
+ )
83
+ },
84
+ )
85
+ model_type: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "If training from scratch, pass a model type from the list: "
89
+ + ", ".join(MODEL_TYPES)
90
+ },
91
+ )
92
+ config_overrides: Optional[str] = field(
93
+ default=None,
94
+ metadata={
95
+ "help": (
96
+ "Override some existing default config settings when a model is trained from scratch. Example: "
97
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
98
+ )
99
+ },
100
+ )
101
+ config_name: Optional[str] = field(
102
+ default=None,
103
+ metadata={
104
+ "help": "Pretrained config name or path if not the same as model_name"
105
+ },
106
+ )
107
+ tokenizer_name: Optional[str] = field(
108
+ default=None,
109
+ metadata={
110
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
111
+ },
112
+ )
113
+ cache_dir: Optional[str] = field(
114
+ default=None,
115
+ metadata={
116
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
117
+ },
118
+ )
119
+ use_fast_tokenizer: bool = field(
120
+ default=True,
121
+ metadata={
122
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
123
+ },
124
+ )
125
+ model_revision: str = field(
126
+ default="main",
127
+ metadata={
128
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
129
+ },
130
+ )
131
+ use_auth_token: bool = field(
132
+ default=False,
133
+ metadata={
134
+ "help": (
135
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
136
+ "with private models)."
137
+ )
138
+ },
139
+ )
140
+
141
+ def __post_init__(self):
142
+ if self.config_overrides is not None and (
143
+ self.config_name is not None or self.model_name_or_path is not None
144
+ ):
145
+ raise ValueError(
146
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
147
+ )
148
+
149
+
150
+ @dataclass
151
+ class DataTrainingArguments:
152
+ """
153
+ Arguments pertaining to what data we are going to input our model for training and eval.
154
+ """
155
+
156
+ dataset_name: Optional[str] = field(
157
+ default=None,
158
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
159
+ )
160
+ dataset_config_name: Optional[str] = field(
161
+ default=None,
162
+ metadata={
163
+ "help": "The configuration name of the dataset to use (via the datasets library)."
164
+ },
165
+ )
166
+ train_file: Optional[str] = field(
167
+ default=None, metadata={"help": "The input training data file (a text file)."}
168
+ )
169
+ validation_file: Optional[str] = field(
170
+ default=None,
171
+ metadata={
172
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
173
+ },
174
+ )
175
+ max_train_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": (
179
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
180
+ "value if set."
181
+ )
182
+ },
183
+ )
184
+ max_eval_samples: Optional[int] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": (
188
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
189
+ "value if set."
190
+ )
191
+ },
192
+ )
193
+
194
+ block_size: Optional[int] = field(
195
+ default=None,
196
+ metadata={
197
+ "help": (
198
+ "Optional input sequence length after tokenization. "
199
+ "The training dataset will be truncated in block of this size for training. "
200
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
201
+ )
202
+ },
203
+ )
204
+ overwrite_cache: bool = field(
205
+ default=False,
206
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
207
+ )
208
+ validation_split_percentage: Optional[int] = field(
209
+ default=5,
210
+ metadata={
211
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
212
+ },
213
+ )
214
+ preprocessing_num_workers: Optional[int] = field(
215
+ default=None,
216
+ metadata={"help": "The number of processes to use for the preprocessing."},
217
+ )
218
+ keep_linebreaks: bool = field(
219
+ default=True,
220
+ metadata={"help": "Whether to keep line breaks when using TXT files or not."},
221
+ )
222
+
223
+ def __post_init__(self):
224
+ if (
225
+ self.dataset_name is None
226
+ and self.train_file is None
227
+ and self.validation_file is None
228
+ ):
229
+ raise ValueError(
230
+ "Need either a dataset name or a training/validation file."
231
+ )
232
+ else:
233
+ if self.train_file is not None:
234
+ extension = self.train_file.split(".")[-1]
235
+ assert extension in [
236
+ "csv",
237
+ "json",
238
+ "txt",
239
+ ], "`train_file` should be a csv, a json or a txt file."
240
+ if self.validation_file is not None:
241
+ extension = self.validation_file.split(".")[-1]
242
+ assert extension in [
243
+ "csv",
244
+ "json",
245
+ "txt",
246
+ ], "`validation_file` should be a csv, a json or a txt file."
247
+
248
+
249
+ def main():
250
+ # See all possible arguments in src/transformers/training_args.py
251
+ # or by passing the --help flag to this script.
252
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
253
+
254
+ parser = HfArgumentParser(
255
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
256
+ )
257
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
258
+ # If we pass only one argument to the script and it's the path to a json file,
259
+ # let's parse it to get our arguments.
260
+ model_args, data_args, training_args = parser.parse_json_file(
261
+ json_file=os.path.abspath(sys.argv[1])
262
+ )
263
+ else:
264
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
265
+
266
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
267
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
268
+ send_example_telemetry("run_clm", model_args, data_args)
269
+
270
+ # Setup logging
271
+ logging.basicConfig(
272
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
273
+ datefmt="%m/%d/%Y %H:%M:%S",
274
+ handlers=[logging.StreamHandler(sys.stdout)],
275
+ )
276
+
277
+ log_level = training_args.get_process_log_level()
278
+ logger.setLevel(log_level)
279
+ datasets.utils.logging.set_verbosity(log_level)
280
+ transformers.utils.logging.set_verbosity(log_level)
281
+ transformers.utils.logging.enable_default_handler()
282
+ transformers.utils.logging.enable_explicit_format()
283
+
284
+ # Log on each process the small summary:
285
+ logger.warning(
286
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
287
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
288
+ )
289
+ logger.info(f"Training/evaluation parameters {training_args}")
290
+
291
+ # Detecting last checkpoint.
292
+ last_checkpoint = None
293
+ if (
294
+ os.path.isdir(training_args.output_dir)
295
+ and training_args.do_train
296
+ and not training_args.overwrite_output_dir
297
+ ):
298
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
299
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
300
+ raise ValueError(
301
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
302
+ "Use --overwrite_output_dir to overcome."
303
+ )
304
+ elif (
305
+ last_checkpoint is not None and training_args.resume_from_checkpoint is None
306
+ ):
307
+ logger.info(
308
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
309
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
310
+ )
311
+
312
+ # Set seed before initializing model.
313
+ set_seed(training_args.seed)
314
+
315
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
316
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
317
+ # (the dataset will be downloaded automatically from the datasets Hub).
318
+ #
319
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
320
+ # 'text' is found. You can easily tweak this behavior (see below).
321
+ #
322
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
323
+ # download the dataset.
324
+ if data_args.dataset_name is not None:
325
+ # Downloading and loading a dataset from the hub.
326
+ raw_datasets = load_dataset(
327
+ data_args.dataset_name,
328
+ data_args.dataset_config_name,
329
+ cache_dir=model_args.cache_dir,
330
+ use_auth_token=True if model_args.use_auth_token else None,
331
+ )
332
+ if "validation" not in raw_datasets.keys():
333
+ raw_datasets["validation"] = load_dataset(
334
+ data_args.dataset_name,
335
+ data_args.dataset_config_name,
336
+ split=f"train[:{data_args.validation_split_percentage}%]",
337
+ cache_dir=model_args.cache_dir,
338
+ use_auth_token=True if model_args.use_auth_token else None,
339
+ )
340
+ raw_datasets["train"] = load_dataset(
341
+ data_args.dataset_name,
342
+ data_args.dataset_config_name,
343
+ split=f"train[{data_args.validation_split_percentage}%:]",
344
+ cache_dir=model_args.cache_dir,
345
+ use_auth_token=True if model_args.use_auth_token else None,
346
+ )
347
+ else:
348
+ data_files = {}
349
+ dataset_args = {}
350
+ if data_args.train_file is not None:
351
+ data_files["train"] = data_args.train_file
352
+ if data_args.validation_file is not None:
353
+ data_files["validation"] = data_args.validation_file
354
+ extension = (
355
+ data_args.train_file.split(".")[-1]
356
+ if data_args.train_file is not None
357
+ else data_args.validation_file.split(".")[-1]
358
+ )
359
+ if extension == "txt":
360
+ extension = "text"
361
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
362
+ raw_datasets = load_dataset(
363
+ extension,
364
+ data_files=data_files,
365
+ cache_dir=model_args.cache_dir,
366
+ use_auth_token=True if model_args.use_auth_token else None,
367
+ **dataset_args,
368
+ )
369
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
370
+ if "validation" not in raw_datasets.keys():
371
+ raw_datasets["validation"] = load_dataset(
372
+ extension,
373
+ data_files=data_files,
374
+ split=f"train[:{data_args.validation_split_percentage}%]",
375
+ cache_dir=model_args.cache_dir,
376
+ use_auth_token=True if model_args.use_auth_token else None,
377
+ **dataset_args,
378
+ )
379
+ raw_datasets["train"] = load_dataset(
380
+ extension,
381
+ data_files=data_files,
382
+ split=f"train[{data_args.validation_split_percentage}%:]",
383
+ cache_dir=model_args.cache_dir,
384
+ use_auth_token=True if model_args.use_auth_token else None,
385
+ **dataset_args,
386
+ )
387
+
388
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
389
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
390
+
391
+ # Load pretrained model and tokenizer
392
+ #
393
+ # Distributed training:
394
+ # The .from_pretrained methods guarantee that only one local process can concurrently
395
+ # download model & vocab.
396
+
397
+ config_kwargs = {
398
+ "cache_dir": model_args.cache_dir,
399
+ "revision": model_args.model_revision,
400
+ "use_auth_token": True if model_args.use_auth_token else None,
401
+ }
402
+ if model_args.config_name:
403
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
404
+ elif model_args.model_name_or_path:
405
+ config = AutoConfig.from_pretrained(
406
+ model_args.model_name_or_path, **config_kwargs
407
+ )
408
+ else:
409
+ config = CONFIG_MAPPING[model_args.model_type]()
410
+ logger.warning("You are instantiating a new config instance from scratch.")
411
+ if model_args.config_overrides is not None:
412
+ logger.info(f"Overriding config: {model_args.config_overrides}")
413
+ config.update_from_string(model_args.config_overrides)
414
+ logger.info(f"New config: {config}")
415
+
416
+ tokenizer_kwargs = {
417
+ "cache_dir": model_args.cache_dir,
418
+ "use_fast": model_args.use_fast_tokenizer,
419
+ "revision": model_args.model_revision,
420
+ "use_auth_token": True if model_args.use_auth_token else None,
421
+ }
422
+ if model_args.tokenizer_name:
423
+ tokenizer = AutoTokenizer.from_pretrained(
424
+ model_args.tokenizer_name, **tokenizer_kwargs
425
+ )
426
+ elif model_args.model_name_or_path:
427
+ tokenizer = AutoTokenizer.from_pretrained(
428
+ model_args.model_name_or_path, **tokenizer_kwargs
429
+ )
430
+ else:
431
+ raise ValueError(
432
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
433
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
434
+ )
435
+
436
+ if model_args.model_name_or_path:
437
+ model = AutoModelForCausalLM.from_pretrained(
438
+ model_args.model_name_or_path,
439
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
440
+ config=config,
441
+ cache_dir=model_args.cache_dir,
442
+ revision=model_args.model_revision,
443
+ use_auth_token=True if model_args.use_auth_token else None,
444
+ )
445
+ else:
446
+ model = AutoModelForCausalLM.from_config(config)
447
+ n_params = sum(
448
+ dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()
449
+ )
450
+ logger.info(
451
+ f"Training new model from scratch - Total size={n_params/2**20:.2f}M params"
452
+ )
453
+
454
+ model.resize_token_embeddings(len(tokenizer))
455
+
456
+ # Preprocessing the datasets.
457
+ # First we tokenize all the texts.
458
+ if training_args.do_train:
459
+ column_names = raw_datasets["train"].column_names
460
+ else:
461
+ column_names = raw_datasets["validation"].column_names
462
+ text_column_name = "text" if "text" in column_names else column_names[0]
463
+
464
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
465
+ tok_logger = transformers.utils.logging.get_logger(
466
+ "transformers.tokenization_utils_base"
467
+ )
468
+
469
+ def tokenize_function(examples):
470
+ with CaptureLogger(tok_logger) as cl:
471
+ output = tokenizer(examples[text_column_name])
472
+ # clm input could be much much longer than block_size
473
+ if "Token indices sequence length is longer than the" in cl.out:
474
+ tok_logger.warning(
475
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
476
+ " before being passed to the model."
477
+ )
478
+ return output
479
+
480
+ with training_args.main_process_first(desc="dataset map tokenization"):
481
+ tokenized_datasets = raw_datasets.map(
482
+ tokenize_function,
483
+ batched=True,
484
+ num_proc=data_args.preprocessing_num_workers,
485
+ remove_columns=column_names,
486
+ load_from_cache_file=not data_args.overwrite_cache,
487
+ desc="Running tokenizer on dataset",
488
+ )
489
+
490
+ if data_args.block_size is None:
491
+ block_size = tokenizer.model_max_length
492
+ if block_size > 1024:
493
+ logger.warning(
494
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
495
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
496
+ )
497
+ block_size = 1024
498
+ else:
499
+ if data_args.block_size > tokenizer.model_max_length:
500
+ logger.warning(
501
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
502
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
503
+ )
504
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
505
+
506
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
507
+ def group_texts(examples):
508
+ # Concatenate all texts.
509
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
510
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
511
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
512
+ # customize this part to your needs.
513
+ if total_length >= block_size:
514
+ total_length = (total_length // block_size) * block_size
515
+ # Split by chunks of max_len.
516
+ result = {
517
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
518
+ for k, t in concatenated_examples.items()
519
+ }
520
+ result["labels"] = result["input_ids"].copy()
521
+ return result
522
+
523
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
524
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
525
+ # to preprocess.
526
+ #
527
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
528
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
529
+
530
+ with training_args.main_process_first(desc="grouping texts together"):
531
+ lm_datasets = tokenized_datasets.map(
532
+ group_texts,
533
+ batched=True,
534
+ num_proc=data_args.preprocessing_num_workers,
535
+ load_from_cache_file=not data_args.overwrite_cache,
536
+ desc=f"Grouping texts in chunks of {block_size}",
537
+ )
538
+
539
+ if training_args.do_train:
540
+ if "train" not in tokenized_datasets:
541
+ raise ValueError("--do_train requires a train dataset")
542
+ train_dataset = lm_datasets["train"]
543
+ if data_args.max_train_samples is not None:
544
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
545
+ train_dataset = train_dataset.select(range(max_train_samples))
546
+
547
+ if training_args.do_eval:
548
+ if "validation" not in tokenized_datasets:
549
+ raise ValueError("--do_eval requires a validation dataset")
550
+ eval_dataset = lm_datasets["validation"]
551
+ if data_args.max_eval_samples is not None:
552
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
553
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
554
+
555
+ def preprocess_logits_for_metrics(logits, labels):
556
+ if isinstance(logits, tuple):
557
+ # Depending on the model and config, logits may contain extra tensors,
558
+ # like past_key_values, but logits always come first
559
+ logits = logits[0]
560
+ return logits.argmax(dim=-1)
561
+
562
+ metric = evaluate.load("accuracy")
563
+
564
+ def compute_metrics(eval_preds):
565
+ preds, labels = eval_preds
566
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
567
+ # by preprocess_logits_for_metrics but we need to shift the labels
568
+ labels = labels[:, 1:].reshape(-1)
569
+ preds = preds[:, :-1].reshape(-1)
570
+ return metric.compute(predictions=preds, references=labels)
571
+
572
+ # Initialize our Trainer
573
+ trainer = Trainer(
574
+ model=model,
575
+ args=training_args,
576
+ train_dataset=train_dataset if training_args.do_train else None,
577
+ eval_dataset=eval_dataset if training_args.do_eval else None,
578
+ tokenizer=tokenizer,
579
+ # Data collator will default to DataCollatorWithPadding, so we change it.
580
+ data_collator=default_data_collator,
581
+ compute_metrics=compute_metrics
582
+ if training_args.do_eval and not is_torch_tpu_available()
583
+ else None,
584
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
585
+ if training_args.do_eval and not is_torch_tpu_available()
586
+ else None,
587
+ )
588
+
589
+ # Training
590
+ if training_args.do_train:
591
+ checkpoint = None
592
+ if training_args.resume_from_checkpoint is not None:
593
+ checkpoint = training_args.resume_from_checkpoint
594
+ elif last_checkpoint is not None:
595
+ checkpoint = last_checkpoint
596
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
597
+ trainer.save_model() # Saves the tokenizer too for easy upload
598
+
599
+ metrics = train_result.metrics
600
+
601
+ max_train_samples = (
602
+ data_args.max_train_samples
603
+ if data_args.max_train_samples is not None
604
+ else len(train_dataset)
605
+ )
606
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
607
+
608
+ trainer.log_metrics("train", metrics)
609
+ trainer.save_metrics("train", metrics)
610
+ trainer.save_state()
611
+
612
+ # Evaluation
613
+ if training_args.do_eval:
614
+ logger.info("*** Evaluate ***")
615
+
616
+ metrics = trainer.evaluate()
617
+
618
+ max_eval_samples = (
619
+ data_args.max_eval_samples
620
+ if data_args.max_eval_samples is not None
621
+ else len(eval_dataset)
622
+ )
623
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
624
+ try:
625
+ perplexity = math.exp(metrics["eval_loss"])
626
+ except OverflowError:
627
+ perplexity = float("inf")
628
+ metrics["perplexity"] = perplexity
629
+
630
+ trainer.log_metrics("eval", metrics)
631
+ trainer.save_metrics("eval", metrics)
632
+
633
+ kwargs = {
634
+ "finetuned_from": model_args.model_name_or_path,
635
+ "tasks": "text-generation",
636
+ }
637
+ if data_args.dataset_name is not None:
638
+ kwargs["dataset_tags"] = data_args.dataset_name
639
+ if data_args.dataset_config_name is not None:
640
+ kwargs["dataset_args"] = data_args.dataset_config_name
641
+ kwargs[
642
+ "dataset"
643
+ ] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
644
+ else:
645
+ kwargs["dataset"] = data_args.dataset_name
646
+
647
+ if training_args.push_to_hub:
648
+ trainer.push_to_hub(**kwargs)
649
+ else:
650
+ trainer.create_model_card(**kwargs)
651
+
652
+
653
+ def _mp_fn(index):
654
+ # For xla_spawn (TPUs)
655
+ main()
656
+
657
+
658
+ if __name__ == "__main__":
659
+ main()