File size: 3,659 Bytes
9f6b670
5ffb8a0
 
 
 
 
 
9f6b670
 
 
 
87a1e68
9f6b670
 
 
 
87a1e68
9f6b670
bef3075
b43dafd
bef3075
5ffb8a0
bef3075
 
 
 
5ffb8a0
bef3075
 
5ffb8a0
 
 
 
 
b43dafd
 
bef3075
b43dafd
 
 
 
 
 
 
 
 
9f6b670
ac157d0
9f6b670
 
 
 
 
 
 
b43dafd
 
 
 
 
 
 
 
9f6b670
 
87a1e68
9f6b670
 
b43dafd
 
 
9f6b670
b43dafd
 
 
 
 
 
9f6b670
 
 
 
 
 
 
87a1e68
9f6b670
 
87a1e68
9f6b670
 
 
 
 
b43dafd
 
 
bef3075
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import streamlit as st

# Must be the first Streamlit command
st.set_page_config(page_title="Image Classifier", layout="wide")

# Now import other libraries
import tensorflow as tf
import numpy as np
import pickle
from PIL import Image

# Constants
MODEL_PATH = "image_classification.h5"
LABEL_ENCODER_PATH = "le.pkl"
EXPECTED_SIZE = (64, 64)  # Update this based on your model's input shape

def load_resources():
    """Load model and label encoder with custom object handling."""
    try:
        # Define custom objects to handle compatibility issues
        custom_objects = {
            # Handle InputLayer batch_shape issue
            'InputLayer': lambda **kwargs: tf.keras.layers.InputLayer(**{k: v for k, v in kwargs.items() if k != 'batch_shape'}),
            # Handle DTypePolicy issue
            'DTypePolicy': tf.keras.mixed_precision.Policy
        }
        
        # Try loading with custom objects
        model = tf.keras.models.load_model(
            MODEL_PATH,
            compile=False,
            custom_objects=custom_objects
        )
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        st.error("Please ensure you're using TensorFlow 2.x and the model file is not corrupted.")
        return None, None
    
    try:
        with open(LABEL_ENCODER_PATH, "rb") as f:
            label_encoder = pickle.load(f)
    except Exception as e:
        st.error(f"Error loading label encoder: {str(e)}")
        return None, None
    
    return model, label_encoder

# Load resources
model, label_encoder = load_resources()

def preprocess_image(image):
    """Resize image to match model input shape."""
    image = image.resize(EXPECTED_SIZE)  # Resize to match model input
    image_array = np.array(image)  # Convert to numpy array
    
    # Ensure image has 3 channels (for RGB)
    if len(image_array.shape) == 2:  # Grayscale image
        image_array = np.stack((image_array,)*3, axis=-1)
    elif image_array.shape[2] == 4:  # RGBA image
        image_array = image_array[:, :, :3]
    
    image_array = image_array.astype('float32') / 255.0  # Normalize to [0,1]
    image_array = np.expand_dims(image_array, axis=0)  # Add batch dimension
    return image_array

def predict(image):
    """Predict the class of the uploaded image."""
    if model is None or label_encoder is None:
        return "Model or label encoder not loaded properly"
    
    image_array = preprocess_image(image)
    try:
        preds = model.predict(image_array)
        class_index = np.argmax(preds)
        return label_encoder.inverse_transform([class_index])[0]
    except Exception as e:
        return f"Error during prediction: {str(e)}"

# Streamlit UI
st.markdown("""
    <h1 style='text-align: center; color: #4A90E2;'>🖼️ Image Classification App</h1>
    <p style='text-align: center; font-size: 18px;'>Upload an image and let our model classify it for you!</p>
    <hr>
""", unsafe_allow_html=True)

st.sidebar.header("Upload Your Image")
uploaded_file = st.sidebar.file_uploader("Choose an image", type=["jpg", "png", "jpeg"], help="Supported formats: JPG, PNG, JPEG")

if uploaded_file:
    image = Image.open(uploaded_file)
    st.image(image, caption="📸 Uploaded Image", use_column_width=True)
    
    if st.button("🔍 Classify Image", use_container_width=True):
        if model is None or label_encoder is None:
            st.error("Model or label encoder failed to load. Please check the files.")
        else:
            with st.spinner('Predicting...'):
                prediction = predict(image)
                st.success(f"🎯 Predicted Class: {prediction}")