Spaces:
Running
Running
import pandas as pd | |
import faiss | |
import numpy as np | |
from .embedder import Embedder | |
from fuzzywuzzy import fuzz | |
from langchain_community.llms import Ollama | |
from langchain.prompts import PromptTemplate | |
class Retriever: | |
def __init__(self, file_path): | |
self.embedder = Embedder(model_name="all-MiniLM-L6-v2") | |
self.index = None | |
self.documents = [] | |
self.data = None | |
self.embeddings = None | |
self.load_file(file_path) | |
self.build_index() | |
def load_file(self, file_path): | |
try: | |
if file_path.endswith('.csv'): | |
self.data = pd.read_csv(file_path) | |
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'): | |
self.data = pd.read_excel(file_path) | |
else: | |
raise ValueError("Unsupported file format. Use .csv, .xlsx, or .xls") | |
self.documents = self.data["Ticker"].astype(str).tolist() | |
except Exception as e: | |
print(f"Error loading file: {e}") | |
self.documents = [] | |
self.data = pd.DataFrame() | |
def build_index(self): | |
if not self.documents: | |
return | |
self.embeddings = self.embedder.embed(self.documents) | |
dim = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(dim) | |
self.index.add(self.embeddings) | |
def retrieve(self, query, entities, k=3, threshold=0.7): | |
query_prompt = f"{entities['ticker']} {entities['metric']} {entities['year']}" | |
# print(query_prompt) | |
if not self.index or not self.documents or self.data.empty: | |
return [] | |
query_parts = query_prompt.split() | |
if len(query_parts) != 3: | |
print("Query must follow 'ticker metric year' pattern") | |
return [] | |
query_ticker, query_metric, query_year = query_parts | |
# Ticker similarity | |
query_ticker_embedding = self.embedder.embed([query_ticker]) | |
distances, indices = self.index.search(query_ticker_embedding, k) | |
ticker_matches = [] | |
for i, idx in enumerate(indices[0]): | |
if idx < len(self.documents): | |
ticker = self.data.iloc[idx]["Ticker"] | |
similarity_score = 1 - distances[0][i] / 2 | |
ticker_matches.append((ticker, similarity_score, idx)) | |
# Metric similarity | |
metric_embeddings = self.embedder.embed(self.data.columns.tolist()) | |
query_metric_embedding = self.embedder.embed([query_metric])[0] | |
metric_scores = [] | |
for col, col_embedding in zip(self.data.columns, metric_embeddings): | |
if col.lower() in ["ticker", "year"]: | |
continue | |
cos_sim = np.dot(query_metric_embedding, col_embedding) / ( | |
np.linalg.norm(query_metric_embedding) * np.linalg.norm(col_embedding) | |
) | |
metric_scores.append((col, cos_sim)) | |
# Year similarity | |
if "Year" not in self.data.columns: | |
print("No 'Year' column found in data") | |
return [] | |
year_scores = [] | |
for year in self.data["Year"].astype(str).unique(): | |
similarity = fuzz.ratio(query_year, year) / 100.0 | |
year_scores.append((year, similarity)) | |
# Combine matches | |
retrieved_data = [] | |
seen = set() | |
for ticker, ticker_score, idx in ticker_matches: | |
if ticker_score < threshold: | |
continue | |
for metric, metric_score in metric_scores: | |
if metric_score < threshold: | |
continue | |
for year, year_score in year_scores: | |
if year_score < 0.5: | |
continue | |
combined_score = (ticker_score + metric_score + year_score) / 3 | |
match = self.data[ | |
(self.data["Ticker"].str.lower() == ticker.lower()) & | |
(self.data["Year"].astype(str) == year) & | |
(self.data[metric].notnull()) | |
] | |
if not match.empty: | |
value = match[metric].iloc[0] | |
key = (ticker, metric, year) | |
if key not in seen: | |
seen.add(key) | |
retrieved_data.append({ | |
"ticker": ticker, | |
"metric": metric, | |
"value": value, | |
"year": year, | |
"combined_score": combined_score | |
}) | |
if retrieved_data: | |
# print(retrieved_data) | |
retrieved_data.sort(key=lambda x: x["combined_score"], reverse=True) | |
best_match = retrieved_data[0] | |
answer = answer_question(query, best_match) | |
return answer | |
return "No relevant data found." | |
def answer_question(question, retrieved_data): | |
""" | |
Use a lightweight LLM to generate a natural-language answer on CPU. | |
Args: | |
question (str): The question to answer | |
retrieved_data (list): List of dictionaries with ticker, metric, value, year | |
Returns: | |
str: Natural-language answer | |
""" | |
# print(question) | |
# print(retrieved_data) | |
try: | |
# Initialize lightweight LLM (llama3.2:3b, CPU-friendly) | |
llm = Ollama(model="gemma:2b", num_gpu=0) # Explicitly disable GPU | |
# Minimal prompt for CPU efficiency | |
prompt_template = PromptTemplate( | |
input_variables=["question", "ticker", "metric", "value", "year"], | |
template=( | |
"Question: {question}\n" | |
"Data: Ticker={ticker}, Metric={metric}, Value={value}, Year={year}\n" | |
"Answer concisely, formatting the value with commas." | |
) | |
) | |
# print(prompt_template) | |
# Format data | |
if not retrieved_data: | |
return "No relevant data found." | |
prompt = prompt_template.format( | |
question=question, | |
ticker=retrieved_data['ticker'], | |
metric=retrieved_data['metric'], | |
value=retrieved_data, # formatted_value, | |
year=retrieved_data['year'] | |
) | |
# Generate response | |
response = llm.invoke(prompt) | |
return response.strip() | |
except Exception as e: | |
print(f"Error generating answer: {e}") | |
return "Unable to generate answer." | |
# def main(file_path, query, question): | |
# """ | |
# Main function to process a query, retrieve results, and answer a question. | |
# Args: | |
# file_path (str): Path to the CSV or Excel file | |
# query (str): Query string in 'ticker metric year' format | |
# question (str): Natural-language question to answer | |
# Returns: | |
# tuple: (retrieved data, answer) | |
# """ | |
# try: | |
# retriever = Retriever(file_path) | |
# results = retriever.retrieve(query) | |
# answer = answer_question(question, results) | |
# return results, answer | |
# except Exception as e: | |
# print(f"Error processing query: {e}") | |
# return [], "Unable to process query." | |
# if __name__ == "__main__": | |
# file_path = "./financial_data.csv" | |
# query = "AAPL InterestExpense 2024" | |
# question = "What is the InterestExpense of AAPL 2024?" | |
# results, answer = main(file_path, query, question) | |
# for result in results: | |
# print(f"Ticker: {result['ticker']}, Metric: {result['metric']}, Value: {result['value']}, Year: {result['year']}") | |
# print(f"Answer: {answer}") |