svsaurav95's picture
Update app.py
c66a153 verified
raw
history blame contribute delete
6.68 kB
import streamlit as st
import pymupdf
import re
import traceback
import faiss
import numpy as np
import requests
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
import torch
import os
st.set_page_config(page_title="Financial Insights Chatbot", page_icon="πŸ“Š", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
try:
llm = ChatGroq(temperature=0, model="llama3-70b-8192", api_key=GROQ_API_KEY)
st.success("βœ… LLM initialized successfully. Using llama3-70b-8192")
except Exception as e:
st.error("❌ Failed to initialize Groq LLM.")
traceback.print_exc()
embedding_model = SentenceTransformer("baconnier/Finance2_embedding_small_en-V1.5", device=device)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
def fetch_financial_data(company_ticker):
if not company_ticker:
return "No ticker symbol provided. Please enter a valid company ticker."
try:
overview_url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
overview_response = requests.get(overview_url)
if overview_response.status_code == 200:
overview_data = overview_response.json()
market_cap = overview_data.get("MarketCapitalization", "N/A")
else:
return "Error fetching company overview."
income_url = f"https://www.alphavantage.co/query?function=INCOME_STATEMENT&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
income_response = requests.get(income_url)
if income_response.status_code == 200:
income_data = income_response.json()
annual_reports = income_data.get("annualReports", [])
revenue = annual_reports[0].get("totalRevenue", "N/A") if annual_reports else "N/A"
else:
return "Error fetching income statement."
return f"Market Cap: ${market_cap}\nTotal Revenue: ${revenue}"
except Exception as e:
traceback.print_exc()
return "Error fetching financial data."
def extract_and_embed_text(pdf_file):
"""Processes PDFs and generates embeddings with GPU acceleration using pymupdf."""
try:
docs, tokenized_texts = [], []
with pymupdf.open(stream=pdf_file.read(), filetype="pdf") as doc:
full_text = "\n".join(page.get_text("text") for page in doc)
chunks = text_splitter.split_text(full_text)
for chunk in chunks:
docs.append(chunk)
tokenized_texts.append(chunk.split())
embeddings = embedding_model.encode(docs, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
embedding_dim = embeddings.shape[1]
index = faiss.IndexHNSWFlat(embedding_dim, 32)
index.add(embeddings)
bm25 = BM25Okapi(tokenized_texts)
return docs, embeddings, index, bm25
except Exception as e:
traceback.print_exc()
return [], [], None, None
def retrieve_relevant_docs(user_query, docs, index, bm25):
"""Hybrid search using FAISS cosine similarity & BM25 keyword retrieval."""
query_embedding = embedding_model.encode(user_query, convert_to_numpy=True, normalize_embeddings=True)
_, faiss_indices = index.search(np.array([query_embedding]), 8)
bm25_scores = bm25.get_scores(user_query.split())
bm25_indices = np.argsort(bm25_scores)[::-1][:8]
combined_indices = list(set(faiss_indices[0]) | set(bm25_indices))
return [docs[i] for i in combined_indices[:3]]
def generate_response(user_query, pdf_ticker, ai_ticker, mode, uploaded_file):
try:
if mode == "πŸ“„ PDF Upload Mode":
docs, embeddings, index, bm25 = extract_and_embed_text(uploaded_file)
if not docs:
return "❌ Error extracting text from PDF."
retrieved_docs = retrieve_relevant_docs(user_query, docs, index, bm25)
context = "\n\n".join(retrieved_docs)
prompt = f"Summarize the key financial insights for {pdf_ticker} from this document:\n\n{context}"
elif mode == "🌍 Live Data Mode":
financial_info = fetch_financial_data(ai_ticker)
prompt = f"Analyze the financial status of {ai_ticker} based on:\n{financial_info}\n\nUser Query: {user_query}"
else:
return "Invalid mode selected."
response = llm.invoke(prompt)
return response.content
except Exception as e:
traceback.print_exc()
return "Error generating response."
st.markdown(
"<h1 style='text-align: center; color: #4CAF50;'>πŸ“„ FinQuery RAG Chatbot</h1>",
unsafe_allow_html=True
)
st.markdown(
"<h5 style='text-align: center; color: #666;'>Analyze financial reports or fetch live financial data effortlessly!</h5>",
unsafe_allow_html=True
)
col1, col2 = st.columns(2)
with col1:
st.markdown("### 🏒 **Choose Your Analysis Mode**")
mode = st.radio("", ["πŸ“„ PDF Upload Mode", "🌍 Live Data Mode"], horizontal=True)
with col2:
st.markdown("### πŸ”Ž **Enter Your Query**")
user_query = st.text_input("πŸ’¬ What financial insights are you looking for?")
st.markdown("---")
if mode == "πŸ“„ PDF Upload Mode":
st.markdown("### πŸ“‚ Upload Your Financial Report")
uploaded_file = st.file_uploader("πŸ”Ό Upload PDF (Only for PDF Mode)", type=["pdf"])
pdf_ticker = st.text_input("🏒 Enter Company Ticker for PDF Insights", placeholder="e.g., INFY, TCS")
ai_ticker = None
else:
st.markdown("### 🌍 Live Market Data")
ai_ticker = st.text_input("🏒 Enter Company Ticker for AI Insights", placeholder="e.g., AAPL, MSFT")
uploaded_file = None
pdf_ticker = None
if st.button("Analyze Now"):
if mode == "πŸ“„ PDF Upload Mode" and (not uploaded_file or not pdf_ticker):
st.error("❌ Please upload a PDF and enter a company ticker for insights.")
elif mode == "🌍 Live Data Mode" and not ai_ticker:
st.error("❌ Please enter a valid company ticker for AI insights.")
else:
with st.spinner("πŸ” Your Query is Processing, this can take up to 5 - 7 minutes ⏳"):
response = generate_response(user_query, pdf_ticker, ai_ticker, mode, uploaded_file)
st.markdown("---")
st.markdown("<h3 style='color: #4CAF50;'>πŸ’‘ AI Response</h3>", unsafe_allow_html=True)
st.write(response)
st.markdown("---")