Spaces:
Sleeping
Sleeping
add more options for GIS
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import time
|
|
4 |
import random
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
7 |
-
import gradio as gr
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
10 |
from models import get_model
|
@@ -83,7 +83,9 @@ _search_params = {
|
|
83 |
|
84 |
|
85 |
# Gradio UI
|
86 |
-
def inference(query, labels, n_supp=10
|
|
|
|
|
87 |
'''
|
88 |
query: PIL image
|
89 |
labels: list of class names
|
@@ -91,6 +93,12 @@ def inference(query, labels, n_supp=10):
|
|
91 |
labels = labels.split(',')
|
92 |
n_supp = int(n_supp)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
95 |
|
96 |
with torch.no_grad():
|
@@ -104,9 +112,8 @@ def inference(query, labels, n_supp=10):
|
|
104 |
for idx, y in enumerate(labels):
|
105 |
gis = GoogleImagesSearch(args.api_key, args.cx)
|
106 |
_search_params['q'] = y
|
107 |
-
_search_params['num'] = n_supp
|
108 |
gis.search(search_params=_search_params, custom_image_name='my_image')
|
109 |
-
gis._custom_image_name = 'my_image'
|
110 |
|
111 |
for j, x in enumerate(gis.results()):
|
112 |
x.download('./')
|
@@ -135,9 +142,10 @@ def inference(query, labels, n_supp=10):
|
|
135 |
|
136 |
|
137 |
# DEBUG
|
138 |
-
|
|
|
139 |
##labels = 'dog, cat'
|
140 |
-
#labels = 'girl,
|
141 |
#output = inference(query, labels, n_supp=2)
|
142 |
#print(output)
|
143 |
|
@@ -146,7 +154,11 @@ gr.Interface(fn=inference,
|
|
146 |
inputs=[
|
147 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
148 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
149 |
-
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples
|
|
|
|
|
|
|
|
|
150 |
],
|
151 |
theme="grass",
|
152 |
outputs=[
|
|
|
4 |
import random
|
5 |
import torch
|
6 |
import torchvision.transforms as transforms
|
7 |
+
#import gradio as gr
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
10 |
from models import get_model
|
|
|
83 |
|
84 |
|
85 |
# Gradio UI
|
86 |
+
def inference(query, labels, n_supp=10,
|
87 |
+
file_type='png', rights='cc_publicdomain',
|
88 |
+
image_type='photo', color_type='color'):
|
89 |
'''
|
90 |
query: PIL image
|
91 |
labels: list of class names
|
|
|
93 |
labels = labels.split(',')
|
94 |
n_supp = int(n_supp)
|
95 |
|
96 |
+
_search_params['num'] = n_supp
|
97 |
+
_search_params['fileType'] = file_type
|
98 |
+
_search_params['rights'] = rights
|
99 |
+
_search_params['imgType'] = image_type
|
100 |
+
_search_params['imgColorType'] = color_type
|
101 |
+
|
102 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
103 |
|
104 |
with torch.no_grad():
|
|
|
112 |
for idx, y in enumerate(labels):
|
113 |
gis = GoogleImagesSearch(args.api_key, args.cx)
|
114 |
_search_params['q'] = y
|
|
|
115 |
gis.search(search_params=_search_params, custom_image_name='my_image')
|
116 |
+
gis._custom_image_name = 'my_image' # fix: image name sometimes too long
|
117 |
|
118 |
for j, x in enumerate(gis.results()):
|
119 |
x.download('./')
|
|
|
142 |
|
143 |
|
144 |
# DEBUG
|
145 |
+
##query = Image.open('../labrador-puppy.jpg')
|
146 |
+
#query = Image.open('/Users/hushell/Documents/Dan_tr.png')
|
147 |
##labels = 'dog, cat'
|
148 |
+
#labels = 'girl, sussie'
|
149 |
#output = inference(query, labels, n_supp=2)
|
150 |
#print(output)
|
151 |
|
|
|
154 |
inputs=[
|
155 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
156 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
157 |
+
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="GIS: Number of support examples per class"),
|
158 |
+
gr.inputs.Dropdown(['png', 'jpg'], default='png', label='GIS: Image file type'),
|
159 |
+
gr.inputs.Dropdown(['cc_publicdomain', 'cc_attribute', 'cc_sharealike', 'cc_noncommercial', 'cc_nonderived'], default='cc_publicdomain', label='GIS: Copy rights'),
|
160 |
+
gr.inputs.Dropdown(['clipart', 'face', 'lineart', 'stock', 'photo', 'animated', 'imgTypeUndefined'], default='photo', label='GIS: Image type'),
|
161 |
+
gr.inputs.Dropdown(['color', 'gray', 'mono', 'trans', 'imgColorTypeUndefined'], default='color', label='GIS: Image color type'),
|
162 |
],
|
163 |
theme="grass",
|
164 |
outputs=[
|