SatyamD31 anejaprerna commited on
Commit
49add26
·
verified ·
1 Parent(s): 58c74fe

Update rag.py (#2)

Browse files

- Update rag.py (04920e9fbcf79cc0085ee080ad648d41d024dd7b)


Co-authored-by: Prerna Aneja <[email protected]>

Files changed (1) hide show
  1. rag.py +91 -135
rag.py CHANGED
@@ -1,163 +1,119 @@
1
- # import time
2
- import threading
3
- import pandas as pd
4
  import faiss
5
  import numpy as np
6
- # import numpy as np
7
  import pickle
 
 
 
 
 
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- # import torch
11
 
12
  class FinancialChatbot:
13
- def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2.5-1.5b"):
14
- self.data_path = data_path
15
- self.sbert_model = SentenceTransformer(model_name)
16
- self.index_map = {}
17
- self.faiss_index = None
18
- # def get_device_map() -> str:
19
- # return 'cuda' if torch.cuda.is_available() else ''
20
-
21
- # device = get_device_map()
22
- self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
23
- self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
24
- self.load_or_create_index()
25
-
26
- def load_or_create_index(self):
27
- try:
28
- self.faiss_index = faiss.read_index("financial_faiss.index")
29
- with open("index_map.pkl", "rb") as f:
30
- self.index_map = pickle.load(f)
31
- print("Index loaded successfully!")
32
- except:
33
- print("Creating new FAISS index...")
34
- df = pd.read_excel(self.data_path)
35
- sentences = []
36
- for index, row in df.iterrows():
37
- for col in df.columns[1:]:
38
- text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
39
- sentences.append(text)
40
- self.index_map[len(sentences) - 1] = text
41
- embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
42
- dim = embeddings.shape[1]
43
- self.faiss_index = faiss.IndexFlatL2(dim)
44
- self.faiss_index.add(embeddings)
45
- faiss.write_index(self.faiss_index, "financial_faiss.index")
46
- with open("index_map.pkl", "wb") as f:
47
- pickle.dump(self.index_map, f)
48
- print("Indexing completed!")
49
-
50
- # def query_faiss(self, query, top_k=5):
51
- # query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
52
- # distances, indices = self.faiss_index.search(query_embedding, top_k)
53
- # return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
54
-
55
- def query_faiss(self, query, top_k=5):
56
- """Retrieve top-k documents from FAISS and return confidence scores."""
57
 
58
- query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
59
- distances, indices = self.faiss_index.search(query_embedding, top_k)
60
 
61
- results = []
62
- confidences = []
 
 
63
 
64
- if len(distances[0]) > 0:
65
- max_dist = np.max(distances[0]) if np.max(distances[0]) != 0 else 1 # Avoid division by zero
 
 
66
 
67
- for idx, dist in zip(indices[0], distances[0]):
68
- if idx in self.index_map:
69
- results.append(self.index_map[idx])
70
- confidence = 1 - (dist / max_dist) # Normalize confidence (closer to 1 is better)
71
- confidences.append(round(confidence, 2)) # Round for clarity
 
 
72
 
73
- return results, confidences
74
-
75
  def moderate_query(self, query):
76
- BLOCKED_WORDS = ["hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm"]
77
- return not any(word in query.lower() for word in BLOCKED_WORDS)
78
-
79
- def generate_answer(self, context, question):
80
- prompt = f"""
81
- You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
82
 
83
- For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- Context: {context}
86
- User Query: {question}
87
- Answer:
88
- """
 
 
 
 
 
 
89
 
90
- input_text = prompt
91
- # f"Context: {context}\nQuestion: {question}\nAnswer:"
 
92
  inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
93
  outputs = self.qwen_model.generate(inputs, max_length=100)
94
  return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
95
-
96
- # def get_answer(self, query, timeout=150):
97
- # result = ["", 0.0] # Placeholder for answer and confidence
98
-
99
- # def task():
100
- # if self.moderate_query(query):
101
- # retrieved_docs = self.query_faiss(query)
102
- # context = " ".join(retrieved_docs)
103
- # answer = self.generate_answer(context, query)
104
- # last_index = answer.rfind("Answer")
105
- # if answer[last_index+9:11] == "--":
106
- # result[:] = ["No relevant information found", 0.0]
107
- # else:
108
- # result[:] = [answer[last_index:], 0.9]
109
- # else:
110
- # result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
111
-
112
- # thread = threading.Thread(target=task)
113
- # thread.start()
114
- # thread.join(timeout)
115
- # if thread.is_alive():
116
- # return "Execution exceeded time limit. Stopping function.", 0.0
117
- # return tuple(result)
118
-
119
- def get_answer(self, query, timeout=150):
120
- """Retrieve the best-matched answer along with confidence score, with execution timeout."""
121
-
122
- result = ["Execution exceeded time limit. Stopping function.", 0.0] # Default timeout response
123
 
 
 
 
 
124
  def task():
125
- """Processing function to retrieve and generate answer."""
126
- if self.moderate_query(query):
127
- retrieved_docs, confidences = self.query_faiss(query) # Get results + confidence scores
128
-
129
- if not retrieved_docs: # If no relevant docs found
130
- result[:] = ["No relevant information found", 0.0]
131
- return
132
 
133
- # Combine retrieved docs and calculate final confidence
134
- context = " ".join(retrieved_docs)
135
- avg_confidence = round(sum(confidences) / len(confidences), 2) # Avg confidence
136
-
137
- answer = self.generate_answer(context, query)
138
- last_index = answer.rfind("Answer")
139
-
140
- if answer[last_index + 9:11] == "--":
141
- result[:] = ["No relevant information found", 0.0]
142
- else:
143
- result[:] = [answer[last_index:], avg_confidence]
144
-
145
- else:
146
  result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Start execution in a separate thread
149
  thread = threading.Thread(target=task)
150
  thread.start()
151
- thread.join(timeout) # Wait for execution up to timeout
152
-
153
- # If thread is still running after timeout, return timeout message
154
  if thread.is_alive():
155
  return "Execution exceeded time limit. Stopping function.", 0.0
156
-
157
- return tuple(result)
158
-
159
-
160
- # if __name__ == "__main__":
161
- # chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
162
- # query = "What is the Employees Cost in Dec'20?"
163
- # print(chatbot.get_answer(query))
 
 
 
 
1
  import faiss
2
  import numpy as np
 
3
  import pickle
4
+ import threading
5
+ import time
6
+ import torch
7
+ import pandas as pd
8
+
9
  from sentence_transformers import SentenceTransformer
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from rank_bm25 import BM25Okapi
12
 
13
  class FinancialChatbot:
14
+ def __init__(self):
15
+ # Load financial dataset
16
+ self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load embedding model
19
+ self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
20
 
21
+ # Load FAISS index
22
+ self.faiss_index = faiss.read_index("financial_faiss.index")
23
+ with open("index_map.pkl", "rb") as f:
24
+ self.index_map = pickle.load(f)
25
 
26
+ # BM25 Indexing
27
+ self.documents = [" ".join(row) for row in self.df.astype(str).values]
28
+ self.tokenized_docs = [doc.split() for doc in self.documents]
29
+ self.bm25 = BM25Okapi(self.tokenized_docs)
30
 
31
+ # Load Qwen model
32
+ self.qwen_model_name = "Qwen/Qwen2.5-1.5b"
33
+ self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
34
+ self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True)
35
+
36
+ # Guardrail: Blocked words
37
+ self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"]
38
 
 
 
