Spaces:
Sleeping
Sleeping
from typing import Dict | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import hypernetx as hnx | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
from io import BytesIO | |
import time | |
from utils.data_processor import load_data, process_data, build_hyperedges | |
from utils.visualizer import visualize_gmm, visualize_ratings | |
from utils.streamlit_hypergraph import hypergraph_visualization_component | |
def main(): | |
st.title("NeurIPS 2024 Bench Paper 高斯混合聚类分析") | |
# 自动播放 | |
slider_max = 10 | |
if 'play_state' not in st.session_state: | |
st.session_state.play_state = False | |
if 'iteration' not in st.session_state: | |
st.session_state.iteration = 0 | |
# 定义回调函数来切换播放状态 | |
def toggle_play(): | |
if not st.session_state.play_state and st.session_state.iteration == slider_max: | |
st.session_state.iteration = 0 # 重置迭代次数 | |
st.session_state.play_state = not st.session_state.play_state | |
# 创建播放/暂停按钮 | |
if st.session_state.play_state: | |
button_label = "暂停" | |
else: | |
button_label = "开始拟合" | |
st.button(button_label, on_click=toggle_play, key="play_button") | |
# 播放速度 | |
# speed = st.slider("播放速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") | |
# 主页面布局 | |
# 显示迭代次数滑条 | |
iteration = st.slider("迭代步骤", min_value=1, max_value=slider_max, | |
value=st.session_state.iteration, step=1, | |
key="iteration_slider") | |
# st.write(f"当前迭代次数: {iteration}") | |
# print(st.session_state.iteration) | |
# 动态限制采样数量的最大值 | |
df = load_data() | |
# 使用 sidebar 控制参数 | |
with st.sidebar: | |
st.header("控制面板") | |
speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider") | |
draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width") | |
draw_height = st.slider("绘图高度",min_value=3, max_value=20, value=6, step=1, key="draw_height") | |
max_samples = len(df) | |
num_samples = st.slider("选择采样论文数量", min_value=1, | |
max_value=min(100, max_samples), value=min(10, max_samples), step=1) | |
# 添加复选框选择显示 paper 的属性 | |
display_attribute = st.selectbox( | |
"选择显示 paper 的属性", | |
["order", "index", "id", "title", "keywords", "author"] | |
) | |
# 选择是 top k 还是 top p | |
display_option = st.selectbox( | |
"选择显示的选项", | |
["Top K Clusters", "Clusters Up To Probability P"] | |
) | |
# Top K Clusters | |
if display_option == "Top K Clusters": | |
max_k = 5 | |
top_k = st.slider("选择 K 值", min_value=1, max_value=max_k, | |
value=1, step=1) | |
top_p = None | |
else: | |
top_k = None | |
top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01) | |
# 处理数据 | |
sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples) | |
# print(display_attribute) # 字符串 | |
hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p) | |
hypergraph = hnx.Hypergraph(hyperedges) | |
# print(hyperedges) | |
show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges") | |
show_gaussian = st.checkbox("显示高斯分布", value=False, key="show_gaussian") | |
if show_hypergraph: | |
hypergraph_visualization_component(hypergraph, draw_width, draw_height) | |
if show_gaussian: | |
st.header("高斯混合分布聚类结果") | |
fig_gmm = visualize_gmm(sampled_df, iteration) | |
st.plotly_chart(fig_gmm, use_container_width=True) | |
# 显示采样论文的详细信息 | |
st.header("采样论文详细信息") | |
st.dataframe(sampled_df[["index", "title", "keywords", "rating_avg", "confidence_avg", "site"] | |
] | |
# .style.highlight_max(axis=0) | |
) | |
# 增加第二种可视化方式 | |
# st.header("论文评分分布") | |
# fig_bar = visualize_ratings(sampled_df) | |
# st.plotly_chart(fig_bar, use_container_width=True) | |
# 自动播放功能 | |
# print(st.session_state.play_state) | |
if st.session_state.play_state: | |
# 使用空容器来显示进度 | |
progress_container = st.empty() | |
with st.spinner("正在播放..."): | |
if st.session_state.iteration < slider_max: | |
# 增加滑动条值 | |
st.session_state.iteration += 1 | |
st.write(f"当前迭代次数: {st.session_state.iteration}") | |
# print(st.session_state.iteration) | |
# 等待一小段时间模拟滑动过程 | |
time.sleep(1/speed) # 根据速度调整等待时间 | |
# 使用rerun来更新页面 | |
st.rerun() | |
else: | |
# 到达最大值时停止播放 | |
st.session_state.play_state = False | |
# if __name__ == "__main__": | |
# # 设置页面布局 | |
# st.set_page_config(layout="wide") | |
# # 运行主函数 | |
main() |