hysts HF Staff commited on
Commit
45bd6c9
·
1 Parent(s): 6be1911
Files changed (6) hide show
  1. .gitignore +1 -0
  2. .gitmodules +6 -0
  3. app.py +157 -0
  4. face_alignment +1 -0
  5. face_detection +1 -0
  6. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
4
+ [submodule "face_alignment"]
5
+ path = face_alignment
6
+ url = https://github.com/ibug-group/face_alignment
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ import numpy as np
16
+ import torch
17
+
18
+ sys.path.insert(0, 'face_detection')
19
+ sys.path.insert(0, 'face_alignment')
20
+
21
+ from ibug.face_alignment import FANPredictor
22
+ from ibug.face_detection import RetinaFacePredictor
23
+
24
+ REPO_URL = 'https://github.com/ibug-group/face_alignment'
25
+ TITLE = 'ibug-group/face_alignment'
26
+ DESCRIPTION = f'This is a demo for {REPO_URL}.'
27
+ ARTICLE = None
28
+
29
+ TOKEN = os.environ['TOKEN']
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--device', type=str, default='cpu')
35
+ parser.add_argument('--theme', type=str)
36
+ parser.add_argument('--live', action='store_true')
37
+ parser.add_argument('--share', action='store_true')
38
+ parser.add_argument('--port', type=int)
39
+ parser.add_argument('--disable-queue',
40
+ dest='enable_queue',
41
+ action='store_false')
42
+ parser.add_argument('--allow-flagging', type=str, default='never')
43
+ parser.add_argument('--allow-screenshot', action='store_true')
44
+ return parser.parse_args()
45
+
46
+
47
+ def load_sample_images() -> list[pathlib.Path]:
48
+ image_dir = pathlib.Path('images')
49
+ if not image_dir.exists():
50
+ image_dir.mkdir()
51
+ dataset_repo = 'hysts/input-images'
52
+ filenames = ['001.tar']
53
+ for name in filenames:
54
+ path = huggingface_hub.hf_hub_download(dataset_repo,
55
+ name,
56
+ repo_type='dataset',
57
+ use_auth_token=TOKEN)
58
+ with tarfile.open(path) as f:
59
+ f.extractall(image_dir.as_posix())
60
+ return sorted(image_dir.rglob('*.jpg'))
61
+
62
+
63
+ def load_detector(device: torch.device) -> RetinaFacePredictor:
64
+ model = RetinaFacePredictor(
65
+ threshold=0.8,
66
+ device=device,
67
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
68
+ return model
69
+
70
+
71
+ def load_model(model_name: str, device: torch.device) -> FANPredictor:
72
+ model = FANPredictor(device=device,
73
+ model=FANPredictor.get_model(model_name))
74
+ return model
75
+
76
+
77
+ def predict(image: np.ndarray, model_name: str, max_num_faces: int,
78
+ landmark_score_threshold: int, detector: RetinaFacePredictor,
79
+ models: dict[str, FANPredictor]) -> np.ndarray:
80
+ model = models[model_name]
81
+
82
+ # RGB -> BGR
83
+ image = image[:, :, ::-1]
84
+
85
+ faces = detector(image, rgb=False)
86
+ if len(faces) == 0:
87
+ raise RuntimeError('No face was found.')
88
+ faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
89
+ faces = np.asarray(faces)
90
+ landmarks, landmark_scores = model(image, faces, rgb=False)
91
+
92
+ res = image.copy()
93
+ for face, pts, scores in zip(faces, landmarks, landmark_scores):
94
+ box = np.round(face[:4]).astype(int)
95
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)
96
+ for pt, score in zip(np.round(pts).astype(int), scores):
97
+ if score < landmark_score_threshold:
98
+ continue
99
+ cv2.circle(res, tuple(pt), 2, (0, 255, 0), cv2.FILLED)
100
+
101
+ return res[:, :, ::-1]
102
+
103
+
104
+ def main():
105
+ gr.close_all()
106
+
107
+ args = parse_args()
108
+ device = torch.device(args.device)
109
+
110
+ detector = load_detector(device)
111
+
112
+ model_names = [
113
+ '2dfan2',
114
+ '2dfan4',
115
+ '2dfan2_alt',
116
+ ]
117
+ models = {name: load_model(name, device=device) for name in model_names}
118
+
119
+ func = functools.partial(predict, detector=detector, models=models)
120
+ func = functools.update_wrapper(func, predict)
121
+
122
+ image_paths = load_sample_images()
123
+ examples = [[path.as_posix(), model_names[0], 10, 0.2]
124
+ for path in image_paths]
125
+
126
+ gr.Interface(
127
+ func,
128
+ [
129
+ gr.inputs.Image(type='numpy', label='Input'),
130
+ gr.inputs.Radio(model_names,
131
+ type='value',
132
+ default=model_names[1],
133
+ label='Model'),
134
+ gr.inputs.Slider(
135
+ 1, 20, step=1, default=10, label='Max Number of Faces'),
136
+ gr.inputs.Slider(
137
+ 0, 1, step=0.05, default=0.2,
138
+ label='Landmark Score Threshold'),
139
+ ],
140
+ gr.outputs.Image(type='numpy', label='Output'),
141
+ examples=examples,
142
+ title=TITLE,
143
+ description=DESCRIPTION,
144
+ article=ARTICLE,
145
+ theme=args.theme,
146
+ allow_screenshot=args.allow_screenshot,
147
+ allow_flagging=args.allow_flagging,
148
+ live=args.live,
149
+ ).launch(
150
+ enable_queue=args.enable_queue,
151
+ server_port=args.port,
152
+ share=args.share,
153
+ )
154
+
155
+
156
+ if __name__ == '__main__':
157
+ main()
face_alignment ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit aef843c05be718fbd87ee2cb25fa3a015b7e59b0
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ torch==1.11.0
4
+ torchvision==0.12.0