Spaces:
Running
Running
from db.schema import Response, ModelRatings | |
import streamlit as st | |
from datetime import datetime | |
from dotenv import load_dotenv | |
from views.nav_buttons import navigation_buttons | |
import random | |
load_dotenv() | |
def display_completion_message(): | |
"""Display a standardized survey completion message.""" | |
st.markdown( | |
""" | |
<div class='exit-container'> | |
<h1>You have already completed the survey! Thank you for participating!</h1> | |
<p>Your responses have been saved successfully.</p> | |
<p>You can safely close this window or start a new survey.</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.session_state.show_questions = False | |
st.session_state.completed = True | |
st.session_state.start_new_survey = True | |
def get_previous_ratings(model_name, query_key, current_index): | |
"""Retrieve previous ratings from session state.""" | |
previous_ratings = {} | |
if current_index < st.session_state.current_index and len( | |
st.session_state.responses | |
) > current_index: | |
if st.session_state.previous_ratings: | |
previous_ratings = st.session_state.previous_ratings.get( | |
st.session_state.data.iloc[current_index]["config_id"], {} | |
) | |
previous_ratings = previous_ratings.get( | |
model_name, None | |
) # Fix: Model key from session state | |
elif len(st.session_state.responses) <= current_index: | |
previous_ratings = {} | |
else: | |
response_from_session = st.session_state.responses[current_index] | |
try: | |
previous_ratings = response_from_session.model_ratings.get(model_name, {}) | |
except AttributeError: | |
previous_ratings = response_from_session["model_ratings"].get(model_name, {}) | |
stored_query_ratings = {} | |
if previous_ratings: | |
if "query_v" in query_key: | |
try: | |
stored_query_ratings = previous_ratings.query_v_ratings | |
except AttributeError: | |
stored_query_ratings = previous_ratings["query_v_ratings"] | |
elif "query_p0" in query_key: | |
try: | |
stored_query_ratings = previous_ratings.query_p0_ratings | |
except AttributeError: | |
stored_query_ratings = previous_ratings["query_p0_ratings"] | |
elif "query_p1" in query_key: | |
try: | |
stored_query_ratings = previous_ratings.query_p1_ratings | |
except AttributeError: | |
stored_query_ratings = previous_ratings["query_p1_ratings"] | |
return stored_query_ratings if stored_query_ratings else {} | |
def render_single_rating( | |
label, | |
options, | |
format_func, | |
key_prefix, | |
stored_rating, | |
col, | |
): | |
"""Renders a single rating widget (radio).""" | |
with col: | |
return st.radio( | |
label, | |
options=options, | |
format_func=format_func, | |
key=f"{key_prefix}", | |
index=stored_rating if stored_rating is not None else 0, | |
) | |
def clean_query_text(query_text): | |
"""Clean the query text for display.""" | |
if query_text.startswith('"') or query_text.startswith("'") or query_text.endswith('"') or query_text.endswith("'"): | |
query_text = query_text.replace('"', '').replace("'", "") | |
if query_text[-1] not in [".", "?", "!", "\n"]: | |
query_text += "." | |
return query_text.capitalize() | |
def render_query_ratings( | |
model_name, | |
config, | |
query_key, | |
current_index, | |
has_persona_alignment=False, | |
): | |
"""Helper function to render ratings for a given query.""" | |
stored_query_ratings = get_previous_ratings(model_name, query_key, current_index) | |
stored_relevance = stored_query_ratings.get("relevance", 0) | |
stored_clarity = stored_query_ratings.get("clarity", 0) | |
stored_persona_alignment = ( | |
stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment else 0 | |
) | |
if model_name == "gemini": | |
bg_color = "#e0f7fa" | |
else: | |
bg_color = "#f0f4c3" | |
query_text = clean_query_text(config[model_name + "_" + query_key]) | |
with st.container(): | |
st.markdown( | |
f""" | |
<div style="background-color:{bg_color}; padding:1rem;"> | |
<h3 style="text-align:left;"> | |
{config.index.get_loc(model_name + "_" + query_key) - 5} | |
</h3> | |
<p style="text-align:left;"> | |
{query_text}</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
cols = st.columns(3) | |
options = [0, 1, 2, 3, 4] | |
persona_alignment_rating = None | |
if has_persona_alignment: | |
persona_alignment_rating = render_single_rating( | |
"Persona Alignment:", | |
options, | |
lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][ | |
x | |
], | |
f"rating_{model_name}{query_key}_persona_alignment_", | |
stored_persona_alignment, | |
cols[0], | |
) | |
relevance_rating = render_single_rating( | |
"Relevance:", | |
options, | |
lambda x: ["N/A", "Not Relevant", "Somewhat Relevant", "Relevant", "Unclear"][ | |
x | |
], | |
f"rating_{model_name}{query_key}_relevance_", | |
stored_relevance, | |
cols[1], | |
) | |
clarity_rating = render_single_rating( | |
"Clarity:", | |
[0, 1, 2, 3], | |
lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x], | |
f"rating_{model_name}{query_key}_clarity_", | |
stored_clarity, | |
cols[2], | |
) | |
return { | |
"clarity": clarity_rating, | |
"relevance": relevance_rating, | |
"persona_alignment": persona_alignment_rating if has_persona_alignment else None, | |
} | |
def display_ratings_row(model_name, config, current_index): | |
# st.markdown(f"## {model_name.capitalize()} Ratings") | |
cols = st.columns(3) | |
# combinations = ["query_v", "query_p0", "query_p1"] | |
# random.shuffle(combinations) | |
with cols[0]: | |
query_v_ratings = render_query_ratings( | |
model_name, | |
config, | |
"query_v", | |
current_index, | |
has_persona_alignment=False, | |
) | |
with cols[1]: | |
query_p0_ratings = render_query_ratings( | |
model_name, | |
config, | |
"query_p0", | |
current_index, | |
has_persona_alignment=True, | |
) | |
with cols[2]: | |
query_p1_ratings = render_query_ratings( | |
model_name, | |
config, | |
"query_p1", | |
current_index, | |
has_persona_alignment=True, | |
) | |
if "persona_alignment" in query_v_ratings: | |
query_v_ratings.pop("persona_alignment") | |
return { | |
"query_v_ratings": query_v_ratings, | |
"query_p0_ratings": query_p0_ratings, | |
"query_p1_ratings": query_p1_ratings, | |
} | |
def questions_screen(data): | |
"""Display the questions screen with split layout.""" | |
current_index = st.session_state.current_index | |
try: | |
config = data.iloc[current_index] | |
# Progress bar | |
progress = (current_index + 1) / len(data) | |
st.progress(progress) | |
st.write(f"Question {current_index + 1} of {len(data)}") | |
# st.subheader(f"Config ID: {config['config_id']}") | |
st.markdown("### Instructions") | |
with st.expander("Instructions", expanded=False): | |
st.markdown( | |
""" | |
""" | |
) | |
st.html('''<p style='font-size:large;'>You will be <mark>given a user profile and a travel-related query</mark>. Your task is to <mark>evaluate the generated queries (numbered 1-6)</mark> based on the following criteria:</p> | |
<p><strong><mark>Relevance</mark>:</strong> Evaluate how well the query aligns with the given cities, filters, and displayed context. Consider whether the query description matches the cities and context provided. | |
<br> Select one of the following options: | |
<ol style="padding-left:2rem;"> | |
<li><b>Not Relevant</b> - The query has no connection to the cities, filters, or displayed context.</li> | |
<li><b>Somewhat Relevant</b> - The query is partially related but does not fully match the cities or context.</li> | |
<li><b>Relevant</b> - The query clearly aligns with the cities, filters, and displayed context.</li> | |
<li><b>Unclear</b> - The relevance of the query is difficult to determine based on the given information.</li> | |
</ol> | |
</p> | |
<p><strong><mark>Clarity Assessment</mark>:</strong> Evaluate how clear and understandable the query is. Consider whether it is grammatically correct and easy to interpret. | |
<br>Your options are: | |
<ol style="padding-left:2rem;"> | |
<li><b>Not Clear</b> - The query is difficult to understand or contains significant grammatical errors.</li> | |
<li><b>Somewhat Clear</b> - The query is understandable but may have minor grammatical issues or slight ambiguity.</li> | |
<li><b>Very Clear</b> - The query is well-formed, grammatically correct, and easy to understand.</li> | |
</ol> | |
</p> | |
<p> | |
<strong><mark>Persona Alignment</mark>:</strong> How likely is the query to match the persona and reflect a question they would | |
ask about travel? <br>Your options are: <ol style='padding-left:2rem;'> <li><b>Not Aligned</b> - The user | |
is not likely at all to ask this query.</li> <li><b>Partially Aligned</b> - The user is quite likely to | |
ask this query.</li> <li><b>Aligned</b> - The user is very likely to ask this query. </li> | |
<li><b>Unclear</b> - It is unclear whether the user will ask this query.</li> </ol> </p>''') | |
# Context information | |
st.markdown("### Context Information") | |
with st.expander("Persona", expanded=True): | |
st.write(config["persona"]) | |
with st.expander("Filters & Cities", expanded=True): | |
st.write("**Filters:**", config["filters"]) | |
st.write("**Cities:**", config["city"]) | |
with st.expander("Full Context", expanded=False): | |
st.text_area("", config["context"], height=300, disabled=False) | |
g_ratings = display_ratings_row("gemini", config, current_index) | |
l_ratings = display_ratings_row("llama", config, current_index) | |
# Additional comments | |
comment = st.text_area("Additional Comments (Optional):") | |
# Collecting the response data | |
response = Response( | |
config_id=config["config_id"], | |
model_ratings={ | |
"gemini": ModelRatings( | |
query_v_ratings=g_ratings["query_v_ratings"], | |
query_p0_ratings=g_ratings["query_p0_ratings"], | |
query_p1_ratings=g_ratings["query_p1_ratings"], | |
), | |
"llama": ModelRatings( | |
query_v_ratings=l_ratings["query_v_ratings"], | |
query_p0_ratings=l_ratings["query_p0_ratings"], | |
query_p1_ratings=l_ratings["query_p1_ratings"], | |
), | |
}, | |
comment=comment, | |
timestamp=datetime.now().isoformat(), | |
) | |
try: | |
st.session_state.ratings[current_index] = response["model_ratings"] | |
except TypeError: | |
st.session_state.ratings[current_index] = response.model_ratings | |
if len(st.session_state.responses) > current_index: | |
st.session_state.responses[current_index] = response | |
else: | |
st.session_state.responses.append(response) | |
# Navigation buttons | |
navigation_buttons(data, response) | |
except IndexError: | |
print("Survey completed!") | |