Spaces:
Running
Running
File size: 6,754 Bytes
a2c10b6 cec4377 a2c10b6 cec4377 a2c10b6 cec4377 a2c10b6 ca4d8ec 698c6c9 a2c10b6 cec4377 ca4d8ec cec4377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import pandas as pd
import faiss
import numpy as np
import sqlite3
from .embedder import Embedder
from datetime import datetime
class SQL_Key_Pair:
def __init__(self, file_path="financial_data.csv", model_name="all-MiniLM-L6-v2", db_path="/app/db/financial_data.db"):
# Ensure the database directory exists
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self.embedder = Embedder(model_name)
self.index = None
self.documents = []
self.data = None
self.embeddings = None
try:
self.db_conn = sqlite3.connect(db_path)
print(f"Connected to SQLite database at {db_path}")
except sqlite3.OperationalError as e:
print(f"Failed to connect to database: {e}")
raise
self.create_db_table()
self.load_data(file_path)
def create_db_table(self):
"""
Create the custom_financials table in the database if it doesn’t exist.
"""
cursor = self.db_conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS custom_financials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_file TEXT,
firm TEXT,
ticker TEXT,
date TEXT,
metric TEXT,
value REAL,
last_updated TEXT
)
""")
self.db_conn.commit()
print("Created custom_financials table")
def load_data(self, file_path):
"""
Load financial data from a CSV or Excel file and store it in the database.
"""
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx'):
df = pd.read_excel(file_path)
else:
raise ValueError("Unsupported file format. Use .csv or .xlsx.")
self.data = df
self.documents = self.data["Ticker"].astype(str).tolist()
cursor = self.db_conn.cursor()
for _, row in self.data.iterrows():
firm = row.get("firm", "")
ticker = row.get("Ticker", "")
date = row.get("date", "")
for column in self.data.columns:
if pd.notna(row[column]):
try:
value = float(row[column])
except (ValueError, TypeError):
value = 0.0
cursor.execute("""
INSERT INTO custom_financials (source_file, firm, ticker, date, metric, value, last_updated)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (os.path.basename(file_path), firm, ticker, str(date), column, value, datetime.now().isoformat()))
self.db_conn.commit()
print(f"Loaded {len(self.data)} rows from {file_path} into custom_financials.")
self.build_index() # Rebuild FAISS index after loading
except Exception as e:
print(f"Error loading data: {e}")
self.documents = []
self.data = pd.DataFrame()
def build_index(self):
"""
Build a FAISS index from the embedded descriptions.
"""
if not self.documents:
return
self.embeddings = self.embedder.embed(self.documents)
dim = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(self.embeddings)
def keyword_match_search(self, entities):
"""
Perform strict keyword match based search from CSV.
"""
if self.data is None or self.data.empty:
return "No data loaded."
ticker = entities.get("ticker", "")
metric = entities.get("metric", "")
if not ticker or not metric:
return "No relevant data found."
ticker = ticker.lower()
metric = metric.lower()
retrieved_text = ""
for _, row in self.data.iterrows():
if str(row.get("Ticker", "")).lower() == ticker:
for col in self.data.columns:
if col.lower() == metric:
if pd.isna(row[col]) or row[col] == "":
continue
value_in_billions = row[col] / 1_000_000_000
retrieved_text = f"Retrieved {metric} for {ticker} is : ${value_in_billions:.2f} billion."
break
break
if not retrieved_text:
return "No relevant data found."
return retrieved_text
def query_csv(self, query, k=3):
"""
Query the CSV data with a user query.
"""
retrieved_data = self.retrieve(query, k=k)
if not retrieved_data:
return "No relevant data found."
responses = []
for entry in retrieved_data:
try:
value = float(entry["value"])
value_in_billions = value / 1_000_000_000
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was ${value_in_billions:.2f} billion."
except:
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was {entry['value']}."
responses.append(response)
return "\n".join(responses)
def entity_based_query(self, entities):
return self.keyword_match_search(entities)
def query_db(self, ticker, metric):
"""
Query the custom_financials table based on ticker and metric, ignoring date and year.
"""
try:
cursor = self.db_conn.cursor()
query = """
SELECT value FROM custom_financials
WHERE ticker = ? AND metric = ?
LIMIT 1
"""
params = [ticker, metric]
cursor.execute(query, params)
result = cursor.fetchone()
if result:
value = result[0]
value_in_billions = value / 1_000_000_000
return f"The {metric} for {ticker} was : ${value_in_billions:.2f} billion in 2024."
return f"No relevant data found for {ticker}."
except Exception as e:
print(f"Error querying database: {e}")
return f"Error querying database: {str(e)}"
def __del__(self):
try:
if hasattr(self, 'db_conn') and self.db_conn:
self.db_conn.close()
# print("Closed SQLite database connection")
except Exception as e:
print(f"Error closing database connection: {e}") |