File size: 3,911 Bytes
0508b3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecbbafe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332cb73
 
 
 
 
 
 
 
 
 
 
 
 
 
0508b3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Manage bounding boxes in the database."""

from sqlalchemy import text
import json
from data_models.sql_connection import get_db_connection


class OpenAIManager:
    def __init__(self):
        """Initialise connection and session."""
        self.engine, self.session = get_db_connection()

    def add_predictions(
        self,
        img_id,
        predictions,
    ):
        """
        Add predictions to the `openai_predictions` table.

        Args:
            img_id (int): ID of the image where the bounding box was detected.
            built_elements (dict): Built elements detected in the image.
            fauna_identification (dict): Fauna identification detected in the image.
            human_activity (dict): Human activity detected in the image.
            human_detection (dict): Humans detected in the image.
            vegetation_detection (dict): Vegetation detected in the image.
            water_elements (dict): Water elements detected in the image.
        """
        query = text(
            """
            INSERT INTO openai_predictions (img_id, built_elements, fauna_identification, human_activity, human_detection, vegetation_detection, water_elements)
            VALUES (:img_id, :built_elements, :fauna_identification, :human_activity, :human_detection, :vegetation_detection, :water_elements)
        """
        )
        try:
            self.session.execute(
                query,
                {
                    "img_id": img_id,
                    "built_elements": json.dumps(predictions["built_elements"]),
                    "fauna_identification": json.dumps(predictions["fauna_identification"]),
                    "human_activity": json.dumps(predictions["human_activity"]),
                    "human_detection": json.dumps(predictions["human_detection"]),
                    "vegetation_detection": json.dumps(predictions["vegetation_detection"]),
                    "water_elements": json.dumps(predictions["water_elements"])
                },
            )
            self.session.commit()
            return {
                "img_id": img_id,
                "built_elements": predictions["built_elements"],
                "fauna_identification": predictions["fauna_identification"],
                "human_activity": predictions["human_activity"],
                "human_detection": predictions["human_detection"],
                "vegetation_detection": predictions["vegetation_detection"],
                "water_elements": predictions["water_elements"],
            }
        except Exception as e:
            self.session.rollback()
            raise Exception(f"An error occurred while adding the predictions: {e}")

    def get_predictions(self, img_id):
        """
        Get predictions from the `openai_predictions` table.

        Args:
            img_id (int): ID of the image where the bounding box was detected.

        Returns:
            dict: Predictions of the image.
        """
        query = text("SELECT * FROM openai_predictions WHERE img_id = :img_id")
        try:
            result = self.session.execute(query, {"img_id": img_id}).fetchone()
            return result._asdict() if result else None
        except Exception as e:
            raise Exception(f"An error occurred while getting the predictions: {e}")

    def get_all_predictions(self):
        """
        Get all predictions from the `openai_predictions` table.

        Returns:
            list: List of predictions.
        """
        query = text("SELECT * FROM openai_predictions")
        try:
            result = self.session.execute(query).fetchall()
            return [row._asdict() for row in result]
        except Exception as e:
            raise Exception(f"An error occurred while getting the predictions: {e}")

    def close_connection(self):
        """Close the connection."""
        self.session.close()