RAG_voice / simple_rag_chain.py
jeongsoo's picture
Add application file
4a98f26
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋””๋ฒ„๊น…์šฉ)
"""
import os
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
class SimpleRAGChain:
def __init__(self, vector_store):
"""๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
print("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
self.vector_store = vector_store
# OpenAI API ํ‚ค ํ™•์ธ
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
print(f"API ํ‚ค ์„ค์ •๋จ: {bool(openai_api_key)}")
# OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.2,
api_key=openai_api_key,
)
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
template = """
๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.
์งˆ๋ฌธ: {question}
์ฐธ๊ณ  ์ •๋ณด:
{context}
์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
"""
self.prompt = PromptTemplate.from_template(template)
# ์ฒด์ธ ๊ตฌ์„ฑ
self.chain = (
{"context": self._retrieve, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
print("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
def _retrieve(self, query):
"""๋ฌธ์„œ ๊ฒ€์ƒ‰"""
try:
docs = self.vector_store.similarity_search(query, k=3)
return "\n\n".join(doc.page_content for doc in docs)
except Exception as e:
print(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
def run(self, query):
"""์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
try:
return self.chain.invoke(query)
except Exception as e:
print(f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"