Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
import gradio as gr | |
import nest_asyncio | |
import yaml | |
from trulens.core import TruSession | |
from src.mythesis_chatbot.evaluation import get_prebuilt_trulens_recorder | |
from src.mythesis_chatbot.rag_setup import ( | |
SupportedRags, | |
automerging_retrieval_setup, | |
basic_rag_setup, | |
sentence_window_retrieval_setup, | |
) | |
input_file_dir = Path(__file__).parents[1] / "data/" | |
save_dir = Path(__file__).parents[1] / "data/indices/" | |
config_dir = Path(__file__).parents[1] / "configs/" | |
welcome_message_path = Path(__file__).parents[1] / "spaces/welcome_message.md" | |
# Enables running async code inside an existing event loop without crashing. | |
nest_asyncio.apply() | |
tru = TruSession(database_url=os.getenv("SUPABASE_PROD_CONNECTION_STRING_IPV4")) | |
class ChatBot: | |
def __init__( | |
self, | |
input_file_dir, | |
save_dir, | |
config_dir, | |
): | |
self.recorder = None | |
self.previous_rag_mode = None | |
self.recorder = None | |
with open(os.path.join(config_dir, "basic.yaml")) as f: | |
self.basic_config = yaml.safe_load(f) | |
with open(os.path.join(config_dir, "auto_merging.yaml")) as f: | |
self.automerging_config = yaml.safe_load(f) | |
with open(os.path.join(config_dir, "sentence_window.yaml")) as f: | |
self.sentence_window_config = yaml.safe_load(f) | |
self.basic_engine = basic_rag_setup( | |
input_file=os.path.join(input_file_dir, self.basic_config["source_doc"]), | |
save_dir=save_dir, | |
**self.basic_config, | |
) | |
self.automerging_engine = automerging_retrieval_setup( | |
input_file=os.path.join( | |
input_file_dir, self.automerging_config["source_doc"] | |
), | |
save_dir=save_dir, | |
**self.automerging_config, | |
) | |
self.sentence_window_engine = sentence_window_retrieval_setup( | |
input_file=os.path.join( | |
input_file_dir, self.sentence_window_config["source_doc"] | |
), | |
save_dir=save_dir, | |
**self.sentence_window_config, | |
) | |
def __call__(self, query: str, rag_mode: SupportedRags): | |
match rag_mode: | |
case "classic retrieval": | |
if self.previous_rag_mode != rag_mode: | |
self.previous_rag_mode = rag_mode | |
self.recorder = get_prebuilt_trulens_recorder( | |
self.basic_engine, self.basic_config | |
) | |
with self.recorder as recording: # noqa: F841 | |
response = self.basic_engine.query(query) | |
case "auto-merging retrieval": | |
if self.previous_rag_mode != rag_mode: | |
self.previous_rag_mode = rag_mode | |
self.recorder = get_prebuilt_trulens_recorder( | |
self.automerging_engine, self.automerging_config | |
) | |
with self.recorder as recording: # noqa: F841 | |
response = self.automerging_engine.query(query) | |
case "sentence window retrieval": | |
if self.previous_rag_mode != rag_mode: | |
self.previous_rag_mode = rag_mode | |
self.recorder = get_prebuilt_trulens_recorder( | |
self.sentence_window_engine, self.sentence_window_config | |
) | |
with self.recorder as recording: # noqa: F841 | |
response = self.sentence_window_engine.query(query) | |
return response.response | |
chat_bot = ChatBot(input_file_dir, save_dir, config_dir) | |
default_message = ( | |
"Ask about a topic that is discussed in my master thesis." | |
" E.g., what is this master thesis about? Or what is epistemic uncertainty?" | |
) | |
with open(welcome_message_path, encoding="utf-8") as f: | |
description = f.read() | |
gradio_app = gr.Interface( | |
fn=chat_bot, | |
inputs=[ | |
gr.Textbox(placeholder=default_message, label="Query", lines=2), | |
gr.Dropdown( | |
choices=SupportedRags.__args__, | |
label="RAG mode", | |
value=SupportedRags.__args__[0], | |
), | |
], | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
], | |
title="RAG powered chatbot", | |
description=description, | |
) | |
gradio_app.launch() | |