ahsanMah commited on
Commit
b3c4783
·
1 Parent(s): 589207c

testing cuda

Browse files
Files changed (1) hide show
  1. hfapp.py +2 -6
hfapp.py CHANGED
@@ -14,6 +14,7 @@ from app import (
14
 
15
  @spaces.GPU
16
  def run_inference(model, img):
 
17
  img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
18
  score_norms = model.scorenet(img)
19
  score_norms = score_norms.square().sum(dim=(2, 3, 4)) ** 0.5
@@ -32,11 +33,7 @@ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False
32
  img = np.array(input_img)
33
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
34
  img = img.float().to(device)
35
- if load_from_hub:
36
- model = load_model_from_hub(preset=preset, device=device)
37
- else:
38
- model = load_model(modeldir="models", preset=preset, device=device)
39
-
40
  img_likelihood, score_norms = run_inference(model, img)
41
  nll, pct, ref_nll = compute_gmm_likelihood(
42
  score_norms, model_dir=f"models/{preset}"
@@ -49,7 +46,6 @@ def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False
49
  return outstr, heatmapplot, histplot
50
 
51
 
52
-
53
  demo = build_demo(localize_anomalies)
54
  if __name__ == "__main__":
55
  demo.launch()
 
14
 
15
  @spaces.GPU
16
  def run_inference(model, img):
17
+ print("model on cuda:", next(model.scorenet.net.parameters()).is_cuda)
18
  img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
19
  score_norms = model.scorenet(img)
20
  score_norms = score_norms.square().sum(dim=(2, 3, 4)) ** 0.5
 
33
  img = np.array(input_img)
34
  img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
35
  img = img.float().to(device)
36
+ model = load_model_from_hub(preset=preset, device=device)
 
 
 
 
37
  img_likelihood, score_norms = run_inference(model, img)
38
  nll, pct, ref_nll = compute_gmm_likelihood(
39
  score_norms, model_dir=f"models/{preset}"
 
46
  return outstr, heatmapplot, histplot
47
 
48
 
 
49
  demo = build_demo(localize_anomalies)
50
  if __name__ == "__main__":
51
  demo.launch()