HyperPapers / utils /gmm_vis.py
2catycm's picture
feat: top k top p
cbbacc3
def create_gmm_plot(dataset, centers, K, samples=None):
"""创建GMM分布的可视化图形"""
# 生成网格数据
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
xy = np.column_stack((X.ravel(), Y.ravel()))
# 计算概率密度
Z = dataset.pdf(xy).reshape(X.shape)
# 创建2D和3D可视化
fig = make_subplots(
rows=1, cols=2,
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
)
# 3D Surface
surface = go.Surface(
x=X, y=Y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=0.45)
)
fig.add_trace(surface, row=1, col=1)
# Contour Plot
contour = go.Contour(
x=x, y=y, z=Z,
colorscale='viridis',
showscale=True,
colorbar=dict(x=1.0),
contours=dict(
showlabels=True,
labelfont=dict(size=12)
)
)
fig.add_trace(contour, row=1, col=2)
# 添加分量中心点
fig.add_trace(
go.Scatter(
x=centers[:K, 0], y=centers[:K, 1],
mode='markers+text',
marker=dict(size=10, color='red'),
text=[f'C{i+1}' for i in range(K)],
textposition="top center",
name='分量中心'
),
row=1, col=2
)
# 添加采样点(如果有)
if samples is not None:
fig.add_trace(
go.Scatter(
x=samples[:, 0], y=samples[:, 1],
mode='markers+text',
marker=dict(
size=8,
color='yellow',
line=dict(color='black', width=1)
),
text=[f'S{i+1}' for i in range(len(samples))],
textposition="bottom center",
name='采样点'
),
row=1, col=2
)
# 更新布局
fig.update_layout(
title='广义高斯混合分布',
showlegend=True,
width=1200,
height=600,
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='密度'
)
)
# 更新2D图的坐标轴
fig.update_xaxes(title_text='X', row=1, col=2)
fig.update_yaxes(title_text='Y', row=1, col=2)
return fig