arittrabag commited on
Commit
c7bd69d
·
verified ·
1 Parent(s): 53a36c6

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ from transformers import SegformerForImageClassification
8
+ import google.generativeai as genai
9
+ import io
10
+
11
+ # Initialize Gemini API
12
+ genai.configure(api_key="AIzaSyDD8QW1BggDVVMLteDygHCHrD6Ff9Dy0e8")
13
+ gemini_model = genai.GenerativeModel('gemini-2.0-flash')
14
+
15
+ # Load the MRI vs Non-MRI model
16
+ mri_classifier = tf.keras.models.load_model("alzheimers_detection_model.h5")
17
+
18
+ # Load Alzheimer's and Brain Tumor models
19
+ alzheimers_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
20
+ alzheimers_model.classifier = torch.nn.Linear(alzheimers_model.classifier.in_features, 4)
21
+ alzheimers_model.load_state_dict(torch.load('alzheimers_model.pth', map_location=torch.device('cpu')))
22
+ alzheimers_model.eval()
23
+
24
+ brain_tumor_model = SegformerForImageClassification.from_pretrained('nvidia/mit-b1')
25
+ brain_tumor_model.classifier = torch.nn.Linear(brain_tumor_model.classifier.in_features, 4)
26
+ brain_tumor_model.load_state_dict(torch.load('brain_tumor_model.pth', map_location=torch.device('cpu')))
27
+ brain_tumor_model.eval()
28
+
29
+ # Define class labels
30
+ mri_classes = ["Brain MRI", "Not a Brain MRI"]
31
+ alzheimers_classes = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
32
+ brain_tumor_classes = ['glioma', 'meningioma', 'notumor', 'pituitary']
33
+
34
+ # Define transformations
35
+ transform = transforms.Compose([
36
+ transforms.Resize((224, 224)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39
+ ])
40
+
41
+ def generate_medical_report(diagnosis):
42
+ prompt = f"""
43
+ Generate a detailed medical report for a patient diagnosed with {diagnosis}.
44
+ Include possible causes, symptoms, treatment options, and prognosis.
45
+ Conclude the report with the signature: Team BrainTech.ai.
46
+ """
47
+ response = gemini_model.generate_content(prompt)
48
+ return response.text.strip()
49
+
50
+ def predict_pipeline(image, model_type):
51
+ # Step 1: Check if it's an MRI
52
+ image_resized = image.resize((224, 224))
53
+ image_array = np.array(image_resized) / 255.0
54
+ image_array = np.expand_dims(image_array, axis=0)
55
+ mri_prediction = mri_classifier.predict(image_array)
56
+ mri_class = mri_classes[np.argmax(mri_prediction)]
57
+ mri_confidence = np.max(mri_prediction) * 100 # Confidence score in %
58
+
59
+ if mri_class == "Not a Brain MRI":
60
+ return "Not a Brain MRI", None, None
61
+
62
+ # Step 2: Classify MRI
63
+ image_tensor = transform(image).unsqueeze(0)
64
+ if model_type == "Alzheimer's":
65
+ with torch.no_grad():
66
+ outputs = alzheimers_model(image_tensor).logits
67
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
68
+ confidence = torch.max(probabilities).item() * 100 # Confidence in %
69
+ predicted_class = alzheimers_classes[torch.argmax(outputs).item()]
70
+ elif model_type == "Brain Tumor":
71
+ with torch.no_grad():
72
+ outputs = brain_tumor_model(image_tensor).logits
73
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
74
+ confidence = torch.max(probabilities).item() * 100 # Confidence in %
75
+ predicted_class = brain_tumor_classes[torch.argmax(outputs).item()]
76
+
77
+ # Step 3: Generate medical report
78
+ report = generate_medical_report(predicted_class)
79
+
80
+ return predicted_class, confidence, report
81
+
82
+ def download_report(report_text):
83
+ """Convert report text into a downloadable format."""
84
+ buffer = io.BytesIO()
85
+ buffer.write(report_text.encode())
86
+ buffer.seek(0)
87
+ return buffer
88
+
89
+ # Streamlit UI
90
+ st.title("MRI Scan Classification Pipeline with Gemini AI")
91
+ st.write("Upload an image to check if it's an MRI, classify it, view confidence scores, and get an AI-generated medical report.")
92
+
93
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
94
+ model_type = st.selectbox("Select Model Type", ["Alzheimer's", "Brain Tumor"])
95
+
96
+ if st.button("Predict") and uploaded_file is not None:
97
+ image = Image.open(uploaded_file)
98
+ st.image(image, caption='Uploaded Image', use_column_width=True)
99
+ st.write("Classifying...")
100
+
101
+ # Run the prediction pipeline
102
+ result, confidence, report = predict_pipeline(image, model_type)
103
+
104
+ # Display results
105
+ st.write(f"**Prediction:** {result}")
106
+ if confidence is not None:
107
+ st.write(f"**Confidence Score:** {confidence:.2f}%")
108
+
109
+ # Display AI-Generated Report
110
+ if report:
111
+ st.subheader("AI-Generated Medical Report")
112
+ st.write(report)
113
+
114
+ # Download Report Button
115
+ report_buffer = download_report(report)
116
+ st.download_button(
117
+ label="Download Medical Report",
118
+ data=report_buffer,
119
+ file_name=f"medical_report_{result.replace(' ', '_')}.txt",
120
+ mime="text/plain"
121
+ )
122
+
123
+ # Warning Banner
124
+ st.warning("⚠️ Please consult a doctor before taking any medical decisions based on this report.")