Spaces:
Sleeping
Sleeping
File size: 5,571 Bytes
044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 0f4db48 044c7d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import Dict
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import hypernetx as hnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import time
from utils.data_processor import load_data, process_data, build_hyperedges
from utils.visualizer import visualize_gmm, visualize_ratings
from utils.streamlit_hypergraph import hypergraph_visualization_component
def main():
st.title("NeurIPS 2024 Bench Paper 高斯混合聚类分析")
# 自动播放
slider_max = 10
if 'play_state' not in st.session_state:
st.session_state.play_state = False
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
# 定义回调函数来切换播放状态
def toggle_play():
if not st.session_state.play_state and st.session_state.iteration == slider_max:
st.session_state.iteration = 0 # 重置迭代次数
st.session_state.play_state = not st.session_state.play_state
# 创建播放/暂停按钮
if st.session_state.play_state:
button_label = "暂停"
else:
button_label = "开始拟合"
st.button(button_label, on_click=toggle_play, key="play_button")
# 播放速度
# speed = st.slider("播放速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
# 主页面布局
# 显示迭代次数滑条
iteration = st.slider("迭代步骤", min_value=1, max_value=slider_max,
value=st.session_state.iteration, step=1,
key="iteration_slider")
# st.write(f"当前迭代次数: {iteration}")
# print(st.session_state.iteration)
# 动态限制采样数量的最大值
df = load_data()
# 使用 sidebar 控制参数
with st.sidebar:
st.header("控制面板")
speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width")
draw_height = st.slider("绘图高度",min_value=3, max_value=20, value=6, step=1, key="draw_height")
max_samples = len(df)
num_samples = st.slider("选择采样论文数量", min_value=1,
max_value=min(100, max_samples), value=min(10, max_samples), step=1)
# 添加复选框选择显示 paper 的属性
display_attribute = st.selectbox(
"选择显示 paper 的属性",
["order", "index", "id", "title", "keywords", "author"]
)
# 选择是 top k 还是 top p
display_option = st.selectbox(
"选择显示的选项",
["Top K Clusters", "Clusters Up To Probability P"]
)
# Top K Clusters
if display_option == "Top K Clusters":
max_k = 5
top_k = st.slider("选择 K 值", min_value=1, max_value=max_k,
value=1, step=1)
top_p = None
else:
top_k = None
top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
# 处理数据
sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples)
# print(display_attribute) # 字符串
hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p)
hypergraph = hnx.Hypergraph(hyperedges)
# print(hyperedges)
show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges")
show_gaussian = st.checkbox("显示高斯分布", value=False, key="show_gaussian")
if show_hypergraph:
hypergraph_visualization_component(hypergraph, draw_width, draw_height)
if show_gaussian:
st.header("高斯混合分布聚类结果")
fig_gmm = visualize_gmm(sampled_df, iteration)
st.plotly_chart(fig_gmm, use_container_width=True)
# 显示采样论文的详细信息
st.header("采样论文详细信息")
st.dataframe(sampled_df[["index", "title", "keywords", "rating_avg", "confidence_avg", "site"]
]
# .style.highlight_max(axis=0)
)
# 增加第二种可视化方式
# st.header("论文评分分布")
# fig_bar = visualize_ratings(sampled_df)
# st.plotly_chart(fig_bar, use_container_width=True)
# 自动播放功能
# print(st.session_state.play_state)
if st.session_state.play_state:
# 使用空容器来显示进度
progress_container = st.empty()
with st.spinner("正在播放..."):
if st.session_state.iteration < slider_max:
# 增加滑动条值
st.session_state.iteration += 1
st.write(f"当前迭代次数: {st.session_state.iteration}")
# print(st.session_state.iteration)
# 等待一小段时间模拟滑动过程
time.sleep(1/speed) # 根据速度调整等待时间
# 使用rerun来更新页面
st.rerun()
else:
# 到达最大值时停止播放
st.session_state.play_state = False
# if __name__ == "__main__":
# # 设置页面布局
# st.set_page_config(layout="wide")
# # 运行主函数
main() |