File size: 16,399 Bytes
a95debb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
import gradio as gr
import cv2
import numpy as np
import torch
import sys
import os
import pyvirtualcam
from pyvirtualcam import PixelFormat
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image

# Path configurations
depth_anything_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
if depth_anything_path is None:
    raise ValueError("Environment variable DEPTH_ANYTHING_V2_PATH is not set. Please set it to the path of Depth-Anything-V2")
sys.path.append(depth_anything_path)

from depth_anything_v2.dpt import DepthAnythingV2

# Device selection with MPS support
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

###########################################
# CycleGAN Generator Architecture
###########################################

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_channels, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

###########################################
# Depth Anything Model Functions
###########################################

# Model configurations
model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
}

encoder2name = {
    'vits': 'Small',
    'vitb': 'Base',
    'vitl': 'Large'
}

# Model IDs and filenames for HuggingFace Hub
MODEL_INFO = {
    'vits': {
        'repo_id': 'depth-anything/Depth-Anything-V2-Small',
        'filename': 'depth_anything_v2_vits.pth'
    },
    'vitb': {
        'repo_id': 'depth-anything/Depth-Anything-V2-Base',
        'filename': 'depth_anything_v2_vitb.pth'
    },
    'vitl': {
        'repo_id': 'depth-anything/Depth-Anything-V2-Large',
        'filename': 'depth_anything_v2_vitl.pth'
    }
}

# Global variables for model management
current_depth_model = None
current_encoder = None
current_cyclegan_model = None

def download_model(encoder):
    """Download the specified model from HuggingFace Hub"""
    model_info = MODEL_INFO[encoder]
    model_path = hf_hub_download(
        repo_id=model_info['repo_id'],
        filename=model_info['filename'],
        local_dir='checkpoints'
    )
    return model_path

def load_depth_model(encoder):
    """Load the specified depth model"""
    global current_depth_model, current_encoder
    if current_encoder != encoder:
        model_path = download_model(encoder)
        current_depth_model = DepthAnythingV2(**model_configs[encoder])
        current_depth_model.load_state_dict(torch.load(model_path, map_location='cpu'))
        current_depth_model = current_depth_model.to(DEVICE).eval()
        current_encoder = encoder
    return current_depth_model

def load_cyclegan_model(model_path):
    """Load the CycleGAN model"""
    global current_cyclegan_model
    if current_cyclegan_model is None:
        model = Generator()
        if os.path.exists(model_path):
            print(f"Loading CycleGAN model from {model_path}")
            state_dict = torch.load(model_path, map_location='cpu')
            try:
                model.load_state_dict(state_dict)
            except Exception as e:
                print(f"Warning: {e}")
                # Try loading with strict=False
                model.load_state_dict(state_dict, strict=False)
                print("Loaded model with strict=False")
        else:
            print(f"Error: CycleGAN model file not found at {model_path}")
            return None
        model.eval()
        current_cyclegan_model = model.to(DEVICE)
    return current_cyclegan_model

@torch.inference_mode()
def predict_depth(image, encoder):
    """Predict depth using the selected model"""
    model = load_depth_model(encoder)
    depth = model.infer_image(image)
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.astype(np.uint8)
    return depth

def apply_winter_colormap(depth_map):
    """Apply a winter-themed colormap to the depth map"""
    # Use COLORMAP_WINTER for blue to teal colors
    depth_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_WINTER)
    return depth_colored

def blend_images(original, depth_colored, alpha=0.1):
    """
    Blend the original image on top of the colored depth map
    
    Parameters:
    - original: Original webcam frame (BGR format)
    - depth_colored: Colorized depth map (BGR format)
    - alpha: Blend strength of original webcam (0.0 = depth only, 1.0 = original only)
    
    Returns:
    - Blended image where depth map is the base layer and original is overlaid with transparency
    """
    # Make sure both images have the same dimensions
    if original.shape != depth_colored.shape:
        depth_colored = cv2.resize(depth_colored, (original.shape[1], original.shape[0]))
    
    # Start with depth map at 100% opacity as base
    # Then add original image on top with specified alpha transparency
    result = cv2.addWeighted(depth_colored, 1.0, original, alpha, 0)
    
    return result

