zerchen commited on
Commit
9757ba7
·
1 Parent(s): 7124e54

update app

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -15,6 +15,7 @@ import json
15
  from torchvision import transforms
16
  from typing import Dict, Optional
17
  from PIL import Image, ImageDraw
 
18
  from lang_sam import LangSAM
19
 
20
  from wilor.models import load_wilor
@@ -27,15 +28,18 @@ from ultralytics import YOLO
27
  LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
28
  STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824)
29
 
 
 
 
30
  # Download and load checkpoints
31
- wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
32
  hand_detector = YOLO('./pretrained_models/detector.pt')
33
  # Setup the renderer
34
  renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces)
35
  # Setup the SAM model
36
  sam_model = LangSAM(sam_type="sam2.1_hiera_large")
37
  # Setup the HORT model
38
- hort_model = load_hort("./pretrained_models/hort_final.pth.tar")
39
 
40
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
41
  wilor_model = wilor_model.to(device)
 
15
  from torchvision import transforms
16
  from typing import Dict, Optional
17
  from PIL import Image, ImageDraw
18
+ from huggingface_hub import hf_hub_download
19
  from lang_sam import LangSAM
20
 
21
  from wilor.models import load_wilor
 
28
  LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
29
  STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824)
30
 
31
+ wilor_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="wilor_final.ckpt")
32
+ hort_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="hort_final.pth.tar")
33
+
34
  # Download and load checkpoints
35
+ wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = wilor_checkpoint_path, cfg_path= './pretrained_models/model_config.yaml')
36
  hand_detector = YOLO('./pretrained_models/detector.pt')
37
  # Setup the renderer
38
  renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces)
39
  # Setup the SAM model
40
  sam_model = LangSAM(sam_type="sam2.1_hiera_large")
41
  # Setup the HORT model
42
+ hort_model = load_hort(hort_checkpoint_path)
43
 
44
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
45
  wilor_model = wilor_model.to(device)