Hem345 commited on
Commit
3bd9c33
·
verified ·
1 Parent(s): 6f3b8fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -19
app.py CHANGED
@@ -40,25 +40,10 @@ st.title("CNN for MNIST Classification")
40
  # Check if model is saved
41
  model_path = "mnist_cnn_model.h5"
42
 
43
- # Custom callback for logging
44
- class StreamlitLogger(keras.callbacks.Callback):
45
- def on_epoch_end(self, epoch, logs=None):
46
- if logs is none:
47
- logs = {}
48
-
49
- st.write(f"Epoch {epoch + 1}:")
50
- st.write(f" Train Loss: {logs.get('loss'):.4f}")
51
- st.write(f" Train Accuracy: {logs.get('accuracy'):.4f}")
52
- st.write(f" Val Loss: {logs.get('val_loss'):.4f}")
53
- st.write(f" Val Accuracy: {logs.get('val_accuracy'):.4f}")
54
-
55
  if st.button("Train Model"):
56
  model = create_model()
57
-
58
- logger = StreamlitLogger()
59
-
60
  with st.spinner("Training..."):
61
- history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64, callbacks=[logger])
62
 
63
  # Save the model
64
  model.save(model_path)
@@ -71,14 +56,13 @@ if st.button("Train Model"):
71
  ax1.set_title("Training and Validation Loss")
72
  ax1.set_xlabel("Epoch")
73
  ax1.set_ylabel("Loss")
 
74
 
75
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
76
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
77
  ax2.set_title("Training and Validation Accuracy")
78
  ax2.set_xlabel("Epoch")
79
  ax2.set_ylabel("Accuracy")
80
-
81
- ax1.legend()
82
  ax2.legend()
83
 
84
  st.pyplot(fig)
@@ -87,19 +71,22 @@ if st.button("Train Model"):
87
  test_preds = np.argmax(model.predict(test_images), axis=1)
88
  true_labels = np.argmax(test_labels, axis=1)
89
 
 
90
  st.session_state['true_labels'] = true_labels
91
 
 
92
  report = classification_report(true_labels, test_preds, digits=4)
93
  st.text("Classification Report:")
94
  st.text(report)
95
 
 
96
  index = st.number_input("Enter an index (0-9999) to test:", min_value=0, max_value=9999, step=1)
97
 
98
  def test_index_prediction(index):
99
  image = test_images[index].reshape(28, 28)
100
  st.image(image, caption=f"True Label: {st.session_state['true_labels'][index]}", use_column_width=True)
101
 
102
- # Reload the model
103
  if not os.path.exists(model_path):
104
  st.error("Train the model first.")
105
  return
 
40
  # Check if model is saved
41
  model_path = "mnist_cnn_model.h5"
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if st.button("Train Model"):
44
  model = create_model()
 
 
 
45
  with st.spinner("Training..."):
46
+ history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64)
47
 
48
  # Save the model
49
  model.save(model_path)
 
56
  ax1.set_title("Training and Validation Loss")
57
  ax1.set_xlabel("Epoch")
58
  ax1.set_ylabel("Loss")
59
+ ax1.legend()
60
 
61
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
62
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
63
  ax2.set_title("Training and Validation Accuracy")
64
  ax2.set_xlabel("Epoch")
65
  ax2.set_ylabel("Accuracy")
 
 
66
  ax2.legend()
67
 
68
  st.pyplot(fig)
 
71
  test_preds = np.argmax(model.predict(test_images), axis=1)
72
  true_labels = np.argmax(test_labels, axis=1)
73
 
74
+ # Store the test labels globally for later use
75
  st.session_state['true_labels'] = true_labels
76
 
77
+ # Classification report
78
  report = classification_report(true_labels, test_preds, digits=4)
79
  st.text("Classification Report:")
80
  st.text(report)
81
 
82
+ # Testing with a specific index
83
  index = st.number_input("Enter an index (0-9999) to test:", min_value=0, max_value=9999, step=1)
84
 
85
  def test_index_prediction(index):
86
  image = test_images[index].reshape(28, 28)
87
  st.image(image, caption=f"True Label: {st.session_state['true_labels'][index]}", use_column_width=True)
88
 
89
+ # Reload the model if needed
90
  if not os.path.exists(model_path):
91
  st.error("Train the model first.")
92
  return