File size: 4,313 Bytes
3944997
 
 
45e69ef
3944997
 
 
45e69ef
3944997
e917d8a
45e69ef
 
 
 
 
 
3944997
 
 
 
45e69ef
3944997
 
45e69ef
4fe2243
3944997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45e69ef
3944997
 
45e69ef
3944997
 
 
 
 
45e69ef
3944997
 
45e69ef
3944997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45e69ef
3944997
 
45e69ef
 
3944997
 
 
45e69ef
 
 
cdbb4e1
45e69ef
3944997
45e69ef
3944997
45e69ef
 
3944997
 
 
 
 
45e69ef
 
3944997
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from pathlib import Path

import gradio as gr
import nest_asyncio
import yaml
from trulens.core import TruSession

from src.mythesis_chatbot.evaluation import get_prebuilt_trulens_recorder
from src.mythesis_chatbot.rag_setup import (
    SupportedRags,
    automerging_retrieval_setup,
    basic_rag_setup,
    sentence_window_retrieval_setup,
)

input_file_dir = Path(__file__).parents[1] / "data/"
save_dir = Path(__file__).parents[1] / "data/indices/"
config_dir = Path(__file__).parents[1] / "configs/"
welcome_message_path = Path(__file__).parents[1] / "spaces/welcome_message.md"

# Enables running async code inside an existing event loop without crashing.
nest_asyncio.apply()

tru = TruSession(database_url=os.getenv("SUPABASE_PROD_CONNECTION_STRING_IPV4"))


class ChatBot:
    def __init__(
        self,
        input_file_dir,
        save_dir,
        config_dir,
    ):
        self.recorder = None
        self.previous_rag_mode = None
        self.recorder = None

        with open(os.path.join(config_dir, "basic.yaml")) as f:
            self.basic_config = yaml.safe_load(f)
        with open(os.path.join(config_dir, "auto_merging.yaml")) as f:
            self.automerging_config = yaml.safe_load(f)
        with open(os.path.join(config_dir, "sentence_window.yaml")) as f:
            self.sentence_window_config = yaml.safe_load(f)

        self.basic_engine = basic_rag_setup(
            input_file=os.path.join(input_file_dir, self.basic_config["source_doc"]),
            save_dir=save_dir,
            **self.basic_config,
        )
        self.automerging_engine = automerging_retrieval_setup(
            input_file=os.path.join(
                input_file_dir, self.automerging_config["source_doc"]
            ),
            save_dir=save_dir,
            **self.automerging_config,
        )
        self.sentence_window_engine = sentence_window_retrieval_setup(
            input_file=os.path.join(
                input_file_dir, self.sentence_window_config["source_doc"]
            ),
            save_dir=save_dir,
            **self.sentence_window_config,
        )

    def __call__(self, query: str, rag_mode: SupportedRags):

        match rag_mode:
            case "classic retrieval":

                if self.previous_rag_mode != rag_mode:
                    self.previous_rag_mode = rag_mode
                    self.recorder = get_prebuilt_trulens_recorder(
                        self.basic_engine, self.basic_config
                    )

                with self.recorder as recording:  # noqa: F841
                    response = self.basic_engine.query(query)

            case "auto-merging retrieval":
                if self.previous_rag_mode != rag_mode:
                    self.previous_rag_mode = rag_mode
                    self.recorder = get_prebuilt_trulens_recorder(
                        self.automerging_engine, self.automerging_config
                    )

                with self.recorder as recording:  # noqa: F841
                    response = self.automerging_engine.query(query)

            case "sentence window retrieval":
                if self.previous_rag_mode != rag_mode:
                    self.previous_rag_mode = rag_mode
                    self.recorder = get_prebuilt_trulens_recorder(
                        self.sentence_window_engine, self.sentence_window_config
                    )

                with self.recorder as recording:  # noqa: F841
                    response = self.sentence_window_engine.query(query)

        return response.response


chat_bot = ChatBot(input_file_dir, save_dir, config_dir)
default_message = (
    "Ask about a topic that is discussed in my master thesis."
    " E.g., what is this master thesis about? Or what is epistemic uncertainty?"
)

with open(welcome_message_path, encoding="utf-8") as f:
    description = f.read()

gradio_app = gr.Interface(
    fn=chat_bot,
    inputs=[
        gr.Textbox(placeholder=default_message, label="Query", lines=2),
        gr.Dropdown(
            choices=SupportedRags.__args__,
            label="RAG mode",
            value=SupportedRags.__args__[0],
        ),
    ],
    outputs=[
        gr.Textbox(label="Answer"),
    ],
    title="RAG powered chatbot",
    description=description,
)

gradio_app.launch()