Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
import torch | |
from PIL import Image | |
# Define model repository | |
model_name = "Aya-Ch/brain-tumor-classifier" | |
processor = AutoImageProcessor.from_pretrained(model_name) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Define brain tumor classes | |
tumor_classes = ['meningioma', 'glioma', 'pituitary tumor'] | |
def predict(image): | |
try: | |
# Process the image using the processor | |
processed_image = processor(images=image, return_tensors="pt")['pixel_values'] | |
with torch.no_grad(): | |
outputs = model(processed_image) | |
logits = outputs.logits # Get classification scores | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# Convert tensor outputs to Python numbers | |
results = {tumor_classes[i]: float(probs[0, i]) for i in range(len(tumor_classes))} | |
return results | |
except Exception as e: | |
return {"Error": f"Failed to process image: {str(e)}"} | |
# Define example images | |
examples = [ | |
["examples/meningioma.jpg"], | |
["examples/glioma.jpg"], | |
["examples/pituitary_tumor.jpg"] | |
] | |
# Gradio Interface with Examples | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), # Accepts image input | |
outputs=gr.Label(label="Tumor Classification"), | |
title="Brain Tumor Classifier", | |
description="Upload an MRI scan to classify the type of brain tumor (meningioma, glioma or pituitary tumor)", | |
allow_flagging="never", | |
examples=examples # Add preloaded example images | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |