File size: 15,270 Bytes
3924e13
 
 
 
 
 
40aaca9
 
 
3924e13
b2048d6
3924e13
 
 
b2048d6
3924e13
 
b2048d6
40aaca9
 
b2048d6
 
40aaca9
b2048d6
 
40aaca9
 
b2048d6
40aaca9
3924e13
b2048d6
40aaca9
 
 
 
 
b2048d6
40aaca9
b2048d6
40aaca9
b2048d6
40aaca9
 
 
 
 
 
 
b2048d6
40aaca9
 
b2048d6
40aaca9
 
 
 
b2048d6
40aaca9
 
 
 
 
 
 
b2048d6
40aaca9
 
 
 
b2048d6
 
3924e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2048d6
3924e13
b2048d6
40aaca9
b2048d6
3924e13
 
 
 
 
 
 
 
 
 
 
 
 
 
b2048d6
3924e13
 
b2048d6
3924e13
 
 
 
 
40aaca9
b2048d6
40aaca9
3924e13
b2048d6
3924e13
 
b2048d6
3924e13
 
b2048d6
3924e13
 
40aaca9
b2048d6
40aaca9
 
 
3924e13
 
 
 
 
 
 
 
 
 
b2048d6
3924e13
 
b2048d6
3924e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2048d6
3924e13
 
 
 
 
b2048d6
3924e13
40aaca9
3924e13
 
b2048d6
3924e13
 
 
 
 
 
40aaca9
3924e13
 
40aaca9
b2048d6
40aaca9
3924e13
b2048d6
3924e13
 
b2048d6
 
40aaca9
 
 
b2048d6
 
3924e13
 
b2048d6
3924e13
 
 
b2048d6
3924e13
 
b2048d6
3924e13
 
 
 
 
b2048d6
3924e13
 
 
 
40aaca9
 
 
 
b2048d6
 
3924e13
 
 
b2048d6
 
40aaca9
 
 
b2048d6
 
3924e13
b2048d6
3924e13
 
 
 
 
 
 
b2048d6
 
 
 
 
3924e13
b2048d6
3924e13
b2048d6
3924e13
b2048d6
3924e13
b2048d6
 
 
3924e13
b2048d6
 
 
 
40aaca9
3924e13
b2048d6
3924e13
b2048d6
3924e13
b2048d6
 
3924e13
b2048d6
3924e13
b2048d6
 
3924e13
b2048d6
 
 
 
 
3924e13
 
 
 
b2048d6
3924e13
 
40aaca9
b2048d6
3924e13
 
b2048d6
40aaca9
b2048d6
40aaca9
 
 
 
 
 
 
b2048d6
40aaca9
b2048d6
3924e13
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import cv2
import numpy as np
import torch
import gradio as gr
import segmentation_models_pytorch as smp
from PIL import Image
import boto3
import uuid
import io
from glob import glob
import os
from pipeline.ImgOutlier import detect_outliers
from pipeline.normalization import align_images

# Detect if running inside Hugging Face Spaces
HF_SPACE = os.environ.get('SPACE_ID') is not None

# DigitalOcean Spaces upload function
def upload_mask(image, prefix="mask"):
    """
    Upload segmentation mask image to DigitalOcean Spaces

    Args:
        image: PIL Image object
        prefix: filename prefix
        
    Returns:
        Public URL of the uploaded file
    """
    try:
        # Get credentials from environment variables
        do_key = os.environ.get('DO_SPACES_KEY')
        do_secret = os.environ.get('DO_SPACES_SECRET')
        do_region = os.environ.get('DO_SPACES_REGION')
        do_bucket = os.environ.get('DO_SPACES_BUCKET')
        
        # Check if credentials exist
        if not all([do_key, do_secret, do_region, do_bucket]):
            return "DigitalOcean credentials not set"
        
        # Create S3 client
        session = boto3.session.Session()
        client = session.client('s3',
                               region_name=do_region,
                               endpoint_url=f'https://{do_region}.digitaloceanspaces.com',
                               aws_access_key_id=do_key,
                               aws_secret_access_key=do_secret)
        
        # Generate unique filename
        filename = f"{prefix}_{uuid.uuid4().hex}.png"
        
        # Convert image to bytes
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)
        
        # Upload to Spaces
        client.upload_fileobj(
            img_byte_arr,
            do_bucket,
            filename,
            ExtraArgs={'ACL': 'public-read', 'ContentType': 'image/png'}
        )
        
        # Return public URL
        url = f'https://{do_bucket}.{do_region}.digitaloceanspaces.com/{filename}'
        return url
    
    except Exception as e:
        print(f"Upload failed: {str(e)}")
        return f"Upload error: {str(e)}"

