Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,403 Bytes
7febe9c d10f17e a7ce0cb 7febe9c 15d5ec0 7febe9c 6e533ca 7febe9c 2676e90 b96d57e 7febe9c b96d57e 7febe9c 6e533ca b96d57e 7febe9c 726f866 b96d57e 7febe9c b96d57e 7febe9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
from glob import glob
import cv2
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr
import spaces
from models.GCoNet import GCoNet
import zipfile
device = ['cpu', 'cuda'][0]
class ImagePreprocessor():
def __init__(self) -> None:
self.transform_image = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image):
image = self.transform_image(image)
return image
def save_tensor_img(path, tenor_im):
im = tenor_im.cpu().clone()
im = im.squeeze(0)
tensor2pil = transforms.ToPILImage()
im = tensor2pil(im)
im.save(path)
model = GCoNet(bb_pretrained=False).to(device)
state_dict = './ultimate_duts_cocoseg (The best one).pth'
if os.path.exists(state_dict):
gconet_dict = torch.load(state_dict, map_location=device)
model.load_state_dict(gconet_dict)
model.eval()
@spaces.GPU
def pred_maps(images):
assert (images is not None), 'AssertionError: images cannot be None.'
# For tab_batch
save_paths = []
save_dir = 'preds-GCoNet_plus'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_array_lst = []
for idx_image, image_src in enumerate(images):
save_paths.append(os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0])))
if isinstance(image_src, str):
image = np.array(Image.open(image_src))
else:
image = image_src
image_array_lst.append(image)
images = image_array_lst
image_shapes = [image.shape[:2] for image in images]
images = [Image.fromarray(image) for image in images]
images_proc = []
image_preprocessor = ImagePreprocessor()
for image in images:
images_proc.append(image_preprocessor.proc(image))
images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
with torch.no_grad():
scaled_preds_tensor = model(images_proc.to(device))[-1]
preds = []
for image_shape, pred_tensor, save_path in zip(image_shapes, scaled_preds_tensor, save_paths):
if device == 'cuda':
pred_tensor = pred_tensor.cpu()
pred_tensor = torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze()
save_tensor_img(save_path, pred_tensor)
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
for file in save_paths:
zipf.write(file, os.path.basename(file))
return save_paths, zip_file_path
tab_batch = gr.Interface(
fn=pred_maps,
inputs=gr.File(label="Upload multiple images in a group", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="GCoNet+'s predictions"), gr.File(label="Download predicted maps.")],
api_name="batch",
description='Upload pictures, most of which contain salient objects of the same class. Our demo will give you the binary maps of these co-salient objects :)',
)
demo = gr.TabbedInterface(
[tab_batch],
['batch'],
title="Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`",
)
demo.launch(debug=True)
|