Update app.py
Browse files
app.py
CHANGED
@@ -66,7 +66,7 @@ def download_models():
|
|
66 |
download_models()
|
67 |
|
68 |
# DepthAnythingV2
|
69 |
-
|
70 |
model_configs = {
|
71 |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
72 |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
@@ -83,22 +83,25 @@ if 'dam2' not in globals():
|
|
83 |
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
|
84 |
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
|
85 |
dam2 = dam2.to(DEVICE).eval()
|
|
|
86 |
|
87 |
# GenStereo
|
88 |
-
|
89 |
genwarp_cfg = dict(
|
90 |
pretrained_model_path='checkpoints',
|
91 |
checkpoint_name=CHECKPOINT_NAME,
|
92 |
half_precision_weights=True
|
93 |
)
|
94 |
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
|
|
|
95 |
|
96 |
# Adaptive Fusion
|
97 |
-
|
98 |
fusion_model = AdaptiveFusionLayer()
|
99 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
100 |
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
101 |
fusion_model = fusion_model.to(DEVICE).eval()
|
|
|
102 |
|
103 |
# Crop the image to the shorter side.
|
104 |
def crop(img: Image) -> Image:
|
@@ -190,6 +193,7 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
190 |
|
191 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
192 |
|
|
|
193 |
depth_dam2 = dam2.infer_image(image_bgr)
|
194 |
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
|
195 |
|
@@ -202,6 +206,9 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
202 |
norm_disp = normalize_disp(depth)
|
203 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
204 |
|
|
|
|
|
|
|
205 |
renders = genstereo(
|
206 |
src_image=image,
|
207 |
src_disparity=disp,
|
@@ -231,4 +238,4 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
231 |
)
|
232 |
|
233 |
if __name__ == '__main__':
|
234 |
-
demo.launch()
|
|
|
66 |
download_models()
|
67 |
|
68 |
# DepthAnythingV2
|
69 |
+
def get_dam2_model():
|
70 |
model_configs = {
|
71 |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
72 |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
|
|
83 |
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
|
84 |
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
|
85 |
dam2 = dam2.to(DEVICE).eval()
|
86 |
+
return dam2
|
87 |
|
88 |
# GenStereo
|
89 |
+
def get_genstereo_model():
|
90 |
genwarp_cfg = dict(
|
91 |
pretrained_model_path='checkpoints',
|
92 |
checkpoint_name=CHECKPOINT_NAME,
|
93 |
half_precision_weights=True
|
94 |
)
|
95 |
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
|
96 |
+
return genstereo
|
97 |
|
98 |
# Adaptive Fusion
|
99 |
+
def get_fusion_model():
|
100 |
fusion_model = AdaptiveFusionLayer()
|
101 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
102 |
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
103 |
fusion_model = fusion_model.to(DEVICE).eval()
|
104 |
+
return fusion_model
|
105 |
|
106 |
# Crop the image to the shorter side.
|
107 |
def crop(img: Image) -> Image:
|
|
|
193 |
|
194 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
195 |
|
196 |
+
dam2 = get_dam2_model()
|
197 |
depth_dam2 = dam2.infer_image(image_bgr)
|
198 |
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
|
199 |
|
|
|
206 |
norm_disp = normalize_disp(depth)
|
207 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
208 |
|
209 |
+
genstereo = get_genstereo_model()
|
210 |
+
fusion_model = get_fusion_model()
|
211 |
+
|
212 |
renders = genstereo(
|
213 |
src_image=image,
|
214 |
src_disparity=disp,
|
|
|
238 |
)
|
239 |
|
240 |
if __name__ == '__main__':
|
241 |
+
demo.launch()
|