DeepNets's picture
Update app.py
076433e verified
raw
history blame
3.71 kB
import os
import numpy as np
import pandas as pd
import gradio as gr
from glob import glob
import tensorflow as tf
from annoy import AnnoyIndex
from tensorflow import keras
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (224, 224))
image = tf.image.convert_image_dtype(image, tf.float32)
image = image/255.
return image.numpy()
# Specify Database Path
database_path = os.path.join('.', 'FastFood-DB')
# Create Example Images
class_names = []
with open(os.path.join('.', 'Fast Food-ClassNames.txt'), mode='r') as names:
class_names = names.read().split(',')[:-1]
example_image_paths = [
glob(os.path.join(database_path, name, '*'))[0]
for name in class_names
]
example_images = [load_image(path) for path in example_image_paths]
# Load Feature Extractor
feature_extractor_path = os.path.join('.', 'Fast Food-FeatureExtractor.keras')
feature_extractor = keras.models.load_model(
feature_extractor_path, compile=False)
# Load Annoy index
index_path = os.path.join('.', 'Fast FoodSubset.ann')
annoy_index = AnnoyIndex(256, 'angular')
annoy_index.load(index_path)
def similarity_search(query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path):
if np.max(query_image) == 255:
query_image = query_image/255.
query_vector = feature_extractor.predict(
query_image[np.newaxis, ...], verbose=0)[0]
# Compute nearest neighbors
nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
# Load metadata
metadata_path = os.path.join('.', 'Fast Foods.csv')
metadata = pd.read_csv(metadata_path, index_col=0).iloc[nearest_neighbors]
closest_class = metadata.class_name.values[0]
# Similar Images
similar_images = [
load_image(os.path.join(database_path, class_name, file_name))
for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
]
image_gallery = gr.Gallery(
value=similar_images,
label='Similar Images',
object_fit='fill',
preview=True,
visible=True,
)
return closest_class, image_gallery
# Gradio Application
with gr.Blocks(theme='soft') as app:
gr.Markdown("# Fast Food - Content Based Image Retrieval (CBIR)")
gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.")
with gr.Row(equal_height=True):
# Image Input
query_image = gr.Image(
label='Query Image',
sources=['upload', 'clipboard'],
height='50vh'
)
# Output Gallery Display
output_gallery = gr.Gallery(visible=False)
with gr.Row(equal_height=True):
# Predicted Class
pred_class = gr.Textbox(
label='Predicted Class', placeholder='Let the model think!!...')
# Number of images to search
n_images = gr.Slider(
value=10,
label='Number of images to search',
minimum=1,
maximum=99,
step=1
)
# Search Button
search_btn = gr.Button('Search')
# Example Images
examples = gr.Examples(
examples=example_images,
inputs=query_image,
label='Something similar to me??'
)
# Search - On Click
search_btn.click(
fn=similarity_search,
inputs=[query_image, n_images],
outputs=[pred_class, output_gallery]
)
if __name__ == '__main__':
app.launch()