FrezzyI's picture
Rename app (2).py to app.py
c6393a1 verified
raw
history blame
2.02 kB
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from model_loader import load_model
from index_to_attr import index_to_attr
# Modell laden
model = load_model("model/AttrPredModel_StateDict.pth")
# taskName pro Index extrahieren
def get_task_map(index_to_attr):
task_map = {}
for idx, desc in index_to_attr.items():
if "(" in desc and ")" in desc:
task = desc.split("(")[-1].split(")")[0]
task_map[idx] = task
return task_map
task_map = get_task_map(index_to_attr)
# Bildverarbeitungspipeline
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.6765, 0.6347, 0.6207],
std=[0.3284, 0.3371, 0.3379])
])
# Inferenz-Funktion mit Markierung für unsichere Kategorien
def predict(image):
image = image.convert("RGB")
input_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probs = torch.sigmoid(output).squeeze().numpy()
result = {}
threshold = 0.5
top_per_task = {}
for idx, score in enumerate(probs):
task = task_map.get(idx, "unknown")
if task not in top_per_task or score > top_per_task[task][1]:
top_per_task[task] = (idx, score)
for task, (idx, score) in top_per_task.items():
label = index_to_attr.get(idx, f"Unknown ({idx})").split(" (")[0]
result[task] = {
"label": label,
"score": round(float(score), 4),
"confidence": "low" if score < threshold else "high"
}
return result
# Gradio Interface – stabil und einfach
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload image"),
outputs="json",
title="Fashion Attribute Predictor (mit Confidence)",
description="Zeigt pro Attributgruppe die wahrscheinlichste Vorhersage + Confidence ('high' / 'low')."
)
if __name__ == "__main__":
demo.launch()