rifatramadhani commited on
Commit
34e2c3f
·
1 Parent(s): 098cfe5
age_estimation/age_estimation.py CHANGED
@@ -22,14 +22,18 @@ def age_estimation(input_type, uploaded_image, image_url, base64_string):
22
  base64_string (str): The image base64 string (if input_type is "Enter Base64").
23
 
24
  Returns:
25
- str: The estimated age, or an error message.
 
 
 
 
26
  """
27
  # Use the centralized function to get the image
28
  image = get_image_from_input(input_type, uploaded_image, image_url, base64_string)
29
 
30
  if image is None:
31
  print("Image is None after loading/selection for age estimation.")
32
- return "Error: Image processing failed or no valid input provided."
33
 
34
  try:
35
  face_detector = load_face_detector()
@@ -44,10 +48,13 @@ def age_estimation(input_type, uploaded_image, image_url, base64_string):
44
  age_data = predict_age(processed_image, model, face_detector, device)
45
 
46
  if age_data:
47
- # Assuming age_data is a list of dictionaries, and we take the first face's age
48
- return f"Estimated Age: {age_data[0]['age']}"
 
 
 
49
  else:
50
- return "No faces detected"
51
  except Exception as e:
52
  print(f"Error in age estimation: {e}")
53
- return f"Error in age estimation: {e}"
 
22
  base64_string (str): The image base64 string (if input_type is "Enter Base64").
23
 
24
  Returns:
25
+ tuple: A tuple containing:
26
+ - str: A summary string of the estimated ages, or an error message.
27
+ - list: A list of dictionaries, where each dictionary represents the age
28
+ estimation data for a detected face, or an empty list if no faces
29
+ were detected or an error occurred.
30
  """
31
  # Use the centralized function to get the image
32
  image = get_image_from_input(input_type, uploaded_image, image_url, base64_string)
33
 
34
  if image is None:
35
  print("Image is None after loading/selection for age estimation.")
36
+ return "Error: Image processing failed or no valid input provided.", []
37
 
38
  try:
39
  face_detector = load_face_detector()
 
48
  age_data = predict_age(processed_image, model, face_detector, device)
49
 
50
  if age_data:
51
+ # Create a summary string of all estimated ages
52
+ age_summary = "Estimated Ages: " + ", ".join(
53
+ [str(face["age"]) for face in age_data]
54
+ )
55
+ return age_summary, age_data
56
  else:
57
+ return "No faces detected", []
58
  except Exception as e:
59
  print(f"Error in age estimation: {e}")
60
+ return f"Error in age estimation: {e}", []
age_estimation/model.py CHANGED
@@ -3,6 +3,7 @@ import pretrainedmodels
3
  import torch
4
  import torch.nn as nn
5
 
 
6
  def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
7
  """
8
  Loads a pre-trained model.
 
3
  import torch
4
  import torch.nn as nn
5
 
6
+
7
  def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
8
  """
9
  Loads a pre-trained model.
age_estimation/predict.py CHANGED
@@ -7,8 +7,16 @@ import torch.nn.functional as F
7
  AGE_ESTIMATION_MARGIN = 0.4
8
  AGE_ESTIMATION_INPUT_SIZE = 224
9
 
 
10
  @torch.inference_mode()
11
- def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGIN, input_size=AGE_ESTIMATION_INPUT_SIZE):
 
 
 
 
 
 
 
12
  """
13
  Predicts the age of faces in an image.
14
 
@@ -22,6 +30,8 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
22
 
23
  Returns:
24
  list: A list of dictionaries containing the age and face coordinates for each detected face.
 
 
25
  """
26
  # Read the image using OpenCV
27
  # The image is already a NumPy array (HWC, BGR)
@@ -44,9 +54,15 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
44
  if len(detected) > 0:
45
  for i, d in enumerate(detected):
46
  # Get face coordinates and dimensions
