File size: 4,703 Bytes
64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 ac1b0e8 64a2fb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""
๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ (๋๋ฒ๊น
์ฉ) - ์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์
"""
import os
import logging
import time
from typing import Dict, Any, List
# ์ง์ DeepSeek ํด๋ผ์ด์ธํธ ์ฌ์ฉ
from direct_deepseek import DirectDeepSeekClient
# ๋ก๊น
์ค์
logger = logging.getLogger("SimpleRAGChain")
class SimpleRAGChain:
def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None):
"""๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ"""
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์ค...")
self.vector_store = vector_store
# DeepSeek API ํค ํ์ธ
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "")
self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
logger.info(f"API ํค ์ค์ ๋จ: {bool(self.api_key)}")
# DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ
if self.api_key:
try:
self.client = DirectDeepSeekClient(
api_key=self.api_key,
model_name=self.model
)
logger.info(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์ฑ๊ณต: {self.model}")
except Exception as e:
logger.error(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์คํจ: {e}")
self.client = None
else:
logger.warning("API ํค๊ฐ ์ค์ ๋์ง ์์ ํด๋ผ์ด์ธํธ๋ฅผ ์ด๊ธฐํํ ์ ์์ต๋๋ค.")
self.client = None
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ")
def _retrieve(self, query: str) -> str:
"""๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ
์คํธ ๊ตฌ์ฑ"""
try:
docs = self.vector_store.similarity_search(query, k=3)
if not docs:
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
# ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ
์คํธ ๊ตฌ์ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํ์ด์ง: {page})"
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n")
context = "\n".join(context_parts)
# ๊ธธ์ด ์ ํ
if len(context) > 6000:
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:]
return context
except Exception as e:
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}")
return "๋ฌธ์ ๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค."
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]:
"""DeepSeek API์ฉ ํ๋กฌํํธ ์์ฑ"""
# ์์คํ
ํ๋กฌํํธ
system_prompt = """๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์.
์ฐธ๊ณ ์ ๋ณด์์ ๋ต์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์.
์ ๋ณด ์ถ์ฒ๋ฅผ ํฌํจํด์ ๋๋ตํ์ธ์."""
# ์ฌ์ฉ์ ํ๋กฌํํธ
user_prompt = f"""์ง๋ฌธ: {query}
์ฐธ๊ณ ์ ๋ณด:
{context}"""
# DeepSeek API ํ๋กฌํํธ ํฌ๋งท
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
return messages
def run(self, query: str) -> str:
"""์ฟผ๋ฆฌ ์ฒ๋ฆฌ"""
try:
logger.info(f"SimpleRAGChain ์คํ: {query[:50]}...")
# ๋ฌธ์ ๊ฒ์
context = self._retrieve(query)
# ํด๋ผ์ด์ธํธ๊ฐ ์ด๊ธฐํ๋์ง ์์ ๊ฒฝ์ฐ
if self.client is None:
logger.warning("DeepSeek ํด๋ผ์ด์ธํธ๊ฐ ์ด๊ธฐํ๋์ง ์์. ๊ฒ์ ๊ฒฐ๊ณผ๋ง ๋ฐํ.")
return f"API ์ฐ๊ฒฐ์ด ์ค์ ๋์ง ์์์ต๋๋ค. ๊ฒ์ ๊ฒฐ๊ณผ:\n\n{context}"
# ํ๋กฌํํธ ์์ฑ
messages = self._generate_prompt(query, context)
# API ํธ์ถ
start_time = time.time()
response = self.client.chat(messages)
logger.info(f"API ์๋ต ์๊ฐ: {time.time() - start_time:.2f}์ด")
if response["success"]:
logger.info("์๋ต ์์ฑ ์ฑ๊ณต")
return response["response"]
else:
logger.error(f"์๋ต ์์ฑ ์คํจ: {response['message']}")
return f"์๋ต ์์ฑ ์คํจ: {response['message']}\n\n๊ฒ์ ๊ฒฐ๊ณผ:\n{context}"
except Exception as e:
logger.error(f"์คํ ์ค ์ค๋ฅ: {e}")
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |