"""Manage bounding boxes in the database.""" from sqlalchemy import text import json from data_models.sql_connection import get_db_connection class OpenAIManager: def __init__(self): """Initialise connection and session.""" self.engine, self.session = get_db_connection() def add_predictions( self, img_id, predictions, ): """ Add predictions to the `openai_predictions` table. Args: img_id (int): ID of the image where the bounding box was detected. built_elements (dict): Built elements detected in the image. fauna_identification (dict): Fauna identification detected in the image. human_activity (dict): Human activity detected in the image. human_detection (dict): Humans detected in the image. vegetation_detection (dict): Vegetation detected in the image. water_elements (dict): Water elements detected in the image. """ query = text( """ INSERT INTO openai_predictions (img_id, built_elements, fauna_identification, human_activity, human_detection, vegetation_detection, water_elements) VALUES (:img_id, :built_elements, :fauna_identification, :human_activity, :human_detection, :vegetation_detection, :water_elements) """ ) try: self.session.execute( query, { "img_id": img_id, "built_elements": json.dumps(predictions["built_elements"]), "fauna_identification": json.dumps(predictions["fauna_identification"]), "human_activity": json.dumps(predictions["human_activity"]), "human_detection": json.dumps(predictions["human_detection"]), "vegetation_detection": json.dumps(predictions["vegetation_detection"]), "water_elements": json.dumps(predictions["water_elements"]) }, ) self.session.commit() return { "img_id": img_id, "built_elements": predictions["built_elements"], "fauna_identification": predictions["fauna_identification"], "human_activity": predictions["human_activity"], "human_detection": predictions["human_detection"], "vegetation_detection": predictions["vegetation_detection"], "water_elements": predictions["water_elements"], } except Exception as e: self.session.rollback() raise Exception(f"An error occurred while adding the predictions: {e}") def get_predictions(self, img_id): """ Get predictions from the `openai_predictions` table. Args: img_id (int): ID of the image where the bounding box was detected. Returns: dict: Predictions of the image. """ query = text("SELECT * FROM openai_predictions WHERE img_id = :img_id") try: result = self.session.execute(query, {"img_id": img_id}).fetchone() return result._asdict() if result else None except Exception as e: raise Exception(f"An error occurred while getting the predictions: {e}") def get_all_predictions(self): """ Get all predictions from the `openai_predictions` table. Returns: list: List of predictions. """ query = text("SELECT * FROM openai_predictions") try: result = self.session.execute(query).fetchall() return [row._asdict() for row in result] except Exception as e: raise Exception(f"An error occurred while getting the predictions: {e}") def close_connection(self): """Close the connection.""" self.session.close()