Spaces:
Sleeping
Sleeping
import random | |
import pandas as pd | |
import json | |
def get_current_stage(backend, dataset, stage_splits, threshold=3): | |
df = backend.get_all_rows() | |
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
# Check Stage 1 | |
stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]] | |
if all(counts.get(iid, 0) >= threshold for iid in stage1_ids): | |
# Check Stage 2 | |
stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]] | |
if all(counts.get(iid, 0) >= threshold for iid in stage2_ids): | |
return 3 | |
else: | |
return 2 | |
return 1 | |
def get_random_session_samples( | |
backend, dataset, stage_splits, user_name, num_samples=30 | |
): | |
df = backend.get_all_rows() | |
# Defensive fallback | |
if df.empty: | |
stage = 1 | |
stage_pool = stage_splits["stage1"] | |
return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage | |
global_stage = get_current_stage(backend, dataset, stage_splits) | |
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"]) | |
# if user finished global_stage, they can see the next stage | |
for stage_num in range(global_stage, 4): # stages 1 to 3 | |
stage_key = f"stage{stage_num}" | |
stage_pool = stage_splits[stage_key] | |
eligible_indices = [ | |
i | |
for i in stage_pool | |
if counts.get(dataset[i]["interpretation_id"], 0) < 3 | |
and dataset[i]["interpretation_id"] not in seen_ids | |
] | |
if eligible_indices: | |
return ( | |
random.sample( | |
eligible_indices, min(num_samples, len(eligible_indices)) | |
), | |
stage_num, | |
) | |
# If this user has completed everything (even beyond current stage) | |
return [], 4 | |
def generate_stage_splits( | |
dataset, k_stage1=100, seed=42, output_path="stage_indices.json" | |
): | |
total_indices = list(range(len(dataset))) | |
random.seed(seed) | |
# Stage 1 | |
stage1 = random.sample(total_indices, k_stage1) | |
remaining = list(set(total_indices) - set(stage1)) | |
# Shuffle remaining and split equally | |
random.shuffle(remaining) | |
half = len(remaining) // 2 | |
stage2 = remaining[:half] | |
stage3 = remaining[half:] | |
# Validate: all indices accounted for, no duplicates | |
combined = set(stage1 + stage2 + stage3) | |
assert len(combined) == len( | |
total_indices | |
), "❌ Some indices are missing or duplicated!" | |
assert combined == set(total_indices), "❌ Index sets do not fully cover dataset!" | |
# Save all stages | |
stage_splits = { | |
"stage1": sorted(stage1), | |
"stage2": sorted(stage2), | |
"stage3": sorted(stage3), | |
} | |
with open(output_path, "w") as f: | |
json.dump(stage_splits, f, indent=2) | |
print(f"✅ Saved stage splits to {output_path}") | |
print(f"Stage 1: {len(stage1)} samples") | |
print(f"Stage 2: {len(stage2)} samples") | |
print(f"Stage 3: {len(stage3)} samples") | |