import random from typing import List, Tuple from itertools import islice import datasets from datasets import load_dataset, concatenate_datasets from torch.utils.data import Dataset from PIL import Image import os from torchvision.transforms import RandAugment def get_randaugment_transform(n=2, m=9): return RandAugment(num_ops=n, magnitude=m) def add_prompt_template(data): data["qry"] = f"<|image_1|>{data['qry']}" data["pos_text"] = f"<|image_1|>{data['pos_text']}" data["hard_neg_text"] = f"<|image_1|>{data['hard_neg_text']}" return data Phi_Image_token = "<|image_1|>" Llava_Image_token = "" Qwen_Image_token = "<|image_pad|>" Internvl_Image_token = "" class TrainDataset(Dataset): def __init__(self, data_args, model_args): self.data_args = data_args self.model_args = model_args self.transform = None if self.data_args.randaugment: self.transform = get_randaugment_transform() train_data = [] if data_args.subset_name is not None: print(f"Loading {len(data_args.subset_name)} datasets: {data_args.subset_name}") for subset in data_args.subset_name: dataset_name = os.path.join(self.data_args.dataset_name, subset) subset_data = load_dataset( dataset_name, split=f"{self.data_args.dataset_split}", ) train_data.append(subset_data) self.train_data = concatenate_datasets(train_data) self.train_data = self.train_data.shuffle(seed=42) else: train_data = load_dataset( self.data_args.dataset_name, split=f"{self.data_args.dataset_split}", ) if "hard_neg" in self.data_args.dataset_name: # self.train_data = train_data.map(add_prompt_template, num_proc=8) print(train_data) else: self.train_data = train_data if self.data_args.num_samples: # self.train_data = self.train_data[:self.data_args.num_samples] self.train_data = self.train_data.select(range(self.data_args.num_samples)) print(f"len of train_data: {len(self.train_data)}") def __len__(self): return len(self.train_data) def _process_image(self, image, resolution): if image is None: return None if resolution == "high": image = image.resize((1344, 1344)) elif resolution == "low": image = image.resize((336, 336)) elif resolution == "clip": image = image.resize((224, 224)) return image def _get_image(self, img_path): if img_path == "": return None if img_path.startswith('/'): full_img_path = img_path else: full_img_path = os.path.join(self.data_args.image_dir, img_path) image = Image.open(full_img_path) if self.model_args.model_backbone == "llava_next": # TODO: make it configurable return self._process_image(image, "high") elif self.model_args.model_backbone == "qwen": return self._process_image(image, "low") elif self.model_args.model_backbone == "internvl_2_5": # TODO: make it configurable return self._process_image(image, "high") else: return image def __getitem__(self, item) -> Tuple[str, List[str]]: data_item = self.train_data[item] qry_text, qry_image_path, pos_text, pos_image_path = ( data_item["qry"], data_item["qry_image_path"], data_item["pos_text"], data_item["pos_image_path"], ) qry_image = self._get_image(qry_image_path) if self.transform: qry_image = self.transform(qry_image) if self.model_args.model_backbone == "llava_next": # Update image token qry_text = qry_text.replace(Phi_Image_token, Llava_Image_token) pos_text = pos_text.replace(Phi_Image_token, Llava_Image_token) elif self.model_args.model_backbone == "qwen": qry_text = qry_text.replace(Phi_Image_token, Qwen_Image_token) pos_text = pos_text.replace(Phi_Image_token, Qwen_Image_token) elif self.model_args.model_backbone == "internvl_2_5": qry_text = qry_text.replace(Phi_Image_token, Internvl_Image_token) pos_text = pos_text.replace(Phi_Image_token, Internvl_Image_token) if "hard_neg" in self.data_args.dataset_name: hard_neg_text, hard_neg_image_path = ( data_item["hard_neg_text"], data_item["hard_neg_image_path"], ) if self.model_args.model_backbone == "llava_next": # Update image token hard_neg_text = hard_neg_text.replace(Phi_Image_token, Llava_Image_token) elif self.model_args.model_backbone == "internvl_2_5": hard_neg_text = hard_neg_text.replace(Phi_Image_token, Internvl_Image_token) return ( qry_text, qry_image, pos_text, self._get_image(pos_image_path), hard_neg_text, self._get_image(hard_neg_image_path) ) return ( qry_text, qry_image, pos_text, self._get_image(pos_image_path) ) class EvalDataset(Dataset): def __init__(self, data_args, model_args, subset, text_field, img_path_field): """ (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") """ self.data_args = data_args self.model_args = model_args if data_args.subset_name is not None: self.eval_data = load_dataset( self.data_args.dataset_name, subset, split=self.data_args.dataset_split, ) else: self.eval_data = load_dataset( self.data_args.dataset_name, split=self.data_args.dataset_split, ) print(f"len of eval_data: {len(self.eval_data)}") self.paired_data = self.get_paired_data(text_field, img_path_field) self.paired_dataset = datasets.Dataset.from_dict({ "text": [pair["text"] for pair in self.paired_data], "img_path": [pair["img_path"] for pair in self.paired_data] }) def __len__(self): return len(self.paired_dataset) def __getitem__(self, item): text, img_path = self.paired_dataset[item]["text"], self.paired_dataset[item]["img_path"] if self.model_args.model_backbone == "llava_next": # Update llava image token text = text.replace(Phi_Image_token, Llava_Image_token) elif self.model_args.model_backbone == "qwen": text = text.replace(Phi_Image_token, Qwen_Image_token) elif self.model_args.model_backbone == "internvl_2_5": text = text.replace(Phi_Image_token, Internvl_Image_token) return text, self._get_image(img_path), def _process_image(self, image, resolution): if image is None: return None if resolution == "high": image = image.resize((1344, 1344)) else: image = image.resize((336, 336)) return image def _get_image(self, img_path): if img_path == "": return None if img_path.startswith("/"): full_img_path = img_path else: full_img_path = os.path.join(self.data_args.image_dir, img_path) image = Image.open(full_img_path) if self.model_args.model_backbone == "llava_next": return self._process_image(image, "high") elif self.model_args.model_backbone == "internvl_2_5": return self._process_image(image, "high") else: return image return image def get_paired_data(self, text_field, img_path_field): """ (text_field, image_field) -> ("qry_text", "qry_img_path") or ("tgt_text", "tgt_img_path") """ unique_pair = set() for row in self.eval_data: if isinstance(row[text_field], str): if row[text_field]: unique_pair.add((row[text_field], row[img_path_field])) else: if isinstance(row[img_path_field], List): for img_path in row[img_path_field]: unique_pair.add((row[text_field], img_path)) else: unique_pair.add((row[text_field], row[img_path_field])) elif isinstance(row[text_field], List): assert isinstance(row[img_path_field], List) and len(row[img_path_field]) == len(row[text_field]) for text, img_path in zip(row[text_field], row[img_path_field]): unique_pair.add((text, img_path)) paired_data = [{"text": text, "img_path": img_path} for text, img_path in unique_pair] return paired_data class FlickrDataset(Dataset): def __init__(self, modality, model_backbone): self.model_backbone = model_backbone self.modality = modality self.raw_data = load_dataset("nlphuji/flickr_1k_test_image_text_retrieval", split="test") if modality == "image": self.eval_data, self.image_names = self.get_image_data() else: self.eval_data, self.image_names = self.get_text_data() def __len__(self): return len(self.eval_data) def __getitem__(self, idx): return self.eval_data[idx] def __getitem__(self, idx): text, image = self.eval_data[idx] if self.model_backbone == "llava_next": # Update llava image token text = text.replace(Phi_Image_token, Llava_Image_token) image = self._process_image(image, "high") return text, image def _process_image(self, image, resolution): if image is None: return None if resolution == "high": image = image.resize((1344, 1344)) else: image = image.resize((336, 336)) return image def _get_image(self, img_path): if img_path == "": return None full_img_path = os.path.join(self.data_args.image_dir, img_path) image = Image.open(full_img_path) if self.model_backbone == "llava_next": return self._process_image(image, "high") else: return image return image def get_image_data(self): eval_data, image_names = [], [] # i2t inst = "<|image_1|> Find an image caption describing the given image." # llava-1344-step1k4, i2t=94.0, t2i=80.26 # inst = "<|image_1|> Represent the given image for image caption retrieval." # llava-1344-step1k4, i2t=94.6, t2i=78.98 # t2i # inst = "<|image_1|> Represent the given image." # MSCOCO t2i for row in self.raw_data: eval_data.append((inst, row["image"])) image_names.append(row["filename"]) return eval_data, image_names def get_text_data(self): eval_data, image_names = [], [] # i2t inst = "" # t2i # inst = "Retrieve an image that matches the given caption: " # inst = "Find me an everyday image that matches the given caption." # MSCOCO t2i for row in self.raw_data: for caption in row["caption"]: # eval_data.append((caption, None)) eval_data.append((inst + caption, None)) image_names.append(row["filename"]) return eval_data, image_names