Spaces:
Sleeping
Sleeping
def create_gmm_plot(dataset, centers, K, samples=None): | |
"""创建GMM分布的可视化图形""" | |
# 生成网格数据 | |
x = np.linspace(-5, 5, 100) | |
y = np.linspace(-5, 5, 100) | |
X, Y = np.meshgrid(x, y) | |
xy = np.column_stack((X.ravel(), Y.ravel())) | |
# 计算概率密度 | |
Z = dataset.pdf(xy).reshape(X.shape) | |
# 创建2D和3D可视化 | |
fig = make_subplots( | |
rows=1, cols=2, | |
specs=[[{'type': 'surface'}, {'type': 'contour'}]], | |
subplot_titles=('3D概率密度曲面', '等高线图与分量中心') | |
) | |
# 3D Surface | |
surface = go.Surface( | |
x=X, y=Y, z=Z, | |
colorscale='viridis', | |
showscale=True, | |
colorbar=dict(x=0.45) | |
) | |
fig.add_trace(surface, row=1, col=1) | |
# Contour Plot | |
contour = go.Contour( | |
x=x, y=y, z=Z, | |
colorscale='viridis', | |
showscale=True, | |
colorbar=dict(x=1.0), | |
contours=dict( | |
showlabels=True, | |
labelfont=dict(size=12) | |
) | |
) | |
fig.add_trace(contour, row=1, col=2) | |
# 添加分量中心点 | |
fig.add_trace( | |
go.Scatter( | |
x=centers[:K, 0], y=centers[:K, 1], | |
mode='markers+text', | |
marker=dict(size=10, color='red'), | |
text=[f'C{i+1}' for i in range(K)], | |
textposition="top center", | |
name='分量中心' | |
), | |
row=1, col=2 | |
) | |
# 添加采样点(如果有) | |
if samples is not None: | |
fig.add_trace( | |
go.Scatter( | |
x=samples[:, 0], y=samples[:, 1], | |
mode='markers+text', | |
marker=dict( | |
size=8, | |
color='yellow', | |
line=dict(color='black', width=1) | |
), | |
text=[f'S{i+1}' for i in range(len(samples))], | |
textposition="bottom center", | |
name='采样点' | |
), | |
row=1, col=2 | |
) | |
# 更新布局 | |
fig.update_layout( | |
title='广义高斯混合分布', | |
showlegend=True, | |
width=1200, | |
height=600, | |
scene=dict( | |
xaxis_title='X', | |
yaxis_title='Y', | |
zaxis_title='密度' | |
) | |
) | |
# 更新2D图的坐标轴 | |
fig.update_xaxes(title_text='X', row=1, col=2) | |
fig.update_yaxes(title_text='Y', row=1, col=2) | |
return fig |