Spaces:
Running
Running
# main.py | |
import asyncio | |
import importlib | |
from datetime import datetime | |
import requests | |
from voice.speech_to_text import SpeechToText | |
from voice.intent_classifier import IntentClassifier | |
# from voice.classifier import TextClassifier | |
from api.endpoints import FMPEndpoints | |
from rag.retriever import Retriever | |
from rag.sql_db import SQL_Key_Pair | |
from rag.web_search import duckduckgo_web_search | |
async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False): | |
# Step 1: Initialize components | |
stt = SpeechToText(model_path=vosk_model_path) | |
classifier = IntentClassifier() | |
# classifier = TextClassifier() | |
endpoints = FMPEndpoints() | |
# initialize rag tools | |
retriever = Retriever(file_path="./data/financial_data.csv") | |
sql_db = SQL_Key_Pair(file_path="./data/financial_data.csv") | |
# Output format | |
output = { | |
"User asked": "", | |
"intent": "", | |
"entities": "", | |
"base_response": "", | |
"retriever_response": "", | |
"web_search_response": "", | |
"final_response": "", | |
"error": "" | |
} | |
try: | |
# Step 2: Process input (text or audio) | |
if audio_data: | |
text = stt.transcribe_audio(audio_data) | |
if not text: | |
output["error"] = "Could not understand the audio." | |
return output | |
elif query_text: | |
text = query_text | |
else: | |
output["error"] = "No audio or text query provided." | |
return output | |
output["User asked"] = text | |
# Step 3: Classify intent (zero-shot) and extract entities | |
intent = classifier.classify_with_llm(text) | |
output["intent"] = intent if intent else "Could not classify intent." | |
entities = classifier.extract_entities(text) | |
output["entities"] = str(entities) | |
if intent: | |
intent_to_module = { | |
"net_income": ("modules.get_net_income", "GetNetIncome"), | |
"revenue": ("modules.get_revenue", "GetRevenue"), | |
"stock_price": ("modules.get_stock_price", "GetStockPrice"), | |
"profit_margin": ("modules.get_profit_margin", "GetProfitMargin"), | |
"company_info": ("modules.get_company_profile", "GetCompanyProfile"), | |
"market_capitalization": ("modules.get_market_cap", "GetMarketCap"), | |
"historical_stock_price": ("modules.get_historical_stock_price", "GetHistoricalStockPrice"), | |
"dividend_info": ("modules.get_dividend_info", "GetDividendInfo"), | |
"balance_sheet": ("modules.get_balance_sheet", "GetBalanceSheet"), | |
"cash_flow": ("modules.get_cash_flow", "GetCashFlow"), | |
"financial_ratios": ("modules.get_financial_ratios", "GetFinancialRatios"), | |
"earnings_per_share": ("modules.get_earnings_per_share", "GetEarningsPerShare"), | |
"interest_rate": ("modules.get_interest", "GetInterest"), | |
"income_tax": ("modules.get_income_tax", "GetIncomeTax"), | |
"cost_info": ("modules.get_cost_info", "GetCostInfo"), | |
"research_info": ("modules.get_research_info", "GetResearchInfo") | |
} | |
# Identify module for API calling | |
module_info = intent_to_module.get(intent) | |
if module_info: | |
module_path, class_name = module_info | |
try: | |
module = importlib.import_module(module_path) | |
class_instance = getattr(module, class_name)() | |
ticker = entities["ticker"] | |
# Step 4: Get the base response from the module | |
base_response = None | |
try: | |
base_response = await class_instance.get_data( | |
ticker=ticker, | |
year=entities["year"], | |
date=entities["date"], | |
) | |
except Exception as e: | |
base_response = f"Error fetching base response: {e}" | |
# Step 5: Handle the response based on requirements | |
final_response = None | |
if base_response and "Error" not in str(base_response) and "None" not in str(base_response): | |
# Base response succeeded | |
final_response = base_response | |
output["base_response"] = f"{final_response}" | |
# Use retriever if specified (optional) | |
if use_retriever: | |
# retriever_response = retriever.retrieve(text, entities) | |
# retriever_response = sql_db.entity_based_query(entities) | |
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"]) | |
final_response = f"{final_response} Additional Info found in the CSV: {retriever_response}" | |
output["retriever_response"] = retriever_response | |
else: | |
# Base response failed, use the retriever | |
output["base_response"] = f"{base_response} Using retriever to query CSV file..." | |
# retriever_response = retriever.retrieve(text, entities) | |
# retriever_response = sql_db.keyword_match_search(entities) | |
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"]) | |
output["retriever_response"] = retriever_response | |
if "No relevant data found" in retriever_response: | |
# If both API and rag failed to extract information, search on the web | |
search_results = duckduckgo_web_search(text) | |
if search_results: | |
output["web_search_response"] = search_results[0]['snippet'] | |
final_response = search_results[0]['snippet'] | |
else: | |
output["web_search_response"] = "No relevant data found on the web." | |
final_response = "No relevant data found on the web." | |
else: | |
final_response = retriever_response | |
output["final_response"] = final_response | |
except ImportError as e: | |
output["error"] = f"Module import error: {e}" | |
except AttributeError as e: | |
output["error"] = f"Class not found in module: {e}" | |
except Exception as e: | |
output["error"] = f"Error processing intent {intent}: {e}" | |
else: | |
output["error"] = f"Unsupported intent: {intent}" | |
else: | |
output["error"] = "Could not classify intent." | |
except Exception as e: | |
output["error"] = f"Unexpected error: {e}" | |
# Current Time | |
now = datetime.now() | |
print("Current Time:", now.strftime("%Y-%m-%d %H:%M:%S")) | |
# Location Info | |
try: | |
response = requests.get("https://ipinfo.io") | |
data = response.json() | |
print("Location:", data.get("city"), data.get("region"), data.get("country")) | |
except Exception as e: | |
print("Could not fetch location:", e) | |
# Return output to the User Interface | |
return output | |