Update app.py
Browse files
app.py
CHANGED
@@ -96,7 +96,7 @@ if 'genstereo' not in globals():
|
|
96 |
if 'fusion_model' not in globals():
|
97 |
fusion_model = AdaptiveFusionLayer()
|
98 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
99 |
-
fusion_model.load_state_dict(torch.load(fusion_checkpoint))
|
100 |
fusion_model = fusion_model.to(DEVICE).eval()
|
101 |
|
102 |
# Crop the image to the shorter side.
|
|
|
96 |
if 'fusion_model' not in globals():
|
97 |
fusion_model = AdaptiveFusionLayer()
|
98 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
99 |
+
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
100 |
fusion_model = fusion_model.to(DEVICE).eval()
|
101 |
|
102 |
# Crop the image to the shorter side.
|