WwYc commited on
Commit
a041d6e
Β·
verified Β·
1 Parent(s): 4604315

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  from PIL import Image
3
  from torchvision import transforms
4
- from visualization import generate_visualization
 
5
 
6
  normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
7
  TRANSFORM = transforms.Compose(
@@ -21,10 +22,10 @@ def generate_viz(image, class_index=None, use_threshold=False):
21
  print(f"Image: {image.size}")
22
  print(f"Class: {class_index}")
23
  print(f"use_threshold: {use_threshold}")
24
- image_trans = TRANSFORM(image)
25
- viz = generate_visualization(image_trans, class_index=class_index, use_threshold=use_threshold)
26
 
27
- return Image.open(viz).convert("RGB")
 
 
28
 
29
  title = "Explain ViT 😊"
30
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  from torchvision import transforms
4
+
5
+ from explain import do_explain
6
 
7
  normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
8
  TRANSFORM = transforms.Compose(
 
22
  print(f"Image: {image.size}")
23
  print(f"Class: {class_index}")
24
  print(f"use_threshold: {use_threshold}")
 
 
25
 
26
+ viz = do_explain(TRANSFORM, image, class_index=class_index, use_threshold=use_threshold)
27
+ viz.savefig("visualization.png")
28
+ return Image.open("visualization.png").convert("RGB")
29
 
30
  title = "Explain ViT 😊"
31