Spaces:
Running
Running
import os | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
import tensorflow as tf | |
from annoy import AnnoyIndex | |
from tensorflow import keras | |
def load_image(image_path): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = tf.image.resize(image, (224, 224)) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = image/255. | |
return image.numpy() | |
# Specify Database Path | |
database_path = './AnimalSubset' | |
# Create Example Images | |
class_names = [] | |
with open('./Animal-ClassNames.txt', mode='r') as names: | |
class_names = names.read().split(',')[:-1] | |
example_image_paths = [ | |
'./AnimalSubset/Beetle/Beetle-Train (101).jpeg', | |
'./AnimalSubset/Butterfly/Butterfly-train (1042).jpeg', | |
'./AnimalSubset/Cat/Cat-Train (1004).jpeg', | |
'./AnimalSubset/Cow/Cow-Train (1022).jpeg', | |
'./AnimalSubset/Dog/Dog-Train (1144).jpeg', | |
'./AnimalSubset/Elephant/Elephant-Train (1043).jpeg', | |
'./AnimalSubset/Gorilla/Gorilla (1045).jpeg', | |
'./AnimalSubset/Hippo/Hippo - Train (1133).jpeg', | |
'./AnimalSubset/Lizard/Lizard-Train (161).jpeg', | |
'./AnimalSubset/Monkey/M (224).jpeg', | |
'./AnimalSubset/Mouse/Mouse-Train (1225).jpeg', | |
'./AnimalSubset/Panda/Panda (1992).jpeg', | |
'./AnimalSubset/Spider/Spider-Train (1191).jpeg', | |
'./AnimalSubset/Tiger/Tiger (1020).jpeg', | |
'./AnimalSubset/Zebra/Zebra (975).jpeg' | |
] | |
example_images = [load_image(path) for path in example_image_paths] | |
# Load Feature Extractor | |
feature_extractor_path = './Animal-FeatureExtractor.keras' | |
feature_extractor = keras.models.load_model( | |
feature_extractor_path, compile=False) | |
# Load Annoy index | |
index_path = './AnimalSubset.ann' | |
annoy_index = AnnoyIndex(256, 'angular') | |
annoy_index.load(index_path) | |
def similarity_search( | |
query_image, num_images=5, *_, | |
feature_extractor=feature_extractor, | |
annoy_index=annoy_index, | |
database_path=database_path, | |
metadata_path='./Animals.csv' | |
): | |
if np.max(query_image) == 255: | |
query_image = query_image/255. | |
query_vector = feature_extractor.predict( | |
query_image[np.newaxis, ...], verbose=0)[0] | |
# Compute nearest neighbors | |
nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) | |
# Load metadata | |
metadata = pd.read_csv(metadata_path, index_col=0) | |
metadata = metadata.iloc[nearest_neighbors] | |
closest_class = metadata.class_name.values[0] | |
# Similar Images | |
similar_images_paths = [ | |
os.path.join(database_path, class_name, file_name) | |
for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) | |
] | |
similar_images = [load_image(img) for img in similar_images_paths] | |
# return closest_class, similar_images | |
image_gallery = gr.Gallery( | |
value=similar_images, | |
label='Similar Images', | |
object_fit='fill', | |
preview=True, | |
visible=True, | |
) | |
return closest_class, image_gallery, similar_images_paths | |
# Gradio Application | |
with gr.Blocks(theme='soft') as app: | |
gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)") | |
gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") | |
gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.") | |
with gr.Row(equal_height=True): | |
# Image Input | |
query_image = gr.Image( | |
label='Query Image', | |
sources=['upload', 'clipboard'], | |
height='50vh' | |
) | |
# Output Gallery Display | |
output_gallery = gr.Gallery(visible=False) | |
# Hidden output for similar images paths | |
similar_paths_output = gr.Textbox(visible=False) | |
with gr.Row(equal_height=True): | |
# Predicted Class | |
pred_class = gr.Textbox( | |
label='Predicted Class', placeholder='Let the model think!!...') | |
# Number of images to search | |
n_images = gr.Slider( | |
value=10, | |
label='Number of images to search', | |
minimum=1, | |
maximum=99, | |
step=1 | |
) | |
# Search Button | |
search_btn = gr.Button('Search') | |
# Example Images | |
examples = gr.Examples( | |
examples=example_images, | |
inputs=query_image, | |
label='Something similar to me??', | |
) | |
# Input - On Change | |
query_image.change( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
# Search - On Click | |
search_btn.click( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
if __name__ == '__main__': | |
app.launch() | |
# pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')]) | |
# print(pred_class) |