rasulbrur commited on
Commit
a2c10b6
·
1 Parent(s): d5b663d

Added files initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. Dockerfile +41 -0
  3. api/__pycache__/endpoints.cpython-310.pyc +0 -0
  4. api/endpoints.py +151 -0
  5. app.py +157 -0
  6. data/financial data sp500 companies.csv +0 -0
  7. data/financial_data.csv +0 -0
  8. important.txt +11 -0
  9. main.py +147 -0
  10. modules/__pycache__/financial_query.cpython-310.pyc +0 -0
  11. modules/__pycache__/get_balance_sheet.cpython-310.pyc +0 -0
  12. modules/__pycache__/get_cash_flow.cpython-310.pyc +0 -0
  13. modules/__pycache__/get_company_profile.cpython-310.pyc +0 -0
  14. modules/__pycache__/get_financial_ratios.cpython-310.pyc +0 -0
  15. modules/__pycache__/get_income_statement.cpython-310.pyc +0 -0
  16. modules/__pycache__/get_income_tax.cpython-310.pyc +0 -0
  17. modules/__pycache__/get_interest.cpython-310.pyc +0 -0
  18. modules/__pycache__/get_market_cap.cpython-310.pyc +0 -0
  19. modules/__pycache__/get_net_income.cpython-310.pyc +0 -0
  20. modules/__pycache__/get_profit_margin.cpython-310.pyc +0 -0
  21. modules/__pycache__/get_research_info.cpython-310.pyc +0 -0
  22. modules/__pycache__/get_revenue.cpython-310.pyc +0 -0
  23. modules/__pycache__/get_stock_price.cpython-310.pyc +0 -0
  24. modules/get_balance_sheet.py +15 -0
  25. modules/get_cash_flow.py +14 -0
  26. modules/get_company_profile.py +15 -0
  27. modules/get_cost_info.py +11 -0
  28. modules/get_divident_info.py +14 -0
  29. modules/get_earnings_per_share.py +14 -0
  30. modules/get_financial_ratios.py +14 -0
  31. modules/get_historical_stock_price.py +14 -0
  32. modules/get_income_tax.py +11 -0
  33. modules/get_interest.py +11 -0
  34. modules/get_market_cap.py +14 -0
  35. modules/get_net_income.py +14 -0
  36. modules/get_profit_margin.py +14 -0
  37. modules/get_research_info.py +11 -0
  38. modules/get_revenue.py +14 -0
  39. modules/get_stock_price.py +14 -0
  40. rag/__pycache__/embedder.cpython-310.pyc +0 -0
  41. rag/__pycache__/retriever.cpython-310.pyc +0 -0
  42. rag/__pycache__/sql_db.cpython-310.pyc +0 -0
  43. rag/__pycache__/web_search.cpython-310.pyc +0 -0
  44. rag/embedder.py +20 -0
  45. rag/graphrag.py +72 -0
  46. rag/retriever.py +202 -0
  47. rag/sql_db.py +171 -0
  48. rag/web_search.py +12 -0
  49. repo.jpg +0 -0
  50. requirements.txt +69 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies
