File size: 815 Bytes
db9ca60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pandas as pd
import numpy as np

def load_data():
    return pd.read_csv("gmm_point_tracking_with_centroids.csv")

def process_data(df, iteration, num_samples):
    # 随机采样论文
    sampled_df = df.sample(n=num_samples, random_state=iteration)

    # 计算每个论文属于各个 cluster 的概率
    probabilities = []
    for idx, row in sampled_df.iterrows():
        prob_str = row["probabilities"].strip("[]")
        prob_list = list(map(float, prob_str.split(", ")))
        probabilities.append(prob_list)

    # 找到每个论文概率最高的 3 个 cluster
    k = 3
    hyperedges = {}
    for idx, prob in enumerate(probabilities):
        top_k = np.argsort(prob)[-k:][::-1]
        hyperedges[idx] = [f"Cluster {c}" for c in top_k]

    return sampled_df, probabilities, hyperedges