Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import requests
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
# Sidebar content
|
10 |
+
sidebar_markdown = """
|
11 |
+
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
|
12 |
+
|
13 |
+
## Documentation
|
14 |
+
📚 [Blog Post](https://www.marqo.ai/blog/search-model-for-fashion)
|
15 |
+
|
16 |
+
📝 [Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-marqo-fashionclip)
|
17 |
+
|
18 |
+
## Code
|
19 |
+
💻 [GitHub Repo](https://github.com/marqo-ai/marqo-FashionCLIP)
|
20 |
+
|
21 |
+
🤝 [Google Colab](https://colab.research.google.com/drive/1nq978xFJjJcnyrJ2aE5l82GHAXOvTmfd?usp=sharing)
|
22 |
+
|
23 |
+
🤗 [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-fashionclip-and-marqo-fashionsiglip-66b43f2d09a06ad2368d4af6)
|
24 |
+
"""
|
25 |
+
|
26 |
+
# List of fashion items and their IDs
|
27 |
+
categories = [
|
28 |
+
{"name": "Nettoyants visage", "id": 101},
|
29 |
+
{"name": "Exfoliants visage", "id": 102},
|
30 |
+
{"name": "Hydratants visage", "id": 103},
|
31 |
+
{"name": "Masques visage", "id": 104},
|
32 |
+
{"name": "Soins ciblés visage", "id": 105},
|
33 |
+
{"name": "Protection solaire visage", "id": 106},
|
34 |
+
{"name": "Nettoyants visage homme", "id": 107},
|
35 |
+
{"name": "Crèmes hydratantes homme", "id": 108},
|
36 |
+
{"name": "Soins après-rasage", "id": 109},
|
37 |
+
{"name": "Hydratants corps", "id": 110},
|
38 |
+
{"name": "Exfoliants corps", "id": 111},
|
39 |
+
{"name": "Soins fermeté & minceur", "id": 112},
|
40 |
+
{"name": "Auto-bronzants", "id": 113},
|
41 |
+
{"name": "Soins des mains", "id": 114},
|
42 |
+
{"name": "Soins des pieds", "id": 115},
|
43 |
+
{"name": "Hydratants corps homme", "id": 116},
|
44 |
+
{"name": "Déodorants corps homme", "id": 117},
|
45 |
+
{"name": "Shampoings", "id": 118},
|
46 |
+
{"name": "Après-shampoings", "id": 119},
|
47 |
+
{"name": "Masques capillaires", "id": 120},
|
48 |
+
{"name": "Huiles capillaires", "id": 121},
|
49 |
+
{"name": "Coiffants", "id": 122},
|
50 |
+
{"name": "Accessoires cheveux", "id": 123},
|
51 |
+
{"name": "Soins cheveux homme", "id": 124},
|
52 |
+
{"name": "Produits coiffants homme", "id": 125},
|
53 |
+
{"name": "Fond de teint", "id": 126},
|
54 |
+
{"name": "BB/CC crèmes", "id": 127},
|
55 |
+
{"name": "Poudres", "id": 128},
|
56 |
+
{"name": "Fards à paupières", "id": 129},
|
57 |
+
{"name": "Mascaras", "id": 130},
|
58 |
+
{"name": "Eyeliners", "id": 131},
|
59 |
+
{"name": "Rouges à lèvres", "id": 132},
|
60 |
+
{"name": "Gloss", "id": 133},
|
61 |
+
{"name": "Crayons à sourcils", "id": 134},
|
62 |
+
{"name": "Accessoires maquillage", "id": 135},
|
63 |
+
{"name": "Correcteurs teint homme", "id": 136},
|
64 |
+
{"name": "Poudres matifiantes homme", "id": 137},
|
65 |
+
{"name": "Parfums", "id": 138},
|
66 |
+
{"name": "Brumes corporelles", "id": 139},
|
67 |
+
{"name": "Huiles essentielles", "id": 140},
|
68 |
+
{"name": "Diffuseurs d'huiles", "id": 141},
|
69 |
+
{"name": "Bougies parfumées", "id": 142},
|
70 |
+
{"name": "Déodorants solides", "id": 143},
|
71 |
+
{"name": "Déodorants sprays", "id": 144},
|
72 |
+
{"name": "Savons solides", "id": 145},
|
73 |
+
{"name": "Savons liquides", "id": 146},
|
74 |
+
{"name": "Produits bain", "id": 147},
|
75 |
+
{"name": "Hygiène intime", "id": 148},
|
76 |
+
{"name": "Cups menstruelles", "id": 149},
|
77 |
+
{"name": "Produits zéro déchet", "id": 150},
|
78 |
+
{"name": "Brosses nettoyantes visage", "id": 151},
|
79 |
+
{"name": "Pinces à épiler", "id": 152},
|
80 |
+
{"name": "Trousse de voyage", "id": 153},
|
81 |
+
{"name": "Huiles de CBD", "id": 154},
|
82 |
+
{"name": "Cosmétiques au CBD", "id": 155},
|
83 |
+
{"name": "Infusions au CBD", "id": 156},
|
84 |
+
{"name": "Bonbons au CBD", "id": 157},
|
85 |
+
{"name": "Accessoires CBD", "id": 158},
|
86 |
+
{"name": "Robes femme", "id": 201},
|
87 |
+
{"name": "Tops femme", "id": 202},
|
88 |
+
{"name": "Chemisiers femme", "id": 203},
|
89 |
+
{"name": "T-shirts femme", "id": 204},
|
90 |
+
{"name": "Pulls femme", "id": 205},
|
91 |
+
{"name": "Jeans femme", "id": 206},
|
92 |
+
{"name": "Pantalons femme", "id": 207},
|
93 |
+
{"name": "Jupes femme", "id": 208},
|
94 |
+
{"name": "Shorts femme", "id": 209},
|
95 |
+
{"name": "Vestes femme", "id": 210},
|
96 |
+
{"name": "Manteaux femme", "id": 211},
|
97 |
+
{"name": "Maillots de bain femme", "id": 212},
|
98 |
+
{"name": "Lingerie femme", "id": 213},
|
99 |
+
{"name": "Chaussures femme", "id": 214},
|
100 |
+
{"name": "Sacs femme", "id": 215},
|
101 |
+
{"name": "Bijoux femme", "id": 216},
|
102 |
+
{"name": "Chemises homme", "id": 301},
|
103 |
+
{"name": "T-shirts homme", "id": 302},
|
104 |
+
{"name": "Polos homme", "id": 303},
|
105 |
+
{"name": "Pulls homme", "id": 304},
|
106 |
+
{"name": "Jeans homme", "id": 305},
|
107 |
+
{"name": "Pantalons homme", "id": 306},
|
108 |
+
{"name": "Shorts homme", "id": 307},
|
109 |
+
{"name": "Vestes homme", "id": 308},
|
110 |
+
{"name": "Manteaux homme", "id": 309},
|
111 |
+
{"name": "Costumes homme", "id": 310},
|
112 |
+
{"name": "Maillots de bain homme", "id": 311},
|
113 |
+
{"name": "Sous-vêtements homme", "id": 312},
|
114 |
+
{"name": "Chaussures homme", "id": 313},
|
115 |
+
{"name": "Accessoires homme", "id": 314},
|
116 |
+
{"name": "Montres homme", "id": 315},
|
117 |
+
{"name": "Vêtements bébé (0-2 ans)", "id": 401},
|
118 |
+
{"name": "T-shirts enfant", "id": 402},
|
119 |
+
{"name": "Pulls enfant", "id": 403},
|
120 |
+
{"name": "Pantalons enfant", "id": 404},
|
121 |
+
{"name": "Robes enfant", "id": 405},
|
122 |
+
{"name": "Jeans enfant", "id": 406},
|
123 |
+
{"name": "Vestes enfant", "id": 407},
|
124 |
+
{"name": "Pyjamas enfant", "id": 408},
|
125 |
+
{"name": "Chaussures enfant", "id": 409},
|
126 |
+
{"name": "Accessoires enfant", "id": 410},
|
127 |
+
{"name": "Vêtements de sport enfant", "id": 411},
|
128 |
+
{"name": "Maillots de bain enfant", "id": 412},
|
129 |
+
{"name": "Sous-vêtements enfant", "id": 413},
|
130 |
+
{"name": "Déguisements enfant", "id": 414},
|
131 |
+
{"name": "Cartables et sacs enfant", "id": 415},
|
132 |
+
# Chaussures Femme détaillées
|
133 |
+
{"name": "Sneakers femme", "id": 217},
|
134 |
+
{"name": "Boots femme", "id": 218},
|
135 |
+
{"name": "Escarpins femme", "id": 219},
|
136 |
+
{"name": "Sandales femme", "id": 220},
|
137 |
+
{"name": "Ballerines femme", "id": 221},
|
138 |
+
{"name": "Mocassins femme", "id": 222},
|
139 |
+
{"name": "Bottines femme", "id": 223},
|
140 |
+
{"name": "Espadrilles femme", "id": 224},
|
141 |
+
{"name": "Mules femme", "id": 225},
|
142 |
+
{"name": "Chaussures de sport femme", "id": 226},
|
143 |
+
{"name": "Bottes hautes femme", "id": 227},
|
144 |
+
{"name": "Chaussures compensées femme", "id": 228},
|
145 |
+
# Chaussures Homme détaillées
|
146 |
+
{"name": "Sneakers homme", "id": 316},
|
147 |
+
{"name": "Boots homme", "id": 317},
|
148 |
+
{"name": "Chaussures de ville homme", "id": 318},
|
149 |
+
{"name": "Mocassins homme", "id": 319},
|
150 |
+
{"name": "Sandales homme", "id": 320},
|
151 |
+
{"name": "Chaussures bateau homme", "id": 321},
|
152 |
+
{"name": "Bottines homme", "id": 322},
|
153 |
+
{"name": "Chaussures de sport homme", "id": 323},
|
154 |
+
{"name": "Espadrilles homme", "id": 324},
|
155 |
+
{"name": "Derbies homme", "id": 325},
|
156 |
+
{"name": "Richelieus homme", "id": 326},
|
157 |
+
{"name": "Chaussures de randonnée homme", "id": 327},
|
158 |
+
# Chaussures Enfant détaillées
|
159 |
+
{"name": "Sneakers enfant", "id": 416},
|
160 |
+
{"name": "Bottes enfant", "id": 417},
|
161 |
+
{"name": "Sandales enfant", "id": 418},
|
162 |
+
{"name": "Chaussures de sport enfant", "id": 419},
|
163 |
+
{"name": "Chaussures premiers pas", "id": 420},
|
164 |
+
{"name": "Chaussures à scratch enfant", "id": 421},
|
165 |
+
{"name": "Chaussures d'école enfant", "id": 422},
|
166 |
+
{"name": "Pantoufles enfant", "id": 423},
|
167 |
+
{"name": "Chaussures de cérémonie enfant", "id": 424},
|
168 |
+
{"name": "Bottes de pluie enfant", "id": 425}
|
169 |
+
];
|
170 |
+
|
171 |
+
|
172 |
+
# Extract category names
|
173 |
+
items = [category["name"] for category in categories]
|
174 |
+
|
175 |
+
# Initialize the model and tokenizer
|
176 |
+
model_name = 'hf-hub:Marqo/marqo-fashionSigLIP'
|
177 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
|
178 |
+
tokenizer = open_clip.get_tokenizer(model_name)
|
179 |
+
|
180 |
+
# Generate descriptions
|
181 |
+
def generate_description(item):
|
182 |
+
return f"A fashion item called {item}"
|
183 |
+
|
184 |
+
items_desc = [generate_description(item) for item in items]
|
185 |
+
text = tokenizer(items_desc)
|
186 |
+
|
187 |
+
# Encode text features
|
188 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
189 |
+
model.to(device)
|
190 |
+
|
191 |
+
torch.cuda.empty_cache() # Avant de charger le modèle
|
192 |
+
|
193 |
+
with torch.no_grad(), torch.amp.autocast(device_type=device):
|
194 |
+
text_features = model.encode_text(text.to(device))
|
195 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
196 |
+
|
197 |
+
# Prediction function
|
198 |
+
def predict(image, url):
|
199 |
+
if url:
|
200 |
+
response = requests.get(url)
|
201 |
+
image = Image.open(BytesIO(response.content))
|
202 |
+
|
203 |
+
processed_image = preprocess_val(image).unsqueeze(0).to(device)
|
204 |
+
|
205 |
+
with torch.no_grad(), torch.amp.autocast(device_type=device):
|
206 |
+
image_features = model.encode_image(processed_image)
|
207 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
208 |
+
|
209 |
+
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
|
210 |
+
|
211 |
+
sorted_confidences = sorted(
|
212 |
+
{items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(),
|
213 |
+
key=lambda x: x[1],
|
214 |
+
reverse=True
|
215 |
+
)
|
216 |
+
|
217 |
+
# Include category IDs in the response
|
218 |
+
top_10_categories = [
|
219 |
+
{
|
220 |
+
"category_name": category["name"],
|
221 |
+
"id": category["id"],
|
222 |
+
"confidence": confidence
|
223 |
+
}
|
224 |
+
for category_name, confidence in sorted_confidences[:10]
|
225 |
+
for category in categories if category["name"] == category_name
|
226 |
+
]
|
227 |
+
|
228 |
+
return image, top_10_categories
|
229 |
+
|
230 |
+
# Ajout de la fonction de prédiction par lots
|
231 |
+
def predict_batch(images, urls):
|
232 |
+
# Combiner les images provenant des URLs et des téléchargements directs
|
233 |
+
combined_images = []
|
234 |
+
for image, url in zip(images, urls):
|
235 |
+
if url:
|
236 |
+
response = requests.get(url)
|
237 |
+
image = Image.open(BytesIO(response.content))
|
238 |
+
combined_images.append(preprocess_val(image).unsqueeze(0).to(device))
|
239 |
+
|
240 |
+
# Empiler toutes les images traitées en un seul lot
|
241 |
+
batch_images = torch.cat(combined_images, dim=0)
|
242 |
+
|
243 |
+
with torch.no_grad(), torch.amp.autocast(device_type=device):
|
244 |
+
image_features = model.encode_image(batch_images)
|
245 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
246 |
+
|
247 |
+
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
|
248 |
+
|
249 |
+
# Traiter chaque image dans le lot
|
250 |
+
batch_results = []
|
251 |
+
for i in range(len(images)):
|
252 |
+
sorted_confidences = sorted(
|
253 |
+
{items[j]: float(text_probs[i, j]) for j in range(len(items))}.items(),
|
254 |
+
key=lambda x: x[1],
|
255 |
+
reverse=True
|
256 |
+
)
|
257 |
+
|
258 |
+
# Inclure les IDs de catégorie dans la réponse
|
259 |
+
top_10_categories = [
|
260 |
+
{
|
261 |
+
"category_name": category["name"],
|
262 |
+
"id": category["id"],
|
263 |
+
"confidence": confidence
|
264 |
+
}
|
265 |
+
for category_name, confidence in sorted_confidences[:10]
|
266 |
+
for category in categories if category["name"] == category_name
|
267 |
+
]
|
268 |
+
batch_results.append(top_10_categories)
|
269 |
+
|
270 |
+
return batch_results
|
271 |
+
|
272 |
+
# Clear function
|
273 |
+
def clear_fields():
|
274 |
+
# return None, "", None, ""
|
275 |
+
return None, ""
|
276 |
+
# Gradio interface
|
277 |
+
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
|
278 |
+
description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
|
279 |
+
|
280 |
+
examples = [
|
281 |
+
["images/dress.jpg", "Dress"],
|
282 |
+
["images/sweatpants.jpg", "Sweatpants"],
|
283 |
+
["images/t-shirt.jpg", "T-Shirt"],
|
284 |
+
]
|
285 |
+
|
286 |
+
with gr.Blocks() as demo:
|
287 |
+
with gr.Row():
|
288 |
+
with gr.Column(scale=1):
|
289 |
+
gr.Markdown(f"# {title}")
|
290 |
+
gr.Markdown(description)
|
291 |
+
gr.Markdown(sidebar_markdown)
|
292 |
+
with gr.Column(scale=2):
|
293 |
+
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
|
294 |
+
input_url = gr.Textbox(label="Or provide an image URL")
|
295 |
+
# input_images = gr.Image(type="pil", label="Upload Fashion Item Images", height=312)
|
296 |
+
# input_urls = gr.Textbox(label="Or provide image URLs (comma-separated)", lines=2)
|
297 |
+
with gr.Row():
|
298 |
+
predict_button = gr.Button("Classify")
|
299 |
+
# predict_batch_button = gr.Button("Classify Batch")
|
300 |
+
clear_button = gr.Button("Clear")
|
301 |
+
gr.Markdown("Or click on one of the images below to classify it:")
|
302 |
+
gr.Examples(examples=examples, inputs=input_image)
|
303 |
+
output_label = gr.JSON(label="Top Categories")
|
304 |
+
# output_batch_label = gr.JSON(label="Top Categories for Each Image")
|
305 |
+
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
|
306 |
+
# predict_batch_button.click(predict_batch, inputs=[input_images, input_urls], outputs=output_batch_label)
|
307 |
+
# clear_button.click(clear_fields, outputs=[input_image, input_url, input_images, input_urls])
|
308 |
+
|
309 |
+
# Launch the interface
|
310 |
+
demo.launch()
|