File size: 5,229 Bytes
4a98f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
LangChain์„ ํ™œ์šฉํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„
"""
from typing import List, Dict, Any
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.chat_models import ChatOllama
from langchain_openai import ChatOpenAI

from config import (
    OLLAMA_HOST, LLM_MODEL, USE_OPENAI,
    OPENAI_API_KEY, TOP_K_RETRIEVAL, TOP_K_RERANK
)
from vector_store import VectorStore
from reranker import Reranker


class RAGChain:
    def __init__(self, vector_store: VectorStore, use_reranker: bool = True):
        """
        RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” (ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ LLM ์„ ํƒ)

        Args:
            vector_store: ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ธ์Šคํ„ด์Šค
            use_reranker: ๋ฆฌ๋žญ์ปค ์‚ฌ์šฉ ์—ฌ๋ถ€
        """
        try:
            print("RAGChain ์ดˆ๊ธฐํ™” ์‹œ์ž‘...")
            self.vector_store = vector_store
            self.use_reranker = use_reranker
            print(f"๋ฆฌ๋žญ์ปค ์‚ฌ์šฉ ์—ฌ๋ถ€: {use_reranker}")

            if use_reranker:
                try:
                    self.reranker = Reranker()
                    print("๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
                except Exception as e:
                    print(f"๋ฆฌ๋žญ์ปค ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
                    self.reranker = None
                    self.use_reranker = False
            else:
                self.reranker = None

            # ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ LLM ๋ชจ๋ธ ์„ค์ •
            if USE_OPENAI or IS_HUGGINGFACE:
                print(f"OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
                print(f"API ํ‚ค ์กด์žฌ ์—ฌ๋ถ€: {'์žˆ์Œ' if OPENAI_API_KEY else '์—†์Œ'}")
                try:
                    self.llm = ChatOpenAI(
                        model_name=LLM_MODEL,
                        temperature=0.2,
                        api_key=OPENAI_API_KEY,
                    )
                    print("OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
                except Exception as e:
                    print(f"OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
                    raise
            else:
                try:
                    print(f"Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
                    self.llm = ChatOllama(
                        model=LLM_MODEL,
                        temperature=0.2,
                        base_url=OLLAMA_HOST,
                    )
                    print("Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
                except Exception as e:
                    print(f"Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
                    raise

            # RAG ์ฒด์ธ ๊ตฌ์„ฑ ๋ฐ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
            print("RAG ์ฒด์ธ ์„ค์ • ์‹œ์ž‘...")
            self.setup_chain()
            print("RAG ์ฒด์ธ ์„ค์ • ์™„๋ฃŒ")
        except Exception as e:
            print(f"RAGChain ์ดˆ๊ธฐํ™” ์ค‘ ์ƒ์„ธ ์˜ค๋ฅ˜: {str(e)}")
            import traceback
            traceback.print_exc()
            raise

    def setup_chain(self) -> None:
        """
        RAG ์ฒด์ธ ๋ฐ ํ”„๋กฌํ”„ํŠธ ์„ค์ •
        """
        # ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ •์˜
        template = """
        ๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.

        ์งˆ๋ฌธ: {question}

        ์ฐธ๊ณ  ์ •๋ณด:
        {context}

        ์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
        ๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
        ์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”.
        """

        self.prompt = PromptTemplate.from_template(template)

        # RAG ์ฒด์ธ ์ •์˜
        self.chain = (
                {"context": self._retrieve, "question": RunnablePassthrough()}
                | self.prompt
                | self.llm
                | StrOutputParser()
        )

    def _retrieve(self, query: str) -> str:
        """
        ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ

        Args:
            query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ

        Returns:
            ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํฌํ•จํ•œ ์ปจํ…์ŠคํŠธ ๋ฌธ์ž์—ด
        """
        # ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
        docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)

        # ๋ฆฌ๋žญ์ปค ์ ์šฉ (์„ ํƒ์ )
        if self.use_reranker and docs:
            docs = self.reranker.rerank(query, docs, top_k=TOP_K_RERANK)

        # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
        context_parts = []
        for i, doc in enumerate(docs, 1):
            source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
            page = doc.metadata.get("page", "")
            source_info = f"{source}"
            if page:
                source_info += f" (ํŽ˜์ด์ง€: {page})"

            context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")

        return "\n".join(context_parts)

    def run(self, query: str) -> str:
        """
        ์‚ฌ์šฉ์ž ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ RAG ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰

        Args:
            query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ

        Returns:
            ๋ชจ๋ธ ์‘๋‹ต ๋ฌธ์ž์—ด
        """
        return self.chain.invoke(query)