SatyamD31's picture
Update rag.py
b9e5f92 verified
raw
history blame
7.36 kB
import torch
import pandas as pd
import faiss
import numpy as np
import re
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
class FinancialChatbot:
def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2-0.5B-Instruct"):
self.device = "cpu"
self.data_path = data_path # Store data path
# Load SBERT for embeddings
self.sbert_model = SentenceTransformer(model_name, device=self.device)
self.sbert_model = self.sbert_model.half()
# Load Qwen model for text generation
self.qwen_model = AutoModelForCausalLM.from_pretrained(
qwen_model_name, torch_dtype=torch.float16, trust_remote_code=True
).to(self.device)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
# Load or create FAISS index
self.load_or_create_index()
import os # Import os for file checks
def load_or_create_index(self):
"""Loads FAISS index and index_map if they exist, otherwise creates new ones."""
if os.path.exists("financial_faiss.index") and os.path.exists("index_map.txt"):
try:
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.txt", "r", encoding="utf-8") as f:
self.index_map = {i: line.strip() for i, line in enumerate(f)}
print("FAISS index and index_map loaded successfully.")
except Exception as e:
print(f"Error loading FAISS index: {e}. Recreating index...")
self.create_faiss_index()
else:
print("FAISS index or index_map not found. Creating a new one...")
self.create_faiss_index()
def create_faiss_index(self):
"""Creates a FAISS index from the provided Excel file."""
df = pd.read_excel(self.data_path)
sentences = []
self.index_map = {} # Initialize index_map
for row_idx, row in df.iterrows():
for col in df.columns[1:]: # Ignore the first column (assumed to be labels)
sentence = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
sentences.append(sentence)
self.index_map[len(self.index_map)] = sentence # Store mapping
# Encode the sentences into embeddings
embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
# Create FAISS index (FlatL2 for simplicity)
self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
self.faiss_index.add(embeddings)
# Save index and index map
faiss.write_index(self.faiss_index, "financial_faiss.index")
with open("index_map.txt", "w", encoding="utf-8") as f:
for sentence in self.index_map.values():
f.write(sentence + "\n")
def query_faiss(self, query, top_k=3):
"""Retrieves the top_k closest sentences from FAISS index."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
confidences = [(1 - (dist / (np.max(distances[0]) or 1))) * 10 for dist in distances[0]]
return results, confidences
def moderate_query(self, query):
"""Blocks inappropriate queries containing restricted words."""
BLOCKED_WORDS = re.compile(r"\b(hack|bypass|illegal|exploit|scam|kill|laundering|murder|suicide|self-harm)\b", re.IGNORECASE)
return not bool(BLOCKED_WORDS.search(query))
# def generate_answer(self, context, question):
# messages = [
# {"role": "system", "content": "You are a financial assistant. Answer only finance-related questions. If the question is not related to finance, reply: 'I'm sorry, but I can only answer financial-related questions.' If the user greets you (e.g., 'Hello', 'Hi', 'Good morning'), respond politely with 'Hello! How can I assist you today?'."},
# {"role": "user", "content": f"{question} - related contect extracted form db {context}"}
# ]
# # Use Qwen's chat template
# input_text = self.qwen_tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# # Tokenize and move input to device
# inputs = self.qwen_tokenizer([input_text], return_tensors="pt").to(self.device)
# self.qwen_model.config.pad_token_id = self.qwen_tokenizer.eos_token_id
# # Generate response
# outputs = self.qwen_model.generate(
# inputs.input_ids,
# max_new_tokens=50,
# pad_token_id=self.qwen_tokenizer.eos_token_id,
# )
# # Extract only the newly generated part
# generated_ids = outputs[:, inputs.input_ids.shape[1]:] # Remove prompt part
# response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return response
def generate_answer(self, context, question):
prompt = f"""
You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely with 'Hello! How can I assist you today? without requiring context.
For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
Context: {context}
User Query: {question}
Answer:
"""
input_text = prompt
# f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
# outputs = self.qwen_model.generate(inputs, max_length=100)
outputs = self.qwen_model.generate(inputs, max_new_tokens=100)
generated_ids = outputs[:, inputs.shape[1]:] # Remove prompt part
response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def get_answer(self, query):
"""Main function to process a user query and return an answer."""
# Check if query is appropriate
if not self.moderate_query(query):
return "Inappropriate request.", 0.0
# Retrieve relevant documents and their confidence scores
retrieved_docs, confidences = self.query_faiss(query)
if not retrieved_docs:
return "No relevant information found.", 0.0
# Combine retrieved documents as context
context = " ".join(retrieved_docs)
avg_confidence = round(sum(confidences) / len(confidences), 2)
# Generate model response
model_response = self.generate_answer(context, query)
# Extract only the relevant part of the response
model_response = model_response.strip()
# Ensure only the actual answer is returned
if model_response.lower() in ["i don't know", "no relevant information found"]:
return "I don't know.", avg_confidence
#print(avg_confidence)
if avg_confidence == 0.0:
return "Not relevant ", avg_confidence
return model_response, avg_confidence