liubangwei commited on
Commit
1855cc2
·
1 Parent(s): a72a7d4

init IDMR demo

Browse files
app.py CHANGED
@@ -8,27 +8,25 @@ from transformers import AutoProcessor
8
  from src.model import MMEBModel
9
  from src.arguments import ModelArguments
10
 
11
- # 假设图片库存储在本地文件夹中
12
  QUERY_DIR = "imgs/queries"
13
  IMAGE_DIR = "imgs/candidates"
14
- # IMAGE_DIR = "imgs"
15
  image_paths = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith((".jpg", ".png"))]
16
  global IMAGE_TOKEN, TOP_N
17
  IMAGE_TOKEN = "<|image_1|>"
18
  TOP_N = 5
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  print(f"device: {device}")
21
- # 模型加载和初始化
 
22
  def load_model():
23
  global IMAGE_TOKEN
24
- # 模型参数
25
  model_args = ModelArguments(
26
- # model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/IDMR/IDMR_InternVL2_5-2B", # 替换为你的模型名称
27
- model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B",
28
- model_backbone="internvl_2_5", # 替换为你的模型 backbone
29
  )
30
 
31
- # 加载处理器
32
  if model_args.model_backbone == "phi35v":
33
  processor = AutoProcessor.from_pretrained(
34
  model_args.model_name,
@@ -54,14 +52,12 @@ def load_model():
54
  )
55
  IMAGE_TOKEN = "<image>"
56
 
57
- # 加载模型
58
  model = MMEBModel.load(model_args)
59
  model = model.to(device, dtype=torch.bfloat16)
60
  model.eval()
61
 
62
  return model, processor
63
 
64
- # 加载模型和处理器
65
  model, processor = load_model()
66
 
67
  def get_inputs(processor, text, image_path=None, image=None):
@@ -84,8 +80,6 @@ def get_inputs(processor, text, image_path=None, image=None):
84
  del inputs['pixel_values']
85
  return inputs
86
 
87
-
88
- # 将图片库中的图像编码为 embedding
89
  def encode_image_library(image_paths):
90
  embeddings = []
91
  for img_path in image_paths:
@@ -97,22 +91,18 @@ def encode_image_library(image_paths):
97
  embeddings.append(output["tgt_reps"].float().cpu().numpy())
98
  return np.stack(embeddings)
99
 
100
- # 保存 embedding 到文件
101
  def save_embeddings(embeddings, file_path="image_embeddings.pkl"):
102
  with open(file_path, "wb") as f:
103
  pickle.dump(embeddings, f)
104
 
105
- # 加载 embedding 从文件
106
  def load_embeddings(file_path="image_embeddings.pkl"):
107
  with open(file_path, "rb") as f:
108
  return pickle.load(f)
109
 
110
- # 计算相似度(余弦相似度)
111
  def cosine_similarity(query_embedding, embeddings):
112
  similarity = np.sum(query_embedding * embeddings, axis=-1)
113
  return similarity
114
 
115
- # 检索逻辑
116
  def retrieve_images(query_text, query_image, top_n=TOP_N):
117
  if query_text:
118
  query_text = f"{IMAGE_TOKEN}\n {query_text}"
@@ -129,11 +119,8 @@ def retrieve_images(query_text, query_image, top_n=TOP_N):
129
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
130
  query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy()
131
 
132
-
133
- # 加载图片库的 embedding
134
  embeddings = load_embeddings()
135
 
136
- # 计算相似度
137
  similarity = cosine_similarity(query_embedding, embeddings)
138
  similarity = similarity.T
139
  print(f"cosine_similarity: {similarity}")
@@ -145,29 +132,22 @@ def retrieve_images(query_text, query_image, top_n=TOP_N):
145
 
146
  return [image_paths[i] for i in top_indices]
147
 
148
- # 界面逻辑
149
  def demo(query_text, query_image):
150
- # 执行检索
151
  # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}")
152
 
153
  retrieved_images = retrieve_images(query_text, query_image)
154
- # 返回检索结果(图片列表)
155
  return [Image.open(img) for img in retrieved_images]
156
 
157
- # 预置示例
158
  def load_examples():
159
  examples = []
160
- # 获取QUERY_DIR中的所有图片文件
161
  image_files = [f for f in os.listdir(QUERY_DIR) if f.endswith((".jpg", ".png"))]
162
 
163
  for img_file in image_files:
164
- # 构建图片完整路径
165
  img_path = os.path.join(QUERY_DIR, img_file)
166
- # 获取对应的txt文件名(将图片扩展名替换为.txt)
167
  txt_file = os.path.splitext(img_file)[0] + ".txt"
168
  txt_path = os.path.join(QUERY_DIR, txt_file)
169
 
170
- # 如果存在对应的txt文件,读取查询文本
171
  if os.path.exists(txt_path):
172
  with open(txt_path, 'r', encoding='utf-8') as f:
173
  query_text = f.read().strip().replace("<|image_1|>\n", "")
@@ -175,20 +155,17 @@ def load_examples():
175
 
176
  return examples
177
 
178
- # 构建 Gradio 界面
179
  iface = gr.Interface(
180
  fn=demo,
181
  inputs=["text", "image"],
182
  outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})"),
183
- examples=load_examples(), # 使用动态加载的示例
184
  title="Multimodal Retrieval Demo",
185
  description="Enter a query and upload an image to retrieve relevant images from the library. You can click on the example below to use it as a query"
