Hem345 commited on
Commit
c89ca4b
·
verified ·
1 Parent(s): 1541b0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -5,6 +5,7 @@ from tensorflow import keras
5
  from tensorflow.keras import layers
6
  from tensorflow.keras.datasets import mnist
7
  import streamlit as st
 
8
 
9
  # Load the MNIST dataset
10
  (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
@@ -36,11 +37,17 @@ def create_model():
36
  # Streamlit UI
37
  st.title("CNN for MNIST Classification")
38
 
 
 
 
39
  if st.button("Train Model"):
40
  model = create_model()
41
  with st.spinner("Training..."):
42
  history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64)
43
-
 
 
 
44
  # Plot training loss and accuracy
45
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
46
 
@@ -76,13 +83,16 @@ def test_index_prediction(index):
76
  image = test_images[index].reshape(28, 28)
77
  st.image(image, caption=f"True Label: {true_labels[index]}", use_column_width=True)
78
 
 
 
 
 
 
 
 
79
  prediction = model.predict(test_images[index].reshape(1, 28, 28, 1))
80
  predicted_class = np.argmax(prediction)
81
  st.write(f"Predicted Class: {predicted_class}")
82
 
83
  if st.button("Test Index"):
84
- if 'model' in locals() and model:
85
- test_index_prediction(index)
86
- else:
87
- st.error("Train the model first.")
88
-
 
5
  from tensorflow.keras import layers
6
  from tensorflow.keras.datasets import mnist
7
  import streamlit as st
8
+ import os
9
 
10
  # Load the MNIST dataset
11
  (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
37
  # Streamlit UI
38
  st.title("CNN for MNIST Classification")
39
 
40
+ # Check if model is saved
41
+ model_path = "mnist_cnn_model.h5"
42
+
43
  if st.button("Train Model"):
44
  model = create_model()
45
  with st.spinner("Training..."):
46
  history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64)
47
+
48
+ # Save the model
49
+ model.save(model_path)
50
+
51
  # Plot training loss and accuracy
52
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
53
 
 
83
  image = test_images[index].reshape(28, 28)
84
  st.image(image, caption=f"True Label: {true_labels[index]}", use_column_width=True)
85
 
86
+ # Reload the model if needed
87
+ if not os.path.exists(model_path):
88
+ st.error("Train the model first.")
89
+ return
90
+
91
+ model = keras.models.load_model(model_path)
92
+
93
  prediction = model.predict(test_images[index].reshape(1, 28, 28, 1))
94
  predicted_class = np.argmax(prediction)
95
  st.write(f"Predicted Class: {predicted_class}")
96
 
97
  if st.button("Test Index"):
98
+ test_index_prediction(index)