quizbowl-submission / src /hf_datasets_utils.py
Maharshi Gor
Update leaderboard download, refactored hf_datasets_utils
54e2d5b
from huggingface_hub import HfApi, snapshot_download
from loguru import logger
api = HfApi()
def download_dataset_snapshot(repo_id, local_dir):
try:
logger.info(f"Downloading dataset snapshot from {repo_id} to {local_dir}")
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
repo_type="dataset",
tqdm_class=None,
)
except Exception as e:
logger.error(f"Error downloading dataset snapshot from {repo_id} to {local_dir}: {e}. Restarting space.")
api.restart_space(repo_id=repo_id)
def remove_files_from_dataset_repo(repo_id: str, path_patterns: list[str], commit_message: str = "Remove files"):
"""
Remove files or directories matching specified patterns from a Hugging Face dataset repository.
Args:
repo_id: The ID of the dataset repository (e.g., "username/dataset-name")
path_patterns: List of file or directory path patterns to remove
commit_message: Message for the commit that removes the files
"""
import fnmatch
import os
# Get all files in the repository
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
# Find files matching the patterns
files_to_remove = []
for pattern in path_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
files_to_remove.extend(matching_files)
# Delete each matching file
for path in files_to_remove:
try:
api.delete_file(
path_in_repo=path, repo_id=repo_id, repo_type="dataset", commit_message=f"{commit_message}: {path}"
)
print(f"Successfully removed {path} from {repo_id}")
except Exception as e:
print(f"Error removing {path}: {e}")
def update_dataset_info_readme(
repo_id: str,
dataset_info: dict,
license_id: str = None,
commit_message: str = "Update dataset_info in README.md",
):
"""
Update the dataset_info section in the README.md file of a Hugging Face dataset repository.
Args:
repo_id: The ID of the dataset repository (e.g., "username/dataset-name")
dataset_info: Dictionary containing dataset information to include in the README
license_id: Optional license identifier (e.g., "mit", "cc-by-4.0")
commit_message: Message for the commit
Example dataset_info structure:
{
"features": [
{"name": "text", "dtype": "string"},
{"name": "label", "dtype": "int64"}
],
"splits": [
{"name": "train", "num_examples": 10000, "num_bytes": 1000000},
{"name": "test", "num_examples": 1000, "num_bytes": 100000}
],
"download_size": 1200000,
"dataset_size": 1100000,
"configs": [
{
"config_name": "default",
"data_files": [
{"split": "train", "path": "data/train.csv"},
{"split": "test", "path": "data/test.csv"}
]
}
]
}
"""
import re
import yaml
from huggingface_hub import HfApi
api = HfApi()
# Check if README.md exists
try:
readme_content = api.hf_hub_download(repo_id=repo_id, repo_type="dataset", filename="README.md", token=None)
with open(readme_content, "r", encoding="utf-8") as f:
content = f.read()
except Exception:
# Create a new README.md if it doesn't exist
content = ""
# Parse existing YAML front matter if it exists
yaml_block = None
yaml_match = re.search(r"---\s*\n(.*?)\n\s*---", content, re.DOTALL)
if yaml_match:
yaml_text = yaml_match.group(1)
try:
yaml_block = yaml.safe_load(yaml_text)
except Exception as e:
print(f"Error parsing existing YAML front matter: {e}")
yaml_block = {}
else:
yaml_block = {}
# Update or add dataset_info and license
if dataset_info:
yaml_block["dataset_info"] = dataset_info
if license_id:
yaml_block["license"] = license_id
# Generate new YAML front matter
new_yaml = yaml.dump(yaml_block, sort_keys=False, default_flow_style=False)
new_yaml_block = f"---\n{new_yaml}---\n"
# Replace existing YAML front matter or add it at the beginning
if yaml_match:
new_content = content[: yaml_match.start()] + new_yaml_block + content[yaml_match.end() :]
else:
new_content = new_yaml_block + content
# Create a temporary file with the new content
import tempfile
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".md") as temp_file:
temp_file.write(new_content)
temp_path = temp_file.name
# Upload the updated README.md
try:
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="README.md",
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
print(f"Successfully updated README.md in {repo_id}")
except Exception as e:
print(f"Error updating README.md: {e}")
# Clean up temporary file
import os
os.unlink(temp_path)