Spaces:
Running
Running
File size: 3,671 Bytes
5d0efae b5a7c3b 5d0efae b5a7c3b 220eb39 b5a7c3b 220eb39 b5a7c3b 220eb39 b5a7c3b 220eb39 b5a7c3b 220eb39 b5a7c3b 5d0efae b5a7c3b 6ea4dca b5a7c3b 220eb39 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae 220eb39 b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b 5d0efae b5a7c3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
########################################
# 1. Define Model Architecture
########################################
class MultiTaskModel(nn.Module):
def __init__(self, backbone, feature_dim, num_obj_classes):
super(MultiTaskModel, self).__init__()
self.backbone = backbone
# Object recognition head
self.obj_head = nn.Linear(feature_dim, num_obj_classes)
# Binary classification head (0: AI-generated, 1: Real)
self.bin_head = nn.Linear(feature_dim, 2)
def forward(self, x):
feats = self.backbone(x)
obj_logits = self.obj_head(feats)
bin_logits = self.bin_head(feats)
return obj_logits, bin_logits
########################################
# 2. Reconstruct the Model and Load Weights
########################################
# Set the number of object classes (update this to match your training)
num_obj_classes = 139 # Example; change as needed
device = torch.device("cpu")
# Instantiate the backbone (ResNet-50 without its final layer)
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()
feature_dim = 2048
# Build the model architecture
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)
# Download the state dict from HF Hub
repo_id = "Abdu07/multitask-model"
filename = "best_model_new.pt" # Make sure this is the state dict file
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these mappings with your actual training labels.
idx_to_obj_label = {
0: "cat",
1: "dog",
2: "car",
# ... add your object classes here ...
}
bin_label_names = ["AI-Generated", "Real"]
# Define the validation transforms (same as used during training/validation)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
"""
Takes an uploaded PIL image, processes it, and returns the model's prediction.
"""
# Ensure image is in RGB
img = img.convert("RGB")
# Apply validation transforms
img_tensor = val_transforms(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
with torch.no_grad():
obj_logits, bin_logits = model(img_tensor)
obj_pred = torch.argmax(obj_logits, dim=1).item()
bin_pred = torch.argmax(bin_logits, dim=1).item()
obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
bin_name = bin_label_names[bin_pred]
return f"Prediction: {obj_name} ({bin_name})"
########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Multi-Task Image Classifier",
description=(
"Upload an image to receive two predictions:\n"
"1) The primary object in the image,\n"
"2) Whether the image is AI-generated or Real."
)
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True)
|