Spaces:
Paused
Paused
""" | |
ํด๋ฐฑ RAG ์ฒด์ธ ๊ตฌํ (๊ธฐ๋ณธ์ ์ธ ๊ธฐ๋ฅ๋ง ํฌํจ) - ์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ | |
""" | |
import os | |
import logging | |
import time | |
from typing import List, Dict, Any, Optional, Tuple | |
from langchain.schema import Document | |
# ์ง์ DeepSeek ํด๋ผ์ด์ธํธ ์ฌ์ฉ | |
from direct_deepseek import DirectDeepSeekClient | |
# ์ค์ ๊ฐ์ ธ์ค๊ธฐ | |
from config import ( | |
LLM_MODEL, USE_OPENAI, USE_DEEPSEEK, | |
DEEPSEEK_API_KEY, DEEPSEEK_ENDPOINT, DEEPSEEK_MODEL, | |
TOP_K_RETRIEVAL | |
) | |
# ๋ก๊น ์ค์ | |
logger = logging.getLogger("FallbackRAGChain") | |
class FallbackRAGChain: | |
""" | |
๊ธฐ๋ณธ์ ์ธ RAG ์ฒด์ธ ๊ตฌํ (๋จ์ํ๋ ๋ฒ์ , ๋ฌธ์ ํด๊ฒฐ์ฉ) | |
์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ ์ฌ์ฉ | |
""" | |
def __init__(self, vector_store): | |
""" | |
RAG ์ฒด์ธ ์ด๊ธฐํ | |
Args: | |
vector_store: ๋ฒกํฐ ์คํ ์ด ์ธ์คํด์ค | |
""" | |
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ด๊ธฐํ...") | |
self.vector_store = vector_store | |
# DeepSeek ๋ชจ๋ธ ์ง์ ์ด๊ธฐํ | |
if USE_DEEPSEEK and DEEPSEEK_API_KEY: | |
logger.info(f"DeepSeek ๋ชจ๋ธ ์ง์ ์ด๊ธฐํ: {DEEPSEEK_MODEL}") | |
try: | |
self.client = DirectDeepSeekClient( | |
api_key=DEEPSEEK_API_KEY, | |
model_name=DEEPSEEK_MODEL | |
) | |
logger.info("DeepSeek ๋ชจ๋ธ ์ง์ ์ด๊ธฐํ ์ฑ๊ณต") | |
except Exception as e: | |
logger.error(f"DeepSeek ๋ชจ๋ธ ์ด๊ธฐํ ์คํจ: {e}") | |
# ์คํ๋ผ์ธ ๋ชจ๋๋ก ํด๋ฐฑ | |
self.client = None | |
logger.warning("LLM์ด ์ด๊ธฐํ๋์ง ์์ ์คํ๋ผ์ธ ๋ชจ๋๋ก ๋์ํฉ๋๋ค.") | |
else: | |
# LLM์ด ์ค์ ๋์ง ์์ | |
logger.warning("LLM์ด ์ค์ ๋์ง ์์ ์คํ๋ผ์ธ ๋ชจ๋๋ก ๋์ํฉ๋๋ค.") | |
self.client = None | |
logger.info("ํด๋ฐฑ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") | |
def _retrieve(self, query: str) -> str: | |
""" | |
์ฟผ๋ฆฌ์ ๋ํ ๊ด๋ จ ๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ ์คํธ ๊ตฌ์ฑ | |
Args: | |
query: ์ฌ์ฉ์ ์ง๋ฌธ | |
Returns: | |
๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌํจํ ์ปจํ ์คํธ ๋ฌธ์์ด | |
""" | |
if not query or not query.strip(): | |
return "๊ฒ์ ์ฟผ๋ฆฌ๊ฐ ๋น์ด์์ต๋๋ค." | |
try: | |
# ๋ฒกํฐ ๊ฒ์ ์ํ | |
logger.info(f"๋ฒกํฐ ๊ฒ์: '{query[:30]}...'") | |
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL) | |
if not docs: | |
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
# ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ ์คํธ ๊ตฌ์ฑ | |
context_parts = [] | |
for i, doc in enumerate(docs, 1): | |
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") | |
page = doc.metadata.get("page", "") | |
source_info = f"{source}" | |
if page: | |
source_info += f" (ํ์ด์ง: {page})" | |
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") | |
context = "\n".join(context_parts) | |
# ์ปจํ ์คํธ ๊ธธ์ด ์ ํ (ํ ํฐ ์ ์ ํ) | |
if len(context) > 6000: | |
logger.warning(f"์ปจํ ์คํธ๊ฐ ๋๋ฌด ๊น๋๋ค ({len(context)} ๋ฌธ์). ์ ํํฉ๋๋ค.") | |
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] | |
return context | |
except Exception as e: | |
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") | |
return f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]: | |
""" | |
ํ๋กฌํํธ ์์ฑ (DeepSeek API ํ์) | |
Args: | |
query: ์ฌ์ฉ์ ์ง๋ฌธ | |
context: ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ ์คํธ | |
Returns: | |
DeepSeek API์ฉ messages ํ์ | |
""" | |
# ์์คํ ํ๋กฌํํธ | |
system_prompt = """๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์์ผ๋ฉด ๋ฐ๋์ ๊ทธ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ต๋ณํ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ์๋ ์ผ๋ฐ์ ์ธ ์ง์์ ํ์ฉํ์ฌ ๋ต๋ณํ ์ ์์ง๋ง, "์ ๊ณต๋ ๋ฌธ์์๋ ์ด ์ ๋ณด๊ฐ ์์ผ๋, ์ผ๋ฐ์ ์ผ๋ก๋..." ์์ผ๋ก ์์ํ์ธ์. | |
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ๊ฐ๋ฅํ ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช ํด์ฃผ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์.""" | |
# ์ฌ์ฉ์ ํ๋กฌํํธ (์ง๋ฌธ๊ณผ ์ปจํ ์คํธ ํฌํจ) | |
user_prompt = f"""์ง๋ฌธ: {query} | |
์ฐธ๊ณ ์ ๋ณด: | |
{context}""" | |
# DeepSeek API์ ๋ง๋ ๋ฉ์์ง ํฌ๋งท | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
return messages | |
def _generate_simple_response(self, query: str, context: str) -> str: | |
""" | |
๊ฐ๋จํ ์คํ๋ผ์ธ ์๋ต ์์ฑ (LLM์ด ์์ ๋ ์ฌ์ฉ) | |
""" | |
# ๊ธฐ๋ณธ ์ ๊ณต ์๋ต (์ผ๋ฐ์ ์ธ ์ง๋ฌธ์ ๋ํ ์ ํด์ง ์๋ต) | |
predefined_answers = { | |
"๋ํ๋ฏผ๊ตญ์ ์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์ ๋๋ค.", | |
"์๋": "๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์ ๋๋ค.", | |
"๋๊ตฌ์ผ": "์ ๋ RAG ๊ธฐ๋ฐ ์ง์์๋ต ์์คํ ์ ๋๋ค. ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์๋๋ฆฝ๋๋ค.", | |
"์๋ ": "์๋ ํ์ธ์! ๋ฌด์์ ๋์๋๋ฆด๊น์?", | |
"๋ญํด": "์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ต๋ณํ๊ธฐ ์ํด ๋ฌธ์๋ฅผ ๊ฒ์ํ๊ณ ์์ต๋๋ค. ๋ฌด์์ ์๋ ค๋๋ฆด๊น์?" | |
} | |
# ์ง๋ฌธ์ ๋ง๋ ๋ฏธ๋ฆฌ ์ ์๋ ์๋ต์ด ์๋์ง ํ์ธ | |
for key, answer in predefined_answers.items(): | |
if key in query.lower(): | |
return answer | |
# ๋ฏธ๋ฆฌ ์ ์๋ ์๋ต์ด ์์ผ๋ฉด ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ | |
return f""" | |
ํ์ฌ LLM API ์ฐ๊ฒฐ์ ๋ฌธ์ ๊ฐ ์์ด ๊ฒ์ ๊ฒฐ๊ณผ๋ง ํ์ํฉ๋๋ค. | |
์ง๋ฌธ: {query} | |
๊ฒ์๋ ๊ด๋ จ ๋ฌธ์: | |
{context} | |
[์ฐธ๊ณ ] ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์ผ์ จ๋์? API ์ฐ๊ฒฐ ๋ฌธ์ ๋ก ์ธํด ์๋ ์์ฝ์ด ์ ๊ณต๋์ง ์์ต๋๋ค. ๋ค์ ์๋ํ๊ฑฐ๋ ๋ค๋ฅธ ์ง๋ฌธ์ ํด๋ณด์ธ์. | |
""" | |
def run(self, query: str) -> str: | |
""" | |
์ฌ์ฉ์ ์ฟผ๋ฆฌ์ ๋ํ RAG ํ์ดํ๋ผ์ธ ์คํ | |
Args: | |
query: ์ฌ์ฉ์ ์ง๋ฌธ | |
Returns: | |
๋ชจ๋ธ ์๋ต ๋ฌธ์์ด | |
""" | |
if not query or not query.strip(): | |
return "์ง๋ฌธ์ด ๋น์ด์์ต๋๋ค. ์ง๋ฌธ์ ์ ๋ ฅํด ์ฃผ์ธ์." | |
try: | |
logger.info(f"RAG ์ฒด์ธ ์คํ: '{query[:30]}...'") | |
# ๋ฌธ์ ๊ฒ์ | |
context = self._retrieve(query) | |
# LLM์ด ์ด๊ธฐํ๋์ง ์์ ๊ฒฝ์ฐ ์คํ๋ผ์ธ ์๋ต | |
if self.client is None: | |
logger.warning("LLM์ด ์ด๊ธฐํ๋์ง ์์ ์คํ๋ผ์ธ ์๋ต ์์ฑ") | |
return self._generate_simple_response(query, context) | |
# ํ๋กฌํํธ ๊ตฌ์ฑ | |
messages = self._generate_prompt(query, context) | |
# ์๋ต ์์ฑ (์ต๋ 3ํ ์๋) | |
max_retries = 3 | |
retry_delay = 1.0 | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"์๋ต ์์ฑ ์๋ ({attempt+1}/{max_retries})") | |
# ์ง์ DeepSeek API ํธ์ถ | |
response = self.client.chat(messages) | |
if response["success"]: | |
logger.info(f"์๋ต ์์ฑ ์ฑ๊ณต (๊ธธ์ด: {len(response['response'])})") | |
return response["response"] | |
else: | |
logger.error(f"์๋ต ์์ฑ ์คํจ: {response['message']}") | |
if attempt < max_retries - 1: | |
logger.info(f"{retry_delay}์ด ํ ์ฌ์๋...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 | |
else: | |
# ๋ชจ๋ ์๋ ์คํจ ์ ์คํ๋ผ์ธ ์๋ต | |
logger.warning("์ต๋ ์ฌ์๋ ํ์ ์ด๊ณผ, ์คํ๋ผ์ธ ์๋ต ์์ฑ") | |
return self._generate_simple_response(query, context) | |
except Exception as e: | |
logger.error(f"์๋ต ์์ฑ ์ค ์ค๋ฅ: {e}") | |
if attempt < max_retries - 1: | |
logger.info(f"{retry_delay}์ด ํ ์ฌ์๋...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 | |
else: | |
# ๋ชจ๋ ์๋ ์คํจ ์ ์คํ๋ผ์ธ ์๋ต ์์ฑ | |
return self._generate_simple_response(query, context) | |
except Exception as e: | |
logger.error(f"RAG ์ฒด์ธ ์คํ ์ค ์ค๋ฅ: {e}") | |
return f"์ง๋ฌธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" |