47
- x1, y1, x2, y2, w, h = d.left(), d.top(
48
- ), d.right() + 1, d.bottom() + 1, d.width(), d.height()
49
-
 
 
 
 
 
 
50
  # Calculate expanded face region with margin
51
  xw1 = max(int(x1 - margin * w), 0)
52
  yw1 = max(int(y1 - margin * h), 0)
@@ -54,8 +70,9 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
54
  yw2 = min(int(y2 + margin * h), image_h - 1)
55
 
56
  # Resize face image to the required input size for the model
57
- faces[i] = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1],
58
- (input_size, input_size))
 
59
 
60
  # Draw rectangles around the detected face and the expanded region
61
  cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
@@ -63,17 +80,30 @@ def predict_age(image, model, face_detector, device, margin=AGE_ESTIMATION_MARGI
63
 
64
  # Prepare face images for model input
65
  inputs = torch.from_numpy(
66
- np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device)
67
-
 
68
  # Perform age prediction using the model
69
  outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
70
  ages = np.arange(0, 101)
71
  predicted_ages = (outputs * ages).sum(axis=-1)
72
 
73
- # Store the predicted age and face coordinates
74
  for age, d in zip(predicted_ages, detected):
75
- age_text = f'{int(age)}'
76
- age_data.append({'age': int(age), 'text': age_text, 'face_coordinates': (d.left(), d.top())})
77
-
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Return the list of age data for each detected face
79
- return age_data
 
7
  AGE_ESTIMATION_MARGIN = 0.4
8
  AGE_ESTIMATION_INPUT_SIZE = 224
9
 
10
+
11
  @torch.inference_mode()
12
+ def predict_age(
13
+ image,
14
+ model,
15
+ face_detector,
16
+ device,
17
+ margin=AGE_ESTIMATION_MARGIN,
18
+ input_size=AGE_ESTIMATION_INPUT_SIZE,
19
+ ):
20
  """
21
  Predicts the age of faces in an image.
22
 
 
30
 
31
  Returns:
32
  list: A list of dictionaries containing the age and face coordinates for each detected face.
33
+ The 'face_coordinates' key contains a dictionary with 'x', 'y', 'w', and 'h' keys
34
+ representing the bounding box of the detected face.
35
  """
36
  # Read the image using OpenCV
37
  # The image is already a NumPy array (HWC, BGR)
 
54
  if len(detected) > 0:
55
  for i, d in enumerate(detected):
56
  # Get face coordinates and dimensions
57
+ x1, y1, x2, y2, w, h = (
58
+ d.left(),
59
+ d.top(),
60
+ d.right() + 1,
61
+ d.bottom() + 1,
62
+ d.width(),
63
+ d.height(),
64
+ )
65
+
66
  # Calculate expanded face region with margin
67
  xw1 = max(int(x1 - margin * w), 0)
68
  yw1 = max(int(y1 - margin * h), 0)
 
70
  yw2 = min(int(y2 + margin * h), image_h - 1)
71
 
72
  # Resize face image to the required input size for the model
73
+ faces[i] = cv2.resize(
74
+ image[yw1 : yw2 + 1, xw1 : xw2 + 1], (input_size, input_size)
75
+ )
76
 
77
  # Draw rectangles around the detected face and the expanded region
78
  cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
 
80
 
81
  # Prepare face images for model input
82
  inputs = torch.from_numpy(
83
+ np.transpose(faces.astype(np.float32), (0, 3, 1, 2))
84
+ ).to(device)
85
+
86
  # Perform age prediction using the model
87
  outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
88
  ages = np.arange(0, 101)
89
  predicted_ages = (outputs * ages).sum(axis=-1)
90
 
91
+ # Store the predicted age and face coordinates in [x, y, w, h] format
92
  for age, d in zip(predicted_ages, detected):
93
+ x, y, w, h = d.left(), d.top(), d.width(), d.height()
94
+ age_text = f"{int(age)}"
95
+ age_data.append(
96
+ {
97
+ "age": int(age),
98
+ "text": age_text,
99
+ "face_coordinates": {
100
+ "x": int(x),
101
+ "y": int(y),
102
+ "w": int(w),
103
+ "h": int(h),
104
+ },
105
+ }
106
+ )
107
+
108
  # Return the list of age data for each detected face
109
+ return age_data
app.py CHANGED
@@ -54,7 +54,7 @@ with gr.Blocks() as demo:
54
  inputs=[face_input_type],
55
  outputs=[face_img_upload, face_url_input, face_base64_input],
56
  queue=False,
57
- api_name=False
58
  )
