2catycm commited on
Commit
db9ca60
·
1 Parent(s): 2835ddd
__pycache__/data_processor.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
__pycache__/hypergraph_drawer.cpython-311.pyc ADDED
Binary file (1.17 kB). View file
 
__pycache__/visualizer.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
app.py CHANGED
@@ -1,136 +1,69 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import plotly.express as px
5
- import hypernetx as hnx
6
- import matplotlib.pyplot as plt
7
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
- from io import BytesIO
9
- import time
10
 
11
- # 读取数据
12
- df = pd.read_csv("gmm_point_tracking_with_centroids.csv")
13
  st.set_page_config(layout="wide")
14
 
15
- # Streamlit 应用
16
- st.title("高斯混合分布聚类可视化")
17
-
18
- # 设置页面宽度
19
-
20
- # 使用 sidebar 控制参数
21
- with st.sidebar:
22
- st.header("控制面板")
23
- autoplay = st.button("自动播放")
24
- if autoplay:
25
- for i in range(1, 11):
26
- with st.spinner(f"迭代 {i}"):
27
- time.sleep(1)
28
- st.session_state.iteration = i
29
- st.rerun()
30
- st.session_state.autoplay = False
31
- # st.experimental_rerun()
32
-
33
- # 主页面布局
34
- if 'autoplay' not in st.session_state:
35
- st.session_state.autoplay = True
36
-
37
- if 'iteration' not in st.session_state:
38
- st.session_state.iteration = 1
39
-
40
- if st.session_state.autoplay:
41
- # 隐藏迭代次数滑条
42
- iteration = st.session_state.iteration
43
- else:
44
- # 显示迭代次数滑条
45
- iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
46
-
47
- # 动态限制采样数量的最大值
48
- max_samples = len(df)
49
- num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
50
-
51
- # 随机采样论文
52
- sampled_df = df.sample(n=num_samples, random_state=iteration)
53
-
54
- # 计算每个论文属于各个 cluster 的概率
55
- probabilities = []
56
- for idx, row in sampled_df.iterrows():
57
- prob_str = row["probabilities"].strip("[]")
58
- prob_list = list(map(float, prob_str.split(", ")))
59
- probabilities.append(prob_list)
60
-
61
- # 找到每个论文概率最高的 3 个 cluster
62
- k = 3
63
- hyperedges = {}
64
- for idx, prob in enumerate(probabilities):
65
- top_k = np.argsort(prob)[-k:][::-1]
66
- hyperedges[idx] = [f"Cluster {c}" for c in top_k]
67
-
68
- # 构建超图
69
- H = hnx.Hypergraph(hyperedges)
70
-
71
- # 绘制超图
72
- fig_hnx, ax = plt.subplots(figsize=(12, 8))
73
- hnx.draw(H, ax=ax)
74
-
75
- # 将超图保存为图像
76
- canvas = FigureCanvas(fig_hnx)
77
- buffer = BytesIO()
78
- canvas.print_png(buffer)
79
- buffer.seek(0)
80
-
81
- # 用 Plotly 可视化高斯混合分布
82
- fig_gmm = px.scatter(
83
- sampled_df,
84
- x="x",
85
- y="y",
86
- color="cluster",
87
- hover_data=["title", "keywords", "rating_avg", "confidence_avg", "author", "site"],
88
- title=f"高斯混合分布聚类(迭代 {iteration})",
89
- )
90
-
91
- # 添加聚类中心点
92
- for cluster in sampled_df["cluster"].unique():
93
- centroid_x = sampled_df[sampled_df["cluster"] == cluster]["centroid_x"].iloc[0]
94
- centroid_y = sampled_df[sampled_df["cluster"] == cluster]["centroid_y"].iloc[0]
95
- fig_gmm.add_scatter(
96
- x=[centroid_x],
97
- y=[centroid_y],
98
- mode="markers",
99
- marker=dict(size=15, color="black", symbol="x"),
100
- name=f"Cluster {cluster} Center",
101
- )
102
-
103
- # 并排展示超图和高斯混合分布
104
- col1, col2 = st.columns(2)
105
- col1.header("超图可视化")
106
- col1.image(buffer, caption="超图可视化", use_column_width=True)
107
-
108
- col2.header("高斯混合分布聚类结果")
109
- col2.plotly_chart(fig_gmm, use_container_width=True)
110
-
111
- # 显示采样论文的详细信息
112
- st.header("采样论文详细信息")
113
- st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
114
-
115
- # 增加第二种可视化方式
116
- st.header("论文评分分布")
117
-
118
- # 创建柱状图
119
- fig_bar = px.bar(
120
- sampled_df,
121
- x="title",
122
- y="rating_avg",
123
- color="cluster",
124
- title="论文评分分布",
125
- hover_data=["keywords", "confidence_avg", "author"],
126
- )
127
-
128
- # 调整布局
129
- fig_bar.update_layout(
130
- xaxis_title="论文标题",
131
- yaxis_title="平均评分",
132
- xaxis_tickangle=-45,
133
- )
134
-
135
- # 显示柱状图
136
- st.plotly_chart(fig_bar, use_container_width=True)
 
1
  import streamlit as st
2
+ from data_processor import load_data, process_data
3
+ from visualizer import visualize_gmm, visualize_ratings
4
+ from hypergraph_drawer import draw_hypergraph
 
 
 
 
 
5
 
6
+ # 设置页面配置
 
7
  st.set_page_config(layout="wide")
8
 
