File size: 5,198 Bytes
0e83c47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import random
from PIL import Image
from torchvision import datasets
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# CIFAR-10 Class Names
CLASS_NAMES = [
    "Airplane", "Automobile", "Bird", "Cat", "Deer",
    "Dog", "Frog", "Horse", "Ship", "Truck"
]

# Load CIFAR-10 Dataset for Visualization
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

# Load Trained Model
@st.cache_resource
def load_model():
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
    model.eval()
    return model

model = load_model()

# Sidebar Navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to", ["Dataset", "Visualizations", "Model Metrics", "Predictor"])

# πŸ“Œ Dataset Preview Page
if page == "Dataset":
    st.title("πŸ“Š CIFAR-10 Dataset Preview")

    # Dataset Information
    st.markdown("""
    ## πŸ“ About CIFAR-10
    The **CIFAR-10 dataset** is widely used in image classification research.
    - πŸ“ **Created by**: Alex Krizhevsky, Vinod Nair, Geoffrey Hinton  
    - πŸ› **From**: University of Toronto  
    - πŸ“Έ **Images**: 60,000 color images (**32Γ—32 pixels**)  
    - 🏷 **Classes (10)**:  
      - πŸ›« Airplane  
      - πŸš— Automobile  
      - 🐦 Bird  
      - 🐱 Cat  
      - 🦌 Deer  
      - 🐢 Dog  
      - 🐸 Frog  
      - 🐴 Horse  
      - 🚒 Ship  
      - πŸš› Truck  
    - πŸ”— **[Dataset Link](https://www.cs.toronto.edu/~kriz/cifar.html)**
    """)

    # Show 10 Random Images
    st.subheader("πŸ” Random CIFAR-10 Images")
    cols = st.columns(5)  # Display in 5 columns
    for i in range(10):
        index = random.randint(0, len(dataset) - 1)
        image, label = dataset[index]
        image = transforms.ToPILImage()(image)  # Convert tensor to image
        cols[i % 5].image(image, caption=CLASS_NAMES[label], use_container_width=True)

# πŸ“ˆ Visualization Page
elif page == "Visualizations":
    st.title("πŸ“Š Dataset Visualizations")
    
    # Count class occurrences
    class_counts = {cls: 0 for cls in CLASS_NAMES}
    for _, label in dataset:
        class_counts[CLASS_NAMES[label]] += 1

    # Pie Chart
    st.subheader("πŸ“Œ Class Distribution (Pie Chart)")
    fig, ax = plt.subplots()
    colors = sns.color_palette("husl", len(CLASS_NAMES))
    ax.pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%', colors=colors)
    st.pyplot(fig)

    # Bar Chart
    st.subheader("πŸ“Š Class Distribution (Bar Chart)")
    fig, ax = plt.subplots()
    sns.barplot(x=list(class_counts.keys()), y=list(class_counts.values()), palette="husl")
    plt.xticks(rotation=45)
    st.pyplot(fig)

# πŸ“Š Model Metrics Page
elif page == "Model Metrics":
    st.title("πŸ“ˆ Model Performance")
    
    try:
        y_true = torch.load("y_true.pth")
        y_pred = torch.load("y_pred.pth")

        # Display Accuracy
        st.write(f"### βœ… Accuracy: **{accuracy_score(y_true, y_pred):.2f}**")

        # Classification Report
        report = classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)
        st.write(pd.DataFrame(report).T)

        # Confusion Matrix
        st.subheader("πŸ”„ Confusion Matrix")
        cm = confusion_matrix(y_true, y_pred)
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
        st.pyplot(fig)
    
    except:
        st.error("🚨 Model metrics files not found!")

# πŸ” Prediction Page
elif page == "Predictor":
    st.title("πŸ” CIFAR-10 Image Classifier")

    # About the Classifier
    st.markdown("""
    ## πŸ“ About This App
    This app is a **deep learning image classifier** trained on the **CIFAR-10 dataset**.  
    It can recognize **10 different objects/animals**:
    - πŸ›« Airplane, πŸš— Automobile, 🐦 Bird, 🐱 Cat, 🦌 Deer
    - 🐢 Dog, 🐸 Frog, 🐴 Horse, 🚒 Ship, πŸš› Truck
    """)

    # Upload Image
    uploaded_file = st.file_uploader("πŸ“€ Upload an image", type=["jpg", "png", "jpeg"])
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="πŸ–Ό Uploaded Image", use_container_width=True)

        # Transform image for model
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        image_tensor = transform(image).unsqueeze(0)

        # Make prediction
        with torch.no_grad():
            output = model(image_tensor)
            predicted_class = torch.argmax(output, dim=1).item()

        # Display Prediction
        st.success(f"### βœ… Prediction: **{CLASS_NAMES[predicted_class]}**")