39
  def moderate_query(self, query):
40
+ """Check if the query contains blocked words."""
41
+ return not any(word in query.lower() for word in self.BLOCKED_WORDS)
 
 
 
 
42
 
43
+ def query_faiss(self, query, top_k=5):
44
+ """Retrieve top K relevant documents using FAISS."""
45
+ query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
46
+ distances, indices = self.faiss_index.search(query_embedding, top_k)
47
+
48
+ results = []
49
+ confidences = []
50
+ for idx, dist in zip(indices[0], distances[0]):
51
+ if idx in self.index_map:
52
+ results.append(self.index_map[idx])
53
+ confidences.append(1 / (1 + dist)) # Convert distance to confidence
54
+
55
+ return results, confidences
56
 
57
+ def query_bm25(self, query, top_k=5):
58
+ """Retrieve top K relevant documents using BM25."""
59
+ tokenized_query = query.split()
60
+ scores = self.bm25.get_scores(tokenized_query)
61
+ top_indices = np.argsort(scores)[-top_k:][::-1]
62
+
63
+ results = [self.documents[i] for i in top_indices]
64
+ confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence
65
+
66
+ return results, confidences
67
 
68
+ def generate_answer(self, context, question):
69
+ """Generate answer using Qwen model."""
70
+ input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
71
  inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
72
  outputs = self.qwen_model.generate(inputs, max_length=100)
73
  return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def get_answer(self, query, timeout=200):
76
+ """Fetch an answer using multi-step retrieval and Qwen model, with timeout handling."""
77
+ result = ["No relevant information found", 0.0] # Default response
78
+
79
  def task():
80
+ # Handle greetings
81
+ if query.lower() in ["hi", "hello", "hey"]:
82
+ result[:] = ["Hi, how can I help you?", 1.0]
83
+ return
 
 
 
84
 
85
+ # Guardrail check
86
+ if not self.moderate_query(query):
 
 
 
 
 
 
 
 
 
 
 
87
  result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
88
+ return
89
+
90
+ # Multi-step retrieval (BM25 + FAISS)
91
+ bm25_results, bm25_confidences = self.query_bm25(query, top_k=3)
92
+ faiss_results, faiss_confidences = self.query_faiss(query, top_k=3)
93
+
94
+ retrieved_docs = bm25_results + faiss_results
95
+ confidences = bm25_confidences + faiss_confidences
96
+
97
+ if not retrieved_docs:
98
+ return # Default response already set
99
+
100
+ # Construct context
101
+ context = " ".join(retrieved_docs)
102
+ answer = self.generate_answer(context, query)
103
+ last_index = answer.rfind("Answer")
104
+
105
+ # Confidence calculation
106
+ final_confidence = max(confidences) if confidences else 0.0
107
+ if answer[last_index+9:11] == "--":
108
+ result[:] = ["No relevant information found", 0.0]
109
+ else:
110
+ result[:] = [answer[last_index:], final_confidence]
111
 
112
+ # Run task with timeout
113
  thread = threading.Thread(target=task)
114
  thread.start()
115
+ thread.join(timeout)
 
 
116
  if thread.is_alive():
117
  return "Execution exceeded time limit. Stopping function.", 0.0
118
+
119
+ return tuple(result)