import cv2 import numpy as np import torch import gradio as gr import segmentation_models_pytorch as smp from PIL import Image import boto3 import uuid import io from glob import glob import os from pipeline.ImgOutlier import detect_outliers from pipeline.normalization import align_images # Detect if running inside Hugging Face Spaces HF_SPACE = os.environ.get('SPACE_ID') is not None # DigitalOcean Spaces upload function def upload_mask(image, prefix="mask"): """ Upload segmentation mask image to DigitalOcean Spaces Args: image: PIL Image object prefix: filename prefix Returns: Public URL of the uploaded file """ try: # Get credentials from environment variables do_key = os.environ.get('DO_SPACES_KEY') do_secret = os.environ.get('DO_SPACES_SECRET') do_region = os.environ.get('DO_SPACES_REGION') do_bucket = os.environ.get('DO_SPACES_BUCKET') # Check if credentials exist if not all([do_key, do_secret, do_region, do_bucket]): return "DigitalOcean credentials not set" # Create S3 client session = boto3.session.Session() client = session.client('s3', region_name=do_region, endpoint_url=f'https://{do_region}.digitaloceanspaces.com', aws_access_key_id=do_key, aws_secret_access_key=do_secret) # Generate unique filename filename = f"{prefix}_{uuid.uuid4().hex}.png" # Convert image to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) # Upload to Spaces client.upload_fileobj( img_byte_arr, do_bucket, filename, ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'} ) # Return public URL url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}' return url except Exception as e: print(f"Upload failed: {str(e)}") return f"Upload error: {str(e)}" # Global Configuration MODEL_PATHS = { "Metal Marcy": "models/MM_best_model.pth", "Silhouette Jaenette": "models/SJ_best_model.pth" } REFERENCE_VECTOR_PATHS = { "Metal Marcy": "models/MM_mean.npy", "Silhouette Jaenette": "models/SJ_mean.npy" } REFERENCE_IMAGE_DIRS = { "Metal Marcy": "reference_images/MM", "Silhouette Jaenette": "reference_images/SJ" } # Category names and color mapping CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand'] COLORS = [ [0, 0, 0], # background - black [139, 137, 137], # cobbles - dark gray [255, 228, 181], # drysand - light yellow [0, 128, 0], # plant - green [135, 206, 235], # sky - sky blue [0, 0, 255], # water - blue [194, 178, 128] # wetsand - sand brown ] # Load model function def load_model(model_path, device="cuda"): try: # If running inside HF Spaces, default to CPU if HF_SPACE: device = "cpu" elif not torch.cuda.is_available(): device = "cpu" model = smp.create_model( "DeepLabV3Plus", encoder_name="efficientnet-b6", in_channels=3, classes=len(CLASSES), encoder_weights=None ) state_dict = torch.load(model_path, map_location=device) if all(k.startswith('model.') for k in state_dict.keys()): state_dict = {k[6:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(device) model.eval() print(f"Model loaded successfully: {model_path}") return model except Exception as e: print(f"Model loading failed: {e}") return None # Load reference vector def load_reference_vector(vector_path): try: if not os.path.exists(vector_path): print(f"Reference vector file not found: {vector_path}") return [] ref_vector = np.load(vector_path) print(f"Reference vector loaded successfully: {vector_path}") return ref_vector except Exception as e: print(f"Reference vector loading failed {vector_path}: {e}") return [] # Load reference images def load_reference_images(ref_dir): try: if not os.path.exists(ref_dir): print(f"Reference image directory not found: {ref_dir}") os.makedirs(ref_dir, exist_ok=True) return [] image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] image_files = [] for ext in image_extensions: image_files.extend(glob(os.path.join(ref_dir, ext))) image_files.sort() reference_images = [] for file in image_files[:4]: img = cv2.imread(file) if img is not None: reference_images.append(img) print(f"Loaded {len(reference_images)} images from {ref_dir}") return reference_images except Exception as e: print(f"Image loading failed {ref_dir}: {e}") return [] # Preprocess the image def preprocess_image(image): if image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) orig_h, orig_w = image.shape[:2] image_resized = cv2.resize(image, (1024, 1024)) image_norm = image_resized.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image_norm = (image_norm - mean) / std image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0) return image_tensor, orig_h, orig_w # Generate segmentation map and visualization def generate_segmentation_map(prediction, orig_h, orig_w): mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8) mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) kernel = np.ones((5, 5), np.uint8) processed_mask = mask_resized.copy() for idx in range(1, len(CLASSES)): class_mask = (mask_resized == idx).astype(np.uint8) dilated_mask = cv2.dilate(class_mask, kernel, iterations=2) dilated_effect = dilated_mask & (mask_resized == 0) processed_mask[dilated_effect > 0] = idx segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8) for idx, color in enumerate(COLORS): segmentation_map[processed_mask == idx] = color return segmentation_map # Analysis result HTML def create_analysis_result(mask): total_pixels = mask.size percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1) for i, cls in enumerate(CLASSES)} ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water'] result = "
" result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered) result += "
" return result # Merge and overlay def create_overlay(image, segmentation_map, alpha=0.5): if image.shape[:2] != segmentation_map.shape[:2]: segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0) # Perform segmentation def perform_segmentation(model, image_bgr): device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) image_tensor, orig_h, orig_w = preprocess_image(image_rgb) with torch.no_grad(): prediction = model(image_tensor.to(device)) seg_map = generate_segmentation_map(prediction, orig_h, orig_w) # RGB overlay = create_overlay(image_rgb, seg_map) mask = prediction.argmax(1).squeeze().cpu().numpy() analysis = create_analysis_result(mask) return seg_map, overlay, analysis # Single image processing def process_coastal_image(location, input_image): if input_image is None: return None, None, "Please upload an image", "Not detected", None device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" model = load_model(MODEL_PATHS[location], device) if model is None: return None, None, f"Error: Failed to load model", "Not detected", None ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location]) ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location]) outlier_status = "Not detected" is_outlier = False image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) if len(ref_vector) > 0: filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector) is_outlier = len(filtered) == 0 elif len(ref_images) > 0: filtered, _ = detect_outliers(ref_images, [image_bgr]) is_outlier = len(filtered) == 0 else: print("Warning: No reference images or reference vectors available for outlier detection") is_outlier = False outlier_status = "Outlier Detection: Failed" if is_outlier else "Outlier Detection: Passed" seg_map, overlay, analysis = perform_segmentation(model, image_bgr) # Try uploading to DigitalOcean Spaces url = "Local Storage" try: url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_')) except Exception as e: print(f"Upload failed: {e}") url = f"Upload error: {str(e)}" if is_outlier: analysis = "
Warning: The image failed outlier detection, the result may be inaccurate!
" + analysis return seg_map, overlay, analysis, outlier_status, url # Spatial Alignment def process_with_alignment(location, reference_image, input_image): if reference_image is None or input_image is None: return None, None, None, None, "Please upload both reference and target images", "Not processed", None device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu" model = load_model(MODEL_PATHS[location], device) if model is None: return None, None, None, None, "Error: Failed to load model", "Not processed", None ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR) tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) try: aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)]) aligned_tgt_bgr = aligned[1] except Exception as e: print(f"Spatial alignment failed: {e}") return None, None, None, None, f"Spatial alignment failed: {str(e)}", "Processing failed", None seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr) # Try uploading to DigitalOcean Spaces url = "Local Storage" try: url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_')) except Exception as e: print(f"Upload failed: {e}") url = f"Upload error: {str(e)}" status = "Spatial Alignment: Completed" ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB) aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB) return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url # Create the Gradio interface def create_interface(): # Set unified display size disp_w, disp_h = 683, 512 # Maintain aspect ratio with gr.Blocks(title="Coastal Erosion Analysis System") as demo: gr.Markdown("""# Coastal Erosion Analysis System Upload coastal images for analysis, including segmentation and spatial alignment.""") with gr.Tabs(): with gr.TabItem("Single Image Segmentation"): with gr.Row(): loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) with gr.Row(): inp = gr.Image(label="Input Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) seg = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) ovl = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) with gr.Row(): btn1 = gr.Button("Run Segmentation") url1 = gr.Text(label="Segmentation Image URL") status1 = gr.HTML(label="Outlier Detection Status") res1 = gr.HTML(label="Analysis Result") btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1]) with gr.TabItem("Spatial Alignment Segmentation"): with gr.Row(): loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0]) with gr.Row(): ref_img = gr.Image(label="Reference Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) tgt_img = gr.Image(label="Target Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w) with gr.Row(): btn2 = gr.Button("Run Spatial Alignment and Segmentation") with gr.Row(): orig = gr.Image(label="Original Image", type="numpy", height=disp_h, width=disp_w) aligned = gr.Image(label="Aligned Image", type="numpy", height=disp_h, width=disp_w) with gr.Row(): seg2 = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w) ovl2 = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w) url2 = gr.Text(label="Segmentation Image URL") status2 = gr.HTML(label="Alignment Status") res2 = gr.HTML(label="Analysis Result") btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2]) return demo if __name__ == "__main__": # Create necessary directories for path in ["models", "reference_images/MM", "reference_images/SJ"]: os.makedirs(path, exist_ok=True) # Check if model files exist for p in MODEL_PATHS.values(): if not os.path.exists(p): print(f"Warning: Model file {p} does not exist!") # Check if DigitalOcean credentials exist do_creds = [ os.environ.get('DO_SPACES_KEY'), os.environ.get('DO_SPACES_SECRET'), os.environ.get('DO_SPACES_REGION'), os.environ.get('DO_SPACES_BUCKET') ] if not all(do_creds): print("Warning: Incomplete DigitalOcean Spaces credentials, upload functionality may not work") # Create and launch the interface demo = create_interface() if HF_SPACE: demo.launch() else: demo.launch(share=True)