File size: 3,659 Bytes
9f6b670 5ffb8a0 9f6b670 87a1e68 9f6b670 87a1e68 9f6b670 bef3075 b43dafd bef3075 5ffb8a0 bef3075 5ffb8a0 bef3075 5ffb8a0 b43dafd bef3075 b43dafd 9f6b670 ac157d0 9f6b670 b43dafd 9f6b670 87a1e68 9f6b670 b43dafd 9f6b670 b43dafd 9f6b670 87a1e68 9f6b670 87a1e68 9f6b670 b43dafd bef3075 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import streamlit as st
# Must be the first Streamlit command
st.set_page_config(page_title="Image Classifier", layout="wide")
# Now import other libraries
import tensorflow as tf
import numpy as np
import pickle
from PIL import Image
# Constants
MODEL_PATH = "image_classification.h5"
LABEL_ENCODER_PATH = "le.pkl"
EXPECTED_SIZE = (64, 64) # Update this based on your model's input shape
def load_resources():
"""Load model and label encoder with custom object handling."""
try:
# Define custom objects to handle compatibility issues
custom_objects = {
# Handle InputLayer batch_shape issue
'InputLayer': lambda **kwargs: tf.keras.layers.InputLayer(**{k: v for k, v in kwargs.items() if k != 'batch_shape'}),
# Handle DTypePolicy issue
'DTypePolicy': tf.keras.mixed_precision.Policy
}
# Try loading with custom objects
model = tf.keras.models.load_model(
MODEL_PATH,
compile=False,
custom_objects=custom_objects
)
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.error("Please ensure you're using TensorFlow 2.x and the model file is not corrupted.")
return None, None
try:
with open(LABEL_ENCODER_PATH, "rb") as f:
label_encoder = pickle.load(f)
except Exception as e:
st.error(f"Error loading label encoder: {str(e)}")
return None, None
return model, label_encoder
# Load resources
model, label_encoder = load_resources()
def preprocess_image(image):
"""Resize image to match model input shape."""
image = image.resize(EXPECTED_SIZE) # Resize to match model input
image_array = np.array(image) # Convert to numpy array
# Ensure image has 3 channels (for RGB)
if len(image_array.shape) == 2: # Grayscale image
image_array = np.stack((image_array,)*3, axis=-1)
elif image_array.shape[2] == 4: # RGBA image
image_array = image_array[:, :, :3]
image_array = image_array.astype('float32') / 255.0 # Normalize to [0,1]
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
return image_array
def predict(image):
"""Predict the class of the uploaded image."""
if model is None or label_encoder is None:
return "Model or label encoder not loaded properly"
image_array = preprocess_image(image)
try:
preds = model.predict(image_array)
class_index = np.argmax(preds)
return label_encoder.inverse_transform([class_index])[0]
except Exception as e:
return f"Error during prediction: {str(e)}"
# Streamlit UI
st.markdown("""
<h1 style='text-align: center; color: #4A90E2;'>🖼️ Image Classification App</h1>
<p style='text-align: center; font-size: 18px;'>Upload an image and let our model classify it for you!</p>
<hr>
""", unsafe_allow_html=True)
st.sidebar.header("Upload Your Image")
uploaded_file = st.sidebar.file_uploader("Choose an image", type=["jpg", "png", "jpeg"], help="Supported formats: JPG, PNG, JPEG")
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="📸 Uploaded Image", use_column_width=True)
if st.button("🔍 Classify Image", use_container_width=True):
if model is None or label_encoder is None:
st.error("Model or label encoder failed to load. Please check the files.")
else:
with st.spinner('Predicting...'):
prediction = predict(image)
st.success(f"🎯 Predicted Class: {prediction}") |