Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -111,6 +111,55 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
111 |
def change_rank_default(concept_name):
|
112 |
return RANKS_MAP.get(concept_name, 30)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
@spaces.GPU
|
115 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
116 |
"""Get CLIP image embeddings for a given PIL image"""
|
@@ -464,9 +513,21 @@ Following the algorithm proposed in IP-Composer: Semantic Composition of Visual
|
|
464 |
inputs=[concept_name3],
|
465 |
outputs=[rank3]
|
466 |
)
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
if __name__ == "__main__":
|
469 |
-
demo.launch()
|
470 |
-
|
471 |
-
|
472 |
-
|
|
|
111 |
def change_rank_default(concept_name):
|
112 |
return RANKS_MAP.get(concept_name, 30)
|
113 |
|
114 |
+
@spaces.GPU
|
115 |
+
def match_image_to_concept(image):
|
116 |
+
"""
|
117 |
+
Match an uploaded image to the closest concept type using CLIP embeddings
|
118 |
+
"""
|
119 |
+
if image is None:
|
120 |
+
return None
|
121 |
+
|
122 |
+
# Get image embeddings
|
123 |
+
img_pil = Image.fromarray(image).convert("RGB")
|
124 |
+
img_embed = get_image_embeds(img_pil, clip_model, preprocess, device)
|
125 |
+
|
126 |
+
# Calculate similarity to each concept
|
127 |
+
similarities = {}
|
128 |
+
for concept_name, concept_file in CONCEPTS_MAP.items():
|
129 |
+
try:
|
130 |
+
# Load concept embeddings
|
131 |
+
embeds_path = f"./IP_Composer/text_embeddings/{concept_file}"
|
132 |
+
with open(embeds_path, "rb") as f:
|
133 |
+
concept_embeds = np.load(f)
|
134 |
+
|
135 |
+
# Calculate similarity to each text embedding
|
136 |
+
sim_scores = []
|
137 |
+
for embed in concept_embeds:
|
138 |
+
# Normalize both embeddings
|
139 |
+
img_embed_norm = img_embed / np.linalg.norm(img_embed)
|
140 |
+
text_embed_norm = embed / np.linalg.norm(embed)
|
141 |
+
|
142 |
+
# Calculate cosine similarity
|
143 |
+
similarity = np.dot(img_embed_norm.flatten(), text_embed_norm.flatten())
|
144 |
+
sim_scores.append(similarity)
|
145 |
+
|
146 |
+
# Use the average of top 5 similarities for better matching
|
147 |
+
sim_scores.sort(reverse=True)
|
148 |
+
top_similarities = sim_scores[:min(5, len(sim_scores))]
|
149 |
+
avg_similarity = sum(top_similarities) / len(top_similarities)
|
150 |
+
|
151 |
+
similarities[concept_name] = avg_similarity
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Error processing concept {concept_name}: {e}")
|
154 |
+
|
155 |
+
# Return the concept with highest similarity
|
156 |
+
if similarities:
|
157 |
+
matched_concept = max(similarities.items(), key=lambda x: x[1])[0]
|
158 |
+
# Display a notification to the user
|
159 |
+
gr.Info(f"Image automatically matched to concept: {matched_concept}")
|
160 |
+
return matched_concept
|
161 |
+
return None
|
162 |
+
|
163 |
@spaces.GPU
|
164 |
def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
|
165 |
"""Get CLIP image embeddings for a given PIL image"""
|
|
|
513 |
inputs=[concept_name3],
|
514 |
outputs=[rank3]
|
515 |
)
|
516 |
+
concept_image1.upload(
|
517 |
+
fn=match_image_to_concept,
|
518 |
+
inputs=[concept_image1],
|
519 |
+
outputs=[concept_name1]
|
520 |
+
)
|
521 |
+
concept_image2.upload(
|
522 |
+
fn=match_image_to_concept,
|
523 |
+
inputs=[concept_image2],
|
524 |
+
outputs=[concept_name2]
|
525 |
+
)
|
526 |
+
concept_image3.upload(
|
527 |
+
fn=match_image_to_concept,
|
528 |
+
inputs=[concept_image3],
|
529 |
+
outputs=[concept_name3]
|
530 |
+
)
|
531 |
+
|
532 |
if __name__ == "__main__":
|
533 |
+
demo.launch()
|
|
|
|
|
|