File size: 16,399 Bytes
a95debb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 |
import gradio as gr
import cv2
import numpy as np
import torch
import sys
import os
import pyvirtualcam
from pyvirtualcam import PixelFormat
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
# Path configurations
depth_anything_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
if depth_anything_path is None:
raise ValueError("Environment variable DEPTH_ANYTHING_V2_PATH is not set. Please set it to the path of Depth-Anything-V2")
sys.path.append(depth_anything_path)
from depth_anything_v2.dpt import DepthAnythingV2
# Device selection with MPS support
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
###########################################
# CycleGAN Generator Architecture
###########################################
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels)
)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
super(Generator, self).__init__()
# Initial convolution
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
]
# Downsampling
in_features = 64
out_features = in_features * 2
for _ in range(2):
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
# Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)]
# Upsampling
out_features = in_features // 2
for _ in range(2):
model += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
# Output layer
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_channels, 7),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
###########################################
# Depth Anything Model Functions
###########################################
# Model configurations
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
}
encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large'
}
# Model IDs and filenames for HuggingFace Hub
MODEL_INFO = {
'vits': {
'repo_id': 'depth-anything/Depth-Anything-V2-Small',
'filename': 'depth_anything_v2_vits.pth'
},
'vitb': {
'repo_id': 'depth-anything/Depth-Anything-V2-Base',
'filename': 'depth_anything_v2_vitb.pth'
},
'vitl': {
'repo_id': 'depth-anything/Depth-Anything-V2-Large',
'filename': 'depth_anything_v2_vitl.pth'
}
}
# Global variables for model management
current_depth_model = None
current_encoder = None
current_cyclegan_model = None
def download_model(encoder):
"""Download the specified model from HuggingFace Hub"""
model_info = MODEL_INFO[encoder]
model_path = hf_hub_download(
repo_id=model_info['repo_id'],
filename=model_info['filename'],
local_dir='checkpoints'
)
return model_path
def load_depth_model(encoder):
"""Load the specified depth model"""
global current_depth_model, current_encoder
if current_encoder != encoder:
model_path = download_model(encoder)
current_depth_model = DepthAnythingV2(**model_configs[encoder])
current_depth_model.load_state_dict(torch.load(model_path, map_location='cpu'))
current_depth_model = current_depth_model.to(DEVICE).eval()
current_encoder = encoder
return current_depth_model
def load_cyclegan_model(model_path):
"""Load the CycleGAN model"""
global current_cyclegan_model
if current_cyclegan_model is None:
model = Generator()
if os.path.exists(model_path):
print(f"Loading CycleGAN model from {model_path}")
state_dict = torch.load(model_path, map_location='cpu')
try:
model.load_state_dict(state_dict)
except Exception as e:
print(f"Warning: {e}")
# Try loading with strict=False
model.load_state_dict(state_dict, strict=False)
print("Loaded model with strict=False")
else:
print(f"Error: CycleGAN model file not found at {model_path}")
return None
model.eval()
current_cyclegan_model = model.to(DEVICE)
return current_cyclegan_model
@torch.inference_mode()
def predict_depth(image, encoder):
"""Predict depth using the selected model"""
model = load_depth_model(encoder)
depth = model.infer_image(image)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
return depth
def apply_winter_colormap(depth_map):
"""Apply a winter-themed colormap to the depth map"""
# Use COLORMAP_WINTER for blue to teal colors
depth_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_WINTER)
return depth_colored
def blend_images(original, depth_colored, alpha=0.1):
"""
Blend the original image on top of the colored depth map
Parameters:
- original: Original webcam frame (BGR format)
- depth_colored: Colorized depth map (BGR format)
- alpha: Blend strength of original webcam (0.0 = depth only, 1.0 = original only)
Returns:
- Blended image where depth map is the base layer and original is overlaid with transparency
"""
# Make sure both images have the same dimensions
if original.shape != depth_colored.shape:
depth_colored = cv2.resize(depth_colored, (original.shape[1], original.shape[0]))
# Start with depth map at 100% opacity as base
# Then add original image on top with specified alpha transparency
result = cv2.addWeighted(depth_colored, 1.0, original, alpha, 0)
return result
def preprocess_for_cyclegan(image, original_size=None):
"""Preprocess image for CycleGAN input"""
# Convert numpy array to PIL Image
image_pil = Image.fromarray(image)
# Save original size if provided
if original_size is None:
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Create transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Process image
input_tensor = transform(image_pil).unsqueeze(0).to(DEVICE)
return input_tensor, original_size
def postprocess_from_cyclegan(tensor, original_size):
"""Convert CycleGAN output tensor to numpy image with original dimensions"""
tensor = tensor.squeeze(0).cpu()
tensor = (tensor + 1) / 2
tensor = tensor.clamp(0, 1)
tensor = tensor.permute(1, 2, 0).numpy()
# Convert to uint8
image = (tensor * 255).astype(np.uint8)
# Resize back to original dimensions
if image.shape[0] != original_size[1] or image.shape[1] != original_size[0]:
image = cv2.resize(image, original_size)
return image
@torch.inference_mode()
def apply_cyclegan(image, direction):
"""Apply CycleGAN transformation to the image"""
if direction == "Depth to Image":
model_path = "./checkpoints/depth2image/latest_net_G_A.pth"
else:
model_path = "./checkpoints/depth2image/latest_net_G_B.pth"
model = load_cyclegan_model(model_path)
if model is None:
return None
# Save original dimensions
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Preprocess
input_tensor, _ = preprocess_for_cyclegan(image, original_size)
# Generate output
output_tensor = model(input_tensor)
# Postprocess with original size
output_image = postprocess_from_cyclegan(output_tensor, original_size)
return output_image
def process_webcam_with_depth_and_cyclegan(encoder, blend_alpha, cyclegan_direction, enable_cyclegan=True):
"""Process webcam with depth, blend, and optionally apply CycleGAN"""
# Open the webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open webcam")
return
# Read a test frame to get the actual dimensions
ret, test_frame = cap.read()
if not ret:
print("Error: Could not read from webcam")
return
# Get the actual frame dimensions
frame_height, frame_width = test_frame.shape[:2]
print(f"Webcam frame dimensions: {frame_width}x{frame_height}")
# Ensure checkpoints directory exists
os.makedirs("checkpoints/depth2image", exist_ok=True)
# Create a preview window
preview_window = "Depth Winter + CycleGAN Preview"
cv2.namedWindow(preview_window, cv2.WINDOW_NORMAL)
try:
# Initialize virtual camera with exact frame dimensions
with pyvirtualcam.Camera(width=frame_width, height=frame_height, fps=30, fmt=PixelFormat.BGR, backend='obs') as cam:
print(f'Using virtual camera: {cam.device}')
print(f'Virtual camera dimensions: {cam.width}x{cam.height}')
frame_count = 0
while True:
# Capture frame
ret, frame = cap.read()
if not ret:
break
# Print dimensions occasionally for debugging
if frame_count % 100 == 0:
print(f"Frame {frame_count} dimensions: {frame.shape}")
frame_count += 1
# Convert BGR to RGB for depth prediction
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Predict depth
depth_map = predict_depth(frame_rgb, encoder)
# Apply winter colormap
depth_colored = apply_winter_colormap(depth_map)
# Blend with original
blended = blend_images(frame, depth_colored, alpha=blend_alpha)
# Apply CycleGAN if enabled
if enable_cyclegan:
if cyclegan_direction == "Image to Depth":
# For Image to Depth, use raw webcam feed (not blended)
input_for_gan = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
# For Depth to Image, use the blended result
input_for_gan = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
cyclegan_output = apply_cyclegan(input_for_gan, cyclegan_direction)
if cyclegan_output is not None:
# Convert RGB back to BGR for virtual cam
output = cv2.cvtColor(cyclegan_output, cv2.COLOR_RGB2BGR)
else:
output = blended
else:
output = blended
# Ensure output has the exact dimensions expected by the virtual camera
if output.shape[0] != frame_height or output.shape[1] != frame_width:
print(f"Resizing output from {output.shape[1]}x{output.shape[0]} to {frame_width}x{frame_height}")
output = cv2.resize(output, (frame_width, frame_height))
# Show preview
cv2.imshow(preview_window, output)
# Send to virtual camera
try:
cam.send(output)
cam.sleep_until_next_frame()
except Exception as e:
print(f"Error sending to virtual camera: {e}")
print(f"Output shape: {output.shape}, Expected: {frame_height}x{frame_width}x3")
# Press 'q' to exit
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except Exception as e:
print(f"Error in webcam processing: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up
cap.release()
cv2.destroyAllWindows()
###########################################
# Gradio Interface
###########################################
with gr.Blocks(title="Depth Anything with CycleGAN") as demo:
gr.Markdown("# Depth Anything V2 with Winter Colormap + CycleGAN")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(encoder2name.values()),
value="Small",
label="Select Depth Model Size"
)
blend_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1, # Set default to 0.1 (10% webcam opacity)
step=0.1,
label="Webcam Overlay Opacity (0 = depth only, 1 = full webcam overlay)"
)
cyclegan_toggle = gr.Checkbox(
value=True,
label="Enable CycleGAN Transformation"
)
cyclegan_direction = gr.Radio(
choices=["Depth to Image", "Image to Depth"],
value="Depth to Image",
label="CycleGAN Direction"
)
start_button = gr.Button("Start Processing", variant="primary")
with gr.Column():
output_status = gr.Textbox(
label="Status",
value="Ready to start...",
interactive=False
)
# Instructions
gr.Markdown("""
### Instructions:
1. Select the depth model size (smaller models are faster but less accurate)
2. Adjust the blend strength between the original webcam feed and the winter-colored depth map
3. Enable/disable CycleGAN transformation
4. Select the CycleGAN conversion direction
5. Click "Start Processing" to begin the virtual camera feed
6. A preview window will open - press 'q' in that window to stop processing
**Note:** You'll need to have pyvirtualcam installed and a virtual camera device
(like OBS Virtual Camera) configured on your system.
""")
def start_processing(model_name, blend_alpha, enable_cyclegan, cyclegan_dir):
encoder = {v: k for k, v in encoder2name.items()}[model_name]
try:
process_webcam_with_depth_and_cyclegan(
encoder,
blend_alpha,
cyclegan_dir,
enable_cyclegan
)
return "Processing completed. (If this message appears immediately, check for errors in the console)"
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {str(e)}"
start_button.click(
fn=start_processing,
inputs=[model_dropdown, blend_slider, cyclegan_toggle, cyclegan_direction],
outputs=output_status
)
if __name__ == "__main__":
demo.launch() |