Update app.py
Browse files
app.py
CHANGED
@@ -15,9 +15,20 @@ EXPECTED_SIZE = (64, 64) # Update this based on your model's input shape
|
|
15 |
|
16 |
def load_resources():
|
17 |
"""Load model and label encoder."""
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
return model, label_encoder
|
22 |
|
23 |
# Load resources
|
@@ -27,15 +38,29 @@ def preprocess_image(image):
|
|
27 |
"""Resize image to match model input shape."""
|
28 |
image = image.resize(EXPECTED_SIZE) # Resize to match model input
|
29 |
image_array = np.array(image) # Convert to numpy array
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
|
31 |
return image_array
|
32 |
|
33 |
def predict(image):
|
34 |
"""Predict the class of the uploaded image."""
|
|
|
|
|
|
|
35 |
image_array = preprocess_image(image)
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
# Streamlit UI
|
41 |
st.set_page_config(page_title="Image Classifier", layout="wide")
|
@@ -53,5 +78,8 @@ if uploaded_file:
|
|
53 |
st.image(image, caption="πΈ Uploaded Image", use_column_width=True)
|
54 |
|
55 |
if st.button("π Classify Image", use_container_width=True):
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
15 |
|
16 |
def load_resources():
|
17 |
"""Load model and label encoder."""
|
18 |
+
try:
|
19 |
+
# Try loading with the newer Keras API first
|
20 |
+
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
21 |
+
except Exception as e:
|
22 |
+
st.error(f"Error loading model: {str(e)}")
|
23 |
+
return None, None
|
24 |
+
|
25 |
+
try:
|
26 |
+
with open(LABEL_ENCODER_PATH, "rb") as f:
|
27 |
+
label_encoder = pickle.load(f)
|
28 |
+
except Exception as e:
|
29 |
+
st.error(f"Error loading label encoder: {str(e)}")
|
30 |
+
return None, None
|
31 |
+
|
32 |
return model, label_encoder
|
33 |
|
34 |
# Load resources
|
|
|
38 |
"""Resize image to match model input shape."""
|
39 |
image = image.resize(EXPECTED_SIZE) # Resize to match model input
|
40 |
image_array = np.array(image) # Convert to numpy array
|
41 |
+
|
42 |
+
# Ensure image has 3 channels (for RGB)
|
43 |
+
if len(image_array.shape) == 2: # Grayscale image
|
44 |
+
image_array = np.stack((image_array,)*3, axis=-1)
|
45 |
+
elif image_array.shape[2] == 4: # RGBA image
|
46 |
+
image_array = image_array[:, :, :3]
|
47 |
+
|
48 |
+
image_array = image_array.astype('float32') / 255.0 # Normalize to [0,1]
|
49 |
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
|
50 |
return image_array
|
51 |
|
52 |
def predict(image):
|
53 |
"""Predict the class of the uploaded image."""
|
54 |
+
if model is None or label_encoder is None:
|
55 |
+
return "Model or label encoder not loaded properly"
|
56 |
+
|
57 |
image_array = preprocess_image(image)
|
58 |
+
try:
|
59 |
+
preds = model.predict(image_array)
|
60 |
+
class_index = np.argmax(preds)
|
61 |
+
return label_encoder.inverse_transform([class_index])[0]
|
62 |
+
except Exception as e:
|
63 |
+
return f"Error during prediction: {str(e)}"
|
64 |
|
65 |
# Streamlit UI
|
66 |
st.set_page_config(page_title="Image Classifier", layout="wide")
|
|
|
78 |
st.image(image, caption="πΈ Uploaded Image", use_column_width=True)
|
79 |
|
80 |
if st.button("π Classify Image", use_container_width=True):
|
81 |
+
if model is None or label_encoder is None:
|
82 |
+
st.error("Model or label encoder failed to load. Please check the files.")
|
83 |
+
else:
|
84 |
+
prediction = predict(image)
|
85 |
+
st.success(f"π― Predicted Class: {prediction}")
|