File size: 7,557 Bytes
a2c10b6
 
 
ca4d8ec
 
a2c10b6
c29e353
 
a2c10b6
 
 
 
 
 
 
 
c29e353
 
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec34f15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca4d8ec
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca4d8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
a2c10b6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# main.py
import asyncio
import importlib
from datetime import datetime
import requests
from voice.speech_to_text import SpeechToText
from voice.intent_classifier import IntentClassifier
# from voice.classifier import TextClassifier
from api.endpoints import FMPEndpoints
from rag.retriever import Retriever
from rag.sql_db import SQL_Key_Pair
from rag.web_search import duckduckgo_web_search

async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
    # Step 1: Initialize components
    stt = SpeechToText(model_path=vosk_model_path)
    classifier = IntentClassifier()
    # classifier = TextClassifier()
    endpoints = FMPEndpoints()
    # initialize rag tools
    retriever = Retriever(file_path="./data/financial_data.csv")
    sql_db = SQL_Key_Pair(file_path="./data/financial_data.csv")

    # Output format
    output = {
        "User asked": "",
        "intent": "",
        "entities": "",
        "base_response": "",
        "retriever_response": "",
        "web_search_response": "",
        "final_response": "",
        "error": ""
    }

    try:
        # Step 2: Process input (text or audio)
        if audio_data:
            text = stt.transcribe_audio(audio_data)
            if not text:
                output["error"] = "Could not understand the audio."
                return output
        elif query_text:
            text = query_text
        else:
            output["error"] = "No audio or text query provided."
            return output

        output["User asked"] = text

        # Step 3: Classify intent (zero-shot) and extract entities
        intent = classifier.classify_with_llm(text)
        output["intent"] = intent if intent else "Could not classify intent."

        entities = classifier.extract_entities(text)
        output["entities"] = str(entities)

        if intent:
            intent_to_module = {
                "net_income": ("modules.get_net_income", "GetNetIncome"),
                "revenue": ("modules.get_revenue", "GetRevenue"),
                "stock_price": ("modules.get_stock_price", "GetStockPrice"),
                "profit_margin": ("modules.get_profit_margin", "GetProfitMargin"),
                "company_info": ("modules.get_company_profile", "GetCompanyProfile"),
                "market_capitalization": ("modules.get_market_cap", "GetMarketCap"),
                "historical_stock_price": ("modules.get_historical_stock_price", "GetHistoricalStockPrice"),
                "dividend_info": ("modules.get_dividend_info", "GetDividendInfo"),
                "balance_sheet": ("modules.get_balance_sheet", "GetBalanceSheet"),
                "cash_flow": ("modules.get_cash_flow", "GetCashFlow"),
                "financial_ratios": ("modules.get_financial_ratios", "GetFinancialRatios"),
                "earnings_per_share": ("modules.get_earnings_per_share", "GetEarningsPerShare"),
                "interest_rate": ("modules.get_interest", "GetInterest"),
                "income_tax": ("modules.get_income_tax", "GetIncomeTax"),
                "cost_info": ("modules.get_cost_info", "GetCostInfo"),
                "research_info": ("modules.get_research_info", "GetResearchInfo")
                
            }

            # Identify module for API calling
            module_info = intent_to_module.get(intent)
            if module_info:
                module_path, class_name = module_info
                try:
                    module = importlib.import_module(module_path)
                    class_instance = getattr(module, class_name)()
                    ticker = entities["ticker"]

                    # Step 4: Get the base response from the module
                    base_response = None
                    try:
                        base_response = await class_instance.get_data(
                            ticker=ticker,
                            year=entities["year"],
                            date=entities["date"],
                        )
                    except Exception as e:
                        base_response = f"Error fetching base response: {e}"
                    

                    # Step 5: Handle the response based on requirements
                    final_response = None
                    if base_response and "Error" not in str(base_response) and "None" not in str(base_response):
                        # Base response succeeded
                        final_response = base_response
                        output["base_response"] = f"{final_response}"
                        
                        # Use retriever if specified (optional)
                        if use_retriever:
                            # retriever_response = retriever.retrieve(text, entities)
                            # retriever_response = sql_db.entity_based_query(entities)
                            retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
                            final_response = f"{final_response} Additional Info found in the CSV: {retriever_response}"
                            output["retriever_response"] = retriever_response
                    else:
                        # Base response failed, use the retriever
                        output["base_response"] = f"{base_response} Using retriever to query CSV file..."
                        # retriever_response = retriever.retrieve(text, entities)
                        # retriever_response = sql_db.keyword_match_search(entities)
                        retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
                        output["retriever_response"] = retriever_response

                        if "No relevant data found" in retriever_response:
                            # If both API and rag failed to extract information, search on the web
                            search_results = duckduckgo_web_search(text)
                            if search_results:
                                output["web_search_response"] = search_results[0]['snippet']
                                final_response = search_results[0]['snippet']
                            else:
                                output["web_search_response"] = "No relevant data found on the web."
                                final_response = "No relevant data found on the web."
                        else:
                            final_response = retriever_response

                    output["final_response"] = final_response
                except ImportError as e:
                    output["error"] = f"Module import error: {e}"
                except AttributeError as e:
                    output["error"] = f"Class not found in module: {e}"
                except Exception as e:
                    output["error"] = f"Error processing intent {intent}: {e}"
            else:
                output["error"] = f"Unsupported intent: {intent}"
        else:
            output["error"] = "Could not classify intent."

    except Exception as e:
        output["error"] = f"Unexpected error: {e}"


    # Current Time
    now = datetime.now()
    print("Current Time:", now.strftime("%Y-%m-%d %H:%M:%S"))

    # Location Info
    try:
        response = requests.get("https://ipinfo.io")
        data = response.json()
        print("Location:", data.get("city"), data.get("region"), data.get("country"))

    except Exception as e:
        print("Could not fetch location:", e)

    # Return output to the User Interface
    return output