File size: 3,671 Bytes
5d0efae
 
 
b5a7c3b
 
5d0efae
 
 
b5a7c3b
 
 
220eb39
 
 
 
b5a7c3b
220eb39
b5a7c3b
220eb39
b5a7c3b
220eb39
 
 
 
 
 
b5a7c3b
 
 
 
 
 
220eb39
b5a7c3b
 
 
 
 
 
 
 
5d0efae
b5a7c3b
 
6ea4dca
b5a7c3b
 
 
220eb39
5d0efae
b5a7c3b
 
 
 
5d0efae
 
 
 
b5a7c3b
5d0efae
220eb39
 
b5a7c3b
 
 
 
 
 
 
5d0efae
 
b5a7c3b
 
 
5d0efae
b5a7c3b
 
 
 
5d0efae
b5a7c3b
 
5d0efae
b5a7c3b
 
 
5d0efae
 
b5a7c3b
5d0efae
b5a7c3b
 
 
5d0efae
 
 
 
 
 
b5a7c3b
 
 
5d0efae
 
 
 
b5a7c3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download

########################################
# 1. Define Model Architecture
########################################
class MultiTaskModel(nn.Module):
    def __init__(self, backbone, feature_dim, num_obj_classes):
        super(MultiTaskModel, self).__init__()
        self.backbone = backbone
        # Object recognition head
        self.obj_head = nn.Linear(feature_dim, num_obj_classes)
        # Binary classification head (0: AI-generated, 1: Real)
        self.bin_head = nn.Linear(feature_dim, 2)
    
    def forward(self, x):
        feats = self.backbone(x)
        obj_logits = self.obj_head(feats)
        bin_logits = self.bin_head(feats)
        return obj_logits, bin_logits

########################################
# 2. Reconstruct the Model and Load Weights
########################################
# Set the number of object classes (update this to match your training)
num_obj_classes = 139  # Example; change as needed
device = torch.device("cpu")

# Instantiate the backbone (ResNet-50 without its final layer)
resnet = models.resnet50(pretrained=False)
resnet.fc = nn.Identity()
feature_dim = 2048

# Build the model architecture
model = MultiTaskModel(resnet, feature_dim, num_obj_classes)
model.to(device)

# Download the state dict from HF Hub
repo_id = "Abdu07/multitask-model"
filename = "best_model_new.pt"  # Make sure this is the state dict file
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

########################################
# 3. Define Label Mappings and Transforms
########################################
# Update these mappings with your actual training labels.
idx_to_obj_label = {
    0: "cat",
    1: "dog",
    2: "car",
    # ... add your object classes here ...
}
bin_label_names = ["AI-Generated", "Real"]

# Define the validation transforms (same as used during training/validation)
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

########################################
# 4. Define the Inference Function
########################################
def predict_image(img: Image.Image) -> str:
    """
    Takes an uploaded PIL image, processes it, and returns the model's prediction.
    """
    # Ensure image is in RGB
    img = img.convert("RGB")
    # Apply validation transforms
    img_tensor = val_transforms(img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]
    with torch.no_grad():
        obj_logits, bin_logits = model(img_tensor)
    obj_pred = torch.argmax(obj_logits, dim=1).item()
    bin_pred = torch.argmax(bin_logits, dim=1).item()
    obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
    bin_name = bin_label_names[bin_pred]
    return f"Prediction: {obj_name} ({bin_name})"

########################################
# 5. Create Gradio UI
########################################
demo = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Multi-Task Image Classifier",
    description=(
        "Upload an image to receive two predictions:\n"
        "1) The primary object in the image,\n"
        "2) Whether the image is AI-generated or Real."
    )
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=True)