npb_data_viz_demo_backup / gradio_function.py
patrickramos's picture
Update app
26c325e
raw
history blame
10.3 kB
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc
from scipy.stats import gaussian_kde
import numpy as np
import gradio as gr
from gradio_client import Client
from scipy.stats import gaussian_kde
import numpy as np
import os
import re
from translate import translate_pa_outcome, translate_pitch_outcome, jp_pitch_to_en_pitch, jp_pitch_to_pitch_code, translate_pitch_outcome, max_pitch_types
# load game data
game_df = pd.read_csv('game.csv').drop_duplicates()
assert len(game_df) == len(game_df['game_pk'].unique())
# load pa data
pa_df = []
for game_pk in tqdm(game_df['game_pk']):
pa_df.append(pd.read_csv(os.path.join('pa', f'{game_pk}.csv'), dtype={'pa_pk': str}))
pa_df = pd.concat(pa_df, axis='rows')
# load pitch data
pitch_df = []
for game_pk in tqdm(game_df['game_pk']):
pitch_df.append(pd.read_csv(os.path.join('pitch', f'{game_pk}.csv'), dtype={'pa_pk': str}))
pitch_df = pd.concat(pitch_df, axis='rows')
pitch_df
# load player data
player_df = pd.read_csv('player.csv')
player_df
# translate pa data
pa_df['_des'] = pa_df['des'].str.strip()
pa_df['des'] = pa_df['des'].str.strip()
pa_df['des_more'] = pa_df['des_more'].str.strip()
pa_df.loc[pa_df['des'].isna(), 'des'] = pa_df[pa_df['des'].isna()]['des_more']
pa_df.loc[:, 'des'] = pa_df['des'].apply(lambda item: item.split()[0] if (len(item.split()) > 1 and re.search(r'+\d+点', item)) else item)
non_home_plate_outcome = (pa_df['des'].isin(['ボール', '見逃し', '空振り'])) | (pa_df['des'].str.endswith('塁けん制'))
pa_df.loc[non_home_plate_outcome, 'des'] = pa_df.loc[non_home_plate_outcome, 'des_more']
pa_df['des'] = pa_df['des'].apply(translate_pa_outcome)
# translate pitch data
pitch_df = pitch_df[~pitch_df['pitch_name'].isna()]
pitch_df['jp_pitch_name'] = pitch_df['pitch_name']
pitch_df['pitch_name'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_en_pitch[pitch_name])
pitch_df['pitch_type'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_pitch_code[pitch_name])
pitch_df['description'] = pitch_df['description'].apply(lambda item: item.split()[0] if len(item.split()) > 1 else item)
pitch_df['description'] = pitch_df['description'].apply(translate_pitch_outcome)
pitch_df['release_speed'] = pitch_df['release_speed'].replace('-', np.nan)
pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'] = pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'].str.removesuffix('km/h').astype(int)
pitch_df['plate_x'] = (pitch_df['plate_x'] + 13) - 80
pitch_df['plate_z'] = 200 - (pitch_df['plate_z'] + 13) - 100
# translate player data
client = Client("Ramos-Ramos/npb_name_translator")
en_names = client.predict(
jp_names='\n'.join(player_df.name.tolist()),
api_name="/predict"
)
player_df['jp_name'] = player_df['name']
player_df['name'] = [name if name != 'nan' else np.nan for name in en_names.splitlines()]
# merge pitch and pa data
df = pd.merge(pitch_df, pa_df, 'inner', on=['game_pk', 'pa_pk'])
df = pd.merge(df, player_df.rename(columns={'player_id': 'pitcher'}), 'inner', on='pitcher')
df['whiff'] = df['description'].isin(['SS', 'K'])
df['swing'] = ~df['description'].isin(['B', 'BB', 'LS', 'inv_K', 'bunt_K', 'HBP', 'SH', 'SH E', 'SH FC', 'obstruction', 'illegal_pitch', 'defensive_interference'])
df['csw'] = df['description'].isin(['SS', 'K', 'LS', 'inv_K'])
df['normal_pitch'] = ~df['description'].isin(['obstruction', 'illegal_pitch', 'defensive_interference']) # guess
whiff_rate = df.groupby(['name', 'pitch_name'])
whiff_rate = (whiff_rate['whiff'].sum() / whiff_rate['swing'].sum() * 100).round(1).rename('Whiff%').reset_index()
csw_rate = df.groupby(['name', 'pitch_name'])
csw_rate = (csw_rate['csw'].sum() / csw_rate['normal_pitch'].sum() * 100).round(1).rename('CSW%').reset_index()
pitch_stats = pd.merge(
whiff_rate,
csw_rate,
on=['name', 'pitch_name']
).set_index(['name', 'pitch_name'])
# GRADIO FUNCTIONS
# location maps
def fit_pred_kde(data, X, Y):
kde = gaussian_kde(data)
return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape)
plot_s = 256
sz_h = 200
sz_w = 160
h_h = 200 - 40*2
h_w = 160 - 32*2
kde_range = np.arange(-plot_s/2, plot_s/2, 1)
X, Y = np.meshgrid(
kde_range,
kde_range
)
def coordinatify(h, w):
return dict(
x0=-w/2,
y0=-h/2,
x1=w/2,
y1=h/2
)
colorscale = pc.sequential.OrRd
colorscale = [
[0, 'rgba(0, 0, 0, 0)'],
] + [
[i / len(colorscale), color] for i, color in enumerate(colorscale, start=1)
]
def plot_pitch_map(player=None, loc=None, pitch_type=None, pitch_name=None):
assert not ((loc is None and player is None) or (loc is not None and player is not None)), 'exactly one of `player` or `loc` must be specified'
if loc is None and player is not None:
assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified'
pitch_val = pitch_type or pitch_name
pitch_col = 'pitch_type' if pitch_type else 'pitch_name'
loc = df.set_index(['name', pitch_col]).loc[(player, pitch_val), ['plate_x', 'plate_z']]
Z = fit_pred_kde(loc.to_numpy().T, X, Y)
fig = go.Figure()
fig.add_shape(
type="rect",
**coordinatify(sz_h, sz_w),
line_color='gray',
# fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro
)
fig.add_shape(
type="rect",
**coordinatify(h_h, h_w),
line_color='dimgray',
)
fig.add_trace(go.Contour(
z=Z,
x=kde_range,
y=kde_range,
colorscale=colorscale,
zmin=1e-5,
zmax=Z.max(),
contours={
'start': 1e-5,
'end': Z.max(),
'size': (Z.max() - 1e-5) / 5
},
showscale=False
))
fig.update_layout(
xaxis=dict(range=[-plot_s/2, plot_s/2+1]),
yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1),
# width=384,
# height=384
)
return fig
def plot_empty_pitch_map():
fig = go.Figure()
fig.add_annotation(
x=0,
y=0,
text='No visualization<br>as less than 10 pitches thrown',
showarrow=False
)
fig.update_layout(
xaxis=dict(range=[-plot_s/2, plot_s/2+1]),
yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1),
# width=384,
# height=384
)
return fig
# velo distribution
def plot_pitch_velo(player=None, velos=None, pitch_type=None, pitch_name=None):
assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `loc` must be specified'
if velos is None and player is not None:
assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified'
pitch_val = pitch_type or pitch_name
pitch_col = 'pitch_type' if pitch_type else 'pitch_name'
velos = df.set_index(['name', pitch_col]).loc[(player, pitch_val), 'release_speed']
fig = go.Figure(data=go.Violin(x=velos, side='positive', hoveron='points', name='Velocity Distribution'))
fig.update_layout(
xaxis=dict(
title='Velocity',
range=[125, 170],
scaleratio=2
),
yaxis=dict(
title='Frequency',
range=[0, 0.3],
scaleanchor='x',
scaleratio=1,
tickvals=np.linspace(0, 0.3, 3),
ticktext=np.linspace(0, 0.3, 3),
),
autosize=True,
# width=512,
# height=256,
modebar_remove=['zoom', 'autoScale', 'resetScale'],
)
return fig
def plot_empty_pitch_velo():
fig = go.Figure()
fig.add_annotation(
x=(170+125)/2,
y=0.3/2,
text='No visualization<br>as less than 10 pitches thrown',
showarrow=False
)
fig.update_layout(
xaxis=dict(
title='Velocity',
range=[125, 170],
scaleratio=2
),
yaxis=dict(
title='Frequency',
range=[0, 0.3],
scaleanchor='x',
scaleratio=1,
tickvals=np.linspace(0, 0.3, 3),
ticktext=np.linspace(0, 0.3, 3),
),
autosize=True,
# width=512,
# height=256,
modebar_remove=['zoom', 'autoScale', 'resetScale'],
)
return fig
def get_data(player):
player_name = f'# {player}'
_df = df.set_index('name').loc[player]
_df_by_pitch_name = _df.set_index('pitch_name')
usage_fig = px.pie(_df['pitch_name'], names='pitch_name')
usage_fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'<b>{player}</b><br>' + 'threw a <b>%{label}</b><br><b>%{percent:.1%}</b> of the time (<b>%{value}</b> pitches)')
pitch_counts = _df['pitch_name'].value_counts()
pitch_groups = []
pitch_names = []
pitch_infos = []
pitch_velos = []
pitch_maps = []
for pitch_name, count in pitch_counts.items():
pitch_groups.append(gr.update(visible=True))
pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True))
pitch_infos.append(gr.update(
value=pd.DataFrame([{
'Whiff%': pitch_stats.loc[(player, pitch_name), 'Whiff%'].item(),
'CSW%': pitch_stats.loc[(player, pitch_name), 'CSW%'].item()
}]),
visible=True
))
if count > 10:
pitch_velos.append(gr.update(
value=plot_pitch_velo(velos=_df_by_pitch_name.loc[pitch_name, 'release_speed']),
visible=True
))
pitch_maps.append(gr.update(value=plot_pitch_map(player, pitch_name=pitch_name), label='Pitch location', visible=True))
else:
pitch_velos.append(gr.update(value=plot_empty_pitch_velo(),visible=True ))
pitch_maps.append(gr.update(value=plot_empty_pitch_map(), label=pitch_name, visible=True))
for _ in range(max_pitch_types - len(pitch_names)):
pitch_groups.append(gr.update(visible=False))
pitch_names.append(gr.update(value=None, visible=False))
pitch_infos.append(gr.update(value=None, visible=False))
for _ in range(max_pitch_types - len(pitch_maps)):
pitch_velos.append(gr.update(value=None, visible=False))
pitch_maps.append(gr.update(value=None, visible=False))
return player_name, usage_fig, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps