FQiao commited on
Commit
0d91fab
·
verified ·
1 Parent(s): 9f58fa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -66,7 +66,7 @@ def download_models():
66
  download_models()
67
 
68
  # DepthAnythingV2
69
- if 'dam2' not in globals():
70
  model_configs = {
71
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
72
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
@@ -83,22 +83,25 @@ if 'dam2' not in globals():
83
  dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
84
  dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
85
  dam2 = dam2.to(DEVICE).eval()
 
86
 
87
  # GenStereo
88
- if 'genstereo' not in globals():
89
  genwarp_cfg = dict(
90
  pretrained_model_path='checkpoints',
91
  checkpoint_name=CHECKPOINT_NAME,
92
  half_precision_weights=True
93
  )
94
  genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
 
95
 
96
  # Adaptive Fusion
97
- if 'fusion_model' not in globals():
98
  fusion_model = AdaptiveFusionLayer()
99
  fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
100
  fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
101
  fusion_model = fusion_model.to(DEVICE).eval()
 
102
 
103
  # Crop the image to the shorter side.
104
  def crop(img: Image) -> Image:
@@ -190,6 +193,7 @@ with tempfile.TemporaryDirectory() as tmpdir:
190
 
191
  image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
192
 
 
193
  depth_dam2 = dam2.infer_image(image_bgr)
194
  depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
195
 
@@ -202,6 +206,9 @@ with tempfile.TemporaryDirectory() as tmpdir:
202
  norm_disp = normalize_disp(depth)
203
  disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
204
 
 
 
 
205
  renders = genstereo(
206
  src_image=image,
207
  src_disparity=disp,
@@ -231,4 +238,4 @@ with tempfile.TemporaryDirectory() as tmpdir:
231
  )
232
 
233
  if __name__ == '__main__':
234
- demo.launch()
 
66
  download_models()
67
 
68
  # DepthAnythingV2
69
+ def get_dam2_model():
70
  model_configs = {
71
  'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
72
  'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
 
83
  dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
84
  dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
85
  dam2 = dam2.to(DEVICE).eval()
86
+ return dam2
87
 
88
  # GenStereo
89
+ def get_genstereo_model():
90
  genwarp_cfg = dict(
91
  pretrained_model_path='checkpoints',
92
  checkpoint_name=CHECKPOINT_NAME,
93
  half_precision_weights=True
94
  )
95
  genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
96
+ return genstereo
97
 
98
  # Adaptive Fusion
99
+ def get_fusion_model():
100
  fusion_model = AdaptiveFusionLayer()
101
  fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
102
  fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
103
  fusion_model = fusion_model.to(DEVICE).eval()
104
+ return fusion_model
105
 
106
  # Crop the image to the shorter side.
107
  def crop(img: Image) -> Image:
 
193
 
194
  image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
195
 
196
+ dam2 = get_dam2_model()
197
  depth_dam2 = dam2.infer_image(image_bgr)
198
  depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
199
 
 
206
  norm_disp = normalize_disp(depth)
207
  disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
208
 
209
+ genstereo = get_genstereo_model()
210
+ fusion_model = get_fusion_model()
211
+
212
  renders = genstereo(
213
  src_image=image,
214
  src_disparity=disp,
 
238
  )
239
 
240
  if __name__ == '__main__':
241
+ demo.launch()