def preprocess_for_cyclegan(image, original_size=None):
    """Preprocess image for CycleGAN input"""
    # Convert numpy array to PIL Image
    image_pil = Image.fromarray(image)
    
    # Save original size if provided
    if original_size is None:
        original_size = (image.shape[1], image.shape[0])  # (width, height)
    
    # Create transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Process image
    input_tensor = transform(image_pil).unsqueeze(0).to(DEVICE)
    return input_tensor, original_size

def postprocess_from_cyclegan(tensor, original_size):
    """Convert CycleGAN output tensor to numpy image with original dimensions"""
    tensor = tensor.squeeze(0).cpu()
    tensor = (tensor + 1) / 2
    tensor = tensor.clamp(0, 1)
    tensor = tensor.permute(1, 2, 0).numpy()
    # Convert to uint8
    image = (tensor * 255).astype(np.uint8)
    # Resize back to original dimensions
    if image.shape[0] != original_size[1] or image.shape[1] != original_size[0]:
        image = cv2.resize(image, original_size)
    return image

@torch.inference_mode()
def apply_cyclegan(image, direction):
    """Apply CycleGAN transformation to the image"""
    if direction == "Depth to Image":
        model_path = "./checkpoints/depth2image/latest_net_G_A.pth"
    else:
        model_path = "./checkpoints/depth2image/latest_net_G_B.pth"
    
    model = load_cyclegan_model(model_path)
    if model is None:
        return None
    
    # Save original dimensions
    original_size = (image.shape[1], image.shape[0])  # (width, height)
    
    # Preprocess
    input_tensor, _ = preprocess_for_cyclegan(image, original_size)
    
    # Generate output
    output_tensor = model(input_tensor)
    
    # Postprocess with original size
    output_image = postprocess_from_cyclegan(output_tensor, original_size)
    
    return output_image

