Spaces:
Sleeping
Sleeping
File size: 5,374 Bytes
0f4db48 cbbacc3 0f4db48 cbbacc3 0f4db48 81f3976 0f4db48 81f3976 0f4db48 81f3976 0f4db48 81f3976 0f4db48 81f3976 0f4db48 044c7d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from typing import Dict, List
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import hypernetx as hnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import time
import json
from utils.data_processor import process_data, build_hyperedges
from utils.visualizer import visualize_gmm, visualize_ratings
from utils.streamlit_hypergraph import hypergraph_visualization_component
def load_json_data(file_path: str):
"""从JSON文件中加载数据"""
with open(file_path, 'r') as f:
data = json.load(f)
return data
def main():
with st.sidebar:
st.header("控制面板")
dataset_selection = st.selectbox(
"选择可视化的数据集",
["NeurIPS 2024 Bench", "Cora Co-Author"]
)
st.title(f"{dataset_selection} Author-Paper 超图可视化分析")
# 自动播放状态
slider_max = 200
if 'play_state' not in st.session_state:
st.session_state.play_state = False
if 'iteration_app2' not in st.session_state:
st.session_state.iteration_app2 = 10
# 定义回调函数来切换播放状态
def toggle_play():
if not st.session_state.play_state and st.session_state.iteration_app2 == slider_max:
st.session_state.iteration_app2 = 10 # 重置迭代次数
st.session_state.play_state = not st.session_state.play_state
# 创建播放/暂停按钮
if st.session_state.play_state:
button_label = "暂停"
else:
button_label = "开始拟合"
st.button(button_label, on_click=toggle_play, key="play_button")
# 显示迭代次数滑条
iteration = st.slider("迭代步骤", min_value=10, max_value=slider_max,
value=st.session_state.iteration_app2, step=10,
key="iteration_app2_slider")
# 从JSON文件加载数据
# 假设JSON文件路径为"authors_papers.json"
if "hyper_edges" not in st.session_state or "labels_history" not in st.session_state:
if dataset_selection == "NeurIPS 2024 Bench":
json_hyper_edges = load_json_data("hyper_edges_nips.json")
json_labels_history:list = load_json_data("labels_history_nips.json")
elif dataset_selection == "Cora Co-Author":
json_hyper_edges = load_json_data("hyper_edges.json")
json_labels_history:list = load_json_data("labels_history.json")
else:
st.error("未知数据集")
raise ValueError("Unknown dataset selected")
st.session_state.hyper_edges = json_hyper_edges
st.session_state.labels_history = {
item['epoch']:item["labels"]
for item in json_labels_history
}
# 使用 sidebar 控制参数
with st.sidebar:
speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width")
draw_height = st.slider("绘图高度", min_value=3, max_value=20, value=6, step=1, key="draw_height")
max_samples = len(st.session_state.hyper_edges)
num_samples = st.slider("选择采样作者数量", min_value=1,
max_value=min(100, max_samples),
value=min(10, max_samples), step=1)
# 采样部分作者
sampled_authors = np.random.choice(st.session_state.hyper_edges, size=num_samples, replace=False)
show_labels = st.checkbox("展示分类结果", value=True, key="show_labels")
def paper_list_to_types(iteration:int, paper_list:List[int]) -> List[str]:
"""将论文列表转换为类型列表"""
labels:List[int] = st.session_state.labels_history[str(iteration)]
return [labels[paper] for paper in paper_list]
sampled_data = {list(author_dict.keys()
)[0]: list(author_dict.values()
)[0] if not show_labels else paper_list_to_types(iteration, list(author_dict.values())[0])
for author_dict in sampled_authors}
# 构建超边
hyperedges = sampled_data
# 修改 类别信息
hypergraph = hnx.Hypergraph(hyperedges)
show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges")
if show_hypergraph:
hypergraph_visualization_component(hypergraph, draw_width, draw_height)
# 显示采样作者的详细信息
st.header("采样作者详细信息")
authors_df = pd.DataFrame([(author, len(papers)) for author, papers in sampled_data.items()],
columns=["作者", "论文数量"])
st.dataframe(authors_df)
# 自动播放功能
if st.session_state.play_state:
with st.spinner("正在播放..."):
if st.session_state.iteration_app2 < slider_max:
st.session_state.iteration_app2 += 10
st.write(f"当前迭代次数: {st.session_state.iteration_app2}")
time.sleep(1/speed) # 根据速度调整等待时间
st.rerun()
else:
st.session_state.play_state = False
main() |