amaye15 commited on
Commit
c26b6eb
·
1 Parent(s): f574b59
Files changed (8) hide show
  1. .dockerignore +45 -0
  2. .gitignore +45 -0
  3. Dockerfile +23 -47
  4. README.md +1 -1
  5. database_api.py +426 -0
  6. main.py +372 -169
  7. requirements.txt +3 -2
  8. test_api.py +246 -0
.dockerignore ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .dockerignore
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ env/
8
+ .env
9
+ .venv/
10
+ venv/
11
+ ENV/
12
+ env.bak/
13
+ venv.bak/
14
+
15
+ .pytest_cache/
16
+ .mypy_cache/
17
+ .nox/
18
+ .tox/
19
+ .coverage
20
+ .coverage.*
21
+ coverage.xml
22
+ htmlcov/
23
+ .hypothesis/
24
+
25
+ *.db
26
+ *.db.wal
27
+ *.log
28
+ *.sqlite
29
+ *.sqlite3
30
+
31
+ # Ignore specific generated files if needed
32
+ api_database.db
33
+ api_database.db.wal
34
+ my_duckdb_api_db.db
35
+ my_duckdb_api_db.db.wal
36
+ exported_db/
37
+ duckdb_api_exports/ # Don't copy local temp exports
38
+
39
+ # OS-specific files
40
+ .DS_Store
41
+ Thumbs.db
42
+
43
+ # IDE files
44
+ .idea/
45
+ .vscode/
.gitignore ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .dockerignore
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ env/
8
+ .env
9
+ .venv/
10
+ venv/
11
+ ENV/
12
+ env.bak/
13
+ venv.bak/
14
+
15
+ .pytest_cache/
16
+ .mypy_cache/
17
+ .nox/
18
+ .tox/
19
+ .coverage
20
+ .coverage.*
21
+ coverage.xml
22
+ htmlcov/
23
+ .hypothesis/
24
+
25
+ *.db
26
+ *.db.wal
27
+ *.log
28
+ *.sqlite
29
+ *.sqlite3
30
+
31
+ # Ignore specific generated files if needed
32
+ api_database.db
33
+ api_database.db.wal
34
+ my_duckdb_api_db.db
35
+ my_duckdb_api_db.db.wal
36
+ exported_db/
37
+ duckdb_api_exports/ # Don't copy local temp exports
38
+
39
+ # OS-specific files
40
+ .DS_Store
41
+ Thumbs.db
42
+
43
+ # IDE files
44
+ .idea/
45
+ .vscode/
Dockerfile CHANGED
@@ -1,57 +1,33 @@
1
- # Use an official Python runtime as a parent image
2
- FROM python:3.10-slim
3
 
4
- # Define arguments for user/group IDs (optional, but good practice)
5
- ARG USER_ID=1001
6
- ARG GROUP_ID=1001
 
7
 
8
- # Create a non-root user and group
9
- # Use standard IDs > 1000. Don't use 'node' or common names if not applicable.
10
- RUN groupadd --system --gid ${GROUP_ID} appgroup && \
11
- useradd --system --uid ${USER_ID} --gid appgroup --shell /sbin/nologin --create-home appuser
12
 
13
- # Set the working directory
14
  WORKDIR /app
15
 
16
- # Create essential directories and set ownership *before* copying files
17
- # DuckDB UI often uses ~/.duckdb (which will be /home/appuser/.duckdb)
18
- # Ensure these are owned by the user *before* VOLUME instruction
19
- RUN mkdir -p /app/data /home/appuser/.duckdb && \
20
- chown -R ${USER_ID}:${GROUP_ID} /app /home/appuser/.duckdb
21
-
22
- # Switch context to the non-root user early for subsequent RUN/COPY commands
23
- USER appuser
24
-
25
- # Copy requirements file (as appuser)
26
  COPY requirements.txt .
27
 
28
- # Install dependencies (as appuser)
29
- # This also ensures packages are installed in a user-context if applicable
30
- RUN pip install --no-cache-dir --user --upgrade pip && \
31
- pip install --no-cache-dir --user -r requirements.txt
 
32
 
33
- # Copy application code (as appuser)
34
- COPY main.py .
35
 
36
- # --- Define Volumes ---
37
- # These paths MUST match the directories the 'appuser' process will write to.
38
- # Note: We created and chowned these earlier.
39
- VOLUME /app/data
40
- VOLUME /home/appuser/.duckdb
41
- # --- End Define Volumes ---
42
-
43
- # Expose ports
44
  EXPOSE 8000
45
- EXPOSE 8080
46
-
47
- # Define environment variables
48
- ENV PYTHONUNBUFFERED=1
49
- ENV UI_EXPECTED_PORT=8080
50
- # Ensure Python user packages are in the path
51
- ENV PATH="/home/appuser/.local/bin:${PATH}"
52
- # Set HOME so things like ~/.duckdb resolve correctly
53
- ENV HOME=/home/appuser
54
-
55
- # Command to run the application (now runs as appuser)
56
- # No chmod needed here. Ownership was handled during build.
57
- CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ # Dockerfile
 
2
 
3
+ # 1. Choose a base Python image
4
+ # Using a specific version is recommended for reproducibility.
5
+ # The '-slim' variant is smaller.
6
+ FROM python:3.12-slim
7
 
8
+ # 2. Set environment variables (optional but good practice)
9
+ ENV PYTHONDONTWRITEBYTECODE 1
10
+ ENV PYTHONUNBUFFERED 1
 
11
 
12
+ # 3. Set the working directory inside the container
13
  WORKDIR /app
14
 
15
+ # 4. Copy only the requirements file first to leverage Docker cache
 
 
 
 
 
 
 
 
 
16
  COPY requirements.txt .
17
 
18
+ # 5. Install dependencies
19
+ # --no-cache-dir makes the image smaller
20
+ # --upgrade pip ensures we have the latest pip
21
+ RUN pip install --no-cache-dir --upgrade pip && \
22
+ pip install --no-cache-dir -r requirements.txt
23
 
24
+ # 6. Copy the rest of the application code into the working directory
25
+ COPY . .
26
 
27
+ # 7. Expose the port the app runs on (uvicorn default is 8000)
 
 
 
 
 
 
 
28
  EXPOSE 8000
29
+
30
+ # 8. Define the default command to run when the container starts
31
+ # Use exec form for proper signal handling.
32
+ # Do NOT use --reload in production.
33
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: docker
7
  pinned: false
8
  license: mit
9
  short_description: DuckDB Hosting with UI & FastAPI 4 SQL Calls & DB Downloads
10
- port:
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  pinned: false
8
  license: mit
9
  short_description: DuckDB Hosting with UI & FastAPI 4 SQL Calls & DB Downloads
10
+ port: 8000
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
database_api.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # database_api.py
2
+ import duckdb
3
+ import pandas as pd
4
+ import pyarrow as pa
5
+ import pyarrow.ipc
6
+ from pathlib import Path
7
+ import tempfile
8
+ import os
9
+ import shutil
10
+ from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
11
+ # No need for pybind11 import here anymore
12
+
13
+ # --- Custom Exceptions ---
14
+ class DatabaseAPIError(Exception):
15
+ """Base exception for our custom API."""
16
+ pass
17
+
18
+ class QueryError(DatabaseAPIError):
19
+ """Exception raised for errors during query execution."""
20
+ pass
21
+
22
+ # --- Helper function to format COPY options ---
23
+ def _format_copy_options(options: Optional[Dict[str, Any]]) -> str:
24
+ if not options:
25
+ return ""
26
+ opts_parts = []
27
+ for k, v in options.items():
28
+ key_upper = k.upper()
29
+ if isinstance(v, bool):
30
+ value_repr = str(v).upper()
31
+ elif isinstance(v, (int, float)):
32
+ value_repr = str(v)
33
+ elif isinstance(v, str):
34
+ escaped_v = v.replace("'", "''")
35
+ value_repr = f"'{escaped_v}'"
36
+ else:
37
+ value_repr = repr(v)
38
+ opts_parts.append(f"{key_upper} {value_repr}")
39
+
40
+ opts_str = ", ".join(opts_parts)
41
+ return f"WITH ({opts_str})"
42
+
43
+ # --- Main DatabaseAPI Class ---
44
+ class DatabaseAPI:
45
+ def __init__(self,
46
+ db_path: Union[str, Path] = ":memory:",
47
+ read_only: bool = False,
48
+ config: Optional[Dict[str, str]] = None):
49
+ self._db_path = str(db_path)
50
+ self._config = config or {}
51
+ self._read_only = read_only
52
+ self._conn: Optional[duckdb.DuckDBPyConnection] = None
53
+ try:
54
+ self._conn = duckdb.connect(
55
+ database=self._db_path,
56
+ read_only=self._read_only,
57
+ config=self._config
58
+ )
59
+ print(f"Connected to DuckDB database at '{self._db_path}'")
60
+ except duckdb.Error as e:
61
+ print(f"Failed to connect to DuckDB: {e}")
62
+ raise DatabaseAPIError(f"Failed to connect to DuckDB: {e}") from e
63
+
64
+ def _ensure_connection(self):
65
+ if self._conn is None:
66
+ raise DatabaseAPIError("Database connection is not established or has been closed.")
67
+ try:
68
+ self._conn.execute("SELECT 1", [])
69
+ except (duckdb.ConnectionException, RuntimeError) as e:
70
+ if "Connection has already been closed" in str(e) or "connection closed" in str(e).lower():
71
+ self._conn = None
72
+ raise DatabaseAPIError("Database connection is closed.") from e
73
+ else:
74
+ raise DatabaseAPIError(f"Database connection error: {e}") from e
75
+
76
+ # --- Basic Query Methods --- (Keep as before)
77
+ def execute_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> None:
78
+ self._ensure_connection()
79
+ print(f"Executing SQL: {sql}")
80
+ try:
81
+ self._conn.execute(sql, parameters)
82
+ except duckdb.Error as e:
83
+ print(f"Error executing SQL: {e}")
84
+ raise QueryError(f"Error executing SQL: {e}") from e
85
+
86
+ def query_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> duckdb.DuckDBPyRelation:
87
+ self._ensure_connection()
88
+ print(f"Querying SQL: {sql}")
89
+ try:
90
+ return self._conn.sql(sql, params=parameters)
91
+ except duckdb.Error as e:
92
+ print(f"Error querying SQL: {e}")
93
+ raise QueryError(f"Error querying SQL: {e}") from e
94
+
95
+ def query_df(self, sql: str, parameters: Optional[List[Any]] = None) -> pd.DataFrame:
96
+ self._ensure_connection()
97
+ print(f"Querying SQL to DataFrame: {sql}")
98
+ try:
99
+ return self._conn.execute(sql, parameters).df()
100
+ except ImportError:
101
+ print("Pandas library is required for DataFrame operations.")
102
+ raise
103
+ except duckdb.Error as e:
104
+ print(f"Error querying SQL to DataFrame: {e}")
105
+ raise QueryError(f"Error querying SQL to DataFrame: {e}") from e
106
+
107
+ def query_arrow(self, sql: str, parameters: Optional[List[Any]] = None) -> pa.Table:
108
+ self._ensure_connection()
109
+ print(f"Querying SQL to Arrow Table: {sql}")
110
+ try:
111
+ return self._conn.execute(sql, parameters).arrow()
112
+ except ImportError:
113
+ print("PyArrow library is required for Arrow operations.")
114
+ raise
115
+ except duckdb.Error as e:
116
+ print(f"Error querying SQL to Arrow Table: {e}")
117
+ raise QueryError(f"Error querying SQL to Arrow Table: {e}") from e
118
+
119
+ def query_fetchall(self, sql: str, parameters: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]:
120
+ self._ensure_connection()
121
+ print(f"Querying SQL and fetching all: {sql}")
122
+ try:
123
+ return self._conn.execute(sql, parameters).fetchall()
124
+ except duckdb.Error as e:
125
+ print(f"Error querying SQL: {e}")
126
+ raise QueryError(f"Error querying SQL: {e}") from e
127
+
128
+ def query_fetchone(self, sql: str, parameters: Optional[List[Any]] = None) -> Optional[Tuple[Any, ...]]:
129
+ self._ensure_connection()
130
+ print(f"Querying SQL and fetching one: {sql}")
131
+ try:
132
+ return self._conn.execute(sql, parameters).fetchone()
133
+ except duckdb.Error as e:
134
+ print(f"Error querying SQL: {e}")
135
+ raise QueryError(f"Error querying SQL: {e}") from e
136
+
137
+ # --- Registration Methods --- (Keep as before)
138
+ def register_df(self, name: str, df: pd.DataFrame):
139
+ self._ensure_connection()
140
+ print(f"Registering DataFrame as '{name}'")
141
+ try:
142
+ self._conn.register(name, df)
143
+ except duckdb.Error as e:
144
+ print(f"Error registering DataFrame: {e}")
145
+ raise QueryError(f"Error registering DataFrame: {e}") from e
146
+
147
+ def unregister_df(self, name: str):
148
+ self._ensure_connection()
149
+ print(f"Unregistering virtual table '{name}'")
150
+ try:
151
+ self._conn.unregister(name)
152
+ except duckdb.Error as e:
153
+ if "not found" in str(e).lower():
154
+ print(f"Warning: Virtual table '{name}' not found for unregistering.")
155
+ else:
156
+ print(f"Error unregistering virtual table: {e}")
157
+ raise QueryError(f"Error unregistering virtual table: {e}") from e
158
+
159
+ # --- Extension Methods --- (Keep as before)
160
+ def install_extension(self, extension_name: str, force_install: bool = False):
161
+ self._ensure_connection()
162
+ print(f"Installing extension: {extension_name}")
163
+ try:
164
+ self._conn.install_extension(extension_name, force_install=force_install)
165
+ except duckdb.Error as e:
166
+ print(f"Error installing extension '{extension_name}': {e}")
167
+ raise DatabaseAPIError(f"Error installing extension '{extension_name}': {e}") from e
168
+
169
+ def load_extension(self, extension_name: str):
170
+ self._ensure_connection()
171
+ print(f"Loading extension: {extension_name}")
172
+ try:
173
+ self._conn.load_extension(extension_name)
174
+ # Catch specific DuckDB errors that indicate failure but aren't API errors
175
+ except (duckdb.IOException, duckdb.CatalogException) as load_err:
176
+ print(f"Error loading extension '{extension_name}': {load_err}")
177
+ raise QueryError(f"Error loading extension '{extension_name}': {load_err}") from load_err
178
+ except duckdb.Error as e: # Catch other DuckDB errors
179
+ print(f"Unexpected DuckDB error loading extension '{extension_name}': {e}")
180
+ raise DatabaseAPIError(f"Unexpected DuckDB error loading extension '{extension_name}': {e}") from e
181
+
182
+ # --- Export Methods ---
183
+ def export_database(self, directory_path: Union[str, Path]):
184
+ self._ensure_connection()
185
+ path_str = str(directory_path)
186
+ if not os.path.isdir(path_str):
187
+ try:
188
+ os.makedirs(path_str)
189
+ print(f"Created export directory: {path_str}")
190
+ except OSError as e:
191
+ raise DatabaseAPIError(f"Could not create export directory '{path_str}': {e}") from e
192
+ print(f"Exporting database to directory: {path_str}")
193
+ sql = f"EXPORT DATABASE '{path_str}' (FORMAT CSV)"
194
+ try:
195
+ self._conn.execute(sql)
196
+ print("Database export completed successfully.")
197
+ except duckdb.Error as e:
198
+ print(f"Error exporting database: {e}")
199
+ raise DatabaseAPIError(f"Error exporting database: {e}") from e
200
+
201
+ def _export_data(self,
202
+ source: str,
203
+ output_path: Union[str, Path],
204
+ file_format: str,
205
+ options: Optional[Dict[str, Any]] = None):
206
+ self._ensure_connection()
207
+ path_str = str(output_path)
208
+ options_str = _format_copy_options(options)
209
+ source_safe = source.strip()
210
+ # --- MODIFIED: Use f-string quoting instead of quote_identifier ---
211
+ if ' ' in source_safe or source_safe.upper().startswith(('SELECT', 'WITH', 'VALUES')):
212
+ copy_source = f"({source})"
213
+ else:
214
+ # Simple quoting, might need refinement for complex identifiers
215
+ copy_source = f'"{source_safe}"'
216
+ # --- END MODIFICATION ---
217
+
218
+ sql = f"COPY {copy_source} TO '{path_str}' {options_str}"
219
+ print(f"Exporting data to {path_str} (Format: {file_format}) with options: {options or {}}")
220
+ try:
221
+ self._conn.execute(sql)
222
+ print("Data export completed successfully.")
223
+ except duckdb.Error as e:
224
+ print(f"Error exporting data: {e}")
225
+ raise QueryError(f"Error exporting data to {file_format}: {e}") from e
226
+
227
+ # --- Keep export_data_to_csv, parquet, json, jsonl as before ---
228
+ def export_data_to_csv(self,
229
+ source: str,
230
+ output_path: Union[str, Path],
231
+ options: Optional[Dict[str, Any]] = None):
232
+ csv_options = options.copy() if options else {}
233
+ csv_options['FORMAT'] = 'CSV'
234
+ if 'HEADER' not in {k.upper() for k in csv_options}:
235
+ csv_options['HEADER'] = True
236
+ self._export_data(source, output_path, "CSV", csv_options)
237
+
238
+ def export_data_to_parquet(self,
239
+ source: str,
240
+ output_path: Union[str, Path],
241
+ options: Optional[Dict[str, Any]] = None):
242
+ parquet_options = options.copy() if options else {}
243
+ parquet_options['FORMAT'] = 'PARQUET'
244
+ self._export_data(source, output_path, "Parquet", parquet_options)
245
+
246
+ def export_data_to_json(self,
247
+ source: str,
248
+ output_path: Union[str, Path],
249
+ array_format: bool = True,
250
+ options: Optional[Dict[str, Any]] = None):
251
+ json_options = options.copy() if options else {}
252
+ json_options['FORMAT'] = 'JSON'
253
+ if 'ARRAY' not in {k.upper() for k in json_options}:
254
+ json_options['ARRAY'] = array_format
255
+ self._export_data(source, output_path, "JSON", json_options)
256
+
257
+ def export_data_to_jsonl(self,
258
+ source: str,
259
+ output_path: Union[str, Path],
260
+ options: Optional[Dict[str, Any]] = None):
261
+ self.export_data_to_json(source, output_path, array_format=False, options=options)
262
+
263
+
264
+ # # --- Streaming Read Methods --- (Keep as before)
265
+ # def stream_query_arrow(self,
266
+ # sql: str,
267
+ # parameters: Optional[List[Any]] = None,
268
+ # batch_size: int = 1000000
269
+ # ) -> Iterator[pa.RecordBatch]:
270
+ # self._ensure_connection()
271
+ # print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
272
+ # try:
273
+ # result_set = self._conn.execute(sql, parameters)
274
+ # while True:
275
+ # batch = result_set.fetch_record_batch(batch_size)
276
+ # if not batch:
277
+ # break
278
+ # yield batch
279
+ # except ImportError:
280
+ # print("PyArrow library is required for Arrow streaming.")
281
+ # raise
282
+ # except duckdb.Error as e:
283
+ # print(f"Error streaming Arrow query: {e}")
284
+ # raise QueryError(f"Error streaming Arrow query: {e}") from e
285
+
286
+ def stream_query_df(self,
287
+ sql: str,
288
+ parameters: Optional[List[Any]] = None,
289
+ vectors_per_chunk: int = 1
290
+ ) -> Iterator[pd.DataFrame]:
291
+ self._ensure_connection()
292
+ print(f"Streaming DataFrame query (vectors per chunk {vectors_per_chunk}): {sql}")
293
+ try:
294
+ result_set = self._conn.execute(sql, parameters)
295
+ while True:
296
+ chunk_df = result_set.fetch_df_chunk(vectors_per_chunk)
297
+ if chunk_df.empty:
298
+ break
299
+ yield chunk_df
300
+ except ImportError:
301
+ print("Pandas library is required for DataFrame streaming.")
302
+ raise
303
+ except duckdb.Error as e:
304
+ print(f"Error streaming DataFrame query: {e}")
305
+ raise QueryError(f"Error streaming DataFrame query: {e}") from e
306
+
307
+ def stream_query_arrow(self,
308
+ sql: str,
309
+ parameters: Optional[List[Any]] = None,
310
+ batch_size: int = 1000000
311
+ ) -> Iterator[pa.RecordBatch]:
312
+ """
313
+ Executes a SQL query and streams the results as Arrow RecordBatches.
314
+ Useful for processing large results iteratively in Python without
315
+ loading the entire result set into memory.
316
+
317
+ Args:
318
+ sql: The SQL query to execute.
319
+ parameters: Optional list of parameters for prepared statements.
320
+ batch_size: The approximate number of rows per Arrow RecordBatch.
321
+
322
+ Yields:
323
+ pyarrow.RecordBatch: Chunks of the result set.
324
+
325
+ Raises:
326
+ QueryError: If the query execution or fetching fails.
327
+ ImportError: If pyarrow is not installed.
328
+ """
329
+ self._ensure_connection()
330
+ print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
331
+ record_batch_reader = None
332
+ try:
333
+ # Use execute() to get a result object that supports streaming fetch
334
+ result_set = self._conn.execute(sql, parameters)
335
+ # --- MODIFICATION: Get the reader first ---
336
+ record_batch_reader = result_set.fetch_record_batch(batch_size)
337
+ # --- Iterate through the reader ---
338
+ for batch in record_batch_reader:
339
+ yield batch
340
+ # --- END MODIFICATION ---
341
+ except ImportError:
342
+ print("PyArrow library is required for Arrow streaming.")
343
+ raise
344
+ except duckdb.Error as e:
345
+ print(f"Error streaming Arrow query: {e}")
346
+ raise QueryError(f"Error streaming Arrow query: {e}") from e
347
+ finally:
348
+ # Clean up the reader if it was created
349
+ if record_batch_reader is not None:
350
+ # PyArrow readers don't have an explicit close, relying on GC.
351
+ # Forcing cleanup might involve ensuring references are dropped.
352
+ del record_batch_reader # Help GC potentially
353
+ # The original result_set from execute() might also hold resources,
354
+ # although fetch_record_batch typically consumes it.
355
+ # Explicitly closing it if possible, or letting it go out of scope.
356
+ if 'result_set' in locals() and result_set:
357
+ try:
358
+ # DuckDBPyResult doesn't have an explicit close, relies on __del__
359
+ del result_set
360
+ except Exception:
361
+ pass # Best effort
362
+
363
+ # --- Resource Management Methods --- (Keep as before)
364
+ def close(self):
365
+ if self._conn:
366
+ conn_id = id(self._conn)
367
+ print(f"Closing connection to '{self._db_path}' (ID: {conn_id})")
368
+ try:
369
+ self._conn.close()
370
+ except duckdb.Error as e:
371
+ print(f"Error closing DuckDB connection (ID: {conn_id}): {e}")
372
+ finally:
373
+ self._conn = None
374
+ else:
375
+ print("Connection already closed or never opened.")
376
+
377
+ def __enter__(self):
378
+ self._ensure_connection()
379
+ return self
380
+
381
+ def __exit__(self, exc_type, exc_value, traceback):
382
+ self.close()
383
+
384
+ def __del__(self):
385
+ if self._conn:
386
+ print(f"ResourceWarning: DatabaseAPI for '{self._db_path}' was not explicitly closed. Closing now in __del__.")
387
+ try:
388
+ self.close()
389
+ except Exception as e:
390
+ print(f"Exception during implicit close in __del__: {e}")
391
+ self._conn = None
392
+
393
+
394
+ # --- Example Usage --- (Keep as before)
395
+ if __name__ == "__main__":
396
+ # ... (rest of the example usage code from previous response) ...
397
+ temp_dir_obj = tempfile.TemporaryDirectory()
398
+ temp_dir = temp_dir_obj.name
399
+ print(f"\n--- Using temporary directory: {temp_dir} ---")
400
+ db_file = Path(temp_dir) / "export_test.db"
401
+ try:
402
+ with DatabaseAPI(db_path=db_file) as db_api:
403
+ db_api.execute_sql("CREATE OR REPLACE TABLE products(id INTEGER, name VARCHAR, price DECIMAL(8,2))")
404
+ db_api.execute_sql("INSERT INTO products VALUES (101, 'Gadget', 19.99), (102, 'Widget', 35.00), (103, 'Thing''amajig', 9.50)")
405
+ db_api.execute_sql("CREATE OR REPLACE TABLE sales(product_id INTEGER, sale_date DATE, quantity INTEGER)")
406
+ db_api.execute_sql("INSERT INTO sales VALUES (101, '2023-10-26', 5), (102, '2023-10-26', 2), (101, '2023-10-27', 3)")
407
+ export_dir = Path(temp_dir) / "exported_db"
408
+ db_api.export_database(export_dir)
409
+ csv_path = Path(temp_dir) / "products_export.csv"
410
+ db_api.export_data_to_csv('products', csv_path, options={'HEADER': True})
411
+ parquet_path = Path(temp_dir) / "high_value_products.parquet"
412
+ db_api.export_data_to_parquet("SELECT * FROM products WHERE price > 20", parquet_path, options={'COMPRESSION': 'SNAPPY'})
413
+ json_path = Path(temp_dir) / "sales.json"
414
+ db_api.export_data_to_json("SELECT * FROM sales", json_path, array_format=True)
415
+ jsonl_path = Path(temp_dir) / "sales.jsonl"
416
+ db_api.export_data_to_jsonl("SELECT * FROM sales ORDER BY sale_date", jsonl_path)
417
+
418
+ with DatabaseAPI() as db_api:
419
+ db_api.execute_sql("CREATE TABLE large_range AS SELECT range AS id, range % 100 AS category FROM range(1000)")
420
+ for batch in db_api.stream_query_arrow("SELECT * FROM large_range", batch_size=200):
421
+ pass
422
+ for df_chunk in db_api.stream_query_df("SELECT * FROM large_range", vectors_per_chunk=1):
423
+ pass
424
+ finally:
425
+ temp_dir_obj.cleanup()
426
+ print(f"\n--- Cleaned up temporary directory: {temp_dir} ---")
main.py CHANGED
@@ -1,186 +1,389 @@
1
- import os
2
  import duckdb
3
- from fastapi import FastAPI, HTTPException, Body
4
- from fastapi.responses import FileResponse, JSONResponse
5
- from pydantic import BaseModel, Field
6
  from pathlib import Path
7
- import logging
8
- import time # Import time for potential startup delays
9
- import asyncio
 
10
 
11
- # --- Configuration ---
12
- DB_DIR = Path("data")
13
- DB_FILENAME = "mydatabase.db"
14
- DB_FILE = DB_DIR / DB_FILENAME
15
- UI_EXPECTED_PORT = 8080 # Default port DuckDB UI often tries first
16
 
17
- # Ensure the data directory exists
18
- DB_DIR.mkdir(parents=True, exist_ok=True)
19
 
20
- # --- Logging Setup ---
21
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
- logger = logging.getLogger(__name__)
 
 
 
 
23
 
24
- # --- FastAPI App ---
25
- app = FastAPI(
26
- title="DuckDB API & UI Host",
27
- description="Interact with DuckDB via API (/query, /download) and access the official DuckDB Web UI.",
28
- version="1.0.0"
29
- )
 
 
30
 
31
- # --- Pydantic Models ---
32
  class QueryRequest(BaseModel):
33
- sql: str = Field(..., description="The SQL query to execute against DuckDB.")
34
-
35
- class QueryResponse(BaseModel):
36
- columns: list[str] | None = None
37
- rows: list[dict] | None = None
38
- message: str | None = None
39
- error: str | None = None
40
-
41
- # --- Helper Function ---
42
- def execute_duckdb_query(sql_query: str, db_path: str = str(DB_FILE)):
43
- """Connects to DuckDB, executes a query, and returns results or error."""
44
- con = None
45
- try:
46
- logger.info(f"Connecting to database: {db_path}")
47
- con = duckdb.connect(database=db_path, read_only=False)
48
- logger.info(f"Executing SQL: {sql_query[:200]}{'...' if len(sql_query) > 200 else ''}")
49
-
50
- con.begin()
51
- result_relation = con.execute(sql_query)
52
- response_data = {"columns": None, "rows": None, "message": None, "error": None}
53
-
54
- if result_relation.description:
55
- columns = [desc[0] for desc in result_relation.description]
56
- rows_raw = result_relation.fetchall()
57
- rows_dict = [dict(zip(columns, row)) for row in rows_raw]
58
- response_data["columns"] = columns
59
- response_data["rows"] = rows_dict
60
- response_data["message"] = f"Query executed successfully. Fetched {len(rows_dict)} row(s)."
61
- logger.info(f"Query successful, returned {len(rows_dict)} rows.")
62
- else:
63
- response_data["message"] = "Query executed successfully (no data returned)."
64
- logger.info("Query successful (no data returned).")
65
 
