Spaces:
Sleeping
Sleeping
File size: 5,463 Bytes
46004f7 3c23234 46004f7 66b11d3 46004f7 66b11d3 8ef1b85 66b11d3 6bae70d 46004f7 6bae70d 46004f7 66b11d3 46004f7 66b11d3 6bae70d 66b11d3 8ef1b85 46004f7 0eadccf 46004f7 8ef1b85 46004f7 66b11d3 6bae70d 66b11d3 6bae70d 2a60a24 66b11d3 56c6b0d 66b11d3 56c6b0d 66b11d3 46004f7 66b11d3 8ef1b85 66b11d3 46004f7 0eadccf 4326ce4 66b11d3 46004f7 56c6b0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import gradio as gr
from PIL import Image
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import cv2
# Model and processor configuration
model_name_or_path = "google/vit-base-patch16-224-in21k"
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
# Load dataset (adjust dataset_path accordingly)
dataset_path = "pawlo2013/chest_xray"
train_dataset = load_dataset(dataset_path, split="train")
class_names = train_dataset.features["label"].names
# Load ViT model
model = ViTForImageClassification.from_pretrained(
"./models",
num_labels=len(class_names),
id2label={str(i): label for i, label in enumerate(class_names)},
label2id={label: i for i, label in enumerate(class_names)},
)
# Set model to evaluation mode
model.eval()
# Define the classification function
def classify_and_visualize(img, device="cpu", discard_ratio=0.9, head_fusion="mean"):
img = img.convert("RGB")
processed_input = processor(images=img, return_tensors="pt").to(device)
processed_input = processed_input["pixel_values"].to(device)
with torch.no_grad():
outputs = model(processed_input, output_attentions=True)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0].tolist()
prediction = torch.argmax(logits, dim=-1).item()
predicted_class = class_names[prediction]
result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
# Generate attention heatmap
heatmap_img = show_final_layer_attention_maps(
outputs, processed_input, device, discard_ratio, head_fusion
)
return {"probabilities": result, "heatmap": heatmap_img}
def format_output(output):
return (output["probabilities"], output["heatmap"])
# Function to load examples from a folder
def load_examples_from_folder(folder_path):
examples = []
for file in os.listdir(folder_path):
if file.endswith((".png", ".jpg", ".jpeg")):
examples.append(Image.open(os.path.join(folder_path, file)))
return examples
# Function to show final layer attention maps
def show_final_layer_attention_maps(
outputs,
processed_input,
device,
discard_ratio=0.6,
head_fusion="max",
only_last_layer=False,
):
with torch.no_grad():
image = processed_input.squeeze(0)
image = image - image.min()
image = image / image.max()
result = torch.eye(outputs.attentions[0].size(-1)).to(device)
if only_last_layer:
attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
else:
attention_list = outputs.attentions
for attention in attention_list:
if head_fusion == "mean":
attention_heads_fused = attention.mean(axis=1)
elif head_fusion == "max":
attention_heads_fused = attention.max(axis=1)[0]
elif head_fusion == "min":
attention_heads_fused = attention.min(axis=1)[0]
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
indices = indices[indices != 0]
flat[0, indices] = 0
I = torch.eye(attention_heads_fused.size(-1)).to(device)
a = (attention_heads_fused + 1.0 * I) / 2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
mask = result[0, 0, 1:]
width = int(mask.size(-1) ** 0.5)
mask = mask.reshape(width, width).cpu().numpy()
mask = mask / np.max(mask)
mask = cv2.resize(mask, (224, 224))
mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
heatmap = plt.cm.jet(mask)[:, :, :3]
showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
showed_img = (showed_img - np.min(showed_img)) / (
np.max(showed_img) - np.min(showed_img)
)
superimposed_img = heatmap * 0.4 + showed_img * 0.6
superimposed_img_pil = Image.fromarray(
(superimposed_img * 255).astype(np.uint8)
)
return superimposed_img_pil
# Define the path to the examples folder
examples_folder = "./examples"
examples = load_examples_from_folder(examples_folder)
# Create the Gradio interface
iface = gr.Interface(
fn=lambda img: format_output(classify_and_visualize(img)),
inputs=gr.Image(type="pil", label="Upload X-Ray Image"),
outputs=[
gr.Label(),
gr.Image(label="Attention Heatmap"),
],
examples=examples,
cache_examples=False,
allow_flagging=False,
concurrency_limit=1,
title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details [here](https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class). The examples presented are taken from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2.) The attention heatmap over all layers of the transfomer done by the attention rollout techinique by the implementation of [jacobgil](https://github.com/jacobgil/vit-explain).",
)
# Launch the app
if __name__ == "__main__":
iface.launch(debug=True)
|