Spaces:
Build error
Build error
import csv | |
import os | |
import random | |
import sys | |
from itertools import product | |
import gdown | |
import gradio as gr | |
import matplotlib | |
import matplotlib.patches as patches | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
from matplotlib import pyplot as plt | |
from matplotlib.patches import ConnectionPatch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from common.evaluation import Evaluator | |
from common.logger import AverageMeter, Logger | |
from data import download | |
from model import chmnet | |
from model.base.geometry import Geometry | |
csv.field_size_limit(sys.maxsize) | |
# Downloading the Model | |
md5 = "6b7b4d7bad7f89600fac340d6aa7708b" | |
gdown.cached_download( | |
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", | |
path="pas_psi.pt", | |
quiet=False, | |
md5=md5, | |
) | |
# Model Initialization | |
args = dict( | |
{ | |
"alpha": [0.05, 0.1], | |
"benchmark": "pfpascal", | |
"bsz": 90, | |
"datapath": "../Datasets_CHM", | |
"img_size": 240, | |
"ktype": "psi", | |
"load": "pas_psi.pt", | |
"thres": "img", | |
} | |
) | |
model = chmnet.CHMNet(args["ktype"]) | |
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) | |
Evaluator.initialize(args["alpha"]) | |
Geometry.initialize(img_size=args["img_size"]) | |
model.eval() | |
# Transforms | |
chm_transform = transforms.Compose( | |
[ | |
transforms.Resize(args["img_size"]), | |
transforms.CenterCrop((args["img_size"], args["img_size"])), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
chm_transform_plot = transforms.Compose( | |
[ | |
transforms.Resize(args["img_size"]), | |
transforms.CenterCrop((args["img_size"], args["img_size"])), | |
] | |
) | |
# A Helper Function | |
to_np = lambda x: x.data.to("cpu").numpy() | |
# Colors for Plotting | |
cmap = matplotlib.cm.get_cmap("Spectral") | |
rgba = cmap(0.5) | |
colors = [] | |
for k in range(49): | |
colors.append(cmap(k / 49.0)) | |
# CHM MODEL | |
def run_chm( | |
source_image, | |
target_image, | |
selected_points, | |
number_src_points, | |
chm_transform, | |
display_transform, | |
): | |
# Convert to Tensor | |
src_img_tnsr = chm_transform(source_image).unsqueeze(0) | |
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) | |
# Selected_points = selected_points.T | |
keypoints = torch.tensor(selected_points).unsqueeze(0) | |
n_pts = torch.tensor(np.asarray([number_src_points])) | |
# RUN CHM ------------------------------------------------------------------------ | |
with torch.no_grad(): | |
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) | |
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) | |
# VISUALIZATION | |
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() | |
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() | |
src_points_converted = [] | |
w, h = display_transform(source_image).size | |
for x, y in zip(src_points[0], src_points[1]): | |
src_points_converted.append( | |
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])] | |
) | |
src_points_converted = np.asarray(src_points_converted[:number_src_points]) | |
tgt_points_converted = [] | |
w, h = display_transform(target_image).size | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_points_converted.append( | |
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] | |
) | |
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) | |
tgt_grid = [] | |
for x, y in zip(tgt_points[0], tgt_points[1]): | |
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) | |
# VISUALIZATION | |
# PLOT | |
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) | |
# Source image plot | |
ax[0].imshow(display_transform(source_image)) | |
ax[0].scatter( | |
src_points_converted[:, 0], | |
src_points_converted[:, 1], | |
c="blue", | |
edgecolors="white", | |
s=50, | |
label="Source points", | |
) | |
ax[0].set_title("Source Image with Selected Points") | |
ax[0].set_xticks([]) | |
ax[0].set_yticks([]) | |
# Target image plot | |
ax[1].imshow(display_transform(target_image)) | |
ax[1].scatter( | |
tgt_points_converted[:, 0], | |
tgt_points_converted[:, 1], | |
c="red", | |
edgecolors="white", | |
s=50, | |
label="Target points", | |
) | |
ax[1].set_title("Target Image with Corresponding Points") | |
ax[1].set_xticks([]) | |
ax[1].set_yticks([]) | |
# Adding labels to points | |
for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): | |
ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5)) | |
ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7)) | |
# Create a colormap that will generate 49 distinct colors | |
cmap = plt.get_cmap( | |
"gist_rainbow", 49 | |
) # 'gist_rainbow' is just an example, you can choose another colormap | |
# Drawing lines between corresponding source and target points | |
# for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): | |
# con = ConnectionPatch( | |
# xyA=tgt, | |
# xyB=src, | |
# coordsA="data", | |
# coordsB="data", | |
# axesA=ax[1], | |
# axesB=ax[0], | |
# color=cmap(i), | |
# linewidth=2, | |
# ) | |
# ax[1].add_artist(con) | |
# Adding legend | |
ax[0].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) | |
ax[1].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) | |
plt.tight_layout() | |
plt.subplots_adjust(wspace=0.1, hspace=0.1) | |
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) | |
return fig | |
# Wrapper | |
def generate_correspondences( | |
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 | |
): | |
A = np.linspace(min_x, max_x, 7) | |
B = np.linspace(min_y, max_y, 7) | |
point_list = list(product(A, B)) | |
new_points = np.asarray(point_list, dtype=np.float64).T | |
return run_chm( | |
sousrce_image, | |
target_image, | |
selected_points=new_points, | |
number_src_points=49, | |
chm_transform=chm_transform, | |
display_transform=chm_transform_plot, | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Correspondence Matching with Convolutional Hough Matching Networks | |
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. | |
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) | |
""" | |
) | |
with gr.Row(): | |
# Add an Image component to display the source image. | |
image1 = gr.Image( | |
height=240, | |
width=240, | |
type="pil", | |
label="Source Image", | |
) | |
# Add an Image component to display the target image. | |
image2 = gr.Image( | |
height=240, | |
width=240, | |
type="pil", | |
label="Target Image", | |
) | |
with gr.Row(): | |
# Add a Slider component to adjust the minimum x-coordinate of the grid. | |
min_x = gr.Slider( | |
minimum=1, | |
maximum=240, | |
step=1, | |
value=15, | |
label="Min X", | |
) | |
# Add a Slider component to adjust the maximum x-coordinate of the grid. | |
max_x = gr.Slider( | |
minimum=1, | |
maximum=240, | |
step=1, | |
value=215, | |
label="Max X", | |
) | |
# Add a Slider component to adjust the minimum y-coordinate of the grid. | |
min_y = gr.Slider( | |
minimum=1, | |
maximum=240, | |
step=1, | |
value=15, | |
label="Min Y", | |
) | |
# Add a Slider component to adjust the maximum y-coordinate of the grid. | |
max_y = gr.Slider( | |
minimum=1, | |
maximum=240, | |
step=1, | |
value=215, | |
label="Max Y", | |
) | |
with gr.Row(): | |
output_plot = gr.Plot() | |
gr.Examples( | |
[ | |
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], | |
[ | |
"./examples/Red_Winged_Blackbird_0012_6015.jpg", | |
"./examples/Red_Winged_Blackbird_0025_5342.jpg", | |
17, | |
223, | |
17, | |
223, | |
], | |
[ | |
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg", | |
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg", | |
17, | |
223, | |
17, | |
223, | |
], | |
], | |
inputs=[ | |
image1, | |
image2, | |
min_x, | |
max_x, | |
min_y, | |
max_y, | |
], | |
) | |
run_btn = gr.Button("Run") | |
run_btn.click( | |
generate_correspondences, | |
inputs=[image1, image2, min_x, max_x, min_y, max_y], | |
outputs=output_plot, | |
) | |
demo.launch() | |