petergpt commited on
Commit
66a61d0
·
verified ·
1 Parent(s): efae294

multiple images

Browse files
Files changed (1) hide show
  1. app.py +48 -36
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import gradio as gr
3
  import os
@@ -7,11 +8,10 @@ import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
- import gdown
11
  import matplotlib.pyplot as plt
12
  import warnings
13
- warnings.filterwarnings("ignore")
14
  import time
 
15
 
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
@@ -20,14 +20,13 @@ os.system("mv DIS/IS-Net/* .")
20
  from data_loader_cache import normalize, im_reader, im_preprocess
21
  from models import *
22
 
23
- #Helpers
24
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
 
26
  # Download official weights
27
  if not os.path.exists("saved_models"):
28
  os.mkdir("saved_models")
29
  os.system("mv isnet.pth saved_models/")
30
-
31
  class GOSNormalize(object):
32
  '''
33
  Normalize the Image using torch.transforms
@@ -45,9 +44,9 @@ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
45
  def load_image(im_path, hypar):
46
  im = im_reader(im_path)
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
- im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
- return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
51
 
52
  def build_model(hypar, device):
53
  net = hypar["model"]
@@ -67,10 +66,7 @@ def build_model(hypar, device):
67
  net.eval()
68
  return net
69
 
70
- def predict(net, inputs_val, shapes_val, hypar, device):
71
- '''
72
- Given an Image, predict the mask
73
- '''
74
  net.eval()
75
 
76
  if(hypar["model_digit"]=="full"):
@@ -81,21 +77,21 @@ def predict(net, inputs_val, shapes_val, hypar, device):
81
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
82
  ds_val = net(inputs_val_v)[0]
83
  pred_val = ds_val[0][0,:,:,:]
84
- pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),
85
- (shapes_val[0][0],shapes_val[0][1]),
86
  mode='bilinear'))
87
 
88
  ma = torch.max(pred_val)
89
  mi = torch.min(pred_val)
90
- pred_val = (pred_val-mi)/(ma-mi) # normalize to 0~1
91
 
92
  if device == 'cuda':
93
  torch.cuda.empty_cache()
94
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
95
 
96
- # Set Parameters
97
- hypar = {}
98
- hypar["model_path"] ="./saved_models"
99
  hypar["restore_model"] = "isnet.pth"
100
  hypar["interm_sup"] = False
101
  hypar["model_digit"] = "full"
@@ -108,32 +104,42 @@ hypar["model"] = ISNetDIS()
108
  # Build Model
109
  net = build_model(hypar, device)
110
 
111
- def inference(image, logs):
112
  start_time = time.time()
113
-
114
- image_tensor, orig_size = load_image(image, hypar)
115
- mask = predict(net, image_tensor, orig_size, hypar, device)
116
-
117
- pil_mask = Image.fromarray(mask).convert('L')
118
- im_rgb = Image.open(image).convert("RGB")
119
-
120
- im_rgba = im_rgb.copy()
121
- im_rgba.putalpha(pil_mask)
 
 
 
 
 
 
122
 
123
  end_time = time.time()
124
  elapsed = round(end_time - start_time, 2)
125
-
126
- # Update and return logs
 
 
 
 
 
127
  logs = logs or ""
128
- logs += f"Processed in {elapsed} seconds.\n"
129
 
130
- # Return (gallery output), the logs state, and the logs display
131
- return [im_rgba, pil_mask], logs, logs
132
 
133
  title = "Highly Accurate Dichotomous Image Segmentation"
134
  description = (
135
  "This is an unofficial demo for DIS, a model that can remove the background from a given image. "
136
- "To use it, simply upload your image, or click one of the examples to load them. "
137
  "Read more at the links below.<br>"
138
  "GitHub: https://github.com/xuebinqin/DIS<br>"
139
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
@@ -146,13 +152,19 @@ article = (
146
 
147
  interface = gr.Interface(
148
  fn=inference,
149
- inputs=[gr.Image(type='filepath'), gr.State()],
 
 
 
 
 
 
150
  outputs=[
151
- gr.Gallery(format="png"),
152
  gr.State(),
153
  gr.Textbox(label="Logs", lines=6)
154
  ],
155
- examples=[['robot.png'], ['ship.png']],
156
  title=title,
157
  description=description,
158
  article=article,
 
1
+
2
  import cv2
3
  import gradio as gr
4
  import os
 
8
  from torch.autograd import Variable
9
  from torchvision import transforms
10
  import torch.nn.functional as F
 
11
  import matplotlib.pyplot as plt
12
  import warnings
 
13
  import time
14
+ warnings.filterwarnings("ignore")
15
 
16
  os.system("git clone https://github.com/xuebinqin/DIS")
17
  os.system("mv DIS/IS-Net/* .")
 
20
  from data_loader_cache import normalize, im_reader, im_preprocess
21
  from models import *
22
 
 
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
  # Download official weights
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
  os.system("mv isnet.pth saved_models/")
29
+
30
  class GOSNormalize(object):
31
  '''
