File size: 5,571 Bytes
044c7d2
 
 
 
 
 
 
 
 
 
 
 
0f4db48
 
 
044c7d2
 
0f4db48
044c7d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f4db48
044c7d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f4db48
 
044c7d2
 
 
 
 
 
 
 
 
 
 
 
 
0f4db48
 
044c7d2
0f4db48
 
044c7d2
0f4db48
 
 
 
044c7d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Dict
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import hypernetx as hnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import time
from utils.data_processor import load_data, process_data, build_hyperedges
from utils.visualizer import visualize_gmm, visualize_ratings

from utils.streamlit_hypergraph import hypergraph_visualization_component


def main():
    st.title("NeurIPS 2024 Bench Paper 高斯混合聚类分析")

    # 自动播放
    slider_max = 10
    if 'play_state' not in st.session_state:
        st.session_state.play_state = False
    if 'iteration' not in st.session_state:
        st.session_state.iteration = 0
    # 定义回调函数来切换播放状态
    def toggle_play():
        if not st.session_state.play_state and st.session_state.iteration == slider_max:
            st.session_state.iteration = 0 # 重置迭代次数
        st.session_state.play_state = not st.session_state.play_state

    # 创建播放/暂停按钮
    if st.session_state.play_state:
        button_label = "暂停"
    else:
        button_label = "开始拟合"
    st.button(button_label, on_click=toggle_play, key="play_button")
    # 播放速度
    # speed = st.slider("播放速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
    # 主页面布局

    # 显示迭代次数滑条
    iteration = st.slider("迭代步骤", min_value=1, max_value=slider_max, 
                            value=st.session_state.iteration, step=1, 
                            key="iteration_slider")
    # st.write(f"当前迭代次数: {iteration}")
    # print(st.session_state.iteration)


    # 动态限制采样数量的最大值
    df = load_data()

    
    # 使用 sidebar 控制参数
    with st.sidebar:
        st.header("控制面板")
        speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
        draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width")
        draw_height = st.slider("绘图高度",min_value=3, max_value=20, value=6, step=1, key="draw_height")

        max_samples = len(df)
        num_samples = st.slider("选择采样论文数量", min_value=1, 
                                max_value=min(100, max_samples), value=min(10, max_samples), step=1)
                                
        
        # 添加复选框选择显示 paper 的属性
        display_attribute = st.selectbox(
            "选择显示 paper 的属性",
            ["order", "index", "id", "title", "keywords", "author"]
        )
        # 选择是 top k 还是 top p
        display_option = st.selectbox(
            "选择显示的选项",
            ["Top K Clusters", "Clusters Up To Probability P"]
        )
        # Top K Clusters
        if display_option == "Top K Clusters":
            max_k = 5
            top_k = st.slider("选择 K 值", min_value=1, max_value=max_k, 
                              value=1, step=1)
            top_p = None
        else:
            top_k = None
            top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01)

        # 处理数据
        sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples)
        # print(display_attribute) # 字符串
        hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p)
        hypergraph = hnx.Hypergraph(hyperedges)

        # print(hyperedges)

    show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges")
    show_gaussian = st.checkbox("显示高斯分布", value=False, key="show_gaussian")

    if show_hypergraph:
        hypergraph_visualization_component(hypergraph, draw_width, draw_height)
    
    if show_gaussian:
        st.header("高斯混合分布聚类结果")
        fig_gmm = visualize_gmm(sampled_df, iteration)
        st.plotly_chart(fig_gmm, use_container_width=True)

    # 显示采样论文的详细信息
    st.header("采样论文详细信息")
    st.dataframe(sampled_df[["index", "title", "keywords", "rating_avg", "confidence_avg", "site"]
                            ]
                            # .style.highlight_max(axis=0)
                            )

    # 增加第二种可视化方式
    # st.header("论文评分分布")
    # fig_bar = visualize_ratings(sampled_df)
    # st.plotly_chart(fig_bar, use_container_width=True)


    
    # 自动播放功能
    # print(st.session_state.play_state)
    if st.session_state.play_state:
        # 使用空容器来显示进度
        progress_container = st.empty()
        
        with st.spinner("正在播放..."):
            if st.session_state.iteration < slider_max:
                # 增加滑动条值
                st.session_state.iteration += 1
                st.write(f"当前迭代次数: {st.session_state.iteration}")
                # print(st.session_state.iteration)
                # 等待一小段时间模拟滑动过程
                time.sleep(1/speed)  # 根据速度调整等待时间
                # 使用rerun来更新页面
                st.rerun()
            else:
                # 到达最大值时停止播放
                st.session_state.play_state = False


    
# if __name__ == "__main__":
# # 设置页面布局
#     st.set_page_config(layout="wide")
#     # 运行主函数
main()