from transformers.trainer import Trainer, TRAINING_ARGS_NAME import torch.distributed as dist from typing import Optional import os import torch from src.loss import SimpleContrastiveLoss, DistributedContrastiveLoss, HardNegativeContrastiveLoss, DistributedHardNegativeContrastiveLoss from itertools import repeat from grad_cache.grad_cache import GradCache MAX_INPUT_ID = int(1e9) LLAVA_IMAGE_TOKEN_ID = 32000 class MMEBTrainer(Trainer): def __init__(self, *args, **kwargs): super(MMEBTrainer, self).__init__(*args, **kwargs) self.is_ddp = dist.is_initialized() self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 def compute_loss(self, model, inputs, *args, **kwargs): if self.args.hard_neg: qry_inputs, tgt_inputs, neg_inputs = inputs return model(qry=qry_inputs, tgt=tgt_inputs, neg=neg_inputs) qry_inputs, tgt_inputs = inputs return model(qry=qry_inputs, tgt=tgt_inputs) def _save(self, output_dir: Optional[str] = None, state_dict=None): os.makedirs(output_dir, exist_ok=True) if state_dict is None: state_dict = self.model.state_dict() prefix = 'encoder.' assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} self.model.encoder.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) def split_dense_inputs(model_input: dict, chunk_size: int): assert len(model_input) == 1 arg_key = list(model_input.keys())[0] arg_val = model_input[arg_key] keys = list(arg_val.keys()) chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys] chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] return [{arg_key: c} for c in chunked_arg_val] def split_vlm_inputs(model_input: dict, chunk_size: int): assert len(model_input) == 1 arg_key = list(model_input.keys())[0] arg_val = model_input[arg_key] keys = list(arg_val.keys()) # for input_ids and attention_mask, split directly chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in ["input_ids", "attention_mask"]] # for pixel_values and image_sizes or any other image-related fields, need to split based on the position of images image_mask = "image_mask" if "image_mask" in keys else None if image_mask in keys: row_contain_image = torch.nonzero(arg_val[image_mask], as_tuple=False).squeeze() # indicates which row in input_ids contain images if image_mask == "image_mask": keys.remove(image_mask) num_chunks = len(chunked_tensors[0]) chunk_image_count = [] for chunk_idx in range(num_chunks): chunk_image_count.append(torch.sum( (row_contain_image >= chunk_idx * chunk_size) & (row_contain_image < (chunk_idx + 1) * chunk_size)).item()) if "pixel_values" in keys: pixel_values = arg_val["pixel_values"] chunked_tensors.append(torch.split(pixel_values, chunk_image_count)) if "image_sizes" in keys: image_sizes = arg_val["image_sizes"] chunked_tensors.append(torch.split(image_sizes, chunk_image_count)) if "image_grid_thw" in keys: image_grid_thw = arg_val["image_grid_thw"] chunked_tensors.append(torch.split(image_grid_thw, chunk_image_count)) if "image_flags" in keys: image_flags = arg_val["image_flags"] chunked_tensors.append(torch.split(image_flags, chunk_size)) keys.remove("image_flags") chunked_arg_val = [] for kk, tt in zip(repeat(keys), zip(*chunked_tensors)): chunk_dict = {} # 先添加基本字段 if "pixel_values" in keys and tt[2].numel() == 0: # this chunk doesn't contain image chunk_dict.update(dict(zip(kk[:2], tt[:2]))) else: chunk_dict.update(dict(zip(kk, tt))) # 如果有image_flags,添加对应的chunk if "image_flags" in arg_val: chunk_idx = len(chunked_arg_val) chunk_dict["image_flags"] = chunked_tensors[-1][chunk_idx] chunked_arg_val.append(chunk_dict) return [{arg_key: c} for c in chunked_arg_val] def get_dense_rep(x): """ Get either qry_reps or tgt_reps. """ if x["qry_reps"] is None: return x["tgt_reps"] else: return x["qry_reps"] class GradCacheTrainer(Trainer): """ Adapted from gradcache repo. """ def __init__(self, *args, **kwargs): super(GradCacheTrainer, self).__init__(*args, **kwargs) self.is_ddp = dist.is_initialized() self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1 # loss_fn_cls = DistributedContrastiveLoss if self.is_ddp else SimpleContrastiveLoss # 使用新的损失函数 loss_fn_cls = DistributedHardNegativeContrastiveLoss if self.is_ddp else HardNegativeContrastiveLoss loss_fn = loss_fn_cls(temperature=self.model.temperature) self.gc = GradCache( models=[self.model, self.model], chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size], loss_fn=loss_fn, split_input_fn=split_vlm_inputs, get_rep_fn=get_dense_rep, fp16=self.args.fp16, scaler=self.scaler if self.args.fp16 else None ) def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: model.train() if self.args.hard_neg: queries, passages, negatives = inputs queries, passages, negatives = {'qry': queries}, {'tgt': passages}, {'neg': negatives} if self.args.local_rank == 0: print(f"qry.shape={queries['qry']['input_ids'].shape}") print(f"tgt.shape={passages['tgt']['input_ids'].shape}") print(f"neg.shape={negatives['neg']['input_ids'].shape}") if 'pixel_values' in queries['qry']: print(f"qry_img.shape={queries['qry']['pixel_values'].shape}") if 'pixel_values' in passages['tgt']: print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}") if 'pixel_values' in negatives['neg']: print(f"neg_img.shape={negatives['neg']['pixel_values'].shape}") _distributed = self.args.local_rank > -1 self.gc.models = [model, model, model] loss = self.gc(queries, passages, negatives, no_sync_except_last=_distributed) else: queries, passages = inputs queries, passages = {'qry': queries}, {'tgt': passages} if self.args.local_rank == 0: print(f"qry.shape={queries['qry']['input_ids'].shape}") print(f"tgt.shape={passages['tgt']['input_ids'].shape}") if 'pixel_values' in queries['qry']: print(f"qry_img.shape={queries['qry']['pixel_values'].shape}") if 'pixel_values' in passages['tgt']: print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}") _distributed = self.args.local_rank > -1 self.gc.models = [model, model] loss = self.gc(queries, passages, no_sync_except_last=_distributed) return loss / self._dist_loss_scale_factor def _save(self, output_dir: Optional[str] = None, state_dict=None): print(f"Saving model to {output_dir}") os.makedirs(output_dir, exist_ok=True) if state_dict is None: state_dict = self.model.state_dict() prefix = 'encoder.' assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys()) state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} self.model.encoder.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json'))