CyberWaifu commited on
Commit
677217d
·
verified ·
1 Parent(s): e1b761d

Fix missing batch dimension.

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -75,8 +75,10 @@ def run_inference(pil_image: Image.Image) -> np.ndarray:
75
  """
76
  input_tensor = preprocess_image(pil_image)
77
  input_name = session.get_inputs()[0].name
 
 
78
  # Only refined_logits are used (initial_logits is ignored)
79
- _, refined_logits = session.run(None, {input_name: input_tensor})
80
  return refined_logits[0]
81
 
82
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):
 
75
  """
76
  input_tensor = preprocess_image(pil_image)
77
  input_name = session.get_inputs()[0].name
78
+ # Expand dimensions to make it (1, C, H, W)
79
+ input_tensor_expanded = np.expand_dims(input_tensor, axis=0)
80
  # Only refined_logits are used (initial_logits is ignored)
81
+ _, refined_logits = session.run(None, {input_name: input_tensor_expanded})
82
  return refined_logits[0]
83
 
84
  def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: float):