File size: 1,854 Bytes
14586a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ์˜ต์…˜์ด ์ถ”๊ฐ€๋œ ๋ฆฌ๋žญ์ปค ๋ชจ๋“ˆ
"""
from typing import List, Dict, Tuple
import numpy as np
from sentence_transformers import CrossEncoder
from langchain.schema import Document
from config import RERANKER_MODEL

class Reranker:
    def __init__(self, model_name: str = RERANKER_MODEL):
        """
        Cross-Encoder ๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™”

        Args:
            model_name: ์‚ฌ์šฉํ•  Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ„
        """
        print(f"๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {model_name}")

        # ์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ํ—ˆ์šฉ ์˜ต์…˜ ์ถ”๊ฐ€
        self.model = CrossEncoder(
            model_name,
            trust_remote_code=True  # ์›๊ฒฉ ์ฝ”๋“œ ์‹คํ–‰ ํ—ˆ์šฉ (ํ•„์ˆ˜)
        )

        print(f"๋ฆฌ๋žญ์ปค ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ: {model_name}")

    def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]:
        """
        ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์žฌ์ •๋ ฌ

        Args:
            query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
            documents: ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ
            top_k: ๋ฐ˜ํ™˜ํ•  ์ƒ์œ„ ๊ฒฐ๊ณผ ์ˆ˜

        Returns:
            ์žฌ์ •๋ ฌ๋œ ์ƒ์œ„ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ
        """
        if not documents:
            return []

        # Cross-Encoder ์ž…๋ ฅ ์Œ ์ƒ์„ฑ
        document_texts = [doc.page_content for doc in documents]
        query_doc_pairs = [(query, doc) for doc in document_texts]

        # ์ ์ˆ˜ ๊ณ„์‚ฐ
        print(f"๋ฆฌ๋žญํ‚น ์ˆ˜ํ–‰ ์ค‘: {len(documents)}๊ฐœ ๋ฌธ์„œ")
        scores = self.model.predict(query_doc_pairs)

        # ์ ์ˆ˜์— ๋”ฐ๋ผ ๋ฌธ์„œ ์žฌ์ •๋ ฌ
        doc_score_pairs = list(zip(documents, scores))
        doc_score_pairs.sort(key=lambda x: x[1], reverse=True)

        print(f"๋ฆฌ๋žญํ‚น ์™„๋ฃŒ: ์ƒ์œ„ {top_k}๊ฐœ ๋ฌธ์„œ ์„ ํƒ")

        # ์ƒ์œ„ k๊ฐœ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
        return [doc for doc, score in doc_score_pairs[:top_k]]