kanneboinakumar's picture
Update app.py
93dbc65 verified
raw
history blame
3.78 kB
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from PIL import Image
st.markdown(
"""
<style>
/* Set background image for the entire app */
.stApp {
background: url('https://cdn.i-scmp.com/sites/default/files/images/methode/2016/05/06/078562fe-1286-11e6-95eb-aaf30b46b489_image_hires.JPG') no-repeat center center fixed;
background-size: cover;
}
.stApp h1 {
background: linear-gradient(to right, #4b6cb7, #182848);
color: white;
padding: 15px 25px;
border-radius: 15px;
font-size: 2.5em;
text-align: center;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
max-width: 90%;
margin: 30px auto;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* Style for the button */
.stButton>button {
background-color: #4CAF50; /* Green */
color: white;
font-size: 1.2em;
border-radius: 10px;
padding: 10px 24px;
border: none;
}
/* Center the button */
.stButton {
display: flex;
justify-content: center;
}
/* Style for the output container */
.output-container {
background-color: lightpink;
color: black;
font-size: 1.5em;
padding: 15px;
border-radius: 10px;
margin-top: 20px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
width: 100%;
margin-left: auto;
margin-right: auto;
text-align: center;
}
</style>
""",
unsafe_allow_html=True
)
# Title
st.title("Brain Tumor Classification")
st.write("")
# Class names
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
# Load pre-trained ResNet18 model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_of_classes = len(class_names)
num_of_features = model.fc.in_features
model.fc = nn.Linear(num_of_features, num_of_classes)
# Load trained model weights
model.load_state_dict(torch.load('resnet18_model (1).pth', map_location=torch.device('cpu')))
model.eval()
st.markdown(
"""
<div style='
background-color: weight:
font-weight: bold;
font-size: 25px;
color: black;
margin-bottom: 10px;
'>
📤 Upload a Scan Image
</div>
""",
unsafe_allow_html=True
)
# Image upload
uploaded_img = st.file_uploader("", type=["jpg", "jpeg", "png"])
if st.button("Submit"):
if uploaded_img is not None:
# Display uploaded image in a smaller size
image = Image.open(uploaded_img)
st.image(image, caption="**Uploaded Image**", width=200)
# Image transformations
sample_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.1776, 0.1776, 0.1776], std=[0.1735, 0.1735, 0.1735])
])
# Apply transformations
transformed_img = sample_transform(image).unsqueeze(0)
# Model inference
with torch.no_grad():
pred = model(transformed_img).argmax(dim=1).item()
# Stylish output box
st.markdown(
f"""
<div class="output-container">
🧠 <strong>Predicted Class:</strong> {class_names[pred]}
</div>
""",
unsafe_allow_html=True
)
else:
st.markdown(
"""
<div style='background-color: #f8d7da; padding: 10px; border-radius: 5px;'>
<h4 style='color: #721c24;'> ⚠️ Plese upload image </h4>
</div>
""",
unsafe_allow_html=True
)