2catycm commited on
Commit
cbbacc3
·
1 Parent(s): db9ca60

feat: top k top p

Browse files
app.py CHANGED
@@ -1,69 +1,22 @@
1
  import streamlit as st
2
- from data_processor import load_data, process_data
3
- from visualizer import visualize_gmm, visualize_ratings
4
- from hypergraph_drawer import draw_hypergraph
5
 
6
  # 设置页面配置
7
  st.set_page_config(layout="wide")
8
 
9
  # 主应用
10
  def main():
11
- st.title("高斯混合分布聚类可视化")
12
-
13
- # 使用 sidebar 控制参数
14
- with st.sidebar:
15
- st.header("控制面板")
16
- autoplay = st.button("自动播放")
17
- if autoplay:
18
- for i in range(1, 11):
19
- with st.spinner(f"迭代 {i}"):
20
- time.sleep(1)
21
- st.session_state.iteration = i
22
- st.session_state.autoplay = False
23
- st.experimental_rerun()
24
-
25
- # 主页面布局
26
- if 'autoplay' not in st.session_state:
27
- st.session_state.autoplay = True
28
-
29
- if 'iteration' not in st.session_state:
30
- st.session_state.iteration = 1
31
-
32
- if st.session_state.autoplay:
33
- # 隐藏迭代次数滑条
34
- iteration = st.session_state.iteration
35
- else:
36
- # 显示迭代次数滑条
37
- iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
38
-
39
- # 动态限制采样数量的最大值
40
- df = load_data()
41
- max_samples = len(df)
42
- num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
43
-
44
- # 处理数据
45
- sampled_df, probabilities, hyperedges = process_data(df, iteration, num_samples)
46
-
47
- # 并排展示超图和高斯混合分布
48
- col1, col2 = st.columns(2)
49
- with col1:
50
- st.header("超图可视化")
51
- hypergraph_image = draw_hypergraph(hyperedges)
52
- st.image(hypergraph_image, caption="超图可视化", use_container_width=True)
53
-
54
- with col2:
55
- st.header("高斯混合分布聚类结果")
56
- fig_gmm = visualize_gmm(sampled_df, iteration)
57
- st.plotly_chart(fig_gmm, use_container_width=True)
58
-
59
- # 显示采样论文的详细信息
60
- st.header("采样论文详细信息")
61
- st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
62
-
63
- # 增加第二种可视化方式
64
- st.header("论文评分分布")
65
- fig_bar = visualize_ratings(sampled_df)
66
- st.plotly_chart(fig_bar, use_container_width=True)
67
 
68
  if __name__ == "__main__":
69
  main()
 
1
  import streamlit as st
2
+ from pages import page1, page2, page3
 
 
3
 
4
  # 设置页面配置
5
  st.set_page_config(layout="wide")
6
 
7
  # 主应用
8
  def main():
9
+ st.sidebar.title("导航")
10
+ pages = {
11
+ "NIPS 论文数据集高斯混合聚类分析": page1,
12
+ "第二个子应用": page2,
13
+ "第三个子应用": page3
14
+ }
15
+
16
+ page = st.sidebar.radio("选择子应用", tuple(pages.keys()))
17
+
18
+ # 根据选择的子应用加载相应的页面
19
+ pages[page].main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  if __name__ == "__main__":
22
  main()
data_processor.py DELETED
@@ -1,25 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
-
4
- def load_data():
5
- return pd.read_csv("gmm_point_tracking_with_centroids.csv")
6
-
7
- def process_data(df, iteration, num_samples):
8
- # 随机采样论文
9
- sampled_df = df.sample(n=num_samples, random_state=iteration)
10
-
11
- # 计算每个论文属于各个 cluster 的概率
12
- probabilities = []
13
- for idx, row in sampled_df.iterrows():
14
- prob_str = row["probabilities"].strip("[]")
15
- prob_list = list(map(float, prob_str.split(", ")))
16
- probabilities.append(prob_list)
17
-
18
- # 找到每个论文概率最高的 3 个 cluster
19
- k = 3
20
- hyperedges = {}
21
- for idx, prob in enumerate(probabilities):
22
- top_k = np.argsort(prob)[-k:][::-1]
23
- hyperedges[idx] = [f"Cluster {c}" for c in top_k]
24
-
25
- return sampled_df, probabilities, hyperedges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gmm_point_tracking_with_centroids.csv CHANGED
The diff for this file is too large to render. See raw diff
 
