import pandas as pd import numpy as np def load_data(): return pd.read_csv("gmm_point_tracking_with_centroids.csv") def process_data(df, iteration, num_samples): # 随机采样论文 sampled_df = df.sample(n=num_samples, random_state=iteration) # 计算每个论文属于各个 cluster 的概率 probabilities = [] for idx, row in sampled_df.iterrows(): prob_str = row["probabilities"].strip("[]") prob_list = list(map(float, prob_str.split(", "))) probabilities.append(prob_list) # 找到每个论文概率最高的 3 个 cluster k = 3 hyperedges = {} for idx, prob in enumerate(probabilities): top_k = np.argsort(prob)[-k:][::-1] hyperedges[idx] = [f"Cluster {c}" for c in top_k] return sampled_df, probabilities, hyperedges