IncreasingLoss commited on
Commit
7e8fb73
·
verified ·
1 Parent(s): 6afc4b9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -166,6 +166,8 @@ classes = ['antelope',
166
  'tiger',
167
  'zebra']
168
 
 
 
169
  def predict(img):
170
  # Add preprocessing and fix inference mode
171
  model.eval()
@@ -179,15 +181,18 @@ def predict(img):
179
  ])
180
 
181
  img_tensor = preprocess(img).unsqueeze(0) # Fix spelling of unsqueeze
 
182
  with torch.inference_mode(): # Add parentheses
183
  logits = model(img_tensor)
184
  preds = logits.argmax(dim=1)
185
  return classes[preds.item()]
186
 
 
 
187
  """gradio interface"""
188
  demo = gr.Interface(
189
  fn=predict,
190
- inputs=gr.Image(type="pil", shape=(224, 224)), # Input component
191
  outputs="label", # Output component
192
  title="Animal Classifier",
193
  description="Classify images into 30 animal categories"
 
166
  'tiger',
167
  'zebra']
168
 
169
+
170
+
171
  def predict(img):
172
  # Add preprocessing and fix inference mode
173
  model.eval()
 
181
  ])
182
 
183
  img_tensor = preprocess(img).unsqueeze(0) # Fix spelling of unsqueeze
184
+
185
  with torch.inference_mode(): # Add parentheses
186
  logits = model(img_tensor)
187
  preds = logits.argmax(dim=1)
188
  return classes[preds.item()]
189
 
190
+
191
+
192
  """gradio interface"""
193
  demo = gr.Interface(
194
  fn=predict,
195
+ inputs=gr.Image(type="pil", width=244, height=244), # Input component
196
  outputs="label", # Output component
197
  title="Animal Classifier",
198
  description="Classify images into 30 animal categories"