Spaces:
Paused
Paused
""" | |
์๊ฒฉ ์ฝ๋ ์คํ ์ต์ ์ด ์ถ๊ฐ๋ ๋ฆฌ๋ญ์ปค ๋ชจ๋ | |
""" | |
from typing import List, Dict, Tuple | |
import numpy as np | |
from sentence_transformers import CrossEncoder | |
from langchain.schema import Document | |
from config import RERANKER_MODEL | |
class Reranker: | |
def __init__(self, model_name: str = RERANKER_MODEL): | |
""" | |
Cross-Encoder ๋ฆฌ๋ญ์ปค ์ด๊ธฐํ | |
Args: | |
model_name: ์ฌ์ฉํ Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ | |
""" | |
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์ค: {model_name}") | |
# ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ ์ต์ ์ถ๊ฐ | |
self.model = CrossEncoder( | |
model_name, | |
trust_remote_code=True # ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ (ํ์) | |
) | |
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {model_name}") | |
def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]: | |
""" | |
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ | |
Args: | |
query: ๊ฒ์ ์ฟผ๋ฆฌ | |
documents: ๋ฒกํฐ ๊ฒ์ ๊ฒฐ๊ณผ ๋ฌธ์ ๋ฆฌ์คํธ | |
top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ | |
Returns: | |
์ฌ์ ๋ ฌ๋ ์์ ๋ฌธ์ ๋ฆฌ์คํธ | |
""" | |
if not documents: | |
return [] | |
# Cross-Encoder ์ ๋ ฅ ์ ์์ฑ | |
document_texts = [doc.page_content for doc in documents] | |
query_doc_pairs = [(query, doc) for doc in document_texts] | |
# ์ ์ ๊ณ์ฐ | |
print(f"๋ฆฌ๋ญํน ์ํ ์ค: {len(documents)}๊ฐ ๋ฌธ์") | |
scores = self.model.predict(query_doc_pairs) | |
# ์ ์์ ๋ฐ๋ผ ๋ฌธ์ ์ฌ์ ๋ ฌ | |
doc_score_pairs = list(zip(documents, scores)) | |
doc_score_pairs.sort(key=lambda x: x[1], reverse=True) | |
print(f"๋ฆฌ๋ญํน ์๋ฃ: ์์ {top_k}๊ฐ ๋ฌธ์ ์ ํ") | |
# ์์ k๊ฐ ๊ฒฐ๊ณผ ๋ฐํ | |
return [doc for doc, score in doc_score_pairs[:top_k]] |