from typing import Dict, List import streamlit as st import pandas as pd import numpy as np import plotly.express as px import hypernetx as hnx import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from io import BytesIO import time import json from utils.data_processor import process_data, build_hyperedges from utils.visualizer import visualize_gmm, visualize_ratings from utils.streamlit_hypergraph import hypergraph_visualization_component def load_json_data(file_path: str): """从JSON文件中加载数据""" with open(file_path, 'r') as f: data = json.load(f) return data def main(): with st.sidebar: st.header("控制面板") dataset_selection = st.selectbox( "选择可视化的数据集", ["NeurIPS 2024 Bench", "Cora Co-Author"] ) st.title(f"{dataset_selection} Author-Paper 超图可视化分析") # 自动播放状态 slider_max = 200 if 'play_state' not in st.session_state: st.session_state.play_state = False if 'iteration_app2' not in st.session_state: st.session_state.iteration_app2 = 10 # 定义回调函数来切换播放状态 def toggle_play(): if not st.session_state.play_state and st.session_state.iteration_app2 == slider_max: st.session_state.iteration_app2 = 10 # 重置迭代次数 st.session_state.play_state = not st.session_state.play_state # 创建播放/暂停按钮 if st.session_state.play_state: button_label = "暂停" else: button_label = "开始拟合" st.button(button_label, on_click=toggle_play, key="play_button") # 显示迭代次数滑条 iteration = st.slider("迭代步骤", min_value=10, max_value=slider_max, value=st.session_state.iteration_app2, step=10, key="iteration_app2_slider") # 从JSON文件加载数据 # 假设JSON文件路径为"authors_papers.json" if "hyper_edges" not in st.session_state or "labels_history" not in st.session_state: if dataset_selection == "NeurIPS 2024 Bench": json_hyper_edges = load_json_data("hyper_edges_nips.json") json_labels_history:list = load_json_data("labels_history_nips.json") elif dataset_selection == "Cora Co-Author": json_hyper_edges = load_json_data("hyper_edges.json") json_labels_history:list = load_json_data("labels_history.json") else: st.error("未知数据集") raise ValueError("Unknown dataset selected") st.session_state.hyper_edges = json_hyper_edges st.session_state.labels_history = { item['epoch']:item["labels"] for item in json_labels_history } # 使用 sidebar 控制参数 with st.sidebar: speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width") draw_height = st.slider("绘图高度", min_value=3, max_value=20, value=6, step=1, key="draw_height") max_samples = len(st.session_state.hyper_edges) num_samples = st.slider("选择采样作者数量", min_value=1, max_value=min(100, max_samples), value=min(10, max_samples), step=1) # 采样部分作者 sampled_authors = np.random.choice(st.session_state.hyper_edges, size=num_samples, replace=False) show_labels = st.checkbox("展示分类结果", value=True, key="show_labels") def paper_list_to_types(iteration:int, paper_list:List[int]) -> List[str]: """将论文列表转换为类型列表""" labels:List[int] = st.session_state.labels_history[str(iteration)] return [labels[paper] for paper in paper_list] sampled_data = {list(author_dict.keys() )[0]: list(author_dict.values() )[0] if not show_labels else paper_list_to_types(iteration, list(author_dict.values())[0]) for author_dict in sampled_authors} # 构建超边 hyperedges = sampled_data # 修改 类别信息 hypergraph = hnx.Hypergraph(hyperedges) show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges") if show_hypergraph: hypergraph_visualization_component(hypergraph, draw_width, draw_height) # 显示采样作者的详细信息 st.header("采样作者详细信息") authors_df = pd.DataFrame([(author, len(papers)) for author, papers in sampled_data.items()], columns=["作者", "论文数量"]) st.dataframe(authors_df) # 自动播放功能 if st.session_state.play_state: with st.spinner("正在播放..."): if st.session_state.iteration_app2 < slider_max: st.session_state.iteration_app2 += 10 st.write(f"当前迭代次数: {st.session_state.iteration_app2}") time.sleep(1/speed) # 根据速度调整等待时间 st.rerun() else: st.session_state.play_state = False main()