Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app (2).py +68 -0
- index_to_attr (3).py +230 -0
- model_loader (1).py +33 -0
- requirements (1).txt +4 -0
app (2).py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
from model_loader import load_model
|
7 |
+
from index_to_attr import index_to_attr
|
8 |
+
|
9 |
+
# Modell laden
|
10 |
+
model = load_model("model/AttrPredModel_StateDict.pth")
|
11 |
+
|
12 |
+
# taskName pro Index extrahieren
|
13 |
+
def get_task_map(index_to_attr):
|
14 |
+
task_map = {}
|
15 |
+
for idx, desc in index_to_attr.items():
|
16 |
+
if "(" in desc and ")" in desc:
|
17 |
+
task = desc.split("(")[-1].split(")")[0]
|
18 |
+
task_map[idx] = task
|
19 |
+
return task_map
|
20 |
+
|
21 |
+
task_map = get_task_map(index_to_attr)
|
22 |
+
|
23 |
+
# Bildverarbeitungspipeline
|
24 |
+
preprocess = transforms.Compose([
|
25 |
+
transforms.Resize((512, 512)),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize(mean=[0.6765, 0.6347, 0.6207],
|
28 |
+
std=[0.3284, 0.3371, 0.3379])
|
29 |
+
])
|
30 |
+
|
31 |
+
# Inferenz-Funktion mit Markierung für unsichere Kategorien
|
32 |
+
def predict(image):
|
33 |
+
image = image.convert("RGB")
|
34 |
+
input_tensor = preprocess(image).unsqueeze(0)
|
35 |
+
with torch.no_grad():
|
36 |
+
output = model(input_tensor)
|
37 |
+
probs = torch.sigmoid(output).squeeze().numpy()
|
38 |
+
|
39 |
+
result = {}
|
40 |
+
threshold = 0.5
|
41 |
+
top_per_task = {}
|
42 |
+
|
43 |
+
for idx, score in enumerate(probs):
|
44 |
+
task = task_map.get(idx, "unknown")
|
45 |
+
if task not in top_per_task or score > top_per_task[task][1]:
|
46 |
+
top_per_task[task] = (idx, score)
|
47 |
+
|
48 |
+
for task, (idx, score) in top_per_task.items():
|
49 |
+
label = index_to_attr.get(idx, f"Unknown ({idx})").split(" (")[0]
|
50 |
+
result[task] = {
|
51 |
+
"label": label,
|
52 |
+
"score": round(float(score), 4),
|
53 |
+
"confidence": "low" if score < threshold else "high"
|
54 |
+
}
|
55 |
+
|
56 |
+
return result
|
57 |
+
|
58 |
+
# Gradio Interface – stabil und einfach
|
59 |
+
demo = gr.Interface(
|
60 |
+
fn=predict,
|
61 |
+
inputs=gr.Image(type="pil", label="Upload image"),
|
62 |
+
outputs="json",
|
63 |
+
title="Fashion Attribute Predictor (mit Confidence)",
|
64 |
+
description="Zeigt pro Attributgruppe die wahrscheinlichste Vorhersage + Confidence ('high' / 'low')."
|
65 |
+
)
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
demo.launch()
|
index_to_attr (3).py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
index_to_attr = {
|
2 |
+
0: 'Argyle (pattern)',
|
3 |
+
1: 'Asymmetric (style)',
|
4 |
+
2: 'Athletic Pants (category)',
|
5 |
+
3: 'Athletic Sets (category)',
|
6 |
+
4: 'Athletic Shirts (category)',
|
7 |
+
5: 'Athletic Shorts (category)',
|
8 |
+
6: 'Backless Dresses (neckline)',
|
9 |
+
7: 'Baggy Jeans (category)',
|
10 |
+
8: 'Bandage (style)',
|
11 |
+
9: 'Bandeaus (style)',
|
12 |
+
10: 'Batwing Tops (category)',
|
13 |
+
11: 'Beach & Swim Wear (category)',
|
14 |
+
12: 'Beaded (style)',
|
15 |
+
13: 'Beige (color)',
|
16 |
+
14: 'Bikinis (category)',
|
17 |
+
15: 'Binders (category)',
|
18 |
+
16: 'Black (color)',
|
19 |
+
17: 'Blouses (category)',
|
20 |
+
18: 'Blue (color)',
|
21 |
+
19: 'Bodycon (style)',
|
22 |
+
20: 'Bodysuits (category)',
|
23 |
+
21: 'Boots (category)',
|
24 |
+
22: 'Bra Straps (category)',
|
25 |
+
23: 'Bronze (color)',
|
26 |
+
24: 'Brown (color)',
|
27 |
+
25: 'Bubble Coats (category)',
|
28 |
+
26: 'Business Shoes (category)',
|
29 |
+
27: 'Camouflage (pattern)',
|
30 |
+
28: 'Canvas (material)',
|
31 |
+
29: 'Capes & Capelets (category)',
|
32 |
+
30: 'Capri Pants (category)',
|
33 |
+
31: 'Cardigans (category)',
|
34 |
+
32: 'Cargo Pants (category)',
|
35 |
+
33: 'Cargo Shorts (category)',
|
36 |
+
34: 'Cashmere (material)',
|
37 |
+
35: 'Casual Dresses (category)',
|
38 |
+
36: 'Casual Pants (category)',
|
39 |
+
37: 'Casual Shirts (category)',
|
40 |
+
38: 'Casual Shoes (category)',
|
41 |
+
39: 'Casual Shorts (category)',
|
42 |
+
40: 'Chambray (material)',
|
43 |
+
41: 'Checkered (pattern)',
|
44 |
+
42: 'Chevron (pattern)',
|
45 |
+
43: 'Chiffon (material)',
|
46 |
+
44: 'Clear (color)',
|
47 |
+
45: 'Cleats (category)',
|
48 |
+
46: 'Clubbing Dresses (category)',
|
49 |
+
47: 'Cocktail Dresses (category)',
|
50 |
+
48: 'Collared (neckline)',
|
51 |
+
49: 'Corduroy (material)',
|
52 |
+
50: 'Corsets (category)',
|
53 |
+
51: 'Costumes & Cosplay (category)',
|
54 |
+
52: 'Cotton (material)',
|
55 |
+
53: 'Criss Cross (style)',
|
56 |
+
54: 'Crochet (pattern)',
|
57 |
+
55: 'Crop Tops (category)',
|
58 |
+
56: 'Custom Made Clothing (category)',
|
59 |
+
57: 'Dance Wear (category)',
|
60 |
+
58: 'Denim (material)',
|
61 |
+
59: 'Drawstring Pants (category)',
|
62 |
+
60: 'Dress Shirts (category)',
|
63 |
+
61: 'Dresses (category)',
|
64 |
+
62: 'Embroidered (style)',
|
65 |
+
63: 'Fashion Sets (category)',
|
66 |
+
64: 'Faux Fur (material)',
|
67 |
+
65: 'Female (gender)',
|
68 |
+
66: 'Flannel (material)',
|
69 |
+
67: 'Flats (category)',
|
70 |
+
68: 'Fleece (material)',
|
71 |
+
69: 'Floral (pattern)',
|
72 |
+
70: 'Formal Dresses (category)',
|
73 |
+
71: 'Fringe (pattern)',
|
74 |
+
72: 'Furry (style)',
|
75 |
+
73: 'Galaxy (pattern)',
|
76 |
+
74: 'Geometric (pattern)',
|
77 |
+
75: 'Gingham (material)',
|
78 |
+
76: 'Gold (color)',
|
79 |
+
77: 'Gray (color)',
|
80 |
+
78: 'Green (color)',
|
81 |
+
79: 'Halter Tops (category)',
|
82 |
+
80: 'Harem Pants (category)',
|
83 |
+
81: 'Hearts (pattern)',
|
84 |
+
82: 'Heels (category)',
|
85 |
+
83: 'Herringbone (pattern)',
|
86 |
+
84: 'Hi-Lo (style)',
|
87 |
+
85: 'Hiking Boots (category)',
|
88 |
+
86: 'Hollow-Out (style)',
|
89 |
+
87: 'Hoodies & Sweatshirts (category)',
|
90 |
+
88: 'Hosiery, Stockings, Tights (category)',
|
91 |
+
89: 'Houndstooth (pattern)',
|
92 |
+
90: 'Jackets (category)',
|
93 |
+
91: 'Jeans (category)',
|
94 |
+
92: 'Jerseys (category)',
|
95 |
+
93: 'Jilbaab (category)',
|
96 |
+
94: 'Jumpsuits Overalls & Rompers (category)',
|
97 |
+
95: 'Kimonos (category)',
|
98 |
+
96: 'Knit (material)',
|
99 |
+
97: 'Lace (material)',
|
100 |
+
98: 'Leather (material)',
|
101 |
+
99: 'Leggings (category)',
|
102 |
+
100: 'Leopard And Cheetah (pattern)',
|
103 |
+
101: 'Linen (material)',
|
104 |
+
102: 'Lingerie Sleepwear & Underwear (category)',
|
105 |
+
103: 'Loafers & Slip-on Shoes (category)',
|
106 |
+
104: 'Long Sleeved (sleeve)',
|
107 |
+
105: 'Male (gender)',
|
108 |
+
106: 'Marbled (pattern)',
|
109 |
+
107: 'Maroon (color)',
|
110 |
+
108: 'Maternity (category)',
|
111 |
+
109: 'Mesh (pattern)',
|
112 |
+
110: 'Multi Color (color)',
|
113 |
+
111: 'Neoprene (material)',
|
114 |
+
112: 'Neutral (gender)',
|
115 |
+
113: 'Nightgowns (category)',
|
116 |
+
114: 'Nylon (material)',
|
117 |
+
115: 'Off The Shoulder (neckline)',
|
118 |
+
116: 'Orange (color)',
|
119 |
+
117: 'Organza (material)',
|
120 |
+
118: 'Padded Bras (category)',
|
121 |
+
119: 'Paisley (pattern)',
|
122 |
+
120: 'Pajamas (category)',
|
123 |
+
121: 'Party Dresses (category)',
|
124 |
+
122: 'Pasties (category)',
|
125 |
+
123: 'Patent (material)',
|
126 |
+
124: 'Peach (color)',
|
127 |
+
125: 'Peacoats (category)',
|
128 |
+
126: 'Pencil Skirts (category)',
|
129 |
+
127: 'Peplum (style)',
|
130 |
+
128: 'Petticoats (category)',
|
131 |
+
129: 'Pin Stripes (pattern)',
|
132 |
+
130: 'Pink (color)',
|
133 |
+
131: 'Plaid (pattern)',
|
134 |
+
132: 'Pleated (style)',
|
135 |
+
133: 'Plush (material)',
|
136 |
+
134: 'Polka Dot (pattern)',
|
137 |
+
135: 'Polos (category)',
|
138 |
+
136: 'Polyester (material)',
|
139 |
+
137: 'Printed (style)',
|
140 |
+
138: 'Prom Dresses (category)',
|
141 |
+
139: 'Puff Sleeves (sleeve)',
|
142 |
+
140: 'Pullover Sweaters (category)',
|
143 |
+
141: 'Purple (color)',
|
144 |
+
142: 'Quilted (pattern)',
|
145 |
+
143: 'Racerback (neckline)',
|
146 |
+
144: 'Rain Boots (category)',
|
147 |
+
145: 'Raincoats (category)',
|
148 |
+
146: 'Rayon (material)',
|
149 |
+
147: 'Red (color)',
|
150 |
+
148: 'Reversible (style)',
|
151 |
+
149: 'Rhinestone Studded (style)',
|
152 |
+
150: 'Ripped (pattern)',
|
153 |
+
151: 'Robes (category)',
|
154 |
+
152: 'Round Neck (neckline)',
|
155 |
+
153: 'Ruched (pattern)',
|
156 |
+
154: 'Ruffles (pattern)',
|
157 |
+
155: 'Running Shoes (category)',
|
158 |
+
156: 'Sandals (category)',
|
159 |
+
157: 'Satin (material)',
|
160 |
+
158: 'Sequins (pattern)',
|
161 |
+
159: 'Sheer Tops (category)',
|
162 |
+
160: 'Shoe Accessories (category)',
|
163 |
+
161: 'Shoe Inserts (category)',
|
164 |
+
162: 'Shoelaces (category)',
|
165 |
+
163: 'Short Sleeves (sleeve)',
|
166 |
+
164: 'Shorts (category)',
|
167 |
+
165: 'Shoulder Drapes (neckline)',
|
168 |
+
166: 'Silk (material)',
|
169 |
+
167: 'Silver (color)',
|
170 |
+
168: 'Skinny Jeans (category)',
|
171 |
+
169: 'Skirts (category)',
|
172 |
+
170: 'Sleeveless (sleeve)',
|
173 |
+
171: 'Slippers (category)',
|
174 |
+
172: 'Snakeskin (pattern)',
|
175 |
+
173: 'Sneakers (category)',
|
176 |
+
174: 'Spaghetti Straps (style)',
|
177 |
+
175: 'Spandex (material)',
|
178 |
+
176: 'Sports Bras (category)',
|
179 |
+
177: 'Square Necked (neckline)',
|
180 |
+
178: 'Stilettos (category)',
|
181 |
+
179: 'Strapless (sleeve)',
|
182 |
+
180: 'Stripes (pattern)',
|
183 |
+
181: 'Suede (material)',
|
184 |
+
182: 'Suits & Blazers (category)',
|
185 |
+
183: 'Summer (style)',
|
186 |
+
184: 'Sweatpants (category)',
|
187 |
+
185: 'Sweetheart Neckline (neckline)',
|
188 |
+
186: 'Swim Trunks (category)',
|
189 |
+
187: 'Swimsuit Cover-ups (category)',
|
190 |
+
188: 'Swimsuits (category)',
|
191 |
+
189: 'T-Shirts (category)',
|
192 |
+
190: 'Taffeta (material)',
|
193 |
+
191: 'Tan (color)',
|
194 |
+
192: 'Tank Tops (category)',
|
195 |
+
193: 'Teal (color)',
|
196 |
+
194: 'Thermal Underwear (category)',
|
197 |
+
195: 'Thigh Highs (category)',
|
198 |
+
196: 'Thongs (category)',
|
199 |
+
197: 'Three Piece Suits (category)',
|
200 |
+
198: 'Tie Dye (pattern)',
|
201 |
+
199: 'Trench Coats (category)',
|
202 |
+
200: 'Trousers (category)',
|
203 |
+
201: 'Tube Tops (category)',
|
204 |
+
202: 'Tulle (material)',
|
205 |
+
203: 'Tunic (style)',
|
206 |
+
204: 'Turtlenecks (neckline)',
|
207 |
+
205: 'Tutus (category)',
|
208 |
+
206: 'Tweed (material)',
|
209 |
+
207: 'Twill (material)',
|
210 |
+
208: 'Two-Tone (style)',
|
211 |
+
209: 'U-Necks (neckline)',
|
212 |
+
210: 'Undershirts (category)',
|
213 |
+
211: 'Underwear (category)',
|
214 |
+
212: 'Uniforms (category)',
|
215 |
+
213: 'V-Necks (neckline)',
|
216 |
+
214: 'Velour (material)',
|
217 |
+
215: 'Velvet (material)',
|
218 |
+
216: 'Vests (category)',
|
219 |
+
217: 'Vintage Retro (style)',
|
220 |
+
218: 'Vinyl (material)',
|
221 |
+
219: 'Wedding Dresses (category)',
|
222 |
+
220: 'Wedges & Platforms (category)',
|
223 |
+
221: 'White (color)',
|
224 |
+
222: 'Winter Boots (category)',
|
225 |
+
223: 'Wool (material)',
|
226 |
+
224: 'Wrap (style)',
|
227 |
+
225: 'Yellow (color)',
|
228 |
+
226: 'Yoga Pants (category)',
|
229 |
+
227: 'Zebra (pattern)',
|
230 |
+
}
|
model_loader (1).py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.models as models
|
5 |
+
|
6 |
+
def load_model(model_path):
|
7 |
+
# Architektur aufbauen
|
8 |
+
model = models.resnet50(pretrained=False)
|
9 |
+
model.fc = nn.Linear(2048, 228)
|
10 |
+
|
11 |
+
# State Dict laden
|
12 |
+
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
13 |
+
|
14 |
+
# Keys ggf. anpassen
|
15 |
+
new_state_dict = {}
|
16 |
+
for k, v in state_dict.items():
|
17 |
+
if k.startswith("predictor."):
|
18 |
+
new_k = k.replace("predictor.", "")
|
19 |
+
else:
|
20 |
+
new_k = k
|
21 |
+
new_state_dict[new_k] = v
|
22 |
+
|
23 |
+
model.load_state_dict(new_state_dict)
|
24 |
+
model.eval()
|
25 |
+
return model
|
26 |
+
|
27 |
+
def predict_attributes(model, input_tensor):
|
28 |
+
with torch.no_grad():
|
29 |
+
output = model(input_tensor)
|
30 |
+
prediction = torch.sigmoid(output).squeeze().numpy()
|
31 |
+
threshold = 0.5
|
32 |
+
predicted_indices = [i for i, p in enumerate(prediction) if p > threshold]
|
33 |
+
return predicted_indices
|
requirements (1).txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
Pillow
|