Spaces:
Sleeping
Sleeping
from typing import Dict, List | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import hypernetx as hnx | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from io import BytesIO | |
import time | |
import json | |
from utils.data_processor import process_data, build_hyperedges | |
from utils.visualizer import visualize_gmm, visualize_ratings | |
from utils.streamlit_hypergraph import hypergraph_visualization_component | |
def load_json_data(file_path: str): | |
"""从JSON文件中加载数据""" | |
with open(file_path, 'r') as f: | |
data = json.load(f) | |
return data | |
def main(): | |
with st.sidebar: | |
st.header("控制面板") | |
dataset_selection = st.selectbox( | |
"选择可视化的数据集", | |
["NeurIPS 2024 Bench", "Cora Co-Author"] | |
) | |
st.title(f"{dataset_selection} Author-Paper 超图可视化分析") | |
# 自动播放状态 | |
slider_max = 200 | |
if 'play_state' not in st.session_state: | |
st.session_state.play_state = False | |
if 'iteration_app2' not in st.session_state: | |
st.session_state.iteration_app2 = 10 | |
# 定义回调函数来切换播放状态 | |
def toggle_play(): | |
if not st.session_state.play_state and st.session_state.iteration_app2 == slider_max: | |
st.session_state.iteration_app2 = 10 # 重置迭代次数 | |
st.session_state.play_state = not st.session_state.play_state | |
# 创建播放/暂停按钮 | |
if st.session_state.play_state: | |
button_label = "暂停" | |
else: | |
button_label = "开始拟合" | |
st.button(button_label, on_click=toggle_play, key="play_button") | |
# 显示迭代次数滑条 | |
iteration = st.slider("迭代步骤", min_value=10, max_value=slider_max, | |
value=st.session_state.iteration_app2, step=10, | |
key="iteration_app2_slider") | |
# 从JSON文件加载数据 | |
# 假设JSON文件路径为"authors_papers.json" | |
if "hyper_edges" not in st.session_state or "labels_history" not in st.session_state: | |
if dataset_selection == "NeurIPS 2024 Bench": | |
json_hyper_edges = load_json_data("hyper_edges_nips.json") | |
json_labels_history:list = load_json_data("labels_history_nips.json") | |
elif dataset_selection == "Cora Co-Author": | |
json_hyper_edges = load_json_data("hyper_edges.json") | |
json_labels_history:list = load_json_data("labels_history.json") | |
else: | |
st.error("未知数据集") | |
raise ValueError("Unknown dataset selected") | |
st.session_state.hyper_edges = json_hyper_edges | |
st.session_state.labels_history = { | |
item['epoch']:item["labels"] | |
for item in json_labels_history | |
} | |
# 使用 sidebar 控制参数 | |
with st.sidebar: | |
speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") | |
draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width") | |
draw_height = st.slider("绘图高度", min_value=3, max_value=20, value=6, step=1, key="draw_height") | |
max_samples = len(st.session_state.hyper_edges) | |
num_samples = st.slider("选择采样作者数量", min_value=1, | |
max_value=min(100, max_samples), | |
value=min(10, max_samples), step=1) | |
# 采样部分作者 | |
sampled_authors = np.random.choice(st.session_state.hyper_edges, size=num_samples, replace=False) | |
show_labels = st.checkbox("展示分类结果", value=True, key="show_labels") | |
def paper_list_to_types(iteration:int, paper_list:List[int]) -> List[str]: | |
"""将论文列表转换为类型列表""" | |
labels:List[int] = st.session_state.labels_history[str(iteration)] | |
return [labels[paper] for paper in paper_list] | |
sampled_data = {list(author_dict.keys() | |
)[0]: list(author_dict.values() | |
)[0] if not show_labels else paper_list_to_types(iteration, list(author_dict.values())[0]) | |
for author_dict in sampled_authors} | |
# 构建超边 | |
hyperedges = sampled_data | |
# 修改 类别信息 | |
hypergraph = hnx.Hypergraph(hyperedges) | |
show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges") | |
if show_hypergraph: | |
hypergraph_visualization_component(hypergraph, draw_width, draw_height) | |
# 显示采样作者的详细信息 | |
st.header("采样作者详细信息") | |
authors_df = pd.DataFrame([(author, len(papers)) for author, papers in sampled_data.items()], | |
columns=["作者", "论文数量"]) | |
st.dataframe(authors_df) | |
# 自动播放功能 | |
if st.session_state.play_state: | |
with st.spinner("正在播放..."): | |
if st.session_state.iteration_app2 < slider_max: | |
st.session_state.iteration_app2 += 10 | |
st.write(f"当前迭代次数: {st.session_state.iteration_app2}") | |
time.sleep(1/speed) # 根据速度调整等待时间 | |
st.rerun() | |
else: | |
st.session_state.play_state = False | |
main() |