import importlib.metadata import os from typing import List, Optional import torch import torch.nn as nn from packaging import version from peft import PeftModel from torch.utils.data import Sampler from transformers import Trainer from transformers.trainer import ( ALL_LAYERNORM_LAYERS, get_parameter_names, has_length, is_sagemaker_mp_enabled, logger, ) from transformers.trainer_pt_utils import get_dataloader_sampler from transformers.trainer_pt_utils import ( get_length_grouped_indices as get_length_grouped_indices_hf, ) from transformers.trainer_pt_utils import get_model_param_count, get_parameter_names from transformers.trainer_utils import ( HPSearchBackend, TrainOutput, has_length, speed_metrics, ) from transformers.training_args import ParallelMode from transformers.utils import ( is_accelerate_available, is_peft_available, is_sagemaker_mp_enabled, is_torch_xla_available, ) TIME_STAMP = os.environ.get("TIME_STAMP", "default_value") BYTENAS = os.environ.get("BYTENAS", "vl-research") def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, "no ignore status") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = { k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match) } to_return = { k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items() } return to_return def split_to_even_chunks(indices, lengths, num_chunks): """ Split a list of indices into `chunks` chunks of roughly equal lengths. """ if len(indices) % num_chunks != 0: return [indices[i::num_chunks] for i in range(num_chunks)] num_indices_per_chunk = len(indices) // num_chunks chunks = [[] for _ in range(num_chunks)] chunks_lengths = [0 for _ in range(num_chunks)] for index in indices: shortest_chunk = chunks_lengths.index(min(chunks_lengths)) chunks[shortest_chunk].append(index) chunks_lengths[shortest_chunk] += lengths[index] if len(chunks[shortest_chunk]) == num_indices_per_chunk: chunks_lengths[shortest_chunk] = float("inf") return chunks def get_variable_length_grouped_indices( lengths, batch_size, world_size, megabatch_mult=8, generator=None ): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True) megabatch_size = world_size * batch_size * megabatch_mult megabatches = [ sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size) ] megabatches = [ sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches ] shuffled_indices = [i for megabatch in megabatches for i in megabatch] world_batch_size = world_size * batch_size batches = [ shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size) ] batch_indices = torch.randperm(len(batches), generator=generator) batches = [batches[i] for i in batch_indices] return [i for batch in batches for i in batch] def get_modality_length_grouped_indices( lengths, batch_size, world_size, generator=None ): """ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar lengths. To do this, the indices are: - randomly permuted - grouped in mega-batches of size `mega_batch_mult * batch_size` - reorder by length in each mega-batch The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of maximum length placed first, so that an OOM happens sooner rather than later. """ # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): # all samples are in the same modality return get_length_grouped_indices( lengths, batch_size, world_size, generator=generator ) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) mm_shuffle = [ mm_indices[i] for i in get_length_grouped_indices( mm_lengths, batch_size, world_size, generator=None ) ] lang_shuffle = [ lang_indices[i] for i in get_length_grouped_indices( lang_lengths, batch_size, world_size, generator=None ) ] megabatch_size = world_size * batch_size mm_megabatches = [ mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size) ] lang_megabatches = [ lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size) ] last_mm = mm_megabatches[-1] last_lang = lang_megabatches[-1] additional_batch = last_mm + last_lang megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] if len(additional_batch) > 0: megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] def get_length_grouped_indices( lengths, batch_size, world_size, generator=None, merge=True ): """ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar lengths. To do this, the indices are: - randomly permuted - grouped in mega-batches of size `mega_batch_mult * batch_size` - reorder by length in each mega-batch The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of maximum length placed first, so that an OOM happens sooner rather than later. """ # We need to use torch for the random part as a distributed sampler will set the random seed for torch. indices = torch.randperm(len(lengths), generator=generator) megabatch_size = world_size * batch_size megabatches = [ indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size) ] megabatches = [ sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches ] megabatches = [ split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches ] return [i for megabatch in megabatches for batch in megabatch for i in batch] def get_length_grouped_indices_auto_single( lengths, batch_size, world_size, generator=None ): indices = get_length_grouped_indices_hf( lengths, batch_size * world_size, generator=generator ) megabatch_size = world_size * batch_size megabatches = [ indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size) ] megabatches = [ sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches ] megabatches = [ split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches ] # We need to use torch for the random part as a distributed sampler will set the random seed for torch. batch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in batch_indices] return [i for megabatch in megabatches for batch in megabatch for i in batch] def get_modality_length_grouped_indices_auto( lengths, batch_size, world_size, generator=None ): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): # all samples are in the same modality return get_length_grouped_indices_auto_single( lengths, batch_size, world_size, generator=generator ) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) mm_shuffle = [ mm_indices[i] for i in get_length_grouped_indices_auto_single( mm_lengths, batch_size, world_size, generator=None ) ] lang_shuffle = [ lang_indices[i] for i in get_length_grouped_indices_auto_single( lang_lengths, batch_size, world_size, generator=None ) ] megabatch_size = world_size * batch_size mm_megabatches = [ mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size) ] lang_megabatches = [ lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size) ] last_mm = mm_megabatches[-1] last_lang = lang_megabatches[-1] additional_batch = last_mm + last_lang megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] megabatch_indices = torch.randperm(len(megabatches), generator=generator) megabatches = [megabatches[i] for i in megabatch_indices] if len(additional_batch) > 0: megabatches.append(sorted(additional_batch)) return [i for megabatch in megabatches for i in megabatch] class LengthGroupedSampler(Sampler): r""" Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness. """ def __init__( self, batch_size: int, world_size: int, lengths: Optional[List[int]] = None, generator=None, variable_length: bool = False, group_by_modality: bool = False, group_by_modality_auto: bool = False, ): if lengths is None: raise ValueError("Lengths must be provided.") self.batch_size = batch_size self.world_size = world_size self.lengths = lengths self.generator = generator self.variable_length = variable_length self.group_by_modality = group_by_modality self.group_by_modality_auto = group_by_modality_auto def __len__(self): return len(self.lengths) def __iter__(self): if self.variable_length: assert ( not self.group_by_modality ), "Variable length grouping is not supported with modality grouping." indices = get_variable_length_grouped_indices( self.lengths, self.batch_size, self.world_size, generator=self.generator ) else: if self.group_by_modality: indices = get_modality_length_grouped_indices( self.lengths, self.batch_size, self.world_size, generator=self.generator, ) elif self.group_by_modality_auto: indices = get_modality_length_grouped_indices_auto( self.lengths, self.batch_size, self.world_size, generator=self.generator, ) else: indices = get_length_grouped_indices_auto_single( self.lengths, self.batch_size, self.world_size, generator=self.generator, ) return iter(indices) def _is_peft_model(model): if is_peft_available(): classes_to_check = (PeftModel,) if is_peft_available() else () # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): from peft import PeftMixedModel classes_to_check = (*classes_to_check, PeftMixedModel) return isinstance(model, classes_to_check) return False TRAINER_STATE_NAME = "trainer_state.json" class LLaVATrainer(Trainer): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None if self.args.group_by_length: lengths = self.train_dataset.lengths return LengthGroupedSampler( # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps self.args.train_batch_size, # world_size=self.args.world_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? lengths=lengths, ) elif self.args.group_by_modality_length: lengths = self.train_dataset.modality_lengths return LengthGroupedSampler( # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps self.args.train_batch_size, # world_size=self.args.world_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? lengths=lengths, group_by_modality=True, ) elif self.args.group_by_modality_length_auto: lengths = self.train_dataset.modality_lengths return LengthGroupedSampler( # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps self.args.train_batch_size, # world_size=self.args.world_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? lengths=lengths, group_by_modality_auto=True, ) elif self.args.group_by_varlen: lengths = self.train_dataset.lengths return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps # world_size=self.args.world_size, world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? lengths=lengths, variable_length=True, ) else: return super()._get_train_sampler() def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] lr_mapper = {} if self.args.speech_projector_lr is not None: lr_mapper["speech_projector"] = self.args.speech_projector_lr if len(lr_mapper) > 0: special_lr_parameters = [ name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper) ] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if ( n in decay_parameters and n not in special_lr_parameters and p.requires_grad ) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if ( n not in decay_parameters and n not in special_lr_parameters and p.requires_grad ) ], "weight_decay": 0.0, }, ] for module_keyword, lr in lr_mapper.items(): module_parameters = [ name for name, _ in opt_model.named_parameters() if module_keyword in name ] optimizer_grouped_parameters.extend( [ { "params": [ p for n, p in opt_model.named_parameters() if ( n in decay_parameters and n in module_parameters and p.requires_grad ) ], "weight_decay": self.args.weight_decay, "lr": lr, }, { "params": [ p for n, p in opt_model.named_parameters() if ( n not in decay_parameters and n in module_parameters and p.requires_grad ) ], "weight_decay": 0.0, "lr": lr, }, ] ) else: optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args ) self.optimizer = optimizer_cls( optimizer_grouped_parameters, **optimizer_kwargs ) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum( { p.data_ptr(): p.numel() for p in module.parameters() }.values() ) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override( module, "weight", {"optim_bits": 32} ) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, "tune_mm_mlp_adapter", False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ["speech_projector"] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3( self.model.named_parameters(), keys_to_match ) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save( weight_to_save, os.path.join(output_dir, f"speech_projector.bin") ) else: print("self.is_local_process_zero()", self.is_local_process_zero()) super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, "tune_mm_mlp_adapter", False): pass super(LLaVATrainer, self)._save(output_dir, state_dict) else: super(LLaVATrainer, self)._save(output_dir, state_dict)