Spaces:
Build error
Build error
from torch.utils.data import DataLoader | |
import torch | |
from model.base.geometry import Geometry | |
from common.evaluation import Evaluator | |
from common.logger import AverageMeter | |
from common.logger import Logger | |
from data import download | |
from model import chmnet | |
from itertools import product | |
import matplotlib | |
import matplotlib.patches as patches | |
from matplotlib.patches import ConnectionPatch | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
import numpy as np | |
import os | |
import torchvision | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
import torchvision.models as models | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import random | |
import gradio as gr | |
# Downloading the Model | |
torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt') | |
# Model Initialization | |
args = dict({ | |
'alpha' : [0.05, 0.1], | |
'benchmark':'pfpascal', | |
'bsz':90, | |
'datapath':'../Datasets_CHM', | |
'img_size':240, | |
'ktype':'psi', | |
'load':'pas_psi.pt', | |
'thres':'img' | |
}) | |
model = chmnet.CHMNet(args['ktype']) | |
model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu'))) | |
Evaluator.initialize(args['alpha']) | |
Geometry.initialize(img_size=args['img_size']) | |
model.eval(); | |
# Transforms | |
chm_transform = transforms.Compose( | |
[transforms.Resize(args['img_size']), | |
transforms.CenterCrop((args['img_size'], args['img_size'])), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) | |
chm_transform_plot = transforms.Compose( | |
[transforms.Resize(args['img_size']), | |
transforms.CenterCrop((args['img_size'], args['img_size']))]) | |
# A Helper Function | |
to_np = lambda x: x.data.to('cpu').numpy() | |
# Colors for Plotting | |
cmap = matplotlib.cm.get_cmap('Spectral') | |
rgba = cmap(0.5) | |
colors = [] | |
for k in range(49): | |
colors.append(cmap(k/49.0)) | |
# CHM MODEL | |
def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform): | |
# Convert to Tensor | |
src_img_tnsr = chm_transform(source_image).unsqueeze(0) | |
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) | |
# Selected_points = selected_points.T | |
keypoints = torch.tensor(selected_points).unsqueeze(0) | |
n_pts = torch.tensor(np.asarray([number_src_points])) | |
# RUN CHM ------------------------------------------------------------------------ | |
with torch.no_grad(): | |
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) | |
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) | |
# VISUALIZATION | |
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() | |
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() | |
src_points_converted = [] | |
w, h = display_transform(source_image).size | |
for x,y in zip(src_points[0], src_points[1]): | |
src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])]) | |
src_points_converted = np.asarray(src_points_converted[:number_src_points]) | |
tgt_points_converted = [] | |
w, h = display_transform(target_image).size | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)]) | |
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) | |
tgt_grid = [] | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)]) | |
# PLOT | |
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) | |
ax[0].imshow(display_transform(source_image)) | |
ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points]) | |
ax[0].set_title('Source') | |
ax[0].set_xticks([]) | |
ax[0].set_yticks([]) | |
ax[1].imshow(display_transform(target_image)) | |
ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points]) | |
ax[1].set_title('Target') | |
ax[1].set_xticks([]) | |
ax[1].set_yticks([]) | |
for TL in range(49): | |
ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=11)) | |
for TL in range(49): | |
ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=11)) | |
plt.tight_layout() | |
fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16) | |
return fig | |
# Wrapper | |
def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100): | |
A = np.linspace(min_x, max_x, 7) | |
B = np.linspace(min_y, max_y, 7) | |
point_list = list(product(A, B)) | |
new_points = np.asarray(point_list, dtype=np.float64).T | |
return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot) | |
# GRADIO APP | |
title = "Correspondence Matching with Convolutional Hough Matching Networks " | |
description = "Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid." | |
article = "<p style='text-align: center'><a href='https://github.com/juhongm999/chm' target='_blank'>Original Github Repo</a></p>" | |
iface = gr.Interface(fn=generate_correspondences, | |
inputs=[gr.inputs.Image(shape=(240, 240), type='pil'), | |
gr.inputs.Image(shape=(240, 240), type='pil'), | |
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min X'), | |
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max X'), | |
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min Y'), | |
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max Y')], outputs="plot", enable_queue=True, title=title, | |
description=description, | |
article=article, | |
examples=[['sample1.jpeg', 'sample2.jpeg', 15, 215, 15, 215]]) | |
iface.launch() |