Spaces:
Sleeping
Sleeping
File size: 4,313 Bytes
3944997 45e69ef 3944997 45e69ef 3944997 e917d8a 45e69ef 3944997 45e69ef 3944997 45e69ef 4fe2243 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef cdbb4e1 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 45e69ef 3944997 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
from pathlib import Path
import gradio as gr
import nest_asyncio
import yaml
from trulens.core import TruSession
from src.mythesis_chatbot.evaluation import get_prebuilt_trulens_recorder
from src.mythesis_chatbot.rag_setup import (
SupportedRags,
automerging_retrieval_setup,
basic_rag_setup,
sentence_window_retrieval_setup,
)
input_file_dir = Path(__file__).parents[1] / "data/"
save_dir = Path(__file__).parents[1] / "data/indices/"
config_dir = Path(__file__).parents[1] / "configs/"
welcome_message_path = Path(__file__).parents[1] / "spaces/welcome_message.md"
# Enables running async code inside an existing event loop without crashing.
nest_asyncio.apply()
tru = TruSession(database_url=os.getenv("SUPABASE_PROD_CONNECTION_STRING_IPV4"))
class ChatBot:
def __init__(
self,
input_file_dir,
save_dir,
config_dir,
):
self.recorder = None
self.previous_rag_mode = None
self.recorder = None
with open(os.path.join(config_dir, "basic.yaml")) as f:
self.basic_config = yaml.safe_load(f)
with open(os.path.join(config_dir, "auto_merging.yaml")) as f:
self.automerging_config = yaml.safe_load(f)
with open(os.path.join(config_dir, "sentence_window.yaml")) as f:
self.sentence_window_config = yaml.safe_load(f)
self.basic_engine = basic_rag_setup(
input_file=os.path.join(input_file_dir, self.basic_config["source_doc"]),
save_dir=save_dir,
**self.basic_config,
)
self.automerging_engine = automerging_retrieval_setup(
input_file=os.path.join(
input_file_dir, self.automerging_config["source_doc"]
),
save_dir=save_dir,
**self.automerging_config,
)
self.sentence_window_engine = sentence_window_retrieval_setup(
input_file=os.path.join(
input_file_dir, self.sentence_window_config["source_doc"]
),
save_dir=save_dir,
**self.sentence_window_config,
)
def __call__(self, query: str, rag_mode: SupportedRags):
match rag_mode:
case "classic retrieval":
if self.previous_rag_mode != rag_mode:
self.previous_rag_mode = rag_mode
self.recorder = get_prebuilt_trulens_recorder(
self.basic_engine, self.basic_config
)
with self.recorder as recording: # noqa: F841
response = self.basic_engine.query(query)
case "auto-merging retrieval":
if self.previous_rag_mode != rag_mode:
self.previous_rag_mode = rag_mode
self.recorder = get_prebuilt_trulens_recorder(
self.automerging_engine, self.automerging_config
)
with self.recorder as recording: # noqa: F841
response = self.automerging_engine.query(query)
case "sentence window retrieval":
if self.previous_rag_mode != rag_mode:
self.previous_rag_mode = rag_mode
self.recorder = get_prebuilt_trulens_recorder(
self.sentence_window_engine, self.sentence_window_config
)
with self.recorder as recording: # noqa: F841
response = self.sentence_window_engine.query(query)
return response.response
chat_bot = ChatBot(input_file_dir, save_dir, config_dir)
default_message = (
"Ask about a topic that is discussed in my master thesis."
" E.g., what is this master thesis about? Or what is epistemic uncertainty?"
)
with open(welcome_message_path, encoding="utf-8") as f:
description = f.read()
gradio_app = gr.Interface(
fn=chat_bot,
inputs=[
gr.Textbox(placeholder=default_message, label="Query", lines=2),
gr.Dropdown(
choices=SupportedRags.__args__,
label="RAG mode",
value=SupportedRags.__args__[0],
),
],
outputs=[
gr.Textbox(label="Answer"),
],
title="RAG powered chatbot",
description=description,
)
gradio_app.launch()
|