Update app.py
Browse files
app.py
CHANGED
@@ -7,182 +7,156 @@ from torchvision.transforms import ToTensor, Resize
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
-
from
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
def __init__(self, n_filters):
|
16 |
-
super().__init__()
|
17 |
-
self.conv1 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
|
18 |
-
self.conv2 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
|
19 |
-
|
20 |
-
def forward(self, x):
|
21 |
-
residual = x
|
22 |
-
x = torch.relu(self.conv1(x))
|
23 |
-
x = self.conv2(x)
|
24 |
-
x += residual
|
25 |
-
return x
|
26 |
-
|
27 |
-
class AuraSR(torch.nn.Module):
|
28 |
-
def __init__(self, scale=4, n_filters=64, n_blocks=8):
|
29 |
-
super().__init__()
|
30 |
-
self.scale = scale
|
31 |
-
self.head = torch.nn.Conv2d(3, n_filters, 3, padding=1)
|
32 |
-
self.body = torch.nn.Sequential(*[ResBlock(n_filters) for _ in range(n_blocks)])
|
33 |
-
self.tail = torch.nn.Sequential(
|
34 |
-
torch.nn.Conv2d(n_filters, n_filters * (scale ** 2), 3, padding=1),
|
35 |
-
torch.nn.PixelShuffle(scale),
|
36 |
-
torch.nn.Conv2d(n_filters, 3, 3, padding=1)
|
37 |
-
)
|
38 |
-
|
39 |
-
def forward(self, x):
|
40 |
-
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest')
|
41 |
-
x = self.head(x)
|
42 |
-
x = self.body(x)
|
43 |
-
x = self.tail(x)
|
44 |
-
return x
|
45 |
-
|
46 |
-
# Load AuraSR-v2 model
|
47 |
-
model_path = hf_hub_download(repo_id="fal/AuraSR-v2", filename="model.safetensors")
|
48 |
-
state_dict = load_file(model_path)
|
49 |
-
upscaler_model = AuraSR().eval().to('cuda')
|
50 |
-
upscaler_model.load_state_dict(state_dict)
|
51 |
-
|
52 |
-
# ------------------------- Core Parallax Function ------------------------- #
|
53 |
@spaces.GPU
|
54 |
-
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps,
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if image.size != depth_map.size:
|
58 |
-
raise ValueError("Image and depth map
|
59 |
|
60 |
-
#
|
61 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
62 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
63 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
64 |
|
65 |
-
#
|
66 |
-
depth_np =
|
67 |
-
|
|
|
68 |
|
69 |
-
#
|
70 |
if ssaa_factor > 1:
|
71 |
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
|
72 |
image_tensor = upscale(image_tensor)
|
73 |
depth_tensor = upscale(depth_tensor)
|
74 |
|
75 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
76 |
-
x, y = torch.meshgrid(torch.arange(W, device='cuda'), torch.arange(H, device='cuda'), indexing='xy')
|
77 |
-
pixel_grid = torch.stack((x, y), dim=-1)
|
78 |
|
79 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
num_frames = int(fps * duration)
|
81 |
frames = []
|
82 |
prev_frame = None
|
83 |
|
84 |
-
for
|
85 |
-
t =
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
displacement_x = k * camera_x * depth_tensor.squeeze()
|
90 |
displacement_y = k * camera_y * depth_tensor.squeeze()
|
91 |
-
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
frames.append(frame_img)
|
103 |
prev_frame = frame_img.copy() if use_taa else None
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
"""Calculate camera movement based on animation style"""
|
110 |
-
if style == "horizontal":
|
111 |
-
return amplitude * np.sin(2*np.pi*t), 0
|
112 |
-
elif style == "vertical":
|
113 |
-
return 0, amplitude * np.sin(2*np.pi*t)
|
114 |
-
elif style == "circle":
|
115 |
-
return amplitude*np.sin(2*np.pi*t), amplitude*np.cos(2*np.pi*t)
|
116 |
-
elif style == "spiral":
|
117 |
-
radius = amplitude * (1 - t)
|
118 |
-
return radius*np.sin(4*np.pi*t), radius*np.cos(4*np.pi*t)
|
119 |
-
|
120 |
-
def warp_image(image_tensor, pixel_grid, dx, dy, W, H):
|
121 |
-
"""Warp image using computed displacements"""
|
122 |
-
source_x = pixel_grid[:, :, 0] + dx
|
123 |
-
source_y = pixel_grid[:, :, 1] + dy
|
124 |
-
grid = torch.stack((2*source_x/(W-1)-1, 2*source_y/(H-1)-1), dim=-1).unsqueeze(0)
|
125 |
-
return torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
126 |
-
|
127 |
-
def post_process_frame(warped, ssaa_factor, orig_size, use_taa, prev_frame):
|
128 |
-
"""Process frame with SSAA and TAA"""
|
129 |
-
if ssaa_factor > 1:
|
130 |
-
warped = Resize(orig_size[::-1], antialias=True)(warped.squeeze(0)).unsqueeze(0)
|
131 |
-
|
132 |
-
frame = (warped.squeeze().permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
|
133 |
-
|
134 |
-
if use_taa and prev_frame is not None:
|
135 |
-
frame = cv2.addWeighted(frame, 0.8, prev_frame, 0.2, 0)
|
136 |
-
|
137 |
-
return frame
|
138 |
-
|
139 |
-
def apply_upscaler(frame):
|
140 |
-
"""Apply 4x super-resolution using AuraSR-v2"""
|
141 |
-
tensor = torch.tensor(frame).permute(2,0,1).unsqueeze(0).float() / 255.0
|
142 |
-
with torch.no_grad():
|
143 |
-
upscaled = upscaler_model(tensor.to('cuda'))
|
144 |
-
return (upscaled[0].permute(1,2,0).clamp(0,1).cpu().numpy() * 255).astype(np.uint8)
|
145 |
-
|
146 |
-
def save_video(frames, fps):
|
147 |
-
"""Save frames to video file"""
|
148 |
-
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
149 |
-
writer = imageio.get_writer(f.name, fps=fps, codec='libx264', quality=9)
|
150 |
for frame in frames:
|
151 |
writer.append_data(frame)
|
152 |
writer.close()
|
153 |
-
return f.name
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
with
|
161 |
-
|
162 |
-
depth_input = gr.Image(type="pil", label="Depth Map")
|
163 |
-
|
164 |
-
with gr.Row():
|
165 |
-
with gr.Column():
|
166 |
-
animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"],
|
167 |
-
value="horizontal", label="Animation Style")
|
168 |
-
amplitude = gr.Slider(0, 10, value=2, step=0.1, label="Movement Amplitude")
|
169 |
-
k = gr.Slider(0, 20, value=5, step=0.1, label="Depth Scaling Factor")
|
170 |
-
|
171 |
-
with gr.Column():
|
172 |
-
fps = gr.Slider(10, 60, value=30, step=1, label="FPS")
|
173 |
-
duration = gr.Slider(1, 10, value=5, step=0.1, label="Duration (seconds)")
|
174 |
-
ssaa_factor = gr.Dropdown([1, 2, 4], value=1, label="Anti-Aliasing Quality")
|
175 |
-
|
176 |
with gr.Row():
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
generate_btn = gr.Button("Generate Video", variant="primary")
|
181 |
-
video_output = gr.Video(label="Generated Video", format="mp4")
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
demo.launch()
|
|
|
7 |
import spaces
|
8 |
import tempfile
|
9 |
from scipy.ndimage import gaussian_filter
|
10 |
+
from aura_sr import AuraSR # Import AuraSR for upscaling
|
11 |
+
|
12 |
+
# Load AuraSR-v2 model once at startup
|
13 |
+
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@spaces.GPU
|
16 |
+
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale):
|
17 |
+
"""
|
18 |
+
Generate a 3D parallax video with enhanced quality features and optional upscaling.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
image (PIL.Image): Input RGB image.
|
22 |
+
depth_map (PIL.Image): Grayscale depth map.
|
23 |
+
animation_style (str): Animation type.
|
24 |
+
amplitude (float): Camera movement intensity.
|
25 |
+
k (float): Depth displacement scale.
|
26 |
+
fps (int): Frames per second.
|
27 |
+
duration (float): Video duration in seconds.
|
28 |
+
ssaa_factor (int): Super sampling factor (1, 2, 4).
|
29 |
+
use_taa (bool): Enable temporal anti-aliasing.
|
30 |
+
use_upscale (bool): Enable AuraSR-v2 upscaling for each frame.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: Path to the generated video file.
|
34 |
+
"""
|
35 |
+
# Validate input dimensions
|
36 |
if image.size != depth_map.size:
|
37 |
+
raise ValueError("Image and depth map must have the same dimensions")
|
38 |
|
39 |
+
# Convert to tensors with high precision
|
40 |
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
|
41 |
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
|
42 |
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
|
43 |
|
44 |
+
# Smooth depth map
|
45 |
+
depth_np = depth_tensor.squeeze().cpu().numpy()
|
46 |
+
depth_np = gaussian_filter(depth_np, sigma=1)
|
47 |
+
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
|
48 |
|
49 |
+
# Apply SSAA
|
50 |
if ssaa_factor > 1:
|
51 |
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
|
52 |
image_tensor = upscale(image_tensor)
|
53 |
depth_tensor = upscale(depth_tensor)
|
54 |
|
55 |
H, W = image_tensor.shape[1], image_tensor.shape[2]
|
|
|
|
|
56 |
|
57 |
+
# Create coordinate grid
|
58 |
+
x = torch.arange(0, W).float().to('cuda')
|
59 |
+
y = torch.arange(0, H).float().to('cuda')
|
60 |
+
xx, yy = torch.meshgrid(x, y, indexing='xy')
|
61 |
+
pixel_grid = torch.stack((xx, yy), dim=-1)
|
62 |
+
|
63 |
+
# Generate frames
|
64 |
num_frames = int(fps * duration)
|
65 |
frames = []
|
66 |
prev_frame = None
|
67 |
|
68 |
+
for frame in range(num_frames):
|
69 |
+
t = frame / num_frames
|
70 |
+
if animation_style == "horizontal":
|
71 |
+
camera_x = amplitude * np.sin(2 * np.pi * t)
|
72 |
+
camera_y = 0
|
73 |
+
elif animation_style == "vertical":
|
74 |
+
camera_x = 0
|
75 |
+
camera_y = amplitude * np.sin(2 * np.pi * t)
|
76 |
+
elif animation_style == "circle":
|
77 |
+
camera_x = amplitude * np.sin(2 * np.pi * t)
|
78 |
+
camera_y = amplitude * np.cos(2 * np.pi * t)
|
79 |
+
elif animation_style == "spiral":
|
80 |
+
radius = amplitude * (1 - t)
|
81 |
+
camera_x = radius * np.sin(4 * np.pi * t)
|
82 |
+
camera_y = radius * np.cos(4 * np.pi * t)
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Unsupported animation style: {animation_style}")
|
85 |
+
|
86 |
+
# Compute displacements
|
87 |
displacement_x = k * camera_x * depth_tensor.squeeze()
|
88 |
displacement_y = k * camera_y * depth_tensor.squeeze()
|
89 |
+
|
90 |
+
# Calculate source coordinates
|
91 |
+
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
|
92 |
+
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
|
93 |
+
|
94 |
+
# Normalize to [-1, 1]
|
95 |
+
grid_x = 2 * source_pixel_x / (W - 1) - 1
|
96 |
+
grid_y = 2 * source_pixel_y / (H - 1) - 1
|
97 |
+
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
|
98 |
+
|
99 |
+
# Warp with bicubic interpolation
|
100 |
+
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
|
101 |
+
|
102 |
+
# Downsample if SSAA is enabled
|
103 |
+
if ssaa_factor > 1:
|
104 |
+
downscale = Resize((image.height, image.width), antialias=True)
|
105 |
+
warped = downscale(warped.squeeze(0)).unsqueeze(0)
|
106 |
+
|
107 |
+
# Convert to PIL image for upscaling or further processing
|
108 |
+
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
109 |
+
frame_img = (frame_img * 255).astype(np.uint8)
|
110 |
+
frame_pil = Image.fromarray(frame_img)
|
111 |
+
|
112 |
+
# Apply AuraSR-v2 upscaling if enabled
|
113 |
+
if use_upscale:
|
114 |
+
frame_pil = aura_sr.upscale_4x_overlapped(frame_pil) # 4x upscaling
|
115 |
+
frame_img = np.array(frame_pil)
|
116 |
+
|
117 |
+
# Apply TAA if enabled
|
118 |
+
if use_taa and prev_frame is not None:
|
119 |
+
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
|
120 |
+
|
121 |
frames.append(frame_img)
|
122 |
prev_frame = frame_img.copy() if use_taa else None
|
123 |
|
124 |
+
# Save video
|
125 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
126 |
+
output_path = tmpfile.name
|
127 |
+
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
for frame in frames:
|
129 |
writer.append_data(frame)
|
130 |
writer.close()
|
|
|
131 |
|
132 |
+
return output_path
|
133 |
+
|
134 |
+
# Gradio interface
|
135 |
+
with gr.Blocks(title="Enhanced 3D Parallax Video Generator with Upscaling") as demo:
|
136 |
+
gr.Markdown("# Enhanced 3D Parallax Video Generator with Upscaling")
|
137 |
+
gr.Markdown("Create high-quality 3D parallax videos with advanced features and optional AuraSR-v2 upscaling.")
|
138 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
with gr.Row():
|
140 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
141 |
+
depth_input = gr.Image(type="pil", label="Upload Depth Map")
|
|
|
|
|
|
|
142 |
|
143 |
+
with gr.Row():
|
144 |
+
animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
|
145 |
+
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
|
146 |
+
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
|
147 |
+
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
|
148 |
+
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
|
149 |
+
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
|
150 |
+
use_taa = gr.Checkbox(label="Enable TAA", value=False)
|
151 |
+
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
|
152 |
+
|
153 |
+
generate_btn = gr.Button("Generate Video")
|
154 |
+
video_output = gr.Video(label="Parallax Video")
|
155 |
+
|
156 |
+
generate_btn.click(
|
157 |
+
fn=generate_parallax_video,
|
158 |
+
inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale],
|
159 |
+
outputs=video_output
|
160 |
+
)
|
161 |
|
162 |
demo.launch()
|