File size: 5,374 Bytes
0f4db48
cbbacc3
0f4db48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbbacc3
 
0f4db48
 
 
 
 
 
 
 
 
 
 
 
81f3976
 
0f4db48
 
 
81f3976
 
0f4db48
 
 
 
 
 
 
 
 
 
 
81f3976
 
0f4db48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f3976
 
0f4db48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f3976
 
 
0f4db48
 
 
 
044c7d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Dict, List
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import hypernetx as hnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import time
import json
from utils.data_processor import process_data, build_hyperedges
from utils.visualizer import visualize_gmm, visualize_ratings
from utils.streamlit_hypergraph import hypergraph_visualization_component

def load_json_data(file_path: str):
    """从JSON文件中加载数据"""
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data




def main():
    with st.sidebar:
        st.header("控制面板")
        dataset_selection = st.selectbox(
            "选择可视化的数据集",
            ["NeurIPS 2024 Bench", "Cora Co-Author"]
        )
    st.title(f"{dataset_selection} Author-Paper 超图可视化分析")

    # 自动播放状态
    slider_max = 200
    if 'play_state' not in st.session_state:
        st.session_state.play_state = False
    if 'iteration_app2' not in st.session_state:
        st.session_state.iteration_app2 = 10

    # 定义回调函数来切换播放状态
    def toggle_play():
        if not st.session_state.play_state and st.session_state.iteration_app2 == slider_max:
            st.session_state.iteration_app2 = 10  # 重置迭代次数
        st.session_state.play_state = not st.session_state.play_state

    # 创建播放/暂停按钮
    if st.session_state.play_state:
        button_label = "暂停"
    else:
        button_label = "开始拟合"
    st.button(button_label, on_click=toggle_play, key="play_button")

    # 显示迭代次数滑条
    iteration = st.slider("迭代步骤", min_value=10, max_value=slider_max, 
                          value=st.session_state.iteration_app2, step=10, 
                          key="iteration_app2_slider")

    # 从JSON文件加载数据
    # 假设JSON文件路径为"authors_papers.json"
    if "hyper_edges" not in st.session_state or "labels_history" not in st.session_state:
        if dataset_selection == "NeurIPS 2024 Bench":
            json_hyper_edges = load_json_data("hyper_edges_nips.json")
            json_labels_history:list = load_json_data("labels_history_nips.json")
        elif dataset_selection == "Cora Co-Author":
            json_hyper_edges = load_json_data("hyper_edges.json")
            json_labels_history:list = load_json_data("labels_history.json")
        else:
            st.error("未知数据集")
            raise ValueError("Unknown dataset selected")

        st.session_state.hyper_edges = json_hyper_edges
        
        st.session_state.labels_history = {
            item['epoch']:item["labels"]
            for item in json_labels_history
        }
        
    
    # 使用 sidebar 控制参数
    with st.sidebar:
        speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
        draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width")
        draw_height = st.slider("绘图高度", min_value=3, max_value=20, value=6, step=1, key="draw_height")

        max_samples = len(st.session_state.hyper_edges)
        num_samples = st.slider("选择采样作者数量", min_value=1, 
                                max_value=min(100, max_samples), 
                                value=min(10, max_samples), step=1)
        
        # 采样部分作者
        sampled_authors = np.random.choice(st.session_state.hyper_edges, size=num_samples, replace=False)
        
        show_labels = st.checkbox("展示分类结果", value=True, key="show_labels")
        
        def paper_list_to_types(iteration:int, paper_list:List[int]) -> List[str]:
            """将论文列表转换为类型列表"""
            labels:List[int] = st.session_state.labels_history[str(iteration)]
            return [labels[paper] for paper in paper_list]
        
        
        sampled_data = {list(author_dict.keys()
        )[0]: list(author_dict.values()
        )[0] if not show_labels else paper_list_to_types(iteration, list(author_dict.values())[0])
        for author_dict in sampled_authors}
        
        # 构建超边
        hyperedges = sampled_data


        # 修改 类别信息




        hypergraph = hnx.Hypergraph(hyperedges)

    show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges")

    if show_hypergraph:
        hypergraph_visualization_component(hypergraph, draw_width, draw_height)
    
    # 显示采样作者的详细信息
    st.header("采样作者详细信息")
    authors_df = pd.DataFrame([(author, len(papers)) for author, papers in sampled_data.items()], 
                             columns=["作者", "论文数量"])
    st.dataframe(authors_df)

    # 自动播放功能
    if st.session_state.play_state:
        with st.spinner("正在播放..."):
            if st.session_state.iteration_app2 < slider_max:
                st.session_state.iteration_app2 += 10
                st.write(f"当前迭代次数: {st.session_state.iteration_app2}")
                time.sleep(1/speed)  # 根据速度调整等待时间
                st.rerun()
            else:
                st.session_state.play_state = False

main()