DeepNets commited on
Commit
55d409d
·
verified ·
1 Parent(s): 7cc3f7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -146
app.py CHANGED
@@ -1,146 +1,150 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- from glob import glob
6
- import tensorflow as tf
7
- from annoy import AnnoyIndex
8
- from tensorflow import keras
9
-
10
-
11
- def load_image(image_path):
12
- image = tf.io.read_file(image_path)
13
- image = tf.image.decode_jpeg(image, channels=3)
14
- image = tf.image.resize(image, (224, 224))
15
- image = tf.image.convert_image_dtype(image, tf.float32)
16
- image = image/255.
17
- return image.numpy()
18
-
19
-
20
- # Specify Database Path
21
- database_path = './FruitSubset'
22
-
23
- # Create Example Images
24
- class_names = []
25
- with open('./Fruit-ClassNames.txt', mode='r') as names:
26
- class_names = names.read().split(',')[:-1]
27
-
28
- example_image_paths = [
29
- glob(os.path.join(database_path, name, '*'))[0]
30
- if name != 'Mango' else
31
- './FruitSubset/Mango/Mango (1018).jpeg'
32
- for name in class_names
33
- ]
34
- example_images = [load_image(path) for path in example_image_paths]
35
-
36
- # Load Feature Extractor
37
- feature_extractor_path = './Fruit-FeatureExtractor.keras'
38
- feature_extractor = keras.models.load_model(
39
- feature_extractor_path, compile=False)
40
-
41
- # Load Annoy index
42
- index_path = './FruitSubset.ann'
43
- annoy_index = AnnoyIndex(256, 'angular')
44
- annoy_index.load(index_path)
45
-
46
-
47
- def similarity_search(
48
- query_image, num_images=5, *_,
49
- feature_extractor=feature_extractor,
50
- annoy_index=annoy_index,
51
- database_path=database_path,
52
- metadata_path='./Fruits.csv'
53
- ):
54
-
55
- if np.max(query_image) == 255:
56
- query_image = query_image/255.
57
-
58
- query_vector = feature_extractor.predict(
59
- query_image[np.newaxis, ...], verbose=0)[0]
60
-
61
- # Compute nearest neighbors
62
- nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
63
-
64
- # Load metadata
65
- metadata = pd.read_csv(metadata_path, index_col=0)
66
- metadata = metadata.iloc[nearest_neighbors]
67
- closest_class = metadata.class_name.values[0]
68
-
69
- # Similar Images
70
- similar_images = [
71
- load_image(os.path.join(database_path, class_name, file_name))
72
- for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
73
- ]
74
-
75
- image_gallery = gr.Gallery(
76
- value=similar_images,
77
- label='Similar Images',
78
- object_fit='fill',
79
- preview=True,
80
- visible=True,
81
- height='50vh'
82
- )
83
- return closest_class, image_gallery
84
-
85
-
86
- # Gradio Application
87
- with gr.Blocks(theme='soft') as app:
88
-
89
- gr.Markdown("# Fruit - Content Based Image Retrieval (CBIR)")
90
- gr.Markdown(
91
- f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
92
- gr.Markdown(
93
- "Disclaimer:- Model might suggest incorrect images, try using a different image.")
94
-
95
- with gr.Row(equal_height=True):
96
- # Image Input
97
- query_image = gr.Image(
98
- label='Query Image',
99
- sources=['upload', 'clipboard'],
100
- height='50vh'
101
- )
102
-
103
- # Output Gallery Display
104
- output_gallery = gr.Gallery(visible=False)
105
-
106
- with gr.Row(equal_height=True):
107
-
108
- # Predicted Class
109
- pred_class = gr.Textbox(
110
- label='Predicted Class', placeholder='Let the model think!!...')
111
-
112
- # Number of images to search
113
- n_images = gr.Slider(
114
- value=10,
115
- label='Number of images to search',
116
- minimum=1,
117
- maximum=99,
118
- step=1
119
- )
120
-
121
- # Search Button
122
- search_btn = gr.Button('Search')
123
-
124
- # Example Images
125
- examples = gr.Examples(
126
- examples=example_images,
127
- inputs=query_image,
128
- label='Something similar to me??',
129
- )
130
-
131
- # Input - On Change
132
- query_image.change(
133
- fn=similarity_search,
134
- inputs=[query_image, n_images],
135
- outputs=[pred_class, output_gallery]
136
- )
137
-
138
- # Search - On Click
139
- search_btn.click(
140
- fn=similarity_search,
141
- inputs=[query_image, n_images],
142
- outputs=[pred_class, output_gallery]
143
- )
144
-
145
- if __name__ == '__main__':
146
- app.launch()
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from glob import glob
6
+ import tensorflow as tf
7
+ from annoy import AnnoyIndex
8
+ from tensorflow import keras
9
+
10
+
11
+ def load_image(image_path):
12
+ image = tf.io.read_file(image_path)
13
+ image = tf.image.decode_jpeg(image, channels=3)
14
+ image = tf.image.resize(image, (224, 224))
15
+ image = tf.image.convert_image_dtype(image, tf.float32)
16
+ image = image/255.
17
+ return image.numpy()
18
+
19
+
20
+ # Specify Database Path
21
+ database_path = './FruitSubset'
22
+
23
+ # Create Example Images
24
+ class_names = []
25
+ with open('./Fruit-ClassNames.txt', mode='r') as names:
26
+ class_names = names.read().split(',')[:-1]
27
+
28
+ example_image_paths = [
29
+ glob(os.path.join(database_path, name, '*'))[0]
30
+ if name != 'Mango' else
31
+ './FruitSubset/Mango/Mango (1018).jpeg'
32
+ for name in class_names
33
+ ]
34
+ example_images = [load_image(path) for path in example_image_paths]
35
+
36
+ # Load Feature Extractor
37
+ feature_extractor_path = './Fruit-FeatureExtractor.keras'
38
+ feature_extractor = keras.models.load_model(
39
+ feature_extractor_path, compile=False)
40
+
41
+ # Load Annoy index
42
+ index_path = './FruitSubset.ann'
43
+ annoy_index = AnnoyIndex(256, 'angular')
44
+ annoy_index.load(index_path)
45
+
46
+
47
+ def similarity_search(
48
+ query_image, num_images=5, *_,
49
+ feature_extractor=feature_extractor,
50
+ annoy_index=annoy_index,
51
+ database_path=database_path,
52
+ metadata_path='./Fruits.csv'
53
+ ):
54
+
55
+ if np.max(query_image) == 255:
56
+ query_image = query_image/255.
57
+
58
+ query_vector = feature_extractor.predict(
59
+ query_image[np.newaxis, ...], verbose=0)[0]
60
+
61
+ # Compute nearest neighbors
62
+ nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
63
+
64
+ # Load metadata
65
+ metadata = pd.read_csv(metadata_path, index_col=0)
66
+ metadata = metadata.iloc[nearest_neighbors]
67
+ closest_class = metadata.class_name.values[0]
68
+
69
+ # Similar Images
70
+ similar_images_paths = [
71
+ os.path.join(database_path, class_name, file_name)
72
+ for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
73
+ ]
74
+ similar_images = [load_image(img) for img in similar_images_paths]
75
+
76
+ image_gallery = gr.Gallery(
77
+ value=similar_images,
78
+ label='Similar Images',
79
+ object_fit='fill',
80
+ preview=True,
81
+ visible=True,
82
+ height='50vh'
83
+ )
84
+ return closest_class, image_gallery, similar_images_paths
85
+
86
+
87
+ # Gradio Application
88
+ with gr.Blocks(theme='soft') as app:
89
+
90
+ gr.Markdown("# Fruit - Content Based Image Retrieval (CBIR)")
91
+ gr.Markdown(
92
+ f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
93
+ gr.Markdown(
94
+ "Disclaimer:- Model might suggest incorrect images, try using a different image.")
95
+
96
+ with gr.Row(equal_height=True):
97
+ # Image Input
98
+ query_image = gr.Image(
99
+ label='Query Image',
100
+ sources=['upload', 'clipboard'],
101
+ height='50vh'
102
+ )
103
+
104
+ # Output Gallery Display
105
+ output_gallery = gr.Gallery(visible=False)
106
+
107
+ # Hidden output for similar images paths
108
+ similar_paths_output = gr.Textbox(visible=False)
109
+
110
+ with gr.Row(equal_height=True):
111
+
112
+ # Predicted Class
113
+ pred_class = gr.Textbox(
114
+ label='Predicted Class', placeholder='Let the model think!!...')
115
+
116
+ # Number of images to search
117
+ n_images = gr.Slider(
118
+ value=10,
119
+ label='Number of images to search',
120
+ minimum=1,
121
+ maximum=99,
122
+ step=1
123
+ )
124
+
125
+ # Search Button
126
+ search_btn = gr.Button('Search')
127
+
128
+ # Example Images
129
+ examples = gr.Examples(
130
+ examples=example_images,
131
+ inputs=query_image,
132
+ label='Something similar to me??',
133
+ )
134
+
135
+ # Input - On Change
136
+ query_image.change(
137
+ fn=similarity_search,
138
+ inputs=[query_image, n_images],
139
+ outputs=[pred_class, output_gallery, similar_paths_output]
140
+ )
141
+
142
+ # Search - On Click
143
+ search_btn.click(
144
+ fn=similarity_search,
145
+ inputs=[query_image, n_images],
146
+ outputs=[pred_class, output_gallery, similar_paths_output]
147
+ )
148
+
149
+ if __name__ == '__main__':
150
+ app.launch()