DeepNets commited on
Commit
5da8629
·
verified ·
1 Parent(s): 35d8a2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -149
app.py CHANGED
@@ -1,150 +1,150 @@
1
- import os
2
- import numpy as np
3
- import pandas as pd
4
- import gradio as gr
5
- import tensorflow as tf
6
- from annoy import AnnoyIndex
7
- from tensorflow import keras
8
-
9
-
10
- def load_image(image_path):
11
- image = tf.io.read_file(image_path)
12
- image = tf.image.decode_jpeg(image, channels=3)
13
- image = tf.image.resize(image, (224, 224))
14
- image = tf.image.convert_image_dtype(image, tf.float32)
15
- image = image/255.
16
- return image.numpy()
17
-
18
-
19
- # Specify Database Path
20
- database_path = './AnimalSubset'
21
-
22
- # Create Example Images
23
- class_names = []
24
- with open('./Animal-ClassNames.txt', mode='r') as names:
25
- class_names = names.read().split(',')[:-1]
26
-
27
- example_image_paths = [
28
- './AnimalSubset/Beetle/Beetle-Train (101).jpeg',
29
- './AnimalSubset/Butterfly/Butterfly-train (1042).jpeg',
30
- './AnimalSubset/Cat/Cat-Train (1004).jpeg',
31
- './AnimalSubset/Cow/Cow-Train (1022).jpeg',
32
- './AnimalSubset/Dog/Dog-Train (1144).jpeg',
33
- './AnimalSubset/Elephant/Elephant-Train (1043).jpeg',
34
- './AnimalSubset/Gorilla/Gorilla (1045).jpeg',
35
- './AnimalSubset/Hippo/Hippo - Train (1133).jpeg',
36
- './AnimalSubset/Lizard/Lizard-Train (161).jpeg',
37
- './AnimalSubset/Monkey/M (224).jpeg',
38
- './AnimalSubset/Mouse/Mouse-Train (1225).jpeg',
39
- './AnimalSubset/Panda/Panda (1992).jpeg',
40
- './AnimalSubset/Spider/Spider-Train (1191).jpeg',
41
- './AnimalSubset/Tiger/Tiger (1020).jpeg',
42
- './AnimalSubset/Zebra/Zebra (975).jpeg'
43
- ]
44
- example_images = [load_image(path) for path in example_image_paths]
45
-
46
- # Load Feature Extractor
47
- feature_extractor_path = './Animal-FeatureExtractor.keras'
48
- feature_extractor = keras.models.load_model(
49
- feature_extractor_path, compile=False)
50
-
51
- # Load Annoy index
52
- index_path = './AnimalSubset.ann'
53
- annoy_index = AnnoyIndex(256, 'angular')
54
- annoy_index.load(index_path)
55
-
56
-
57
- def similarity_search(
58
- query_image, num_images=5, *_,
59
- feature_extractor=feature_extractor,
60
- annoy_index=annoy_index,
61
- database_path=database_path,
62
- metadata_path='./Animals.csv'
63
- ):
64
-
65
- if np.max(query_image) == 255:
66
- query_image = query_image/255.
67
-
68
- query_vector = feature_extractor.predict(
69
- query_image[np.newaxis, ...], verbose=0)[0]
70
-
71
- # Compute nearest neighbors
72
- nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
73
-
74
- # Load metadata
75
- metadata = pd.read_csv(metadata_path, index_col=0)
76
- metadata = metadata.iloc[nearest_neighbors]
77
- closest_class = metadata.class_name.values[0]
78
-
79
- # Similar Images
80
- similar_images = [
81
- load_image(os.path.join(database_path, class_name, file_name))
82
- for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
83
- ]
84
-
85
- # return closest_class, similar_images
86
-
87
- image_gallery = gr.Gallery(
88
- value=similar_images,
89
- label='Similar Images',
90
- object_fit='fill',
91
- preview=True,
92
- visible=True,
93
- )
94
- return closest_class, image_gallery
95
-
96
-
97
- # Gradio Application
98
- with gr.Blocks(theme='soft') as app:
99
-
100
- gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)")
101
- gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
102
- gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.")
103
-
104
- with gr.Row(equal_height=True):
105
- # Image Input
106
- query_image = gr.Image(
107
- label='Query Image',
108
- sources=['upload', 'clipboard'],
109
- height='70vh'
110
- )
111
-
112
- # Output Gallery Display
113
- output_gallery = gr.Gallery(visible=False)
114
-
115
- with gr.Row(equal_height=True):
116
-
117
- # Predicted Class
118
- pred_class = gr.Textbox(
119
- label='Predicted Class', placeholder='Let the model think!!...')
120
-
121
- # Number of images to search
122
- n_images = gr.Slider(
123
- value=10,
124
- label='Number of images to search',
125
- minimum=1,
126
- maximum=99,
127
- step=1
128
- )
129
-
130
- # Search Button
131
- search_btn = gr.Button('Search')
132
-
133
- # Example Images
134
- examples = gr.Examples(
135
- examples=example_images,
136
- inputs=query_image,
137
- label='Something similar to me??',
138
- )
139
-
140
- # Search - On Click
141
- search_btn.click(
142
- fn=similarity_search,
143
- inputs=[query_image, n_images],
144
- outputs=[pred_class, output_gallery]
145
- )
146
-
147
- if __name__ == '__main__':
148
- app.launch()
149
- # pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')])
150
  # print(pred_class)
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import tensorflow as tf
6
+ from annoy import AnnoyIndex
7
+ from tensorflow import keras
8
+
9
+
10
+ def load_image(image_path):
11
+ image = tf.io.read_file(image_path)
12
+ image = tf.image.decode_jpeg(image, channels=3)
13
+ image = tf.image.resize(image, (224, 224))
14
+ image = tf.image.convert_image_dtype(image, tf.float32)
15
+ image = image/255.
16
+ return image.numpy()
17
+
18
+
19
+ # Specify Database Path
20
+ database_path = './AnimalSubset'
21
+
22
+ # Create Example Images
23
+ class_names = []
24
+ with open('./Animal-ClassNames.txt', mode='r') as names:
25
+ class_names = names.read().split(',')[:-1]
26
+
27
+ example_image_paths = [
28
+ './AnimalSubset/Beetle/Beetle-Train (101).jpeg',
29
+ './AnimalSubset/Butterfly/Butterfly-train (1042).jpeg',
30
+ './AnimalSubset/Cat/Cat-Train (1004).jpeg',
31
+ './AnimalSubset/Cow/Cow-Train (1022).jpeg',
32
+ './AnimalSubset/Dog/Dog-Train (1144).jpeg',
33
+ './AnimalSubset/Elephant/Elephant-Train (1043).jpeg',
34
+ './AnimalSubset/Gorilla/Gorilla (1045).jpeg',
35
+ './AnimalSubset/Hippo/Hippo - Train (1133).jpeg',
36
+ './AnimalSubset/Lizard/Lizard-Train (161).jpeg',
37
+ './AnimalSubset/Monkey/M (224).jpeg',
38
+ './AnimalSubset/Mouse/Mouse-Train (1225).jpeg',
39
+ './AnimalSubset/Panda/Panda (1992).jpeg',
40
+ './AnimalSubset/Spider/Spider-Train (1191).jpeg',
41
+ './AnimalSubset/Tiger/Tiger (1020).jpeg',
42
+ './AnimalSubset/Zebra/Zebra (975).jpeg'
43
+ ]
44
+ example_images = [load_image(path) for path in example_image_paths]
45
+
46
+ # Load Feature Extractor
47
+ feature_extractor_path = './Animal-FeatureExtractor.keras'
48
+ feature_extractor = keras.models.load_model(
49
+ feature_extractor_path, compile=False)
50
+
51
+ # Load Annoy index
52
+ index_path = './AnimalSubset.ann'
53
+ annoy_index = AnnoyIndex(256, 'angular')
54
+ annoy_index.load(index_path)
55
+
56
+
57
+ def similarity_search(
58
+ query_image, num_images=5, *_,
59
+ feature_extractor=feature_extractor,
60
+ annoy_index=annoy_index,
61
+ database_path=database_path,
62
+ metadata_path='./Animals.csv'
63
+ ):
64
+
65
+ if np.max(query_image) == 255:
66
+ query_image = query_image/255.
67
+
68
+ query_vector = feature_extractor.predict(
69
+ query_image[np.newaxis, ...], verbose=0)[0]
70
+
71
+ # Compute nearest neighbors
72
+ nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images)
73
+
74
+ # Load metadata
75
+ metadata = pd.read_csv(metadata_path, index_col=0)
76
+ metadata = metadata.iloc[nearest_neighbors]
77
+ closest_class = metadata.class_name.values[0]
78
+
79
+ # Similar Images
80
+ similar_images = [
81
+ load_image(os.path.join(database_path, class_name, file_name))
82
+ for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values)
83
+ ]
84
+
85
+ # return closest_class, similar_images
86
+
87
+ image_gallery = gr.Gallery(
88
+ value=similar_images,
89
+ label='Similar Images',
90
+ object_fit='fill',
91
+ preview=True,
92
+ visible=True,
93
+ )
94
+ return closest_class, image_gallery
95
+
96
+
97
+ # Gradio Application
98
+ with gr.Blocks(theme='soft') as app:
99
+
100
+ gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)")
101
+ gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}")
102
+ gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.")
103
+
104
+ with gr.Row(equal_height=True):
105
+ # Image Input
106
+ query_image = gr.Image(
107
+ label='Query Image',
108
+ sources=['upload', 'clipboard'],
109
+ height='50vh'
110
+ )
111
+
112
+ # Output Gallery Display
113
+ output_gallery = gr.Gallery(visible=False)
114
+
115
+ with gr.Row(equal_height=True):
116
+
117
+ # Predicted Class
118
+ pred_class = gr.Textbox(
119
+ label='Predicted Class', placeholder='Let the model think!!...')
120
+
121
+ # Number of images to search
122
+ n_images = gr.Slider(
123
+ value=10,
124
+ label='Number of images to search',
125
+ minimum=1,
126
+ maximum=99,
127
+ step=1
128
+ )
129
+
130
+ # Search Button
131
+ search_btn = gr.Button('Search')
132
+
133
+ # Example Images
134
+ examples = gr.Examples(
135
+ examples=example_images,
136
+ inputs=query_image,
137
+ label='Something similar to me??',
138
+ )
139
+
140
+ # Search - On Click
141
+ search_btn.click(
142
+ fn=similarity_search,
143
+ inputs=[query_image, n_images],
144
+ outputs=[pred_class, output_gallery]
145
+ )
146
+
147
+ if __name__ == '__main__':
148
+ app.launch()
149
+ # pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')])
150
  # print(pred_class)