RAG4_Voice_Fast / simple_rag_chain.py
jeongsoo's picture
Add greeting function to app.py
1f59ca4
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋””๋ฒ„๊น…์šฉ) - ์ง์ ‘ DeepSeek API ํ˜ธ์ถœ ๋ฐฉ์‹
"""
import os
import logging
import time
from typing import Dict, Any, List
# ์ง์ ‘ DeepSeek ํด๋ผ์ด์–ธํŠธ ์‚ฌ์šฉ
from direct_deepseek import DirectDeepSeekClient
# ๋กœ๊น… ์„ค์ •
logger = logging.getLogger("SimpleRAGChain")
class SimpleRAGChain:
def __init__(self, vector_store, api_key=None, model="deepseek-chat", endpoint=None):
"""๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
self.vector_store = vector_store
# DeepSeek API ํ‚ค ํ™•์ธ
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY", "")
self.model = model or os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
logger.info(f"API ํ‚ค ์„ค์ •๋จ: {bool(self.api_key)}")
# DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
if self.api_key:
try:
self.client = DirectDeepSeekClient(
api_key=self.api_key,
model_name=self.model
)
logger.info(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์„ฑ๊ณต: {self.model}")
except Exception as e:
logger.error(f"DeepSeek ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}")
self.client = None
else:
logger.warning("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•„ ํด๋ผ์ด์–ธํŠธ๋ฅผ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
self.client = None
logger.info("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
def _retrieve(self, query: str) -> str:
"""๋ฌธ์„œ ๊ฒ€์ƒ‰ ๋ฐ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ"""
try:
docs = self.vector_store.similarity_search(query, k=3)
if not docs:
return "๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ปจํ…์ŠคํŠธ ๊ตฌ์„ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์•Œ ์ˆ˜ ์—†๋Š” ์ถœ์ฒ˜")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํŽ˜์ด์ง€: {page})"
context_parts.append(f"[์ฐธ๊ณ ์ž๋ฃŒ {i}] - ์ถœ์ฒ˜: {source_info}\n{doc.page_content}\n")
context = "\n".join(context_parts)
# ๊ธธ์ด ์ œํ•œ
if len(context) > 6000:
context = context[:2500] + "\n...(์ค‘๋žต)...\n" + context[-2500:]
return context
except Exception as e:
logger.error(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
def _generate_prompt(self, query: str, context: str) -> List[Dict[str, str]]:
"""DeepSeek API์šฉ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
system_prompt = """๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์ฐธ๊ณ  ์ •๋ณด์—์„œ ๋‹ต์„ ์ฐพ์„ ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
์ •๋ณด ์ถœ์ฒ˜๋ฅผ ํฌํ•จํ•ด์„œ ๋Œ€๋‹ตํ•˜์„ธ์š”."""
# ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ
user_prompt = f"""์งˆ๋ฌธ: {query}
์ฐธ๊ณ  ์ •๋ณด:
{context}"""
# DeepSeek API ํ”„๋กฌํ”„ํŠธ ํฌ๋งท
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
return messages
def run(self, query: str) -> str:
"""์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
try:
logger.info(f"SimpleRAGChain ์‹คํ–‰: {query[:50]}...")
# ๋ฌธ์„œ ๊ฒ€์ƒ‰
context = self._retrieve(query)
# ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ
if self.client is None:
logger.warning("DeepSeek ํด๋ผ์ด์–ธํŠธ๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์Œ. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋งŒ ๋ฐ˜ํ™˜.")
return f"API ์—ฐ๊ฒฐ์ด ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n\n{context}"
# ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
messages = self._generate_prompt(query, context)
# API ํ˜ธ์ถœ
start_time = time.time()
response = self.client.chat(messages)
logger.info(f"API ์‘๋‹ต ์‹œ๊ฐ„: {time.time() - start_time:.2f}์ดˆ")
if response["success"]:
logger.info("์‘๋‹ต ์ƒ์„ฑ ์„ฑ๊ณต")
return response["response"]
else:
logger.error(f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}")
return f"์‘๋‹ต ์ƒ์„ฑ ์‹คํŒจ: {response['message']}\n\n๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ:\n{context}"
except Exception as e:
logger.error(f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"