Alexvatti's picture
Update app.py
1334359 verified
raw
history blame contribute delete
14 kB
import streamlit as st
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import pathlib
import natsort
import datetime
import shutil
import rasterio
import cv2
import tensorflow as tf
import tempfile
from rasterio import features
from shapely.geometry import shape
from rasterio.features import shapes
import geopandas as gpd
import zipfile
# Configuration
HEIGHT = WIDTH = 256
SAR_SHAPE = (HEIGHT, WIDTH, 1)
OPTIC_SHAPE = (HEIGHT, WIDTH, 3)
MASK_SHAPE = (HEIGHT, WIDTH, 4) # One-hot encoded masks with 4 classes
# Class colors: non-mining (green), mining (red), beach (black)
CLASS_COLORS = [
[115/255, 178/255, 115/255],
[1, 0, 0],
[0, 0, 0]
]
@tf.keras.saving.register_keras_serializable()
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
#determine binary or multiclass segmentation
is_multiclass = y_true.shape[-1] > 1
if not is_multiclass:
# Binary segmentation
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
return score
else:
# Multiclass segmentation
num_classes = y_true.shape[-1]
score_per_class = []
for i in range(num_classes):
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
score_per_class.append(score)
return tf.reduce_mean(score_per_class)
@tf.keras.saving.register_keras_serializable()
def dice_loss(y_true, y_pred):
dice = dice_score(y_true, y_pred)
loss = 1. - dice
return tf.cast(loss, dtype=tf.float32)
@tf.keras.saving.register_keras_serializable()
def cce_dice_loss(y_true, y_pred):
cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
dice = dice_loss(y_true, y_pred)
return tf.cast(cce, dtype=tf.float32) + dice
def convertColorToLabel(img):
color_to_label = {
(115, 178, 115): 0, # non_mining_land (green)
(255, 0, 0): 1, # illegal_mining_land (red)
(0, 0, 0): 2, # beach (black)
}
# Create empty label array
label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
# Map each RGB color to its corresponding label
for color, label in color_to_label.items():
mask = np.all(img == color, axis=2)
label_img[mask] = label
# One-hot encode the label image
num_classes = len(color_to_label)
one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8)
for c in range(num_classes):
one_hot[:, :, c] = (label_img == c).astype(np.uint8)
return one_hot
def readImages(data, typeData, width, height):
images = []
for img in data:
if typeData == 's': # SAR image
with rasterio.open(str(img)) as src:
sar_bands = [src.read(i) for i in range(1, src.count + 1)]
sar_image = np.stack(sar_bands, axis=-1)
# Contrast stretching
p2, p98 = np.percentile(sar_image, (2, 98))
sar_image = np.clip(sar_image, p2, p98)
sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8)
# Resize
sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA)
images.append(np.expand_dims(sar_image, axis=-1))
elif typeData == 'm': # Mask image
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST)
images.append(convertColorToLabel(img))
elif typeData == 'o': # Optic image
img = cv2.imread(str(img), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
images.append(img)
print(f"(INFO..) Read {len(images)} '{typeData}' image(s)")
return np.array(images)
def normalizeImages(images, typeData):
normalized_images = []
for img in images:
img = img.astype(np.uint8)
if typeData in ['s', 'o']:
img = img / 255.
normalized_images.append(img)
print("(INFO..) Normalization Image Done")
return np.array(normalized_images)
def save_uploaded_file(uploaded_file, suffix=".tif"):
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(uploaded_file.read())
return tmp.name
def get_transform_from_tif(tif_path):
"""
Extracts the affine transform and CRS from a GeoTIFF image.
Args:
tif_path (str): Path to the GeoTIFF file.
Returns:
transform (Affine): Affine transformation for pixel to coordinate mapping.
crs (CRS): Coordinate Reference System of the image.
"""
with rasterio.open(tif_path) as src:
transform = src.transform
crs = src.crs
return transform, crs
def mask_to_polygons(mask, transform, mining_class_id=1):
# Create a binary mask for mining areas
mining_mask = (mask == mining_class_id).astype(np.uint8)
# Extract shapes (polygons) from the mask
results = (
{'properties': {'class': 'mining_land'}, 'geometry': s}
for s, v in shapes(mining_mask, mask=None, transform=transform)
if v == 1 # Only keep areas marked as mining
)
# Create a GeoDataFrame from the polygon results
geoms = list(results)
if not geoms:
return gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs='EPSG:4326')
gdf = gpd.GeoDataFrame.from_features(geoms)
gdf.set_crs(epsg=4326, inplace=True)
return gdf
# Function to calculate illegal mining area (mining outside WIUP)
def calculate_illegal_area(mining_gdf, wiup_gdf):
# Ensure both are in the same CRS
mining_gdf = mining_gdf.to_crs(wiup_gdf.crs)
# Perform overlay operation: difference (mining area outside WIUP)
illegal_area_gdf = gpd.overlay(mining_gdf, wiup_gdf, how='difference')
# Convert to metric CRS (e.g., UTM) to calculate area in m²
metric_crs = ' EPSG:4326' # UTM zone for SE Asia (adjust based on location)
illegal_area_gdf = illegal_area_gdf.to_crs(metric_crs)
# Calculate the area in square meters
illegal_area_gdf['area_m2'] = illegal_area_gdf.geometry.area
total_illegal_area = illegal_area_gdf['area_m2'].sum()
return total_illegal_area, illegal_area_gdf
# Streamlit App Title
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
sar_file = st.file_uploader("Upload SAR Image", type=["tiff"])
optic_file = st.file_uploader("Upload Optical Image", type=["tiff"])
mask_file = st.file_uploader("Upload Mask Image", type=["tiff"])
wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"])
num_samples = 1
if st.button("Run Inference"):
with st.spinner("Loading data and model..."):
if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None:
st.success("All files uploaded successfully!")
# Save uploaded files
sar_path = save_uploaded_file(sar_file, suffix=".tif")
optic_path = save_uploaded_file(optic_file, suffix=".tif")
mask_path = save_uploaded_file(mask_file, suffix=".tif")
wiup_zip_path = save_uploaded_file(wiup_file, ".zip")
extract_folder = wiup_zip_path.replace(".zip", "")
with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
# Load WIUP shapefile
wiup_gdf = gpd.read_file(extract_folder)
# Create image lists
sarImages = [sar_path]
opticImages = [optic_path]
masks = [mask_path]
# Model path
model_path = "Residual_UNET_Bilinear.keras"
# Read and normalize images
sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT)
optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT)
masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT)
sar_images = normalizeImages(sar_images, 's')
optic_images = normalizeImages(optic_images, 'i')
# Load model
model = tf.keras.models.load_model(
model_path,
custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score}
)
# Predict masks
pred_masks = model.predict([optic_images, sar_images], verbose=0)
is_multiclass = pred_masks.shape[-1] > 1
num_samples = min(num_samples, len(sar_images))
# Plotting results
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
for i in range(num_samples):
ax = axes[i] if num_samples > 1 else axes
# Plot SAR image
ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
ax[0].set_title(f"SAR Image {i+1}")
ax[0].axis('off')
# Plot Optic image
ax[1].imshow(optic_images[i])
ax[1].set_title(f"Optic Image {i+1}")
ax[1].axis('off')
# Plot Ground Truth Mask
if is_multiclass:
gt_color_mask = np.zeros((*masks[i].shape[:2], 3))
for j, color in enumerate(CLASS_COLORS):
gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
ax[2].imshow(gt_color_mask)
else:
ax[2].imshow(masks[i], cmap='gray')
ax[2].set_title(f"Ground Truth Mask {i+1}")
ax[2].axis('off')
# Plot Predicted Mask
if is_multiclass:
pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3))
for j, color in enumerate(CLASS_COLORS):
pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color)
ax[3].imshow(pred_color_mask)
else:
ax[3].imshow(pred_masks[i], cmap='gray')
ax[3].set_title(f"Predicted Mask {i+1}")
ax[3].axis('off')
st.pyplot(fig)
# Define color for class 1: illegal mining
red_color = [255, 0, 0]
# Convert optic_images to uint8 if needed
if optic_images.dtype != np.uint8:
optic_images = (optic_images * 255).astype(np.uint8)
# Plot overlays
fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples))
for i in range(num_samples):
ax = axes[i] if num_samples > 1 else axes
# SAR image
ax[0].imshow(sar_images[i].squeeze(), cmap='gray')
ax[0].set_title(f"SAR Image {i+1}")
ax[0].axis('off')
# Optic image
ax[1].imshow(optic_images[i])
ax[1].set_title(f"Optic Image {i+1}")
ax[1].axis('off')
# Ground truth overlay
gt_overlay = optic_images[i].copy()
if is_multiclass:
gt_overlay[masks[i][:, :, 1] == 1] = red_color
else:
gt_overlay[masks[i].squeeze() == 1] = red_color
ax[2].imshow(optic_images[i])
ax[2].imshow(gt_overlay, alpha=0.4)
ax[2].set_title(f"Ground Truth Overlay {i+1}")
ax[2].axis('off')
# Predicted mask overlay
pred_overlay = optic_images[i].copy()
if is_multiclass:
pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color
else:
pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color
ax[3].imshow(optic_images[i])
ax[3].imshow(pred_overlay, alpha=0.4)
ax[3].set_title(f"Predicted Overlay {i+1}")
ax[3].axis('off')
plt.tight_layout()
st.pyplot(fig)
transform, crs = get_transform_from_tif(optic_path)
# After getting prediction from model
pred_label_mask = np.argmax(pred_masks[0], axis=-1) # shape: (H, W)
st.success(f" Unique classes in mask: {np.unique(pred_label_mask)}")
# Get georeference transform from TIF image
transform, crs = get_transform_from_tif(optic_path)
# Convert mask to mining polygons
mining_gdf = mask_to_polygons(pred_label_mask, transform, mining_class_id=1)
st.success(f"Mining polygons: {len(mining_gdf)}")
# Make sure WIUP and prediction are in same CRS
wiup_gdf = wiup_gdf.to_crs(mining_gdf.crs)
st.success(f"WIUP CRS: {wiup_gdf.crs}")
# Find area
illegal_mining_area, illegal_mining_area_gdf = calculate_illegal_area(mining_gdf, wiup_gdf)
# Display in Streamlit
st.success(f" Illegal Mining Area : {illegal_mining_area/1e6:.2f} sq. km")
else:
st.warning("Please upload all three .tiff files to proceed.")