5
+ RUN apt-get update && \
6
+ apt-get install -y wget unzip curl gcc portaudio19-dev && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy your code into the container
13
+ COPY . .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --upgrade pip && pip install -r requirements.txt
17
+
18
+
19
+
20
+ # Install spaCy model
21
+ # RUN python -m spacy download en_core_web_lg
22
+ RUN python -m spacy download en_core_web_sm
23
+
24
+ # Download and unzip the Vosk model
25
+ RUN wget https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip && \
26
+ unzip vosk-model-small-en-us-0.15.zip && \
27
+ rm vosk-model-small-en-us-0.15.zip
28
+
29
+ # Install Ollama
30
+ # RUN curl -fsSL https://ollama.com/install.sh | sh
31
+
32
+ # Pull the Ollama model
33
+ # RUN ollama serve & sleep 5 && ollama pull gemma:2b
34
+
35
+ # Expose the port FastAPI will run on
36
+ EXPOSE 7860
37
+
38
+ # Start the API
39
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
40
+
41
+ # CMD uvicorn app:app --host 0.0.0.0 --port $PORT
api/__pycache__/endpoints.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
api/endpoints.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/endpoints.py
2
+ import httpx
3
+ import os
4
+
5
+ class FMPEndpoints:
6
+ def __init__(self):
7
+ # self.db = FinancialDB()
8
+ self.fmp_api_key = os.getenv("FMP_API_KEY")
9
+ # print(self.fmp_api_key)
10
+ self.base_url = "https://financialmodelingprep.com/api/v3"
11
+
12
+ async def get_income_statement(self, ticker, year=None, period="annual", limit=1):
13
+ """
14
+ Fetch income statement data for a given ticker.
15
+ """
16
+ endpoint = f"{self.base_url}/income-statement/{ticker}"
17
+ params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
18
+ if year:
19
+ params["year"] = year
20
+ try:
21
+ async with httpx.AsyncClient() as client:
22
+ response = await client.get(endpoint, params=params)
23
+ response.raise_for_status()
24
+ return response.json()
25
+ except httpx.HTTPStatusError as e:
26
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
27
+ except Exception as e:
28
+ raise Exception(f"Error fetching income statement: {e}")
29
+
30
+ async def get_quote_short(self, ticker):
31
+ """
32
+ Fetch the current stock price (short quote) for a given ticker.
33
+ """
34
+ endpoint = f"{self.base_url}/quote-short/{ticker}"
35
+ params = {"apikey": self.fmp_api_key}
36
+ try:
37
+ async with httpx.AsyncClient() as client:
38
+ response = await client.get(endpoint, params=params)
39
+ response.raise_for_status()
40
+ return response.json()
41
+ except httpx.HTTPStatusError as e:
42
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
43
+ except Exception as e:
44
+ raise Exception(f"Error fetching quote: {e}")
45
+
46
+ async def get_ratios(self, ticker, year=None, limit=1):
47
+ """
48
+ Fetch financial ratios for a given ticker.
49
+ """
50
+ endpoint = f"{self.base_url}/ratios/{ticker}"
51
+ params = {"apikey": self.fmp_api_key, "limit": limit}
52
+ if year:
53
+ params["year"] = year
54
+ try:
55
+ async with httpx.AsyncClient() as client:
56
+ response = await client.get(endpoint, params=params)
57
+ response.raise_for_status()
58
+ return response.json()
59
+ except httpx.HTTPStatusError as e:
60
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
61
+ except Exception as e:
62
+ raise Exception(f"Error fetching ratios: {e}")
63
+
64
+ async def get_profile(self, ticker):
65
+ """
66
+ Fetch company profile data for a given ticker.
67
+ """
68
+ endpoint = f"{self.base_url}/profile/{ticker}"
69
+ params = {"apikey": self.fmp_api_key}
70
+ try:
71
+ async with httpx.AsyncClient() as client:
72
+ response = await client.get(endpoint, params=params)
73
+ response.raise_for_status()
74
+ return response.json()
75
+ except httpx.HTTPStatusError as e:
76
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
77
+ except Exception as e:
78
+ raise Exception(f"Error fetching profile: {e}")
79
+
80
+ async def get_historical_price(self, ticker, date=None):
81
+ """
82
+ Fetch historical stock price for a given ticker on a specific date.
83
+ """
84
+ endpoint = f"{self.base_url}/historical-price-full/{ticker}"
85
+ params = {"apikey": self.fmp_api_key}
86
+ if date:
87
+ params["from"] = date
88
+ params["to"] = date
89
+ try:
90
+ async with httpx.AsyncClient() as client:
91
+ response = await client.get(endpoint, params=params)
92
+ response.raise_for_status()
93
+ return response.json()
94
+ except httpx.HTTPStatusError as e:
95
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
96
+ except Exception as e:
97
+ raise Exception(f"Error fetching historical price: {e}")
98
+
99
+ async def get_balance_sheet(self, ticker, year=None, period="annual", limit=1):
100
+ """
101
+ Fetch balance sheet data for a given ticker.
102
+ """
103
+ endpoint = f"{self.base_url}/balance-sheet-statement/{ticker}"
104
+ params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
105
+ if year:
106
+ params["year"] = year
107
+ try:
108
+ async with httpx.AsyncClient() as client:
109
+ response = await client.get(endpoint, params=params)
110
+ response.raise_for_status()
111
+ return response.json()
112
+ except httpx.HTTPStatusError as e:
113
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
114
+ except Exception as e:
115
+ raise Exception(f"Error fetching balance sheet: {e}")
116
+
117
+ async def get_cash_flow(self, ticker, year=None, period="annual", limit=1):
118
+ """
119
+ Fetch cash flow statement data for a given ticker.
120
+ """
121
+ endpoint = f"{self.base_url}/cash-flow-statement/{ticker}"
122
+ params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
123
+ if year:
124
+ params["year"] = year
125
+ try:
126
+ async with httpx.AsyncClient() as client:
127
+ response = await client.get(endpoint, params=params)
128
+ response.raise_for_status()
129
+ return response.json()
130
+ except httpx.HTTPStatusError as e:
131
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
132
+ except Exception as e:
133
+ raise Exception(f"Error fetching cash flow: {e}")
134
+
135
+ async def get_key_metrics(self, ticker, year=None, limit=1):
136
+ """
137
+ Fetch key metrics (e.g., EPS) for a given ticker.
138
+ """
139
+ endpoint = f"{self.base_url}/key-metrics/{ticker}"
140
+ params = {"apikey": self.fmp_api_key, "limit": limit}
141
+ if year:
142
+ params["year"] = year
143
+ try:
144
+ async with httpx.AsyncClient() as client:
145
+ response = await client.get(endpoint, params=params)
146
+ response.raise_for_status()
147
+ return response.json()
148
+ except httpx.HTTPStatusError as e:
149
+ raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
150
+ except Exception as e:
151
+ raise Exception(f"Error fetching key metrics: {e}")
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, Request, Form
3
+ from fastapi.responses import HTMLResponse, JSONResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ from main import process_query
7
+ from voice.speech_to_text import SpeechToText
8
+ import os
9
+ import asyncio
10
+ import pyaudio
11
+ import wave
12
+ import logging
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ app = FastAPI()
19
+
20
+ # Mount static files (for CSS, JS, etc.)
21
+ app.mount("/static", StaticFiles(directory="static"), name="static")
22
+
23
+ # Set up templates
24
+ templates = Jinja2Templates(directory="templates")
25
+
26
+ # Vosk model path and audio file path
27
+ vosk_model_path = "./vosk-model-small-en-us-0.15"
28
+ audio_file_path = "voice/temp_audio.wav"
29
+
30
+ # Ensure the voice directory exists
31
+ os.makedirs("voice", exist_ok=True)
32
+
33
+ # Initialize SpeechToText
34
+ stt = SpeechToText(model_path=vosk_model_path)
35
+
36
+ # Global variables for recording state
37
+ recording = False
38
+ audio_frames = []
39
+ recording_task = None
40
+
41
+ def save_audio_to_wav(frames, sample_rate=16000):
42
+ """Save audio frames to a WAV file."""
43
+ try:
44
+ logger.info(f"Saving audio to {audio_file_path} with {len(frames)} frames")
45
+ wf = wave.open(audio_file_path, 'wb')
46
+ wf.setnchannels(1)
47
+ wf.setsampwidth(2) # 16-bit
48
+ wf.setframerate(sample_rate)
49
+ wf.writeframes(b''.join(frames))
50
+ wf.close()
51
+ if os.path.exists(audio_file_path):
52
+ logger.info(f"WAV file saved successfully: {os.path.getsize(audio_file_path)} bytes")
53
+ else:
54
+ logger.error("WAV file was not created")
55
+ except Exception as e:
56
+ logger.error(f"Error saving WAV file: {str(e)}")
57
+ raise
58
+
59
+ async def record_audio():
60
+ """Background task to record audio."""
61
+ global audio_frames
62
+ p = pyaudio.PyAudio()
63
+ try:
64
+ stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024)
65
+ stream.start_stream()
66
+ logger.info("Recording started...")
67
+
68
+ while recording:
69
+ data = stream.read(1024, exception_on_overflow=False)
70
+ audio_frames.append(data)
71
+ await asyncio.sleep(0.01) # Small sleep to prevent blocking
72
+
73
+ stream.stop_stream()
74
+ stream.close()
75
+ logger.info(f"Recording stopped, captured {len(audio_frames)} frames")
76
+ except Exception as e:
77
+ logger.error(f"Error during recording: {str(e)}")
78
+ finally:
79
+ p.terminate()
80
+
81
+ @app.get("/", response_class=HTMLResponse)
82
+ async def get_index(request: Request):
83
+ return templates.TemplateResponse("index.html", {"request": request})
84
+
85
+ @app.post("/start_recording", response_class=JSONResponse)
86
+ async def start_recording():
87
+ global recording, audio_frames, recording_task
88
+ if not recording:
89
+ recording = True
90
+ audio_frames = []
91
+ recording_task = asyncio.create_task(record_audio())
92
+ logger.info("Started recording task")
93
+ return {"status": "Recording started"}
94
+ logger.warning("Recording already in progress")
95
+ return {"status": "Already recording"}
96
+
97
+ @app.post("/stop_recording", response_class=HTMLResponse)
98
+ async def stop_recording(request: Request):
99
+ global recording, recording_task
100
+ if recording:
101
+ recording = False
102
+ if recording_task:
103
+ await recording_task # Wait for the recording task to complete
104
+ recording_task = None
105
+
106
+ # Save the audio to WAV
107
+ try:
108
+ save_audio_to_wav(audio_frames)
109
+ except Exception as e:
110
+ logger.error(f"Failed to save audio: {str(e)}")
111
+ return templates.TemplateResponse("index.html", {
112
+ "request": request,
113
+ "error": f"Failed to save audio: {str(e)}"
114
+ })
115
+
116
+ # Transcribe the saved audio
117
+ try:
118
+ text = stt.transcribe_audio(audio_file_path)
119
+ logger.info(f"Transcription result: '{text}'")
120
+ if not text:
121
+ logger.warning("Transcription returned no text")
122
+ return templates.TemplateResponse("index.html", {
123
+ "request": request,
124
+ "error": "Could not understand the audio."
125
+ })
126
+ return templates.TemplateResponse("index.html", {
127
+ "request": request,
128
+ "transcribed_text": text
129
+ })
130
+ except Exception as e:
131
+ logger.error(f"Transcription error: {str(e)}")
132
+ return templates.TemplateResponse("index.html", {
133
+ "request": request,
134
+ "error": f"Transcription error: {str(e)}"
135
+ })
136
+ logger.warning("No recording in progress")
137
+ return templates.TemplateResponse("index.html", {
138
+ "request": request,
139
+ "error": "No recording in progress."
140
+ })
141
+
142
+ @app.post("/query", response_class=HTMLResponse)
143
+ async def handle_query(request: Request, query_text: str = Form(...), use_retriever: str = Form("no")):
144
+ use_retriever = use_retriever.lower() in ["yes", "y"]
145
+ result = await process_query(vosk_model_path, query_text=query_text, use_retriever=use_retriever)
146
+
147
+ return templates.TemplateResponse("index.html", {
148
+ "request": request,
149
+ "User_Query": query_text,
150
+ "Intent": result["intent"],
151
+ "Entities": result["entities"],
152
+ "API_Response": result["base_response"],
153
+ "RAG_Response": result["retriever_response"],
154
+ "Web_Search_Response": result["web_search_response"],
155
+ "Final_Response": result["final_response"],
156
+ "Error": result["error"]
157
+ })
data/financial data sp500 companies.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/financial_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
important.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1. The full project is inside the codebase folder. To run the project,
2
+ Create conda env using “conda create -n env_name python=3.10”
3
+ Activate the env using “conda activate env_name”
4
+ Install requirements.txt using, “pip install -r requirements.txt”
5
+ Install the bash scripts using “source env_variable.sh” and “source setup.sh”
6
+ Run “uvicorn app:app –reload”
7
+
8
+ 2. The “Documentation of AI Finance Accountant Agent” have full overview of the project explaining key features, required tools, data sources etc. I attached some testing examples at the end of the documentation file.
9
+
10
+ 3. The “demo_video.mkv” file is a demo of how my api agent is working. Carefully watch the video.
11
+
main.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import asyncio
3
+ import importlib
4
+ from voice.speech_to_text import SpeechToText
5
+ from voice.intent_classifier import IntentClassifier
6
+ from api.endpoints import FMPEndpoints
7
+ from rag.retriever import Retriever
8
+ from rag.sql_db import SQL_Key_Pair
9
+ from rag.web_search import duckduckgo_web_search
10
+
11
+ async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
12
+ # Step 1: Initialize components
13
+ stt = SpeechToText(model_path=vosk_model_path)
14
+ classifier = IntentClassifier()
15
+ endpoints = FMPEndpoints()
16
+ # initialize rag tools
17
+ retriever = Retriever(file_path="./data/financial_data.csv")
18
+ sql_db = SQL_Key_Pair(file_path="./data/financial_data.csv")
19
+
20
+ # Output format
21
+ output = {
22
+ "User asked": "",
23
+ "intent": "",
24
+ "entities": "",
25
+ "base_response": "",
26
+ "retriever_response": "",
27
+ "web_search_response": "",
28
+ "final_response": "",
29
+ "error": ""
30
+ }
31
+
32
+ try:
33
+ # Step 2: Process input (text or audio)
34
+ if audio_data:
35
+ text = stt.transcribe_audio(audio_data)
36
+ if not text:
37
+ output["error"] = "Could not understand the audio."
38
+ return output
39
+ elif query_text:
40
+ text = query_text
41
+ else:
42
+ output["error"] = "No audio or text query provided."
43
+ return output
44
+
45
+ output["User asked"] = text
46
+
47
+ # Step 3: Classify intent (zero-shot) and extract entities
48
+ intent = classifier.classify_with_llm(text)
49
+ output["intent"] = intent if intent else "Could not classify intent."
50
+
51
+ entities = classifier.extract_entities(text)
52
+ output["entities"] = str(entities)
53
+
54
+ if intent:
55
+ intent_to_module = {
56
+ "get_net_income": ("modules.get_net_income", "GetNetIncome"),
57
+ "get_revenue": ("modules.get_revenue", "GetRevenue"),
58
+ "get_stock_price": ("modules.get_stock_price", "GetStockPrice"),
59
+ "get_profit_margin": ("modules.get_profit_margin", "GetProfitMargin"),
60
+ "get_company_profile": ("modules.get_company_profile", "GetCompanyProfile"),
61
+ "get_market_cap": ("modules.get_market_cap", "GetMarketCap"),
62
+ "get_historical_stock_price": ("modules.get_historical_stock_price", "GetHistoricalStockPrice"),
63
+ "get_dividend_info": ("modules.get_dividend_info", "GetDividendInfo"),
64
+ "get_balance_sheet": ("modules.get_balance_sheet", "GetBalanceSheet"),
65
+ "get_cash_flow": ("modules.get_cash_flow", "GetCashFlow"),
66
+ "get_financial_ratios": ("modules.get_financial_ratios", "GetFinancialRatios"),
67
+ "get_earnings_per_share": ("modules.get_earnings_per_share", "GetEarningsPerShare"),
68
+ "get_interest": ("modules.get_interest", "GetInterest"),
69
+ "get_income_tax": ("modules.get_income_tax", "GetIncomeTax"),
70
+ "get_cost_info": ("modules.get_cost_info", "GetCostInfo"),
71
+ "get_research_info": ("modules.get_research_info", "GetResearchInfo")
72
+
73
+ }
74
+
75
+ # Identify module for API calling
76
+ module_info = intent_to_module.get(intent)
77
+ if module_info:
78
+ module_path, class_name = module_info
79
+ try:
80
+ module = importlib.import_module(module_path)
81
+ class_instance = getattr(module, class_name)()
82
+ ticker = entities["ticker"]
83
+
84
+ # Step 4: Get the base response from the module
85
+ base_response = None
86
+ try:
87
+ base_response = await class_instance.get_data(
88
+ ticker=ticker,
89
+ year=entities["year"],
90
+ date=entities["date"],
91
+ )
92
+ except Exception as e:
93
+ base_response = f"Error fetching base response: {e}"
94
+
95
+
96
+ # Step 5: Handle the response based on requirements
97
+ final_response = None
98
+ if base_response and "Error" not in str(base_response) and "None" not in str(base_response):
99
+ # Base response succeeded
100
+ final_response = base_response
101
+ output["base_response"] = f"{final_response}"
102
+
103
+ # Use retriever if specified (optional)
104
+ if use_retriever:
105
+ # retriever_response = retriever.retrieve(text, entities)
106
+ # retriever_response = sql_db.entity_based_query(entities)
107
+ retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
108
+ final_response = f"{final_response} Additional Info found in the CSV: {retriever_response}"
109
+ output["retriever_response"] = retriever_response
110
+ else:
111
+ # Base response failed, use the retriever
112
+ output["base_response"] = f"{base_response} Using retriever to query CSV file..."
113
+ # retriever_response = retriever.retrieve(text, entities)
114
+ # retriever_response = sql_db.keyword_match_search(entities)
115
+ retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
116
+ output["retriever_response"] = retriever_response
117
+
118
+ if "No relevant data found" in retriever_response:
119
+ # If both API and rag failed to extract information, search on the web
120
+ search_results = duckduckgo_web_search(text)
121
+ if search_results:
122
+ output["web_search_response"] = search_results[0]['snippet']
123
+ final_response = search_results[0]['snippet']
124
+ else:
125
+ output["web_search_response"] = "No relevant data found on the web."
126
+ final_response = "No relevant data found on the web."
127
+ else:
128
+ final_response = retriever_response
129
+
130
+ output["final_response"] = final_response
131
+ except ImportError as e:
132
+ output["error"] = f"Module import error: {e}"
133
+ except AttributeError as e:
134
+ output["error"] = f"Class not found in module: {e}"
135
+ except Exception as e:
136
+ output["error"] = f"Error processing intent {intent}: {e}"
137
+ else:
138
+ output["error"] = f"Unsupported intent: {intent}"
139
+ else:
140
+ output["error"] = "Could not classify intent."
141
+
142
+ except Exception as e:
143
+ output["error"] = f"Unexpected error: {e}"
144
+
145
+ # print(output)
146
+ # Return output to the User Interface
147
+ return output
modules/__pycache__/financial_query.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
modules/__pycache__/get_balance_sheet.cpython-310.pyc ADDED
Binary file (1 kB). View file
 
