Alexvatti commited on
Commit
c652f62
·
verified ·
1 Parent(s): 33a73d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -1
app.py CHANGED
@@ -190,6 +190,45 @@ class readDataset:
190
  np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
191
  np.array(test_sar), np.array(test_optic), np.array(test_masks)
192
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # Streamlit App Title
194
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
195
 
@@ -224,7 +263,8 @@ if st.button("Run Inference"):
224
  optic_images = dataset.normalizeImages(optic_images, 'i')
225
 
226
  # Load model
227
- model = load_model(model_path)
 
228
 
229
  # Predict
230
  pred_masks = model.predict([optic_images, sar_images], verbose=0)
 
190
  np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
191
  np.array(test_sar), np.array(test_optic), np.array(test_masks)
192
  )
193
+
194
+ @tf.keras.saving.register_keras_serializable()
195
+ def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
196
+ #determine binary or multiclass segmentation
197
+ is_multiclass = y_true.shape[-1] > 1
198
+
199
+ if not is_multiclass:
200
+ # Binary segmentation
201
+ y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
202
+ y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
203
+ intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
204
+ score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
205
+ return score
206
+ else:
207
+ # Multiclass segmentation
208
+ num_classes = y_true.shape[-1]
209
+ score_per_class = []
210
+
211
+ for i in range(num_classes):
212
+ y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
213
+ y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
214
+ intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
215
+ score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
216
+ score_per_class.append(score)
217
+
218
+ return tf.reduce_mean(score_per_class)
219
+
220
+ @tf.keras.saving.register_keras_serializable()
221
+ def dice_loss(y_true, y_pred):
222
+ dice = dice_score(y_true, y_pred)
223
+ loss = 1. - dice
224
+ return tf.cast(loss, dtype=tf.float32)
225
+
226
+ @tf.keras.saving.register_keras_serializable()
227
+ def cce_dice_loss(y_true, y_pred):
228
+ cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
229
+ dice = dice_loss(y_true, y_pred)
230
+ return tf.cast(cce, dtype=tf.float32) + dice
231
+
232
  # Streamlit App Title
233
  st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
234
 
 
263
  optic_images = dataset.normalizeImages(optic_images, 'i')
264
 
265
  # Load model
266
+ model = tf.keras.models.load_model(model_path,
267
+ custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
268
 
269
  # Predict
270
  pred_masks = model.predict([optic_images, sar_images], verbose=0)