TRL documentation

Iterative Trainer

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.17.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Iterative Trainer

Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.

Quickstart

To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:

from trl import IterativeSFTConfig, IterativeSFTTrainer

# Using a model identifier
trainer = IterativeSFTTrainer(
    "facebook/opt-350m",
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
)

# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

trainer = IterativeSFTTrainer(
    model,
    args=IterativeSFTConfig(
        max_length=512,
        output_dir="./output",
    ),
    processing_class=tokenizer,
)

Usage

The IterativeSFTTrainer supports two ways of providing input data to the step function:

Using a list of tensors as input:

inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}

trainer.step(**inputs)

Using a list of strings as input:

inputs = {
    "texts": texts,
    "texts_labels": texts_labels,  # Optional, defaults to texts
}

trainer.step(**inputs)

For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.

Configuration

The IterativeSFTConfig class provides several parameters to customize the training:

from trl import IterativeSFTConfig

config = IterativeSFTConfig(
    # Model initialization parameters
    model_init_kwargs={"torch_dtype": "bfloat16"},

    # Data preprocessing parameters
    max_length=512,
    truncation_mode="keep_end",

    # Training parameters
    output_dir="./output",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_steps=1000,
    logging_steps=10,
    save_steps=100,
    optim="adamw_torch",
    report_to="wandb",
)

Model Initialization

You can control how the model is initialized by passing keyword arguments to model_init_kwargs:

config = IterativeSFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
        "device_map": "auto",
        "trust_remote_code": True,
    }
)

Data Preprocessing

The trainer supports two truncation modes:

  • keep_end: Truncates from the start of the sequence
  • keep_start: Truncates from the end of the sequence
config = IterativeSFTConfig(
    max_length=512,
    truncation_mode="keep_end",  # or "keep_start"
)

Training Optimization

You can optimize CUDA cache usage for more memory-efficient training:

config = IterativeSFTConfig(
    optimize_device_cache=True,
)

IterativeSFTTrainer

class trl.IterativeSFTTrainer

< >

( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.iterative_sft_config.IterativeSFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None max_length: typing.Optional[int] = None truncation_mode: typing.Optional[str] = None optimize_device_cache: typing.Optional[bool] = None )

Parameters

  • model (Union[str, PreTrainedModel]) — Model to be trained. Can be either:

    • A string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained with the keywork arguments in args.model_init_kwargs.
    • A PreTrainedModel object. Only causal language models are supported.
  • args (IterativeSFTConfig, optional, defaults to None) — Configuration for this trainer. If None, a default configuration is used.
  • data_collator (DataCollator, optional) — Function to use to form a batch from a list of elements of the processed train_dataset or eval_dataset. Will default to default_data_collator if no processing_class is provided, an instance of DataCollatorWithPadding otherwise if the processing_class is a feature extractor or tokenizer.
  • eval_dataset (datasets.Dataset) — The dataset to use for evaluation.
  • processing_class (PreTrainedTokenizerBase, optional, defaults to None) — Processing class used to process the data. If None, the processing class is loaded from the model’s name with from_pretrained.
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — The optimizer and scheduler to use for training.
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — The function to use to preprocess the logits before computing the metrics.
  • compute_metrics (Callable[[EvalPrediction], dict], optional) — The function to use to compute the metrics. Must take a EvalPrediction and return a dictionary string to metric values.
  • max_length (int, optional, deprecated) — Maximum length of the tokenized sequence. Use args.max_length instead.
  • truncation_mode (str, optional, deprecated) — The truncation mode to use. Use args.truncation_mode instead.
  • optimize_device_cache (bool, optional, deprecated) — Whether to optimize CUDA cache. Use args.optimize_device_cache instead.

The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.

create_model_card

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

Parameters

  • model_name (str or None, optional, defaults to None) — Name of the model.
  • dataset_name (str or None, optional, defaults to None) — Name of the dataset used for training.
  • tags (str, list[str] or None, optional, defaults to None) — Tags to be associated with the model card.

Creates a draft of a model card using the information available to the Trainer.

step

< >

( input_ids: typing.Optional[list[torch.LongTensor]] = None attention_mask: typing.Optional[list[torch.LongTensor]] = None labels: typing.Optional[list[torch.LongTensor]] = None texts: typing.Optional[list[str]] = None texts_labels: typing.Optional[list[str]] = None ) dict[str, Any]

Parameters

  • input_ids (listtorch.LongTensor) — List of tensors containing the input_ids (if not provided, text will be used)
  • attention_mask (listtorch.LongTensor, , optional) — List of tensors containing the attention_mask
  • labels (listtorch.FloatTensor, optional) — List of tensors containing the labels (if set to None, will default to input_ids)
  • texts (liststr, optional) — List of strings containing the text input (if not provided, input_ids will directly be used)
  • texts_labels (liststr, optional) — List of strings containing the text labels (if set to None, will default to text)

Returns

dict[str, Any]

A summary of the training statistics

Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.

IterativeSFTConfig

class trl.IterativeSFTConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None max_length: typing.Optional[int] = None truncation_mode: str = 'keep_end' optimize_device_cache: bool = False )

Parameters that control the model

  • model_init_kwargs (dict[str, Any] or None, optional, defaults to None) — Keyword arguments for from_pretrained, used when the model argument of the IterativeSFTTrainer is provided as a string.

Parameters that control the data preprocessing

  • max_length (int or None, optional, defaults to None) — Maximum length of the tokenized sequence. Sequences longer than max_length are truncated.
  • truncation_mode (str, optional, defaults to "keep_end") — The truncation mode to use, either "keep_end" or "keep_start".
  • optimize_device_cache (bool, optional, defaults to False) — Whether to optimize CUDA cache for slightly more memory-efficient training.

Configuration class for the IterativeSFTTrainer.

Only the parameters specific to iterative SFT training are listed here. For details on other parameters, refer to the TrainingArguments documentation.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

< > Update on GitHub