hysts HF Staff commited on
Commit
54102f1
·
1 Parent(s): 75a9695
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -2,13 +2,13 @@
2
 
3
  from __future__ import annotations
4
 
5
- import os
6
  import pathlib
7
  import sys
8
 
9
  import cv2
10
  import gradio as gr
11
  import numpy as np
 
12
  import torch
13
 
14
  sys.path.insert(0, "face_detection")
@@ -20,15 +20,29 @@ from ibug.face_detection import RetinaFacePredictor
20
  DESCRIPTION = "# [ibug-group/face_alignment](https://github.com/ibug-group/face_alignment)"
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
- detector = RetinaFacePredictor(threshold=0.8, device=device, model=RetinaFacePredictor.get_model("mobilenet0.25"))
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  model_names = [
25
  "2dfan2",
26
  "2dfan4",
27
  "2dfan2_alt",
28
  ]
29
- models = {name: FANPredictor(device=device, model=FANPredictor.get_model(name)) for name in model_names}
30
 
31
 
 
32
  def predict(image: np.ndarray, model_name: str, max_num_faces: int, landmark_score_threshold: int) -> np.ndarray:
33
  model = models[model_name]
34
 
@@ -72,7 +86,6 @@ with gr.Blocks(css="style.css") as demo:
72
  inputs=[image, model_name, max_num_faces, landmark_score_thrshold],
73
  outputs=result,
74
  fn=predict,
75
- cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
76
  )
77
  run_button.click(
78
  fn=predict,
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pathlib
6
  import sys
7
 
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
11
+ import spaces
12
  import torch
13
 
14
  sys.path.insert(0, "face_detection")
 
20
  DESCRIPTION = "# [ibug-group/face_alignment](https://github.com/ibug-group/face_alignment)"
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
24
+ detector.device = device
25
+ detector.net.to(device)
26
+
27
+
28
+ def load_model(model_name: str, device: torch.device) -> FANPredictor:
29
+ model = FANPredictor(
30
+ device="cpu", model=FANPredictor.get_model(model_name), config=FANPredictor.create_config(use_jit=False)
31
+ )
32
+ model.device = device
33
+ model.net.to(device)
34
+ return model
35
+
36
+
37
  model_names = [
38
  "2dfan2",
39
  "2dfan4",
40
  "2dfan2_alt",
41
  ]
42
+ models = {name: load_model(name, device) for name in model_names}
43
 
44
 
45
+ @spaces.GPU
46
  def predict(image: np.ndarray, model_name: str, max_num_faces: int, landmark_score_threshold: int) -> np.ndarray:
47
  model = models[model_name]
48
 
 
86
  inputs=[image, model_name, max_num_faces, landmark_score_thrshold],
87
  outputs=result,
88
  fn=predict,
 
89
  )
90
  run_button.click(
91
  fn=predict,