File size: 6,754 Bytes
a2c10b6
 
 
 
 
 
 
 
 
cec4377
 
 
a2c10b6
 
 
 
 
cec4377
 
 
 
 
 
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec4377
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca4d8ec
698c6c9
a2c10b6
 
 
 
 
cec4377
 
 
ca4d8ec
cec4377
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import pandas as pd
import faiss
import numpy as np
import sqlite3
from .embedder import Embedder
from datetime import datetime

class SQL_Key_Pair:
    def __init__(self, file_path="financial_data.csv", model_name="all-MiniLM-L6-v2", db_path="/app/db/financial_data.db"):
        # Ensure the database directory exists
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        self.embedder = Embedder(model_name)
        self.index = None
        self.documents = []
        self.data = None
        self.embeddings = None
        try:
            self.db_conn = sqlite3.connect(db_path)
            print(f"Connected to SQLite database at {db_path}")
        except sqlite3.OperationalError as e:
            print(f"Failed to connect to database: {e}")
            raise
        self.create_db_table()
        self.load_data(file_path)

    def create_db_table(self):
        """
        Create the custom_financials table in the database if it doesn’t exist.
        """
        cursor = self.db_conn.cursor()
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS custom_financials (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                source_file TEXT,
                firm TEXT,
                ticker TEXT,
                date TEXT,
                metric TEXT,
                value REAL,
                last_updated TEXT
            )
        """)
        self.db_conn.commit()
        print("Created custom_financials table")

    def load_data(self, file_path):
        """
        Load financial data from a CSV or Excel file and store it in the database.
        """
        try:
            if file_path.endswith('.csv'):
                df = pd.read_csv(file_path)
            elif file_path.endswith('.xlsx'):
                df = pd.read_excel(file_path)
            else:
                raise ValueError("Unsupported file format. Use .csv or .xlsx.")

            self.data = df
            self.documents = self.data["Ticker"].astype(str).tolist()
            
            cursor = self.db_conn.cursor()
            for _, row in self.data.iterrows():
                firm = row.get("firm", "")
                ticker = row.get("Ticker", "")
                date = row.get("date", "")
                for column in self.data.columns:
                    if pd.notna(row[column]):
                        try:
                            value = float(row[column])
                        except (ValueError, TypeError):
                            value = 0.0
                        cursor.execute("""
                            INSERT INTO custom_financials (source_file, firm, ticker, date, metric, value, last_updated)
                            VALUES (?, ?, ?, ?, ?, ?, ?)
                        """, (os.path.basename(file_path), firm, ticker, str(date), column, value, datetime.now().isoformat()))
            self.db_conn.commit()
            print(f"Loaded {len(self.data)} rows from {file_path} into custom_financials.")
            self.build_index()  # Rebuild FAISS index after loading
        except Exception as e:
            print(f"Error loading data: {e}")
            self.documents = []
            self.data = pd.DataFrame()

    def build_index(self):
        """
        Build a FAISS index from the embedded descriptions.
        """
        if not self.documents:
            return
        self.embeddings = self.embedder.embed(self.documents)
        dim = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dim)
        self.index.add(self.embeddings)

    def keyword_match_search(self, entities):
        """
        Perform strict keyword match based search from CSV.
        """
        if self.data is None or self.data.empty:
            return "No data loaded."

        ticker = entities.get("ticker", "")
        metric = entities.get("metric", "")
        
        if not ticker or not metric:
            return "No relevant data found."
        
        ticker = ticker.lower()
        metric = metric.lower()

        retrieved_text = ""
        for _, row in self.data.iterrows():
            if str(row.get("Ticker", "")).lower() == ticker:
                for col in self.data.columns:
                    if col.lower() == metric:
                        if pd.isna(row[col]) or row[col] == "":
                            continue
                        value_in_billions = row[col] / 1_000_000_000
                        retrieved_text = f"Retrieved {metric} for {ticker} is : ${value_in_billions:.2f} billion."
                        break
                break

        if not retrieved_text:
            return "No relevant data found."

        return retrieved_text

    def query_csv(self, query, k=3):
        """
        Query the CSV data with a user query.
        """
        retrieved_data = self.retrieve(query, k=k)
        if not retrieved_data:
            return "No relevant data found."

        responses = []
        for entry in retrieved_data:
            try:
                value = float(entry["value"])
                value_in_billions = value / 1_000_000_000
                response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was ${value_in_billions:.2f} billion."
            except:
                response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was {entry['value']}."
            responses.append(response)

        return "\n".join(responses)

    def entity_based_query(self, entities):
        return self.keyword_match_search(entities)

    def query_db(self, ticker, metric):
        """
        Query the custom_financials table based on ticker and metric, ignoring date and year.
        """
        try:
            cursor = self.db_conn.cursor()
            query = """
                SELECT value FROM custom_financials
                WHERE ticker = ? AND metric = ?
                LIMIT 1
            """
            params = [ticker, metric]
            cursor.execute(query, params)
            result = cursor.fetchone()
            if result:
                value = result[0]
                value_in_billions = value / 1_000_000_000
                return f"The {metric} for {ticker} was : ${value_in_billions:.2f} billion in 2024."
            return f"No relevant data found for {ticker}."
        except Exception as e:
            print(f"Error querying database: {e}")
            return f"Error querying database: {str(e)}"

    def __del__(self):
        try:
            if hasattr(self, 'db_conn') and self.db_conn:
                self.db_conn.close()
                # print("Closed SQLite database connection")
        except Exception as e:
            print(f"Error closing database connection: {e}")