yuyutsu07 commited on
Commit
d494365
·
verified ·
1 Parent(s): 16f4e59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -111
app.py CHANGED
@@ -7,144 +7,182 @@ from torchvision.transforms import ToTensor, Resize
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @spaces.GPU
12
- def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa):
13
- """
14
- Generate a 3D parallax video with enhanced quality features.
15
-
16
- Args:
17
- image (PIL.Image): Input RGB image.
18
- depth_map (PIL.Image): Grayscale depth map.
19
- animation_style (str): Animation type (e.g., horizontal, spiral).
20
- amplitude (float): Camera movement intensity.
21
- k (float): Depth displacement scale.
22
- fps (int): Frames per second.
23
- duration (float): Video duration in seconds.
24
- ssaa_factor (int): Super sampling factor (1, 2, 4).
25
- use_taa (bool): Enable temporal anti-aliasing.
26
-
27
- Returns:
28
- str: Path to the generated video file.
29
- """
30
- # Validate input dimensions
31
  if image.size != depth_map.size:
32
- raise ValueError("Image and depth map must have the same dimensions")
33
 
34
- # Convert to tensors with high precision
35
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
36
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
37
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
38
 
39
- # Smooth depth map to improve intersections
40
- depth_np = depth_tensor.squeeze().cpu().numpy()
41
- depth_np = gaussian_filter(depth_np, sigma=1) # Basic smoothing
42
- depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
43
 
44
- # Apply SSAA: upscale image and depth map
45
  if ssaa_factor > 1:
46
  upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
47
  image_tensor = upscale(image_tensor)
48
  depth_tensor = upscale(depth_tensor)
49
 
50
  H, W = image_tensor.shape[1], image_tensor.shape[2]
 
 
51
 
52
- # Create coordinate grid
53
- x = torch.arange(0, W).float().to('cuda')
54
- y = torch.arange(0, H).float().to('cuda')
55
- xx, yy = torch.meshgrid(x, y, indexing='xy')
56
- pixel_grid = torch.stack((xx, yy), dim=-1)
57
-
58
- # Generate frames
59
  num_frames = int(fps * duration)
60
  frames = []
61
  prev_frame = None
62
 
63
- for frame in range(num_frames):
64
- t = frame / num_frames
65
- if animation_style == "horizontal":
66
- camera_x = amplitude * np.sin(2 * np.pi * t)
67
- camera_y = 0
68
- elif animation_style == "vertical":
69
- camera_x = 0
70
- camera_y = amplitude * np.sin(2 * np.pi * t)
71
- elif animation_style == "circle":
72
- camera_x = amplitude * np.sin(2 * np.pi * t)
73
- camera_y = amplitude * np.cos(2 * np.pi * t)
74
- elif animation_style == "spiral": # Inspired by DepthFlow
75
- radius = amplitude * (1 - t)
76
- camera_x = radius * np.sin(4 * np.pi * t)
77
- camera_y = radius * np.cos(4 * np.pi * t)
78
- else:
79
- raise ValueError(f"Unsupported animation style: {animation_style}")
80
-
81
- # Compute displacements
82
  displacement_x = k * camera_x * depth_tensor.squeeze()
83
  displacement_y = k * camera_y * depth_tensor.squeeze()
84
-
85
- # Calculate source coordinates
86
- source_pixel_x = pixel_grid[:, :, 0] + displacement_x
87
- source_pixel_y = pixel_grid[:, :, 1] + displacement_y
88
-
89
- # Normalize to [-1, 1]
90
- grid_x = 2 * source_pixel_x / (W - 1) - 1
91
- grid_y = 2 * source_pixel_y / (H - 1) - 1
92
- grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
93
-
94
- # Warp with high-quality interpolation
95
- warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
96
-
97
- # Downsample if SSAA is enabled
98
- if ssaa_factor > 1:
99
- downscale = Resize((image.height, image.width), antialias=True)
100
- warped = downscale(warped.squeeze(0)).unsqueeze(0)
101
-
102
- # Convert to numpy
103
- frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
104
- frame_img = (frame_img * 255).astype(np.uint8)
105
-
106
- # Apply TAA if enabled
107
- if use_taa and prev_frame is not None:
108
- frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
109
-
110
  frames.append(frame_img)
111
- prev_frame = frame_img
112
-
113
- # Save video
114
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
115
- output_path = tmpfile.name
116
- writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  for frame in frames:
118
  writer.append_data(frame)
119
  writer.close()
 
120
 
121
- return output_path
122
-
123
- # Gradio interface
124
- with gr.Blocks(title="Enhanced 3D Parallax Video Generator") as demo:
125
- gr.Markdown("# Enhanced 3D Parallax Video Generator")
126
- gr.Markdown("Create high-quality 3D parallax videos with advanced features.")
127
-
128
  with gr.Row():
