File size: 4,185 Bytes
38158d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import numpy as np
import pandas as pd
import gradio as gr
from glob import glob
import tensorflow as tf
from annoy import AnnoyIndex
from tensorflow import keras


def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (224, 224))
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = image/255.
    return image.numpy()


# Specify Database Path
database_path = './FlowerSubset'

# Create Example Images
class_names = []
with open('./Flower-ClassNames.txt', mode='r') as names:
    class_names = names.read().split(',')[:-1]

example_image_paths = [
    glob(os.path.join(database_path, name, '*'))[0]
    if name != 'Sunflower' else
    './FlowerSubset/Sunflower/Sunflower-Train (320).jpeg'
    for name in class_names
]
example_images = [load_image(path) for path in example_image_paths]

# Load Feature Extractor
feature_extractor_path = './Flower-FeatureExtractor.keras'
feature_extractor = keras.models.load_model(
    feature_extractor_path, compile=False)

# Load Annoy index
index_path = './FlowerSubset.ann'
annoy_index = AnnoyIndex(256, 'angular')
annoy_index.load(index_path)


def similarity_search(
    query_image, num_images=5, *_,
    feature_extractor=feature_extractor,
    annoy_index=annoy_index,
    database_path=database_path,
    metadata_path='./Flowers.csv'
):

    if np.max(query_image) == 255:
        query_image = query_image/255.

    query_vector = feature_extractor.predict(
        query_image[np.newaxis, ...], verbose=0)[0]

    # Compute nearest neighbors
    nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)

    # Load metadata
    metadata = pd.read_csv(metadata_path, index_col=0)
    metadata = metadata.iloc[nearest_neighbors]
    closest_class = metadata.class_name.values[0]

    # Similar Images
    similar_images_paths = [
        os.path.join(database_path, class_name, file_name)
        for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
    ]
    similar_images = [load_image(img) for img in similar_images_paths]

    image_gallery = gr.Gallery(
        value=similar_images,
        label='Similar Images',
        object_fit='fill',
        preview=True,
        visible=True,
        height='50vh'
    )
    return closest_class, image_gallery, similar_images_paths


# Gradio Application
with gr.Blocks(theme='soft') as app:

    gr.Markdown("# Flower - Content Based Image Retrieval (CBIR)")
    gr.Markdown(
        f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
    gr.Markdown(
        "Disclaimer:- Model might suggest incorrect images, try using a different image.")

    with gr.Row(equal_height=True):
        # Image Input
        query_image = gr.Image(
            label='Query Image',
            sources=['upload', 'clipboard'],
            height='50vh'
        )

        # Output Gallery Display
        output_gallery = gr.Gallery(visible=False)

    # Hidden output for similar images paths
    similar_paths_output = gr.Textbox(visible=False)

    with gr.Row(equal_height=True):

        # Predicted Class
        pred_class = gr.Textbox(
            label='Predicted Class', placeholder='Let the model think!!...')

        # Number of images to search
        n_images = gr.Slider(
            value=10,
            label='Number of images to search',
            minimum=1,
            maximum=99,
            step=1
        )

        # Search Button
        search_btn = gr.Button('Search')

    # Example Images
    examples = gr.Examples(
        examples=example_images,
        inputs=query_image,
        label='Something similar to me??',
    )

    # Input - On Change
    query_image.change(
        fn=similarity_search,
        inputs=[query_image, n_images],
        outputs=[pred_class, output_gallery, similar_paths_output]
    )

    # Search - On Click
    search_btn.click(
        fn=similarity_search,
        inputs=[query_image, n_images],
        outputs=[pred_class, output_gallery, similar_paths_output]
    )

if __name__ == '__main__':
    app.launch()