Upload app.py
Browse files
app.py
CHANGED
@@ -166,6 +166,8 @@ classes = ['antelope',
|
|
166 |
'tiger',
|
167 |
'zebra']
|
168 |
|
|
|
|
|
169 |
def predict(img):
|
170 |
# Add preprocessing and fix inference mode
|
171 |
model.eval()
|
@@ -179,15 +181,18 @@ def predict(img):
|
|
179 |
])
|
180 |
|
181 |
img_tensor = preprocess(img).unsqueeze(0) # Fix spelling of unsqueeze
|
|
|
182 |
with torch.inference_mode(): # Add parentheses
|
183 |
logits = model(img_tensor)
|
184 |
preds = logits.argmax(dim=1)
|
185 |
return classes[preds.item()]
|
186 |
|
|
|
|
|
187 |
"""gradio interface"""
|
188 |
demo = gr.Interface(
|
189 |
fn=predict,
|
190 |
-
inputs=gr.Image(type="pil",
|
191 |
outputs="label", # Output component
|
192 |
title="Animal Classifier",
|
193 |
description="Classify images into 30 animal categories"
|
|
|
166 |
'tiger',
|
167 |
'zebra']
|
168 |
|
169 |
+
|
170 |
+
|
171 |
def predict(img):
|
172 |
# Add preprocessing and fix inference mode
|
173 |
model.eval()
|
|
|
181 |
])
|
182 |
|
183 |
img_tensor = preprocess(img).unsqueeze(0) # Fix spelling of unsqueeze
|
184 |
+
|
185 |
with torch.inference_mode(): # Add parentheses
|
186 |
logits = model(img_tensor)
|
187 |
preds = logits.argmax(dim=1)
|
188 |
return classes[preds.item()]
|
189 |
|
190 |
+
|
191 |
+
|
192 |
"""gradio interface"""
|
193 |
demo = gr.Interface(
|
194 |
fn=predict,
|
195 |
+
inputs=gr.Image(type="pil", width=244, height=244), # Input component
|
196 |
outputs="label", # Output component
|
197 |
title="Animal Classifier",
|
198 |
description="Classify images into 30 animal categories"
|