melk2025 commited on
Commit
8c3f7aa
·
verified ·
1 Parent(s): 6262323

added bm25 +history

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -68,7 +68,7 @@ for idx, row in df.iterrows():
68
  )
69
 
70
  # ---------------------- Config ----------------------
71
- SIMILARITY_THRESHOLD = 0.80
72
  client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Replace with your OpenRouter API key
73
 
74
  # ---------------------- Models ----------------------
@@ -81,6 +81,19 @@ with open("qa.json", "r", encoding="utf-8") as f:
81
  qa_questions = list(qa_data.keys())
82
  qa_answers = list(qa_data.values())
83
  qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # ---------------------- History-Aware CAG ----------------------
86
  def retrieve_from_cag(user_query, chat_history):
@@ -97,18 +110,19 @@ def retrieve_from_cag(user_query, chat_history):
97
 
98
  # ---------------------- History-Aware RAG ----------------------
99
  def retrieve_from_rag(user_query, chat_history):
100
- # Combine the previous chat history with the current query for context
101
  history_context = " ".join([f"User: {msg[0]} Bot: {msg[1]}" for msg in chat_history]) + " "
102
  full_query = history_context + user_query
103
 
104
  print("Searching in RAG with history context...")
105
 
106
  query_embedding = embedding_model.encode(full_query)
107
- results = collection.query(query_embeddings=[query_embedding], n_results=3)
108
 
109
  if not results or not results.get('documents'):
110
  return None
111
 
 
112
  documents = []
113
  for i, content in enumerate(results['documents'][0]):
114
  metadata = results['metadatas'][0][i]
@@ -116,8 +130,12 @@ def retrieve_from_rag(user_query, chat_history):
116
  "content": content.strip(),
117
  "metadata": metadata
118
  })
119
- print("Documents retrieved:", documents)
120
- return documents
 
 
 
 
121
 
122
  # ---------------------- Generation function (OpenRouter) ----------------------
123
  def generate_via_openrouter(context, query, chat_history=None):
 
68
  )
69
 
70
  # ---------------------- Config ----------------------
71
+ SIMILARITY_THRESHOLD = 0.75
72
  client1 = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=API_KEY) # Replace with your OpenRouter API key
73
 
74
  # ---------------------- Models ----------------------
 
81
  qa_questions = list(qa_data.keys())
82
  qa_answers = list(qa_data.values())
83
  qa_embeddings = semantic_model.encode(qa_questions, convert_to_tensor=True)
84
+ #-------------------------bm25---------------------------------
85
+ from rank_bm25 import BM25Okapi
86
+ from nltk.tokenize import word_tokenize
87
+
88
+ def rerank_with_bm25(docs, query):
89
+ tokenized_docs = [word_tokenize(doc['content'].lower()) for doc in docs]
90
+ bm25 = BM25Okapi(tokenized_docs)
91
+ tokenized_query = word_tokenize(query.lower())
92
+
93
+ scores = bm25.get_scores(tokenized_query)
94
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:3]
95
+ return [docs[i] for i in top_indices]
96
+
97
 
98
  # ---------------------- History-Aware CAG ----------------------
99
  def retrieve_from_cag(user_query, chat_history):
 
110
 
111
  # ---------------------- History-Aware RAG ----------------------
112
  def retrieve_from_rag(user_query, chat_history):
113
+ # Combine history with current query
114
  history_context = " ".join([f"User: {msg[0]} Bot: {msg[1]}" for msg in chat_history]) + " "
115
  full_query = history_context + user_query
116
 
117
  print("Searching in RAG with history context...")
118
 
119
  query_embedding = embedding_model.encode(full_query)
120
+ results = collection.query(query_embeddings=[query_embedding], n_results=5) # Get top 5 first
121
 
122
  if not results or not results.get('documents'):
123
  return None
124
 
125
+ # Build docs list
126
  documents = []
127
  for i, content in enumerate(results['documents'][0]):
128
  metadata = results['metadatas'][0][i]
 
130
  "content": content.strip(),
131
  "metadata": metadata
132
  })
133
+
134
+ # Rerank with BM25
135
+ top_docs = rerank_with_bm25(documents, user_query)
136
+
137
+ print("BM25-selected top 3 documents:", top_docs)
138
+ return top_docs
139
 
140
  # ---------------------- Generation function (OpenRouter) ----------------------
141
  def generate_via_openrouter(context, query, chat_history=None):