Spaces:
Paused
Paused
""" | |
๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ (๋๋ฒ๊น ์ฉ) - ์ง์ DeepSeek API ํธ์ถ ๋ฐฉ์ | |
""" | |
import os | |
import logging | |
import time | |
from typing import Dict, Any, List | |
# ์ง์ DeepSeek ํด๋ผ์ด์ธํธ ์ฌ์ฉ | |
from direct_deepseek import DirectDeepSeekClient | |
# ๋ก๊น ์ค์ | |
logger = logging.getLogger("SimpleRAGChain") | |
class SimpleRAGChain: | |
def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None): | |
"""๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ""" | |
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์ค...") | |
self.vector_store = vector_store | |
# DeepSeek API ํค ํ์ธ | |
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "") | |
self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") | |
logger.info(f"API ํค ์ค์ ๋จ: {bool(self.api_key)}") | |
# DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ | |
if self.api_key: | |
try: | |
self.client = DirectDeepSeekClient( | |
api_key=self.api_key, | |
model_name=self.model | |
) | |
logger.info(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์ฑ๊ณต: {self.model}") | |
except Exception as e: | |
logger.error(f"DeepSeek ํด๋ผ์ด์ธํธ ์ด๊ธฐํ ์คํจ: {e}") | |
self.client = None | |
else: | |
logger.warning("API ํค๊ฐ ์ค์ ๋์ง ์์ ํด๋ผ์ด์ธํธ๋ฅผ ์ด๊ธฐํํ ์ ์์ต๋๋ค.") | |
self.client = None | |
logger.info("๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ") | |
def _retrieve(self, query: str) -> str: | |
"""๋ฌธ์ ๊ฒ์ ๋ฐ ์ปจํ ์คํธ ๊ตฌ์ฑ""" | |
try: | |
docs = self.vector_store.similarity_search(query, k=3) | |
if not docs: | |
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
# ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ ์คํธ ๊ตฌ์ฑ | |
context_parts = [] | |
for i, doc in enumerate(docs, 1): | |
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ") | |
page = doc.metadata.get("page", "") | |
source_info = f"{source}" | |
if page: | |
source_info += f" (ํ์ด์ง: {page})" | |
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n") | |
context = "\n".join(context_parts) | |
# ๊ธธ์ด ์ ํ | |
if len(context) > 6000: | |
context = context[:2500] + "\n...(์ค๋ต)...\n" + context[-2500:] | |
return context | |
except Exception as e: | |
logger.error(f"๊ฒ์ ์ค ์ค๋ฅ: {e}") | |
return "๋ฌธ์ ๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค." | |
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]: | |
"""DeepSeek API์ฉ ํ๋กฌํํธ ์์ฑ""" | |
# ์์คํ ํ๋กฌํํธ | |
system_prompt = """๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์. | |
์ฐธ๊ณ ์ ๋ณด์์ ๋ต์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. | |
์ ๋ณด ์ถ์ฒ๋ฅผ ํฌํจํด์ ๋๋ตํ์ธ์.""" | |
# ์ฌ์ฉ์ ํ๋กฌํํธ | |
user_prompt = f"""์ง๋ฌธ: {query} | |
์ฐธ๊ณ ์ ๋ณด: | |
{context}""" | |
# DeepSeek API ํ๋กฌํํธ ํฌ๋งท | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
return messages | |
def run(self, query: str) -> str: | |
"""์ฟผ๋ฆฌ ์ฒ๋ฆฌ""" | |
try: | |
logger.info(f"SimpleRAGChain ์คํ: {query[:50]}...") | |
# ๋ฌธ์ ๊ฒ์ | |
context = self._retrieve(query) | |
# ํด๋ผ์ด์ธํธ๊ฐ ์ด๊ธฐํ๋์ง ์์ ๊ฒฝ์ฐ | |
if self.client is None: | |
logger.warning("DeepSeek ํด๋ผ์ด์ธํธ๊ฐ ์ด๊ธฐํ๋์ง ์์. ๊ฒ์ ๊ฒฐ๊ณผ๋ง ๋ฐํ.") | |
return f"API ์ฐ๊ฒฐ์ด ์ค์ ๋์ง ์์์ต๋๋ค. ๊ฒ์ ๊ฒฐ๊ณผ:\n\n{context}" | |
# ํ๋กฌํํธ ์์ฑ | |
messages = self._generate_prompt(query, context) | |
# API ํธ์ถ | |
start_time = time.time() | |
response = self.client.chat(messages) | |
logger.info(f"API ์๋ต ์๊ฐ: {time.time() - start_time:.2f}์ด") | |
if response["success"]: | |
logger.info("์๋ต ์์ฑ ์ฑ๊ณต") | |
return response["response"] | |
else: | |
logger.error(f"์๋ต ์์ฑ ์คํจ: {response['message']}") | |
return f"์๋ต ์์ฑ ์คํจ: {response['message']}\n\n๊ฒ์ ๊ฒฐ๊ณผ:\n{context}" | |
except Exception as e: | |
logger.error(f"์คํ ์ค ์ค๋ฅ: {e}") | |
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |