multitask-demo / app.py
Abdu07's picture
Update app.py
6ea4dca verified
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
########################################
# 1. Define Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
# Object recognition head
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
# Binary classification head (0: AI-generated, 1: Real)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# Set the number of object classes (update this to match your training)
num_obj_classes = 139 # Example; change as needed
device = torch.device("cpu")
# Instantiate the backbone (ResNet-50 without its final layer)
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()
feature_dim = 2048
# Build the model architecture
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
# Download the state dict from HF Hub
repo_id = "Abdu07/multitask-model"
filename = "best_model_new.pt" # Make sure this is the state dict file
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these mappings with your actual training labels.
idx_to_obj_label = {
0: "cat",
1: "dog",
2: "car",
# ... add your object classes here ...
}
bin_label_names = ["AI-Generated", "Real"]
# Define the validation transforms (same as used during training/validation)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
"""
Takes an uploaded PIL image, processes it, and returns the model's prediction.
"""
# Ensure image is in RGB
img = img.convert("RGB")
# Apply validation transforms
img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to receive two predictions:\n"
"1) The primary object in the image,\n"
"2) Whether the image is AI-generated or Real."
)
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)