RAG3 / fallback_rag_chain.py
jeongsoo's picture
deepseek_done
ac1b0e8
"""
ํด๋ฐฑ RAG ์ฒด์ธ ๊ตฌํ˜„ (๊ธฐ๋ณธ์ ์ธ ๊ธฐ๋Šฅ๋งŒ ํฌํ•จ) - ์ง์ ‘ DeepSeek API ํ˜ธ์ถœ ๋ฐฉ์‹
"""
import os
import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from langchain.schema import Document
# ์ง์ ‘ DeepSeek ํด๋ผ์ด์–ธํŠธ ์‚ฌ์šฉ
from direct_deepseek import DirectDeepSeekClient
# ์„ค์ • ๊ฐ€์ ธ์˜ค๊ธฐ
from config import (
LLM_MODEL, USE_OPENAI, USE_DEEPSEEK,
DEEPSEEK_API_KEY, DEEPSEEK_ENDPOINT, DEEPSEEK_MODEL,
TOP_K_RETRIEVAL
)
# ๋กœ๊น… ์„ค์ •
logger = logging.getLogger("FallbackRAGChain")
class FallbackRAGChain:
"""
๊ธฐ๋ณธ์ ์ธ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋‹จ์ˆœํ™”๋œ ๋ฒ„์ „, ๋ฌธ์ œ ํ•ด๊ฒฐ์šฉ)
์ง์ ‘ DeepSeek API ํ˜ธ์ถœ ๋ฐฉ์‹ ์‚ฌ์šฉ
"""
def __init__(self, vector_store):
"""
RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”
Args:
vector_store: ๋ฒกํ„ฐ ์Šคํ† ์–ด ์ธ์Šคํ„ด์Šค
"""
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”...")
self.vector_store = vector_store
# DeepSeek ๋ชจ๋ธ ์ง์ ‘ ์ดˆ๊ธฐํ™”
if USE_DEEPSEEK and DEEPSEEK_API_KEY:
logger.info(f"DeepSeek ๋ชจ๋ธ ์ง์ ‘ ์ดˆ๊ธฐํ™”: {DEEPSEEK_MODEL}")
try:
self.client = DirectDeepSeekClient(
api_key=DEEPSEEK_API_KEY,
model_name=DEEPSEEK_MODEL
)
logger.info("DeepSeek ๋ชจ๋ธ ์ง์ ‘ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต")
except Exception as e:
logger.error(f"DeepSeek ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
# ์˜คํ”„๋ผ์ธ ๋ชจ๋“œ๋กœ ํด๋ฐฑ
self.client = None
logger.warning("LLM์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•„ ์˜คํ”„๋ผ์ธ ๋ชจ๋“œ๋กœ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.")
else:
# LLM์ด ์„ค์ •๋˜์ง€ ์•Š์Œ
logger.warning("LLM์ด ์„ค์ •๋˜์ง€ ์•Š์•„ ์˜คํ”„๋ผ์ธ ๋ชจ๋“œ๋กœ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.")
self.client = None
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
def _retrieve(self, query: str) -> str:
"""
์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๊ด€๋ จ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํฌํ•จํ•œ ์ปจํ…์ŠคํŠธ ๋ฌธ์ž์—ด
"""
if not query or not query.strip():
return "๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค."
try:
# ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
logger.info(f"๋ฒกํ„ฐ ๊ฒ€์ƒ‰: '{query[:30]}...'")
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)
if not docs:
return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํŽ˜์ด์ง€: {page})"
context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")
context = "\n".join(context_parts)
# ์ปจํ…์ŠคํŠธ ๊ธธ์ด ์ œํ•œ (ํ† ํฐ ์ˆ˜ ์ œํ•œ)
if len(context) > 6000:
logger.warning(f"์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค ({len(context)} ๋ฌธ์ž). ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.")
context = context[:2500] + "\n...(์ค‘๋žต)...\n" + context[-2500:]
return context
except Exception as e:
logger.error(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]:
"""
ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ (DeepSeek API ํ˜•์‹)
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
context: ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ
Returns:
DeepSeek API์šฉ messages ํ˜•์‹
"""
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
system_prompt = """๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์žˆ์œผ๋ฉด ๋ฐ˜๋“œ์‹œ ๊ทธ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ์—๋Š” ์ผ๋ฐ˜์ ์ธ ์ง€์‹์„ ํ™œ์šฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, "์ œ๊ณต๋œ ๋ฌธ์„œ์—๋Š” ์ด ์ •๋ณด๊ฐ€ ์—†์œผ๋‚˜, ์ผ๋ฐ˜์ ์œผ๋กœ๋Š”..." ์‹์œผ๋กœ ์‹œ์ž‘ํ•˜์„ธ์š”.
๋‹ต๋ณ€์€ ์ •ํ™•ํ•˜๊ณ  ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ œ๊ณตํ•˜๋˜, ๊ฐ€๋Šฅํ•œ ์ฐธ๊ณ  ์ •๋ณด์—์„œ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์•„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์˜ ์ถœ์ฒ˜๋„ ํ•จ๊ป˜ ์•Œ๋ ค์ฃผ์„ธ์š”."""
# ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ (์งˆ๋ฌธ๊ณผ ์ปจํ…์ŠคํŠธ ํฌํ•จ)
user_prompt = f"""์งˆ๋ฌธ: {query}
์ฐธ๊ณ  ์ •๋ณด:
{context}"""
# DeepSeek API์— ๋งž๋Š” ๋ฉ”์‹œ์ง€ ํฌ๋งท
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
return messages
def _generate_simple_response(self, query: str, context: str) -> str:
"""
๊ฐ„๋‹จํ•œ ์˜คํ”„๋ผ์ธ ์‘๋‹ต ์ƒ์„ฑ (LLM์ด ์—†์„ ๋•Œ ์‚ฌ์šฉ)
"""
# ๊ธฐ๋ณธ ์ œ๊ณต ์‘๋‹ต (์ผ๋ฐ˜์ ์ธ ์งˆ๋ฌธ์— ๋Œ€ํ•œ ์ •ํ•ด์ง„ ์‘๋‹ต)
predefined_answers = {
"๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š” ์„œ์šธ์ž…๋‹ˆ๋‹ค.",
"์ˆ˜๋„": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š” ์„œ์šธ์ž…๋‹ˆ๋‹ค.",
"๋ˆ„๊ตฌ์•ผ": "์ €๋Š” RAG ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค. ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ  ๊ด€๋ จ ์ •๋ณด๋ฅผ ์ฐพ์•„๋“œ๋ฆฝ๋‹ˆ๋‹ค.",
"์•ˆ๋…•": "์•ˆ๋…•ํ•˜์„ธ์š”! ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?",
"๋ญํ•ด": "์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๊ธฐ ์œ„ํ•ด ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฌด์—‡์„ ์•Œ๋ ค๋“œ๋ฆด๊นŒ์š”?"
}
# ์งˆ๋ฌธ์— ๋งž๋Š” ๋ฏธ๋ฆฌ ์ •์˜๋œ ์‘๋‹ต์ด ์žˆ๋Š”์ง€ ํ™•์ธ
for key, answer in predefined_answers.items():
if key in query.lower():
return answer
# ๋ฏธ๋ฆฌ ์ •์˜๋œ ์‘๋‹ต์ด ์—†์œผ๋ฉด ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ํ‘œ์‹œ
return f"""
ํ˜„์žฌ LLM API ์—ฐ๊ฒฐ์— ๋ฌธ์ œ๊ฐ€ ์žˆ์–ด ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
์งˆ๋ฌธ: {query}
๊ฒ€์ƒ‰๋œ ๊ด€๋ จ ๋ฌธ์„œ:
{context}
[์ฐธ๊ณ ] ๊ด€๋ จ ์ •๋ณด๋ฅผ ์ฐพ์œผ์…จ๋‚˜์š”? API ์—ฐ๊ฒฐ ๋ฌธ์ œ๋กœ ์ธํ•ด ์ž๋™ ์š”์•ฝ์ด ์ œ๊ณต๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•˜๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ์งˆ๋ฌธ์„ ํ•ด๋ณด์„ธ์š”.
"""
def run(self, query: str) -> str:
"""
์‚ฌ์šฉ์ž ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ RAG ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰
Args:
query: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
Returns:
๋ชจ๋ธ ์‘๋‹ต ๋ฌธ์ž์—ด
"""
if not query or not query.strip():
return "์งˆ๋ฌธ์ด ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”."
try:
logger.info(f"RAG ์ฒด์ธ ์‹คํ–‰: '{query[:30]}...'")
# ๋ฌธ์„œ ๊ฒ€์ƒ‰
context = self._retrieve(query)
# LLM์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ ์˜คํ”„๋ผ์ธ ์‘๋‹ต
if self.client is None:
logger.warning("LLM์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•„ ์˜คํ”„๋ผ์ธ ์‘๋‹ต ์ƒ์„ฑ")
return self._generate_simple_response(query, context)
# ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
messages = self._generate_prompt(query, context)
# ์‘๋‹ต ์ƒ์„ฑ (์ตœ๋Œ€ 3ํšŒ ์‹œ๋„)
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
logger.info(f"์‘๋‹ต ์ƒ์„ฑ ์‹œ๋„ ({attempt+1}/{max_retries})")
# ์ง์ ‘ DeepSeek API ํ˜ธ์ถœ
response = self.client.chat(messages)
if response["success"]:
logger.info(f"์‘๋‹ต ์ƒ์„ฑ ์„ฑ๊ณต (๊ธธ์ด: {len(response['response'])})")
return response["response"]
else:
logger.error(f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}")
if attempt < max_retries - 1:
logger.info(f"{retry_delay}์ดˆ ํ›„ ์žฌ์‹œ๋„...")
time.sleep(retry_delay)
retry_delay *= 2
else:
# ๋ชจ๋“  ์‹œ๋„ ์‹คํŒจ ์‹œ ์˜คํ”„๋ผ์ธ ์‘๋‹ต
logger.warning("์ตœ๋Œ€ ์žฌ์‹œ๋„ ํšŸ์ˆ˜ ์ดˆ๊ณผ, ์˜คํ”„๋ผ์ธ ์‘๋‹ต ์ƒ์„ฑ")
return self._generate_simple_response(query, context)
except Exception as e:
logger.error(f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜: {e}")
if attempt < max_retries - 1:
logger.info(f"{retry_delay}์ดˆ ํ›„ ์žฌ์‹œ๋„...")
time.sleep(retry_delay)
retry_delay *= 2
else:
# ๋ชจ๋“  ์‹œ๋„ ์‹คํŒจ ์‹œ ์˜คํ”„๋ผ์ธ ์‘๋‹ต ์ƒ์„ฑ
return self._generate_simple_response(query, context)
except Exception as e:
logger.error(f"RAG ์ฒด์ธ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"์งˆ๋ฌธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"