from typing import Dict, List import pandas as pd import numpy as np def load_data(): return pd.read_csv("gmm_point_tracking_with_centroids.csv").reset_index() def process_data(df, iteration, num_samples): # 随机采样论文 sampled_df = df.sample(n=num_samples, random_state=iteration).reset_index() # 计算每个论文属于各个 cluster 的概率 probabilities = [] paper_attributes = [] for idx, row in sampled_df.iterrows(): prob_str = row["probabilities"].strip("[]") prob_list = list(map(float, prob_str.split(", "))) probabilities.append(prob_list) paper_attributes.append( { "order": idx, "index": row['index'], "id": row["id"], "title": row["title"], "keywords": row["keywords"], "author": row["author"], } ) return sampled_df, probabilities, paper_attributes def build_hyperedges( probabilities, paper_attributes: List[Dict[str, str]], display_attribute_name: str, top_k: int = None, top_p: float = None, ) -> Dict[str, List[str]]: # 构建超图边 hyperedges: Dict[str, List[str]] = {} for idx, (prob, paper_attr) in enumerate(zip(probabilities, paper_attributes)): if display_attribute_name == "index" or display_attribute_name == "order": # display_attribute = f"Paper {idx}" display_attribute = f"Paper {paper_attr[display_attribute_name]}" else: display_attribute: str = paper_attr[display_attribute_name] if top_k is not None: selected_indices = np.argsort(prob)[-top_k:][::-1] else: # 累加起来,直到第一次大于等于 p selected_indices = [] cumulative_prob = 0.0 for i, p in enumerate(np.sort(prob)[::-1]): selected_indices.append(i) cumulative_prob += p if cumulative_prob > top_p+1e-4: break for cluster in selected_indices: cluster_name: str = f"Cluster {cluster}" if cluster_name not in hyperedges: hyperedges[cluster_name] = [] hyperedges[cluster_name].append(display_attribute) return hyperedges