File size: 3,551 Bytes
e70400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import cv2
import numpy as np
import dlib
import torch
import torch.nn.functional as F

AGE_ESTIMATION_MARGIN = 0.4
AGE_ESTIMATION_INPUT_SIZE = 224

@torch.inference_mode()
def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGIN, input_size=AGE_ESTIMATION_INPUT_SIZE):
    """
    Predicts the age of faces in an image.

    Args:
        image (numpy.ndarray): The image as a NumPy array (HWC, BGR).
        model (torch.nn.Module): The age estimation model.
        face_detector (dlib.detector): The dlib face detector.
        device (torch.device): The device to run the model on.
        margin (float): The margin to add around the detected face.
        input_size (int): The size of the input image for the model.

    Returns:
        list: A list of dictionaries containing the age and face coordinates for each detected face.
    """
    # Read the image using OpenCV
    # The image is already a NumPy array (HWC, BGR)
    # Ensure it's in the correct color space if needed by dlib or subsequent steps
    # dlib's detector can work on grayscale or RGB. The current code uses the BGR array directly.
    # Let's keep it as is for now, assuming the input array is BGR as produced by cv2 or similar.
    # If preprocess_image returns RGB, we might need a conversion here or in preprocess_image.
    # Checking utils/image_utils.py, preprocess_image converts to RGB PIL, then to numpy array.
    # PIL to numpy conversion usually results in RGB. cv2 expects BGR.
    # Let's convert the input image (assumed RGB from preprocess_image) to BGR for cv2 operations.
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    image_h, image_w = image.shape[:2]

    # Detect faces in the image using the dlib face detector
    detected = face_detector(image, 3)
    faces = np.empty((len(detected), input_size, input_size, 3))
    age_data = []

    # Process each detected face
    if len(detected) > 0:
        for i, d in enumerate(detected):
            # Get face coordinates and dimensions
            x1, y1, x2, y2, w, h = d.left(), d.top(
            ), d.right() + 1, d.bottom() + 1, d.width(), d.height()
            
            # Calculate expanded face region with margin
            xw1 = max(int(x1 - margin * w), 0)
            yw1 = max(int(y1 - margin * h), 0)
            xw2 = min(int(x2 + margin * w), image_w - 1)
            yw2 = min(int(y2 + margin * h), image_h - 1)

            # Resize face image to the required input size for the model
            faces[i] = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1],
                                  (input_size, input_size))

            # Draw rectangles around the detected face and the expanded region
            cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
            cv2.rectangle(image, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)

        # Prepare face images for model input
        inputs = torch.from_numpy(
            np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device)
        
        # Perform age prediction using the model
        outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
        ages = np.arange(0, 101)
        predicted_ages = (outputs * ages).sum(axis=-1)

        # Store the predicted age and face coordinates
        for age, d in zip(predicted_ages, detected):
            age_text = f'{int(age)}'
            age_data.append({'age': int(age), 'text': age_text, 'face_coordinates': (d.left(), d.top())})
    
    # Return the list of age data for each detected face
    return age_data