|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
from model_loader import load_model |
|
from index_to_attr import index_to_attr |
|
|
|
|
|
model = load_model("model/AttrPredModel_StateDict.pth") |
|
|
|
|
|
def get_task_map(index_to_attr): |
|
task_map = {} |
|
for idx, desc in index_to_attr.items(): |
|
if "(" in desc and ")" in desc: |
|
task = desc.split("(")[-1].split(")")[0] |
|
task_map[idx] = task |
|
return task_map |
|
|
|
task_map = get_task_map(index_to_attr) |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.6765, 0.6347, 0.6207], |
|
std=[0.3284, 0.3371, 0.3379]) |
|
]) |
|
|
|
|
|
def predict(image): |
|
image = image.convert("RGB") |
|
input_tensor = preprocess(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
probs = torch.sigmoid(output).squeeze().numpy() |
|
|
|
result = {} |
|
threshold = 0.5 |
|
top_per_task = {} |
|
|
|
for idx, score in enumerate(probs): |
|
task = task_map.get(idx, "unknown") |
|
if task not in top_per_task or score > top_per_task[task][1]: |
|
top_per_task[task] = (idx, score) |
|
|
|
for task, (idx, score) in top_per_task.items(): |
|
label = index_to_attr.get(idx, f"Unknown ({idx})").split(" (")[0] |
|
result[task] = { |
|
"label": label, |
|
"score": round(float(score), 4), |
|
"confidence": "low" if score < threshold else "high" |
|
} |
|
|
|
return result |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload image"), |
|
outputs="json", |
|
title="Fashion Attribute Predictor (mit Confidence)", |
|
description="Zeigt pro Attributgruppe die wahrscheinlichste Vorhersage + Confidence ('high' / 'low')." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|