import plotly.graph_objects as go import pandas as pd import numpy as np # 生成符合学术规范的模拟数据 np.random.seed(42) # 确保可重复性 methods = ['RSM', 'PSRN', 'NGGP', 'PySR', 'BMS', 'uDSR', 'AIF', 'DGSR', 'E2E', 'SymINDy', 'physo', 'TPSR', 'SPL', 'DEAP', 'SINDy', 'NSRS', 'gplearn', 'SNIP', 'KAN', 'EQL'] # 生成0-1之间的RMSE数据(保持原图分布模式) rmse_values = np.clip(np.abs(np.random.normal(0.3, 0.15, len(methods))), 0.05, 0.95) uncertainty = np.random.uniform(0.02, 0.08, len(methods)) param_sizes = np.array([1e3, 1e4, 5e4, 1e5, 5e5, 1e6, 2e6]) # 定义标准参数规模 # 构建分类系统(基于方法原理) method_categories = { 'Symbolic': ['RSM', 'PySR', 'SymINDy', 'gplearn', 'EQL'], 'Neural': ['NGGP', 'DGSR', 'E2E', 'KAN'], 'Evolutionary': ['PSRN', 'DEAP', 'NSRS'], 'Physics-based': ['physo', 'TPSR', 'SPL'], 'Hybrid': ['BMS', 'uDSR', 'AIF', 'SINDy', 'SNIP'] } # 创建数据框架 df = pd.DataFrame({ 'Method': methods, 'RMSE': rmse_values, 'Uncertainty': uncertainty, 'ParamSize': np.random.choice(param_sizes, len(methods)) }).sort_values('RMSE', ascending=True) # 添加分类信息 df['Category'] = df['Method'].apply(lambda x: next((k for k,v in method_categories.items() if x in v), 'Other')) # 颜色映射系统(学术级调色板) category_colors = { 'Symbolic': '#1f77b4', 'Neural': '#ff7f0e', 'Evolutionary': '#2ca02c', 'Physics-based': '#d62728', 'Hybrid': '#9467bd' } # 创建基础图表 fig = go.Figure() # 动态尺寸计算系统 size_min = np.log(df['ParamSize'].min()) size_max = np.log(df['ParamSize'].max()) sizes = 15 + 25 * (np.log(df['ParamSize']) - size_min) / (size_max - size_min) # 动态尺寸范围 # 添加主数据轨迹 for category in df['Category'].unique(): df_sub = df[df['Category'] == category] fig.add_trace(go.Scatter( x=df_sub['RMSE'], y=df_sub['Method'], mode='markers', name=category, marker=dict( size=sizes[df_sub.index], color=category_colors[category], opacity=0.9, line=dict(width=1, color='black') ), error_x=dict( type='data', array=df_sub['Uncertainty'], color='rgba(40,40,40,0.6)', thickness=1.2, width=10 ), hoverinfo='text', hovertext=df_sub.apply(lambda r: f"{r['Method']}
RMSE: {r['RMSE']:.3f} ± {r['Uncertainty']:.3f}
Params: {r['ParamSize']:,.0f}", axis=1) )) # 动态轴范围计算 data_min = (df['RMSE'] - df['Uncertainty']).min() x_min = max(data_min - 0.05, 0) # 保证最小值不低于0 x_max = min(df['RMSE'].max() + df['Uncertainty'].max(), 1) # 保证最大值不超过1 # 专业级布局配置 fig.update_layout( title='Methods RMSE Comparison with Parameter Scale', xaxis=dict( title='Root Mean Square Error (RMSE) → Lower is better', range=[x_min, x_max], tickvals=np.arange(0, 1.1, 0.1), gridcolor='#F0F0F0', zeroline=False, showspikes=True ), yaxis=dict( categoryorder='array', categoryarray=df['Method'].tolist(), tickfont=dict(size=12), showticklabels=False # 禁用默认标签 ), plot_bgcolor='white', width=1100, height=700, margin=dict(l=180, r=50, t=80, b=40), legend=dict( title='Method Categories', orientation='v', yanchor="top", y=0.98, xanchor="left", x=1.02 ) ) # 添加自定义y轴标签(分类着色) y_positions = np.linspace(0.03, 0.97, len(methods)) # 动态计算标签位置 for idx, method in enumerate(df['Method']): category = df[df['Method'] == method]['Category'].values[0] fig.add_annotation( x=0.01, y=y_positions[idx], xref='paper', yref='paper', text=method, showarrow=False, font=dict( size=12, color=category_colors[category] ), xanchor='right' ) # 添加专业级尺寸图例系统 # size_legend_values = [1e3, 1e4, 1e5, 1e6] # 定义标准参数规模 # size_legend_sizes = 15 + 25 * (np.log(size_legend_values) - size_min) / (size_max - size_min) # fig.add_trace(go.Scatter( # x=[0.52, 0.55, 0.58, 0.61], # y=np.array([0.00, 0.05, 0.10, 0.15]), # mode='markers', # marker=dict( # size=size_legend_sizes, # color='#444444', # opacity=0.8 # ), # showlegend=False, # text=[f'{size:.2e}' for size in size_legend_values], # )) # # 添加尺寸图例标注 # size_labels = ['1K', '10K', '100K', '1M'] # for i, (x, y, label) in enumerate(zip([0.95]*4, [0.15,0.20,0.25,0.30], size_labels)): # fig.add_annotation( # x=x, # y=y, # xref="paper", # yref="paper", # text=label, # showarrow=False, # font=dict(size=10), # xanchor='left' # ) # | Parameters (log scale) # 添加最终标注 fig.add_annotation( x=0.98, y=0.02, xref='paper', yref='paper', text='叶璨铭绘制', showarrow=False, font=dict(size=10, color='#666666'), bgcolor='white' ) fig.show()