rasulbrur's picture
Added updated intent
ca4d8ec
# main.py
import asyncio
import importlib
from datetime import datetime
import requests
from voice.speech_to_text import SpeechToText
from voice.intent_classifier import IntentClassifier
# from voice.classifier import TextClassifier
from api.endpoints import FMPEndpoints
from rag.retriever import Retriever
from rag.sql_db import SQL_Key_Pair
from rag.web_search import duckduckgo_web_search
async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
# Step 1: Initialize components
stt = SpeechToText(model_path=vosk_model_path)
classifier = IntentClassifier()
# classifier = TextClassifier()
endpoints = FMPEndpoints()
# initialize rag tools
retriever = Retriever(file_path="./data/financial_data.csv")
sql_db = SQL_Key_Pair(file_path="./data/financial_data.csv")
# Output format
output = {
"User asked": "",
"intent": "",
"entities": "",
"base_response": "",
"retriever_response": "",
"web_search_response": "",
"final_response": "",
"error": ""
}
try:
# Step 2: Process input (text or audio)
if audio_data:
text = stt.transcribe_audio(audio_data)
if not text:
output["error"] = "Could not understand the audio."
return output
elif query_text:
text = query_text
else:
output["error"] = "No audio or text query provided."
return output
output["User asked"] = text
# Step 3: Classify intent (zero-shot) and extract entities
intent = classifier.classify_with_llm(text)
output["intent"] = intent if intent else "Could not classify intent."
entities = classifier.extract_entities(text)
output["entities"] = str(entities)
if intent:
intent_to_module = {
"net_income": ("modules.get_net_income", "GetNetIncome"),
"revenue": ("modules.get_revenue", "GetRevenue"),
"stock_price": ("modules.get_stock_price", "GetStockPrice"),
"profit_margin": ("modules.get_profit_margin", "GetProfitMargin"),
"company_info": ("modules.get_company_profile", "GetCompanyProfile"),
"market_capitalization": ("modules.get_market_cap", "GetMarketCap"),
"historical_stock_price": ("modules.get_historical_stock_price", "GetHistoricalStockPrice"),
"dividend_info": ("modules.get_dividend_info", "GetDividendInfo"),
"balance_sheet": ("modules.get_balance_sheet", "GetBalanceSheet"),
"cash_flow": ("modules.get_cash_flow", "GetCashFlow"),
"financial_ratios": ("modules.get_financial_ratios", "GetFinancialRatios"),
"earnings_per_share": ("modules.get_earnings_per_share", "GetEarningsPerShare"),
"interest_rate": ("modules.get_interest", "GetInterest"),
"income_tax": ("modules.get_income_tax", "GetIncomeTax"),
"cost_info": ("modules.get_cost_info", "GetCostInfo"),
"research_info": ("modules.get_research_info", "GetResearchInfo")
}
# Identify module for API calling
module_info = intent_to_module.get(intent)
if module_info:
module_path, class_name = module_info
try:
module = importlib.import_module(module_path)
class_instance = getattr(module, class_name)()
ticker = entities["ticker"]
# Step 4: Get the base response from the module
base_response = None
try:
base_response = await class_instance.get_data(
ticker=ticker,
year=entities["year"],
date=entities["date"],
)
except Exception as e:
base_response = f"Error fetching base response: {e}"
# Step 5: Handle the response based on requirements
final_response = None
if base_response and "Error" not in str(base_response) and "None" not in str(base_response):
# Base response succeeded
final_response = base_response
output["base_response"] = f"{final_response}"
# Use retriever if specified (optional)
if use_retriever:
# retriever_response = retriever.retrieve(text, entities)
# retriever_response = sql_db.entity_based_query(entities)
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
final_response = f"{final_response} Additional Info found in the CSV: {retriever_response}"
output["retriever_response"] = retriever_response
else:
# Base response failed, use the retriever
output["base_response"] = f"{base_response} Using retriever to query CSV file..."
# retriever_response = retriever.retrieve(text, entities)
# retriever_response = sql_db.keyword_match_search(entities)
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
output["retriever_response"] = retriever_response
if "No relevant data found" in retriever_response:
# If both API and rag failed to extract information, search on the web
search_results = duckduckgo_web_search(text)
if search_results:
output["web_search_response"] = search_results[0]['snippet']
final_response = search_results[0]['snippet']
else:
output["web_search_response"] = "No relevant data found on the web."
final_response = "No relevant data found on the web."
else:
final_response = retriever_response
output["final_response"] = final_response
except ImportError as e:
output["error"] = f"Module import error: {e}"
except AttributeError as e:
output["error"] = f"Class not found in module: {e}"
except Exception as e:
output["error"] = f"Error processing intent {intent}: {e}"
else:
output["error"] = f"Unsupported intent: {intent}"
else:
output["error"] = "Could not classify intent."
except Exception as e:
output["error"] = f"Unexpected error: {e}"
# Current Time
now = datetime.now()
print("Current Time:", now.strftime("%Y-%m-%d %H:%M:%S"))
# Location Info
try:
response = requests.get("https://ipinfo.io")
data = response.json()
print("Location:", data.get("city"), data.get("region"), data.get("country"))
except Exception as e:
print("Could not fetch location:", e)
# Return output to the User Interface
return output