Spaces:
Runtime error
Runtime error
# from retriever.vectordb import search_documents | |
from retriever.vectordb_rerank import search_documents | |
from retriever.vectordb_rerank_law import search_documents as search_law | |
from retriever.vectordb_rerank_exam import search_documents as search_exam | |
from generator.prompt_builder import build_prompt | |
from generator.prompt_builder_all import build_prompt as build_prompt_all | |
from generator.llm_inference import generate_answer as generate_answer | |
from generator.llm_inference_all import generate_answer as generate_answer_all | |
# 2. ์บ์ ๊ด๋ฆฌ | |
search_cache = {} | |
def rag_pipeline(query: str, top_k: int = 5) -> str: | |
""" | |
1. ์ฌ์ฉ์ ์ง๋ฌธ์ผ๋ก ๊ด๋ จ ๋ฌธ์๋ฅผ ๊ฒ์ | |
2. ๊ฒ์๋ ๋ฌธ์์ ํจ๊ป ํ๋กฌํํธ ๊ตฌ์ฑ | |
3. ํ๋กฌํํธ๋ก๋ถํฐ ๋ต๋ณ ์์ฑ | |
""" | |
# ์บ์ ํ์ธ | |
if query in search_cache: | |
print(f"โก ์บ์ ์ฌ์ฉ: '{query}'") | |
return search_cache[query] | |
# 1. ๊ฒ์ | |
# context_docs = search_documents(query, top_k=top_k) | |
# print("context_docs: ", context_docs) | |
# print("==============================================\n\n") | |
context_exam_docs = search_exam(query, top_k=top_k) | |
print("context_exam_docs: ", context_exam_docs) | |
print("==============================================\n\n") | |
constext_law_docs = search_law(query, top_k=top_k) | |
print("context_law_docs: ", constext_law_docs) | |
print("==============================================\n\n") | |
# 2. ํ๋กฌํํธ ์กฐ๋ฆฝ | |
prompt = build_prompt_all(query, constext_law_docs, context_exam_docs) | |
print("prompt: ", prompt) | |
print("==============================================\n\n") | |
# 3. ๋ชจ๋ธ ์ถ๋ก | |
output = generate_answer(prompt) | |
# return output | |
if isinstance(context_exam_docs, list): | |
context_exam_docs = "\n\n".join(context_exam_docs) | |
search_cache[query] = output | |
return output | |
# ์์ ์ฟผ๋ฆฌ | |
if __name__ == "__main__": | |
query = "์ค๊ฐ์ ์๊ฐ ์ฌ๋ฌด์๋ฅผ ์ฎ๊ฒผ์ ๋ ํ์ํ ์กฐ์น" | |
top_k = 5 | |
result = rag_pipeline(query, top_k) | |
print(result) |