SUAD_Park / src /data_models /openai_manager.py
leo-bourrel's picture
feat: add statistics
332cb73
"""Manage bounding boxes in the database."""
from sqlalchemy import text
import json
from data_models.sql_connection import get_db_connection
class OpenAIManager:
def __init__(self):
"""Initialise connection and session."""
self.engine, self.session = get_db_connection()
def add_predictions(
self,
img_id,
predictions,
):
"""
Add predictions to the `openai_predictions` table.
Args:
img_id (int): ID of the image where the bounding box was detected.
built_elements (dict): Built elements detected in the image.
fauna_identification (dict): Fauna identification detected in the image.
human_activity (dict): Human activity detected in the image.
human_detection (dict): Humans detected in the image.
vegetation_detection (dict): Vegetation detected in the image.
water_elements (dict): Water elements detected in the image.
"""
query = text(
"""
INSERT INTO openai_predictions (img_id, built_elements, fauna_identification, human_activity, human_detection, vegetation_detection, water_elements)
VALUES (:img_id, :built_elements, :fauna_identification, :human_activity, :human_detection, :vegetation_detection, :water_elements)
"""
)
try:
self.session.execute(
query,
{
"img_id": img_id,
"built_elements": json.dumps(predictions["built_elements"]),
"fauna_identification": json.dumps(predictions["fauna_identification"]),
"human_activity": json.dumps(predictions["human_activity"]),
"human_detection": json.dumps(predictions["human_detection"]),
"vegetation_detection": json.dumps(predictions["vegetation_detection"]),
"water_elements": json.dumps(predictions["water_elements"])
},
)
self.session.commit()
return {
"img_id": img_id,
"built_elements": predictions["built_elements"],
"fauna_identification": predictions["fauna_identification"],
"human_activity": predictions["human_activity"],
"human_detection": predictions["human_detection"],
"vegetation_detection": predictions["vegetation_detection"],
"water_elements": predictions["water_elements"],
}
except Exception as e:
self.session.rollback()
raise Exception(f"An error occurred while adding the predictions: {e}")
def get_predictions(self, img_id):
"""
Get predictions from the `openai_predictions` table.
Args:
img_id (int): ID of the image where the bounding box was detected.
Returns:
dict: Predictions of the image.
"""
query = text("SELECT * FROM openai_predictions WHERE img_id = :img_id")
try:
result = self.session.execute(query, {"img_id": img_id}).fetchone()
return result._asdict() if result else None
except Exception as e:
raise Exception(f"An error occurred while getting the predictions: {e}")
def get_all_predictions(self):
"""
Get all predictions from the `openai_predictions` table.
Returns:
list: List of predictions.
"""
query = text("SELECT * FROM openai_predictions")
try:
result = self.session.execute(query).fetchall()
return [row._asdict() for row in result]
except Exception as e:
raise Exception(f"An error occurred while getting the predictions: {e}")
def close_connection(self):
"""Close the connection."""
self.session.close()