Spaces:
Paused
Paused
""" | |
DeepSeek API๋ฅผ ํ์ฉํ ์ปค์คํ RAG ์ฒด์ธ ๊ตฌํ | |
""" | |
import os | |
import logging | |
import time | |
from typing import List, Dict, Any, Optional, Tuple | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
# DeepSeek ์ปค์คํ LLM ์ํฌํธ | |
from deepseek_llm import DeepSeekLLM, DeepSeekChat | |
# ์ค์ ๊ฐ์ ธ์ค๊ธฐ | |
try: | |
from config import ( | |
DEEPSEEK_API_KEY, DEEPSEEK_MODEL, DEEPSEEK_ENDPOINT, | |
TOP_K_RETRIEVAL, TOP_K_RERANK | |
) | |
except ImportError: | |
# ์ค์ ๋ชจ๋์ ๊ฐ์ ธ์ฌ ์ ์๋ ๊ฒฝ์ฐ ๊ธฐ๋ณธ๊ฐ ์ค์ | |
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "") | |
DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") | |
DEEPSEEK_ENDPOINT = os.environ.get("DEEPSEEK_ENDPOINT", "https://api.deepseek.com/v1/chat/completions") | |
TOP_K_RETRIEVAL = int(os.environ.get("TOP_K_RETRIEVAL", "5")) | |
TOP_K_RERANK = int(os.environ.get("TOP_K_RERANK", "3")) | |
# ๋ก๊น ์ค์ | |
logger = logging.getLogger("CustomRAGChain") | |
class CustomRAGChain: | |
""" | |
DeepSeek API๋ฅผ ํ์ฉํ ์ปค์คํ RAG ์ฒด์ธ | |
""" | |
def __init__(self, vector_store, use_reranker=False): | |
""" | |
RAG ์ฒด์ธ ์ด๊ธฐํ | |
Args: | |
vector_store: ๋ฒกํฐ ์คํ ์ด ์ธ์คํด์ค | |
use_reranker: ๋ฆฌ๋ญ์ปค ์ฌ์ฉ ์ฌ๋ถ (ํ์ฌ ๋ฏธ์ง์) | |
""" | |
logger.info("์ปค์คํ RAG ์ฒด์ธ ์ด๊ธฐํ...") | |
self.vector_store = vector_store | |
self.use_reranker = use_reranker | |
# API ํค ํ์ธ | |
if not DEEPSEEK_API_KEY: | |
logger.error("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
raise ValueError("DeepSeek API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
# DeepSeek LLM ์ด๊ธฐํ | |
try: | |
self.llm = DeepSeekLLM( | |
api_key=DEEPSEEK_API_KEY, | |
model=DEEPSEEK_MODEL, | |
endpoint=DEEPSEEK_ENDPOINT, | |
temperature=0.3, | |
max_tokens=1000, | |
request_timeout=120, | |
max_retries=5 | |
) | |
logger.info(f"DeepSeek LLM ์ด๊ธฐํ ์ฑ๊ณต: {DEEPSEEK_MODEL}") | |
except Exception as e: | |
logger.error(f"DeepSeek LLM ์ด๊ธฐํ ์คํจ: {e}") | |
raise ValueError(f"DeepSeek LLM ์ด๊ธฐํ ์คํจ: {str(e)}") | |
# ์ฑ ์ธํฐํ์ด์ค ์ด๊ธฐํ (๋์ฒด์ฉ) | |
self.chat = DeepSeekChat( | |
api_key=DEEPSEEK_API_KEY, | |
model=DEEPSEEK_MODEL, | |
endpoint=DEEPSEEK_ENDPOINT | |
) | |
# RAG ํ๋กฌํํธ ํ ํ๋ฆฟ | |
self.prompt = PromptTemplate.from_template(""" | |
๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. | |
์ง๋ฌธ: {question} | |
์ฐธ๊ณ ์ ๋ณด: | |
{context} | |
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์์ผ๋ฉด ๋ฐ๋์ ๊ทธ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ต๋ณํ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ์๋ ์ผ๋ฐ์ ์ธ ์ง์์ ํ์ฉํ์ฌ ๋ต๋ณํ ์ ์์ง๋ง, "์ ๊ณต๋ ๋ฌธ์์๋ ์ด ์ ๋ณด๊ฐ ์์ผ๋, ์ผ๋ฐ์ ์ผ๋ก๋..." ์์ผ๋ก ์์ํ์ธ์. | |
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ๊ฐ๋ฅํ ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช ํด์ฃผ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์. | |
""") | |
# RAG ์ฒด์ธ ๊ตฌ์ฑ | |
self.chain = ( | |
{"context": self._retrieve, "question": RunnablePassthrough()} | |
| self.prompt | |
| self.llm | |
| StrOutputParser() | |
) | |
logger.info("์ปค์คํ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") | |
def _retrieve(self, query: str) -> str: | |
""" | |
์ฟผ๋ฆฌ์ ๋ํ ๊ด๋ จ ๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ ์คํธ ๊ตฌ์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง๋ฌธ | |
Returns: | |
๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌํจํ ์ปจํ ์คํธ ๋ฌธ์์ด | |
""" | |
if not query or not query.strip(): | |
logger.warning("๋น ์ฟผ๋ฆฌ๋ก ๊ฒ์ ์๋") | |
return "๊ฒ์ ์ฟผ๋ฆฌ๊ฐ ๋น์ด์์ต๋๋ค." | |
try: | |
# ๋ฒกํฐ ๊ฒ์ ์ํ | |
logger.info(f"๋ฒกํฐ ๊ฒ์ ์ํ: '{query[:50]}{'...' if len(query) > 50 else ''}'") | |
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) | |
if not docs: | |
logger.warning("๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค") | |
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
# ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ ์คํธ ๊ตฌ์ฑ | |
context_parts = [] | |
for i, doc in enumerate(docs, 1): | |
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") | |
page = doc.metadata.get("page", "") | |
source_info = f"{source}" | |
if page: | |
source_info += f" (ํ์ด์ง: {page})" | |
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") | |
context = "\n".join(context_parts) | |
# ์ปจํ ์คํธ ๊ธธ์ด ์ ํ (ํ ํฐ ์ ์ ํ) | |
if len(context) > 6000: | |
logger.warning(f"์ปจํ ์คํธ๊ฐ ๋๋ฌด ๊น๋๋ค ({len(context)} ๋ฌธ์). ์ ํํฉ๋๋ค.") | |
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] | |
logger.info(f"์ปจํ ์คํธ ์์ฑ ์๋ฃ: {len(context_parts)}๊ฐ ๋ฌธ์, {len(context)} ๋ฌธ์") | |
return context | |
except Exception as e: | |
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") | |
return f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
def run(self, query: str) -> str: | |
""" | |
์ฌ์ฉ์ ์ฟผ๋ฆฌ์ ๋ํ RAG ํ์ดํ๋ผ์ธ ์คํ | |
Args: | |
query: ์ฌ์ฉ์ ์ง๋ฌธ | |
Returns: | |
๋ชจ๋ธ ์๋ต ๋ฌธ์์ด | |
""" | |
if not query or not query.strip(): | |
logger.warning("๋น ์ฟผ๋ฆฌ๋ก ์คํ ์๋") | |
return "์ง๋ฌธ์ด ๋น์ด์์ต๋๋ค. ์ง๋ฌธ์ ์ ๋ ฅํด ์ฃผ์ธ์." | |
try: | |
logger.info(f"RAG ์ฒด์ธ ์คํ: '{query[:50]}{'...' if len(query) > 50 else ''}'") | |
start_time = time.time() | |
# ๋ฒกํฐ ๊ฒ์ ์คํ | |
context = self._retrieve(query) | |
# ์ง์ LLM ํธ์ถ (์ฒด์ธ ์ฌ์ฉ) | |
try: | |
response = self.chain.invoke(query) | |
logger.info(f"LangChain ์ฒด์ธ ํธ์ถ ์ฑ๊ณต") | |
except Exception as chain_error: | |
logger.error(f"์ฒด์ธ ํธ์ถ ์คํจ: {chain_error}, ๋์ฒด ๋ฐฉ์ ์๋") | |
# ๋์ฒด ๋ฐฉ์: ์ง์ ์ฑํ API ํธ์ถ | |
try: | |
prompt = self.prompt.format(question=query, context=context) | |
response = self.chat.generate([{"role": "user", "content": prompt}]) | |
logger.info("๋์ฒด ์ฑํ API ํธ์ถ ์ฑ๊ณต") | |
except Exception as chat_error: | |
logger.error(f"๋์ฒด ์ฑํ API ํธ์ถ ์คํจ: {chat_error}") | |
# ๋ฏธ๋ฆฌ ์ ์๋ ์๋ต์ผ๋ก ํด๋ฐฑ | |
predefined_answers = { | |
"๋ํ๋ฏผ๊ตญ์ ์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์ ๋๋ค.", | |
"์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์ ๋๋ค.", | |
"๋๊ตฌ์ผ": "์ ๋ RAG ๊ธฐ๋ฐ ์ง์์๋ต ์์คํ ์ ๋๋ค. ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์๋๋ฆฝ๋๋ค.", | |
"์๋ ": "์๋ ํ์ธ์! ๋ฌด์์ ๋์๋๋ฆด๊น์?", | |
"๋ญํด": "์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ต๋ณํ๊ธฐ ์ํด ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ์์ต๋๋ค. ๋ฌด์์ ์๋ ค๋๋ฆด๊น์?" | |
} | |
# ์ง๋ฌธ์ ๋ง๋ ๋ฏธ๋ฆฌ ์ ์๋ ์๋ต์ด ์๋์ง ํ์ธ | |
for key, answer in predefined_answers.items(): | |
if key in query.lower(): | |
response = answer | |
logger.info(f"๋ฏธ๋ฆฌ ์ ์๋ ์๋ต ์ ๊ณต: {key}") | |
break | |
else: | |
# ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ | |
response = f""" | |
API ์ฐ๊ฒฐ ์ค๋ฅ๋ก ์ธํด ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ํฉ๋๋ค. | |
์ง๋ฌธ: {query} | |
๊ฒ์๋ ๊ด๋ จ ๋ฌธ์: | |
{context} | |
[์ฐธ๊ณ ] API ์ฐ๊ฒฐ ๋ฌธ์ ๋ก ์ธํด ์๋ ์์ฝ์ด ์ ๊ณต๋์ง ์์ต๋๋ค. ๋ค์ ์๋ํ๊ฑฐ๋ ๋ค๋ฅธ ์ง๋ฌธ์ ํด๋ณด์ธ์. | |
""" | |
logger.info("๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์") | |
end_time = time.time() | |
logger.info(f"RAG ์ฒด์ธ ์คํ ์๋ฃ: {end_time - start_time:.2f}์ด") | |
return response | |
except Exception as e: | |
logger.error(f"RAG ์ฒด์ธ ์คํ ์ค ์ค๋ฅ: {e}") | |
return f"์ง๋ฌธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" |