Spaces:
Sleeping
Sleeping
Update rag.py (#4)
Browse files- Update rag.py (730471acfb95df5fc45e201bec1d5d29d7130a32)
Co-authored-by: Prerna Aneja <[email protected]>
rag.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import faiss
|
|
|
2 |
import pickle
|
3 |
import threading
|
4 |
import time
|
5 |
import torch
|
6 |
-
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from rank_bm25 import BM25Okapi
|
9 |
-
from sentence_transformers import SentenceTransformer
|
10 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
11 |
|
12 |
class FinancialChatbot:
|
13 |
def __init__(self):
|
@@ -15,10 +14,12 @@ class FinancialChatbot:
|
|
15 |
self.faiss_index = faiss.read_index("financial_faiss.index")
|
16 |
with open("index_map.pkl", "rb") as f:
|
17 |
self.index_map = pickle.load(f)
|
18 |
-
|
19 |
-
#
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
self.bm25 = BM25Okapi(self.bm25_corpus)
|
23 |
|
24 |
# Load SentenceTransformer for embedding-based retrieval
|
@@ -66,7 +67,7 @@ class FinancialChatbot:
|
|
66 |
return results, confidence_scores
|
67 |
|
68 |
def query_bm25(self, query, top_k=5):
|
69 |
-
"""Retrieve relevant documents using BM25 keyword-based search."""
|
70 |
tokenized_query = query.lower().split()
|
71 |
scores = self.bm25.get_scores(tokenized_query)
|
72 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
@@ -76,7 +77,7 @@ class FinancialChatbot:
|
|
76 |
|
77 |
for idx in top_indices:
|
78 |
if scores[idx] > 0: # Ignore zero-score matches
|
79 |
-
results.append(self.
|
80 |
confidence_scores.append(scores[idx])
|
81 |
|
82 |
return results, confidence_scores
|
@@ -98,7 +99,7 @@ class FinancialChatbot:
|
|
98 |
return
|
99 |
|
100 |
if not self.moderate_query(query):
|
101 |
-
result[:] = ["I'm unable to process your request due to inappropriate language.",
|
102 |
return
|
103 |
|
104 |
faiss_results, faiss_conf = self.query_faiss(query)
|
@@ -107,25 +108,28 @@ class FinancialChatbot:
|
|
107 |
all_results = faiss_results + bm25_results
|
108 |
all_conf = faiss_conf + bm25_conf
|
109 |
|
110 |
-
# Check
|
111 |
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
|
112 |
-
result[:] = ["No relevant information found",
|
113 |
return
|
114 |
|
115 |
context = " ".join(all_results)
|
116 |
answer = self.generate_answer(context, query)
|
117 |
|
118 |
last_index = answer.rfind("Answer")
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
121 |
else:
|
122 |
-
result[:] = [
|
123 |
|
124 |
thread = threading.Thread(target=task)
|
125 |
thread.start()
|
126 |
thread.join(timeout)
|
127 |
|
128 |
if thread.is_alive():
|
129 |
-
return "No relevant information found",
|
130 |
|
131 |
-
return tuple(result)
|
|
|
1 |
import faiss
|
2 |
+
import numpy as np
|
3 |
import pickle
|
4 |
import threading
|
5 |
import time
|
6 |
import torch
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
from rank_bm25 import BM25Okapi
|
|
|
|
|
10 |
|
11 |
class FinancialChatbot:
|
12 |
def __init__(self):
|
|
|
14 |
self.faiss_index = faiss.read_index("financial_faiss.index")
|
15 |
with open("index_map.pkl", "rb") as f:
|
16 |
self.index_map = pickle.load(f)
|
17 |
+
|
18 |
+
# Extract document texts for BM25 dynamically
|
19 |
+
self.documents = list(self.index_map.values())
|
20 |
+
|
21 |
+
# Build BM25 index dynamically
|
22 |
+
self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization
|
23 |
self.bm25 = BM25Okapi(self.bm25_corpus)
|
24 |
|
25 |
# Load SentenceTransformer for embedding-based retrieval
|
|
|
67 |
return results, confidence_scores
|
68 |
|
69 |
def query_bm25(self, query, top_k=5):
|
70 |
+
"""Retrieve relevant documents using BM25 keyword-based search dynamically."""
|
71 |
tokenized_query = query.lower().split()
|
72 |
scores = self.bm25.get_scores(tokenized_query)
|
73 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
|
77 |
|
78 |
for idx in top_indices:
|
79 |
if scores[idx] > 0: # Ignore zero-score matches
|
80 |
+
results.append(self.documents[idx])
|
81 |
confidence_scores.append(scores[idx])
|
82 |
|
83 |
return results, confidence_scores
|
|
|
99 |
return
|
100 |
|
101 |
if not self.moderate_query(query):
|
102 |
+
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
|
103 |
return
|
104 |
|
105 |
faiss_results, faiss_conf = self.query_faiss(query)
|
|
|
108 |
all_results = faiss_results + bm25_results
|
109 |
all_conf = faiss_conf + bm25_conf
|
110 |
|
111 |
+
# Check if results are relevant
|
112 |
if not all_results or max(all_conf, default=0) < self.min_similarity_threshold:
|
113 |
+
result[:] = ["No relevant information found", 0.0]
|
114 |
return
|
115 |
|
116 |
context = " ".join(all_results)
|
117 |
answer = self.generate_answer(context, query)
|
118 |
|
119 |
last_index = answer.rfind("Answer")
|
120 |
+
extracted_answer = answer[last_index:].strip() if last_index != -1 else ""
|
121 |
+
|
122 |
+
# Ensure the answer is grounded in the context
|
123 |
+
if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric():
|
124 |
+
result[:] = ["No relevant information found", 0.0]
|
125 |
else:
|
126 |
+
result[:] = [extracted_answer, max(all_conf, default=0.9)]
|
127 |
|
128 |
thread = threading.Thread(target=task)
|
129 |
thread.start()
|
130 |
thread.join(timeout)
|
131 |
|
132 |
if thread.is_alive():
|
133 |
+
return "No relevant information found", 0.0 # Timeout case
|
134 |
|
135 |
+
return tuple(result)
|