File size: 4,703 Bytes
64a2fb9
ac1b0e8
64a2fb9
 
ac1b0e8
 
 
64a2fb9
ac1b0e8
 
 
 
 
64a2fb9
 
ac1b0e8
64a2fb9
ac1b0e8
64a2fb9
 
ac1b0e8
 
 
 
64a2fb9
ac1b0e8
 
 
 
 
 
 
 
 
 
 
 
 
 
64a2fb9
ac1b0e8
64a2fb9
ac1b0e8
 
 
 
 
 
64a2fb9
ac1b0e8
 
 
 
 
 
 
 
64a2fb9
ac1b0e8
64a2fb9
ac1b0e8
64a2fb9
ac1b0e8
 
 
64a2fb9
ac1b0e8
64a2fb9
ac1b0e8
64a2fb9
 
ac1b0e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a2fb9
 
ac1b0e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a2fb9
ac1b0e8
64a2fb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋””๋ฒ„๊น…์šฉ) - ์ง์ ‘ DeepSeek API ํ˜ธ์ถœ ๋ฐฉ์‹
"""
import os
import logging
import time
from typing import Dict, Any, List

# ์ง์ ‘ DeepSeek ํด๋ผ์ด์–ธํŠธ ์‚ฌ์šฉ
from direct_deepseek import DirectDeepSeekClient

# ๋กœ๊น… ์„ค์ •
logger = logging.getLogger("SimpleRAGChain")

class SimpleRAGChain:
    def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None):
        """๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
        logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
        self.vector_store = vector_store

        # DeepSeek API ํ‚ค ํ™•์ธ
        self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "")
        self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
        logger.info(f"API ํ‚ค ์„ค์ •๋จ: {bool(self.api_key)}")

        # DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
        if self.api_key:
            try:
                self.client = DirectDeepSeekClient(
                    api_key=self.api_key,
                    model_name=self.model
                )
                logger.info(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต: {self.model}")
            except Exception as e:
                logger.error(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
                self.client = None
        else:
            logger.warning("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•„ ํด๋ผ์ด์–ธํŠธ๋ฅผ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
            self.client = None

        logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")

    def _retrieve(self, query: str) -> str:
        """๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ"""
        try:
            docs = self.vector_store.similarity_search(query, k=3)
            if not docs:
                return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."

            # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
            context_parts = []
            for i, doc in enumerate(docs, 1):
                source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
                page = doc.metadata.get("page", "")
                source_info = f"{source}"
                if page:
                    source_info += f" (ํŽ˜์ด์ง€: {page})"

                context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")

            context = "\n".join(context_parts)

            # ๊ธธ์ด ์ œํ•œ
            if len(context) > 6000:
                context = context[:2500] + "\n...(์ค‘๋žต)...\n" + context[-2500:]

            return context
        except Exception as e:
            logger.error(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."

    def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]:
        """DeepSeek API์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
        # ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
        system_prompt = """๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์—์„œ ๋‹ต์„ ์ฐพ์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
์ •๋ณด ์ถœ์ฒ˜๋ฅผ ํฌํ•จํ•ด์„œ ๋Œ€๋‹ตํ•˜์„ธ์š”."""

        # ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ
        user_prompt = f"""์งˆ๋ฌธ: {query}

์ฐธ๊ณ  ์ •๋ณด:
{context}"""

        # DeepSeek API ํ”„๋กฌํ”„ํŠธ ํฌ๋งท
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        return messages

    def run(self, query: str) -> str:
        """์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
        try:
            logger.info(f"SimpleRAGChain ์‹คํ–‰: {query[:50]}...")

            # ๋ฌธ์„œ ๊ฒ€์ƒ‰
            context = self._retrieve(query)

            # ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
            if self.client is None:
                logger.warning("DeepSeek ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์Œ. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ๋ฐ˜ํ™˜.")
                return f"API ์—ฐ๊ฒฐ์ด ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n\n{context}"

            # ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
            messages = self._generate_prompt(query, context)

            # API ํ˜ธ์ถœ
            start_time = time.time()
            response = self.client.chat(messages)
            logger.info(f"API ์‘๋‹ต ์‹œ๊ฐ„: {time.time() - start_time:.2f}์ดˆ")

            if response["success"]:
                logger.info("์‘๋‹ต ์ƒ์„ฑ ์„ฑ๊ณต")
                return response["response"]
            else:
                logger.error(f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}")
                return f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}\n\n๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n{context}"

        except Exception as e:
            logger.error(f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"