amaye15 commited on
Commit
5dc86cf
·
1 Parent(s): b138bfa
Files changed (7) hide show
  1. .dockerignore +6 -36
  2. Dockerfile +22 -17
  3. README.md +51 -9
  4. database_api.py +0 -426
  5. main.py +350 -357
  6. requirements.txt +5 -10
  7. test_api.py +0 -246
.dockerignore CHANGED
@@ -1,45 +1,15 @@
1
- # .dockerignore
2
  __pycache__/
3
  *.pyc
4
  *.pyo
5
  *.pyd
6
  .Python
7
  env/
8
- .env
9
- .venv/
10
  venv/
11
- ENV/
12
- env.bak/
13
- venv.bak/
14
-
15
- .pytest_cache/
16
- .mypy_cache/
17
- .nox/
18
- .tox/
19
- .coverage
20
- .coverage.*
21
- coverage.xml
22
- htmlcov/
23
- .hypothesis/
24
-
25
  *.db
26
  *.db.wal
27
- *.log
28
- *.sqlite
29
- *.sqlite3
30
-
31
- # Ignore specific generated files if needed
32
- api_database.db
33
- api_database.db.wal
34
- my_duckdb_api_db.db
35
- my_duckdb_api_db.db.wal
36
- exported_db/
37
- duckdb_api_exports/ # Don't copy local temp exports
38
-
39
- # OS-specific files
40
- .DS_Store
41
- Thumbs.db
42
-
43
- # IDE files
44
- .idea/
45
- .vscode/
 
 
1
  __pycache__/
2
  *.pyc
3
  *.pyo
4
  *.pyd
5
  .Python
6
  env/
 
 
7
  venv/
8
+ .env
9
+ .git
10
+ .gitignore
 
 
 
 
 
 
 
 
 
 
 
11
  *.db
12
  *.db.wal
