Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import torch | |
import gradio as gr | |
import segmentation_models_pytorch as smp | |
from PIL import Image | |
import boto3 | |
import uuid | |
import io | |
from glob import glob | |
import os | |
from pipeline.ImgOutlier import detect_outliers | |
from pipeline.normalization import align_images | |
# Detect if running inside Hugging Face Spaces | |
HF_SPACE = os.environ.get('SPACE_ID') is not None | |
# DigitalOcean Spaces upload function | |
def upload_mask(image, prefix="mask"): | |
""" | |
Upload segmentation mask image to DigitalOcean Spaces | |
Args: | |
image: PIL Image object | |
prefix: filename prefix | |
Returns: | |
Public URL of the uploaded file | |
""" | |
try: | |
# Get credentials from environment variables | |
do_key = os.environ.get('DO_SPACES_KEY') | |
do_secret = os.environ.get('DO_SPACES_SECRET') | |
do_region = os.environ.get('DO_SPACES_REGION') | |
do_bucket = os.environ.get('DO_SPACES_BUCKET') | |
# Check if credentials exist | |
if not all([do_key, do_secret, do_region, do_bucket]): | |
return "DigitalOcean credentials not set" | |
# Create S3 client | |
session = boto3.session.Session() | |
client = session.client('s3', | |
region_name=do_region, | |
endpoint_url=f'https://{do_region}.digitaloceanspaces.com', | |
aws_access_key_id=do_key, | |
aws_secret_access_key=do_secret) | |
# Generate unique filename | |
filename = f"{prefix}_{uuid.uuid4().hex}.png" | |
# Convert image to bytes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr.seek(0) | |
# Upload to Spaces | |
client.upload_fileobj( | |
img_byte_arr, | |
do_bucket, | |
filename, | |
ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'} | |
) | |
# Return public URL | |
url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}' | |
return url | |
except Exception as e: | |
print(f"Upload failed: {str(e)}") | |
return f"Upload error: {str(e)}" | |
# Global Configuration | |
MODEL_PATHS = { | |
"Metal Marcy": "models/MM_best_model.pth", | |
"Silhouette Jaenette": "models/SJ_best_model.pth" | |
} | |
REFERENCE_VECTOR_PATHS = { | |
"Metal Marcy": "models/MM_mean.npy", | |
"Silhouette Jaenette": "models/SJ_mean.npy" | |
} | |
REFERENCE_IMAGE_DIRS = { | |
"Metal Marcy": "reference_images/MM", | |
"Silhouette Jaenette": "reference_images/SJ" | |
} | |
# Category names and color mapping | |
CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand'] | |
COLORS = [ | |
[0, 0, 0], # background - black | |
[139, 137, 137], # cobbles - dark gray | |
[255, 228, 181], # drysand - light yellow | |
[0, 128, 0], # plant - green | |
[135, 206, 235], # sky - sky blue | |
[0, 0, 255], # water - blue | |
[194, 178, 128] # wetsand - sand brown | |
] | |
# Load model function | |
def load_model(model_path, device="cuda"): | |
try: | |
# If running inside HF Spaces, default to CPU | |
if HF_SPACE: | |
device = "cpu" | |
elif not torch.cuda.is_available(): | |
device = "cpu" | |
model = smp.create_model( | |
"DeepLabV3Plus", | |
encoder_name="efficientnet-b6", | |
in_channels=3, | |
classes=len(CLASSES), | |
encoder_weights=None | |
) | |
state_dict = torch.load(model_path, map_location=device) | |
if all(k.startswith('model.') for k in state_dict.keys()): | |
state_dict = {k[6:]: v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
print(f"Model loaded successfully: {model_path}") | |
return model | |
except Exception as e: | |
print(f"Model loading failed: {e}") | |
return None | |
# Load reference vector | |
def load_reference_vector(vector_path): | |
try: | |
if not os.path.exists(vector_path): | |
print(f"Reference vector file not found: {vector_path}") | |
return [] | |
ref_vector = np.load(vector_path) | |
print(f"Reference vector loaded successfully: {vector_path}") | |
return ref_vector | |
except Exception as e: | |
print(f"Reference vector loading failed {vector_path}: {e}") | |
return [] | |
# Load reference images | |
def load_reference_images(ref_dir): | |
try: | |
if not os.path.exists(ref_dir): | |
print(f"Reference image directory not found: {ref_dir}") | |
os.makedirs(ref_dir, exist_ok=True) | |
return [] | |
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] | |
image_files = [] | |
for ext in image_extensions: | |
image_files.extend(glob(os.path.join(ref_dir, ext))) | |
image_files.sort() | |
reference_images = [] | |
for file in image_files[:4]: | |
img = cv2.imread(file) | |
if img is not None: | |
reference_images.append(img) | |
print(f"Loaded {len(reference_images)} images from {ref_dir}") | |
return reference_images | |
except Exception as e: | |
print(f"Image loading failed {ref_dir}: {e}") | |
return [] | |
# Preprocess the image | |
def preprocess_image(image): | |
if image.shape[2] == 4: | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
orig_h, orig_w = image.shape[:2] | |
image_resized = cv2.resize(image, (1024, 1024)) | |
image_norm = image_resized.astype(np.float32) / 255.0 | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
image_norm = (image_norm - mean) / std | |
image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0) | |
return image_tensor, orig_h, orig_w | |
# Generate segmentation map and visualization | |
def generate_segmentation_map(prediction, orig_h, orig_w): | |
mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8) | |
mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
kernel = np.ones((5, 5), np.uint8) | |
processed_mask = mask_resized.copy() | |
for idx in range(1, len(CLASSES)): | |
class_mask = (mask_resized == idx).astype(np.uint8) | |
dilated_mask = cv2.dilate(class_mask, kernel, iterations=2) | |
dilated_effect = dilated_mask & (mask_resized == 0) | |
processed_mask[dilated_effect > 0] = idx | |
segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) | |
for idx, color in enumerate(COLORS): | |
segmentation_map[processed_mask == idx] = color | |
return segmentation_map | |
# Analysis result HTML | |
def create_analysis_result(mask): | |
total_pixels = mask.size | |
percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1) | |
for i, cls in enumerate(CLASSES)} | |
ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water'] | |
result = "<div style='font-size:18px;font-weight:bold;'>" | |
result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered) | |
result += "</div>" | |
return result | |
# Merge and overlay | |
def create_overlay(image, segmentation_map, alpha=0.5): | |
if image.shape[:2] != segmentation_map.shape[:2]: | |
segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) | |
return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0) | |
# Perform segmentation | |
def perform_segmentation(model, image_bgr): | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
image_tensor, orig_h, orig_w = preprocess_image(image_rgb) | |
with torch.no_grad(): | |
prediction = model(image_tensor.to(device)) | |
seg_map = generate_segmentation_map(prediction, orig_h, orig_w) # RGB | |
overlay = create_overlay(image_rgb, seg_map) | |
mask = prediction.argmax(1).squeeze().cpu().numpy() | |
analysis = create_analysis_result(mask) | |
return seg_map, overlay, analysis | |
# Single image processing | |
def process_coastal_image(location, input_image): | |
if input_image is None: | |
return None, None, "Please upload an image", "Not detected", None | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
model = load_model(MODEL_PATHS[location], device) | |
if model is None: | |
return None, None, f"Error: Failed to load model", "Not detected", None | |
ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) | |
ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location]) | |
outlier_status = "Not detected" | |
is_outlier = False | |
image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
if len(ref_vector) > 0: | |
filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector) | |
is_outlier = len(filtered) == 0 | |
elif len(ref_images) > 0: | |
filtered, _ = detect_outliers(ref_images, [image_bgr]) | |
is_outlier = len(filtered) == 0 | |
else: | |
print("Warning: No reference images or reference vectors available for outlier detection") | |
is_outlier = False | |
outlier_status = "Outlier Detection: <span style='color:red;font-weight:bold'>Failed</span>" if is_outlier else "Outlier Detection: <span style='color:green;font-weight:bold'>Passed</span>" | |
seg_map, overlay, analysis = perform_segmentation(model, image_bgr) | |
# Try uploading to DigitalOcean Spaces | |
url = "Local Storage" | |
try: | |
url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_')) | |
except Exception as e: | |
print(f"Upload failed: {e}") | |
url = f"Upload error: {str(e)}" | |
if is_outlier: | |
analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>Warning: The image failed outlier detection, the result may be inaccurate!</div>" + analysis | |
return seg_map, overlay, analysis, outlier_status, url | |
# Spatial Alignment | |
def process_with_alignment(location, reference_image, input_image): | |
if reference_image is None or input_image is None: | |
return None, None, None, None, "Please upload both reference and target images", "Not processed", None | |
device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" | |
model = load_model(MODEL_PATHS[location], device) | |
if model is None: | |
return None, None, None, None, "Error: Failed to load model", "Not processed", None | |
ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR) | |
tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
try: | |
aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)]) | |
aligned_tgt_bgr = aligned[1] | |
except Exception as e: | |
print(f"Spatial alignment failed: {e}") | |
return None, None, None, None, f"Spatial alignment failed: {str(e)}", "Processing failed", None | |
seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr) | |
# Try uploading to DigitalOcean Spaces | |
url = "Local Storage" | |
try: | |
url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_')) | |
except Exception as e: | |
print(f"Upload failed: {e}") | |
url = f"Upload error: {str(e)}" | |
status = "Spatial Alignment: <span style='color:green;font-weight:bold'>Completed</span>" | |
ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB) | |
aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB) | |
return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url | |
# Create the Gradio interface | |
def create_interface(): | |
# Set unified display size | |
disp_w, disp_h = 683, 512 # Maintain aspect ratio | |
with gr.Blocks(title="Coastal Erosion Analysis System") as demo: | |
gr.Markdown("""# Coastal Erosion Analysis System | |
Upload coastal images for analysis, including segmentation and spatial alignment.""") | |
with gr.Tabs(): | |
with gr.TabItem("Single Image Segmentation"): | |
with gr.Row(): | |
loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) | |
with gr.Row(): | |
inp = gr.Image(label="Input Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
seg = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) | |
ovl = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) | |
with gr.Row(): | |
btn1 = gr.Button("Run Segmentation") | |
url1 = gr.Text(label="Segmentation Image URL") | |
status1 = gr.HTML(label="Outlier Detection Status") | |
res1 = gr.HTML(label="Analysis Result") | |
btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1]) | |
with gr.TabItem("Spatial Alignment Segmentation"): | |
with gr.Row(): | |
loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) | |
with gr.Row(): | |
ref_img = gr.Image(label="Reference Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
tgt_img = gr.Image(label="Target Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) | |
with gr.Row(): | |
btn2 = gr.Button("Run Spatial Alignment and Segmentation") | |
with gr.Row(): | |
orig = gr.Image(label="Original Image", type="numpy", height=disp_h, width=disp_w) | |
aligned = gr.Image(label="Aligned Image", type="numpy", height=disp_h, width=disp_w) | |
with gr.Row(): | |
seg2 = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) | |
ovl2 = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) | |
url2 = gr.Text(label="Segmentation Image URL") | |
status2 = gr.HTML(label="Alignment Status") | |
res2 = gr.HTML(label="Analysis Result") | |
btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2]) | |
return demo | |
if __name__ == "__main__": | |
# Create necessary directories | |
for path in ["models", "reference_images/MM", "reference_images/SJ"]: | |
os.makedirs(path, exist_ok=True) | |
# Check if model files exist | |
for p in MODEL_PATHS.values(): | |
if not os.path.exists(p): | |
print(f"Warning: Model file {p} does not exist!") | |
# Check if DigitalOcean credentials exist | |
do_creds = [ | |
os.environ.get('DO_SPACES_KEY'), | |
os.environ.get('DO_SPACES_SECRET'), | |
os.environ.get('DO_SPACES_REGION'), | |
os.environ.get('DO_SPACES_BUCKET') | |
] | |
if not all(do_creds): | |
print("Warning: Incomplete DigitalOcean Spaces credentials, upload functionality may not work") | |
# Create and launch the interface | |
demo = create_interface() | |
if HF_SPACE: | |
demo.launch() | |
else: | |
demo.launch(share=True) |