Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,551 Bytes
e70400c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import cv2
import numpy as np
import dlib
import torch
import torch.nn.functional as F
AGE_ESTIMATION_MARGIN = 0.4
AGE_ESTIMATION_INPUT_SIZE = 224
@torch.inference_mode()
def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGIN, input_size=AGE_ESTIMATION_INPUT_SIZE):
"""
Predicts the age of faces in an image.
Args:
image (numpy.ndarray): The image as a NumPy array (HWC, BGR).
model (torch.nn.Module): The age estimation model.
face_detector (dlib.detector): The dlib face detector.
device (torch.device): The device to run the model on.
margin (float): The margin to add around the detected face.
input_size (int): The size of the input image for the model.
Returns:
list: A list of dictionaries containing the age and face coordinates for each detected face.
"""
# Read the image using OpenCV
# The image is already a NumPy array (HWC, BGR)
# Ensure it's in the correct color space if needed by dlib or subsequent steps
# dlib's detector can work on grayscale or RGB. The current code uses the BGR array directly.
# Let's keep it as is for now, assuming the input array is BGR as produced by cv2 or similar.
# If preprocess_image returns RGB, we might need a conversion here or in preprocess_image.
# Checking utils/image_utils.py, preprocess_image converts to RGB PIL, then to numpy array.
# PIL to numpy conversion usually results in RGB. cv2 expects BGR.
# Let's convert the input image (assumed RGB from preprocess_image) to BGR for cv2 operations.
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image_h, image_w = image.shape[:2]
# Detect faces in the image using the dlib face detector
detected = face_detector(image, 3)
faces = np.empty((len(detected), input_size, input_size, 3))
age_data = []
# Process each detected face
if len(detected) > 0:
for i, d in enumerate(detected):
# Get face coordinates and dimensions
x1, y1, x2, y2, w, h = d.left(), d.top(
), d.right() + 1, d.bottom() + 1, d.width(), d.height()
# Calculate expanded face region with margin
xw1 = max(int(x1 - margin * w), 0)
yw1 = max(int(y1 - margin * h), 0)
xw2 = min(int(x2 + margin * w), image_w - 1)
yw2 = min(int(y2 + margin * h), image_h - 1)
# Resize face image to the required input size for the model
faces[i] = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1],
(input_size, input_size))
# Draw rectangles around the detected face and the expanded region
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
cv2.rectangle(image, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
# Prepare face images for model input
inputs = torch.from_numpy(
np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device)
# Perform age prediction using the model
outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
ages = np.arange(0, 101)
predicted_ages = (outputs * ages).sum(axis=-1)
# Store the predicted age and face coordinates
for age, d in zip(predicted_ages, detected):
age_text = f'{int(age)}'
age_data.append({'age': int(age), 'text': age_text, 'face_coordinates': (d.left(), d.top())})
# Return the list of age data for each detected face
return age_data |