datasets-ai / app.py
Caleb Fahlgren
make model parameters more dynamic w env variables
e8c1c43
raw
history blame
5.77 kB
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import pandas as pd
import gradio as gr
import duckdb
import requests
import llama_cpp
import instructor
import spaces
import enum
import os
from pydantic import BaseModel, Field
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
view_name = "dataset_view"
hf_api = HfApi()
conn = duckdb.connect()
gpu_layers = int(os.environ.get("GPU_LAYERS", 81))
draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2))
repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF")
model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q8_0.gguf")
hf_hub_download(
repo_id=repo_id,
filename=model_file_name,
local_dir="./models",
)
class OutputTypes(str, enum.Enum):
TABLE = "table"
BARCHART = "barchart"
LINECHART = "linechart"
class SQLResponse(BaseModel):
sql: str
visualization_type: Optional[OutputTypes] = Field(
None, description="The type of visualization to display"
)
data_key: Optional[str] = Field(
None,
description="The column name from the sql query that contains the data for chart responses",
)
label_key: Optional[str] = Field(
None,
description="The column name from the sql query that contains the labels for chart responses",
)
def get_dataset_ddl(dataset_id: str) -> str:
response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}")
response.raise_for_status() # Check if the request was successful
first_parquet = response.json().get("parquet_files", [])[0]
first_parquet_url = first_parquet.get("url")
if not first_parquet_url:
raise ValueError("No valid URL found for the first parquet file.")
conn.execute(
f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');"
)
dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall()
column_data_types = ",\n\t".join(
[f"{column[1]} {column[2]}" for column in dataset_ddl]
)
sql_ddl = """
CREATE TABLE {} (
{}
);
""".format(
view_name, column_data_types
)
return sql_ddl
@spaces.GPU(duration=120)
def generate_query(ddl: str, query: str) -> dict:
llama = llama_cpp.Llama(
model_path=f"models/{model_file_name}",
n_gpu_layers=gpu_layers,
chat_format="chatml",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens),
logits_all=True,
n_ctx=2048,
verbose=True,
temperature=0.1,
)
create = instructor.patch(
create=llama.create_chat_completion_openai_v1,
mode=instructor.Mode.JSON_SCHEMA,
)
system_prompt = f"""
You are an expert SQL assistant with access to the following PostgreSQL Table:
```sql
{ddl.strip()}
```
Please assist the user by writing a SQL query that answers the user's question.
"""
print("Calling LLM with system prompt: ", system_prompt, query)
resp: SQLResponse = create(
model="Hermes-2-Pro-Llama-3-8B",
messages=[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": query,
},
],
response_model=SQLResponse,
)
print("Received Response: ", resp)
return resp.model_dump()
def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]:
ddl = get_dataset_ddl(dataset_id)
response = generate_query(ddl, query)
print("Querying Parquet...")
df = conn.execute(response.get("sql")).fetchdf()
plot = None
label_key = response.get("label_key")
data_key = response.get("data_key")
viz_type = response.get("visualization_type")
sql = response.get("sql")
markdown_output = f"""```sql\n{sql}\n```"""
# handle incorrect data and label keys
if label_key and label_key not in df.columns:
label_key = None
if data_key and data_key not in df.columns:
data_key = None
if df.empty:
return df, f"```sql\n{sql}\n```", plot
if viz_type == OutputTypes.LINECHART:
plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
elif viz_type == OutputTypes.BARCHART:
plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure()
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
return df, markdown_output, plot
with gr.Blocks() as demo:
gr.Markdown("# Query your HF Datasets with Natural Language πŸ“ˆπŸ“Š")
dataset_id = HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Find your favorite dataset...",
search_type="dataset",
value="gretelai/synthetic_text_to_sql",
)
user_query = gr.Textbox("", label="Ask anything...")
examples = [
["Show me a preview of the data"],
["Show me something interesting"],
["Which row has longest description length?"],
["find the average length of sql query context"],
]
gr.Examples(examples=examples, inputs=[user_query], outputs=[])
btn = gr.Button("Ask πŸͺ„")
sql_query = gr.Markdown(label="Output SQL Query")
df = gr.DataFrame()
plot = gr.Plot()
btn.click(
query_dataset,
inputs=[dataset_id, user_query],
outputs=[df, sql_query, plot],
)
if __name__ == "__main__":
demo.launch()