rafaldembski commited on
Commit
f914a7f
·
verified ·
1 Parent(s): 3d1932a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -33
app.py CHANGED
@@ -8,7 +8,6 @@ from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
  import gdown
11
- import matplotlib.pyplot as plt
12
  import warnings
13
  warnings.filterwarnings("ignore")
14
 
@@ -19,14 +18,14 @@ os.system("mv DIS/IS-Net/* .")
19
  from data_loader_cache import normalize, im_reader, im_preprocess
20
  from models import *
21
 
22
- #Helpers
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
  # Download official weights
26
  if not os.path.exists("saved_models"):
27
  os.mkdir("saved_models")
28
  os.system("mv isnet.pth saved_models/")
29
-
30
  class GOSNormalize(object):
31
  '''
32
  Normalize the Image using torch.transforms
@@ -39,8 +38,7 @@ class GOSNormalize(object):
39
  image = normalize(image,self.mean,self.std)
40
  return image
41
 
42
-
43
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
44
 
45
  def load_image(im_path, hypar):
46
  im = im_reader(im_path)
@@ -49,9 +47,8 @@ def load_image(im_path, hypar):
49
  shape = torch.from_numpy(np.array(im_shp))
50
  return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
51
 
52
-
53
  def build_model(hypar,device):
54
- net = hypar["model"]#GOSNETINC(3,1)
55
 
56
  # convert to half precision
57
  if(hypar["model_digit"]=="half"):
@@ -68,8 +65,7 @@ def build_model(hypar,device):
68
  net.eval()
69
  return net
70
 
71
-
72
- def predict(net, inputs_val, shapes_val, hypar, device):
73
  '''
74
  Given an Image, predict the mask
75
  '''
@@ -80,14 +76,11 @@ def predict(net, inputs_val, shapes_val, hypar, device):
80
  else:
81
  inputs_val = inputs_val.type(torch.HalfTensor)
82
 
83
-
84
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
85
-
86
  ds_val = net(inputs_val_v)[0] # list of 6 results
87
-
88
  pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
89
 
90
- ## recover the prediction spatial size to the orignal image size
91
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
92
 
93
  ma = torch.max(pred_val)
@@ -98,8 +91,7 @@ def predict(net, inputs_val, shapes_val, hypar, device):
98
  return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
99
 
100
  # Set Parameters
101
- hypar = {} # paramters for inferencing
102
-
103
 
104
  hypar["model_path"] ="./saved_models" ## load trained weights from this path
105
  hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
@@ -112,42 +104,38 @@ hypar["seed"] = 0
112
  hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
113
 
114
  ## data augmentation parameters ---
115
- hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
116
  hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
117
 
118
  hypar["model"] = ISNetDIS()
119
 
120
- # Build Model
121
  net = build_model(hypar, device)
122
 
123
-
124
  def inference(image):
125
- image_path = image
 
 
126
 
127
- image_tensor, orig_size = load_image(image_path, hypar)
128
- mask = predict(net, image_tensor, orig_size, hypar, device)
129
 
130
- pil_mask = Image.fromarray(mask).convert('L')
131
- im_rgb = Image.open(image).convert("RGB")
132
-
133
- im_rgba = im_rgb.copy()
134
- im_rgba.putalpha(pil_mask)
135
-
136
- return [im_rgba, pil_mask]
137
 
 
138
 
139
- title = "Highly Accurate Dichotomous Image Segmentation"
140
- description = "This is an unofficial demo for DIS, a model that can remove the background from a given image. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.<br>GitHub: https://github.com/xuebinqin/DIS<br>Telegram bot: https://t.me/restoration_photo_bot<br>[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
141
- article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=max_skobeev_dis_cmp_public' alt='visitor badge'></center></div>"
142
 
143
  interface = gr.Interface(
144
  fn=inference,
145
  inputs=gr.Image(type='filepath'),
146
  outputs=["image", "image"],
147
- examples=[['robot.png'], ['ship.png']],
148
  title=title,
149
  description=description,
150
  article=article,
151
  allow_flagging='never',
152
  cache_examples=False,
153
- ).queue().launch(show_error=True)
 
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
  import gdown
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
 
18
  from data_loader_cache import normalize, im_reader, im_preprocess
19
  from models import *
20
 
21
+ # Helpers
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
 
24
  # Download official weights
25
  if not os.path.exists("saved_models"):
26
  os.mkdir("saved_models")
27
  os.system("mv isnet.pth saved_models/")
28
+
29
  class GOSNormalize(object):
30
  '''
31
  Normalize the Image using torch.transforms
 
38
  image = normalize(image,self.mean,self.std)
39
  return image
40
 
41
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
 
42
 
43
  def load_image(im_path, hypar):
44
  im = im_reader(im_path)
 
47
  shape = torch.from_numpy(np.array(im_shp))
48
  return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
49
 
 
50
  def build_model(hypar,device):
51
+ net = hypar["model"]
52
 
53
  # convert to half precision
54
  if(hypar["model_digit"]=="half"):
 
65
  net.eval()
66
  return net
67
 
68
+ def predict(net, inputs_val, shapes_val, hypar, device):
 
69
  '''
70
  Given an Image, predict the mask
71
  '''
 
76
  else:
77
  inputs_val = inputs_val.type(torch.HalfTensor)
78
 
 
79
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
 
80
  ds_val = net(inputs_val_v)[0] # list of 6 results
 
81
  pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
82
 
83
+ # recover the prediction spatial size to the orignal image size
84
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
85
 
86
  ma = torch.max(pred_val)
 
91
  return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
92
 
93
  # Set Parameters
94
+ hypar = {} # parameters for inferencing
 
95
 
96
  hypar["model_path"] ="./saved_models" ## load trained weights from this path
97
  hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
 
104
  hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
105
 
106
  ## data augmentation parameters ---
107
+ hypar["input_size"] = [1024, 1024] ## model input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
108
  hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
109
 
110
  hypar["model"] = ISNetDIS()
111
 
112
+ # Build Model
113
  net = build_model(hypar, device)
114
 
 
115
  def inference(image):
116
+ image_path = image
117
+ image_tensor, orig_size = load_image(image_path, hypar)
118
+ mask = predict(net, image_tensor, orig_size, hypar, device)
119
 
120
+ pil_mask = Image.fromarray(mask).convert('L')
121
+ im_rgb = Image.open(image).convert("RGB")
122
 
123
+ im_rgba = im_rgb.copy()
124
+ im_rgba.putalpha(pil_mask)
 
 
 
 
 
125
 
126
+ return [im_rgba, pil_mask]
127
 
128
+ title = "Image Segmentation"
129
+ description = ""
130
+ article = ""
131
 
132
  interface = gr.Interface(
133
  fn=inference,
134
  inputs=gr.Image(type='filepath'),
135
  outputs=["image", "image"],
 
136
  title=title,
137
  description=description,
138
  article=article,
139
  allow_flagging='never',
140
  cache_examples=False,
141
+ ).queue().launch(show_error=True)