File size: 5,240 Bytes
54e2d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from huggingface_hub import HfApi, snapshot_download
from loguru import logger

api = HfApi()


def download_dataset_snapshot(repo_id, local_dir):
    try:
        logger.info(f"Downloading dataset snapshot from {repo_id} to {local_dir}")
        snapshot_download(
            repo_id=repo_id,
            local_dir=local_dir,
            repo_type="dataset",
            tqdm_class=None,
        )
    except Exception as e:
        logger.error(f"Error downloading dataset snapshot from {repo_id} to {local_dir}: {e}. Restarting space.")
        api.restart_space(repo_id=repo_id)


def remove_files_from_dataset_repo(repo_id: str, path_patterns: list[str], commit_message: str = "Remove files"):
    """
    Remove files or directories matching specified patterns from a Hugging Face dataset repository.

    Args:
        repo_id: The ID of the dataset repository (e.g., "username/dataset-name")
        path_patterns: List of file or directory path patterns to remove
        commit_message: Message for the commit that removes the files
    """
    import fnmatch
    import os

    # Get all files in the repository
    repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")

    # Find files matching the patterns
    files_to_remove = []
    for pattern in path_patterns:
        matching_files = fnmatch.filter(repo_files, pattern)
        files_to_remove.extend(matching_files)

    # Delete each matching file
    for path in files_to_remove:
        try:
            api.delete_file(
                path_in_repo=path, repo_id=repo_id, repo_type="dataset", commit_message=f"{commit_message}: {path}"
            )
            print(f"Successfully removed {path} from {repo_id}")
        except Exception as e:
            print(f"Error removing {path}: {e}")


def update_dataset_info_readme(
    repo_id: str,
    dataset_info: dict,
    license_id: str = None,
    commit_message: str = "Update dataset_info in README.md",
):
    """
    Update the dataset_info section in the README.md file of a Hugging Face dataset repository.

    Args:
        repo_id: The ID of the dataset repository (e.g., "username/dataset-name")
        dataset_info: Dictionary containing dataset information to include in the README
        license_id: Optional license identifier (e.g., "mit", "cc-by-4.0")
        commit_message: Message for the commit

    Example dataset_info structure:
    {
        "features": [
            {"name": "text", "dtype": "string"},
            {"name": "label", "dtype": "int64"}
        ],
        "splits": [
            {"name": "train", "num_examples": 10000, "num_bytes": 1000000},
            {"name": "test", "num_examples": 1000, "num_bytes": 100000}
        ],
        "download_size": 1200000,
        "dataset_size": 1100000,
        "configs": [
            {
                "config_name": "default",
                "data_files": [
                    {"split": "train", "path": "data/train.csv"},
                    {"split": "test", "path": "data/test.csv"}
                ]
            }
        ]
    }
    """
    import re

    import yaml
    from huggingface_hub import HfApi

    api = HfApi()

    # Check if README.md exists
    try:
        readme_content = api.hf_hub_download(repo_id=repo_id, repo_type="dataset", filename="README.md", token=None)
        with open(readme_content, "r", encoding="utf-8") as f:
            content = f.read()
    except Exception:
        # Create a new README.md if it doesn't exist
        content = ""

    # Parse existing YAML front matter if it exists
    yaml_block = None
    yaml_match = re.search(r"---\s*\n(.*?)\n\s*---", content, re.DOTALL)

    if yaml_match:
        yaml_text = yaml_match.group(1)
        try:
            yaml_block = yaml.safe_load(yaml_text)
        except Exception as e:
            print(f"Error parsing existing YAML front matter: {e}")
            yaml_block = {}
    else:
        yaml_block = {}

    # Update or add dataset_info and license
    if dataset_info:
        yaml_block["dataset_info"] = dataset_info

    if license_id:
        yaml_block["license"] = license_id

    # Generate new YAML front matter
    new_yaml = yaml.dump(yaml_block, sort_keys=False, default_flow_style=False)
    new_yaml_block = f"---\n{new_yaml}---\n"

    # Replace existing YAML front matter or add it at the beginning
    if yaml_match:
        new_content = content[: yaml_match.start()] + new_yaml_block + content[yaml_match.end() :]
    else:
        new_content = new_yaml_block + content

    # Create a temporary file with the new content
    import tempfile

    with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".md") as temp_file:
        temp_file.write(new_content)
        temp_path = temp_file.name

    # Upload the updated README.md
    try:
        api.upload_file(
            path_or_fileobj=temp_path,
            path_in_repo="README.md",
            repo_id=repo_id,
            repo_type="dataset",
            commit_message=commit_message,
        )
        print(f"Successfully updated README.md in {repo_id}")
    except Exception as e:
        print(f"Error updating README.md: {e}")

    # Clean up temporary file
    import os

    os.unlink(temp_path)