import os from pathlib import Path from typing import Literal import numpy as np from tqdm import tqdm from trulens.apps.llamaindex import TruLlama from trulens.core import Feedback, TruSession from trulens.providers.openai import OpenAI from src.mythesis_chatbot.utils import get_config_hash def run_evals(eval_questions_path: Path, tru_recorder, query_engine): eval_questions = [] with open(eval_questions_path) as file: for line in file: item = line.strip() eval_questions.append(item) for question in tqdm(eval_questions): with tru_recorder as recording: # noqa: F841 response = query_engine.query(question) # noqa: F841 # Feedback function def f_answer_relevance(provider=OpenAI(), name="Answer Relevance") -> Feedback: return Feedback(provider.relevance_with_cot_reasons, name=name).on_input_output() # Feedback function def f_context_relevance( provider=OpenAI(), context=TruLlama.select_source_nodes().node.text, name="Context Relevance", ) -> Feedback: return ( Feedback(provider.relevance, name=name) .on_input() .on(context) .aggregate(np.mean) ) # Feedback function def f_groundedness( provider=OpenAI(), context=TruLlama.select_source_nodes().node.text, name="Groundedness", ) -> Feedback: return ( Feedback( provider.groundedness_measure_with_cot_reasons, name=name, ) .on(context) .on_output() ) def get_prebuilt_trulens_recorder( query_engine, query_engine_config: dict[str, str | int] ) -> TruLlama: app_name = query_engine_config["rag_mode"] app_version = get_config_hash(query_engine_config) tru_recorder = TruLlama( query_engine, app_name=app_name, app_version=app_version, metadata=query_engine_config, feedbacks=[f_answer_relevance(), f_context_relevance(), f_groundedness()], ) return tru_recorder def get_tru_session(database: Literal["prod", "dev"]) -> TruSession: print(f"Connecting to {database.lower()} database...") match database.lower(): case "prod": database_url = os.getenv("SUPABASE_PROD_CONNECTION_STRING_IPV4") if database_url is None: raise RuntimeError( "IPv4 connection string to production database is not available as" " an environment variable." ) else: print("Using IPv4 connection string...") tru = TruSession(database_url=database_url) return tru case "dev": database_url = os.getenv("SUPABASE_DEV_CONNECTION_STRING_IPV6") if database_url: try: print("Using IPv6 connection string...") tru = TruSession(database_url=database_url) return tru except Exception as e: print( "An error occurred while connecting to remote dev database with" f" IPv6 connection string: {e}" ) print("Reverting to IPv4") else: print( "IPv6 connection string to dev database is not available as an" " environment variable. Reverting to IPv4." ) database_url = os.getenv("SUPABASE_DEV_CONNECTION_STRING_IPV4") if database_url is None: raise RuntimeError( "IPv4 connection string to dev database is not available" " as an environment variable." ) else: tru = TruSession(database_url=database_url) return tru case _: raise ValueError( f"Invalid database: {database}. Choose betwen 'prod' and 'dev'" )