multimodalart HF Staff commited on
Commit
a8ca9e4
·
verified ·
1 Parent(s): 2ff63f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -5
app.py CHANGED
@@ -111,6 +111,55 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
111
  def change_rank_default(concept_name):
112
  return RANKS_MAP.get(concept_name, 30)
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  @spaces.GPU
115
  def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
116
  """Get CLIP image embeddings for a given PIL image"""
@@ -464,9 +513,21 @@ Following the algorithm proposed in IP-Composer: Semantic Composition of Visual
464
  inputs=[concept_name3],
465
  outputs=[rank3]
466
  )
467
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  if __name__ == "__main__":
469
- demo.launch()
470
-
471
-
472
-
 
111
  def change_rank_default(concept_name):
112
  return RANKS_MAP.get(concept_name, 30)
113
 
114
+ @spaces.GPU
115
+ def match_image_to_concept(image):
116
+ """
117
+ Match an uploaded image to the closest concept type using CLIP embeddings
118
+ """
119
+ if image is None:
120
+ return None
121
+
122
+ # Get image embeddings
123
+ img_pil = Image.fromarray(image).convert("RGB")
124
+ img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
125
+
126
+ # Calculate similarity to each concept
127
+ similarities = {}
128
+ for concept_name, concept_file in CONCEPTS_MAP.items():
129
+ try:
130
+ # Load concept embeddings
131
+ embeds_path = f"./IP_Composer/text_embeddings/{concept_file}"
132
+ with open(embeds_path, "rb") as f:
133
+ concept_embeds = np.load(f)
134
+
135
+ # Calculate similarity to each text embedding
136
+ sim_scores = []
137
+ for embed in concept_embeds:
138
+ # Normalize both embeddings
139
+ img_embed_norm = img_embed / np.linalg.norm(img_embed)
140
+ text_embed_norm = embed / np.linalg.norm(embed)
141
+
142
+ # Calculate cosine similarity
143
+ similarity = np.dot(img_embed_norm.flatten(), text_embed_norm.flatten())
144
+ sim_scores.append(similarity)
145
+
146
+ # Use the average of top 5 similarities for better matching
147
+ sim_scores.sort(reverse=True)
148
+ top_similarities = sim_scores[:min(5, len(sim_scores))]
149
+ avg_similarity = sum(top_similarities) / len(top_similarities)
150
+
151
+ similarities[concept_name] = avg_similarity
152
+ except Exception as e:
153
+ print(f"Error processing concept {concept_name}: {e}")
154
+
155
+ # Return the concept with highest similarity
156
+ if similarities:
157
+ matched_concept = max(similarities.items(), key=lambda x: x[1])[0]
158
+ # Display a notification to the user
159
+ gr.Info(f"Image automatically matched to concept: {matched_concept}")
160
+ return matched_concept
161
+ return None
162
+
163
  @spaces.GPU
164
  def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
165
  """Get CLIP image embeddings for a given PIL image"""
 
513
  inputs=[concept_name3],
514
  outputs=[rank3]
515
  )
516
+ concept_image1.upload(
517
+ fn=match_image_to_concept,
518
+ inputs=[concept_image1],
519
+ outputs=[concept_name1]
520
+ )
521
+ concept_image2.upload(
522
+ fn=match_image_to_concept,
523
+ inputs=[concept_image2],
524
+ outputs=[concept_name2]
525
+ )
526
+ concept_image3.upload(
527
+ fn=match_image_to_concept,
528
+ inputs=[concept_image3],
529
+ outputs=[concept_name3]
530
+ )
531
+
532
  if __name__ == "__main__":
533
+ demo.launch()