66
- con.commit()
67
- return response_data
 
68
 
69
- except duckdb.Error as e:
70
- logger.error(f"DuckDB Error: {e}")
71
- if con: con.rollback()
72
- return {"columns": None, "rows": None, "message": None, "error": str(e)}
73
- except Exception as e:
74
- logger.error(f"General Error: {e}")
75
- if con: con.rollback()
76
- return {"columns": None, "rows": None, "message": None, "error": f"An unexpected error occurred: {e}"}
77
- finally:
78
- if con:
79
- con.close()
80
- logger.info("Database connection closed.")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # --- FastAPI Startup Event ---
84
  @app.on_event("startup")
85
  async def startup_event():
86
- logger.info("Application startup: Initializing DuckDB UI...")
87
- con = None
88
- try:
89
- # Connect to the main DB file to execute initialization commands
90
- # Use a temporary in-memory DB for UI start if main DB doesn't exist yet?
91
- # No, start_ui seems to need the target DB. Ensure DB file path exists.
92
- if not DB_FILE.parent.exists():
93
- DB_FILE.parent.mkdir(parents=True, exist_ok=True)
94
-
95
- # It's crucial the UI extension can write its state.
96
- # By default it uses ~/.duckdb/ which will be /root/.duckdb in the container.
97
- # Ensure this is writable or mount a volume there.
98
- logger.info(f"Attempting to connect to {DB_FILE} for UI setup.")
99
- con = duckdb.connect(database=str(DB_FILE), read_only=False)
100
-
101
- logger.info("Installing and loading 'ui' extension...")
102
- con.execute("INSTALL ui;")
103
- con.execute("LOAD ui;")
104
-
105
- logger.info("Calling start_ui()... This will start a separate web server.")
106
- # CALL start_ui() starts the server in the background (usually)
107
- # It might print the URL/port it's using to stderr/stdout of the main process
108
- con.execute("CALL start_ui();")
109
-
110
- # Give the UI server a moment to start up. This is a guess.
111
- # A more robust solution might involve checking if the port is listening.
112
- await asyncio.sleep(2)
113
-
114
- logger.info(f"DuckDB UI server startup initiated. It usually listens on port {UI_EXPECTED_PORT}.")
115
- logger.info("Check container logs for the exact URL if it differs.")
116
- logger.info("API server (FastAPI/Uvicorn) is running on port 8000.")
117
-
118
- except duckdb.Error as e:
119
- logger.error(f"CRITICAL: Failed to install/load/start DuckDB UI extension: {e}")
120
- logger.error("The DuckDB UI will likely not be available.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
- logger.error(f"CRITICAL: An unexpected error occurred during UI startup: {e}")
123
- logger.error("The DuckDB UI will likely not be available.")
124
  finally:
125
- if con:
126
- con.close()
127
- logger.info("UI setup connection closed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # --- API Endpoints ---
130
- @app.get("/", summary="Root Endpoint / Info", tags=["General"])
131
- async def read_root():
132
- """Provides links to the API docs and the DuckDB UI."""
133
- # Assumes UI is running on localhost from the container's perspective
134
- # User needs to map the port correctly
135
- return JSONResponse({
136
- "message": "DuckDB API and UI Host",
137
- "api_details": {
138
- "docs": "/docs",
139
- "query_endpoint": "/query (POST)",
140
- "download_endpoint": "/download (GET)"
141
- },
142
- "duckdb_ui": {
143
- "message": f"Access the official DuckDB Web UI. It should be running on port {UI_EXPECTED_PORT} inside the container.",
144
- "typical_access_url": f"http://localhost:{UI_EXPECTED_PORT}",
145
- "notes": f"Ensure you have mapped port {UI_EXPECTED_PORT} from the container when running `docker run` (e.g., -p {UI_EXPECTED_PORT}:{UI_EXPECTED_PORT})."
146
- },
147
- "database_file_container_path": str(DB_FILE)
148
- })
149
-
150
- @app.post("/query", response_model=QueryResponse, summary="Execute SQL Query", tags=["Database API"])
151
- async def execute_query_endpoint(query_request: QueryRequest):
152
- """
153
- Executes a given SQL query against the DuckDB database via the API.
154
- Handles SELECT, INSERT, UPDATE, DELETE, CREATE TABLE, etc.
155
- """
156
- result = execute_duckdb_query(query_request.sql)
157
- if result["error"]:
158
- raise HTTPException(status_code=400, detail=result["error"])
159
- return JSONResponse(content=result)
160
-
161
-
162
- @app.get("/download", summary="Download Database File", tags=["Database API"])
163
- async def download_database_file():
164
- """
165
- Allows downloading the current DuckDB database file via the API.
166
- """
167
- if not DB_FILE.is_file():
168
- logger.error(f"Download request failed: Database file not found at {DB_FILE}")
169
- raise HTTPException(status_code=404, detail="Database file not found.")
170
-
171
- logger.info(f"Serving database file for download: {DB_FILE}")
172
- return FileResponse(
173
- path=str(DB_FILE),
174
- filename=DB_FILENAME,
175
- media_type='application/octet-stream'
176
- )
177
-
178
- # Need asyncio for sleep in startup
179
- # import asyncio
180
-
181
- # --- Run with Uvicorn (for local testing - doesn't handle UI startup well here) ---
182
- # if __name__ == "__main__":
183
- # # Note: Running directly with python main.py won't trigger the startup
184
- # # event correctly in the same way uvicorn command does.
185
- # # Use `uvicorn main:app --reload --port 8000` for local dev testing.
186
- # print("Run using: uvicorn main:app --host 0.0.0.0 --port 8000")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
  import duckdb
3
+ import pandas as pd
4
+ import pyarrow as pa
5
+ import pyarrow.ipc
6
  from pathlib import Path
7
+ import tempfile
8
+ import os
9
+ import shutil
10
+ from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
11
 
12
+ from fastapi import FastAPI, HTTPException, Body, Query, BackgroundTasks, Depends
13
+ from fastapi.responses import StreamingResponse, FileResponse
14
+ from pydantic import BaseModel, Field
 
 
15
 
16
+ from database_api import DatabaseAPI, DatabaseAPIError, QueryError
 
17
 
18
+ # --- Configuration --- (Keep as before)
19
+ DUCKDB_API_DB_PATH = os.getenv("DUCKDB_API_DB_PATH", "api_database.db")
20
+ DUCKDB_API_READ_ONLY = os.getenv("DUCKDB_API_READ_ONLY", False)
21
+ DUCKDB_API_CONFIG = {}
22
+ TEMP_EXPORT_DIR = Path(tempfile.gettempdir()) / "duckdb_api_exports"
23
+ TEMP_EXPORT_DIR.mkdir(exist_ok=True)
24
+ print(f"Using temporary directory for exports: {TEMP_EXPORT_DIR}")
25
 
26
+ # --- Pydantic Models --- (Keep as before)
27
+ class StatusResponse(BaseModel):
28
+ status: str
29
+ message: Optional[str] = None
30
+
31
+ class ExecuteRequest(BaseModel):
32
+ sql: str
33
+ parameters: Optional[List[Any]] = None
34
 
 
35
  class QueryRequest(BaseModel):
36
+ sql: str
37
+ parameters: Optional[List[Any]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ class DataFrameResponse(BaseModel):
40
+ columns: List[str]
41
+ records: List[Dict[str, Any]]
42
 
43
+ class InstallRequest(BaseModel):
44
+ extension_name: str
45
+ force_install: bool = False
 
 
 
 
 
 
 
 
 
46
 
47
+ class LoadRequest(BaseModel):
48
+ extension_name: str
49
+
50
+ class ExportDataRequest(BaseModel):
51
+ source: str = Field(..., description="Table name or SQL SELECT query to export")
52
+ options: Optional[Dict[str, Any]] = Field(None, description="Format-specific export options")
53
+
54
+ # --- FastAPI Application --- (Keep as before)
55
+ app = FastAPI(
56
+ title="DuckDB API Wrapper",
57
+ description="Exposes DuckDB functionalities via a RESTful API.",
58
+ version="0.2.1" # Incremented version
59
+ )
60
+
61
+ # --- Global DatabaseAPI Instance & Lifecycle --- (Keep as before)
62
+ db_api_instance: Optional[DatabaseAPI] = None
63
 
 
64
  @app.on_event("startup")
65
  async def startup_event():
66
+ global db_api_instance
67
+ print("Starting up DuckDB API...")
68
+ try:
69
+ db_api_instance = DatabaseAPI(db_path=DUCKDB_API_DB_PATH, read_only=DUCKDB_API_READ_ONLY, config=DUCKDB_API_CONFIG)
70
+ except DatabaseAPIError as e:
71
+ print(f"FATAL: Could not initialize DatabaseAPI on startup: {e}")
72
+ db_api_instance = None
73
+
74
+ @app.on_event("shutdown")
75
+ def shutdown_event():
76
+ print("Shutting down DuckDB API...")
77
+ if db_api_instance:
78
+ db_api_instance.close()
79
+
80
+ # --- Dependency to get the DB API instance --- (Keep as before)
81
+ def get_db_api() -> DatabaseAPI:
82
+ if db_api_instance is None:
83
+ raise HTTPException(status_code=503, detail="Database service is unavailable (failed to initialize).")
84
+ try:
85
+ db_api_instance._ensure_connection()
86
+ return db_api_instance
87
+ except DatabaseAPIError as e:
88
+ raise HTTPException(status_code=503, detail=f"Database service error: {e}")
89
+
90
+ # --- API Endpoints ---
91
+
92
+ # --- CRUD and Querying Endpoints (Keep as before) ---
93
+ @app.post("/execute", response_model=StatusResponse, tags=["CRUD"])
94
+ async def execute_statement(request: ExecuteRequest, api: DatabaseAPI = Depends(get_db_api)):
95
+ try:
96
+ api.execute_sql(request.sql, request.parameters)
97
+ return {"status": "success", "message": None} # Explicitly return None for message
98
+ except QueryError as e:
99
+ raise HTTPException(status_code=400, detail=str(e))
100
+ except DatabaseAPIError as e:
101
+ raise HTTPException(status_code=500, detail=str(e))
102
+
103
+ @app.post("/query/fetchall", response_model=List[tuple], tags=["Querying"])
104
+ async def query_fetchall_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
105
+ try:
106
+ return api.query_fetchall(request.sql, request.parameters)
107
+ except QueryError as e:
108
+ raise HTTPException(status_code=400, detail=str(e))
109
+ except DatabaseAPIError as e:
110
+ raise HTTPException(status_code=500, detail=str(e))
111
+
112
+ @app.post("/query/dataframe", response_model=DataFrameResponse, tags=["Querying"])
113
+ async def query_dataframe_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
114
+ try:
115
+ df = api.query_df(request.sql, request.parameters)
116
+ df_serializable = df.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
117
+ return {"columns": df_serializable.columns.tolist(), "records": df_serializable.to_dict(orient='records')}
118
+ except (QueryError, ImportError) as e:
119
+ raise HTTPException(status_code=400, detail=str(e))
120
+ except DatabaseAPIError as e:
121
+ raise HTTPException(status_code=500, detail=str(e))
122
+
123
+ # --- Streaming Endpoints ---
124
+
125
+ # --- CORRECTED _stream_arrow_ipc ---
126
+ async def _stream_arrow_ipc(record_batch_iterator: Iterator[pa.RecordBatch]) -> Generator[bytes, None, None]:
127
+ """Helper generator to stream Arrow IPC Stream format."""
128
+ writer = None
129
+ sink = pa.BufferOutputStream() # Create sink once
130
+ try:
131
+ first_batch = next(record_batch_iterator)
132
+ writer = pa.ipc.new_stream(sink, first_batch.schema)
133
+ writer.write_batch(first_batch)
134
+ # Do NOT yield yet, wait for potential subsequent batches or closure
135
+
136
+ for batch in record_batch_iterator:
137
+ # Write subsequent batches to the SAME writer
138
+ writer.write_batch(batch)
139
+
140
+ except StopIteration:
141
+ # Handles the case where the iterator was empty initially
142
+ if writer is None: # No batches were ever processed
143
+ print("Warning: Arrow stream iterator was empty.")
144
+ # Yield empty bytes or handle as needed, depends on client expectation
145
+ # yield b'' # Option 1: empty bytes
146
+ return # Option 2: Just finish generator
147
+
148
  except Exception as e:
149
+ print(f"Error during Arrow streaming generator: {e}")
150
+ # Consider how to signal error downstream if possible
151
  finally:
152
+ if writer:
153
+ try:
154
+ print("Closing Arrow IPC Stream Writer...")
155
+ writer.close() # Close the writer to finalize the stream in the sink
156
+ print("Writer closed.")
157
+ except Exception as close_e:
158
+ print(f"Error closing Arrow writer: {close_e}")
159
+ if sink:
160
+ try:
161
+ buffer = sink.getvalue()
162
+ if buffer:
163
+ print(f"Yielding final Arrow buffer (size: {len(buffer.to_pybytes())})...")
164
+ yield buffer.to_pybytes() # Yield the complete stream buffer
165
+ else:
166
+ print("Arrow sink buffer was empty after closing writer.")
167
+ sink.close()
168
+ except Exception as close_e:
169
+ print(f"Error closing or getting value from Arrow sink: {close_e}")
170
+ # --- END CORRECTION ---
171
 
172
+
173
+ @app.post("/query/stream/arrow", tags=["Streaming"])
174
+ async def query_stream_arrow_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
175
+ """Executes a SQL query and streams results as Arrow IPC Stream format."""
176
+ try:
177
+ iterator = api.stream_query_arrow(request.sql, request.parameters)
178
+ return StreamingResponse(
179
+ _stream_arrow_ipc(iterator),
180
+ media_type="application/vnd.apache.arrow.stream"
181
+ )
182
+ except (QueryError, ImportError) as e:
183
+ raise HTTPException(status_code=400, detail=str(e))
184
+ except DatabaseAPIError as e:
185
+ raise HTTPException(status_code=500, detail=str(e))
186
+
187
+ # --- _stream_jsonl (Keep as before) ---
188
+ async def _stream_jsonl(dataframe_iterator: Iterator[pd.DataFrame]) -> Generator[bytes, None, None]:
189
+ try:
190
+ for df_chunk in dataframe_iterator:
191
+ df_serializable = df_chunk.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
192
+ jsonl_string = df_serializable.to_json(orient='records', lines=True, date_format='iso')
193
+ if jsonl_string:
194
+ # pandas>=1.5.0 adds newline by default
195
+ if not jsonl_string.endswith('\n'):
196
+ jsonl_string += '\n'
197
+ yield jsonl_string.encode('utf-8')
198
+ except Exception as e:
199
+ print(f"Error during JSONL streaming generator: {e}")
200
+
201
+ @app.post("/query/stream/jsonl", tags=["Streaming"])
202
+ async def query_stream_jsonl_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
203
+ """Executes a SQL query and streams results as JSON Lines (JSONL)."""
204
+ try:
205
+ iterator = api.stream_query_df(request.sql, request.parameters)
206
+ return StreamingResponse(_stream_jsonl(iterator), media_type="application/jsonl")
207
+ except (QueryError, ImportError) as e:
208
+ raise HTTPException(status_code=400, detail=str(e))
209
+ except DatabaseAPIError as e:
210
+ raise HTTPException(status_code=500, detail=str(e))
211
+
212
+
213
+ # --- Download / Export Endpoints (Keep as before, uses corrected _export_data) ---
214
+ def _cleanup_temp_file(path: Union[str, Path]):
215
+ try:
216
+ if Path(path).is_file():
217
+ os.remove(path)
218
+ print(f"Cleaned up temporary file: {path}")
219
+ except OSError as e:
220
+ print(f"Error cleaning up temporary file {path}: {e}")
221
+
222
+ async def _create_temp_export(
223
+ api: DatabaseAPI,
224
+ source: str,
225
+ export_format: str,
226
+ options: Optional[Dict[str, Any]] = None,
227
+ suffix: str = ".tmp"
228
+ ) -> Path:
229
+ fd, temp_path_str = tempfile.mkstemp(suffix=suffix, dir=TEMP_EXPORT_DIR)
230
+ os.close(fd)
231
+ temp_file_path = Path(temp_path_str)
232
+
233
+ try:
234
+ print(f"Exporting to temporary file: {temp_file_path}")
235
+ if export_format == 'csv':
236
+ api.export_data_to_csv(source, temp_file_path, options)
237
+ elif export_format == 'parquet':
238
+ api.export_data_to_parquet(source, temp_file_path, options)
239
+ elif export_format == 'json':
240
+ api.export_data_to_json(source, temp_file_path, array_format=True, options=options)
241
+ elif export_format == 'jsonl':
242
+ api.export_data_to_jsonl(source, temp_file_path, options=options)
243
+ else:
244
+ raise ValueError(f"Unsupported export format: {export_format}")
245
+ return temp_file_path
246
+ except Exception as e:
247
+ _cleanup_temp_file(temp_file_path)
248
+ raise e
249
+
250
+ @app.post("/export/data/csv", response_class=FileResponse, tags=["Export / Download"])
251
+ async def export_csv_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
252
+ try:
253
+ temp_file_path = await _create_temp_export(api, request.source, 'csv', request.options, suffix=".csv")
254
+ background_tasks.add_task(_cleanup_temp_file, temp_file_path)
255
+ filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.csv"
256
+ return FileResponse(temp_file_path, media_type='text/csv', filename=filename)
257
+ except (QueryError, ValueError) as e:
258
+ raise HTTPException(status_code=400, detail=str(e))
259
+ except DatabaseAPIError as e:
260
+ raise HTTPException(status_code=500, detail=str(e))
261
+ except Exception as e:
262
+ raise HTTPException(status_code=500, detail=f"Unexpected error during CSV export: {e}")
263
+
264
+ @app.post("/export/data/parquet", response_class=FileResponse, tags=["Export / Download"])
265
+ async def export_parquet_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
266
+ try:
267
+ temp_file_path = await _create_temp_export(api, request.source, 'parquet', request.options, suffix=".parquet")
268
+ background_tasks.add_task(_cleanup_temp_file, temp_file_path)
269
+ filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.parquet"
270
+ return FileResponse(temp_file_path, media_type='application/vnd.apache.parquet', filename=filename)
271
+ except (QueryError, ValueError) as e:
272
+ raise HTTPException(status_code=400, detail=str(e))
273
+ except DatabaseAPIError as e:
274
+ raise HTTPException(status_code=500, detail=str(e))
275
+ except Exception as e:
276
+ raise HTTPException(status_code=500, detail=f"Unexpected error during Parquet export: {e}")
277
+
278
+ @app.post("/export/data/json", response_class=FileResponse, tags=["Export / Download"])
279
+ async def export_json_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
280
+ try:
281
+ temp_file_path = await _create_temp_export(api, request.source, 'json', request.options, suffix=".json")
282
+ background_tasks.add_task(_cleanup_temp_file, temp_file_path)
283
+ filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.json"
284
+ return FileResponse(temp_file_path, media_type='application/json', filename=filename)
285
+ except (QueryError, ValueError) as e:
286
+ raise HTTPException(status_code=400, detail=str(e))
287
+ except DatabaseAPIError as e:
288
+ raise HTTPException(status_code=500, detail=str(e))
289
+ except Exception as e:
290
+ raise HTTPException(status_code=500, detail=f"Unexpected error during JSON export: {e}")
291
+
292
+ @app.post("/export/data/jsonl", response_class=FileResponse, tags=["Export / Download"])
293
+ async def export_jsonl_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
294
+ try:
295
+ temp_file_path = await _create_temp_export(api, request.source, 'jsonl', request.options, suffix=".jsonl")
296
+ background_tasks.add_task(_cleanup_temp_file, temp_file_path)
297
+ filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.jsonl"
298
+ return FileResponse(temp_file_path, media_type='application/jsonl', filename=filename)
299
+ except (QueryError, ValueError) as e:
300
+ raise HTTPException(status_code=400, detail=str(e))
301
+ except DatabaseAPIError as e:
302
+ raise HTTPException(status_code=500, detail=str(e))
303
+ except Exception as e:
304
+ raise HTTPException(status_code=500, detail=f"Unexpected error during JSONL export: {e}")
305
+
306
+ @app.post("/export/database", response_class=FileResponse, tags=["Export / Download"])
307
+ async def export_database_endpoint(background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
308
+ export_target_dir = Path(tempfile.mkdtemp(dir=TEMP_EXPORT_DIR))
309
+ fd, zip_path_str = tempfile.mkstemp(suffix=".zip", dir=TEMP_EXPORT_DIR)
310
+ os.close(fd)
311
+ zip_file_path = Path(zip_path_str)
312
+ try:
313
+ print(f"Exporting database to temporary directory: {export_target_dir}")
314
+ api.export_database(export_target_dir)
315
+ print(f"Creating zip archive at: {zip_file_path}")
316
+ shutil.make_archive(str(zip_file_path.with_suffix('')), 'zip', str(export_target_dir))
317
+ print(f"Zip archive created: {zip_file_path}")
318
+ background_tasks.add_task(shutil.rmtree, export_target_dir, ignore_errors=True)
319
+ background_tasks.add_task(_cleanup_temp_file, zip_file_path)
320
+ db_name = Path(api._db_path).stem if api._db_path != ':memory:' else 'in_memory_db'
321
+ return FileResponse(zip_file_path, media_type='application/zip', filename=f"{db_name}_export.zip")
322
+ except (QueryError, ValueError, OSError, DatabaseAPIError) as e:
323
+ print(f"Error during database export: {e}")
324
+ shutil.rmtree(export_target_dir, ignore_errors=True)
325
+ _cleanup_temp_file(zip_file_path)
326
+ if isinstance(e, DatabaseAPIError):
327
+ raise HTTPException(status_code=500, detail=str(e))
328
+ else:
329
+ raise HTTPException(status_code=400, detail=str(e))
330
+ except Exception as e:
331
+ print(f"Unexpected error during database export: {e}")
332
+ shutil.rmtree(export_target_dir, ignore_errors=True)
333
+ _cleanup_temp_file(zip_file_path)
334
+ raise HTTPException(status_code=500, detail=f"Unexpected error during database export: {e}")
335
+
336
+ # --- Extension Management Endpoints ---
337
+
338
+ @app.post("/extensions/install", response_model=StatusResponse, tags=["Extensions"])
339
+ async def install_extension_endpoint(request: InstallRequest, api: DatabaseAPI = Depends(get_db_api)):
340
+ try:
341
+ api.install_extension(request.extension_name, request.force_install)
342
+ return {"status": "success", "message": f"Extension '{request.extension_name}' installed."}
343
+ except DatabaseAPIError as e:
344
+ raise HTTPException(status_code=500, detail=str(e))
345
+ # Catch specific DuckDB errors that should be client errors (400)
346
+ except (duckdb.IOException, duckdb.CatalogException, duckdb.InvalidInputException) as e:
347
+ raise HTTPException(status_code=400, detail=f"DuckDB Error during install: {e}")
348
+ except duckdb.Error as e: # Catch other potential DuckDB errors as 500
349
+ raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during install: {e}")
350
+
351
+
352
+ @app.post("/extensions/load", response_model=StatusResponse, tags=["Extensions"])
353
+ async def load_extension_endpoint(request: LoadRequest, api: DatabaseAPI = Depends(get_db_api)):
354
+ """Loads an installed DuckDB extension."""
355
+ try:
356
+ api.load_extension(request.extension_name)
357
+ return {"status": "success", "message": f"Extension '{request.extension_name}' loaded."}
358
+ # --- MODIFIED Exception Handling ---
359
+ except QueryError as e: # If api.load_extension raised QueryError (e.g., IO/Catalog)
360
+ raise HTTPException(status_code=400, detail=str(e))
361
+ except DatabaseAPIError as e: # For other API-level issues
362
+ raise HTTPException(status_code=500, detail=str(e))
363
+ # Catch specific DuckDB errors that should be client errors (400)
364
+ except (duckdb.IOException, duckdb.CatalogException) as e:
365
+ raise HTTPException(status_code=400, detail=f"DuckDB Error during load: {e}")
366
+ except duckdb.Error as e: # Catch other potential DuckDB errors as 500
367
+ raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during load: {e}")
368
+ # --- END MODIFICATION ---
369
+
370
+ # --- Health Check --- (Keep as before)
371
+ @app.get("/health", response_model=StatusResponse, tags=["Health"])
372
+ async def health_check():
373
+ """Basic health check."""
374
+ try:
375
+ _ = get_db_api()
376
+ return {"status": "ok", "message": None} # Explicitly return None for message
377
+ except HTTPException as e:
378
+ raise e
379
+ except Exception as e:
380
+ raise HTTPException(status_code=500, detail=f"Health check failed unexpectedly: {e}")
381
+
382
+ # --- Run the app --- (Keep as before)
383
+ if __name__ == "__main__":
384
+ import uvicorn
385
+ print(f"Starting DuckDB API server...")
386
+ print(f"Database file configured at: {DUCKDB_API_DB_PATH}")
387
+ print(f"Read-only mode: {DUCKDB_API_READ_ONLY}")
388
+ print(f"Temporary export directory: {TEMP_EXPORT_DIR}")
389
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
3
- duckdb>=1.0.0 # Ensure version compatibility with UI extension
4
  pydantic
5
- python-multipart
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ duckdb>=1.2.1
4
  pydantic
5
+ python-multipart
6
+ httpx
test_api.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import zipfile
6
+ import json
7
+ from pathlib import Path
8
+ from typing import List, Dict, Any
9
+ from unittest.mock import patch
10
+
11
+ pd = pytest.importorskip("pandas")
12
+ pa = pytest.importorskip("pyarrow")
13
+ pa_ipc = pytest.importorskip("pyarrow.ipc")
14
+
15
+ from fastapi.testclient import TestClient
16
+ import main # Import main to reload and access config
17
+
18
+ # --- Test Fixtures --- (Keep client fixture as before)
19
+ @pytest.fixture(scope="module")
20
+ def client():
21
+ with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}):
22
+ import importlib
23
+ importlib.reload(main)
24
+ main.TEMP_EXPORT_DIR.mkdir(exist_ok=True)
25
+ print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}")
26
+ with TestClient(main.app) as c:
27
+ yield c
28
+ print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}")
29
+ for item in main.TEMP_EXPORT_DIR.iterdir():
30
+ try:
31
+ if item.is_file():
32
+ os.remove(item)
33
+ elif item.is_dir():
34
+ shutil.rmtree(item)
35
+ except Exception as e:
36
+ print(f"Error cleaning up {item}: {e}")
37
+
38
+ # --- Test Classes ---
39
+
40
+ class TestHealth: # (Keep as before)
41
+ def test_health_check(self, client: TestClient):
42
+ response = client.get("/health")
43
+ assert response.status_code == 200
44
+ assert response.json() == {"status": "ok", "message": None}
45
+
46
+ class TestExecution: # (Keep as before)
47
+ def test_execute_create(self, client: TestClient):
48
+ response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"})
49
+ assert response.status_code == 200
50
+ assert response.json() == {"status": "success", "message": None}
51
+ response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"})
52
+ assert response_fail.status_code == 400
53
+
54
+ def test_execute_insert(self, client: TestClient):
55
+ client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
56
+ response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"})
57
+ assert response.status_code == 200
58
+ query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"})
59
+ assert query_response.status_code == 200
60
+ assert query_response.json() == [[1]]
61
+
62
+ def test_execute_insert_params(self, client: TestClient):
63
+ client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
64
+ response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]})
65
+ assert response.status_code == 200
66
+ query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"})
67
+ assert query_response.status_code == 200
68
+ assert query_response.json() == [[2, "Bob"]]
69
+
70
+ def test_execute_invalid_sql(self, client: TestClient):
71
+ response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"})
72
+ assert response.status_code == 400
73
+ assert "Parser Error" in response.json()["detail"]
74
+
75
+ class TestQuerying: # (Keep as before)
76
+ @pytest.fixture(scope="class", autouse=True)
77
+ def setup_data(self, client: TestClient):
78
+ client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"})
79
+ client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"})
80
+
81
+ def test_query_fetchall(self, client: TestClient):
82
+ response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"})
83
+ assert response.status_code == 200
84
+ assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']]
85
+
86
+ def test_query_fetchall_params(self, client: TestClient):
87
+ response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]})
88
+ assert response.status_code == 200
89
+ assert response.json() == [[2, 'two'], [3, 'three']]
90
+
91
+ def test_query_fetchall_empty(self, client: TestClient):
92
+ response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"})
93
+ assert response.status_code == 200
94
+ assert response.json() == []
95
+
96
+ def test_query_dataframe(self, client: TestClient):
97
+ response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"})
98
+ assert response.status_code == 200
99
+ data = response.json()
100
+ assert data["columns"] == ["id", "val"]
101
+ assert data["records"] == [
102
+ {"id": 1, "val": "one"},
103
+ {"id": 2, "val": "two"},
104
+ {"id": 3, "val": "three"}
105
+ ]
106
+
107
+ def test_query_dataframe_invalid_sql(self, client: TestClient):
108
+ response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"})
109
+ assert response.status_code == 400
110
+ assert "Binder Error" in response.json()["detail"]
111
+
112
+ class TestStreaming: # (Keep as before)
113
+ @pytest.fixture(scope="class", autouse=True)
114
+ def setup_data(self, client: TestClient):
115
+ client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"})
116
+
117
+ def test_stream_arrow(self, client: TestClient):
118
+ response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"})
119
+ assert response.status_code == 200
120
+ assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
121
+ if not response.content:
122
+ pytest.fail("Arrow stream response content is empty")
123
+ try:
124
+ reader = pa_ipc.open_stream(response.content)
125
+ table = reader.read_all()
126
+ except pa.ArrowInvalid as e:
127
+ pytest.fail(f"Failed to read Arrow stream: {e}")
128
+ assert table.num_rows == 10
129
+ assert table.column_names == ["id", "category"]
130
+ assert table.column('id').to_pylist() == list(range(10))
131
+
132
+ def test_stream_arrow_empty(self, client: TestClient):
133
+ response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
134
+ assert response.status_code == 200
135
+ assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
136
+ try:
137
+ reader = pa_ipc.open_stream(response.content)
138
+ table = reader.read_all()
139
+ assert table.num_rows == 0
140
+ except pa.ArrowInvalid as e:
141
+ print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}")
142
+ assert response.content == b''
143
+
144
+ def test_stream_jsonl(self, client: TestClient):
145
+ response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"})
146
+ assert response.status_code == 200
147
+ assert response.headers["content-type"] == "application/jsonl"
148
+ lines = response.text.strip().split('\n')
149
+ records = [json.loads(line) for line in lines if line]
150
+ assert len(records) == 10
151
+ assert records[0] == {"id": 0, "category": 0}
152
+ assert records[9] == {"id": 9, "category": 4}
153
+
154
+ def test_stream_jsonl_empty(self, client: TestClient):
155
+ response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
156
+ assert response.status_code == 200
157
+ assert response.headers["content-type"] == "application/jsonl"
158
+ assert response.text.strip() == ""
159
+
160
+ class TestExportDownload: # (Keep setup_data as before)
161
+ @pytest.fixture(scope="class", autouse=True)
162
+ def setup_data(self, client: TestClient):
163
+ client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"})
164
+ client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"})
165
+
166
+ @pytest.mark.parametrize(
167
+ "endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn",
168
+ [
169
+ ("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c),
170
+ ("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")),
171
+ # --- MODIFIED JSON/JSONL Lambdas ---
172
+ ("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')),
173
+ ("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c),
174
+ # --- END MODIFICATION ---
175
+ ]
176
+ )
177
+ def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path):
178
+ endpoint = f"/export/data/{endpoint_suffix}"
179
+ payload = {"source": "export_table"}
180
+ if endpoint_suffix == 'csv':
181
+ payload['options'] = {'HEADER': True}
182
+
183
+ response = client.post(endpoint, json=payload)
184
+
185
+ assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}"
186
+ assert response.headers["content-type"].startswith(expected_content_type)
187
+ assert "content-disposition" in response.headers
188
+ assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"]
189
+
190
+ downloaded_path = tmp_path / f"downloaded{expected_filename_ext}"
191
+ with open(downloaded_path, "wb") as f:
192
+ f.write(response.content)
193
+ assert downloaded_path.exists()
194
+ assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}"
195
+
196
+ # Test with a query source
197
+ payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"}
198
+ response = client.post(endpoint, json=payload)
199
+ assert response.status_code == 200
200
+ assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"]
201
+ assert len(response.content) > 0
202
+
203
+ # --- Keep test_export_database as before ---
204
+ def test_export_database(self, client: TestClient, tmp_path):
205
+ client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"})
206
+ response = client.post("/export/database")
207
+ assert response.status_code == 200
208
+ assert response.headers["content-type"] == "application/zip"
209
+ assert "content-disposition" in response.headers
210
+ assert response.headers["content-disposition"].startswith("attachment; filename=")
211
+ assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"]
212
+ zip_path = tmp_path / "db_export.zip"
213
+ with open(zip_path, "wb") as f:
214
+ f.write(response.content)
215
+ assert zip_path.exists()
216
+ with zipfile.ZipFile(zip_path, 'r') as z:
217
+ print(f"Zip contents: {z.namelist()}")
218
+ assert "schema.sql" in z.namelist()
219
+ assert "load.sql" in z.namelist()
220
+ assert any(name.startswith("export_table") for name in z.namelist())
221
+ assert any(name.startswith("another_table") for name in z.namelist())
222
+
223
+
224
+ class TestExtensions: # (Keep as before)
225
+ def test_install_extension_fail(self, client: TestClient):
226
+ response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"})
227
+ assert response.status_code >= 400
228
+ assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"]
229
+
230
+ def test_load_extension_fail(self, client: TestClient):
231
+ response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"})
232
+ assert response.status_code == 400
233
+ # --- MODIFIED Assertion ---
234
+ assert "Error loading extension" in response.json()["detail"]
235
+ # --- END MODIFICATION ---
236
+ assert "not found" in response.json()["detail"].lower()
237
+
238
+ @pytest.mark.skip(reason="Requires httpfs extension to be available for install/load")
239
+ def test_install_and_load_httpfs(self, client: TestClient):
240
+ install_response = client.post("/extensions/install", json={"extension_name": "httpfs"})
241
+ assert install_response.status_code == 200
242
+ assert install_response.json()["status"] == "success"
243
+
244
+ load_response = client.post("/extensions/load", json={"extension_name": "httpfs"})
245
+ assert load_response.status_code == 200
246
+ assert load_response.json()["status"] == "success"