|
import streamlit as st |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from typing import List, Optional |
|
|
|
from ..core.glicko2_ranking import analyze_device_glicko2_matches |
|
from ..components.visualizations import clean_device_id |
|
|
|
|
|
def create_head_to_head_battle_chart( |
|
device1: str, |
|
device2: str, |
|
device1_display: str, |
|
device2_display: str, |
|
token_wins_1: int, |
|
prompt_wins_1: int, |
|
combined_wins_1: int, |
|
total_matches: int, |
|
): |
|
"""Create an engaging head-to-head battle visualization.""" |
|
|
|
|
|
token_pct_1 = token_wins_1 / total_matches * 100 |
|
token_pct_2 = 100 - token_pct_1 |
|
|
|
prompt_pct_1 = prompt_wins_1 / total_matches * 100 |
|
prompt_pct_2 = 100 - prompt_pct_1 |
|
|
|
combined_pct_1 = combined_wins_1 / total_matches * 100 |
|
combined_pct_2 = 100 - combined_pct_1 |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace( |
|
go.Bar( |
|
y=["Token Gen", "Prompt Proc", "Combined"], |
|
x=[token_pct_1, prompt_pct_1, combined_pct_1], |
|
name=device1_display, |
|
orientation="h", |
|
marker=dict( |
|
color="rgba(58, 71, 180, 0.8)", |
|
line=dict(color="rgba(58, 71, 180, 1.0)", width=2), |
|
), |
|
text=[ |
|
f"{token_pct_1:.1f}%", |
|
f"{prompt_pct_1:.1f}%", |
|
f"{combined_pct_1:.1f}%", |
|
], |
|
textposition="inside", |
|
insidetextanchor="middle", |
|
hoverinfo="text", |
|
hovertext=[ |
|
f"{device1_display}<br>Token Wins: {token_wins_1} ({token_pct_1:.1f}%)", |
|
f"{device1_display}<br>Prompt Wins: {prompt_wins_1} ({prompt_pct_1:.1f}%)", |
|
f"{device1_display}<br>Combined Wins: {combined_wins_1} ({combined_pct_1:.1f}%)", |
|
], |
|
width=0.5, |
|
) |
|
) |
|
|
|
|
|
token_wins_2 = total_matches - token_wins_1 |
|
prompt_wins_2 = total_matches - prompt_wins_1 |
|
combined_wins_2 = total_matches - combined_wins_1 |
|
|
|
fig.add_trace( |
|
go.Bar( |
|
y=["Token Gen", "Prompt Proc", "Combined"], |
|
x=[-token_pct_2, -prompt_pct_2, -combined_pct_2], |
|
name=device2_display, |
|
orientation="h", |
|
marker=dict( |
|
color="rgba(231, 99, 99, 0.8)", |
|
line=dict(color="rgba(231, 99, 99, 1.0)", width=2), |
|
), |
|
text=[ |
|
f"{token_pct_2:.1f}%", |
|
f"{prompt_pct_2:.1f}%", |
|
f"{combined_pct_2:.1f}%", |
|
], |
|
textposition="inside", |
|
insidetextanchor="middle", |
|
hoverinfo="text", |
|
hovertext=[ |
|
f"{device2_display}<br>Token Wins: {token_wins_2} ({token_pct_2:.1f}%)", |
|
f"{device2_display}<br>Prompt Wins: {prompt_wins_2} ({prompt_pct_2:.1f}%)", |
|
f"{device2_display}<br>Combined Wins: {combined_wins_2} ({combined_pct_2:.1f}%)", |
|
], |
|
width=0.5, |
|
) |
|
) |
|
|
|
|
|
fig.add_shape( |
|
type="line", |
|
x0=0, |
|
y0=-0.5, |
|
x1=0, |
|
y1=2.5, |
|
line=dict(color="black", width=2, dash="solid"), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fig.update_layout( |
|
title=dict( |
|
text=f"⚔️ {device1_display} vs {device2_display} ⚔️", |
|
font=dict(size=24, family="Arial Black"), |
|
x=0.5, |
|
), |
|
barmode="overlay", |
|
bargap=0.15, |
|
bargroupgap=0.1, |
|
legend=dict(x=0.5, y=1.05, xanchor="center", orientation="h"), |
|
xaxis=dict( |
|
title="Win Rate (%)", |
|
range=[-100, 100], |
|
tickvals=[-100, -75, -50, -25, 0, 25, 50, 75, 100], |
|
ticktext=["100%", "75%", "50%", "25%", "0%", "25%", "50%", "75%", "100%"], |
|
zeroline=True, |
|
zerolinewidth=2, |
|
zerolinecolor="black", |
|
), |
|
yaxis=dict(title="", autorange="reversed"), |
|
plot_bgcolor="rgba(240, 240, 240, 0.8)", |
|
height=400, |
|
margin=dict(l=20, r=20, t=80, b=20), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
return fig |
|
|
|
|
|
def create_victory_badge(winner_device: str, loser_device: str, win_percentage: float): |
|
"""Create a stylized victory badge.""" |
|
badge_color = ( |
|
"#FFD700" |
|
if win_percentage >= 75 |
|
else "#C0C0C0" if win_percentage >= 50 else "#CD7F32" |
|
) |
|
badge_text = ( |
|
"DOMINANT VICTORY" |
|
if win_percentage >= 75 |
|
else "CLEAR WINNER" if win_percentage >= 50 else "NARROW VICTORY" |
|
) |
|
|
|
html = f""" |
|
<div style="display: flex; justify-content: center; margin: 20px 0;"> |
|
<div style=" |
|
background: linear-gradient(135deg, {badge_color} 0%, #FFFFFF 50%, {badge_color} 100%); |
|
border-radius: 16px; |
|
padding: 20px; |
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2); |
|
text-align: center; |
|
border: 2px solid {badge_color}; |
|
max-width: 90%; |
|
"> |
|
<div style="font-size: 24px; font-weight: bold; margin-bottom: 8px; font-family: 'Arial Black', sans-serif;"> |
|
🏆 {badge_text} 🏆 |
|
</div> |
|
<div style="font-size: 18px; font-weight: bold; color: #333;"> |
|
{winner_device} |
|
</div> |
|
<div style="font-size: 14px; margin: 8px 0;"> |
|
defeated |
|
</div> |
|
<div style="font-size: 16px; color: #555;"> |
|
{loser_device} |
|
</div> |
|
<div style="font-size: 20px; font-weight: bold; margin-top: 8px; color: #333;"> |
|
{win_percentage:.1f}% Win Rate |
|
</div> |
|
</div> |
|
</div> |
|
""" |
|
return html |
|
|
|
|
|
def create_model_performance_chart( |
|
matches_df, device1, device2, device1_display, device2_display, top_n=8 |
|
): |
|
"""Create an improved model performance comparison chart with vertical models and side-by-side bars.""" |
|
|
|
token_cols = ["Model", "Token Generation 1", "Token Generation 2"] |
|
prompt_cols = ["Model", "Prompt Processing 1", "Prompt Processing 2"] |
|
|
|
|
|
if not all(col in matches_df.columns for col in token_cols + prompt_cols[1:]): |
|
return None |
|
|
|
|
|
agg_dict = { |
|
"Token Generation 1": "mean", |
|
"Token Generation 2": "mean", |
|
"Prompt Processing 1": "mean", |
|
"Prompt Processing 2": "mean", |
|
"Model File Size": "first", |
|
} |
|
|
|
|
|
grouped = matches_df.groupby("Model").agg(agg_dict).reset_index() |
|
|
|
|
|
grouped = grouped.sort_values("Model File Size", ascending=False) |
|
|
|
|
|
if len(grouped) > top_n: |
|
grouped = grouped.head(top_n) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
models = grouped["Model"].tolist() |
|
|
|
token_gen_1 = grouped["Token Generation 1"].tolist() |
|
token_gen_2 = grouped["Token Generation 2"].tolist() |
|
prompt_proc_1 = grouped["Prompt Processing 1"].tolist() |
|
prompt_proc_2 = grouped["Prompt Processing 2"].tolist() |
|
|
|
|
|
fig.add_trace( |
|
go.Bar( |
|
x=token_gen_1, |
|
y=models, |
|
name=f"{device1_display} Token Gen", |
|
orientation="h", |
|
marker=dict(color="rgba(58, 71, 180, 0.8)"), |
|
hovertemplate="%{y}<br>%{x:.2f} tokens/sec<extra></extra>", |
|
legendgroup="device1", |
|
offsetgroup=1, |
|
xaxis="x", |
|
) |
|
) |
|
|
|
fig.add_trace( |
|
go.Bar( |
|
x=token_gen_2, |
|
y=models, |
|
name=f"{device2_display} Token Gen", |
|
orientation="h", |
|
marker=dict(color="rgba(231, 99, 99, 0.8)"), |
|
hovertemplate="%{y}<br>%{x:.2f} tokens/sec<extra></extra>", |
|
legendgroup="device2", |
|
offsetgroup=2, |
|
xaxis="x", |
|
) |
|
) |
|
|
|
|
|
fig.add_trace( |
|
go.Bar( |
|
x=prompt_proc_1, |
|
y=models, |
|
name=f"{device1_display} Prompt Proc", |
|
orientation="h", |
|
marker=dict(color="rgba(58, 71, 180, 0.4)"), |
|
hovertemplate="%{y}<br>%{x:.2f} tokens/sec<extra></extra>", |
|
legendgroup="device1", |
|
offsetgroup=1, |
|
xaxis="x2", |
|
showlegend=False, |
|
) |
|
) |
|
|
|
fig.add_trace( |
|
go.Bar( |
|
x=prompt_proc_2, |
|
y=models, |
|
name=f"{device2_display} Prompt Proc", |
|
orientation="h", |
|
marker=dict(color="rgba(231, 99, 99, 0.4)"), |
|
hovertemplate="%{y}<br>%{x:.2f} tokens/sec<extra></extra>", |
|
legendgroup="device2", |
|
offsetgroup=2, |
|
xaxis="x2", |
|
showlegend=False, |
|
) |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title_text="📊 Performance Breakdown by Model", |
|
grid=dict(rows=1, columns=2, pattern="independent"), |
|
legend=dict(orientation="h", yanchor="bottom", y=1.12, xanchor="right", x=1), |
|
height=max( |
|
350, 50 * len(models) + 120 |
|
), |
|
margin=dict(l=20, r=20, t=80, b=50), |
|
xaxis=dict( |
|
title="Token Generation (tokens/sec)", side="bottom", domain=[0, 0.48] |
|
), |
|
xaxis2=dict( |
|
title="Prompt Processing (tokens/sec)", side="bottom", domain=[0.52, 1] |
|
), |
|
yaxis=dict(title="", autorange="reversed"), |
|
) |
|
|
|
|
|
fig.add_shape( |
|
type="line", |
|
x0=0.5, |
|
y0=0, |
|
x1=0.5, |
|
y1=1, |
|
xref="paper", |
|
yref="paper", |
|
line=dict(color="rgba(0,0,0,0.2)", width=1, dash="dash"), |
|
) |
|
|
|
|
|
fig.add_annotation( |
|
x=0.4, |
|
y=1.08, |
|
xanchor="right", |
|
xref="paper", |
|
yref="paper", |
|
text="Token Generation", |
|
showarrow=False, |
|
font=dict( |
|
size=14, |
|
color="rgba(58, 71, 180, 1.0)", |
|
family="Arial, sans-serif", |
|
weight="bold", |
|
), |
|
) |
|
|
|
fig.add_annotation( |
|
x=0.6, |
|
y=1.08, |
|
xanchor="left", |
|
xref="paper", |
|
yref="paper", |
|
text="Prompt Processing", |
|
showarrow=False, |
|
font=dict( |
|
size=14, |
|
color="rgba(231, 99, 99, 1.0)", |
|
family="Arial, sans-serif", |
|
weight="bold", |
|
), |
|
) |
|
|
|
|
|
fig.update_yaxes( |
|
tickfont=dict(size=12, family="Arial, sans-serif"), gridcolor="rgba(0,0,0,0.05)" |
|
) |
|
|
|
return fig |
|
|
|
|
|
def render_device_comparison(df: pd.DataFrame, normalized_device_ids: List[str]): |
|
""" |
|
Render a component for comparing two devices and analyzing their matches. |
|
|
|
Args: |
|
df: DataFrame containing benchmark data |
|
normalized_device_ids: List of normalized device IDs to select from |
|
""" |
|
st.title("⚔️ Device Duel Arena") |
|
|
|
|
|
st.markdown( |
|
""" |
|
<div style="text-align: center; padding: 10px; margin-bottom: 20px; |
|
background: linear-gradient(135deg, #f6f8fa 0%, #e9ecef 100%); |
|
border-radius: 10px; border: 1px solid #dee2e6;"> |
|
<p style="font-size: 16px; font-style: italic; color: #495057;"> |
|
Welcome to the arena where devices face off in direct comparison! |
|
Choose any two and see how they stack up. |
|
</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
device_display_names = { |
|
device_id: clean_device_id(device_id) for device_id in normalized_device_ids |
|
} |
|
|
|
|
|
sorted_device_ids = sorted( |
|
normalized_device_ids, key=lambda x: device_display_names[x].lower() |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.device-select-header { |
|
font-weight: bold; |
|
font-size: 18px; |
|
margin-bottom: 10px; |
|
text-align: center; |
|
padding: 5px; |
|
border-radius: 5px; |
|
} |
|
.device1-header { |
|
background-color: rgba(58, 71, 180, 0.2); |
|
border-left: 4px solid rgba(58, 71, 180, 1.0); |
|
} |
|
.device2-header { |
|
background-color: rgba(231, 99, 99, 0.2); |
|
border-left: 4px solid rgba(231, 99, 99, 1.0); |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
col1, vs_col, col2 = st.columns([0.45, 0.1, 0.45]) |
|
|
|
with vs_col: |
|
st.markdown( |
|
""" |
|
<div style="display: flex; height: 100%; align-items: center; justify-content: center;"> |
|
<div style="font-size: 24px; font-weight: bold; color: #555;">VS</div> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
with col1: |
|
st.markdown( |
|
'<div class="device-select-header device1-header">CHALLENGER</div>', |
|
unsafe_allow_html=True, |
|
) |
|
device1 = st.selectbox( |
|
"First Device", |
|
options=sorted_device_ids, |
|
format_func=lambda x: device_display_names[x], |
|
key="device_compare_1", |
|
index=None, |
|
placeholder="Select a device ...", |
|
) |
|
|
|
with col2: |
|
st.markdown( |
|
'<div class="device-select-header device2-header">OPPONENT</div>', |
|
unsafe_allow_html=True, |
|
) |
|
device2 = st.selectbox( |
|
"Second Device", |
|
options=sorted_device_ids, |
|
format_func=lambda x: device_display_names[x], |
|
key="device_compare_2", |
|
index=None, |
|
placeholder="Select a device ...", |
|
) |
|
|
|
|
|
button_col1, button_col2, button_col3 = st.columns([0.3, 0.4, 0.3]) |
|
with button_col2: |
|
duel_button = st.button( |
|
"️Start", |
|
key="analyze_matches_btn", |
|
use_container_width=True, |
|
) |
|
|
|
if duel_button: |
|
|
|
if not device1 or not device2: |
|
st.error("Please select two devices to battle!") |
|
return |
|
elif device1 == device2: |
|
st.error("Please select two different devices to compare.") |
|
return |
|
|
|
|
|
st.markdown( |
|
""" |
|
<div style="text-align: center; margin: 20px 0;"> |
|
<div style="font-size: 24px; font-weight: bold; color: #333;">⚔️ BATTLE RESULTS ⚔️</div> |
|
<div style="height: 4px; background: linear-gradient(90deg, rgba(58,71,180,1) 0%, rgba(231,99,99,1) 100%); margin: 10px 0;"></div> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
with st.spinner( |
|
f"⚔️ Battle in progress between {device_display_names[device1]} and {device_display_names[device2]}..." |
|
): |
|
try: |
|
|
|
matches_df = analyze_device_glicko2_matches(df, device1, device2) |
|
|
|
if not matches_df.empty: |
|
|
|
total_matches = len(matches_df) |
|
|
|
|
|
if ( |
|
"Token Winner" in matches_df.columns |
|
and "Prompt Winner" in matches_df.columns |
|
and "Combined Winner" in matches_df.columns |
|
): |
|
token_wins_1 = sum(matches_df["Token Winner"] == device1) |
|
prompt_wins_1 = sum(matches_df["Prompt Winner"] == device1) |
|
combined_wins_1 = sum(matches_df["Combined Winner"] == device1) |
|
|
|
|
|
st.markdown( |
|
f""" |
|
<div style="text-align: center; padding: 10px; background-color: #f8f9fa; |
|
border-radius: 5px; margin: 10px 0; border: 1px solid #dee2e6;"> |
|
<span style="font-size: 16px; font-weight: bold;">Total Matches: {total_matches}</span> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
winner_device = ( |
|
device1 if combined_wins_1 > total_matches / 2 else device2 |
|
) |
|
loser_device = device2 if winner_device == device1 else device1 |
|
|
|
winner_display = device_display_names[winner_device] |
|
loser_display = device_display_names[loser_device] |
|
|
|
win_percentage = ( |
|
(combined_wins_1 / total_matches * 100) |
|
if winner_device == device1 |
|
else ( |
|
(total_matches - combined_wins_1) / total_matches * 100 |
|
) |
|
) |
|
|
|
st.markdown( |
|
create_victory_badge( |
|
winner_display, loser_display, win_percentage |
|
), |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
battle_fig = create_head_to_head_battle_chart( |
|
device1, |
|
device2, |
|
device_display_names[device1], |
|
device_display_names[device2], |
|
token_wins_1, |
|
prompt_wins_1, |
|
combined_wins_1, |
|
total_matches, |
|
) |
|
|
|
st.plotly_chart(battle_fig, use_container_width=True) |
|
|
|
|
|
model_performance_chart = create_model_performance_chart( |
|
matches_df, |
|
device1, |
|
device2, |
|
device_display_names[device1], |
|
device_display_names[device2], |
|
) |
|
|
|
if model_performance_chart: |
|
st.plotly_chart( |
|
model_performance_chart, use_container_width=True |
|
) |
|
|
|
|
|
with st.expander("View Detailed Match Results", expanded=False): |
|
st.markdown("#### All Match Data") |
|
|
|
|
|
display_cols = [ |
|
"Model", |
|
"Token Generation 1", |
|
"Token Generation 2", |
|
"Token Winner", |
|
"Token Win Prob", |
|
"Prompt Processing 1", |
|
"Prompt Processing 2", |
|
"Prompt Winner", |
|
"Prompt Win Prob", |
|
"Combined Winner", |
|
"Combined Win Prob", |
|
"Platform 1", |
|
"Platform 2", |
|
] |
|
|
|
|
|
valid_cols = [ |
|
col for col in display_cols if col in matches_df.columns |
|
] |
|
|
|
if valid_cols: |
|
|
|
matches_display = matches_df[valid_cols].copy() |
|
|
|
|
|
rename_mapping = { |
|
"Token Generation 1": f"{device_display_names[device1]} Token Gen", |
|
"Token Generation 2": f"{device_display_names[device2]} Token Gen", |
|
"Prompt Processing 1": f"{device_display_names[device1]} Prompt Proc", |
|
"Prompt Processing 2": f"{device_display_names[device2]} Prompt Proc", |
|
"Platform 1": f"{device_display_names[device1]} Platform", |
|
"Platform 2": f"{device_display_names[device2]} Platform", |
|
"Token Win Prob": "Device 1 Token Win Prob", |
|
"Prompt Win Prob": "Device 1 Prompt Win Prob", |
|
"Combined Win Prob": "Device 1 Combined Win Prob", |
|
} |
|
|
|
|
|
rename_filtered = { |
|
k: v |
|
for k, v in rename_mapping.items() |
|
if k in matches_display.columns |
|
} |
|
matches_display = matches_display.rename( |
|
columns=rename_filtered |
|
) |
|
|
|
|
|
for col in matches_display.columns: |
|
if matches_display[col].dtype in [ |
|
"float64", |
|
"float32", |
|
]: |
|
matches_display[col] = matches_display[ |
|
col |
|
].round(2) |
|
|
|
st.dataframe( |
|
matches_display, |
|
use_container_width=True, |
|
height=400, |
|
) |
|
else: |
|
st.warning( |
|
"No valid columns found for display in the match data." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
st.warning("Winner information is missing from the match data.") |
|
else: |
|
st.error( |
|
f"No matches found between {device_display_names[device1]} and {device_display_names[device2]}." |
|
) |
|
st.info( |
|
"Try selecting different devices or checking if they both have benchmark data for the same models." |
|
) |
|
except Exception as e: |
|
st.error(f"An error occurred during match analysis: {str(e)}") |
|
st.info("Please try with different devices.") |
|
|