yuyutsu07 commited on
Commit
2552c4b
·
verified ·
1 Parent(s): a984096

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import imageio
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision.transforms import ToTensor
7
+ import spaces
8
+ import tempfile
9
+
10
+ @spaces.GPU
11
+ def generate_parallax_video(image, depth_map, T_max=10, k=50, fps=30, duration=5):
12
+ """
13
+ Generate a 5-second 3D parallax video from an image and depth map.
14
+
15
+ Parameters:
16
+ - image (PIL.Image): Input image.
17
+ - depth_map (PIL.Image): Depth map (grayscale).
18
+ - T_max (float): Maximum camera translation amplitude.
19
+ - k (float): Depth displacement scale factor.
20
+ - fps (int): Frames per second.
21
+ - duration (int): Video duration in seconds.
22
+
23
+ Returns:
24
+ - str: Path to the generated video file.
25
+ """
26
+ # Validate input sizes
27
+ if image.size != depth_map.size:
28
+ raise ValueError("Image and depth map must be the same size")
29
+
30
+ # Convert to PyTorch tensors and move to GPU
31
+ image_tensor = ToTensor()(image).unsqueeze(0).to('cuda') # Shape: (1, 3, H, W)
32
+ depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda') # Shape: (1, 1, H, W)
33
+ depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
34
+ depth_tensor = depth_tensor.squeeze(0).squeeze(0) # Shape: (H, W)
35
+
36
+ H, W = image_tensor.shape[2], image_tensor.shape[3]
37
+
38
+ # Create base pixel grid
39
+ x = torch.arange(0, W).float().to('cuda')
40
+ y = torch.arange(0, H).float().to('cuda')
41
+ xx, yy = torch.meshgrid(x, y, indexing='ij')
42
+ pixel_grid = torch.stack((xx, yy), dim=-1) # Shape: (H, W, 2)
43
+
44
+ # Generate frames
45
+ num_frames = int(fps * duration)
46
+ frames = []
47
+
48
+ for frame in range(num_frames):
49
+ # Simulate horizontal camera movement
50
+ T = T_max * np.sin(2 * np.pi * frame / num_frames)
51
+ displacement = k * T * depth_tensor # Shape: (H, W)
52
+
53
+ # Compute source coordinates
54
+ source_pixel_x = pixel_grid[:, :, 0] - displacement
55
+ source_pixel_y = pixel_grid[:, :, 1]
56
+
57
+ # Normalize to [-1, 1] for grid_sample
58
+ grid_x = 2 * source_pixel_x / (W - 1) - 1
59
+ grid_y = 2 * source_pixel_y / (H - 1) - 1
60
+ grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) # Shape: (1, H, W, 2)
61
+
62
+ # Warp the image
63
+ warped = torch.nn.functional.grid_sample(image_tensor, grid, align_corners=True)
64
+
65
+ # Convert to numpy for video
66
+ warped_np = warped.squeeze(0).permute(1, 2, 0).cpu().numpy() # Shape: (H, W, 3)
67
+ frame_img = (warped_np * 255).astype(np.uint8)
68
+ frames.append(frame_img)
69
+
70
+ # Save video to a temporary file
71
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
72
+ output_path = tmpfile.name
73
+ imageio.mimwrite(output_path, frames, fps=fps, quality=8)
74
+
75
+ return output_path
76
+
77
+ # Gradio interface
78
+ with gr.Blocks(title="3D Parallax Video Generator") as demo:
79
+ gr.Markdown("# 3D Parallax Video Generator")
80
+ gr.Markdown("Upload an image and its depth map to create a 5-second 3D parallax video.")
81
+
82
+ with gr.Row():
83
+ image_input = gr.Image(type="pil", label="Upload Image")
84
+ depth_input = gr.Image(type="pil", label="Upload Depth Map")
85
+
86
+ with gr.Row():
87
+ T_max_slider = gr.Slider(minimum=1, maximum=50, value=10, label="Camera Amplitude (T_max)")
88
+ k_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Depth Scale (k)")
89
+ fps_slider = gr.Slider(minimum=10, maximum=60, value=30, label="Frames Per Second")
90
+
91
+ generate_btn = gr.Button("Generate Video")
92
+ video_output = gr.Video(label="Parallax Video")
93
+
94
+ generate_btn.click(
95
+ fn=generate_parallax_video,
96
+ inputs=[image_input, depth_input, T_max_slider, k_slider, fps_slider],
97
+ outputs=video_output
98
+ )
99
+
100
+ demo.launch()