File size: 3,571 Bytes
6aa4ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import numpy as np
import pandas as pd
import gradio as gr
from glob import glob
import tensorflow as tf
from annoy import AnnoyIndex
from tensorflow import keras


def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (224, 224))
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = image/255.
    return image.numpy()


# Specify Database Path
database_path = os.path.join('.', 'FastFood-DB')

# Create Example Images
class_names = []
with open(os.path.join('.', 'Fast Food-ClassNames.txt'), mode='r') as names:
    class_names = names.read().split(',')[:-1]

example_image_paths = [
    glob(os.path.join(database_path, name, '*'))[0]
    for name in class_names
]
example_images = [load_image(path) for path in example_image_paths]

# Load Feature Extractor
feature_extractor_path = os.path.join('.', 'Fast Food-FeatureExtractor.keras')
feature_extractor = keras.models.load_model(
    feature_extractor_path, compile=False)

# Load Annoy index
index_path = os.path.join('.', 'Fast FoodSubset.ann')
annoy_index = AnnoyIndex(256, 'angular')
annoy_index.load(index_path)


def similarity_search(query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path):

    if np.max(query_image) == 255:
        query_image = query_image/255.

    query_vector = feature_extractor.predict(
        query_image[np.newaxis, ...], verbose=0)[0]

    # Compute nearest neighbors
    nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)

    # Load metadata
    metadata_path = os.path.join('.', 'Fast Foods.csv')
    metadata = pd.read_csv(metadata_path, index_col=0).iloc[nearest_neighbors]
    closest_class = metadata.class_name.values[0]

    # Similar Images
    similar_images = [
        load_image(os.path.join(database_path, class_name, file_name))
        for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
    ]
    image_gallery = gr.Gallery(
        value=similar_images,
        label='Similar Images',
        object_fit='fill',
        preview=True,
        visible=True,
    )
    return closest_class, image_gallery


# Gradio Application
with gr.Blocks(theme='soft') as app:

    with gr.Row(equal_height=True):
        # Image Input
        query_image = gr.Image(
            label='Query Image',
            sources=['upload', 'clipboard'],
            height='70vh'
        )

        # Output Gallery Display
        output_gallery = gr.Gallery(visible=False)

    with gr.Row(equal_height=True):

        # Predicted Class
        pred_class = gr.Textbox(
            label='Predicted Class', placeholder='Let the model think!!...')

        # Number of images to search
        n_images = gr.Slider(
            value=10,
            label='Number of images to search',
            minimum=1,
            maximum=99,
            step=1
        )

        # Search Button
        search_btn = gr.Button('Search')

    # Example Images
    examples = gr.Examples(
        examples=example_images,
        inputs=query_image,
        label='Something similar to me??'
    )

    # Search - On Click
    search_btn.click(
        fn=similarity_search,
        inputs=[query_image, n_images],
        outputs=[pred_class, output_gallery]
    )

if __name__ == '__main__':
    app.launch()