File size: 14,035 Bytes
b90fed9 b2d94fd 87b6bca 33a73d2 227bb08 469f05c 6695dcb 4487561 496e5c7 b90fed9 b6c25e3 b90fed9 c652f62 a52b712 662786f cd33515 662786f cd33515 a52b712 e4c18ca e8a87d7 38bdd0c e8a87d7 38bdd0c e8a87d7 38bdd0c e8a87d7 38bdd0c e8a87d7 f1e0408 e8a87d7 34a7c7d f1e0408 38bdd0c 34a7c7d 38bdd0c 34a7c7d 38bdd0c 34a7c7d 38bdd0c e4c18ca b90fed9 cd33515 6695dcb cd33515 6bc373a d98d462 6695dcb d98d462 6bc373a d98d462 3610efa 6695dcb d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 d98d462 c3f9b72 6695dcb 7dffef2 e8a87d7 5a39277 e8a87d7 38bdd0c f1e0408 38bdd0c e8a87d7 f1e0408 e8a87d7 b6c25e3 1334359 e8a87d7 1334359 6695dcb d98d462 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
import streamlit as st
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import pathlib
import natsort
import datetime
import shutil
import rasterio
import cv2
import tensorflow as tf
import tempfile
from rasterio import features
from shapely.geometry import shape
from rasterio.features import shapes
import geopandas as gpd
import zipfile
# Configuration
HEIGHT = WIDTH = 256
SAR_SHAPE = (HEIGHT, WIDTH, 1)
OPTIC_SHAPE = (HEIGHT, WIDTH, 3)
MASK_SHAPE = (HEIGHT, WIDTH, 4) # One-hot encoded masks with 4 classes
# Class colors: non-mining (green), mining (red), beach (black)
CLASS_COLORS = [
[115/255, 178/255, 115/255],
[1, 0, 0],
[0, 0, 0]
]
@tf.keras.saving.register_keras_serializable()
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
#determine binary or multiclass segmentation
is_multiclass = y_true.shape[-1] > 1
if not is_multiclass:
# Binary segmentation
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
return score
else:
# Multiclass segmentation
num_classes = y_true.shape[-1]
score_per_class = []
for i in range(num_classes):
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
score_per_class.append(score)
return tf.reduce_mean(score_per_class)
@tf.keras.saving.register_keras_serializable()
def dice_loss(y_true, y_pred):
dice = dice_score(y_true, y_pred)
loss = 1. - dice
return tf.cast(loss, dtype=tf.float32)
@tf.keras.saving.register_keras_serializable()
def cce_dice_loss(y_true, y_pred):
cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
dice = dice_loss(y_true, y_pred)
return tf.cast(cce, dtype=tf.float32) + dice
def convertColorToLabel(img):
color_to_label = {
(115, 178, 115): 0, # non_mining_land (green)
(255, 0, 0): 1, # illegal_mining_land (red)
(0, 0, 0): 2, # beach (black)
}
# Create empty label array
label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
# Map each RGB color to its corresponding label
for color, label in color_to_label.items():
mask = np.all(img == color, axis=2)
label_img[mask] = label
# One-hot encode the label image
num_classes = len(color_to_label)
one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8)
for c in range(num_classes):
one_hot[:, :, c] = (label_img == c).astype(np.uint8)
return one_hot
def readImages(data, typeData, width, height):
images = []
for img in data:
if typeData == 's': # SAR image
with rasterio.open(str(img)) as src:
sar_bands = [src.read(i) for i in range(1, src.count + 1)]
sar_image = np.stack(sar_bands, axis=-1)
# Contrast stretching
p2, p98 = np.percentile(sar_image, (2, 98))
sar_image = np.clip(sar_image, p2, p98)
sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)
# Resize
sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
images.append(np.expand_dims(sar_image, axis=-1))
elif typeData == 'm': # Mask image
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
images.append(convertColorToLabel(img))
elif typeData == 'o': # Optic image
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
images.append(img)
print(f"(INFO..) Read {len(images)} '{typeData}' image(s)")
return np.array(images)
def normalizeImages(images, typeData):
normalized_images = []
for img in images:
img = img.astype(np.uint8)
if typeData in ['s', 'o']:
img = img / 255.
normalized_images.append(img)
print("(INFO..) Normalization Image Done")
return np.array(normalized_images)
def save_uploaded_file(uploaded_file, suffix=".tif"):
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(uploaded_file.read())
return tmp.name
def get_transform_from_tif(tif_path):
"""
Extracts the affine transform and CRS from a GeoTIFF image.
Args:
tif_path (str): Path to the GeoTIFF file.
Returns:
transform (Affine): Affine transformation for pixel to coordinate mapping.
crs (CRS): Coordinate Reference System of the image.
"""
with rasterio.open(tif_path) as src:
transform = src.transform
crs = src.crs
return transform, crs
def mask_to_polygons(mask, transform, mining_class_id=1):
# Create a binary mask for mining areas
mining_mask = (mask == mining_class_id).astype(np.uint8)
# Extract shapes (polygons) from the mask
results = (
{'properties': {'class': 'mining_land'}, 'geometry': s}
for s, v in shapes(mining_mask, mask=None, transform=transform)
if v == 1 # Only keep areas marked as mining
)
# Create a GeoDataFrame from the polygon results
geoms = list(results)
if not geoms:
return gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs='EPSG:4326')
gdf = gpd.GeoDataFrame.from_features(geoms)
gdf.set_crs(epsg=4326, inplace=True)
return gdf
# Function to calculate illegal mining area (mining outside WIUP)
def calculate_illegal_area(mining_gdf, wiup_gdf):
# Ensure both are in the same CRS
mining_gdf = mining_gdf.to_crs(wiup_gdf.crs)
# Perform overlay operation: difference (mining area outside WIUP)
illegal_area_gdf = gpd.overlay(mining_gdf, wiup_gdf, how='difference')
# Convert to metric CRS (e.g., UTM) to calculate area in m²
metric_crs = ' EPSG:4326' # UTM zone for SE Asia (adjust based on location)
illegal_area_gdf = illegal_area_gdf.to_crs(metric_crs)
# Calculate the area in square meters
illegal_area_gdf['area_m2'] = illegal_area_gdf.geometry.area
total_illegal_area = illegal_area_gdf['area_m2'].sum()
return total_illegal_area, illegal_area_gdf
# Streamlit App Title
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
sar_file = st.file_uploader("Upload SAR Image", type=["tiff"])
optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"])
num_samples = 1
if st.button("Run Inference"):
with st.spinner("Loading data and model..."):
if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None:
st.success("All files uploaded successfully!")
# Save uploaded files
sar_path = save_uploaded_file(sar_file, suffix=".tif")
optic_path = save_uploaded_file(optic_file, suffix=".tif")
mask_path = save_uploaded_file(mask_file, suffix=".tif")
wiup_zip_path = save_uploaded_file(wiup_file, ".zip")
extract_folder = wiup_zip_path.replace(".zip", "")
with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
# Load WIUP shapefile
wiup_gdf = gpd.read_file(extract_folder)
# Create image lists
sarImages = [sar_path]
opticImages = [optic_path]
masks = [mask_path]
# Model path
model_path = "Residual_UNET_Bilinear.keras"
# Read and normalize images
sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
sar_images = normalizeImages(sar_images, 's')
optic_images = normalizeImages(optic_images, 'i')
# Load model
model = tf.keras.models.load_model(
model_path,
custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score}
)
# Predict masks
pred_masks = model.predict([optic_images, sar_images], verbose=0)
is_multiclass = pred_masks.shape[-1] > 1
num_samples = min(num_samples, len(sar_images))
# Plotting results
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
for i in range(num_samples):
ax = axes[i] if num_samples > 1 else axes
# Plot SAR image
ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
ax[0].set_title(f"SAR Image {i+1}")
ax[0].axis('off')
# Plot Optic image
ax[1].imshow(optic_images[i])
ax[1].set_title(f"Optic Image {i+1}")
ax[1].axis('off')
# Plot Ground Truth Mask
if is_multiclass:
gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
for j, color in enumerate(CLASS_COLORS):
gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
ax[2].imshow(gt_color_mask)
else:
ax[2].imshow(masks[i], cmap='gray')
ax[2].set_title(f"Ground Truth Mask {i+1}")
ax[2].axis('off')
# Plot Predicted Mask
if is_multiclass:
pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
for j, color in enumerate(CLASS_COLORS):
pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
ax[3].imshow(pred_color_mask)
else:
ax[3].imshow(pred_masks[i], cmap='gray')
ax[3].set_title(f"Predicted Mask {i+1}")
ax[3].axis('off')
st.pyplot(fig)
# Define color for class 1: illegal mining
red_color = [255, 0, 0]
# Convert optic_images to uint8 if needed
if optic_images.dtype != np.uint8:
optic_images = (optic_images * 255).astype(np.uint8)
# Plot overlays
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
for i in range(num_samples):
ax = axes[i] if num_samples > 1 else axes
# SAR image
ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
ax[0].set_title(f"SAR Image {i+1}")
ax[0].axis('off')
# Optic image
ax[1].imshow(optic_images[i])
ax[1].set_title(f"Optic Image {i+1}")
ax[1].axis('off')
# Ground truth overlay
gt_overlay = optic_images[i].copy()
if is_multiclass:
gt_overlay[masks[i][:, :, 1] == 1] = red_color
else:
gt_overlay[masks[i].squeeze() == 1] = red_color
ax[2].imshow(optic_images[i])
ax[2].imshow(gt_overlay, alpha=0.4)
ax[2].set_title(f"Ground Truth Overlay {i+1}")
ax[2].axis('off')
# Predicted mask overlay
pred_overlay = optic_images[i].copy()
if is_multiclass:
pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
else:
pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
ax[3].imshow(optic_images[i])
ax[3].imshow(pred_overlay, alpha=0.4)
ax[3].set_title(f"Predicted Overlay {i+1}")
ax[3].axis('off')
plt.tight_layout()
st.pyplot(fig)
transform, crs = get_transform_from_tif(optic_path)
# After getting prediction from model
pred_label_mask = np.argmax(pred_masks[0], axis=-1) # shape: (H, W)
st.success(f" Unique classes in mask: {np.unique(pred_label_mask)}")
# Get georeference transform from TIF image
transform, crs = get_transform_from_tif(optic_path)
# Convert mask to mining polygons
mining_gdf = mask_to_polygons(pred_label_mask, transform, mining_class_id=1)
st.success(f"Mining polygons: {len(mining_gdf)}")
# Make sure WIUP and prediction are in same CRS
wiup_gdf = wiup_gdf.to_crs(mining_gdf.crs)
st.success(f"WIUP CRS: {wiup_gdf.crs}")
# Find area
illegal_mining_area, illegal_mining_area_gdf = calculate_illegal_area(mining_gdf, wiup_gdf)
# Display in Streamlit
st.success(f" Illegal Mining Area : {illegal_mining_area/1e6:.2f} sq. km")
else:
st.warning("Please upload all three .tiff files to proceed.")
|