devadethanr commited on
Commit
9865f77
·
verified ·
1 Parent(s): 4388b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -1
app.py CHANGED
@@ -1,3 +1,57 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/devadethanr/alz_model").launch()
 
1
+ # import gradio as gr
2
+
3
+ # gr.load("models/devadethanr/alz_model").launch()
4
+
5
+
6
  import gradio as gr
7
+ from transformers import AutoModelForImageClassification
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ # Load the model and image processor from the Hub
13
+ model_name = "devadethanr/alz_model"
14
+ model = AutoModelForImageClassification.from_pretrained(model_name)
15
+
16
+ # Get the label names from the model's configuration
17
+ labels = model.config.id2label
18
+
19
+ # Define the prediction function (with preprocessing)
20
+ def predict_image(image):
21
+ """
22
+ Predicts the Alzheimer's disease stage from an uploaded MRI image.
23
+
24
+ Args:
25
+ image: The uploaded MRI image (PIL Image).
26
+
27
+ Returns:
28
+ The predicted label with its corresponding probability.
29
+ """
30
+
31
+ image = model.preprocess_image(image, return_tensors="pt").to(model.device)
32
+ with torch.no_grad():
33
+ logits = model(**image).logits
34
+
35
+ predicted_label_id = logits.argmax(-1).item()
36
+ predicted_label = labels[predicted_label_id]
37
+
38
+ # Calculate probabilities using softmax
39
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
40
+ confidences = {label: float(probabilities[0][i]) for i, label in enumerate(labels)}
41
+
42
+ return predicted_label, confidences
43
+
44
+
45
+ # Create the Gradio interface (same as before)
46
+ iface = gr.Interface(
47
+ fn=predict_image,
48
+ inputs=gr.inputs.Image(type="pil", label="Upload MRI Image"),
49
+ outputs=[
50
+ gr.outputs.Label(label="Prediction"),
51
+ gr.outputs.JSON(label="Confidence Scores")
52
+ ],
53
+ title="Alzheimer's Disease MRI Image Classifier",
54
+ description="Upload an MRI image to predict the stage of Alzheimer's disease."
55
+ )
56
 
57
+ iface.launch()