Spaces:
Sleeping
Sleeping
File size: 2,390 Bytes
cbbacc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
def create_gmm_plot(dataset, centers, K, samples=None):
"""创建GMM分布的可视化图形"""
# 生成网格数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# 计算概率密度
Z = dataset.pdf(xy).reshape(X.shape)
# 创建2D和3D可视化
fig = make_subplots(
rows=1, cols=2,
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
)
# 3D Surface
surface = go.Surface(
x=X, y=Y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=0.45)
)
fig.add_trace(surface, row=1, col=1)
# Contour Plot
contour = go.Contour(
x=x, y=y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=1.0),
contours=dict(
showlabels=True,
labelfont=dict(size=12)
)
)
fig.add_trace(contour, row=1, col=2)
# 添加分量中心点
fig.add_trace(
go.Scatter(
x=centers[:K, 0], y=centers[:K, 1],
mode='markers+text',
marker=dict(size=10, color='red'),
text=[f'C{i+1}' for i in range(K)],
textposition="top center",
name='分量中心'
),
row=1, col=2
)
# 添加采样点(如果有)
if samples is not None:
fig.add_trace(
go.Scatter(
x=samples[:, 0], y=samples[:, 1],
mode='markers+text',
marker=dict(
size=8,
color='yellow',
line=dict(color='black', width=1)
),
text=[f'S{i+1}' for i in range(len(samples))],
textposition="bottom center",
name='采样点'
),
row=1, col=2
)
# 更新布局
fig.update_layout(
title='广义高斯混合分布',
showlegend=True,
width=1200,
height=600,
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='密度'
)
)
# 更新2D图的坐标轴
fig.update_xaxes(title_text='X', row=1, col=2)
fig.update_yaxes(title_text='Y', row=1, col=2)
return fig |