def process_webcam_with_depth_and_cyclegan(encoder, blend_alpha, cyclegan_direction, enable_cyclegan=True):
    """Process webcam with depth, blend, and optionally apply CycleGAN"""
    # Open the webcam
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("Error: Could not open webcam")
        return
    
    # Read a test frame to get the actual dimensions
    ret, test_frame = cap.read()
    if not ret:
        print("Error: Could not read from webcam")
        return
    
    # Get the actual frame dimensions
    frame_height, frame_width = test_frame.shape[:2]
    print(f"Webcam frame dimensions: {frame_width}x{frame_height}")
    
    # Ensure checkpoints directory exists
    os.makedirs("checkpoints/depth2image", exist_ok=True)
    
    # Create a preview window
    preview_window = "Depth Winter + CycleGAN Preview"
    cv2.namedWindow(preview_window, cv2.WINDOW_NORMAL)
    
    try:
        # Initialize virtual camera with exact frame dimensions
        with pyvirtualcam.Camera(width=frame_width, height=frame_height, fps=30, fmt=PixelFormat.BGR, backend='obs') as cam:
            print(f'Using virtual camera: {cam.device}')
            print(f'Virtual camera dimensions: {cam.width}x{cam.height}')
            
            frame_count = 0
            while True:
                # Capture frame
                ret, frame = cap.read()
                if not ret:
                    break
                
                # Print dimensions occasionally for debugging
                if frame_count % 100 == 0:
                    print(f"Frame {frame_count} dimensions: {frame.shape}")
                frame_count += 1
                
                # Convert BGR to RGB for depth prediction
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Predict depth
                depth_map = predict_depth(frame_rgb, encoder)
                
                # Apply winter colormap
                depth_colored = apply_winter_colormap(depth_map)
                
                # Blend with original
                blended = blend_images(frame, depth_colored, alpha=blend_alpha)
                
                # Apply CycleGAN if enabled
                if enable_cyclegan:
                    if cyclegan_direction == "Image to Depth":
                        # For Image to Depth, use raw webcam feed (not blended)
                        input_for_gan = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    else:
                        # For Depth to Image, use the blended result
                        input_for_gan = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
                    
                    cyclegan_output = apply_cyclegan(input_for_gan, cyclegan_direction)
                    
                    if cyclegan_output is not None:
                        # Convert RGB back to BGR for virtual cam
                        output = cv2.cvtColor(cyclegan_output, cv2.COLOR_RGB2BGR)
                    else:
                        output = blended
                else:
                    output = blended
                
                # Ensure output has the exact dimensions expected by the virtual camera
                if output.shape[0] != frame_height or output.shape[1] != frame_width:
                    print(f"Resizing output from {output.shape[1]}x{output.shape[0]} to {frame_width}x{frame_height}")
                    output = cv2.resize(output, (frame_width, frame_height))
                
                # Show preview
                cv2.imshow(preview_window, output)
                
                # Send to virtual camera
                try:
                    cam.send(output)
                    cam.sleep_until_next_frame()
                except Exception as e:
                    print(f"Error sending to virtual camera: {e}")
                    print(f"Output shape: {output.shape}, Expected: {frame_height}x{frame_width}x3")
                
                # Press 'q' to exit
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
    
    except Exception as e:
        print(f"Error in webcam processing: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        # Clean up
        cap.release()
        cv2.destroyAllWindows()

###########################################
# Gradio Interface
###########################################

with gr.Blocks(title="Depth Anything with CycleGAN") as demo:
    gr.Markdown("# Depth Anything V2 with Winter Colormap + CycleGAN")
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=list(encoder2name.values()),
                value="Small",
                label="Select Depth Model Size"
            )
            
            blend_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                value=0.1,  # Set default to 0.1 (10% webcam opacity)
                step=0.1,
                label="Webcam Overlay Opacity (0 = depth only, 1 = full webcam overlay)"
            )
            
            cyclegan_toggle = gr.Checkbox(
                value=True,
                label="Enable CycleGAN Transformation"
            )
            
            cyclegan_direction = gr.Radio(
                choices=["Depth to Image", "Image to Depth"],
                value="Depth to Image",
                label="CycleGAN Direction"
            )
            
            start_button = gr.Button("Start Processing", variant="primary")
        
        with gr.Column():
            output_status = gr.Textbox(
                label="Status",
                value="Ready to start...",
                interactive=False
            )
    
    # Instructions
    gr.Markdown("""
    ### Instructions:
    1. Select the depth model size (smaller models are faster but less accurate)
    2. Adjust the blend strength between the original webcam feed and the winter-colored depth map
    3. Enable/disable CycleGAN transformation
    4. Select the CycleGAN conversion direction
    5. Click "Start Processing" to begin the virtual camera feed
    6. A preview window will open - press 'q' in that window to stop processing
    
    **Note:** You'll need to have pyvirtualcam installed and a virtual camera device 
    (like OBS Virtual Camera) configured on your system.
    """)
    
    def start_processing(model_name, blend_alpha, enable_cyclegan, cyclegan_dir):
        encoder = {v: k for k, v in encoder2name.items()}[model_name]
        try:
            process_webcam_with_depth_and_cyclegan(
                encoder, 
                blend_alpha, 
                cyclegan_dir,
                enable_cyclegan
            )
            return "Processing completed. (If this message appears immediately, check for errors in the console)"
        except Exception as e:
            import traceback
            traceback.print_exc()
            return f"Error: {str(e)}"
    
    start_button.click(
        fn=start_processing,
        inputs=[model_dropdown, blend_slider, cyclegan_toggle, cyclegan_direction],
        outputs=output_status
    )

if __name__ == "__main__":
    demo.launch()