SatyamD31 anejaprerna commited on
Commit
b381b95
·
verified ·
1 Parent(s): 08c0aaf

Update rag.py (#4)

Browse files

- Update rag.py (730471acfb95df5fc45e201bec1d5d29d7130a32)


Co-authored-by: Prerna Aneja <[email protected]>

Files changed (1) hide show
  1. rag.py +21 -17
rag.py CHANGED
@@ -1,13 +1,12 @@
1
  import faiss
 
2
  import pickle
3
  import threading
4
  import time
5
  import torch
6
- import numpy as np
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from rank_bm25 import BM25Okapi
9
- from sentence_transformers import SentenceTransformer
10
- from sklearn.metrics.pairwise import cosine_similarity
11
 
12
  class FinancialChatbot:
13
  def __init__(self):
@@ -15,10 +14,12 @@ class FinancialChatbot:
15
  self.faiss_index = faiss.read_index("financial_faiss.index")
16
  with open("index_map.pkl", "rb") as f:
17
  self.index_map = pickle.load(f)
18
-
19
- # Load BM25 keyword-based search
20
- with open("bm25_corpus.pkl", "rb") as f:
21
- self.bm25_corpus = pickle.load(f)
 
 
22
  self.bm25 = BM25Okapi(self.bm25_corpus)
23
 
24
  # Load SentenceTransformer for embedding-based retrieval
@@ -66,7 +67,7 @@ class FinancialChatbot:
66
  return results, confidence_scores
67
 
68
  def query_bm25(self, query, top_k=5):
69
- """Retrieve relevant documents using BM25 keyword-based search."""
70
  tokenized_query = query.lower().split()
71
  scores = self.bm25.get_scores(tokenized_query)
72
  top_indices = np.argsort(scores)[::-1][:top_k]
@@ -76,7 +77,7 @@ class FinancialChatbot:
76
 
77
  for idx in top_indices:
78
  if scores[idx] > 0: # Ignore zero-score matches
79
- results.append(self.bm25_corpus[idx])
80
  confidence_scores.append(scores[idx])
81
 
82
  return results, confidence_scores
@@ -98,7 +99,7 @@ class FinancialChatbot:
98
  return
99
 
100
  if not self.moderate_query(query):
101
- result[:] = ["I'm unable to process your request due to inappropriate language.", 1.0]
102
  return
103
 
104
  faiss_results, faiss_conf = self.query_faiss(query)
@@ -107,25 +108,28 @@ class FinancialChatbot:
107
  all_results = faiss_results + bm25_results
108
  all_conf = faiss_conf + bm25_conf
109
 
110
- # Check relevance
111
  if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
112
- result[:] = ["No relevant information found", 1.0]
113
  return
114
 
115
  context = " ".join(all_results)
116
  answer = self.generate_answer(context, query)
117
 
118
  last_index = answer.rfind("Answer")
119
- if answer[last_index+9:11] == "--":
120
- result[:] = ["No relevant information found", 1.0]
 
 
 
121
  else:
122
- result[:] = [answer[last_index:], max(all_conf, default=0.9)]
123
 
124
  thread = threading.Thread(target=task)
125
  thread.start()
126
  thread.join(timeout)
127
 
128
  if thread.is_alive():
129
- return "No relevant information found", 1.0 # Timeout case
130
 
131
- return tuple(result)
 
1
  import faiss
2
+ import numpy as np
3
  import pickle
4
  import threading
5
  import time
6
  import torch
7
+ from sentence_transformers import SentenceTransformer
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  from rank_bm25 import BM25Okapi
 
 
10
 
11
  class FinancialChatbot:
12
  def __init__(self):
 
14
  self.faiss_index = faiss.read_index("financial_faiss.index")
15
  with open("index_map.pkl", "rb") as f:
16
  self.index_map = pickle.load(f)
17
+
18
+ # Extract document texts for BM25 dynamically
19
+ self.documents = list(self.index_map.values())
20
+
21
+ # Build BM25 index dynamically
22
+ self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization
23
  self.bm25 = BM25Okapi(self.bm25_corpus)
24
 
25
  # Load SentenceTransformer for embedding-based retrieval
 
67
  return results, confidence_scores
68
 
69
  def query_bm25(self, query, top_k=5):
70
+ """Retrieve relevant documents using BM25 keyword-based search dynamically."""
71
  tokenized_query = query.lower().split()
72
  scores = self.bm25.get_scores(tokenized_query)
73
  top_indices = np.argsort(scores)[::-1][:top_k]
 
77
 
78
  for idx in top_indices:
79
  if scores[idx] > 0: # Ignore zero-score matches
80
+ results.append(self.documents[idx])
81
  confidence_scores.append(scores[idx])
82
 
83
  return results, confidence_scores
 
99
  return
100
 
101
  if not self.moderate_query(query):
102
+ result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
103
  return
104
 
105
  faiss_results, faiss_conf = self.query_faiss(query)
 
108
  all_results = faiss_results + bm25_results
109
  all_conf = faiss_conf + bm25_conf
110
 
111
+ # Check if results are relevant
112
  if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
113
+ result[:] = ["No relevant information found", 0.0]
114
  return
115
 
116
  context = " ".join(all_results)
117
  answer = self.generate_answer(context, query)
118
 
119
  last_index = answer.rfind("Answer")
120
+ extracted_answer = answer[last_index:].strip() if last_index != -1 else ""
121
+
122
+ # Ensure the answer is grounded in the context
123
+ if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric():
124
+ result[:] = ["No relevant information found", 0.0]
125
  else:
126
+ result[:] = [extracted_answer, max(all_conf, default=0.9)]
127
 
128
  thread = threading.Thread(target=task)
129
  thread.start()
130
  thread.join(timeout)
131
 
132
  if thread.is_alive():
133
+ return "No relevant information found", 0.0 # Timeout case
134
 
135
+ return tuple(result)