Spaces:
Runtime error
Runtime error
File size: 2,140 Bytes
06696b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# from retriever.vectordb import search_documents
from retriever.vectordb_rerank import search_documents
from retriever.vectordb_rerank_law import search_documents as search_law
from retriever.vectordb_rerank_exam import search_documents as search_exam
from generator.prompt_builder import build_prompt
from generator.prompt_builder_all import build_prompt as build_prompt_all
from generator.llm_inference import generate_answer as generate_answer
from generator.llm_inference_all import generate_answer as generate_answer_all
# 2. ์บ์ ๊ด๋ฆฌ
search_cache = {}
def rag_pipeline(query: str, top_k: int = 5) -> str:
"""
1. ์ฌ์ฉ์ ์ง๋ฌธ์ผ๋ก ๊ด๋ จ ๋ฌธ์๋ฅผ ๊ฒ์
2. ๊ฒ์๋ ๋ฌธ์์ ํจ๊ป ํ๋กฌํํธ ๊ตฌ์ฑ
3. ํ๋กฌํํธ๋ก๋ถํฐ ๋ต๋ณ ์์ฑ
"""
# ์บ์ ํ์ธ
if query in search_cache:
print(f"โก ์บ์ ์ฌ์ฉ: '{query}'")
return search_cache[query]
# 1. ๊ฒ์
# context_docs = search_documents(query, top_k=top_k)
# print("context_docs: ", context_docs)
# print("==============================================\n\n")
context_exam_docs = search_exam(query, top_k=top_k)
print("context_exam_docs: ", context_exam_docs)
print("==============================================\n\n")
constext_law_docs = search_law(query, top_k=top_k)
print("context_law_docs: ", constext_law_docs)
print("==============================================\n\n")
# 2. ํ๋กฌํํธ ์กฐ๋ฆฝ
prompt = build_prompt_all(query, constext_law_docs, context_exam_docs)
print("prompt: ", prompt)
print("==============================================\n\n")
# 3. ๋ชจ๋ธ ์ถ๋ก
output = generate_answer(prompt)
# return output
if isinstance(context_exam_docs, list):
context_exam_docs = "\n\n".join(context_exam_docs)
search_cache[query] = output
return output
# ์์ ์ฟผ๋ฆฌ
if __name__ == "__main__":
query = "์ค๊ฐ์
์๊ฐ ์ฌ๋ฌด์๋ฅผ ์ฎ๊ฒผ์ ๋ ํ์ํ ์กฐ์น"
top_k = 5
result = rag_pipeline(query, top_k)
print(result) |