File size: 3,977 Bytes
cdbb4e1
5de4570
cdbb4e1
5de4570
 
 
 
cdbb4e1
5de4570
 
e47aaa6
5de4570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdbb4e1
5de4570
 
 
 
 
 
 
 
cdbb4e1
5de4570
 
 
 
 
 
 
 
 
 
 
 
 
cdbb4e1
5de4570
 
 
 
 
 
 
 
 
 
 
 
cdbb4e1
5de4570
 
 
 
 
 
 
 
 
 
 
cdbb4e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from pathlib import Path
from typing import Literal

import numpy as np
from tqdm import tqdm
from trulens.apps.llamaindex import TruLlama
from trulens.core import Feedback, TruSession
from trulens.providers.openai import OpenAI

from src.mythesis_chatbot.utils import get_config_hash


def run_evals(eval_questions_path: Path, tru_recorder, query_engine):

    eval_questions = []
    with open(eval_questions_path) as file:
        for line in file:
            item = line.strip()
            eval_questions.append(item)

    for question in tqdm(eval_questions):
        with tru_recorder as recording:  # noqa: F841
            response = query_engine.query(question)  # noqa: F841


# Feedback function
def f_answer_relevance(provider=OpenAI(), name="Answer Relevance") -> Feedback:
    return Feedback(provider.relevance_with_cot_reasons, name=name).on_input_output()


# Feedback function
def f_context_relevance(
    provider=OpenAI(),
    context=TruLlama.select_source_nodes().node.text,
    name="Context Relevance",
) -> Feedback:
    return (
        Feedback(provider.relevance, name=name)
        .on_input()
        .on(context)
        .aggregate(np.mean)
    )


# Feedback function
def f_groundedness(
    provider=OpenAI(),
    context=TruLlama.select_source_nodes().node.text,
    name="Groundedness",
) -> Feedback:
    return (
        Feedback(
            provider.groundedness_measure_with_cot_reasons,
            name=name,
        )
        .on(context)
        .on_output()
    )


def get_prebuilt_trulens_recorder(
    query_engine, query_engine_config: dict[str, str | int]
) -> TruLlama:
    app_name = query_engine_config["rag_mode"]
    app_version = get_config_hash(query_engine_config)

    tru_recorder = TruLlama(
        query_engine,
        app_name=app_name,
        app_version=app_version,
        metadata=query_engine_config,
        feedbacks=[f_answer_relevance(), f_context_relevance(), f_groundedness()],
    )
    return tru_recorder


def get_tru_session(database: Literal["prod", "dev"]) -> TruSession:

    print(f"Connecting to {database.lower()} database...")

    match database.lower():
        case "prod":
            database_url = os.getenv("SUPABASE_PROD_CONNECTION_STRING_IPV4")
            if database_url is None:
                raise RuntimeError(
                    "IPv4 connection string to production database is not available as"
                    " an environment variable."
                )
            else:
                print("Using IPv4 connection string...")
                tru = TruSession(database_url=database_url)
                return tru

        case "dev":
            database_url = os.getenv("SUPABASE_DEV_CONNECTION_STRING_IPV6")
            if database_url:
                try:
                    print("Using IPv6 connection string...")
                    tru = TruSession(database_url=database_url)
                    return tru
                except Exception as e:
                    print(
                        "An error occurred while connecting to remote dev database with"
                        f" IPv6 connection string: {e}"
                    )
                    print("Reverting to IPv4")
            else:
                print(
                    "IPv6 connection string to dev database is not available as an"
                    " environment variable. Reverting to IPv4."
                )

            database_url = os.getenv("SUPABASE_DEV_CONNECTION_STRING_IPV4")
            if database_url is None:
                raise RuntimeError(
                    "IPv4 connection string to dev database is not available"
                    " as an environment variable."
                )
            else:
                tru = TruSession(database_url=database_url)
                return tru
        case _:
            raise ValueError(
                f"Invalid database: {database}. Choose betwen 'prod' and 'dev'"
            )