File size: 4,949 Bytes
a76f77b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
์•ฑ์— ๋‚ด์žฅ๋œ ๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„
"""
from typing import List, Dict, Any, Optional
import os
from config import OPENAI_API_KEY, LLM_MODEL, USE_OPENAI, TOP_K_RETRIEVAL

# ์•ˆ์ „ํ•œ ์ž„ํฌํŠธ
try:
    from langchain_openai import ChatOpenAI
    from langchain.prompts import PromptTemplate
    from langchain_core.output_parsers import StrOutputParser
    from langchain_core.runnables import RunnablePassthrough
    LANGCHAIN_IMPORTS_AVAILABLE = True
except ImportError:
    print("[APP_RAG] langchain ๊ด€๋ จ ํŒจํ‚ค์ง€๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
    LANGCHAIN_IMPORTS_AVAILABLE = False

class SimpleRAGChain:
    """
    ๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (์•ฑ์— ๋‚ด์žฅ)
    """
    def __init__(self, vector_store):
        """๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
        print("[APP_RAG] ๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
        self.vector_store = vector_store

        if not LANGCHAIN_IMPORTS_AVAILABLE:
            print("[APP_RAG] langchain ํŒจํ‚ค์ง€๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์–ด RAG ์ฒด์ธ์„ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
            raise ImportError("RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”์— ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")

        # API ํ‚ค ํ™•์ธ
        if not OPENAI_API_KEY and USE_OPENAI:
            print("[APP_RAG] ๊ฒฝ๊ณ : OpenAI API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
            raise ValueError("OpenAI API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")

        try:
            # LLM ์ดˆ๊ธฐํ™”
            if USE_OPENAI:
                self.llm = ChatOpenAI(
                    model_name=LLM_MODEL,
                    temperature=0.2,
                    api_key=OPENAI_API_KEY,
                )
                print(f"[APP_RAG] OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
            else:
                try:
                    # Ollama ์‚ฌ์šฉ ์‹œ๋„
                    from langchain_community.chat_models import ChatOllama
                    from config import OLLAMA_HOST

                    self.llm = ChatOllama(
                        model=LLM_MODEL,
                        temperature=0.2,
                        base_url=OLLAMA_HOST,
                    )
                    print(f"[APP_RAG] Ollama ๋ชจ๋ธ ์ดˆ๊ธฐํ™”: {LLM_MODEL}")
                except ImportError:
                    # Ollama ๊ฐ€์ ธ์˜ค๊ธฐ ์‹คํŒจ ์‹œ OpenAI ์‚ฌ์šฉ
                    self.llm = ChatOpenAI(
                        model_name="gpt-3.5-turbo",
                        temperature=0.2,
                        api_key=OPENAI_API_KEY,
                    )
                    print("[APP_RAG] Ollama๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์–ด OpenAI๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.")

            # ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
            template = """
            ๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.

            ์งˆ๋ฌธ: {question}

            ์ฐธ๊ณ  ์ •๋ณด:
            {context}

            ์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
            ๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
            ์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”.
            """

            self.prompt = PromptTemplate.from_template(template)

            # ์ฒด์ธ ๊ตฌ์„ฑ
            self.chain = (
                    {"context": self._retrieve, "question": RunnablePassthrough()}
                    | self.prompt
                    | self.llm
                    | StrOutputParser()
            )
            print("[APP_RAG] RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
        except Exception as e:
            print(f"[APP_RAG] RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
            import traceback
            traceback.print_exc()
            raise

    def _retrieve(self, query):
        """๋ฌธ์„œ ๊ฒ€์ƒ‰"""
        try:
            docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)

            # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
            context_parts = []
            for i, doc in enumerate(docs, 1):
                source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
                page = doc.metadata.get("page", "")
                source_info = f"{source}"
                if page:
                    source_info += f" (ํŽ˜์ด์ง€: {page})"

                context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")

            return "\n".join(context_parts)
        except Exception as e:
            print(f"[APP_RAG] ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            import traceback
            traceback.print_exc()
            return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."

    def run(self, query):
        """์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
        try:
            return self.chain.invoke(query)
        except Exception as e:
            print(f"[APP_RAG] ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            import traceback
            traceback.print_exc()
            return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"