Deploy
Browse files- .dockerignore +45 -0
- .gitignore +45 -0
- Dockerfile +23 -47
- README.md +1 -1
- database_api.py +426 -0
- main.py +372 -169
- requirements.txt +3 -2
- test_api.py +246 -0
.dockerignore
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .dockerignore
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.pyo
|
5 |
+
*.pyd
|
6 |
+
.Python
|
7 |
+
env/
|
8 |
+
.env
|
9 |
+
.venv/
|
10 |
+
venv/
|
11 |
+
ENV/
|
12 |
+
env.bak/
|
13 |
+
venv.bak/
|
14 |
+
|
15 |
+
.pytest_cache/
|
16 |
+
.mypy_cache/
|
17 |
+
.nox/
|
18 |
+
.tox/
|
19 |
+
.coverage
|
20 |
+
.coverage.*
|
21 |
+
coverage.xml
|
22 |
+
htmlcov/
|
23 |
+
.hypothesis/
|
24 |
+
|
25 |
+
*.db
|
26 |
+
*.db.wal
|
27 |
+
*.log
|
28 |
+
*.sqlite
|
29 |
+
*.sqlite3
|
30 |
+
|
31 |
+
# Ignore specific generated files if needed
|
32 |
+
api_database.db
|
33 |
+
api_database.db.wal
|
34 |
+
my_duckdb_api_db.db
|
35 |
+
my_duckdb_api_db.db.wal
|
36 |
+
exported_db/
|
37 |
+
duckdb_api_exports/ # Don't copy local temp exports
|
38 |
+
|
39 |
+
# OS-specific files
|
40 |
+
.DS_Store
|
41 |
+
Thumbs.db
|
42 |
+
|
43 |
+
# IDE files
|
44 |
+
.idea/
|
45 |
+
.vscode/
|
.gitignore
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .dockerignore
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.pyo
|
5 |
+
*.pyd
|
6 |
+
.Python
|
7 |
+
env/
|
8 |
+
.env
|
9 |
+
.venv/
|
10 |
+
venv/
|
11 |
+
ENV/
|
12 |
+
env.bak/
|
13 |
+
venv.bak/
|
14 |
+
|
15 |
+
.pytest_cache/
|
16 |
+
.mypy_cache/
|
17 |
+
.nox/
|
18 |
+
.tox/
|
19 |
+
.coverage
|
20 |
+
.coverage.*
|
21 |
+
coverage.xml
|
22 |
+
htmlcov/
|
23 |
+
.hypothesis/
|
24 |
+
|
25 |
+
*.db
|
26 |
+
*.db.wal
|
27 |
+
*.log
|
28 |
+
*.sqlite
|
29 |
+
*.sqlite3
|
30 |
+
|
31 |
+
# Ignore specific generated files if needed
|
32 |
+
api_database.db
|
33 |
+
api_database.db.wal
|
34 |
+
my_duckdb_api_db.db
|
35 |
+
my_duckdb_api_db.db.wal
|
36 |
+
exported_db/
|
37 |
+
duckdb_api_exports/ # Don't copy local temp exports
|
38 |
+
|
39 |
+
# OS-specific files
|
40 |
+
.DS_Store
|
41 |
+
Thumbs.db
|
42 |
+
|
43 |
+
# IDE files
|
44 |
+
.idea/
|
45 |
+
.vscode/
|
Dockerfile
CHANGED
@@ -1,57 +1,33 @@
|
|
1 |
-
#
|
2 |
-
FROM python:3.10-slim
|
3 |
|
4 |
-
#
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
useradd --system --uid ${USER_ID} --gid appgroup --shell /sbin/nologin --create-home appuser
|
12 |
|
13 |
-
# Set the working directory
|
14 |
WORKDIR /app
|
15 |
|
16 |
-
#
|
17 |
-
# DuckDB UI often uses ~/.duckdb (which will be /home/appuser/.duckdb)
|
18 |
-
# Ensure these are owned by the user *before* VOLUME instruction
|
19 |
-
RUN mkdir -p /app/data /home/appuser/.duckdb && \
|
20 |
-
chown -R ${USER_ID}:${GROUP_ID} /app /home/appuser/.duckdb
|
21 |
-
|
22 |
-
# Switch context to the non-root user early for subsequent RUN/COPY commands
|
23 |
-
USER appuser
|
24 |
-
|
25 |
-
# Copy requirements file (as appuser)
|
26 |
COPY requirements.txt .
|
27 |
|
28 |
-
# Install dependencies
|
29 |
-
#
|
30 |
-
|
31 |
-
|
|
|
32 |
|
33 |
-
# Copy application code
|
34 |
-
COPY
|
35 |
|
36 |
-
#
|
37 |
-
# These paths MUST match the directories the 'appuser' process will write to.
|
38 |
-
# Note: We created and chowned these earlier.
|
39 |
-
VOLUME /app/data
|
40 |
-
VOLUME /home/appuser/.duckdb
|
41 |
-
# --- End Define Volumes ---
|
42 |
-
|
43 |
-
# Expose ports
|
44 |
EXPOSE 8000
|
45 |
-
|
46 |
-
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
# Ensure Python user packages are in the path
|
51 |
-
ENV PATH="/home/appuser/.local/bin:${PATH}"
|
52 |
-
# Set HOME so things like ~/.duckdb resolve correctly
|
53 |
-
ENV HOME=/home/appuser
|
54 |
-
|
55 |
-
# Command to run the application (now runs as appuser)
|
56 |
-
# No chmod needed here. Ownership was handled during build.
|
57 |
-
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
1 |
+
# Dockerfile
|
|
|
2 |
|
3 |
+
# 1. Choose a base Python image
|
4 |
+
# Using a specific version is recommended for reproducibility.
|
5 |
+
# The '-slim' variant is smaller.
|
6 |
+
FROM python:3.12-slim
|
7 |
|
8 |
+
# 2. Set environment variables (optional but good practice)
|
9 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
10 |
+
ENV PYTHONUNBUFFERED 1
|
|
|
11 |
|
12 |
+
# 3. Set the working directory inside the container
|
13 |
WORKDIR /app
|
14 |
|
15 |
+
# 4. Copy only the requirements file first to leverage Docker cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
COPY requirements.txt .
|
17 |
|
18 |
+
# 5. Install dependencies
|
19 |
+
# --no-cache-dir makes the image smaller
|
20 |
+
# --upgrade pip ensures we have the latest pip
|
21 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
22 |
+
pip install --no-cache-dir -r requirements.txt
|
23 |
|
24 |
+
# 6. Copy the rest of the application code into the working directory
|
25 |
+
COPY . .
|
26 |
|
27 |
+
# 7. Expose the port the app runs on (uvicorn default is 8000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
EXPOSE 8000
|
29 |
+
|
30 |
+
# 8. Define the default command to run when the container starts
|
31 |
+
# Use exec form for proper signal handling.
|
32 |
+
# Do NOT use --reload in production.
|
33 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -7,7 +7,7 @@ sdk: docker
|
|
7 |
pinned: false
|
8 |
license: mit
|
9 |
short_description: DuckDB Hosting with UI & FastAPI 4 SQL Calls & DB Downloads
|
10 |
-
port:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
7 |
pinned: false
|
8 |
license: mit
|
9 |
short_description: DuckDB Hosting with UI & FastAPI 4 SQL Calls & DB Downloads
|
10 |
+
port: 8000
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
database_api.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# database_api.py
|
2 |
+
import duckdb
|
3 |
+
import pandas as pd
|
4 |
+
import pyarrow as pa
|
5 |
+
import pyarrow.ipc
|
6 |
+
from pathlib import Path
|
7 |
+
import tempfile
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
|
11 |
+
# No need for pybind11 import here anymore
|
12 |
+
|
13 |
+
# --- Custom Exceptions ---
|
14 |
+
class DatabaseAPIError(Exception):
|
15 |
+
"""Base exception for our custom API."""
|
16 |
+
pass
|
17 |
+
|
18 |
+
class QueryError(DatabaseAPIError):
|
19 |
+
"""Exception raised for errors during query execution."""
|
20 |
+
pass
|
21 |
+
|
22 |
+
# --- Helper function to format COPY options ---
|
23 |
+
def _format_copy_options(options: Optional[Dict[str, Any]]) -> str:
|
24 |
+
if not options:
|
25 |
+
return ""
|
26 |
+
opts_parts = []
|
27 |
+
for k, v in options.items():
|
28 |
+
key_upper = k.upper()
|
29 |
+
if isinstance(v, bool):
|
30 |
+
value_repr = str(v).upper()
|
31 |
+
elif isinstance(v, (int, float)):
|
32 |
+
value_repr = str(v)
|
33 |
+
elif isinstance(v, str):
|
34 |
+
escaped_v = v.replace("'", "''")
|
35 |
+
value_repr = f"'{escaped_v}'"
|
36 |
+
else:
|
37 |
+
value_repr = repr(v)
|
38 |
+
opts_parts.append(f"{key_upper} {value_repr}")
|
39 |
+
|
40 |
+
opts_str = ", ".join(opts_parts)
|
41 |
+
return f"WITH ({opts_str})"
|
42 |
+
|
43 |
+
# --- Main DatabaseAPI Class ---
|
44 |
+
class DatabaseAPI:
|
45 |
+
def __init__(self,
|
46 |
+
db_path: Union[str, Path] = ":memory:",
|
47 |
+
read_only: bool = False,
|
48 |
+
config: Optional[Dict[str, str]] = None):
|
49 |
+
self._db_path = str(db_path)
|
50 |
+
self._config = config or {}
|
51 |
+
self._read_only = read_only
|
52 |
+
self._conn: Optional[duckdb.DuckDBPyConnection] = None
|
53 |
+
try:
|
54 |
+
self._conn = duckdb.connect(
|
55 |
+
database=self._db_path,
|
56 |
+
read_only=self._read_only,
|
57 |
+
config=self._config
|
58 |
+
)
|
59 |
+
print(f"Connected to DuckDB database at '{self._db_path}'")
|
60 |
+
except duckdb.Error as e:
|
61 |
+
print(f"Failed to connect to DuckDB: {e}")
|
62 |
+
raise DatabaseAPIError(f"Failed to connect to DuckDB: {e}") from e
|
63 |
+
|
64 |
+
def _ensure_connection(self):
|
65 |
+
if self._conn is None:
|
66 |
+
raise DatabaseAPIError("Database connection is not established or has been closed.")
|
67 |
+
try:
|
68 |
+
self._conn.execute("SELECT 1", [])
|
69 |
+
except (duckdb.ConnectionException, RuntimeError) as e:
|
70 |
+
if "Connection has already been closed" in str(e) or "connection closed" in str(e).lower():
|
71 |
+
self._conn = None
|
72 |
+
raise DatabaseAPIError("Database connection is closed.") from e
|
73 |
+
else:
|
74 |
+
raise DatabaseAPIError(f"Database connection error: {e}") from e
|
75 |
+
|
76 |
+
# --- Basic Query Methods --- (Keep as before)
|
77 |
+
def execute_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> None:
|
78 |
+
self._ensure_connection()
|
79 |
+
print(f"Executing SQL: {sql}")
|
80 |
+
try:
|
81 |
+
self._conn.execute(sql, parameters)
|
82 |
+
except duckdb.Error as e:
|
83 |
+
print(f"Error executing SQL: {e}")
|
84 |
+
raise QueryError(f"Error executing SQL: {e}") from e
|
85 |
+
|
86 |
+
def query_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> duckdb.DuckDBPyRelation:
|
87 |
+
self._ensure_connection()
|
88 |
+
print(f"Querying SQL: {sql}")
|
89 |
+
try:
|
90 |
+
return self._conn.sql(sql, params=parameters)
|
91 |
+
except duckdb.Error as e:
|
92 |
+
print(f"Error querying SQL: {e}")
|
93 |
+
raise QueryError(f"Error querying SQL: {e}") from e
|
94 |
+
|
95 |
+
def query_df(self, sql: str, parameters: Optional[List[Any]] = None) -> pd.DataFrame:
|
96 |
+
self._ensure_connection()
|
97 |
+
print(f"Querying SQL to DataFrame: {sql}")
|
98 |
+
try:
|
99 |
+
return self._conn.execute(sql, parameters).df()
|
100 |
+
except ImportError:
|
101 |
+
print("Pandas library is required for DataFrame operations.")
|
102 |
+
raise
|
103 |
+
except duckdb.Error as e:
|
104 |
+
print(f"Error querying SQL to DataFrame: {e}")
|
105 |
+
raise QueryError(f"Error querying SQL to DataFrame: {e}") from e
|
106 |
+
|
107 |
+
def query_arrow(self, sql: str, parameters: Optional[List[Any]] = None) -> pa.Table:
|
108 |
+
self._ensure_connection()
|
109 |
+
print(f"Querying SQL to Arrow Table: {sql}")
|
110 |
+
try:
|
111 |
+
return self._conn.execute(sql, parameters).arrow()
|
112 |
+
except ImportError:
|
113 |
+
print("PyArrow library is required for Arrow operations.")
|
114 |
+
raise
|
115 |
+
except duckdb.Error as e:
|
116 |
+
print(f"Error querying SQL to Arrow Table: {e}")
|
117 |
+
raise QueryError(f"Error querying SQL to Arrow Table: {e}") from e
|
118 |
+
|
119 |
+
def query_fetchall(self, sql: str, parameters: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]:
|
120 |
+
self._ensure_connection()
|
121 |
+
print(f"Querying SQL and fetching all: {sql}")
|
122 |
+
try:
|
123 |
+
return self._conn.execute(sql, parameters).fetchall()
|
124 |
+
except duckdb.Error as e:
|
125 |
+
print(f"Error querying SQL: {e}")
|
126 |
+
raise QueryError(f"Error querying SQL: {e}") from e
|
127 |
+
|
128 |
+
def query_fetchone(self, sql: str, parameters: Optional[List[Any]] = None) -> Optional[Tuple[Any, ...]]:
|
129 |
+
self._ensure_connection()
|
130 |
+
print(f"Querying SQL and fetching one: {sql}")
|
131 |
+
try:
|
132 |
+
return self._conn.execute(sql, parameters).fetchone()
|
133 |
+
except duckdb.Error as e:
|
134 |
+
print(f"Error querying SQL: {e}")
|
135 |
+
raise QueryError(f"Error querying SQL: {e}") from e
|
136 |
+
|
137 |
+
# --- Registration Methods --- (Keep as before)
|
138 |
+
def register_df(self, name: str, df: pd.DataFrame):
|
139 |
+
self._ensure_connection()
|
140 |
+
print(f"Registering DataFrame as '{name}'")
|
141 |
+
try:
|
142 |
+
self._conn.register(name, df)
|
143 |
+
except duckdb.Error as e:
|
144 |
+
print(f"Error registering DataFrame: {e}")
|
145 |
+
raise QueryError(f"Error registering DataFrame: {e}") from e
|
146 |
+
|
147 |
+
def unregister_df(self, name: str):
|
148 |
+
self._ensure_connection()
|
149 |
+
print(f"Unregistering virtual table '{name}'")
|
150 |
+
try:
|
151 |
+
self._conn.unregister(name)
|
152 |
+
except duckdb.Error as e:
|
153 |
+
if "not found" in str(e).lower():
|
154 |
+
print(f"Warning: Virtual table '{name}' not found for unregistering.")
|
155 |
+
else:
|
156 |
+
print(f"Error unregistering virtual table: {e}")
|
157 |
+
raise QueryError(f"Error unregistering virtual table: {e}") from e
|
158 |
+
|
159 |
+
# --- Extension Methods --- (Keep as before)
|
160 |
+
def install_extension(self, extension_name: str, force_install: bool = False):
|
161 |
+
self._ensure_connection()
|
162 |
+
print(f"Installing extension: {extension_name}")
|
163 |
+
try:
|
164 |
+
self._conn.install_extension(extension_name, force_install=force_install)
|
165 |
+
except duckdb.Error as e:
|
166 |
+
print(f"Error installing extension '{extension_name}': {e}")
|
167 |
+
raise DatabaseAPIError(f"Error installing extension '{extension_name}': {e}") from e
|
168 |
+
|
169 |
+
def load_extension(self, extension_name: str):
|
170 |
+
self._ensure_connection()
|
171 |
+
print(f"Loading extension: {extension_name}")
|
172 |
+
try:
|
173 |
+
self._conn.load_extension(extension_name)
|
174 |
+
# Catch specific DuckDB errors that indicate failure but aren't API errors
|
175 |
+
except (duckdb.IOException, duckdb.CatalogException) as load_err:
|
176 |
+
print(f"Error loading extension '{extension_name}': {load_err}")
|
177 |
+
raise QueryError(f"Error loading extension '{extension_name}': {load_err}") from load_err
|
178 |
+
except duckdb.Error as e: # Catch other DuckDB errors
|
179 |
+
print(f"Unexpected DuckDB error loading extension '{extension_name}': {e}")
|
180 |
+
raise DatabaseAPIError(f"Unexpected DuckDB error loading extension '{extension_name}': {e}") from e
|
181 |
+
|
182 |
+
# --- Export Methods ---
|
183 |
+
def export_database(self, directory_path: Union[str, Path]):
|
184 |
+
self._ensure_connection()
|
185 |
+
path_str = str(directory_path)
|
186 |
+
if not os.path.isdir(path_str):
|
187 |
+
try:
|
188 |
+
os.makedirs(path_str)
|
189 |
+
print(f"Created export directory: {path_str}")
|
190 |
+
except OSError as e:
|
191 |
+
raise DatabaseAPIError(f"Could not create export directory '{path_str}': {e}") from e
|
192 |
+
print(f"Exporting database to directory: {path_str}")
|
193 |
+
sql = f"EXPORT DATABASE '{path_str}' (FORMAT CSV)"
|
194 |
+
try:
|
195 |
+
self._conn.execute(sql)
|
196 |
+
print("Database export completed successfully.")
|
197 |
+
except duckdb.Error as e:
|
198 |
+
print(f"Error exporting database: {e}")
|
199 |
+
raise DatabaseAPIError(f"Error exporting database: {e}") from e
|
200 |
+
|
201 |
+
def _export_data(self,
|
202 |
+
source: str,
|
203 |
+
output_path: Union[str, Path],
|
204 |
+
file_format: str,
|
205 |
+
options: Optional[Dict[str, Any]] = None):
|
206 |
+
self._ensure_connection()
|
207 |
+
path_str = str(output_path)
|
208 |
+
options_str = _format_copy_options(options)
|
209 |
+
source_safe = source.strip()
|
210 |
+
# --- MODIFIED: Use f-string quoting instead of quote_identifier ---
|
211 |
+
if ' ' in source_safe or source_safe.upper().startswith(('SELECT', 'WITH', 'VALUES')):
|
212 |
+
copy_source = f"({source})"
|
213 |
+
else:
|
214 |
+
# Simple quoting, might need refinement for complex identifiers
|
215 |
+
copy_source = f'"{source_safe}"'
|
216 |
+
# --- END MODIFICATION ---
|
217 |
+
|
218 |
+
sql = f"COPY {copy_source} TO '{path_str}' {options_str}"
|
219 |
+
print(f"Exporting data to {path_str} (Format: {file_format}) with options: {options or {}}")
|
220 |
+
try:
|
221 |
+
self._conn.execute(sql)
|
222 |
+
print("Data export completed successfully.")
|
223 |
+
except duckdb.Error as e:
|
224 |
+
print(f"Error exporting data: {e}")
|
225 |
+
raise QueryError(f"Error exporting data to {file_format}: {e}") from e
|
226 |
+
|
227 |
+
# --- Keep export_data_to_csv, parquet, json, jsonl as before ---
|
228 |
+
def export_data_to_csv(self,
|
229 |
+
source: str,
|
230 |
+
output_path: Union[str, Path],
|
231 |
+
options: Optional[Dict[str, Any]] = None):
|
232 |
+
csv_options = options.copy() if options else {}
|
233 |
+
csv_options['FORMAT'] = 'CSV'
|
234 |
+
if 'HEADER' not in {k.upper() for k in csv_options}:
|
235 |
+
csv_options['HEADER'] = True
|
236 |
+
self._export_data(source, output_path, "CSV", csv_options)
|
237 |
+
|
238 |
+
def export_data_to_parquet(self,
|
239 |
+
source: str,
|
240 |
+
output_path: Union[str, Path],
|
241 |
+
options: Optional[Dict[str, Any]] = None):
|
242 |
+
parquet_options = options.copy() if options else {}
|
243 |
+
parquet_options['FORMAT'] = 'PARQUET'
|
244 |
+
self._export_data(source, output_path, "Parquet", parquet_options)
|
245 |
+
|
246 |
+
def export_data_to_json(self,
|
247 |
+
source: str,
|
248 |
+
output_path: Union[str, Path],
|
249 |
+
array_format: bool = True,
|
250 |
+
options: Optional[Dict[str, Any]] = None):
|
251 |
+
json_options = options.copy() if options else {}
|
252 |
+
json_options['FORMAT'] = 'JSON'
|
253 |
+
if 'ARRAY' not in {k.upper() for k in json_options}:
|
254 |
+
json_options['ARRAY'] = array_format
|
255 |
+
self._export_data(source, output_path, "JSON", json_options)
|
256 |
+
|
257 |
+
def export_data_to_jsonl(self,
|
258 |
+
source: str,
|
259 |
+
output_path: Union[str, Path],
|
260 |
+
options: Optional[Dict[str, Any]] = None):
|
261 |
+
self.export_data_to_json(source, output_path, array_format=False, options=options)
|
262 |
+
|
263 |
+
|
264 |
+
# # --- Streaming Read Methods --- (Keep as before)
|
265 |
+
# def stream_query_arrow(self,
|
266 |
+
# sql: str,
|
267 |
+
# parameters: Optional[List[Any]] = None,
|
268 |
+
# batch_size: int = 1000000
|
269 |
+
# ) -> Iterator[pa.RecordBatch]:
|
270 |
+
# self._ensure_connection()
|
271 |
+
# print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
|
272 |
+
# try:
|
273 |
+
# result_set = self._conn.execute(sql, parameters)
|
274 |
+
# while True:
|
275 |
+
# batch = result_set.fetch_record_batch(batch_size)
|
276 |
+
# if not batch:
|
277 |
+
# break
|
278 |
+
# yield batch
|
279 |
+
# except ImportError:
|
280 |
+
# print("PyArrow library is required for Arrow streaming.")
|
281 |
+
# raise
|
282 |
+
# except duckdb.Error as e:
|
283 |
+
# print(f"Error streaming Arrow query: {e}")
|
284 |
+
# raise QueryError(f"Error streaming Arrow query: {e}") from e
|
285 |
+
|
286 |
+
def stream_query_df(self,
|
287 |
+
sql: str,
|
288 |
+
parameters: Optional[List[Any]] = None,
|
289 |
+
vectors_per_chunk: int = 1
|
290 |
+
) -> Iterator[pd.DataFrame]:
|
291 |
+
self._ensure_connection()
|
292 |
+
print(f"Streaming DataFrame query (vectors per chunk {vectors_per_chunk}): {sql}")
|
293 |
+
try:
|
294 |
+
result_set = self._conn.execute(sql, parameters)
|
295 |
+
while True:
|
296 |
+
chunk_df = result_set.fetch_df_chunk(vectors_per_chunk)
|
297 |
+
if chunk_df.empty:
|
298 |
+
break
|
299 |
+
yield chunk_df
|
300 |
+
except ImportError:
|
301 |
+
print("Pandas library is required for DataFrame streaming.")
|
302 |
+
raise
|
303 |
+
except duckdb.Error as e:
|
304 |
+
print(f"Error streaming DataFrame query: {e}")
|
305 |
+
raise QueryError(f"Error streaming DataFrame query: {e}") from e
|
306 |
+
|
307 |
+
def stream_query_arrow(self,
|
308 |
+
sql: str,
|
309 |
+
parameters: Optional[List[Any]] = None,
|
310 |
+
batch_size: int = 1000000
|
311 |
+
) -> Iterator[pa.RecordBatch]:
|
312 |
+
"""
|
313 |
+
Executes a SQL query and streams the results as Arrow RecordBatches.
|
314 |
+
Useful for processing large results iteratively in Python without
|
315 |
+
loading the entire result set into memory.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
sql: The SQL query to execute.
|
319 |
+
parameters: Optional list of parameters for prepared statements.
|
320 |
+
batch_size: The approximate number of rows per Arrow RecordBatch.
|
321 |
+
|
322 |
+
Yields:
|
323 |
+
pyarrow.RecordBatch: Chunks of the result set.
|
324 |
+
|
325 |
+
Raises:
|
326 |
+
QueryError: If the query execution or fetching fails.
|
327 |
+
ImportError: If pyarrow is not installed.
|
328 |
+
"""
|
329 |
+
self._ensure_connection()
|
330 |
+
print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
|
331 |
+
record_batch_reader = None
|
332 |
+
try:
|
333 |
+
# Use execute() to get a result object that supports streaming fetch
|
334 |
+
result_set = self._conn.execute(sql, parameters)
|
335 |
+
# --- MODIFICATION: Get the reader first ---
|
336 |
+
record_batch_reader = result_set.fetch_record_batch(batch_size)
|
337 |
+
# --- Iterate through the reader ---
|
338 |
+
for batch in record_batch_reader:
|
339 |
+
yield batch
|
340 |
+
# --- END MODIFICATION ---
|
341 |
+
except ImportError:
|
342 |
+
print("PyArrow library is required for Arrow streaming.")
|
343 |
+
raise
|
344 |
+
except duckdb.Error as e:
|
345 |
+
print(f"Error streaming Arrow query: {e}")
|
346 |
+
raise QueryError(f"Error streaming Arrow query: {e}") from e
|
347 |
+
finally:
|
348 |
+
# Clean up the reader if it was created
|
349 |
+
if record_batch_reader is not None:
|
350 |
+
# PyArrow readers don't have an explicit close, relying on GC.
|
351 |
+
# Forcing cleanup might involve ensuring references are dropped.
|
352 |
+
del record_batch_reader # Help GC potentially
|
353 |
+
# The original result_set from execute() might also hold resources,
|
354 |
+
# although fetch_record_batch typically consumes it.
|
355 |
+
# Explicitly closing it if possible, or letting it go out of scope.
|
356 |
+
if 'result_set' in locals() and result_set:
|
357 |
+
try:
|
358 |
+
# DuckDBPyResult doesn't have an explicit close, relies on __del__
|
359 |
+
del result_set
|
360 |
+
except Exception:
|
361 |
+
pass # Best effort
|
362 |
+
|
363 |
+
# --- Resource Management Methods --- (Keep as before)
|
364 |
+
def close(self):
|
365 |
+
if self._conn:
|
366 |
+
conn_id = id(self._conn)
|
367 |
+
print(f"Closing connection to '{self._db_path}' (ID: {conn_id})")
|
368 |
+
try:
|
369 |
+
self._conn.close()
|
370 |
+
except duckdb.Error as e:
|
371 |
+
print(f"Error closing DuckDB connection (ID: {conn_id}): {e}")
|
372 |
+
finally:
|
373 |
+
self._conn = None
|
374 |
+
else:
|
375 |
+
print("Connection already closed or never opened.")
|
376 |
+
|
377 |
+
def __enter__(self):
|
378 |
+
self._ensure_connection()
|
379 |
+
return self
|
380 |
+
|
381 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
382 |
+
self.close()
|
383 |
+
|
384 |
+
def __del__(self):
|
385 |
+
if self._conn:
|
386 |
+
print(f"ResourceWarning: DatabaseAPI for '{self._db_path}' was not explicitly closed. Closing now in __del__.")
|
387 |
+
try:
|
388 |
+
self.close()
|
389 |
+
except Exception as e:
|
390 |
+
print(f"Exception during implicit close in __del__: {e}")
|
391 |
+
self._conn = None
|
392 |
+
|
393 |
+
|
394 |
+
# --- Example Usage --- (Keep as before)
|
395 |
+
if __name__ == "__main__":
|
396 |
+
# ... (rest of the example usage code from previous response) ...
|
397 |
+
temp_dir_obj = tempfile.TemporaryDirectory()
|
398 |
+
temp_dir = temp_dir_obj.name
|
399 |
+
print(f"\n--- Using temporary directory: {temp_dir} ---")
|
400 |
+
db_file = Path(temp_dir) / "export_test.db"
|
401 |
+
try:
|
402 |
+
with DatabaseAPI(db_path=db_file) as db_api:
|
403 |
+
db_api.execute_sql("CREATE OR REPLACE TABLE products(id INTEGER, name VARCHAR, price DECIMAL(8,2))")
|
404 |
+
db_api.execute_sql("INSERT INTO products VALUES (101, 'Gadget', 19.99), (102, 'Widget', 35.00), (103, 'Thing''amajig', 9.50)")
|
405 |
+
db_api.execute_sql("CREATE OR REPLACE TABLE sales(product_id INTEGER, sale_date DATE, quantity INTEGER)")
|
406 |
+
db_api.execute_sql("INSERT INTO sales VALUES (101, '2023-10-26', 5), (102, '2023-10-26', 2), (101, '2023-10-27', 3)")
|
407 |
+
export_dir = Path(temp_dir) / "exported_db"
|
408 |
+
db_api.export_database(export_dir)
|
409 |
+
csv_path = Path(temp_dir) / "products_export.csv"
|
410 |
+
db_api.export_data_to_csv('products', csv_path, options={'HEADER': True})
|
411 |
+
parquet_path = Path(temp_dir) / "high_value_products.parquet"
|
412 |
+
db_api.export_data_to_parquet("SELECT * FROM products WHERE price > 20", parquet_path, options={'COMPRESSION': 'SNAPPY'})
|
413 |
+
json_path = Path(temp_dir) / "sales.json"
|
414 |
+
db_api.export_data_to_json("SELECT * FROM sales", json_path, array_format=True)
|
415 |
+
jsonl_path = Path(temp_dir) / "sales.jsonl"
|
416 |
+
db_api.export_data_to_jsonl("SELECT * FROM sales ORDER BY sale_date", jsonl_path)
|
417 |
+
|
418 |
+
with DatabaseAPI() as db_api:
|
419 |
+
db_api.execute_sql("CREATE TABLE large_range AS SELECT range AS id, range % 100 AS category FROM range(1000)")
|
420 |
+
for batch in db_api.stream_query_arrow("SELECT * FROM large_range", batch_size=200):
|
421 |
+
pass
|
422 |
+
for df_chunk in db_api.stream_query_df("SELECT * FROM large_range", vectors_per_chunk=1):
|
423 |
+
pass
|
424 |
+
finally:
|
425 |
+
temp_dir_obj.cleanup()
|
426 |
+
print(f"\n--- Cleaned up temporary directory: {temp_dir} ---")
|
main.py
CHANGED
@@ -1,186 +1,389 @@
|
|
1 |
-
|
2 |
import duckdb
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
from pathlib import Path
|
7 |
-
import
|
8 |
-
import
|
9 |
-
import
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
DB_FILE = DB_DIR / DB_FILENAME
|
15 |
-
UI_EXPECTED_PORT = 8080 # Default port DuckDB UI often tries first
|
16 |
|
17 |
-
|
18 |
-
DB_DIR.mkdir(parents=True, exist_ok=True)
|
19 |
|
20 |
-
# ---
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
# ---
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
)
|
|
|
|
|
30 |
|
31 |
-
# --- Pydantic Models ---
|
32 |
class QueryRequest(BaseModel):
|
33 |
-
sql: str
|
34 |
-
|
35 |
-
class QueryResponse(BaseModel):
|
36 |
-
columns: list[str] | None = None
|
37 |
-
rows: list[dict] | None = None
|
38 |
-
message: str | None = None
|
39 |
-
error: str | None = None
|
40 |
-
|
41 |
-
# --- Helper Function ---
|
42 |
-
def execute_duckdb_query(sql_query: str, db_path: str = str(DB_FILE)):
|
43 |
-
"""Connects to DuckDB, executes a query, and returns results or error."""
|
44 |
-
con = None
|
45 |
-
try:
|
46 |
-
logger.info(f"Connecting to database: {db_path}")
|
47 |
-
con = duckdb.connect(database=db_path, read_only=False)
|
48 |
-
logger.info(f"Executing SQL: {sql_query[:200]}{'...' if len(sql_query) > 200 else ''}")
|
49 |
-
|
50 |
-
con.begin()
|
51 |
-
result_relation = con.execute(sql_query)
|
52 |
-
response_data = {"columns": None, "rows": None, "message": None, "error": None}
|
53 |
-
|
54 |
-
if result_relation.description:
|
55 |
-
columns = [desc[0] for desc in result_relation.description]
|
56 |
-
rows_raw = result_relation.fetchall()
|
57 |
-
rows_dict = [dict(zip(columns, row)) for row in rows_raw]
|
58 |
-
response_data["columns"] = columns
|
59 |
-
response_data["rows"] = rows_dict
|
60 |
-
response_data["message"] = f"Query executed successfully. Fetched {len(rows_dict)} row(s)."
|
61 |
-
logger.info(f"Query successful, returned {len(rows_dict)} rows.")
|
62 |
-
else:
|
63 |
-
response_data["message"] = "Query executed successfully (no data returned)."
|
64 |
-
logger.info("Query successful (no data returned).")
|
65 |
|
66 |
-
|
67 |
-
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
return {"columns": None, "rows": None, "message": None, "error": str(e)}
|
73 |
-
except Exception as e:
|
74 |
-
logger.error(f"General Error: {e}")
|
75 |
-
if con: con.rollback()
|
76 |
-
return {"columns": None, "rows": None, "message": None, "error": f"An unexpected error occurred: {e}"}
|
77 |
-
finally:
|
78 |
-
if con:
|
79 |
-
con.close()
|
80 |
-
logger.info("Database connection closed.")
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
# --- FastAPI Startup Event ---
|
84 |
@app.on_event("startup")
|
85 |
async def startup_event():
|
86 |
-
|
87 |
-
|
88 |
-
try:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
except
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
except Exception as e:
|
122 |
-
|
123 |
-
|
124 |
finally:
|
125 |
-
if
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
@app.
|
131 |
-
async def
|
132 |
-
"""
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
path
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
import duckdb
|
3 |
+
import pandas as pd
|
4 |
+
import pyarrow as pa
|
5 |
+
import pyarrow.ipc
|
6 |
from pathlib import Path
|
7 |
+
import tempfile
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
|
11 |
|
12 |
+
from fastapi import FastAPI, HTTPException, Body, Query, BackgroundTasks, Depends
|
13 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
14 |
+
from pydantic import BaseModel, Field
|
|
|
|
|
15 |
|
16 |
+
from database_api import DatabaseAPI, DatabaseAPIError, QueryError
|
|
|
17 |
|
18 |
+
# --- Configuration --- (Keep as before)
|
19 |
+
DUCKDB_API_DB_PATH = os.getenv("DUCKDB_API_DB_PATH", "api_database.db")
|
20 |
+
DUCKDB_API_READ_ONLY = os.getenv("DUCKDB_API_READ_ONLY", False)
|
21 |
+
DUCKDB_API_CONFIG = {}
|
22 |
+
TEMP_EXPORT_DIR = Path(tempfile.gettempdir()) / "duckdb_api_exports"
|
23 |
+
TEMP_EXPORT_DIR.mkdir(exist_ok=True)
|
24 |
+
print(f"Using temporary directory for exports: {TEMP_EXPORT_DIR}")
|
25 |
|
26 |
+
# --- Pydantic Models --- (Keep as before)
|
27 |
+
class StatusResponse(BaseModel):
|
28 |
+
status: str
|
29 |
+
message: Optional[str] = None
|
30 |
+
|
31 |
+
class ExecuteRequest(BaseModel):
|
32 |
+
sql: str
|
33 |
+
parameters: Optional[List[Any]] = None
|
34 |
|
|
|
35 |
class QueryRequest(BaseModel):
|
36 |
+
sql: str
|
37 |
+
parameters: Optional[List[Any]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
class DataFrameResponse(BaseModel):
|
40 |
+
columns: List[str]
|
41 |
+
records: List[Dict[str, Any]]
|
42 |
|
43 |
+
class InstallRequest(BaseModel):
|
44 |
+
extension_name: str
|
45 |
+
force_install: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
class LoadRequest(BaseModel):
|
48 |
+
extension_name: str
|
49 |
+
|
50 |
+
class ExportDataRequest(BaseModel):
|
51 |
+
source: str = Field(..., description="Table name or SQL SELECT query to export")
|
52 |
+
options: Optional[Dict[str, Any]] = Field(None, description="Format-specific export options")
|
53 |
+
|
54 |
+
# --- FastAPI Application --- (Keep as before)
|
55 |
+
app = FastAPI(
|
56 |
+
title="DuckDB API Wrapper",
|
57 |
+
description="Exposes DuckDB functionalities via a RESTful API.",
|
58 |
+
version="0.2.1" # Incremented version
|
59 |
+
)
|
60 |
+
|
61 |
+
# --- Global DatabaseAPI Instance & Lifecycle --- (Keep as before)
|
62 |
+
db_api_instance: Optional[DatabaseAPI] = None
|
63 |
|
|
|
64 |
@app.on_event("startup")
|
65 |
async def startup_event():
|
66 |
+
global db_api_instance
|
67 |
+
print("Starting up DuckDB API...")
|
68 |
+
try:
|
69 |
+
db_api_instance = DatabaseAPI(db_path=DUCKDB_API_DB_PATH, read_only=DUCKDB_API_READ_ONLY, config=DUCKDB_API_CONFIG)
|
70 |
+
except DatabaseAPIError as e:
|
71 |
+
print(f"FATAL: Could not initialize DatabaseAPI on startup: {e}")
|
72 |
+
db_api_instance = None
|
73 |
+
|
74 |
+
@app.on_event("shutdown")
|
75 |
+
def shutdown_event():
|
76 |
+
print("Shutting down DuckDB API...")
|
77 |
+
if db_api_instance:
|
78 |
+
db_api_instance.close()
|
79 |
+
|
80 |
+
# --- Dependency to get the DB API instance --- (Keep as before)
|
81 |
+
def get_db_api() -> DatabaseAPI:
|
82 |
+
if db_api_instance is None:
|
83 |
+
raise HTTPException(status_code=503, detail="Database service is unavailable (failed to initialize).")
|
84 |
+
try:
|
85 |
+
db_api_instance._ensure_connection()
|
86 |
+
return db_api_instance
|
87 |
+
except DatabaseAPIError as e:
|
88 |
+
raise HTTPException(status_code=503, detail=f"Database service error: {e}")
|
89 |
+
|
90 |
+
# --- API Endpoints ---
|
91 |
+
|
92 |
+
# --- CRUD and Querying Endpoints (Keep as before) ---
|
93 |
+
@app.post("/execute", response_model=StatusResponse, tags=["CRUD"])
|
94 |
+
async def execute_statement(request: ExecuteRequest, api: DatabaseAPI = Depends(get_db_api)):
|
95 |
+
try:
|
96 |
+
api.execute_sql(request.sql, request.parameters)
|
97 |
+
return {"status": "success", "message": None} # Explicitly return None for message
|
98 |
+
except QueryError as e:
|
99 |
+
raise HTTPException(status_code=400, detail=str(e))
|
100 |
+
except DatabaseAPIError as e:
|
101 |
+
raise HTTPException(status_code=500, detail=str(e))
|
102 |
+
|
103 |
+
@app.post("/query/fetchall", response_model=List[tuple], tags=["Querying"])
|
104 |
+
async def query_fetchall_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
|
105 |
+
try:
|
106 |
+
return api.query_fetchall(request.sql, request.parameters)
|
107 |
+
except QueryError as e:
|
108 |
+
raise HTTPException(status_code=400, detail=str(e))
|
109 |
+
except DatabaseAPIError as e:
|
110 |
+
raise HTTPException(status_code=500, detail=str(e))
|
111 |
+
|
112 |
+
@app.post("/query/dataframe", response_model=DataFrameResponse, tags=["Querying"])
|
113 |
+
async def query_dataframe_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
|
114 |
+
try:
|
115 |
+
df = api.query_df(request.sql, request.parameters)
|
116 |
+
df_serializable = df.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
|
117 |
+
return {"columns": df_serializable.columns.tolist(), "records": df_serializable.to_dict(orient='records')}
|
118 |
+
except (QueryError, ImportError) as e:
|
119 |
+
raise HTTPException(status_code=400, detail=str(e))
|
120 |
+
except DatabaseAPIError as e:
|
121 |
+
raise HTTPException(status_code=500, detail=str(e))
|
122 |
+
|
123 |
+
# --- Streaming Endpoints ---
|
124 |
+
|
125 |
+
# --- CORRECTED _stream_arrow_ipc ---
|
126 |
+
async def _stream_arrow_ipc(record_batch_iterator: Iterator[pa.RecordBatch]) -> Generator[bytes, None, None]:
|
127 |
+
"""Helper generator to stream Arrow IPC Stream format."""
|
128 |
+
writer = None
|
129 |
+
sink = pa.BufferOutputStream() # Create sink once
|
130 |
+
try:
|
131 |
+
first_batch = next(record_batch_iterator)
|
132 |
+
writer = pa.ipc.new_stream(sink, first_batch.schema)
|
133 |
+
writer.write_batch(first_batch)
|
134 |
+
# Do NOT yield yet, wait for potential subsequent batches or closure
|
135 |
+
|
136 |
+
for batch in record_batch_iterator:
|
137 |
+
# Write subsequent batches to the SAME writer
|
138 |
+
writer.write_batch(batch)
|
139 |
+
|
140 |
+
except StopIteration:
|
141 |
+
# Handles the case where the iterator was empty initially
|
142 |
+
if writer is None: # No batches were ever processed
|
143 |
+
print("Warning: Arrow stream iterator was empty.")
|
144 |
+
# Yield empty bytes or handle as needed, depends on client expectation
|
145 |
+
# yield b'' # Option 1: empty bytes
|
146 |
+
return # Option 2: Just finish generator
|
147 |
+
|
148 |
except Exception as e:
|
149 |
+
print(f"Error during Arrow streaming generator: {e}")
|
150 |
+
# Consider how to signal error downstream if possible
|
151 |
finally:
|
152 |
+
if writer:
|
153 |
+
try:
|
154 |
+
print("Closing Arrow IPC Stream Writer...")
|
155 |
+
writer.close() # Close the writer to finalize the stream in the sink
|
156 |
+
print("Writer closed.")
|
157 |
+
except Exception as close_e:
|
158 |
+
print(f"Error closing Arrow writer: {close_e}")
|
159 |
+
if sink:
|
160 |
+
try:
|
161 |
+
buffer = sink.getvalue()
|
162 |
+
if buffer:
|
163 |
+
print(f"Yielding final Arrow buffer (size: {len(buffer.to_pybytes())})...")
|
164 |
+
yield buffer.to_pybytes() # Yield the complete stream buffer
|
165 |
+
else:
|
166 |
+
print("Arrow sink buffer was empty after closing writer.")
|
167 |
+
sink.close()
|
168 |
+
except Exception as close_e:
|
169 |
+
print(f"Error closing or getting value from Arrow sink: {close_e}")
|
170 |
+
# --- END CORRECTION ---
|
171 |
|
172 |
+
|
173 |
+
@app.post("/query/stream/arrow", tags=["Streaming"])
|
174 |
+
async def query_stream_arrow_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
|
175 |
+
"""Executes a SQL query and streams results as Arrow IPC Stream format."""
|
176 |
+
try:
|
177 |
+
iterator = api.stream_query_arrow(request.sql, request.parameters)
|
178 |
+
return StreamingResponse(
|
179 |
+
_stream_arrow_ipc(iterator),
|
180 |
+
media_type="application/vnd.apache.arrow.stream"
|
181 |
+
)
|
182 |
+
except (QueryError, ImportError) as e:
|
183 |
+
raise HTTPException(status_code=400, detail=str(e))
|
184 |
+
except DatabaseAPIError as e:
|
185 |
+
raise HTTPException(status_code=500, detail=str(e))
|
186 |
+
|
187 |
+
# --- _stream_jsonl (Keep as before) ---
|
188 |
+
async def _stream_jsonl(dataframe_iterator: Iterator[pd.DataFrame]) -> Generator[bytes, None, None]:
|
189 |
+
try:
|
190 |
+
for df_chunk in dataframe_iterator:
|
191 |
+
df_serializable = df_chunk.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
|
192 |
+
jsonl_string = df_serializable.to_json(orient='records', lines=True, date_format='iso')
|
193 |
+
if jsonl_string:
|
194 |
+
# pandas>=1.5.0 adds newline by default
|
195 |
+
if not jsonl_string.endswith('\n'):
|
196 |
+
jsonl_string += '\n'
|
197 |
+
yield jsonl_string.encode('utf-8')
|
198 |
+
except Exception as e:
|
199 |
+
print(f"Error during JSONL streaming generator: {e}")
|
200 |
+
|
201 |
+
@app.post("/query/stream/jsonl", tags=["Streaming"])
|
202 |
+
async def query_stream_jsonl_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
|
203 |
+
"""Executes a SQL query and streams results as JSON Lines (JSONL)."""
|
204 |
+
try:
|
205 |
+
iterator = api.stream_query_df(request.sql, request.parameters)
|
206 |
+
return StreamingResponse(_stream_jsonl(iterator), media_type="application/jsonl")
|
207 |
+
except (QueryError, ImportError) as e:
|
208 |
+
raise HTTPException(status_code=400, detail=str(e))
|
209 |
+
except DatabaseAPIError as e:
|
210 |
+
raise HTTPException(status_code=500, detail=str(e))
|
211 |
+
|
212 |
+
|
213 |
+
# --- Download / Export Endpoints (Keep as before, uses corrected _export_data) ---
|
214 |
+
def _cleanup_temp_file(path: Union[str, Path]):
|
215 |
+
try:
|
216 |
+
if Path(path).is_file():
|
217 |
+
os.remove(path)
|
218 |
+
print(f"Cleaned up temporary file: {path}")
|
219 |
+
except OSError as e:
|
220 |
+
print(f"Error cleaning up temporary file {path}: {e}")
|
221 |
+
|
222 |
+
async def _create_temp_export(
|
223 |
+
api: DatabaseAPI,
|
224 |
+
source: str,
|
225 |
+
export_format: str,
|
226 |
+
options: Optional[Dict[str, Any]] = None,
|
227 |
+
suffix: str = ".tmp"
|
228 |
+
) -> Path:
|
229 |
+
fd, temp_path_str = tempfile.mkstemp(suffix=suffix, dir=TEMP_EXPORT_DIR)
|
230 |
+
os.close(fd)
|
231 |
+
temp_file_path = Path(temp_path_str)
|
232 |
+
|
233 |
+
try:
|
234 |
+
print(f"Exporting to temporary file: {temp_file_path}")
|
235 |
+
if export_format == 'csv':
|
236 |
+
api.export_data_to_csv(source, temp_file_path, options)
|
237 |
+
elif export_format == 'parquet':
|
238 |
+
api.export_data_to_parquet(source, temp_file_path, options)
|
239 |
+
elif export_format == 'json':
|
240 |
+
api.export_data_to_json(source, temp_file_path, array_format=True, options=options)
|
241 |
+
elif export_format == 'jsonl':
|
242 |
+
api.export_data_to_jsonl(source, temp_file_path, options=options)
|
243 |
+
else:
|
244 |
+
raise ValueError(f"Unsupported export format: {export_format}")
|
245 |
+
return temp_file_path
|
246 |
+
except Exception as e:
|
247 |
+
_cleanup_temp_file(temp_file_path)
|
248 |
+
raise e
|
249 |
+
|
250 |
+
@app.post("/export/data/csv", response_class=FileResponse, tags=["Export / Download"])
|
251 |
+
async def export_csv_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
|
252 |
+
try:
|
253 |
+
temp_file_path = await _create_temp_export(api, request.source, 'csv', request.options, suffix=".csv")
|
254 |
+
background_tasks.add_task(_cleanup_temp_file, temp_file_path)
|
255 |
+
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.csv"
|
256 |
+
return FileResponse(temp_file_path, media_type='text/csv', filename=filename)
|
257 |
+
except (QueryError, ValueError) as e:
|
258 |
+
raise HTTPException(status_code=400, detail=str(e))
|
259 |
+
except DatabaseAPIError as e:
|
260 |
+
raise HTTPException(status_code=500, detail=str(e))
|
261 |
+
except Exception as e:
|
262 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error during CSV export: {e}")
|
263 |
+
|
264 |
+
@app.post("/export/data/parquet", response_class=FileResponse, tags=["Export / Download"])
|
265 |
+
async def export_parquet_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
|
266 |
+
try:
|
267 |
+
temp_file_path = await _create_temp_export(api, request.source, 'parquet', request.options, suffix=".parquet")
|
268 |
+
background_tasks.add_task(_cleanup_temp_file, temp_file_path)
|
269 |
+
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.parquet"
|
270 |
+
return FileResponse(temp_file_path, media_type='application/vnd.apache.parquet', filename=filename)
|
271 |
+
except (QueryError, ValueError) as e:
|
272 |
+
raise HTTPException(status_code=400, detail=str(e))
|
273 |
+
except DatabaseAPIError as e:
|
274 |
+
raise HTTPException(status_code=500, detail=str(e))
|
275 |
+
except Exception as e:
|
276 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error during Parquet export: {e}")
|
277 |
+
|
278 |
+
@app.post("/export/data/json", response_class=FileResponse, tags=["Export / Download"])
|
279 |
+
async def export_json_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
|
280 |
+
try:
|
281 |
+
temp_file_path = await _create_temp_export(api, request.source, 'json', request.options, suffix=".json")
|
282 |
+
background_tasks.add_task(_cleanup_temp_file, temp_file_path)
|
283 |
+
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.json"
|
284 |
+
return FileResponse(temp_file_path, media_type='application/json', filename=filename)
|
285 |
+
except (QueryError, ValueError) as e:
|
286 |
+
raise HTTPException(status_code=400, detail=str(e))
|
287 |
+
except DatabaseAPIError as e:
|
288 |
+
raise HTTPException(status_code=500, detail=str(e))
|
289 |
+
except Exception as e:
|
290 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error during JSON export: {e}")
|
291 |
+
|
292 |
+
@app.post("/export/data/jsonl", response_class=FileResponse, tags=["Export / Download"])
|
293 |
+
async def export_jsonl_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
|
294 |
+
try:
|
295 |
+
temp_file_path = await _create_temp_export(api, request.source, 'jsonl', request.options, suffix=".jsonl")
|
296 |
+
background_tasks.add_task(_cleanup_temp_file, temp_file_path)
|
297 |
+
filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.jsonl"
|
298 |
+
return FileResponse(temp_file_path, media_type='application/jsonl', filename=filename)
|
299 |
+
except (QueryError, ValueError) as e:
|
300 |
+
raise HTTPException(status_code=400, detail=str(e))
|
301 |
+
except DatabaseAPIError as e:
|
302 |
+
raise HTTPException(status_code=500, detail=str(e))
|
303 |
+
except Exception as e:
|
304 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error during JSONL export: {e}")
|
305 |
+
|
306 |
+
@app.post("/export/database", response_class=FileResponse, tags=["Export / Download"])
|
307 |
+
async def export_database_endpoint(background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
|
308 |
+
export_target_dir = Path(tempfile.mkdtemp(dir=TEMP_EXPORT_DIR))
|
309 |
+
fd, zip_path_str = tempfile.mkstemp(suffix=".zip", dir=TEMP_EXPORT_DIR)
|
310 |
+
os.close(fd)
|
311 |
+
zip_file_path = Path(zip_path_str)
|
312 |
+
try:
|
313 |
+
print(f"Exporting database to temporary directory: {export_target_dir}")
|
314 |
+
api.export_database(export_target_dir)
|
315 |
+
print(f"Creating zip archive at: {zip_file_path}")
|
316 |
+
shutil.make_archive(str(zip_file_path.with_suffix('')), 'zip', str(export_target_dir))
|
317 |
+
print(f"Zip archive created: {zip_file_path}")
|
318 |
+
background_tasks.add_task(shutil.rmtree, export_target_dir, ignore_errors=True)
|
319 |
+
background_tasks.add_task(_cleanup_temp_file, zip_file_path)
|
320 |
+
db_name = Path(api._db_path).stem if api._db_path != ':memory:' else 'in_memory_db'
|
321 |
+
return FileResponse(zip_file_path, media_type='application/zip', filename=f"{db_name}_export.zip")
|
322 |
+
except (QueryError, ValueError, OSError, DatabaseAPIError) as e:
|
323 |
+
print(f"Error during database export: {e}")
|
324 |
+
shutil.rmtree(export_target_dir, ignore_errors=True)
|
325 |
+
_cleanup_temp_file(zip_file_path)
|
326 |
+
if isinstance(e, DatabaseAPIError):
|
327 |
+
raise HTTPException(status_code=500, detail=str(e))
|
328 |
+
else:
|
329 |
+
raise HTTPException(status_code=400, detail=str(e))
|
330 |
+
except Exception as e:
|
331 |
+
print(f"Unexpected error during database export: {e}")
|
332 |
+
shutil.rmtree(export_target_dir, ignore_errors=True)
|
333 |
+
_cleanup_temp_file(zip_file_path)
|
334 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error during database export: {e}")
|
335 |
+
|
336 |
+
# --- Extension Management Endpoints ---
|
337 |
+
|
338 |
+
@app.post("/extensions/install", response_model=StatusResponse, tags=["Extensions"])
|
339 |
+
async def install_extension_endpoint(request: InstallRequest, api: DatabaseAPI = Depends(get_db_api)):
|
340 |
+
try:
|
341 |
+
api.install_extension(request.extension_name, request.force_install)
|
342 |
+
return {"status": "success", "message": f"Extension '{request.extension_name}' installed."}
|
343 |
+
except DatabaseAPIError as e:
|
344 |
+
raise HTTPException(status_code=500, detail=str(e))
|
345 |
+
# Catch specific DuckDB errors that should be client errors (400)
|
346 |
+
except (duckdb.IOException, duckdb.CatalogException, duckdb.InvalidInputException) as e:
|
347 |
+
raise HTTPException(status_code=400, detail=f"DuckDB Error during install: {e}")
|
348 |
+
except duckdb.Error as e: # Catch other potential DuckDB errors as 500
|
349 |
+
raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during install: {e}")
|
350 |
+
|
351 |
+
|
352 |
+
@app.post("/extensions/load", response_model=StatusResponse, tags=["Extensions"])
|
353 |
+
async def load_extension_endpoint(request: LoadRequest, api: DatabaseAPI = Depends(get_db_api)):
|
354 |
+
"""Loads an installed DuckDB extension."""
|
355 |
+
try:
|
356 |
+
api.load_extension(request.extension_name)
|
357 |
+
return {"status": "success", "message": f"Extension '{request.extension_name}' loaded."}
|
358 |
+
# --- MODIFIED Exception Handling ---
|
359 |
+
except QueryError as e: # If api.load_extension raised QueryError (e.g., IO/Catalog)
|
360 |
+
raise HTTPException(status_code=400, detail=str(e))
|
361 |
+
except DatabaseAPIError as e: # For other API-level issues
|
362 |
+
raise HTTPException(status_code=500, detail=str(e))
|
363 |
+
# Catch specific DuckDB errors that should be client errors (400)
|
364 |
+
except (duckdb.IOException, duckdb.CatalogException) as e:
|
365 |
+
raise HTTPException(status_code=400, detail=f"DuckDB Error during load: {e}")
|
366 |
+
except duckdb.Error as e: # Catch other potential DuckDB errors as 500
|
367 |
+
raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during load: {e}")
|
368 |
+
# --- END MODIFICATION ---
|
369 |
+
|
370 |
+
# --- Health Check --- (Keep as before)
|
371 |
+
@app.get("/health", response_model=StatusResponse, tags=["Health"])
|
372 |
+
async def health_check():
|
373 |
+
"""Basic health check."""
|
374 |
+
try:
|
375 |
+
_ = get_db_api()
|
376 |
+
return {"status": "ok", "message": None} # Explicitly return None for message
|
377 |
+
except HTTPException as e:
|
378 |
+
raise e
|
379 |
+
except Exception as e:
|
380 |
+
raise HTTPException(status_code=500, detail=f"Health check failed unexpectedly: {e}")
|
381 |
+
|
382 |
+
# --- Run the app --- (Keep as before)
|
383 |
+
if __name__ == "__main__":
|
384 |
+
import uvicorn
|
385 |
+
print(f"Starting DuckDB API server...")
|
386 |
+
print(f"Database file configured at: {DUCKDB_API_DB_PATH}")
|
387 |
+
print(f"Read-only mode: {DUCKDB_API_READ_ONLY}")
|
388 |
+
print(f"Temporary export directory: {TEMP_EXPORT_DIR}")
|
389 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
-
duckdb>=1.
|
4 |
pydantic
|
5 |
-
python-multipart
|
|
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
+
duckdb>=1.2.1
|
4 |
pydantic
|
5 |
+
python-multipart
|
6 |
+
httpx
|
test_api.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import zipfile
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Dict, Any
|
9 |
+
from unittest.mock import patch
|
10 |
+
|
11 |
+
pd = pytest.importorskip("pandas")
|
12 |
+
pa = pytest.importorskip("pyarrow")
|
13 |
+
pa_ipc = pytest.importorskip("pyarrow.ipc")
|
14 |
+
|
15 |
+
from fastapi.testclient import TestClient
|
16 |
+
import main # Import main to reload and access config
|
17 |
+
|
18 |
+
# --- Test Fixtures --- (Keep client fixture as before)
|
19 |
+
@pytest.fixture(scope="module")
|
20 |
+
def client():
|
21 |
+
with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}):
|
22 |
+
import importlib
|
23 |
+
importlib.reload(main)
|
24 |
+
main.TEMP_EXPORT_DIR.mkdir(exist_ok=True)
|
25 |
+
print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}")
|
26 |
+
with TestClient(main.app) as c:
|
27 |
+
yield c
|
28 |
+
print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}")
|
29 |
+
for item in main.TEMP_EXPORT_DIR.iterdir():
|
30 |
+
try:
|
31 |
+
if item.is_file():
|
32 |
+
os.remove(item)
|
33 |
+
elif item.is_dir():
|
34 |
+
shutil.rmtree(item)
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error cleaning up {item}: {e}")
|
37 |
+
|
38 |
+
# --- Test Classes ---
|
39 |
+
|
40 |
+
class TestHealth: # (Keep as before)
|
41 |
+
def test_health_check(self, client: TestClient):
|
42 |
+
response = client.get("/health")
|
43 |
+
assert response.status_code == 200
|
44 |
+
assert response.json() == {"status": "ok", "message": None}
|
45 |
+
|
46 |
+
class TestExecution: # (Keep as before)
|
47 |
+
def test_execute_create(self, client: TestClient):
|
48 |
+
response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"})
|
49 |
+
assert response.status_code == 200
|
50 |
+
assert response.json() == {"status": "success", "message": None}
|
51 |
+
response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"})
|
52 |
+
assert response_fail.status_code == 400
|
53 |
+
|
54 |
+
def test_execute_insert(self, client: TestClient):
|
55 |
+
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
|
56 |
+
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"})
|
57 |
+
assert response.status_code == 200
|
58 |
+
query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"})
|
59 |
+
assert query_response.status_code == 200
|
60 |
+
assert query_response.json() == [[1]]
|
61 |
+
|
62 |
+
def test_execute_insert_params(self, client: TestClient):
|
63 |
+
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
|
64 |
+
response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]})
|
65 |
+
assert response.status_code == 200
|
66 |
+
query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"})
|
67 |
+
assert query_response.status_code == 200
|
68 |
+
assert query_response.json() == [[2, "Bob"]]
|
69 |
+
|
70 |
+
def test_execute_invalid_sql(self, client: TestClient):
|
71 |
+
response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"})
|
72 |
+
assert response.status_code == 400
|
73 |
+
assert "Parser Error" in response.json()["detail"]
|
74 |
+
|
75 |
+
class TestQuerying: # (Keep as before)
|
76 |
+
@pytest.fixture(scope="class", autouse=True)
|
77 |
+
def setup_data(self, client: TestClient):
|
78 |
+
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"})
|
79 |
+
client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"})
|
80 |
+
|
81 |
+
def test_query_fetchall(self, client: TestClient):
|
82 |
+
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"})
|
83 |
+
assert response.status_code == 200
|
84 |
+
assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']]
|
85 |
+
|
86 |
+
def test_query_fetchall_params(self, client: TestClient):
|
87 |
+
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]})
|
88 |
+
assert response.status_code == 200
|
89 |
+
assert response.json() == [[2, 'two'], [3, 'three']]
|
90 |
+
|
91 |
+
def test_query_fetchall_empty(self, client: TestClient):
|
92 |
+
response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"})
|
93 |
+
assert response.status_code == 200
|
94 |
+
assert response.json() == []
|
95 |
+
|
96 |
+
def test_query_dataframe(self, client: TestClient):
|
97 |
+
response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"})
|
98 |
+
assert response.status_code == 200
|
99 |
+
data = response.json()
|
100 |
+
assert data["columns"] == ["id", "val"]
|
101 |
+
assert data["records"] == [
|
102 |
+
{"id": 1, "val": "one"},
|
103 |
+
{"id": 2, "val": "two"},
|
104 |
+
{"id": 3, "val": "three"}
|
105 |
+
]
|
106 |
+
|
107 |
+
def test_query_dataframe_invalid_sql(self, client: TestClient):
|
108 |
+
response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"})
|
109 |
+
assert response.status_code == 400
|
110 |
+
assert "Binder Error" in response.json()["detail"]
|
111 |
+
|
112 |
+
class TestStreaming: # (Keep as before)
|
113 |
+
@pytest.fixture(scope="class", autouse=True)
|
114 |
+
def setup_data(self, client: TestClient):
|
115 |
+
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"})
|
116 |
+
|
117 |
+
def test_stream_arrow(self, client: TestClient):
|
118 |
+
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"})
|
119 |
+
assert response.status_code == 200
|
120 |
+
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
|
121 |
+
if not response.content:
|
122 |
+
pytest.fail("Arrow stream response content is empty")
|
123 |
+
try:
|
124 |
+
reader = pa_ipc.open_stream(response.content)
|
125 |
+
table = reader.read_all()
|
126 |
+
except pa.ArrowInvalid as e:
|
127 |
+
pytest.fail(f"Failed to read Arrow stream: {e}")
|
128 |
+
assert table.num_rows == 10
|
129 |
+
assert table.column_names == ["id", "category"]
|
130 |
+
assert table.column('id').to_pylist() == list(range(10))
|
131 |
+
|
132 |
+
def test_stream_arrow_empty(self, client: TestClient):
|
133 |
+
response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
|
134 |
+
assert response.status_code == 200
|
135 |
+
assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
|
136 |
+
try:
|
137 |
+
reader = pa_ipc.open_stream(response.content)
|
138 |
+
table = reader.read_all()
|
139 |
+
assert table.num_rows == 0
|
140 |
+
except pa.ArrowInvalid as e:
|
141 |
+
print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}")
|
142 |
+
assert response.content == b''
|
143 |
+
|
144 |
+
def test_stream_jsonl(self, client: TestClient):
|
145 |
+
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"})
|
146 |
+
assert response.status_code == 200
|
147 |
+
assert response.headers["content-type"] == "application/jsonl"
|
148 |
+
lines = response.text.strip().split('\n')
|
149 |
+
records = [json.loads(line) for line in lines if line]
|
150 |
+
assert len(records) == 10
|
151 |
+
assert records[0] == {"id": 0, "category": 0}
|
152 |
+
assert records[9] == {"id": 9, "category": 4}
|
153 |
+
|
154 |
+
def test_stream_jsonl_empty(self, client: TestClient):
|
155 |
+
response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
|
156 |
+
assert response.status_code == 200
|
157 |
+
assert response.headers["content-type"] == "application/jsonl"
|
158 |
+
assert response.text.strip() == ""
|
159 |
+
|
160 |
+
class TestExportDownload: # (Keep setup_data as before)
|
161 |
+
@pytest.fixture(scope="class", autouse=True)
|
162 |
+
def setup_data(self, client: TestClient):
|
163 |
+
client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"})
|
164 |
+
client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"})
|
165 |
+
|
166 |
+
@pytest.mark.parametrize(
|
167 |
+
"endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn",
|
168 |
+
[
|
169 |
+
("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c),
|
170 |
+
("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")),
|
171 |
+
# --- MODIFIED JSON/JSONL Lambdas ---
|
172 |
+
("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')),
|
173 |
+
("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c),
|
174 |
+
# --- END MODIFICATION ---
|
175 |
+
]
|
176 |
+
)
|
177 |
+
def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path):
|
178 |
+
endpoint = f"/export/data/{endpoint_suffix}"
|
179 |
+
payload = {"source": "export_table"}
|
180 |
+
if endpoint_suffix == 'csv':
|
181 |
+
payload['options'] = {'HEADER': True}
|
182 |
+
|
183 |
+
response = client.post(endpoint, json=payload)
|
184 |
+
|
185 |
+
assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}"
|
186 |
+
assert response.headers["content-type"].startswith(expected_content_type)
|
187 |
+
assert "content-disposition" in response.headers
|
188 |
+
assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"]
|
189 |
+
|
190 |
+
downloaded_path = tmp_path / f"downloaded{expected_filename_ext}"
|
191 |
+
with open(downloaded_path, "wb") as f:
|
192 |
+
f.write(response.content)
|
193 |
+
assert downloaded_path.exists()
|
194 |
+
assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}"
|
195 |
+
|
196 |
+
# Test with a query source
|
197 |
+
payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"}
|
198 |
+
response = client.post(endpoint, json=payload)
|
199 |
+
assert response.status_code == 200
|
200 |
+
assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"]
|
201 |
+
assert len(response.content) > 0
|
202 |
+
|
203 |
+
# --- Keep test_export_database as before ---
|
204 |
+
def test_export_database(self, client: TestClient, tmp_path):
|
205 |
+
client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"})
|
206 |
+
response = client.post("/export/database")
|
207 |
+
assert response.status_code == 200
|
208 |
+
assert response.headers["content-type"] == "application/zip"
|
209 |
+
assert "content-disposition" in response.headers
|
210 |
+
assert response.headers["content-disposition"].startswith("attachment; filename=")
|
211 |
+
assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"]
|
212 |
+
zip_path = tmp_path / "db_export.zip"
|
213 |
+
with open(zip_path, "wb") as f:
|
214 |
+
f.write(response.content)
|
215 |
+
assert zip_path.exists()
|
216 |
+
with zipfile.ZipFile(zip_path, 'r') as z:
|
217 |
+
print(f"Zip contents: {z.namelist()}")
|
218 |
+
assert "schema.sql" in z.namelist()
|
219 |
+
assert "load.sql" in z.namelist()
|
220 |
+
assert any(name.startswith("export_table") for name in z.namelist())
|
221 |
+
assert any(name.startswith("another_table") for name in z.namelist())
|
222 |
+
|
223 |
+
|
224 |
+
class TestExtensions: # (Keep as before)
|
225 |
+
def test_install_extension_fail(self, client: TestClient):
|
226 |
+
response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"})
|
227 |
+
assert response.status_code >= 400
|
228 |
+
assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"]
|
229 |
+
|
230 |
+
def test_load_extension_fail(self, client: TestClient):
|
231 |
+
response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"})
|
232 |
+
assert response.status_code == 400
|
233 |
+
# --- MODIFIED Assertion ---
|
234 |
+
assert "Error loading extension" in response.json()["detail"]
|
235 |
+
# --- END MODIFICATION ---
|
236 |
+
assert "not found" in response.json()["detail"].lower()
|
237 |
+
|
238 |
+
@pytest.mark.skip(reason="Requires httpfs extension to be available for install/load")
|
239 |
+
def test_install_and_load_httpfs(self, client: TestClient):
|
240 |
+
install_response = client.post("/extensions/install", json={"extension_name": "httpfs"})
|
241 |
+
assert install_response.status_code == 200
|
242 |
+
assert install_response.json()["status"] == "success"
|
243 |
+
|
244 |
+
load_response = client.post("/extensions/load", json={"extension_name": "httpfs"})
|
245 |
+
assert load_response.status_code == 200
|
246 |
+
assert load_response.json()["status"] == "success"
|