Spaces:
Sleeping
Sleeping
"""Manage bounding boxes in the database.""" | |
from sqlalchemy import text | |
import json | |
from data_models.sql_connection import get_db_connection | |
class OpenAIManager: | |
def __init__(self): | |
"""Initialise connection and session.""" | |
self.engine, self.session = get_db_connection() | |
def add_predictions( | |
self, | |
img_id, | |
predictions, | |
): | |
""" | |
Add predictions to the `openai_predictions` table. | |
Args: | |
img_id (int): ID of the image where the bounding box was detected. | |
built_elements (dict): Built elements detected in the image. | |
fauna_identification (dict): Fauna identification detected in the image. | |
human_activity (dict): Human activity detected in the image. | |
human_detection (dict): Humans detected in the image. | |
vegetation_detection (dict): Vegetation detected in the image. | |
water_elements (dict): Water elements detected in the image. | |
""" | |
query = text( | |
""" | |
INSERT INTO openai_predictions (img_id, built_elements, fauna_identification, human_activity, human_detection, vegetation_detection, water_elements) | |
VALUES (:img_id, :built_elements, :fauna_identification, :human_activity, :human_detection, :vegetation_detection, :water_elements) | |
""" | |
) | |
try: | |
self.session.execute( | |
query, | |
{ | |
"img_id": img_id, | |
"built_elements": json.dumps(predictions["built_elements"]), | |
"fauna_identification": json.dumps(predictions["fauna_identification"]), | |
"human_activity": json.dumps(predictions["human_activity"]), | |
"human_detection": json.dumps(predictions["human_detection"]), | |
"vegetation_detection": json.dumps(predictions["vegetation_detection"]), | |
"water_elements": json.dumps(predictions["water_elements"]) | |
}, | |
) | |
self.session.commit() | |
return { | |
"img_id": img_id, | |
"built_elements": predictions["built_elements"], | |
"fauna_identification": predictions["fauna_identification"], | |
"human_activity": predictions["human_activity"], | |
"human_detection": predictions["human_detection"], | |
"vegetation_detection": predictions["vegetation_detection"], | |
"water_elements": predictions["water_elements"], | |
} | |
except Exception as e: | |
self.session.rollback() | |
raise Exception(f"An error occurred while adding the predictions: {e}") | |
def get_predictions(self, img_id): | |
""" | |
Get predictions from the `openai_predictions` table. | |
Args: | |
img_id (int): ID of the image where the bounding box was detected. | |
Returns: | |
dict: Predictions of the image. | |
""" | |
query = text("SELECT * FROM openai_predictions WHERE img_id = :img_id") | |
try: | |
result = self.session.execute(query, {"img_id": img_id}).fetchone() | |
return result._asdict() if result else None | |
except Exception as e: | |
raise Exception(f"An error occurred while getting the predictions: {e}") | |
def get_all_predictions(self): | |
""" | |
Get all predictions from the `openai_predictions` table. | |
Returns: | |
list: List of predictions. | |
""" | |
query = text("SELECT * FROM openai_predictions") | |
try: | |
result = self.session.execute(query).fetchall() | |
return [row._asdict() for row in result] | |
except Exception as e: | |
raise Exception(f"An error occurred while getting the predictions: {e}") | |
def close_connection(self): | |
"""Close the connection.""" | |
self.session.close() | |