File size: 4,949 Bytes
a76f77b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
์ฑ์ ๋ด์ฅ๋ ๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ
"""
from typing import List, Dict, Any, Optional
import os
from config import OPENAI_API_KEY, LLM_MODEL, USE_OPENAI, TOP_K_RETRIEVAL
# ์์ ํ ์ํฌํธ
try:
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
LANGCHAIN_IMPORTS_AVAILABLE = True
except ImportError:
print("[APP_RAG] langchain ๊ด๋ จ ํจํค์ง๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.")
LANGCHAIN_IMPORTS_AVAILABLE = False
class SimpleRAGChain:
"""
๊ฐ๋จํ RAG ์ฒด์ธ ๊ตฌํ (์ฑ์ ๋ด์ฅ)
"""
def __init__(self, vector_store):
"""๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ"""
print("[APP_RAG] ๊ฐ๋จํ RAG ์ฒด์ธ ์ด๊ธฐํ ์ค...")
self.vector_store = vector_store
if not LANGCHAIN_IMPORTS_AVAILABLE:
print("[APP_RAG] langchain ํจํค์ง๋ฅผ ์ฐพ์ ์ ์์ด RAG ์ฒด์ธ์ ์ด๊ธฐํํ ์ ์์ต๋๋ค.")
raise ImportError("RAG ์ฒด์ธ ์ด๊ธฐํ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค.")
# API ํค ํ์ธ
if not OPENAI_API_KEY and USE_OPENAI:
print("[APP_RAG] ๊ฒฝ๊ณ : OpenAI API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.")
raise ValueError("OpenAI API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.")
try:
# LLM ์ด๊ธฐํ
if USE_OPENAI:
self.llm = ChatOpenAI(
model_name=LLM_MODEL,
temperature=0.2,
api_key=OPENAI_API_KEY,
)
print(f"[APP_RAG] OpenAI ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}")
else:
try:
# Ollama ์ฌ์ฉ ์๋
from langchain_community.chat_models import ChatOllama
from config import OLLAMA_HOST
self.llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_HOST,
)
print(f"[APP_RAG] Ollama ๋ชจ๋ธ ์ด๊ธฐํ: {LLM_MODEL}")
except ImportError:
# Ollama ๊ฐ์ ธ์ค๊ธฐ ์คํจ ์ OpenAI ์ฌ์ฉ
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.2,
api_key=OPENAI_API_KEY,
)
print("[APP_RAG] Ollama๋ฅผ ์ฌ์ฉํ ์ ์์ด OpenAI๋ก ๋์ฒดํฉ๋๋ค.")
# ํ๋กฌํํธ ํ
ํ๋ฆฟ
template = """
๋ค์ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ง๋ฌธ์ ์ ํํ๊ฒ ๋ต๋ณํด์ฃผ์ธ์.
์ง๋ฌธ: {question}
์ฐธ๊ณ ์ ๋ณด:
{context}
์ฐธ๊ณ ์ ๋ณด์ ๋ต์ด ์๋ ๊ฒฝ์ฐ "์ ๊ณต๋ ๋ฌธ์์์ ํด๋น ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์.
๋ต๋ณ์ ์ ํํ๊ณ ๊ฐ๊ฒฐํ๊ฒ ์ ๊ณตํ๋, ์ฐธ๊ณ ์ ๋ณด์์ ๊ทผ๊ฑฐ๋ฅผ ์ฐพ์ ์ค๋ช
ํด์ฃผ์ธ์.
์ฐธ๊ณ ์ ๋ณด์ ์ถ์ฒ๋ ํจ๊ป ์๋ ค์ฃผ์ธ์.
"""
self.prompt = PromptTemplate.from_template(template)
# ์ฒด์ธ ๊ตฌ์ฑ
self.chain = (
{"context": self._retrieve, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
print("[APP_RAG] RAG ์ฒด์ธ ์ด๊ธฐํ ์๋ฃ")
except Exception as e:
print(f"[APP_RAG] RAG ์ฒด์ธ ์ด๊ธฐํ ์คํจ: {e}")
import traceback
traceback.print_exc()
raise
def _retrieve(self, query):
"""๋ฌธ์ ๊ฒ์"""
try:
docs = self.vector_store.similarity_search(query, k=TOP_K_RETRIEVAL)
# ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ
์คํธ ๊ตฌ์ฑ
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc.metadata.get("source", "์ ์ ์๋ ์ถ์ฒ")
page = doc.metadata.get("page", "")
source_info = f"{source}"
if page:
source_info += f" (ํ์ด์ง: {page})"
context_parts.append(f"[์ฐธ๊ณ ์๋ฃ {i}] - ์ถ์ฒ: {source_info}\n{doc.page_content}\n")
return "\n".join(context_parts)
except Exception as e:
print(f"[APP_RAG] ๊ฒ์ ์ค ์ค๋ฅ: {e}")
import traceback
traceback.print_exc()
return "๋ฌธ์ ๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค."
def run(self, query):
"""์ฟผ๋ฆฌ ์ฒ๋ฆฌ"""
try:
return self.chain.invoke(query)
except Exception as e:
print(f"[APP_RAG] ์คํ ์ค ์ค๋ฅ: {e}")
import traceback
traceback.print_exc()
return f"์ค๋ฅ ๋ฐ์: {str(e)}" |