import os import numpy as np import pandas as pd import gradio as gr import tensorflow as tf from annoy import AnnoyIndex from tensorflow import keras def load_image(image_path): image = tf.io.read_file(image_path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, (224, 224)) image = tf.image.convert_image_dtype(image, tf.float32) image = image/255. return image.numpy() # Specify Database Path database_path = './AnimalSubset' # Create Example Images class_names = [] with open('./Animal-ClassNames.txt', mode='r') as names: class_names = names.read().split(',')[:-1] example_image_paths = [ './AnimalSubset/Beetle/Beetle-Train (101).jpeg', './AnimalSubset/Butterfly/Butterfly-train (1042).jpeg', './AnimalSubset/Cat/Cat-Train (1004).jpeg', './AnimalSubset/Cow/Cow-Train (1022).jpeg', './AnimalSubset/Dog/Dog-Train (1144).jpeg', './AnimalSubset/Elephant/Elephant-Train (1043).jpeg', './AnimalSubset/Gorilla/Gorilla (1045).jpeg', './AnimalSubset/Hippo/Hippo - Train (1133).jpeg', './AnimalSubset/Lizard/Lizard-Train (161).jpeg', './AnimalSubset/Monkey/M (224).jpeg', './AnimalSubset/Mouse/Mouse-Train (1225).jpeg', './AnimalSubset/Panda/Panda (1992).jpeg', './AnimalSubset/Spider/Spider-Train (1191).jpeg', './AnimalSubset/Tiger/Tiger (1020).jpeg', './AnimalSubset/Zebra/Zebra (975).jpeg' ] example_images = [load_image(path) for path in example_image_paths] # Load Feature Extractor feature_extractor_path = './Animal-FeatureExtractor.keras' feature_extractor = keras.models.load_model( feature_extractor_path, compile=False) # Load Annoy index index_path = './AnimalSubset.ann' annoy_index = AnnoyIndex(256, 'angular') annoy_index.load(index_path) def similarity_search( query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path, metadata_path='./Animals.csv' ): if np.max(query_image) == 255: query_image = query_image/255. query_vector = feature_extractor.predict( query_image[np.newaxis, ...], verbose=0)[0] # Compute nearest neighbors nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) # Load metadata metadata = pd.read_csv(metadata_path, index_col=0) metadata = metadata.iloc[nearest_neighbors] closest_class = metadata.class_name.values[0] # Similar Images similar_images_paths = [ os.path.join(database_path, class_name, file_name) for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) ] similar_images = [load_image(img) for img in similar_images_paths] # return closest_class, similar_images image_gallery = gr.Gallery( value=similar_images, label='Similar Images', object_fit='fill', preview=True, visible=True, ) return closest_class, image_gallery, similar_images_paths # Gradio Application with gr.Blocks(theme='soft') as app: gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)") gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.") with gr.Row(equal_height=True): # Image Input query_image = gr.Image( label='Query Image', sources=['upload', 'clipboard'], height='50vh' ) # Output Gallery Display output_gallery = gr.Gallery(visible=False) # Hidden output for similar images paths similar_paths_output = gr.Textbox(visible=False) with gr.Row(equal_height=True): # Predicted Class pred_class = gr.Textbox( label='Predicted Class', placeholder='Let the model think!!...') # Number of images to search n_images = gr.Slider( value=10, label='Number of images to search', minimum=1, maximum=99, step=1 ) # Search Button search_btn = gr.Button('Search') # Example Images examples = gr.Examples( examples=example_images, inputs=query_image, label='Something similar to me??', ) # Input - On Change query_image.change( fn=similarity_search, inputs=[query_image, n_images], outputs=[pred_class, output_gallery, similar_paths_output] ) # Search - On Click search_btn.click( fn=similarity_search, inputs=[query_image, n_images], outputs=[pred_class, output_gallery, similar_paths_output] ) if __name__ == '__main__': app.launch() # pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')]) # print(pred_class)