129
- image_input = gr.Image(type="pil", label="Upload Image")
130
- depth_input = gr.Image(type="pil", label="Upload Depth Map")
131
-
 
 
 
 
 
 
 
 
 
 
 
 
132
  with gr.Row():
133
- animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
134
- amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
135
- k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
136
- fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
137
- duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
138
- ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
139
- use_taa = gr.Checkbox(label="Enable TAA", value=False)
140
-
141
- generate_btn = gr.Button("Generate Video")
142
- video_output = gr.Video(label="Parallax Video")
143
-
144
- generate_btn.click(
145
- fn=generate_parallax_video,
146
- inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa],
147
- outputs=video_output
148
- )
149
 
150
  demo.launch()
 
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
+ from huggingface_hub import hf_hub_download
11
+ from safetensors.torch import load_file
12
+
13
+ # ------------------------- AuraSR Model Definition ------------------------- #
14
+ class ResBlock(torch.nn.Module):
15
+ def __init__(self, n_filters):
16
+ super().__init__()
17
+ self.conv1 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
18
+ self.conv2 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
19
+
20
+ def forward(self, x):
21
+ residual = x
22
+ x = torch.relu(self.conv1(x))
23
+ x = self.conv2(x)
24
+ x += residual
25
+ return x
26
+
27
+ class AuraSR(torch.nn.Module):
28
+ def __init__(self, scale=4, n_filters=64, n_blocks=8):
29
+ super().__init__()
30
+ self.scale = scale
31
+ self.head = torch.nn.Conv2d(3, n_filters, 3, padding=1)
32
+ self.body = torch.nn.Sequential(*[ResBlock(n_filters) for _ in range(n_blocks)])
33
+ self.tail = torch.nn.Sequential(
34
+ torch.nn.Conv2d(n_filters, n_filters * (scale ** 2), 3, padding=1),
35
+ torch.nn.PixelShuffle(scale),
36
+ torch.nn.Conv2d(n_filters, 3, 3, padding=1)
37
+ )
38
+
39
+ def forward(self, x):
40
+ x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest')
41
+ x = self.head(x)
42
+ x = self.body(x)
43
+ x = self.tail(x)
44
+ return x
45
+
46
+ # Load AuraSR-v2 model
47
+ model_path = hf_hub_download(repo_id="fal/AuraSR-v2", filename="model.safetensors")
48
+ state_dict = load_file(model_path)
49
+ upscaler_model = AuraSR().eval().to('cuda')
50
+ upscaler_model.load_state_dict(state_dict)
51
+
52
+ # ------------------------- Core Parallax Function ------------------------- #
53
  @spaces.GPU
