Spaces:
Running
Running
import pandas as pd | |
import plotly.graph_objects as go | |
import streamlit as st | |
from PIL import Image | |
from joypy import joyplot | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# Function to load dataset | |
def load_dataset(): | |
file_path = "Data/mvtec_meta_features_dataset.csv" | |
try: | |
complete_df = pd.read_csv(file_path) | |
# Show available column names for debugging | |
print("Available columns:", complete_df.columns) | |
# Verify column presence | |
required_columns = ["category", "set_type", "anomaly_status"] | |
for col in required_columns: | |
if col not in complete_df.columns: | |
raise KeyError(f"Missing required column: {col}") | |
# Define the subclasses for each category | |
subclasses = { | |
'Texture-Based': ['carpet', 'wood', 'tile', 'leather', 'zipper'], | |
'Industrial Components': ['cable', 'transistor', 'screw', 'grid', 'metal_nut'], | |
'Consumer Products': ['bottle', 'capsule', 'toothbrush'], | |
'Edible': ['hazelnut', 'pill'] | |
} | |
# Add a new column to the DataFrame to store the subclass | |
complete_df['subclass'] = complete_df['category'].apply( | |
lambda x: next((key for key, value in subclasses.items() if x in value), 'Unknown') | |
) | |
# Reorder columns to place 'subclass' after 'category' | |
cols = list(complete_df.columns) | |
cols.insert(cols.index('category') + 1, cols.pop(cols.index('subclass'))) | |
complete_df = complete_df[cols] | |
return complete_df | |
except Exception as e: | |
st.error(f"Error loading dataset: {e}") | |
return None | |
# Function to generate dataset statistics | |
def dataset_statistics(): | |
df = load_dataset() | |
if df is not None: | |
print("Loaded dataset preview:\n", df.head()) # Debugging step | |
# Aggregate counts for each category and condition | |
train_normal = df[(df['set_type'] == 'train') & (df['anomaly_status'] == 'normal')].groupby('category').size() | |
test_normal = df[(df['set_type'] == 'test') & (df['anomaly_status'] == 'normal')].groupby('category').size() | |
test_anomalous = df[(df['set_type'] == 'test') & (df['anomaly_status'] == 'anomalous')].groupby('category').size() | |
# Combine into a single DataFrame | |
final_summary = pd.DataFrame({ | |
'Train Normal Images': train_normal, | |
'Test Normal Images': test_normal, | |
'Test Anomalous Images': test_anomalous | |
}).fillna(0).reset_index() | |
return final_summary | |
return None | |
# Function to generate the bar chart | |
def dataset_distribution_chart(df): | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=df['category'], | |
y=df['Train Normal Images'], | |
name='Train Normal Images', | |
marker_color='blue' | |
)) | |
fig.add_trace(go.Bar( | |
x=df['category'], | |
y=df['Test Normal Images'], | |
name='Test Normal Images', | |
marker_color='red' | |
)) | |
fig.add_trace(go.Bar( | |
x=df['category'], | |
y=df['Test Anomalous Images'], | |
name='Test Anomalous Images', | |
marker_color='green' | |
)) | |
# Update layout | |
fig.update_layout( | |
title="Distribution of Normal and Anomalous Images per Category", | |
xaxis_title="Categories", | |
yaxis_title="Number of Images", | |
barmode='stack', | |
legend_title="Image Types" | |
) | |
# Display chart in Streamlit | |
st.plotly_chart(fig, use_container_width=True) | |
# Function to display the complete dataframe with expander | |
def display_dataframe(): | |
df = load_dataset() | |
if df is not None: | |
with st.expander("Show Complete DataFrame"): | |
st.dataframe(df) | |
def plot_bgr_pixel_densities(df, pixel_columns=['num_pixels_b', 'num_pixels_g', 'num_pixels_r']): | |
""" | |
Generate JoyPy density plots for pixel counts of BGR channels for a given category. | |
Parameters: | |
df (pd.DataFrame): Filtered DataFrame for a single category. | |
pixel_columns (list): List of column names for BGR pixel counts. | |
Returns: | |
None | |
""" | |
if df.empty: | |
st.warning("⚠️ No data available for the selected category.") | |
return | |
# Plot JoyPy density plot | |
fig, axes = joyplot( | |
data=df, | |
by="category", # Group by category | |
column=pixel_columns, | |
color=['blue', 'green', 'red'], # Colors for BGR channels | |
alpha=0.5, | |
fade=True, | |
legend=True, | |
linewidth=1.0, | |
overlap=3, | |
figsize=(8, 6) # Adjust the figure size here | |
) | |
# Add title and labels | |
plt.title(f'Density Plots for {df["category"].unique()[0]}', fontsize=14) | |
plt.xlabel('Number of Pixels Density', fontsize=12) | |
plt.ylabel('Categories', fontsize=12) | |
# Show the plot in Streamlit | |
st.pyplot(fig) | |
def plot_pair_plots(complete_df): | |
""" | |
Generate and display pair plots for each category in the dataset. | |
Parameters: | |
complete_df (pd.DataFrame): The input DataFrame containing image features and categories. | |
Returns: | |
None | |
""" | |
# Define the features to be included in the pairplot | |
features = ['num_pixels_b', 'num_pixels_g', 'num_pixels_r', 'perceived_brightness'] | |
# Create a separate pairplot for each category | |
for category in complete_df['category'].unique(): | |
# Filter data for current category | |
category_df = complete_df[complete_df['category'] == category] | |
# Check if the filtered DataFrame is not empty | |
if not category_df.empty: | |
# Create PairGrid with hue and palette | |
g = sns.PairGrid(category_df, vars=features, hue='anomaly_status', palette={'normal': 'blue', 'anomalous': 'red'}) | |
# Map the plots to the grid | |
g.map_upper(sns.scatterplot, alpha=0.6) | |
g.map_diag(sns.histplot, kde=True) | |
g.map_lower(sns.scatterplot, alpha=0.6) | |
# Add legend | |
g.add_legend() | |
# Customize the plot | |
g.figure.suptitle(f'Feature Relationships for {category.title()}', y=1.02, fontsize=14) | |
# Improve label readability | |
for i in range(len(g.axes)): | |
for j in range(len(g.axes)): | |
if g.axes[i][j] is not None: | |
g.axes[i][j].set_xlabel(g.axes[i][j].get_xlabel().replace('_', ' ').title()) | |
g.axes[i][j].set_ylabel(g.axes[i][j].get_ylabel().replace('_', ' ').title()) | |
# Adjust legend position to the right without overlapping the plots | |
g._legend.set_bbox_to_anchor((1.05, 0.5)) | |
g._legend.set_loc('center left') | |
plt.tight_layout() | |
st.pyplot(g.figure) | |