Spaces:
Sleeping
Sleeping
import faiss | |
import numpy as np | |
import pickle | |
import threading | |
import time | |
import torch | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from rank_bm25 import BM25Okapi | |
class FinancialChatbot: | |
def __init__(self): | |
# Load financial dataset | |
self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx") | |
# Load embedding model | |
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# Load FAISS index | |
self.faiss_index = faiss.read_index("financial_faiss.index") | |
with open("index_map.pkl", "rb") as f: | |
self.index_map = pickle.load(f) | |
# BM25 Indexing | |
self.documents = [" ".join(row) for row in self.df.astype(str).values] | |
self.tokenized_docs = [doc.split() for doc in self.documents] | |
self.bm25 = BM25Okapi(self.tokenized_docs) | |
# Load Qwen model | |
self.qwen_model_name = "Qwen/Qwen2.5-1.5b" | |
self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True) | |
self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True) | |
# Guardrail: Blocked words | |
self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"] | |
def moderate_query(self, query): | |
"""Check if the query contains blocked words.""" | |
return not any(word in query.lower() for word in self.BLOCKED_WORDS) | |
def query_faiss(self, query, top_k=5): | |
"""Retrieve top K relevant documents using FAISS.""" | |
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True) | |
distances, indices = self.faiss_index.search(query_embedding, top_k) | |
results = [] | |
confidences = [] | |
for idx, dist in zip(indices[0], distances[0]): | |
if idx in self.index_map: | |
results.append(self.index_map[idx]) | |
confidences.append(1 / (1 + dist)) # Convert distance to confidence | |
return results, confidences | |
def query_bm25(self, query, top_k=5): | |
"""Retrieve top K relevant documents using BM25.""" | |
tokenized_query = query.split() | |
scores = self.bm25.get_scores(tokenized_query) | |
top_indices = np.argsort(scores)[-top_k:][::-1] | |
results = [self.documents[i] for i in top_indices] | |
confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence | |
return results, confidences | |
def generate_answer(self, context, question): | |
"""Generate answer using Qwen model.""" | |
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:" | |
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt") | |
outputs = self.qwen_model.generate(inputs, max_length=100) | |
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def get_answer(self, query, timeout=200): | |
"""Fetch an answer using multi-step retrieval and Qwen model, with timeout handling.""" | |
result = ["No relevant information found", 0.0] # Default response | |
def task(): | |
# Handle greetings | |
if query.lower() in ["hi", "hello", "hey"]: | |
result[:] = ["Hi, how can I help you?", 1.0] | |
return | |
# Guardrail check | |
if not self.moderate_query(query): | |
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0] | |
return | |
# Multi-step retrieval (BM25 + FAISS) | |
bm25_results, bm25_confidences = self.query_bm25(query, top_k=3) | |
faiss_results, faiss_confidences = self.query_faiss(query, top_k=3) | |
retrieved_docs = bm25_results + faiss_results | |
confidences = bm25_confidences + faiss_confidences | |
if not retrieved_docs: | |
return # Default response already set | |
# Construct context | |
context = " ".join(retrieved_docs) | |
answer = self.generate_answer(context, query) | |
last_index = answer.rfind("Answer") | |
# Confidence calculation | |
final_confidence = max(confidences) if confidences else 0.0 | |
if answer[last_index+9:11] == "--": | |
result[:] = ["No relevant information found", 0.0] | |
else: | |
result[:] = [answer[last_index:], final_confidence] | |
# Run task with timeout | |
thread = threading.Thread(target=task) | |
thread.start() | |
thread.join(timeout) | |
if thread.is_alive(): | |
return "Execution exceeded time limit. Stopping function.", 0.0 | |
return tuple(result) |