SUAD_Park / src /data_models /bbox_manager.py
leo-bourrel's picture
feat: split prediction into YOLO predictions
911c0ac
"""Manage bounding boxes in the database."""
from sqlalchemy import text
from data_models.sql_connection import get_db_connection
class BoundingBoxManager:
def __init__(self):
"""Initialise connection and session."""
self.engine, self.session = get_db_connection()
def add_bbox(self, confidence, class_id, img_id, x_min, y_min, x_max, y_max):
"""
Add a bounding box to the `bboxes` table.
Args:
confidence (float): Confidence of the detection.
class_id (str): Class ID of the detection.
img_id (int): ID of the image where the bounding box was detected.
x_min (float): Minimum X coordinate.
y_min (float): Minimum Y coordinate.
x_max (float): Maximum X coordinate.
y_max (float): Maximum Y coordinate.
Returns:
dict: Information of the added bounding box.
"""
query = text(
"""
INSERT INTO bboxes (confidence, class_id, img_id, x_min, y_min, x_max, y_max)
VALUES (:confidence, :class_id, :img_id, :x_min, :y_min, :x_max, :y_max)
"""
)
try:
self.session.execute(
query,
{
"confidence": confidence,
"class_id": class_id,
"img_id": img_id,
"x_min": x_min,
"y_min": y_min,
"x_max": x_max,
"y_max": y_max,
},
)
self.session.commit()
return {
"confidence": confidence,
"class_id": class_id,
"img_id": img_id,
"x_min": x_min,
"y_min": y_min,
"x_max": x_max,
"y_max": y_max,
}
except Exception as e:
self.session.rollback()
raise Exception(f"An error occurred while adding the bounding box: {e}")
def close_connection(self):
"""Close the connection."""
self.session.close()
def add_bboxes_to_db(result, bbox_manager, image_id):
"""
Adds bounding boxes from a YOLO result to the database.
Args:
result: YOLO result object containing the bounding box details.
bbox_manager (BoundingBoxManager): Instance of BoundingBoxManager to interact with the database.
image_id (int): The ID of the associated image.
Returns:
None
"""
for box in result.boxes:
try:
print(f"Adding bounding box for image ID: {image_id}...")
bbox_manager.add_bbox(
confidence=box.conf[0].numpy().item(),
class_id=int(box.cls[0].numpy().item()),
img_id=image_id,
x_min=box.xyxy[0][0].numpy().item(),
y_min=box.xyxy[0][1].numpy().item(),
x_max=box.xyxy[0][2].numpy().item(),
y_max=box.xyxy[0][3].numpy().item(),
)
print(f"Bounding box added to DB for image ID: {image_id}")
except Exception as e:
print(f"Error inserting bounding box into DB: {e}")