186
  )
187
 
188
- # 在启动时编码图片库并保存 embedding
189
  if not os.path.exists("image_embeddings.pkl"):
190
  embeddings = encode_image_library(image_paths)
191
  save_embeddings(embeddings)
192
 
193
- # 启动 Gradio 应用
194
  iface.launch()
 
8
  from src.model import MMEBModel
9
  from src.arguments import ModelArguments
10
 
 
11
  QUERY_DIR = "imgs/queries"
12
  IMAGE_DIR = "imgs/candidates"
 
13
  image_paths = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith((".jpg", ".png"))]
14
  global IMAGE_TOKEN, TOP_N
15
  IMAGE_TOKEN = "<|image_1|>"
16
  TOP_N = 5
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f"device: {device}")
19
+
20
+
21
  def load_model():
22
  global IMAGE_TOKEN
23
+
24
  model_args = ModelArguments(
25
+ # model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B",
26
+ model_name="lbw18601752667/IDMR-2B",
27
+ model_backbone="internvl_2_5",
28
  )
29
 
 
30
  if model_args.model_backbone == "phi35v":
31
  processor = AutoProcessor.from_pretrained(
32
  model_args.model_name,
 
52
  )
53
  IMAGE_TOKEN = "<image>"
54
 
 
55
  model = MMEBModel.load(model_args)
56
  model = model.to(device, dtype=torch.bfloat16)
57
  model.eval()
58
 
59
  return model, processor
60
 
 
61
  model, processor = load_model()
62
 
63
  def get_inputs(processor, text, image_path=None, image=None):
 
80
  del inputs['pixel_values']
81
  return inputs
82
 
 
 
83
  def encode_image_library(image_paths):
84
  embeddings = []
85
  for img_path in image_paths:
 
91
  embeddings.append(output["tgt_reps"].float().cpu().numpy())
92
  return np.stack(embeddings)
93
 
 
94
  def save_embeddings(embeddings, file_path="image_embeddings.pkl"):
95
  with open(file_path, "wb") as f:
96
  pickle.dump(embeddings, f)
97
 
 
98
  def load_embeddings(file_path="image_embeddings.pkl"):
99
  with open(file_path, "rb") as f:
100
  return pickle.load(f)
101
 
 
102
  def cosine_similarity(query_embedding, embeddings):
103
  similarity = np.sum(query_embedding * embeddings, axis=-1)
104
  return similarity
105
 
 
106
  def retrieve_images(query_text, query_image, top_n=TOP_N):
107
  if query_text:
108
  query_text = f"{IMAGE_TOKEN}\n {query_text}"
 
119
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
120
  query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy()
121
 
 
 
122
  embeddings = load_embeddings()
123
 
 
124
  similarity = cosine_similarity(query_embedding, embeddings)
125
  similarity = similarity.T
126
  print(f"cosine_similarity: {similarity}")
 
132
 
133
  return [image_paths[i] for i in top_indices]
134
 
 
135
  def demo(query_text, query_image):
 
136
  # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}")
137
 
138
  retrieved_images = retrieve_images(query_text, query_image)
 
139
  return [Image.open(img) for img in retrieved_images]
140
 
 
141
  def load_examples():
142
  examples = []
143
+
144
  image_files = [f for f in os.listdir(QUERY_DIR) if f.endswith((".jpg", ".png"))]
145
 
146
  for img_file in image_files:
 
147
  img_path = os.path.join(QUERY_DIR, img_file)
 
148
  txt_file = os.path.splitext(img_file)[0] + ".txt"
149
  txt_path = os.path.join(QUERY_DIR, txt_file)
150
 
 
151
  if os.path.exists(txt_path):
152
  with open(txt_path, 'r', encoding='utf-8') as f:
153
  query_text = f.read().strip().replace("<|image_1|>\n", "")
 
155
 
156
  return examples
157
 
 
158
  iface = gr.Interface(
159
  fn=demo,
160
  inputs=["text", "image"],
161
  outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})"),
162
+ examples=load_examples(),
163
  title="Multimodal Retrieval Demo",
164
  description="Enter a query and upload an image to retrieve relevant images from the library. You can click on the example below to use it as a query"
165
  )
166
 
 
167
  if not os.path.exists("image_embeddings.pkl"):
168
  embeddings = encode_image_library(image_paths)
169
  save_embeddings(embeddings)
170
 
 
171
  iface.launch()
src/collator.py CHANGED
@@ -19,8 +19,7 @@ class TrainCollator:
19
  """
20
  :param examples: [{qry:..., qry_image:..., pos_text:..., pos_image:...}] * batch_size
21
  """
22
- # import pdb; pdb.set_trace()
23
- qry_inputs = self._get_batch_inputs(examples, 0, 1) # qry_inputs: {'input_ids': tensor(batch_size, max_len), 'attention_mask': tensor(batch_size, max_len), 'pixel_values': tensor(batch_size, 4, 224, 224), 'image_sizes': tensor(batch_size, 2)}
24
  pos_inputs = self._get_batch_inputs(examples, 2, 3)
25
  if "hard_neg" in self.data_args.dataset_name:
26
  hard_neg_inputs = self._get_batch_inputs(examples, 4, 5)
@@ -45,15 +44,15 @@ class TrainCollator:
45
  max_length=self.data_args.max_len,
46
  truncation=True
47
  )
