added bm25 +history
Browse files
app.py
CHANGED
@@ -68,7 +68,7 @@ for idx, row in df.iterrows():
|
|
68 |
)
|
69 |
|
70 |
# ---------------------- Config ----------------------
|
71 |
-
SIMILARITY_THRESHOLD = 0.
|
72 |
client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Replace with your OpenRouter API key
|
73 |
|
74 |
# ---------------------- Models ----------------------
|
@@ -81,6 +81,19 @@ with open("qa.json", "r", encoding="utf-8") as f:
|
|
81 |
qa_questions = list(qa_data.keys())
|
82 |
qa_answers = list(qa_data.values())
|
83 |
qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
# ---------------------- History-Aware CAG ----------------------
|
86 |
def retrieve_from_cag(user_query, chat_history):
|
@@ -97,18 +110,19 @@ def retrieve_from_cag(user_query, chat_history):
|
|
97 |
|
98 |
# ---------------------- History-Aware RAG ----------------------
|
99 |
def retrieve_from_rag(user_query, chat_history):
|
100 |
-
# Combine
|
101 |
history_context = " ".join([f"User: {msg[0]} Bot: {msg[1]}" for msg in chat_history]) + " "
|
102 |
full_query = history_context + user_query
|
103 |
|
104 |
print("Searching in RAG with history context...")
|
105 |
|
106 |
query_embedding = embedding_model.encode(full_query)
|
107 |
-
results = collection.query(query_embeddings=[query_embedding], n_results=
|
108 |
|
109 |
if not results or not results.get('documents'):
|
110 |
return None
|
111 |
|
|
|
112 |
documents = []
|
113 |
for i, content in enumerate(results['documents'][0]):
|
114 |
metadata = results['metadatas'][0][i]
|
@@ -116,8 +130,12 @@ def retrieve_from_rag(user_query, chat_history):
|
|
116 |
"content": content.strip(),
|
117 |
"metadata": metadata
|
118 |
})
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# ---------------------- Generation function (OpenRouter) ----------------------
|
123 |
def generate_via_openrouter(context, query, chat_history=None):
|
|
|
68 |
)
|
69 |
|
70 |
# ---------------------- Config ----------------------
|
71 |
+
SIMILARITY_THRESHOLD = 0.75
|
72 |
client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Replace with your OpenRouter API key
|
73 |
|
74 |
# ---------------------- Models ----------------------
|
|
|
81 |
qa_questions = list(qa_data.keys())
|
82 |
qa_answers = list(qa_data.values())
|
83 |
qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
|
84 |
+
#-------------------------bm25---------------------------------
|
85 |
+
from rank_bm25 import BM25Okapi
|
86 |
+
from nltk.tokenize import word_tokenize
|
87 |
+
|
88 |
+
def rerank_with_bm25(docs, query):
|
89 |
+
tokenized_docs = [word_tokenize(doc['content'].lower()) for doc in docs]
|
90 |
+
bm25 = BM25Okapi(tokenized_docs)
|
91 |
+
tokenized_query = word_tokenize(query.lower())
|
92 |
+
|
93 |
+
scores = bm25.get_scores(tokenized_query)
|
94 |
+
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:3]
|
95 |
+
return [docs[i] for i in top_indices]
|
96 |
+
|
97 |
|
98 |
# ---------------------- History-Aware CAG ----------------------
|
99 |
def retrieve_from_cag(user_query, chat_history):
|
|
|
110 |
|
111 |
# ---------------------- History-Aware RAG ----------------------
|
112 |
def retrieve_from_rag(user_query, chat_history):
|
113 |
+
# Combine history with current query
|
114 |
history_context = " ".join([f"User: {msg[0]} Bot: {msg[1]}" for msg in chat_history]) + " "
|
115 |
full_query = history_context + user_query
|
116 |
|
117 |
print("Searching in RAG with history context...")
|
118 |
|
119 |
query_embedding = embedding_model.encode(full_query)
|
120 |
+
results = collection.query(query_embeddings=[query_embedding], n_results=5) # Get top 5 first
|
121 |
|
122 |
if not results or not results.get('documents'):
|
123 |
return None
|
124 |
|
125 |
+
# Build docs list
|
126 |
documents = []
|
127 |
for i, content in enumerate(results['documents'][0]):
|
128 |
metadata = results['metadatas'][0][i]
|
|
|
130 |
"content": content.strip(),
|
131 |
"metadata": metadata
|
132 |
})
|
133 |
+
|
134 |
+
# Rerank with BM25
|
135 |
+
top_docs = rerank_with_bm25(documents, user_query)
|
136 |
+
|
137 |
+
print("BM25-selected top 3 documents:", top_docs)
|
138 |
+
return top_docs
|
139 |
|
140 |
# ---------------------- Generation function (OpenRouter) ----------------------
|
141 |
def generate_via_openrouter(context, query, chat_history=None):
|