github_search_visualizations / task_visualizations.py
lambdaofgod's picture
feat: Add separate sliders for all and selected repositories in the PapersWithCode tasks tab
15420a6
raw
history blame
2.99 kB
import pandas as pd
import ast
import json
import plotly.express as px
import plotly.graph_objects as go
class TaskVisualizations:
def __init__(
self, task_counts_path, selected_task_counts_path, tasks_with_areas_path
):
self.tasks_with_areas_df = self.load_tasks_with_areas_df(
task_counts_path, tasks_with_areas_path
)
self.selected_tasks_with_areas_df = self.load_tasks_with_areas_df(
selected_task_counts_path, tasks_with_areas_path
)
@classmethod
def load_tasks_with_areas_df(
cls, task_counts_path, tasks_with_areas_path="data/paperswithcode_tasks.csv"
):
task_counts_df = pd.read_csv(task_counts_path)
raw_tasks_with_areas_df = pd.read_csv(tasks_with_areas_path)
return raw_tasks_with_areas_df.merge(task_counts_df, on="task")
@classmethod
def get_topk_merge_others(cls, df, by_col, val_col, k=10, val_threshold=1000):
sorted_df = df.copy().sort_values(val_col, ascending=False)
topk_dict = (
sorted_df[[by_col, val_col]].set_index(by_col).iloc[:k].to_dict()[val_col]
)
print(topk_dict)
sorted_df[by_col] = sorted_df[by_col].apply(
lambda k: k
if k in topk_dict.keys() and topk_dict[k] >= val_threshold
else "other"
)
sorted_df = sorted_df.groupby(by_col).agg({val_col: sum})
return sorted_df
@classmethod
def get_displayed_tasks_with_areas_df(cls, tasks_with_areas_df, min_task_count):
displayed_tasks_with_areas_df = tasks_with_areas_df.dropna().copy()
displayed_tasks_with_areas_df["task"] = displayed_tasks_with_areas_df.apply(
lambda r: r["task"] if r["count"] >= min_task_count else "other", axis=1
)
displayed_tasks_with_areas_df = (
displayed_tasks_with_areas_df.groupby("area")
.apply(lambda df: cls.get_topk_merge_others(df, "task", "count"))
.reset_index()
)
displayed_tasks_with_areas_df["task"] = (
displayed_tasks_with_areas_df["task"]
+ " "
+ displayed_tasks_with_areas_df["count"].apply(str)
)
return displayed_tasks_with_areas_df
def get_tasks_sunbursts(self, min_task_count_all, min_task_count_selected):
all_df = self.tasks_with_areas_df
selected_df = self.selected_tasks_with_areas_df
displayed_tasks_all_df = self.get_displayed_tasks_with_areas_df(
all_df, min_task_count_all
)
displayed_tasks_selected_df = self.get_displayed_tasks_with_areas_df(
selected_df, min_task_count_selected
)
all_sunburst = px.sunburst(
displayed_tasks_all_df, path=["area", "task"], values="count"
)
selected_sunburst = px.sunburst(
displayed_tasks_selected_df, path=["area", "task"], values="count"
)
return all_sunburst, selected_sunburst