import streamlit as st import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import random from PIL import Image from torchvision import datasets from sklearn.metrics import accuracy_score, classification_report, confusion_matrix # CIFAR-10 Class Names CLASS_NAMES = [ "Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck" ] # Load CIFAR-10 Dataset for Visualization transform = transforms.Compose([transforms.ToTensor()]) dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) # Load Trained Model @st.cache_resource def load_model(): model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES)) model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu"))) model.eval() return model model = load_model() # Sidebar Navigation st.sidebar.title("Navigation") page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Predictor"]) # 📌 Dataset Preview Page if page == "Dataset": st.title("📊 CIFAR-10 Dataset Preview") # Dataset Information st.markdown(""" ## 📝 About CIFAR-10 The **CIFAR-10 dataset** is widely used in image classification research. - 📍 **Created by**: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton - 🏛 **From**: University of Toronto - 📸 **Images**: 60,000 color images (**32×32 pixels**) - 🏷 **Classes (10)**: - 🛫 Airplane - 🚗 Automobile - 🐦 Bird - 🐱 Cat - 🦌 Deer - 🐶 Dog - 🐸 Frog - 🐴 Horse - 🚢 Ship - 🚛 Truck - 🔗 **[Dataset Link](https://www.cs.toronto.edu/~kriz/cifar.html)** """) # Show 10 Random Images st.subheader("🔍 Random CIFAR-10 Images") cols = st.columns(5) # Display in 5 columns for i in range(10): index = random.randint(0, len(dataset) - 1) image, label = dataset[index] image = transforms.ToPILImage()(image) # Convert tensor to image cols[i % 5].image(image, caption=CLASS_NAMES[label], use_container_width=True) # 📈 Visualization Page elif page == "Visualizations": st.title("📊 Dataset Visualizations") # Count class occurrences class_counts = {cls: 0 for cls in CLASS_NAMES} for _, label in dataset: class_counts[CLASS_NAMES[label]] += 1 # Pie Chart st.subheader("📌 Class Distribution (Pie Chart)") fig, ax = plt.subplots() colors = sns.color_palette("husl", len(CLASS_NAMES)) ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors) st.pyplot(fig) # Bar Chart st.subheader("📊 Class Distribution (Bar Chart)") fig, ax = plt.subplots() sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl") plt.xticks(rotation=45) st.pyplot(fig) # 📊 Model Metrics Page elif page == "Model Metrics": st.title("📈 Model Performance") try: y_true = torch.load("y_true.pth") y_pred = torch.load("y_pred.pth") # Display Accuracy st.write(f"### ✅ Accuracy: **{accuracy_score(y_true, y_pred):.2f}**") # Classification Report report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True) st.write(pd.DataFrame(report).T) # Confusion Matrix st.subheader("🔄 Confusion Matrix") cm = confusion_matrix(y_true, y_pred) fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES) st.pyplot(fig) except: st.error("🚨 Model metrics files not found!") # 🔍 Prediction Page elif page == "Predictor": st.title("🔍 CIFAR-10 Image Classifier") # About the Classifier st.markdown(""" ## 📝 About This App This app is a **deep learning image classifier** trained on the **CIFAR-10 dataset**. It can recognize **10 different objects/animals**: - 🛫 Airplane, 🚗 Automobile, 🐦 Bird, 🐱 Cat, 🦌 Deer - 🐶 Dog, 🐸 Frog, 🐴 Horse, 🚢 Ship, 🚛 Truck """) # Upload Image uploaded_file = st.file_uploader("📤 Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="🖼 Uploaded Image", use_container_width=True) # Transform image for model transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) image_tensor = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): output = model(image_tensor) predicted_class = torch.argmax(output, dim=1).item() # Display Prediction st.success(f"### ✅ Prediction: **{CLASS_NAMES[predicted_class]}**")