from db.schema import Response, ModelRatings import streamlit as st from datetime import datetime from dotenv import load_dotenv from views.nav_buttons import navigation_buttons load_dotenv() def survey_completed(): """Display the survey completion message.""" st.markdown("""

You have already completed the survey! Thank you for participating!

Your responses have been saved successfully.

You can safely close this window or start a new survey.

""", unsafe_allow_html=True) st.session_state.show_questions = False st.session_state.completed = True st.session_state.start_new_survey = True def display_ratings_row(model_name, config, current_index): st.markdown(f"## {model_name.capitalize()} Ratings") cols = st.columns(3) with cols[0]: query_v_ratings = render_query_ratings(model_name, "Query_v", config, f"{model_name}_query_v", current_index, has_persona_alignment=False) with cols[1]: query_p0_ratings = render_query_ratings(model_name, "Query_p0", config, f"{model_name}_query_p0", current_index, has_persona_alignment=True) with cols[2]: query_p1_ratings = render_query_ratings(model_name, "Query_p1", config, f"{model_name}_query_p1", current_index, has_persona_alignment=True) if "persona_alignment" in query_v_ratings: query_v_ratings.pop("persona_alignment") return { "query_v_ratings": query_v_ratings, "query_p0_ratings": query_p0_ratings, "query_p1_ratings": query_p1_ratings, } def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False): """Helper function to render ratings for a given query.""" previous_ratings = {} # If the user is coming from a next button press, then previous rating will not exist if current_index < st.session_state.current_index and len(st.session_state.responses) > current_index: if st.session_state.previous_ratings: previous_ratings = st.session_state.previous_ratings.get(config["config_id"], {}) if "gemini" == model_name: previous_ratings = previous_ratings.get("gemini", None) else: previous_ratings = previous_ratings.get("llama", None) # This means there were no previous responses, i.e first time opening the page, or just clicking continue elif len(st.session_state.responses) <= current_index: previous_ratings = {} # User has already entered some response in the page they are in else: # get the saved ratings from session state for this question response_from_session = st.session_state.responses[current_index] if "gemini" == model_name: try: previous_ratings = response_from_session.model_ratings.get("gemini", {}) except AttributeError: previous_ratings = response_from_session["model_ratings"].get("gemini", {}) else: try: previous_ratings = response_from_session.model_ratings.get("llama", {}) except AttributeError: previous_ratings = response_from_session["model_ratings"].get("llama", {}) stored_query_ratings = {} if previous_ratings: if "query_v" in query_key: try: stored_query_ratings = previous_ratings.query_v_ratings except AttributeError: stored_query_ratings = previous_ratings["query_v_ratings"] elif "query_p0" in query_key: try: stored_query_ratings = previous_ratings.query_p0_ratings except AttributeError: stored_query_ratings = previous_ratings["query_p0_ratings"] elif "query_p1" in query_key: try: stored_query_ratings = previous_ratings.query_p1_ratings except AttributeError: stored_query_ratings = previous_ratings["query_p1_ratings"] else: stored_query_ratings = {} stored_relevance = stored_query_ratings.get("relevance", 0) if stored_query_ratings else 0 stored_clarity = stored_query_ratings.get("clarity", 0) if stored_query_ratings else 0 stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment and stored_query_ratings else 0 if model_name == "gemini": bg_color = "#e0f7fa" else: bg_color = "#f0f4c3" with st.container(): st.markdown(f"""

{query_label}

{config[query_key]}

""", unsafe_allow_html=True) columns = st.columns(3) options = [0, 1, 2, 3, 4] persona_alignment_rating = None if has_persona_alignment: with columns[0]: persona_alignment_rating = st.radio( "Persona Alignment:", options=[0, 1, 2, 3, 4], format_func=lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][x], key=f"rating_{query_key}_persona_alignment_{current_index}", index=stored_persona_alignment if stored_persona_alignment is not None else 0 ) with columns[1]: relevance_rating = st.radio( "Relevance:", options, key=f"rating_{query_key}_relevance_{current_index}", format_func=lambda x: ["N/A", "Not Relevant", "Somewhat Relevant", "Relevant", "Unclear"][x], index=stored_relevance if stored_relevance is not None else 0 ) with columns[2]: with columns[2]: clarity_rating = st.radio( "Clarity:", options=[0, 1, 2, 3], key=f"rating_{query_key}_clarity_{current_index}", format_func=lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x], index=stored_clarity if stored_clarity is not None else 0 ) return { "clarity": clarity_rating, "relevance": relevance_rating, "persona_alignment": persona_alignment_rating if has_persona_alignment else None } def questions_screen(data): """Display the questions screen with split layout""" current_index = st.session_state.current_index try: config = data.iloc[current_index] # Progress bar progress = (current_index + 1) / len(data) st.progress(progress) st.write(f"Question {current_index + 1} of {len(data)}") st.subheader(f"Config ID: {config['config_id']}") # Context information st.markdown("### Context Information") with st.expander("Persona", expanded=True): st.write(config['persona']) with st.expander("Filters & Cities", expanded=True): st.write("**Filters:**", config['filters']) st.write("**Cities:**", config['city']) with st.expander("Full Context", expanded=False): st.text_area("", config['context'], height=300, disabled=False) g_ratings = display_ratings_row("gemini", config, current_index) l_ratings = display_ratings_row("llama", config, current_index) # Additional comments comment = st.text_area("Additional Comments (Optional):") # Collecting the response data response = Response( config_id=config["config_id"], model_ratings={ "gemini": ModelRatings( query_v_ratings=g_ratings["query_v_ratings"], query_p0_ratings=g_ratings["query_p0_ratings"], query_p1_ratings=g_ratings["query_p1_ratings"], ), "llama": ModelRatings( query_v_ratings=l_ratings["query_v_ratings"], query_p0_ratings=l_ratings["query_p0_ratings"], query_p1_ratings=l_ratings["query_p1_ratings"], ) }, comment=comment, timestamp=datetime.now().isoformat() ) print(response) try: st.session_state.ratings[current_index] = response["model_ratings"] except TypeError: st.session_state.ratings[current_index] = response.model_ratings if len(st.session_state.responses) > current_index: st.session_state.responses[current_index] = response else: st.session_state.responses.append(response) # Navigation buttons navigation_buttons(data, response) except IndexError: print("Survey completed!")