Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -28,11 +28,17 @@ class Config:
|
|
28 |
|
29 |
class ModelManager:
|
30 |
@staticmethod
|
31 |
-
def load_model(checkpoint_name: str):
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
35 |
)
|
|
|
|
|
|
|
|
|
|
|
36 |
model = torch.jit.load(model_path)
|
37 |
model.eval()
|
38 |
model.to("cuda")
|
@@ -60,7 +66,7 @@ class ImageProcessor:
|
|
60 |
depth_map = depth_output.squeeze().cpu().numpy()
|
61 |
|
62 |
if seg_model_name != "no-bg-removal":
|
63 |
-
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
|
64 |
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
|
65 |
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
|
66 |
depth_map[seg_mask == 0] = np.nan
|
|
|
28 |
|
29 |
class ModelManager:
|
30 |
@staticmethod
|
31 |
+
def load_model(checkpoint_name: str, type='depth'):
|
32 |
+
if checkpoint_name == 'seg':
|
33 |
+
model_path = hf_hub_download(
|
34 |
+
repo_id="shimu0215/seg", # 你的模型仓库
|
35 |
+
filename="sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", # 你的模型文件
|
36 |
)
|
37 |
+
else:
|
38 |
+
model_path = hf_hub_download(
|
39 |
+
repo_id="shimu0215/seg",
|
40 |
+
filename="sapiens_2b_render_people_epoch_25_torchscript.pt2",
|
41 |
+
)
|
42 |
model = torch.jit.load(model_path)
|
43 |
model.eval()
|
44 |
model.to("cuda")
|
|
|
66 |
depth_map = depth_output.squeeze().cpu().numpy()
|
67 |
|
68 |
if seg_model_name != "no-bg-removal":
|
69 |
+
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name],type='seg')
|
70 |
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
|
71 |
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
|
72 |
depth_map[seg_mask == 0] = np.nan
|