import gradio as gr import torch from PIL import Image from torchvision import transforms from model_loader import load_model from index_to_attr import index_to_attr # Modell laden model = load_model("model/AttrPredModel_StateDict.pth") # taskName pro Index extrahieren def get_task_map(index_to_attr): task_map = {} for idx, desc in index_to_attr.items(): if "(" in desc and ")" in desc: task = desc.split("(")[-1].split(")")[0] task_map[idx] = task return task_map task_map = get_task_map(index_to_attr) # Bildverarbeitungspipeline preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize(mean=[0.6765, 0.6347, 0.6207], std=[0.3284, 0.3371, 0.3379]) ]) # Inferenz-Funktion mit Markierung für unsichere Kategorien def predict(image): image = image.convert("RGB") input_tensor = preprocess(image).unsqueeze(0) with torch.no_grad(): output = model(input_tensor) probs = torch.sigmoid(output).squeeze().numpy() result = {} threshold = 0.5 top_per_task = {} for idx, score in enumerate(probs): task = task_map.get(idx, "unknown") if task not in top_per_task or score > top_per_task[task][1]: top_per_task[task] = (idx, score) for task, (idx, score) in top_per_task.items(): label = index_to_attr.get(idx, f"Unknown ({idx})").split(" (")[0] result[task] = { "label": label, "score": round(float(score), 4), "confidence": "low" if score < threshold else "high" } return result # Gradio Interface – stabil und einfach demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload image"), outputs="json", title="Fashion Attribute Predictor (mit Confidence)", description="Zeigt pro Attributgruppe die wahrscheinlichste Vorhersage + Confidence ('high' / 'low')." ) if __name__ == "__main__": demo.launch()