59
 
60
  # Link process button to the face detection function
@@ -94,8 +94,9 @@ with gr.Blocks() as demo:
94
  # Process Button
95
  age_process_btn = gr.Button("Estimate Age")
96
 
97
- # Output Component
98
- age_text_output = gr.Textbox(label="Estimated Age")
 
99
 
100
  # Link radio button change to visibility update function
101
  age_input_type.change(
@@ -103,15 +104,15 @@ with gr.Blocks() as demo:
103
  inputs=[age_input_type],
104
  outputs=[age_img_upload, age_url_input, age_base64_input],
105
  queue=False,
106
- api_name=False
107
  )
108
 
109
  # Link process button to the age estimation function
110
- # The age_estimation function will need to be updated to handle these inputs
111
  age_process_btn.click(
112
  fn=age_estimation,
113
  inputs=[age_input_type, age_img_upload, age_url_input, age_base64_input],
114
- outputs=age_text_output,
115
  )
116
  # Create a tab for object detection
117
  with gr.Tab("Object Detection"):
@@ -146,7 +147,7 @@ with gr.Blocks() as demo:
146
  inputs=[obj_input_type],
147
  outputs=[obj_img_upload, obj_url_input, obj_base64_input],
148
  queue=False,
149
- api_name=False
150
  )
151
 
152
  # Link process button to the object detection function
 
54
  inputs=[face_input_type],
55
  outputs=[face_img_upload, face_url_input, face_base64_input],
56
  queue=False,
57
+ api_name=False,
58
  )
59
 
60
  # Link process button to the face detection function
 
94
  # Process Button
95
  age_process_btn = gr.Button("Estimate Age")
96
 
97
+ # Output Components
98
+ age_text_output = gr.Textbox(label="Estimated Age Summary")
99
+ age_raw_output = gr.JSON(label="Raw Age Estimation Data")
100
 
101
  # Link radio button change to visibility update function
102
  age_input_type.change(
 
104
  inputs=[age_input_type],
105
  outputs=[age_img_upload, age_url_input, age_base64_input],
106
  queue=False,
107
+ api_name=False,
108
  )
109
 
110
  # Link process button to the age estimation function
111
+ # The age_estimation function will now return a tuple
112
  age_process_btn.click(
113
  fn=age_estimation,
114
  inputs=[age_input_type, age_img_upload, age_url_input, age_base64_input],
115
+ outputs=[age_text_output, age_raw_output],
116
  )
117
  # Create a tab for object detection
118
  with gr.Tab("Object Detection"):
 
147
  inputs=[obj_input_type],
148
  outputs=[obj_img_upload, obj_url_input, obj_base64_input],
149
  queue=False,
150
+ api_name=False,
151
  )
152
 
153
  # Link process button to the object detection function
detection/face_detection.py CHANGED
@@ -8,12 +8,17 @@ from PIL import Image
8
 
9
  # Local imports
10
  from utils.image_utils import load_image, preprocess_image, get_image_from_input
11
- from utils.face_detector import load_face_detector # Assuming this is the dlib detector loader
 
 
12
 
13
  # Define constants
14
  HAAR_CASCADE_FILENAME = "haarcascade_frontalface_default.xml"
15
 
16
- def face_detection(input_type, uploaded_image, image_url, base64_string, face_detection_method):
 
 
 
