Spaces:
Sleeping
Sleeping
File size: 4,873 Bytes
e2b8671 58c74fe 04920e9 e2b8671 04920e9 e2b8671 04920e9 58c74fe 04920e9 e2b8671 04920e9 e2b8671 04920e9 e2b8671 04920e9 e2b8671 58c74fe 04920e9 e2b8671 04920e9 e2b8671 04920e9 58c74fe 04920e9 58c74fe e2b8671 04920e9 58c74fe 04920e9 f8d190b 04920e9 58c74fe 04920e9 e2b8671 04920e9 58c74fe 04920e9 58c74fe 04920e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import faiss
import numpy as np
import pickle
import threading
import time
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from rank_bm25 import BM25Okapi
class FinancialChatbot:
def __init__(self):
# Load financial dataset
self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx")
# Load embedding model
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
# Load FAISS index
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
# BM25 Indexing
self.documents = [" ".join(row) for row in self.df.astype(str).values]
self.tokenized_docs = [doc.split() for doc in self.documents]
self.bm25 = BM25Okapi(self.tokenized_docs)
# Load Qwen model
self.qwen_model_name = "Qwen/Qwen2.5-1.5b"
self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True)
# Guardrail: Blocked words
self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"]
def moderate_query(self, query):
"""Check if the query contains blocked words."""
return not any(word in query.lower() for word in self.BLOCKED_WORDS)
def query_faiss(self, query, top_k=5):
"""Retrieve top K relevant documents using FAISS."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidences = []
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
results.append(self.index_map[idx])
confidences.append(1 / (1 + dist)) # Convert distance to confidence
return results, confidences
def query_bm25(self, query, top_k=5):
"""Retrieve top K relevant documents using BM25."""
tokenized_query = query.split()
scores = self.bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[-top_k:][::-1]
results = [self.documents[i] for i in top_indices]
confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence
return results, confidences
def generate_answer(self, context, question):
"""Generate answer using Qwen model."""
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query, timeout=200):
"""Fetch an answer using multi-step retrieval and Qwen model, with timeout handling."""
result = ["No relevant information found", 0.0] # Default response
def task():
# Handle greetings
if query.lower() in ["hi", "hello", "hey"]:
result[:] = ["Hi, how can I help you?", 1.0]
return
# Guardrail check
if not self.moderate_query(query):
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
return
# Multi-step retrieval (BM25 + FAISS)
bm25_results, bm25_confidences = self.query_bm25(query, top_k=3)
faiss_results, faiss_confidences = self.query_faiss(query, top_k=3)
retrieved_docs = bm25_results + faiss_results
confidences = bm25_confidences + faiss_confidences
if not retrieved_docs:
return # Default response already set
# Construct context
context = " ".join(retrieved_docs)
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
# Confidence calculation
final_confidence = max(confidences) if confidences else 0.0
if answer[last_index+9:11] == "--":
result[:] = ["No relevant information found", 0.0]
else:
result[:] = [answer[last_index:], final_confidence]
# Run task with timeout
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout)
if thread.is_alive():
return "Execution exceeded time limit. Stopping function.", 0.0
return tuple(result) |