Spaces:
Sleeping
Sleeping
File size: 5,198 Bytes
0e83c47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
from PIL import Image
from torchvision import datasets
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# CIFAR-10 Class Names
CLASS_NAMES = [
"Airplane", "Automobile", "Bird", "Cat", "Deer",
"Dog", "Frog", "Horse", "Ship", "Truck"
]
# Load CIFAR-10 Dataset for Visualization
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
# Load Trained Model
@st.cache_resource
def load_model():
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Predictor"])
# π Dataset Preview Page
if page == "Dataset":
st.title("π CIFAR-10 Dataset Preview")
# Dataset Information
st.markdown("""
## π About CIFAR-10
The **CIFAR-10 dataset** is widely used in image classification research.
- π **Created by**: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton
- π **From**: University of Toronto
- πΈ **Images**: 60,000 color images (**32Γ32 pixels**)
- π· **Classes (10)**:
- π« Airplane
- π Automobile
- π¦ Bird
- π± Cat
- π¦ Deer
- πΆ Dog
- πΈ Frog
- π΄ Horse
- π’ Ship
- π Truck
- π **[Dataset Link](https://www.cs.toronto.edu/~kriz/cifar.html)**
""")
# Show 10 Random Images
st.subheader("π Random CIFAR-10 Images")
cols = st.columns(5) # Display in 5 columns
for i in range(10):
index = random.randint(0, len(dataset) - 1)
image, label = dataset[index]
image = transforms.ToPILImage()(image) # Convert tensor to image
cols[i % 5].image(image, caption=CLASS_NAMES[label], use_container_width=True)
# π Visualization Page
elif page == "Visualizations":
st.title("π Dataset Visualizations")
# Count class occurrences
class_counts = {cls: 0 for cls in CLASS_NAMES}
for _, label in dataset:
class_counts[CLASS_NAMES[label]] += 1
# Pie Chart
st.subheader("π Class Distribution (Pie Chart)")
fig, ax = plt.subplots()
colors = sns.color_palette("husl", len(CLASS_NAMES))
ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
st.pyplot(fig)
# Bar Chart
st.subheader("π Class Distribution (Bar Chart)")
fig, ax = plt.subplots()
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
plt.xticks(rotation=45)
st.pyplot(fig)
# π Model Metrics Page
elif page == "Model Metrics":
st.title("π Model Performance")
try:
y_true = torch.load("y_true.pth")
y_pred = torch.load("y_pred.pth")
# Display Accuracy
st.write(f"### β
Accuracy: **{accuracy_score(y_true, y_pred):.2f}**")
# Classification Report
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
st.write(pd.DataFrame(report).T)
# Confusion Matrix
st.subheader("π Confusion Matrix")
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
st.pyplot(fig)
except:
st.error("π¨ Model metrics files not found!")
# π Prediction Page
elif page == "Predictor":
st.title("π CIFAR-10 Image Classifier")
# About the Classifier
st.markdown("""
## π About This App
This app is a **deep learning image classifier** trained on the **CIFAR-10 dataset**.
It can recognize **10 different objects/animals**:
- π« Airplane, π Automobile, π¦ Bird, π± Cat, π¦ Deer
- πΆ Dog, πΈ Frog, π΄ Horse, π’ Ship, π Truck
""")
# Upload Image
uploaded_file = st.file_uploader("π€ Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="πΌ Uploaded Image", use_container_width=True)
# Transform image for model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(image_tensor)
predicted_class = torch.argmax(output, dim=1).item()
# Display Prediction
st.success(f"### β
Prediction: **{CLASS_NAMES[predicted_class]}**")
|