yuyutsu07 commited on
Commit
18e62f1
·
verified ·
1 Parent(s): 09c414e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -147
app.py CHANGED
@@ -7,182 +7,156 @@ from torchvision.transforms import ToTensor, Resize
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
- from huggingface_hub import hf_hub_download
11
- from safetensors.torch import load_file
12
-
13
- # ------------------------- AuraSR Model Definition ------------------------- #
14
- class ResBlock(torch.nn.Module):
15
- def __init__(self, n_filters):
16
- super().__init__()
17
- self.conv1 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
18
- self.conv2 = torch.nn.Conv2d(n_filters, n_filters, 3, padding=1)
19
-
20
- def forward(self, x):
21
- residual = x
22
- x = torch.relu(self.conv1(x))
23
- x = self.conv2(x)
24
- x += residual
25
- return x
26
-
27
- class AuraSR(torch.nn.Module):
28
- def __init__(self, scale=4, n_filters=64, n_blocks=8):
29
- super().__init__()
30
- self.scale = scale
31
- self.head = torch.nn.Conv2d(3, n_filters, 3, padding=1)
32
- self.body = torch.nn.Sequential(*[ResBlock(n_filters) for _ in range(n_blocks)])
33
- self.tail = torch.nn.Sequential(
34
- torch.nn.Conv2d(n_filters, n_filters * (scale ** 2), 3, padding=1),
35
- torch.nn.PixelShuffle(scale),
36
- torch.nn.Conv2d(n_filters, 3, 3, padding=1)
37
- )
38
-
39
- def forward(self, x):
40
- x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest')
41
- x = self.head(x)
42
- x = self.body(x)
43
- x = self.tail(x)
44
- return x
45
-
46
- # Load AuraSR-v2 model
47
- model_path = hf_hub_download(repo_id="fal/AuraSR-v2", filename="model.safetensors")
48
- state_dict = load_file(model_path)
49
- upscaler_model = AuraSR().eval().to('cuda')
50
- upscaler_model.load_state_dict(state_dict)
51
-
52
- # ------------------------- Core Parallax Function ------------------------- #
53
  @spaces.GPU
