File size: 1,716 Bytes
1e5834c
 
e1ef382
1e5834c
 
 
 
 
bf18264
448ae42
8674838
 
448ae42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5834c
bf18264
e1ef382
 
 
 
 
 
 
 
 
 
1e5834c
 
448ae42
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import requests
import os
import random
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}

dataset = "mozilla-foundation/common_voice_17_0"
config = "en"


def _search(paths: list[str]):
    if len(paths) == 0:
        return []
    if paths[0].startswith("en_train"):
        split = "train"
    else:
        split = "validation"

    paths_in_clause = ", ".join([f"'{path}'" for path in paths])
    where_clause = f'"path" IN ({paths_in_clause})'

    api_url = f"https://datasets-server.huggingface.co/filter?dataset={dataset}&config={config}&split={split}&where={where_clause}&offset=0"

    response = requests.get(api_url, headers=headers)
    response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)
    data = response.json()

    return data.get("rows", [])


def get_prompt():
    """Get a random sentence from the Common Voice dataset"""
    offset = random.randint(0, 100_000)
    api_url = f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split=train&offset={offset}&length=1"
    response = requests.get(api_url, headers=headers)
    response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)
    data = response.json()
    return data.get("rows", [])[0]["row"]["sentence"]


def search(rows: list[dict]):
    file_paths_to_find = [row["path"] for row in rows]
    train_paths = []
    validation_paths = []
    for path in file_paths_to_find:
        if path.startswith("en_train"):
            train_paths.append(path)
        else:
            validation_paths.append(path)

    train_rows = _search(train_paths)
    validation_rows = _search(validation_paths)

    return train_rows + validation_rows