2catycm's picture
feat: 基本画完,缺乏大小标尺
4a7325a
import plotly.graph_objects as go
import pandas as pd
import numpy as np
# 生成符合学术规范的模拟数据
np.random.seed(42) # 确保可重复性
methods = ['RSM', 'PSRN', 'NGGP', 'PySR', 'BMS', 'uDSR', 'AIF',
'DGSR', 'E2E', 'SymINDy', 'physo', 'TPSR', 'SPL',
'DEAP', 'SINDy', 'NSRS', 'gplearn', 'SNIP', 'KAN', 'EQL']
# 生成0-1之间的RMSE数据(保持原图分布模式)
rmse_values = np.clip(np.abs(np.random.normal(0.3, 0.15, len(methods))), 0.05, 0.95)
uncertainty = np.random.uniform(0.02, 0.08, len(methods))
param_sizes = np.array([1e3, 1e4, 5e4, 1e5, 5e5, 1e6, 2e6]) # 定义标准参数规模
# 构建分类系统(基于方法原理)
method_categories = {
'Symbolic': ['RSM', 'PySR', 'SymINDy', 'gplearn', 'EQL'],
'Neural': ['NGGP', 'DGSR', 'E2E', 'KAN'],
'Evolutionary': ['PSRN', 'DEAP', 'NSRS'],
'Physics-based': ['physo', 'TPSR', 'SPL'],
'Hybrid': ['BMS', 'uDSR', 'AIF', 'SINDy', 'SNIP']
}
# 创建数据框架
df = pd.DataFrame({
'Method': methods,
'RMSE': rmse_values,
'Uncertainty': uncertainty,
'ParamSize': np.random.choice(param_sizes, len(methods))
}).sort_values('RMSE', ascending=True)
# 添加分类信息
df['Category'] = df['Method'].apply(lambda x: next((k for k,v in method_categories.items() if x in v), 'Other'))
# 颜色映射系统(学术级调色板)
category_colors = {
'Symbolic': '#1f77b4',
'Neural': '#ff7f0e',
'Evolutionary': '#2ca02c',
'Physics-based': '#d62728',
'Hybrid': '#9467bd'
}
# 创建基础图表
fig = go.Figure()
# 动态尺寸计算系统
size_min = np.log(df['ParamSize'].min())
size_max = np.log(df['ParamSize'].max())
sizes = 15 + 25 * (np.log(df['ParamSize']) - size_min) / (size_max - size_min) # 动态尺寸范围
# 添加主数据轨迹
for category in df['Category'].unique():
df_sub = df[df['Category'] == category]
fig.add_trace(go.Scatter(
x=df_sub['RMSE'],
y=df_sub['Method'],
mode='markers',
name=category,
marker=dict(
size=sizes[df_sub.index],
color=category_colors[category],
opacity=0.9,
line=dict(width=1, color='black')
),
error_x=dict(
type='data',
array=df_sub['Uncertainty'],
color='rgba(40,40,40,0.6)',
thickness=1.2,
width=10
),
hoverinfo='text',
hovertext=df_sub.apply(lambda r: f"{r['Method']}<br>RMSE: {r['RMSE']:.3f} ± {r['Uncertainty']:.3f}<br>Params: {r['ParamSize']:,.0f}", axis=1)
))
# 动态轴范围计算
data_min = (df['RMSE'] - df['Uncertainty']).min()
x_min = max(data_min - 0.05, 0) # 保证最小值不低于0
x_max = min(df['RMSE'].max() + df['Uncertainty'].max(), 1) # 保证最大值不超过1
# 专业级布局配置
fig.update_layout(
title='Methods RMSE Comparison with Parameter Scale',
xaxis=dict(
title='Root Mean Square Error (RMSE) → Lower is better',
range=[x_min, x_max],
tickvals=np.arange(0, 1.1, 0.1),
gridcolor='#F0F0F0',
zeroline=False,
showspikes=True
),
yaxis=dict(
categoryorder='array',
categoryarray=df['Method'].tolist(),
tickfont=dict(size=12),
showticklabels=False # 禁用默认标签
),
plot_bgcolor='white',
width=1100,
height=700,
margin=dict(l=180, r=50, t=80, b=40),
legend=dict(
title='Method Categories',
orientation='v',
yanchor="top",
y=0.98,
xanchor="left",
x=1.02
)
)
# 添加自定义y轴标签(分类着色)
y_positions = np.linspace(0.03, 0.97, len(methods)) # 动态计算标签位置
for idx, method in enumerate(df['Method']):
category = df[df['Method'] == method]['Category'].values[0]
fig.add_annotation(
x=0.01,
y=y_positions[idx],
xref='paper',
yref='paper',
text=method,
showarrow=False,
font=dict(
size=12,
color=category_colors[category]
),
xanchor='right'
)
# 添加专业级尺寸图例系统
# size_legend_values = [1e3, 1e4, 1e5, 1e6] # 定义标准参数规模
# size_legend_sizes = 15 + 25 * (np.log(size_legend_values) - size_min) / (size_max - size_min)
# fig.add_trace(go.Scatter(
# x=[0.52, 0.55, 0.58, 0.61],
# y=np.array([0.00, 0.05, 0.10, 0.15]),
# mode='markers',
# marker=dict(
# size=size_legend_sizes,
# color='#444444',
# opacity=0.8
# ),
# showlegend=False,
# text=[f'{size:.2e}' for size in size_legend_values],
# ))
# # 添加尺寸图例标注
# size_labels = ['1K', '10K', '100K', '1M']
# for i, (x, y, label) in enumerate(zip([0.95]*4, [0.15,0.20,0.25,0.30], size_labels)):
# fig.add_annotation(
# x=x,
# y=y,
# xref="paper",
# yref="paper",
# text=label,
# showarrow=False,
# font=dict(size=10),
# xanchor='left'
# )
# | Parameters (log scale)
# 添加最终标注
fig.add_annotation(
x=0.98, y=0.02,
xref='paper', yref='paper',
text='叶璨铭绘制',
showarrow=False,
font=dict(size=10, color='#666666'),
bgcolor='white'
)
fig.show()