Spaces:
Running
Running
import os | |
import pandas as pd | |
import faiss | |
import numpy as np | |
import sqlite3 | |
from .embedder import Embedder | |
from datetime import datetime | |
class SQL_Key_Pair: | |
def __init__(self, file_path="financial_data.csv", model_name="all-MiniLM-L6-v2", db_path="/app/db/financial_data.db"): | |
# Ensure the database directory exists | |
os.makedirs(os.path.dirname(db_path), exist_ok=True) | |
self.embedder = Embedder(model_name) | |
self.index = None | |
self.documents = [] | |
self.data = None | |
self.embeddings = None | |
try: | |
self.db_conn = sqlite3.connect(db_path) | |
print(f"Connected to SQLite database at {db_path}") | |
except sqlite3.OperationalError as e: | |
print(f"Failed to connect to database: {e}") | |
raise | |
self.create_db_table() | |
self.load_data(file_path) | |
def create_db_table(self): | |
""" | |
Create the custom_financials table in the database if it doesn’t exist. | |
""" | |
cursor = self.db_conn.cursor() | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS custom_financials ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
source_file TEXT, | |
firm TEXT, | |
ticker TEXT, | |
date TEXT, | |
metric TEXT, | |
value REAL, | |
last_updated TEXT | |
) | |
""") | |
self.db_conn.commit() | |
print("Created custom_financials table") | |
def load_data(self, file_path): | |
""" | |
Load financial data from a CSV or Excel file and store it in the database. | |
""" | |
try: | |
if file_path.endswith('.csv'): | |
df = pd.read_csv(file_path) | |
elif file_path.endswith('.xlsx'): | |
df = pd.read_excel(file_path) | |
else: | |
raise ValueError("Unsupported file format. Use .csv or .xlsx.") | |
self.data = df | |
self.documents = self.data["Ticker"].astype(str).tolist() | |
cursor = self.db_conn.cursor() | |
for _, row in self.data.iterrows(): | |
firm = row.get("firm", "") | |
ticker = row.get("Ticker", "") | |
date = row.get("date", "") | |
for column in self.data.columns: | |
if pd.notna(row[column]): | |
try: | |
value = float(row[column]) | |
except (ValueError, TypeError): | |
value = 0.0 | |
cursor.execute(""" | |
INSERT INTO custom_financials (source_file, firm, ticker, date, metric, value, last_updated) | |
VALUES (?, ?, ?, ?, ?, ?, ?) | |
""", (os.path.basename(file_path), firm, ticker, str(date), column, value, datetime.now().isoformat())) | |
self.db_conn.commit() | |
print(f"Loaded {len(self.data)} rows from {file_path} into custom_financials.") | |
self.build_index() # Rebuild FAISS index after loading | |
except Exception as e: | |
print(f"Error loading data: {e}") | |
self.documents = [] | |
self.data = pd.DataFrame() | |
def build_index(self): | |
""" | |
Build a FAISS index from the embedded descriptions. | |
""" | |
if not self.documents: | |
return | |
self.embeddings = self.embedder.embed(self.documents) | |
dim = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(dim) | |
self.index.add(self.embeddings) | |
def keyword_match_search(self, entities): | |
""" | |
Perform strict keyword match based search from CSV. | |
""" | |
if self.data is None or self.data.empty: | |
return "No data loaded." | |
ticker = entities.get("ticker", "") | |
metric = entities.get("metric", "") | |
if not ticker or not metric: | |
return "No relevant data found." | |
ticker = ticker.lower() | |
metric = metric.lower() | |
retrieved_text = "" | |
for _, row in self.data.iterrows(): | |
if str(row.get("Ticker", "")).lower() == ticker: | |
for col in self.data.columns: | |
if col.lower() == metric: | |
if pd.isna(row[col]) or row[col] == "": | |
continue | |
value_in_billions = row[col] / 1_000_000_000 | |
retrieved_text = f"Retrieved {metric} for {ticker} is : ${value_in_billions:.2f} billion." | |
break | |
break | |
if not retrieved_text: | |
return "No relevant data found." | |
return retrieved_text | |
def query_csv(self, query, k=3): | |
""" | |
Query the CSV data with a user query. | |
""" | |
retrieved_data = self.retrieve(query, k=k) | |
if not retrieved_data: | |
return "No relevant data found." | |
responses = [] | |
for entry in retrieved_data: | |
try: | |
value = float(entry["value"]) | |
value_in_billions = value / 1_000_000_000 | |
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was ${value_in_billions:.2f} billion." | |
except: | |
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was {entry['value']}." | |
responses.append(response) | |
return "\n".join(responses) | |
def entity_based_query(self, entities): | |
return self.keyword_match_search(entities) | |
def query_db(self, ticker, metric): | |
""" | |
Query the custom_financials table based on ticker and metric, ignoring date and year. | |
""" | |
try: | |
cursor = self.db_conn.cursor() | |
query = """ | |
SELECT value FROM custom_financials | |
WHERE ticker = ? AND metric = ? | |
LIMIT 1 | |
""" | |
params = [ticker, metric] | |
cursor.execute(query, params) | |
result = cursor.fetchone() | |
if result: | |
value = result[0] | |
value_in_billions = value / 1_000_000_000 | |
return f"The {metric} for {ticker} was : ${value_in_billions:.2f} billion in 2024." | |
return f"No relevant data found for {ticker}." | |
except Exception as e: | |
print(f"Error querying database: {e}") | |
return f"Error querying database: {str(e)}" | |
def __del__(self): | |
try: | |
if hasattr(self, 'db_conn') and self.db_conn: | |
self.db_conn.close() | |
# print("Closed SQLite database connection") | |
except Exception as e: | |
print(f"Error closing database connection: {e}") |