Spaces:
Running
Running
import os | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from glob import glob | |
import tensorflow as tf | |
from annoy import AnnoyIndex | |
from tensorflow import keras | |
def load_image(image_path): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_jpeg(image, channels=3) | |
image = tf.image.resize(image, (224, 224)) | |
image = tf.image.convert_image_dtype(image, tf.float32) | |
image = image/255. | |
return image.numpy() | |
# Specify Database Path | |
database_path = os.path.join('.', 'FastFood-DB') | |
# Create Example Images | |
class_names = [] | |
with open(os.path.join('.', 'Fast Food-ClassNames.txt'), mode='r') as names: | |
class_names = names.read().split(',')[:-1] | |
example_image_paths = [ | |
glob(os.path.join(database_path, name, '*'))[0] | |
for name in class_names | |
] | |
example_images = [load_image(path) for path in example_image_paths] | |
# Load Feature Extractor | |
feature_extractor_path = os.path.join('.', 'Fast Food-FeatureExtractor.keras') | |
feature_extractor = keras.models.load_model( | |
feature_extractor_path, compile=False) | |
# Load Annoy index | |
index_path = os.path.join('.', 'Fast FoodSubset.ann') | |
annoy_index = AnnoyIndex(256, 'angular') | |
annoy_index.load(index_path) | |
def similarity_search(query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path): | |
if np.max(query_image) == 255: | |
query_image = query_image/255. | |
query_vector = feature_extractor.predict( | |
query_image[np.newaxis, ...], verbose=0)[0] | |
# Compute nearest neighbors | |
nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) | |
# Load metadata | |
metadata_path = os.path.join('.', 'Fast Foods.csv') | |
metadata = pd.read_csv(metadata_path, index_col=0).iloc[nearest_neighbors] | |
closest_class = metadata.class_name.values[0] | |
# Similar Images | |
similar_images_paths = [ | |
os.path.join(database_path, class_name, file_name) | |
for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) | |
] | |
similar_images = [load_image(img) for img in similar_images_paths] | |
image_gallery = gr.Gallery( | |
value=similar_images, | |
label='Similar Images', | |
object_fit='fill', | |
preview=True, | |
visible=True, | |
) | |
return closest_class, image_gallery, similar_images_paths | |
# Gradio Application | |
with gr.Blocks(theme='soft') as app: | |
gr.Markdown("# Fast Food - Content Based Image Retrieval (CBIR)") | |
gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") | |
gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.") | |
with gr.Row(equal_height=True): | |
# Image Input | |
query_image = gr.Image( | |
label='Query Image', | |
sources=['upload', 'clipboard'], | |
height='50vh' | |
) | |
# Output Gallery Display | |
output_gallery = gr.Gallery(visible=False) | |
# Hidden output for similar images paths | |
similar_paths_output = gr.Textbox(visible=False) | |
with gr.Row(equal_height=True): | |
# Predicted Class | |
pred_class = gr.Textbox( | |
label='Predicted Class', placeholder='Let the model think!!...') | |
# Number of images to search | |
n_images = gr.Slider( | |
value=10, | |
label='Number of images to search', | |
minimum=1, | |
maximum=99, | |
step=1 | |
) | |
# Search Button | |
search_btn = gr.Button('Search') | |
# Example Images | |
examples = gr.Examples( | |
examples=example_images, | |
inputs=query_image, | |
label='Something similar to me??' | |
) | |
# Input - On Change | |
query_image.change( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
# Search - On Click | |
search_btn.click( | |
fn=similarity_search, | |
inputs=[query_image, n_images], | |
outputs=[pred_class, output_gallery, similar_paths_output] | |
) | |
if __name__ == '__main__': | |
app.launch() | |