pages/__pycache__/page1.cpython-311.pyc ADDED
Binary file (5.33 kB). View file
 
pages/__pycache__/page2.cpython-311.pyc ADDED
Binary file (602 Bytes). View file
 
pages/__pycache__/page3.cpython-311.pyc ADDED
Binary file (602 Bytes). View file
 
pages/page1.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+ import hypernetx as hnx
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
9
+ from io import BytesIO
10
+ import time
11
+ from utils.data_processor import load_data, process_data, build_hyperedges
12
+ from utils.visualizer import visualize_gmm, visualize_ratings
13
+ from utils.hypergraph_drawer import draw_hypergraph
14
+
15
+ def main():
16
+ st.title("NIPS 论文数据集高斯混合聚类分析")
17
+
18
+ # 使用 sidebar 控制参数
19
+ with st.sidebar:
20
+ st.header("控制面板")
21
+ autoplay = st.button("自动播放")
22
+ if autoplay:
23
+ for i in range(1, 11):
24
+ with st.spinner(f"迭代 {i}"):
25
+ time.sleep(1)
26
+ st.session_state.iteration = i
27
+ st.session_state.autoplay = False
28
+ st.experimental_rerun()
29
+
30
+ # 添加复选框选择显示 paper 的属性
31
+ display_attribute = st.selectbox(
32
+ "选择显示 paper 的属性",
33
+ ["index", "id", "title", "keywords", "author"]
34
+ )
35
+ # 选择是 top k 还是 top p
36
+ display_option = st.selectbox(
37
+ "选择显示的选项",
38
+ ["Top K Clusters", "Clusters Up To Probability P"]
39
+ )
40
+ # Top K Clusters
41
+ if display_option == "Top K Clusters":
42
+ top_k = st.slider("选择 K 值", min_value=1, max_value=10, value=3, step=1)
43
+ top_p = None
44
+ else:
45
+ top_k = None
46
+ top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
47
+
48
+ # 主页面布局
49
+ if 'autoplay' not in st.session_state:
50
+ st.session_state.autoplay = True
51
+
52
+ if 'iteration' not in st.session_state:
53
+ st.session_state.iteration = 1
54
+
55
+ if st.session_state.autoplay:
56
+ # 隐藏迭代次数滑条
57
+ iteration = st.session_state.iteration
58
+ else:
59
+ # 显示迭代次数滑条
60
+ iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
61
+
62
+ # 动态限制采样数量的最大值
63
+ df = load_data()
64
+ max_samples = len(df)
65
+ num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
66
+
67
+ # 处理数据
68
+ sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples)
69
+ # print(display_attribute) # 字符串
70
+ hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p)
71
+ # print(hyperedges)
72
+
73
+ # 并排展示超图和高斯混合分布
74
+ col1, col2 = st.columns(2)
75
+ with col1:
76
+ st.header("超图可视化")
77
+ hypergraph_image = draw_hypergraph(hyperedges)
78
+ st.image(hypergraph_image, caption="超图可视化", use_container_width=True)
79
+
80
+ with col2:
81
+ st.header("高斯混合分布聚类结果")
82
+ fig_gmm = visualize_gmm(sampled_df, iteration)
83
+ st.plotly_chart(fig_gmm, use_container_width=True)
84
+
85
+ # 显示采样论文的详细信息
86
+ st.header("采样论文详细信息")
87
+ st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
88
+
89
+ # 增加第二种可视化方式
90
+ st.header("论文评分分布")
91
+ fig_bar = visualize_ratings(sampled_df)
92
+ st.plotly_chart(fig_bar, use_container_width=True)
pages/page2.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def main():
4
+ st.title("第二个子应用")
5
+ st.write("这里是第二个子应用的内容。")
pages/page3.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def main():
4
+ st.title("第三个子应用")
5
+ st.write("这里是第三个子应用的内容。")
utils/__pycache__/data_processor.cpython-311.pyc ADDED
Binary file (2.95 kB). View file
 
