DuckDB-UI / main.py
amaye15's picture
API
850182e
import duckdb
import os
from fastapi import FastAPI, HTTPException, Request, Path as FastPath, Body
# --- Add FileResponse ---
from fastapi.responses import FileResponse, StreamingResponse
# --- Add CORS ---
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import logging
import io
import asyncio
from contextlib import contextmanager # Ensure this is imported
# --- Configuration ---
DATABASE_PATH = os.environ.get("DUCKDB_PATH", "data/mydatabase.db")
DATA_DIR = "data"
HTML_FILE_PATH = "index.html" # Path relative to main.py
# Ensure data directory exists
os.makedirs(DATA_DIR, exist_ok=True)
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- FastAPI App ---
app = FastAPI(
title="DuckDB API",
description="An API to interact with a DuckDB database.",
version="0.1.0"
)
# --- Add CORS Middleware ---
# Allows requests from any origin in this example.
# Restrict this in a production environment!
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# --- Database Connection (using context manager for safety) ---
@contextmanager
def get_db_context():
conn = None
try:
# Check if the database file needs initialization
initialize = not os.path.exists(DATABASE_PATH) or os.path.getsize(DATABASE_PATH) == 0
conn = duckdb.connect(DATABASE_PATH, read_only=False) # Allow writes for setup
if initialize:
logger.info(f"Database file not found or empty at {DATABASE_PATH}. Initializing.")
# Optionally create a default table if the DB is new
# conn.execute("CREATE TABLE IF NOT EXISTS example (id INTEGER, name VARCHAR);")
yield conn
except duckdb.Error as e:
logger.error(f"Database connection error: {e}")
raise HTTPException(status_code=500, detail=f"Database connection error: {e}")
finally:
if conn:
conn.close()
# --- Pydantic Models (keep existing) ---
class ColumnDefinition(BaseModel):
name: str
type: str
class TableSchemaResponse(BaseModel):
columns: List[ColumnDefinition]
class CreateTableRequest(BaseModel):
columns: List[ColumnDefinition]
class CreateRowRequest(BaseModel):
rows: List[Dict[str, Any]]
class UpdateRowRequest(BaseModel):
updates: Dict[str, Any]
condition: str
class DeleteRowRequest(BaseModel):
condition: str
class SQLQueryRequest(BaseModel):
sql: str
class ApiResponse(BaseModel):
message: str
details: Optional[Any] = None
# # --- Helper Functions (keep existing) ---
# def safe_identifier(name: str) -> str:
# """Quotes an identifier safely using DuckDB."""
# if not name or not isinstance(name, str):
# raise HTTPException(status_code=400, detail=f"Invalid identifier provided: {name}")
# try:
# with duckdb.connect(':memory:') as temp_conn:
# quoted = temp_conn.sql(f"SELECT '{name}'::IDENTIFIER").fetchone()
# if quoted:
# return quoted[0]
# else:
# raise HTTPException(status_code=500, detail="Failed to quote identifier")
# except duckdb.Error as e:
# logger.error(f"Error quoting identifier '{name}': {e}")
# raise HTTPException(status_code=400, detail=f"Invalid identifier '{name}': {e}")
def safe_identifier(name: str) -> str:
"""Quotes an identifier safely for DuckDB SQL."""
if not name or not isinstance(name, str):
raise HTTPException(status_code=400, detail=f"Invalid identifier provided: {name}")
# Escape any double quotes within the identifier itself
escaped_name = name.replace('"', '""')
# Always enclose in double quotes for safety, especially with keywords or special chars
return f'"{escaped_name}"'
# def generate_column_sql(columns: List[ColumnDefinition]) -> str:
# """Generates the column definition part of a CREATE TABLE statement."""
# defs = []
# for col in columns:
# col_name_safe = safe_identifier(col.name)
# allowed_types_prefix = ['INTEGER', 'VARCHAR', 'TEXT', 'BOOLEAN', 'FLOAT', 'DOUBLE', 'DATE', 'TIMESTAMP', 'BLOB', 'BIGINT', 'DECIMAL', 'LIST', 'STRUCT', 'MAP', 'UNION']
# type_upper = col.type.strip().upper()
# is_allowed = False
# for prefix in allowed_types_prefix:
# if type_upper.startswith(prefix):
# is_allowed = True
# break
# if not is_allowed:
# raise HTTPException(status_code=400, detail=f"Unsupported or potentially invalid data type: {col.type}")
# defs.append(f"{col_name_safe} {col.type}")
# return ", ".join(defs)
def generate_column_sql(columns: List[ColumnDefinition]) -> str:
"""Generates the column definition part of a CREATE TABLE statement."""
defs = []
for col in columns:
col_name_safe = safe_identifier(col.name)
# --- REMOVE OR COMMENT OUT THE STRICT VALIDATION ---
# allowed_types_prefix = ['INTEGER', 'VARCHAR', 'TEXT', 'BOOLEAN', 'FLOAT', 'DOUBLE', 'DATE', 'TIMESTAMP', 'BLOB', 'BIGINT', 'DECIMAL', 'LIST', 'STRUCT', 'MAP', 'UNION']
# type_upper = col.type.strip().upper()
# is_allowed = False
# for prefix in allowed_types_prefix:
# # Allow types like VARCHAR(255), DECIMAL(10,2), LIST<INT>, STRUCT<a INT> etc.
# if type_upper.startswith(prefix):
# is_allowed = True
# break
# if not is_allowed:
# raise HTTPException(status_code=400, detail=f"Unsupported or potentially invalid data type: {col.type}")
# --- END REMOVAL ---
# Trust DuckDB to validate the full type string including constraints
defs.append(f"{col_name_safe} {col.type}")
return ", ".join(defs)
def result_to_dict(cursor_description, rows):
"""Converts cursor results (description + rows) to a list of dictionaries."""
if not cursor_description: # Handle cases like non-SELECT queries returning None description
return []
column_names = [desc[0] for desc in cursor_description]
return [dict(zip(column_names, row)) for row in rows]
# --- NEW ROOT ENDPOINT ---
@app.get("/", include_in_schema=False) # include_in_schema=False hides it from OpenAPI docs
async def read_index_html():
"""Serves the main index.html file."""
if not os.path.exists(HTML_FILE_PATH):
logger.error(f"{HTML_FILE_PATH} not found!")
raise HTTPException(status_code=404, detail="index.html not found")
logger.info(f"Serving {HTML_FILE_PATH}")
return FileResponse(HTML_FILE_PATH)
# --- API Endpoints (keep or adapt existing, add /tables and /tables/{...}/schema if not present) ---
@app.get("/tables", summary="List Tables", response_model=List[str])
async def list_tables():
"""Lists all tables in the default schema."""
try:
with get_db_context() as conn:
tables = conn.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name").fetchall()
return [table[0] for table in tables]
except duckdb.Error as e:
logger.error(f"Error listing tables: {e}")
raise HTTPException(status_code=500, detail=f"Error listing tables: {e}")
@app.get("/tables/{table_name}/schema", summary="Get Table Schema", response_model=TableSchemaResponse)
async def get_table_schema(
table_name: str = FastPath(..., description="Name of the table")
):
"""Gets the schema (column names and types) for a specific table."""
table_name_safe = safe_identifier(table_name)
sql = f"PRAGMA table_info({table_name_safe});"
try:
with get_db_context() as conn:
result = conn.execute(sql).fetchall()
if not result:
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found or has no columns.")
columns = [ColumnDefinition(name=row[1], type=row[2]) for row in result]
return TableSchemaResponse(columns=columns)
except duckdb.CatalogException as e:
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
except duckdb.Error as e:
logger.error(f"Error getting schema for table '{table_name}': {e}")
raise HTTPException(status_code=400, detail=f"Error getting table schema: {e}")
@app.post("/query", summary="Execute Read-Only SQL Query")
async def execute_query(query_request: SQLQueryRequest):
"""Executes a provided SQL query (read-only enforced)."""
sql = query_request.sql.strip()
forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'ATTACH', 'DETACH', 'COPY', 'EXPORT', 'IMPORT']
sql_upper = sql.upper()
if any(keyword in sql_upper for keyword in forbidden_keywords):
raise HTTPException(status_code=403, detail="Only SELECT queries are allowed.")
if not sql_upper.startswith('SELECT') and not sql_upper.startswith('WITH') and not sql_upper.startswith('PRAGMA') and not sql_upper.startswith('SHOW'):
# Allow PRAGMA and SHOW for exploration
raise HTTPException(status_code=400, detail="Query must start with SELECT, WITH, PRAGMA, or SHOW.")
try:
logger.info(f"Executing user SQL: {sql}")
with get_db_context() as conn:
# Use sql() to get a relation, which gives description even for empty results
rel = conn.sql(sql)
description = rel.description
result = rel.fetchall()
data = result_to_dict(description, result)
return data
except duckdb.Error as e:
logger.error(f"Error executing user query: {e}")
raise HTTPException(status_code=400, detail=f"Error executing query: {e}")
except Exception as e:
logger.error(f"Unexpected error executing user query: {e}")
raise HTTPException(status_code=500, detail="An unexpected error occurred during query execution.")
@app.post("/tables/{table_name}", summary="Create Table", response_model=ApiResponse, status_code=201)
async def create_table(
table_name: str = FastPath(..., description="Name of the table to create"),
schema: CreateTableRequest = ...,
):
"""Creates or replaces a table with the specified schema."""
table_name_safe = safe_identifier(table_name)
if not schema.columns:
raise HTTPException(status_code=400, detail="Table must have at least one column.")
try:
columns_sql = generate_column_sql(schema.columns)
sql = f"CREATE OR REPLACE TABLE {table_name_safe} ({columns_sql});"
logger.info(f"Executing SQL: {sql}")
with get_db_context() as conn:
conn.execute(sql)
return {"message": f"Table '{table_name}' created or replaced successfully."}
except HTTPException as e: raise e
except duckdb.Error as e:
logger.error(f"Error creating/replacing table '{table_name}': {e}")
raise HTTPException(status_code=400, detail=f"Error creating/replacing table: {e}")
except Exception as e:
logger.error(f"Unexpected error creating/replacing table '{table_name}': {e}")
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
@app.get("/tables/{table_name}", summary="Read Table Data")
async def read_table(
table_name: str = FastPath(..., description="Name of the table to read from"),
limit: Optional[int] = 100,
offset: Optional[int] = 0
):
"""Reads and returns rows from a specified table. Supports limit and offset."""
table_name_safe = safe_identifier(table_name)
sql = f"SELECT * FROM {table_name_safe}"
params = []
if limit is not None and limit >= 0:
sql += " LIMIT ?"
params.append(limit)
if offset is not None and offset >= 0:
sql += " OFFSET ?"
params.append(offset)
sql += ";"
try:
logger.info(f"Executing SQL: {sql} with params: {params}")
with get_db_context() as conn:
rel = conn.sql(sql, params=params)
description = rel.description
result = rel.fetchall()
data = result_to_dict(description, result)
return data
except duckdb.CatalogException as e:
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
except duckdb.Error as e:
logger.error(f"Error reading table '{table_name}': {e}")
raise HTTPException(status_code=400, detail=f"Error reading table: {e}")
except Exception as e:
logger.error(f"Unexpected error reading table '{table_name}': {e}")
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
# ... (keep other existing endpoints like create_rows, update_rows, delete_rows, downloads, health check) ...
# Make sure they use `with get_db_context() as conn:` instead of the `for conn in get_db():` loop
# For example:
@app.post("/tables/{table_name}/rows", summary="Create Rows", response_model=ApiResponse, status_code=201)
async def create_rows(
table_name: str = FastPath(..., description="Name of the table to insert into"),
request: CreateRowRequest = ...,
):
table_name_safe = safe_identifier(table_name)
if not request.rows:
raise HTTPException(status_code=400, detail="No rows provided to insert.")
columns = list(request.rows[0].keys())
columns_safe = [safe_identifier(col) for col in columns]
placeholders = ", ".join(["?"] * len(columns))
columns_sql = ", ".join(columns_safe)
sql = f"INSERT INTO {table_name_safe} ({columns_sql}) VALUES ({placeholders});"
params_list = []
for row_dict in request.rows:
if list(row_dict.keys()) != columns:
raise HTTPException(status_code=400, detail="All rows must have the same columns in the same order.")
params_list.append(list(row_dict.values()))
try:
logger.info(f"Executing SQL: {sql} for {len(params_list)} rows")
with get_db_context() as conn: # Use context manager
conn.executemany(sql, params_list)
return {"message": f"Successfully inserted {len(params_list)} rows into '{table_name}'."}
except duckdb.CatalogException:
raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.")
except duckdb.Error as e:
logger.error(f"Error inserting rows into '{table_name}': {e}")
raise HTTPException(status_code=400, detail=f"Error inserting rows: {e}")
except Exception as e:
logger.error(f"Unexpected error inserting rows into '{table_name}': {e}")
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
# --- Apply the `with get_db_context() as conn:` pattern to update_rows, delete_rows, download_table_csv etc. ---
# --- Download Endpoints ---
@app.get("/download/table/{table_name}", summary="Download Table as CSV")
async def download_table_csv(
table_name: str = FastPath(..., description="Name of the table to download")
):
table_name_safe = safe_identifier(table_name)
sql = f"COPY (SELECT * FROM {table_name_safe}) TO STDOUT (FORMAT CSV, HEADER)"
async def stream_csv_data():
try:
with get_db_context() as conn:
# Check if table exists before fetching
conn.execute(f"SELECT 1 FROM {table_name_safe} LIMIT 0")
# Use pandas for CSV conversion in-memory
df = conn.execute(f"SELECT * FROM {table_name_safe}").df()
all_data_io = io.StringIO()
df.to_csv(all_data_io, index=False)
all_data_io.seek(0)
chunk_size = 8192
while True:
chunk = all_data_io.read(chunk_size)
if not chunk: break
yield chunk.encode('utf-8')
await asyncio.sleep(0)
all_data_io.close()
except duckdb.CatalogException:
yield f"Error: Table '{table_name}' not found.".encode('utf-8')
logger.error(f"Error downloading table '{table_name}': Table not found.")
except duckdb.Error as e:
yield f"Error: Could not export table '{table_name}'. {e}".encode('utf-8')
logger.error(f"Error downloading table '{table_name}': {e}")
except Exception as e:
yield f"Error: An unexpected error occurred.".encode('utf-8')
logger.error(f"Unexpected error downloading table '{table_name}': {e}")
return StreamingResponse(
stream_csv_data(),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={table_name}.csv"},
)
@app.get("/download/database", summary="Download Database File")
async def download_database_file():
if not os.path.exists(DATABASE_PATH):
raise HTTPException(status_code=404, detail="Database file not found.")
logger.warning("Attempting to download database file. Ensure no active writes are occurring.")
return FileResponse(
path=DATABASE_PATH,
filename=os.path.basename(DATABASE_PATH),
media_type="application/vnd.duckdb.database"
)
# --- Health Check ---
@app.get("/health", summary="Health Check", response_model=ApiResponse)
async def health_check():
try:
with get_db_context() as conn:
conn.execute("SELECT 1")
return {"message": "API is healthy and database connection is successful."}
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(status_code=503, detail=f"Health check failed: {e}")