import os import streamlit as st # Must be the first Streamlit command st.set_page_config(page_title="Image Classifier", layout="wide") # Now import other libraries import tensorflow as tf import numpy as np import pickle from PIL import Image # Constants MODEL_PATH = "image_classification.h5" LABEL_ENCODER_PATH = "le.pkl" EXPECTED_SIZE = (64, 64) # Update this based on your model's input shape def load_resources(): """Load model and label encoder with custom object handling.""" try: # Define custom objects to handle compatibility issues custom_objects = { # Handle InputLayer batch_shape issue 'InputLayer': lambda **kwargs: tf.keras.layers.InputLayer(**{k: v for k, v in kwargs.items() if k != 'batch_shape'}), # Handle DTypePolicy issue 'DTypePolicy': tf.keras.mixed_precision.Policy } # Try loading with custom objects model = tf.keras.models.load_model( MODEL_PATH, compile=False, custom_objects=custom_objects ) except Exception as e: st.error(f"Error loading model: {str(e)}") st.error("Please ensure you're using TensorFlow 2.x and the model file is not corrupted.") return None, None try: with open(LABEL_ENCODER_PATH, "rb") as f: label_encoder = pickle.load(f) except Exception as e: st.error(f"Error loading label encoder: {str(e)}") return None, None return model, label_encoder # Load resources model, label_encoder = load_resources() def preprocess_image(image): """Resize image to match model input shape.""" image = image.resize(EXPECTED_SIZE) # Resize to match model input image_array = np.array(image) # Convert to numpy array # Ensure image has 3 channels (for RGB) if len(image_array.shape) == 2: # Grayscale image image_array = np.stack((image_array,)*3, axis=-1) elif image_array.shape[2] == 4: # RGBA image image_array = image_array[:, :, :3] image_array = image_array.astype('float32') / 255.0 # Normalize to [0,1] image_array = np.expand_dims(image_array, axis=0) # Add batch dimension return image_array def predict(image): """Predict the class of the uploaded image.""" if model is None or label_encoder is None: return "Model or label encoder not loaded properly" image_array = preprocess_image(image) try: preds = model.predict(image_array) class_index = np.argmax(preds) return label_encoder.inverse_transform([class_index])[0] except Exception as e: return f"Error during prediction: {str(e)}" # Streamlit UI st.markdown("""
Upload an image and let our model classify it for you!