FQiao commited on
Commit
eed95bc
·
verified ·
1 Parent(s): b466e69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -47
app.py CHANGED
@@ -135,6 +135,53 @@ with tempfile.TemporaryDirectory() as tmpdir:
135
  src_image = gr.State()
136
  src_depth = gr.State()
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # Blocks.
139
  gr.Markdown(
140
  """
@@ -191,53 +238,6 @@ with tempfile.TemporaryDirectory() as tmpdir:
191
  label='Generated Right', type='pil', interactive=False
192
  )
193
 
194
- def normalize_disp(disp):
195
- return (disp - disp.min()) / (disp.max() - disp.min())
196
-
197
- # Callbacks
198
- @spaces.GPU()
199
- def cb_mde(image_file: str):
200
- if not image_file:
201
- # Return None if no image is provided (e.g., when file is cleared).
202
- return None, None, None, None
203
-
204
- image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
205
- image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
206
-
207
- image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
208
-
209
- dam2 = get_dam2_model()
210
- depth_dam2 = dam2.infer_image(image_bgr)
211
- depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float()
212
-
213
- depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
214
-
215
- return image, depth_image, image, depth
216
-
217
- @spaces.GPU()
218
- def cb_generate(image, depth: Tensor, scale_factor):
219
- norm_disp = normalize_disp(depth.cuda())
220
- disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
221
-
222
- genstereo = get_genstereo_model()
223
- fusion_model = get_fusion_model()
224
-
225
- renders = genstereo(
226
- src_image=image,
227
- src_disparity=disp,
228
- ratio=None,
229
- )
230
- warped = (renders['warped'] + 1) / 2
231
-
232
- synthesized = renders['synthesized']
233
- mask = renders['mask']
234
- fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
235
-
236
- warped_pil = to_pil_image(warped[0])
237
- fusion_pil = to_pil_image(fusion_image[0])
238
-
239
- return warped_pil, fusion_pil
240
-
241
  # Events
242
  file.change(
243
  fn=cb_mde,
 
135
  src_image = gr.State()
136
  src_depth = gr.State()
137
 
138
+ def normalize_disp(disp):
139
+ return (disp - disp.min()) / (disp.max() - disp.min())
140
+
141
+ # Callbacks
142
+ @spaces.GPU()
143
+ def cb_mde(image_file: str):
144
+ if not image_file:
145
+ # Return None if no image is provided (e.g., when file is cleared).
146
+ return None, None, None, None
147
+
148
+ image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
149
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
150
+
151
+ image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
152
+
153
+ dam2 = get_dam2_model()
154
+ depth_dam2 = dam2.infer_image(image_bgr)
155
+ depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float()
156
+
157
+ depth_image = cv2.applyColorMap((normalize_disp(depth_dam2) * 255).astype(np.uint8), cv2.COLORMAP_JET)
158
+
159
+ return image, depth_image, image, depth
160
+
161
+ @spaces.GPU()
162
+ def cb_generate(image, depth: Tensor, scale_factor):
163
+ norm_disp = normalize_disp(depth.cuda())
164
+ disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
165
+
166
+ genstereo = get_genstereo_model()
167
+ fusion_model = get_fusion_model()
168
+
169
+ renders = genstereo(
170
+ src_image=image,
171
+ src_disparity=disp,
172
+ ratio=None,
173
+ )
174
+ warped = (renders['warped'] + 1) / 2
175
+
176
+ synthesized = renders['synthesized']
177
+ mask = renders['mask']
178
+ fusion_image = fusion_model(synthesized.float(), warped.float(), mask.float())
179
+
180
+ warped_pil = to_pil_image(warped[0])
181
+ fusion_pil = to_pil_image(fusion_image[0])
182
+
183
+ return warped_pil, fusion_pil
184
+
185
  # Blocks.
186
  gr.Markdown(
187
  """
 
238
  label='Generated Right', type='pil', interactive=False
239
  )
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # Events
242
  file.change(
243
  fn=cb_mde,