Spaces:
Running
Running
File size: 6,471 Bytes
a80e992 4d4ba02 157ff13 a80e992 157ff13 a80e992 157ff13 a80e992 04dff86 a80e992 157ff13 a80e992 157ff13 e904acc a80e992 157ff13 a80e992 157ff13 a80e992 157ff13 e904acc 157ff13 e904acc 157ff13 a80e992 157ff13 a80e992 157ff13 a80e992 157ff13 62d72e3 157ff13 a80e992 157ff13 a80e992 157ff13 a80e992 e904acc a80e992 157ff13 a80e992 157ff13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import streamlit as st
import pymupdf
import re
import traceback
import faiss
import numpy as np
import requests
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
import torch
import os
st.set_page_config(page_title="Financial Insights Chatbot", page_icon="π", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
try:
llm = ChatGroq(temperature=0, model="llama3-70b-8192", api_key=GROQ_API_KEY)
st.success("β
LLM initialized successfully. Using llama3-70b-8192")
except Exception as e:
st.error("β Failed to initialize Groq LLM.")
traceback.print_exc()
embedding_model = SentenceTransformer("baconnier/Finance2_embedding_small_en-V1.5", device=device)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
def fetch_financial_data(company_ticker):
if not company_ticker:
return "No ticker symbol provided. Please enter a valid company ticker."
try:
overview_url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
overview_response = requests.get(overview_url)
if overview_response.status_code == 200:
overview_data = overview_response.json()
market_cap = overview_data.get("MarketCapitalization", "N/A")
else:
return "Error fetching company overview."
income_url = f"https://www.alphavantage.co/query?function=INCOME_STATEMENT&symbol={company_ticker}&apikey={ALPHA_VANTAGE_API_KEY}"
income_response = requests.get(income_url)
if income_response.status_code == 200:
income_data = income_response.json()
annual_reports = income_data.get("annualReports", [])
revenue = annual_reports[0].get("totalRevenue", "N/A") if annual_reports else "N/A"
else:
return "Error fetching income statement."
return f"Market Cap: ${market_cap}\nTotal Revenue: ${revenue}"
except Exception as e:
traceback.print_exc()
return "Error fetching financial data."
def extract_and_embed_text(pdf_file):
"""Processes PDFs and generates embeddings with GPU acceleration using pymupdf."""
try:
docs, tokenized_texts = [], []
with pymupdf.open(stream=pdf_file.read(), filetype="pdf") as doc:
full_text = "\n".join(page.get_text("text") for page in doc)
chunks = text_splitter.split_text(full_text)
for chunk in chunks:
docs.append(chunk)
tokenized_texts.append(chunk.split())
embeddings = embedding_model.encode(docs, batch_size=64, convert_to_numpy=True, normalize_embeddings=True)
embedding_dim = embeddings.shape[1]
index = faiss.IndexHNSWFlat(embedding_dim, 32)
index.add(embeddings)
bm25 = BM25Okapi(tokenized_texts)
return docs, embeddings, index, bm25
except Exception as e:
traceback.print_exc()
return [], [], None, None
def retrieve_relevant_docs(user_query, docs, index, bm25):
"""Hybrid search using FAISS cosine similarity & BM25 keyword retrieval."""
query_embedding = embedding_model.encode(user_query, convert_to_numpy=True, normalize_embeddings=True)
_, faiss_indices = index.search(np.array([query_embedding]), 8)
bm25_scores = bm25.get_scores(user_query.split())
bm25_indices = np.argsort(bm25_scores)[::-1][:8]
combined_indices = list(set(faiss_indices[0]) | set(bm25_indices))
return [docs[i] for i in combined_indices[:3]]
def generate_response(user_query, company_ticker, mode, uploaded_file):
try:
if mode == "π PDF Upload Mode":
docs, embeddings, index, bm25 = extract_and_embed_text(uploaded_file)
if not docs:
return "β Error extracting text from PDF."
retrieved_docs = retrieve_relevant_docs(user_query, docs, index, bm25)
context = "\n\n".join(retrieved_docs)
prompt = f"Summarize the key financial insights from this document:\n\n{context}"
elif mode == "π Live Data Mode":
financial_info = fetch_financial_data(company_ticker)
prompt = f"Analyze the financial status of {company_ticker} based on:\n{financial_info}\n\nUser Query: {user_query}"
else:
return "Invalid mode selected."
response = llm.invoke(prompt)
return response.content
except Exception as e:
traceback.print_exc()
return "Error generating response."
st.markdown(
"<h1 style='text-align: center; color: #4CAF50;'>π AI-Powered Financial Insights Chatbot</h1>",
unsafe_allow_html=True
)
st.markdown(
"<h5 style='text-align: center; color: #666;'>Analyze financial reports or fetch live financial data effortlessly!</h5>",
unsafe_allow_html=True
)
col1, col2 = st.columns(2)
with col1:
st.markdown("### π’ **Choose Your Analysis Mode**")
mode = st.radio("", ["π PDF Upload Mode", "π Live Data Mode"], horizontal=True)
with col2:
st.markdown("### π **Enter Your Query**")
user_query = st.text_input("π¬ What financial insights are you looking for?")
if mode == "π PDF Upload Mode":
st.markdown("### π Upload Your Financial Report")
uploaded_file = st.file_uploader("πΌ Upload PDF (Only for PDF Mode)", type=["pdf"])
company_ticker = None
else:
st.markdown("### π Live Market Data")
company_ticker = st.text_input("π’ Enter Company Ticker Symbol", placeholder="e.g., AAPL, MSFT")
uploaded_file = None
if st.button("π Analyze Now"):
if mode == "π PDF Upload Mode" and not uploaded_file:
st.error("β Please upload a PDF file.")
elif mode == "π Live Data Mode" and not company_ticker:
st.error("β Please enter a valid company ticker symbol.")
else:
with st.spinner("π Your Query is Processing, this can take up to 5 - 7 minutes β³"):
response = generate_response(user_query, company_ticker, mode, uploaded_file)
st.markdown("---")
st.markdown("<h3 style='color: #4CAF50;'>π‘ AI Response</h3>", unsafe_allow_html=True)
st.write(response)
st.markdown("---")
|