CycleGAN_Depth2RobotsV2_Blend / src /inference /discordDepth2AnythingGAN.py
Borcherding's picture
Rename src/inference/integrated-depth-cyclegan.py to src/inference/discordDepth2AnythingGAN.py
b6801da verified
import gradio as gr
import cv2
import numpy as np
import torch
import sys
import os
import pyvirtualcam
from pyvirtualcam import PixelFormat
from huggingface_hub import hf_hub_download
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
# Path configurations
depth_anything_path = os.getenv('DEPTH_ANYTHING_V2_PATH')
if depth_anything_path is None:
raise ValueError("Environment variable DEPTH_ANYTHING_V2_PATH is not set. Please set it to the path of Depth-Anything-V2")
sys.path.append(depth_anything_path)
from depth_anything_v2.dpt import DepthAnythingV2
# Device selection with MPS support
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
###########################################
# CycleGAN Generator Architecture
###########################################
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, 3),
nn.InstanceNorm2d(channels)
)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_channels=3, output_channels=3, n_residual_blocks=9):
super(Generator, self).__init__()
# Initial convolution
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_channels, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
]
# Downsampling
in_features = 64
out_features = in_features * 2
for _ in range(2):
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features * 2
# Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)]
# Upsampling
out_features = in_features // 2
for _ in range(2):
model += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features // 2
# Output layer
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_channels, 7),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
###########################################
# Depth Anything Model Functions
###########################################
# Model configurations
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
}
encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large'
}
# Model IDs and filenames for HuggingFace Hub
MODEL_INFO = {
'vits': {
'repo_id': 'depth-anything/Depth-Anything-V2-Small',
'filename': 'depth_anything_v2_vits.pth'
},
'vitb': {
'repo_id': 'depth-anything/Depth-Anything-V2-Base',
'filename': 'depth_anything_v2_vitb.pth'
},
'vitl': {
'repo_id': 'depth-anything/Depth-Anything-V2-Large',
'filename': 'depth_anything_v2_vitl.pth'
}
}
# Global variables for model management
current_depth_model = None
current_encoder = None
current_cyclegan_model = None
def download_model(encoder):
"""Download the specified model from HuggingFace Hub"""
model_info = MODEL_INFO[encoder]
model_path = hf_hub_download(
repo_id=model_info['repo_id'],
filename=model_info['filename'],
local_dir='checkpoints'
)
return model_path
def load_depth_model(encoder):
"""Load the specified depth model"""
global current_depth_model, current_encoder
if current_encoder != encoder:
model_path = download_model(encoder)
current_depth_model = DepthAnythingV2(**model_configs[encoder])
current_depth_model.load_state_dict(torch.load(model_path, map_location='cpu'))
current_depth_model = current_depth_model.to(DEVICE).eval()
current_encoder = encoder
return current_depth_model
def load_cyclegan_model(model_path):
"""Load the CycleGAN model"""
global current_cyclegan_model
if current_cyclegan_model is None:
model = Generator()
if os.path.exists(model_path):
print(f"Loading CycleGAN model from {model_path}")
state_dict = torch.load(model_path, map_location='cpu')
try:
model.load_state_dict(state_dict)
except Exception as e:
print(f"Warning: {e}")
# Try loading with strict=False
model.load_state_dict(state_dict, strict=False)
print("Loaded model with strict=False")
else:
print(f"Error: CycleGAN model file not found at {model_path}")
return None
model.eval()
current_cyclegan_model = model.to(DEVICE)
return current_cyclegan_model
@torch.inference_mode()
def predict_depth(image, encoder):
"""Predict depth using the selected model"""
model = load_depth_model(encoder)
depth = model.infer_image(image)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
return depth
def apply_winter_colormap(depth_map):
"""Apply a winter-themed colormap to the depth map"""
# Use COLORMAP_WINTER for blue to teal colors
depth_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_WINTER)
return depth_colored
def blend_images(original, depth_colored, alpha=0.1):
"""
Blend the original image on top of the colored depth map
Parameters:
- original: Original webcam frame (BGR format)
- depth_colored: Colorized depth map (BGR format)
- alpha: Blend strength of original webcam (0.0 = depth only, 1.0 = original only)
Returns:
- Blended image where depth map is the base layer and original is overlaid with transparency
"""
# Make sure both images have the same dimensions
if original.shape != depth_colored.shape:
depth_colored = cv2.resize(depth_colored, (original.shape[1], original.shape[0]))
# Start with depth map at 100% opacity as base
# Then add original image on top with specified alpha transparency
result = cv2.addWeighted(depth_colored, 1.0, original, alpha, 0)
return result
def preprocess_for_cyclegan(image, original_size=None):
"""Preprocess image for CycleGAN input"""
# Convert numpy array to PIL Image
image_pil = Image.fromarray(image)
# Save original size if provided
if original_size is None:
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Create transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Process image
input_tensor = transform(image_pil).unsqueeze(0).to(DEVICE)
return input_tensor, original_size
def postprocess_from_cyclegan(tensor, original_size):
"""Convert CycleGAN output tensor to numpy image with original dimensions"""
tensor = tensor.squeeze(0).cpu()
tensor = (tensor + 1) / 2
tensor = tensor.clamp(0, 1)
tensor = tensor.permute(1, 2, 0).numpy()
# Convert to uint8
image = (tensor * 255).astype(np.uint8)
# Resize back to original dimensions
if image.shape[0] != original_size[1] or image.shape[1] != original_size[0]:
image = cv2.resize(image, original_size)
return image
@torch.inference_mode()
def apply_cyclegan(image, direction):
"""Apply CycleGAN transformation to the image"""
if direction == "Depth to Image":
model_path = "./checkpoints/depth2image/latest_net_G_A.pth"
else:
model_path = "./checkpoints/depth2image/latest_net_G_B.pth"
model = load_cyclegan_model(model_path)
if model is None:
return None
# Save original dimensions
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Preprocess
input_tensor, _ = preprocess_for_cyclegan(image, original_size)
# Generate output
output_tensor = model(input_tensor)
# Postprocess with original size
output_image = postprocess_from_cyclegan(output_tensor, original_size)
return output_image
def process_webcam_with_depth_and_cyclegan(encoder, blend_alpha, cyclegan_direction, enable_cyclegan=True):
"""Process webcam with depth, blend, and optionally apply CycleGAN"""
# Open the webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open webcam")
return
# Read a test frame to get the actual dimensions
ret, test_frame = cap.read()
if not ret:
print("Error: Could not read from webcam")
return
# Get the actual frame dimensions
frame_height, frame_width = test_frame.shape[:2]
print(f"Webcam frame dimensions: {frame_width}x{frame_height}")
# Ensure checkpoints directory exists
os.makedirs("checkpoints/depth2image", exist_ok=True)
# Create a preview window
preview_window = "Depth Winter + CycleGAN Preview"
cv2.namedWindow(preview_window, cv2.WINDOW_NORMAL)
try:
# Initialize virtual camera with exact frame dimensions
with pyvirtualcam.Camera(width=frame_width, height=frame_height, fps=30, fmt=PixelFormat.BGR, backend='obs') as cam:
print(f'Using virtual camera: {cam.device}')
print(f'Virtual camera dimensions: {cam.width}x{cam.height}')
frame_count = 0
while True:
# Capture frame
ret, frame = cap.read()
if not ret:
break
# Print dimensions occasionally for debugging
if frame_count % 100 == 0:
print(f"Frame {frame_count} dimensions: {frame.shape}")
frame_count += 1
# Convert BGR to RGB for depth prediction
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Predict depth
depth_map = predict_depth(frame_rgb, encoder)
# Apply winter colormap
depth_colored = apply_winter_colormap(depth_map)
# Blend with original
blended = blend_images(frame, depth_colored, alpha=blend_alpha)
# Apply CycleGAN if enabled
if enable_cyclegan:
if cyclegan_direction == "Image to Depth":
# For Image to Depth, use raw webcam feed (not blended)
input_for_gan = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
# For Depth to Image, use the blended result
input_for_gan = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
cyclegan_output = apply_cyclegan(input_for_gan, cyclegan_direction)
if cyclegan_output is not None:
# Convert RGB back to BGR for virtual cam
output = cv2.cvtColor(cyclegan_output, cv2.COLOR_RGB2BGR)
else:
output = blended
else:
output = blended
# Ensure output has the exact dimensions expected by the virtual camera
if output.shape[0] != frame_height or output.shape[1] != frame_width:
print(f"Resizing output from {output.shape[1]}x{output.shape[0]} to {frame_width}x{frame_height}")
output = cv2.resize(output, (frame_width, frame_height))
# Show preview
cv2.imshow(preview_window, output)
# Send to virtual camera
try:
cam.send(output)
cam.sleep_until_next_frame()
except Exception as e:
print(f"Error sending to virtual camera: {e}")
print(f"Output shape: {output.shape}, Expected: {frame_height}x{frame_width}x3")
# Press 'q' to exit
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except Exception as e:
print(f"Error in webcam processing: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up
cap.release()
cv2.destroyAllWindows()
###########################################
# Gradio Interface
###########################################
with gr.Blocks(title="Depth Anything with CycleGAN") as demo:
gr.Markdown("# Depth Anything V2 with Winter Colormap + CycleGAN")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(encoder2name.values()),
value="Small",
label="Select Depth Model Size"
)
blend_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1, # Set default to 0.1 (10% webcam opacity)
step=0.1,
label="Webcam Overlay Opacity (0 = depth only, 1 = full webcam overlay)"
)
cyclegan_toggle = gr.Checkbox(
value=True,
label="Enable CycleGAN Transformation"
)
cyclegan_direction = gr.Radio(
choices=["Depth to Image", "Image to Depth"],
value="Depth to Image",
label="CycleGAN Direction"
)
start_button = gr.Button("Start Processing", variant="primary")
with gr.Column():
output_status = gr.Textbox(
label="Status",
value="Ready to start...",
interactive=False
)
# Instructions
gr.Markdown("""
### Instructions:
1. Select the depth model size (smaller models are faster but less accurate)
2. Adjust the blend strength between the original webcam feed and the winter-colored depth map
3. Enable/disable CycleGAN transformation
4. Select the CycleGAN conversion direction
5. Click "Start Processing" to begin the virtual camera feed
6. A preview window will open - press 'q' in that window to stop processing
**Note:** You'll need to have pyvirtualcam installed and a virtual camera device
(like OBS Virtual Camera) configured on your system.
""")
def start_processing(model_name, blend_alpha, enable_cyclegan, cyclegan_dir):
encoder = {v: k for k, v in encoder2name.items()}[model_name]
try:
process_webcam_with_depth_and_cyclegan(
encoder,
blend_alpha,
cyclegan_dir,
enable_cyclegan
)
return "Processing completed. (If this message appears immediately, check for errors in the console)"
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {str(e)}"
start_button.click(
fn=start_processing,
inputs=[model_dropdown, blend_slider, cyclegan_toggle, cyclegan_direction],
outputs=output_status
)
if __name__ == "__main__":
demo.launch()