Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
34e2c3f
1
Parent(s):
098cfe5
wip
Browse files- age_estimation/age_estimation.py +13 -6
- age_estimation/model.py +1 -0
- age_estimation/predict.py +43 -13
- app.py +8 -7
- detection/face_detection.py +29 -14
- detection/object_detection.py +8 -7
age_estimation/age_estimation.py
CHANGED
@@ -22,14 +22,18 @@ def age_estimation(input_type, uploaded_image, image_url, base64_string):
|
|
22 |
base64_string (str): The image base64 string (if input_type is "Enter Base64").
|
23 |
|
24 |
Returns:
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
"""
|
27 |
# Use the centralized function to get the image
|
28 |
image = get_image_from_input(input_type, uploaded_image, image_url, base64_string)
|
29 |
|
30 |
if image is None:
|
31 |
print("Image is None after loading/selection for age estimation.")
|
32 |
-
return "Error: Image processing failed or no valid input provided."
|
33 |
|
34 |
try:
|
35 |
face_detector = load_face_detector()
|
@@ -44,10 +48,13 @@ def age_estimation(input_type, uploaded_image, image_url, base64_string):
|
|
44 |
age_data = predict_age(processed_image, model, face_detector, device)
|
45 |
|
46 |
if age_data:
|
47 |
-
#
|
48 |
-
|
|
|
|
|
|
|
49 |
else:
|
50 |
-
return "No faces detected"
|
51 |
except Exception as e:
|
52 |
print(f"Error in age estimation: {e}")
|
53 |
-
return f"Error in age estimation: {e}"
|
|
|
22 |
base64_string (str): The image base64 string (if input_type is "Enter Base64").
|
23 |
|
24 |
Returns:
|
25 |
+
tuple: A tuple containing:
|
26 |
+
- str: A summary string of the estimated ages, or an error message.
|
27 |
+
- list: A list of dictionaries, where each dictionary represents the age
|
28 |
+
estimation data for a detected face, or an empty list if no faces
|
29 |
+
were detected or an error occurred.
|
30 |
"""
|
31 |
# Use the centralized function to get the image
|
32 |
image = get_image_from_input(input_type, uploaded_image, image_url, base64_string)
|
33 |
|
34 |
if image is None:
|
35 |
print("Image is None after loading/selection for age estimation.")
|
36 |
+
return "Error: Image processing failed or no valid input provided.", []
|
37 |
|
38 |
try:
|
39 |
face_detector = load_face_detector()
|
|
|
48 |
age_data = predict_age(processed_image, model, face_detector, device)
|
49 |
|
50 |
if age_data:
|
51 |
+
# Create a summary string of all estimated ages
|
52 |
+
age_summary = "Estimated Ages: " + ", ".join(
|
53 |
+
[str(face["age"]) for face in age_data]
|
54 |
+
)
|
55 |
+
return age_summary, age_data
|
56 |
else:
|
57 |
+
return "No faces detected", []
|
58 |
except Exception as e:
|
59 |
print(f"Error in age estimation: {e}")
|
60 |
+
return f"Error in age estimation: {e}", []
|
age_estimation/model.py
CHANGED
@@ -3,6 +3,7 @@ import pretrainedmodels
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
|
|
6 |
def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
|
7 |
"""
|
8 |
Loads a pre-trained model.
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
|
7 |
def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
|
8 |
"""
|
9 |
Loads a pre-trained model.
|
age_estimation/predict.py
CHANGED
@@ -7,8 +7,16 @@ import torch.nn.functional as F
|
|
7 |
AGE_ESTIMATION_MARGIN = 0.4
|
8 |
AGE_ESTIMATION_INPUT_SIZE = 224
|
9 |
|
|
|
10 |
@torch.inference_mode()
|
11 |
-
def predict_age(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
Predicts the age of faces in an image.
|
14 |
|
@@ -22,6 +30,8 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
|
|
22 |
|
23 |
Returns:
|
24 |
list: A list of dictionaries containing the age and face coordinates for each detected face.
|
|
|
|
|
25 |
"""
|
26 |
# Read the image using OpenCV
|
27 |
# The image is already a NumPy array (HWC, BGR)
|
@@ -44,9 +54,15 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
|
|
44 |
if len(detected) > 0:
|
45 |
for i, d in enumerate(detected):
|
46 |
# Get face coordinates and dimensions
|
47 |
-
x1, y1, x2, y2, w, h =
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Calculate expanded face region with margin
|
51 |
xw1 = max(int(x1 - margin * w), 0)
|
52 |
yw1 = max(int(y1 - margin * h), 0)
|
@@ -54,8 +70,9 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
|
|
54 |
yw2 = min(int(y2 + margin * h), image_h - 1)
|
55 |
|
56 |
# Resize face image to the required input size for the model
|
57 |
-
faces[i] = cv2.resize(
|
58 |
-
|
|
|
59 |
|
60 |
# Draw rectangles around the detected face and the expanded region
|
61 |
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
|
@@ -63,17 +80,30 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
|
|
63 |
|
64 |
# Prepare face images for model input
|
65 |
inputs = torch.from_numpy(
|
66 |
-
np.transpose(faces.astype(np.float32), (0, 3, 1, 2))
|
67 |
-
|
|
|
68 |
# Perform age prediction using the model
|
69 |
outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
|
70 |
ages = np.arange(0, 101)
|
71 |
predicted_ages = (outputs * ages).sum(axis=-1)
|
72 |
|
73 |
-
# Store the predicted age and face coordinates
|
74 |
for age, d in zip(predicted_ages, detected):
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
# Return the list of age data for each detected face
|
79 |
-
return age_data
|
|
|
7 |
AGE_ESTIMATION_MARGIN = 0.4
|
8 |
AGE_ESTIMATION_INPUT_SIZE = 224
|
9 |
|
10 |
+
|
11 |
@torch.inference_mode()
|
12 |
+
def predict_age(
|
13 |
+
image,
|
14 |
+
model,
|
15 |
+
face_detector,
|
16 |
+
device,
|
17 |
+
margin=AGE_ESTIMATION_MARGIN,
|
18 |
+
input_size=AGE_ESTIMATION_INPUT_SIZE,
|
19 |
+
):
|
20 |
"""
|
21 |
Predicts the age of faces in an image.
|
22 |
|
|
|
30 |
|
31 |
Returns:
|
32 |
list: A list of dictionaries containing the age and face coordinates for each detected face.
|
33 |
+
The 'face_coordinates' key contains a dictionary with 'x', 'y', 'w', and 'h' keys
|
34 |
+
representing the bounding box of the detected face.
|
35 |
"""
|
36 |
# Read the image using OpenCV
|
37 |
# The image is already a NumPy array (HWC, BGR)
|
|
|
54 |
if len(detected) > 0:
|
55 |
for i, d in enumerate(detected):
|
56 |
# Get face coordinates and dimensions
|
57 |
+
x1, y1, x2, y2, w, h = (
|
58 |
+
d.left(),
|
59 |
+
d.top(),
|
60 |
+
d.right() + 1,
|
61 |
+
d.bottom() + 1,
|
62 |
+
d.width(),
|
63 |
+
d.height(),
|
64 |
+
)
|
65 |
+
|
66 |
# Calculate expanded face region with margin
|
67 |
xw1 = max(int(x1 - margin * w), 0)
|
68 |
yw1 = max(int(y1 - margin * h), 0)
|
|
|
70 |
yw2 = min(int(y2 + margin * h), image_h - 1)
|
71 |
|
72 |
# Resize face image to the required input size for the model
|
73 |
+
faces[i] = cv2.resize(
|
74 |
+
image[yw1 : yw2 + 1, xw1 : xw2 + 1], (input_size, input_size)
|
75 |
+
)
|
76 |
|
77 |
# Draw rectangles around the detected face and the expanded region
|
78 |
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
|
|
|
80 |
|
81 |
# Prepare face images for model input
|
82 |
inputs = torch.from_numpy(
|
83 |
+
np.transpose(faces.astype(np.float32), (0, 3, 1, 2))
|
84 |
+
).to(device)
|
85 |
+
|
86 |
# Perform age prediction using the model
|
87 |
outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
|
88 |
ages = np.arange(0, 101)
|
89 |
predicted_ages = (outputs * ages).sum(axis=-1)
|
90 |
|
91 |
+
# Store the predicted age and face coordinates in [x, y, w, h] format
|
92 |
for age, d in zip(predicted_ages, detected):
|
93 |
+
x, y, w, h = d.left(), d.top(), d.width(), d.height()
|
94 |
+
age_text = f"{int(age)}"
|
95 |
+
age_data.append(
|
96 |
+
{
|
97 |
+
"age": int(age),
|
98 |
+
"text": age_text,
|
99 |
+
"face_coordinates": {
|
100 |
+
"x": int(x),
|
101 |
+
"y": int(y),
|
102 |
+
"w": int(w),
|
103 |
+
"h": int(h),
|
104 |
+
},
|
105 |
+
}
|
106 |
+
)
|
107 |
+
|
108 |
# Return the list of age data for each detected face
|
109 |
+
return age_data
|
app.py
CHANGED
@@ -54,7 +54,7 @@ with gr.Blocks() as demo:
|
|
54 |
inputs=[face_input_type],
|
55 |
outputs=[face_img_upload, face_url_input, face_base64_input],
|
56 |
queue=False,
|
57 |
-
api_name=False
|
58 |
)
|
59 |
|
60 |
# Link process button to the face detection function
|
@@ -94,8 +94,9 @@ with gr.Blocks() as demo:
|
|
94 |
# Process Button
|
95 |
age_process_btn = gr.Button("Estimate Age")
|
96 |
|
97 |
-
# Output
|
98 |
-
age_text_output = gr.Textbox(label="Estimated Age")
|
|
|
99 |
|
100 |
# Link radio button change to visibility update function
|
101 |
age_input_type.change(
|
@@ -103,15 +104,15 @@ with gr.Blocks() as demo:
|
|
103 |
inputs=[age_input_type],
|
104 |
outputs=[age_img_upload, age_url_input, age_base64_input],
|
105 |
queue=False,
|
106 |
-
api_name=False
|
107 |
)
|
108 |
|
109 |
# Link process button to the age estimation function
|
110 |
-
# The age_estimation function will
|
111 |
age_process_btn.click(
|
112 |
fn=age_estimation,
|
113 |
inputs=[age_input_type, age_img_upload, age_url_input, age_base64_input],
|
114 |
-
outputs=age_text_output,
|
115 |
)
|
116 |
# Create a tab for object detection
|
117 |
with gr.Tab("Object Detection"):
|
@@ -146,7 +147,7 @@ with gr.Blocks() as demo:
|
|
146 |
inputs=[obj_input_type],
|
147 |
outputs=[obj_img_upload, obj_url_input, obj_base64_input],
|
148 |
queue=False,
|
149 |
-
api_name=False
|
150 |
)
|
151 |
|
152 |
# Link process button to the object detection function
|
|
|
54 |
inputs=[face_input_type],
|
55 |
outputs=[face_img_upload, face_url_input, face_base64_input],
|
56 |
queue=False,
|
57 |
+
api_name=False,
|
58 |
)
|
59 |
|
60 |
# Link process button to the face detection function
|
|
|
94 |
# Process Button
|
95 |
age_process_btn = gr.Button("Estimate Age")
|
96 |
|
97 |
+
# Output Components
|
98 |
+
age_text_output = gr.Textbox(label="Estimated Age Summary")
|
99 |
+
age_raw_output = gr.JSON(label="Raw Age Estimation Data")
|
100 |
|
101 |
# Link radio button change to visibility update function
|
102 |
age_input_type.change(
|
|
|
104 |
inputs=[age_input_type],
|
105 |
outputs=[age_img_upload, age_url_input, age_base64_input],
|
106 |
queue=False,
|
107 |
+
api_name=False,
|
108 |
)
|
109 |
|
110 |
# Link process button to the age estimation function
|
111 |
+
# The age_estimation function will now return a tuple
|
112 |
age_process_btn.click(
|
113 |
fn=age_estimation,
|
114 |
inputs=[age_input_type, age_img_upload, age_url_input, age_base64_input],
|
115 |
+
outputs=[age_text_output, age_raw_output],
|
116 |
)
|
117 |
# Create a tab for object detection
|
118 |
with gr.Tab("Object Detection"):
|
|
|
147 |
inputs=[obj_input_type],
|
148 |
outputs=[obj_img_upload, obj_url_input, obj_base64_input],
|
149 |
queue=False,
|
150 |
+
api_name=False,
|
151 |
)
|
152 |
|
153 |
# Link process button to the object detection function
|
detection/face_detection.py
CHANGED
@@ -8,12 +8,17 @@ from PIL import Image
|
|
8 |
|
9 |
# Local imports
|
10 |
from utils.image_utils import load_image, preprocess_image, get_image_from_input
|
11 |
-
from utils.face_detector import
|
|
|
|
|
12 |
|
13 |
# Define constants
|
14 |
HAAR_CASCADE_FILENAME = "haarcascade_frontalface_default.xml"
|
15 |
|
16 |
-
|
|
|
|
|
|
|
17 |
"""
|
18 |
Performs face detection on the image from various input types using the selected method.
|
19 |
|
@@ -36,7 +41,7 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
|
|
36 |
|
37 |
if image is None:
|
38 |
print("Image is None after loading/selection.")
|
39 |
-
return None, []
|
40 |
|
41 |
processed_image = None
|
42 |
bounding_boxes = []
|
@@ -54,20 +59,26 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
|
|
54 |
# Ensure the haarcascade file is accessible.
|
55 |
# This path might need adjustment depending on the environment.
|
56 |
# Construct the full path to the Haar cascade file
|
57 |
-
cascade_path = os.path.join(
|
|
|
|
|
58 |
|
59 |
# Check if the cascade file exists
|
60 |
if not os.path.exists(cascade_path):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
face_cascade = cv2.CascadeClassifier(cascade_path)
|
66 |
|
67 |
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
|
68 |
for x, y, w, h in faces:
|
69 |
-
cv2.rectangle(
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
elif face_detection_method == "dlib":
|
73 |
print("Using dlib for face detection.")
|
@@ -75,15 +86,19 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
|
|
75 |
# dlib works on RGB images, but the detector can take grayscale
|
76 |
# However, the rectangles are relative to the original image size
|
77 |
# Let's use the original processed_image (RGB numpy array) for drawing
|
78 |
-
faces = face_detector(processed_image, 1)
|
79 |
for face in faces:
|
80 |
x, y, w, h = face.left(), face.top(), face.width(), face.height()
|
81 |
-
cv2.rectangle(
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
|
84 |
return processed_image, bounding_boxes
|
85 |
else:
|
86 |
-
return None, []
|
87 |
except Exception as e:
|
88 |
print(f"Error in face detection processing: {e}")
|
89 |
-
return None, []
|
|
|
8 |
|
9 |
# Local imports
|
10 |
from utils.image_utils import load_image, preprocess_image, get_image_from_input
|
11 |
+
from utils.face_detector import (
|
12 |
+
load_face_detector,
|
13 |
+
) # Assuming this is the dlib detector loader
|
14 |
|
15 |
# Define constants
|
16 |
HAAR_CASCADE_FILENAME = "haarcascade_frontalface_default.xml"
|
17 |
|
18 |
+
|
19 |
+
def face_detection(
|
20 |
+
input_type, uploaded_image, image_url, base64_string, face_detection_method
|
21 |
+
):
|
22 |
"""
|
23 |
Performs face detection on the image from various input types using the selected method.
|
24 |
|
|
|
41 |
|
42 |
if image is None:
|
43 |
print("Image is None after loading/selection.")
|
44 |
+
return None, [] # Return None for image and empty list for bboxes
|
45 |
|
46 |
processed_image = None
|
47 |
bounding_boxes = []
|
|
|
59 |
# Ensure the haarcascade file is accessible.
|
60 |
# This path might need adjustment depending on the environment.
|
61 |
# Construct the full path to the Haar cascade file
|
62 |
+
cascade_path = os.path.join(
|
63 |
+
cv2.data.haarcascades, HAAR_CASCADE_FILENAME
|
64 |
+
)
|
65 |
|
66 |
# Check if the cascade file exists
|
67 |
if not os.path.exists(cascade_path):
|
68 |
+
error_message = f"Error: Haar cascade file not found at {cascade_path}. Please ensure OpenCV is installed correctly and the file exists."
|
69 |
+
print(error_message)
|
70 |
+
return None, [] # Return None for image and empty list for bboxes
|
71 |
|
72 |
face_cascade = cv2.CascadeClassifier(cascade_path)
|
73 |
|
74 |
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
|
75 |
for x, y, w, h in faces:
|
76 |
+
cv2.rectangle(
|
77 |
+
processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2
|
78 |
+
)
|
79 |
+
bounding_boxes.append(
|
80 |
+
{"x": int(x), "y": int(y), "w": int(w), "h": int(h)}
|
81 |
+
)
|
82 |
|
83 |
elif face_detection_method == "dlib":
|
84 |
print("Using dlib for face detection.")
|
|
|
86 |
# dlib works on RGB images, but the detector can take grayscale
|
87 |
# However, the rectangles are relative to the original image size
|
88 |
# Let's use the original processed_image (RGB numpy array) for drawing
|
89 |
+
faces = face_detector(processed_image, 1) # 1 is the upsample level
|
90 |
for face in faces:
|
91 |
x, y, w, h = face.left(), face.top(), face.width(), face.height()
|
92 |
+
cv2.rectangle(
|
93 |
+
processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2
|
94 |
+
)
|
95 |
+
bounding_boxes.append(
|
96 |
+
{"x": int(x), "y": int(y), "w": int(w), "h": int(h)}
|
97 |
+
)
|
98 |
|
99 |
return processed_image, bounding_boxes
|
100 |
else:
|
101 |
+
return None, [] # Return None for image and empty list for bboxes
|
102 |
except Exception as e:
|
103 |
print(f"Error in face detection processing: {e}")
|
104 |
+
return None, [] # Return None for image and empty list for bboxes
|
detection/object_detection.py
CHANGED
@@ -8,6 +8,7 @@ import numpy as np
|
|
8 |
# Local imports
|
9 |
from utils.image_utils import load_image, preprocess_image
|
10 |
|
|
|
11 |
def object_detection(input_type, uploaded_image, image_url, base64_string):
|
12 |
"""
|
13 |
Performs object detection on the image from various input types.
|
@@ -25,26 +26,26 @@ def object_detection(input_type, uploaded_image, image_url, base64_string):
|
|
25 |
input_value = None
|
26 |
|
27 |
if input_type == "Upload File" and uploaded_image is not None:
|
28 |
-
image = uploaded_image
|
29 |
-
print("Using uploaded image (PIL) for object detection")
|
30 |
|
31 |
elif input_type == "Enter URL" and image_url and image_url.strip():
|
32 |
input_value = image_url
|
33 |
-
print(f"Using URL for object detection: {input_value}")
|
34 |
|
35 |
elif input_type == "Enter Base64" and base64_string and base64_string.strip():
|
36 |
input_value = base64_string
|
37 |
-
print(f"Using Base64 string for object detection")
|
38 |
|
39 |
else:
|
40 |
print("No valid input provided for object detection based on selected type.")
|
41 |
-
return None
|
42 |
|
43 |
# If input_value is set (URL or Base64), use load_image
|
44 |
if input_value:
|
45 |
image = load_image(input_value)
|
46 |
if image is None:
|
47 |
-
return None
|
48 |
|
49 |
# Now 'image' should be a PIL Image or None
|
50 |
if image is None:
|
@@ -62,4 +63,4 @@ def object_detection(input_type, uploaded_image, image_url, base64_string):
|
|
62 |
return processed_image
|
63 |
except Exception as e:
|
64 |
print(f"Error in object detection processing: {e}")
|
65 |
-
return None
|
|
|
8 |
# Local imports
|
9 |
from utils.image_utils import load_image, preprocess_image
|
10 |
|
11 |
+
|
12 |
def object_detection(input_type, uploaded_image, image_url, base64_string):
|
13 |
"""
|
14 |
Performs object detection on the image from various input types.
|
|
|
26 |
input_value = None
|
27 |
|
28 |
if input_type == "Upload File" and uploaded_image is not None:
|
29 |
+
image = uploaded_image # This is a PIL Image
|
30 |
+
print("Using uploaded image (PIL) for object detection") # Debug print
|
31 |
|
32 |
elif input_type == "Enter URL" and image_url and image_url.strip():
|
33 |
input_value = image_url
|
34 |
+
print(f"Using URL for object detection: {input_value}") # Debug print
|
35 |
|
36 |
elif input_type == "Enter Base64" and base64_string and base64_string.strip():
|
37 |
input_value = base64_string
|
38 |
+
print(f"Using Base64 string for object detection") # Debug print
|
39 |
|
40 |
else:
|
41 |
print("No valid input provided for object detection based on selected type.")
|
42 |
+
return None # No valid input
|
43 |
|
44 |
# If input_value is set (URL or Base64), use load_image
|
45 |
if input_value:
|
46 |
image = load_image(input_value)
|
47 |
if image is None:
|
48 |
+
return None # load_image failed
|
49 |
|
50 |
# Now 'image' should be a PIL Image or None
|
51 |
if image is None:
|
|
|
63 |
return processed_image
|
64 |
except Exception as e:
|
65 |
print(f"Error in object detection processing: {e}")
|
66 |
+
return None
|