import faiss import pickle import threading import time import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity class FinancialChatbot: def __init__(self): # Load FAISS index self.faiss_index = faiss.read_index("financial_faiss.index") with open("index_map.pkl", "rb") as f: self.index_map = pickle.load(f) # Load BM25 keyword-based search with open("bm25_corpus.pkl", "rb") as f: self.bm25_corpus = pickle.load(f) self.bm25 = BM25Okapi(self.bm25_corpus) # Load SentenceTransformer for embedding-based retrieval self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") # Load Qwen Model model_name = "Qwen/Qwen2.5-1.5b" self.qwen_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True ) self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Guardrail: Blocked Words self.BLOCKED_WORDS = [ "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm", "assault", "bomb", "terrorism", "attack", "genocide", "mass shooting", "credit card number" ] # Relevance threshold self.min_similarity_threshold = 0.2 def moderate_query(self, query): """Check if the query contains inappropriate words.""" query_lower = query.lower() for word in self.BLOCKED_WORDS: if word in query_lower: return False # Block query return True # Allow query def query_faiss(self, query, top_k=5): """Retrieve relevant documents using FAISS and compute confidence scores.""" query_embedding = self.sbert_model.encode([query], convert_to_numpy=True) distances, indices = self.faiss_index.search(query_embedding, top_k) results = [] confidence_scores = [] for idx, dist in zip(indices[0], distances[0]): if idx in self.index_map: similarity = 1 / (1 + dist) # Convert L2 distance to similarity results.append(self.index_map[idx]) confidence_scores.append(similarity) return results, confidence_scores def query_bm25(self, query, top_k=5): """Retrieve relevant documents using BM25 keyword-based search.""" tokenized_query = query.lower().split() scores = self.bm25.get_scores(tokenized_query) top_indices = np.argsort(scores)[::-1][:top_k] results = [] confidence_scores = [] for idx in top_indices: if scores[idx] > 0: # Ignore zero-score matches results.append(self.bm25_corpus[idx]) confidence_scores.append(scores[idx]) return results, confidence_scores def generate_answer(self, context, question): """Generate answer using the Qwen model.""" input_text = f"Context: {context}\nQuestion: {question}\nAnswer:" inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt") outputs = self.qwen_model.generate(inputs, max_length=100) return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) def get_answer(self, query, timeout=200): """Fetch an answer from FAISS and Qwen model while handling timeouts.""" result = ["No relevant information found", 0.0] # Default response def task(): if query.lower() in ["hi", "hello", "hey"]: result[:] = ["Hi, how can I help you?", 1.0] return if not self.moderate_query(query): result[:] = ["I'm unable to process your request due to inappropriate language.", 1.0] return faiss_results, faiss_conf = self.query_faiss(query) bm25_results, bm25_conf = self.query_bm25(query) all_results = faiss_results + bm25_results all_conf = faiss_conf + bm25_conf # Check relevance if not all_results or max(all_conf, default=0) < self.min_similarity_threshold: result[:] = ["No relevant information found", 1.0] return context = " ".join(all_results) answer = self.generate_answer(context, query) last_index = answer.rfind("Answer") if answer[last_index+9:11] == "--": result[:] = ["No relevant information found", 1.0] else: result[:] = [answer[last_index:], max(all_conf, default=0.9)] thread = threading.Thread(target=task) thread.start() thread.join(timeout) if thread.is_alive(): return "No relevant information found", 1.0 # Timeout case return tuple(result)