user-feedback / views /questions_screen.py
Ashmi Banerjee
hacky but works :D
46dae9a
raw
history blame
9.23 kB
from db.schema import Response, ModelRatings
import streamlit as st
from datetime import datetime
from dotenv import load_dotenv
from views.nav_buttons import navigation_buttons
load_dotenv()
def survey_completed():
"""Display the survey completion message."""
st.markdown("""
<div class='exit-container'>
<h1>You have already completed the survey! Thank you for participating!</h1>
<p>Your responses have been saved successfully.</p>
<p>You can safely close this window or start a new survey.</p>
</div>
""", unsafe_allow_html=True)
st.session_state.show_questions = False
st.session_state.completed = True
st.session_state.start_new_survey = True
def display_ratings_row(model_name, config, current_index):
st.markdown(f"## {model_name.capitalize()} Ratings")
cols = st.columns(3)
with cols[0]:
query_v_ratings = render_query_ratings(model_name, "Query_v",
config, f"{model_name}_query_v", current_index,
has_persona_alignment=False)
with cols[1]:
query_p0_ratings = render_query_ratings(model_name, "Query_p0",
config, f"{model_name}_query_p0", current_index,
has_persona_alignment=True)
with cols[2]:
query_p1_ratings = render_query_ratings(model_name, "Query_p1",
config, f"{model_name}_query_p1",
current_index, has_persona_alignment=True)
if "persona_alignment" in query_v_ratings:
query_v_ratings.pop("persona_alignment")
return {
"query_v_ratings": query_v_ratings,
"query_p0_ratings": query_p0_ratings,
"query_p1_ratings": query_p1_ratings,
}
def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
"""Helper function to render ratings for a given query."""
previous_ratings = {}
# If the user is coming from a next button press, then previous rating will not exist
if current_index < st.session_state.current_index and len(st.session_state.responses) > current_index:
if st.session_state.previous_ratings:
previous_ratings = st.session_state.previous_ratings.get(config["config_id"], {})
if "gemini" == model_name:
previous_ratings = previous_ratings.get("gemini", None)
else:
previous_ratings = previous_ratings.get("llama", None)
# This means there were no previous responses, i.e first time opening the page, or just clicking continue
elif len(st.session_state.responses) <= current_index:
previous_ratings = {}
# User has already entered some response in the page they are in
else:
# get the saved ratings from session state for this question
response_from_session = st.session_state.responses[current_index]
if "gemini" == model_name:
try:
previous_ratings = response_from_session.model_ratings.get("gemini", {})
except AttributeError:
previous_ratings = response_from_session["model_ratings"].get("gemini", {})
else:
try:
previous_ratings = response_from_session.model_ratings.get("llama", {})
except AttributeError:
previous_ratings = response_from_session["model_ratings"].get("llama", {})
stored_query_ratings = {}
if previous_ratings:
if "query_v" in query_key:
try:
stored_query_ratings = previous_ratings.query_v_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_v_ratings"]
elif "query_p0" in query_key:
try:
stored_query_ratings = previous_ratings.query_p0_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_p0_ratings"]
elif "query_p1" in query_key:
try:
stored_query_ratings = previous_ratings.query_p1_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_p1_ratings"]
else:
stored_query_ratings = {}
stored_relevance = stored_query_ratings.get("relevance", 0) if stored_query_ratings else 0
stored_clarity = stored_query_ratings.get("clarity", 0) if stored_query_ratings else 0
stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment and stored_query_ratings else 0
if model_name == "gemini":
bg_color = "#e0f7fa"
else:
bg_color = "#f0f4c3"
with st.container():
st.markdown(f"""
<div style="background-color:{bg_color}; padding:1rem;">
<h3 style="color:blue;"> {query_label} </h3>
<p style="text-align:left;">{config[query_key]}</p>
</div>
""", unsafe_allow_html=True)
columns = st.columns(3)
options = [0, 1, 2, 3, 4]
persona_alignment_rating = None
if has_persona_alignment:
with columns[0]:
persona_alignment_rating = st.radio(
"Persona Alignment:", options=[0, 1, 2, 3, 4],
format_func=lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][x],
key=f"rating_{query_key}_persona_alignment_{current_index}",
index=stored_persona_alignment if stored_persona_alignment is not None else 0
)
with columns[1]:
relevance_rating = st.radio(
"Relevance:",
options,
key=f"rating_{query_key}_relevance_{current_index}",
format_func=lambda x: ["N/A", "Not Relevant", "Somewhat Relevant", "Relevant", "Unclear"][x],
index=stored_relevance if stored_relevance is not None else 0
)
with columns[2]:
with columns[2]:
clarity_rating = st.radio(
"Clarity:",
options=[0, 1, 2, 3],
key=f"rating_{query_key}_clarity_{current_index}",
format_func=lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x],
index=stored_clarity if stored_clarity is not None else 0
)
return {
"clarity": clarity_rating,
"relevance": relevance_rating,
"persona_alignment": persona_alignment_rating if has_persona_alignment else None
}
def questions_screen(data):
"""Display the questions screen with split layout"""
current_index = st.session_state.current_index
try:
config = data.iloc[current_index]
# Progress bar
progress = (current_index + 1) / len(data)
st.progress(progress)
st.write(f"Question {current_index + 1} of {len(data)}")
st.subheader(f"Config ID: {config['config_id']}")
# Context information
st.markdown("### Context Information")
with st.expander("Persona", expanded=True):
st.write(config['persona'])
with st.expander("Filters & Cities", expanded=True):
st.write("**Filters:**", config['filters'])
st.write("**Cities:**", config['city'])
with st.expander("Full Context", expanded=False):
st.text_area("", config['context'], height=300, disabled=False)
g_ratings = display_ratings_row("gemini", config, current_index)
l_ratings = display_ratings_row("llama", config, current_index)
# Additional comments
comment = st.text_area("Additional Comments (Optional):")
# Collecting the response data
response = Response(
config_id=config["config_id"],
model_ratings={
"gemini": ModelRatings(
query_v_ratings=g_ratings["query_v_ratings"],
query_p0_ratings=g_ratings["query_p0_ratings"],
query_p1_ratings=g_ratings["query_p1_ratings"],
),
"llama": ModelRatings(
query_v_ratings=l_ratings["query_v_ratings"],
query_p0_ratings=l_ratings["query_p0_ratings"],
query_p1_ratings=l_ratings["query_p1_ratings"],
)
},
comment=comment,
timestamp=datetime.now().isoformat()
)
print(response)
try:
st.session_state.ratings[current_index] = response["model_ratings"]
except TypeError:
st.session_state.ratings[current_index] = response.model_ratings
if len(st.session_state.responses) > current_index:
st.session_state.responses[current_index] = response
else:
st.session_state.responses.append(response)
# Navigation buttons
navigation_buttons(data, response)
except IndexError:
print("Survey completed!")