Spaces:
Sleeping
Sleeping
import csv | |
import sys | |
import gradio as gr | |
import numpy as np | |
import skimage.transform | |
import torch | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from matplotlib import pyplot as plt | |
from numpy import matlib as mb | |
from PIL import Image | |
csv.field_size_limit(sys.maxsize) | |
def compute_spatial_similarity(conv1, conv2): | |
""" | |
Takes in the last convolutional layer from two images, computes the pooled output | |
feature, and then generates the spatial similarity map for both images. | |
""" | |
conv1 = conv1.reshape(-1, 7 * 7).T | |
conv2 = conv2.reshape(-1, 7 * 7).T | |
pool1 = np.mean(conv1, axis=0) | |
pool2 = np.mean(conv2, axis=0) | |
out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0]))) | |
conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0] | |
conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0] | |
im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0])) | |
for zz in range(conv1_normed.shape[0]): | |
repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1) | |
im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1) | |
similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz) | |
similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz) | |
return similarity1, similarity2 | |
# Get Layer 4 | |
display_transform = transforms.Compose( | |
[transforms.Resize(256), transforms.CenterCrop((224, 224))] | |
) | |
imagenet_transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
class Wrapper(torch.nn.Module): | |
def __init__(self, model): | |
super(Wrapper, self).__init__() | |
self.model = model | |
self.layer4_ouputs = None | |
def fw_hook(module, input, output): | |
self.layer4_ouputs = output | |
self.model.layer4.register_forward_hook(fw_hook) | |
def forward(self, input): | |
_ = self.model(input) | |
return self.layer4_ouputs | |
def __repr__(self): | |
return "Wrapper" | |
def get_layer4(input_image): | |
l4_model = models.resnet50(pretrained=True) | |
# l4_model = l4_model.cuda() | |
l4_model.eval() | |
wrapped_model = Wrapper(l4_model) | |
with torch.no_grad(): | |
data = imagenet_transform(input_image).unsqueeze(0) | |
# data = data.cuda() | |
reference_layer4 = wrapped_model(data) | |
return reference_layer4.data.to("cpu").numpy() | |
def NormalizeData(data): | |
return (data - np.min(data)) / (np.max(data) - np.min(data)) | |
# Visualization | |
def visualize_similarities(q, n): | |
image1 = Image.fromarray(q) | |
image2 = Image.fromarray(n) | |
a = get_layer4(image1).squeeze() | |
b = get_layer4(image2).squeeze() | |
sim1, sim2 = compute_spatial_similarity(a, b) | |
sim1 = NormalizeData(sim1) | |
sim2 = NormalizeData(sim2) | |
fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
axes[0].imshow(display_transform(image1)) | |
im1 = axes[0].imshow( | |
skimage.transform.resize(sim1, (224, 224)), | |
alpha=0.5, | |
cmap="jet", | |
vmin=0, | |
vmax=1, | |
) | |
axes[1].imshow(display_transform(image2)) | |
im2 = axes[1].imshow( | |
skimage.transform.resize(sim2, (224, 224)), | |
alpha=0.5, | |
cmap="jet", | |
vmin=0, | |
vmax=1, | |
) | |
axes[0].set_axis_off() | |
axes[1].set_axis_off() | |
fig.colorbar(im1, ax=axes[0]) | |
fig.colorbar(im2, ax=axes[1]) | |
plt.tight_layout() | |
q_image = display_transform(image1) | |
nearest_image = display_transform(image2) | |
# make a binarized veruin of the Q | |
fig2, ax = plt.subplots(1, figsize=(5, 5)) | |
ax.imshow(display_transform(image1)) | |
# create a binarized version of sim1 , for value below 0.5 set to 0 and above 0.5 set to 1 | |
sim1_bin = np.where(sim1 > 0.5, 1, 0) | |
print(sim1_bin) | |
# create a binarized version of sim2 , for value below 0.5 set to 0 and above 0.5 set to 1 | |
sim2_bin = np.where(sim2 > 0.5, 1, 0) | |
ax.imshow( | |
skimage.transform.resize(sim1_bin, (224, 224)), | |
alpha=1, | |
cmap="binary", | |
vmin=0, | |
vmax=1, | |
) | |
return fig, q_image, nearest_image, fig2 | |
# GRADIO APP | |
main = gr.Interface( | |
fn=visualize_similarities, | |
inputs=["image", "image"], | |
allow_flagging="never", | |
outputs=["plot", "image", "image", "plot"], | |
cache_examples=True, | |
enable_queue=False, | |
examples=[ | |
[ | |
"./examples/Red_Winged_Blackbird_0012_6015.jpg", | |
"./examples/Red_Winged_Blackbird_0025_5342.jpg", | |
], | |
["./examples/Q.jpg", "./examples/1.jpg"], | |
], | |
) | |
# iface.launch() | |
blocks = gr.Blocks() | |
with blocks: | |
gr.Markdown( | |
""" | |
# Visualizing Deep Similarity Networks | |
A quick demo to visualize the similarity between two images. | |
[Original Paper](https://arxiv.org/pdf/1901.00536.pdf) - [Github Page](https://github.com/GWUvision/Similarity-Visualization) | |
""" | |
) | |
gr.TabbedInterface([main], ["Main"]) | |
blocks.launch(debug=True) | |