Spaces:
Paused
Paused
File size: 1,854 Bytes
1f59ca4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""
์๊ฒฉ ์ฝ๋ ์คํ ์ต์
์ด ์ถ๊ฐ๋ ๋ฆฌ๋ญ์ปค ๋ชจ๋
"""
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import CrossEncoder
from langchain.schema import Document
from config import RERANKER_MODEL
class Reranker:
def __init__(self, model_name: str = RERANKER_MODEL):
"""
Cross-Encoder ๋ฆฌ๋ญ์ปค ์ด๊ธฐํ
Args:
model_name: ์ฌ์ฉํ Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ
"""
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์ค: {model_name}")
# ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ ์ต์
์ถ๊ฐ
self.model = CrossEncoder(
model_name,
trust_remote_code=True # ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ (ํ์)
)
print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {model_name}")
def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]:
"""
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ
Args:
query: ๊ฒ์ ์ฟผ๋ฆฌ
documents: ๋ฒกํฐ ๊ฒ์ ๊ฒฐ๊ณผ ๋ฌธ์ ๋ฆฌ์คํธ
top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์
Returns:
์ฌ์ ๋ ฌ๋ ์์ ๋ฌธ์ ๋ฆฌ์คํธ
"""
if not documents:
return []
# Cross-Encoder ์
๋ ฅ ์ ์์ฑ
document_texts = [doc.page_content for doc in documents]
query_doc_pairs = [(query, doc) for doc in document_texts]
# ์ ์ ๊ณ์ฐ
print(f"๋ฆฌ๋ญํน ์ํ ์ค: {len(documents)}๊ฐ ๋ฌธ์")
scores = self.model.predict(query_doc_pairs)
# ์ ์์ ๋ฐ๋ผ ๋ฌธ์ ์ฌ์ ๋ ฌ
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
print(f"๋ฆฌ๋ญํน ์๋ฃ: ์์ {top_k}๊ฐ ๋ฌธ์ ์ ํ")
# ์์ k๊ฐ ๊ฒฐ๊ณผ ๋ฐํ
return [doc for doc, score in doc_score_pairs[:top_k]] |