File size: 14,035 Bytes
b90fed9
 
 
 
 
b2d94fd
 
 
 
87b6bca
33a73d2
227bb08
469f05c
6695dcb
 
4487561
496e5c7
 
b90fed9
 
 
 
 
 
 
b6c25e3
b90fed9
 
 
 
 
 
c652f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a52b712
662786f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd33515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662786f
cd33515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a52b712
e4c18ca
 
 
 
 
 
 
e8a87d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38bdd0c
 
 
e8a87d7
38bdd0c
e8a87d7
38bdd0c
 
 
e8a87d7
 
38bdd0c
e8a87d7
 
 
f1e0408
e8a87d7
 
 
34a7c7d
f1e0408
38bdd0c
 
 
34a7c7d
 
38bdd0c
 
34a7c7d
 
38bdd0c
 
 
 
 
 
34a7c7d
38bdd0c
e4c18ca
b90fed9
 
 
cd33515
 
 
 
6695dcb
cd33515
6bc373a
d98d462
 
6695dcb
d98d462
6bc373a
d98d462
 
 
 
 
3610efa
 
6695dcb
 
 
 
 
 
 
d98d462
 
 
 
 
 
 
 
 
c3f9b72
 
 
d98d462
c3f9b72
 
d98d462
c3f9b72
d98d462
 
 
 
 
 
c3f9b72
 
d98d462
c3f9b72
d98d462
 
c3f9b72
d98d462
c3f9b72
 
d98d462
 
c3f9b72
 
 
d98d462
 
c3f9b72
 
 
d98d462
 
c3f9b72
 
 
d98d462
c3f9b72
 
 
 
 
d98d462
 
c3f9b72
 
 
d98d462
c3f9b72
 
 
 
 
d98d462
c3f9b72
d98d462
c3f9b72
 
d98d462
c3f9b72
 
 
d98d462
 
c3f9b72
d98d462
c3f9b72
 
d98d462
c3f9b72
 
 
 
d98d462
c3f9b72
 
 
 
d98d462
c3f9b72
 
 
 
 
 
d98d462
c3f9b72
 
 
 
d98d462
c3f9b72
 
 
 
 
 
d98d462
c3f9b72
 
 
 
d98d462
c3f9b72
 
6695dcb
7dffef2
e8a87d7
 
 
 
5a39277
e8a87d7
 
 
 
 
38bdd0c
f1e0408
38bdd0c
e8a87d7
 
 
f1e0408
e8a87d7
b6c25e3
1334359
e8a87d7
 
1334359
6695dcb
d98d462
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import streamlit as st
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import pathlib
import natsort
import datetime
import shutil
import rasterio
import cv2
import tensorflow as tf
import tempfile
from rasterio import features
from shapely.geometry import shape
from rasterio.features import shapes
import geopandas as gpd
import zipfile

# Configuration
HEIGHT = WIDTH = 256
SAR_SHAPE = (HEIGHT, WIDTH, 1)
OPTIC_SHAPE = (HEIGHT, WIDTH, 3)
MASK_SHAPE = (HEIGHT, WIDTH, 4)  # One-hot encoded masks with 4 classes

# Class colors: non-mining (green), mining (red), beach (black)
CLASS_COLORS = [
    [115/255, 178/255, 115/255],
    [1, 0, 0],
    [0, 0, 0]
]

@tf.keras.saving.register_keras_serializable()
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
    #determine binary or multiclass segmentation
    is_multiclass = y_true.shape[-1] > 1
    
    if not is_multiclass:
        # Binary segmentation
        y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
        y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
        intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
        score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) 
        return score
    else:
        # Multiclass segmentation
        num_classes = y_true.shape[-1]
        score_per_class = []
        
        for i in range(num_classes):
            y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
            y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
            intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
            score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
            score_per_class.append(score)
        
        return tf.reduce_mean(score_per_class)
    
@tf.keras.saving.register_keras_serializable() 
def dice_loss(y_true, y_pred):
    dice = dice_score(y_true, y_pred)
    loss = 1. - dice
    return tf.cast(loss, dtype=tf.float32)

@tf.keras.saving.register_keras_serializable()
def cce_dice_loss(y_true, y_pred):
    cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return tf.cast(cce, dtype=tf.float32) + dice

