Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import numpy as np | |
import dlib | |
import torch | |
import torch.nn.functional as F | |
AGE_ESTIMATION_MARGIN = 0.4 | |
AGE_ESTIMATION_INPUT_SIZE = 224 | |
def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGIN, input_size=AGE_ESTIMATION_INPUT_SIZE): | |
""" | |
Predicts the age of faces in an image. | |
Args: | |
image (numpy.ndarray): The image as a NumPy array (HWC, BGR). | |
model (torch.nn.Module): The age estimation model. | |
face_detector (dlib.detector): The dlib face detector. | |
device (torch.device): The device to run the model on. | |
margin (float): The margin to add around the detected face. | |
input_size (int): The size of the input image for the model. | |
Returns: | |
list: A list of dictionaries containing the age and face coordinates for each detected face. | |
""" | |
# Read the image using OpenCV | |
# The image is already a NumPy array (HWC, BGR) | |
# Ensure it's in the correct color space if needed by dlib or subsequent steps | |
# dlib's detector can work on grayscale or RGB. The current code uses the BGR array directly. | |
# Let's keep it as is for now, assuming the input array is BGR as produced by cv2 or similar. | |
# If preprocess_image returns RGB, we might need a conversion here or in preprocess_image. | |
# Checking utils/image_utils.py, preprocess_image converts to RGB PIL, then to numpy array. | |
# PIL to numpy conversion usually results in RGB. cv2 expects BGR. | |
# Let's convert the input image (assumed RGB from preprocess_image) to BGR for cv2 operations. | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
image_h, image_w = image.shape[:2] | |
# Detect faces in the image using the dlib face detector | |
detected = face_detector(image, 3) | |
faces = np.empty((len(detected), input_size, input_size, 3)) | |
age_data = [] | |
# Process each detected face | |
if len(detected) > 0: | |
for i, d in enumerate(detected): | |
# Get face coordinates and dimensions | |
x1, y1, x2, y2, w, h = d.left(), d.top( | |
), d.right() + 1, d.bottom() + 1, d.width(), d.height() | |
# Calculate expanded face region with margin | |
xw1 = max(int(x1 - margin * w), 0) | |
yw1 = max(int(y1 - margin * h), 0) | |
xw2 = min(int(x2 + margin * w), image_w - 1) | |
yw2 = min(int(y2 + margin * h), image_h - 1) | |
# Resize face image to the required input size for the model | |
faces[i] = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1], | |
(input_size, input_size)) | |
# Draw rectangles around the detected face and the expanded region | |
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
cv2.rectangle(image, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2) | |
# Prepare face images for model input | |
inputs = torch.from_numpy( | |
np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device) | |
# Perform age prediction using the model | |
outputs = F.softmax(model(inputs), dim=-1).cpu().numpy() | |
ages = np.arange(0, 101) | |
predicted_ages = (outputs * ages).sum(axis=-1) | |
# Store the predicted age and face coordinates | |
for age, d in zip(predicted_ages, detected): | |
age_text = f'{int(age)}' | |
age_data.append({'age': int(age), 'text': age_text, 'face_coordinates': (d.left(), d.top())}) | |
# Return the list of age data for each detected face | |
return age_data |