HyperPapers / utils /data_processor.py
2catycm's picture
初步结果
044c7d2
from typing import Dict, List
import pandas as pd
import numpy as np
def load_data():
return pd.read_csv("gmm_point_tracking_with_centroids.csv").reset_index()
def process_data(df, iteration, num_samples):
# 随机采样论文
sampled_df = df.sample(n=num_samples, random_state=iteration).reset_index()
# 计算每个论文属于各个 cluster 的概率
probabilities = []
paper_attributes = []
for idx, row in sampled_df.iterrows():
prob_str = row["probabilities"].strip("[]")
prob_list = list(map(float, prob_str.split(", ")))
probabilities.append(prob_list)
paper_attributes.append(
{
"order": idx,
"index": row['index'],
"id": row["id"],
"title": row["title"],
"keywords": row["keywords"],
"author": row["author"],
}
)
return sampled_df, probabilities, paper_attributes
def build_hyperedges(
probabilities,
paper_attributes: List[Dict[str, str]],
display_attribute_name: str,
top_k: int = None,
top_p: float = None,
) -> Dict[str, List[str]]:
# 构建超图边
hyperedges: Dict[str, List[str]] = {}
for idx, (prob, paper_attr) in enumerate(zip(probabilities, paper_attributes)):
if display_attribute_name == "index" or display_attribute_name == "order":
# display_attribute = f"Paper {idx}"
display_attribute = f"Paper {paper_attr[display_attribute_name]}"
else:
display_attribute: str = paper_attr[display_attribute_name]
if top_k is not None:
selected_indices = np.argsort(prob)[-top_k:][::-1]
else:
# 累加起来,直到第一次大于等于 p
selected_indices = []
cumulative_prob = 0.0
for i, p in enumerate(np.sort(prob)[::-1]):
selected_indices.append(i)
cumulative_prob += p
if cumulative_prob > top_p+1e-4:
break
for cluster in selected_indices:
cluster_name: str = f"Cluster {cluster}"
if cluster_name not in hyperedges:
hyperedges[cluster_name] = []
hyperedges[cluster_name].append(display_attribute)
return hyperedges