32
  Normalize the Image using torch.transforms
 
44
  def load_image(im_path, hypar):
45
  im = im_reader(im_path)
46
  im, im_shp = im_preprocess(im, hypar["cache_size"])
47
+ im = torch.divide(im, 255.0)
48
  shape = torch.from_numpy(np.array(im_shp))
49
+ return transform(im).unsqueeze(0), shape.unsqueeze(0)
50
 
51
  def build_model(hypar, device):
52
  net = hypar["model"]
 
66
  net.eval()
67
  return net
68
 
69
+ def predict(net, inputs_val, shapes_val, hypar, device):
 
 
 
70
  net.eval()
71
 
72
  if(hypar["model_digit"]=="full"):
 
77
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
78
  ds_val = net(inputs_val_v)[0]
79
  pred_val = ds_val[0][0,:,:,:]
80
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0),
81
+ (shapes_val[0][0], shapes_val[0][1]),
82
  mode='bilinear'))
83
 
84
  ma = torch.max(pred_val)
85
  mi = torch.min(pred_val)
86
+ pred_val = (pred_val - mi) / (ma - mi + 1e-8) # normalize to 0~1, +1e-8 to avoid div by zero
87
 
88
  if device == 'cuda':
89
  torch.cuda.empty_cache()
90
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
91
 
92
+ # Parameters
93
+ hypar = {}
94
+ hypar["model_path"] = "./saved_models"
95
  hypar["restore_model"] = "isnet.pth"
96
  hypar["interm_sup"] = False
97
  hypar["model_digit"] = "full"
 
104
  # Build Model
105
  net = build_model(hypar, device)
106
 
107
+ def inference(images, logs):
108
  start_time = time.time()
109
+
110
+ # If user didn't upload images, just return empty
111
+ if not images:
112
+ return [], logs, logs
113
+
114
+ processed_pairs = []
115
+ for img_path in images:
116
+ image_tensor, orig_size = load_image(img_path, hypar)
117
+ mask = predict(net, image_tensor, orig_size, hypar, device)
118
+
119
+ pil_mask = Image.fromarray(mask).convert('L')
120
+ im_rgb = Image.open(img_path).convert("RGB")
121
+ im_rgba = im_rgb.copy()
122
+ im_rgba.putalpha(pil_mask)
123
+ processed_pairs.append([im_rgba, pil_mask])
124
 
125
  end_time = time.time()
126
  elapsed = round(end_time - start_time, 2)
127
+
128
+ # Flatten the list so that we can display all images in a single Gallery
129
+ final_images = []
130
+ for pair in processed_pairs:
131
+ final_images.extend(pair)
132
+
133
+ # Update logs
134
  logs = logs or ""
135
+ logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} seconds.\n"
136
 
137
+ return final_images, logs, logs
 
138
 
139
  title = "Highly Accurate Dichotomous Image Segmentation"
140
  description = (
141
  "This is an unofficial demo for DIS, a model that can remove the background from a given image. "
142
+ "To use it, simply upload up to 3 images, or click one of the examples to load them. "
143
  "Read more at the links below.<br>"
144
  "GitHub: https://github.com/xuebinqin/DIS<br>"
145
  "Telegram bot: https://t.me/restoration_photo_bot<br>"
 
152
 
153
  interface = gr.Interface(
154
  fn=inference,
155
+ inputs=[gr.Image(
156
+ type='filepath',
157
+ label='Images (up to 3)',
158
+ multiple=True,
159
+ max_count=3
160
+ ),
161
+ gr.State()],
162
  outputs=[
163
+ gr.Gallery(label="Output (rgba + mask)"),
164
  gr.State(),
165
  gr.Textbox(label="Logs", lines=6)
166
  ],
167
+ examples=[['robot.png'], ['ship.png']], # for multi-image examples, pass a list like ['robot.png','ship.png']
168
  title=title,
169
  description=description,
170
  article=article,