DeepNets commited on
Commit
ab114c7
·
verified ·
1 Parent(s): d9a1e46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -123
app.py CHANGED
@@ -1,123 +1,127 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- from glob import glob
6
- import tensorflow as tf
7
- from annoy import AnnoyIndex
8
- from tensorflow import keras
9
-
10
-
11
- def load_image(image_path):
12
- image = tf.io.read_file(image_path)
13
- image = tf.image.decode_jpeg(image, channels=3)
14
- image = tf.image.resize(image, (224, 224))
15
- image = tf.image.convert_image_dtype(image, tf.float32)
16
- image = image/255.
17
- return image.numpy()
18
-
19
-
20
- # Specify Database Path
21
- database_path = os.path.join('.', 'FastFood-DB')
22
-
23
- # Create Example Images
24
- class_names = []
25
- with open(os.path.join('.', 'Fast Food-ClassNames.txt'), mode='r') as names:
26
- class_names = names.read().split(',')[:-1]
27
-
28
- example_image_paths = [
29
- glob(os.path.join(database_path, name, '*'))[0]
30
- for name in class_names
31
- ]
32
- example_images = [load_image(path) for path in example_image_paths]
33
-
34
- # Load Feature Extractor
35
- feature_extractor_path = os.path.join('.', 'Fast Food-FeatureExtractor.keras')
36
- feature_extractor = keras.models.load_model(
37
- feature_extractor_path, compile=False)
38
-
39
- # Load Annoy index
40
- index_path = os.path.join('.', 'Fast FoodSubset.ann')
41
- annoy_index = AnnoyIndex(256, 'angular')
42
- annoy_index.load(index_path)
43
-
44
-
45
- def similarity_search(query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path):
46
-
47
- if np.max(query_image) == 255:
48
- query_image = query_image/255.
49
-
50
- query_vector = feature_extractor.predict(
51
- query_image[np.newaxis, ...], verbose=0)[0]
52
-
53
- # Compute nearest neighbors
54
- nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
55
-
56
- # Load metadata
57
- metadata_path = os.path.join('.', 'Fast Foods.csv')
58
- metadata = pd.read_csv(metadata_path, index_col=0).iloc[nearest_neighbors]
59
- closest_class = metadata.class_name.values[0]
60
-
61
- # Similar Images
62
- similar_images = [
63
- load_image(os.path.join(database_path, class_name, file_name))
64
- for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
65
- ]
66
- image_gallery = gr.Gallery(
67
- value=similar_images,
68
- label='Similar Images',
69
- object_fit='fill',
70
- preview=True,
71
- visible=True,
72
- )
73
- return closest_class, image_gallery
74
-
75
-
76
- # Gradio Application
77
- with gr.Blocks(theme='soft') as app:
78
-
79
- with gr.Row(equal_height=True):
80
- # Image Input
81
- query_image = gr.Image(
82
- label='Query Image',
83
- sources=['upload', 'clipboard'],
84
- height='70vh'
85
- )
86
-
87
- # Output Gallery Display
88
- output_gallery = gr.Gallery(visible=False)
89
-
90
- with gr.Row(equal_height=True):
91
-
92
- # Predicted Class
93
- pred_class = gr.Textbox(
94
- label='Predicted Class', placeholder='Let the model think!!...')
95
-
96
- # Number of images to search
97
- n_images = gr.Slider(
98
- value=10,
99
- label='Number of images to search',
100
- minimum=1,
101
- maximum=99,
102
- step=1
103
- )
104
-
105
- # Search Button
106
- search_btn = gr.Button('Search')
107
-
108
- # Example Images
109
- examples = gr.Examples(
110
- examples=example_images,
111
- inputs=query_image,
112
- label='Something similar to me??'
113
- )
114
-
115
- # Search - On Click
116
- search_btn.click(
117
- fn=similarity_search,
118
- inputs=[query_image, n_images],
119
- outputs=[pred_class, output_gallery]
120
- )
121
-
122
- if __name__ == '__main__':
123
- app.launch()
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from glob import glob
6
+ import tensorflow as tf
7
+ from annoy import AnnoyIndex
8
+ from tensorflow import keras
9
+
10
+
11
+ def load_image(image_path):
12
+ image = tf.io.read_file(image_path)
13
+ image = tf.image.decode_jpeg(image, channels=3)
14
+ image = tf.image.resize(image, (224, 224))
15
+ image = tf.image.convert_image_dtype(image, tf.float32)
16
+ image = image/255.
17
+ return image.numpy()
18
+
19
+
20
+ # Specify Database Path
21
+ database_path = os.path.join('.', 'FastFood-DB')
22
+
23
+ # Create Example Images
24
+ class_names = []
25
+ with open(os.path.join('.', 'Fast Food-ClassNames.txt'), mode='r') as names:
26
+ class_names = names.read().split(',')[:-1]
27
+
28
+ example_image_paths = [
29
+ glob(os.path.join(database_path, name, '*'))[0]
30
+ for name in class_names
31
+ ]
32
+ example_images = [load_image(path) for path in example_image_paths]
33
+
34
+ # Load Feature Extractor
35
+ feature_extractor_path = os.path.join('.', 'Fast Food-FeatureExtractor.keras')
36
+ feature_extractor = keras.models.load_model(
37
+ feature_extractor_path, compile=False)
38
+
39
+ # Load Annoy index
40
+ index_path = os.path.join('.', 'Fast FoodSubset.ann')
41
+ annoy_index = AnnoyIndex(256, 'angular')
42
+ annoy_index.load(index_path)
43
+
44
+
45
+ def similarity_search(query_image, num_images=5, *_, feature_extractor=feature_extractor, annoy_index=annoy_index, database_path=database_path):
46
+
47
+ if np.max(query_image) == 255:
48
+ query_image = query_image/255.
49
+
50
+ query_vector = feature_extractor.predict(
51
+ query_image[np.newaxis, ...], verbose=0)[0]
52
+
53
+ # Compute nearest neighbors
54
+ nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
55
+
56
+ # Load metadata
57
+ metadata_path = os.path.join('.', 'Fast Foods.csv')
58
+ metadata = pd.read_csv(metadata_path, index_col=0).iloc[nearest_neighbors]
59
+ closest_class = metadata.class_name.values[0]
60
+
61
+ # Similar Images
62
+ similar_images = [
63
+ load_image(os.path.join(database_path, class_name, file_name))
64
+ for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
65
+ ]
66
+ image_gallery = gr.Gallery(
67
+ value=similar_images,
68
+ label='Similar Images',
69
+ object_fit='fill',
70
+ preview=True,
71
+ visible=True,
72
+ )
73
+ return closest_class, image_gallery
74
+
75
+
76
+ # Gradio Application
77
+ with gr.Blocks(theme='soft') as app:
78
+
79
+ gr.Markdown("# Fast Food - Content Based Image Retrieval (CBIR)")
80
+ gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
81
+
82
+
83
+ with gr.Row(equal_height=True):
84
+ # Image Input
85
+ query_image = gr.Image(
86
+ label='Query Image',
87
+ sources=['upload', 'clipboard'],
88
+ height='70vh'
89
+ )
90
+
91
+ # Output Gallery Display
92
+ output_gallery = gr.Gallery(visible=False)
93
+
94
+ with gr.Row(equal_height=True):
95
+
96
+ # Predicted Class
97
+ pred_class = gr.Textbox(
98
+ label='Predicted Class', placeholder='Let the model think!!...')
99
+
100
+ # Number of images to search
101
+ n_images = gr.Slider(
102
+ value=10,
103
+ label='Number of images to search',
104
+ minimum=1,
105
+ maximum=99,
106
+ step=1
107
+ )
108
+
109
+ # Search Button
110
+ search_btn = gr.Button('Search')
111
+
112
+ # Example Images
113
+ examples = gr.Examples(
114
+ examples=example_images,
115
+ inputs=query_image,
116
+ label='Something similar to me??'
117
+ )
118
+
119
+ # Search - On Click
120
+ search_btn.click(
121
+ fn=similarity_search,
122
+ inputs=[query_image, n_images],
123
+ outputs=[pred_class, output_gallery]
124
+ )
125
+
126
+ if __name__ == '__main__':
127
+ app.launch()