import torch import gradio as gr import imageio import numpy as np from PIL import Image from torchvision.transforms import ToTensor, Resize import spaces import tempfile from scipy.ndimage import gaussian_filter from aura_sr import AuraSR import cv2 import torch.nn.functional as F # Load AuraSR-v2 model once at startup aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2") # Post-processing functions def apply_lens_distortion(image, k1=0.2): """Apply lens distortion using OpenCV.""" h, w = image.shape[:2] camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32) dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32) distorted = cv2.undistort(image, camera_matrix, dist_coeffs) return distorted def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5): """Apply depth of field blur using PyTorch.""" depth_diff = torch.abs(depth_tensor - focus_depth) blur_kernel = blur_size * depth_diff.clamp(0, 1) blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0) padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect') blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3) mask = (depth_diff < 0.1).float() return image_tensor * mask + blurred * (1 - mask) def apply_vignette(image): """Apply vignette effect.""" h, w = image.shape[:2] x, y = np.meshgrid(np.arange(w), np.arange(h)) center_x, center_y = w / 2, h / 2 radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) max_radius = np.sqrt(center_x ** 2 + center_y ** 2) vignette = 1 - (radius / max_radius) ** 2 vignette = np.clip(vignette, 0, 1) return (image * vignette[..., np.newaxis]).astype(np.uint8) def parse_keyframes(keyframe_text): """Parse keyframe text into time-position pairs.""" keyframes = [] try: for entry in keyframe_text.split(): time, pos = entry.split(':') x, y = map(float, pos.split(',')) keyframes.append((float(time), x, y)) keyframes.sort() # Sort by time return keyframes except: return [(0, 0, 0), (1, 0, 0)] # Default fallback def interpolate_keyframes(t, keyframes): """Interpolate camera position between keyframes.""" if t <= keyframes[0][0]: return keyframes[0][1], keyframes[0][2] if t >= keyframes[-1][0]: return keyframes[-1][1], keyframes[-1][2] for i in range(len(keyframes) - 1): t1, x1, y1 = keyframes[i] t2, x2, y2 = keyframes[i + 1] if t1 <= t <= t2: alpha = (t - t1) / (t2 - t1) return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1) return 0, 0 # Fallback @spaces.GPU def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text): """Generate a 3D parallax video with advanced features.""" # Validate input dimensions if image.size != depth_map.size: raise ValueError("Image and depth map must have the same dimensions") # Convert to tensors image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32) depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32) depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6) # Smooth depth map depth_np = depth_tensor.squeeze().cpu().numpy() depth_np = gaussian_filter(depth_np, sigma=1) depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0) # Apply SSAA if ssaa_factor > 1: upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True) image_tensor = upscale(image_tensor) depth_tensor = upscale(depth_tensor) H, W = image_tensor.shape[1], image_tensor.shape[2] x = torch.arange(0, W).float().to('cuda') y = torch.arange(0, H).float().to('cuda') xx, yy = torch.meshgrid(x, y, indexing='xy') pixel_grid = torch.stack((xx, yy), dim=-1) # Parse keyframes for custom path keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None # Generate frames num_frames = int(fps * duration) frames = [] prev_frame = None for frame in range(num_frames): t = frame / num_frames if animation_style == "zoom": zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t) displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze() displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze() elif animation_style == "horizontal": camera_x = amplitude * np.sin(2 * np.pi * t) displacement_x = k * camera_x * depth_tensor.squeeze() displacement_y = 0 elif animation_style == "vertical": camera_y = amplitude * np.sin(2 * np.pi * t) displacement_x = 0 displacement_y = k * camera_y * depth_tensor.squeeze() elif animation_style == "circle": camera_x = amplitude * np.sin(2 * np.pi * t) camera_y = amplitude * np.cos(2 * np.pi * t) displacement_x = k * camera_x * depth_tensor.squeeze() displacement_y = k * camera_y * depth_tensor.squeeze() elif animation_style == "spiral": radius = amplitude * (1 - t) camera_x = radius * np.sin(4 * np.pi * t) camera_y = radius * np.cos(4 * np.pi * t) displacement_x = k * camera_x * depth_tensor.squeeze() displacement_y = k * camera_y * depth_tensor.squeeze() elif animation_style == "custom": camera_x, camera_y = interpolate_keyframes(t, keyframes) displacement_x = k * camera_x * depth_tensor.squeeze() displacement_y = k * camera_y * depth_tensor.squeeze() else: raise ValueError(f"Unsupported animation style: {animation_style}") source_pixel_x = pixel_grid[:, :, 0] + displacement_x source_pixel_y = pixel_grid[:, :, 1] + displacement_y # Normalize to [-1, 1] grid_x = 2 * source_pixel_x / (W - 1) - 1 grid_y = 2 * source_pixel_y / (H - 1) - 1 grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # Warp image warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True) # Downsample if SSAA if ssaa_factor > 1: downscale = Resize((image.height, image.width), antialias=True) warped = downscale(warped.squeeze(0)).unsqueeze(0) # Apply depth of field if enabled if apply_dof: warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0)) # Convert to numpy frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy() frame_img = (frame_img * 255).astype(np.uint8) # Apply lens distortion if enabled if apply_lens: frame_img = apply_lens_distortion(frame_img) # Apply vignette if enabled if apply_vig: frame_img = apply_vignette(frame_img) # Apply upscaling if enabled if use_upscale: frame_pil = Image.fromarray(frame_img) frame_pil = aura_sr.upscale_4x_overlapped(frame_pil) frame_img = np.array(frame_pil) # Apply TAA if enabled if use_taa and prev_frame is not None: frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8) frames.append(frame_img) prev_frame = frame_img.copy() if use_taa else None # Save video with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: output_path = tmpfile.name writer = imageio.get_writer(output_path, fps=fps, codec='libx264') for frame in frames: writer.append_data(frame) writer.close() return output_path # Gradio interface with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo: gr.Markdown("# Ultimate 3D Parallax Video Generator") gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") depth_input = gr.Image(type="pil", label="Upload Depth Map") with gr.Row(): animation_style = gr.Dropdown( ["zoom", "horizontal", "vertical", "circle", "spiral", "custom"], label="Animation Style", value="horizontal" ) amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1) k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1) fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1) duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1) with gr.Row(): ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1) use_taa = gr.Checkbox(label="Enable TAA", value=False) use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False) apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False) apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False) apply_vig = gr.Checkbox(label="Apply Vignette", value=False) with gr.Row(): keyframe_text = gr.Textbox( label="Custom Keyframes (time:x,y)", value="0:0,0 0.5:5,0 1:0,0", placeholder="e.g., 0:0,0 0.5:5,0 1:0,0", visible=True ) generate_btn = gr.Button("Generate Video") video_output = gr.Video(label="Parallax Video") generate_btn.click( fn=generate_parallax_video, inputs=[ image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text ], outputs=video_output ) demo.launch()