Spaces:
Running
Running
File size: 9,229 Bytes
29cc4c5 90cb4f4 29cc4c5 90cb4f4 0759822 876e6bb 0960577 876e6bb 46dae9a dd3763f 46dae9a 4f7c053 46dae9a 4f7c053 46dae9a 4f7c053 46dae9a 4f7c053 46dae9a 4f7c053 46dae9a 4f7c053 0960577 4f7c053 0960577 4f7c053 0960577 4f7c053 876e6bb 0759822 90cb4f4 0759822 876e6bb 0759822 0960577 0759822 0960577 0759822 29cc4c5 0960577 29cc4c5 0960577 29cc4c5 0759822 90cb4f4 46dae9a 0759822 29cc4c5 0759822 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
from db.schema import Response, ModelRatings
import streamlit as st
from datetime import datetime
from dotenv import load_dotenv
from views.nav_buttons import navigation_buttons
load_dotenv()
def survey_completed():
"""Display the survey completion message."""
st.markdown("""
<div class='exit-container'>
<h1>You have already completed the survey! Thank you for participating!</h1>
<p>Your responses have been saved successfully.</p>
<p>You can safely close this window or start a new survey.</p>
</div>
""", unsafe_allow_html=True)
st.session_state.show_questions = False
st.session_state.completed = True
st.session_state.start_new_survey = True
def display_ratings_row(model_name, config, current_index):
st.markdown(f"## {model_name.capitalize()} Ratings")
cols = st.columns(3)
with cols[0]:
query_v_ratings = render_query_ratings(model_name, "Query_v",
config, f"{model_name}_query_v", current_index,
has_persona_alignment=False)
with cols[1]:
query_p0_ratings = render_query_ratings(model_name, "Query_p0",
config, f"{model_name}_query_p0", current_index,
has_persona_alignment=True)
with cols[2]:
query_p1_ratings = render_query_ratings(model_name, "Query_p1",
config, f"{model_name}_query_p1",
current_index, has_persona_alignment=True)
if "persona_alignment" in query_v_ratings:
query_v_ratings.pop("persona_alignment")
return {
"query_v_ratings": query_v_ratings,
"query_p0_ratings": query_p0_ratings,
"query_p1_ratings": query_p1_ratings,
}
def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
"""Helper function to render ratings for a given query."""
previous_ratings = {}
# If the user is coming from a next button press, then previous rating will not exist
if current_index < st.session_state.current_index and len(st.session_state.responses) > current_index:
if st.session_state.previous_ratings:
previous_ratings = st.session_state.previous_ratings.get(config["config_id"], {})
if "gemini" == model_name:
previous_ratings = previous_ratings.get("gemini", None)
else:
previous_ratings = previous_ratings.get("llama", None)
# This means there were no previous responses, i.e first time opening the page, or just clicking continue
elif len(st.session_state.responses) <= current_index:
previous_ratings = {}
# User has already entered some response in the page they are in
else:
# get the saved ratings from session state for this question
response_from_session = st.session_state.responses[current_index]
if "gemini" == model_name:
try:
previous_ratings = response_from_session.model_ratings.get("gemini", {})
except AttributeError:
previous_ratings = response_from_session["model_ratings"].get("gemini", {})
else:
try:
previous_ratings = response_from_session.model_ratings.get("llama", {})
except AttributeError:
previous_ratings = response_from_session["model_ratings"].get("llama", {})
stored_query_ratings = {}
if previous_ratings:
if "query_v" in query_key:
try:
stored_query_ratings = previous_ratings.query_v_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_v_ratings"]
elif "query_p0" in query_key:
try:
stored_query_ratings = previous_ratings.query_p0_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_p0_ratings"]
elif "query_p1" in query_key:
try:
stored_query_ratings = previous_ratings.query_p1_ratings
except AttributeError:
stored_query_ratings = previous_ratings["query_p1_ratings"]
else:
stored_query_ratings = {}
stored_relevance = stored_query_ratings.get("relevance", 0) if stored_query_ratings else 0
stored_clarity = stored_query_ratings.get("clarity", 0) if stored_query_ratings else 0
stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment and stored_query_ratings else 0
if model_name == "gemini":
bg_color = "#e0f7fa"
else:
bg_color = "#f0f4c3"
with st.container():
st.markdown(f"""
<div style="background-color:{bg_color}; padding:1rem;">
<h3 style="color:blue;"> {query_label} </h3>
<p style="text-align:left;">{config[query_key]}</p>
</div>
""", unsafe_allow_html=True)
columns = st.columns(3)
options = [0, 1, 2, 3, 4]
persona_alignment_rating = None
if has_persona_alignment:
with columns[0]:
persona_alignment_rating = st.radio(
"Persona Alignment:", options=[0, 1, 2, 3, 4],
format_func=lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][x],
key=f"rating_{query_key}_persona_alignment_{current_index}",
index=stored_persona_alignment if stored_persona_alignment is not None else 0
)
with columns[1]:
relevance_rating = st.radio(
"Relevance:",
options,
key=f"rating_{query_key}_relevance_{current_index}",
format_func=lambda x: ["N/A", "Not Relevant", "Somewhat Relevant", "Relevant", "Unclear"][x],
index=stored_relevance if stored_relevance is not None else 0
)
with columns[2]:
with columns[2]:
clarity_rating = st.radio(
"Clarity:",
options=[0, 1, 2, 3],
key=f"rating_{query_key}_clarity_{current_index}",
format_func=lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x],
index=stored_clarity if stored_clarity is not None else 0
)
return {
"clarity": clarity_rating,
"relevance": relevance_rating,
"persona_alignment": persona_alignment_rating if has_persona_alignment else None
}
def questions_screen(data):
"""Display the questions screen with split layout"""
current_index = st.session_state.current_index
try:
config = data.iloc[current_index]
# Progress bar
progress = (current_index + 1) / len(data)
st.progress(progress)
st.write(f"Question {current_index + 1} of {len(data)}")
st.subheader(f"Config ID: {config['config_id']}")
# Context information
st.markdown("### Context Information")
with st.expander("Persona", expanded=True):
st.write(config['persona'])
with st.expander("Filters & Cities", expanded=True):
st.write("**Filters:**", config['filters'])
st.write("**Cities:**", config['city'])
with st.expander("Full Context", expanded=False):
st.text_area("", config['context'], height=300, disabled=False)
g_ratings = display_ratings_row("gemini", config, current_index)
l_ratings = display_ratings_row("llama", config, current_index)
# Additional comments
comment = st.text_area("Additional Comments (Optional):")
# Collecting the response data
response = Response(
config_id=config["config_id"],
model_ratings={
"gemini": ModelRatings(
query_v_ratings=g_ratings["query_v_ratings"],
query_p0_ratings=g_ratings["query_p0_ratings"],
query_p1_ratings=g_ratings["query_p1_ratings"],
),
"llama": ModelRatings(
query_v_ratings=l_ratings["query_v_ratings"],
query_p0_ratings=l_ratings["query_p0_ratings"],
query_p1_ratings=l_ratings["query_p1_ratings"],
)
},
comment=comment,
timestamp=datetime.now().isoformat()
)
print(response)
try:
st.session_state.ratings[current_index] = response["model_ratings"]
except TypeError:
st.session_state.ratings[current_index] = response.model_ratings
if len(st.session_state.responses) > current_index:
st.session_state.responses[current_index] = response
else:
st.session_state.responses.append(response)
# Navigation buttons
navigation_buttons(data, response)
except IndexError:
print("Survey completed!")
|