File size: 2,140 Bytes
06696b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# from retriever.vectordb import search_documents
from retriever.vectordb_rerank import search_documents
from retriever.vectordb_rerank_law import search_documents as search_law
from retriever.vectordb_rerank_exam import search_documents as search_exam

from generator.prompt_builder import build_prompt
from generator.prompt_builder_all import build_prompt as build_prompt_all

from generator.llm_inference import generate_answer as generate_answer
from generator.llm_inference_all import generate_answer as generate_answer_all

# 2. ์บ์‹œ ๊ด€๋ฆฌ
search_cache = {}

def rag_pipeline(query: str, top_k: int = 5) -> str:
    """

    1. ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์œผ๋กœ ๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰

    2. ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ์™€ ํ•จ๊ป˜ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ

    3. ํ”„๋กฌํ”„ํŠธ๋กœ๋ถ€ํ„ฐ ๋‹ต๋ณ€ ์ƒ์„ฑ

    """
    # ์บ์‹œ ํ™•์ธ
    if query in search_cache:
        print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'")
        return search_cache[query]
    
    # 1. ๊ฒ€์ƒ‰
    # context_docs = search_documents(query, top_k=top_k)
    # print("context_docs: ", context_docs)
    # print("==============================================\n\n")

    context_exam_docs = search_exam(query, top_k=top_k)
    print("context_exam_docs: ", context_exam_docs)
    print("==============================================\n\n")

    constext_law_docs = search_law(query, top_k=top_k)
    print("context_law_docs: ", constext_law_docs)
    print("==============================================\n\n")

    # 2. ํ”„๋กฌํ”„ํŠธ ์กฐ๋ฆฝ
    prompt = build_prompt_all(query, constext_law_docs, context_exam_docs)
    print("prompt: ", prompt)    
    print("==============================================\n\n")

    # 3. ๋ชจ๋ธ ์ถ”๋ก 
    output = generate_answer(prompt)
    # return output

    if isinstance(context_exam_docs, list):
        context_exam_docs = "\n\n".join(context_exam_docs)

    search_cache[query] = output
    return output

# ์˜ˆ์‹œ ์ฟผ๋ฆฌ
if __name__ == "__main__":
    query = "์ค‘๊ฐœ์—…์ž๊ฐ€ ์‚ฌ๋ฌด์†Œ๋ฅผ ์˜ฎ๊ฒผ์„ ๋•Œ ํ•„์š”ํ•œ ์กฐ์น˜"
    top_k = 5
    result = rag_pipeline(query, top_k)
    print(result)