17
  """
18
  Performs face detection on the image from various input types using the selected method.
19
 
@@ -36,7 +41,7 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
36
 
37
  if image is None:
38
  print("Image is None after loading/selection.")
39
- return None, [] # Return None for image and empty list for bboxes
40
 
41
  processed_image = None
42
  bounding_boxes = []
@@ -54,20 +59,26 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
54
  # Ensure the haarcascade file is accessible.
55
  # This path might need adjustment depending on the environment.
56
  # Construct the full path to the Haar cascade file
57
- cascade_path = os.path.join(cv2.data.haarcascades, HAAR_CASCADE_FILENAME)
 
 
58
 
59
  # Check if the cascade file exists
60
  if not os.path.exists(cascade_path):
61
- error_message = f"Error: Haar cascade file not found at {cascade_path}. Please ensure OpenCV is installed correctly and the file exists."
62
- print(error_message)
63
- return None, [] # Return None for image and empty list for bboxes
64
 
65
  face_cascade = cv2.CascadeClassifier(cascade_path)
66
 
67
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
68
  for x, y, w, h in faces:
69
- cv2.rectangle(processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2)
70
- bounding_boxes.append({'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h)})
 
 
 
 
71
 
72
  elif face_detection_method == "dlib":
73
  print("Using dlib for face detection.")
@@ -75,15 +86,19 @@ def face_detection(input_type, uploaded_image, image_url, base64_string, face_de
75
  # dlib works on RGB images, but the detector can take grayscale
76
  # However, the rectangles are relative to the original image size
77
  # Let's use the original processed_image (RGB numpy array) for drawing
78
- faces = face_detector(processed_image, 1) # 1 is the upsample level
79
  for face in faces:
80
  x, y, w, h = face.left(), face.top(), face.width(), face.height()
81
- cv2.rectangle(processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2)
82
- bounding_boxes.append({'x': int(x), 'y': int(y), 'w': int(w), 'h': int(h)})
 
 
 
 
83
 
84
  return processed_image, bounding_boxes
85
  else:
86
- return None, [] # Return None for image and empty list for bboxes
87
  except Exception as e:
88
  print(f"Error in face detection processing: {e}")
89
- return None, [] # Return None for image and empty list for bboxes
 
8
 
9
  # Local imports
10
  from utils.image_utils import load_image, preprocess_image, get_image_from_input
11
+ from utils.face_detector import (
12
+ load_face_detector,
13
+ ) # Assuming this is the dlib detector loader
14
 
15
  # Define constants
16
  HAAR_CASCADE_FILENAME = "haarcascade_frontalface_default.xml"
17
 
18
+
19
+ def face_detection(
20
+ input_type, uploaded_image, image_url, base64_string, face_detection_method
21
+ ):
22
  """
23
  Performs face detection on the image from various input types using the selected method.
24
 
 
41
 
42
  if image is None:
43
  print("Image is None after loading/selection.")
44
+ return None, [] # Return None for image and empty list for bboxes
45
 
46
  processed_image = None
47
  bounding_boxes = []
 
59
  # Ensure the haarcascade file is accessible.
60
  # This path might need adjustment depending on the environment.
61
  # Construct the full path to the Haar cascade file
62
+ cascade_path = os.path.join(
63
+ cv2.data.haarcascades, HAAR_CASCADE_FILENAME
64
+ )
65
 
66
  # Check if the cascade file exists
67
  if not os.path.exists(cascade_path):
68
+ error_message = f"Error: Haar cascade file not found at {cascade_path}. Please ensure OpenCV is installed correctly and the file exists."
69
+ print(error_message)
70
+ return None, [] # Return None for image and empty list for bboxes
71
 
72
  face_cascade = cv2.CascadeClassifier(cascade_path)
73
 
74
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
75
  for x, y, w, h in faces:
76
+ cv2.rectangle(
77
+ processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2
78
+ )
79
+ bounding_boxes.append(
80
+ {"x": int(x), "y": int(y), "w": int(w), "h": int(h)}
81
+ )
82
 
