from typing import Dict, Optional import torch import torch.distributed as dist from torch import nn, Tensor from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.arguments import ModelArguments from src.vlm_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.vlm_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import Qwen2VLForConditionalGeneration class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'cls', normalize: bool = False, temperature: float = 1.0, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() if self.is_ddp: self.process_rank = dist.get_rank() self.world_size = dist.get_world_size() def encode_input(self, input): hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output def _pooling(self, last_hidden_state, attention_mask): if self.pooling == 'last' or self.pooling == 'eos': sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_state.shape[0] reps = last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **hf_kwargs): # Loading the base model lora_target_modules = None if model_args.model_backbone == "llava_next": config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_args.model_backbone == "qwen": base_model = Qwen2VLForConditionalGeneration.from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) base_model.padding_side = "right" # Loading the base model elif model_args.model_backbone == "phi35v": config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # config._attn_implementation = "eager" config.attn_implementation = "flash_attention_2" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_args.model_backbone == "internvl_2_5": # from transformers import InternVLChatConfig, InternVLChatModel from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_args.model_name, trust_remote_code=True ) # import pdb;pdb.set_trace() config = InternVLChatConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸 config.use_flash_attn = False base_model = InternVLChatModel.from_pretrained( model_args.model_name, config=config, tokenizer=tokenizer, # attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 ) lora_target_modules = base_model.get_lora_target_modules() else: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **hf_kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" if model_args.lora: if lora_target_modules is None: lora_target_modules = model_args.lora_target_modules.split(',') lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=lora_target_modules, lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) return model @classmethod def load(cls, model_args: ModelArguments, **hf_kwargs): # Loading the base model checkpoint_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name if model_args.model_backbone == "llava_next": config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, # attn_implementation="flash_attention_2" ) base_model.padding_side = "left" elif model_args.model_backbone == "phi35v": # Loading the base model config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **hf_kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" elif model_args.model_backbone == "internvl_2_5": print("loading model") from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_args.model_name, trust_remote_code=True ) config = InternVLChatConfig.from_pretrained(model_args.model_name) # config.vision_config.image_size = data_args.force_image_size config.use_flash_attn = False base_model = InternVLChatModel.from_pretrained( model_args.model_name, config=config, tokenizer=tokenizer, # attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 ) else: # Loading the base model config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = cls.TRANSFORMER_CLS.from_pretrained( checkpoint_path, **hf_kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" # Building the model on top of the base if model_args.lora: print("loading lora parameters") lora_config = LoraConfig.from_pretrained(checkpoint_path) lora_model = PeftModel.from_pretrained(base_model, checkpoint_path, config=lora_config) merged_model = lora_model.merge_and_unload() model = cls( encoder=merged_model, pooling=model_args.pooling, normalize=model_args.normalize ) else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize ) return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, neg: Dict[str, Tensor] = None): qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) neg_reps = self.encode_input(neg) if neg else None # (bsz_per_device, dim) if qry_reps is None or tgt_reps is None: return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} # Gather representations if using DDP if self.is_ddp: all_qry_reps = self._dist_gather_tensor(qry_reps) all_tgt_reps = self._dist_gather_tensor(tgt_reps) all_neg_reps = self._dist_gather_tensor(neg_reps) if neg_reps is not None else None else: all_qry_reps = qry_reps all_tgt_reps = tgt_reps all_neg_reps = neg_reps # Compute similarity scores scores = self.compute_similarity(all_qry_reps, all_tgt_reps) scores = scores.view(all_qry_reps.size(0), -1) # Add negative scores if available if all_neg_reps is not None: qry_neg_cos = self.compute_similarity(all_qry_reps, all_neg_reps) scores = torch.cat([scores, qry_neg_cos], dim=1) # Compute loss target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) loss = self.cross_entropy(scores / self.temperature, target) if self.is_ddp: loss = loss * self.world_size return loss def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))