File size: 2,088 Bytes
4a98f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ๊ตฌํ˜„ (๋””๋ฒ„๊น…์šฉ)
"""
import os
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


class SimpleRAGChain:
    def __init__(self, vector_store):
        """๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™”"""
        print("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์ค‘...")
        self.vector_store = vector_store

        # OpenAI API ํ‚ค ํ™•์ธ
        openai_api_key = os.environ.get("OPENAI_API_KEY", "")
        print(f"API ํ‚ค ์„ค์ •๋จ: {bool(openai_api_key)}")

        # OpenAI ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
        self.llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            temperature=0.2,
            api_key=openai_api_key,
        )

        # ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
        template = """
        ๋‹ค์Œ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์งˆ๋ฌธ์— ์ •ํ™•ํ•˜๊ฒŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”.

        ์งˆ๋ฌธ: {question}

        ์ฐธ๊ณ  ์ •๋ณด:
        {context}

        ์ฐธ๊ณ  ์ •๋ณด์— ๋‹ต์ด ์—†๋Š” ๊ฒฝ์šฐ "์ œ๊ณต๋œ ๋ฌธ์„œ์—์„œ ํ•ด๋‹น ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."๋ผ๊ณ  ๋‹ต๋ณ€ํ•˜์„ธ์š”.
        """

        self.prompt = PromptTemplate.from_template(template)

        # ์ฒด์ธ ๊ตฌ์„ฑ
        self.chain = (
                {"context": self._retrieve, "question": RunnablePassthrough()}
                | self.prompt
                | self.llm
                | StrOutputParser()
        )
        print("๊ฐ„๋‹จํ•œ RAG ์ฒด์ธ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")

    def _retrieve(self, query):
        """๋ฌธ์„œ ๊ฒ€์ƒ‰"""
        try:
            docs = self.vector_store.similarity_search(query, k=3)
            return "\n\n".join(doc.page_content for doc in docs)
        except Exception as e:
            print(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return "๋ฌธ์„œ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."

    def run(self, query):
        """์ฟผ๋ฆฌ ์ฒ˜๋ฆฌ"""
        try:
            return self.chain.invoke(query)
        except Exception as e:
            print(f"์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜: {e}")
            return f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"