Update app.py
Browse files
app.py
CHANGED
@@ -190,6 +190,45 @@ class readDataset:
|
|
190 |
np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
|
191 |
np.array(test_sar), np.array(test_optic), np.array(test_masks)
|
192 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
# Streamlit App Title
|
194 |
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
|
195 |
|
@@ -224,7 +263,8 @@ if st.button("Run Inference"):
|
|
224 |
optic_images = dataset.normalizeImages(optic_images, 'i')
|
225 |
|
226 |
# Load model
|
227 |
-
model = load_model(model_path
|
|
|
228 |
|
229 |
# Predict
|
230 |
pred_masks = model.predict([optic_images, sar_images], verbose=0)
|
|
|
190 |
np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
|
191 |
np.array(test_sar), np.array(test_optic), np.array(test_masks)
|
192 |
)
|
193 |
+
|
194 |
+
@tf.keras.saving.register_keras_serializable()
|
195 |
+
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
|
196 |
+
#determine binary or multiclass segmentation
|
197 |
+
is_multiclass = y_true.shape[-1] > 1
|
198 |
+
|
199 |
+
if not is_multiclass:
|
200 |
+
# Binary segmentation
|
201 |
+
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
|
202 |
+
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
|
203 |
+
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
|
204 |
+
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
|
205 |
+
return score
|
206 |
+
else:
|
207 |
+
# Multiclass segmentation
|
208 |
+
num_classes = y_true.shape[-1]
|
209 |
+
score_per_class = []
|
210 |
+
|
211 |
+
for i in range(num_classes):
|
212 |
+
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
|
213 |
+
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
|
214 |
+
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
|
215 |
+
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
|
216 |
+
score_per_class.append(score)
|
217 |
+
|
218 |
+
return tf.reduce_mean(score_per_class)
|
219 |
+
|
220 |
+
@tf.keras.saving.register_keras_serializable()
|
221 |
+
def dice_loss(y_true, y_pred):
|
222 |
+
dice = dice_score(y_true, y_pred)
|
223 |
+
loss = 1. - dice
|
224 |
+
return tf.cast(loss, dtype=tf.float32)
|
225 |
+
|
226 |
+
@tf.keras.saving.register_keras_serializable()
|
227 |
+
def cce_dice_loss(y_true, y_pred):
|
228 |
+
cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
|
229 |
+
dice = dice_loss(y_true, y_pred)
|
230 |
+
return tf.cast(cce, dtype=tf.float32) + dice
|
231 |
+
|
232 |
# Streamlit App Title
|
233 |
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
|
234 |
|
|
|
263 |
optic_images = dataset.normalizeImages(optic_images, 'i')
|
264 |
|
265 |
# Load model
|
266 |
+
model = tf.keras.models.load_model(model_path,
|
267 |
+
custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
|
268 |
|
269 |
# Predict
|
270 |
pred_masks = model.predict([optic_images, sar_images], verbose=0)
|