FrezzyI commited on
Commit
9cbb8bc
·
verified ·
1 Parent(s): d5a3442

Upload 4 files

Browse files
Files changed (4) hide show
  1. app (2).py +68 -0
  2. index_to_attr (3).py +230 -0
  3. model_loader (1).py +33 -0
  4. requirements (1).txt +4 -0
app (2).py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from model_loader import load_model
7
+ from index_to_attr import index_to_attr
8
+
9
+ # Modell laden
10
+ model = load_model("model/AttrPredModel_StateDict.pth")
11
+
12
+ # taskName pro Index extrahieren
13
+ def get_task_map(index_to_attr):
14
+ task_map = {}
15
+ for idx, desc in index_to_attr.items():
16
+ if "(" in desc and ")" in desc:
17
+ task = desc.split("(")[-1].split(")")[0]
18
+ task_map[idx] = task
19
+ return task_map
20
+
21
+ task_map = get_task_map(index_to_attr)
22
+
23
+ # Bildverarbeitungspipeline
24
+ preprocess = transforms.Compose([
25
+ transforms.Resize((512, 512)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.6765, 0.6347, 0.6207],
28
+ std=[0.3284, 0.3371, 0.3379])
29
+ ])
30
+
31
+ # Inferenz-Funktion mit Markierung für unsichere Kategorien
32
+ def predict(image):
33
+ image = image.convert("RGB")
34
+ input_tensor = preprocess(image).unsqueeze(0)
35
+ with torch.no_grad():
36
+ output = model(input_tensor)
37
+ probs = torch.sigmoid(output).squeeze().numpy()
38
+
39
+ result = {}
40
+ threshold = 0.5
41
+ top_per_task = {}
42
+
43
+ for idx, score in enumerate(probs):
44
+ task = task_map.get(idx, "unknown")
45
+ if task not in top_per_task or score > top_per_task[task][1]:
46
+ top_per_task[task] = (idx, score)
47
+
48
+ for task, (idx, score) in top_per_task.items():
49
+ label = index_to_attr.get(idx, f"Unknown ({idx})").split(" (")[0]
50
+ result[task] = {
51
+ "label": label,
52
+ "score": round(float(score), 4),
53
+ "confidence": "low" if score < threshold else "high"
54
+ }
55
+
56
+ return result
57
+
58
+ # Gradio Interface – stabil und einfach
59
+ demo = gr.Interface(
60
+ fn=predict,
61
+ inputs=gr.Image(type="pil", label="Upload image"),
62
+ outputs="json",
63
+ title="Fashion Attribute Predictor (mit Confidence)",
64
+ description="Zeigt pro Attributgruppe die wahrscheinlichste Vorhersage + Confidence ('high' / 'low')."
65
+ )
66
+
67
+ if __name__ == "__main__":
68
+ demo.launch()
index_to_attr (3).py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ index_to_attr = {
2
+ 0: 'Argyle (pattern)',
3
+ 1: 'Asymmetric (style)',
4
+ 2: 'Athletic Pants (category)',
5
+ 3: 'Athletic Sets (category)',
6
+ 4: 'Athletic Shirts (category)',
7
+ 5: 'Athletic Shorts (category)',
8
+ 6: 'Backless Dresses (neckline)',
9
+ 7: 'Baggy Jeans (category)',
10
+ 8: 'Bandage (style)',
11
+ 9: 'Bandeaus (style)',
12
+ 10: 'Batwing Tops (category)',
13
+ 11: 'Beach & Swim Wear (category)',
14
+ 12: 'Beaded (style)',
15
+ 13: 'Beige (color)',
16
+ 14: 'Bikinis (category)',
17
+ 15: 'Binders (category)',
18
+ 16: 'Black (color)',
19
+ 17: 'Blouses (category)',
20
+ 18: 'Blue (color)',
21
+ 19: 'Bodycon (style)',
22
+ 20: 'Bodysuits (category)',
23
+ 21: 'Boots (category)',
24
+ 22: 'Bra Straps (category)',
25
+ 23: 'Bronze (color)',
26
+ 24: 'Brown (color)',
27
+ 25: 'Bubble Coats (category)',
28
+ 26: 'Business Shoes (category)',
29
+ 27: 'Camouflage (pattern)',
30
+ 28: 'Canvas (material)',
31
+ 29: 'Capes & Capelets (category)',
32
+ 30: 'Capri Pants (category)',
33
+ 31: 'Cardigans (category)',
34
+ 32: 'Cargo Pants (category)',
35
+ 33: 'Cargo Shorts (category)',
36
+ 34: 'Cashmere (material)',
37
+ 35: 'Casual Dresses (category)',
38
+ 36: 'Casual Pants (category)',
39
+ 37: 'Casual Shirts (category)',
40
+ 38: 'Casual Shoes (category)',
41
+ 39: 'Casual Shorts (category)',
42
+ 40: 'Chambray (material)',
43
+ 41: 'Checkered (pattern)',
44
+ 42: 'Chevron (pattern)',
45
+ 43: 'Chiffon (material)',
46
+ 44: 'Clear (color)',
47
+ 45: 'Cleats (category)',
48
+ 46: 'Clubbing Dresses (category)',
49
+ 47: 'Cocktail Dresses (category)',
50
+ 48: 'Collared (neckline)',
51
+ 49: 'Corduroy (material)',
52
+ 50: 'Corsets (category)',
53
+ 51: 'Costumes & Cosplay (category)',
54
+ 52: 'Cotton (material)',
55
+ 53: 'Criss Cross (style)',
56
+ 54: 'Crochet (pattern)',
57
+ 55: 'Crop Tops (category)',
58
+ 56: 'Custom Made Clothing (category)',
59
+ 57: 'Dance Wear (category)',
60
+ 58: 'Denim (material)',
61
+ 59: 'Drawstring Pants (category)',
62
+ 60: 'Dress Shirts (category)',
63
+ 61: 'Dresses (category)',
64
+ 62: 'Embroidered (style)',
65
+ 63: 'Fashion Sets (category)',
66
+ 64: 'Faux Fur (material)',
67
+ 65: 'Female (gender)',
68
+ 66: 'Flannel (material)',
69
+ 67: 'Flats (category)',
70
+ 68: 'Fleece (material)',
71
+ 69: 'Floral (pattern)',
72
+ 70: 'Formal Dresses (category)',
73
+ 71: 'Fringe (pattern)',
74
+ 72: 'Furry (style)',
75
+ 73: 'Galaxy (pattern)',
76
+ 74: 'Geometric (pattern)',
77
+ 75: 'Gingham (material)',
78
+ 76: 'Gold (color)',
79
+ 77: 'Gray (color)',
80
+ 78: 'Green (color)',
81
+ 79: 'Halter Tops (category)',
82
+ 80: 'Harem Pants (category)',
83
+ 81: 'Hearts (pattern)',
84
+ 82: 'Heels (category)',
85
+ 83: 'Herringbone (pattern)',
86
+ 84: 'Hi-Lo (style)',
87
+ 85: 'Hiking Boots (category)',
88
+ 86: 'Hollow-Out (style)',
89
+ 87: 'Hoodies & Sweatshirts (category)',
90
+ 88: 'Hosiery, Stockings, Tights (category)',
91
+ 89: 'Houndstooth (pattern)',
92
+ 90: 'Jackets (category)',
93
+ 91: 'Jeans (category)',
94
+ 92: 'Jerseys (category)',
95
+ 93: 'Jilbaab (category)',
96
+ 94: 'Jumpsuits Overalls & Rompers (category)',
97
+ 95: 'Kimonos (category)',
98
+ 96: 'Knit (material)',
99
+ 97: 'Lace (material)',
100
+ 98: 'Leather (material)',
101
+ 99: 'Leggings (category)',
102
+ 100: 'Leopard And Cheetah (pattern)',
103
+ 101: 'Linen (material)',
104
+ 102: 'Lingerie Sleepwear & Underwear (category)',
105
+ 103: 'Loafers & Slip-on Shoes (category)',
106
+ 104: 'Long Sleeved (sleeve)',
107
+ 105: 'Male (gender)',
108
+ 106: 'Marbled (pattern)',
109
+ 107: 'Maroon (color)',
110
+ 108: 'Maternity (category)',
111
+ 109: 'Mesh (pattern)',
112
+ 110: 'Multi Color (color)',
113
+ 111: 'Neoprene (material)',
114
+ 112: 'Neutral (gender)',
115
+ 113: 'Nightgowns (category)',
116
+ 114: 'Nylon (material)',
117
+ 115: 'Off The Shoulder (neckline)',
118
+ 116: 'Orange (color)',
119
+ 117: 'Organza (material)',
120
+ 118: 'Padded Bras (category)',
121
+ 119: 'Paisley (pattern)',
122
+ 120: 'Pajamas (category)',
123
+ 121: 'Party Dresses (category)',
124
+ 122: 'Pasties (category)',
125
+ 123: 'Patent (material)',
126
+ 124: 'Peach (color)',
127
+ 125: 'Peacoats (category)',
128
+ 126: 'Pencil Skirts (category)',
129
+ 127: 'Peplum (style)',
130
+ 128: 'Petticoats (category)',
131
+ 129: 'Pin Stripes (pattern)',
132
+ 130: 'Pink (color)',
133
+ 131: 'Plaid (pattern)',
134
+ 132: 'Pleated (style)',
135
+ 133: 'Plush (material)',
136
+ 134: 'Polka Dot (pattern)',
137
+ 135: 'Polos (category)',
138
+ 136: 'Polyester (material)',
139
+ 137: 'Printed (style)',
140
+ 138: 'Prom Dresses (category)',
141
+ 139: 'Puff Sleeves (sleeve)',
142
+ 140: 'Pullover Sweaters (category)',
143
+ 141: 'Purple (color)',
144
+ 142: 'Quilted (pattern)',
145
+ 143: 'Racerback (neckline)',
146
+ 144: 'Rain Boots (category)',
147
+ 145: 'Raincoats (category)',
148
+ 146: 'Rayon (material)',
149
+ 147: 'Red (color)',
150
+ 148: 'Reversible (style)',
151
+ 149: 'Rhinestone Studded (style)',
152
+ 150: 'Ripped (pattern)',
153
+ 151: 'Robes (category)',
154
+ 152: 'Round Neck (neckline)',
155
+ 153: 'Ruched (pattern)',
156
+ 154: 'Ruffles (pattern)',
157
+ 155: 'Running Shoes (category)',
158
+ 156: 'Sandals (category)',
159
+ 157: 'Satin (material)',
160
+ 158: 'Sequins (pattern)',
161
+ 159: 'Sheer Tops (category)',
162
+ 160: 'Shoe Accessories (category)',
163
+ 161: 'Shoe Inserts (category)',
164
+ 162: 'Shoelaces (category)',
165
+ 163: 'Short Sleeves (sleeve)',
166
+ 164: 'Shorts (category)',
167
+ 165: 'Shoulder Drapes (neckline)',
168
+ 166: 'Silk (material)',
169
+ 167: 'Silver (color)',
170
+ 168: 'Skinny Jeans (category)',
171
+ 169: 'Skirts (category)',
172
+ 170: 'Sleeveless (sleeve)',
173
+ 171: 'Slippers (category)',
174
+ 172: 'Snakeskin (pattern)',
175
+ 173: 'Sneakers (category)',
176
+ 174: 'Spaghetti Straps (style)',
177
+ 175: 'Spandex (material)',
178
+ 176: 'Sports Bras (category)',
179
+ 177: 'Square Necked (neckline)',
180
+ 178: 'Stilettos (category)',
181
+ 179: 'Strapless (sleeve)',
182
+ 180: 'Stripes (pattern)',
183
+ 181: 'Suede (material)',
184
+ 182: 'Suits & Blazers (category)',
185
+ 183: 'Summer (style)',
186
+ 184: 'Sweatpants (category)',
187
+ 185: 'Sweetheart Neckline (neckline)',
188
+ 186: 'Swim Trunks (category)',
189
+ 187: 'Swimsuit Cover-ups (category)',
190
+ 188: 'Swimsuits (category)',
191
+ 189: 'T-Shirts (category)',
192
+ 190: 'Taffeta (material)',
193
+ 191: 'Tan (color)',
194
+ 192: 'Tank Tops (category)',
195
+ 193: 'Teal (color)',
196
+ 194: 'Thermal Underwear (category)',
197
+ 195: 'Thigh Highs (category)',
198
+ 196: 'Thongs (category)',
199
+ 197: 'Three Piece Suits (category)',
200
+ 198: 'Tie Dye (pattern)',
201
+ 199: 'Trench Coats (category)',
202
+ 200: 'Trousers (category)',
203
+ 201: 'Tube Tops (category)',
204
+ 202: 'Tulle (material)',
205
+ 203: 'Tunic (style)',
206
+ 204: 'Turtlenecks (neckline)',
207
+ 205: 'Tutus (category)',
208
+ 206: 'Tweed (material)',
209
+ 207: 'Twill (material)',
210
+ 208: 'Two-Tone (style)',
211
+ 209: 'U-Necks (neckline)',
212
+ 210: 'Undershirts (category)',
213
+ 211: 'Underwear (category)',
214
+ 212: 'Uniforms (category)',
215
+ 213: 'V-Necks (neckline)',
216
+ 214: 'Velour (material)',
217
+ 215: 'Velvet (material)',
218
+ 216: 'Vests (category)',
219
+ 217: 'Vintage Retro (style)',
220
+ 218: 'Vinyl (material)',
221
+ 219: 'Wedding Dresses (category)',
222
+ 220: 'Wedges & Platforms (category)',
223
+ 221: 'White (color)',
224
+ 222: 'Winter Boots (category)',
225
+ 223: 'Wool (material)',
226
+ 224: 'Wrap (style)',
227
+ 225: 'Yellow (color)',
228
+ 226: 'Yoga Pants (category)',
229
+ 227: 'Zebra (pattern)',
230
+ }
model_loader (1).py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+
6
+ def load_model(model_path):
7
+ # Architektur aufbauen
8
+ model = models.resnet50(pretrained=False)
9
+ model.fc = nn.Linear(2048, 228)
10
+
11
+ # State Dict laden
12
+ state_dict = torch.load(model_path, map_location=torch.device("cpu"))
13
+
14
+ # Keys ggf. anpassen
15
+ new_state_dict = {}
16
+ for k, v in state_dict.items():
17
+ if k.startswith("predictor."):
18
+ new_k = k.replace("predictor.", "")
19
+ else:
20
+ new_k = k
21
+ new_state_dict[new_k] = v
22
+
23
+ model.load_state_dict(new_state_dict)
24
+ model.eval()
25
+ return model
26
+
27
+ def predict_attributes(model, input_tensor):
28
+ with torch.no_grad():
29
+ output = model(input_tensor)
30
+ prediction = torch.sigmoid(output).squeeze().numpy()
31
+ threshold = 0.5
32
+ predicted_indices = [i for i, p in enumerate(prediction) if p > threshold]
33
+ return predicted_indices
requirements (1).txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow