Spaces:
Running
Running
Delete retrieval
Browse files- retrieval/hybrid_search.py +0 -49
- retrieval/keyword_filter.py +0 -41
- retrieval/vector_store.py +0 -34
retrieval/hybrid_search.py
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
# src/retrieval/hybrid_search.py
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
from rank_bm25 import BM25Okapi
|
5 |
-
from typing import List, Dict, Any
|
6 |
-
import pandas as pd
|
7 |
-
from src.embeddings.embedder import Embedder
|
8 |
-
from src.retrieval.vector_store import VectorStore
|
9 |
-
|
10 |
-
class HybridRetriever:
|
11 |
-
def __init__(self, df: pd.DataFrame, vector_store: VectorStore, embedder: Embedder, alpha: float = 0.5):
|
12 |
-
self.df = df
|
13 |
-
self.vector_store = vector_store
|
14 |
-
self.embedder = embedder
|
15 |
-
self.alpha = alpha
|
16 |
-
tokenized_corpus = [doc.lower().split() for doc in df['description']]
|
17 |
-
self.bm25 = BM25Okapi(tokenized_corpus)
|
18 |
-
|
19 |
-
def retrieve(self, query: str, filtered_df: pd.DataFrame, top_k: int = 3) -> List[Dict[str, Any]]:
|
20 |
-
filtered_indices = filtered_df.index.tolist()
|
21 |
-
filtered_texts = filtered_df['description'].tolist()
|
22 |
-
filtered_ids = [str(row['id']) for _, row in filtered_df.iterrows()]
|
23 |
-
|
24 |
-
if not filtered_texts:
|
25 |
-
return []
|
26 |
-
|
27 |
-
query_embedding = self.embedder.embed([query])[0]
|
28 |
-
dense_results = self.vector_store.query(query_embedding, top_k=top_k * 2)
|
29 |
-
dense_ids = [id for id in dense_results['ids'][0] if id in filtered_ids]
|
30 |
-
dense_scores = [1 - dist for dist, id in zip(dense_results['distances'][0], dense_results['ids'][0]) if id in filtered_ids]
|
31 |
-
|
32 |
-
tokenized_query = query.lower().split()
|
33 |
-
bm25_scores = self.bm25.get_scores(tokenized_query)
|
34 |
-
bm25_scores_filtered = [bm25_scores[i] for i in filtered_indices]
|
35 |
-
bm25_top_k = np.argsort(bm25_scores_filtered)[::-1][:top_k * 2]
|
36 |
-
bm25_ids = [filtered_ids[i] for i in bm25_top_k]
|
37 |
-
bm25_scores = [bm25_scores_filtered[i] for i in bm25_top_k]
|
38 |
-
|
39 |
-
dense_scores = np.array(dense_scores) / np.max(dense_scores) if dense_scores else dense_scores
|
40 |
-
bm25_scores = np.array(bm25_scores) / np.max(bm25_scores) if bm25_scores else bm25_scores
|
41 |
-
|
42 |
-
combined_scores = {}
|
43 |
-
for idx, dense_id in enumerate(dense_ids):
|
44 |
-
combined_scores[int(dense_id)] = combined_scores.get(int(dense_id), 0) + self.alpha * dense_scores[idx]
|
45 |
-
for idx, bm25_id in enumerate(bm25_ids):
|
46 |
-
combined_scores[int(bm25_id)] = combined_scores.get(int(bm25_id), 0) + (1 - self.alpha) * bm25_scores[idx]
|
47 |
-
|
48 |
-
sorted_ids = sorted(combined_scores, key=combined_scores.get, reverse=True)[:top_k]
|
49 |
-
return [self.df[self.df['id'] == id].iloc[0].to_dict() for id in sorted_ids]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/keyword_filter.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
# src/retrieval/keyword_filter.py
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
from typing import Dict, Any
|
5 |
-
|
6 |
-
def filter_restaurants(df: pd.DataFrame, parsed_query: Dict[str, Any]) -> pd.DataFrame:
|
7 |
-
"""
|
8 |
-
Filter restaurants based on extracted features from the query.
|
9 |
-
|
10 |
-
Args:
|
11 |
-
df (pd.DataFrame): DataFrame containing restaurant data.
|
12 |
-
parsed_query (Dict[str, Any]): Parsed query with features.
|
13 |
-
|
14 |
-
Returns:
|
15 |
-
pd.DataFrame: Filtered DataFrame.
|
16 |
-
"""
|
17 |
-
filtered_df = df.copy()
|
18 |
-
|
19 |
-
if parsed_query.get("cuisine"):
|
20 |
-
filtered_df = filtered_df[filtered_df["cuisine"].str.lower() == parsed_query["cuisine"].lower()]
|
21 |
-
|
22 |
-
if parsed_query.get("menu"):
|
23 |
-
filtered_df = filtered_df[filtered_df["dishes"].apply(
|
24 |
-
lambda dishes: any(item.lower() in [d.lower() for d in dishes] for item in parsed_query["menu"])
|
25 |
-
)]
|
26 |
-
|
27 |
-
if parsed_query.get("price_range"):
|
28 |
-
filtered_df = filtered_df[filtered_df["price_range"].str.lower() == parsed_query["price_range"].lower()]
|
29 |
-
|
30 |
-
distance = parsed_query.get("distance")
|
31 |
-
if isinstance(distance, (int, float)):
|
32 |
-
filtered_df = filtered_df[filtered_df["distance"] <= distance]
|
33 |
-
elif distance in ["nearby", "close"]:
|
34 |
-
filtered_df = filtered_df[filtered_df["distance"] <= 2.0]
|
35 |
-
elif distance == "far":
|
36 |
-
filtered_df = filtered_df[filtered_df["distance"] <= 10.0]
|
37 |
-
|
38 |
-
if parsed_query.get("rating"):
|
39 |
-
filtered_df = filtered_df[filtered_df["rating"] >= parsed_query["rating"]]
|
40 |
-
|
41 |
-
return filtered_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retrieval/vector_store.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# src/retrieval/vector_store.py
|
2 |
-
|
3 |
-
from langchain_community.vectorstores import Chroma
|
4 |
-
from langchain_core.documents import Document
|
5 |
-
import numpy as np
|
6 |
-
from typing import List, Dict, Any
|
7 |
-
|
8 |
-
class VectorStore:
|
9 |
-
def __init__(self, embedding_function):
|
10 |
-
self.embedding_function = embedding_function
|
11 |
-
self.collection = None
|
12 |
-
|
13 |
-
def add_documents(self, documents: List[str], embeddings: List[np.ndarray], ids: List[str]):
|
14 |
-
langchain_docs = [Document(page_content=doc, metadata={"id": id}) for doc, id in zip(documents, ids)]
|
15 |
-
self.collection = Chroma.from_documents(
|
16 |
-
documents=langchain_docs,
|
17 |
-
embedding=self.embedding_function,
|
18 |
-
ids=ids,
|
19 |
-
persist_directory="./chroma_db"
|
20 |
-
)
|
21 |
-
self.collection.persist()
|
22 |
-
|
23 |
-
def query(self, query_embedding: np.ndarray, top_k: int = 5) -> Dict[str, Any]:
|
24 |
-
results = self.collection.similarity_search_by_vector(
|
25 |
-
embedding=query_embedding,
|
26 |
-
k=top_k
|
27 |
-
)
|
28 |
-
ids = [doc.metadata["id"] for doc in results]
|
29 |
-
distances = [1 - np.dot(query_embedding, doc.vector) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc.vector))
|
30 |
-
if hasattr(doc, "vector") else 1.0 for doc in results]
|
31 |
-
return {
|
32 |
-
"ids": [ids],
|
33 |
-
"distances": [distances]
|
34 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|