9
+ # 主应用
10
+ def main():
11
+ st.title("高斯混合分布聚类可视化")
12
+
13
+ # 使用 sidebar 控制参数
14
+ with st.sidebar:
15
+ st.header("控制面板")
16
+ autoplay = st.button("自动播放")
17
+ if autoplay:
18
+ for i in range(1, 11):
19
+ with st.spinner(f"迭代 {i}"):
20
+ time.sleep(1)
21
+ st.session_state.iteration = i
22
+ st.session_state.autoplay = False
23
+ st.experimental_rerun()
24
+
25
+ # 主页面布局
26
+ if 'autoplay' not in st.session_state:
27
+ st.session_state.autoplay = True
28
+
29
+ if 'iteration' not in st.session_state:
30
+ st.session_state.iteration = 1
31
+
32
+ if st.session_state.autoplay:
33
+ # 隐藏迭代次数滑条
34
+ iteration = st.session_state.iteration
35
+ else:
36
+ # 显示迭代次数滑条
37
+ iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
38
+
39
+ # 动态限制采样数量的最大值
40
+ df = load_data()
41
+ max_samples = len(df)
42
+ num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
43
+
44
+ # 处理数据
45
+ sampled_df, probabilities, hyperedges = process_data(df, iteration, num_samples)
46
+
47
+ # 并排展示超图和高斯混合分布
48
+ col1, col2 = st.columns(2)
49
+ with col1:
50
+ st.header("超图可视化")
51
+ hypergraph_image = draw_hypergraph(hyperedges)
52
+ st.image(hypergraph_image, caption="超图可视化", use_container_width=True)
53
+
54
+ with col2:
55
+ st.header("高斯混合分布聚类结果")
56
+ fig_gmm = visualize_gmm(sampled_df, iteration)
57
+ st.plotly_chart(fig_gmm, use_container_width=True)
58
+
59
+ # 显示采样论文的详细信息
60
+ st.header("采样论文详细信息")
61
+ st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
62
+
63
+ # 增加第二种可视化方式
64
+ st.header("论文评分分布")
65
+ fig_bar = visualize_ratings(sampled_df)
66
+ st.plotly_chart(fig_bar, use_container_width=True)
67
+
68
+ if __name__ == "__main__":
69
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_processor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ def load_data():
5
+ return pd.read_csv("gmm_point_tracking_with_centroids.csv")
6
+
7
+ def process_data(df, iteration, num_samples):
8
+ # 随机采样论文
9
+ sampled_df = df.sample(n=num_samples, random_state=iteration)
10
+
11
+ # 计算每个论文属于各个 cluster 的概率
12
+ probabilities = []
13
+ for idx, row in sampled_df.iterrows():
14
+ prob_str = row["probabilities"].strip("[]")
15
+ prob_list = list(map(float, prob_str.split(", ")))
16
+ probabilities.append(prob_list)
17
+
18
+ # 找到每个论文概率最高的 3 个 cluster
19
+ k = 3
20
+ hyperedges = {}
21
+ for idx, prob in enumerate(probabilities):
22
+ top_k = np.argsort(prob)[-k:][::-1]
23
+ hyperedges[idx] = [f"Cluster {c}" for c in top_k]
24
+
25
+ return sampled_df, probabilities, hyperedges
hypergraph_drawer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hypernetx as hnx
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
4
+ from io import BytesIO
5
+
6
+ def draw_hypergraph(hyperedges):
7
+ # 构建超图
8
+ H = hnx.Hypergraph(hyperedges)
9
+
10
+ # 绘制超图
11
+ fig, ax = plt.subplots(figsize=(12, 8))
12
+ hnx.draw(H, ax=ax)
13
+
14
+ # 将超图保存为图像
15
+ canvas = FigureCanvas(fig)
16
+ buffer = BytesIO()
17
+ canvas.print_png(buffer)
18
+ buffer.seek(0)
19
+
20
+ return buffer
visualizer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+
3
+ def visualize_gmm(sampled_df, iteration):
4
+ fig = px.scatter(
5
+ sampled_df,
6
+ x="x",
7
+ y="y",
8
+ color="cluster",
9
+ hover_data=["title", "keywords", "rating_avg", "confidence_avg", "author", "site"],
10
+ title=f"高斯混合分布聚类(迭代 {iteration})",
11
+ )
12
+
13
+ # 添加聚类中心点
14
+ for cluster in sampled_df["cluster"].unique():
15
+ centroid_x = sampled_df[sampled_df["cluster"] == cluster]["centroid_x"].iloc[0]
16
+ centroid_y = sampled_df[sampled_df["cluster"] == cluster]["centroid_y"].iloc[0]
17
+ fig.add_scatter(
18
+ x=[centroid_x],
19
+ y=[centroid_y],
20
+ mode="markers",
21
+ marker=dict(size=15, color="black", symbol="x"),
22
+ name=f"Cluster {cluster} Center",
23
+ )
24
+
25
+ return fig
26
+
27
+ def visualize_ratings(sampled_df):
28
+ fig = px.bar(
29
+ sampled_df,
30
+ x="title",
31
+ y="rating_avg",
32
+ color="cluster",
33
+ title="论文评分分布",
34
+ hover_data=["keywords", "confidence_avg", "author"],
35
+ )
36
+
37
+ fig.update_layout(
38
+ xaxis_title="论文标题",
39
+ yaxis_title="平均评分",
40
+ xaxis_tickangle=-45,
41
+ )
42
+
43
+ return fig