|
import streamlit as st |
|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from tensorflow.keras.models import load_model |
|
import pathlib |
|
import natsort |
|
import datetime |
|
import shutil |
|
import rasterio |
|
import cv2 |
|
import tensorflow as tf |
|
import tempfile |
|
from rasterio import features |
|
from shapely.geometry import shape |
|
from rasterio.features import shapes |
|
import geopandas as gpd |
|
import zipfile |
|
|
|
|
|
HEIGHT = WIDTH = 256 |
|
SAR_SHAPE = (HEIGHT, WIDTH, 1) |
|
OPTIC_SHAPE = (HEIGHT, WIDTH, 3) |
|
MASK_SHAPE = (HEIGHT, WIDTH, 4) |
|
|
|
|
|
CLASS_COLORS = [ |
|
[115/255, 178/255, 115/255], |
|
[1, 0, 0], |
|
[0, 0, 0] |
|
] |
|
|
|
@tf.keras.saving.register_keras_serializable() |
|
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0): |
|
|
|
is_multiclass = y_true.shape[-1] > 1 |
|
|
|
if not is_multiclass: |
|
|
|
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) |
|
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) |
|
intersection = tf.reduce_sum(y_true_flat * y_pred_flat) |
|
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) |
|
return score |
|
else: |
|
|
|
num_classes = y_true.shape[-1] |
|
score_per_class = [] |
|
|
|
for i in range(num_classes): |
|
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) |
|
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) |
|
intersection = tf.reduce_sum(y_true_flat * y_pred_flat) |
|
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) |
|
score_per_class.append(score) |
|
|
|
return tf.reduce_mean(score_per_class) |
|
|
|
@tf.keras.saving.register_keras_serializable() |
|
def dice_loss(y_true, y_pred): |
|
dice = dice_score(y_true, y_pred) |
|
loss = 1. - dice |
|
return tf.cast(loss, dtype=tf.float32) |
|
|
|
@tf.keras.saving.register_keras_serializable() |
|
def cce_dice_loss(y_true, y_pred): |
|
cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred) |
|
dice = dice_loss(y_true, y_pred) |
|
return tf.cast(cce, dtype=tf.float32) + dice |
|
|
|
def convertColorToLabel(img): |
|
color_to_label = { |
|
(115, 178, 115): 0, |
|
(255, 0, 0): 1, |
|
(0, 0, 0): 2, |
|
} |
|
|
|
|
|
label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) |
|
|
|
|
|
for color, label in color_to_label.items(): |
|
mask = np.all(img == color, axis=2) |
|
label_img[mask] = label |
|
|
|
|
|
num_classes = len(color_to_label) |
|
one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8) |
|
for c in range(num_classes): |
|
one_hot[:, :, c] = (label_img == c).astype(np.uint8) |
|
|
|
return one_hot |
|
|
|
def readImages(data, typeData, width, height): |
|
images = [] |
|
for img in data: |
|
if typeData == 's': |
|
with rasterio.open(str(img)) as src: |
|
sar_bands = [src.read(i) for i in range(1, src.count + 1)] |
|
sar_image = np.stack(sar_bands, axis=-1) |
|
|
|
|
|
p2, p98 = np.percentile(sar_image, (2, 98)) |
|
sar_image = np.clip(sar_image, p2, p98) |
|
sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8) |
|
|
|
|
|
sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA) |
|
images.append(np.expand_dims(sar_image, axis=-1)) |
|
|
|
elif typeData == 'm': |
|
img = cv2.imread(str(img), cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST) |
|
images.append(convertColorToLabel(img)) |
|
|
|
elif typeData == 'o': |
|
img = cv2.imread(str(img), cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) |
|
images.append(img) |
|
|
|
print(f"(INFO..) Read {len(images)} '{typeData}' image(s)") |
|
return np.array(images) |
|
|
|
|
|
def normalizeImages(images, typeData): |
|
normalized_images = [] |
|
for img in images: |
|
img = img.astype(np.uint8) |
|
if typeData in ['s', 'o']: |
|
img = img / 255. |
|
normalized_images.append(img) |
|
|
|
print("(INFO..) Normalization Image Done") |
|
return np.array(normalized_images) |
|
|
|
|
|
|
|
def save_uploaded_file(uploaded_file, suffix=".tif"): |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
|
tmp.write(uploaded_file.read()) |
|
return tmp.name |
|
|
|
def get_transform_from_tif(tif_path): |
|
""" |
|
Extracts the affine transform and CRS from a GeoTIFF image. |
|
|
|
Args: |
|
tif_path (str): Path to the GeoTIFF file. |
|
|
|
Returns: |
|
transform (Affine): Affine transformation for pixel to coordinate mapping. |
|
crs (CRS): Coordinate Reference System of the image. |
|
""" |
|
with rasterio.open(tif_path) as src: |
|
transform = src.transform |
|
crs = src.crs |
|
return transform, crs |
|
|
|
def mask_to_polygons(mask, transform, mining_class_id=1): |
|
|
|
mining_mask = (mask == mining_class_id).astype(np.uint8) |
|
|
|
|
|
results = ( |
|
{'properties': {'class': 'mining_land'}, 'geometry': s} |
|
for s, v in shapes(mining_mask, mask=None, transform=transform) |
|
if v == 1 |
|
) |
|
|
|
|
|
geoms = list(results) |
|
if not geoms: |
|
return gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs='EPSG:4326') |
|
|
|
gdf = gpd.GeoDataFrame.from_features(geoms) |
|
gdf.set_crs(epsg=4326, inplace=True) |
|
return gdf |
|
|
|
|
|
|
|
def calculate_illegal_area(mining_gdf, wiup_gdf): |
|
|
|
mining_gdf = mining_gdf.to_crs(wiup_gdf.crs) |
|
|
|
|
|
illegal_area_gdf = gpd.overlay(mining_gdf, wiup_gdf, how='difference') |
|
|
|
|
|
metric_crs = ' EPSG:4326' |
|
illegal_area_gdf = illegal_area_gdf.to_crs(metric_crs) |
|
|
|
|
|
illegal_area_gdf['area_m2'] = illegal_area_gdf.geometry.area |
|
total_illegal_area = illegal_area_gdf['area_m2'].sum() |
|
|
|
return total_illegal_area, illegal_area_gdf |
|
|
|
|
|
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference") |
|
|
|
|
|
sar_file = st.file_uploader("Upload SAR Image", type=["tiff"]) |
|
optic_file = st.file_uploader("Upload Optical Image", type=["tiff"]) |
|
mask_file = st.file_uploader("Upload Mask Image", type=["tiff"]) |
|
wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"]) |
|
|
|
num_samples = 1 |
|
if st.button("Run Inference"): |
|
with st.spinner("Loading data and model..."): |
|
if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None: |
|
st.success("All files uploaded successfully!") |
|
|
|
|
|
sar_path = save_uploaded_file(sar_file, suffix=".tif") |
|
optic_path = save_uploaded_file(optic_file, suffix=".tif") |
|
mask_path = save_uploaded_file(mask_file, suffix=".tif") |
|
|
|
wiup_zip_path = save_uploaded_file(wiup_file, ".zip") |
|
|
|
extract_folder = wiup_zip_path.replace(".zip", "") |
|
with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref: |
|
zip_ref.extractall(extract_folder) |
|
|
|
|
|
wiup_gdf = gpd.read_file(extract_folder) |
|
|
|
|
|
sarImages = [sar_path] |
|
opticImages = [optic_path] |
|
masks = [mask_path] |
|
|
|
|
|
model_path = "Residual_UNET_Bilinear.keras" |
|
|
|
|
|
sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT) |
|
optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT) |
|
masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT) |
|
|
|
sar_images = normalizeImages(sar_images, 's') |
|
optic_images = normalizeImages(optic_images, 'i') |
|
|
|
|
|
model = tf.keras.models.load_model( |
|
model_path, |
|
custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score} |
|
) |
|
|
|
|
|
pred_masks = model.predict([optic_images, sar_images], verbose=0) |
|
is_multiclass = pred_masks.shape[-1] > 1 |
|
|
|
num_samples = min(num_samples, len(sar_images)) |
|
|
|
|
|
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) |
|
|
|
for i in range(num_samples): |
|
ax = axes[i] if num_samples > 1 else axes |
|
|
|
|
|
ax[0].imshow(sar_images[i].squeeze(), cmap='gray') |
|
ax[0].set_title(f"SAR Image {i+1}") |
|
ax[0].axis('off') |
|
|
|
|
|
ax[1].imshow(optic_images[i]) |
|
ax[1].set_title(f"Optic Image {i+1}") |
|
ax[1].axis('off') |
|
|
|
|
|
if is_multiclass: |
|
gt_color_mask = np.zeros((*masks[i].shape[:2], 3)) |
|
for j, color in enumerate(CLASS_COLORS): |
|
gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color) |
|
ax[2].imshow(gt_color_mask) |
|
else: |
|
ax[2].imshow(masks[i], cmap='gray') |
|
ax[2].set_title(f"Ground Truth Mask {i+1}") |
|
ax[2].axis('off') |
|
|
|
|
|
if is_multiclass: |
|
pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3)) |
|
for j, color in enumerate(CLASS_COLORS): |
|
pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color) |
|
ax[3].imshow(pred_color_mask) |
|
else: |
|
ax[3].imshow(pred_masks[i], cmap='gray') |
|
ax[3].set_title(f"Predicted Mask {i+1}") |
|
ax[3].axis('off') |
|
|
|
st.pyplot(fig) |
|
|
|
|
|
red_color = [255, 0, 0] |
|
|
|
|
|
if optic_images.dtype != np.uint8: |
|
optic_images = (optic_images * 255).astype(np.uint8) |
|
|
|
|
|
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) |
|
|
|
for i in range(num_samples): |
|
ax = axes[i] if num_samples > 1 else axes |
|
|
|
|
|
ax[0].imshow(sar_images[i].squeeze(), cmap='gray') |
|
ax[0].set_title(f"SAR Image {i+1}") |
|
ax[0].axis('off') |
|
|
|
|
|
ax[1].imshow(optic_images[i]) |
|
ax[1].set_title(f"Optic Image {i+1}") |
|
ax[1].axis('off') |
|
|
|
|
|
gt_overlay = optic_images[i].copy() |
|
if is_multiclass: |
|
gt_overlay[masks[i][:, :, 1] == 1] = red_color |
|
else: |
|
gt_overlay[masks[i].squeeze() == 1] = red_color |
|
|
|
ax[2].imshow(optic_images[i]) |
|
ax[2].imshow(gt_overlay, alpha=0.4) |
|
ax[2].set_title(f"Ground Truth Overlay {i+1}") |
|
ax[2].axis('off') |
|
|
|
|
|
pred_overlay = optic_images[i].copy() |
|
if is_multiclass: |
|
pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color |
|
else: |
|
pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color |
|
|
|
ax[3].imshow(optic_images[i]) |
|
ax[3].imshow(pred_overlay, alpha=0.4) |
|
ax[3].set_title(f"Predicted Overlay {i+1}") |
|
ax[3].axis('off') |
|
|
|
plt.tight_layout() |
|
st.pyplot(fig) |
|
|
|
|
|
transform, crs = get_transform_from_tif(optic_path) |
|
|
|
|
|
pred_label_mask = np.argmax(pred_masks[0], axis=-1) |
|
st.success(f" Unique classes in mask: {np.unique(pred_label_mask)}") |
|
|
|
|
|
transform, crs = get_transform_from_tif(optic_path) |
|
|
|
|
|
mining_gdf = mask_to_polygons(pred_label_mask, transform, mining_class_id=1) |
|
st.success(f"Mining polygons: {len(mining_gdf)}") |
|
|
|
|
|
|
|
wiup_gdf = wiup_gdf.to_crs(mining_gdf.crs) |
|
st.success(f"WIUP CRS: {wiup_gdf.crs}") |
|
|
|
|
|
illegal_mining_area, illegal_mining_area_gdf = calculate_illegal_area(mining_gdf, wiup_gdf) |
|
|
|
|
|
st.success(f" Illegal Mining Area : {illegal_mining_area/1e6:.2f} sq. km") |
|
|
|
else: |
|
st.warning("Please upload all three .tiff files to proceed.") |
|
|
|
|