|
""" |
|
๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ (๋๋ฒ๊น
์ฉ) - ์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ |
|
""" |
|
import os |
|
import logging |
|
import time |
|
from typing import Dict, Any, List |
|
|
|
|
|
from direct_deepseek import DirectDeepSeekClient |
|
|
|
|
|
logger = logging.getLogger("SimpleRAGChain") |
|
|
|
class SimpleRAGChain: |
|
def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None): |
|
"""๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ""" |
|
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์ค...") |
|
self.vector_store = vector_store |
|
|
|
|
|
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "") |
|
self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") |
|
logger.info(f"API ํค ์ค์ ๋จ: {bool(self.api_key)}") |
|
|
|
|
|
if self.api_key: |
|
try: |
|
self.client = DirectDeepSeekClient( |
|
api_key=self.api_key, |
|
model_name=self.model |
|
) |
|
logger.info(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์ฑ๊ณต: {self.model}") |
|
except Exception as e: |
|
logger.error(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์คํจ: {e}") |
|
self.client = None |
|
else: |
|
logger.warning("API ํค๊ฐ ์ค์ ๋์ง ์์ ํด๋ผ์ด์ธํธ๋ฅผ ์ด๊ธฐํํ ์ ์์ต๋๋ค.") |
|
self.client = None |
|
|
|
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") |
|
|
|
def _retrieve(self, query: str) -> str: |
|
"""๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ
์คํธ ๊ตฌ์ฑ""" |
|
try: |
|
docs = self.vector_store.similarity_search(query, k=3) |
|
if not docs: |
|
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(docs, 1): |
|
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") |
|
page = doc.metadata.get("page", "") |
|
source_info = f"{source}" |
|
if page: |
|
source_info += f" (ํ์ด์ง: {page})" |
|
|
|
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") |
|
|
|
context = "\n".join(context_parts) |
|
|
|
|
|
if len(context) > 6000: |
|
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] |
|
|
|
return context |
|
except Exception as e: |
|
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") |
|
return "๋ฌธ์ ๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค." |
|
|
|
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]: |
|
"""DeepSeek API์ฉ ํ๋กฌํํธ ์์ฑ""" |
|
|
|
system_prompt = """๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์์ ๋ต์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. |
|
์ ๋ณด ์ถ์ฒ๋ฅผ ํฌํจํด์ ๋๋ตํ์ธ์.""" |
|
|
|
|
|
user_prompt = f"""์ง๋ฌธ: {query} |
|
|
|
์ฐธ๊ณ ์ ๋ณด: |
|
{context}""" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt} |
|
] |
|
|
|
return messages |
|
|
|
def run(self, query: str) -> str: |
|
"""์ฟผ๋ฆฌ ์ฒ๋ฆฌ""" |
|
try: |
|
logger.info(f"SimpleRAGChain ์คํ: {query[:50]}...") |
|
|
|
|
|
context = self._retrieve(query) |
|
|
|
|
|
if self.client is None: |
|
logger.warning("DeepSeek ํด๋ผ์ด์ธํธ๊ฐ ์ด๊ธฐํ๋์ง ์์. ๊ฒ์ ๊ฒฐ๊ณผ๋ง ๋ฐํ.") |
|
return f"API ์ฐ๊ฒฐ์ด ์ค์ ๋์ง ์์์ต๋๋ค. ๊ฒ์ ๊ฒฐ๊ณผ:\n\n{context}" |
|
|
|
|
|
messages = self._generate_prompt(query, context) |
|
|
|
|
|
start_time = time.time() |
|
response = self.client.chat(messages) |
|
logger.info(f"API ์๋ต ์๊ฐ: {time.time() - start_time:.2f}์ด") |
|
|
|
if response["success"]: |
|
logger.info("์๋ต ์์ฑ ์ฑ๊ณต") |
|
return response["response"] |
|
else: |
|
logger.error(f"์๋ต ์์ฑ ์คํจ: {response['message']}") |
|
return f"์๋ต ์์ฑ ์คํจ: {response['message']}\n\n๊ฒ์ ๊ฒฐ๊ณผ:\n{context}" |
|
|
|
except Exception as e: |
|
logger.error(f"์คํ ์ค ์ค๋ฅ: {e}") |
|
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |