liubangwei commited on
Commit
3ee7efa
·
1 Parent(s): d37eb96

fix retrieval

Browse files
app.py CHANGED
@@ -22,8 +22,8 @@ def load_model():
22
  global IMAGE_TOKEN
23
 
24
  model_args = ModelArguments(
25
- # model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B",
26
- model_name="lbw18601752667/IDMR-2B",
27
  model_backbone="internvl_2_5",
28
  )
29
 
@@ -81,15 +81,16 @@ def get_inputs(processor, text, image_path=None, image=None):
81
  return inputs
82
 
83
  def encode_image_library(image_paths):
84
- embeddings = []
85
  for img_path in image_paths:
86
  text = f"{IMAGE_TOKEN}\n Represent the given image."
87
  print(f"text: {text}")
88
  inputs = get_inputs(processor, text, image_path=img_path)
89
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
90
  output = model(tgt=inputs)
91
- embeddings.append(output["tgt_reps"].float().cpu().numpy())
92
- return np.stack(embeddings)
 
93
 
94
  def save_embeddings(embeddings, file_path="image_embeddings.pkl"):
95
  with open(file_path, "wb") as f:
@@ -115,22 +116,26 @@ def retrieve_images(query_text, query_image, top_n=TOP_N):
115
  image = None
116
  inputs = get_inputs(processor, query_text, image=image)
117
  print(f"inputs: {inputs}")
118
- # with torch.no_grad():
119
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
120
  query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy()
121
 
122
- embeddings = load_embeddings()
 
 
 
 
 
 
 
 
123
 
124
  similarity = cosine_similarity(query_embedding, embeddings)
125
  similarity = similarity.T
126
  print(f"cosine_similarity: {similarity}")
127
  top_indices = np.argsort(-similarity).squeeze(0)[:top_n]
128
  print(f"top_indices: {top_indices}")
129
-
130
- # similarity = model.compute_similarity(np.expand_dims(query_embedding.squeeze(0), axis=1), embeddings.squeeze(1))
131
- # print(f"model.compute_similarity: {similarity}")
132
 
133
- return [image_paths[i] for i in top_indices]
134
 
135
  def demo(query_text, query_image):
136
  # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}")
@@ -157,13 +162,15 @@ def load_examples():
157
 
158
  iface = gr.Interface(
159
  fn=demo,
160
- inputs=["text", "image"],
161
- outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})"),
 
 
 
162
  examples=load_examples(),
163
- title="Multimodal Retrieval Demo",
164
- description="Enter a query and upload an image to retrieve relevant images from the library. You can click on the example below to use it as a query"
165
  )
166
-
167
  if not os.path.exists("image_embeddings.pkl"):
168
  embeddings = encode_image_library(image_paths)
169
  save_embeddings(embeddings)
 
22
  global IMAGE_TOKEN
23
 
24
  model_args = ModelArguments(
25
+ model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B",
26
+ # model_name="lbw18601752667/IDMR-2B",
27
  model_backbone="internvl_2_5",
28
  )
29
 
 
81
  return inputs
82
 
83
  def encode_image_library(image_paths):
84
+ embeddings_dict = {}
85
  for img_path in image_paths:
86
  text = f"{IMAGE_TOKEN}\n Represent the given image."
87
  print(f"text: {text}")
88
  inputs = get_inputs(processor, text, image_path=img_path)
89
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
90
  output = model(tgt=inputs)
91
+ img_name = os.path.basename(img_path)
92
+ embeddings_dict[img_name] = output["tgt_reps"].float().cpu().numpy()
93
+ return embeddings_dict
94
 
95
  def save_embeddings(embeddings, file_path="image_embeddings.pkl"):
96
  with open(file_path, "wb") as f:
 
116
  image = None
117
  inputs = get_inputs(processor, query_text, image=image)
118
  print(f"inputs: {inputs}")
 
119
  with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16):
120
  query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy()
121
 
122
+ embeddings_dict = load_embeddings()
123
+
124
+ img_names = []
125
+ embeddings = []
126
+ for img_name in os.listdir(IMAGE_DIR):
127
+ if img_name in embeddings_dict:
128
+ img_names.append(img_name)
129
+ embeddings.append(embeddings_dict[img_name])
130
+ embeddings = np.stack(embeddings)
131
 
132
  similarity = cosine_similarity(query_embedding, embeddings)
133
  similarity = similarity.T
134
  print(f"cosine_similarity: {similarity}")
135
  top_indices = np.argsort(-similarity).squeeze(0)[:top_n]
136
  print(f"top_indices: {top_indices}")
 
 
 
137
 
138
+ return [os.path.join(IMAGE_DIR, img_names[i]) for i in top_indices]
139
 
140
  def demo(query_text, query_image):
141
  # print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}")
 
162
 
163
  iface = gr.Interface(
164
  fn=demo,
165
+ inputs=[
166
+ gr.Textbox(placeholder="Enter your query text here...", label="Query Text"),
167
+ gr.Image(label="Query Image", type="numpy")
168
+ ],
169
+ outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})", columns=3),
170
  examples=load_examples(),
171
+ title="Instance-Driven Multi-modal Retrieval (IDMR) Demo",
172
+ description="Enter a query text or upload an image to retrieve relevant images from the library. You can click on the examples below to try them out."
173
  )
 
174
  if not os.path.exists("image_embeddings.pkl"):
175
  embeddings = encode_image_library(image_paths)
176
  save_embeddings(embeddings)
image_embeddings.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8dcedaab4e3bcc555795f56b15a7d830b74ffc707260c3b0152ba8d99a992bd
3
- size 409764
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05be7a8cbfdc77b64e473f3d2c30c47c6e522623abb11d288fb54fad07a5f8da
3
+ size 412525
src/__pycache__/arguments.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/arguments.cpython-310.pyc and b/src/__pycache__/arguments.cpython-310.pyc differ
 
src/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/src/__pycache__/model.cpython-310.pyc and b/src/__pycache__/model.cpython-310.pyc differ
 
src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc CHANGED
Binary files a/src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc and b/src/vlm_backbone/intern_vl/__pycache__/modeling_internvl_chat.cpython-310.pyc differ
 
src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc CHANGED
Binary files a/src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc and b/src/vlm_backbone/intern_vl/__pycache__/processing_internvl.cpython-310.pyc differ
 
src/vlm_backbone/llava_next/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/vlm_backbone/llava_next/__pycache__/__init__.cpython-310.pyc and b/src/vlm_backbone/llava_next/__pycache__/__init__.cpython-310.pyc differ
 
src/vlm_backbone/llava_next/__pycache__/modeling_llava_next.cpython-310.pyc CHANGED
Binary files a/src/vlm_backbone/llava_next/__pycache__/modeling_llava_next.cpython-310.pyc and b/src/vlm_backbone/llava_next/__pycache__/modeling_llava_next.cpython-310.pyc differ