File size: 2,727 Bytes
73e0168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9935ce4
74560e6
73e0168
 
 
74560e6
73e0168
 
 
 
32d7d7d
73e0168
 
 
 
 
 
 
 
8b00326
 
 
 
43b024e
8b00326
43b024e
8b00326
 
 
 
 
56f42a5
8b00326
 
 
dc055d3
 
73e0168
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import fire

CONFIG = {
    "preserve_insertion_order": False
}

CMD_SRC_KWARGS = """
SELECT ('hf://datasets/{src}/' || lo.arguments['splits']['{split}']) AS path, function
FROM (
    SELECT unnest(li.loading_codes) AS lo, li.function[4:] as function
    FROM (
        SELECT unnest(libraries) as li
        FROM read_json('https://datasets-server.huggingface.co/compatible-libraries?dataset={src}')
    ) WHERE li.function[:3] = 'pl.'
) WHERE lo.config_name='{config}';
""".strip()

CMD_SRC = """
CREATE VIEW src AS SELECT * FROM {function}('{path}');
""".strip()


CMD_DST = """
COPY ({query}) to 'tmp' (FORMAT PARQUET, ROW_GROUP_SIZE_BYTES '100MB', ROW_GROUPS_PER_FILE 5, PER_THREAD_OUTPUT true);
""".strip()

CMD_SRC_DRY_RUN = CMD_SRC[:-1] + " LIMIT 5;"
CMD_DST_DRY_RUN = "{query};"

DATA_CARD = "# Dataset Card for {dst}\n\nDataset prepared from [{src}](https://huggingface.co/datasets/{src}) using\n\n```\n{query}\n```\n"

def sql(src: str, dst: str, query: str, config: str = "default", split: str = "train", private: bool = False, dry_run: bool = False):
    import os
    import duckdb
    from huggingface_hub import CommitScheduler, DatasetCard

    class CommitAndCleanScheduler(CommitScheduler):

        def push_to_hub(self):
            for path in self.folder_path.with_name("tmp").glob("*.parquet"):
                with path.open("rb") as f:
                    footer = f.read(4) and f.seek(-4, os.SEEK_END) and f.read(4)
                if footer == b"PAR1":
                    path.rename(self.folder_path / path.name)
            super().push_to_hub()
            for path in self.last_uploaded:
                path.unlink(missing_ok=True)

    con = duckdb.connect(":memory:", config=CONFIG)
    src_kwargs = con.sql(CMD_SRC_KWARGS.format(src=src, config=config, split=split)).df().to_dict(orient="records")
    if not src_kwargs:
        raise ValueError(f'Invalid --config "{config}" for dataset "{src}", please select a valid dataset config/subset.')

    con.sql((CMD_SRC_DRY_RUN if dry_run else CMD_SRC).format(**src_kwargs[0]))

    if dry_run:
        print(f"Sample data from '{src}' that would be written to dataset '{dst}':\n")
        result = con.sql(CMD_DST_DRY_RUN.format(query=query.rstrip("\n ;")))
        print(result.df().to_markdown())
        return

    with CommitAndCleanScheduler(repo_id=dst, repo_type="dataset", folder_path="dst", path_in_repo="data", every=0.1, private=private):
        con.sql("PRAGMA enable_progress_bar;")
        result = con.sql(CMD_DST.format(query=query.rstrip("\n ;")))
    DatasetCard(DATA_CARD.format(src=src, dst=dst, query=query)).push_to_hub(repo_id=dst, repo_type="dataset")
    print("done")

if __name__ == '__main__':
    fire.Fire(sql)