SatyamD31 anejaprerna commited on
Commit
58c74fe
·
verified ·
1 Parent(s): f8d190b

Update rag.py (#1)

Browse files

- Update rag.py (43697b49baad1adf45e00601ce018871dc5f3971)


Co-authored-by: Prerna Aneja <[email protected]>

Files changed (1) hide show
  1. rag.py +128 -139
rag.py CHANGED
@@ -1,174 +1,163 @@
1
- import torch
 
2
  import pandas as pd
3
  import faiss
4
  import numpy as np
5
- import re
6
- import os
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
 
10
  class FinancialChatbot:
11
- def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2-0.5B-Instruct"):
12
- self.device = "cpu"
13
- self.data_path = data_path # Store data path
14
-
15
- # Load SBERT for embeddings
16
- self.sbert_model = SentenceTransformer(model_name, device=self.device)
17
- self.sbert_model = self.sbert_model.half()
18
-
19
- # Load Qwen model for text generation
20
- self.qwen_model = AutoModelForCausalLM.from_pretrained(
21
- qwen_model_name, torch_dtype=torch.float16, trust_remote_code=True
22
- ).to(self.device)
23
-
24
  self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
25
-
26
- # Load or create FAISS index
27
  self.load_or_create_index()
28
-
29
- import os # Import os for file checks
30
-
31
  def load_or_create_index(self):
32
- """Loads FAISS index and index_map if they exist, otherwise creates new ones."""
33
- if os.path.exists("financial_faiss.index") and os.path.exists("index_map.txt"):
34
- try:
35
- self.faiss_index = faiss.read_index("financial_faiss.index")
36
- with open("index_map.txt", "r", encoding="utf-8") as f:
37
- self.index_map = {i: line.strip() for i, line in enumerate(f)}
38
- print("FAISS index and index_map loaded successfully.")
39
- except Exception as e:
40
- print(f"Error loading FAISS index: {e}. Recreating index...")
41
- self.create_faiss_index()
42
- else:
43
- print("FAISS index or index_map not found. Creating a new one...")
44
- self.create_faiss_index()
45
-
46
-
47
- def create_faiss_index(self):
48
- """Creates a FAISS index from the provided Excel file."""
49
- df = pd.read_excel(self.data_path)
50
- sentences = []
51
- self.index_map = {} # Initialize index_map
52
-
53
- for row_idx, row in df.iterrows():
54
- for col in df.columns[1:]: # Ignore the first column (assumed to be labels)
55
- sentence = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
56
- sentences.append(sentence)
57
- self.index_map[len(self.index_map)] = sentence # Store mapping
58
-
59
- # Encode the sentences into embeddings
60
- embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
61
-
62
- # Create FAISS index (FlatL2 for simplicity)
63
- self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
64
- self.faiss_index.add(embeddings)
65
-
66
- # Save index and index map
67
- faiss.write_index(self.faiss_index, "financial_faiss.index")
68
- with open("index_map.txt", "w", encoding="utf-8") as f:
69
- for sentence in self.index_map.values():
70
- f.write(sentence + "\n")
71
-
72
- def query_faiss(self, query, top_k=3):
73
- """Retrieves the top_k closest sentences from FAISS index."""
74
  query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
75
  distances, indices = self.faiss_index.search(query_embedding, top_k)
76
 
77
- results = [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
78
- confidences = [(1 - (dist / (np.max(distances[0]) or 1))) * 10 for dist in distances[0]]
79
 
80
- return results, confidences
 
81
 
82
- def moderate_query(self, query):
83
- """Blocks inappropriate queries containing restricted words."""
84
- BLOCKED_WORDS = re.compile(r"\b(hack|bypass|illegal|exploit|scam|kill|laundering|murder|suicide|self-harm)\b", re.IGNORECASE)
85
- return not bool(BLOCKED_WORDS.search(query))
 
86
 
 
 
 
 
 
 
87
  def generate_answer(self, context, question):
88
- messages = [
89
- # {"role": "system", "content": "You are a financial assistant. Answer only finance-related questions. If the question is not related to finance, reply: 'I'm sorry, but I can only answer financial-related questions.' If the user greets you (e.g., 'Hello', 'Hi', 'Good morning'), respond politely with 'Hello! How can I assist you today?'."},
90
- # {"role": "user", "content": f"{question} - related contect extracted form db {context}"}
91
- {"role": "user", "content": f"""You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely with 'Hello! How can I assist you today? without requiring context.
92
 
93
  For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
94
 
95
  Context: {context}
96
  User Query: {question}
97
- Answer:"""}
98
- ]
99
-
100
- # Use Qwen's chat template
101
- input_text = self.qwen_tokenizer.apply_chat_template(
102
- messages, tokenize=False, add_generation_prompt=True
103
- )
104
-
105
- # Tokenize and move input to device
106
- inputs = self.qwen_tokenizer([input_text], return_tensors="pt").to(self.device)
107
- self.qwen_model.config.pad_token_id = self.qwen_tokenizer.eos_token_id
108
-
109
- # Generate response
110
- outputs = self.qwen_model.generate(
111
- inputs.input_ids,
112
- max_new_tokens=50,
113
- pad_token_id=self.qwen_tokenizer.eos_token_id,
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Extract only the newly generated part
117
- generated_ids = outputs[:, inputs.input_ids.shape[1]:] # Remove prompt part
118
- response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
 
120
- return response
121
 
122
-
123
- # def generate_answer(self, context, question):
124
- # prompt = f"""
125
- # You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely with 'Hello! How can I assist you today? without requiring context.
126
 
127
- # For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
 
 
128
 
129
- # Context: {context}
130
- # User Query: {question}
131
- # Answer:
132
- # """
133
 
134
- # input_text = prompt
135
- # # f"Context: {context}\nQuestion: {question}\nAnswer:"
136
- # inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
137
- # # outputs = self.qwen_model.generate(inputs, max_length=100)
138
- # outputs = self.qwen_model.generate(inputs, max_new_tokens=100)
139
- # generated_ids = outputs[:, inputs.shape[1]:] # Remove prompt part
140
- # response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
141
- # return response
142
 
 
 
 
 
143
 
144
- def get_answer(self, query):
145
- """Main function to process a user query and return an answer."""
146
-
147
- # Check if query is appropriate
148
- if not self.moderate_query(query):
149
- return "Inappropriate request.", 0.0
150
 
151
- # Retrieve relevant documents and their confidence scores
152
- retrieved_docs, confidences = self.query_faiss(query)
153
- if not retrieved_docs:
154
- return "No relevant information found.", 0.0
155
 
156
- # Combine retrieved documents as context
157
- context = " ".join(retrieved_docs)
158
- avg_confidence = round(sum(confidences) / len(confidences), 2)
159
 
160
- # Generate model response
161
- model_response = self.generate_answer(context, query)
162
 
163
- # Extract only the relevant part of the response
164
- model_response = model_response.strip()
165
-
166
- # Ensure only the actual answer is returned
167
- if model_response.lower() in ["i don't know", "no relevant information found"]:
168
- return "I don't know.", avg_confidence
169
- #print(avg_confidence)
170
- if avg_confidence == 0.0:
171
- return "Not relevant ", avg_confidence
172
 
173
-
174
- return model_response, avg_confidence
 
 
 
1
+ # import time
2
+ import threading
3
  import pandas as pd
4
  import faiss
5
  import numpy as np
6
+ # import numpy as np
7
+ import pickle
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ # import torch
11
 
12
  class FinancialChatbot:
13
+ def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2.5-1.5b"):
14
+ self.data_path = data_path
15
+ self.sbert_model = SentenceTransformer(model_name)
16
+ self.index_map = {}
17
+ self.faiss_index = None
18
+ # def get_device_map() -> str:
19
+ # return 'cuda' if torch.cuda.is_available() else ''
20
+
21
+ # device = get_device_map()
22
+ self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
 
 
 
23
  self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
 
 
24
  self.load_or_create_index()
25
+
 
 
26
  def load_or_create_index(self):
27
+ try:
28
+ self.faiss_index = faiss.read_index("financial_faiss.index")
29
+ with open("index_map.pkl", "rb") as f:
30
+ self.index_map = pickle.load(f)
31
+ print("Index loaded successfully!")
32
+ except:
33
+ print("Creating new FAISS index...")
34
+ df = pd.read_excel(self.data_path)
35
+ sentences = []
36
+ for index, row in df.iterrows():
37
+ for col in df.columns[1:]:
38
+ text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
39
+ sentences.append(text)
40
+ self.index_map[len(sentences) - 1] = text
41
+ embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
42
+ dim = embeddings.shape[1]
43
+ self.faiss_index = faiss.IndexFlatL2(dim)
44
+ self.faiss_index.add(embeddings)
45
+ faiss.write_index(self.faiss_index, "financial_faiss.index")
46
+ with open("index_map.pkl", "wb") as f:
47
+ pickle.dump(self.index_map, f)
48
+ print("Indexing completed!")
49
+
50
+ # def query_faiss(self, query, top_k=5):
51
+ # query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
52
+ # distances, indices = self.faiss_index.search(query_embedding, top_k)
53
+ # return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
54
+
55
+ def query_faiss(self, query, top_k=5):
56
+ """Retrieve top-k documents from FAISS and return confidence scores."""
57
+
 
 
 
 
 
 
 
 
 
 
 
58
  query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
59
  distances, indices = self.faiss_index.search(query_embedding, top_k)
60
 
61
+ results = []
62
+ confidences = []
63
 
64
+ if len(distances[0]) > 0:
65
+ max_dist = np.max(distances[0]) if np.max(distances[0]) != 0 else 1 # Avoid division by zero
66
 
67
+ for idx, dist in zip(indices[0], distances[0]):
68
+ if idx in self.index_map:
69
+ results.append(self.index_map[idx])
70
+ confidence = 1 - (dist / max_dist) # Normalize confidence (closer to 1 is better)
71
+ confidences.append(round(confidence, 2)) # Round for clarity
72
 
73
+ return results, confidences
74
+
75
+ def moderate_query(self, query):
76
+ BLOCKED_WORDS = ["hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm"]
77
+ return not any(word in query.lower() for word in BLOCKED_WORDS)
78
+
79
  def generate_answer(self, context, question):
80
+ prompt = f"""
81
+ You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
 
 
82
 
83
  For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
84
 
85
  Context: {context}
86
  User Query: {question}
87
+ Answer:
88
+ """
89
+
90
+ input_text = prompt
91
+ # f"Context: {context}\nQuestion: {question}\nAnswer:"
92
+ inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
93
+ outputs = self.qwen_model.generate(inputs, max_length=100)
94
+ return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+
96
+ # def get_answer(self, query, timeout=150):
97
+ # result = ["", 0.0] # Placeholder for answer and confidence
98
+
99
+ # def task():
100
+ # if self.moderate_query(query):
101
+ # retrieved_docs = self.query_faiss(query)
102
+ # context = " ".join(retrieved_docs)
103
+ # answer = self.generate_answer(context, query)
104
+ # last_index = answer.rfind("Answer")
105
+ # if answer[last_index+9:11] == "--":
106
+ # result[:] = ["No relevant information found", 0.0]
107
+ # else:
108
+ # result[:] = [answer[last_index:], 0.9]
109
+ # else:
110
+ # result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
111
+
112
+ # thread = threading.Thread(target=task)
113
+ # thread.start()
114
+ # thread.join(timeout)
115
+ # if thread.is_alive():
116
+ # return "Execution exceeded time limit. Stopping function.", 0.0
117
+ # return tuple(result)
118
 
119
+ def get_answer(self, query, timeout=150):
120
+ """Retrieve the best-matched answer along with confidence score, with execution timeout."""
 
121
 
122
+ result = ["Execution exceeded time limit. Stopping function.", 0.0] # Default timeout response
123
 
124
+ def task():
125
+ """Processing function to retrieve and generate answer."""
126
+ if self.moderate_query(query):
127
+ retrieved_docs, confidences = self.query_faiss(query) # Get results + confidence scores
128
 
129
+ if not retrieved_docs: # If no relevant docs found
130
+ result[:] = ["No relevant information found", 0.0]
131
+ return
132
 
133
+ # Combine retrieved docs and calculate final confidence
134
+ context = " ".join(retrieved_docs)
135
+ avg_confidence = round(sum(confidences) / len(confidences), 2) # Avg confidence
 
136
 
137
+ answer = self.generate_answer(context, query)
138
+ last_index = answer.rfind("Answer")
 
 
 
 
 
 
139
 
140
+ if answer[last_index + 9:11] == "--":
141
+ result[:] = ["No relevant information found", 0.0]
142
+ else:
143
+ result[:] = [answer[last_index:], avg_confidence]
144
 
145
+ else:
146
+ result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
 
 
 
 
147
 
148
+ # Start execution in a separate thread
149
+ thread = threading.Thread(target=task)
150
+ thread.start()
151
+ thread.join(timeout) # Wait for execution up to timeout
152
 
153
+ # If thread is still running after timeout, return timeout message
154
+ if thread.is_alive():
155
+ return "Execution exceeded time limit. Stopping function.", 0.0
156
 
157
+ return tuple(result)
 
158
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # if __name__ == "__main__":
161
+ # chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
162
+ # query = "What is the Employees Cost in Dec'20?"
163
+ # print(chatbot.get_answer(query))