pawlo2013 commited on
Commit
8ef1b85
·
1 Parent(s): 3c9f6b6
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -32,10 +32,9 @@ model.eval()
32
 
33
  # Define the classification function
34
  # Define the classification function
35
- def classify_and_visualize(
36
- img_path, device="cpu", discard_ratio=0.9, head_fusion="mean"
37
- ):
38
- img = Image.open(img_path).convert("RGB")
39
  processed_input = processor(images=img, return_tensors="pt").to(device)
40
 
41
  with torch.no_grad():
@@ -46,21 +45,20 @@ def classify_and_visualize(
46
  predicted_class = class_names[prediction]
47
 
48
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
49
- filename = os.path.basename(img_path).split(".")[0]
50
 
51
  # Generate attention heatmap
52
  heatmap_img = show_final_layer_attention_maps(
53
  model, processed_input, device, discard_ratio, head_fusion
54
  )
55
 
56
- return {"filename": filename, "probabilities": result, "heatmap": heatmap_img}
57
 
58
 
59
  def format_output(output):
60
  return (
61
- f"{output['filename']}",
62
  output["probabilities"],
63
- gr.Image(value=output["heatmap"]),
64
  )
65
 
66
 
@@ -69,7 +67,7 @@ def load_examples_from_folder(folder_path):
69
  examples = []
70
  for file in os.listdir(folder_path):
71
  if file.endswith((".png", ".jpg", ".jpeg")):
72
- examples.append(os.path.join(folder_path, file))
73
  return examples
74
 
75
 
@@ -156,9 +154,8 @@ examples = load_examples_from_folder(examples_folder)
156
  # Create the Gradio interface
157
  iface = gr.Interface(
158
  fn=lambda img: format_output(classify_and_visualize(img)),
159
- inputs=gr.Image(type="filepath"),
160
  outputs=[
161
- gr.Textbox(label="True Label (from filename)"),
162
  gr.Label(),
163
  gr.Image(label="Attention Heatmap"),
164
  ],
 
32
 
33
  # Define the classification function
34
  # Define the classification function
35
+ def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
36
+ # filename = img.filename
37
+ img = img.convert("RGB")
 
38
  processed_input = processor(images=img, return_tensors="pt").to(device)
39
 
40
  with torch.no_grad():
 
45
  predicted_class = class_names[prediction]
46
 
47
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
48
+ # get the filename from the image object
49
 
50
  # Generate attention heatmap
51
  heatmap_img = show_final_layer_attention_maps(
52
  model, processed_input, device, discard_ratio, head_fusion
53
  )
54
 
55
+ return {"probabilities": result, "heatmap": heatmap_img}
56
 
57
 
58
  def format_output(output):
59
  return (
 
60
  output["probabilities"],
61
+ output["heatmap"] if output["heatmap"] is not None else None,
62
  )
63
 
64
 
 
67
  examples = []
68
  for file in os.listdir(folder_path):
69
  if file.endswith((".png", ".jpg", ".jpeg")):
70
+ examples.append(Image.open(os.path.join(folder_path, file)))
71
  return examples
72
 
73
 
 
154
  # Create the Gradio interface
155
  iface = gr.Interface(
156
  fn=lambda img: format_output(classify_and_visualize(img)),
157
+ inputs=gr.Image(type="pil", label="Upload X-Ray Image"),
158
  outputs=[
 
159
  gr.Label(),
160
  gr.Image(label="Attention Heatmap"),
161
  ],