anejaprerna's picture
Update rag.py
04920e9 verified
raw
history blame
4.87 kB
import faiss
import numpy as np
import pickle
import threading
import time
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
class FinancialChatbot:
def __init__(self):
# Load financial dataset
self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx")
# Load embedding model
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load FAISS index
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
# BM25 Indexing
self.documents = [" ".join(row) for row in self.df.astype(str).values]
self.tokenized_docs = [doc.split() for doc in self.documents]
self.bm25 = BM25Okapi(self.tokenized_docs)
# Load Qwen model
self.qwen_model_name = "Qwen/Qwen2.5-1.5b"
self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True)
# Guardrail: Blocked words
self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"]
def moderate_query(self, query):
"""Check if the query contains blocked words."""
return not any(word in query.lower() for word in self.BLOCKED_WORDS)
def query_faiss(self, query, top_k=5):
"""Retrieve top K relevant documents using FAISS."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidences = []
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
results.append(self.index_map[idx])
confidences.append(1 / (1 + dist)) # Convert distance to confidence
return results, confidences
def query_bm25(self, query, top_k=5):
"""Retrieve top K relevant documents using BM25."""
tokenized_query = query.split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [self.documents[i] for i in top_indices]
confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence
return results, confidences
def generate_answer(self, context, question):
"""Generate answer using Qwen model."""
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query, timeout=200):
"""Fetch an answer using multi-step retrieval and Qwen model, with timeout handling."""
result = ["No relevant information found", 0.0] # Default response
def task():
# Handle greetings
if query.lower() in ["hi", "hello", "hey"]:
result[:] = ["Hi, how can I help you?", 1.0]
return
# Guardrail check
if not self.moderate_query(query):
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
return
# Multi-step retrieval (BM25 + FAISS)
bm25_results, bm25_confidences = self.query_bm25(query, top_k=3)
faiss_results, faiss_confidences = self.query_faiss(query, top_k=3)
retrieved_docs = bm25_results + faiss_results
confidences = bm25_confidences + faiss_confidences
if not retrieved_docs:
return # Default response already set
# Construct context
context = " ".join(retrieved_docs)
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
# Confidence calculation
final_confidence = max(confidences) if confidences else 0.0
if answer[last_index+9:11] == "--":
result[:] = ["No relevant information found", 0.0]
else:
result[:] = [answer[last_index:], final_confidence]
# Run task with timeout
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout)
if thread.is_alive():
return "Execution exceeded time limit. Stopping function.", 0.0
return tuple(result)