evaluation / backend /helpers.py
iyosha's picture
Update backend/helpers.py
156bd1a verified
import random
import pandas as pd
import json
def get_current_stage(backend, dataset, stage_splits, threshold=3):
df = backend.get_all_rows()
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
# Check Stage 1
stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]]
if all(counts.get(iid, 0) >= threshold for iid in stage1_ids):
# Check Stage 2
stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]]
if all(counts.get(iid, 0) >= threshold for iid in stage2_ids):
return 3
else:
return 2
return 1
def get_random_session_samples(
backend, dataset, stage_splits, user_name, num_samples=30
):
df = backend.get_all_rows()
# Defensive fallback
if df.empty:
stage = 1
stage_pool = stage_splits["stage1"]
return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage
global_stage = get_current_stage(backend, dataset, stage_splits)
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"])
# if user finished global_stage, they can see the next stage
for stage_num in range(global_stage, 4): # stages 1 to 3
stage_key = f"stage{stage_num}"
stage_pool = stage_splits[stage_key]
eligible_indices = [
i
for i in stage_pool
if counts.get(dataset[i]["interpretation_id"], 0) < 3
and dataset[i]["interpretation_id"] not in seen_ids
]
if eligible_indices:
return (
random.sample(
eligible_indices, min(num_samples, len(eligible_indices))
),
stage_num,
)
# If this user has completed everything (even beyond current stage)
return [], 4
def generate_stage_splits(
dataset, k_stage1=100, seed=42, output_path="stage_indices.json"
):
total_indices = list(range(len(dataset)))
random.seed(seed)
# Stage 1
stage1 = random.sample(total_indices, k_stage1)
remaining = list(set(total_indices) - set(stage1))
# Shuffle remaining and split equally
random.shuffle(remaining)
half = len(remaining) // 2
stage2 = remaining[:half]
stage3 = remaining[half:]
# Validate: all indices accounted for, no duplicates
combined = set(stage1 + stage2 + stage3)
assert len(combined) == len(
total_indices
), "❌ Some indices are missing or duplicated!"
assert combined == set(total_indices), "❌ Index sets do not fully cover dataset!"
# Save all stages
stage_splits = {
"stage1": sorted(stage1),
"stage2": sorted(stage2),
"stage3": sorted(stage3),
}
with open(output_path, "w") as f:
json.dump(stage_splits, f, indent=2)
print(f"✅ Saved stage splits to {output_path}")
print(f"Stage 1: {len(stage1)} samples")
print(f"Stage 2: {len(stage2)} samples")
print(f"Stage 3: {len(stage3)} samples")