def create_gmm_plot(dataset, centers, K, samples=None): """创建GMM分布的可视化图形""" # 生成网格数据 x = np.linspace(-5, 5, 100) y = np.linspace(-5, 5, 100) X, Y = np.meshgrid(x, y) xy = np.column_stack((X.ravel(), Y.ravel())) # 计算概率密度 Z = dataset.pdf(xy).reshape(X.shape) # 创建2D和3D可视化 fig = make_subplots( rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'contour'}]], subplot_titles=('3D概率密度曲面', '等高线图与分量中心') ) # 3D Surface surface = go.Surface( x=X, y=Y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=0.45) ) fig.add_trace(surface, row=1, col=1) # Contour Plot contour = go.Contour( x=x, y=y, z=Z, colorscale='viridis', showscale=True, colorbar=dict(x=1.0), contours=dict( showlabels=True, labelfont=dict(size=12) ) ) fig.add_trace(contour, row=1, col=2) # 添加分量中心点 fig.add_trace( go.Scatter( x=centers[:K, 0], y=centers[:K, 1], mode='markers+text', marker=dict(size=10, color='red'), text=[f'C{i+1}' for i in range(K)], textposition="top center", name='分量中心' ), row=1, col=2 ) # 添加采样点(如果有) if samples is not None: fig.add_trace( go.Scatter( x=samples[:, 0], y=samples[:, 1], mode='markers+text', marker=dict( size=8, color='yellow', line=dict(color='black', width=1) ), text=[f'S{i+1}' for i in range(len(samples))], textposition="bottom center", name='采样点' ), row=1, col=2 ) # 更新布局 fig.update_layout( title='广义高斯混合分布', showlegend=True, width=1200, height=600, scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='密度' ) ) # 更新2D图的坐标轴 fig.update_xaxes(title_text='X', row=1, col=2) fig.update_yaxes(title_text='Y', row=1, col=2) return fig