SatyamD31 anejaprerna commited on
Commit
08c0aaf
·
verified ·
1 Parent(s): 8427e43

Update rag.py (#3)

Browse files

- Update rag.py (a073b2501d9a8c18427d0a55cdbccbad30540ff5)


Co-authored-by: Prerna Aneja <[email protected]>

Files changed (1) hide show
  1. rag.py +72 -60
rag.py CHANGED
@@ -1,119 +1,131 @@
1
  import faiss
2
- import numpy as np
3
  import pickle
4
  import threading
5
  import time
6
  import torch
7
- import pandas as pd
8
-
9
- from sentence_transformers import SentenceTransformer
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  from rank_bm25 import BM25Okapi
 
 
12
 
13
  class FinancialChatbot:
14
  def __init__(self):
15
- # Load financial dataset
16
- self.df = pd.read_excel("Nestle_Financtial_report_till2023.xlsx")
17
-
18
- # Load embedding model
19
- self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
20
-
21
  # Load FAISS index
22
  self.faiss_index = faiss.read_index("financial_faiss.index")
23
  with open("index_map.pkl", "rb") as f:
24
  self.index_map = pickle.load(f)
 
 
 
 
 
 
 
 
25
 
26
- # BM25 Indexing
27
- self.documents = [" ".join(row) for row in self.df.astype(str).values]
28
- self.tokenized_docs = [doc.split() for doc in self.documents]
29
- self.bm25 = BM25Okapi(self.tokenized_docs)
 
 
30
 
31
- # Load Qwen model
32
- self.qwen_model_name = "Qwen/Qwen2.5-1.5b"
33
- self.qwen_model = AutoModelForCausalLM.from_pretrained(self.qwen_model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True)
34
- self.qwen_tokenizer = AutoTokenizer.from_pretrained(self.qwen_model_name, trust_remote_code=True)
 
 
35
 
36
- # Guardrail: Blocked words
37
- self.BLOCKED_WORDS = ["hack", "bypass", "illegal", "scam", "terrorism", "attack", "suicide", "bomb"]
38
 
39
  def moderate_query(self, query):
40
- """Check if the query contains blocked words."""
41
- return not any(word in query.lower() for word in self.BLOCKED_WORDS)
 
 
 
 
42
 
43
  def query_faiss(self, query, top_k=5):
44
- """Retrieve top K relevant documents using FAISS."""
45
  query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
46
  distances, indices = self.faiss_index.search(query_embedding, top_k)
47
-
48
  results = []
49
- confidences = []
 
50
  for idx, dist in zip(indices[0], distances[0]):
51
  if idx in self.index_map:
 
52
  results.append(self.index_map[idx])
53
- confidences.append(1 / (1 + dist)) # Convert distance to confidence
54
-
55
- return results, confidences
56
 
57
  def query_bm25(self, query, top_k=5):
58
- """Retrieve top K relevant documents using BM25."""
59
- tokenized_query = query.split()
60
  scores = self.bm25.get_scores(tokenized_query)
61
- top_indices = np.argsort(scores)[-top_k:][::-1]
62
-
63
- results = [self.documents[i] for i in top_indices]
64
- confidences = [scores[i] / max(scores) for i in top_indices] # Normalize confidence
65
 
66
- return results, confidences
 
 
 
 
 
 
 
 
67
 
68
  def generate_answer(self, context, question):
69
- """Generate answer using Qwen model."""
70
  input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
71
  inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
72
  outputs = self.qwen_model.generate(inputs, max_length=100)
73
  return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
74
 
75
  def get_answer(self, query, timeout=200):
76
- """Fetch an answer using multi-step retrieval and Qwen model, with timeout handling."""
77
  result = ["No relevant information found", 0.0] # Default response
78
-
79
  def task():
80
- # Handle greetings
81
  if query.lower() in ["hi", "hello", "hey"]:
82
  result[:] = ["Hi, how can I help you?", 1.0]
83
  return
84
 
85
- # Guardrail check
86
  if not self.moderate_query(query):
87
- result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
88
  return
89
-
90
- # Multi-step retrieval (BM25 + FAISS)
91
- bm25_results, bm25_confidences = self.query_bm25(query, top_k=3)
92
- faiss_results, faiss_confidences = self.query_faiss(query, top_k=3)
93
 
94
- retrieved_docs = bm25_results + faiss_results
95
- confidences = bm25_confidences + faiss_confidences
96
 
97
- if not retrieved_docs:
98
- return # Default response already set
99
 
100
- # Construct context
101
- context = " ".join(retrieved_docs)
 
 
 
 
102
  answer = self.generate_answer(context, query)
103
- last_index = answer.rfind("Answer")
104
 
105
- # Confidence calculation
106
- final_confidence = max(confidences) if confidences else 0.0
107
  if answer[last_index+9:11] == "--":
108
- result[:] = ["No relevant information found", 0.0]
109
  else:
110
- result[:] = [answer[last_index:], final_confidence]
111
 
112
- # Run task with timeout
113
  thread = threading.Thread(target=task)
114
  thread.start()
115
  thread.join(timeout)
 
116
  if thread.is_alive():
117
- return "Execution exceeded time limit. Stopping function.", 0.0
118
-
119
- return tuple(result)
 
1
  import faiss
 
2
  import pickle
3
  import threading
4
  import time
5
  import torch
6
+ import numpy as np
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from rank_bm25 import BM25Okapi
9
+ from sentence_transformers import SentenceTransformer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
 
12
  class FinancialChatbot:
13
  def __init__(self):
 
 
 
 
 
 
14
  # Load FAISS index
15
  self.faiss_index = faiss.read_index("financial_faiss.index")
16
  with open("index_map.pkl", "rb") as f:
17
  self.index_map = pickle.load(f)
18
+
19
+ # Load BM25 keyword-based search
20
+ with open("bm25_corpus.pkl", "rb") as f:
21
+ self.bm25_corpus = pickle.load(f)
22
+ self.bm25 = BM25Okapi(self.bm25_corpus)
23
+
24
+ # Load SentenceTransformer for embedding-based retrieval
25
+ self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
26
 
27
+ # Load Qwen Model
28
+ model_name = "Qwen/Qwen2.5-1.5b"
29
+ self.qwen_model = AutoModelForCausalLM.from_pretrained(
30
+ model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True
31
+ )
32
+ self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
33
 
34
+ # Guardrail: Blocked Words
35
+ self.BLOCKED_WORDS = [
36
+ "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering",
37
+ "murder", "suicide", "self-harm", "assault", "bomb", "terrorism",
38
+ "attack", "genocide", "mass shooting", "credit card number"
39
+ ]
40
 
41
+ # Relevance threshold
42
+ self.min_similarity_threshold = 0.2
43
 
44
  def moderate_query(self, query):
45
+ """Check if the query contains inappropriate words."""
46
+ query_lower = query.lower()
47
+ for word in self.BLOCKED_WORDS:
48
+ if word in query_lower:
49
+ return False # Block query
50
+ return True # Allow query
51
 
52
  def query_faiss(self, query, top_k=5):
53
+ """Retrieve relevant documents using FAISS and compute confidence scores."""
54
  query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
55
  distances, indices = self.faiss_index.search(query_embedding, top_k)
56
+
57
  results = []
58
+ confidence_scores = []
59
+
60
  for idx, dist in zip(indices[0], distances[0]):
61
  if idx in self.index_map:
62
+ similarity = 1 / (1 + dist) # Convert L2 distance to similarity
63
  results.append(self.index_map[idx])
64
+ confidence_scores.append(similarity)
65
+
66
+ return results, confidence_scores
67
 
68
  def query_bm25(self, query, top_k=5):
69
+ """Retrieve relevant documents using BM25 keyword-based search."""
70
+ tokenized_query = query.lower().split()
71
  scores = self.bm25.get_scores(tokenized_query)
72
+ top_indices = np.argsort(scores)[::-1][:top_k]
 
 
 
73
 
74
+ results = []
75
+ confidence_scores = []
76
+
77
+ for idx in top_indices:
78
+ if scores[idx] > 0: # Ignore zero-score matches
79
+ results.append(self.bm25_corpus[idx])
80
+ confidence_scores.append(scores[idx])
81
+
82
+ return results, confidence_scores
83
 
84
  def generate_answer(self, context, question):
85
+ """Generate answer using the Qwen model."""
86
  input_text = f"Context: {context}\nQuestion: {question}\nAnswer:"
87
  inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
88
  outputs = self.qwen_model.generate(inputs, max_length=100)
89
  return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
90
 
91
  def get_answer(self, query, timeout=200):
92
+ """Fetch an answer from FAISS and Qwen model while handling timeouts."""
93
  result = ["No relevant information found", 0.0] # Default response
94
+
95
  def task():
 
96
  if query.lower() in ["hi", "hello", "hey"]:
97
  result[:] = ["Hi, how can I help you?", 1.0]
98
  return
99
 
 
100
  if not self.moderate_query(query):
101
+ result[:] = ["I'm unable to process your request due to inappropriate language.", 1.0]
102
  return
 
 
 
 
103
 
104
+ faiss_results, faiss_conf = self.query_faiss(query)
105
+ bm25_results, bm25_conf = self.query_bm25(query)
106
 
107
+ all_results = faiss_results + bm25_results
108
+ all_conf = faiss_conf + bm25_conf
109
 
110
+ # Check relevance
111
+ if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
112
+ result[:] = ["No relevant information found", 1.0]
113
+ return
114
+
115
+ context = " ".join(all_results)
116
  answer = self.generate_answer(context, query)
 
117
 
118
+ last_index = answer.rfind("Answer")
 
119
  if answer[last_index+9:11] == "--":
120
+ result[:] = ["No relevant information found", 1.0]
121
  else:
122
+ result[:] = [answer[last_index:], max(all_conf, default=0.9)]
123
 
 
124
  thread = threading.Thread(target=task)
125
  thread.start()
126
  thread.join(timeout)
127
+
128
  if thread.is_alive():
129
+ return "No relevant information found", 1.0 # Timeout case
130
+
131
+ return tuple(result)