Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from glob import glob | |
import tensorflow as tf | |
from annoy import AnnoyIndex | |
from tensorflow import keras | |
def load_image(image_path): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = tf.image.resize(image, (224, 224)) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = image/255. | |
return image.numpy() | |
# Specify Database Path | |
database_path = './FlowerSubset' | |
# Create Example Images | |
class_names = [] | |
with open('./Flower-ClassNames.txt', mode='r') as names: | |
class_names = names.read().split(',')[:-1] | |
example_image_paths = [ | |
glob(os.path.join(database_path, name, '*'))[0] | |
if name != 'Sunflower' else | |
'./FlowerSubset/Sunflower/Sunflower-Train (320).jpeg' | |
for name in class_names | |
] | |
example_images = [load_image(path) for path in example_image_paths] | |
# Load Feature Extractor | |
feature_extractor_path = './Flower-FeatureExtractor.keras' | |
feature_extractor = keras.models.load_model( | |
feature_extractor_path, compile=False) | |
# Load Annoy index | |
index_path = './FlowerSubset.ann' | |
annoy_index = AnnoyIndex(256, 'angular') | |
annoy_index.load(index_path) | |
def similarity_search( | |
query_image, num_images=5, *_, | |
feature_extractor=feature_extractor, | |
annoy_index=annoy_index, | |
database_path=database_path, | |
metadata_path='./Flowers.csv' | |
): | |
if np.max(query_image) == 255: | |
query_image = query_image/255. | |
query_vector = feature_extractor.predict( | |
query_image[np.newaxis, ...], verbose=0)[0] | |
# Compute nearest neighbors | |
nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) | |
# Load metadata | |
metadata = pd.read_csv(metadata_path, index_col=0) | |
metadata = metadata.iloc[nearest_neighbors] | |
closest_class = metadata.class_name.values[0] | |
# Similar Images | |
similar_images_paths = [ | |
os.path.join(database_path, class_name, file_name) | |
for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) | |
] | |
similar_images = [load_image(img) for img in similar_images_paths] | |
image_gallery = gr.Gallery( | |
value=similar_images, | |
label='Similar Images', | |
object_fit='fill', | |
preview=True, | |
visible=True, | |
height='50vh' | |
) | |
return closest_class, image_gallery, similar_images_paths | |
# Gradio Application | |
with gr.Blocks(theme='soft') as app: | |
gr.Markdown("# Flower - Content Based Image Retrieval (CBIR)") | |
gr.Markdown( | |
f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") | |
gr.Markdown( | |
"Disclaimer:- Model might suggest incorrect images, try using a different image.") | |
with gr.Row(equal_height=True): | |
# Image Input | |
query_image = gr.Image( | |
label='Query Image', | |
sources=['upload', 'clipboard'], | |
height='50vh' | |
) | |
# Output Gallery Display | |
output_gallery = gr.Gallery(visible=False) | |
# Hidden output for similar images paths | |
similar_paths_output = gr.Textbox(visible=False) | |
with gr.Row(equal_height=True): | |
# Predicted Class | |
pred_class = gr.Textbox( | |
label='Predicted Class', placeholder='Let the model think!!...') | |
# Number of images to search | |
n_images = gr.Slider( | |
value=10, | |
label='Number of images to search', | |
minimum=1, | |
maximum=99, | |
step=1 | |
) | |
# Search Button | |
search_btn = gr.Button('Search') | |
# Example Images | |
examples = gr.Examples( | |
examples=example_images, | |
inputs=query_image, | |
label='Something similar to me??', | |
) | |
# Input - On Change | |
query_image.change( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
# Search - On Click | |
search_btn.click( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
if __name__ == '__main__': | |
app.launch() | |