|
import os |
|
import streamlit as st |
|
|
|
|
|
st.set_page_config(page_title="Image Classifier", layout="wide") |
|
|
|
|
|
import tensorflow as tf |
|
import numpy as np |
|
import pickle |
|
from PIL import Image |
|
|
|
|
|
MODEL_PATH = "image_classification.h5" |
|
LABEL_ENCODER_PATH = "le.pkl" |
|
EXPECTED_SIZE = (64, 64) |
|
|
|
def load_resources(): |
|
"""Load model and label encoder with custom object handling.""" |
|
try: |
|
|
|
custom_objects = { |
|
|
|
'InputLayer': lambda **kwargs: tf.keras.layers.InputLayer(**{k: v for k, v in kwargs.items() if k != 'batch_shape'}), |
|
|
|
'DTypePolicy': tf.keras.mixed_precision.Policy |
|
} |
|
|
|
|
|
model = tf.keras.models.load_model( |
|
MODEL_PATH, |
|
compile=False, |
|
custom_objects=custom_objects |
|
) |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.error("Please ensure you're using TensorFlow 2.x and the model file is not corrupted.") |
|
return None, None |
|
|
|
try: |
|
with open(LABEL_ENCODER_PATH, "rb") as f: |
|
label_encoder = pickle.load(f) |
|
except Exception as e: |
|
st.error(f"Error loading label encoder: {str(e)}") |
|
return None, None |
|
|
|
return model, label_encoder |
|
|
|
|
|
model, label_encoder = load_resources() |
|
|
|
def preprocess_image(image): |
|
"""Resize image to match model input shape.""" |
|
image = image.resize(EXPECTED_SIZE) |
|
image_array = np.array(image) |
|
|
|
|
|
if len(image_array.shape) == 2: |
|
image_array = np.stack((image_array,)*3, axis=-1) |
|
elif image_array.shape[2] == 4: |
|
image_array = image_array[:, :, :3] |
|
|
|
image_array = image_array.astype('float32') / 255.0 |
|
image_array = np.expand_dims(image_array, axis=0) |
|
return image_array |
|
|
|
def predict(image): |
|
"""Predict the class of the uploaded image.""" |
|
if model is None or label_encoder is None: |
|
return "Model or label encoder not loaded properly" |
|
|
|
image_array = preprocess_image(image) |
|
try: |
|
preds = model.predict(image_array) |
|
class_index = np.argmax(preds) |
|
return label_encoder.inverse_transform([class_index])[0] |
|
except Exception as e: |
|
return f"Error during prediction: {str(e)}" |
|
|
|
|
|
st.markdown(""" |
|
<h1 style='text-align: center; color: #4A90E2;'>πΌοΈ Image Classification App</h1> |
|
<p style='text-align: center; font-size: 18px;'>Upload an image and let our model classify it for you!</p> |
|
<hr> |
|
""", unsafe_allow_html=True) |
|
|
|
st.sidebar.header("Upload Your Image") |
|
uploaded_file = st.sidebar.file_uploader("Choose an image", type=["jpg", "png", "jpeg"], help="Supported formats: JPG, PNG, JPEG") |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="πΈ Uploaded Image", use_column_width=True) |
|
|
|
if st.button("π Classify Image", use_container_width=True): |
|
if model is None or label_encoder is None: |
|
st.error("Model or label encoder failed to load. Please check the files.") |
|
else: |
|
with st.spinner('Predicting...'): |
|
prediction = predict(image) |
|
st.success(f"π― Predicted Class: {prediction}") |