Zebra / backend /puzzle_dataset.py
guoj5's picture
Add application file
8e80adf
raw
history blame contribute delete
789 Bytes
# puzzle_dataset.py
import pandas as pd
import os
from dotenv import load_dotenv
from huggingface_hub import login
from deep_convert import deep_convert
df = None
def init_dataset():
global df
if df is not None:
return
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# 加载 parquet 数据
df = pd.read_parquet("hf://datasets/WildEval/ZebraLogic/grid_mode/test-00000-of-00001.parquet")
def get_puzzle_by_index(idx: int):
"""
返回 (puzzle_text, expected_solution) 二元组
"""
global df
if df is None:
init_dataset()
# 简单判断 index 是否越界
if idx < 0 or idx >= len(df):
return None, None
row = df.iloc[idx]
return row['puzzle'], deep_convert(row['solution'])