83
  elif face_detection_method == "dlib":
84
  print("Using dlib for face detection.")
 
86
  # dlib works on RGB images, but the detector can take grayscale
87
  # However, the rectangles are relative to the original image size
88
  # Let's use the original processed_image (RGB numpy array) for drawing
89
+ faces = face_detector(processed_image, 1) # 1 is the upsample level
90
  for face in faces:
91
  x, y, w, h = face.left(), face.top(), face.width(), face.height()
92
+ cv2.rectangle(
93
+ processed_image, (x, y), (x + w, y + h), (255, 0, 0), 2
94
+ )
95
+ bounding_boxes.append(
96
+ {"x": int(x), "y": int(y), "w": int(w), "h": int(h)}
97
+ )
98
 
99
  return processed_image, bounding_boxes
100
  else:
101
+ return None, [] # Return None for image and empty list for bboxes
102
  except Exception as e:
103
  print(f"Error in face detection processing: {e}")
104
+ return None, [] # Return None for image and empty list for bboxes
detection/object_detection.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  # Local imports
9
  from utils.image_utils import load_image, preprocess_image
10
 
 
11
  def object_detection(input_type, uploaded_image, image_url, base64_string):
12
  """
13
  Performs object detection on the image from various input types.
@@ -25,26 +26,26 @@ def object_detection(input_type, uploaded_image, image_url, base64_string):
25
  input_value = None
26
 
27
  if input_type == "Upload File" and uploaded_image is not None:
28
- image = uploaded_image # This is a PIL Image
29
- print("Using uploaded image (PIL) for object detection") # Debug print
30
 
31
  elif input_type == "Enter URL" and image_url and image_url.strip():
32
  input_value = image_url
33
- print(f"Using URL for object detection: {input_value}") # Debug print
34
 
35
  elif input_type == "Enter Base64" and base64_string and base64_string.strip():
36
  input_value = base64_string
37
- print(f"Using Base64 string for object detection") # Debug print
38
 
39
  else:
40
  print("No valid input provided for object detection based on selected type.")
41
- return None # No valid input
42
 
43
  # If input_value is set (URL or Base64), use load_image
44
  if input_value:
45
  image = load_image(input_value)
46
  if image is None:
47
- return None # load_image failed
48
 
49
  # Now 'image' should be a PIL Image or None
50
  if image is None:
@@ -62,4 +63,4 @@ def object_detection(input_type, uploaded_image, image_url, base64_string):
62
  return processed_image
63
  except Exception as e:
64
  print(f"Error in object detection processing: {e}")
65
- return None
 
8
  # Local imports
9
  from utils.image_utils import load_image, preprocess_image
10
 
11
+
12
  def object_detection(input_type, uploaded_image, image_url, base64_string):
13
  """
14
  Performs object detection on the image from various input types.
 
26
  input_value = None
27
 
28
  if input_type == "Upload File" and uploaded_image is not None:
29
+ image = uploaded_image # This is a PIL Image
30
+ print("Using uploaded image (PIL) for object detection") # Debug print
31
 
32
  elif input_type == "Enter URL" and image_url and image_url.strip():
33
  input_value = image_url
34
+ print(f"Using URL for object detection: {input_value}") # Debug print
35
 
36
  elif input_type == "Enter Base64" and base64_string and base64_string.strip():
37
  input_value = base64_string
38
+ print(f"Using Base64 string for object detection") # Debug print
39
 
40
  else:
41
  print("No valid input provided for object detection based on selected type.")
42
+ return None # No valid input
43
 
44
  # If input_value is set (URL or Base64), use load_image
45
  if input_value:
46
  image = load_image(input_value)
47
  if image is None:
48
+ return None # load_image failed
49
 
50
  # Now 'image' should be a PIL Image or None
51
  if image is None:
 
63
  return processed_image
64
  except Exception as e:
65
  print(f"Error in object detection processing: {e}")
66
+ return None