File size: 4,703 Bytes
14586a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋””๋ฒ„๊น…์šฉ) - ์ง์ ‘ DeepSeek API ํ˜ธ์ถœ ๋ฐฉ์‹
"""
import os
import logging
import time
from typing import Dict, Any, List

# ์ง์ ‘ DeepSeek ํด๋ผ์ด์–ธํŠธ ์‚ฌ์šฉ
from direct_deepseek import DirectDeepSeekClient

# ๋กœ๊น… ์„ค์ •
logger = logging.getLogger("SimpleRAGChain")

class SimpleRAGChain:
    def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None):
        """๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
        logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
        self.vector_store = vector_store

        # DeepSeek API ํ‚ค ํ™•์ธ
        self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "")
        self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
        logger.info(f"API ํ‚ค ์„ค์ •๋จ: {bool(self.api_key)}")

        # DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
        if self.api_key:
            try:
                self.client = DirectDeepSeekClient(
                    api_key=self.api_key,
                    model_name=self.model
                )
                logger.info(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต: {self.model}")
            except Exception as e:
                logger.error(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
                self.client = None
        else:
            logger.warning("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•„ ํด๋ผ์ด์–ธํŠธ๋ฅผ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
            self.client = None

        logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")

    def _retrieve(self, query: str) -> str:
        """๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ"""
        try:
            docs = self.vector_store.similarity_search(query, k=3)
            if not docs:
                return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."

            # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
            context_parts = []
            for i, doc in enumerate(docs, 1):
                source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
                page = doc.metadata.get("page", "")
                source_info = f"{source}"
                if page:
                    source_info += f" (ํŽ˜์ด์ง€: {page})"

                context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")

            context = "\n".join(context_parts)

            # ๊ธธ์ด ์ œํ•œ
            if len(context) > 6000:
                context = context[:2500] + "\n...(์ค‘๋žต)...\n" + context[-2500:]

            return context
        except Exception as e:
            logger.error(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."

    def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]:
        """DeepSeek API์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
        # ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
        system_prompt = """๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์—์„œ ๋‹ต์„ ์ฐพ์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
์ •๋ณด ์ถœ์ฒ˜๋ฅผ ํฌํ•จํ•ด์„œ ๋Œ€๋‹ตํ•˜์„ธ์š”."""

        # ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ
        user_prompt = f"""์งˆ๋ฌธ: {query}

์ฐธ๊ณ  ์ •๋ณด:
{context}"""

        # DeepSeek API ํ”„๋กฌํ”„ํŠธ ํฌ๋งท
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        return messages

    def run(self, query: str) -> str:
        """์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
        try:
            logger.info(f"SimpleRAGChain ์‹คํ–‰: {query[:50]}...")

            # ๋ฌธ์„œ ๊ฒ€์ƒ‰
            context = self._retrieve(query)

            # ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
            if self.client is None:
                logger.warning("DeepSeek ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์Œ. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ๋ฐ˜ํ™˜.")
                return f"API ์—ฐ๊ฒฐ์ด ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n\n{context}"

            # ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
            messages = self._generate_prompt(query, context)

            # API ํ˜ธ์ถœ
            start_time = time.time()
            response = self.client.chat(messages)
            logger.info(f"API ์‘๋‹ต ์‹œ๊ฐ„: {time.time() - start_time:.2f}์ดˆ")

            if response["success"]:
                logger.info("์‘๋‹ต ์ƒ์„ฑ ์„ฑ๊ณต")
                return response["response"]
            else:
                logger.error(f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}")
                return f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}\n\n๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n{context}"

        except Exception as e:
            logger.error(f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"