File size: 2,669 Bytes
6216400
5822bba
 
 
9b95e7f
aef303c
5822bba
aef303c
5822bba
554bcd2
5822bba
554bcd2
 
6216400
aef303c
66f4448
aef303c
 
 
67fdc30
 
aef303c
 
 
 
 
 
9b95e7f
554bcd2
 
aef303c
 
 
50bd73b
aef303c
 
554bcd2
 
 
aef303c
9b95e7f
 
 
 
 
 
518f556
 
 
 
aef303c
50bd73b
aef303c
 
 
 
9b95e7f
 
 
aef303c
9b95e7f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import duckdb
from huggingface_hub import HfFileSystem
from huggingface_hub.hf_file_system import safe_quote
import pandas as pd
import requests

DATASETS_SERVER_ENDPOINT = "https://datasets-server.huggingface.co"
PARQUET_REVISION="refs/convert/parquet"
TABLE_WILDCARD="{table}"

fs = HfFileSystem()
duckdb.register_filesystem(fs)

def get_parquet_files(dataset, config, split):
        response = requests.get(f"{DATASETS_SERVER_ENDPOINT}/parquet?dataset={dataset}&config={config}", timeout=60)
        if response.status_code != 200:
            raise Exception(response)
        
        response = response.json()
        parquet_files = response["parquet_files"]
        file_names = [content["filename"] for content in parquet_files if content["split"] == split]
        if len(file_names) == 0:
             raise Exception("No parquet files found for dataset")
        return file_names

def run_command(dataset, config, split, sql):
    try:
        if TABLE_WILDCARD not in sql:
            raise Exception(f"Query must contains {TABLE_WILDCARD} wildcard.")
        
        parquet_files = get_parquet_files(dataset, config, split)
        print(f"File names found: {','.join(parquet_files)}")
        parquet_first_file = parquet_files[0] # TODO: Send pattern to duck db to read all split parquets
        print(f"Trying with the first one {parquet_first_file}")
        location=f"hf://datasets/{dataset}@{safe_quote(PARQUET_REVISION)}/{config}/{parquet_first_file}"
        print(location)
        sql = sql.replace(TABLE_WILDCARD, f"'{location}'")
        result = duckdb.query(sql).to_df()
        print("Ok")
    except Exception as error:
        print(f"Error: {str(error)}")
        return pd.DataFrame({"Error": [f"❌ {str(error)}"]})
    return result

with gr.Blocks() as demo:
    gr.Markdown(" ## SQL Query using DuckDB for datasets server parquet files")
    dataset = gr.Textbox(label="dataset", placeholder="mstz/iris", value="mstz/iris")
    config = gr.Textbox(label="config", placeholder="iris", value="iris")
    split = gr.Textbox(label="split", placeholder="train", value="train")
    sql = gr.Textbox(
            label="Query in SQL format - It should have {table} wildcard",
            placeholder=f"SELECT sepal_length FROM {TABLE_WILDCARD} LIMIT 3",
            value=f"SELECT sepal_length FROM {TABLE_WILDCARD} LIMIT 3",
            lines=3,
    )
    run_button = gr.Button("Run")
    gr.Markdown("### Result")
    cached_responses_table = gr.DataFrame()
    run_button.click(run_command, inputs=[dataset, config, split, sql], outputs=cached_responses_table)


if __name__ == "__main__":
    demo.launch()