RAG3 / custom_rag_chain.py
jeongsoo's picture
deepseek_done
ac1b0e8
"""
DeepSeek API๋ฅผ ํ™œ์šฉํ•œ ์ปค์Šคํ…€ RAG ์ฒด์ธ ๊ตฌํ˜„
"""
import os
import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# DeepSeek ์ปค์Šคํ…€ LLM ์ž„ํฌํŠธ
from deepseek_llm import DeepSeekLLM, DeepSeekChat
# ์„ค์ • ๊ฐ€์ ธ์˜ค๊ธฐ
try:
from config import (
DEEPSEEK_API_KEY, DEEPSEEK_MODEL, DEEPSEEK_ENDPOINT,
TOP_K_RETRIEVAL, TOP_K_RERANK
)
except ImportError:
# ์„ค์ • ๋ชจ๋“ˆ์„ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "")
DEEPSEEK_MODEL = os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
DEEPSEEK_ENDPOINT = os.environ.get("DEEPSEEK_ENDPOINT", "https://api.deepseek.com/v1/chat/completions")
TOP_K_RETRIEVAL = int(os.environ.get("TOP_K_RETRIEVAL", "5"))
TOP_K_RERANK = int(os.environ.get("TOP_K_RERANK", "3"))
# ๋กœ๊น… ์„ค์ •
logger = logging.getLogger("CustomRAGChain")
class CustomRAGChain:
"""
DeepSeek API๋ฅผ ํ™œ์šฉํ•œ ์ปค์Šคํ…€ RAG ์ฒด์ธ
"""
def __init__(self, vector_store, use_reranker=False):
"""
RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”
Args:
vector_store: ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ธ์Šคํ„ด์Šค
use_reranker: ๋ฆฌ๋žญ์ปค ์‚ฌ์šฉ ์—ฌ๋ถ€ (ํ˜„์žฌ ๋ฏธ์ง€์›)
"""
logger.info("์ปค์Šคํ…€ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”...")
self.vector_store = vector_store
self.use_reranker = use_reranker
# API ํ‚ค ํ™•์ธ
if not DEEPSEEK_API_KEY:
logger.error("DeepSeek API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
raise ValueError("DeepSeek API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
# DeepSeek LLM ์ดˆ๊ธฐํ™”
try:
self.llm = DeepSeekLLM(
api_key=DEEPSEEK_API_KEY,
model=DEEPSEEK_MODEL,
endpoint=DEEPSEEK_ENDPOINT,
temperature=0.3,
max_tokens=1000,
request_timeout=120,
max_retries=5
)
logger.info(f"DeepSeek LLM ์ดˆ๊ธฐํ™” ์„ฑ๊ณต: {DEEPSEEK_MODEL}")
except Exception as e:
logger.error(f"DeepSeek LLM ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
raise ValueError(f"DeepSeek LLM ์ดˆ๊ธฐํ™” ์‹คํŒจ: {str(e)}")
# ์ฑ— ์ธํ„ฐํŽ˜์ด์Šค ์ดˆ๊ธฐํ™” (๋Œ€์ฒด์šฉ)
self.chat = DeepSeekChat(
api_key=DEEPSEEK_API_KEY,
model=DEEPSEEK_MODEL,
endpoint=DEEPSEEK_ENDPOINT
)
# RAG ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
self.prompt = PromptTemplate.from_template("""
๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์งˆ๋ฌธ: {question}
์ฐธ๊ณ  ์ •๋ณด:
{context}
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์žˆ์œผ๋ฉด ๋ฐ˜๋“œ์‹œ ๊ทธ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ์—๋Š” ์ผ๋ฐ˜์ ์ธ ์ง€์‹์„ ํ™œ์šฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, "์ œ๊ณต๋œ ๋ฌธ์„œ์—๋Š” ์ด ์ •๋ณด๊ฐ€ ์—†์œผ๋‚˜, ์ผ๋ฐ˜์ ์œผ๋กœ๋Š”..." ์‹์œผ๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”.
๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ๊ฐ€๋Šฅํ•œ ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”.
""")
# RAG ์ฒด์ธ ๊ตฌ์„ฑ
self.chain = (
{"context": self._retrieve, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
logger.info("์ปค์Šคํ…€ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
def _retrieve(self, query: str) -> str:
"""
์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํฌํ•จํ•œ ์ปจํ…์ŠคํŠธ ๋ฌธ์ž์—ด
"""
if not query or not query.strip():
logger.warning("๋นˆ ์ฟผ๋ฆฌ๋กœ ๊ฒ€์ƒ‰ ์‹œ๋„")
return "๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค."
try:
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
logger.info(f"๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰: '{query[:50]}{'...' if len(query) > 50 else ''}'")
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)
if not docs:
logger.warning("๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํŽ˜์ด์ง€: {page})"
context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")
context = "\n".join(context_parts)
# ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ œํ•œ (ํ† ํฐ ์ˆ˜ ์ œํ•œ)
if len(context) > 6000:
logger.warning(f"์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค ({len(context)} ๋ฌธ์ž). ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.")
context = context[:2500] + "\n...(์ค‘๋žต)...\n" + context[-2500:]
logger.info(f"์ปจํ…์ŠคํŠธ ์ƒ์„ฑ ์™„๋ฃŒ: {len(context_parts)}๊ฐœ ๋ฌธ์„œ, {len(context)} ๋ฌธ์ž")
return context
except Exception as e:
logger.error(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
def run(self, query: str) -> str:
"""
์‚ฌ์šฉ์ž ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ RAG ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๋ชจ๋ธ ์‘๋‹ต ๋ฌธ์ž์—ด
"""
if not query or not query.strip():
logger.warning("๋นˆ ์ฟผ๋ฆฌ๋กœ ์‹คํ–‰ ์‹œ๋„")
return "์งˆ๋ฌธ์ด ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”."
try:
logger.info(f"RAG ์ฒด์ธ ์‹คํ–‰: '{query[:50]}{'...' if len(query) > 50 else ''}'")
start_time = time.time()
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์‹คํ–‰
context = self._retrieve(query)
# ์ง์ ‘ LLM ํ˜ธ์ถœ (์ฒด์ธ ์‚ฌ์šฉ)
try:
response = self.chain.invoke(query)
logger.info(f"LangChain ์ฒด์ธ ํ˜ธ์ถœ ์„ฑ๊ณต")
except Exception as chain_error:
logger.error(f"์ฒด์ธ ํ˜ธ์ถœ ์‹คํŒจ: {chain_error}, ๋Œ€์ฒด ๋ฐฉ์‹ ์‹œ๋„")
# ๋Œ€์ฒด ๋ฐฉ์‹: ์ง์ ‘ ์ฑ„ํŒ… API ํ˜ธ์ถœ
try:
prompt = self.prompt.format(question=query, context=context)
response = self.chat.generate([{"role": "user", "content": prompt}])
logger.info("๋Œ€์ฒด ์ฑ„ํŒ… API ํ˜ธ์ถœ ์„ฑ๊ณต")
except Exception as chat_error:
logger.error(f"๋Œ€์ฒด ์ฑ„ํŒ… API ํ˜ธ์ถœ ์‹คํŒจ: {chat_error}")
# ๋ฏธ๋ฆฌ ์ •์˜๋œ ์‘๋‹ต์œผ๋กœ ํด๋ฐฑ
predefined_answers = {
"๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š” ์„œ์šธ์ž…๋‹ˆ๋‹ค.",
"์ˆ˜๋„": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š” ์„œ์šธ์ž…๋‹ˆ๋‹ค.",
"๋ˆ„๊ตฌ์•ผ": "์ €๋Š” RAG ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค. ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ  ๊ด€๋ จ ์ •๋ณด๋ฅผ ์ฐพ์•„๋“œ๋ฆฝ๋‹ˆ๋‹ค.",
"์•ˆ๋…•": "์•ˆ๋…•ํ•˜์„ธ์š”! ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?",
"๋ญํ•ด": "์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๊ธฐ ์œ„ํ•ด ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌด์—‡์„ ์•Œ๋ ค๋“œ๋ฆด๊นŒ์š”?"
}
# ์งˆ๋ฌธ์— ๋งž๋Š” ๋ฏธ๋ฆฌ ์ •์˜๋œ ์‘๋‹ต์ด ์žˆ๋Š”์ง€ ํ™•์ธ
for key, answer in predefined_answers.items():
if key in query.lower():
response = answer
logger.info(f"๋ฏธ๋ฆฌ ์ •์˜๋œ ์‘๋‹ต ์ œ๊ณต: {key}")
break
else:
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ํ‘œ์‹œ
response = f"""
API ์—ฐ๊ฒฐ ์˜ค๋ฅ˜๋กœ ์ธํ•ด ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
์งˆ๋ฌธ: {query}
๊ฒ€์ƒ‰๋œ ๊ด€๋ จ ๋ฌธ์„œ:
{context}
[์ฐธ๊ณ ] API ์—ฐ๊ฒฐ ๋ฌธ์ œ๋กœ ์ธํ•ด ์ž๋™ ์š”์•ฝ์ด ์ œ๊ณต๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•˜๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ์งˆ๋ฌธ์„ ํ•ด๋ณด์„ธ์š”.
"""
logger.info("๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ํ‘œ์‹œ")
end_time = time.time()
logger.info(f"RAG ์ฒด์ธ ์‹คํ–‰ ์™„๋ฃŒ: {end_time - start_time:.2f}์ดˆ")
return response
except Exception as e:
logger.error(f"RAG ์ฒด์ธ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"์งˆ๋ฌธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"