Update app.py
Browse files
app.py
CHANGED
@@ -7,36 +7,78 @@ from torchvision.transforms import ToTensor, Resize
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
-
from aura_sr import AuraSR
|
|
|
|
|
11 |
|
12 |
# Load AuraSR-v2 model once at startup
|
13 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@spaces.GPU
|
16 |
-
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale):
|
17 |
-
"""
|
18 |
-
Generate a 3D parallax video with enhanced quality features and optional upscaling.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
image (PIL.Image): Input RGB image.
|
22 |
-
depth_map (PIL.Image): Grayscale depth map.
|
23 |
-
animation_style (str): Animation type.
|
24 |
-
amplitude (float): Camera movement intensity.
|
25 |
-
k (float): Depth displacement scale.
|
26 |
-
fps (int): Frames per second.
|
27 |
-
duration (float): Video duration in seconds.
|
28 |
-
ssaa_factor (int): Super sampling factor (1, 2, 4).
|
29 |
-
use_taa (bool): Enable temporal anti-aliasing.
|
30 |
-
use_upscale (bool): Enable AuraSR-v2 upscaling for each frame.
|
31 |
-
|
32 |
-
Returns:
|
33 |
-
str: Path to the generated video file.
|
34 |
-
"""
|
35 |
# Validate input dimensions
|
36 |
if image.size != depth_map.size:
|
37 |
raise ValueError("Image and depth map must have the same dimensions")
|
38 |
|
39 |
-
# Convert to tensors
|
40 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
41 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
42 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
@@ -53,13 +95,14 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
53 |
depth_tensor = upscale(depth_tensor)
|
54 |
|
55 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
56 |
-
|
57 |
-
# Create coordinate grid
|
58 |
x = torch.arange(0, W).float().to('cuda')
|
59 |
y = torch.arange(0, H).float().to('cuda')
|
60 |
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
61 |
pixel_grid = torch.stack((xx, yy), dim=-1)
|
62 |
|
|
|
|
|
|
|
63 |
# Generate frames
|
64 |
num_frames = int(fps * duration)
|
65 |
frames = []
|
@@ -67,27 +110,36 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
67 |
|
68 |
for frame in range(num_frames):
|
69 |
t = frame / num_frames
|
70 |
-
if animation_style == "
|
|
|
|
|
|
|
|
|
71 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
72 |
-
|
|
|
73 |
elif animation_style == "vertical":
|
74 |
-
camera_x = 0
|
75 |
camera_y = amplitude * np.sin(2 * np.pi * t)
|
|
|
|
|
76 |
elif animation_style == "circle":
|
77 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
78 |
camera_y = amplitude * np.cos(2 * np.pi * t)
|
|
|
|
|
79 |
elif animation_style == "spiral":
|
80 |
radius = amplitude * (1 - t)
|
81 |
camera_x = radius * np.sin(4 * np.pi * t)
|
82 |
camera_y = radius * np.cos(4 * np.pi * t)
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
else:
|
84 |
raise ValueError(f"Unsupported animation style: {animation_style}")
|
85 |
|
86 |
-
# Compute displacements
|
87 |
-
displacement_x = k * camera_x * depth_tensor.squeeze()
|
88 |
-
displacement_y = k * camera_y * depth_tensor.squeeze()
|
89 |
-
|
90 |
-
# Calculate source coordinates
|
91 |
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
|
92 |
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
|
93 |
|
@@ -96,22 +148,34 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
96 |
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
97 |
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
|
98 |
|
99 |
-
# Warp
|
100 |
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
101 |
|
102 |
-
# Downsample if SSAA
|
103 |
if ssaa_factor > 1:
|
104 |
downscale = Resize((image.height, image.width), antialias=True)
|
105 |
warped = downscale(warped.squeeze(0)).unsqueeze(0)
|
106 |
|
107 |
-
#
|
|
|
|
|
|
|
|
|
108 |
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
109 |
frame_img = (frame_img * 255).astype(np.uint8)
|
110 |
-
frame_pil = Image.fromarray(frame_img)
|
111 |
|
112 |
-
# Apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
if use_upscale:
|
114 |
-
frame_pil =
|
|
|
115 |
frame_img = np.array(frame_pil)
|
116 |
|
117 |
# Apply TAA if enabled
|
@@ -132,30 +196,51 @@ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps
|
|
132 |
return output_path
|
133 |
|
134 |
# Gradio interface
|
135 |
-
with gr.Blocks(title="
|
136 |
-
gr.Markdown("#
|
137 |
-
gr.Markdown("
|
138 |
|
139 |
with gr.Row():
|
140 |
image_input = gr.Image(type="pil", label="Upload Image")
|
141 |
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
142 |
|
143 |
with gr.Row():
|
144 |
-
animation_style = gr.Dropdown(
|
|
|
|
|
|
|
|
|
145 |
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
|
146 |
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
|
147 |
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
|
148 |
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
|
|
|
|
|
149 |
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
|
150 |
use_taa = gr.Checkbox(label="Enable TAA", value=False)
|
151 |
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
generate_btn = gr.Button("Generate Video")
|
154 |
video_output = gr.Video(label="Parallax Video")
|
155 |
|
156 |
generate_btn.click(
|
157 |
fn=generate_parallax_video,
|
158 |
-
inputs=[
|
|
|
|
|
|
|
|
|
159 |
outputs=video_output
|
160 |
)
|
161 |
|
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
+
from aura_sr import AuraSR
|
11 |
+
import cv2
|
12 |
+
import torch.nn.functional as F
|
13 |
|
14 |
# Load AuraSR-v2 model once at startup
|
15 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
16 |
|
17 |
+
# Post-processing functions
|
18 |
+
def apply_lens_distortion(image, k1=0.2):
|
19 |
+
"""Apply lens distortion using OpenCV."""
|
20 |
+
h, w = image.shape[:2]
|
21 |
+
camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32)
|
22 |
+
dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32)
|
23 |
+
distorted = cv2.undistort(image, camera_matrix, dist_coeffs)
|
24 |
+
return distorted
|
25 |
+
|
26 |
+
def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5):
|
27 |
+
"""Apply depth of field blur using PyTorch."""
|
28 |
+
depth_diff = torch.abs(depth_tensor - focus_depth)
|
29 |
+
blur_kernel = blur_size * depth_diff.clamp(0, 1)
|
30 |
+
blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0)
|
31 |
+
padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect')
|
32 |
+
blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3)
|
33 |
+
mask = (depth_diff < 0.1).float()
|
34 |
+
return image_tensor * mask + blurred * (1 - mask)
|
35 |
+
|
36 |
+
def apply_vignette(image):
|
37 |
+
"""Apply vignette effect."""
|
38 |
+
h, w = image.shape[:2]
|
39 |
+
x, y = np.meshgrid(np.arange(w), np.arange(h))
|
40 |
+
center_x, center_y = w / 2, h / 2
|
41 |
+
radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
|
42 |
+
max_radius = np.sqrt(center_x ** 2 + center_y ** 2)
|
43 |
+
vignette = 1 - (radius / max_radius) ** 2
|
44 |
+
vignette = np.clip(vignette, 0, 1)
|
45 |
+
return (image * vignette[..., np.newaxis]).astype(np.uint8)
|
46 |
+
|
47 |
+
def parse_keyframes(keyframe_text):
|
48 |
+
"""Parse keyframe text into time-position pairs."""
|
49 |
+
keyframes = []
|
50 |
+
try:
|
51 |
+
for entry in keyframe_text.split():
|
52 |
+
time, pos = entry.split(':')
|
53 |
+
x, y = map(float, pos.split(','))
|
54 |
+
keyframes.append((float(time), x, y))
|
55 |
+
keyframes.sort() # Sort by time
|
56 |
+
return keyframes
|
57 |
+
except:
|
58 |
+
return [(0, 0, 0), (1, 0, 0)] # Default fallback
|
59 |
+
|
60 |
+
def interpolate_keyframes(t, keyframes):
|
61 |
+
"""Interpolate camera position between keyframes."""
|
62 |
+
if t <= keyframes[0][0]:
|
63 |
+
return keyframes[0][1], keyframes[0][2]
|
64 |
+
if t >= keyframes[-1][0]:
|
65 |
+
return keyframes[-1][1], keyframes[-1][2]
|
66 |
+
for i in range(len(keyframes) - 1):
|
67 |
+
t1, x1, y1 = keyframes[i]
|
68 |
+
t2, x2, y2 = keyframes[i + 1]
|
69 |
+
if t1 <= t <= t2:
|
70 |
+
alpha = (t - t1) / (t2 - t1)
|
71 |
+
return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1)
|
72 |
+
return 0, 0 # Fallback
|
73 |
+
|
74 |
@spaces.GPU
|
75 |
+
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text):
|
76 |
+
"""Generate a 3D parallax video with advanced features."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
# Validate input dimensions
|
78 |
if image.size != depth_map.size:
|
79 |
raise ValueError("Image and depth map must have the same dimensions")
|
80 |
|
81 |
+
# Convert to tensors
|
82 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
83 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
84 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
|
|
95 |
depth_tensor = upscale(depth_tensor)
|
96 |
|
97 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
|
|
|
|
98 |
x = torch.arange(0, W).float().to('cuda')
|
99 |
y = torch.arange(0, H).float().to('cuda')
|
100 |
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
101 |
pixel_grid = torch.stack((xx, yy), dim=-1)
|
102 |
|
103 |
+
# Parse keyframes for custom path
|
104 |
+
keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None
|
105 |
+
|
106 |
# Generate frames
|
107 |
num_frames = int(fps * duration)
|
108 |
frames = []
|
|
|
110 |
|
111 |
for frame in range(num_frames):
|
112 |
t = frame / num_frames
|
113 |
+
if animation_style == "zoom":
|
114 |
+
zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t)
|
115 |
+
displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
|
116 |
+
displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
|
117 |
+
elif animation_style == "horizontal":
|
118 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
119 |
+
displacement_x = k * camera_x * depth_tensor.squeeze()
|
120 |
+
displacement_y = 0
|
121 |
elif animation_style == "vertical":
|
|
|
122 |
camera_y = amplitude * np.sin(2 * np.pi * t)
|
123 |
+
displacement_x = 0
|
124 |
+
displacement_y = k * camera_y * depth_tensor.squeeze()
|
125 |
elif animation_style == "circle":
|
126 |
camera_x = amplitude * np.sin(2 * np.pi * t)
|
127 |
camera_y = amplitude * np.cos(2 * np.pi * t)
|
128 |
+
displacement_x = k * camera_x * depth_tensor.squeeze()
|
129 |
+
displacement_y = k * camera_y * depth_tensor.squeeze()
|
130 |
elif animation_style == "spiral":
|
131 |
radius = amplitude * (1 - t)
|
132 |
camera_x = radius * np.sin(4 * np.pi * t)
|
133 |
camera_y = radius * np.cos(4 * np.pi * t)
|
134 |
+
displacement_x = k * camera_x * depth_tensor.squeeze()
|
135 |
+
displacement_y = k * camera_y * depth_tensor.squeeze()
|
136 |
+
elif animation_style == "custom":
|
137 |
+
camera_x, camera_y = interpolate_keyframes(t, keyframes)
|
138 |
+
displacement_x = k * camera_x * depth_tensor.squeeze()
|
139 |
+
displacement_y = k * camera_y * depth_tensor.squeeze()
|
140 |
else:
|
141 |
raise ValueError(f"Unsupported animation style: {animation_style}")
|
142 |
|
|
|
|
|
|
|
|
|
|
|
143 |
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
|
144 |
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
|
145 |
|
|
|
148 |
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
149 |
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
|
150 |
|
151 |
+
# Warp image
|
152 |
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
153 |
|
154 |
+
# Downsample if SSAA
|
155 |
if ssaa_factor > 1:
|
156 |
downscale = Resize((image.height, image.width), antialias=True)
|
157 |
warped = downscale(warped.squeeze(0)).unsqueeze(0)
|
158 |
|
159 |
+
# Apply depth of field if enabled
|
160 |
+
if apply_dof:
|
161 |
+
warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0))
|
162 |
+
|
163 |
+
# Convert to numpy
|
164 |
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
165 |
frame_img = (frame_img * 255).astype(np.uint8)
|
|
|
166 |
|
167 |
+
# Apply lens distortion if enabled
|
168 |
+
if apply_lens:
|
169 |
+
frame_img = apply_lens_distortion(frame_img)
|
170 |
+
|
171 |
+
# Apply vignette if enabled
|
172 |
+
if apply_vig:
|
173 |
+
frame_img = apply_vignette(frame_img)
|
174 |
+
|
175 |
+
# Apply upscaling if enabled
|
176 |
if use_upscale:
|
177 |
+
frame_pil = Image.fromarray(frame_img)
|
178 |
+
frame_pil = aura_sr.upscale_4x_overlapped(frame_pil)
|
179 |
frame_img = np.array(frame_pil)
|
180 |
|
181 |
# Apply TAA if enabled
|
|
|
196 |
return output_path
|
197 |
|
198 |
# Gradio interface
|
199 |
+
with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo:
|
200 |
+
gr.Markdown("# Ultimate 3D Parallax Video Generator")
|
201 |
+
gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.")
|
202 |
|
203 |
with gr.Row():
|
204 |
image_input = gr.Image(type="pil", label="Upload Image")
|
205 |
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
206 |
|
207 |
with gr.Row():
|
208 |
+
animation_style = gr.Dropdown(
|
209 |
+
["zoom", "horizontal", "vertical", "circle", "spiral", "custom"],
|
210 |
+
label="Animation Style",
|
211 |
+
value="horizontal"
|
212 |
+
)
|
213 |
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
|
214 |
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
|
215 |
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
|
216 |
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
|
217 |
+
|
218 |
+
with gr.Row():
|
219 |
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
|
220 |
use_taa = gr.Checkbox(label="Enable TAA", value=False)
|
221 |
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
|
222 |
+
apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False)
|
223 |
+
apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False)
|
224 |
+
apply_vig = gr.Checkbox(label="Apply Vignette", value=False)
|
225 |
+
|
226 |
+
with gr.Row():
|
227 |
+
keyframe_text = gr.Textbox(
|
228 |
+
label="Custom Keyframes (time:x,y)",
|
229 |
+
value="0:0,0 0.5:5,0 1:0,0",
|
230 |
+
placeholder="e.g., 0:0,0 0.5:5,0 1:0,0",
|
231 |
+
visible=True
|
232 |
+
)
|
233 |
|
234 |
generate_btn = gr.Button("Generate Video")
|
235 |
video_output = gr.Video(label="Parallax Video")
|
236 |
|
237 |
generate_btn.click(
|
238 |
fn=generate_parallax_video,
|
239 |
+
inputs=[
|
240 |
+
image_input, depth_input, animation_style, amplitude_slider, k_slider,
|
241 |
+
fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale,
|
242 |
+
apply_lens, apply_dof, apply_vig, keyframe_text
|
243 |
+
],
|
244 |
outputs=video_output
|
245 |
)
|
246 |
|