def convertColorToLabel(img):
    color_to_label = {
        (115, 178, 115): 0,  # non_mining_land (green)
        (255, 0, 0): 1,      # illegal_mining_land (red)
        (0, 0, 0): 2,        # beach (black)
    }

    # Create empty label array
    label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)

    # Map each RGB color to its corresponding label
    for color, label in color_to_label.items():
        mask = np.all(img == color, axis=2)
        label_img[mask] = label

    # One-hot encode the label image
    num_classes = len(color_to_label)
    one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8)
    for c in range(num_classes):
        one_hot[:, :, c] = (label_img == c).astype(np.uint8)

    return one_hot

def readImages(data, typeData, width, height):
    images = []
    for img in data:
        if typeData == 's':  # SAR image
            with rasterio.open(str(img)) as src:
                sar_bands = [src.read(i) for i in range(1, src.count + 1)]
                sar_image = np.stack(sar_bands, axis=-1)

            # Contrast stretching
            p2, p98 = np.percentile(sar_image, (2, 98))
            sar_image = np.clip(sar_image, p2, p98)
            sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)

            # Resize
            sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
            images.append(np.expand_dims(sar_image, axis=-1))

        elif typeData == 'm':  # Mask image
            img = cv2.imread(str(img), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
            images.append(convertColorToLabel(img))

        elif typeData == 'o':  # Optic image
            img = cv2.imread(str(img), cv2.IMREAD_COLOR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
            images.append(img)

    print(f"(INFO..) Read {len(images)} '{typeData}' image(s)")
    return np.array(images)


def normalizeImages(images, typeData):
    normalized_images = []
    for img in images:
        img = img.astype(np.uint8)
        if typeData in ['s', 'o']:
            img = img / 255.
        normalized_images.append(img)

    print("(INFO..) Normalization Image Done")
    return np.array(normalized_images)



def save_uploaded_file(uploaded_file, suffix=".tif"):
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(uploaded_file.read())
        return tmp.name  

def get_transform_from_tif(tif_path):
    """
    Extracts the affine transform and CRS from a GeoTIFF image.

    Args:
        tif_path (str): Path to the GeoTIFF file.

    Returns:
        transform (Affine): Affine transformation for pixel to coordinate mapping.
        crs (CRS): Coordinate Reference System of the image.
    """
    with rasterio.open(tif_path) as src:
        transform = src.transform
        crs = src.crs
    return transform, crs
    
def mask_to_polygons(mask, transform, mining_class_id=1):
    # Create a binary mask for mining areas
    mining_mask = (mask == mining_class_id).astype(np.uint8)

    # Extract shapes (polygons) from the mask
    results = (
        {'properties': {'class': 'mining_land'}, 'geometry': s}
        for s, v in shapes(mining_mask, mask=None, transform=transform)
        if v == 1  # Only keep areas marked as mining
    )

    # Create a GeoDataFrame from the polygon results
    geoms = list(results)
    if not geoms:
        return gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs='EPSG:4326')

    gdf = gpd.GeoDataFrame.from_features(geoms)
    gdf.set_crs(epsg=4326, inplace=True)
    return gdf


# Function to calculate illegal mining area (mining outside WIUP)
def calculate_illegal_area(mining_gdf, wiup_gdf):
    # Ensure both are in the same CRS
    mining_gdf = mining_gdf.to_crs(wiup_gdf.crs)

    # Perform overlay operation: difference (mining area outside WIUP)
    illegal_area_gdf = gpd.overlay(mining_gdf, wiup_gdf, how='difference')

    # Convert to metric CRS (e.g., UTM) to calculate area in m²
    metric_crs = ' EPSG:4326'  # UTM zone for SE Asia (adjust based on location)
    illegal_area_gdf = illegal_area_gdf.to_crs(metric_crs)

    # Calculate the area in square meters
    illegal_area_gdf['area_m2'] = illegal_area_gdf.geometry.area
    total_illegal_area = illegal_area_gdf['area_m2'].sum()

    return total_illegal_area, illegal_area_gdf

# Streamlit App Title
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")


sar_file = st.file_uploader("Upload SAR Image", type=["tiff"])
optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"])

num_samples = 1
if st.button("Run Inference"):
    with st.spinner("Loading data and model..."):
        if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None:
            st.success("All files uploaded successfully!")
            
            # Save uploaded files
            sar_path = save_uploaded_file(sar_file, suffix=".tif")
            optic_path = save_uploaded_file(optic_file, suffix=".tif")
            mask_path = save_uploaded_file(mask_file, suffix=".tif")

            wiup_zip_path = save_uploaded_file(wiup_file, ".zip")

            extract_folder = wiup_zip_path.replace(".zip", "")
            with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref:
                zip_ref.extractall(extract_folder)

            # Load WIUP shapefile
            wiup_gdf = gpd.read_file(extract_folder)
            
            # Create image lists
            sarImages = [sar_path]
            opticImages = [optic_path]
            masks = [mask_path]

            # Model path
            model_path = "Residual_UNET_Bilinear.keras"

            # Read and normalize images
            sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
            optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
            masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)

            sar_images = normalizeImages(sar_images, 's')
            optic_images = normalizeImages(optic_images, 'i')

            # Load model
            model = tf.keras.models.load_model(
                model_path,
                custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score}
            )

            # Predict masks
            pred_masks = model.predict([optic_images, sar_images], verbose=0)
            is_multiclass = pred_masks.shape[-1] > 1

            num_samples = min(num_samples, len(sar_images))

            # Plotting results
            fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))

            for i in range(num_samples):
                ax = axes[i] if num_samples > 1 else axes

                # Plot SAR image
                ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
                ax[0].set_title(f"SAR Image {i+1}")
                ax[0].axis('off')

                # Plot Optic image
                ax[1].imshow(optic_images[i])
                ax[1].set_title(f"Optic Image {i+1}")
                ax[1].axis('off')

                # Plot Ground Truth Mask
                if is_multiclass:
                    gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
                    for j, color in enumerate(CLASS_COLORS):
                        gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
                    ax[2].imshow(gt_color_mask)
                else:
                    ax[2].imshow(masks[i], cmap='gray')
                ax[2].set_title(f"Ground Truth Mask {i+1}")
                ax[2].axis('off')

                # Plot Predicted Mask
                if is_multiclass:
                    pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
                    for j, color in enumerate(CLASS_COLORS):
                        pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
                    ax[3].imshow(pred_color_mask)
                else:
                    ax[3].imshow(pred_masks[i], cmap='gray')
                ax[3].set_title(f"Predicted Mask {i+1}")
                ax[3].axis('off')

            st.pyplot(fig)

            # Define color for class 1: illegal mining
            red_color = [255, 0, 0]

            # Convert optic_images to uint8 if needed
            if optic_images.dtype != np.uint8:
                optic_images = (optic_images * 255).astype(np.uint8)

            # Plot overlays
            fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))

            for i in range(num_samples):
                ax = axes[i] if num_samples > 1 else axes

                # SAR image
                ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
                ax[0].set_title(f"SAR Image {i+1}")
                ax[0].axis('off')

                # Optic image
                ax[1].imshow(optic_images[i])
                ax[1].set_title(f"Optic Image {i+1}")
                ax[1].axis('off')

                # Ground truth overlay
                gt_overlay = optic_images[i].copy()
                if is_multiclass:
                    gt_overlay[masks[i][:, :, 1] == 1] = red_color
                else:
                    gt_overlay[masks[i].squeeze() == 1] = red_color

                ax[2].imshow(optic_images[i])
                ax[2].imshow(gt_overlay, alpha=0.4)
                ax[2].set_title(f"Ground Truth Overlay {i+1}")
                ax[2].axis('off')

                # Predicted mask overlay
                pred_overlay = optic_images[i].copy()
                if is_multiclass:
                    pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
                else:
                    pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color

                ax[3].imshow(optic_images[i])
                ax[3].imshow(pred_overlay, alpha=0.4)
                ax[3].set_title(f"Predicted Overlay {i+1}")
                ax[3].axis('off')

            plt.tight_layout()
            st.pyplot(fig)

            
            transform, crs = get_transform_from_tif(optic_path)
            
            # After getting prediction from model
            pred_label_mask = np.argmax(pred_masks[0], axis=-1)  # shape: (H, W)
            st.success(f" Unique classes in mask: {np.unique(pred_label_mask)}")
            
            # Get georeference transform from TIF image
            transform, crs = get_transform_from_tif(optic_path)
            
            # Convert mask to mining polygons
            mining_gdf = mask_to_polygons(pred_label_mask, transform, mining_class_id=1)
            st.success(f"Mining polygons: {len(mining_gdf)}")

            
            # Make sure WIUP and prediction are in same CRS
            wiup_gdf = wiup_gdf.to_crs(mining_gdf.crs)
            st.success(f"WIUP CRS: {wiup_gdf.crs}")
            
            # Find area
            illegal_mining_area, illegal_mining_area_gdf = calculate_illegal_area(mining_gdf, wiup_gdf)
            
            # Display in Streamlit
            st.success(f" Illegal Mining Area : {illegal_mining_area/1e6:.2f} sq. km")                  
            
        else:
            st.warning("Please upload all three .tiff files to proceed.")