Abdu07 commited on
Commit
220eb39
·
verified ·
1 Parent(s): fcab065

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -60
app.py CHANGED
@@ -1,83 +1,83 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from torchvision import transforms
5
  from PIL import Image
6
- import requests
7
  from huggingface_hub import hf_hub_download
8
 
9
- ########################
10
- # 1) Download & Load Model
11
- ########################
12
-
13
- # Replace with your actual model repo on HF
14
- repo_id = "Abdu07/multitask-model"
15
- filename = "multitask_model.pth"
16
-
17
- # Download the model file from the Hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
19
- model = torch.load(model_path, map_location="cpu") # or map_location="cuda" if you prefer
20
- model.eval()
21
 
22
- ########################
23
- # 2) Define Label Mappings
24
- ########################
25
 
26
- # For example, if your object labels are saved in code:
 
 
27
  idx_to_obj_label = {
 
28
  0: "cat",
29
  1: "dog",
30
  2: "car",
31
- # ... fill in all your categories ...
32
  }
33
-
34
- bin_label_names = ["AI-Generated", "Real"] # Adjust if 0=AI, 1=Real
35
-
36
- ########################
37
- # 3) Define Transforms
38
- ########################
39
-
40
- # Match the transforms you used during validation
41
- val_transforms = transforms.Compose([
42
- transforms.Resize(256),
43
- transforms.CenterCrop(224),
44
- transforms.ToTensor(),
45
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
- std=[0.229, 0.224, 0.225])
47
  ])
48
 
49
- ########################
50
- # 4) Define the Inference Function
51
- ########################
52
-
53
  def predict_image(img: Image.Image) -> str:
54
- """
55
- Takes a PIL image, applies transforms, passes through the model,
56
- and returns the combined prediction (object + AI/Real).
57
- """
58
- # Convert to RGB just in case
59
  img = img.convert("RGB")
60
-
61
- # Apply transforms
62
- img_t = val_transforms(img)
63
- # Add batch dimension
64
- img_t = img_t.unsqueeze(0)
65
-
66
  with torch.no_grad():
67
  obj_logits, bin_logits = model(img_t)
68
  obj_pred = torch.argmax(obj_logits, dim=1).item()
69
  bin_pred = torch.argmax(bin_logits, dim=1).item()
70
-
71
- # Map predictions to labels
72
  obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
73
  bin_name = bin_label_names[bin_pred]
74
-
75
  return f"Object: {obj_name} | Authenticity: {bin_name}"
76
 
77
- ########################
78
- # 5) Build Gradio UI
79
- ########################
80
-
81
  demo = gr.Interface(
82
  fn=predict_image,
83
  inputs=gr.Image(type="pil"),
@@ -85,15 +85,10 @@ demo = gr.Interface(
85
  title="Multi-Task Image Classifier",
86
  description=(
87
  "Upload an image to get two predictions: "
88
- "1) The primary object (from pseudo-labeling), "
89
- "2) Whether the image is AI-generated or real."
90
  )
91
  )
92
 
93
- ########################
94
- # 6) Launch the App
95
- ########################
96
-
97
  def main():
98
  demo.launch(server_name="0.0.0.0", enable_queue=True)
99
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import torchvision.transforms as T
5
  from PIL import Image
 
6
  from huggingface_hub import hf_hub_download
7
 
8
+ #####################################
9
+ # 1) Define the same custom class
10
+ #####################################
11
+ class MultiTaskModel(nn.Module):
12
+ def __init__(self, backbone, feature_dim, num_obj_classes):
13
+ super(MultiTaskModel, self).__init__()
14
+ self.backbone = backbone
15
+ self.obj_head = nn.Linear(feature_dim, num_obj_classes)
16
+ self.bin_head = nn.Linear(feature_dim, 2)
17
+
18
+ def forward(self, x):
19
+ feats = self.backbone(x)
20
+ obj_logits = self.obj_head(feats)
21
+ bin_logits = self.bin_head(feats)
22
+ return obj_logits, bin_logits
23
+
24
+ #####################################
25
+ # 2) Allowlist the class
26
+ #####################################
27
+ import torch.serialization
28
+ torch.serialization.add_safe_globals([MultiTaskModel])
29
+
30
+ #####################################
31
+ # 3) Download & Load the full model
32
+ #####################################
33
+ repo_id = "Abdu07/multitask-model" # or your actual repo
34
+ filename = "multitask_model.pth" # the file you uploaded
35
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
 
36
 
37
+ # Force PyTorch to load the full model object
38
+ model = torch.load(model_path, map_location="cpu") # default weights_only=True, but we added safe_globals
39
+ model.eval()
40
 
41
+ #####################################
42
+ # 4) Label Mappings
43
+ #####################################
44
  idx_to_obj_label = {
45
+ # Fill in with your actual object label indices
46
  0: "cat",
47
  1: "dog",
48
  2: "car",
49
+ # ...
50
  }
51
+ bin_label_names = ["AI-Generated", "Real"]
52
+
53
+ #####################################
54
+ # 5) Validation Transforms
55
+ #####################################
56
+ val_transforms = T.Compose([
57
+ T.Resize(256),
58
+ T.CenterCrop(224),
59
+ T.ToTensor(),
60
+ T.Normalize(mean=[0.485, 0.456, 0.406],
61
+ std=[0.229, 0.224, 0.225])
 
 
 
62
  ])
63
 
64
+ #####################################
65
+ # 6) Inference Function
66
+ #####################################
 
67
  def predict_image(img: Image.Image) -> str:
 
 
 
 
 
68
  img = img.convert("RGB")
69
+ img_t = val_transforms(img).unsqueeze(0)
 
 
 
 
 
70
  with torch.no_grad():
71
  obj_logits, bin_logits = model(img_t)
72
  obj_pred = torch.argmax(obj_logits, dim=1).item()
73
  bin_pred = torch.argmax(bin_logits, dim=1).item()
 
 
74
  obj_name = idx_to_obj_label.get(obj_pred, "Unknown")
75
  bin_name = bin_label_names[bin_pred]
 
76
  return f"Object: {obj_name} | Authenticity: {bin_name}"
77
 
78
+ #####################################
79
+ # 7) Gradio UI
80
+ #####################################
 
81
  demo = gr.Interface(
82
  fn=predict_image,
83
  inputs=gr.Image(type="pil"),
 
85
  title="Multi-Task Image Classifier",
86
  description=(
87
  "Upload an image to get two predictions: "
88
+ "1) The primary object, 2) Whether the image is AI-generated or real."
 
89
  )
90
  )
91
 
 
 
 
 
92
  def main():
93
  demo.launch(server_name="0.0.0.0", enable_queue=True)
94