docker-api / services /rag_pipeline.py
dasomaru's picture
Upload folder using huggingface_hub
06696b5 verified
# from retriever.vectordb import search_documents
from retriever.vectordb_rerank import search_documents
from retriever.vectordb_rerank_law import search_documents as search_law
from retriever.vectordb_rerank_exam import search_documents as search_exam
from generator.prompt_builder import build_prompt
from generator.prompt_builder_all import build_prompt as build_prompt_all
from generator.llm_inference import generate_answer as generate_answer
from generator.llm_inference_all import generate_answer as generate_answer_all
# 2. ์บ์‹œ ๊ด€๋ฆฌ
search_cache = {}
def rag_pipeline(query: str, top_k: int = 5) -> str:
"""
1. ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์œผ๋กœ ๊ด€๋ จ ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰
2. ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ์™€ ํ•จ๊ป˜ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
3. ํ”„๋กฌํ”„ํŠธ๋กœ๋ถ€ํ„ฐ ๋‹ต๋ณ€ ์ƒ์„ฑ
"""
# ์บ์‹œ ํ™•์ธ
if query in search_cache:
print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'")
return search_cache[query]
# 1. ๊ฒ€์ƒ‰
# context_docs = search_documents(query, top_k=top_k)
# print("context_docs: ", context_docs)
# print("==============================================\n\n")
context_exam_docs = search_exam(query, top_k=top_k)
print("context_exam_docs: ", context_exam_docs)
print("==============================================\n\n")
constext_law_docs = search_law(query, top_k=top_k)
print("context_law_docs: ", constext_law_docs)
print("==============================================\n\n")
# 2. ํ”„๋กฌํ”„ํŠธ ์กฐ๋ฆฝ
prompt = build_prompt_all(query, constext_law_docs, context_exam_docs)
print("prompt: ", prompt)
print("==============================================\n\n")
# 3. ๋ชจ๋ธ ์ถ”๋ก 
output = generate_answer(prompt)
# return output
if isinstance(context_exam_docs, list):
context_exam_docs = "\n\n".join(context_exam_docs)
search_cache[query] = output
return output
# ์˜ˆ์‹œ ์ฟผ๋ฆฌ
if __name__ == "__main__":
query = "์ค‘๊ฐœ์—…์ž๊ฐ€ ์‚ฌ๋ฌด์†Œ๋ฅผ ์˜ฎ๊ฒผ์„ ๋•Œ ํ•„์š”ํ•œ ์กฐ์น˜"
top_k = 5
result = rag_pipeline(query, top_k)
print(result)