File size: 3,813 Bytes
319d3b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011fc1
319d3b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011fc1
 
319d3b5
 
 
 
 
 
 
1011fc1
319d3b5
 
1011fc1
319d3b5
1011fc1
 
 
 
319d3b5
 
 
e314b9a
319d3b5
 
1011fc1
319d3b5
 
1011fc1
319d3b5
1011fc1
 
 
 
319d3b5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, MODEL_DIR

TMP_DIR = "./__pycache__"

@dataclass
class Cfg:
    detector_weights: str
    checkpoint: str
    device: str = "cpu"
    with_persons: bool = True
    disable_faces: bool = False
    draw: bool = True

class ValidImgDetector:
    predictor = None

    def __init__(self):
        detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
        age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
        predictor_cfg = Cfg(detector_path, age_gender_path)
        self.predictor = Predictor(predictor_cfg)

    def _detect(
        self,
        image: np.ndarray,
        score_threshold: float,
        iou_threshold: float,
        mode: str,
        predictor: Predictor,
    ) -> np.ndarray:
        # input is RGB image, output must be RGB too
        predictor.detector.detector_kwargs["conf"] = score_threshold
        predictor.detector.detector_kwargs["iou"] = iou_threshold
        if mode == "Use persons and faces":
            use_persons = True
            disable_faces = False

        elif mode == "Use persons only":
            use_persons = True
            disable_faces = True

        elif mode == "Use faces only":
            use_persons = False
            disable_faces = False

        predictor.age_gender_model.meta.use_persons = use_persons
        predictor.age_gender_model.meta.disable_faces = disable_faces
        # image = image[:, :, ::-1]  # RGB -> BGR
        detected_objects, out_im = predictor.recognize(image)
        has_child, has_female, has_male = False, False, False
        if len(detected_objects.ages) > 0:
            has_child = min(detected_objects.ages) < 18
            has_female = "female" in detected_objects.genders
            has_male = "male" in detected_objects.genders

        return out_im[:, :, ::-1], has_child, has_female, has_male

    def valid_img(self, img_path):
        image = cv2.imread(img_path)
        return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)


def infer(photo: str):
    if is_url(photo):
        if os.path.exists(TMP_DIR):
            shutil.rmtree(TMP_DIR)

        photo = download_file(photo, f"{TMP_DIR}/download.jpg")

    detector = ValidImgDetector()
    if not photo or not os.path.exists(photo) or imghdr.what(photo) is None:
        return None, None, None, "Please input the image correctly"

    return detector.valid_img(photo)


if __name__ == "__main__":
    with gr.Blocks() as iface:
        warnings.filterwarnings("ignore")
        with gr.Tab("Upload Mode"):
            gr.Interface(
                fn=infer,
                inputs=gr.Image(label="Upload Photo", type="filepath"),
                outputs=[
                    gr.Image(label="Detection Result", type="numpy"),
                    gr.Textbox(label="Has Child"),
                    gr.Textbox(label="Has Female"),
                    gr.Textbox(label="Has Male"),
                ],
                examples=get_jpg_files(f"{MODEL_DIR}/examples"),
                allow_flagging="never",
                cache_examples=False,
            )

        with gr.Tab("Online Mode"):
            gr.Interface(
                fn=infer,
                inputs=gr.Textbox(label="Online Picture URL"),
                outputs=[
                    gr.Image(label="Detection Result", type="numpy"),
                    gr.Textbox(label="Has Child"),
                    gr.Textbox(label="Has Female"),
                    gr.Textbox(label="Has Male"),
                ],
                allow_flagging="never",
            )

    iface.launch()