shimu0215 commited on
Commit
8d9bcdc
·
verified ·
1 Parent(s): 77ffbb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -28,11 +28,17 @@ class Config:
28
 
29
  class ModelManager:
30
  @staticmethod
31
- def load_model(checkpoint_name: str):
32
- model_path = hf_hub_download(
33
- repo_id="shimu0215/seg",
34
- filename="sapiens_2b_render_people_epoch_25_torchscript.pt2",
 
35
  )
 
 
 
 
 
36
  model = torch.jit.load(model_path)
37
  model.eval()
38
  model.to("cuda")
@@ -60,7 +66,7 @@ class ImageProcessor:
60
  depth_map = depth_output.squeeze().cpu().numpy()
61
 
62
  if seg_model_name != "no-bg-removal":
63
- seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
64
  seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
65
  seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
66
  depth_map[seg_mask == 0] = np.nan
 
28
 
29
  class ModelManager:
30
  @staticmethod
31
+ def load_model(checkpoint_name: str, type='depth'):
32
+ if checkpoint_name == 'seg':
33
+ model_path = hf_hub_download(
34
+ repo_id="shimu0215/seg", # 你的模型仓库
35
+ filename="sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", # 你的模型文件
36
  )
37
+ else:
38
+ model_path = hf_hub_download(
39
+ repo_id="shimu0215/seg",
40
+ filename="sapiens_2b_render_people_epoch_25_torchscript.pt2",
41
+ )
42
  model = torch.jit.load(model_path)
43
  model.eval()
44
  model.to("cuda")
 
66
  depth_map = depth_output.squeeze().cpu().numpy()
67
 
68
  if seg_model_name != "no-bg-removal":
69
+ seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name],type='seg')
70
  seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
71
  seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
72
  depth_map[seg_mask == 0] = np.nan