Spaces:
Sleeping
Sleeping
Update rag.py (#2)
Browse files- Update rag.py (04920e9fbcf79cc0085ee080ad648d41d024dd7b)
Co-authored-by: Prerna Aneja <[email protected]>
rag.py
CHANGED
@@ -1,163 +1,119 @@
|
|
1 |
-
# import time
|
2 |
-
import threading
|
3 |
-
import pandas as pd
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
-
# import numpy as np
|
7 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
|
11 |
|
12 |
class FinancialChatbot:
|
13 |
-
def __init__(self
|
14 |
-
|
15 |
-
self.
|
16 |
-
self.index_map = {}
|
17 |
-
self.faiss_index = None
|
18 |
-
# def get_device_map() -> str:
|
19 |
-
# return 'cuda' if torch.cuda.is_available() else ''
|
20 |
-
|
21 |
-
# device = get_device_map()
|
22 |
-
self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
|
23 |
-
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
|
24 |
-
self.load_or_create_index()
|
25 |
-
|
26 |
-
def load_or_create_index(self):
|
27 |
-
try:
|
28 |
-
self.faiss_index = faiss.read_index("financial_faiss.index")
|
29 |
-
with open("index_map.pkl", "rb") as f:
|
30 |
-
self.index_map = pickle.load(f)
|
31 |
-
print("Index loaded successfully!")
|
32 |
-
except:
|
33 |
-
print("Creating new FAISS index...")
|
34 |
-
df = pd.read_excel(self.data_path)
|
35 |
-
sentences = []
|
36 |
-
for index, row in df.iterrows():
|
37 |
-
for col in df.columns[1:]:
|
38 |
-
text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
|
39 |
-
sentences.append(text)
|
40 |
-
self.index_map[len(sentences) - 1] = text
|
41 |
-
embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
|
42 |
-
dim = embeddings.shape[1]
|
43 |
-
self.faiss_index = faiss.IndexFlatL2(dim)
|
44 |
-
self.faiss_index.add(embeddings)
|
45 |
-
faiss.write_index(self.faiss_index, "financial_faiss.index")
|
46 |
-
with open("index_map.pkl", "wb") as f:
|
47 |
-
pickle.dump(self.index_map, f)
|
48 |
-
print("Indexing completed!")
|
49 |
-
|
50 |
-
# def query_faiss(self, query, top_k=5):
|
51 |
-
# query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
|
52 |
-
# distances, indices = self.faiss_index.search(query_embedding, top_k)
|
53 |
-
# return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
|
54 |
-
|
55 |
-
def query_faiss(self, query, top_k=5):
|
56 |
-
"""Retrieve top-k documents from FAISS and return confidence scores."""
|
57 |
|
58 |
-
|
59 |
-
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
|
73 |
-
return results, confidences
|
74 |
-
|
75 |
def moderate_query(self, query):
|
76 |
-
|
77 |
-
return not any(word in query.lower() for word in BLOCKED_WORDS)
|
78 |
-
|
79 |
-
def generate_answer(self, context, question):
|
80 |
-
prompt = f"""
|
81 |
-
You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
|
|
92 |
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
|
93 |
outputs = self.qwen_model.generate(inputs, max_length=100)
|
94 |
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
95 |
-
|
96 |
-
# def get_answer(self, query, timeout=150):
|
97 |
-
# result = ["", 0.0] # Placeholder for answer and confidence
|
98 |
-
|
99 |
-
# def task():
|
100 |
-
# if self.moderate_query(query):
|
101 |
-
# retrieved_docs = self.query_faiss(query)
|
102 |
-
# context = " ".join(retrieved_docs)
|
103 |
-
# answer = self.generate_answer(context, query)
|
104 |
-
# last_index = answer.rfind("Answer")
|
105 |
-
# if answer[last_index+9:11] == "--":
|
106 |
-
# result[:] = ["No relevant information found", 0.0]
|
107 |
-
# else:
|
108 |
-
# result[:] = [answer[last_index:], 0.9]
|
109 |
-
# else:
|
110 |
-
# result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
111 |
-
|
112 |
-
# thread = threading.Thread(target=task)
|
113 |
-
# thread.start()
|
114 |
-
# thread.join(timeout)
|
115 |
-
# if thread.is_alive():
|
116 |
-
# return "Execution exceeded time limit. Stopping function.", 0.0
|
117 |
-
# return tuple(result)
|
118 |
-
|
119 |
-
def get_answer(self, query, timeout=150):
|
120 |
-
"""Retrieve the best-matched answer along with confidence score, with execution timeout."""
|
121 |
-
|
122 |
-
result = ["Execution exceeded time limit. Stopping function.", 0.0] # Default timeout response
|
123 |
|
|
|
|
|
|
|
|
|
124 |
def task():
|
125 |
-
|
126 |
-
if
|
127 |
-
|
128 |
-
|
129 |
-
if not retrieved_docs: # If no relevant docs found
|
130 |
-
result[:] = ["No relevant information found", 0.0]
|
131 |
-
return
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
avg_confidence = round(sum(confidences) / len(confidences), 2) # Avg confidence
|
136 |
-
|
137 |
-
answer = self.generate_answer(context, query)
|
138 |
-
last_index = answer.rfind("Answer")
|
139 |
-
|
140 |
-
if answer[last_index + 9:11] == "--":
|
141 |
-
result[:] = ["No relevant information found", 0.0]
|
142 |
-
else:
|
143 |
-
result[:] = [answer[last_index:], avg_confidence]
|
144 |
-
|
145 |
-
else:
|
146 |
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
#
|
149 |
thread = threading.Thread(target=task)
|
150 |
thread.start()
|
151 |
-
thread.join(timeout)
|
152 |
-
|
153 |
-
# If thread is still running after timeout, return timeout message
|
154 |
if thread.is_alive():
|
155 |
return "Execution exceeded time limit. Stopping function.", 0.0
|
156 |
-
|
157 |
-
return tuple(result)
|
158 |
-
|
159 |
-
|
160 |
-
# if __name__ == "__main__":
|
161 |
-
# chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
|
162 |
-
# query = "What is the Employees Cost in Dec'20?"
|
163 |
-
# print(chatbot.get_answer(query))
|
|
|
|
|
|
|
|
|
1 |
import faiss
|
2 |
import numpy as np
|
|
|
3 |
import pickle
|
4 |
+
import threading
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from rank_bm25 import BM25Okapi
|
12 |
|
13 |
class FinancialChatbot:
|
14 |
+
def __init__(self):
|
15 |
+
# Load financial dataset
|
16 |
+
self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# Load embedding model
|
19 |
+
self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
|
20 |
|
21 |
+
# Load FAISS index
|
22 |
+
self.faiss_index = faiss.read_index("financial_faiss.index")
|
23 |
+
with open("index_map.pkl", "rb") as f:
|
24 |
+
self.index_map = pickle.load(f)
|
25 |
|
26 |
+
# BM25 Indexing
|
27 |
+
self.documents = [" ".join(row) for row in self.df.astype(str).values]
|
28 |
+
self.tokenized_docs = [doc.split() for doc in self.documents]
|
29 |
+
self.bm25 = BM25Okapi(self.tokenized_docs)
|
30 |
|
31 |
+
# Load Qwen model
|
32 |
+
self.qwen_model_name = "Qwen/Qwen2.5-1.5b"
|
33 |
+
self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
|
34 |
+
self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True)
|
35 |
+
|
36 |
+
# Guardrail: Blocked words
|
37 |
+
self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"]
|
38 |
|
|
|
|
|
39 |
def moderate_query(self, query):
|
40 |
+
"""Check if the query contains blocked words."""
|
41 |
+
return not any(word in query.lower() for word in self.BLOCKED_WORDS)
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
def query_faiss(self, query, top_k=5):
|
44 |
+
"""Retrieve top K relevant documents using FAISS."""
|
45 |
+
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
|
46 |
+
distances, indices = self.faiss_index.search(query_embedding, top_k)
|
47 |
+
|
48 |
+
results = []
|
49 |
+
confidences = []
|
50 |
+
for idx, dist in zip(indices[0], distances[0]):
|
51 |
+
if idx in self.index_map:
|
52 |
+
results.append(self.index_map[idx])
|
53 |
+
confidences.append(1 / (1 + dist)) # Convert distance to confidence
|
54 |
+
|
55 |
+
return results, confidences
|
56 |
|
57 |
+
def query_bm25(self, query, top_k=5):
|
58 |
+
"""Retrieve top K relevant documents using BM25."""
|
59 |
+
tokenized_query = query.split()
|
60 |
+
scores = self.bm25.get_scores(tokenized_query)
|
61 |
+
top_indices = np.argsort(scores)[-top_k:][::-1]
|
62 |
+
|
63 |
+
results = [self.documents[i] for i in top_indices]
|
64 |
+
confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence
|
65 |
+
|
66 |
+
return results, confidences
|
67 |
|
68 |
+
def generate_answer(self, context, question):
|
69 |
+
"""Generate answer using Qwen model."""
|
70 |
+
input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
|
71 |
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
|
72 |
outputs = self.qwen_model.generate(inputs, max_length=100)
|
73 |
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
def get_answer(self, query, timeout=200):
|
76 |
+
"""Fetch an answer using multi-step retrieval and Qwen model, with timeout handling."""
|
77 |
+
result = ["No relevant information found", 0.0] # Default response
|
78 |
+
|
79 |
def task():
|
80 |
+
# Handle greetings
|
81 |
+
if query.lower() in ["hi", "hello", "hey"]:
|
82 |
+
result[:] = ["Hi, how can I help you?", 1.0]
|
83 |
+
return
|
|
|
|
|
|
|
84 |
|
85 |
+
# Guardrail check
|
86 |
+
if not self.moderate_query(query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
88 |
+
return
|
89 |
+
|
90 |
+
# Multi-step retrieval (BM25 + FAISS)
|
91 |
+
bm25_results, bm25_confidences = self.query_bm25(query, top_k=3)
|
92 |
+
faiss_results, faiss_confidences = self.query_faiss(query, top_k=3)
|
93 |
+
|
94 |
+
retrieved_docs = bm25_results + faiss_results
|
95 |
+
confidences = bm25_confidences + faiss_confidences
|
96 |
+
|
97 |
+
if not retrieved_docs:
|
98 |
+
return # Default response already set
|
99 |
+
|
100 |
+
# Construct context
|
101 |
+
context = " ".join(retrieved_docs)
|
102 |
+
answer = self.generate_answer(context, query)
|
103 |
+
last_index = answer.rfind("Answer")
|
104 |
+
|
105 |
+
# Confidence calculation
|
106 |
+
final_confidence = max(confidences) if confidences else 0.0
|
107 |
+
if answer[last_index+9:11] == "--":
|
108 |
+
result[:] = ["No relevant information found", 0.0]
|
109 |
+
else:
|
110 |
+
result[:] = [answer[last_index:], final_confidence]
|
111 |
|
112 |
+
# Run task with timeout
|
113 |
thread = threading.Thread(target=task)
|
114 |
thread.start()
|
115 |
+
thread.join(timeout)
|
|
|
|
|
116 |
if thread.is_alive():
|
117 |
return "Execution exceeded time limit. Stopping function.", 0.0
|
118 |
+
|
119 |
+
return tuple(result)
|
|
|
|
|
|
|
|
|
|
|
|