Spaces:
Running
Running
File size: 4,878 Bytes
5da8629 caf926d 5da8629 caf926d 5da8629 caf926d 5da8629 caf926d 5da8629 c330bfa caf926d c330bfa 5da8629 caf926d 5da8629 35d8a2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import numpy as np
import pandas as pd
import gradio as gr
import tensorflow as tf
from annoy import AnnoyIndex
from tensorflow import keras
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, (224, 224))
image = tf.image.convert_image_dtype(image, tf.float32)
image = image/255.
return image.numpy()
# Specify Database Path
database_path = './AnimalSubset'
# Create Example Images
class_names = []
with open('./Animal-ClassNames.txt', mode='r') as names:
class_names = names.read().split(',')[:-1]
example_image_paths = [
'./AnimalSubset/Beetle/Beetle-Train (101).jpeg',
'./AnimalSubset/Butterfly/Butterfly-train (1042).jpeg',
'./AnimalSubset/Cat/Cat-Train (1004).jpeg',
'./AnimalSubset/Cow/Cow-Train (1022).jpeg',
'./AnimalSubset/Dog/Dog-Train (1144).jpeg',
'./AnimalSubset/Elephant/Elephant-Train (1043).jpeg',
'./AnimalSubset/Gorilla/Gorilla (1045).jpeg',
'./AnimalSubset/Hippo/Hippo - Train (1133).jpeg',
'./AnimalSubset/Lizard/Lizard-Train (161).jpeg',
'./AnimalSubset/Monkey/M (224).jpeg',
'./AnimalSubset/Mouse/Mouse-Train (1225).jpeg',
'./AnimalSubset/Panda/Panda (1992).jpeg',
'./AnimalSubset/Spider/Spider-Train (1191).jpeg',
'./AnimalSubset/Tiger/Tiger (1020).jpeg',
'./AnimalSubset/Zebra/Zebra (975).jpeg'
]
example_images = [load_image(path) for path in example_image_paths]
# Load Feature Extractor
feature_extractor_path = './Animal-FeatureExtractor.keras'
feature_extractor = keras.models.load_model(
feature_extractor_path, compile=False)
# Load Annoy index
index_path = './AnimalSubset.ann'
annoy_index = AnnoyIndex(256, 'angular')
annoy_index.load(index_path)
def similarity_search(
query_image, num_images=5, *_,
feature_extractor=feature_extractor,
annoy_index=annoy_index,
database_path=database_path,
metadata_path='./Animals.csv'
):
if np.max(query_image) == 255:
query_image = query_image/255.
query_vector = feature_extractor.predict(
query_image[np.newaxis, ...], verbose=0)[0]
# Compute nearest neighbors
nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
# Load metadata
metadata = pd.read_csv(metadata_path, index_col=0)
metadata = metadata.iloc[nearest_neighbors]
closest_class = metadata.class_name.values[0]
# Similar Images
similar_images_paths = [
os.path.join(database_path, class_name, file_name)
for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
]
similar_images = [load_image(img) for img in similar_images_paths]
# return closest_class, similar_images
image_gallery = gr.Gallery(
value=similar_images,
label='Similar Images',
object_fit='fill',
preview=True,
visible=True,
)
return closest_class, image_gallery, similar_images_paths
# Gradio Application
with gr.Blocks(theme='soft') as app:
gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)")
gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.")
with gr.Row(equal_height=True):
# Image Input
query_image = gr.Image(
label='Query Image',
sources=['upload', 'clipboard'],
height='50vh'
)
# Output Gallery Display
output_gallery = gr.Gallery(visible=False)
# Hidden output for similar images paths
similar_paths_output = gr.Textbox(visible=False)
with gr.Row(equal_height=True):
# Predicted Class
pred_class = gr.Textbox(
label='Predicted Class', placeholder='Let the model think!!...')
# Number of images to search
n_images = gr.Slider(
value=10,
label='Number of images to search',
minimum=1,
maximum=99,
step=1
)
# Search Button
search_btn = gr.Button('Search')
# Example Images
examples = gr.Examples(
examples=example_images,
inputs=query_image,
label='Something similar to me??',
)
# Input - On Change
query_image.change(
fn=similarity_search,
inputs=[query_image, n_images],
outputs=[pred_class, output_gallery, similar_paths_output]
)
# Search - On Click
search_btn.click(
fn=similarity_search,
inputs=[query_image, n_images],
outputs=[pred_class, output_gallery, similar_paths_output]
)
if __name__ == '__main__':
app.launch()
# pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')])
# print(pred_class) |