File size: 2,153 Bytes
06696b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# rag_pipeline_graph/graph.py
from langgraph.graph import StateGraph, START, END
from nodes.check_cache import check_cache
from nodes.search_docs import search_documents
from nodes.rerank import rerank_documents
from nodes.build_prompt import build_prompt
from nodes.call_llm import call_llm
from nodes.cache_save import save_response
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
from type.state_schema import RAGState

# class RAGState(BaseModel):
#     query: str
#     top_k: int = 5
#     cached_response: Optional[str] = None
#     retrieved_docs: List[Dict[str, Any]] = []
#     reranked_docs: List[Dict[str, Any]] = []
#     prompt: str = ""
#     final_response: str = ""

def get_rag_pipeline_graph():
    workflow = StateGraph(RAGState)

    # λ…Έλ“œ 등둝
    workflow.add_node("CheckCache", check_cache)
    workflow.add_node("SearchDocs", search_documents)
    workflow.add_node("Rerank", rerank_documents)
    workflow.add_node("BuildPrompt", build_prompt)
    workflow.add_node("CallLLM", call_llm)
    workflow.add_node("SaveResponse", save_response)

    # 흐름 ꡬ성
    workflow.set_entry_point("CheckCache")
    workflow.add_conditional_edges(
        "CheckCache",
        lambda state: "Return" if state.cached_response else "SearchDocs",
        {"Return": END, "SearchDocs": "SearchDocs"}
    )

    workflow.add_edge("SearchDocs", "Rerank")
    workflow.add_edge("Rerank", "BuildPrompt")
    workflow.add_edge("BuildPrompt", "CallLLM")
    workflow.add_edge("CallLLM", "SaveResponse")
    workflow.add_edge("SaveResponse", END)

    return workflow.compile()

if __name__ == "__main__":
    graph = get_rag_pipeline_graph()

    input_data = RAGState(
        query="κ³΅μΈμ€‘κ°œμ‚¬λ²•λ Ήμƒ μ€‘κ°œμ‚¬λ¬΄μ†Œμ˜ κ°œμ„€λ“±λ‘μ— κ΄€ν•œ μ„€λͺ…",
        top_k=5
    )

    final_state = graph.invoke(input_data)
    # print("\n🧠 μ΅œμ’… 응닡:", final_state.final_response)
    # print("\n🧠 μ΅œμ’… 응닡:", final_state["final_response"])
    print("\n🧠 μ΅œμ’… 응닡:", final_state.get("final_response", "[응닡 μ—†μŒ]"))