Spaces:
Sleeping
Sleeping
feat: top k top p
Browse files- app.py +12 -59
- data_processor.py +0 -25
- gmm_point_tracking_with_centroids.csv +0 -0
- pages/__pycache__/page1.cpython-311.pyc +0 -0
- pages/__pycache__/page2.cpython-311.pyc +0 -0
- pages/__pycache__/page3.cpython-311.pyc +0 -0
- pages/page1.py +92 -0
- pages/page2.py +5 -0
- pages/page3.py +5 -0
- utils/__pycache__/data_processor.cpython-311.pyc +0 -0
- utils/__pycache__/hypergraph_drawer.cpython-311.pyc +0 -0
- utils/__pycache__/visualizer.cpython-311.pyc +0 -0
- utils/data_processor.py +65 -0
- utils/gmm_dataset.py +190 -0
- utils/gmm_vis.py +89 -0
- hypergraph_drawer.py → utils/hypergraph_drawer.py +0 -0
- visualizer.py → utils/visualizer.py +0 -0
app.py
CHANGED
@@ -1,69 +1,22 @@
|
|
1 |
import streamlit as st
|
2 |
-
from
|
3 |
-
from visualizer import visualize_gmm, visualize_ratings
|
4 |
-
from hypergraph_drawer import draw_hypergraph
|
5 |
|
6 |
# 设置页面配置
|
7 |
st.set_page_config(layout="wide")
|
8 |
|
9 |
# 主应用
|
10 |
def main():
|
11 |
-
st.title("
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
st.session_state.autoplay = False
|
23 |
-
st.experimental_rerun()
|
24 |
-
|
25 |
-
# 主页面布局
|
26 |
-
if 'autoplay' not in st.session_state:
|
27 |
-
st.session_state.autoplay = True
|
28 |
-
|
29 |
-
if 'iteration' not in st.session_state:
|
30 |
-
st.session_state.iteration = 1
|
31 |
-
|
32 |
-
if st.session_state.autoplay:
|
33 |
-
# 隐藏迭代次数滑条
|
34 |
-
iteration = st.session_state.iteration
|
35 |
-
else:
|
36 |
-
# 显示迭代次数滑条
|
37 |
-
iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
|
38 |
-
|
39 |
-
# 动态限制采样数量的最大值
|
40 |
-
df = load_data()
|
41 |
-
max_samples = len(df)
|
42 |
-
num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
|
43 |
-
|
44 |
-
# 处理数据
|
45 |
-
sampled_df, probabilities, hyperedges = process_data(df, iteration, num_samples)
|
46 |
-
|
47 |
-
# 并排展示超图和高斯混合分布
|
48 |
-
col1, col2 = st.columns(2)
|
49 |
-
with col1:
|
50 |
-
st.header("超图可视化")
|
51 |
-
hypergraph_image = draw_hypergraph(hyperedges)
|
52 |
-
st.image(hypergraph_image, caption="超图可视化", use_container_width=True)
|
53 |
-
|
54 |
-
with col2:
|
55 |
-
st.header("高斯混合分布聚类结果")
|
56 |
-
fig_gmm = visualize_gmm(sampled_df, iteration)
|
57 |
-
st.plotly_chart(fig_gmm, use_container_width=True)
|
58 |
-
|
59 |
-
# 显示采样论文的详细信息
|
60 |
-
st.header("采样论文详细信息")
|
61 |
-
st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
|
62 |
-
|
63 |
-
# 增加第二种可视化方式
|
64 |
-
st.header("论文评分分布")
|
65 |
-
fig_bar = visualize_ratings(sampled_df)
|
66 |
-
st.plotly_chart(fig_bar, use_container_width=True)
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
from pages import page1, page2, page3
|
|
|
|
|
3 |
|
4 |
# 设置页面配置
|
5 |
st.set_page_config(layout="wide")
|
6 |
|
7 |
# 主应用
|
8 |
def main():
|
9 |
+
st.sidebar.title("导航")
|
10 |
+
pages = {
|
11 |
+
"NIPS 论文数据集高斯混合聚类分析": page1,
|
12 |
+
"第二个子应用": page2,
|
13 |
+
"第三个子应用": page3
|
14 |
+
}
|
15 |
+
|
16 |
+
page = st.sidebar.radio("选择子应用", tuple(pages.keys()))
|
17 |
+
|
18 |
+
# 根据选择的子应用加载相应的页面
|
19 |
+
pages[page].main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
if __name__ == "__main__":
|
22 |
main()
|
data_processor.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import numpy as np
|
3 |
-
|
4 |
-
def load_data():
|
5 |
-
return pd.read_csv("gmm_point_tracking_with_centroids.csv")
|
6 |
-
|
7 |
-
def process_data(df, iteration, num_samples):
|
8 |
-
# 随机采样论文
|
9 |
-
sampled_df = df.sample(n=num_samples, random_state=iteration)
|
10 |
-
|
11 |
-
# 计算每个论文属于各个 cluster 的概率
|
12 |
-
probabilities = []
|
13 |
-
for idx, row in sampled_df.iterrows():
|
14 |
-
prob_str = row["probabilities"].strip("[]")
|
15 |
-
prob_list = list(map(float, prob_str.split(", ")))
|
16 |
-
probabilities.append(prob_list)
|
17 |
-
|
18 |
-
# 找到每个论文概率最高的 3 个 cluster
|
19 |
-
k = 3
|
20 |
-
hyperedges = {}
|
21 |
-
for idx, prob in enumerate(probabilities):
|
22 |
-
top_k = np.argsort(prob)[-k:][::-1]
|
23 |
-
hyperedges[idx] = [f"Cluster {c}" for c in top_k]
|
24 |
-
|
25 |
-
return sampled_df, probabilities, hyperedges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gmm_point_tracking_with_centroids.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pages/__pycache__/page1.cpython-311.pyc
ADDED
Binary file (5.33 kB). View file
|
|
pages/__pycache__/page2.cpython-311.pyc
ADDED
Binary file (602 Bytes). View file
|
|
pages/__pycache__/page3.cpython-311.pyc
ADDED
Binary file (602 Bytes). View file
|
|
pages/page1.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import plotly.express as px
|
6 |
+
import hypernetx as hnx
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
9 |
+
from io import BytesIO
|
10 |
+
import time
|
11 |
+
from utils.data_processor import load_data, process_data, build_hyperedges
|
12 |
+
from utils.visualizer import visualize_gmm, visualize_ratings
|
13 |
+
from utils.hypergraph_drawer import draw_hypergraph
|
14 |
+
|
15 |
+
def main():
|
16 |
+
st.title("NIPS 论文数据集高斯混合聚类分析")
|
17 |
+
|
18 |
+
# 使用 sidebar 控制参数
|
19 |
+
with st.sidebar:
|
20 |
+
st.header("控制面板")
|
21 |
+
autoplay = st.button("自动播放")
|
22 |
+
if autoplay:
|
23 |
+
for i in range(1, 11):
|
24 |
+
with st.spinner(f"迭代 {i}"):
|
25 |
+
time.sleep(1)
|
26 |
+
st.session_state.iteration = i
|
27 |
+
st.session_state.autoplay = False
|
28 |
+
st.experimental_rerun()
|
29 |
+
|
30 |
+
# 添加复选框选择显示 paper 的属性
|
31 |
+
display_attribute = st.selectbox(
|
32 |
+
"选择显示 paper 的属性",
|
33 |
+
["index", "id", "title", "keywords", "author"]
|
34 |
+
)
|
35 |
+
# 选择是 top k 还是 top p
|
36 |
+
display_option = st.selectbox(
|
37 |
+
"选择显示的选项",
|
38 |
+
["Top K Clusters", "Clusters Up To Probability P"]
|
39 |
+
)
|
40 |
+
# Top K Clusters
|
41 |
+
if display_option == "Top K Clusters":
|
42 |
+
top_k = st.slider("选择 K 值", min_value=1, max_value=10, value=3, step=1)
|
43 |
+
top_p = None
|
44 |
+
else:
|
45 |
+
top_k = None
|
46 |
+
top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
47 |
+
|
48 |
+
# 主页面布局
|
49 |
+
if 'autoplay' not in st.session_state:
|
50 |
+
st.session_state.autoplay = True
|
51 |
+
|
52 |
+
if 'iteration' not in st.session_state:
|
53 |
+
st.session_state.iteration = 1
|
54 |
+
|
55 |
+
if st.session_state.autoplay:
|
56 |
+
# 隐藏迭代次数滑条
|
57 |
+
iteration = st.session_state.iteration
|
58 |
+
else:
|
59 |
+
# 显示迭代次数滑条
|
60 |
+
iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
|
61 |
+
|
62 |
+
# 动态限制采样数量的最大值
|
63 |
+
df = load_data()
|
64 |
+
max_samples = len(df)
|
65 |
+
num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
|
66 |
+
|
67 |
+
# 处理数据
|
68 |
+
sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples)
|
69 |
+
# print(display_attribute) # 字符串
|
70 |
+
hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p)
|
71 |
+
# print(hyperedges)
|
72 |
+
|
73 |
+
# 并排展示超图和高斯混合分布
|
74 |
+
col1, col2 = st.columns(2)
|
75 |
+
with col1:
|
76 |
+
st.header("超图可视化")
|
77 |
+
hypergraph_image = draw_hypergraph(hyperedges)
|
78 |
+
st.image(hypergraph_image, caption="超图可视化", use_container_width=True)
|
79 |
+
|
80 |
+
with col2:
|
81 |
+
st.header("高斯混合分布聚类结果")
|
82 |
+
fig_gmm = visualize_gmm(sampled_df, iteration)
|
83 |
+
st.plotly_chart(fig_gmm, use_container_width=True)
|
84 |
+
|
85 |
+
# 显示采样论文的详细信息
|
86 |
+
st.header("采样论文详细信息")
|
87 |
+
st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
|
88 |
+
|
89 |
+
# 增加第二种可视化方式
|
90 |
+
st.header("论文评分分布")
|
91 |
+
fig_bar = visualize_ratings(sampled_df)
|
92 |
+
st.plotly_chart(fig_bar, use_container_width=True)
|
pages/page2.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def main():
|
4 |
+
st.title("第二个子应用")
|
5 |
+
st.write("这里是第二个子应用的内容。")
|
pages/page3.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def main():
|
4 |
+
st.title("第三个子应用")
|
5 |
+
st.write("这里是第三个子应用的内容。")
|
utils/__pycache__/data_processor.cpython-311.pyc
ADDED
Binary file (2.95 kB). View file
|
|
utils/__pycache__/hypergraph_drawer.cpython-311.pyc
ADDED
Binary file (1.17 kB). View file
|
|
utils/__pycache__/visualizer.cpython-311.pyc
ADDED
Binary file (1.92 kB). View file
|
|
utils/data_processor.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def load_data():
|
7 |
+
return pd.read_csv("gmm_point_tracking_with_centroids.csv").reset_index()
|
8 |
+
|
9 |
+
|
10 |
+
def process_data(df, iteration, num_samples):
|
11 |
+
# 随机采样论文
|
12 |
+
sampled_df = df.sample(n=num_samples, random_state=iteration)
|
13 |
+
|
14 |
+
# 计算每个论文属于各个 cluster 的概率
|
15 |
+
probabilities = []
|
16 |
+
paper_attributes = []
|
17 |
+
for idx, row in sampled_df.iterrows():
|
18 |
+
prob_str = row["probabilities"].strip("[]")
|
19 |
+
prob_list = list(map(float, prob_str.split(", ")))
|
20 |
+
probabilities.append(prob_list)
|
21 |
+
paper_attributes.append(
|
22 |
+
{
|
23 |
+
"id": row["id"],
|
24 |
+
"title": row["title"],
|
25 |
+
"keywords": row["keywords"],
|
26 |
+
"author": row["author"],
|
27 |
+
}
|
28 |
+
)
|
29 |
+
|
30 |
+
return sampled_df, probabilities, paper_attributes
|
31 |
+
|
32 |
+
|
33 |
+
def build_hyperedges(
|
34 |
+
probabilities,
|
35 |
+
paper_attributes: List[Dict[str, str]],
|
36 |
+
display_attribute_name: str,
|
37 |
+
top_k: int = None,
|
38 |
+
top_p: float = None,
|
39 |
+
) -> Dict[str, List[str]]:
|
40 |
+
# 构建超图边
|
41 |
+
hyperedges: Dict[str, List[str]] = {}
|
42 |
+
for idx, (prob, paper_attr) in enumerate(zip(probabilities, paper_attributes)):
|
43 |
+
if display_attribute_name == "index":
|
44 |
+
display_attribute = f"Paper {idx}"
|
45 |
+
display_attribute: str = paper_attr[display_attribute_name]
|
46 |
+
if top_k is not None:
|
47 |
+
selected_indices = np.argsort(prob)[-top_k:][::-1]
|
48 |
+
else:
|
49 |
+
# 累加起来,直到第一次大于等于 p
|
50 |
+
selected_indices = []
|
51 |
+
cumulative_prob = 0.0
|
52 |
+
for i, p in enumerate(prob):
|
53 |
+
selected_indices.append(i)
|
54 |
+
cumulative_prob += p
|
55 |
+
if cumulative_prob >= top_p:
|
56 |
+
break
|
57 |
+
|
58 |
+
|
59 |
+
for cluster in selected_indices:
|
60 |
+
cluster_name: str = f"Cluster {cluster}"
|
61 |
+
if cluster_name not in hyperedges:
|
62 |
+
hyperedges[cluster_name] = []
|
63 |
+
hyperedges[cluster_name].append(display_attribute)
|
64 |
+
|
65 |
+
return hyperedges
|
utils/gmm_dataset.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pathlib import Path
|
3 |
+
from scipy.special import gamma
|
4 |
+
from typing import Optional, Tuple, Dict, List, Union
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
|
8 |
+
class GeneralizedGaussianMixture:
|
9 |
+
r"""广义高斯混合分布数据集生成器
|
10 |
+
P_{\theta_k}(x_i) = \eta_k \exp(-s_k d_k(x_i)) = \frac{p}{2\alpha_k \Gamma(1/p)}\exp(-|\frac{x_i-c_k}{\alpha_k}|^p)
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
D: int = 2, # 维度
|
15 |
+
K: int = 3, # 聚类数量
|
16 |
+
p: float = 2.0, # 幂次,p=2为标准高斯分布
|
17 |
+
centers: Optional[np.ndarray] = None, # 聚类中心
|
18 |
+
scales: Optional[np.ndarray] = None, # 尺度参数
|
19 |
+
weights: Optional[np.ndarray] = None, # 混合权重
|
20 |
+
seed: int = 42): # 随机种子
|
21 |
+
"""初始化GMM数据集生成器
|
22 |
+
Args:
|
23 |
+
D: 数据维度
|
24 |
+
K: 聚类数量
|
25 |
+
p: 幂次参数,控制分布的形状
|
26 |
+
centers: 聚类中心,形状为(K, D)
|
27 |
+
scales: 尺度参数,形状为(K, D)
|
28 |
+
weights: 混合权重,形状为(K,)
|
29 |
+
seed: 随机种子
|
30 |
+
"""
|
31 |
+
self.D = D
|
32 |
+
self.K = K
|
33 |
+
self.p = p
|
34 |
+
self.seed = seed
|
35 |
+
np.random.seed(seed)
|
36 |
+
|
37 |
+
# 初始化分布参数
|
38 |
+
if centers is None:
|
39 |
+
self.centers = np.random.randn(K, D) * 2
|
40 |
+
else:
|
41 |
+
self.centers = centers
|
42 |
+
|
43 |
+
if scales is None:
|
44 |
+
self.scales = np.random.uniform(0.1, 0.5, size=(K, D))
|
45 |
+
else:
|
46 |
+
self.scales = scales
|
47 |
+
|
48 |
+
if weights is None:
|
49 |
+
self.weights = np.random.dirichlet(np.ones(K))
|
50 |
+
else:
|
51 |
+
self.weights = weights / weights.sum() # 确保权重和为1
|
52 |
+
|
53 |
+
def component_pdf(self, x: np.ndarray, k: int) -> np.ndarray:
|
54 |
+
"""计算第k个分量的概率密度
|
55 |
+
Args:
|
56 |
+
x: 输入数据点,形状为(N, D)
|
57 |
+
k: 分量索引
|
58 |
+
Returns:
|
59 |
+
概率密度值,形状为(N,)
|
60 |
+
"""
|
61 |
+
# 计算归一化常数
|
62 |
+
norm_const = self.p / (2 * self.scales[k] * gamma(1/self.p))
|
63 |
+
|
64 |
+
# 计算|x_i - c_k|^p / α_k^p
|
65 |
+
z = np.abs(x - self.centers[k]) / self.scales[k]
|
66 |
+
exp_term = np.exp(-np.sum(z**self.p, axis=1))
|
67 |
+
|
68 |
+
return np.prod(norm_const) * exp_term
|
69 |
+
|
70 |
+
def pdf(self, x: np.ndarray) -> np.ndarray:
|
71 |
+
"""计算混合分布的概率密度
|
72 |
+
Args:
|
73 |
+
x: 输入数据点,形状为(N, D)
|
74 |
+
Returns:
|
75 |
+
概率密度值,形状为(N,)
|
76 |
+
"""
|
77 |
+
density = np.zeros(len(x))
|
78 |
+
for k in range(self.K):
|
79 |
+
density += self.weights[k] * self.component_pdf(x, k)
|
80 |
+
return density
|
81 |
+
|
82 |
+
def generate_component_samples(self, n: int, k: int) -> np.ndarray:
|
83 |
+
"""从第k个分量生成样本
|
84 |
+
Args:
|
85 |
+
n: 样本数量
|
86 |
+
k: 分量索引
|
87 |
+
Returns:
|
88 |
+
样本点,形状为(n, D)
|
89 |
+
"""
|
90 |
+
# 使用幂指数分布的反变换采样
|
91 |
+
u = np.random.uniform(-1, 1, size=(n, self.D))
|
92 |
+
r = np.abs(u) ** (1/self.p)
|
93 |
+
samples = self.centers[k] + self.scales[k] * np.sign(u) * r
|
94 |
+
return samples
|
95 |
+
|
96 |
+
def generate_samples(self, N: int) -> Tuple[np.ndarray, np.ndarray]:
|
97 |
+
"""生成混合分布的样本
|
98 |
+
Args:
|
99 |
+
N: 总样本数量
|
100 |
+
Returns:
|
101 |
+
X: 生成的数据点,形状为(N, D)
|
102 |
+
y: 对应的概率密度值,形状为(N,)
|
103 |
+
"""
|
104 |
+
# 根据混合权重确定每个分量的样本数量
|
105 |
+
n_samples = np.random.multinomial(N, self.weights)
|
106 |
+
|
107 |
+
# 从每个分量生成样本
|
108 |
+
samples = []
|
109 |
+
for k in range(self.K):
|
110 |
+
x = self.generate_component_samples(n_samples[k], k)
|
111 |
+
samples.append(x)
|
112 |
+
|
113 |
+
# 合并并打乱样本
|
114 |
+
X = np.vstack(samples)
|
115 |
+
idx = np.random.permutation(N)
|
116 |
+
X = X[idx]
|
117 |
+
|
118 |
+
# 计算概率密度
|
119 |
+
y = self.pdf(X)
|
120 |
+
|
121 |
+
return X, y
|
122 |
+
|
123 |
+
def save_dataset(self, save_dir: Union[str, Path], name: str = 'gmm_dataset') -> None:
|
124 |
+
"""保存数据集到文件
|
125 |
+
Args:
|
126 |
+
save_dir: 保存目录
|
127 |
+
name: 数据集名称
|
128 |
+
"""
|
129 |
+
save_path = Path(save_dir)
|
130 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
131 |
+
|
132 |
+
# 生成并保存数据
|
133 |
+
X, y = self.generate_samples(N=1000)
|
134 |
+
np.savez(str(save_path / f'{name}.npz'),
|
135 |
+
X=X, y=y,
|
136 |
+
centers=self.centers,
|
137 |
+
scales=self.scales,
|
138 |
+
weights=self.weights,
|
139 |
+
D=self.D,
|
140 |
+
K=self.K,
|
141 |
+
p=self.p)
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def load_dataset(cls, file_path: Union[str, Path]) -> "GeneralizedGaussianMixture":
|
145 |
+
"""从文件加载数据集
|
146 |
+
Args:
|
147 |
+
file_path: 数据文件路径
|
148 |
+
Returns:
|
149 |
+
加载的GMM对象
|
150 |
+
"""
|
151 |
+
data = np.load(str(file_path))
|
152 |
+
return cls(
|
153 |
+
D=int(data['D']),
|
154 |
+
K=int(data['K']),
|
155 |
+
p=float(data['p']),
|
156 |
+
centers=data['centers'],
|
157 |
+
scales=data['scales'],
|
158 |
+
weights=data['weights']
|
159 |
+
)
|
160 |
+
|
161 |
+
def test_gmm_dataset():
|
162 |
+
"""测试GMM数据集生成器"""
|
163 |
+
# 创建2D的GMM数据集
|
164 |
+
gmm = GeneralizedGaussianMixture(
|
165 |
+
D=2,
|
166 |
+
K=3,
|
167 |
+
p=2.0,
|
168 |
+
centers=np.array([[-2, -2], [0, 0], [2, 2]]),
|
169 |
+
scales=np.array([[0.3, 0.3], [0.2, 0.2], [0.4, 0.4]]),
|
170 |
+
weights=np.array([0.3, 0.4, 0.3])
|
171 |
+
)
|
172 |
+
|
173 |
+
# 生成样本
|
174 |
+
X, y = gmm.generate_samples(1000)
|
175 |
+
|
176 |
+
# 保存数据集
|
177 |
+
gmm.save_dataset('test_data')
|
178 |
+
|
179 |
+
# 加载数据集
|
180 |
+
loaded_gmm = GeneralizedGaussianMixture.load_dataset('test_data/gmm_dataset.npz')
|
181 |
+
|
182 |
+
# 验证保存和加载的参数是否一致
|
183 |
+
assert np.allclose(gmm.centers, loaded_gmm.centers)
|
184 |
+
assert np.allclose(gmm.scales, loaded_gmm.scales)
|
185 |
+
assert np.allclose(gmm.weights, loaded_gmm.weights)
|
186 |
+
|
187 |
+
print("GMM数据集测试通过!")
|
188 |
+
|
189 |
+
if __name__ == '__main__':
|
190 |
+
test_gmm_dataset()
|
utils/gmm_vis.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def create_gmm_plot(dataset, centers, K, samples=None):
|
2 |
+
"""创建GMM分布的可视化图形"""
|
3 |
+
# 生成网格数据
|
4 |
+
x = np.linspace(-5, 5, 100)
|
5 |
+
y = np.linspace(-5, 5, 100)
|
6 |
+
X, Y = np.meshgrid(x, y)
|
7 |
+
xy = np.column_stack((X.ravel(), Y.ravel()))
|
8 |
+
|
9 |
+
# 计算概率密度
|
10 |
+
Z = dataset.pdf(xy).reshape(X.shape)
|
11 |
+
|
12 |
+
# 创建2D和3D可视化
|
13 |
+
fig = make_subplots(
|
14 |
+
rows=1, cols=2,
|
15 |
+
specs=[[{'type': 'surface'}, {'type': 'contour'}]],
|
16 |
+
subplot_titles=('3D概率密度曲面', '等高线图与分量中心')
|
17 |
+
)
|
18 |
+
|
19 |
+
# 3D Surface
|
20 |
+
surface = go.Surface(
|
21 |
+
x=X, y=Y, z=Z,
|
22 |
+
colorscale='viridis',
|
23 |
+
showscale=True,
|
24 |
+
colorbar=dict(x=0.45)
|
25 |
+
)
|
26 |
+
fig.add_trace(surface, row=1, col=1)
|
27 |
+
|
28 |
+
# Contour Plot
|
29 |
+
contour = go.Contour(
|
30 |
+
x=x, y=y, z=Z,
|
31 |
+
colorscale='viridis',
|
32 |
+
showscale=True,
|
33 |
+
colorbar=dict(x=1.0),
|
34 |
+
contours=dict(
|
35 |
+
showlabels=True,
|
36 |
+
labelfont=dict(size=12)
|
37 |
+
)
|
38 |
+
)
|
39 |
+
fig.add_trace(contour, row=1, col=2)
|
40 |
+
|
41 |
+
# 添加分量中心点
|
42 |
+
fig.add_trace(
|
43 |
+
go.Scatter(
|
44 |
+
x=centers[:K, 0], y=centers[:K, 1],
|
45 |
+
mode='markers+text',
|
46 |
+
marker=dict(size=10, color='red'),
|
47 |
+
text=[f'C{i+1}' for i in range(K)],
|
48 |
+
textposition="top center",
|
49 |
+
name='分量中心'
|
50 |
+
),
|
51 |
+
row=1, col=2
|
52 |
+
)
|
53 |
+
|
54 |
+
# 添加采样点(如果有)
|
55 |
+
if samples is not None:
|
56 |
+
fig.add_trace(
|
57 |
+
go.Scatter(
|
58 |
+
x=samples[:, 0], y=samples[:, 1],
|
59 |
+
mode='markers+text',
|
60 |
+
marker=dict(
|
61 |
+
size=8,
|
62 |
+
color='yellow',
|
63 |
+
line=dict(color='black', width=1)
|
64 |
+
),
|
65 |
+
text=[f'S{i+1}' for i in range(len(samples))],
|
66 |
+
textposition="bottom center",
|
67 |
+
name='采样点'
|
68 |
+
),
|
69 |
+
row=1, col=2
|
70 |
+
)
|
71 |
+
|
72 |
+
# 更新布局
|
73 |
+
fig.update_layout(
|
74 |
+
title='广义高斯混合分布',
|
75 |
+
showlegend=True,
|
76 |
+
width=1200,
|
77 |
+
height=600,
|
78 |
+
scene=dict(
|
79 |
+
xaxis_title='X',
|
80 |
+
yaxis_title='Y',
|
81 |
+
zaxis_title='密度'
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
# 更新2D图的坐标轴
|
86 |
+
fig.update_xaxes(title_text='X', row=1, col=2)
|
87 |
+
fig.update_yaxes(title_text='Y', row=1, col=2)
|
88 |
+
|
89 |
+
return fig
|
hypergraph_drawer.py → utils/hypergraph_drawer.py
RENAMED
File without changes
|
visualizer.py → utils/visualizer.py
RENAMED
File without changes
|