Update app.py
Browse files
app.py
CHANGED
@@ -7,144 +7,182 @@ from torchvision.transforms import ToTensor, Resize
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
@spaces.GPU
|
12 |
-
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps,
|
13 |
-
|
14 |
-
Generate
|
15 |
-
|
16 |
-
Args:
|
17 |
-
image (PIL.Image): Input RGB image.
|
18 |
-
depth_map (PIL.Image): Grayscale depth map.
|
19 |
-
animation_style (str): Animation type (e.g., horizontal, spiral).
|
20 |
-
amplitude (float): Camera movement intensity.
|
21 |
-
k (float): Depth displacement scale.
|
22 |
-
fps (int): Frames per second.
|
23 |
-
duration (float): Video duration in seconds.
|
24 |
-
ssaa_factor (int): Super sampling factor (1, 2, 4).
|
25 |
-
use_taa (bool): Enable temporal anti-aliasing.
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
str: Path to the generated video file.
|
29 |
-
"""
|
30 |
-
# Validate input dimensions
|
31 |
if image.size != depth_map.size:
|
32 |
-
raise ValueError("Image and depth map must
|
33 |
|
34 |
-
#
|
35 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
36 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
37 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
38 |
|
39 |
-
#
|
40 |
-
depth_np = depth_tensor.squeeze().cpu().numpy()
|
41 |
-
|
42 |
-
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
|
43 |
|
44 |
-
#
|
45 |
if ssaa_factor > 1:
|
46 |
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
|
47 |
image_tensor = upscale(image_tensor)
|
48 |
depth_tensor = upscale(depth_tensor)
|
49 |
|
50 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
|
|
|
|
51 |
|
52 |
-
#
|
53 |
-
x = torch.arange(0, W).float().to('cuda')
|
54 |
-
y = torch.arange(0, H).float().to('cuda')
|
55 |
-
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
56 |
-
pixel_grid = torch.stack((xx, yy), dim=-1)
|
57 |
-
|
58 |
-
# Generate frames
|
59 |
num_frames = int(fps * duration)
|
60 |
frames = []
|
61 |
prev_frame = None
|
62 |
|
63 |
-
for
|
64 |
-
t =
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
elif animation_style == "vertical":
|
69 |
-
camera_x = 0
|
70 |
-
camera_y = amplitude * np.sin(2 * np.pi * t)
|
71 |
-
elif animation_style == "circle":
|
72 |
-
camera_x = amplitude * np.sin(2 * np.pi * t)
|
73 |
-
camera_y = amplitude * np.cos(2 * np.pi * t)
|
74 |
-
elif animation_style == "spiral": # Inspired by DepthFlow
|
75 |
-
radius = amplitude * (1 - t)
|
76 |
-
camera_x = radius * np.sin(4 * np.pi * t)
|
77 |
-
camera_y = radius * np.cos(4 * np.pi * t)
|
78 |
-
else:
|
79 |
-
raise ValueError(f"Unsupported animation style: {animation_style}")
|
80 |
-
|
81 |
-
# Compute displacements
|
82 |
displacement_x = k * camera_x * depth_tensor.squeeze()
|
83 |
displacement_y = k * camera_y * depth_tensor.squeeze()
|
84 |
-
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
96 |
-
|
97 |
-
# Downsample if SSAA is enabled
|
98 |
-
if ssaa_factor > 1:
|
99 |
-
downscale = Resize((image.height, image.width), antialias=True)
|
100 |
-
warped = downscale(warped.squeeze(0)).unsqueeze(0)
|
101 |
-
|
102 |
-
# Convert to numpy
|
103 |
-
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
104 |
-
frame_img = (frame_img * 255).astype(np.uint8)
|
105 |
-
|
106 |
-
# Apply TAA if enabled
|
107 |
-
if use_taa and prev_frame is not None:
|
108 |
-
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
|
109 |
-
|
110 |
frames.append(frame_img)
|
111 |
-
prev_frame = frame_img
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
for frame in frames:
|
118 |
writer.append_data(frame)
|
119 |
writer.close()
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
#
|
124 |
-
|
125 |
-
|
126 |
-
gr.Markdown("Create high-quality 3D parallax videos with advanced features.")
|
127 |
-
|
128 |
with gr.Row():
|
129 |
-
image_input = gr.Image(type="pil", label="
|
130 |
-
depth_input = gr.Image(type="pil", label="
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
with gr.Row():
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
generate_btn.click(
|
145 |
-
fn=generate_parallax_video,
|
146 |
-
inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa],
|
147 |
-
outputs=video_output
|
148 |
-
)
|
149 |
|
150 |
demo.launch()
|
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
|
13 |
+
# ------------------------- AuraSR Model Definition ------------------------- #
|
14 |
+
class ResBlock(torch.nn.Module):
|
15 |
+
def __init__(self, n_filters):
|
16 |
+
super().__init__()
|
17 |
+
self.conv1 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
|
18 |
+
self.conv2 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
residual = x
|
22 |
+
x = torch.relu(self.conv1(x))
|
23 |
+
x = self.conv2(x)
|
24 |
+
x += residual
|
25 |
+
return x
|
26 |
+
|
27 |
+
class AuraSR(torch.nn.Module):
|
28 |
+
def __init__(self, scale=4, n_filters=64, n_blocks=8):
|
29 |
+
super().__init__()
|
30 |
+
self.scale = scale
|
31 |
+
self.head = torch.nn.Conv2d(3, n_filters, 3, padding=1)
|
32 |
+
self.body = torch.nn.Sequential(*[ResBlock(n_filters) for _ in range(n_blocks)])
|
33 |
+
self.tail = torch.nn.Sequential(
|
34 |
+
torch.nn.Conv2d(n_filters, n_filters * (scale ** 2), 3, padding=1),
|
35 |
+
torch.nn.PixelShuffle(scale),
|
36 |
+
torch.nn.Conv2d(n_filters, 3, 3, padding=1)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest')
|
41 |
+
x = self.head(x)
|
42 |
+
x = self.body(x)
|
43 |
+
x = self.tail(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
# Load AuraSR-v2 model
|
47 |
+
model_path = hf_hub_download(repo_id="fal/AuraSR-v2", filename="model.safetensors")
|
48 |
+
state_dict = load_file(model_path)
|
49 |
+
upscaler_model = AuraSR().eval().to('cuda')
|
50 |
+
upscaler_model.load_state_dict(state_dict)
|
51 |
+
|
52 |
+
# ------------------------- Core Parallax Function ------------------------- #
|
53 |
@spaces.GPU
|
54 |
+
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps,
|
55 |
+
duration, ssaa_factor, use_taa, use_upscaler):
|
56 |
+
"""Generate parallax video with optional super-resolution upscaling"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if image.size != depth_map.size:
|
58 |
+
raise ValueError("Image and depth map dimensions must match")
|
59 |
|
60 |
+
# Preprocess inputs
|
61 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
62 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
63 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
64 |
|
65 |
+
# Apply Gaussian smoothing to depth map
|
66 |
+
depth_np = gaussian_filter(depth_tensor.squeeze().cpu().numpy(), sigma=1)
|
67 |
+
depth_tensor = torch.tensor(depth_np, device='cuda').unsqueeze(0)
|
|
|
68 |
|
69 |
+
# Super Sampling Anti-Aliasing
|
70 |
if ssaa_factor > 1:
|
71 |
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
|
72 |
image_tensor = upscale(image_tensor)
|
73 |
depth_tensor = upscale(depth_tensor)
|
74 |
|
75 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
76 |
+
x, y = torch.meshgrid(torch.arange(W, device='cuda'), torch.arange(H, device='cuda'), indexing='xy')
|
77 |
+
pixel_grid = torch.stack((x, y), dim=-1)
|
78 |
|
79 |
+
# Animation parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
num_frames = int(fps * duration)
|
81 |
frames = []
|
82 |
prev_frame = None
|
83 |
|
84 |
+
for frame_idx in range(num_frames):
|
85 |
+
t = frame_idx / num_frames
|
86 |
+
camera_x, camera_y = calculate_movement(t, amplitude, animation_style)
|
87 |
+
|
88 |
+
# Calculate displacement
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
displacement_x = k * camera_x * depth_tensor.squeeze()
|
90 |
displacement_y = k * camera_y * depth_tensor.squeeze()
|
91 |
+
|
92 |
+
# Warp image
|
93 |
+
warped = warp_image(image_tensor, pixel_grid, displacement_x, displacement_y, W, H)
|
94 |
+
|
95 |
+
# Post-processing
|
96 |
+
frame_img = post_process_frame(warped, ssaa_factor, image.size, use_taa, prev_frame)
|
97 |
+
|
98 |
+
# Apply super-resolution
|
99 |
+
if use_upscaler:
|
100 |
+
frame_img = apply_upscaler(frame_img)
|
101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
frames.append(frame_img)
|
103 |
+
prev_frame = frame_img.copy() if use_taa else None
|
104 |
+
|
105 |
+
return save_video(frames, fps)
|
106 |
+
|
107 |
+
# ------------------------- Helper Functions ------------------------- #
|
108 |
+
def calculate_movement(t, amplitude, style):
|
109 |
+
"""Calculate camera movement based on animation style"""
|
110 |
+
if style == "horizontal":
|
111 |
+
return amplitude * np.sin(2*np.pi*t), 0
|
112 |
+
elif style == "vertical":
|
113 |
+
return 0, amplitude * np.sin(2*np.pi*t)
|
114 |
+
elif style == "circle":
|
115 |
+
return amplitude*np.sin(2*np.pi*t), amplitude*np.cos(2*np.pi*t)
|
116 |
+
elif style == "spiral":
|
117 |
+
radius = amplitude * (1 - t)
|
118 |
+
return radius*np.sin(4*np.pi*t), radius*np.cos(4*np.pi*t)
|
119 |
+
|
120 |
+
def warp_image(image_tensor, pixel_grid, dx, dy, W, H):
|
121 |
+
"""Warp image using computed displacements"""
|
122 |
+
source_x = pixel_grid[:, :, 0] + dx
|
123 |
+
source_y = pixel_grid[:, :, 1] + dy
|
124 |
+
grid = torch.stack((2*source_x/(W-1)-1, 2*source_y/(H-1)-1), dim=-1).unsqueeze(0)
|
125 |
+
return torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
126 |
+
|
127 |
+
def post_process_frame(warped, ssaa_factor, orig_size, use_taa, prev_frame):
|
128 |
+
"""Process frame with SSAA and TAA"""
|
129 |
+
if ssaa_factor > 1:
|
130 |
+
warped = Resize(orig_size[::-1], antialias=True)(warped.squeeze(0)).unsqueeze(0)
|
131 |
+
|
132 |
+
frame = (warped.squeeze().permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
|
133 |
+
|
134 |
+
if use_taa and prev_frame is not None:
|
135 |
+
frame = cv2.addWeighted(frame, 0.8, prev_frame, 0.2, 0)
|
136 |
+
|
137 |
+
return frame
|
138 |
+
|
139 |
+
def apply_upscaler(frame):
|
140 |
+
"""Apply 4x super-resolution using AuraSR-v2"""
|
141 |
+
tensor = torch.tensor(frame).permute(2,0,1).unsqueeze(0).float() / 255.0
|
142 |
+
with torch.no_grad():
|
143 |
+
upscaled = upscaler_model(tensor.to('cuda'))
|
144 |
+
return (upscaled[0].permute(1,2,0).clamp(0,1).cpu().numpy() * 255).astype(np.uint8)
|
145 |
+
|
146 |
+
def save_video(frames, fps):
|
147 |
+
"""Save frames to video file"""
|
148 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
149 |
+
writer = imageio.get_writer(f.name, fps=fps, codec='libx264', quality=9)
|
150 |
for frame in frames:
|
151 |
writer.append_data(frame)
|
152 |
writer.close()
|
153 |
+
return f.name
|
154 |
|
155 |
+
# ------------------------- Gradio Interface ------------------------- #
|
156 |
+
with gr.Blocks(title="3D Parallax Video Generator with Super-Resolution") as demo:
|
157 |
+
gr.Markdown("# 🔥 3D Parallax Video Generator with 4x Super-Resolution")
|
158 |
+
gr.Markdown("Generate stunning 3D parallax videos from 2D images with optional AI upscaling")
|
159 |
+
|
|
|
|
|
160 |
with gr.Row():
|
161 |
+
image_input = gr.Image(type="pil", label="Input Image")
|
162 |
+
depth_input = gr.Image(type="pil", label="Depth Map")
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
with gr.Column():
|
166 |
+
animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"],
|
167 |
+
value="horizontal", label="Animation Style")
|
168 |
+
amplitude = gr.Slider(0, 10, value=2, step=0.1, label="Movement Amplitude")
|
169 |
+
k = gr.Slider(0, 20, value=5, step=0.1, label="Depth Scaling Factor")
|
170 |
+
|
171 |
+
with gr.Column():
|
172 |
+
fps = gr.Slider(10, 60, value=30, step=1, label="FPS")
|
173 |
+
duration = gr.Slider(1, 10, value=5, step=0.1, label="Duration (seconds)")
|
174 |
+
ssaa_factor = gr.Dropdown([1, 2, 4], value=1, label="Anti-Aliasing Quality")
|
175 |
+
|
176 |
with gr.Row():
|
177 |
+
use_taa = gr.Checkbox(label="Enable Temporal Anti-Aliasing", value=False)
|
178 |
+
use_upscaler = gr.Checkbox(label="Enable 4x Super-Resolution (AuraSR-v2)", value=False)
|
179 |
+
|
180 |
+
generate_btn = gr.Button("Generate Video", variant="primary")
|
181 |
+
video_output = gr.Video(label="Generated Video", format="mp4")
|
182 |
+
|
183 |
+
generate_btn.click(fn=generate_parallax_video,
|
184 |
+
inputs=[image_input, depth_input, animation_style, amplitude, k,
|
185 |
+
fps, duration, ssaa_factor, use_taa, use_upscaler],
|
186 |
+
outputs=video_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
demo.launch()
|