54
- def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps,
55
- duration, ssaa_factor, use_taa, use_upscaler):
56
- """Generate parallax video with optional super-resolution upscaling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if image.size != depth_map.size:
58
- raise ValueError("Image and depth map dimensions must match")
59
 
60
- # Preprocess inputs
61
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
62
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
63
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
64
 
65
- # Apply Gaussian smoothing to depth map
66
- depth_np = gaussian_filter(depth_tensor.squeeze().cpu().numpy(), sigma=1)
67
- depth_tensor = torch.tensor(depth_np, device='cuda').unsqueeze(0)
 
68
 
69
- # Super Sampling Anti-Aliasing
70
  if ssaa_factor > 1:
71
  upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
72
  image_tensor = upscale(image_tensor)
73
  depth_tensor = upscale(depth_tensor)
74
 
75
  H, W = image_tensor.shape[1], image_tensor.shape[2]
76
- x, y = torch.meshgrid(torch.arange(W, device='cuda'), torch.arange(H, device='cuda'), indexing='xy')
77
- pixel_grid = torch.stack((x, y), dim=-1)
78
 
79
- # Animation parameters
 
 
 
 
 
 
80
  num_frames = int(fps * duration)
81
  frames = []
82
  prev_frame = None
83
 
84
- for frame_idx in range(num_frames):
85
- t = frame_idx / num_frames
86
- camera_x, camera_y = calculate_movement(t, amplitude, animation_style)
87
-
88
- # Calculate displacement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  displacement_x = k * camera_x * depth_tensor.squeeze()
90
  displacement_y = k * camera_y * depth_tensor.squeeze()
91
-
92
- # Warp image
93
- warped = warp_image(image_tensor, pixel_grid, displacement_x, displacement_y, W, H)
94
-
95
- # Post-processing
96
- frame_img = post_process_frame(warped, ssaa_factor, image.size, use_taa, prev_frame)
97
-
98
- # Apply super-resolution
99
- if use_upscaler:
100
- frame_img = apply_upscaler(frame_img)
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  frames.append(frame_img)
103
  prev_frame = frame_img.copy() if use_taa else None
104
 
105
- return save_video(frames, fps)
106
-
107
- # ------------------------- Helper Functions ------------------------- #
108
- def calculate_movement(t, amplitude, style):
109
- """Calculate camera movement based on animation style"""
110
- if style == "horizontal":
111
- return amplitude * np.sin(2*np.pi*t), 0
112
- elif style == "vertical":
113
- return 0, amplitude * np.sin(2*np.pi*t)
114
- elif style == "circle":
115
- return amplitude*np.sin(2*np.pi*t), amplitude*np.cos(2*np.pi*t)
116
- elif style == "spiral":
117
- radius = amplitude * (1 - t)
118
- return radius*np.sin(4*np.pi*t), radius*np.cos(4*np.pi*t)
119
-
120
- def warp_image(image_tensor, pixel_grid, dx, dy, W, H):
121
- """Warp image using computed displacements"""
122
- source_x = pixel_grid[:, :, 0] + dx
123
- source_y = pixel_grid[:, :, 1] + dy
124
- grid = torch.stack((2*source_x/(W-1)-1, 2*source_y/(H-1)-1), dim=-1).unsqueeze(0)
125
- return torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
126
-
127
- def post_process_frame(warped, ssaa_factor, orig_size, use_taa, prev_frame):
128
- """Process frame with SSAA and TAA"""
129
- if ssaa_factor > 1:
130
- warped = Resize(orig_size[::-1], antialias=True)(warped.squeeze(0)).unsqueeze(0)
131
-
132
- frame = (warped.squeeze().permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
133
-
134
- if use_taa and prev_frame is not None:
135
- frame = cv2.addWeighted(frame, 0.8, prev_frame, 0.2, 0)
136
-
137
- return frame
138
-
139
- def apply_upscaler(frame):
140
- """Apply 4x super-resolution using AuraSR-v2"""
141
- tensor = torch.tensor(frame).permute(2,0,1).unsqueeze(0).float() / 255.0
142
- with torch.no_grad():
143
- upscaled = upscaler_model(tensor.to('cuda'))
144
- return (upscaled[0].permute(1,2,0).clamp(0,1).cpu().numpy() * 255).astype(np.uint8)
145
-
146
- def save_video(frames, fps):
147
- """Save frames to video file"""
148
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
149
- writer = imageio.get_writer(f.name, fps=fps, codec='libx264', quality=9)
150
  for frame in frames:
151
  writer.append_data(frame)
152
  writer.close()
153
- return f.name
154
 
155
- # ------------------------- Gradio Interface ------------------------- #
156
- with gr.Blocks(title="3D Parallax Video Generator with Super-Resolution") as demo:
157
- gr.Markdown("# 🔥 3D Parallax Video Generator with 4x Super-Resolution")
158
- gr.Markdown("Generate stunning 3D parallax videos from 2D images with optional AI upscaling")
159
-
160
- with gr.Row():
161
- image_input = gr.Image(type="pil", label="Input Image")
162
- depth_input = gr.Image(type="pil", label="Depth Map")
163
-
164
- with gr.Row():
165
- with gr.Column():
166
- animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"],
167
- value="horizontal", label="Animation Style")
168
- amplitude = gr.Slider(0, 10, value=2, step=0.1, label="Movement Amplitude")
169
- k = gr.Slider(0, 20, value=5, step=0.1, label="Depth Scaling Factor")
170
-
171
- with gr.Column():
172
- fps = gr.Slider(10, 60, value=30, step=1, label="FPS")
173
- duration = gr.Slider(1, 10, value=5, step=0.1, label="Duration (seconds)")
174
- ssaa_factor = gr.Dropdown([1, 2, 4], value=1, label="Anti-Aliasing Quality")
175
-
176
  with gr.Row():
177
- use_taa = gr.Checkbox(label="Enable Temporal Anti-Aliasing", value=False)
178
- use_upscaler = gr.Checkbox(label="Enable 4x Super-Resolution (AuraSR-v2)", value=False)
179
-
180
- generate_btn = gr.Button("Generate Video", variant="primary")
181
- video_output = gr.Video(label="Generated Video", format="mp4")
182
 
183
- generate_btn.click(fn=generate_parallax_video,
184
- inputs=[image_input, depth_input, animation_style, amplitude, k,
185
- fps, duration, ssaa_factor, use_taa, use_upscaler],
186
- outputs=video_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  demo.launch()
 
7
  import spaces
8
  import tempfile
9
  from scipy.ndimage import gaussian_filter
10
+ from aura_sr import AuraSR # Import AuraSR for upscaling
11
+
12
+ # Load AuraSR-v2 model once at startup
13
+ aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
14
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
+ def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale):
17
+ """
18
+ Generate a 3D parallax video with enhanced quality features and optional upscaling.
19
+
20
+ Args:
21
+ image (PIL.Image): Input RGB image.
22
+ depth_map (PIL.Image): Grayscale depth map.
23
+ animation_style (str): Animation type.
24
+ amplitude (float): Camera movement intensity.
25
+ k (float): Depth displacement scale.
26
+ fps (int): Frames per second.
27
+ duration (float): Video duration in seconds.
28
+ ssaa_factor (int): Super sampling factor (1, 2, 4).
29
+ use_taa (bool): Enable temporal anti-aliasing.
30
+ use_upscale (bool): Enable AuraSR-v2 upscaling for each frame.
31
+
32
+ Returns:
33
+ str: Path to the generated video file.
34
+ """
35
+ # Validate input dimensions
36
  if image.size != depth_map.size:
37
+ raise ValueError("Image and depth map must have the same dimensions")
38
 
39
+ # Convert to tensors with high precision
40
  image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
41
  depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
42
  depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
43
 
44
+ # Smooth depth map
45
+ depth_np = depth_tensor.squeeze().cpu().numpy()
46
+ depth_np = gaussian_filter(depth_np, sigma=1)
47
+ depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
48
 
49
+ # Apply SSAA
50
  if ssaa_factor > 1:
51
  upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
52
  image_tensor = upscale(image_tensor)
53
  depth_tensor = upscale(depth_tensor)
54
 
55
  H, W = image_tensor.shape[1], image_tensor.shape[2]
 
 
56
 
57
+ # Create coordinate grid
58
+ x = torch.arange(0, W).float().to('cuda')
59
+ y = torch.arange(0, H).float().to('cuda')
60
+ xx, yy = torch.meshgrid(x, y, indexing='xy')
61
+ pixel_grid = torch.stack((xx, yy), dim=-1)
62
+
63
+ # Generate frames
64
  num_frames = int(fps * duration)
65
  frames = []
66
  prev_frame = None
67
 
68
+ for frame in range(num_frames):
69
+ t = frame / num_frames
70
+ if animation_style == "horizontal":
71
+ camera_x = amplitude * np.sin(2 * np.pi * t)
72
+ camera_y = 0
73
+ elif animation_style == "vertical":
74
+ camera_x = 0
75
+ camera_y = amplitude * np.sin(2 * np.pi * t)
76
+ elif animation_style == "circle":
77
+ camera_x = amplitude * np.sin(2 * np.pi * t)
78
+ camera_y = amplitude * np.cos(2 * np.pi * t)
79
+ elif animation_style == "spiral":
80
+ radius = amplitude * (1 - t)
81
+ camera_x = radius * np.sin(4 * np.pi * t)
82
+ camera_y = radius * np.cos(4 * np.pi * t)
83
+ else:
84
+ raise ValueError(f"Unsupported animation style: {animation_style}")
85
+
86
+ # Compute displacements
87
  displacement_x = k * camera_x * depth_tensor.squeeze()
88
  displacement_y = k * camera_y * depth_tensor.squeeze()
89
+
90
+ # Calculate source coordinates
91
+ source_pixel_x = pixel_grid[:, :, 0] + displacement_x
92
+ source_pixel_y = pixel_grid[:, :, 1] + displacement_y
93
+
94
+ # Normalize to [-1, 1]
95
+ grid_x = 2 * source_pixel_x / (W - 1) - 1
96
+ grid_y = 2 * source_pixel_y / (H - 1) - 1
97
+ grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
98
+
99
+ # Warp with bicubic interpolation
100
+ warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
101
+
102
+ # Downsample if SSAA is enabled
103
+ if ssaa_factor > 1:
104
+ downscale = Resize((image.height, image.width), antialias=True)
105
+ warped = downscale(warped.squeeze(0)).unsqueeze(0)
106
+
107
+ # Convert to PIL image for upscaling or further processing
108
+ frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
109
+ frame_img = (frame_img * 255).astype(np.uint8)
110
+ frame_pil = Image.fromarray(frame_img)
111
+
112
+ # Apply AuraSR-v2 upscaling if enabled
113
+ if use_upscale:
114
+ frame_pil = aura_sr.upscale_4x_overlapped(frame_pil) # 4x upscaling
115
+ frame_img = np.array(frame_pil)
116
+
117
+ # Apply TAA if enabled
118
+ if use_taa and prev_frame is not None:
119
+ frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
120
+
121
  frames.append(frame_img)
122
  prev_frame = frame_img.copy() if use_taa else None
123
 
124
+ # Save video
125
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
126
+ output_path = tmpfile.name
127
+ writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  for frame in frames:
129
  writer.append_data(frame)
130
  writer.close()
 
131
 
132
+ return output_path
133
+
134
+ # Gradio interface
135
+ with gr.Blocks(title="Enhanced 3D Parallax Video Generator with Upscaling") as demo:
136
+ gr.Markdown("# Enhanced 3D Parallax Video Generator with Upscaling")
137
+ gr.Markdown("Create high-quality 3D parallax videos with advanced features and optional AuraSR-v2 upscaling.")
138
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with gr.Row():
140
+ image_input = gr.Image(type="pil", label="Upload Image")
141
+ depth_input = gr.Image(type="pil", label="Upload Depth Map")
 
 
 
142
 
143
+ with gr.Row():
144
+ animation_style = gr.Dropdown(["horizontal", "vertical", "circle", "spiral"], label="Animation Style", value="horizontal")
145
+ amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
146
+ k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
147
+ fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
148
+ duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
149
+ ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
150
+ use_taa = gr.Checkbox(label="Enable TAA", value=False)
151
+ use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
152
+
153
+ generate_btn = gr.Button("Generate Video")
154
+ video_output = gr.Video(label="Parallax Video")
155
+
156
+ generate_btn.click(
157
+ fn=generate_parallax_video,
158
+ inputs=[image_input, depth_input, animation_style, amplitude_slider, k_slider, fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale],
159
+ outputs=video_output
160
+ )
161
 
162
  demo.launch()