import logging from typing import List, Tuple from dataclasses import dataclass from transformers import ProcessorMixin, AutoProcessor, AutoTokenizer from src.arguments import DataArguments, ModelArguments import torch logger = logging.getLogger(__name__) @dataclass class TrainCollator: data_args: DataArguments model_args: ModelArguments processor: ProcessorMixin def __call__(self, examples): """ :param examples: [{qry:..., qry_image:..., pos_text:..., pos_image:...}] * batch_size """ qry_inputs = self._get_batch_inputs(examples, 0, 1) pos_inputs = self._get_batch_inputs(examples, 2, 3) if "hard_neg" in self.data_args.dataset_name: hard_neg_inputs = self._get_batch_inputs(examples, 4, 5) return qry_inputs, pos_inputs, hard_neg_inputs return qry_inputs, pos_inputs def _get_batch_inputs(self, examples, text_idx, image_idx): input_ids, pixel_values = [], [] image_mask, image_sizes, image_grid_thw = [], [], [] for example in examples: text, image = example[text_idx], example[image_idx] has_image = image is not None image_mask.append(1 if has_image else 0) # 统一processor调用逻辑 if self.model_args.model_backbone == "llava_next": inputs = self.processor( text=text, images=image if has_image else None, return_tensors="pt", max_length=self.data_args.max_len, truncation=True ) elif self.model_args.model_backbone in ["qwen", "qwen2_vl"]: inputs = self.processor( text=[text], images=[image] if has_image else None, return_tensors="pt", max_length=self.data_args.max_len, truncation=True ) else: inputs = self.processor( text=text, images=[image] if has_image else None, return_tensors="pt", max_length=self.data_args.max_len, truncation=True ) if has_image: if self.model_args.model_backbone == "qwen": pixel_values.append(inputs['pixel_values'].unsqueeze(0)) else: pixel_values.append(inputs['pixel_values']) input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) if "image_sizes" in inputs: image_sizes.append(inputs['image_sizes']) if "image_grid_thw" in inputs: image_grid_thw.append(inputs['image_grid_thw']) input_ids = torch._C._nn.pad_sequence( input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id ).squeeze(2) attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'image_mask': torch.tensor(image_mask, dtype=torch.float) } if any(image_mask): if pixel_values: inputs['pixel_values'] = torch.cat(pixel_values, dim=0) if image_sizes: inputs['image_sizes'] = torch.cat(image_sizes, dim=0) if image_grid_thw: inputs['image_grid_thw'] = torch.cat(image_grid_thw, dim=0) if self.model_args.model_backbone == "internvl_2_5": inputs['image_flags'] = inputs['image_mask'].to(torch.long) return inputs @dataclass class EvalCollator: data_args: DataArguments model_args: ModelArguments processor: ProcessorMixin def __call__(self, examples): """ :param examples: qry, qry_image, pos_text, pos_image """ inputs = self._get_batch_inputs(examples) return inputs def _get_batch_inputs(self, examples): input_ids, pixel_values, image_sizes = [], [], [] image_mask = [] image_exist = False for example in examples: text, image = example has_image = image is not None image_mask.append(1 if has_image else 0) if self.model_args.model_backbone == "internvl_2_5": inputs = self.processor( text=text, images=[image] if has_image else None, return_tensors="pt", max_length=self.data_args.max_len, truncation=True ) input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) if has_image: pixel_values.append(inputs['pixel_values']) if 'image_sizes' in inputs: image_sizes.append(inputs['image_sizes']) continue if image is None: if self.model_args.model_backbone == "llava_next": inputs = self.processor(images=None, text=text, return_tensors="pt") else: inputs = self.processor(text, None, return_tensors="pt", max_length=self.data_args.max_len, truncation=True) input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) pixel_values.append(None) image_sizes.append(None) else: image_exist = True if self.model_args.model_backbone == "llava_next": inputs = self.processor(images=image, text=text, return_tensors="pt") else: inputs = self.processor(text, [image], return_tensors="pt", max_length=self.data_args.max_len, truncation=True) input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) pixel_values.append(inputs['pixel_values']) image_sizes.append(inputs['image_sizes']) input_ids = torch._C._nn.pad_sequence( input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id ).squeeze(2) attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) if self.model_args.model_backbone == "internvl_2_5": inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'image_mask': torch.tensor(image_mask, dtype=torch.float) } if any(image_mask): if pixel_values: inputs['pixel_values'] = torch.cat(pixel_values, dim=0) if image_sizes: inputs['image_sizes'] = torch.cat(image_sizes, dim=0) inputs['image_flags'] = inputs['image_mask'].to(torch.long) del inputs['image_mask'] else: if not image_exist: dummy_pixel_values = torch.zeros(input_ids.shape[0], 1) dummy_image_sizes = torch.ones(input_ids.shape[0], 1) inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'pixel_values': dummy_pixel_values, 'image_sizes': dummy_image_sizes, } else: pixel_values_shape = list(set(v.shape for v in pixel_values if v is not None))[0] pixel_values = [v if v is not None else torch.zeros(pixel_values_shape) for v in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) image_sizes_shape = list(set(v.shape for v in image_sizes if v is not None))[0] image_sizes = [v if v is not None else torch.ones(image_sizes_shape) for v in image_sizes] image_sizes = torch.cat(image_sizes, dim=0) inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'pixel_values': pixel_values, 'image_sizes': image_sizes, } return inputs @dataclass class CLIPCollator: data_args: DataArguments vis_processors: AutoProcessor txt_processors: AutoTokenizer def __call__(self, examples): """ :param examples: qry, qry_image, pos_text, pos_image """ inputs = self._get_batch_inputs(examples) return inputs def _get_batch_inputs(self, examples): input_ids, pixel_values, attention_mask = [], [], [] image_exist, text_exist = False, False for example in examples: text, image = example if image is not None: if image.mode == 'L': image = image.convert('RGB') image_inputs = self.vis_processors(images=image, return_tensors="pt") image_exist = True pixel_values.append(image_inputs['pixel_values']) if text: text_exist = True text_inputs = self.txt_processors(text, padding=getattr(self.data_args, "padding", True), max_length=self.data_args.max_len, truncation=True, return_tensors="pt") input_ids.append(text_inputs["input_ids"].squeeze(0)) if text_exist: input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.txt_processors.pad_token_id ) attention_mask = input_ids.ne(self.txt_processors.pad_token_id) if image_exist: pixel_values = torch.cat(pixel_values, dim=0) if text_exist and image_exist: assert input_ids.size()[0]==pixel_values.size()[0] inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'pixel_values': pixel_values, } return inputs @dataclass class OpenCLIPCollator: data_args: DataArguments vis_processors: AutoProcessor txt_processors: AutoTokenizer def __call__(self, examples): """ :param examples: qry, qry_image, pos_text, pos_image """ inputs = self._get_batch_inputs(examples) return inputs def _get_batch_inputs(self, examples): input_ids, pixel_values, attention_mask = [], [], [] image_exist, text_exist = False, False for example in examples: text, image = example if image is not None: if image.mode == 'L': image = image.convert('RGB') image_inputs = self.vis_processors(image).unsqueeze(0) image_exist = True pixel_values.append(image_inputs) if text: text_exist = True text_inputs = self.txt_processors(text) input_ids.append(text_inputs) if text_exist: input_ids = torch.cat(input_ids, dim=0) attention_mask = input_ids.ne(0) if image_exist: pixel_values = torch.cat(pixel_values, dim=0) if text_exist and image_exist: assert input_ids.size()[0]==pixel_values.size()[0] inputs = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'pixel_values': pixel_values, } return inputs