|
""" |
|
DeepSeek API๋ฅผ ํ์ฉํ ์ปค์คํ
RAG ์ฒด์ธ ๊ตฌํ |
|
""" |
|
import os |
|
import logging |
|
import time |
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
from langchain.schema import Document |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
|
|
|
|
from deepseek_llm import DeepSeekLLM, DeepSeekChat |
|
|
|
|
|
try: |
|
from config import ( |
|
DEEPSEEK_API_KEY, DEEPSEEK_MODEL, DEEPSEEK_ENDPOINT, |
|
TOP_K_RETRIEVAL, TOP_K_RERANK |
|
) |
|
except ImportError: |
|
|
|
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "") |
|
DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") |
|
DEEPSEEK_ENDPOINT = os.environ.get("DEEPSEEK_ENDPOINT", "https://api.deepseek.com/v1/chat/completions") |
|
TOP_K_RETRIEVAL = int(os.environ.get("TOP_K_RETRIEVAL", "5")) |
|
TOP_K_RERANK = int(os.environ.get("TOP_K_RERANK", "3")) |
|
|
|
|
|
logger = logging.getLogger("CustomRAGChain") |
|
|
|
|
|
class CustomRAGChain: |
|
""" |
|
DeepSeek API๋ฅผ ํ์ฉํ ์ปค์คํ
RAG ์ฒด์ธ |
|
""" |
|
|
|
def __init__(self, vector_store, use_reranker=False): |
|
""" |
|
RAG ์ฒด์ธ ์ด๊ธฐํ |
|
|
|
Args: |
|
vector_store: ๋ฒกํฐ ์คํ ์ด ์ธ์คํด์ค |
|
use_reranker: ๋ฆฌ๋ญ์ปค ์ฌ์ฉ ์ฌ๋ถ (ํ์ฌ ๋ฏธ์ง์) |
|
""" |
|
logger.info("์ปค์คํ
RAG ์ฒด์ธ ์ด๊ธฐํ...") |
|
self.vector_store = vector_store |
|
self.use_reranker = use_reranker |
|
|
|
|
|
if not DEEPSEEK_API_KEY: |
|
logger.error("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") |
|
raise ValueError("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") |
|
|
|
|
|
try: |
|
self.llm = DeepSeekLLM( |
|
api_key=DEEPSEEK_API_KEY, |
|
model=DEEPSEEK_MODEL, |
|
endpoint=DEEPSEEK_ENDPOINT, |
|
temperature=0.3, |
|
max_tokens=1000, |
|
request_timeout=120, |
|
max_retries=5 |
|
) |
|
logger.info(f"DeepSeek LLM ์ด๊ธฐํ ์ฑ๊ณต: {DEEPSEEK_MODEL}") |
|
except Exception as e: |
|
logger.error(f"DeepSeek LLM ์ด๊ธฐํ ์คํจ: {e}") |
|
raise ValueError(f"DeepSeek LLM ์ด๊ธฐํ ์คํจ: {str(e)}") |
|
|
|
|
|
self.chat = DeepSeekChat( |
|
api_key=DEEPSEEK_API_KEY, |
|
model=DEEPSEEK_MODEL, |
|
endpoint=DEEPSEEK_ENDPOINT |
|
) |
|
|
|
|
|
self.prompt = PromptTemplate.from_template(""" |
|
๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. |
|
|
|
์ง๋ฌธ: {question} |
|
|
|
์ฐธ๊ณ ์ ๋ณด: |
|
{context} |
|
|
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์์ผ๋ฉด ๋ฐ๋์ ๊ทธ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ต๋ณํ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ์๋ ์ผ๋ฐ์ ์ธ ์ง์์ ํ์ฉํ์ฌ ๋ต๋ณํ ์ ์์ง๋ง, "์ ๊ณต๋ ๋ฌธ์์๋ ์ด ์ ๋ณด๊ฐ ์์ผ๋, ์ผ๋ฐ์ ์ผ๋ก๋..." ์์ผ๋ก ์์ํ์ธ์. |
|
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ๊ฐ๋ฅํ ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช
ํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์. |
|
""") |
|
|
|
|
|
self.chain = ( |
|
{"context": self._retrieve, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
|
|
logger.info("์ปค์คํ
RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") |
|
|
|
def _retrieve(self, query: str) -> str: |
|
""" |
|
์ฟผ๋ฆฌ์ ๋ํ ๊ด๋ จ ๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ
์คํธ ๊ตฌ์ฑ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌํจํ ์ปจํ
์คํธ ๋ฌธ์์ด |
|
""" |
|
if not query or not query.strip(): |
|
logger.warning("๋น ์ฟผ๋ฆฌ๋ก ๊ฒ์ ์๋") |
|
return "๊ฒ์ ์ฟผ๋ฆฌ๊ฐ ๋น์ด์์ต๋๋ค." |
|
|
|
try: |
|
|
|
logger.info(f"๋ฒกํฐ ๊ฒ์ ์ํ: '{query[:50]}{'...' if len(query) > 50 else ''}'") |
|
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) |
|
|
|
if not docs: |
|
logger.warning("๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค") |
|
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(docs, 1): |
|
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") |
|
page = doc.metadata.get("page", "") |
|
source_info = f"{source}" |
|
if page: |
|
source_info += f" (ํ์ด์ง: {page})" |
|
|
|
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") |
|
|
|
context = "\n".join(context_parts) |
|
|
|
|
|
if len(context) > 6000: |
|
logger.warning(f"์ปจํ
์คํธ๊ฐ ๋๋ฌด ๊น๋๋ค ({len(context)} ๋ฌธ์). ์ ํํฉ๋๋ค.") |
|
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] |
|
|
|
logger.info(f"์ปจํ
์คํธ ์์ฑ ์๋ฃ: {len(context_parts)}๊ฐ ๋ฌธ์, {len(context)} ๋ฌธ์") |
|
return context |
|
|
|
except Exception as e: |
|
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") |
|
return f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" |
|
|
|
def run(self, query: str) -> str: |
|
""" |
|
์ฌ์ฉ์ ์ฟผ๋ฆฌ์ ๋ํ RAG ํ์ดํ๋ผ์ธ ์คํ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๋ชจ๋ธ ์๋ต ๋ฌธ์์ด |
|
""" |
|
if not query or not query.strip(): |
|
logger.warning("๋น ์ฟผ๋ฆฌ๋ก ์คํ ์๋") |
|
return "์ง๋ฌธ์ด ๋น์ด์์ต๋๋ค. ์ง๋ฌธ์ ์
๋ ฅํด ์ฃผ์ธ์." |
|
|
|
try: |
|
logger.info(f"RAG ์ฒด์ธ ์คํ: '{query[:50]}{'...' if len(query) > 50 else ''}'") |
|
start_time = time.time() |
|
|
|
|
|
context = self._retrieve(query) |
|
|
|
|
|
try: |
|
response = self.chain.invoke(query) |
|
logger.info(f"LangChain ์ฒด์ธ ํธ์ถ ์ฑ๊ณต") |
|
except Exception as chain_error: |
|
logger.error(f"์ฒด์ธ ํธ์ถ ์คํจ: {chain_error}, ๋์ฒด ๋ฐฉ์ ์๋") |
|
|
|
|
|
try: |
|
prompt = self.prompt.format(question=query, context=context) |
|
response = self.chat.generate([{"role": "user", "content": prompt}]) |
|
logger.info("๋์ฒด ์ฑํ
API ํธ์ถ ์ฑ๊ณต") |
|
except Exception as chat_error: |
|
logger.error(f"๋์ฒด ์ฑํ
API ํธ์ถ ์คํจ: {chat_error}") |
|
|
|
|
|
predefined_answers = { |
|
"๋ํ๋ฏผ๊ตญ์ ์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์
๋๋ค.", |
|
"์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์
๋๋ค.", |
|
"๋๊ตฌ์ผ": "์ ๋ RAG ๊ธฐ๋ฐ ์ง์์๋ต ์์คํ
์
๋๋ค. ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์๋๋ฆฝ๋๋ค.", |
|
"์๋
": "์๋
ํ์ธ์! ๋ฌด์์ ๋์๋๋ฆด๊น์?", |
|
"๋ญํด": "์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ต๋ณํ๊ธฐ ์ํด ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ์์ต๋๋ค. ๋ฌด์์ ์๋ ค๋๋ฆด๊น์?" |
|
} |
|
|
|
|
|
for key, answer in predefined_answers.items(): |
|
if key in query.lower(): |
|
response = answer |
|
logger.info(f"๋ฏธ๋ฆฌ ์ ์๋ ์๋ต ์ ๊ณต: {key}") |
|
break |
|
else: |
|
|
|
response = f""" |
|
API ์ฐ๊ฒฐ ์ค๋ฅ๋ก ์ธํด ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ํฉ๋๋ค. |
|
|
|
์ง๋ฌธ: {query} |
|
|
|
๊ฒ์๋ ๊ด๋ จ ๋ฌธ์: |
|
{context} |
|
|
|
[์ฐธ๊ณ ] API ์ฐ๊ฒฐ ๋ฌธ์ ๋ก ์ธํด ์๋ ์์ฝ์ด ์ ๊ณต๋์ง ์์ต๋๋ค. ๋ค์ ์๋ํ๊ฑฐ๋ ๋ค๋ฅธ ์ง๋ฌธ์ ํด๋ณด์ธ์. |
|
""" |
|
logger.info("๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์") |
|
|
|
end_time = time.time() |
|
logger.info(f"RAG ์ฒด์ธ ์คํ ์๋ฃ: {end_time - start_time:.2f}์ด") |
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"RAG ์ฒด์ธ ์คํ ์ค ์ค๋ฅ: {e}") |
|
return f"์ง๋ฌธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" |