yuyutsu07 commited on
Commit
c1d8b1e
·
verified ·
1 Parent(s): 34ee49d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -42
app.py CHANGED
@@ -7,36 +7,78 @@ from torchvision.transforms import ToTensor, Resize
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
- from aura_sr import AuraSR # Import AuraSR for upscaling
 
 
11
 
12
  # Load AuraSR-v2 model once at startup
13
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
- def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale):
17
- """
18
- Generate a 3D parallax video with enhanced quality features and optional upscaling.
19
-
20
- Args:
21
- image (PIL.Image): Input RGB image.
22
- depth_map (PIL.Image): Grayscale depth map.
23
- animation_style (str): Animation type.
24
- amplitude (float): Camera movement intensity.
25
- k (float): Depth displacement scale.
26
- fps (int): Frames per second.
27
- duration (float): Video duration in seconds.
28
- ssaa_factor (int): Super sampling factor (1, 2, 4).
29
- use_taa (bool): Enable temporal anti-aliasing.
30
- use_upscale (bool): Enable AuraSR-v2 upscaling for each frame.
31
-
32
- Returns:
33
- str: Path to the generated video file.
34
- """
35
  # Validate input dimensions
36
  if image.size != depth_map.size:
37
  raise ValueError("Image and depth map must have the same dimensions")
38
 
39
- # Convert to tensors with high precision
40
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
41
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
42
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
@@ -53,13 +95,14 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
53
  depth_tensor = upscale(depth_tensor)
54
 
55
  H, W = image_tensor.shape[1], image_tensor.shape[2]
56
-
57
- # Create coordinate grid
58
  x = torch.arange(0, W).float().to('cuda')
59
  y = torch.arange(0, H).float().to('cuda')
60
  xx, yy = torch.meshgrid(x, y, indexing='xy')
61
  pixel_grid = torch.stack((xx, yy), dim=-1)
62
 
 
 
 
63
  # Generate frames
64
  num_frames = int(fps * duration)
65
  frames = []
@@ -67,27 +110,36 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
67
 
68
  for frame in range(num_frames):
69
  t = frame / num_frames
70
- if animation_style == "horizontal":
 
 
 
 
71
  camera_x = amplitude * np.sin(2 * np.pi * t)
72
- camera_y = 0
 
73
  elif animation_style == "vertical":
74
- camera_x = 0
75
  camera_y = amplitude * np.sin(2 * np.pi * t)
 
 
76
  elif animation_style == "circle":
77
  camera_x = amplitude * np.sin(2 * np.pi * t)
78
  camera_y = amplitude * np.cos(2 * np.pi * t)
 
 
79
  elif animation_style == "spiral":
80
  radius = amplitude * (1 - t)
81
  camera_x = radius * np.sin(4 * np.pi * t)
82
  camera_y = radius * np.cos(4 * np.pi * t)
 
 
 
 
 
 
83
  else:
84
  raise ValueError(f"Unsupported animation style: {animation_style}")
85
 
86
- # Compute displacements
87
- displacement_x = k * camera_x * depth_tensor.squeeze()
88
- displacement_y = k * camera_y * depth_tensor.squeeze()
89
-
90
- # Calculate source coordinates
91
  source_pixel_x = pixel_grid[:, :, 0] + displacement_x
92
  source_pixel_y = pixel_grid[:, :, 1] + displacement_y
93
 
@@ -96,22 +148,34 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
96
  grid_y = 2 * source_pixel_y / (H - 1) - 1
97
  grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
98
 
99
- # Warp with bicubic interpolation
100
  warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
101
 
102
- # Downsample if SSAA is enabled
103
  if ssaa_factor > 1:
104
  downscale = Resize((image.height, image.width), antialias=True)
105
  warped = downscale(warped.squeeze(0)).unsqueeze(0)
106
 
107
- # Convert to PIL image for upscaling or further processing
 
 
 
 
108
  frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
109
  frame_img = (frame_img * 255).astype(np.uint8)
110
- frame_pil = Image.fromarray(frame_img)
111
 
112
- # Apply AuraSR-v2 upscaling if enabled
 
 
 
 
 
 
 
 
113
  if use_upscale:
114
- frame_pil = aura_sr.upscale_4x_overlapped(frame_pil) # 4x upscaling
 
115
  frame_img = np.array(frame_pil)
116
 
117
  # Apply TAA if enabled
@@ -132,30 +196,51 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
132
  return output_path
133
 
134
  # Gradio interface
135
- with gr.Blocks(title="Enhanced 3D Parallax Video Generator with Upscaling") as demo:
136
- gr.Markdown("# Enhanced 3D Parallax Video Generator with Upscaling")
137
- gr.Markdown("Create high-quality 3D parallax videos with advanced features and optional AuraSR-v2 upscaling.")
138
 
139
  with gr.Row():
140
  image_input = gr.Image(type="pil", label="Upload Image")
141
  depth_input = gr.Image(type="pil", label="Upload Depth Map")
142
 
143
  with gr.Row():
144
- animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
 
 
 
 
145
  amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
146
  k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
147
  fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
148
  duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
 
 
149
  ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
150
  use_taa = gr.Checkbox(label="Enable TAA", value=False)
151
  use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  generate_btn = gr.Button("Generate Video")
154
  video_output = gr.Video(label="Parallax Video")
155
 
156
  generate_btn.click(
157
  fn=generate_parallax_video,
158
- inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale],
 
 
 
 
159
  outputs=video_output
160
  )
161
 
 
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
+ from aura_sr import AuraSR
11
+ import cv2
12
+ import torch.nn.functional as F
13
 
14
  # Load AuraSR-v2 model once at startup
15
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
16
 
17
+ # Post-processing functions
18
+ def apply_lens_distortion(image, k1=0.2):
19
+ """Apply lens distortion using OpenCV."""
20
+ h, w = image.shape[:2]
21
+ camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32)
22
+ dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32)
23
+ distorted = cv2.undistort(image, camera_matrix, dist_coeffs)
24
+ return distorted
25
+
26
+ def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5):
27
+ """Apply depth of field blur using PyTorch."""
28
+ depth_diff = torch.abs(depth_tensor - focus_depth)
29
+ blur_kernel = blur_size * depth_diff.clamp(0, 1)
30
+ blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0)
31
+ padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect')
32
+ blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3)
33
+ mask = (depth_diff < 0.1).float()
34
+ return image_tensor * mask + blurred * (1 - mask)
35
+
36
+ def apply_vignette(image):
37
+ """Apply vignette effect."""
38
+ h, w = image.shape[:2]
39
+ x, y = np.meshgrid(np.arange(w), np.arange(h))
40
+ center_x, center_y = w / 2, h / 2
41
+ radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
42
+ max_radius = np.sqrt(center_x ** 2 + center_y ** 2)
43
+ vignette = 1 - (radius / max_radius) ** 2
44
+ vignette = np.clip(vignette, 0, 1)
45
+ return (image * vignette[..., np.newaxis]).astype(np.uint8)
46
+
47
+ def parse_keyframes(keyframe_text):
48
+ """Parse keyframe text into time-position pairs."""
49
+ keyframes = []
50
+ try:
51
+ for entry in keyframe_text.split():
52
+ time, pos = entry.split(':')
53
+ x, y = map(float, pos.split(','))
54
+ keyframes.append((float(time), x, y))
55
+ keyframes.sort() # Sort by time
56
+ return keyframes
57
+ except:
58
+ return [(0, 0, 0), (1, 0, 0)] # Default fallback
59
+
60
+ def interpolate_keyframes(t, keyframes):
61
+ """Interpolate camera position between keyframes."""
62
+ if t <= keyframes[0][0]:
63
+ return keyframes[0][1], keyframes[0][2]
64
+ if t >= keyframes[-1][0]:
65
+ return keyframes[-1][1], keyframes[-1][2]
66
+ for i in range(len(keyframes) - 1):
67
+ t1, x1, y1 = keyframes[i]
68
+ t2, x2, y2 = keyframes[i + 1]
69
+ if t1 <= t <= t2:
70
+ alpha = (t - t1) / (t2 - t1)
71
+ return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1)
72
+ return 0, 0 # Fallback
73
+
74
  @spaces.GPU
75
+ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text):
76
+ """Generate a 3D parallax video with advanced features."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Validate input dimensions
78
  if image.size != depth_map.size:
79
  raise ValueError("Image and depth map must have the same dimensions")
80
 
81
+ # Convert to tensors
82
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
83
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
84
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
 
95
  depth_tensor = upscale(depth_tensor)
96
 
97
  H, W = image_tensor.shape[1], image_tensor.shape[2]
 
 
98
  x = torch.arange(0, W).float().to('cuda')
99
  y = torch.arange(0, H).float().to('cuda')
100
  xx, yy = torch.meshgrid(x, y, indexing='xy')
101
  pixel_grid = torch.stack((xx, yy), dim=-1)
102
 
103
+ # Parse keyframes for custom path
104
+ keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None
105
+
106
  # Generate frames
107
  num_frames = int(fps * duration)
108
  frames = []
 
110
 
111
  for frame in range(num_frames):
112
  t = frame / num_frames
113
+ if animation_style == "zoom":
114
+ zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t)
115
+ displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
116
+ displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
117
+ elif animation_style == "horizontal":
118
  camera_x = amplitude * np.sin(2 * np.pi * t)
119
+ displacement_x = k * camera_x * depth_tensor.squeeze()
120
+ displacement_y = 0
121
  elif animation_style == "vertical":
 
122
  camera_y = amplitude * np.sin(2 * np.pi * t)
123
+ displacement_x = 0
124
+ displacement_y = k * camera_y * depth_tensor.squeeze()
125
  elif animation_style == "circle":
126
  camera_x = amplitude * np.sin(2 * np.pi * t)
127
  camera_y = amplitude * np.cos(2 * np.pi * t)
128
+ displacement_x = k * camera_x * depth_tensor.squeeze()
129
+ displacement_y = k * camera_y * depth_tensor.squeeze()
130
  elif animation_style == "spiral":
131
  radius = amplitude * (1 - t)
132
  camera_x = radius * np.sin(4 * np.pi * t)
133
  camera_y = radius * np.cos(4 * np.pi * t)
134
+ displacement_x = k * camera_x * depth_tensor.squeeze()
135
+ displacement_y = k * camera_y * depth_tensor.squeeze()
136
+ elif animation_style == "custom":
137
+ camera_x, camera_y = interpolate_keyframes(t, keyframes)
138
+ displacement_x = k * camera_x * depth_tensor.squeeze()
139
+ displacement_y = k * camera_y * depth_tensor.squeeze()
140
  else:
141
  raise ValueError(f"Unsupported animation style: {animation_style}")
142
 
 
 
 
 
 
143
  source_pixel_x = pixel_grid[:, :, 0] + displacement_x
144
  source_pixel_y = pixel_grid[:, :, 1] + displacement_y
145
 
 
148
  grid_y = 2 * source_pixel_y / (H - 1) - 1
149
  grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
150
 
151
+ # Warp image
152
  warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
153
 
154
+ # Downsample if SSAA
155
  if ssaa_factor > 1:
156
  downscale = Resize((image.height, image.width), antialias=True)
157
  warped = downscale(warped.squeeze(0)).unsqueeze(0)
158
 
159
+ # Apply depth of field if enabled
160
+ if apply_dof:
161
+ warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0))
162
+
163
+ # Convert to numpy
164
  frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
165
  frame_img = (frame_img * 255).astype(np.uint8)
 
166
 
167
+ # Apply lens distortion if enabled
168
+ if apply_lens:
169
+ frame_img = apply_lens_distortion(frame_img)
170
+
171
+ # Apply vignette if enabled
172
+ if apply_vig:
173
+ frame_img = apply_vignette(frame_img)
174
+
175
+ # Apply upscaling if enabled
176
  if use_upscale:
177
+ frame_pil = Image.fromarray(frame_img)
178
+ frame_pil = aura_sr.upscale_4x_overlapped(frame_pil)
179
  frame_img = np.array(frame_pil)
180
 
181
  # Apply TAA if enabled
 
196
  return output_path
197
 
198
  # Gradio interface
199
+ with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo:
200
+ gr.Markdown("# Ultimate 3D Parallax Video Generator")
201
+ gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.")
202
 
203
  with gr.Row():
204
  image_input = gr.Image(type="pil", label="Upload Image")
205
  depth_input = gr.Image(type="pil", label="Upload Depth Map")
206
 
207
  with gr.Row():
208
+ animation_style = gr.Dropdown(
209
+ ["zoom", "horizontal", "vertical", "circle", "spiral", "custom"],
210
+ label="Animation Style",
211
+ value="horizontal"
212
+ )
213
  amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
214
  k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
215
  fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
216
  duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
217
+
218
+ with gr.Row():
219
  ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
220
  use_taa = gr.Checkbox(label="Enable TAA", value=False)
221
  use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
222
+ apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False)
223
+ apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False)
224
+ apply_vig = gr.Checkbox(label="Apply Vignette", value=False)
225
+
226
+ with gr.Row():
227
+ keyframe_text = gr.Textbox(
228
+ label="Custom Keyframes (time:x,y)",
229
+ value="0:0,0 0.5:5,0 1:0,0",
230
+ placeholder="e.g., 0:0,0 0.5:5,0 1:0,0",
231
+ visible=True
232
+ )
233
 
234
  generate_btn = gr.Button("Generate Video")
235
  video_output = gr.Video(label="Parallax Video")
236
 
237
  generate_btn.click(
238
  fn=generate_parallax_video,
239
+ inputs=[
240
+ image_input, depth_input, animation_style, amplitude_slider, k_slider,
241
+ fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale,
242
+ apply_lens, apply_dof, apply_vig, keyframe_text
243
+ ],
244
  outputs=video_output
245
  )
246