Spaces:
Build error
Build error
import csv | |
import os | |
import random | |
import sys | |
from itertools import product | |
import gdown | |
import gradio as gr | |
import matplotlib | |
import matplotlib.patches as patches | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
from matplotlib import pyplot as plt | |
from matplotlib.patches import ConnectionPatch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from common.evaluation import Evaluator | |
from common.logger import AverageMeter, Logger | |
from data import download | |
from model import chmnet | |
from model.base.geometry import Geometry | |
csv.field_size_limit(sys.maxsize) | |
# Downloading the Model | |
md5 = "6b7b4d7bad7f89600fac340d6aa7708b" | |
gdown.cached_download( | |
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", | |
path="pas_psi.pt", | |
quiet=False, | |
md5=md5, | |
) | |
# Model Initialization | |
args = dict( | |
{ | |
"alpha": [0.05, 0.1], | |
"benchmark": "pfpascal", | |
"bsz": 90, | |
"datapath": "../Datasets_CHM", | |
"img_size": 240, | |
"ktype": "psi", | |
"load": "pas_psi.pt", | |
"thres": "img", | |
} | |
) | |
model = chmnet.CHMNet(args["ktype"]) | |
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) | |
Evaluator.initialize(args["alpha"]) | |
Geometry.initialize(img_size=args["img_size"]) | |
model.eval() | |
# Transforms | |
chm_transform = transforms.Compose( | |
[ | |
transforms.Resize(args["img_size"]), | |
transforms.CenterCrop((args["img_size"], args["img_size"])), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
chm_transform_plot = transforms.Compose( | |
[ | |
transforms.Resize(args["img_size"]), | |
transforms.CenterCrop((args["img_size"], args["img_size"])), | |
] | |
) | |
# A Helper Function | |
to_np = lambda x: x.data.to("cpu").numpy() | |
# Colors for Plotting | |
cmap = matplotlib.cm.get_cmap("Spectral") | |
rgba = cmap(0.5) | |
colors = [] | |
for k in range(49): | |
colors.append(cmap(k / 49.0)) | |
# CHM MODEL | |
def run_chm( | |
source_image, | |
target_image, | |
selected_points, | |
number_src_points, | |
chm_transform, | |
display_transform, | |
): | |
# Convert to Tensor | |
src_img_tnsr = chm_transform(source_image).unsqueeze(0) | |
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) | |
# Selected_points = selected_points.T | |
keypoints = torch.tensor(selected_points).unsqueeze(0) | |
n_pts = torch.tensor(np.asarray([number_src_points])) | |
# RUN CHM ------------------------------------------------------------------------ | |
with torch.no_grad(): | |
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) | |
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) | |
# VISUALIZATION | |
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() | |
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() | |
src_points_converted = [] | |
w, h = display_transform(source_image).size | |
for x, y in zip(src_points[0], src_points[1]): | |
src_points_converted.append( | |
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])] | |
) | |
src_points_converted = np.asarray(src_points_converted[:number_src_points]) | |
tgt_points_converted = [] | |
w, h = display_transform(target_image).size | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_points_converted.append( | |
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] | |
) | |
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) | |
tgt_grid = [] | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) | |
# PLOT | |
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) | |
ax[0].imshow(display_transform(source_image)) | |
ax[0].scatter( | |
src_points_converted[:, 0], | |
src_points_converted[:, 1], | |
c=colors[:number_src_points], | |
) | |
ax[0].set_title("Source") | |
ax[0].set_xticks([]) | |
ax[0].set_yticks([]) | |
ax[1].imshow(display_transform(target_image)) | |
ax[1].scatter( | |
tgt_points_converted[:, 0], | |
tgt_points_converted[:, 1], | |
c=colors[:number_src_points], | |
) | |
ax[1].set_title("Target") | |
ax[1].set_xticks([]) | |
ax[1].set_yticks([]) | |
for TL in range(49): | |
ax[0].text( | |
x=src_points_converted[TL][0], | |
y=src_points_converted[TL][1], | |
s=str(TL), | |
fontdict=dict(color="red", size=11), | |
) | |
for TL in range(49): | |
ax[1].text( | |
x=tgt_points_converted[TL][0], | |
y=tgt_points_converted[TL][1], | |
s=f"{str(TL)}", | |
fontdict=dict(color="orange", size=11), | |
) | |
plt.tight_layout() | |
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) | |
return fig | |
# Wrapper | |
def generate_correspondences( | |
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 | |
): | |
A = np.linspace(min_x, max_x, 7) | |
B = np.linspace(min_y, max_y, 7) | |
point_list = list(product(A, B)) | |
new_points = np.asarray(point_list, dtype=np.float64).T | |
return run_chm( | |
sousrce_image, | |
target_image, | |
selected_points=new_points, | |
number_src_points=49, | |
chm_transform=chm_transform, | |
display_transform=chm_transform_plot, | |
) | |
# Gradio App | |
main = gr.Interface( | |
fn=generate_correspondences, | |
inputs=[ | |
gr.Image(shape=(240, 240), type="pil"), | |
gr.Image(shape=(240, 240), type="pil"), | |
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"), | |
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"), | |
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"), | |
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"), | |
], | |
allow_flagging="never", | |
outputs="plot", | |
examples=[ | |
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], | |
[ | |
"./examples/Red_Winged_Blackbird_0012_6015.jpg", | |
"./examples/Red_Winged_Blackbird_0025_5342.jpg", | |
17, | |
223, | |
17, | |
223, | |
], | |
[ | |
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg", | |
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg", | |
17, | |
223, | |
17, | |
223, | |
], | |
], | |
) | |
blocks = gr.Blocks() | |
with blocks: | |
gr.Markdown( | |
""" | |
# Correspondence Matching with Convolutional Hough Matching Networks | |
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. | |
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) | |
""" | |
) | |
gr.TabbedInterface([main], ["Main"]) | |
blocks.launch( | |
debug=True, | |
enable_queue=False, | |
) | |