13
+ data/*.db
14
+ data/*.db.wal
15
+ # Add other files/directories to ignore if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile CHANGED
@@ -1,33 +1,38 @@
1
- # Dockerfile
2
-
3
  FROM python:3.10-slim
4
 
 
5
  ENV PYTHONDONTWRITEBYTECODE 1
6
  ENV PYTHONUNBUFFERED 1
 
 
7
 
8
  # Create a non-root user and group
9
- ARG UID=1000
10
- ARG GID=1000
11
- RUN groupadd -g ${GID} --system appgroup && useradd -u ${UID} -g appgroup --system appuser
12
 
 
13
  WORKDIR /app
14
 
15
- # Create data directory and set permissions
16
- RUN mkdir /app/data && chown appuser:appgroup /app/data
 
 
 
 
17
 
18
- # Copy requirements and install as root first (some packages might need it)
19
- COPY requirements.txt .
20
- RUN pip install --no-cache-dir --upgrade pip && \
21
- pip install --no-cache-dir -r requirements.txt
22
 
23
- # Copy application code and set permissions
24
- COPY . .
25
- RUN chown -R appuser:appgroup /app
26
 
27
  # Switch to the non-root user
28
  USER appuser
29
 
30
- EXPOSE 8000
 
31
 
32
- # Run uvicorn as the non-root user
33
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
 
 
1
+ # Use an official Python runtime as a parent image
 
2
  FROM python:3.10-slim
3
 
4
+ # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE 1
6
  ENV PYTHONUNBUFFERED 1
7
+ # Set the DuckDB path inside the container
8
+ ENV DUCKDB_PATH /app/data/mydatabase.db
9
 
10
  # Create a non-root user and group
11
+ RUN adduser --disabled-password --gecos "" appuser
 
 
12
 
13
+ # Set the working directory in the container
14
  WORKDIR /app
15
 
16
+ # Copy the requirements file into the container at /app
17
+ COPY requirements.txt /app/
18
+
19
+ # Install any needed packages specified in requirements.txt
20
+ # Use --no-cache-dir to reduce image size
21
+ RUN pip install --no-cache-dir -r requirements.txt
22
 
23
+ # Copy the current directory contents into the container at /app
24
+ COPY . /app/
 
 
25
 
26
+ # Create the data directory and set permissions
27
+ # Run these steps as root before switching user
28
+ RUN mkdir -p /app/data && chown -R appuser:appuser /app
29
 
30
  # Switch to the non-root user
31
  USER appuser
32
 
33
+ # Make port 7860 available to the world outside this container (Hugging Face default)
34
+ EXPOSE 7860
35
 
36
+ # Run main.py when the container launches using Uvicorn
37
+ # Use 0.0.0.0 to make it accessible externally
38
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,13 +1,55 @@
1
  ---
2
- title: DuckDB UI
3
- emoji:
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: DuckDB Hosting with UI & FastAPI 4 SQL Calls & DB Downloads
10
- port: 8000
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DuckDB FastAPI API
3
+ emoji: 🦆
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
+ # Optional: specify Python version for clarity, though the Dockerfile defines it
9
+ # python_version: 3.10
 
10
  ---
11
 
12
+ # DuckDB FastAPI API
13
+
14
+ This Space provides a simple API built with FastAPI to interact with a DuckDB database.
15
+
16
+ **Features:**
17
+
18
+ * Create tables
19
+ * Read table data (with limit/offset)
20
+ * Insert rows into tables
21
+ * Update rows based on a condition
22
+ * Delete rows based on a condition
23
+ * Download a table as CSV
24
+ * Download the entire database file
25
+ * Health check endpoint
26
+
27
+ **API Documentation:**
28
+
29
+ The API documentation (powered by Swagger UI) is available at the `/docs` endpoint of your Space URL.
30
+
31
+ **Example Usage (using curl):**
32
+
33
+ ```bash
34
+ # Health Check
35
+ curl https://[your-space-subdomain].hf.space/health
36
+
37
+ # Create a table
38
+ curl -X POST "https://[your-space-subdomain].hf.space/tables/my_data" \
39
+ -H "Content-Type: application/json" \
40
+ -d '{"columns": [{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}]}'
41
+
42
+ # Insert rows
43
+ curl -X POST "https://[your-space-subdomain].hf.space/tables/my_data/rows" \
44
+ -H "Content-Type: application/json" \
45
+ -d '{"rows": [{"id": 1, "value": "apple"}, {"id": 2, "value": "banana"}]}'
46
+
47
+ # Read table data
48
+ curl https://[your-space-subdomain].hf.space/tables/my_data
49
+
50
+ # Download table as CSV
51
+ curl -o my_data.csv https://[your-space-subdomain].hf.space/download/table/my_data
52
+
53
+ # Download database file
54
+ curl -o downloaded_db.db https://[your-space-subdomain].hf.space/download/database
55
+
database_api.py DELETED
@@ -1,426 +0,0 @@
1
- # database_api.py
2
- import duckdb
3
- import pandas as pd
4
- import pyarrow as pa
5
- import pyarrow.ipc
6
- from pathlib import Path
7
- import tempfile
8
- import os
9
- import shutil
10
- from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
11
- # No need for pybind11 import here anymore
12
-
13
- # --- Custom Exceptions ---
14
- class DatabaseAPIError(Exception):
15
- """Base exception for our custom API."""
16
- pass
17
-
18
- class QueryError(DatabaseAPIError):
19
- """Exception raised for errors during query execution."""
20
- pass
21
-
22
- # --- Helper function to format COPY options ---
23
- def _format_copy_options(options: Optional[Dict[str, Any]]) -> str:
24
- if not options:
25
- return ""
26
- opts_parts = []
27
- for k, v in options.items():
28
- key_upper = k.upper()
29
- if isinstance(v, bool):
30
- value_repr = str(v).upper()
31
- elif isinstance(v, (int, float)):
32
- value_repr = str(v)
33
- elif isinstance(v, str):
34
- escaped_v = v.replace("'", "''")
35
- value_repr = f"'{escaped_v}'"
36
- else:
37
- value_repr = repr(v)
38
- opts_parts.append(f"{key_upper} {value_repr}")
39
-
40
- opts_str = ", ".join(opts_parts)
41
- return f"WITH ({opts_str})"
42
-
43
- # --- Main DatabaseAPI Class ---
44
- class DatabaseAPI:
45
- def __init__(self,
46
- db_path: Union[str, Path] = ":memory:",
47
- read_only: bool = False,
48
- config: Optional[Dict[str, str]] = None):
49
- self._db_path = str(db_path)
50
- self._config = config or {}
51
- self._read_only = read_only
52
- self._conn: Optional[duckdb.DuckDBPyConnection] = None
53
- try:
54
- self._conn = duckdb.connect(
55
- database=self._db_path,
56
- read_only=self._read_only,
57
- config=self._config
58
- )
59
- print(f"Connected to DuckDB database at '{self._db_path}'")
60
- except duckdb.Error as e:
61
- print(f"Failed to connect to DuckDB: {e}")
62
- raise DatabaseAPIError(f"Failed to connect to DuckDB: {e}") from e
63
-
64
- def _ensure_connection(self):
65
- if self._conn is None:
66
- raise DatabaseAPIError("Database connection is not established or has been closed.")
67
- try:
68
- self._conn.execute("SELECT 1", [])
69
- except (duckdb.ConnectionException, RuntimeError) as e:
70
- if "Connection has already been closed" in str(e) or "connection closed" in str(e).lower():
71
- self._conn = None
72
- raise DatabaseAPIError("Database connection is closed.") from e
73
- else:
74
- raise DatabaseAPIError(f"Database connection error: {e}") from e
75
-
76
- # --- Basic Query Methods --- (Keep as before)
77
- def execute_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> None:
78
- self._ensure_connection()
79
- print(f"Executing SQL: {sql}")
80
- try:
81
- self._conn.execute(sql, parameters)
82
- except duckdb.Error as e:
83
- print(f"Error executing SQL: {e}")
84
- raise QueryError(f"Error executing SQL: {e}") from e
85
-
86
- def query_sql(self, sql: str, parameters: Optional[List[Any]] = None) -> duckdb.DuckDBPyRelation:
87
- self._ensure_connection()
88
- print(f"Querying SQL: {sql}")
89
- try:
90
- return self._conn.sql(sql, params=parameters)
91
- except duckdb.Error as e:
92
- print(f"Error querying SQL: {e}")
93
- raise QueryError(f"Error querying SQL: {e}") from e
94
-
95
- def query_df(self, sql: str, parameters: Optional[List[Any]] = None) -> pd.DataFrame:
96
- self._ensure_connection()
97
- print(f"Querying SQL to DataFrame: {sql}")
98
- try:
99
- return self._conn.execute(sql, parameters).df()
100
- except ImportError:
101
- print("Pandas library is required for DataFrame operations.")
102
- raise
103
- except duckdb.Error as e:
104
- print(f"Error querying SQL to DataFrame: {e}")
105
- raise QueryError(f"Error querying SQL to DataFrame: {e}") from e
106
-
107
- def query_arrow(self, sql: str, parameters: Optional[List[Any]] = None) -> pa.Table:
108
- self._ensure_connection()
109
- print(f"Querying SQL to Arrow Table: {sql}")
110
- try:
111
- return self._conn.execute(sql, parameters).arrow()
112
- except ImportError:
113
- print("PyArrow library is required for Arrow operations.")
114
- raise
115
- except duckdb.Error as e:
116
- print(f"Error querying SQL to Arrow Table: {e}")
117
- raise QueryError(f"Error querying SQL to Arrow Table: {e}") from e
118
-
119
- def query_fetchall(self, sql: str, parameters: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]:
120
- self._ensure_connection()
121
- print(f"Querying SQL and fetching all: {sql}")
122
- try:
123
- return self._conn.execute(sql, parameters).fetchall()
124
- except duckdb.Error as e:
125
- print(f"Error querying SQL: {e}")
126
- raise QueryError(f"Error querying SQL: {e}") from e
127
-
128
- def query_fetchone(self, sql: str, parameters: Optional[List[Any]] = None) -> Optional[Tuple[Any, ...]]:
129
- self._ensure_connection()
130
- print(f"Querying SQL and fetching one: {sql}")
131
- try:
132
- return self._conn.execute(sql, parameters).fetchone()
133
- except duckdb.Error as e:
134
- print(f"Error querying SQL: {e}")
135
- raise QueryError(f"Error querying SQL: {e}") from e
136
-
137
- # --- Registration Methods --- (Keep as before)
138
- def register_df(self, name: str, df: pd.DataFrame):
139
- self._ensure_connection()
140
- print(f"Registering DataFrame as '{name}'")
141
- try:
142
- self._conn.register(name, df)
143
- except duckdb.Error as e:
144
- print(f"Error registering DataFrame: {e}")
145
- raise QueryError(f"Error registering DataFrame: {e}") from e
146
-
147
- def unregister_df(self, name: str):
148
- self._ensure_connection()
149
- print(f"Unregistering virtual table '{name}'")
150
- try:
151
- self._conn.unregister(name)
152
- except duckdb.Error as e:
153
- if "not found" in str(e).lower():
154
- print(f"Warning: Virtual table '{name}' not found for unregistering.")
155
- else:
156
- print(f"Error unregistering virtual table: {e}")
157
- raise QueryError(f"Error unregistering virtual table: {e}") from e
158
-
159
- # --- Extension Methods --- (Keep as before)
160
- def install_extension(self, extension_name: str, force_install: bool = False):
161
- self._ensure_connection()
162
- print(f"Installing extension: {extension_name}")
163
- try:
164
- self._conn.install_extension(extension_name, force_install=force_install)
165
- except duckdb.Error as e:
166
- print(f"Error installing extension '{extension_name}': {e}")
167
- raise DatabaseAPIError(f"Error installing extension '{extension_name}': {e}") from e
168
-
169
- def load_extension(self, extension_name: str):
170
- self._ensure_connection()
171
- print(f"Loading extension: {extension_name}")
172
- try:
173
- self._conn.load_extension(extension_name)
174
- # Catch specific DuckDB errors that indicate failure but aren't API errors
175
- except (duckdb.IOException, duckdb.CatalogException) as load_err:
176
- print(f"Error loading extension '{extension_name}': {load_err}")
177
- raise QueryError(f"Error loading extension '{extension_name}': {load_err}") from load_err
178
- except duckdb.Error as e: # Catch other DuckDB errors
179
- print(f"Unexpected DuckDB error loading extension '{extension_name}': {e}")
180
- raise DatabaseAPIError(f"Unexpected DuckDB error loading extension '{extension_name}': {e}") from e
181
-
182
- # --- Export Methods ---
183
- def export_database(self, directory_path: Union[str, Path]):
184
- self._ensure_connection()
185
- path_str = str(directory_path)
186
- if not os.path.isdir(path_str):
187
- try:
188
- os.makedirs(path_str)
189
- print(f"Created export directory: {path_str}")
190
- except OSError as e:
191
- raise DatabaseAPIError(f"Could not create export directory '{path_str}': {e}") from e
192
- print(f"Exporting database to directory: {path_str}")
193
- sql = f"EXPORT DATABASE '{path_str}' (FORMAT CSV)"
194
- try:
195
- self._conn.execute(sql)
196
- print("Database export completed successfully.")
197
- except duckdb.Error as e:
198
- print(f"Error exporting database: {e}")
199
- raise DatabaseAPIError(f"Error exporting database: {e}") from e
200
-
201
- def _export_data(self,
202
- source: str,
203
- output_path: Union[str, Path],
204
- file_format: str,
205
- options: Optional[Dict[str, Any]] = None):
206
- self._ensure_connection()
207
- path_str = str(output_path)
208
- options_str = _format_copy_options(options)
209
- source_safe = source.strip()
210
- # --- MODIFIED: Use f-string quoting instead of quote_identifier ---
211
- if ' ' in source_safe or source_safe.upper().startswith(('SELECT', 'WITH', 'VALUES')):
212
- copy_source = f"({source})"
213
- else:
214
- # Simple quoting, might need refinement for complex identifiers
215
- copy_source = f'"{source_safe}"'
216
- # --- END MODIFICATION ---
217
-
218
- sql = f"COPY {copy_source} TO '{path_str}' {options_str}"
219
- print(f"Exporting data to {path_str} (Format: {file_format}) with options: {options or {}}")
220
- try:
221
- self._conn.execute(sql)
222
- print("Data export completed successfully.")
223
- except duckdb.Error as e:
224
- print(f"Error exporting data: {e}")
225
- raise QueryError(f"Error exporting data to {file_format}: {e}") from e
226
-
227
- # --- Keep export_data_to_csv, parquet, json, jsonl as before ---
228
- def export_data_to_csv(self,
229
- source: str,
230
- output_path: Union[str, Path],
231
- options: Optional[Dict[str, Any]] = None):
232
- csv_options = options.copy() if options else {}
233
- csv_options['FORMAT'] = 'CSV'
234
- if 'HEADER' not in {k.upper() for k in csv_options}:
235
- csv_options['HEADER'] = True
236
- self._export_data(source, output_path, "CSV", csv_options)
237
-
238
- def export_data_to_parquet(self,
239
- source: str,
240
- output_path: Union[str, Path],
241
- options: Optional[Dict[str, Any]] = None):
242
- parquet_options = options.copy() if options else {}
243
- parquet_options['FORMAT'] = 'PARQUET'
244
- self._export_data(source, output_path, "Parquet", parquet_options)
245
-
246
- def export_data_to_json(self,
247
- source: str,
248
- output_path: Union[str, Path],
249
- array_format: bool = True,
250
- options: Optional[Dict[str, Any]] = None):
251
- json_options = options.copy() if options else {}
252
- json_options['FORMAT'] = 'JSON'
253
- if 'ARRAY' not in {k.upper() for k in json_options}:
254
- json_options['ARRAY'] = array_format
255
- self._export_data(source, output_path, "JSON", json_options)
256
-
257
- def export_data_to_jsonl(self,
258
- source: str,
259
- output_path: Union[str, Path],
260
- options: Optional[Dict[str, Any]] = None):
261
- self.export_data_to_json(source, output_path, array_format=False, options=options)
262
-
263
-
264
- # # --- Streaming Read Methods --- (Keep as before)
265
- # def stream_query_arrow(self,
266
- # sql: str,
267
- # parameters: Optional[List[Any]] = None,
268
- # batch_size: int = 1000000
269
- # ) -> Iterator[pa.RecordBatch]:
270
- # self._ensure_connection()
271
- # print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
272
- # try:
273
- # result_set = self._conn.execute(sql, parameters)
274
- # while True:
275
- # batch = result_set.fetch_record_batch(batch_size)
276
- # if not batch:
277
- # break
278
- # yield batch
279
- # except ImportError:
280
- # print("PyArrow library is required for Arrow streaming.")
281
- # raise
282
- # except duckdb.Error as e:
283
- # print(f"Error streaming Arrow query: {e}")
284
- # raise QueryError(f"Error streaming Arrow query: {e}") from e
285
-
286
- def stream_query_df(self,
287
- sql: str,
288
- parameters: Optional[List[Any]] = None,
289
- vectors_per_chunk: int = 1
290
- ) -> Iterator[pd.DataFrame]:
291
- self._ensure_connection()
292
- print(f"Streaming DataFrame query (vectors per chunk {vectors_per_chunk}): {sql}")
293
- try:
294
- result_set = self._conn.execute(sql, parameters)
295
- while True:
296
- chunk_df = result_set.fetch_df_chunk(vectors_per_chunk)
297
- if chunk_df.empty:
298
- break
299
- yield chunk_df
300
- except ImportError:
301
- print("Pandas library is required for DataFrame streaming.")
302
- raise
303
- except duckdb.Error as e:
304
- print(f"Error streaming DataFrame query: {e}")
305
- raise QueryError(f"Error streaming DataFrame query: {e}") from e
306
-
307
- def stream_query_arrow(self,
308
- sql: str,
309
- parameters: Optional[List[Any]] = None,
310
- batch_size: int = 1000000
311
- ) -> Iterator[pa.RecordBatch]:
312
- """
313
- Executes a SQL query and streams the results as Arrow RecordBatches.
314
- Useful for processing large results iteratively in Python without
315
- loading the entire result set into memory.
316
-
317
- Args:
318
- sql: The SQL query to execute.
319
- parameters: Optional list of parameters for prepared statements.
320
- batch_size: The approximate number of rows per Arrow RecordBatch.
321
-
322
- Yields:
323
- pyarrow.RecordBatch: Chunks of the result set.
324
-
325
- Raises:
326
- QueryError: If the query execution or fetching fails.
327
- ImportError: If pyarrow is not installed.
328
- """
329
- self._ensure_connection()
330
- print(f"Streaming Arrow query (batch size {batch_size}): {sql}")
331
- record_batch_reader = None
332
- try:
333
- # Use execute() to get a result object that supports streaming fetch
334
- result_set = self._conn.execute(sql, parameters)
335
- # --- MODIFICATION: Get the reader first ---
336
- record_batch_reader = result_set.fetch_record_batch(batch_size)
337
- # --- Iterate through the reader ---
338
- for batch in record_batch_reader:
339
- yield batch
340
- # --- END MODIFICATION ---
341
- except ImportError:
342
- print("PyArrow library is required for Arrow streaming.")
343
- raise
344
- except duckdb.Error as e:
345
- print(f"Error streaming Arrow query: {e}")
346
- raise QueryError(f"Error streaming Arrow query: {e}") from e
347
- finally:
348
- # Clean up the reader if it was created
349
- if record_batch_reader is not None:
350
- # PyArrow readers don't have an explicit close, relying on GC.
351
- # Forcing cleanup might involve ensuring references are dropped.
352
- del record_batch_reader # Help GC potentially
353
- # The original result_set from execute() might also hold resources,
354
- # although fetch_record_batch typically consumes it.
355
- # Explicitly closing it if possible, or letting it go out of scope.
356
- if 'result_set' in locals() and result_set:
357
- try:
358
- # DuckDBPyResult doesn't have an explicit close, relies on __del__
359
- del result_set
360
- except Exception:
361
- pass # Best effort
362
-
363
- # --- Resource Management Methods --- (Keep as before)
364
- def close(self):
365
- if self._conn:
366
- conn_id = id(self._conn)
367
- print(f"Closing connection to '{self._db_path}' (ID: {conn_id})")
368
- try:
369
- self._conn.close()
370
- except duckdb.Error as e:
371
- print(f"Error closing DuckDB connection (ID: {conn_id}): {e}")
372
- finally:
373
- self._conn = None
374
- else:
375
- print("Connection already closed or never opened.")
376
-
377
- def __enter__(self):
378
- self._ensure_connection()
379
- return self
380
-
381
- def __exit__(self, exc_type, exc_value, traceback):
382
- self.close()
383
-
384
- def __del__(self):
385
- if self._conn:
386
- print(f"ResourceWarning: DatabaseAPI for '{self._db_path}' was not explicitly closed. Closing now in __del__.")
387
- try:
388
- self.close()
389
- except Exception as e:
390
- print(f"Exception during implicit close in __del__: {e}")
391
- self._conn = None
392
-
393
-
394
- # --- Example Usage --- (Keep as before)
395
- if __name__ == "__main__":
396
- # ... (rest of the example usage code from previous response) ...
397
- temp_dir_obj = tempfile.TemporaryDirectory()
398
- temp_dir = temp_dir_obj.name
399
- print(f"\n--- Using temporary directory: {temp_dir} ---")
400
- db_file = Path(temp_dir) / "export_test.db"
401
- try:
402
- with DatabaseAPI(db_path=db_file) as db_api:
403
- db_api.execute_sql("CREATE OR REPLACE TABLE products(id INTEGER, name VARCHAR, price DECIMAL(8,2))")
404
- db_api.execute_sql("INSERT INTO products VALUES (101, 'Gadget', 19.99), (102, 'Widget', 35.00), (103, 'Thing''amajig', 9.50)")
405
- db_api.execute_sql("CREATE OR REPLACE TABLE sales(product_id INTEGER, sale_date DATE, quantity INTEGER)")
406
- db_api.execute_sql("INSERT INTO sales VALUES (101, '2023-10-26', 5), (102, '2023-10-26', 2), (101, '2023-10-27', 3)")
407
- export_dir = Path(temp_dir) / "exported_db"
408
- db_api.export_database(export_dir)
409
- csv_path = Path(temp_dir) / "products_export.csv"
410
- db_api.export_data_to_csv('products', csv_path, options={'HEADER': True})
411
- parquet_path = Path(temp_dir) / "high_value_products.parquet"
412
- db_api.export_data_to_parquet("SELECT * FROM products WHERE price > 20", parquet_path, options={'COMPRESSION': 'SNAPPY'})
413
- json_path = Path(temp_dir) / "sales.json"
414
- db_api.export_data_to_json("SELECT * FROM sales", json_path, array_format=True)
415
- jsonl_path = Path(temp_dir) / "sales.jsonl"
416
- db_api.export_data_to_jsonl("SELECT * FROM sales ORDER BY sale_date", jsonl_path)
417
-
418
- with DatabaseAPI() as db_api:
419
- db_api.execute_sql("CREATE TABLE large_range AS SELECT range AS id, range % 100 AS category FROM range(1000)")
420
- for batch in db_api.stream_query_arrow("SELECT * FROM large_range", batch_size=200):
421
- pass
422
- for df_chunk in db_api.stream_query_df("SELECT * FROM large_range", vectors_per_chunk=1):
423
- pass
424
- finally:
425
- temp_dir_obj.cleanup()
426
- print(f"\n--- Cleaned up temporary directory: {temp_dir} ---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,389 +1,382 @@
1
- # main.py
2
  import duckdb
3
- import pandas as pd
4
- import pyarrow as pa
5
- import pyarrow.ipc
6
- from pathlib import Path
7
- import tempfile
8
  import os
9
- import shutil
10
- from typing import Optional, List, Dict, Any, Union, Iterator, Generator, Tuple
11
-
12
- from fastapi import FastAPI, HTTPException, Body, Query, BackgroundTasks, Depends
13
- from fastapi.responses import StreamingResponse, FileResponse
14
  from pydantic import BaseModel, Field
 
 
 
 
15
 
16
- from database_api import DatabaseAPI, DatabaseAPIError, QueryError
17
-
18
- # --- Configuration --- (Keep as before)
19
- DUCKDB_API_DB_PATH = os.getenv("DUCKDB_API_DB_PATH", "api_database.db")
20
- DUCKDB_API_READ_ONLY = os.getenv("DUCKDB_API_READ_ONLY", False)
21
- DUCKDB_API_CONFIG = {}
22
- TEMP_EXPORT_DIR = Path(tempfile.gettempdir()) / "duckdb_api_exports"
23
- TEMP_EXPORT_DIR.mkdir(exist_ok=True)
24
- print(f"Using temporary directory for exports: {TEMP_EXPORT_DIR}")
25
-
26
- # --- Pydantic Models --- (Keep as before)
27
- class StatusResponse(BaseModel):
28
- status: str
29
- message: Optional[str] = None
30
-
31
- class ExecuteRequest(BaseModel):
32
- sql: str
33
- parameters: Optional[List[Any]] = None
34
-
35
- class QueryRequest(BaseModel):
36
- sql: str
37
- parameters: Optional[List[Any]] = None
38
 
39
- class DataFrameResponse(BaseModel):
40
- columns: List[str]
41
- records: List[Dict[str, Any]]
42
 
43
- class InstallRequest(BaseModel):
44
- extension_name: str
45
- force_install: bool = False
46
 
47
- class LoadRequest(BaseModel):
48
- extension_name: str
49
-
50
- class ExportDataRequest(BaseModel):
51
- source: str = Field(..., description="Table name or SQL SELECT query to export")
52
- options: Optional[Dict[str, Any]] = Field(None, description="Format-specific export options")
53
-
54
- # --- FastAPI Application --- (Keep as before)
55
  app = FastAPI(
56
- title="DuckDB API Wrapper",
57
- description="Exposes DuckDB functionalities via a RESTful API.",
58
- version="0.2.1" # Incremented version
59
  )
60
 
61
- # --- Global DatabaseAPI Instance & Lifecycle --- (Keep as before)
62
- db_api_instance: Optional[DatabaseAPI] = None
63
-
64
- @app.on_event("startup")
65
- async def startup_event():
66
- global db_api_instance
67
- print("Starting up DuckDB API...")
68
- try:
69
- db_api_instance = DatabaseAPI(db_path=DUCKDB_API_DB_PATH, read_only=DUCKDB_API_READ_ONLY, config=DUCKDB_API_CONFIG)
70
- except DatabaseAPIError as e:
71
- print(f"FATAL: Could not initialize DatabaseAPI on startup: {e}")
72
- db_api_instance = None
73
-
74
- @app.on_event("shutdown")
75
- def shutdown_event():
76
- print("Shutting down DuckDB API...")
77
- if db_api_instance:
78
- db_api_instance.close()
79
-
80
- # --- Dependency to get the DB API instance --- (Keep as before)
81
- def get_db_api() -> DatabaseAPI:
82
- if db_api_instance is None:
83
- raise HTTPException(status_code=503, detail="Database service is unavailable (failed to initialize).")
84
  try:
85
- db_api_instance._ensure_connection()
86
- return db_api_instance
87
- except DatabaseAPIError as e:
88
- raise HTTPException(status_code=503, detail=f"Database service error: {e}")
89
-
90
- # --- API Endpoints ---
91
-
92
- # --- CRUD and Querying Endpoints (Keep as before) ---
93
- @app.post("/execute", response_model=StatusResponse, tags=["CRUD"])
94
- async def execute_statement(request: ExecuteRequest, api: DatabaseAPI = Depends(get_db_api)):
95
- try:
96
- api.execute_sql(request.sql, request.parameters)
97
- return {"status": "success", "message": None} # Explicitly return None for message
98
- except QueryError as e:
99
- raise HTTPException(status_code=400, detail=str(e))
100
- except DatabaseAPIError as e:
101
- raise HTTPException(status_code=500, detail=str(e))
102
-
103
- @app.post("/query/fetchall", response_model=List[tuple], tags=["Querying"])
104
- async def query_fetchall_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
105
- try:
106
- return api.query_fetchall(request.sql, request.parameters)
107
- except QueryError as e:
108
- raise HTTPException(status_code=400, detail=str(e))
109
- except DatabaseAPIError as e:
110
- raise HTTPException(status_code=500, detail=str(e))
111
-
112
- @app.post("/query/dataframe", response_model=DataFrameResponse, tags=["Querying"])
113
- async def query_dataframe_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
114
- try:
115
- df = api.query_df(request.sql, request.parameters)
116
- df_serializable = df.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
117
- return {"columns": df_serializable.columns.tolist(), "records": df_serializable.to_dict(orient='records')}
118
- except (QueryError, ImportError) as e:
119
- raise HTTPException(status_code=400, detail=str(e))
120
- except DatabaseAPIError as e:
121
- raise HTTPException(status_code=500, detail=str(e))
122
-
123
- # --- Streaming Endpoints ---
124
-
125
- # --- CORRECTED _stream_arrow_ipc ---
126
- async def _stream_arrow_ipc(record_batch_iterator: Iterator[pa.RecordBatch]) -> Generator[bytes, None, None]:
127
- """Helper generator to stream Arrow IPC Stream format."""
128
- writer = None
129
- sink = pa.BufferOutputStream() # Create sink once
130
- try:
131
- first_batch = next(record_batch_iterator)
132
- writer = pa.ipc.new_stream(sink, first_batch.schema)
133
- writer.write_batch(first_batch)
134
- # Do NOT yield yet, wait for potential subsequent batches or closure
135
-
136
- for batch in record_batch_iterator:
137
- # Write subsequent batches to the SAME writer
138
- writer.write_batch(batch)
139
-
140
- except StopIteration:
141
- # Handles the case where the iterator was empty initially
142
- if writer is None: # No batches were ever processed
143
- print("Warning: Arrow stream iterator was empty.")
144
- # Yield empty bytes or handle as needed, depends on client expectation
145
- # yield b'' # Option 1: empty bytes
146
- return # Option 2: Just finish generator
147
-
148
- except Exception as e:
149
- print(f"Error during Arrow streaming generator: {e}")
150
- # Consider how to signal error downstream if possible
151
  finally:
152
- if writer:
153
- try:
154
- print("Closing Arrow IPC Stream Writer...")
155
- writer.close() # Close the writer to finalize the stream in the sink
156
- print("Writer closed.")
157
- except Exception as close_e:
158
- print(f"Error closing Arrow writer: {close_e}")
159
- if sink:
160
- try:
161
- buffer = sink.getvalue()
162
- if buffer:
163
- print(f"Yielding final Arrow buffer (size: {len(buffer.to_pybytes())})...")
164
- yield buffer.to_pybytes() # Yield the complete stream buffer
165
- else:
166
- print("Arrow sink buffer was empty after closing writer.")
167
- sink.close()
168
- except Exception as close_e:
169
- print(f"Error closing or getting value from Arrow sink: {close_e}")
170
- # --- END CORRECTION ---
171
-
172
-
173
- @app.post("/query/stream/arrow", tags=["Streaming"])
174
- async def query_stream_arrow_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
175
- """Executes a SQL query and streams results as Arrow IPC Stream format."""
176
- try:
177
- iterator = api.stream_query_arrow(request.sql, request.parameters)
178
- return StreamingResponse(
179
- _stream_arrow_ipc(iterator),
180
- media_type="application/vnd.apache.arrow.stream"
181
- )
182
- except (QueryError, ImportError) as e:
183
- raise HTTPException(status_code=400, detail=str(e))
184
- except DatabaseAPIError as e:
185
- raise HTTPException(status_code=500, detail=str(e))
186
-
187
- # --- _stream_jsonl (Keep as before) ---
188
- async def _stream_jsonl(dataframe_iterator: Iterator[pd.DataFrame]) -> Generator[bytes, None, None]:
189
- try:
190
- for df_chunk in dataframe_iterator:
191
- df_serializable = df_chunk.replace({pd.NA: None, pd.NaT: None, float('nan'): None})
192
- jsonl_string = df_serializable.to_json(orient='records', lines=True, date_format='iso')
193
- if jsonl_string:
194
- # pandas>=1.5.0 adds newline by default
195
- if not jsonl_string.endswith('\n'):
196
- jsonl_string += '\n'
197
- yield jsonl_string.encode('utf-8')
198
- except Exception as e:
199
- print(f"Error during JSONL streaming generator: {e}")
200
-
201
- @app.post("/query/stream/jsonl", tags=["Streaming"])
202
- async def query_stream_jsonl_endpoint(request: QueryRequest, api: DatabaseAPI = Depends(get_db_api)):
203
- """Executes a SQL query and streams results as JSON Lines (JSONL)."""
204
- try:
205
- iterator = api.stream_query_df(request.sql, request.parameters)
206
- return StreamingResponse(_stream_jsonl(iterator), media_type="application/jsonl")
207
- except (QueryError, ImportError) as e:
208
- raise HTTPException(status_code=400, detail=str(e))
209
- except DatabaseAPIError as e:
210
- raise HTTPException(status_code=500, detail=str(e))
211
 
 
212
 
213
- # --- Download / Export Endpoints (Keep as before, uses corrected _export_data) ---
214
- def _cleanup_temp_file(path: Union[str, Path]):
215
- try:
216
- if Path(path).is_file():
217
- os.remove(path)
218
- print(f"Cleaned up temporary file: {path}")
219
- except OSError as e:
220
- print(f"Error cleaning up temporary file {path}: {e}")
221
-
222
- async def _create_temp_export(
223
- api: DatabaseAPI,
224
- source: str,
225
- export_format: str,
226
- options: Optional[Dict[str, Any]] = None,
227
- suffix: str = ".tmp"
228
- ) -> Path:
229
- fd, temp_path_str = tempfile.mkstemp(suffix=suffix, dir=TEMP_EXPORT_DIR)
230
- os.close(fd)
231
- temp_file_path = Path(temp_path_str)
232
 
233
  try:
234
- print(f"Exporting to temporary file: {temp_file_path}")
235
- if export_format == 'csv':
236
- api.export_data_to_csv(source, temp_file_path, options)
237
- elif export_format == 'parquet':
238
- api.export_data_to_parquet(source, temp_file_path, options)
239
- elif export_format == 'json':
240
- api.export_data_to_json(source, temp_file_path, array_format=True, options=options)
241
- elif export_format == 'jsonl':
242
- api.export_data_to_jsonl(source, temp_file_path, options=options)
243
- else:
244
- raise ValueError(f"Unsupported export format: {export_format}")
245
- return temp_file_path
246
- except Exception as e:
247
- _cleanup_temp_file(temp_file_path)
248
  raise e
249
-
250
- @app.post("/export/data/csv", response_class=FileResponse, tags=["Export / Download"])
251
- async def export_csv_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
252
- try:
253
- temp_file_path = await _create_temp_export(api, request.source, 'csv', request.options, suffix=".csv")
254
- background_tasks.add_task(_cleanup_temp_file, temp_file_path)
255
- filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.csv"
256
- return FileResponse(temp_file_path, media_type='text/csv', filename=filename)
257
- except (QueryError, ValueError) as e:
258
- raise HTTPException(status_code=400, detail=str(e))
259
- except DatabaseAPIError as e:
260
- raise HTTPException(status_code=500, detail=str(e))
261
  except Exception as e:
262
- raise HTTPException(status_code=500, detail=f"Unexpected error during CSV export: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- @app.post("/export/data/parquet", response_class=FileResponse, tags=["Export / Download"])
265
- async def export_parquet_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
266
  try:
267
- temp_file_path = await _create_temp_export(api, request.source, 'parquet', request.options, suffix=".parquet")
268
- background_tasks.add_task(_cleanup_temp_file, temp_file_path)
269
- filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.parquet"
270
- return FileResponse(temp_file_path, media_type='application/vnd.apache.parquet', filename=filename)
271
- except (QueryError, ValueError) as e:
272
- raise HTTPException(status_code=400, detail=str(e))
273
- except DatabaseAPIError as e:
274
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
275
  except Exception as e:
276
- raise HTTPException(status_code=500, detail=f"Unexpected error during Parquet export: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- @app.post("/export/data/json", response_class=FileResponse, tags=["Export / Download"])
279
- async def export_json_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
280
  try:
281
- temp_file_path = await _create_temp_export(api, request.source, 'json', request.options, suffix=".json")
282
- background_tasks.add_task(_cleanup_temp_file, temp_file_path)
283
- filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.json"
284
- return FileResponse(temp_file_path, media_type='application/json', filename=filename)
285
- except (QueryError, ValueError) as e:
286
- raise HTTPException(status_code=400, detail=str(e))
287
- except DatabaseAPIError as e:
288
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
289
  except Exception as e:
290
- raise HTTPException(status_code=500, detail=f"Unexpected error during JSON export: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- @app.post("/export/data/jsonl", response_class=FileResponse, tags=["Export / Download"])
293
- async def export_jsonl_endpoint(request: ExportDataRequest, background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
294
- try:
295
- temp_file_path = await _create_temp_export(api, request.source, 'jsonl', request.options, suffix=".jsonl")
296
- background_tasks.add_task(_cleanup_temp_file, temp_file_path)
297
- filename = f"export_{Path(request.source).stem if '.' not in request.source else 'query'}.jsonl"
298
- return FileResponse(temp_file_path, media_type='application/jsonl', filename=filename)
299
- except (QueryError, ValueError) as e:
300
- raise HTTPException(status_code=400, detail=str(e))
301
- except DatabaseAPIError as e:
302
- raise HTTPException(status_code=500, detail=str(e))
303
- except Exception as e:
304
- raise HTTPException(status_code=500, detail=f"Unexpected error during JSONL export: {e}")
305
-
306
- @app.post("/export/database", response_class=FileResponse, tags=["Export / Download"])
307
- async def export_database_endpoint(background_tasks: BackgroundTasks, api: DatabaseAPI = Depends(get_db_api)):
308
- export_target_dir = Path(tempfile.mkdtemp(dir=TEMP_EXPORT_DIR))
309
- fd, zip_path_str = tempfile.mkstemp(suffix=".zip", dir=TEMP_EXPORT_DIR)
310
- os.close(fd)
311
- zip_file_path = Path(zip_path_str)
312
  try:
313
- print(f"Exporting database to temporary directory: {export_target_dir}")
314
- api.export_database(export_target_dir)
315
- print(f"Creating zip archive at: {zip_file_path}")
316
- shutil.make_archive(str(zip_file_path.with_suffix('')), 'zip', str(export_target_dir))
317
- print(f"Zip archive created: {zip_file_path}")
318
- background_tasks.add_task(shutil.rmtree, export_target_dir, ignore_errors=True)
319
- background_tasks.add_task(_cleanup_temp_file, zip_file_path)
320
- db_name = Path(api._db_path).stem if api._db_path != ':memory:' else 'in_memory_db'
321
- return FileResponse(zip_file_path, media_type='application/zip', filename=f"{db_name}_export.zip")
322
- except (QueryError, ValueError, OSError, DatabaseAPIError) as e:
323
- print(f"Error during database export: {e}")
324
- shutil.rmtree(export_target_dir, ignore_errors=True)
325
- _cleanup_temp_file(zip_file_path)
326
- if isinstance(e, DatabaseAPIError):
327
- raise HTTPException(status_code=500, detail=str(e))
328
- else:
329
- raise HTTPException(status_code=400, detail=str(e))
330
  except Exception as e:
331
- print(f"Unexpected error during database export: {e}")
332
- shutil.rmtree(export_target_dir, ignore_errors=True)
333
- _cleanup_temp_file(zip_file_path)
334
- raise HTTPException(status_code=500, detail=f"Unexpected error during database export: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
- # --- Extension Management Endpoints ---
337
-
338
- @app.post("/extensions/install", response_model=StatusResponse, tags=["Extensions"])
339
- async def install_extension_endpoint(request: InstallRequest, api: DatabaseAPI = Depends(get_db_api)):
340
- try:
341
- api.install_extension(request.extension_name, request.force_install)
342
- return {"status": "success", "message": f"Extension '{request.extension_name}' installed."}
343
- except DatabaseAPIError as e:
344
- raise HTTPException(status_code=500, detail=str(e))
345
- # Catch specific DuckDB errors that should be client errors (400)
346
- except (duckdb.IOException, duckdb.CatalogException, duckdb.InvalidInputException) as e:
347
- raise HTTPException(status_code=400, detail=f"DuckDB Error during install: {e}")
348
- except duckdb.Error as e: # Catch other potential DuckDB errors as 500
349
- raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during install: {e}")
350
-
351
-
352
- @app.post("/extensions/load", response_model=StatusResponse, tags=["Extensions"])
353
- async def load_extension_endpoint(request: LoadRequest, api: DatabaseAPI = Depends(get_db_api)):
354
- """Loads an installed DuckDB extension."""
355
  try:
356
- api.load_extension(request.extension_name)
357
- return {"status": "success", "message": f"Extension '{request.extension_name}' loaded."}
358
- # --- MODIFIED Exception Handling ---
359
- except QueryError as e: # If api.load_extension raised QueryError (e.g., IO/Catalog)
360
- raise HTTPException(status_code=400, detail=str(e))
361
- except DatabaseAPIError as e: # For other API-level issues
362
- raise HTTPException(status_code=500, detail=str(e))
363
- # Catch specific DuckDB errors that should be client errors (400)
364
- except (duckdb.IOException, duckdb.CatalogException) as e:
365
- raise HTTPException(status_code=400, detail=f"DuckDB Error during load: {e}")
366
- except duckdb.Error as e: # Catch other potential DuckDB errors as 500
367
- raise HTTPException(status_code=500, detail=f"Unexpected DuckDB Error during load: {e}")
368
- # --- END MODIFICATION ---
369
-
370
- # --- Health Check --- (Keep as before)
371
- @app.get("/health", response_model=StatusResponse, tags=["Health"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  async def health_check():
373
- """Basic health check."""
374
  try:
375
- _ = get_db_api()
376
- return {"status": "ok", "message": None} # Explicitly return None for message
377
- except HTTPException as e:
378
- raise e
379
  except Exception as e:
380
- raise HTTPException(status_code=500, detail=f"Health check failed unexpectedly: {e}")
381
-
382
- # --- Run the app --- (Keep as before)
383
- if __name__ == "__main__":
384
- import uvicorn
385
- print(f"Starting DuckDB API server...")
386
- print(f"Database file configured at: {DUCKDB_API_DB_PATH}")
387
- print(f"Read-only mode: {DUCKDB_API_READ_ONLY}")
388
- print(f"Temporary export directory: {TEMP_EXPORT_DIR}")
389
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
 
 
 
 
1
  import duckdb
 
 
 
 
 
2
  import os
3
+ from fastapi import FastAPI, HTTPException, Request, Path as FastPath
4
+ from fastapi.responses import FileResponse, StreamingResponse
 
 
 
5
  from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Any, Optional
7
+ import logging
8
+ import io
9
+ import asyncio
10
 
11
+ # --- Configuration ---
12
+ DATABASE_PATH = os.environ.get("DUCKDB_PATH", "data/mydatabase.db")
13
+ DATA_DIR = "data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Ensure data directory exists
16
+ os.makedirs(DATA_DIR, exist_ok=True)
 
17
 
18
+ # --- Logging ---
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
 
22
+ # --- FastAPI App ---
 
 
 
 
 
 
 
23
  app = FastAPI(
24
+ title="DuckDB API",
25
+ description="An API to interact with a DuckDB database.",
26
+ version="0.1.0"
27
  )
28
 
29
+ # --- Database Connection ---
30
+ # For simplicity in this example, we connect within each request.
31
+ # For production, consider dependency injection or connection pooling.
32
+ def get_db():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
+ # Check if the database file needs initialization
35
+ initialize = not os.path.exists(DATABASE_PATH) or os.path.getsize(DATABASE_PATH) == 0
36
+ conn = duckdb.connect(DATABASE_PATH, read_only=False)
37
+ if initialize:
38
+ logger.info(f"Database file not found or empty at {DATABASE_PATH}. Initializing.")
39
+ # You could add initial schema setup here if needed
40
+ # conn.execute("CREATE TABLE IF NOT EXISTS initial_table (id INTEGER, name VARCHAR);")
41
+ yield conn
42
+ except duckdb.Error as e:
43
+ logger.error(f"Database connection error: {e}")
44
+ raise HTTPException(status_code=500, detail=f"Database connection error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  finally:
46
+ if 'conn' in locals() and conn:
47
+ conn.close()
48
+
49
+ # --- Pydantic Models ---
50
+ class ColumnDefinition(BaseModel):
51
+ name: str
52
+ type: str
53
+
54
+ class CreateTableRequest(BaseModel):
55
+ columns: List[ColumnDefinition]
56
+
57
+ class CreateRowRequest(BaseModel):
58
+ # List of rows, where each row is a dict of column_name: value
59
+ rows: List[Dict[str, Any]]
60
+
61
+ class UpdateRowRequest(BaseModel):
62
+ updates: Dict[str, Any] # Column value pairs to set
63
+ condition: str # SQL WHERE clause string to identify rows
64
+
65
+ class DeleteRowRequest(BaseModel):
66
+ condition: str # SQL WHERE clause string to identify rows
67
+
68
+ class ApiResponse(BaseModel):
69
+ message: str
70
+ details: Optional[Any] = None
71
+
72
+ # --- Helper Functions ---
73
+ def safe_identifier(name: str) -> str:
74
+ """Quotes an identifier safely."""
75
+ if not name.isidentifier():
76
+ # Basic check, consider more robust validation/sanitization if needed
77
+ # Use DuckDB's quoting
78
+ try:
79
+ conn = duckdb.connect(':memory:')
80
+ quoted = conn.execute(f"SELECT '{name}'::IDENTIFIER").fetchone()[0]
81
+ conn.close()
82
+ return quoted
83
+ except duckdb.Error:
84
+ raise HTTPException(status_code=400, detail=f"Invalid identifier: {name}")
85
+ # Also quote standard identifiers to be safe
86
+ return f'"{name}"'
87
+
88
+ def generate_column_sql(columns: List[ColumnDefinition]) -> str:
89
+ """Generates the column definition part of a CREATE TABLE statement."""
90
+ defs = []
91
+ for col in columns:
92
+ col_name_safe = safe_identifier(col.name)
93
+ # Basic type validation (can be expanded)
94
+ allowed_types = ['INTEGER', 'VARCHAR', 'TEXT', 'BOOLEAN', 'FLOAT', 'DOUBLE', 'DATE', 'TIMESTAMP', 'BLOB', 'BIGINT', 'DECIMAL']
95
+ type_upper = col.type.strip().upper()
96
+ # Allow DECIMAL(p,s) syntax
97
+ if not (type_upper.startswith('DECIMAL(') and type_upper.endswith(')')) and \
98
+ not any(base_type in type_upper for base_type in allowed_types):
99
+ raise HTTPException(status_code=400, detail=f"Unsupported or invalid data type: {col.type}")
100
+ defs.append(f"{col_name_safe} {col.type}")
101
+ return ", ".join(defs)
 
 
 
102
 
103
+ # --- API Endpoints ---
104
 
105
+ @app.get("/", summary="API Root", response_model=ApiResponse)
106
+ async def read_root():
107
+ """Provides a welcome message for the API."""
108
+ return {"message": "Welcome to the DuckDB API!"}
109
+
110
+ @app.post("/tables/{table_name}", summary="Create Table", response_model=ApiResponse, status_code=201)
111
+ async def create_table(
112
+ table_name: str = FastPath(..., description="Name of the table to create"),
113
+ schema: CreateTableRequest = ...,
114
+ ):
115
+ """Creates a new table with the specified schema."""
116
+ table_name_safe = safe_identifier(table_name)
117
+ if not schema.columns:
118
+ raise HTTPException(status_code=400, detail="Table must have at least one column.")
 
 
 
 
 
119
 
120
  try:
121
+ columns_sql = generate_column_sql(schema.columns)
122
+ sql = f"CREATE TABLE {table_name_safe} ({columns_sql});"
123
+ logger.info(f"Executing SQL: {sql}")
124
+ for conn in get_db():
125
+ conn.execute(sql)
126
+ return {"message": f"Table '{table_name}' created successfully."}
127
+ except HTTPException as e: # Re-raise validation errors
 
 
 
 
 
 
 
128
  raise e
129
+ except duckdb.Error as e:
130
+ logger.error(f"Error creating table '{table_name}': {e}")
131
+ raise HTTPException(status_code=400, detail=f"Error creating table: {e}")
 
 
 
 
 
 
 
 
 
132
  except Exception as e:
133
+ logger.error(f"Unexpected error creating table '{table_name}': {e}")
134
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
135
+
136
+ @app.get("/tables/{table_name}", summary="Read Table Data")
137
+ async def read_table(
138
+ table_name: str = FastPath(..., description="Name of the table to read from"),
139
+ limit: Optional[int] = None,
140
+ offset: Optional[int] = None
141
+ ):
142
+ """Reads and returns all rows from a specified table. Supports limit and offset."""
143
+ table_name_safe = safe_identifier(table_name)
144
+ sql = f"SELECT * FROM {table_name_safe}"
145
+ params = []
146
+ if limit is not None:
147
+ sql += " LIMIT ?"
148
+ params.append(limit)
149
+ if offset is not None:
150
+ sql += " OFFSET ?"
151
+ params.append(offset)
152
+ sql += ";"
153
 
 
 
154
  try:
155
+ logger.info(f"Executing SQL: {sql} with params: {params}")
156
+ for conn in get_db():
157
+ result = conn.execute(sql, params).fetchall()
158
+ # Convert rows to dictionaries for JSON serialization
159
+ column_names = [desc[0] for desc in conn.description]
160
+ data = [dict(zip(column_names, row)) for row in result]
161
+ return data
162
+ except duckdb.CatalogException as e:
163
+ raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
164
+ except duckdb.Error as e:
165
+ logger.error(f"Error reading table '{table_name}': {e}")
166
+ raise HTTPException(status_code=400, detail=f"Error reading table: {e}")
167
  except Exception as e:
168
+ logger.error(f"Unexpected error reading table '{table_name}': {e}")
169
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
170
+
171
+
172
+ @app.post("/tables/{table_name}/rows", summary="Create Rows", response_model=ApiResponse, status_code=201)
173
+ async def create_rows(
174
+ table_name: str = FastPath(..., description="Name of the table to insert into"),
175
+ request: CreateRowRequest = ...,
176
+ ):
177
+ """Inserts one or more rows into the specified table."""
178
+ table_name_safe = safe_identifier(table_name)
179
+ if not request.rows:
180
+ raise HTTPException(status_code=400, detail="No rows provided to insert.")
181
+
182
+ # Assume all rows have the same columns based on the first row
183
+ columns = list(request.rows[0].keys())
184
+ columns_safe = [safe_identifier(col) for col in columns]
185
+ placeholders = ", ".join(["?"] * len(columns))
186
+ columns_sql = ", ".join(columns_safe)
187
+
188
+ sql = f"INSERT INTO {table_name_safe} ({columns_sql}) VALUES ({placeholders});"
189
+
190
+ # Convert list of dicts to list of lists/tuples for executemany
191
+ params_list = []
192
+ for row_dict in request.rows:
193
+ if list(row_dict.keys()) != columns:
194
+ raise HTTPException(status_code=400, detail="All rows must have the same columns in the same order.")
195
+ params_list.append(list(row_dict.values()))
196
 
 
 
197
  try:
198
+ logger.info(f"Executing SQL: {sql} for {len(params_list)} rows")
199
+ for conn in get_db():
200
+ conn.executemany(sql, params_list)
201
+ conn.commit() # Explicit commit after potential bulk insert
202
+ return {"message": f"Successfully inserted {len(params_list)} rows into '{table_name}'."}
203
+ except duckdb.CatalogException as e:
204
+ raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
205
+ except duckdb.Error as e:
206
+ logger.error(f"Error inserting rows into '{table_name}': {e}")
207
+ # Rollback on error might be needed depending on transaction behavior
208
+ # For get_db creating connection per request, this is less critical
209
+ raise HTTPException(status_code=400, detail=f"Error inserting rows: {e}")
210
  except Exception as e:
211
+ logger.error(f"Unexpected error inserting rows into '{table_name}': {e}")
212
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
213
+
214
+
215
+ @app.put("/tables/{table_name}/rows", summary="Update Rows", response_model=ApiResponse)
216
+ async def update_rows(
217
+ table_name: str = FastPath(..., description="Name of the table to update"),
218
+ request: UpdateRowRequest = ...,
219
+ ):
220
+ """Updates rows in the table based on a condition."""
221
+ table_name_safe = safe_identifier(table_name)
222
+ if not request.updates:
223
+ raise HTTPException(status_code=400, detail="No updates provided.")
224
+ if not request.condition:
225
+ raise HTTPException(status_code=400, detail="Update condition (WHERE clause) is required.")
226
+
227
+ set_clauses = []
228
+ params = []
229
+ for col, value in request.updates.items():
230
+ set_clauses.append(f"{safe_identifier(col)} = ?")
231
+ params.append(value)
232
+
233
+ set_sql = ", ".join(set_clauses)
234
+ # WARNING: Injecting request.condition directly is a security risk.
235
+ # In a real app, use query parameters or a safer way to build the WHERE clause.
236
+ sql = f"UPDATE {table_name_safe} SET {set_sql} WHERE {request.condition};"
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  try:
239
+ logger.info(f"Executing SQL: {sql} with params: {params}")
240
+ for conn in get_db():
241
+ # Use execute for safety with parameters
242
+ conn.execute(sql, params)
243
+ conn.commit()
244
+ return {"message": f"Rows in '{table_name}' updated successfully based on condition."}
245
+ except duckdb.CatalogException as e:
246
+ raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
247
+ except duckdb.Error as e:
248
+ logger.error(f"Error updating rows in '{table_name}': {e}")
249
+ raise HTTPException(status_code=400, detail=f"Error updating rows: {e}")
 
 
 
 
 
 
250
  except Exception as e:
251
+ logger.error(f"Unexpected error updating rows in '{table_name}': {e}")
252
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
253
+
254
+ @app.delete("/tables/{table_name}/rows", summary="Delete Rows", response_model=ApiResponse)
255
+ async def delete_rows(
256
+ table_name: str = FastPath(..., description="Name of the table to delete from"),
257
+ request: DeleteRowRequest = ...,
258
+ ):
259
+ """Deletes rows from the table based on a condition."""
260
+ table_name_safe = safe_identifier(table_name)
261
+ if not request.condition:
262
+ raise HTTPException(status_code=400, detail="Delete condition (WHERE clause) is required.")
263
+
264
+ # WARNING: Injecting request.condition directly is a security risk.
265
+ # In a real app, use query parameters or a safer way to build the WHERE clause.
266
+ sql = f"DELETE FROM {table_name_safe} WHERE {request.condition};"
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  try:
269
+ logger.info(f"Executing SQL: {sql}")
270
+ for conn in get_db():
271
+ # Execute does not directly support parameters for WHERE in DELETE like this easily
272
+ conn.execute(sql)
273
+ conn.commit()
274
+ return {"message": f"Rows from '{table_name}' deleted successfully based on condition."}
275
+ except duckdb.CatalogException as e:
276
+ raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
277
+ except duckdb.Error as e:
278
+ logger.error(f"Error deleting rows from '{table_name}': {e}")
279
+ raise HTTPException(status_code=400, detail=f"Error deleting rows: {e}")
280
+ except Exception as e:
281
+ logger.error(f"Unexpected error deleting rows from '{table_name}': {e}")
282
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
283
+
284
+ # --- Download Endpoints ---
285
+
286
+ @app.get("/download/table/{table_name}", summary="Download Table as CSV")
287
+ async def download_table_csv(
288
+ table_name: str = FastPath(..., description="Name of the table to download")
289
+ ):
290
+ """Downloads the entire content of a table as a CSV file."""
291
+ table_name_safe = safe_identifier(table_name)
292
+ # Use COPY TO STDOUT for efficient streaming
293
+ sql = f"COPY (SELECT * FROM {table_name_safe}) TO STDOUT (FORMAT CSV, HEADER)"
294
+
295
+ async def stream_csv_data():
296
+ # We need a non-blocking way to stream data from DuckDB.
297
+ # DuckDB's Python API is blocking. A simple approach for this demo
298
+ # is to fetch all data first, then stream it.
299
+ # A more advanced approach would involve running the DuckDB query
300
+ # in a separate thread or process pool managed by asyncio.
301
+
302
+ try:
303
+ all_data_io = io.StringIO()
304
+ # This COPY TO variant isn't directly available in Python API for streaming to a buffer easily.
305
+ # Let's fetch data and format as CSV manually or use Pandas.
306
+ for conn in get_db():
307
+ df = conn.execute(f"SELECT * FROM {table_name_safe}").df() # Use pandas for CSV conversion
308
+
309
+ # Use an in-memory text buffer
310
+ df.to_csv(all_data_io, index=False)
311
+ all_data_io.seek(0)
312
+
313
+ # Stream the content chunk by chunk
314
+ chunk_size = 8192
315
+ while True:
316
+ chunk = all_data_io.read(chunk_size)
317
+ if not chunk:
318
+ break
319
+ yield chunk
320
+ # Allow other tasks to run
321
+ await asyncio.sleep(0)
322
+ all_data_io.close()
323
+
324
+ except duckdb.CatalogException as e:
325
+ # Stream an error message if the table doesn't exist
326
+ yield f"Error: Table '{table_name}' not found.".encode('utf-8')
327
+ logger.error(f"Error downloading table '{table_name}': {e}")
328
+ except duckdb.Error as e:
329
+ yield f"Error: Could not export table '{table_name}'. {e}".encode('utf-8')
330
+ logger.error(f"Error downloading table '{table_name}': {e}")
331
+ except Exception as e:
332
+ yield f"Error: An unexpected error occurred.".encode('utf-8')
333
+ logger.error(f"Unexpected error downloading table '{table_name}': {e}")
334
+
335
+
336
+ return StreamingResponse(
337
+ stream_csv_data(),
338
+ media_type="text/csv",
339
+ headers={"Content-Disposition": f"attachment; filename={table_name}.csv"},
340
+ )
341
+
342
+
343
+ @app.get("/download/database", summary="Download Database File")
344
+ async def download_database_file():
345
+ """Downloads the entire DuckDB database file."""
346
+ if not os.path.exists(DATABASE_PATH):
347
+ raise HTTPException(status_code=404, detail="Database file not found.")
348
+
349
+ # Ensure connections are closed before downloading to avoid partial writes/locking issues.
350
+ # This is tricky with the current get_db pattern. A proper app stop/start or
351
+ # dedicated maintenance mode would be better. For this demo, we hope for the best.
352
+ logger.warning("Attempting to download database file. Ensure no active writes are occurring.")
353
+
354
+ return FileResponse(
355
+ path=DATABASE_PATH,
356
+ filename=os.path.basename(DATABASE_PATH),
357
+ media_type="application/octet-stream" # Generic binary file type
358
+ )
359
+
360
+
361
+ # --- Health Check ---
362
+ @app.get("/health", summary="Health Check", response_model=ApiResponse)
363
  async def health_check():
364
+ """Checks if the API and database connection are working."""
365
  try:
366
+ for conn in get_db():
367
+ conn.execute("SELECT 1")
368
+ return {"message": "API is healthy and database connection is successful."}
 
369
  except Exception as e:
370
+ logger.error(f"Health check failed: {e}")
371
+ raise HTTPException(status_code=503, detail=f"Health check failed: {e}")
372
+
373
+ # --- Optional: Add Startup/Shutdown events if needed ---
374
+ # @app.on_event("startup")
375
+ # async def startup_event():
376
+ # # Initialize database connection pool, etc.
377
+ # logger.info("Application startup.")
378
+
379
+ # @app.on_event("shutdown")
380
+ # async def shutdown_event():
381
+ # # Clean up resources, close connections, etc.
382
+ # logger.info("Application shutdown.")
requirements.txt CHANGED
@@ -1,10 +1,5 @@
1
- fastapi[all]>=0.95.0
2
- uvicorn[standard]>=0.18.0
3
- duckdb>=1.2.1
4
- pydantic
5
- python-multipart>=0.0.5
6
- httpx
7
- requests>=2.20.0
8
- aiofiles>=0.8.0
9
- pandas>=1.5.0
10
- pyarrow>=10.0.0
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ duckdb>=0.9.0
4
+ python-multipart
5
+ aiofiles
 
 
 
 
 
test_api.py DELETED
@@ -1,246 +0,0 @@
1
- import pytest
2
- import os
3
- import shutil
4
- import tempfile
5
- import zipfile
6
- import json
7
- from pathlib import Path
8
- from typing import List, Dict, Any
9
- from unittest.mock import patch
10
-
11
- pd = pytest.importorskip("pandas")
12
- pa = pytest.importorskip("pyarrow")
13
- pa_ipc = pytest.importorskip("pyarrow.ipc")
14
-
15
- from fastapi.testclient import TestClient
16
- import main # Import main to reload and access config
17
-
18
- # --- Test Fixtures --- (Keep client fixture as before)
19
- @pytest.fixture(scope="module")
20
- def client():
21
- with patch.dict(os.environ, {"DUCKDB_API_DB_PATH": ":memory:"}):
22
- import importlib
23
- importlib.reload(main)
24
- main.TEMP_EXPORT_DIR.mkdir(exist_ok=True)
25
- print(f"TestClient using temp export dir: {main.TEMP_EXPORT_DIR}")
26
- with TestClient(main.app) as c:
27
- yield c
28
- print(f"Cleaning up test export dir: {main.TEMP_EXPORT_DIR}")
29
- for item in main.TEMP_EXPORT_DIR.iterdir():
30
- try:
31
- if item.is_file():
32
- os.remove(item)
33
- elif item.is_dir():
34
- shutil.rmtree(item)
35
- except Exception as e:
36
- print(f"Error cleaning up {item}: {e}")
37
-
38
- # --- Test Classes ---
39
-
40
- class TestHealth: # (Keep as before)
41
- def test_health_check(self, client: TestClient):
42
- response = client.get("/health")
43
- assert response.status_code == 200
44
- assert response.json() == {"status": "ok", "message": None}
45
-
46
- class TestExecution: # (Keep as before)
47
- def test_execute_create(self, client: TestClient):
48
- response = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER, name VARCHAR);"})
49
- assert response.status_code == 200
50
- assert response.json() == {"status": "success", "message": None}
51
- response_fail = client.post("/execute", json={"sql": "CREATE TABLE test_table(id INTEGER);"})
52
- assert response_fail.status_code == 400
53
-
54
- def test_execute_insert(self, client: TestClient):
55
- client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
56
- response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (1, 'Alice')"})
57
- assert response.status_code == 200
58
- query_response = client.post("/query/fetchall", json={"sql": "SELECT COUNT(*) FROM test_table"})
59
- assert query_response.status_code == 200
60
- assert query_response.json() == [[1]]
61
-
62
- def test_execute_insert_params(self, client: TestClient):
63
- client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE test_table(id INTEGER, name VARCHAR);"})
64
- response = client.post("/execute", json={"sql": "INSERT INTO test_table VALUES (?, ?)", "parameters": [2, "Bob"]})
65
- assert response.status_code == 200
66
- query_response = client.post("/query/fetchall", json={"sql": "SELECT * FROM test_table WHERE id = 2"})
67
- assert query_response.status_code == 200
68
- assert query_response.json() == [[2, "Bob"]]
69
-
70
- def test_execute_invalid_sql(self, client: TestClient):
71
- response = client.post("/execute", json={"sql": "INVALID SQL STATEMENT"})
72
- assert response.status_code == 400
73
- assert "Parser Error" in response.json()["detail"]
74
-
75
- class TestQuerying: # (Keep as before)
76
- @pytest.fixture(scope="class", autouse=True)
77
- def setup_data(self, client: TestClient):
78
- client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE query_test(id INTEGER, val VARCHAR)"})
79
- client.post("/execute", json={"sql": "INSERT INTO query_test VALUES (1, 'one'), (2, 'two'), (3, 'three')"})
80
-
81
- def test_query_fetchall(self, client: TestClient):
82
- response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test ORDER BY id"})
83
- assert response.status_code == 200
84
- assert response.json() == [[1, 'one'], [2, 'two'], [3, 'three']]
85
-
86
- def test_query_fetchall_params(self, client: TestClient):
87
- response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > ? ORDER BY id", "parameters": [1]})
88
- assert response.status_code == 200
89
- assert response.json() == [[2, 'two'], [3, 'three']]
90
-
91
- def test_query_fetchall_empty(self, client: TestClient):
92
- response = client.post("/query/fetchall", json={"sql": "SELECT * FROM query_test WHERE id > 100"})
93
- assert response.status_code == 200
94
- assert response.json() == []
95
-
96
- def test_query_dataframe(self, client: TestClient):
97
- response = client.post("/query/dataframe", json={"sql": "SELECT * FROM query_test ORDER BY id"})
98
- assert response.status_code == 200
99
- data = response.json()
100
- assert data["columns"] == ["id", "val"]
101
- assert data["records"] == [
102
- {"id": 1, "val": "one"},
103
- {"id": 2, "val": "two"},
104
- {"id": 3, "val": "three"}
105
- ]
106
-
107
- def test_query_dataframe_invalid_sql(self, client: TestClient):
108
- response = client.post("/query/dataframe", json={"sql": "SELECT non_existent FROM query_test"})
109
- assert response.status_code == 400
110
- assert "Binder Error" in response.json()["detail"]
111
-
112
- class TestStreaming: # (Keep as before)
113
- @pytest.fixture(scope="class", autouse=True)
114
- def setup_data(self, client: TestClient):
115
- client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE stream_test AS SELECT range AS id, range % 5 AS category FROM range(10)"})
116
-
117
- def test_stream_arrow(self, client: TestClient):
118
- response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test"})
119
- assert response.status_code == 200
120
- assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
121
- if not response.content:
122
- pytest.fail("Arrow stream response content is empty")
123
- try:
124
- reader = pa_ipc.open_stream(response.content)
125
- table = reader.read_all()
126
- except pa.ArrowInvalid as e:
127
- pytest.fail(f"Failed to read Arrow stream: {e}")
128
- assert table.num_rows == 10
129
- assert table.column_names == ["id", "category"]
130
- assert table.column('id').to_pylist() == list(range(10))
131
-
132
- def test_stream_arrow_empty(self, client: TestClient):
133
- response = client.post("/query/stream/arrow", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
134
- assert response.status_code == 200
135
- assert response.headers["content-type"] == "application/vnd.apache.arrow.stream"
136
- try:
137
- reader = pa_ipc.open_stream(response.content)
138
- table = reader.read_all()
139
- assert table.num_rows == 0
140
- except pa.ArrowInvalid as e:
141
- print(f"Received ArrowInvalid for empty stream, which is acceptable: {e}")
142
- assert response.content == b''
143
-
144
- def test_stream_jsonl(self, client: TestClient):
145
- response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test ORDER BY id"})
146
- assert response.status_code == 200
147
- assert response.headers["content-type"] == "application/jsonl"
148
- lines = response.text.strip().split('\n')
149
- records = [json.loads(line) for line in lines if line]
150
- assert len(records) == 10
151
- assert records[0] == {"id": 0, "category": 0}
152
- assert records[9] == {"id": 9, "category": 4}
153
-
154
- def test_stream_jsonl_empty(self, client: TestClient):
155
- response = client.post("/query/stream/jsonl", json={"sql": "SELECT * FROM stream_test WHERE id < 0"})
156
- assert response.status_code == 200
157
- assert response.headers["content-type"] == "application/jsonl"
158
- assert response.text.strip() == ""
159
-
160
- class TestExportDownload: # (Keep setup_data as before)
161
- @pytest.fixture(scope="class", autouse=True)
162
- def setup_data(self, client: TestClient):
163
- client.post("/execute", json={"sql": "CREATE OR REPLACE TABLE export_table(id INTEGER, name VARCHAR, price DECIMAL(5,2))"})
164
- client.post("/execute", json={"sql": "INSERT INTO export_table VALUES (1, 'Apple', 0.50), (2, 'Banana', 0.30), (3, 'Orange', 0.75)"})
165
-
166
- @pytest.mark.parametrize(
167
- "endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn",
168
- [
169
- ("csv", "text/csv", ".csv", lambda c: b"id,name,price\n1,Apple,0.50\n" in c),
170
- ("parquet", "application/vnd.apache.parquet", ".parquet", lambda c: c.startswith(b"PAR1")),
171
- # --- MODIFIED JSON/JSONL Lambdas ---
172
- ("json", "application/json", ".json", lambda c: c.strip().startswith(b'[') and c.strip().endswith(b']')),
173
- ("jsonl", "application/jsonl", ".jsonl", lambda c: b'"id":1' in c and b'"name":"Apple"' in c and b'\n' in c),
174
- # --- END MODIFICATION ---
175
- ]
176
- )
177
- def test_export_data(self, client: TestClient, endpoint_suffix, expected_content_type, expected_filename_ext, validation_fn, tmp_path):
178
- endpoint = f"/export/data/{endpoint_suffix}"
179
- payload = {"source": "export_table"}
180
- if endpoint_suffix == 'csv':
181
- payload['options'] = {'HEADER': True}
182
-
183
- response = client.post(endpoint, json=payload)
184
-
185
- assert response.status_code == 200, f"Request to {endpoint} failed: {response.text}"
186
- assert response.headers["content-type"].startswith(expected_content_type)
187
- assert "content-disposition" in response.headers
188
- assert f'filename="export_export_table{expected_filename_ext}"' in response.headers["content-disposition"]
189
-
190
- downloaded_path = tmp_path / f"downloaded{expected_filename_ext}"
191
- with open(downloaded_path, "wb") as f:
192
- f.write(response.content)
193
- assert downloaded_path.exists()
194
- assert validation_fn(response.content), f"Validation failed for {endpoint_suffix}"
195
-
196
- # Test with a query source
197
- payload = {"source": "SELECT id, name FROM export_table WHERE price > 0.40 ORDER BY id"}
198
- response = client.post(endpoint, json=payload)
199
- assert response.status_code == 200
200
- assert f'filename="export_query{expected_filename_ext}"' in response.headers["content-disposition"]
201
- assert len(response.content) > 0
202
-
203
- # --- Keep test_export_database as before ---
204
- def test_export_database(self, client: TestClient, tmp_path):
205
- client.post("/execute", json={"sql": "CREATE TABLE IF NOT EXISTS another_table(x int)"})
206
- response = client.post("/export/database")
207
- assert response.status_code == 200
208
- assert response.headers["content-type"] == "application/zip"
209
- assert "content-disposition" in response.headers
210
- assert response.headers["content-disposition"].startswith("attachment; filename=")
211
- assert 'filename="in_memory_db_export.zip"' in response.headers["content-disposition"]
212
- zip_path = tmp_path / "db_export.zip"
213
- with open(zip_path, "wb") as f:
214
- f.write(response.content)
215
- assert zip_path.exists()
216
- with zipfile.ZipFile(zip_path, 'r') as z:
217
- print(f"Zip contents: {z.namelist()}")
218
- assert "schema.sql" in z.namelist()
219
- assert "load.sql" in z.namelist()
220
- assert any(name.startswith("export_table") for name in z.namelist())
221
- assert any(name.startswith("another_table") for name in z.namelist())
222
-
223
-
224
- class TestExtensions: # (Keep as before)
225
- def test_install_extension_fail(self, client: TestClient):
226
- response = client.post("/extensions/install", json={"extension_name": "nonexistent_dummy_ext"})
227
- assert response.status_code >= 400
228
- assert "Error during install" in response.json()["detail"] or "Failed to download" in response.json()["detail"]
229
-
230
- def test_load_extension_fail(self, client: TestClient):
231
- response = client.post("/extensions/load", json={"extension_name": "nonexistent_dummy_ext"})
232
- assert response.status_code == 400
233
- # --- MODIFIED Assertion ---
234
- assert "Error loading extension" in response.json()["detail"]
235
- # --- END MODIFICATION ---
236
- assert "not found" in response.json()["detail"].lower()
237
-
238
- @pytest.mark.skip(reason="Requires httpfs extension to be available for install/load")
239
- def test_install_and_load_httpfs(self, client: TestClient):
240
- install_response = client.post("/extensions/install", json={"extension_name": "httpfs"})
241
- assert install_response.status_code == 200
242
- assert install_response.json()["status"] == "success"
243
-
244
- load_response = client.post("/extensions/load", json={"extension_name": "httpfs"})
245
- assert load_response.status_code == 200
246
- assert load_response.json()["status"] == "success"