ImageClassifier / app.py
Regino
sjsbfjsd
0e83c47
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
from PIL import Image
from torchvision import datasets
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# CIFAR-10 Class Names
CLASS_NAMES = [
"Airplane", "Automobile", "Bird", "Cat", "Deer",
"Dog", "Frog", "Horse", "Ship", "Truck"
]
# Load CIFAR-10 Dataset for Visualization
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
# Load Trained Model
@st.cache_resource
def load_model():
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Predictor"])
# πŸ“Œ Dataset Preview Page
if page == "Dataset":
st.title("πŸ“Š CIFAR-10 Dataset Preview")
# Dataset Information
st.markdown("""
## πŸ“ About CIFAR-10
The **CIFAR-10 dataset** is widely used in image classification research.
- πŸ“ **Created by**: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton
- πŸ› **From**: University of Toronto
- πŸ“Έ **Images**: 60,000 color images (**32Γ—32 pixels**)
- 🏷 **Classes (10)**:
- πŸ›« Airplane
- πŸš— Automobile
- 🐦 Bird
- 🐱 Cat
- 🦌 Deer
- 🐢 Dog
- 🐸 Frog
- 🐴 Horse
- 🚒 Ship
- πŸš› Truck
- πŸ”— **[Dataset Link](https://www.cs.toronto.edu/~kriz/cifar.html)**
""")
# Show 10 Random Images
st.subheader("πŸ” Random CIFAR-10 Images")
cols = st.columns(5) # Display in 5 columns
for i in range(10):
index = random.randint(0, len(dataset) - 1)
image, label = dataset[index]
image = transforms.ToPILImage()(image) # Convert tensor to image
cols[i % 5].image(image, caption=CLASS_NAMES[label], use_container_width=True)
# πŸ“ˆ Visualization Page
elif page == "Visualizations":
st.title("πŸ“Š Dataset Visualizations")
# Count class occurrences
class_counts = {cls: 0 for cls in CLASS_NAMES}
for _, label in dataset:
class_counts[CLASS_NAMES[label]] += 1
# Pie Chart
st.subheader("πŸ“Œ Class Distribution (Pie Chart)")
fig, ax = plt.subplots()
colors = sns.color_palette("husl", len(CLASS_NAMES))
ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
st.pyplot(fig)
# Bar Chart
st.subheader("πŸ“Š Class Distribution (Bar Chart)")
fig, ax = plt.subplots()
sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
plt.xticks(rotation=45)
st.pyplot(fig)
# πŸ“Š Model Metrics Page
elif page == "Model Metrics":
st.title("πŸ“ˆ Model Performance")
try:
y_true = torch.load("y_true.pth")
y_pred = torch.load("y_pred.pth")
# Display Accuracy
st.write(f"### βœ… Accuracy: **{accuracy_score(y_true, y_pred):.2f}**")
# Classification Report
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
st.write(pd.DataFrame(report).T)
# Confusion Matrix
st.subheader("πŸ”„ Confusion Matrix")
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
st.pyplot(fig)
except:
st.error("🚨 Model metrics files not found!")
# πŸ” Prediction Page
elif page == "Predictor":
st.title("πŸ” CIFAR-10 Image Classifier")
# About the Classifier
st.markdown("""
## πŸ“ About This App
This app is a **deep learning image classifier** trained on the **CIFAR-10 dataset**.
It can recognize **10 different objects/animals**:
- πŸ›« Airplane, πŸš— Automobile, 🐦 Bird, 🐱 Cat, 🦌 Deer
- 🐢 Dog, 🐸 Frog, 🐴 Horse, 🚒 Ship, πŸš› Truck
""")
# Upload Image
uploaded_file = st.file_uploader("πŸ“€ Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="πŸ–Ό Uploaded Image", use_container_width=True)
# Transform image for model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
image_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(image_tensor)
predicted_class = torch.argmax(output, dim=1).item()
# Display Prediction
st.success(f"### βœ… Prediction: **{CLASS_NAMES[predicted_class]}**")