import streamlit as st import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import pathlib import natsort import datetime import shutil import rasterio import cv2 import tensorflow as tf import tempfile from rasterio import features from shapely.geometry import shape from rasterio.features import shapes import geopandas as gpd import zipfile # Configuration HEIGHT = WIDTH = 256 SAR_SHAPE = (HEIGHT, WIDTH, 1) OPTIC_SHAPE = (HEIGHT, WIDTH, 3) MASK_SHAPE = (HEIGHT, WIDTH, 4) # One-hot encoded masks with 4 classes # Class colors: non-mining (green), mining (red), beach (black) CLASS_COLORS = [ [115/255, 178/255, 115/255], [1, 0, 0], [0, 0, 0] ] @tf.keras.saving.register_keras_serializable() def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0): #determine binary or multiclass segmentation is_multiclass = y_true.shape[-1] > 1 if not is_multiclass: # Binary segmentation y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) intersection = tf.reduce_sum(y_true_flat * y_pred_flat) score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) return score else: # Multiclass segmentation num_classes = y_true.shape[-1] score_per_class = [] for i in range(num_classes): y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) intersection = tf.reduce_sum(y_true_flat * y_pred_flat) score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) score_per_class.append(score) return tf.reduce_mean(score_per_class) @tf.keras.saving.register_keras_serializable() def dice_loss(y_true, y_pred): dice = dice_score(y_true, y_pred) loss = 1. - dice return tf.cast(loss, dtype=tf.float32) @tf.keras.saving.register_keras_serializable() def cce_dice_loss(y_true, y_pred): cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred) dice = dice_loss(y_true, y_pred) return tf.cast(cce, dtype=tf.float32) + dice def convertColorToLabel(img): color_to_label = { (115, 178, 115): 0, # non_mining_land (green) (255, 0, 0): 1, # illegal_mining_land (red) (0, 0, 0): 2, # beach (black) } # Create empty label array label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) # Map each RGB color to its corresponding label for color, label in color_to_label.items(): mask = np.all(img == color, axis=2) label_img[mask] = label # One-hot encode the label image num_classes = len(color_to_label) one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8) for c in range(num_classes): one_hot[:, :, c] = (label_img == c).astype(np.uint8) return one_hot def readImages(data, typeData, width, height): images = [] for img in data: if typeData == 's': # SAR image with rasterio.open(str(img)) as src: sar_bands = [src.read(i) for i in range(1, src.count + 1)] sar_image = np.stack(sar_bands, axis=-1) # Contrast stretching p2, p98 = np.percentile(sar_image, (2, 98)) sar_image = np.clip(sar_image, p2, p98) sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8) # Resize sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA) images.append(np.expand_dims(sar_image, axis=-1)) elif typeData == 'm': # Mask image img = cv2.imread(str(img), cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST) images.append(convertColorToLabel(img)) elif typeData == 'o': # Optic image img = cv2.imread(str(img), cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) images.append(img) print(f"(INFO..) Read {len(images)} '{typeData}' image(s)") return np.array(images) def normalizeImages(images, typeData): normalized_images = [] for img in images: img = img.astype(np.uint8) if typeData in ['s', 'o']: img = img / 255. normalized_images.append(img) print("(INFO..) Normalization Image Done") return np.array(normalized_images) def save_uploaded_file(uploaded_file, suffix=".tif"): with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(uploaded_file.read()) return tmp.name def get_transform_from_tif(tif_path): """ Extracts the affine transform and CRS from a GeoTIFF image. Args: tif_path (str): Path to the GeoTIFF file. Returns: transform (Affine): Affine transformation for pixel to coordinate mapping. crs (CRS): Coordinate Reference System of the image. """ with rasterio.open(tif_path) as src: transform = src.transform crs = src.crs return transform, crs def mask_to_polygons(mask, transform, mining_class_id=1): # Create a binary mask for mining areas mining_mask = (mask == mining_class_id).astype(np.uint8) # Extract shapes (polygons) from the mask results = ( {'properties': {'class': 'mining_land'}, 'geometry': s} for s, v in shapes(mining_mask, mask=None, transform=transform) if v == 1 # Only keep areas marked as mining ) # Create a GeoDataFrame from the polygon results geoms = list(results) if not geoms: return gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs='EPSG:4326') gdf = gpd.GeoDataFrame.from_features(geoms) gdf.set_crs(epsg=4326, inplace=True) return gdf # Function to calculate illegal mining area (mining outside WIUP) def calculate_illegal_area(mining_gdf, wiup_gdf): # Ensure both are in the same CRS mining_gdf = mining_gdf.to_crs(wiup_gdf.crs) # Perform overlay operation: difference (mining area outside WIUP) illegal_area_gdf = gpd.overlay(mining_gdf, wiup_gdf, how='difference') # Convert to metric CRS (e.g., UTM) to calculate area in m² metric_crs = ' EPSG:4326' # UTM zone for SE Asia (adjust based on location) illegal_area_gdf = illegal_area_gdf.to_crs(metric_crs) # Calculate the area in square meters illegal_area_gdf['area_m2'] = illegal_area_gdf.geometry.area total_illegal_area = illegal_area_gdf['area_m2'].sum() return total_illegal_area, illegal_area_gdf # Streamlit App Title st.title("Satellite Mining Segmentation: SAR + Optic Image Inference") sar_file = st.file_uploader("Upload SAR Image", type=["tiff"]) optic_file = st.file_uploader("Upload Optical Image", type=["tiff"]) mask_file = st.file_uploader("Upload Mask Image", type=["tiff"]) wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"]) num_samples = 1 if st.button("Run Inference"): with st.spinner("Loading data and model..."): if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None: st.success("All files uploaded successfully!") # Save uploaded files sar_path = save_uploaded_file(sar_file, suffix=".tif") optic_path = save_uploaded_file(optic_file, suffix=".tif") mask_path = save_uploaded_file(mask_file, suffix=".tif") wiup_zip_path = save_uploaded_file(wiup_file, ".zip") extract_folder = wiup_zip_path.replace(".zip", "") with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) # Load WIUP shapefile wiup_gdf = gpd.read_file(extract_folder) # Create image lists sarImages = [sar_path] opticImages = [optic_path] masks = [mask_path] # Model path model_path = "Residual_UNET_Bilinear.keras" # Read and normalize images sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT) optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT) masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT) sar_images = normalizeImages(sar_images, 's') optic_images = normalizeImages(optic_images, 'i') # Load model model = tf.keras.models.load_model( model_path, custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score} ) # Predict masks pred_masks = model.predict([optic_images, sar_images], verbose=0) is_multiclass = pred_masks.shape[-1] > 1 num_samples = min(num_samples, len(sar_images)) # Plotting results fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) for i in range(num_samples): ax = axes[i] if num_samples > 1 else axes # Plot SAR image ax[0].imshow(sar_images[i].squeeze(), cmap='gray') ax[0].set_title(f"SAR Image {i+1}") ax[0].axis('off') # Plot Optic image ax[1].imshow(optic_images[i]) ax[1].set_title(f"Optic Image {i+1}") ax[1].axis('off') # Plot Ground Truth Mask if is_multiclass: gt_color_mask = np.zeros((*masks[i].shape[:2], 3)) for j, color in enumerate(CLASS_COLORS): gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color) ax[2].imshow(gt_color_mask) else: ax[2].imshow(masks[i], cmap='gray') ax[2].set_title(f"Ground Truth Mask {i+1}") ax[2].axis('off') # Plot Predicted Mask if is_multiclass: pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3)) for j, color in enumerate(CLASS_COLORS): pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color) ax[3].imshow(pred_color_mask) else: ax[3].imshow(pred_masks[i], cmap='gray') ax[3].set_title(f"Predicted Mask {i+1}") ax[3].axis('off') st.pyplot(fig) # Define color for class 1: illegal mining red_color = [255, 0, 0] # Convert optic_images to uint8 if needed if optic_images.dtype != np.uint8: optic_images = (optic_images * 255).astype(np.uint8) # Plot overlays fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) for i in range(num_samples): ax = axes[i] if num_samples > 1 else axes # SAR image ax[0].imshow(sar_images[i].squeeze(), cmap='gray') ax[0].set_title(f"SAR Image {i+1}") ax[0].axis('off') # Optic image ax[1].imshow(optic_images[i]) ax[1].set_title(f"Optic Image {i+1}") ax[1].axis('off') # Ground truth overlay gt_overlay = optic_images[i].copy() if is_multiclass: gt_overlay[masks[i][:, :, 1] == 1] = red_color else: gt_overlay[masks[i].squeeze() == 1] = red_color ax[2].imshow(optic_images[i]) ax[2].imshow(gt_overlay, alpha=0.4) ax[2].set_title(f"Ground Truth Overlay {i+1}") ax[2].axis('off') # Predicted mask overlay pred_overlay = optic_images[i].copy() if is_multiclass: pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color else: pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color ax[3].imshow(optic_images[i]) ax[3].imshow(pred_overlay, alpha=0.4) ax[3].set_title(f"Predicted Overlay {i+1}") ax[3].axis('off') plt.tight_layout() st.pyplot(fig) transform, crs = get_transform_from_tif(optic_path) # After getting prediction from model pred_label_mask = np.argmax(pred_masks[0], axis=-1) # shape: (H, W) st.success(f" Unique classes in mask: {np.unique(pred_label_mask)}") # Get georeference transform from TIF image transform, crs = get_transform_from_tif(optic_path) # Convert mask to mining polygons mining_gdf = mask_to_polygons(pred_label_mask, transform, mining_class_id=1) st.success(f"Mining polygons: {len(mining_gdf)}") # Make sure WIUP and prediction are in same CRS wiup_gdf = wiup_gdf.to_crs(mining_gdf.crs) st.success(f"WIUP CRS: {wiup_gdf.crs}") # Find area illegal_mining_area, illegal_mining_area_gdf = calculate_illegal_area(mining_gdf, wiup_gdf) # Display in Streamlit st.success(f" Illegal Mining Area : {illegal_mining_area/1e6:.2f} sq. km") else: st.warning("Please upload all three .tiff files to proceed.")