Dan Bochman commited on
Commit
5310d0a
·
unverified ·
1 Parent(s): d46329a

explictly cast to bfloat16

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -114,6 +114,7 @@ model.to("cuda")
114
  @torch.inference_mode()
115
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
116
  def run_model(input_tensor, height, width):
 
117
  output = model(input_tensor)
118
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
119
  _, preds = torch.max(output, 1)
 
114
  @torch.inference_mode()
115
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
116
  def run_model(input_tensor, height, width):
117
+ input_tensor = input_tensor.to(device="cuda", dtype=torch.bfloat16) # explicit cast to bfloat16
118
  output = model(input_tensor)
119
  output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
120
  _, preds = torch.max(output, 1)