48
- elif self.model_args.model_backbone in ["qwen", "qwen2_vl"]: # Qwen系列
49
  inputs = self.processor(
50
- text=[text], # Qwen需要列表输入
51
  images=[image] if has_image else None,
52
  return_tensors="pt",
53
  max_length=self.data_args.max_len,
54
  truncation=True
55
  )
56
- else: # Phi3/InternVL通用处理
57
  inputs = self.processor(
58
  text=text,
59
  images=[image] if has_image else None,
@@ -62,23 +61,19 @@ class TrainCollator:
62
  truncation=True
63
  )
64
 
65
- # 统一输入格式处理
66
  if has_image:
67
  if self.model_args.model_backbone == "qwen":
68
  pixel_values.append(inputs['pixel_values'].unsqueeze(0))
69
  else:
70
  pixel_values.append(inputs['pixel_values'])
71
 
72
- # 保持维度对齐原始逻辑
73
  input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
74
 
75
- # 处理多模态元数据
76
  if "image_sizes" in inputs:
77
  image_sizes.append(inputs['image_sizes'])
78
  if "image_grid_thw" in inputs:
79
  image_grid_thw.append(inputs['image_grid_thw'])
80
 
81
- # 保持原始填充逻辑
82
  input_ids = torch._C._nn.pad_sequence(
83
  input_ids,
84
  batch_first=True,
@@ -87,89 +82,24 @@ class TrainCollator:
87
 
88
  attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
89
 
90
- # 构建返回字典
91
  inputs = {
92
  'input_ids': input_ids,
93
  'attention_mask': attention_mask,
94
- 'image_mask': torch.tensor(image_mask, dtype=torch.float) # 保持与原始字段名一致
95
  }
96
 
97
- # 处理图像数据
98
  if any(image_mask):
99
  if pixel_values:
100
  inputs['pixel_values'] = torch.cat(pixel_values, dim=0)
101
- if image_sizes: # LLaMA系列专用
102
  inputs['image_sizes'] = torch.cat(image_sizes, dim=0)
103
- if image_grid_thw: # Phi3专用
104
  inputs['image_grid_thw'] = torch.cat(image_grid_thw, dim=0)
105
 
106
- # InternVL专用字段适配
107
  if self.model_args.model_backbone == "internvl_2_5":
108
- inputs['image_flags'] = inputs['image_mask'].to(torch.long) # 模型需要long类型
109
- # del inputs['image_mask'] # 根据模型接口调整字段名
110
 
111
  return inputs
112
- """
113
- def _get_batch_inputs(self, examples, text_idx, image_idx):
114
- input_ids, pixel_values, image_sizes, image_grid_thw = [], [], [], []
115
- image_mask = []
116
- image_exist = False
117
- for example in examples:
118
- text, image = example[text_idx], example[image_idx] # text: str, image: PIL.Image.Image(765*512)
119
- if image is None:
120
- image_mask.append(0)
121
- if self.model_args.model_backbone == "llava_next":
122
- inputs = self.processor(images=None, text=text, return_tensors="pt")
123
- elif self.model_args.model_backbone == "qwen":
124
- inputs = self.processor(text=[text], images=None, return_tensors="pt",
125
- max_length=self.data_args.max_len, truncation=True)
126
- else: # 'phi', 'internvl'
127
- inputs = self.processor(text=text, images=None, return_tensors="pt",
128
- max_length=self.data_args.max_len, truncation=True)
129
- input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
130
- else:
131
- image_mask.append(1)
132
- image_exist = True
133
- if self.model_args.model_backbone == "llava_next":
134
- inputs = self.processor(images=image, text=text, return_tensors="pt")
135
- pixel_values.append(inputs['pixel_values'])
136
- elif self.model_args.model_backbone == "qwen":
137
- inputs = self.processor(text=[text], images=[image], return_tensors="pt",
138
- max_length=self.data_args.max_len, truncation=True)
139
- pixel_values.append(inputs['pixel_values'].unsqueeze(0))
140
- else:
141
- inputs = self.processor(text=text, images=[image], return_tensors="pt",
142
- max_length=self.data_args.max_len, truncation=True)
143
- pixel_values.append(inputs['pixel_values'])
144
- input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
145
- if "image_sizes" in inputs:
146
- image_sizes.append(inputs['image_sizes'])
147
- if "image_grid_thw" in inputs:
148
- image_grid_thw.append(inputs['image_grid_thw'])
149
-
150
- input_ids = torch._C._nn.pad_sequence(
151
- input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
152
- ).squeeze(2)
153
- attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
154
-
155
- inputs = {
156
- 'input_ids': input_ids,
157
- 'attention_mask': attention_mask,
158
- }
159
- if image_exist:
160
- inputs['image_mask'] = torch.Tensor(image_mask)
161
- pixel_values = torch.cat(pixel_values, dim=0)
162
- inputs['pixel_values'] = pixel_values
163
- if image_sizes:
164
- image_sizes = torch.cat(image_sizes, dim=0)
165
- inputs['image_sizes'] = image_sizes
166
- elif image_grid_thw:
167
- image_grid_thw = torch.cat(image_grid_thw, dim=0)
168
- inputs['image_grid_thw'] = image_grid_thw
169
-
170
- return inputs
171
- """
172
-
173
 
174
  @dataclass
175
  class EvalCollator:
@@ -183,72 +113,17 @@ class EvalCollator:
183
  """
184
  inputs = self._get_batch_inputs(examples)
185
  return inputs
186
- """
187
- def _get_batch_inputs(self, examples):
188
- input_ids, pixel_values, image_sizes = [], [], []
189
- image_exist = False
190
- for example in examples:
191
- text, image = example
192
- if image is None:
193
- if self.model_args.model_backbone == "llava_next":
194
- inputs = self.processor(images=None, text=text, return_tensors="pt")
195
- else:
196
- inputs = self.processor(text, None, return_tensors="pt", max_length=self.data_args.max_len,
197
- truncation=True)
198
- input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
199
- pixel_values.append(None)
200
- image_sizes.append(None)
201
- else:
202
- image_exist = True
203
- if self.model_args.model_backbone == "llava_next":
204
- inputs = self.processor(images=image, text=text, return_tensors="pt")
205
- else:
206
- inputs = self.processor(text, [image], return_tensors="pt", max_length=self.data_args.max_len, truncation=True)
207
- input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
208
- pixel_values.append(inputs['pixel_values'])
209
- image_sizes.append(inputs['image_sizes'])
210
 
211
- input_ids = torch._C._nn.pad_sequence(
212
- input_ids, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
213
- ).squeeze(2)
214
- attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
215
-
216
- if not image_exist:
217
- dummy_pixel_values = torch.zeros(input_ids.shape[0], 1)
218
- dummy_image_sizes = torch.ones(input_ids.shape[0], 1)
219
- inputs = {
220
- 'input_ids': input_ids,
221
- 'attention_mask': attention_mask,
222
- 'pixel_values': dummy_pixel_values,
223
- 'image_sizes': dummy_image_sizes,
224
- }
225
- else:
226
- pixel_values_shape = list(set(v.shape for v in pixel_values if v is not None))[0]
227
- pixel_values = [v if v is not None else torch.zeros(pixel_values_shape) for v in pixel_values]
228
- pixel_values = torch.cat(pixel_values, dim=0)
229
- image_sizes_shape = list(set(v.shape for v in image_sizes if v is not None))[0]
230
- image_sizes = [v if v is not None else torch.ones(image_sizes_shape) for v in image_sizes]
231
- image_sizes = torch.cat(image_sizes, dim=0)
232
- inputs = {
233
- 'input_ids': input_ids,
234
- 'attention_mask': attention_mask,
235
- 'pixel_values': pixel_values,
236
- 'image_sizes': image_sizes,
237
- }
238
-
239
- return inputs
240
- """
241
  def _get_batch_inputs(self, examples):
242
  input_ids, pixel_values, image_sizes = [], [], []
243
- image_mask = [] # 为internvl2_5添加
244
  image_exist = False
245
  for example in examples:
246
  text, image = example
247
- # print(text, image)
248
  has_image = image is not None
249
  image_mask.append(1 if has_image else 0)
250
 
251
- if self.model_args.model_backbone == "internvl_2_5": # Phi3/InternVL通用处理
252
  inputs = self.processor(
253
  text=text,
254
  images=[image] if has_image else None,
@@ -289,22 +164,19 @@ class EvalCollator:
289
  attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
290
 
291
  if self.model_args.model_backbone == "internvl_2_5":
292
- # 构建返回字典
293
  inputs = {
294
  'input_ids': input_ids,
295
  'attention_mask': attention_mask,
296
  'image_mask': torch.tensor(image_mask, dtype=torch.float)
297
  }
298
 
299
- # 处理图像数据
300
  if any(image_mask):
301
  if pixel_values:
302
  inputs['pixel_values'] = torch.cat(pixel_values, dim=0)
303
  if image_sizes:
304
  inputs['image_sizes'] = torch.cat(image_sizes, dim=0)
305
- # InternVL专用字段适配
306
  inputs['image_flags'] = inputs['image_mask'].to(torch.long)
307
- del inputs['image_mask'] # 根据模型接口调整字段名
308
  else:
309
  if not image_exist:
310
  dummy_pixel_values = torch.zeros(input_ids.shape[0], 1)
 
19
  """
20
  :param examples: [{qry:..., qry_image:..., pos_text:..., pos_image:...}] * batch_size
21
  """
22
+ qry_inputs = self._get_batch_inputs(examples, 0, 1)
 
23
  pos_inputs = self._get_batch_inputs(examples, 2, 3)
24
  if "hard_neg" in self.data_args.dataset_name:
25
  hard_neg_inputs = self._get_batch_inputs(examples, 4, 5)
 
44
  max_length=self.data_args.max_len,
45
  truncation=True
46
  )
47
+ elif self.model_args.model_backbone in ["qwen", "qwen2_vl"]:
48
  inputs = self.processor(
49
+ text=[text],
50
  images=[image] if has_image else None,
51
  return_tensors="pt",
52
  max_length=self.data_args.max_len,
53
  truncation=True
54
  )
55
+ else:
56
  inputs = self.processor(
57
  text=text,
58
  images=[image] if has_image else None,
 
61
  truncation=True
62
  )
63
 
 
64
  if has_image:
65
  if self.model_args.model_backbone == "qwen":
66
  pixel_values.append(inputs['pixel_values'].unsqueeze(0))
67
  else:
68
  pixel_values.append(inputs['pixel_values'])
69
 
 
70
  input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1))
71
 
 
72
  if "image_sizes" in inputs:
73
  image_sizes.append(inputs['image_sizes'])
74
  if "image_grid_thw" in inputs:
75
  image_grid_thw.append(inputs['image_grid_thw'])
76
 
 
77
  input_ids = torch._C._nn.pad_sequence(
78
  input_ids,
79
  batch_first=True,
 
82
 
83
  attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
84
 
 
85
  inputs = {
86
  'input_ids': input_ids,
87
  'attention_mask': attention_mask,
88
+ 'image_mask': torch.tensor(image_mask, dtype=torch.float)
89
  }
90
 
 
91
  if any(image_mask):
92
  if pixel_values:
93
  inputs['pixel_values'] = torch.cat(pixel_values, dim=0)
94
+ if image_sizes:
95
  inputs['image_sizes'] = torch.cat(image_sizes, dim=0)
96
+ if image_grid_thw:
97
  inputs['image_grid_thw'] = torch.cat(image_grid_thw, dim=0)
98
 
 
99
  if self.model_args.model_backbone == "internvl_2_5":
100
+ inputs['image_flags'] = inputs['image_mask'].to(torch.long)
 
101
 
102
  return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  @dataclass
105
  class EvalCollator:
 
113
  """
114
  inputs = self._get_batch_inputs(examples)
115
  return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def _get_batch_inputs(self, examples):
118
  input_ids, pixel_values, image_sizes = [], [], []
119
+ image_mask = []
120
  image_exist = False
121
  for example in examples:
122
  text, image = example
 
123
  has_image = image is not None
124
  image_mask.append(1 if has_image else 0)
125
 
126
+ if self.model_args.model_backbone == "internvl_2_5":
127
  inputs = self.processor(
128
  text=text,
129
  images=[image] if has_image else None,
 
164
  attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
165
 
166
  if self.model_args.model_backbone == "internvl_2_5":
 
167
  inputs = {
168
  'input_ids': input_ids,
169
  'attention_mask': attention_mask,
170
  'image_mask': torch.tensor(image_mask, dtype=torch.float)
171
  }
172
 
 
173
  if any(image_mask):
174
  if pixel_values:
175
  inputs['pixel_values'] = torch.cat(pixel_values, dim=0)
176
  if image_sizes:
177
  inputs['image_sizes'] = torch.cat(image_sizes, dim=0)
 
178
  inputs['image_flags'] = inputs['image_mask'].to(torch.long)
179
+ del inputs['image_mask']
180
  else:
181
  if not image_exist:
182
  dummy_pixel_values = torch.zeros(input_ids.shape[0], 1)
src/dataset.py CHANGED
@@ -8,18 +8,8 @@ from PIL import Image
8
  import os
9
  from torchvision.transforms import RandAugment
10
 
11
- # 定义 RandAugment 仅用于增强
12
  def get_randaugment_transform(n=2, m=9):
13
- """
14
- 创建 RandAugment 增强器。
15
-
16
- 参数:
17
- - n: 每次随机选择的增强操作数量。
18
- - m: 每种增强操作的强度。
19
-
20
- 返回:
21
- - RandAugment 对象。
22
- """
23
  return RandAugment(num_ops=n, magnitude=m)
24
 
25
 
@@ -39,7 +29,7 @@ class TrainDataset(Dataset):
39
  self.model_args = model_args
40
  self.transform = None
41
  if self.data_args.randaugment:
42
- self.transform = get_randaugment_transform() # RandAugment 或其他增强器
43
  train_data = []
44
 
45
  if data_args.subset_name is not None:
@@ -103,13 +93,6 @@ class TrainDataset(Dataset):
103
  return image
104
 
105
  def __getitem__(self, item) -> Tuple[str, List[str]]:
106
- # qry_text, qry_image_path, pos_text, pos_image_path = (
107
- # self.train_data[item]["qry"], self.train_data[item]["qry_image_path"],
108
- # self.train_data[item]["pos_text"], self.train_data[item]["pos_image_path"],
109
- # )
110
-
111
- # return (qry_text, self._get_image(qry_image_path),
112
- # pos_text, self._get_image(pos_image_path))
113
 
114
  data_item = self.train_data[item]
115
  qry_text, qry_image_path, pos_text, pos_image_path = (
 
8
  import os
9
  from torchvision.transforms import RandAugment
10
 
11
+
12
  def get_randaugment_transform(n=2, m=9):
 
 
 
 
 
 
 
 
 
 
13
  return RandAugment(num_ops=n, magnitude=m)
14
 
15
 
 
29
  self.model_args = model_args
30
  self.transform = None
31
  if self.data_args.randaugment:
32
+ self.transform = get_randaugment_transform()
33
  train_data = []
34
 
35
  if data_args.subset_name is not None:
 
93
  return image
94
 
95
  def __getitem__(self, item) -> Tuple[str, List[str]]:
 
 
 
 
 
 
 
96
 
97
  data_item = self.train_data[item]
98
  qry_text, qry_image_path, pos_text, pos_image_path = (
src/loss.py CHANGED
@@ -51,7 +51,7 @@ class HardNegativeContrastiveLoss:
51
  # y: positive embeddings
52
  # z: negative embeddings (optional)
53
 
54
- if z is None: # 如果没有负样本,退化为普通的对比学习
55
  target_per_qry = y.size(0) // x.size(0)
56
  target = torch.arange(
57
  0, x.size(0) * target_per_qry, target_per_qry,
@@ -60,18 +60,12 @@ class HardNegativeContrastiveLoss:
60
  loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
61
  return loss
62
 
63
- # 计算查询与正样本的相似度
64
- pos_logits = torch.matmul(x, y.transpose(0, 1)) # [batch_size, batch_size]
65
- # 计算查询与负样本的相似度
66
- neg_logits = torch.matmul(x, z.transpose(0, 1)) # [batch_size, num_negs]
67
 
68
- # 将正负样本的相似度拼接在一起
69
- logits = torch.cat([pos_logits, neg_logits], dim=1) # [batch_size, batch_size + num_negs]
70
-
71
- # 创建目标标签(正样本的索引)
72
  target = torch.arange(x.size(0), device=x.device)
73
-
74
- # 计算交叉熵损失
75
  loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
76
  return loss
77
 
 
51
  # y: positive embeddings
52
  # z: negative embeddings (optional)
53
 
54
+ if z is None:
55
  target_per_qry = y.size(0) // x.size(0)
56
  target = torch.arange(
57
  0, x.size(0) * target_per_qry, target_per_qry,
 
60
  loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
61
  return loss
62
 
63
+ pos_logits = torch.matmul(x, y.transpose(0, 1))
64
+ neg_logits = torch.matmul(x, z.transpose(0, 1))
65
+ logits = torch.cat([pos_logits, neg_logits], dim=1)
 
66
 
 
 
 
 
67
  target = torch.arange(x.size(0), device=x.device)
68
+
 
69
  loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
70
  return loss
71
 
src/model.py CHANGED
@@ -118,20 +118,6 @@ class MMEBModel(nn.Module):
118
  trust_remote_code=True)
119
  base_model.padding_side = "right"
120
 
121
- # # Print all model parameters
122
- # import json
123
- # import os
124
-
125
- # param_info = {}
126
- # for name, param in base_model.named_parameters():
127
- # param_info[name] = {
128
- # "shape": list(param.shape),
129
- # "requires_grad": param.requires_grad
130
- # }
131
-
132
- # with open('./model_parameters.json', 'w') as f:
133
- # json.dump(param_info, f, indent=4)
134
- # import pdb; pdb.set_trace()
135
  if model_args.lora:
136
  if lora_target_modules is None:
137
  lora_target_modules = model_args.lora_target_modules.split(',')
@@ -192,7 +178,7 @@ class MMEBModel(nn.Module):
192
  trust_remote_code=True
193
  )
194
  config = InternVLChatConfig.from_pretrained(model_args.model_name)
195
- # config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸
196
  config.use_flash_attn = False
197
  base_model = InternVLChatModel.from_pretrained(
198
  model_args.model_name,
 
118
  trust_remote_code=True)
119
  base_model.padding_side = "right"
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if model_args.lora:
122
  if lora_target_modules is None:
123
  lora_target_modules = model_args.lora_target_modules.split(',')
 
178
  trust_remote_code=True
179
  )
180
  config = InternVLChatConfig.from_pretrained(model_args.model_name)
181
+ # config.vision_config.image_size = data_args.force_image_size
182
  config.use_flash_attn = False
183
  base_model = InternVLChatModel.from_pretrained(
184
  model_args.model_name,
src/trainer.py CHANGED
@@ -87,11 +87,11 @@ def split_vlm_inputs(model_input: dict, chunk_size: int):
87
  if "image_grid_thw" in keys:
88
  image_grid_thw = arg_val["image_grid_thw"]
89
  chunked_tensors.append(torch.split(image_grid_thw, chunk_image_count))
90
- # 修改这里:image_flags 应该按照 chunk_size 分割,而不是 chunk_image_count
91
  if "image_flags" in keys:
92
  image_flags = arg_val["image_flags"]
93
  chunked_tensors.append(torch.split(image_flags, chunk_size))
94
- keys.remove("image_flags") # 从keys中移除,后面单独处理
95
 
96
 
97
  chunked_arg_val = []
@@ -148,7 +148,7 @@ class GradCacheTrainer(Trainer):
148
 
149
  def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
150
  model.train()
151
- # 支持 hard negative 样本
152
  if self.args.hard_neg:
153
  queries, passages, negatives = inputs
154
  queries, passages, negatives = {'qry': queries}, {'tgt': passages}, {'neg': negatives}
@@ -165,7 +165,7 @@ class GradCacheTrainer(Trainer):
165
  print(f"neg_img.shape={negatives['neg']['pixel_values'].shape}")
166
 
167
  _distributed = self.args.local_rank > -1
168
- self.gc.models = [model, model, model] # 为 negative 样本添加一个模型
169
  loss = self.gc(queries, passages, negatives, no_sync_except_last=_distributed)
170
  else:
171
  queries, passages = inputs
 
87
  if "image_grid_thw" in keys:
88
  image_grid_thw = arg_val["image_grid_thw"]
89
  chunked_tensors.append(torch.split(image_grid_thw, chunk_image_count))
90
+
91
  if "image_flags" in keys:
92
  image_flags = arg_val["image_flags"]
93
  chunked_tensors.append(torch.split(image_flags, chunk_size))
94
+ keys.remove("image_flags")
95
 
96
 
97
  chunked_arg_val = []
 
148
 
149
  def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
150
  model.train()
151
+
152
  if self.args.hard_neg:
153
  queries, passages, negatives = inputs
154
  queries, passages, negatives = {'qry': queries}, {'tgt': passages}, {'neg': negatives}
 
165
  print(f"neg_img.shape={negatives['neg']['pixel_values'].shape}")
166
 
167
  _distributed = self.args.local_rank > -1
168
+ self.gc.models = [model, model, model]
169
  loss = self.gc(queries, passages, negatives, no_sync_except_last=_distributed)
170
  else:
171
  queries, passages = inputs
src/vlm_backbone/intern_vl/modeling_internvl_chat.py CHANGED
@@ -172,53 +172,17 @@ class InternVLChatModel(PreTrainedModel):
172
  loss_reduction_all_gather: Optional[bool] = False,
173
  ) -> Union[Tuple, CausalLMOutputWithPast]:
174
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
175
- # import pdb; pdb.set_trace()
176
- # 获取原始batch size和每个样本的序列长度
177
  B, N = input_ids.shape
178
  input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() # [B, N, C]
179
 
180
  if pixel_values is not None:
181
  vit_embeds = self.extract_feature(pixel_values) # [num_images, num_patches, C]
182
-
183
- # 找到input_ids中需要替换的图片token位置
184
  selected = torch.eq(input_ids, self.img_context_token_id) # [B, N]
185
 
186
- # 确保image_flags维度正确
187
  image_flags = image_flags.squeeze(-1) # [B]
188
 
189
- # # 记录两种方法的时间
190
- # import time
191
-
192
- # # 方法1: 循环替换
193
- # start_time1 = time.time()
194
- # input_embeds2 = input_embeds.clone()
195
- # vit_idx = 0
196
- # for i in range(B):
197
- # if image_flags[i] == 1:
198
- # sample_selected = selected[i]
199
- # input_embeds2[i, sample_selected] = input_embeds2[i, sample_selected] * 0.0 + vit_embeds[vit_idx]
200
- # vit_idx += 1
201
- # time1 = time.time() - start_time1
202
-
203
- # 方法2: 向量化替换
204
- # start_time2 = time.time()
205
  mask = selected & (image_flags.unsqueeze(-1)) == 1
206
  input_embeds[mask] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
207
- # time2 = time.time() - start_time2
208
-
209
- # print(f"循环替换用时: {time1:.6f}秒")
210
- # print(f"向量化替换用时: {time2:.6f}秒")
211
- # print(f"向量化方法比循环方法快 {time1/time2:.2f}倍")
212
-
213
- # print(f"input_ids.shape = {input_ids.shape}") # [B, N]
214
- # print(f"input_embeds.shape = {input_embeds.shape}") # [B, N, C]
215
- # print(f"pixel_values.shape = {pixel_values.shape}") # [num_images, ...]
216
- # print(f"vit_embeds.shape = {vit_embeds.shape}") # [num_images, num_patches, C]
217
- # print(f"image_flags.sum() = {image_flags.sum()}") # 应该等于num_images
218
-
219
- # print(torch.allclose(input_embeds2, input_embeds, rtol=1e-7))
220
- # assert torch.allclose(input_embeds2, input_embeds, rtol=1e-5), "input_embeds2 and input_embeds should have the same values"
221
-
222
 
223
  outputs = self.language_model(
224
  inputs_embeds=input_embeds,
 
172
  loss_reduction_all_gather: Optional[bool] = False,
173
  ) -> Union[Tuple, CausalLMOutputWithPast]:
174
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
175
  B, N = input_ids.shape
176
  input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() # [B, N, C]
177
 
178
  if pixel_values is not None:
179
  vit_embeds = self.extract_feature(pixel_values) # [num_images, num_patches, C]
 
 
180
  selected = torch.eq(input_ids, self.img_context_token_id) # [B, N]
181
 
 
182
  image_flags = image_flags.squeeze(-1) # [B]
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  mask = selected & (image_flags.unsqueeze(-1)) == 1
185
  input_embeds[mask] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  outputs = self.language_model(
188
  inputs_embeds=input_embeds,
src/vlm_backbone/intern_vl/processing_internvl.py CHANGED
@@ -11,70 +11,6 @@ IMG_START_TOKEN = "<img>"
11
  IMG_END_TOKEN = "</img>"
12
  IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
13
 
14
- # class InternVLProcessor(ProcessorMixin):
15
- # attributes = ["image_processor", "tokenizer"]
16
- # image_processor_class = "AutoImageProcessor"
17
- # tokenizer_class = "AutoTokenizer"
18
-
19
- # def __init__(self, image_processor, tokenizer, num_img_tokens=256):
20
- # super().__init__(image_processor, tokenizer)
21
- # self.num_img_tokens = num_img_tokens
22
- # self._add_special_tokens()
23
-
24
- # def _add_special_tokens(self):
25
- # special_tokens = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN]
26
- # self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
27
-
28
- # def __call__(
29
- # self,
30
- # text: Union[TextInput, List[TextInput]] = None,
31
- # images: ImageInput = None,
32
- # padding: Union[bool, str, PaddingStrategy] = False,
33
- # truncation: Union[bool, str, TruncationStrategy] = None,
34
- # max_length: Optional[int] = None,
35
- # return_tensors: Optional[str] = "pt",
36
- # ) -> BatchFeature:
37
-
38
- # # Process images
39
- # pixel_values = []
40
- # if images is not None:
41
- # image_inputs = self.image_processor(images, return_tensors=return_tensors)
42
- # pixel_values = image_inputs.pixel_values
43
-
44
- # # Process text with image tokens
45
- # processed_text = self._insert_image_tokens(text, num_images=len(pixel_values))
46
-
47
- # # Tokenize text
48
- # text_inputs = self.tokenizer(
49
- # processed_text,
50
- # padding=padding,
51
- # truncation=truncation,
52
- # max_length=max_length,
53
- # return_tensors=return_tensors,
54
- # add_special_tokens=False
55
- # )
56
-
57
- # # Build final inputs
58
- # inputs = BatchFeature(data={
59
- # **text_inputs,
60
- # "pixel_values": pixel_values,
61
- # })
62
-
63
- # return inputs
64
-
65
- # def _insert_image_tokens(self, text: str, num_images: int) -> str:
66
- # """Replace <image> tags with image context tokens"""
67
- # image_tokens = []
68
- # for _ in range(num_images):
69
- # image_tokens.append(
70
- # f"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * self.num_img_tokens}{IMG_END_TOKEN}"
71
- # )
72
-
73
- # # Replace the first N occurrences of <image>
74
- # pattern = re.compile(r"<image>")
75
- # return pattern.sub(lambda x: image_tokens.pop(0) if image_tokens else "", text, count=num_images)
76
-
77
-
78
  class InternVLProcessor(ProcessorMixin):
79
  attributes = ["image_processor", "tokenizer"]
80
  image_processor_class = "AutoImageProcessor"
@@ -91,8 +27,7 @@ class InternVLProcessor(ProcessorMixin):
91
  num_added = self.tokenizer.add_special_tokens({
92
  "additional_special_tokens": special_tokens
93
  })
94
- # print(self.tokenizer)
95
- # assert num_added == 1, f"Failed to add IMG_CONTEXT token, added {num_added}"
96
 
97
  def __call__(
98
  self,
@@ -103,38 +38,25 @@ class InternVLProcessor(ProcessorMixin):
103
  max_length: Optional[int] = None,
104
  return_tensors: str = "pt"
105
  ) -> BatchFeature:
106
- # import pdb; pdb.set_trace()
107
-
108
- # 处理单样本输入
109
  if isinstance(text, str):
110
  text = [text]
111
 
112
  if not isinstance(images, list):
113
  images = [images] if images else []
114
 
115
- # 生成image_flags
116
  image_flags = [1] if len(images) else [0]
117
 
118
- # 图像预处理
119
  pixel_values = []
120
  if any(image_flags):
121
  pixel_values = self.image_processor(
122
- [img for img in images if img], # img.size(525, 704)
123
  return_tensors=return_tensors
124
- ).pixel_values # torch.Size([1, 3, 448, 448])
125
 
126
- # 文本预处理
127
  processed_texts = [
128
  self._insert_image_tokens(t, count)
129
  for t, count in zip(text, image_flags)
130
  ]
131
- # print("process text:")
132
- # print(processed_texts)
133
- # print("text")
134
- # print(text)
135
- # print(images)
136
- # print(image_flags)
137
- # Tokenize文本
138
  text_inputs = self.tokenizer(
139
  processed_texts,
140
  padding=padding,
@@ -144,7 +66,6 @@ class InternVLProcessor(ProcessorMixin):
144
  add_special_tokens=True
145
  )
146
 
147
- # 构建最终输入
148
  return BatchFeature({
149
  **text_inputs,
150
  "pixel_values": pixel_values,
@@ -152,7 +73,6 @@ class InternVLProcessor(ProcessorMixin):
152
  }, tensor_type=return_tensors)
153
 
154
  def _insert_image_tokens(self, text: str, image_count: int) -> str:
155
- """动态插入图像token"""
156
  if image_count == 0:
157
  return text
158
 
 
11
  IMG_END_TOKEN = "</img>"
12
  IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class InternVLProcessor(ProcessorMixin):
15
  attributes = ["image_processor", "tokenizer"]
16
  image_processor_class = "AutoImageProcessor"
 
27
  num_added = self.tokenizer.add_special_tokens({
28
  "additional_special_tokens": special_tokens
29
  })
30
+
 
31
 
32
  def __call__(
33
  self,
 
38
  max_length: Optional[int] = None,
39
  return_tensors: str = "pt"
40
  ) -> BatchFeature:
 
 
 
41
  if isinstance(text, str):
42
  text = [text]
43
 
44
  if not isinstance(images, list):
45
  images = [images] if images else []
46
 
 
47
  image_flags = [1] if len(images) else [0]
48
 
 
49
  pixel_values = []
50
  if any(image_flags):
51
  pixel_values = self.image_processor(
52
+ [img for img in images if img],
53
  return_tensors=return_tensors
54
+ ).pixel_values
55
 
 
56
  processed_texts = [
57
  self._insert_image_tokens(t, count)
58
  for t, count in zip(text, image_flags)
59
  ]
 
 
 
 
 
 
 
60
  text_inputs = self.tokenizer(
61
  processed_texts,
62
  padding=padding,
 
66
  add_special_tokens=True
67
  )
68
 
 
69
  return BatchFeature({
70
  **text_inputs,
71
  "pixel_values": pixel_values,
 
73
  }, tensor_type=return_tensors)
74
 
75
  def _insert_image_tokens(self, text: str, image_count: int) -> str:
 
76
  if image_count == 0:
77
  return text
78