modules/__pycache__/get_cash_flow.cpython-310.pyc ADDED
Binary file (924 Bytes). View file
 
modules/__pycache__/get_company_profile.cpython-310.pyc ADDED
Binary file (903 Bytes). View file
 
modules/__pycache__/get_financial_ratios.cpython-310.pyc ADDED
Binary file (917 Bytes). View file
 
modules/__pycache__/get_income_statement.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
modules/__pycache__/get_income_tax.cpython-310.pyc ADDED
Binary file (586 Bytes). View file
 
modules/__pycache__/get_interest.cpython-310.pyc ADDED
Binary file (582 Bytes). View file
 
modules/__pycache__/get_market_cap.cpython-310.pyc ADDED
Binary file (847 Bytes). View file
 
modules/__pycache__/get_net_income.cpython-310.pyc ADDED
Binary file (904 Bytes). View file
 
modules/__pycache__/get_profit_margin.cpython-310.pyc ADDED
Binary file (910 Bytes). View file
 
modules/__pycache__/get_research_info.cpython-310.pyc ADDED
Binary file (595 Bytes). View file
 
modules/__pycache__/get_revenue.cpython-310.pyc ADDED
Binary file (900 Bytes). View file
 
modules/__pycache__/get_stock_price.cpython-310.pyc ADDED
Binary file (844 Bytes). View file
 