# Global Configuration
MODEL_PATHS = {
    "Metal Marcy": "models/MM_best_model.pth",
    "Silhouette Jaenette": "models/SJ_best_model.pth"
}

REFERENCE_VECTOR_PATHS = {
    "Metal Marcy": "models/MM_mean.npy",
    "Silhouette Jaenette": "models/SJ_mean.npy"
}

REFERENCE_IMAGE_DIRS = {
    "Metal Marcy": "reference_images/MM",
    "Silhouette Jaenette": "reference_images/SJ"
}

# Category names and color mapping
CLASSES = ['background', 'cobbles', 'drysand', 'plant', 'sky', 'water', 'wetsand']
COLORS = [
    [0, 0, 0],        # background - black
    [139, 137, 137],  # cobbles - dark gray
    [255, 228, 181],  # drysand - light yellow
    [0, 128, 0],      # plant - green
    [135, 206, 235],  # sky - sky blue
    [0, 0, 255],      # water - blue
    [194, 178, 128]   # wetsand - sand brown
]

# Load model function
def load_model(model_path, device="cuda"):
    try:
        # If running inside HF Spaces, default to CPU
        if HF_SPACE:
            device = "cpu"
        elif not torch.cuda.is_available():
            device = "cpu"
        
        model = smp.create_model(
            "DeepLabV3Plus",
            encoder_name="efficientnet-b6",
            in_channels=3,
            classes=len(CLASSES),
            encoder_weights=None
        )
        state_dict = torch.load(model_path, map_location=device)
        if all(k.startswith('model.') for k in state_dict.keys()):
            state_dict = {k[6:]: v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
        print(f"Model loaded successfully: {model_path}")
        return model
    except Exception as e:
        print(f"Model loading failed: {e}")
        return None

# Load reference vector
def load_reference_vector(vector_path):
    try:
        if not os.path.exists(vector_path):
            print(f"Reference vector file not found: {vector_path}")
            return []
        ref_vector = np.load(vector_path)
        print(f"Reference vector loaded successfully: {vector_path}")
        return ref_vector
    except Exception as e:
        print(f"Reference vector loading failed {vector_path}: {e}")
        return []

# Load reference images
def load_reference_images(ref_dir):
    try:
        if not os.path.exists(ref_dir):
            print(f"Reference image directory not found: {ref_dir}")
            os.makedirs(ref_dir, exist_ok=True)
            return []
            
        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
        image_files = []
        for ext in image_extensions:
            image_files.extend(glob(os.path.join(ref_dir, ext)))
        image_files.sort()
        reference_images = []
        for file in image_files[:4]:
            img = cv2.imread(file)
            if img is not None:
                reference_images.append(img)
        print(f"Loaded {len(reference_images)} images from {ref_dir}")
        return reference_images
    except Exception as e:
        print(f"Image loading failed {ref_dir}: {e}")
        return []

# Preprocess the image
def preprocess_image(image):
    if image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
    orig_h, orig_w = image.shape[:2]
    image_resized = cv2.resize(image, (1024, 1024))
    image_norm = image_resized.astype(np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image_norm = (image_norm - mean) / std
    image_tensor = torch.from_numpy(image_norm.transpose(2, 0, 1)).float().unsqueeze(0)
    return image_tensor, orig_h, orig_w

# Generate segmentation map and visualization
def generate_segmentation_map(prediction, orig_h, orig_w):
    mask = prediction.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
    mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
    kernel = np.ones((5, 5), np.uint8)
    processed_mask = mask_resized.copy()
    for idx in range(1, len(CLASSES)):
        class_mask = (mask_resized == idx).astype(np.uint8)
        dilated_mask = cv2.dilate(class_mask, kernel, iterations=2)
        dilated_effect = dilated_mask & (mask_resized == 0)
        processed_mask[dilated_effect > 0] = idx
    segmentation_map = np.zeros((orig_h, orig_w, 3), dtype=np.uint8)
    for idx, color in enumerate(COLORS):
        segmentation_map[processed_mask == idx] = color
    return segmentation_map

# Analysis result HTML
def create_analysis_result(mask):
    total_pixels = mask.size
    percentages = {cls: round((np.sum(mask == i) / total_pixels) * 100, 1)
                   for i, cls in enumerate(CLASSES)}
    ordered = ['sky', 'cobbles', 'plant', 'drysand', 'wetsand', 'water']
    result = "<div style='font-size:18px;font-weight:bold;'>"
    result += " | ".join(f"{cls}: {percentages.get(cls,0)}%" for cls in ordered)
    result += "</div>"
    return result

# Merge and overlay
def create_overlay(image, segmentation_map, alpha=0.5):
    if image.shape[:2] != segmentation_map.shape[:2]:
        segmentation_map = cv2.resize(segmentation_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
    return cv2.addWeighted(image, 1-alpha, segmentation_map, alpha, 0)
# Perform segmentation
def perform_segmentation(model, image_bgr):
    device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu"
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    image_tensor, orig_h, orig_w = preprocess_image(image_rgb)
    with torch.no_grad():
        prediction = model(image_tensor.to(device))
    seg_map = generate_segmentation_map(prediction, orig_h, orig_w)  # RGB
    overlay = create_overlay(image_rgb, seg_map)
    mask = prediction.argmax(1).squeeze().cpu().numpy()
    analysis = create_analysis_result(mask)
    return seg_map, overlay, analysis

# Single image processing
def process_coastal_image(location, input_image):
    if input_image is None:
        return None, None, "Please upload an image", "Not detected", None
    
    device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu"
    model = load_model(MODEL_PATHS[location], device)
    
    if model is None:
        return None, None, f"Error: Failed to load model", "Not detected", None
    
    ref_vector = load_reference_vector(REFERENCE_VECTOR_PATHS[location])
    ref_images = load_reference_images(REFERENCE_IMAGE_DIRS[location])
    
    outlier_status = "Not detected"
    is_outlier = False
    image_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
    
    if len(ref_vector) > 0:
        filtered, _ = detect_outliers(ref_images, [image_bgr], ref_vector)
        is_outlier = len(filtered) == 0
    elif len(ref_images) > 0:
        filtered, _ = detect_outliers(ref_images, [image_bgr])
        is_outlier = len(filtered) == 0
    else:
        print("Warning: No reference images or reference vectors available for outlier detection")
        is_outlier = False
    
    outlier_status = "Outlier Detection: <span style='color:red;font-weight:bold'>Failed</span>" if is_outlier else "Outlier Detection: <span style='color:green;font-weight:bold'>Passed</span>"
    seg_map, overlay, analysis = perform_segmentation(model, image_bgr)
    
    # Try uploading to DigitalOcean Spaces
    url = "Local Storage"
    try:
        url = upload_mask(Image.fromarray(seg_map), prefix=location.replace(' ', '_'))
    except Exception as e:
        print(f"Upload failed: {e}")
        url = f"Upload error: {str(e)}"
    
    if is_outlier:
        analysis = "<div style='color:red;font-weight:bold;margin-bottom:10px'>Warning: The image failed outlier detection, the result may be inaccurate!</div>" + analysis
    
    return seg_map, overlay, analysis, outlier_status, url

# Spatial Alignment
def process_with_alignment(location, reference_image, input_image):
    if reference_image is None or input_image is None:
        return None, None, None, None, "Please upload both reference and target images", "Not processed", None
    
    device = "cuda" if torch.cuda.is_available() and not HF_SPACE else "cpu"
    model = load_model(MODEL_PATHS[location], device)
    
    if model is None:
        return None, None, None, None, "Error: Failed to load model", "Not processed", None
    
    ref_bgr = cv2.cvtColor(np.array(reference_image), cv2.COLOR_RGB2BGR)
    tgt_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
    
    try:
        aligned, _ = align_images([ref_bgr, tgt_bgr], [np.zeros_like(ref_bgr), np.zeros_like(tgt_bgr)])
        aligned_tgt_bgr = aligned[1]
    except Exception as e:
        print(f"Spatial alignment failed: {e}")
        return None, None, None, None, f"Spatial alignment failed: {str(e)}", "Processing failed", None
    
    seg_map, overlay, analysis = perform_segmentation(model, aligned_tgt_bgr)
    
    # Try uploading to DigitalOcean Spaces
    url = "Local Storage"
    try:
        url = upload_mask(Image.fromarray(seg_map), prefix="aligned_" + location.replace(' ', '_'))
    except Exception as e:
        print(f"Upload failed: {e}")
        url = f"Upload error: {str(e)}"
    
    status = "Spatial Alignment: <span style='color:green;font-weight:bold'>Completed</span>"
    ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)
    aligned_tgt_rgb = cv2.cvtColor(aligned_tgt_bgr, cv2.COLOR_BGR2RGB)
    
    return ref_rgb, aligned_tgt_rgb, seg_map, overlay, analysis, status, url

# Create the Gradio interface
def create_interface():
    # Set unified display size
    disp_w, disp_h = 683, 512  # Maintain aspect ratio

    with gr.Blocks(title="Coastal Erosion Analysis System") as demo:
        gr.Markdown("""# Coastal Erosion Analysis System

Upload coastal images for analysis, including segmentation and spatial alignment.""")
        with gr.Tabs():
            with gr.TabItem("Single Image Segmentation"):
                with gr.Row():
                    loc1 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0])
                with gr.Row():
                    inp = gr.Image(label="Input Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w)
                    seg = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w)
                    ovl = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w)
                with gr.Row():
                    btn1 = gr.Button("Run Segmentation")
                url1 = gr.Text(label="Segmentation Image URL")
                status1 = gr.HTML(label="Outlier Detection Status")
                res1 = gr.HTML(label="Analysis Result")
                btn1.click(fn=process_coastal_image, inputs=[loc1, inp], outputs=[seg, ovl, res1, status1, url1])
            
            with gr.TabItem("Spatial Alignment Segmentation"):
                with gr.Row():
                    loc2 = gr.Radio(list(MODEL_PATHS.keys()), label="Select Model", value=list(MODEL_PATHS.keys())[0])
                with gr.Row():
                    ref_img = gr.Image(label="Reference Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w)
                    tgt_img = gr.Image(label="Target Image", type="numpy", image_mode="RGB", height=disp_h, width=disp_w)
                with gr.Row():
                    btn2 = gr.Button("Run Spatial Alignment and Segmentation")
                with gr.Row():
                    orig = gr.Image(label="Original Image", type="numpy", height=disp_h, width=disp_w)
                    aligned = gr.Image(label="Aligned Image", type="numpy", height=disp_h, width=disp_w)
                with gr.Row():
                    seg2 = gr.Image(label="Segmentation Map", type="numpy", height=disp_h, width=disp_w)
                    ovl2 = gr.Image(label="Overlay Image", type="numpy", height=disp_h, width=disp_w)
                url2 = gr.Text(label="Segmentation Image URL")
                status2 = gr.HTML(label="Alignment Status")
                res2 = gr.HTML(label="Analysis Result")
                btn2.click(fn=process_with_alignment, inputs=[loc2, ref_img, tgt_img], outputs=[orig, aligned, seg2, ovl2, res2, status2, url2])
    return demo

if __name__ == "__main__":
    # Create necessary directories
    for path in ["models", "reference_images/MM", "reference_images/SJ"]:
        os.makedirs(path, exist_ok=True)
    
    # Check if model files exist
    for p in MODEL_PATHS.values():
        if not os.path.exists(p):
            print(f"Warning: Model file {p} does not exist!")
    
    # Check if DigitalOcean credentials exist
    do_creds = [
        os.environ.get('DO_SPACES_KEY'),
        os.environ.get('DO_SPACES_SECRET'),
        os.environ.get('DO_SPACES_REGION'),
        os.environ.get('DO_SPACES_BUCKET')
    ]
    if not all(do_creds):
        print("Warning: Incomplete DigitalOcean Spaces credentials, upload functionality may not work")
    
    # Create and launch the interface
    demo = create_interface()
    if HF_SPACE:
        demo.launch()
    else:
        demo.launch(share=True)