Spaces:
Sleeping
Sleeping
Added files initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- Dockerfile +41 -0
- api/__pycache__/endpoints.cpython-310.pyc +0 -0
- api/endpoints.py +151 -0
- app.py +157 -0
- data/financial data sp500 companies.csv +0 -0
- data/financial_data.csv +0 -0
- important.txt +11 -0
- main.py +147 -0
- modules/__pycache__/financial_query.cpython-310.pyc +0 -0
- modules/__pycache__/get_balance_sheet.cpython-310.pyc +0 -0
- modules/__pycache__/get_cash_flow.cpython-310.pyc +0 -0
- modules/__pycache__/get_company_profile.cpython-310.pyc +0 -0
- modules/__pycache__/get_financial_ratios.cpython-310.pyc +0 -0
- modules/__pycache__/get_income_statement.cpython-310.pyc +0 -0
- modules/__pycache__/get_income_tax.cpython-310.pyc +0 -0
- modules/__pycache__/get_interest.cpython-310.pyc +0 -0
- modules/__pycache__/get_market_cap.cpython-310.pyc +0 -0
- modules/__pycache__/get_net_income.cpython-310.pyc +0 -0
- modules/__pycache__/get_profit_margin.cpython-310.pyc +0 -0
- modules/__pycache__/get_research_info.cpython-310.pyc +0 -0
- modules/__pycache__/get_revenue.cpython-310.pyc +0 -0
- modules/__pycache__/get_stock_price.cpython-310.pyc +0 -0
- modules/get_balance_sheet.py +15 -0
- modules/get_cash_flow.py +14 -0
- modules/get_company_profile.py +15 -0
- modules/get_cost_info.py +11 -0
- modules/get_divident_info.py +14 -0
- modules/get_earnings_per_share.py +14 -0
- modules/get_financial_ratios.py +14 -0
- modules/get_historical_stock_price.py +14 -0
- modules/get_income_tax.py +11 -0
- modules/get_interest.py +11 -0
- modules/get_market_cap.py +14 -0
- modules/get_net_income.py +14 -0
- modules/get_profit_margin.py +14 -0
- modules/get_research_info.py +11 -0
- modules/get_revenue.py +14 -0
- modules/get_stock_price.py +14 -0
- rag/__pycache__/embedder.cpython-310.pyc +0 -0
- rag/__pycache__/retriever.cpython-310.pyc +0 -0
- rag/__pycache__/sql_db.cpython-310.pyc +0 -0
- rag/__pycache__/web_search.cpython-310.pyc +0 -0
- rag/embedder.py +20 -0
- rag/graphrag.py +72 -0
- rag/retriever.py +202 -0
- rag/sql_db.py +171 -0
- rag/web_search.py +12 -0
- repo.jpg +0 -0
- requirements.txt +69 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use official Python image
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Install system dependencies
|
5 |
+
RUN apt-get update && \
|
6 |
+
apt-get install -y wget unzip curl gcc portaudio19-dev && \
|
7 |
+
rm -rf /var/lib/apt/lists/*
|
8 |
+
|
9 |
+
# Set working directory
|
10 |
+
WORKDIR /app
|
11 |
+
|
12 |
+
# Copy your code into the container
|
13 |
+
COPY . .
|
14 |
+
|
15 |
+
# Install Python dependencies
|
16 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# Install spaCy model
|
21 |
+
# RUN python -m spacy download en_core_web_lg
|
22 |
+
RUN python -m spacy download en_core_web_sm
|
23 |
+
|
24 |
+
# Download and unzip the Vosk model
|
25 |
+
RUN wget https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip && \
|
26 |
+
unzip vosk-model-small-en-us-0.15.zip && \
|
27 |
+
rm vosk-model-small-en-us-0.15.zip
|
28 |
+
|
29 |
+
# Install Ollama
|
30 |
+
# RUN curl -fsSL https://ollama.com/install.sh | sh
|
31 |
+
|
32 |
+
# Pull the Ollama model
|
33 |
+
# RUN ollama serve & sleep 5 && ollama pull gemma:2b
|
34 |
+
|
35 |
+
# Expose the port FastAPI will run on
|
36 |
+
EXPOSE 7860
|
37 |
+
|
38 |
+
# Start the API
|
39 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
40 |
+
|
41 |
+
# CMD uvicorn app:app --host 0.0.0.0 --port $PORT
|
api/__pycache__/endpoints.cpython-310.pyc
ADDED
Binary file (4.22 kB). View file
|
|
api/endpoints.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# api/endpoints.py
|
2 |
+
import httpx
|
3 |
+
import os
|
4 |
+
|
5 |
+
class FMPEndpoints:
|
6 |
+
def __init__(self):
|
7 |
+
# self.db = FinancialDB()
|
8 |
+
self.fmp_api_key = os.getenv("FMP_API_KEY")
|
9 |
+
# print(self.fmp_api_key)
|
10 |
+
self.base_url = "https://financialmodelingprep.com/api/v3"
|
11 |
+
|
12 |
+
async def get_income_statement(self, ticker, year=None, period="annual", limit=1):
|
13 |
+
"""
|
14 |
+
Fetch income statement data for a given ticker.
|
15 |
+
"""
|
16 |
+
endpoint = f"{self.base_url}/income-statement/{ticker}"
|
17 |
+
params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
|
18 |
+
if year:
|
19 |
+
params["year"] = year
|
20 |
+
try:
|
21 |
+
async with httpx.AsyncClient() as client:
|
22 |
+
response = await client.get(endpoint, params=params)
|
23 |
+
response.raise_for_status()
|
24 |
+
return response.json()
|
25 |
+
except httpx.HTTPStatusError as e:
|
26 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
27 |
+
except Exception as e:
|
28 |
+
raise Exception(f"Error fetching income statement: {e}")
|
29 |
+
|
30 |
+
async def get_quote_short(self, ticker):
|
31 |
+
"""
|
32 |
+
Fetch the current stock price (short quote) for a given ticker.
|
33 |
+
"""
|
34 |
+
endpoint = f"{self.base_url}/quote-short/{ticker}"
|
35 |
+
params = {"apikey": self.fmp_api_key}
|
36 |
+
try:
|
37 |
+
async with httpx.AsyncClient() as client:
|
38 |
+
response = await client.get(endpoint, params=params)
|
39 |
+
response.raise_for_status()
|
40 |
+
return response.json()
|
41 |
+
except httpx.HTTPStatusError as e:
|
42 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
43 |
+
except Exception as e:
|
44 |
+
raise Exception(f"Error fetching quote: {e}")
|
45 |
+
|
46 |
+
async def get_ratios(self, ticker, year=None, limit=1):
|
47 |
+
"""
|
48 |
+
Fetch financial ratios for a given ticker.
|
49 |
+
"""
|
50 |
+
endpoint = f"{self.base_url}/ratios/{ticker}"
|
51 |
+
params = {"apikey": self.fmp_api_key, "limit": limit}
|
52 |
+
if year:
|
53 |
+
params["year"] = year
|
54 |
+
try:
|
55 |
+
async with httpx.AsyncClient() as client:
|
56 |
+
response = await client.get(endpoint, params=params)
|
57 |
+
response.raise_for_status()
|
58 |
+
return response.json()
|
59 |
+
except httpx.HTTPStatusError as e:
|
60 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
61 |
+
except Exception as e:
|
62 |
+
raise Exception(f"Error fetching ratios: {e}")
|
63 |
+
|
64 |
+
async def get_profile(self, ticker):
|
65 |
+
"""
|
66 |
+
Fetch company profile data for a given ticker.
|
67 |
+
"""
|
68 |
+
endpoint = f"{self.base_url}/profile/{ticker}"
|
69 |
+
params = {"apikey": self.fmp_api_key}
|
70 |
+
try:
|
71 |
+
async with httpx.AsyncClient() as client:
|
72 |
+
response = await client.get(endpoint, params=params)
|
73 |
+
response.raise_for_status()
|
74 |
+
return response.json()
|
75 |
+
except httpx.HTTPStatusError as e:
|
76 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
77 |
+
except Exception as e:
|
78 |
+
raise Exception(f"Error fetching profile: {e}")
|
79 |
+
|
80 |
+
async def get_historical_price(self, ticker, date=None):
|
81 |
+
"""
|
82 |
+
Fetch historical stock price for a given ticker on a specific date.
|
83 |
+
"""
|
84 |
+
endpoint = f"{self.base_url}/historical-price-full/{ticker}"
|
85 |
+
params = {"apikey": self.fmp_api_key}
|
86 |
+
if date:
|
87 |
+
params["from"] = date
|
88 |
+
params["to"] = date
|
89 |
+
try:
|
90 |
+
async with httpx.AsyncClient() as client:
|
91 |
+
response = await client.get(endpoint, params=params)
|
92 |
+
response.raise_for_status()
|
93 |
+
return response.json()
|
94 |
+
except httpx.HTTPStatusError as e:
|
95 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
96 |
+
except Exception as e:
|
97 |
+
raise Exception(f"Error fetching historical price: {e}")
|
98 |
+
|
99 |
+
async def get_balance_sheet(self, ticker, year=None, period="annual", limit=1):
|
100 |
+
"""
|
101 |
+
Fetch balance sheet data for a given ticker.
|
102 |
+
"""
|
103 |
+
endpoint = f"{self.base_url}/balance-sheet-statement/{ticker}"
|
104 |
+
params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
|
105 |
+
if year:
|
106 |
+
params["year"] = year
|
107 |
+
try:
|
108 |
+
async with httpx.AsyncClient() as client:
|
109 |
+
response = await client.get(endpoint, params=params)
|
110 |
+
response.raise_for_status()
|
111 |
+
return response.json()
|
112 |
+
except httpx.HTTPStatusError as e:
|
113 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
114 |
+
except Exception as e:
|
115 |
+
raise Exception(f"Error fetching balance sheet: {e}")
|
116 |
+
|
117 |
+
async def get_cash_flow(self, ticker, year=None, period="annual", limit=1):
|
118 |
+
"""
|
119 |
+
Fetch cash flow statement data for a given ticker.
|
120 |
+
"""
|
121 |
+
endpoint = f"{self.base_url}/cash-flow-statement/{ticker}"
|
122 |
+
params = {"apikey": self.fmp_api_key, "period": period, "limit": limit}
|
123 |
+
if year:
|
124 |
+
params["year"] = year
|
125 |
+
try:
|
126 |
+
async with httpx.AsyncClient() as client:
|
127 |
+
response = await client.get(endpoint, params=params)
|
128 |
+
response.raise_for_status()
|
129 |
+
return response.json()
|
130 |
+
except httpx.HTTPStatusError as e:
|
131 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
132 |
+
except Exception as e:
|
133 |
+
raise Exception(f"Error fetching cash flow: {e}")
|
134 |
+
|
135 |
+
async def get_key_metrics(self, ticker, year=None, limit=1):
|
136 |
+
"""
|
137 |
+
Fetch key metrics (e.g., EPS) for a given ticker.
|
138 |
+
"""
|
139 |
+
endpoint = f"{self.base_url}/key-metrics/{ticker}"
|
140 |
+
params = {"apikey": self.fmp_api_key, "limit": limit}
|
141 |
+
if year:
|
142 |
+
params["year"] = year
|
143 |
+
try:
|
144 |
+
async with httpx.AsyncClient() as client:
|
145 |
+
response = await client.get(endpoint, params=params)
|
146 |
+
response.raise_for_status()
|
147 |
+
return response.json()
|
148 |
+
except httpx.HTTPStatusError as e:
|
149 |
+
raise Exception(f"API error: {e.response.status_code} - {e.response.text}")
|
150 |
+
except Exception as e:
|
151 |
+
raise Exception(f"Error fetching key metrics: {e}")
|
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
from fastapi import FastAPI, Request, Form
|
3 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi.templating import Jinja2Templates
|
6 |
+
from main import process_query
|
7 |
+
from voice.speech_to_text import SpeechToText
|
8 |
+
import os
|
9 |
+
import asyncio
|
10 |
+
import pyaudio
|
11 |
+
import wave
|
12 |
+
import logging
|
13 |
+
|
14 |
+
# Set up logging
|
15 |
+
logging.basicConfig(level=logging.INFO)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
|
20 |
+
# Mount static files (for CSS, JS, etc.)
|
21 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
22 |
+
|
23 |
+
# Set up templates
|
24 |
+
templates = Jinja2Templates(directory="templates")
|
25 |
+
|
26 |
+
# Vosk model path and audio file path
|
27 |
+
vosk_model_path = "./vosk-model-small-en-us-0.15"
|
28 |
+
audio_file_path = "voice/temp_audio.wav"
|
29 |
+
|
30 |
+
# Ensure the voice directory exists
|
31 |
+
os.makedirs("voice", exist_ok=True)
|
32 |
+
|
33 |
+
# Initialize SpeechToText
|
34 |
+
stt = SpeechToText(model_path=vosk_model_path)
|
35 |
+
|
36 |
+
# Global variables for recording state
|
37 |
+
recording = False
|
38 |
+
audio_frames = []
|
39 |
+
recording_task = None
|
40 |
+
|
41 |
+
def save_audio_to_wav(frames, sample_rate=16000):
|
42 |
+
"""Save audio frames to a WAV file."""
|
43 |
+
try:
|
44 |
+
logger.info(f"Saving audio to {audio_file_path} with {len(frames)} frames")
|
45 |
+
wf = wave.open(audio_file_path, 'wb')
|
46 |
+
wf.setnchannels(1)
|
47 |
+
wf.setsampwidth(2) # 16-bit
|
48 |
+
wf.setframerate(sample_rate)
|
49 |
+
wf.writeframes(b''.join(frames))
|
50 |
+
wf.close()
|
51 |
+
if os.path.exists(audio_file_path):
|
52 |
+
logger.info(f"WAV file saved successfully: {os.path.getsize(audio_file_path)} bytes")
|
53 |
+
else:
|
54 |
+
logger.error("WAV file was not created")
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Error saving WAV file: {str(e)}")
|
57 |
+
raise
|
58 |
+
|
59 |
+
async def record_audio():
|
60 |
+
"""Background task to record audio."""
|
61 |
+
global audio_frames
|
62 |
+
p = pyaudio.PyAudio()
|
63 |
+
try:
|
64 |
+
stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024)
|
65 |
+
stream.start_stream()
|
66 |
+
logger.info("Recording started...")
|
67 |
+
|
68 |
+
while recording:
|
69 |
+
data = stream.read(1024, exception_on_overflow=False)
|
70 |
+
audio_frames.append(data)
|
71 |
+
await asyncio.sleep(0.01) # Small sleep to prevent blocking
|
72 |
+
|
73 |
+
stream.stop_stream()
|
74 |
+
stream.close()
|
75 |
+
logger.info(f"Recording stopped, captured {len(audio_frames)} frames")
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error during recording: {str(e)}")
|
78 |
+
finally:
|
79 |
+
p.terminate()
|
80 |
+
|
81 |
+
@app.get("/", response_class=HTMLResponse)
|
82 |
+
async def get_index(request: Request):
|
83 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
84 |
+
|
85 |
+
@app.post("/start_recording", response_class=JSONResponse)
|
86 |
+
async def start_recording():
|
87 |
+
global recording, audio_frames, recording_task
|
88 |
+
if not recording:
|
89 |
+
recording = True
|
90 |
+
audio_frames = []
|
91 |
+
recording_task = asyncio.create_task(record_audio())
|
92 |
+
logger.info("Started recording task")
|
93 |
+
return {"status": "Recording started"}
|
94 |
+
logger.warning("Recording already in progress")
|
95 |
+
return {"status": "Already recording"}
|
96 |
+
|
97 |
+
@app.post("/stop_recording", response_class=HTMLResponse)
|
98 |
+
async def stop_recording(request: Request):
|
99 |
+
global recording, recording_task
|
100 |
+
if recording:
|
101 |
+
recording = False
|
102 |
+
if recording_task:
|
103 |
+
await recording_task # Wait for the recording task to complete
|
104 |
+
recording_task = None
|
105 |
+
|
106 |
+
# Save the audio to WAV
|
107 |
+
try:
|
108 |
+
save_audio_to_wav(audio_frames)
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Failed to save audio: {str(e)}")
|
111 |
+
return templates.TemplateResponse("index.html", {
|
112 |
+
"request": request,
|
113 |
+
"error": f"Failed to save audio: {str(e)}"
|
114 |
+
})
|
115 |
+
|
116 |
+
# Transcribe the saved audio
|
117 |
+
try:
|
118 |
+
text = stt.transcribe_audio(audio_file_path)
|
119 |
+
logger.info(f"Transcription result: '{text}'")
|
120 |
+
if not text:
|
121 |
+
logger.warning("Transcription returned no text")
|
122 |
+
return templates.TemplateResponse("index.html", {
|
123 |
+
"request": request,
|
124 |
+
"error": "Could not understand the audio."
|
125 |
+
})
|
126 |
+
return templates.TemplateResponse("index.html", {
|
127 |
+
"request": request,
|
128 |
+
"transcribed_text": text
|
129 |
+
})
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(f"Transcription error: {str(e)}")
|
132 |
+
return templates.TemplateResponse("index.html", {
|
133 |
+
"request": request,
|
134 |
+
"error": f"Transcription error: {str(e)}"
|
135 |
+
})
|
136 |
+
logger.warning("No recording in progress")
|
137 |
+
return templates.TemplateResponse("index.html", {
|
138 |
+
"request": request,
|
139 |
+
"error": "No recording in progress."
|
140 |
+
})
|
141 |
+
|
142 |
+
@app.post("/query", response_class=HTMLResponse)
|
143 |
+
async def handle_query(request: Request, query_text: str = Form(...), use_retriever: str = Form("no")):
|
144 |
+
use_retriever = use_retriever.lower() in ["yes", "y"]
|
145 |
+
result = await process_query(vosk_model_path, query_text=query_text, use_retriever=use_retriever)
|
146 |
+
|
147 |
+
return templates.TemplateResponse("index.html", {
|
148 |
+
"request": request,
|
149 |
+
"User_Query": query_text,
|
150 |
+
"Intent": result["intent"],
|
151 |
+
"Entities": result["entities"],
|
152 |
+
"API_Response": result["base_response"],
|
153 |
+
"RAG_Response": result["retriever_response"],
|
154 |
+
"Web_Search_Response": result["web_search_response"],
|
155 |
+
"Final_Response": result["final_response"],
|
156 |
+
"Error": result["error"]
|
157 |
+
})
|
data/financial data sp500 companies.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/financial_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
important.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1. The full project is inside the codebase folder. To run the project,
|
2 |
+
Create conda env using “conda create -n env_name python=3.10”
|
3 |
+
Activate the env using “conda activate env_name”
|
4 |
+
Install requirements.txt using, “pip install -r requirements.txt”
|
5 |
+
Install the bash scripts using “source env_variable.sh” and “source setup.sh”
|
6 |
+
Run “uvicorn app:app –reload”
|
7 |
+
|
8 |
+
2. The “Documentation of AI Finance Accountant Agent” have full overview of the project explaining key features, required tools, data sources etc. I attached some testing examples at the end of the documentation file.
|
9 |
+
|
10 |
+
3. The “demo_video.mkv” file is a demo of how my api agent is working. Carefully watch the video.
|
11 |
+
|
main.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
import asyncio
|
3 |
+
import importlib
|
4 |
+
from voice.speech_to_text import SpeechToText
|
5 |
+
from voice.intent_classifier import IntentClassifier
|
6 |
+
from api.endpoints import FMPEndpoints
|
7 |
+
from rag.retriever import Retriever
|
8 |
+
from rag.sql_db import SQL_Key_Pair
|
9 |
+
from rag.web_search import duckduckgo_web_search
|
10 |
+
|
11 |
+
async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
|
12 |
+
# Step 1: Initialize components
|
13 |
+
stt = SpeechToText(model_path=vosk_model_path)
|
14 |
+
classifier = IntentClassifier()
|
15 |
+
endpoints = FMPEndpoints()
|
16 |
+
# initialize rag tools
|
17 |
+
retriever = Retriever(file_path="./data/financial_data.csv")
|
18 |
+
sql_db = SQL_Key_Pair(file_path="./data/financial_data.csv")
|
19 |
+
|
20 |
+
# Output format
|
21 |
+
output = {
|
22 |
+
"User asked": "",
|
23 |
+
"intent": "",
|
24 |
+
"entities": "",
|
25 |
+
"base_response": "",
|
26 |
+
"retriever_response": "",
|
27 |
+
"web_search_response": "",
|
28 |
+
"final_response": "",
|
29 |
+
"error": ""
|
30 |
+
}
|
31 |
+
|
32 |
+
try:
|
33 |
+
# Step 2: Process input (text or audio)
|
34 |
+
if audio_data:
|
35 |
+
text = stt.transcribe_audio(audio_data)
|
36 |
+
if not text:
|
37 |
+
output["error"] = "Could not understand the audio."
|
38 |
+
return output
|
39 |
+
elif query_text:
|
40 |
+
text = query_text
|
41 |
+
else:
|
42 |
+
output["error"] = "No audio or text query provided."
|
43 |
+
return output
|
44 |
+
|
45 |
+
output["User asked"] = text
|
46 |
+
|
47 |
+
# Step 3: Classify intent (zero-shot) and extract entities
|
48 |
+
intent = classifier.classify_with_llm(text)
|
49 |
+
output["intent"] = intent if intent else "Could not classify intent."
|
50 |
+
|
51 |
+
entities = classifier.extract_entities(text)
|
52 |
+
output["entities"] = str(entities)
|
53 |
+
|
54 |
+
if intent:
|
55 |
+
intent_to_module = {
|
56 |
+
"get_net_income": ("modules.get_net_income", "GetNetIncome"),
|
57 |
+
"get_revenue": ("modules.get_revenue", "GetRevenue"),
|
58 |
+
"get_stock_price": ("modules.get_stock_price", "GetStockPrice"),
|
59 |
+
"get_profit_margin": ("modules.get_profit_margin", "GetProfitMargin"),
|
60 |
+
"get_company_profile": ("modules.get_company_profile", "GetCompanyProfile"),
|
61 |
+
"get_market_cap": ("modules.get_market_cap", "GetMarketCap"),
|
62 |
+
"get_historical_stock_price": ("modules.get_historical_stock_price", "GetHistoricalStockPrice"),
|
63 |
+
"get_dividend_info": ("modules.get_dividend_info", "GetDividendInfo"),
|
64 |
+
"get_balance_sheet": ("modules.get_balance_sheet", "GetBalanceSheet"),
|
65 |
+
"get_cash_flow": ("modules.get_cash_flow", "GetCashFlow"),
|
66 |
+
"get_financial_ratios": ("modules.get_financial_ratios", "GetFinancialRatios"),
|
67 |
+
"get_earnings_per_share": ("modules.get_earnings_per_share", "GetEarningsPerShare"),
|
68 |
+
"get_interest": ("modules.get_interest", "GetInterest"),
|
69 |
+
"get_income_tax": ("modules.get_income_tax", "GetIncomeTax"),
|
70 |
+
"get_cost_info": ("modules.get_cost_info", "GetCostInfo"),
|
71 |
+
"get_research_info": ("modules.get_research_info", "GetResearchInfo")
|
72 |
+
|
73 |
+
}
|
74 |
+
|
75 |
+
# Identify module for API calling
|
76 |
+
module_info = intent_to_module.get(intent)
|
77 |
+
if module_info:
|
78 |
+
module_path, class_name = module_info
|
79 |
+
try:
|
80 |
+
module = importlib.import_module(module_path)
|
81 |
+
class_instance = getattr(module, class_name)()
|
82 |
+
ticker = entities["ticker"]
|
83 |
+
|
84 |
+
# Step 4: Get the base response from the module
|
85 |
+
base_response = None
|
86 |
+
try:
|
87 |
+
base_response = await class_instance.get_data(
|
88 |
+
ticker=ticker,
|
89 |
+
year=entities["year"],
|
90 |
+
date=entities["date"],
|
91 |
+
)
|
92 |
+
except Exception as e:
|
93 |
+
base_response = f"Error fetching base response: {e}"
|
94 |
+
|
95 |
+
|
96 |
+
# Step 5: Handle the response based on requirements
|
97 |
+
final_response = None
|
98 |
+
if base_response and "Error" not in str(base_response) and "None" not in str(base_response):
|
99 |
+
# Base response succeeded
|
100 |
+
final_response = base_response
|
101 |
+
output["base_response"] = f"{final_response}"
|
102 |
+
|
103 |
+
# Use retriever if specified (optional)
|
104 |
+
if use_retriever:
|
105 |
+
# retriever_response = retriever.retrieve(text, entities)
|
106 |
+
# retriever_response = sql_db.entity_based_query(entities)
|
107 |
+
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
|
108 |
+
final_response = f"{final_response} Additional Info found in the CSV: {retriever_response}"
|
109 |
+
output["retriever_response"] = retriever_response
|
110 |
+
else:
|
111 |
+
# Base response failed, use the retriever
|
112 |
+
output["base_response"] = f"{base_response} Using retriever to query CSV file..."
|
113 |
+
# retriever_response = retriever.retrieve(text, entities)
|
114 |
+
# retriever_response = sql_db.keyword_match_search(entities)
|
115 |
+
retriever_response = sql_db.query_db(entities["ticker"], entities["metric"])
|
116 |
+
output["retriever_response"] = retriever_response
|
117 |
+
|
118 |
+
if "No relevant data found" in retriever_response:
|
119 |
+
# If both API and rag failed to extract information, search on the web
|
120 |
+
search_results = duckduckgo_web_search(text)
|
121 |
+
if search_results:
|
122 |
+
output["web_search_response"] = search_results[0]['snippet']
|
123 |
+
final_response = search_results[0]['snippet']
|
124 |
+
else:
|
125 |
+
output["web_search_response"] = "No relevant data found on the web."
|
126 |
+
final_response = "No relevant data found on the web."
|
127 |
+
else:
|
128 |
+
final_response = retriever_response
|
129 |
+
|
130 |
+
output["final_response"] = final_response
|
131 |
+
except ImportError as e:
|
132 |
+
output["error"] = f"Module import error: {e}"
|
133 |
+
except AttributeError as e:
|
134 |
+
output["error"] = f"Class not found in module: {e}"
|
135 |
+
except Exception as e:
|
136 |
+
output["error"] = f"Error processing intent {intent}: {e}"
|
137 |
+
else:
|
138 |
+
output["error"] = f"Unsupported intent: {intent}"
|
139 |
+
else:
|
140 |
+
output["error"] = "Could not classify intent."
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
output["error"] = f"Unexpected error: {e}"
|
144 |
+
|
145 |
+
# print(output)
|
146 |
+
# Return output to the User Interface
|
147 |
+
return output
|
modules/__pycache__/financial_query.cpython-310.pyc
ADDED
Binary file (1.15 kB). View file
|
|
modules/__pycache__/get_balance_sheet.cpython-310.pyc
ADDED
Binary file (1 kB). View file
|
|
modules/__pycache__/get_cash_flow.cpython-310.pyc
ADDED
Binary file (924 Bytes). View file
|
|
modules/__pycache__/get_company_profile.cpython-310.pyc
ADDED
Binary file (903 Bytes). View file
|
|
modules/__pycache__/get_financial_ratios.cpython-310.pyc
ADDED
Binary file (917 Bytes). View file
|
|
modules/__pycache__/get_income_statement.cpython-310.pyc
ADDED
Binary file (1.37 kB). View file
|
|
modules/__pycache__/get_income_tax.cpython-310.pyc
ADDED
Binary file (586 Bytes). View file
|
|
modules/__pycache__/get_interest.cpython-310.pyc
ADDED
Binary file (582 Bytes). View file
|
|
modules/__pycache__/get_market_cap.cpython-310.pyc
ADDED
Binary file (847 Bytes). View file
|
|
modules/__pycache__/get_net_income.cpython-310.pyc
ADDED
Binary file (904 Bytes). View file
|
|
modules/__pycache__/get_profit_margin.cpython-310.pyc
ADDED
Binary file (910 Bytes). View file
|
|
modules/__pycache__/get_research_info.cpython-310.pyc
ADDED
Binary file (595 Bytes). View file
|
|
modules/__pycache__/get_revenue.cpython-310.pyc
ADDED
Binary file (900 Bytes). View file
|
|
modules/__pycache__/get_stock_price.cpython-310.pyc
ADDED
Binary file (844 Bytes). View file
|
|
modules/get_balance_sheet.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_balance_sheet.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetBalanceSheet:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_balance_sheet(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No balance sheet data available for {ticker}."
|
11 |
+
assets = data[0].get("totalAssets", 0)
|
12 |
+
liabilities = data[0].get("totalLiabilities", 0)
|
13 |
+
return f"{ticker}'s assets for {year or 'the latest year'} are ${assets / 1_000_000_000:.2f} billion, and liabilities are ${liabilities / 1_000_000_000:.2f} billion."
|
14 |
+
except Exception as e:
|
15 |
+
return f"Error fetching balance sheet: {e}"
|
modules/get_cash_flow.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_cash_flow.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetCashFlow:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_cash_flow(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No cash flow data available for {ticker}."
|
11 |
+
value = data[0].get("cashFlowFromOperatingActivities", 0)
|
12 |
+
return f"{ticker}'s cash from operations for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching cash flow: {e}"
|
modules/get_company_profile.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_company_profile.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetCompanyProfile:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_profile(ticker)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No company profile data available for {ticker}."
|
11 |
+
ceo = data[0].get("ceo", "N/A")
|
12 |
+
sector = data[0].get("sector", "N/A")
|
13 |
+
return f"{ticker}'s CEO is {ceo} and it operates in the {sector} sector."
|
14 |
+
except Exception as e:
|
15 |
+
return f"Error fetching company profile: {e}"
|
modules/get_cost_info.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/cost_info.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetCostInfo:
|
5 |
+
async def get_data(self, ticker):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
# Placeholder for cost of goods sold / operation cost logic
|
9 |
+
return None
|
10 |
+
except Exception as e:
|
11 |
+
return None
|
modules/get_divident_info.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_dividend_info.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetDividendInfo:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_ratios(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No dividend info available for {ticker}."
|
11 |
+
value = data[0].get("payoutRatio", 0) * 100
|
12 |
+
return f"{ticker}'s dividend payout ratio for {year or 'the latest year'} is {value:.2f}%."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching dividend info: {e}"
|
modules/get_earnings_per_share.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_earnings_per_share.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetEarningsPerShare:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_key_metrics(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No earnings per share data available for {ticker}."
|
11 |
+
value = data[0].get("eps", 0)
|
12 |
+
return f"{ticker}'s earnings per share for {year or 'the latest year'} is ${value:.2f}."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching earnings per share: {e}"
|
modules/get_financial_ratios.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_financial_ratios.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetFinancialRatios:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_ratios(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No financial ratios data available for {ticker}."
|
11 |
+
current_ratio = data[0].get("currentRatio", 0)
|
12 |
+
return f"{ticker}'s current ratio for {year or 'the latest year'} is {current_ratio:.2f}."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching financial ratios: {e}"
|
modules/get_historical_stock_price.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_historical_stock_price.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetHistoricalStockPrice:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_historical_price(ticker, date=date)
|
9 |
+
if not data.get("historical"):
|
10 |
+
return f"Error: No historical stock price data available for {ticker} on {date}."
|
11 |
+
value = data["historical"][0].get("close", 0)
|
12 |
+
return f"{ticker}'s stock price on {date} was ${value:.2f}."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching historical stock price: {e}"
|
modules/get_income_tax.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_income_tax.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetIncomeTax:
|
5 |
+
async def get_data(self, ticker):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
# Placeholder for income tax data logic
|
9 |
+
return None
|
10 |
+
except Exception as e:
|
11 |
+
return None
|
modules/get_interest.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_interest.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetInterest:
|
5 |
+
async def get_data(self, ticker):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
# Placeholder for interest expense/income logic
|
9 |
+
return None
|
10 |
+
except Exception as e:
|
11 |
+
return None
|
modules/get_market_cap.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_market_cap.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetMarketCap:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_profile(ticker)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No market cap data available for {ticker}."
|
11 |
+
value = data[0].get("mktCap", 0) / 1_000_000_000
|
12 |
+
return f"{ticker}'s market cap is ${value:.2f} billion."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching market cap: {e}"
|
modules/get_net_income.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_net_income.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetNetIncome:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_income_statement(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No net income data available for {ticker}."
|
11 |
+
value = data[0].get("netIncome", 0)
|
12 |
+
return f"{ticker}'s net income for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching net income: {e}"
|
modules/get_profit_margin.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_profit_margin.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetProfitMargin:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_ratios(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No profit margin data available for {ticker}."
|
11 |
+
value = data[0].get("netProfitMargin", 0) * 100
|
12 |
+
return f"{ticker}'s profit margin for {year or 'the latest year'} is {value:.2f}%."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching profit margin: {e}"
|
modules/get_research_info.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_research_info.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetResearchInfo:
|
5 |
+
async def get_data(self, ticker):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
# Placeholder for R&D info logic
|
9 |
+
return None
|
10 |
+
except Exception as e:
|
11 |
+
return None
|
modules/get_revenue.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_revenue.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetRevenue:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_income_statement(ticker, year=year)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No revenue data available for {ticker}."
|
11 |
+
value = data[0].get("revenue", 0)
|
12 |
+
return f"{ticker}'s revenue for {year or 'the latest year'} is ${value / 1_000_000_000:.2f} billion."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching revenue: {e}"
|
modules/get_stock_price.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/get_stock_price.py
|
2 |
+
from api.endpoints import FMPEndpoints
|
3 |
+
|
4 |
+
class GetStockPrice:
|
5 |
+
async def get_data(self, ticker, year=None, date=None):
|
6 |
+
endpoints = FMPEndpoints()
|
7 |
+
try:
|
8 |
+
data = await endpoints.get_quote_short(ticker)
|
9 |
+
if not data:
|
10 |
+
return f"Error: No stock price data available for {ticker}."
|
11 |
+
value = data[0].get("price", 0)
|
12 |
+
return f"{ticker}'s current stock price is ${value:.2f}."
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error fetching stock price: {e}"
|
rag/__pycache__/embedder.cpython-310.pyc
ADDED
Binary file (968 Bytes). View file
|
|
rag/__pycache__/retriever.cpython-310.pyc
ADDED
Binary file (4.61 kB). View file
|
|
rag/__pycache__/sql_db.cpython-310.pyc
ADDED
Binary file (5.51 kB). View file
|
|
rag/__pycache__/web_search.cpython-310.pyc
ADDED
Binary file (588 Bytes). View file
|
|
rag/embedder.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# rag/embedder.py
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class Embedder:
|
6 |
+
def __init__(self, model_name="all-MiniLM-L6-v2"): # "all-mpnet-base-v2"
|
7 |
+
self.model = SentenceTransformer(model_name)
|
8 |
+
|
9 |
+
def embed(self, texts):
|
10 |
+
"""
|
11 |
+
Embed a list of texts into vectors.
|
12 |
+
Args:
|
13 |
+
texts (list of str): Texts to embed.
|
14 |
+
Returns:
|
15 |
+
numpy.ndarray: Embeddings.
|
16 |
+
"""
|
17 |
+
if isinstance(texts, str):
|
18 |
+
texts = [texts]
|
19 |
+
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
20 |
+
return embeddings
|
rag/graphrag.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from neo4j import GraphDatabase
|
3 |
+
from llama_index.core import (
|
4 |
+
SimpleDirectoryReader,
|
5 |
+
VectorStoreIndex,
|
6 |
+
StorageContext,
|
7 |
+
load_index_from_storage,
|
8 |
+
Settings,
|
9 |
+
)
|
10 |
+
from llama_index.llms.ollama import Ollama
|
11 |
+
from llama_index.vector_stores.neo4jvector import Neo4jVectorStore
|
12 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
13 |
+
|
14 |
+
|
15 |
+
class GraphRAGRetriever:
|
16 |
+
def __init__(self, neo4j_url, neo4j_username, neo4j_password):
|
17 |
+
# Set up the embedding model
|
18 |
+
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
19 |
+
|
20 |
+
# Set up the Ollama LLM
|
21 |
+
self.llm = Ollama(model="gemma:2b", request_timeout=None)
|
22 |
+
|
23 |
+
# Configure Settings
|
24 |
+
Settings.llm = self.llm
|
25 |
+
Settings.embed_model = self.embed_model
|
26 |
+
|
27 |
+
# Set up the Neo4j driver
|
28 |
+
self.driver = GraphDatabase.driver(neo4j_url, auth=(neo4j_username, neo4j_password))
|
29 |
+
|
30 |
+
# Set up the Neo4j vector store
|
31 |
+
self.vector_store = Neo4jVectorStore(
|
32 |
+
url=neo4j_url,
|
33 |
+
username=neo4j_username,
|
34 |
+
password=neo4j_password,
|
35 |
+
embedding_dimension=384, # Matches MiniLM model
|
36 |
+
driver=self.driver
|
37 |
+
)
|
38 |
+
|
39 |
+
def ingest_documents(self, directory_path):
|
40 |
+
# Load documents from the specified directory
|
41 |
+
documents = SimpleDirectoryReader(directory_path).load_data()
|
42 |
+
|
43 |
+
# Create the vector index
|
44 |
+
index = VectorStoreIndex.from_documents(
|
45 |
+
documents,
|
46 |
+
vector_store=self.vector_store,
|
47 |
+
)
|
48 |
+
|
49 |
+
# Persist the index to disk
|
50 |
+
index.storage_context.persist()
|
51 |
+
|
52 |
+
def query(self, question):
|
53 |
+
# Load the index from storage
|
54 |
+
storage_context = StorageContext.from_defaults(persist_dir="./storage")
|
55 |
+
index = load_index_from_storage(storage_context)
|
56 |
+
|
57 |
+
# Create a query engine and execute the query
|
58 |
+
query_engine = index.as_query_engine()
|
59 |
+
response = query_engine.query(question)
|
60 |
+
|
61 |
+
return str(response)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
retriever = GraphRAGRetriever(
|
66 |
+
neo4j_url="bolt://localhost:7687/",
|
67 |
+
neo4j_username="neo4j",
|
68 |
+
neo4j_password=os.getenv("NEO4J_PASSWORD")
|
69 |
+
)
|
70 |
+
retriever.ingest_documents("/home/bapary/Music/AI Finance Agent/rag/data")
|
71 |
+
answer = retriever.query("What is the revenue of Company Microsoft in 2021?")
|
72 |
+
print(answer)
|
rag/retriever.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
from .embedder import Embedder
|
5 |
+
from fuzzywuzzy import fuzz
|
6 |
+
from langchain_community.llms import Ollama
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
|
9 |
+
class Retriever:
|
10 |
+
def __init__(self, file_path):
|
11 |
+
self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
|
12 |
+
self.index = None
|
13 |
+
self.documents = []
|
14 |
+
self.data = None
|
15 |
+
self.embeddings = None
|
16 |
+
self.load_file(file_path)
|
17 |
+
self.build_index()
|
18 |
+
|
19 |
+
def load_file(self, file_path):
|
20 |
+
try:
|
21 |
+
if file_path.endswith('.csv'):
|
22 |
+
self.data = pd.read_csv(file_path)
|
23 |
+
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
|
24 |
+
self.data = pd.read_excel(file_path)
|
25 |
+
else:
|
26 |
+
raise ValueError("Unsupported file format. Use .csv, .xlsx, or .xls")
|
27 |
+
self.documents = self.data["Ticker"].astype(str).tolist()
|
28 |
+
except Exception as e:
|
29 |
+
print(f"Error loading file: {e}")
|
30 |
+
self.documents = []
|
31 |
+
self.data = pd.DataFrame()
|
32 |
+
|
33 |
+
def build_index(self):
|
34 |
+
if not self.documents:
|
35 |
+
return
|
36 |
+
self.embeddings = self.embedder.embed(self.documents)
|
37 |
+
dim = self.embeddings.shape[1]
|
38 |
+
self.index = faiss.IndexFlatL2(dim)
|
39 |
+
self.index.add(self.embeddings)
|
40 |
+
|
41 |
+
def retrieve(self, query, entities, k=3, threshold=0.7):
|
42 |
+
|
43 |
+
query_prompt = f"{entities['ticker']} {entities['metric']} {entities['year']}"
|
44 |
+
# print(query_prompt)
|
45 |
+
if not self.index or not self.documents or self.data.empty:
|
46 |
+
return []
|
47 |
+
|
48 |
+
query_parts = query_prompt.split()
|
49 |
+
if len(query_parts) != 3:
|
50 |
+
print("Query must follow 'ticker metric year' pattern")
|
51 |
+
return []
|
52 |
+
|
53 |
+
query_ticker, query_metric, query_year = query_parts
|
54 |
+
|
55 |
+
# Ticker similarity
|
56 |
+
query_ticker_embedding = self.embedder.embed([query_ticker])
|
57 |
+
distances, indices = self.index.search(query_ticker_embedding, k)
|
58 |
+
ticker_matches = []
|
59 |
+
for i, idx in enumerate(indices[0]):
|
60 |
+
if idx < len(self.documents):
|
61 |
+
ticker = self.data.iloc[idx]["Ticker"]
|
62 |
+
similarity_score = 1 - distances[0][i] / 2
|
63 |
+
ticker_matches.append((ticker, similarity_score, idx))
|
64 |
+
|
65 |
+
# Metric similarity
|
66 |
+
metric_embeddings = self.embedder.embed(self.data.columns.tolist())
|
67 |
+
query_metric_embedding = self.embedder.embed([query_metric])[0]
|
68 |
+
metric_scores = []
|
69 |
+
for col, col_embedding in zip(self.data.columns, metric_embeddings):
|
70 |
+
if col.lower() in ["ticker", "year"]:
|
71 |
+
continue
|
72 |
+
cos_sim = np.dot(query_metric_embedding, col_embedding) / (
|
73 |
+
np.linalg.norm(query_metric_embedding) * np.linalg.norm(col_embedding)
|
74 |
+
)
|
75 |
+
metric_scores.append((col, cos_sim))
|
76 |
+
|
77 |
+
# Year similarity
|
78 |
+
if "Year" not in self.data.columns:
|
79 |
+
print("No 'Year' column found in data")
|
80 |
+
return []
|
81 |
+
year_scores = []
|
82 |
+
for year in self.data["Year"].astype(str).unique():
|
83 |
+
similarity = fuzz.ratio(query_year, year) / 100.0
|
84 |
+
year_scores.append((year, similarity))
|
85 |
+
|
86 |
+
# Combine matches
|
87 |
+
retrieved_data = []
|
88 |
+
seen = set()
|
89 |
+
for ticker, ticker_score, idx in ticker_matches:
|
90 |
+
if ticker_score < threshold:
|
91 |
+
continue
|
92 |
+
for metric, metric_score in metric_scores:
|
93 |
+
if metric_score < threshold:
|
94 |
+
continue
|
95 |
+
for year, year_score in year_scores:
|
96 |
+
if year_score < 0.5:
|
97 |
+
continue
|
98 |
+
combined_score = (ticker_score + metric_score + year_score) / 3
|
99 |
+
match = self.data[
|
100 |
+
(self.data["Ticker"].str.lower() == ticker.lower()) &
|
101 |
+
(self.data["Year"].astype(str) == year) &
|
102 |
+
(self.data[metric].notnull())
|
103 |
+
]
|
104 |
+
if not match.empty:
|
105 |
+
value = match[metric].iloc[0]
|
106 |
+
key = (ticker, metric, year)
|
107 |
+
if key not in seen:
|
108 |
+
seen.add(key)
|
109 |
+
retrieved_data.append({
|
110 |
+
"ticker": ticker,
|
111 |
+
"metric": metric,
|
112 |
+
"value": value,
|
113 |
+
"year": year,
|
114 |
+
"combined_score": combined_score
|
115 |
+
})
|
116 |
+
|
117 |
+
if retrieved_data:
|
118 |
+
# print(retrieved_data)
|
119 |
+
retrieved_data.sort(key=lambda x: x["combined_score"], reverse=True)
|
120 |
+
best_match = retrieved_data[0]
|
121 |
+
answer = answer_question(query, best_match)
|
122 |
+
return answer
|
123 |
+
|
124 |
+
return "No relevant data found."
|
125 |
+
|
126 |
+
def answer_question(question, retrieved_data):
|
127 |
+
"""
|
128 |
+
Use a lightweight LLM to generate a natural-language answer on CPU.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
question (str): The question to answer
|
132 |
+
retrieved_data (list): List of dictionaries with ticker, metric, value, year
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
str: Natural-language answer
|
136 |
+
"""
|
137 |
+
# print(question)
|
138 |
+
# print(retrieved_data)
|
139 |
+
try:
|
140 |
+
# Initialize lightweight LLM (llama3.2:3b, CPU-friendly)
|
141 |
+
llm = Ollama(model="gemma:2b", num_gpu=0) # Explicitly disable GPU
|
142 |
+
|
143 |
+
# Minimal prompt for CPU efficiency
|
144 |
+
prompt_template = PromptTemplate(
|
145 |
+
input_variables=["question", "ticker", "metric", "value", "year"],
|
146 |
+
template=(
|
147 |
+
"Question: {question}\n"
|
148 |
+
"Data: Ticker={ticker}, Metric={metric}, Value={value}, Year={year}\n"
|
149 |
+
"Answer concisely, formatting the value with commas."
|
150 |
+
)
|
151 |
+
)
|
152 |
+
# print(prompt_template)
|
153 |
+
|
154 |
+
# Format data
|
155 |
+
if not retrieved_data:
|
156 |
+
return "No relevant data found."
|
157 |
+
|
158 |
+
prompt = prompt_template.format(
|
159 |
+
question=question,
|
160 |
+
ticker=retrieved_data['ticker'],
|
161 |
+
metric=retrieved_data['metric'],
|
162 |
+
value=retrieved_data, # formatted_value,
|
163 |
+
year=retrieved_data['year']
|
164 |
+
)
|
165 |
+
|
166 |
+
# Generate response
|
167 |
+
response = llm.invoke(prompt)
|
168 |
+
return response.strip()
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
print(f"Error generating answer: {e}")
|
172 |
+
return "Unable to generate answer."
|
173 |
+
|
174 |
+
# def main(file_path, query, question):
|
175 |
+
# """
|
176 |
+
# Main function to process a query, retrieve results, and answer a question.
|
177 |
+
|
178 |
+
# Args:
|
179 |
+
# file_path (str): Path to the CSV or Excel file
|
180 |
+
# query (str): Query string in 'ticker metric year' format
|
181 |
+
# question (str): Natural-language question to answer
|
182 |
+
|
183 |
+
# Returns:
|
184 |
+
# tuple: (retrieved data, answer)
|
185 |
+
# """
|
186 |
+
# try:
|
187 |
+
# retriever = Retriever(file_path)
|
188 |
+
# results = retriever.retrieve(query)
|
189 |
+
# answer = answer_question(question, results)
|
190 |
+
# return results, answer
|
191 |
+
# except Exception as e:
|
192 |
+
# print(f"Error processing query: {e}")
|
193 |
+
# return [], "Unable to process query."
|
194 |
+
|
195 |
+
# if __name__ == "__main__":
|
196 |
+
# file_path = "./financial_data.csv"
|
197 |
+
# query = "AAPL InterestExpense 2024"
|
198 |
+
# question = "What is the InterestExpense of AAPL 2024?"
|
199 |
+
# results, answer = main(file_path, query, question)
|
200 |
+
# for result in results:
|
201 |
+
# print(f"Ticker: {result['ticker']}, Metric: {result['metric']}, Value: {result['value']}, Year: {result['year']}")
|
202 |
+
# print(f"Answer: {answer}")
|
rag/sql_db.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# rag/retriever.py
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import sqlite3
|
7 |
+
from .embedder import Embedder
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
class SQL_Key_Pair:
|
11 |
+
def __init__(self, file_path="financial data sp500 companies.csv", model_name="all-MiniLM-L6-v2", db_path="financial_data.db"):
|
12 |
+
self.embedder = Embedder(model_name)
|
13 |
+
self.index = None
|
14 |
+
self.documents = []
|
15 |
+
self.data = None
|
16 |
+
self.embeddings = None
|
17 |
+
self.db_conn = sqlite3.connect(db_path)
|
18 |
+
self.create_db_table()
|
19 |
+
self.load_data(file_path)
|
20 |
+
|
21 |
+
def create_db_table(self):
|
22 |
+
"""
|
23 |
+
Create the custom_financials table in the database if it doesn’t exist.
|
24 |
+
"""
|
25 |
+
cursor = self.db_conn.cursor()
|
26 |
+
cursor.execute("""
|
27 |
+
CREATE TABLE IF NOT EXISTS custom_financials (
|
28 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
29 |
+
source_file TEXT,
|
30 |
+
firm TEXT,
|
31 |
+
ticker TEXT,
|
32 |
+
date TEXT,
|
33 |
+
metric TEXT,
|
34 |
+
value REAL,
|
35 |
+
last_updated TEXT
|
36 |
+
)
|
37 |
+
""")
|
38 |
+
self.db_conn.commit()
|
39 |
+
|
40 |
+
def load_data(self, file_path):
|
41 |
+
"""
|
42 |
+
Load financial data from a CSV or Excel file and store it in the database.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
if file_path.endswith('.csv'):
|
46 |
+
df = pd.read_csv(file_path)
|
47 |
+
elif file_path.endswith('.xlsx'):
|
48 |
+
df = pd.read_excel(file_path)
|
49 |
+
else:
|
50 |
+
raise ValueError("Unsupported file format. Use .csv or .xlsx.")
|
51 |
+
|
52 |
+
self.data = df
|
53 |
+
self.documents = self.data["Ticker"].astype(str).tolist()
|
54 |
+
|
55 |
+
cursor = self.db_conn.cursor()
|
56 |
+
for _, row in self.data.iterrows():
|
57 |
+
firm = row.get("firm", "")
|
58 |
+
ticker = row.get("Ticker", "")
|
59 |
+
date = row.get("date", "")
|
60 |
+
for column in self.data.columns:
|
61 |
+
if pd.notna(row[column]):
|
62 |
+
try:
|
63 |
+
value = float(row[column])
|
64 |
+
except (ValueError, TypeError):
|
65 |
+
value = 0.0
|
66 |
+
cursor.execute("""
|
67 |
+
INSERT INTO custom_financials (source_file, firm, ticker, date, metric, value, last_updated)
|
68 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
69 |
+
""", (os.path.basename(file_path), firm, ticker, str(date), column, value, datetime.now().isoformat()))
|
70 |
+
self.db_conn.commit()
|
71 |
+
print(f"Loaded {len(self.data)} rows from {file_path} into custom_financials.")
|
72 |
+
self.build_index() # Rebuild FAISS index after loading
|
73 |
+
except Exception as e:
|
74 |
+
print(f"Error loading data: {e}")
|
75 |
+
self.documents = []
|
76 |
+
self.data = pd.DataFrame()
|
77 |
+
|
78 |
+
def build_index(self):
|
79 |
+
"""
|
80 |
+
Build a FAISS index from the embedded descriptions.
|
81 |
+
"""
|
82 |
+
if not self.documents:
|
83 |
+
return
|
84 |
+
self.embeddings = self.embedder.embed(self.documents)
|
85 |
+
dim = self.embeddings.shape[1]
|
86 |
+
self.index = faiss.IndexFlatL2(dim)
|
87 |
+
self.index.add(self.embeddings)
|
88 |
+
|
89 |
+
def keyword_match_search(self, entities):
|
90 |
+
"""
|
91 |
+
Perform strict keyword match based search from CSV.
|
92 |
+
"""
|
93 |
+
if self.data is None or self.data.empty:
|
94 |
+
return "No data loaded."
|
95 |
+
|
96 |
+
ticker = entities.get("ticker", "")
|
97 |
+
metric = entities.get("metric", "")
|
98 |
+
|
99 |
+
if not ticker or not metric:
|
100 |
+
return "No relevant data found."
|
101 |
+
|
102 |
+
ticker = ticker.lower()
|
103 |
+
metric = metric.lower()
|
104 |
+
|
105 |
+
retrieved_text = ""
|
106 |
+
for _, row in self.data.iterrows():
|
107 |
+
if str(row.get("Ticker", "")).lower() == ticker:
|
108 |
+
for col in self.data.columns:
|
109 |
+
if col.lower() == metric:
|
110 |
+
if pd.isna(row[col]) or row[col] == "":
|
111 |
+
continue
|
112 |
+
value_in_billions = row[col] / 1_000_000_000
|
113 |
+
retrieved_text = f"Retrieved {metric} for {ticker} is : ${value_in_billions:.2f} billion."
|
114 |
+
break
|
115 |
+
break
|
116 |
+
|
117 |
+
if not retrieved_text:
|
118 |
+
return "No relevant data found."
|
119 |
+
|
120 |
+
return retrieved_text
|
121 |
+
|
122 |
+
|
123 |
+
def query_csv(self, query, k=3):
|
124 |
+
"""
|
125 |
+
Query the CSV data with a user query.
|
126 |
+
"""
|
127 |
+
retrieved_data = self.retrieve(query, k=k)
|
128 |
+
if not retrieved_data:
|
129 |
+
return "No relevant data found."
|
130 |
+
|
131 |
+
responses = []
|
132 |
+
for entry in retrieved_data:
|
133 |
+
try:
|
134 |
+
value = float(entry["value"])
|
135 |
+
value_in_billions = value / 1_000_000_000
|
136 |
+
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was ${value_in_billions:.2f} billion."
|
137 |
+
except:
|
138 |
+
response = f"{entry['ticker']}'s {entry['metric']} for {entry['year']} was {entry['value']}."
|
139 |
+
responses.append(response)
|
140 |
+
|
141 |
+
return "\n".join(responses)
|
142 |
+
|
143 |
+
|
144 |
+
def entity_based_query(self, entities):
|
145 |
+
return self.keyword_match_search(entities)
|
146 |
+
|
147 |
+
def query_db(self, ticker, metric):
|
148 |
+
"""
|
149 |
+
Query the custom_financials table based on ticker and metric, ignoring date and year.
|
150 |
+
"""
|
151 |
+
try:
|
152 |
+
cursor = self.db_conn.cursor()
|
153 |
+
query = """
|
154 |
+
SELECT value FROM custom_financials
|
155 |
+
WHERE ticker = ? AND metric = ?
|
156 |
+
LIMIT 1
|
157 |
+
"""
|
158 |
+
params = [ticker, metric]
|
159 |
+
cursor.execute(query, params)
|
160 |
+
result = cursor.fetchone()
|
161 |
+
if result:
|
162 |
+
value = result[0]
|
163 |
+
value_in_billions = value / 1_000_000_000
|
164 |
+
return f"{metric} for {ticker}: ${value_in_billions:.2f} billion."
|
165 |
+
return f"No {metric} data found for {ticker}."
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Error querying database: {e}")
|
168 |
+
return f"Error querying database: {str(e)}"
|
169 |
+
|
170 |
+
def __del__(self):
|
171 |
+
self.db_conn.close()
|
rag/web_search.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from duckduckgo_search import DDGS
|
2 |
+
|
3 |
+
def duckduckgo_web_search(query, max_results=1):
|
4 |
+
results = []
|
5 |
+
with DDGS() as ddgs:
|
6 |
+
for r in ddgs.text(query, region='wt-wt', safesearch='Off', max_results=max_results):
|
7 |
+
results.append({
|
8 |
+
"title": r["title"],
|
9 |
+
"href": r["href"],
|
10 |
+
"snippet": r["body"]
|
11 |
+
})
|
12 |
+
return results
|
repo.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.11.16
|
2 |
+
banks==2.1.1
|
3 |
+
blis==1.2.1
|
4 |
+
catalogue==2.0.10
|
5 |
+
certifi==2025.1.31
|
6 |
+
click==8.1.8
|
7 |
+
cloudpathlib==0.21.0
|
8 |
+
colorama==0.4.6
|
9 |
+
confection==0.1.5
|
10 |
+
cymem==2.0.11
|
11 |
+
dirtyjson==1.0.8
|
12 |
+
distro==1.9.0
|
13 |
+
duckduckgo-search
|
14 |
+
fastapi
|
15 |
+
faiss-cpu==1.10.0
|
16 |
+
filetype==1.2.0
|
17 |
+
fuzzywuzzy==0.18.0
|
18 |
+
griffe==1.7.2
|
19 |
+
httpx-sse==0.4.0
|
20 |
+
iniconfig==2.1.0
|
21 |
+
marisa-trie==1.2.1
|
22 |
+
ml-dtypes==0.5.1
|
23 |
+
murmurhash==1.0.12
|
24 |
+
neo4j==5.28.1
|
25 |
+
nest-asyncio==1.6.0
|
26 |
+
nltk==3.9.1
|
27 |
+
numpy==1.26.4
|
28 |
+
packaging==23.2
|
29 |
+
platformdirs==4.3.7
|
30 |
+
pluggy==1.5.0
|
31 |
+
primp==0.14.0
|
32 |
+
pyaudio==0.2.14
|
33 |
+
pydantic==2.11.2
|
34 |
+
pydantic-core==2.33.1
|
35 |
+
pydantic-settings==2.8.1
|
36 |
+
pypdf==4.3.1
|
37 |
+
pytest==8.3.5
|
38 |
+
pyyaml==6.0.2
|
39 |
+
requests==2.32.3
|
40 |
+
rich==14.0.0
|
41 |
+
scikit-learn==1.6.1
|
42 |
+
sentence-transformers==2.6.1
|
43 |
+
shellingham==1.5.4
|
44 |
+
six==1.17.0
|
45 |
+
smart-open==7.1.0
|
46 |
+
spacy==3.8.5
|
47 |
+
spacy-legacy==3.0.12
|
48 |
+
spacy-loggers==1.0.5
|
49 |
+
spacy-lookups-data==1.0.5
|
50 |
+
sqlalchemy==2.0.40
|
51 |
+
srsly==2.5.1
|
52 |
+
srt==3.5.3
|
53 |
+
striprtf==0.0.26
|
54 |
+
tensorboard==2.19.0
|
55 |
+
tensorflow==2.19.0
|
56 |
+
tf-keras==2.19.0
|
57 |
+
thinc==8.3.4
|
58 |
+
threadpoolctl==3.6.0
|
59 |
+
tomli==2.2.1
|
60 |
+
tqdm==4.67.1
|
61 |
+
typer==0.15.2
|
62 |
+
typing-inspection==0.4.0
|
63 |
+
uvicorn==0.34.0
|
64 |
+
vosk==0.3.45
|
65 |
+
wasabi==1.1.3
|
66 |
+
weasel==0.4.1
|
67 |
+
yarl==1.19.0
|
68 |
+
langchain
|
69 |
+
langchain_community
|