import os import numpy as np import pandas as pd import gradio as gr from glob import glob import tensorflow as tf from annoy import AnnoyIndex from tensorflow import keras def load_image(image_path): image = tf.io.read_file(image_path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, (224, 224)) image = tf.image.convert_image_dtype(image, tf.float32) image = image/255. return image.numpy() # Specify Database Path database_path = './FruitSubset' # Create Example Images class_names = [] with open('./Fruit-ClassNames.txt', mode='r') as names: class_names = names.read().split(',')[:-1] example_image_paths = [ glob(os.path.join(database_path, name, '*'))[0] if name != 'Mango' else './FruitSubset/Mango/Mango (1018).jpeg' for name in class_names ] example_images = [load_image(path) for path in example_image_paths] # Load Feature Extractor feature_extractor_path = './Fruit-FeatureExtractor.keras' feature_extractor = keras.models.load_model( feature_extractor_path, compile=False) # Load Annoy index index_path = './FruitSubset.ann' annoy_index = AnnoyIndex(256, 'angular') annoy_index.load(index_path) def similarity_search( query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path, metadata_path='./Fruits.csv' ): if np.max(query_image) == 255: query_image = query_image/255. query_vector = feature_extractor.predict( query_image[np.newaxis, ...], verbose=0)[0] # Compute nearest neighbors nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) # Load metadata metadata = pd.read_csv(metadata_path, index_col=0) metadata = metadata.iloc[nearest_neighbors] closest_class = metadata.class_name.values[0] # Similar Images similar_images_paths = [ os.path.join(database_path, class_name, file_name) for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) ] similar_images = [load_image(img) for img in similar_images_paths] image_gallery = gr.Gallery( value=similar_images, label='Similar Images', object_fit='fill', preview=True, visible=True, height='50vh' ) return closest_class, image_gallery, similar_images_paths # Gradio Application with gr.Blocks(theme='soft') as app: gr.Markdown("# Fruit - Content Based Image Retrieval (CBIR)") gr.Markdown( f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") gr.Markdown( "Disclaimer:- Model might suggest incorrect images, try using a different image.") with gr.Row(equal_height=True): # Image Input query_image = gr.Image( label='Query Image', sources=['upload', 'clipboard'], height='50vh' ) # Output Gallery Display output_gallery = gr.Gallery(visible=False) # Hidden output for similar images paths similar_paths_output = gr.Textbox(visible=False) with gr.Row(equal_height=True): # Predicted Class pred_class = gr.Textbox( label='Predicted Class', placeholder='Let the model think!!...') # Number of images to search n_images = gr.Slider( value=10, label='Number of images to search', minimum=1, maximum=99, step=1 ) # Search Button search_btn = gr.Button('Search') # Example Images examples = gr.Examples( examples=example_images, inputs=query_image, label='Something similar to me??', ) # Input - On Change query_image.change( fn=similarity_search, inputs=[query_image, n_images], outputs=[pred_class, output_gallery, similar_paths_output] ) # Search - On Click search_btn.click( fn=similarity_search, inputs=[query_image, n_images], outputs=[pred_class, output_gallery, similar_paths_output] ) if __name__ == '__main__': app.launch()