|
import duckdb |
|
import os |
|
from fastapi import FastAPI, HTTPException, Request, Path as FastPath, Body |
|
|
|
from fastapi.responses import FileResponse, StreamingResponse |
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, Field |
|
from typing import List, Dict, Any, Optional |
|
import logging |
|
import io |
|
import asyncio |
|
from contextlib import contextmanager |
|
|
|
|
|
DATABASE_PATH = os.environ.get("DUCKDB_PATH", "data/mydatabase.db") |
|
DATA_DIR = "data" |
|
HTML_FILE_PATH = "index.html" |
|
|
|
|
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
title="DuckDB API", |
|
description="An API to interact with a DuckDB database.", |
|
version="0.1.0" |
|
) |
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
@contextmanager |
|
def get_db_context(): |
|
conn = None |
|
try: |
|
|
|
initialize = not os.path.exists(DATABASE_PATH) or os.path.getsize(DATABASE_PATH) == 0 |
|
conn = duckdb.connect(DATABASE_PATH, read_only=False) |
|
if initialize: |
|
logger.info(f"Database file not found or empty at {DATABASE_PATH}. Initializing.") |
|
|
|
|
|
yield conn |
|
except duckdb.Error as e: |
|
logger.error(f"Database connection error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Database connection error: {e}") |
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
|
|
class ColumnDefinition(BaseModel): |
|
name: str |
|
type: str |
|
|
|
class TableSchemaResponse(BaseModel): |
|
columns: List[ColumnDefinition] |
|
|
|
class CreateTableRequest(BaseModel): |
|
columns: List[ColumnDefinition] |
|
|
|
class CreateRowRequest(BaseModel): |
|
rows: List[Dict[str, Any]] |
|
|
|
class UpdateRowRequest(BaseModel): |
|
updates: Dict[str, Any] |
|
condition: str |
|
|
|
class DeleteRowRequest(BaseModel): |
|
condition: str |
|
|
|
class SQLQueryRequest(BaseModel): |
|
sql: str |
|
|
|
class ApiResponse(BaseModel): |
|
message: str |
|
details: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_identifier(name: str) -> str: |
|
"""Quotes an identifier safely for DuckDB SQL.""" |
|
if not name or not isinstance(name, str): |
|
raise HTTPException(status_code=400, detail=f"Invalid identifier provided: {name}") |
|
|
|
escaped_name = name.replace('"', '""') |
|
|
|
return f'"{escaped_name}"' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_column_sql(columns: List[ColumnDefinition]) -> str: |
|
"""Generates the column definition part of a CREATE TABLE statement.""" |
|
defs = [] |
|
for col in columns: |
|
col_name_safe = safe_identifier(col.name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
defs.append(f"{col_name_safe} {col.type}") |
|
return ", ".join(defs) |
|
|
|
def result_to_dict(cursor_description, rows): |
|
"""Converts cursor results (description + rows) to a list of dictionaries.""" |
|
if not cursor_description: |
|
return [] |
|
column_names = [desc[0] for desc in cursor_description] |
|
return [dict(zip(column_names, row)) for row in rows] |
|
|
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
async def read_index_html(): |
|
"""Serves the main index.html file.""" |
|
if not os.path.exists(HTML_FILE_PATH): |
|
logger.error(f"{HTML_FILE_PATH} not found!") |
|
raise HTTPException(status_code=404, detail="index.html not found") |
|
logger.info(f"Serving {HTML_FILE_PATH}") |
|
return FileResponse(HTML_FILE_PATH) |
|
|
|
|
|
|
|
@app.get("/tables", summary="List Tables", response_model=List[str]) |
|
async def list_tables(): |
|
"""Lists all tables in the default schema.""" |
|
try: |
|
with get_db_context() as conn: |
|
tables = conn.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name").fetchall() |
|
return [table[0] for table in tables] |
|
except duckdb.Error as e: |
|
logger.error(f"Error listing tables: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error listing tables: {e}") |
|
|
|
@app.get("/tables/{table_name}/schema", summary="Get Table Schema", response_model=TableSchemaResponse) |
|
async def get_table_schema( |
|
table_name: str = FastPath(..., description="Name of the table") |
|
): |
|
"""Gets the schema (column names and types) for a specific table.""" |
|
table_name_safe = safe_identifier(table_name) |
|
sql = f"PRAGMA table_info({table_name_safe});" |
|
try: |
|
with get_db_context() as conn: |
|
result = conn.execute(sql).fetchall() |
|
if not result: |
|
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found or has no columns.") |
|
columns = [ColumnDefinition(name=row[1], type=row[2]) for row in result] |
|
return TableSchemaResponse(columns=columns) |
|
except duckdb.CatalogException as e: |
|
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") |
|
except duckdb.Error as e: |
|
logger.error(f"Error getting schema for table '{table_name}': {e}") |
|
raise HTTPException(status_code=400, detail=f"Error getting table schema: {e}") |
|
|
|
@app.post("/query", summary="Execute Read-Only SQL Query") |
|
async def execute_query(query_request: SQLQueryRequest): |
|
"""Executes a provided SQL query (read-only enforced).""" |
|
sql = query_request.sql.strip() |
|
forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'ATTACH', 'DETACH', 'COPY', 'EXPORT', 'IMPORT'] |
|
sql_upper = sql.upper() |
|
if any(keyword in sql_upper for keyword in forbidden_keywords): |
|
raise HTTPException(status_code=403, detail="Only SELECT queries are allowed.") |
|
if not sql_upper.startswith('SELECT') and not sql_upper.startswith('WITH') and not sql_upper.startswith('PRAGMA') and not sql_upper.startswith('SHOW'): |
|
|
|
raise HTTPException(status_code=400, detail="Query must start with SELECT, WITH, PRAGMA, or SHOW.") |
|
|
|
try: |
|
logger.info(f"Executing user SQL: {sql}") |
|
with get_db_context() as conn: |
|
|
|
rel = conn.sql(sql) |
|
description = rel.description |
|
result = rel.fetchall() |
|
data = result_to_dict(description, result) |
|
return data |
|
except duckdb.Error as e: |
|
logger.error(f"Error executing user query: {e}") |
|
raise HTTPException(status_code=400, detail=f"Error executing query: {e}") |
|
except Exception as e: |
|
logger.error(f"Unexpected error executing user query: {e}") |
|
raise HTTPException(status_code=500, detail="An unexpected error occurred during query execution.") |
|
|
|
|
|
@app.post("/tables/{table_name}", summary="Create Table", response_model=ApiResponse, status_code=201) |
|
async def create_table( |
|
table_name: str = FastPath(..., description="Name of the table to create"), |
|
schema: CreateTableRequest = ..., |
|
): |
|
"""Creates or replaces a table with the specified schema.""" |
|
table_name_safe = safe_identifier(table_name) |
|
if not schema.columns: |
|
raise HTTPException(status_code=400, detail="Table must have at least one column.") |
|
try: |
|
columns_sql = generate_column_sql(schema.columns) |
|
sql = f"CREATE OR REPLACE TABLE {table_name_safe} ({columns_sql});" |
|
logger.info(f"Executing SQL: {sql}") |
|
with get_db_context() as conn: |
|
conn.execute(sql) |
|
return {"message": f"Table '{table_name}' created or replaced successfully."} |
|
except HTTPException as e: raise e |
|
except duckdb.Error as e: |
|
logger.error(f"Error creating/replacing table '{table_name}': {e}") |
|
raise HTTPException(status_code=400, detail=f"Error creating/replacing table: {e}") |
|
except Exception as e: |
|
logger.error(f"Unexpected error creating/replacing table '{table_name}': {e}") |
|
raise HTTPException(status_code=500, detail="An unexpected error occurred.") |
|
|
|
|
|
@app.get("/tables/{table_name}", summary="Read Table Data") |
|
async def read_table( |
|
table_name: str = FastPath(..., description="Name of the table to read from"), |
|
limit: Optional[int] = 100, |
|
offset: Optional[int] = 0 |
|
): |
|
"""Reads and returns rows from a specified table. Supports limit and offset.""" |
|
table_name_safe = safe_identifier(table_name) |
|
sql = f"SELECT * FROM {table_name_safe}" |
|
params = [] |
|
if limit is not None and limit >= 0: |
|
sql += " LIMIT ?" |
|
params.append(limit) |
|
if offset is not None and offset >= 0: |
|
sql += " OFFSET ?" |
|
params.append(offset) |
|
sql += ";" |
|
try: |
|
logger.info(f"Executing SQL: {sql} with params: {params}") |
|
with get_db_context() as conn: |
|
rel = conn.sql(sql, params=params) |
|
description = rel.description |
|
result = rel.fetchall() |
|
data = result_to_dict(description, result) |
|
return data |
|
except duckdb.CatalogException as e: |
|
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") |
|
except duckdb.Error as e: |
|
logger.error(f"Error reading table '{table_name}': {e}") |
|
raise HTTPException(status_code=400, detail=f"Error reading table: {e}") |
|
except Exception as e: |
|
logger.error(f"Unexpected error reading table '{table_name}': {e}") |
|
raise HTTPException(status_code=500, detail="An unexpected error occurred.") |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/tables/{table_name}/rows", summary="Create Rows", response_model=ApiResponse, status_code=201) |
|
async def create_rows( |
|
table_name: str = FastPath(..., description="Name of the table to insert into"), |
|
request: CreateRowRequest = ..., |
|
): |
|
table_name_safe = safe_identifier(table_name) |
|
if not request.rows: |
|
raise HTTPException(status_code=400, detail="No rows provided to insert.") |
|
columns = list(request.rows[0].keys()) |
|
columns_safe = [safe_identifier(col) for col in columns] |
|
placeholders = ", ".join(["?"] * len(columns)) |
|
columns_sql = ", ".join(columns_safe) |
|
sql = f"INSERT INTO {table_name_safe} ({columns_sql}) VALUES ({placeholders});" |
|
params_list = [] |
|
for row_dict in request.rows: |
|
if list(row_dict.keys()) != columns: |
|
raise HTTPException(status_code=400, detail="All rows must have the same columns in the same order.") |
|
params_list.append(list(row_dict.values())) |
|
try: |
|
logger.info(f"Executing SQL: {sql} for {len(params_list)} rows") |
|
with get_db_context() as conn: |
|
conn.executemany(sql, params_list) |
|
return {"message": f"Successfully inserted {len(params_list)} rows into '{table_name}'."} |
|
except duckdb.CatalogException: |
|
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") |
|
except duckdb.Error as e: |
|
logger.error(f"Error inserting rows into '{table_name}': {e}") |
|
raise HTTPException(status_code=400, detail=f"Error inserting rows: {e}") |
|
except Exception as e: |
|
logger.error(f"Unexpected error inserting rows into '{table_name}': {e}") |
|
raise HTTPException(status_code=500, detail="An unexpected error occurred.") |
|
|
|
|
|
|
|
|
|
@app.get("/download/table/{table_name}", summary="Download Table as CSV") |
|
async def download_table_csv( |
|
table_name: str = FastPath(..., description="Name of the table to download") |
|
): |
|
table_name_safe = safe_identifier(table_name) |
|
sql = f"COPY (SELECT * FROM {table_name_safe}) TO STDOUT (FORMAT CSV, HEADER)" |
|
|
|
async def stream_csv_data(): |
|
try: |
|
with get_db_context() as conn: |
|
|
|
conn.execute(f"SELECT 1 FROM {table_name_safe} LIMIT 0") |
|
|
|
df = conn.execute(f"SELECT * FROM {table_name_safe}").df() |
|
|
|
all_data_io = io.StringIO() |
|
df.to_csv(all_data_io, index=False) |
|
all_data_io.seek(0) |
|
|
|
chunk_size = 8192 |
|
while True: |
|
chunk = all_data_io.read(chunk_size) |
|
if not chunk: break |
|
yield chunk.encode('utf-8') |
|
await asyncio.sleep(0) |
|
all_data_io.close() |
|
except duckdb.CatalogException: |
|
yield f"Error: Table '{table_name}' not found.".encode('utf-8') |
|
logger.error(f"Error downloading table '{table_name}': Table not found.") |
|
except duckdb.Error as e: |
|
yield f"Error: Could not export table '{table_name}'. {e}".encode('utf-8') |
|
logger.error(f"Error downloading table '{table_name}': {e}") |
|
except Exception as e: |
|
yield f"Error: An unexpected error occurred.".encode('utf-8') |
|
logger.error(f"Unexpected error downloading table '{table_name}': {e}") |
|
|
|
return StreamingResponse( |
|
stream_csv_data(), |
|
media_type="text/csv", |
|
headers={"Content-Disposition": f"attachment; filename={table_name}.csv"}, |
|
) |
|
|
|
|
|
@app.get("/download/database", summary="Download Database File") |
|
async def download_database_file(): |
|
if not os.path.exists(DATABASE_PATH): |
|
raise HTTPException(status_code=404, detail="Database file not found.") |
|
logger.warning("Attempting to download database file. Ensure no active writes are occurring.") |
|
return FileResponse( |
|
path=DATABASE_PATH, |
|
filename=os.path.basename(DATABASE_PATH), |
|
media_type="application/vnd.duckdb.database" |
|
) |
|
|
|
|
|
@app.get("/health", summary="Health Check", response_model=ApiResponse) |
|
async def health_check(): |
|
try: |
|
with get_db_context() as conn: |
|
conn.execute("SELECT 1") |
|
return {"message": "API is healthy and database connection is successful."} |
|
except Exception as e: |
|
logger.error(f"Health check failed: {e}") |
|
raise HTTPException(status_code=503, detail=f"Health check failed: {e}") |