2catycm commited on
Commit
2835ddd
·
1 Parent(s): 1d0271a

feat: init 2

Browse files
Files changed (1) hide show
  1. app.py +68 -15
app.py CHANGED
@@ -2,36 +2,84 @@ import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
 
 
 
 
5
  import time
6
 
7
  # 读取数据
8
  df = pd.read_csv("gmm_point_tracking_with_centroids.csv")
 
9
 
10
  # Streamlit 应用
11
  st.title("高斯混合分布聚类可视化")
12
 
 
 
13
  # 使用 sidebar 控制参数
14
  with st.sidebar:
15
  st.header("控制面板")
16
- iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=1, step=1)
17
- max_samples = len(df)
18
- num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
19
- autoplay = st.checkbox("自动播放", value=False)
20
  if autoplay:
21
  for i in range(1, 11):
22
- iteration = i
23
- st.session_state.iteration = i
24
- time.sleep(1)
25
- st.experimental_rerun()
 
 
26
 
27
  # 主页面布局
28
- st.header("高斯混合分布聚类结果")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # 随机采样论文
31
  sampled_df = df.sample(n=num_samples, random_state=iteration)
32
 
33
- # Plotly 可视化
34
- fig = px.scatter(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  sampled_df,
36
  x="x",
37
  y="y",
@@ -44,7 +92,7 @@ fig = px.scatter(
44
  for cluster in sampled_df["cluster"].unique():
45
  centroid_x = sampled_df[sampled_df["cluster"] == cluster]["centroid_x"].iloc[0]
46
  centroid_y = sampled_df[sampled_df["cluster"] == cluster]["centroid_y"].iloc[0]
47
- fig.add_scatter(
48
  x=[centroid_x],
49
  y=[centroid_y],
50
  mode="markers",
@@ -52,11 +100,16 @@ for cluster in sampled_df["cluster"].unique():
52
  name=f"Cluster {cluster} Center",
53
  )
54
 
55
- # 让图占比更大
56
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
57
 
58
  # 显示采样论文的详细信息
59
- st.subheader("采样论文详细信息")
60
  st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
61
 
62
  # 增加第二种可视化方式
 
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
5
+ import hypernetx as hnx
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
+ from io import BytesIO
9
  import time
10
 
11
  # 读取数据
12
  df = pd.read_csv("gmm_point_tracking_with_centroids.csv")
13
+ st.set_page_config(layout="wide")
14
 
15
  # Streamlit 应用
16
  st.title("高斯混合分布聚类可视化")
17
 
18
+ # 设置页面宽度
19
+
20
  # 使用 sidebar 控制参数
21
  with st.sidebar:
22
  st.header("控制面板")
23
+ autoplay = st.button("自动播放")
 
 
 
24
  if autoplay:
25
  for i in range(1, 11):
26
+ with st.spinner(f"迭代 {i}"):
27
+ time.sleep(1)
28
+ st.session_state.iteration = i
29
+ st.rerun()
30
+ st.session_state.autoplay = False
31
+ # st.experimental_rerun()
32
 
33
  # 主页面布局
34
+ if 'autoplay' not in st.session_state:
35
+ st.session_state.autoplay = True
36
+
37
+ if 'iteration' not in st.session_state:
38
+ st.session_state.iteration = 1
39
+
40
+ if st.session_state.autoplay:
41
+ # 隐藏迭代次数滑条
42
+ iteration = st.session_state.iteration
43
+ else:
44
+ # 显示迭代次数滑条
45
+ iteration = st.slider("选择迭代次数", min_value=1, max_value=10, value=st.session_state.iteration, step=1)
46
+
47
+ # 动态限制采样数量的最大值
48
+ max_samples = len(df)
49
+ num_samples = st.slider("选择采样论文数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1)
50
 
51
  # 随机采样论文
52
  sampled_df = df.sample(n=num_samples, random_state=iteration)
53
 
54
+ # 计算每个论文属于各个 cluster 的概率
55
+ probabilities = []
56
+ for idx, row in sampled_df.iterrows():
57
+ prob_str = row["probabilities"].strip("[]")
58
+ prob_list = list(map(float, prob_str.split(", ")))
59
+ probabilities.append(prob_list)
60
+
61
+ # 找到每个论文概率最高的 3 个 cluster
62
+ k = 3
63
+ hyperedges = {}
64
+ for idx, prob in enumerate(probabilities):
65
+ top_k = np.argsort(prob)[-k:][::-1]
66
+ hyperedges[idx] = [f"Cluster {c}" for c in top_k]
67
+
68
+ # 构建超图
69
+ H = hnx.Hypergraph(hyperedges)
70
+
71
+ # 绘制超图
72
+ fig_hnx, ax = plt.subplots(figsize=(12, 8))
73
+ hnx.draw(H, ax=ax)
74
+
75
+ # 将超图保存为图像
76
+ canvas = FigureCanvas(fig_hnx)
77
+ buffer = BytesIO()
78
+ canvas.print_png(buffer)
79
+ buffer.seek(0)
80
+
81
+ # 用 Plotly 可视化高斯混合分布
82
+ fig_gmm = px.scatter(
83
  sampled_df,
84
  x="x",
85
  y="y",
 
92
  for cluster in sampled_df["cluster"].unique():
93
  centroid_x = sampled_df[sampled_df["cluster"] == cluster]["centroid_x"].iloc[0]
94
  centroid_y = sampled_df[sampled_df["cluster"] == cluster]["centroid_y"].iloc[0]
95
+ fig_gmm.add_scatter(
96
  x=[centroid_x],
97
  y=[centroid_y],
98
  mode="markers",
 
100
  name=f"Cluster {cluster} Center",
101
  )
102
 
103
+ # 并排展示超图和高斯混合分布
104
+ col1, col2 = st.columns(2)
105
+ col1.header("超图可视化")
106
+ col1.image(buffer, caption="超图可视化", use_column_width=True)
107
+
108
+ col2.header("高斯混合分布聚类结果")
109
+ col2.plotly_chart(fig_gmm, use_container_width=True)
110
 
111
  # 显示采样论文的详细信息
112
+ st.header("采样论文详细信息")
113
  st.dataframe(sampled_df[["title", "keywords", "rating_avg", "confidence_avg", "site"]])
114
 
115
  # 增加第二种可视化方式