PhoenixDecim's picture
Initial Commit
92ae1b2
raw
history blame
19.7 kB
"""SLM with RAG for financial statements"""
# Importing the dependencies
import logging
import os
import subprocess
import time
import re
import pickle
import numpy as np
import pandas as pd
import torch
import spacy
import pdfplumber
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import faiss
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from data_filters import (
restricted_patterns,
restricted_topics,
FINANCIAL_DATA_PATTERNS,
sensitive_terms,
FINANCIAL_TERMS,
)
# Initialize logger
logging.basicConfig(
# filename="app.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger()
os.makedirs("data", exist_ok=True)
# SLM: Microsoft PHI-2 model is loaded
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
# But it gives the best results among the three
DEVICE = "cpu" # or cuda
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
# MODEL_NAME = "tiiuae/falcon-rw-1b"
MODEL_NAME = "microsoft/phi-2"
# MODEL_NAME = "google/gemma-3-1b-pt"
# Load the Tokenizer for PHI-2
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
MAX_TOKENS = tokenizer.model_max_length
CONTEXT_MULTIPLIER = 0.7
# The max_context tokens is used to limit the retrieved chunks during querying
# to provide some headroom for the query
MAX_CONTEXT_TOKENS = int(MAX_TOKENS * CONTEXT_MULTIPLIER)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Since the model is to be hosted on a cpu instance, we use float32
# For GPU, we can use float16 or bfloat16
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True
).to(DEVICE)
model.eval()
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
logger.info("Model loaded successfully.")
# Load Sentence Transformer for Embeddings and Cross Encoder for re-ranking
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Load spaCy English model for Named Entity Recognition (mainly for guardrail)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Extract the yaer from the upload file's name if any
def extract_year_from_filename(filename):
"""Extract Year from Filename"""
match = re.search(r"(\d{4})-(\d{4})", filename)
if match:
return match.group(1)
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "Unknown"
# Use PDFPlumber to extract the tables from the uploaded file
# Add the year column for context and create a dataframe
def extract_tables_from_pdf(pdf_path):
"""Extract tables from PDF into a DataFrame"""
all_tables = []
report_year = extract_year_from_filename(pdf_path)
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
tables = page.extract_tables()
for table in tables:
df = pd.DataFrame(table)
df["year"] = report_year
all_tables.append(df)
return pd.concat(all_tables, ignore_index=True) if all_tables else pd.DataFrame()
# Load the csv files directly using pandas into a dataframe
def load_csv(file_path):
"""Loads a CSV file into a DataFrame"""
try:
df = pd.read_csv(file_path)
df["year"] = extract_year_from_filename(file_path)
return df
except Exception as e:
print(f"Error loading CSV: {e}")
return None
# Preprocess the dataframe - Replace null values and create text rows suitable for chunking
def clean_dataframe_text(df):
"""Clean and format PDF/CSV data"""
df.fillna("", inplace=True)
text_data = []
for _, row in df.iterrows():
parts = []
if "year" in df.columns:
parts.append(f"Year: {row['year']}")
parts.extend([str(val).strip() for val in row if str(val).strip()])
text_data.append(", ".join(parts))
df["text"] = text_data
return df[["text"]].replace("", np.nan).dropna()
# Chunk the text for retrival
# Different chunk sizes - 256,512,1024,2048 were tried and 512 worked the best for financial RAG
def chunk_text(text, chunk_size=512):
"""Apply Chunking on the text"""
words = text.split()
chunks, temp_chunk = [], []
for word in words:
if sum(len(w) for w in temp_chunk) + len(temp_chunk) + len(word) <= chunk_size:
temp_chunk.append(word)
else:
chunks.append(" ".join(temp_chunk))
temp_chunk = [word]
if temp_chunk:
chunks.append(" ".join(temp_chunk))
return chunks
# Uses regex to identify financial terms and ensure relevant data is only merged
def is_financial_text(text):
"""Detects financial data"""
return bool(
re.search(
FINANCIAL_DATA_PATTERNS,
text,
re.IGNORECASE,
)
)
# Uses a sentence transformer "all-MiniLM-L6-v2" to embed text chunks
# Stores embeddings in a FAISS vector database for similarity search
# BM25 is implemented alongside FAISS to improve retrieval
# Use FAISS Cosine Similarity index and merge only highly similar text chunks (>85%)
def merge_similar_chunks(chunks, similarity_threshold=0.85):
"""Merge similar chunks while preserving financial data structure"""
if not chunks:
return []
# Encode chunks into embeddings
embeddings = np.array(
embed_model.encode(chunks, normalize_embeddings=True), dtype="float32"
)
# FAISS Cosine Similarity Index
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
# Get top-2 most similar chunks
_, indices = index.search(embeddings, 2)
merged_chunks = {}
for i, idx in enumerate(indices[:, 1]):
if i in merged_chunks or idx in merged_chunks:
continue
sim_score = np.dot(embeddings[i], embeddings[idx])
# Ensure financial data isn't incorrectly merged
if is_financial_text(chunks[i]) or is_financial_text(chunks[idx]):
merged_chunks[i] = chunks[i]
merged_chunks[idx] = chunks[idx]
continue
# Merge only if similarity is high and chunks are adjacent
if sim_score > similarity_threshold and abs(i - idx) == 1:
merged_chunks[i] = chunks[i] + " " + chunks[idx]
merged_chunks[idx] = merged_chunks[i]
else:
merged_chunks[i] = chunks[i]
return list(set(merged_chunks.values()))
# Handle for file upload button in UI
# Processes the uploaded files and generates the embeddings
# The FAISS embeddings and tokenized chunks are saved for retrieval
def process_files(files, chunk_size=512):
"""Process uploaded files and generate embeddings"""
if not files:
logger.warning("No files uploaded!")
return "Please upload at least one PDF or CSV file."
pdf_paths = [file.name for file in files if file.name.endswith(".pdf")]
csv_paths = [file.name for file in files if file.name.endswith(".csv")]
logger.info(f"Processing {len(pdf_paths)} PDFs and {len(csv_paths)} CSVs")
df_list = []
if pdf_paths:
df_list.extend([extract_tables_from_pdf(pdf) for pdf in pdf_paths])
for csv in csv_paths:
df = load_csv(csv)
df_list.append(df)
if not df_list:
logger.warning("No valid data found in the uploaded files")
return "No valid data found in the uploaded files"
df = pd.concat(df_list, ignore_index=True)
df.dropna(how="all", inplace=True)
logger.info("Data extracted from the files")
df_cleaned = clean_dataframe_text(df)
df_cleaned["chunks"] = df_cleaned["text"].apply(lambda x: chunk_text(x, chunk_size))
df_chunks = df_cleaned.explode("chunks").reset_index(drop=True)
merged_chunks = merge_similar_chunks(df_chunks["chunks"].tolist())
chunk_texts = merged_chunks
# chunk_texts = df_chunks["chunks"].tolist()
embeddings = np.array(
embed_model.encode(chunk_texts, normalize_embeddings=True), dtype="float32"
)
# Save FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, "data/faiss_index.bin")
logger.info("FAISS index created and saved.")
# Save BM25 index
tokenized_chunks = [text.lower().split() for text in chunk_texts]
bm25_data = {"tokenized_chunks": tokenized_chunks, "chunk_texts": chunk_texts}
logger.info("BM25 index created and saved.")
with open("data/bm25_data.pkl", "wb") as f:
pickle.dump(bm25_data, f)
return "Files processed successfully! You can now query."
# Input guardrail implementation
# Regex is used to filter queries related to sensitive topics
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
# Uses cosine similarity with the embedded query and sensitive topic vectors
# to filter out queries violating confidential/security rules (additional)
def is_query_allowed(query):
"""Checks if the query violates security or confidentiality rules"""
for pattern in restricted_patterns:
if re.search(pattern, query, re.IGNORECASE):
return False, "This query requests sensitive or confidential information."
doc = nlp(query)
for ent in doc.ents:
if ent.label_ == "PERSON":
for token in ent.subtree:
if token.text.lower() in sensitive_terms:
return (
False,
"Query contains personal salary information, which is restricted.",
)
query_embedding = embed_model.encode(query, normalize_embeddings=True)
topic_embeddings = embed_model.encode(
list(restricted_topics), normalize_embeddings=True
)
similarities = np.dot(topic_embeddings, query_embedding)
if np.max(similarities) > 0.85:
return False, "This query requests sensitive or confidential information."
return True, None
# Boosts the scores for texts containing financial terms
# This is useful during re-ranking
def boost_score(text, base_score, boost_factor=1.2):
"""Boost scores if the text contains financial terms"""
if any(term in text.lower() for term in FINANCIAL_TERMS):
return base_score * boost_factor
return base_score
# FAISS embeddings are used to retrieve semantically similar chunks
# BM25 is used to retrieve relevant chunks based on the keywords (TF-IDF)
# FAISS and BM25 complement each other- similar matches and important exact matches
# The retrieved chunks are merged and sorted based on a lambda FAISS value
# if lambda FAISS is 0.6, weightage for retrieved FAISS chunks are 0.6 and 0.4 for BM25 chunks
# Cross encoder model ms-marco-MiniLM-L6-v2 is used for scoring and re-ranking the chunks
def hybrid_retrieve(query, chunk_texts, index, bm25, top_k=5, lambda_faiss=0.7):
"""Hybrid Retrieval with FAISS, BM25, Cross-Encoder & Financial Term Boosting"""
# FAISS Retrieval
query_embedding = np.array(
[embed_model.encode(query, normalize_embeddings=True)], dtype="float32"
)
_, faiss_indices = index.search(query_embedding, top_k)
faiss_results = [chunk_texts[idx] for idx in faiss_indices[0]]
# BM25 Retrieval
tokenized_query = query.lower().split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
bm25_results = [chunk_texts[idx] for idx in bm25_top_indices]
# Merge FAISS & BM25 Scores
results = {}
for entry in faiss_results:
results[entry] = boost_score(entry, lambda_faiss)
for entry in bm25_results:
results[entry] = results.get(entry, 0) + boost_score(entry, (1 - lambda_faiss))
# Rank initial results
retrieved_docs = sorted(results.items(), key=lambda x: x[1], reverse=True)
retrieved_texts = [r[0] for r in retrieved_docs]
# Cross-Encoder Re-Ranking
query_text_pairs = [[query, text] for text in retrieved_texts]
scores = cross_encoder.predict(query_text_pairs)
ranked_indices = np.argsort(scores)[::-1]
# Return top-ranked results
final_results = [retrieved_texts[i] for i in ranked_indices[:top_k]]
return final_results
# A confidence score is computed using FAISS and BM25 ranking
# FAISS: The similarity score between the query (with response) and the retrieved chunks are normalized
# BM25: The BM25 scores for the query is normalized
# Both the scores are aggregated using a weighted sum (lambda FAISS) and normalized
def compute_confidence_score(query, retrieved_chunks, bm25, lambda_faiss):
"""Calculates a confidence score using FAISS and BM25 rankings."""
if not retrieved_chunks:
return 0
query_embedding = embed_model.encode(query, normalize_embeddings=True)
response_embedding = embed_model.encode(
" ".join(retrieved_chunks), normalize_embeddings=True
)
# FAISS Similarity
faiss_score = np.dot(query_embedding, response_embedding)
normalized_faiss = (faiss_score + 1) / 2
# BM25 Ranking
tokenized_query = query.lower().split()
bm25_scores = bm25.get_scores(tokenized_query)
if bm25_scores.size > 0:
min_bm25 = (
np.min(bm25_scores) if np.min(bm25_scores) != np.max(bm25_scores) else 0
)
max_bm25 = (
np.max(bm25_scores) if np.min(bm25_scores) != np.max(bm25_scores) else 1
)
bm25_score = (
np.mean([bm25_scores[idx] for idx in range(len(retrieved_chunks))])
if len(retrieved_chunks) > 0
else 0
)
normalized_bm25 = (bm25_score - min_bm25) / (max_bm25 - min_bm25)
normalized_bm25 = max(0, min(1, normalized_bm25))
else:
normalized_bm25 = 0
# Final Confidence Score (use Lambda FAISS value for weighted sum)
confidence_score = round(
(normalized_faiss * lambda_faiss + normalized_bm25 * (1 - lambda_faiss)), 2
)
return confidence_score
# UI handle for query model button
# Loads the saved FAISS embeddings and tokenized chunks for BM25
# Check the query for any policy violation
# Retrieve similar texts using the RAG implementation
# Prompt the loaded SLM along with the retrieved texts and compute confidence score
def query_model(
query,
top_k=10,
lambda_faiss=0.5,
repetition_penalty=1.5,
max_new_tokens=100,
use_extraction=False,
):
"""Query function"""
start_time = time.perf_counter()
# Check if FAISS and BM25 indexes exist
if not os.path.exists("data/faiss_index.bin") or not os.path.exists(
"data/bm25_data.pkl"
):
logger.error("No index found! Prompting user to upload PDFs.")
return (
"Index files not found! Please upload PDFs first to generate embeddings.",
"Error",
)
allowed, reason = is_query_allowed(query)
if not allowed:
logger.error(f"Query Rejected: {reason}")
return f"Query Rejected: {reason}", "Warning"
logger.info(
f"Received query: {query} | Top-K: {top_k}, "
f"Lambda: {lambda_faiss}, Tokens: {max_new_tokens}"
)
# Load FAISS & BM25 Indexes
index = faiss.read_index("data/faiss_index.bin")
with open("data/bm25_data.pkl", "rb") as f:
bm25_data = pickle.load(f)
# Restore tokenized chunks and metadata
tokenized_chunks = bm25_data["tokenized_chunks"]
chunk_texts = bm25_data["chunk_texts"]
bm25 = BM25Okapi(tokenized_chunks)
retrieved_chunks = hybrid_retrieve(
query, chunk_texts, index, bm25, top_k=top_k, lambda_faiss=lambda_faiss
)
logger.info("Retrieved chunks")
context = ""
token_count = 0
# context = "\n".join(retrieved_chunks)
for chunk in retrieved_chunks:
chunk_tokens = tokenizer(chunk, return_tensors="pt")["input_ids"].shape[1]
if token_count + chunk_tokens < MAX_CONTEXT_TOKENS:
context += chunk + "\n"
token_count += chunk_tokens
else:
break
prompt = (
f"Based on the following information:\n\n{context}\n\n"
"Answer the query in one or two sentences. "
"Do not provide follow-ups. "
f"Answer the query: {query}"
)
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
inputs.pop("token_type_ids", None)
logger.info("Generating output")
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
)
start_len = 0
if use_extraction:
start_len = input_len
output = output[0][start_len:]
execution_time = time.perf_counter() - start_time
logger.info(f"Query processed in {execution_time:.2f} seconds.")
response = tokenizer.decode(output, skip_special_tokens=True)
confidence_score = compute_confidence_score(
query + " " + response, retrieved_chunks, bm25, lambda_faiss
)
logger.info(f"Confidence: {confidence_score*100}%")
if confidence_score <= 0.3:
logger.error(f"The system is unsure about this response.")
response += "\nThe system is unsure about this response."
return (
response,
f"Confidence: {confidence_score*100}%\nTime taken: {execution_time:.2f} seconds",
)
# Gradio UI
with gr.Blocks(title="Financial Statement RAG with LLM") as ui:
gr.Markdown("## Financial Statement RAG with LLM")
# File upload section
with gr.Group():
gr.Markdown("### Upload & Process Annual Reports")
file_input = gr.File(
file_count="multiple",
file_types=[".pdf", ".csv"],
type="filepath",
label="Upload Annual Reports (PDFs/CSVs)",
)
process_button = gr.Button("Process Files")
process_output = gr.Textbox(label="Processing Status", interactive=False)
# Query model section
with gr.Group():
gr.Markdown("### Ask a Financial Query")
query_input = gr.Textbox(label="Enter Query")
with gr.Row():
top_k_input = gr.Number(value=15, label="Top K (Default: 15)")
lambda_faiss_input = gr.Slider(0, 1, value=0.5, label="Lambda FAISS (0-1)")
repetition_penalty = gr.Slider(
1, 2, value=1.0, label="Repetition Penality (1-2)"
)
max_tokens_input = gr.Number(value=100, label="Max New Tokens")
use_extraction = gr.Checkbox(label="Retrieve only the answer", value=False)
query_button = gr.Button("Submit Query")
query_output = gr.Textbox(label="Query Response", interactive=False)
time_output = gr.Textbox(label="Time Taken", interactive=False)
# Button Actions
process_button.click(process_files, inputs=[file_input], outputs=process_output)
query_button.click(
query_model,
inputs=[
query_input,
top_k_input,
lambda_faiss_input,
repetition_penalty,
max_tokens_input,
use_extraction,
],
outputs=[query_output, time_output],
)
# Application entry point
if __name__ == "__main__":
logger.info("Starting Gradio server...")
ui.launch(server_name="0.0.0.0", server_port=7860, pwa=True)