anejaprerna's picture
Update rag.py
a073b25 verified
raw
history blame
5.1 kB
import faiss
import pickle
import threading
import time
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class FinancialChatbot:
def __init__(self):
# Load FAISS index
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
# Load BM25 keyword-based search
with open("bm25_corpus.pkl", "rb") as f:
self.bm25_corpus = pickle.load(f)
self.bm25 = BM25Okapi(self.bm25_corpus)
# Load SentenceTransformer for embedding-based retrieval
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load Qwen Model
model_name = "Qwen/Qwen2.5-1.5b"
self.qwen_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Guardrail: Blocked Words
self.BLOCKED_WORDS = [
"hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering",
"murder", "suicide", "self-harm", "assault", "bomb", "terrorism",
"attack", "genocide", "mass shooting", "credit card number"
]
# Relevance threshold
self.min_similarity_threshold = 0.2
def moderate_query(self, query):
"""Check if the query contains inappropriate words."""
query_lower = query.lower()
for word in self.BLOCKED_WORDS:
if word in query_lower:
return False # Block query
return True # Allow query
def query_faiss(self, query, top_k=5):
"""Retrieve relevant documents using FAISS and compute confidence scores."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidence_scores = []
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
similarity = 1 / (1 + dist) # Convert L2 distance to similarity
results.append(self.index_map[idx])
confidence_scores.append(similarity)
return results, confidence_scores
def query_bm25(self, query, top_k=5):
"""Retrieve relevant documents using BM25 keyword-based search."""
tokenized_query = query.lower().split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
confidence_scores = []
for idx in top_indices:
if scores[idx] > 0: # Ignore zero-score matches
results.append(self.bm25_corpus[idx])
confidence_scores.append(scores[idx])
return results, confidence_scores
def generate_answer(self, context, question):
"""Generate answer using the Qwen model."""
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query, timeout=200):
"""Fetch an answer from FAISS and Qwen model while handling timeouts."""
result = ["No relevant information found", 0.0] # Default response
def task():
if query.lower() in ["hi", "hello", "hey"]:
result[:] = ["Hi, how can I help you?", 1.0]
return
if not self.moderate_query(query):
result[:] = ["I'm unable to process your request due to inappropriate language.", 1.0]
return
faiss_results, faiss_conf = self.query_faiss(query)
bm25_results, bm25_conf = self.query_bm25(query)
all_results = faiss_results + bm25_results
all_conf = faiss_conf + bm25_conf
# Check relevance
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
result[:] = ["No relevant information found", 1.0]
return
context = " ".join(all_results)
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
if answer[last_index+9:11] == "--":
result[:] = ["No relevant information found", 1.0]
else:
result[:] = [answer[last_index:], max(all_conf, default=0.9)]
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout)
if thread.is_alive():
return "No relevant information found", 1.0 # Timeout case
return tuple(result)