File size: 2,390 Bytes
cbbacc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def create_gmm_plot(dataset, centers, K, samples=None):
    """创建GMM分布的可视化图形"""
    # 生成网格数据
    x = np.linspace(-5, 5, 100)
    y = np.linspace(-5, 5, 100)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # 计算概率密度
    Z = dataset.pdf(xy).reshape(X.shape)

    # 创建2D和3D可视化
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'surface'}, {'type': 'contour'}]],
        subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
    )

    # 3D Surface
    surface = go.Surface(
        x=X, y=Y, z=Z,
        colorscale='viridis',
        showscale=True,
        colorbar=dict(x=0.45)
    )
    fig.add_trace(surface, row=1, col=1)

    # Contour Plot
    contour = go.Contour(
        x=x, y=y, z=Z,
        colorscale='viridis',
        showscale=True,
        colorbar=dict(x=1.0),
        contours=dict(
            showlabels=True,
            labelfont=dict(size=12)
        )
    )
    fig.add_trace(contour, row=1, col=2)

    # 添加分量中心点
    fig.add_trace(
        go.Scatter(
            x=centers[:K, 0], y=centers[:K, 1],
            mode='markers+text',
            marker=dict(size=10, color='red'),
            text=[f'C{i+1}' for i in range(K)],
            textposition="top center",
            name='分量中心'
        ),
        row=1, col=2
    )

    # 添加采样点(如果有)
    if samples is not None:
        fig.add_trace(
            go.Scatter(
                x=samples[:, 0], y=samples[:, 1],
                mode='markers+text',
                marker=dict(
                    size=8,
                    color='yellow',
                    line=dict(color='black', width=1)
                ),
                text=[f'S{i+1}' for i in range(len(samples))],
                textposition="bottom center",
                name='采样点'
            ),
            row=1, col=2
        )

    # 更新布局
    fig.update_layout(
        title='广义高斯混合分布',
        showlegend=True,
        width=1200,
        height=600,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='密度'
        )
    )

    # 更新2D图的坐标轴
    fig.update_xaxes(title_text='X', row=1, col=2)
    fig.update_yaxes(title_text='Y', row=1, col=2)

    return fig