Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.transforms import ToTensor
|
7 |
+
import spaces
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
@spaces.GPU
|
11 |
+
def generate_parallax_video(image, depth_map, T_max=10, k=50, fps=30, duration=5):
|
12 |
+
"""
|
13 |
+
Generate a 5-second 3D parallax video from an image and depth map.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
- image (PIL.Image): Input image.
|
17 |
+
- depth_map (PIL.Image): Depth map (grayscale).
|
18 |
+
- T_max (float): Maximum camera translation amplitude.
|
19 |
+
- k (float): Depth displacement scale factor.
|
20 |
+
- fps (int): Frames per second.
|
21 |
+
- duration (int): Video duration in seconds.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
- str: Path to the generated video file.
|
25 |
+
"""
|
26 |
+
# Validate input sizes
|
27 |
+
if image.size != depth_map.size:
|
28 |
+
raise ValueError("Image and depth map must be the same size")
|
29 |
+
|
30 |
+
# Convert to PyTorch tensors and move to GPU
|
31 |
+
image_tensor = ToTensor()(image).unsqueeze(0).to('cuda') # Shape: (1, 3, H, W)
|
32 |
+
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda') # Shape: (1, 1, H, W)
|
33 |
+
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
34 |
+
depth_tensor = depth_tensor.squeeze(0).squeeze(0) # Shape: (H, W)
|
35 |
+
|
36 |
+
H, W = image_tensor.shape[2], image_tensor.shape[3]
|
37 |
+
|
38 |
+
# Create base pixel grid
|
39 |
+
x = torch.arange(0, W).float().to('cuda')
|
40 |
+
y = torch.arange(0, H).float().to('cuda')
|
41 |
+
xx, yy = torch.meshgrid(x, y, indexing='ij')
|
42 |
+
pixel_grid = torch.stack((xx, yy), dim=-1) # Shape: (H, W, 2)
|
43 |
+
|
44 |
+
# Generate frames
|
45 |
+
num_frames = int(fps * duration)
|
46 |
+
frames = []
|
47 |
+
|
48 |
+
for frame in range(num_frames):
|
49 |
+
# Simulate horizontal camera movement
|
50 |
+
T = T_max * np.sin(2 * np.pi * frame / num_frames)
|
51 |
+
displacement = k * T * depth_tensor # Shape: (H, W)
|
52 |
+
|
53 |
+
# Compute source coordinates
|
54 |
+
source_pixel_x = pixel_grid[:, :, 0] - displacement
|
55 |
+
source_pixel_y = pixel_grid[:, :, 1]
|
56 |
+
|
57 |
+
# Normalize to [-1, 1] for grid_sample
|
58 |
+
grid_x = 2 * source_pixel_x / (W - 1) - 1
|
59 |
+
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
60 |
+
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # Shape: (1, H, W, 2)
|
61 |
+
|
62 |
+
# Warp the image
|
63 |
+
warped = torch.nn.functional.grid_sample(image_tensor, grid, align_corners=True)
|
64 |
+
|
65 |
+
# Convert to numpy for video
|
66 |
+
warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy() # Shape: (H, W, 3)
|
67 |
+
frame_img = (warped_np * 255).astype(np.uint8)
|
68 |
+
frames.append(frame_img)
|
69 |
+
|
70 |
+
# Save video to a temporary file
|
71 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
72 |
+
output_path = tmpfile.name
|
73 |
+
imageio.mimwrite(output_path, frames, fps=fps, quality=8)
|
74 |
+
|
75 |
+
return output_path
|
76 |
+
|
77 |
+
# Gradio interface
|
78 |
+
with gr.Blocks(title="3D Parallax Video Generator") as demo:
|
79 |
+
gr.Markdown("# 3D Parallax Video Generator")
|
80 |
+
gr.Markdown("Upload an image and its depth map to create a 5-second 3D parallax video.")
|
81 |
+
|
82 |
+
with gr.Row():
|
83 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
84 |
+
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
T_max_slider = gr.Slider(minimum=1, maximum=50, value=10, label="Camera Amplitude (T_max)")
|
88 |
+
k_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Depth Scale (k)")
|
89 |
+
fps_slider = gr.Slider(minimum=10, maximum=60, value=30, label="Frames Per Second")
|
90 |
+
|
91 |
+
generate_btn = gr.Button("Generate Video")
|
92 |
+
video_output = gr.Video(label="Parallax Video")
|
93 |
+
|
94 |
+
generate_btn.click(
|
95 |
+
fn=generate_parallax_video,
|
96 |
+
inputs=[image_input, depth_input, T_max_slider, k_slider, fps_slider],
|
97 |
+
outputs=video_output
|
98 |
+
)
|
99 |
+
|
100 |
+
demo.launch()
|