Spaces:
Sleeping
Sleeping
Update rag.py (#1)
Browse files- Update rag.py (43697b49baad1adf45e00601ce018871dc5f3971)
Co-authored-by: Prerna Aneja <[email protected]>
rag.py
CHANGED
@@ -1,174 +1,163 @@
|
|
1 |
-
import
|
|
|
2 |
import pandas as pd
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
-
import
|
6 |
-
import
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
9 |
|
10 |
class FinancialChatbot:
|
11 |
-
def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2-
|
12 |
-
self.
|
13 |
-
self.
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
self.qwen_model = AutoModelForCausalLM.from_pretrained(
|
21 |
-
qwen_model_name, torch_dtype=torch.float16, trust_remote_code=True
|
22 |
-
).to(self.device)
|
23 |
-
|
24 |
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
|
25 |
-
|
26 |
-
# Load or create FAISS index
|
27 |
self.load_or_create_index()
|
28 |
-
|
29 |
-
import os # Import os for file checks
|
30 |
-
|
31 |
def load_or_create_index(self):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
self.
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
|
64 |
-
self.faiss_index.add(embeddings)
|
65 |
-
|
66 |
-
# Save index and index map
|
67 |
-
faiss.write_index(self.faiss_index, "financial_faiss.index")
|
68 |
-
with open("index_map.txt", "w", encoding="utf-8") as f:
|
69 |
-
for sentence in self.index_map.values():
|
70 |
-
f.write(sentence + "\n")
|
71 |
-
|
72 |
-
def query_faiss(self, query, top_k=3):
|
73 |
-
"""Retrieves the top_k closest sentences from FAISS index."""
|
74 |
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
|
75 |
distances, indices = self.faiss_index.search(query_embedding, top_k)
|
76 |
|
77 |
-
results = [
|
78 |
-
confidences = [
|
79 |
|
80 |
-
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def generate_answer(self, context, question):
|
88 |
-
|
89 |
-
|
90 |
-
# {"role": "user", "content": f"{question} - related contect extracted form db {context}"}
|
91 |
-
{"role": "user", "content": f"""You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely with 'Hello! How can I assist you today? without requiring context.
|
92 |
|
93 |
For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
|
94 |
|
95 |
Context: {context}
|
96 |
User Query: {question}
|
97 |
-
Answer:
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
)
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
119 |
|
120 |
-
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
|
127 |
-
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
# """
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
# inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
|
137 |
-
# # outputs = self.qwen_model.generate(inputs, max_length=100)
|
138 |
-
# outputs = self.qwen_model.generate(inputs, max_new_tokens=100)
|
139 |
-
# generated_ids = outputs[:, inputs.shape[1]:] # Remove prompt part
|
140 |
-
# response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
141 |
-
# return response
|
142 |
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
# Check if query is appropriate
|
148 |
-
if not self.moderate_query(query):
|
149 |
-
return "Inappropriate request.", 0.0
|
150 |
|
151 |
-
#
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
-
|
159 |
|
160 |
-
|
161 |
-
model_response = self.generate_answer(context, query)
|
162 |
|
163 |
-
# Extract only the relevant part of the response
|
164 |
-
model_response = model_response.strip()
|
165 |
-
|
166 |
-
# Ensure only the actual answer is returned
|
167 |
-
if model_response.lower() in ["i don't know", "no relevant information found"]:
|
168 |
-
return "I don't know.", avg_confidence
|
169 |
-
#print(avg_confidence)
|
170 |
-
if avg_confidence == 0.0:
|
171 |
-
return "Not relevant ", avg_confidence
|
172 |
|
173 |
-
|
174 |
-
|
|
|
|
|
|
1 |
+
# import time
|
2 |
+
import threading
|
3 |
import pandas as pd
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
+
# import numpy as np
|
7 |
+
import pickle
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
# import torch
|
11 |
|
12 |
class FinancialChatbot:
|
13 |
+
def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2.5-1.5b"):
|
14 |
+
self.data_path = data_path
|
15 |
+
self.sbert_model = SentenceTransformer(model_name)
|
16 |
+
self.index_map = {}
|
17 |
+
self.faiss_index = None
|
18 |
+
# def get_device_map() -> str:
|
19 |
+
# return 'cuda' if torch.cuda.is_available() else ''
|
20 |
+
|
21 |
+
# device = get_device_map()
|
22 |
+
self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
|
|
|
|
|
|
|
23 |
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
|
|
|
|
|
24 |
self.load_or_create_index()
|
25 |
+
|
|
|
|
|
26 |
def load_or_create_index(self):
|
27 |
+
try:
|
28 |
+
self.faiss_index = faiss.read_index("financial_faiss.index")
|
29 |
+
with open("index_map.pkl", "rb") as f:
|
30 |
+
self.index_map = pickle.load(f)
|
31 |
+
print("Index loaded successfully!")
|
32 |
+
except:
|
33 |
+
print("Creating new FAISS index...")
|
34 |
+
df = pd.read_excel(self.data_path)
|
35 |
+
sentences = []
|
36 |
+
for index, row in df.iterrows():
|
37 |
+
for col in df.columns[1:]:
|
38 |
+
text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
|
39 |
+
sentences.append(text)
|
40 |
+
self.index_map[len(sentences) - 1] = text
|
41 |
+
embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
|
42 |
+
dim = embeddings.shape[1]
|
43 |
+
self.faiss_index = faiss.IndexFlatL2(dim)
|
44 |
+
self.faiss_index.add(embeddings)
|
45 |
+
faiss.write_index(self.faiss_index, "financial_faiss.index")
|
46 |
+
with open("index_map.pkl", "wb") as f:
|
47 |
+
pickle.dump(self.index_map, f)
|
48 |
+
print("Indexing completed!")
|
49 |
+
|
50 |
+
# def query_faiss(self, query, top_k=5):
|
51 |
+
# query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
|
52 |
+
# distances, indices = self.faiss_index.search(query_embedding, top_k)
|
53 |
+
# return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
|
54 |
+
|
55 |
+
def query_faiss(self, query, top_k=5):
|
56 |
+
"""Retrieve top-k documents from FAISS and return confidence scores."""
|
57 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
|
59 |
distances, indices = self.faiss_index.search(query_embedding, top_k)
|
60 |
|
61 |
+
results = []
|
62 |
+
confidences = []
|
63 |
|
64 |
+
if len(distances[0]) > 0:
|
65 |
+
max_dist = np.max(distances[0]) if np.max(distances[0]) != 0 else 1 # Avoid division by zero
|
66 |
|
67 |
+
for idx, dist in zip(indices[0], distances[0]):
|
68 |
+
if idx in self.index_map:
|
69 |
+
results.append(self.index_map[idx])
|
70 |
+
confidence = 1 - (dist / max_dist) # Normalize confidence (closer to 1 is better)
|
71 |
+
confidences.append(round(confidence, 2)) # Round for clarity
|
72 |
|
73 |
+
return results, confidences
|
74 |
+
|
75 |
+
def moderate_query(self, query):
|
76 |
+
BLOCKED_WORDS = ["hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm"]
|
77 |
+
return not any(word in query.lower() for word in BLOCKED_WORDS)
|
78 |
+
|
79 |
def generate_answer(self, context, question):
|
80 |
+
prompt = f"""
|
81 |
+
You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
|
|
|
|
|
82 |
|
83 |
For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
|
84 |
|
85 |
Context: {context}
|
86 |
User Query: {question}
|
87 |
+
Answer:
|
88 |
+
"""
|
89 |
+
|
90 |
+
input_text = prompt
|
91 |
+
# f"Context: {context}\nQuestion: {question}\nAnswer:"
|
92 |
+
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
|
93 |
+
outputs = self.qwen_model.generate(inputs, max_length=100)
|
94 |
+
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
95 |
+
|
96 |
+
# def get_answer(self, query, timeout=150):
|
97 |
+
# result = ["", 0.0] # Placeholder for answer and confidence
|
98 |
+
|
99 |
+
# def task():
|
100 |
+
# if self.moderate_query(query):
|
101 |
+
# retrieved_docs = self.query_faiss(query)
|
102 |
+
# context = " ".join(retrieved_docs)
|
103 |
+
# answer = self.generate_answer(context, query)
|
104 |
+
# last_index = answer.rfind("Answer")
|
105 |
+
# if answer[last_index+9:11] == "--":
|
106 |
+
# result[:] = ["No relevant information found", 0.0]
|
107 |
+
# else:
|
108 |
+
# result[:] = [answer[last_index:], 0.9]
|
109 |
+
# else:
|
110 |
+
# result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
111 |
+
|
112 |
+
# thread = threading.Thread(target=task)
|
113 |
+
# thread.start()
|
114 |
+
# thread.join(timeout)
|
115 |
+
# if thread.is_alive():
|
116 |
+
# return "Execution exceeded time limit. Stopping function.", 0.0
|
117 |
+
# return tuple(result)
|
118 |
|
119 |
+
def get_answer(self, query, timeout=150):
|
120 |
+
"""Retrieve the best-matched answer along with confidence score, with execution timeout."""
|
|
|
121 |
|
122 |
+
result = ["Execution exceeded time limit. Stopping function.", 0.0] # Default timeout response
|
123 |
|
124 |
+
def task():
|
125 |
+
"""Processing function to retrieve and generate answer."""
|
126 |
+
if self.moderate_query(query):
|
127 |
+
retrieved_docs, confidences = self.query_faiss(query) # Get results + confidence scores
|
128 |
|
129 |
+
if not retrieved_docs: # If no relevant docs found
|
130 |
+
result[:] = ["No relevant information found", 0.0]
|
131 |
+
return
|
132 |
|
133 |
+
# Combine retrieved docs and calculate final confidence
|
134 |
+
context = " ".join(retrieved_docs)
|
135 |
+
avg_confidence = round(sum(confidences) / len(confidences), 2) # Avg confidence
|
|
|
136 |
|
137 |
+
answer = self.generate_answer(context, query)
|
138 |
+
last_index = answer.rfind("Answer")
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
if answer[last_index + 9:11] == "--":
|
141 |
+
result[:] = ["No relevant information found", 0.0]
|
142 |
+
else:
|
143 |
+
result[:] = [answer[last_index:], avg_confidence]
|
144 |
|
145 |
+
else:
|
146 |
+
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
# Start execution in a separate thread
|
149 |
+
thread = threading.Thread(target=task)
|
150 |
+
thread.start()
|
151 |
+
thread.join(timeout) # Wait for execution up to timeout
|
152 |
|
153 |
+
# If thread is still running after timeout, return timeout message
|
154 |
+
if thread.is_alive():
|
155 |
+
return "Execution exceeded time limit. Stopping function.", 0.0
|
156 |
|
157 |
+
return tuple(result)
|
|
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
# if __name__ == "__main__":
|
161 |
+
# chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
|
162 |
+
# query = "What is the Employees Cost in Dec'20?"
|
163 |
+
# print(chatbot.get_answer(query))
|