DeepNets commited on
Commit
cd5d082
·
verified ·
1 Parent(s): 7f037d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -144
app.py CHANGED
@@ -1,144 +1,148 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- from glob import glob
6
- import tensorflow as tf
7
- from annoy import AnnoyIndex
8
- from tensorflow import keras
9
-
10
-
11
- def load_image(image_path):
12
- image = tf.io.read_file(image_path)
13
- image = tf.image.decode_jpeg(image, channels=3)
14
- image = tf.image.resize(image, (224, 224))
15
- image = tf.image.convert_image_dtype(image, tf.float32)
16
- image = image/255.
17
- return image.numpy()
18
-
19
-
20
- # Specify Database Path
21
- database_path = './LandscapeSubset'
22
-
23
- # Create Example Images
24
- class_names = []
25
- with open('./Landscape-ClassNames.txt', mode='r') as names:
26
- class_names = names.read().split(',')[:-1]
27
-
28
- example_image_paths = [
29
- glob(os.path.join(database_path, name, '*'))[0]
30
- for name in class_names
31
- ]
32
- example_images = [load_image(path) for path in example_image_paths]
33
-
34
- # Load Feature Extractor
35
- feature_extractor_path = './Landscape-FeatureExtractor.keras'
36
- feature_extractor = keras.models.load_model(
37
- feature_extractor_path, compile=False)
38
-
39
- # Load Annoy index
40
- index_path = './LandscapeSubset.ann'
41
- annoy_index = AnnoyIndex(256, 'angular')
42
- annoy_index.load(index_path)
43
-
44
-
45
- def similarity_search(
46
- query_image, num_images=5, *_,
47
- feature_extractor=feature_extractor,
48
- annoy_index=annoy_index,
49
- database_path=database_path,
50
- metadata_path='./Landscapes.csv'
51
- ):
52
-
53
- if np.max(query_image) == 255:
54
- query_image = query_image/255.
55
-
56
- query_vector = feature_extractor.predict(
57
- query_image[np.newaxis, ...], verbose=0)[0]
58
-
59
- # Compute nearest neighbors
60
- nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
61
-
62
- # Load metadata
63
- metadata = pd.read_csv(metadata_path, index_col=0)
64
- metadata = metadata.iloc[nearest_neighbors]
65
- closest_class = metadata.class_name.values[0]
66
-
67
- # Similar Images
68
- similar_images = [
69
- load_image(os.path.join(database_path, class_name, file_name))
70
- for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
71
- ]
72
-
73
- image_gallery = gr.Gallery(
74
- value=similar_images,
75
- label='Similar Images',
76
- object_fit='fill',
77
- preview=True,
78
- visible=True,
79
- height='50vh'
80
- )
81
- return closest_class, image_gallery
82
-
83
-
84
- # Gradio Application
85
- with gr.Blocks(theme='soft') as app:
86
-
87
- gr.Markdown("# Landscape - Content Based Image Retrieval (CBIR)")
88
- gr.Markdown(
89
- f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
90
- gr.Markdown(
91
- "Disclaimer:- Model might suggest incorrect images, try using a different image.")
92
-
93
- with gr.Row(equal_height=True):
94
- # Image Input
95
- query_image = gr.Image(
96
- label='Query Image',
97
- sources=['upload', 'clipboard'],
98
- height='50vh'
99
- )
100
-
101
- # Output Gallery Display
102
- output_gallery = gr.Gallery(visible=False)
103
-
104
- with gr.Row(equal_height=True):
105
-
106
- # Predicted Class
107
- pred_class = gr.Textbox(
108
- label='Predicted Class', placeholder='Let the model think!!...')
109
-
110
- # Number of images to search
111
- n_images = gr.Slider(
112
- value=10,
113
- label='Number of images to search',
114
- minimum=1,
115
- maximum=99,
116
- step=1
117
- )
118
-
119
- # Search Button
120
- search_btn = gr.Button('Search')
121
-
122
- # Example Images
123
- examples = gr.Examples(
124
- examples=example_images,
125
- inputs=query_image,
126
- label='Something similar to me??',
127
- )
128
-
129
- # Input - On Change
130
- query_image.change(
131
- fn=similarity_search,
132
- inputs=[query_image, n_images],
133
- outputs=[pred_class, output_gallery]
134
- )
135
-
136
- # Search - On Click
137
- search_btn.click(
138
- fn=similarity_search,
139
- inputs=[query_image, n_images],
140
- outputs=[pred_class, output_gallery]
141
- )
142
-
143
- if __name__ == '__main__':
144
- app.launch()
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from glob import glob
6
+ import tensorflow as tf
7
+ from annoy import AnnoyIndex
8
+ from tensorflow import keras
9
+
10
+
11
+ def load_image(image_path):
12
+ image = tf.io.read_file(image_path)
13
+ image = tf.image.decode_jpeg(image, channels=3)
14
+ image = tf.image.resize(image, (224, 224))
15
+ image = tf.image.convert_image_dtype(image, tf.float32)
16
+ image = image/255.
17
+ return image.numpy()
18
+
19
+
20
+ # Specify Database Path
21
+ database_path = './LandscapeSubset'
22
+
23
+ # Create Example Images
24
+ class_names = []
25
+ with open('./Landscape-ClassNames.txt', mode='r') as names:
26
+ class_names = names.read().split(',')[:-1]
27
+
28
+ example_image_paths = [
29
+ glob(os.path.join(database_path, name, '*'))[0]
30
+ for name in class_names
31
+ ]
32
+ example_images = [load_image(path) for path in example_image_paths]
33
+
34
+ # Load Feature Extractor
35
+ feature_extractor_path = './Landscape-FeatureExtractor.keras'
36
+ feature_extractor = keras.models.load_model(
37
+ feature_extractor_path, compile=False)
38
+
39
+ # Load Annoy index
40
+ index_path = './LandscapeSubset.ann'
41
+ annoy_index = AnnoyIndex(256, 'angular')
42
+ annoy_index.load(index_path)
43
+
44
+
45
+ def similarity_search(
46
+ query_image, num_images=5, *_,
47
+ feature_extractor=feature_extractor,
48
+ annoy_index=annoy_index,
49
+ database_path=database_path,
50
+ metadata_path='./Landscapes.csv'
51
+ ):
52
+
53
+ if np.max(query_image) == 255:
54
+ query_image = query_image/255.
55
+
56
+ query_vector = feature_extractor.predict(
57
+ query_image[np.newaxis, ...], verbose=0)[0]
58
+
59
+ # Compute nearest neighbors
60
+ nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
61
+
62
+ # Load metadata
63
+ metadata = pd.read_csv(metadata_path, index_col=0)
64
+ metadata = metadata.iloc[nearest_neighbors]
65
+ closest_class = metadata.class_name.values[0]
66
+
67
+ # Similar Images
68
+ similar_images_paths = [
69
+ os.path.join(database_path, class_name, file_name)
70
+ for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
71
+ ]
72
+ similar_images = [load_image(img) for img in similar_images_paths]
73
+
74
+ image_gallery = gr.Gallery(
75
+ value=similar_images,
76
+ label='Similar Images',
77
+ object_fit='fill',
78
+ preview=True,
79
+ visible=True,
80
+ height='50vh'
81
+ )
82
+ return closest_class, image_gallery, similar_images_paths
83
+
84
+
85
+ # Gradio Application
86
+ with gr.Blocks(theme='soft') as app:
87
+
88
+ gr.Markdown("# Landscape - Content Based Image Retrieval (CBIR)")
89
+ gr.Markdown(
90
+ f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
91
+ gr.Markdown(
92
+ "Disclaimer:- Model might suggest incorrect images, try using a different image.")
93
+
94
+ with gr.Row(equal_height=True):
95
+ # Image Input
96
+ query_image = gr.Image(
97
+ label='Query Image',
98
+ sources=['upload', 'clipboard'],
99
+ height='50vh'
100
+ )
101
+
102
+ # Output Gallery Display
103
+ output_gallery = gr.Gallery(visible=False)
104
+
105
+ # Hidden output for similar images paths
106
+ similar_paths_output = gr.Textbox(visible=False)
107
+
108
+ with gr.Row(equal_height=True):
109
+
110
+ # Predicted Class
111
+ pred_class = gr.Textbox(
112
+ label='Predicted Class', placeholder='Let the model think!!...')
113
+
114
+ # Number of images to search
115
+ n_images = gr.Slider(
116
+ value=10,
117
+ label='Number of images to search',
118
+ minimum=1,
119
+ maximum=99,
120
+ step=1
121
+ )
122
+
123
+ # Search Button
124
+ search_btn = gr.Button('Search')
125
+
126
+ # Example Images
127
+ examples = gr.Examples(
128
+ examples=example_images,
129
+ inputs=query_image,
130
+ label='Something similar to me??',
131
+ )
132
+
133
+ # Input - On Change
134
+ query_image.change(
135
+ fn=similarity_search,
136
+ inputs=[query_image, n_images],
137
+ outputs=[pred_class, output_gallery, similar_paths_output]
138
+ )
139
+
140
+ # Search - On Click
141
+ search_btn.click(
142
+ fn=similarity_search,
143
+ inputs=[query_image, n_images],
144
+ outputs=[pred_class, output_gallery, similar_paths_output]
145
+ )
146
+
147
+ if __name__ == '__main__':
148
+ app.launch()