modules/get_balance_sheet.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_balance_sheet.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetBalanceSheet:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_balance_sheet(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No balance sheet data available for {ticker}."
11
+ assets = data[0].get("totalAssets", 0)
12
+ liabilities = data[0].get("totalLiabilities", 0)
13
+ return f"{ticker}'s assets for {year or 'the latest year'} are ${assets / 1_000_000_000:.2f} billion, and liabilities are ${liabilities / 1_000_000_000:.2f} billion."
14
+ except Exception as e:
15
+ return f"Error fetching balance sheet: {e}"
modules/get_cash_flow.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_cash_flow.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetCashFlow:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_cash_flow(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No cash flow data available for {ticker}."
11
+ value = data[0].get("cashFlowFromOperatingActivities", 0)
12
+ return f"{ticker}'s cash from operations for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
13
+ except Exception as e:
14
+ return f"Error fetching cash flow: {e}"
modules/get_company_profile.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_company_profile.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetCompanyProfile:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_profile(ticker)
9
+ if not data:
10
+ return f"Error: No company profile data available for {ticker}."
11
+ ceo = data[0].get("ceo", "N/A")
12
+ sector = data[0].get("sector", "N/A")
13
+ return f"{ticker}'s CEO is {ceo} and it operates in the {sector} sector."
14
+ except Exception as e:
15
+ return f"Error fetching company profile: {e}"
modules/get_cost_info.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/cost_info.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetCostInfo:
5
+ async def get_data(self, ticker):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ # Placeholder for cost of goods sold / operation cost logic
9
+ return None
10
+ except Exception as e:
11
+ return None
modules/get_divident_info.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_dividend_info.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetDividendInfo:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_ratios(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No dividend info available for {ticker}."
11
+ value = data[0].get("payoutRatio", 0) * 100
12
+ return f"{ticker}'s dividend payout ratio for {year or 'the latest year'} is {value:.2f}%."
13
+ except Exception as e:
14
+ return f"Error fetching dividend info: {e}"
modules/get_earnings_per_share.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_earnings_per_share.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetEarningsPerShare:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_key_metrics(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No earnings per share data available for {ticker}."
11
+ value = data[0].get("eps", 0)
12
+ return f"{ticker}'s earnings per share for {year or 'the latest year'} is ${value:.2f}."
13
+ except Exception as e:
14
+ return f"Error fetching earnings per share: {e}"
modules/get_financial_ratios.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_financial_ratios.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetFinancialRatios:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_ratios(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No financial ratios data available for {ticker}."
11
+ current_ratio = data[0].get("currentRatio", 0)
12
+ return f"{ticker}'s current ratio for {year or 'the latest year'} is {current_ratio:.2f}."
13
+ except Exception as e:
14
+ return f"Error fetching financial ratios: {e}"
modules/get_historical_stock_price.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_historical_stock_price.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetHistoricalStockPrice:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_historical_price(ticker, date=date)
9
+ if not data.get("historical"):
10
+ return f"Error: No historical stock price data available for {ticker} on {date}."
11
+ value = data["historical"][0].get("close", 0)
12
+ return f"{ticker}'s stock price on {date} was ${value:.2f}."
13
+ except Exception as e:
14
+ return f"Error fetching historical stock price: {e}"
modules/get_income_tax.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_income_tax.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetIncomeTax:
5
+ async def get_data(self, ticker):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ # Placeholder for income tax data logic
9
+ return None
10
+ except Exception as e:
11
+ return None
modules/get_interest.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_interest.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetInterest:
5
+ async def get_data(self, ticker):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ # Placeholder for interest expense/income logic
9
+ return None
10
+ except Exception as e:
11
+ return None
modules/get_market_cap.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_market_cap.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetMarketCap:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_profile(ticker)
9
+ if not data:
10
+ return f"Error: No market cap data available for {ticker}."
11
+ value = data[0].get("mktCap", 0) / 1_000_000_000
12
+ return f"{ticker}'s market cap is ${value:.2f} billion."
13
+ except Exception as e:
14
+ return f"Error fetching market cap: {e}"
modules/get_net_income.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_net_income.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetNetIncome:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_income_statement(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No net income data available for {ticker}."
11
+ value = data[0].get("netIncome", 0)
12
+ return f"{ticker}'s net income for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
13
+ except Exception as e:
14
+ return f"Error fetching net income: {e}"
modules/get_profit_margin.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_profit_margin.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetProfitMargin:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_ratios(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No profit margin data available for {ticker}."
11
+ value = data[0].get("netProfitMargin", 0) * 100
12
+ return f"{ticker}'s profit margin for {year or 'the latest year'} is {value:.2f}%."
13
+ except Exception as e:
14
+ return f"Error fetching profit margin: {e}"
modules/get_research_info.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_research_info.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetResearchInfo:
5
+ async def get_data(self, ticker):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ # Placeholder for R&D info logic
9
+ return None
10
+ except Exception as e:
11
+ return None
modules/get_revenue.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_revenue.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetRevenue:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_income_statement(ticker, year=year)
9
+ if not data:
10
+ return f"Error: No revenue data available for {ticker}."
11
+ value = data[0].get("revenue", 0)
12
+ return f"{ticker}'s revenue for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
13
+ except Exception as e:
14
+ return f"Error fetching revenue: {e}"
modules/get_stock_price.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/get_stock_price.py
2
+ from api.endpoints import FMPEndpoints
3
+
4
+ class GetStockPrice:
5
+ async def get_data(self, ticker, year=None, date=None):
6
+ endpoints = FMPEndpoints()
7
+ try:
8
+ data = await endpoints.get_quote_short(ticker)
9
+ if not data:
10
+ return f"Error: No stock price data available for {ticker}."
11
+ value = data[0].get("price", 0)
12
+ return f"{ticker}'s current stock price is ${value:.2f}."
13
+ except Exception as e:
14
+ return f"Error fetching stock price: {e}"
rag/__pycache__/embedder.cpython-310.pyc ADDED
Binary file (968 Bytes). View file
 
rag/__pycache__/retriever.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
rag/__pycache__/sql_db.cpython-310.pyc ADDED
Binary file (5.51 kB). View file
 
rag/__pycache__/web_search.cpython-310.pyc ADDED
Binary file (588 Bytes). View file
 
rag/embedder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/embedder.py
2
+ from sentence_transformers import SentenceTransformer
3
+ import numpy as np
4
+
5
+ class Embedder:
6
+ def __init__(self, model_name="all-MiniLM-L6-v2"): # "all-mpnet-base-v2"
7
+ self.model = SentenceTransformer(model_name)
8
+
9
+ def embed(self, texts):
10
+ """
11
+ Embed a list of texts into vectors.
12
+ Args:
13
+ texts (list of str): Texts to embed.
14
+ Returns:
15
+ numpy.ndarray: Embeddings.
16
+ """
17
+ if isinstance(texts, str):
18
+ texts = [texts]
19
+ embeddings = self.model.encode(texts, convert_to_numpy=True)
20
+ return embeddings
rag/graphrag.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from neo4j import GraphDatabase
3
+ from llama_index.core import (
4
+ SimpleDirectoryReader,
5
+ VectorStoreIndex,
6
+ StorageContext,
7
+ load_index_from_storage,
8
+ Settings,
9
+ )
10
+ from llama_index.llms.ollama import Ollama
11
+ from llama_index.vector_stores.neo4jvector import Neo4jVectorStore
12
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
13
+
14
+
15
+ class GraphRAGRetriever:
16
+ def __init__(self, neo4j_url, neo4j_username, neo4j_password):
17
+ # Set up the embedding model
18
+ self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
19
+
20
+ # Set up the Ollama LLM
21
+ self.llm = Ollama(model="gemma:2b", request_timeout=None)
22
+
23
+ # Configure Settings
24
+ Settings.llm = self.llm
25
+ Settings.embed_model = self.embed_model
26
+
27
+ # Set up the Neo4j driver
28
+ self.driver = GraphDatabase.driver(neo4j_url, auth=(neo4j_username, neo4j_password))
29
+
30
+ # Set up the Neo4j vector store
31
+ self.vector_store = Neo4jVectorStore(
32
+ url=neo4j_url,
33
+ username=neo4j_username,
34
+ password=neo4j_password,
35
+ embedding_dimension=384, # Matches MiniLM model
36
+ driver=self.driver
37
+ )
38
+
39
+ def ingest_documents(self, directory_path):
40
+ # Load documents from the specified directory
41
+ documents = SimpleDirectoryReader(directory_path).load_data()
42
+
43
+ # Create the vector index
44
+ index = VectorStoreIndex.from_documents(
45
+ documents,
46
+ vector_store=self.vector_store,
47
+ )
48
+
49
+ # Persist the index to disk
50
+ index.storage_context.persist()
51
+
52
+ def query(self, question):
53
+ # Load the index from storage
54
+ storage_context = StorageContext.from_defaults(persist_dir="./storage")
55
+ index = load_index_from_storage(storage_context)
56
+
57
+ # Create a query engine and execute the query
58
+ query_engine = index.as_query_engine()
59
+ response = query_engine.query(question)
60
+
61
+ return str(response)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ retriever = GraphRAGRetriever(
66
+ neo4j_url="bolt://localhost:7687/",
67
+ neo4j_username="neo4j",
68
+ neo4j_password=os.getenv("NEO4J_PASSWORD")
69
+ )
70
+ retriever.ingest_documents("/home/bapary/Music/AI Finance Agent/rag/data")
71
+ answer = retriever.query("What is the revenue of Company Microsoft in 2021?")
72
+ print(answer)
rag/retriever.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import faiss
3
+ import numpy as np
4
+ from .embedder import Embedder
5
+ from fuzzywuzzy import fuzz
6
+ from langchain_community.llms import Ollama
7
+ from langchain.prompts import PromptTemplate
8
+
9
+ class Retriever:
10
+ def __init__(self, file_path):
11
+ self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
12
+ self.index = None
13
+ self.documents = []
14
+ self.data = None
15
+ self.embeddings = None
16
+ self.load_file(file_path)
17
+ self.build_index()
18
+
19
+ def load_file(self, file_path):
20
+ try:
21
+ if file_path.endswith('.csv'):
22
+ self.data = pd.read_csv(file_path)
23
+ elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
24
+ self.data = pd.read_excel(file_path)
25
+ else:
26
+ raise ValueError("Unsupported file format. Use .csv, .xlsx, or .xls")
27
+ self.documents = self.data["Ticker"].astype(str).tolist()
28
+ except Exception as e:
29
+ print(f"Error loading file: {e}")
30
+ self.documents = []
31
+ self.data = pd.DataFrame()
32
+
33
+ def build_index(self):
34
+ if not self.documents:
35
+ return
36
+ self.embeddings = self.embedder.embed(self.documents)
37
+ dim = self.embeddings.shape[1]
38
+ self.index = faiss.IndexFlatL2(dim)
39
+ self.index.add(self.embeddings)
40
+
41
+ def retrieve(self, query, entities, k=3, threshold=0.7):
42
+
43
+ query_prompt = f"{entities['ticker']} {entities['metric']} {entities['year']}"
44
+ # print(query_prompt)
45
+ if not self.index or not self.documents or self.data.empty:
46
+ return []
47
+
48
+ query_parts = query_prompt.split()
49
+ if len(query_parts) != 3:
50
+ print("Query must follow 'ticker metric year' pattern")
51
+ return []
52
+
53
+ query_ticker, query_metric, query_year = query_parts
54
+
55
+ # Ticker similarity
56
+ query_ticker_embedding = self.embedder.embed([query_ticker])
57
+ distances, indices = self.index.search(query_ticker_embedding, k)
58
+ ticker_matches = []
59
+ for i, idx in enumerate(indices[0]):
60
+ if idx < len(self.documents):
61
+ ticker = self.data.iloc[idx]["Ticker"]
62
+ similarity_score = 1 - distances[0][i] / 2
63
+ ticker_matches.append((ticker, similarity_score, idx))
64
+
65
+ # Metric similarity
66
+ metric_embeddings = self.embedder.embed(self.data.columns.tolist())
67
+ query_metric_embedding = self.embedder.embed([query_metric])[0]
68
+ metric_scores = []
69
+ for col, col_embedding in zip(self.data.columns, metric_embeddings):
70
+ if col.lower() in ["ticker", "year"]:
71
+ continue
72
+ cos_sim = np.dot(query_metric_embedding, col_embedding) / (
73
+ np.linalg.norm(query_metric_embedding) * np.linalg.norm(col_embedding)
74
+ )
75
+ metric_scores.append((col, cos_sim))
76
+
77
+ # Year similarity
78
+ if "Year" not in self.data.columns:
79
+ print("No 'Year' column found in data")
80
+ return []
81
+ year_scores = []
82
+ for year in self.data["Year"].astype(str).unique():
83
+ similarity = fuzz.ratio(query_year, year) / 100.0
84
+ year_scores.append((year, similarity))
85
+
86
+ # Combine matches
87
+ retrieved_data = []
88
+ seen = set()
89
+ for ticker, ticker_score, idx in ticker_matches:
90
+ if ticker_score < threshold:
91
+ continue
92
+ for metric, metric_score in metric_scores:
93
+ if metric_score < threshold:
94
+ continue
95
+ for year, year_score in year_scores:
96
+ if year_score < 0.5:
97
+ continue
98
+ combined_score = (ticker_score + metric_score + year_score) / 3
99
+ match = self.data[
100
+ (self.data["Ticker"].str.lower() == ticker.lower()) &
101
+ (self.data["Year"].astype(str) == year) &
102
+ (self.data[metric].notnull())
103
+ ]
104
+ if not match.empty:
105
+ value = match[metric].iloc[0]
106
+ key = (ticker, metric, year)
107
+ if key not in seen:
108
+ seen.add(key)
109
+ retrieved_data.append({
110
+ "ticker": ticker,
111
+ "metric": metric,
112
+ "value": value,
113
+ "year": year,
114
+ "combined_score": combined_score
115
+ })
116
+
117
+ if retrieved_data:
118
+ # print(retrieved_data)
119
+ retrieved_data.sort(key=lambda x: x["combined_score"], reverse=True)
120
+ best_match = retrieved_data[0]
121
+ answer = answer_question(query, best_match)
122
+ return answer
123
+
124
+ return "No relevant data found."
125
+
126
+ def answer_question(question, retrieved_data):
127
+ """
128
+ Use a lightweight LLM to generate a natural-language answer on CPU.
129
+
130
+ Args:
131
+ question (str): The question to answer
132
+ retrieved_data (list): List of dictionaries with ticker, metric, value, year
133
+
134
+ Returns:
135
+ str: Natural-language answer
136
+ """
137
+ # print(question)
138
+ # print(retrieved_data)
139
+ try:
140
+ # Initialize lightweight LLM (llama3.2:3b, CPU-friendly)
141
+ llm = Ollama(model="gemma:2b", num_gpu=0) # Explicitly disable GPU
142
+
143
+ # Minimal prompt for CPU efficiency
144
+ prompt_template = PromptTemplate(
145
+ input_variables=["question", "ticker", "metric", "value", "year"],
146
+ template=(
147
+ "Question: {question}\n"
148
+ "Data: Ticker={ticker}, Metric={metric}, Value={value}, Year={year}\n"
149
+ "Answer concisely, formatting the value with commas."
150
+ )
151
+ )
152
+ # print(prompt_template)
153
+
154
+ # Format data
155
+ if not retrieved_data:
156
+ return "No relevant data found."
157
+
158
+ prompt = prompt_template.format(
159
+ question=question,
160
+ ticker=retrieved_data['ticker'],
161
+ metric=retrieved_data['metric'],
162
+ value=retrieved_data, # formatted_value,
163
+ year=retrieved_data['year']
164
+ )
165
+
166
+ # Generate response
167
+ response = llm.invoke(prompt)
168
+ return response.strip()
169
+
170
+ except Exception as e:
171
+ print(f"Error generating answer: {e}")
172
+ return "Unable to generate answer."
173
+
174
+ # def main(file_path, query, question):
175
+ # """
176
+ # Main function to process a query, retrieve results, and answer a question.
177
+
178
+ # Args:
179
+ # file_path (str): Path to the CSV or Excel file
180
+ # query (str): Query string in 'ticker metric year' format
181
+ # question (str): Natural-language question to answer
182
+
183
+ # Returns:
184
+ # tuple: (retrieved data, answer)
185
+ # """
186
+ # try:
187
+ # retriever = Retriever(file_path)
188
+ # results = retriever.retrieve(query)
189
+ # answer = answer_question(question, results)
190
+ # return results, answer
191
+ # except Exception as e:
192
+ # print(f"Error processing query: {e}")
193
+ # return [], "Unable to process query."
194
+
195
+ # if __name__ == "__main__":
196
+ # file_path = "./financial_data.csv"
197
+ # query = "AAPL InterestExpense 2024"
198
+ # question = "What is the InterestExpense of AAPL 2024?"
199
+ # results, answer = main(file_path, query, question)
200
+ # for result in results:
201
+ # print(f"Ticker: {result['ticker']}, Metric: {result['metric']}, Value: {result['value']}, Year: {result['year']}")
202
+ # print(f"Answer: {answer}")
rag/sql_db.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag/retriever.py
2
+ import os
3
+ import pandas as pd
4
+ import faiss
5
+ import numpy as np
6
+ import sqlite3
7
+ from .embedder import Embedder
8
+ from datetime import datetime
9
+
10
+ class SQL_Key_Pair:
11
+ def __init__(self, file_path="financial data sp500 companies.csv", model_name="all-MiniLM-L6-v2", db_path="financial_data.db"):
12
+ self.embedder = Embedder(model_name)
13
+ self.index = None
14
+ self.documents = []
15
+ self.data = None
16
+ self.embeddings = None
17
+ self.db_conn = sqlite3.connect(db_path)
18
+ self.create_db_table()
19
+ self.load_data(file_path)
20
+
21
+ def create_db_table(self):
22
+ """
23
+ Create the custom_financials table in the database if it doesn’t exist.
24
+ """
25
+ cursor = self.db_conn.cursor()
26
+ cursor.execute("""
27
+ CREATE TABLE IF NOT EXISTS custom_financials (
28
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
29
+ source_file TEXT,
30
+ firm TEXT,
31
+ ticker TEXT,
32
+ date TEXT,
33
+ metric TEXT,
34
+ value REAL,
35
+ last_updated TEXT
36
+ )
37
+ """)
38
+ self.db_conn.commit()
39
+
40
+ def load_data(self, file_path):
41
+ """
42
+ Load financial data from a CSV or Excel file and store it in the database.
43
+ """
44
+ try:
45
+ if file_path.endswith('.csv'):
46
+ df = pd.read_csv(file_path)
47
+ elif file_path.endswith('.xlsx'):
48
+ df = pd.read_excel(file_path)
49
+ else:
50
+ raise ValueError("Unsupported file format. Use .csv or .xlsx.")
51
+
52
+ self.data = df
53
+ self.documents = self.data["Ticker"].astype(str).tolist()
54
+
55
+ cursor = self.db_conn.cursor()
56
+ for _, row in self.data.iterrows():
57
+ firm = row.get("firm", "")
58
+ ticker = row.get("Ticker", "")
59
+ date = row.get("date", "")
60
+ for column in self.data.columns:
61
+ if pd.notna(row[column]):
62
+ try:
63
+ value = float(row[column])
64
+ except (ValueError, TypeError):
65
+ value = 0.0
66
+ cursor.execute("""
67
+ INSERT INTO custom_financials (source_file, firm, ticker, date, metric, value, last_updated)
68
+ VALUES (?, ?, ?, ?, ?, ?, ?)
69
+ """, (os.path.basename(file_path), firm, ticker, str(date), column, value, datetime.now().isoformat()))
70
+ self.db_conn.commit()
71
+ print(f"Loaded {len(self.data)} rows from {file_path} into custom_financials.")
72
+ self.build_index() # Rebuild FAISS index after loading
73
+ except Exception as e:
74
+ print(f"Error loading data: {e}")
75
+ self.documents = []
76
+ self.data = pd.DataFrame()
77
+
78
+ def build_index(self):
79
+ """
80
+ Build a FAISS index from the embedded descriptions.
81
+ """
82
+ if not self.documents:
83
+ return
84
+ self.embeddings = self.embedder.embed(self.documents)
85
+ dim = self.embeddings.shape[1]
86
+ self.index = faiss.IndexFlatL2(dim)
87
+ self.index.add(self.embeddings)
88
+
89
+ def keyword_match_search(self, entities):
90
+ """
91
+ Perform strict keyword match based search from CSV.
92
+ """
93
+ if self.data is None or self.data.empty:
94
+ return "No data loaded."
95
+
96
+ ticker = entities.get("ticker", "")
97
+ metric = entities.get("metric", "")
98
+
99
+ if not ticker or not metric:
100
+ return "No relevant data found."
101
+
102
+ ticker = ticker.lower()
103
+ metric = metric.lower()
104
+
105
+ retrieved_text = ""
106
+ for _, row in self.data.iterrows():
107
+ if str(row.get("Ticker", "")).lower() == ticker:
108
+ for col in self.data.columns:
109
+ if col.lower() == metric:
110
+ if pd.isna(row[col]) or row[col] == "":
111
+ continue
112
+ value_in_billions = row[col] / 1_000_000_000
113
+ retrieved_text = f"Retrieved {metric} for {ticker} is : ${value_in_billions:.2f} billion."
114
+ break
115
+ break
116
+
117
+ if not retrieved_text:
118
+ return "No relevant data found."
119
+
120
+ return retrieved_text
121
+
122
+
123
+ def query_csv(self, query, k=3):
124
+ """
125
+ Query the CSV data with a user query.
126
+ """
127
+ retrieved_data = self.retrieve(query, k=k)
128
+ if not retrieved_data:
129
+ return "No relevant data found."
130
+
131
+ responses = []
132
+ for entry in retrieved_data:
133
+ try:
134
+ value = float(entry["value"])
135
+ value_in_billions = value / 1_000_000_000
136
+ response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was ${value_in_billions:.2f} billion."
137
+ except:
138
+ response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was {entry['value']}."
139
+ responses.append(response)
140
+
141
+ return "\n".join(responses)
142
+
143
+
144
+ def entity_based_query(self, entities):
145
+ return self.keyword_match_search(entities)
146
+
147
+ def query_db(self, ticker, metric):
148
+ """
149
+ Query the custom_financials table based on ticker and metric, ignoring date and year.
150
+ """
151
+ try:
152
+ cursor = self.db_conn.cursor()
153
+ query = """
154
+ SELECT value FROM custom_financials
155
+ WHERE ticker = ? AND metric = ?
156
+ LIMIT 1
157
+ """
158
+ params = [ticker, metric]
159
+ cursor.execute(query, params)
160
+ result = cursor.fetchone()
161
+ if result:
162
+ value = result[0]
163
+ value_in_billions = value / 1_000_000_000
164
+ return f"{metric} for {ticker}: ${value_in_billions:.2f} billion."
165
+ return f"No {metric} data found for {ticker}."
166
+ except Exception as e:
167
+ print(f"Error querying database: {e}")
168
+ return f"Error querying database: {str(e)}"
169
+
170
+ def __del__(self):
171
+ self.db_conn.close()
rag/web_search.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from duckduckgo_search import DDGS
2
+
3
+ def duckduckgo_web_search(query, max_results=1):
4
+ results = []
5
+ with DDGS() as ddgs:
6
+ for r in ddgs.text(query, region='wt-wt', safesearch='Off', max_results=max_results):
7
+ results.append({
8
+ "title": r["title"],
9
+ "href": r["href"],
10
+ "snippet": r["body"]
11
+ })
12
+ return results
repo.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.11.16
2
+ banks==2.1.1
3
+ blis==1.2.1
4
+ catalogue==2.0.10
5
+ certifi==2025.1.31
6
+ click==8.1.8
7
+ cloudpathlib==0.21.0
8
+ colorama==0.4.6
9
+ confection==0.1.5
10
+ cymem==2.0.11
11
+ dirtyjson==1.0.8
12
+ distro==1.9.0
13
+ duckduckgo-search
14
+ fastapi
15
+ faiss-cpu==1.10.0
16
+ filetype==1.2.0
17
+ fuzzywuzzy==0.18.0
18
+ griffe==1.7.2
19
+ httpx-sse==0.4.0
20
+ iniconfig==2.1.0
21
+ marisa-trie==1.2.1
22
+ ml-dtypes==0.5.1
23
+ murmurhash==1.0.12
24
+ neo4j==5.28.1
25
+ nest-asyncio==1.6.0
26
+ nltk==3.9.1
27
+ numpy==1.26.4
28
+ packaging==23.2
29
+ platformdirs==4.3.7
30
+ pluggy==1.5.0
31
+ primp==0.14.0
32
+ pyaudio==0.2.14
33
+ pydantic==2.11.2
34
+ pydantic-core==2.33.1
35
+ pydantic-settings==2.8.1
36
+ pypdf==4.3.1
37
+ pytest==8.3.5
38
+ pyyaml==6.0.2
39
+ requests==2.32.3
40
+ rich==14.0.0
41
+ scikit-learn==1.6.1
42
+ sentence-transformers==2.6.1
43
+ shellingham==1.5.4
44
+ six==1.17.0
45
+ smart-open==7.1.0
46
+ spacy==3.8.5
47
+ spacy-legacy==3.0.12
48
+ spacy-loggers==1.0.5
49
+ spacy-lookups-data==1.0.5
50
+ sqlalchemy==2.0.40
51
+ srsly==2.5.1
52
+ srt==3.5.3
53
+ striprtf==0.0.26
54
+ tensorboard==2.19.0
55
+ tensorflow==2.19.0
56
+ tf-keras==2.19.0
57
+ thinc==8.3.4
58
+ threadpoolctl==3.6.0
59
+ tomli==2.2.1
60
+ tqdm==4.67.1
61
+ typer==0.15.2
62
+ typing-inspection==0.4.0
63
+ uvicorn==0.34.0
64
+ vosk==0.3.45
65
+ wasabi==1.1.3
66
+ weasel==0.4.1
67
+ yarl==1.19.0
68
+ langchain
69
+ langchain_community