54
+ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps,
55
+ duration, ssaa_factor, use_taa, use_upscaler):
56
+ """Generate parallax video with optional super-resolution upscaling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if image.size != depth_map.size:
58
+ raise ValueError("Image and depth map dimensions must match")
59
 
60
+ # Preprocess inputs
61
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
62
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
63
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
64
 
65
+ # Apply Gaussian smoothing to depth map
66
+ depth_np = gaussian_filter(depth_tensor.squeeze().cpu().numpy(), sigma=1)
67
+ depth_tensor = torch.tensor(depth_np, device='cuda').unsqueeze(0)
 
68
 
69
+ # Super Sampling Anti-Aliasing
70
  if ssaa_factor > 1:
71
  upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
72
  image_tensor = upscale(image_tensor)
73
  depth_tensor = upscale(depth_tensor)
74
 
75
  H, W = image_tensor.shape[1], image_tensor.shape[2]
76
+ x, y = torch.meshgrid(torch.arange(W, device='cuda'), torch.arange(H, device='cuda'), indexing='xy')
77
+ pixel_grid = torch.stack((x, y), dim=-1)
78
 
79
+ # Animation parameters
 
 
 
 
 
 
80
  num_frames = int(fps * duration)
81
  frames = []
82
  prev_frame = None
83
 
84
+ for frame_idx in range(num_frames):
85
+ t = frame_idx / num_frames
86
+ camera_x, camera_y = calculate_movement(t, amplitude, animation_style)
87
+
88
+ # Calculate displacement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  displacement_x = k * camera_x * depth_tensor.squeeze()
90
  displacement_y = k * camera_y * depth_tensor.squeeze()
91
+
92
+ # Warp image
93
+ warped = warp_image(image_tensor, pixel_grid, displacement_x, displacement_y, W, H)
94
+
95
+ # Post-processing
96
+ frame_img = post_process_frame(warped, ssaa_factor, image.size, use_taa, prev_frame)
97
+
98
+ # Apply super-resolution
99
+ if use_upscaler:
100
+ frame_img = apply_upscaler(frame_img)
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  frames.append(frame_img)
103
+ prev_frame = frame_img.copy() if use_taa else None
104
+
105
+ return save_video(frames, fps)
106
+
107
+ # ------------------------- Helper Functions ------------------------- #
108
+ def calculate_movement(t, amplitude, style):
109
+ """Calculate camera movement based on animation style"""
110
+ if style == "horizontal":
111
+ return amplitude * np.sin(2*np.pi*t), 0
112
+ elif style == "vertical":
113
+ return 0, amplitude * np.sin(2*np.pi*t)
114
+ elif style == "circle":
115
+ return amplitude*np.sin(2*np.pi*t), amplitude*np.cos(2*np.pi*t)
116
+ elif style == "spiral":
117
+ radius = amplitude * (1 - t)
118
+ return radius*np.sin(4*np.pi*t), radius*np.cos(4*np.pi*t)
119
+
120
+ def warp_image(image_tensor, pixel_grid, dx, dy, W, H):
121
+ """Warp image using computed displacements"""
122
+ source_x = pixel_grid[:, :, 0] + dx
123
+ source_y = pixel_grid[:, :, 1] + dy
124
+ grid = torch.stack((2*source_x/(W-1)-1, 2*source_y/(H-1)-1), dim=-1).unsqueeze(0)
125
+ return torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
126
+
127
+ def post_process_frame(warped, ssaa_factor, orig_size, use_taa, prev_frame):
128
+ """Process frame with SSAA and TAA"""
129
+ if ssaa_factor > 1:
130
+ warped = Resize(orig_size[::-1], antialias=True)(warped.squeeze(0)).unsqueeze(0)
131
+
132
+ frame = (warped.squeeze().permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
133
+
134
+ if use_taa and prev_frame is not None:
135
+ frame = cv2.addWeighted(frame, 0.8, prev_frame, 0.2, 0)
136
+
137
+ return frame
138
+
139
+ def apply_upscaler(frame):
140
+ """Apply 4x super-resolution using AuraSR-v2"""
141
+ tensor = torch.tensor(frame).permute(2,0,1).unsqueeze(0).float() / 255.0
142
+ with torch.no_grad():
143
+ upscaled = upscaler_model(tensor.to('cuda'))
144
+ return (upscaled[0].permute(1,2,0).clamp(0,1).cpu().numpy() * 255).astype(np.uint8)
145
+
146
+ def save_video(frames, fps):
147
+ """Save frames to video file"""
148
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
149
+ writer = imageio.get_writer(f.name, fps=fps, codec='libx264', quality=9)
150
  for frame in frames:
151
  writer.append_data(frame)
152
  writer.close()
153
+ return f.name
154
 
155
+ # ------------------------- Gradio Interface ------------------------- #
156
+ with gr.Blocks(title="3D Parallax Video Generator with Super-Resolution") as demo:
157
+ gr.Markdown("# 🔥 3D Parallax Video Generator with 4x Super-Resolution")
158
+ gr.Markdown("Generate stunning 3D parallax videos from 2D images with optional AI upscaling")
159
+
 
 
160
  with gr.Row():
161
+ image_input = gr.Image(type="pil", label="Input Image")
162
+ depth_input = gr.Image(type="pil", label="Depth Map")
163
+
164
+ with gr.Row():
165
+ with gr.Column():
166
+ animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"],
167
+ value="horizontal", label="Animation Style")
168
+ amplitude = gr.Slider(0, 10, value=2, step=0.1, label="Movement Amplitude")
169
+ k = gr.Slider(0, 20, value=5, step=0.1, label="Depth Scaling Factor")
170
+
171
+ with gr.Column():
172
+ fps = gr.Slider(10, 60, value=30, step=1, label="FPS")
173
+ duration = gr.Slider(1, 10, value=5, step=0.1, label="Duration (seconds)")
174
+ ssaa_factor = gr.Dropdown([1, 2, 4], value=1, label="Anti-Aliasing Quality")
175
+
176
  with gr.Row():
177
+ use_taa = gr.Checkbox(label="Enable Temporal Anti-Aliasing", value=False)
178
+ use_upscaler = gr.Checkbox(label="Enable 4x Super-Resolution (AuraSR-v2)", value=False)
179
+
180
+ generate_btn = gr.Button("Generate Video", variant="primary")
181
+ video_output = gr.Video(label="Generated Video", format="mp4")
182
+
183
+ generate_btn.click(fn=generate_parallax_video,
184
+ inputs=[image_input, depth_input, animation_style, amplitude, k,
185
+ fps, duration, ssaa_factor, use_taa, use_upscaler],
186
+ outputs=video_output)
 
 
 
 
 
 
187
 
188
  demo.launch()