|
""" |
|
ํด๋ฐฑ RAG ์ฒด์ธ ๊ตฌํ (๊ธฐ๋ณธ์ ์ธ ๊ธฐ๋ฅ๋ง ํฌํจ) - ์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ |
|
""" |
|
import os |
|
import logging |
|
import time |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from langchain.schema import Document |
|
|
|
|
|
from direct_deepseek import DirectDeepSeekClient |
|
|
|
|
|
from config import ( |
|
LLM_MODEL, USE_OPENAI, USE_DEEPSEEK, |
|
DEEPSEEK_API_KEY, DEEPSEEK_ENDPOINT, DEEPSEEK_MODEL, |
|
TOP_K_RETRIEVAL |
|
) |
|
|
|
|
|
logger = logging.getLogger("FallbackRAGChain") |
|
|
|
class FallbackRAGChain: |
|
""" |
|
๊ธฐ๋ณธ์ ์ธ RAG ์ฒด์ธ ๊ตฌํ (๋จ์ํ๋ ๋ฒ์ , ๋ฌธ์ ํด๊ฒฐ์ฉ) |
|
์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ ์ฌ์ฉ |
|
""" |
|
|
|
def __init__(self, vector_store): |
|
""" |
|
RAG ์ฒด์ธ ์ด๊ธฐํ |
|
|
|
Args: |
|
vector_store: ๋ฒกํฐ ์คํ ์ด ์ธ์คํด์ค |
|
""" |
|
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ด๊ธฐํ...") |
|
self.vector_store = vector_store |
|
|
|
|
|
if USE_DEEPSEEK and DEEPSEEK_API_KEY: |
|
logger.info(f"DeepSeek ๋ชจ๋ธ ์ง์ ์ด๊ธฐํ: {DEEPSEEK_MODEL}") |
|
try: |
|
self.client = DirectDeepSeekClient( |
|
api_key=DEEPSEEK_API_KEY, |
|
model_name=DEEPSEEK_MODEL |
|
) |
|
logger.info("DeepSeek ๋ชจ๋ธ ์ง์ ์ด๊ธฐํ ์ฑ๊ณต") |
|
except Exception as e: |
|
logger.error(f"DeepSeek ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {e}") |
|
|
|
self.client = None |
|
logger.warning("LLM์ด ์ด๊ธฐํ๋์ง ์์ ์คํ๋ผ์ธ ๋ชจ๋๋ก ๋์ํฉ๋๋ค.") |
|
else: |
|
|
|
logger.warning("LLM์ด ์ค์ ๋์ง ์์ ์คํ๋ผ์ธ ๋ชจ๋๋ก ๋์ํฉ๋๋ค.") |
|
self.client = None |
|
|
|
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") |
|
|
|
def _retrieve(self, query: str) -> str: |
|
""" |
|
์ฟผ๋ฆฌ์ ๋ํ ๊ด๋ จ ๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ
์คํธ ๊ตฌ์ฑ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌํจํ ์ปจํ
์คํธ ๋ฌธ์์ด |
|
""" |
|
if not query or not query.strip(): |
|
return "๊ฒ์ ์ฟผ๋ฆฌ๊ฐ ๋น์ด์์ต๋๋ค." |
|
|
|
try: |
|
|
|
logger.info(f"๋ฒกํฐ ๊ฒ์: '{query[:30]}...'") |
|
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) |
|
|
|
if not docs: |
|
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(docs, 1): |
|
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") |
|
page = doc.metadata.get("page", "") |
|
source_info = f"{source}" |
|
if page: |
|
source_info += f" (ํ์ด์ง: {page})" |
|
|
|
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") |
|
|
|
context = "\n".join(context_parts) |
|
|
|
|
|
if len(context) > 6000: |
|
logger.warning(f"์ปจํ
์คํธ๊ฐ ๋๋ฌด ๊น๋๋ค ({len(context)} ๋ฌธ์). ์ ํํฉ๋๋ค.") |
|
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] |
|
|
|
return context |
|
|
|
except Exception as e: |
|
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") |
|
return f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" |
|
|
|
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]: |
|
""" |
|
ํ๋กฌํํธ ์์ฑ (DeepSeek API ํ์) |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
context: ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ
์คํธ |
|
|
|
Returns: |
|
DeepSeek API์ฉ messages ํ์ |
|
""" |
|
|
|
system_prompt = """๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์์ผ๋ฉด ๋ฐ๋์ ๊ทธ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ต๋ณํ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ์๋ ์ผ๋ฐ์ ์ธ ์ง์์ ํ์ฉํ์ฌ ๋ต๋ณํ ์ ์์ง๋ง, "์ ๊ณต๋ ๋ฌธ์์๋ ์ด ์ ๋ณด๊ฐ ์์ผ๋, ์ผ๋ฐ์ ์ผ๋ก๋..." ์์ผ๋ก ์์ํ์ธ์. |
|
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ๊ฐ๋ฅํ ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช
ํด์ฃผ์ธ์. |
|
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์.""" |
|
|
|
|
|
user_prompt = f"""์ง๋ฌธ: {query} |
|
|
|
์ฐธ๊ณ ์ ๋ณด: |
|
{context}""" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt} |
|
] |
|
|
|
return messages |
|
|
|
def _generate_simple_response(self, query: str, context: str) -> str: |
|
""" |
|
๊ฐ๋จํ ์คํ๋ผ์ธ ์๋ต ์์ฑ (LLM์ด ์์ ๋ ์ฌ์ฉ) |
|
""" |
|
|
|
predefined_answers = { |
|
"๋ํ๋ฏผ๊ตญ์ ์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์
๋๋ค.", |
|
"์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์
๋๋ค.", |
|
"๋๊ตฌ์ผ": "์ ๋ RAG ๊ธฐ๋ฐ ์ง์์๋ต ์์คํ
์
๋๋ค. ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์๋๋ฆฝ๋๋ค.", |
|
"์๋
": "์๋
ํ์ธ์! ๋ฌด์์ ๋์๋๋ฆด๊น์?", |
|
"๋ญํด": "์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ต๋ณํ๊ธฐ ์ํด ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ์์ต๋๋ค. ๋ฌด์์ ์๋ ค๋๋ฆด๊น์?" |
|
} |
|
|
|
|
|
for key, answer in predefined_answers.items(): |
|
if key in query.lower(): |
|
return answer |
|
|
|
|
|
return f""" |
|
ํ์ฌ LLM API ์ฐ๊ฒฐ์ ๋ฌธ์ ๊ฐ ์์ด ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ํฉ๋๋ค. |
|
|
|
์ง๋ฌธ: {query} |
|
|
|
๊ฒ์๋ ๊ด๋ จ ๋ฌธ์: |
|
{context} |
|
|
|
[์ฐธ๊ณ ] ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์ผ์
จ๋์? API ์ฐ๊ฒฐ ๋ฌธ์ ๋ก ์ธํด ์๋ ์์ฝ์ด ์ ๊ณต๋์ง ์์ต๋๋ค. ๋ค์ ์๋ํ๊ฑฐ๋ ๋ค๋ฅธ ์ง๋ฌธ์ ํด๋ณด์ธ์. |
|
""" |
|
|
|
def run(self, query: str) -> str: |
|
""" |
|
์ฌ์ฉ์ ์ฟผ๋ฆฌ์ ๋ํ RAG ํ์ดํ๋ผ์ธ ์คํ |
|
|
|
Args: |
|
query: ์ฌ์ฉ์ ์ง๋ฌธ |
|
|
|
Returns: |
|
๋ชจ๋ธ ์๋ต ๋ฌธ์์ด |
|
""" |
|
if not query or not query.strip(): |
|
return "์ง๋ฌธ์ด ๋น์ด์์ต๋๋ค. ์ง๋ฌธ์ ์
๋ ฅํด ์ฃผ์ธ์." |
|
|
|
try: |
|
logger.info(f"RAG ์ฒด์ธ ์คํ: '{query[:30]}...'") |
|
|
|
|
|
context = self._retrieve(query) |
|
|
|
|
|
if self.client is None: |
|
logger.warning("LLM์ด ์ด๊ธฐํ๋์ง ์์ ์คํ๋ผ์ธ ์๋ต ์์ฑ") |
|
return self._generate_simple_response(query, context) |
|
|
|
|
|
messages = self._generate_prompt(query, context) |
|
|
|
|
|
max_retries = 3 |
|
retry_delay = 1.0 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
logger.info(f"์๋ต ์์ฑ ์๋ ({attempt+1}/{max_retries})") |
|
|
|
|
|
response = self.client.chat(messages) |
|
|
|
if response["success"]: |
|
logger.info(f"์๋ต ์์ฑ ์ฑ๊ณต (๊ธธ์ด: {len(response['response'])})") |
|
return response["response"] |
|
else: |
|
logger.error(f"์๋ต ์์ฑ ์คํจ: {response['message']}") |
|
if attempt < max_retries - 1: |
|
logger.info(f"{retry_delay}์ด ํ ์ฌ์๋...") |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
else: |
|
|
|
logger.warning("์ต๋ ์ฌ์๋ ํ์ ์ด๊ณผ, ์คํ๋ผ์ธ ์๋ต ์์ฑ") |
|
return self._generate_simple_response(query, context) |
|
except Exception as e: |
|
logger.error(f"์๋ต ์์ฑ ์ค ์ค๋ฅ: {e}") |
|
if attempt < max_retries - 1: |
|
logger.info(f"{retry_delay}์ด ํ ์ฌ์๋...") |
|
time.sleep(retry_delay) |
|
retry_delay *= 2 |
|
else: |
|
|
|
return self._generate_simple_response(query, context) |
|
|
|
except Exception as e: |
|
logger.error(f"RAG ์ฒด์ธ ์คํ ์ค ์ค๋ฅ: {e}") |
|
return f"์ง๋ฌธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" |