utils/__pycache__/hypergraph_drawer.cpython-311.pyc ADDED
Binary file (1.17 kB). View file
 
utils/__pycache__/visualizer.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
utils/data_processor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+
6
+ def load_data():
7
+ return pd.read_csv("gmm_point_tracking_with_centroids.csv").reset_index()
8
+
9
+
10
+ def process_data(df, iteration, num_samples):
11
+ # 随机采样论文
12
+ sampled_df = df.sample(n=num_samples, random_state=iteration)
13
+
14
+ # 计算每个论文属于各个 cluster 的概率
15
+ probabilities = []
16
+ paper_attributes = []
17
+ for idx, row in sampled_df.iterrows():
18
+ prob_str = row["probabilities"].strip("[]")
19
+ prob_list = list(map(float, prob_str.split(", ")))
20
+ probabilities.append(prob_list)
21
+ paper_attributes.append(
22
+ {
23
+ "id": row["id"],
24
+ "title": row["title"],
25
+ "keywords": row["keywords"],
26
+ "author": row["author"],
27
+ }
28
+ )
29
+
30
+ return sampled_df, probabilities, paper_attributes
31
+
32
+
33
+ def build_hyperedges(
34
+ probabilities,
35
+ paper_attributes: List[Dict[str, str]],
36
+ display_attribute_name: str,
37
+ top_k: int = None,
38
+ top_p: float = None,
39
+ ) -> Dict[str, List[str]]:
40
+ # 构建超图边
41
+ hyperedges: Dict[str, List[str]] = {}
42
+ for idx, (prob, paper_attr) in enumerate(zip(probabilities, paper_attributes)):
43
+ if display_attribute_name == "index":
44
+ display_attribute = f"Paper {idx}"
45
+ display_attribute: str = paper_attr[display_attribute_name]
46
+ if top_k is not None:
47
+ selected_indices = np.argsort(prob)[-top_k:][::-1]
48
+ else:
49
+ # 累加起来,直到第一次大于等于 p
50
+ selected_indices = []
51
+ cumulative_prob = 0.0
52
+ for i, p in enumerate(prob):
53
+ selected_indices.append(i)
54
+ cumulative_prob += p
55
+ if cumulative_prob >= top_p:
56
+ break
57
+
58
+
59
+ for cluster in selected_indices:
60
+ cluster_name: str = f"Cluster {cluster}"
61
+ if cluster_name not in hyperedges:
62
+ hyperedges[cluster_name] = []
63
+ hyperedges[cluster_name].append(display_attribute)
64
+
65
+ return hyperedges
utils/gmm_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from scipy.special import gamma
4
+ from typing import Optional, Tuple, Dict, List, Union
5
+ import torch
6
+ import os
7
+
8
+ class GeneralizedGaussianMixture:
9
+ r"""广义高斯混合分布数据集生成器
10
+ P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
11
+ """
12
+
13
+ def __init__(self,
14
+ D: int = 2, # 维度
15
+ K: int = 3, # 聚类数量
16
+ p: float = 2.0, # 幂次,p=2为标准高斯分布
17
+ centers: Optional[np.ndarray] = None, # 聚类中心
18
+ scales: Optional[np.ndarray] = None, # 尺度参数
19
+ weights: Optional[np.ndarray] = None, # 混合权重
20
+ seed: int = 42): # 随机种子
21
+ """初始化GMM数据集生成器
22
+ Args:
23
+ D: 数据维度
24
+ K: 聚类数量
25
+ p: 幂次参数,控制分布的形状
26
+ centers: 聚类中心,形状为(K, D)
27
+ scales: 尺度参数,形状为(K, D)
28
+ weights: 混合权重,形状为(K,)
29
+ seed: 随机种子
30
+ """
31
+ self.D = D
32
+ self.K = K
33
+ self.p = p
34
+ self.seed = seed
35
+ np.random.seed(seed)
36
+
37
+ # 初始化分布参数
38
+ if centers is None:
39
+ self.centers = np.random.randn(K, D) * 2
40
+ else:
41
+ self.centers = centers
42
+
43
+ if scales is None:
44
+ self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
45
+ else:
46
+ self.scales = scales
47
+
48
+ if weights is None:
49
+ self.weights = np.random.dirichlet(np.ones(K))
50
+ else:
51
+ self.weights = weights / weights.sum() # 确保权重和为1
52
+
53
+ def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
54
+ """计算第k个分量的概率密度
55
+ Args:
56
+ x: 输入数据点,形状为(N, D)
57
+ k: 分量索引
58
+ Returns:
59
+ 概率密度值,形状为(N,)
60
+ """
61
+ # 计算归一化常数
62
+ norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
63
+
64
+ # 计算|x_i - c_k|^p / α_k^p
65
+ z = np.abs(x - self.centers[k]) / self.scales[k]
66
+ exp_term = np.exp(-np.sum(z**self.p, axis=1))
67
+
68
+ return np.prod(norm_const) * exp_term
69
+
70
+ def pdf(self, x: np.ndarray) -> np.ndarray:
71
+ """计算混合分布的概率密度
72
+ Args:
73
+ x: 输入数据点,形状为(N, D)
74
+ Returns:
75
+ 概率密度值,形状为(N,)
76
+ """
77
+ density = np.zeros(len(x))
78
+ for k in range(self.K):
79
+ density += self.weights[k] * self.component_pdf(x, k)
80
+ return density
81
+
82
+ def generate_component_samples(self, n: int, k: int) -> np.ndarray:
83
+ """从第k个分量生成样本
84
+ Args:
85
+ n: 样本数量
86
+ k: 分量索引
87
+ Returns:
88
+ 样本点,形状为(n, D)
89
+ """
90
+ # 使用幂指数分布的反变换采样
91
+ u = np.random.uniform(-1, 1, size=(n, self.D))
92
+ r = np.abs(u) ** (1/self.p)
93
+ samples = self.centers[k] + self.scales[k] * np.sign(u) * r
94
+ return samples
95
+
96
+ def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
97
+ """生成混合分布的样本
98
+ Args:
99
+ N: 总样本数量
100
+ Returns:
101
+ X: 生成的数据点,形状为(N, D)
102
+ y: 对应的概率密度值,形状为(N,)
103
+ """
104
+ # 根据混合权重确定每个分量的样本数量
105
+ n_samples = np.random.multinomial(N, self.weights)
106
+
107
+ # 从每个分量生成样本
108
+ samples = []
109
+ for k in range(self.K):
110
+ x = self.generate_component_samples(n_samples[k], k)
111
+ samples.append(x)
112
+
113
+ # 合并并打乱样本
114
+ X = np.vstack(samples)
115
+ idx = np.random.permutation(N)
116
+ X = X[idx]
117
+
118
+ # 计算概率密度
119
+ y = self.pdf(X)
120
+
121
+ return X, y
122
+
123
+ def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
124
+ """保存数据集到文件
125
+ Args:
126
+ save_dir: 保存目录
127
+ name: 数据集名称
128
+ """
129
+ save_path = Path(save_dir)
130
+ save_path.mkdir(parents=True, exist_ok=True)
131
+
132
+ # 生成并保存数据
133
+ X, y = self.generate_samples(N=1000)
134
+ np.savez(str(save_path / f'{name}.npz'),
135
+ X=X, y=y,
136
+ centers=self.centers,
137
+ scales=self.scales,
138
+ weights=self.weights,
139
+ D=self.D,
140
+ K=self.K,
141
+ p=self.p)
142
+
143
+ @classmethod
144
+ def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
145
+ """从文件加载数据集
146
+ Args:
147
+ file_path: 数据文件路径
148
+ Returns:
149
+ 加载的GMM对象
150
+ """
151
+ data = np.load(str(file_path))
152
+ return cls(
153
+ D=int(data['D']),
154
+ K=int(data['K']),
155
+ p=float(data['p']),
156
+ centers=data['centers'],
157
+ scales=data['scales'],
158
+ weights=data['weights']
159
+ )
160
+
161
+ def test_gmm_dataset():
162
+ """测试GMM数据集生成器"""
163
+ # 创建2D的GMM数据集
164
+ gmm = GeneralizedGaussianMixture(
165
+ D=2,
166
+ K=3,
167
+ p=2.0,
168
+ centers=np.array([[-2, -2], [0, 0], [2, 2]]),
169
+ scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
170
+ weights=np.array([0.3, 0.4, 0.3])
171
+ )
172
+
173
+ # 生成样本
174
+ X, y = gmm.generate_samples(1000)
175
+
176
+ # 保存数据集
177
+ gmm.save_dataset('test_data')
178
+
179
+ # 加载数据集
180
+ loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
181
+
182
+ # 验证保存和加载的参数是否一致
183
+ assert np.allclose(gmm.centers, loaded_gmm.centers)
184
+ assert np.allclose(gmm.scales, loaded_gmm.scales)
185
+ assert np.allclose(gmm.weights, loaded_gmm.weights)
186
+
187
+ print("GMM数据集测试通过!")
188
+
189
+ if __name__ == '__main__':
190
+ test_gmm_dataset()
utils/gmm_vis.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def create_gmm_plot(dataset, centers, K, samples=None):
2
+ """创建GMM分布的可视化图形"""
3
+ # 生成网格数据
4
+ x = np.linspace(-5, 5, 100)
5
+ y = np.linspace(-5, 5, 100)
6
+ X, Y = np.meshgrid(x, y)
7
+ xy = np.column_stack((X.ravel(), Y.ravel()))
8
+
9
+ # 计算概率密度
10
+ Z = dataset.pdf(xy).reshape(X.shape)
11
+
12
+ # 创建2D和3D可视化
13
+ fig = make_subplots(
14
+ rows=1, cols=2,
15
+ specs=[[{'type': 'surface'}, {'type': 'contour'}]],
16
+ subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
17
+ )
18
+
19
+ # 3D Surface
20
+ surface = go.Surface(
21
+ x=X, y=Y, z=Z,
22
+ colorscale='viridis',
23
+ showscale=True,
24
+ colorbar=dict(x=0.45)
25
+ )
26
+ fig.add_trace(surface, row=1, col=1)
27
+
28
+ # Contour Plot
29
+ contour = go.Contour(
30
+ x=x, y=y, z=Z,
31
+ colorscale='viridis',
32
+ showscale=True,
33
+ colorbar=dict(x=1.0),
34
+ contours=dict(
35
+ showlabels=True,
36
+ labelfont=dict(size=12)
37
+ )
38
+ )
39
+ fig.add_trace(contour, row=1, col=2)
40
+
41
+ # 添加分量中心点
42
+ fig.add_trace(
43
+ go.Scatter(
44
+ x=centers[:K, 0], y=centers[:K, 1],
45
+ mode='markers+text',
46
+ marker=dict(size=10, color='red'),
47
+ text=[f'C{i+1}' for i in range(K)],
48
+ textposition="top center",
49
+ name='分量中心'
50
+ ),
51
+ row=1, col=2
52
+ )
53
+
54
+ # 添加采样点(如果有)
55
+ if samples is not None:
56
+ fig.add_trace(
57
+ go.Scatter(
58
+ x=samples[:, 0], y=samples[:, 1],
59
+ mode='markers+text',
60
+ marker=dict(
61
+ size=8,
62
+ color='yellow',
63
+ line=dict(color='black', width=1)
64
+ ),
65
+ text=[f'S{i+1}' for i in range(len(samples))],
66
+ textposition="bottom center",
67
+ name='采样点'
68
+ ),
69
+ row=1, col=2
70
+ )
71
+
72
+ # 更新布局
73
+ fig.update_layout(
74
+ title='广义高斯混合分布',
75
+ showlegend=True,
76
+ width=1200,
77
+ height=600,
78
+ scene=dict(
79
+ xaxis_title='X',
80
+ yaxis_title='Y',
81
+ zaxis_title='密度'
82
+ )
83
+ )
84
+
85
+ # 更新2D图的坐标轴
86
+ fig.update_xaxes(title_text='X', row=1, col=2)
87
+ fig.update_yaxes(title_text='Y', row=1, col=2)
88
+
89
+ return fig
hypergraph_drawer.py → utils/hypergraph_drawer.py RENAMED
File without changes
visualizer.py → utils/visualizer.py RENAMED
File without changes