import duckdb import os from fastapi import FastAPI, HTTPException, Request, Path as FastPath, Body # --- Add FileResponse --- from fastapi.responses import FileResponse, StreamingResponse # --- Add CORS --- from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import logging import io import asyncio from contextlib import contextmanager # Ensure this is imported # --- Configuration --- DATABASE_PATH = os.environ.get("DUCKDB_PATH", "data/mydatabase.db") DATA_DIR = "data" HTML_FILE_PATH = "index.html" # Path relative to main.py # Ensure data directory exists os.makedirs(DATA_DIR, exist_ok=True) # --- Logging --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- FastAPI App --- app = FastAPI( title="DuckDB API", description="An API to interact with a DuckDB database.", version="0.1.0" ) # --- Add CORS Middleware --- # Allows requests from any origin in this example. # Restrict this in a production environment! app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # --- Database Connection (using context manager for safety) --- @contextmanager def get_db_context(): conn = None try: # Check if the database file needs initialization initialize = not os.path.exists(DATABASE_PATH) or os.path.getsize(DATABASE_PATH) == 0 conn = duckdb.connect(DATABASE_PATH, read_only=False) # Allow writes for setup if initialize: logger.info(f"Database file not found or empty at {DATABASE_PATH}. Initializing.") # Optionally create a default table if the DB is new # conn.execute("CREATE TABLE IF NOT EXISTS example (id INTEGER, name VARCHAR);") yield conn except duckdb.Error as e: logger.error(f"Database connection error: {e}") raise HTTPException(status_code=500, detail=f"Database connection error: {e}") finally: if conn: conn.close() # --- Pydantic Models (keep existing) --- class ColumnDefinition(BaseModel): name: str type: str class TableSchemaResponse(BaseModel): columns: List[ColumnDefinition] class CreateTableRequest(BaseModel): columns: List[ColumnDefinition] class CreateRowRequest(BaseModel): rows: List[Dict[str, Any]] class UpdateRowRequest(BaseModel): updates: Dict[str, Any] condition: str class DeleteRowRequest(BaseModel): condition: str class SQLQueryRequest(BaseModel): sql: str class ApiResponse(BaseModel): message: str details: Optional[Any] = None # # --- Helper Functions (keep existing) --- # def safe_identifier(name: str) -> str: # """Quotes an identifier safely using DuckDB.""" # if not name or not isinstance(name, str): # raise HTTPException(status_code=400, detail=f"Invalid identifier provided: {name}") # try: # with duckdb.connect(':memory:') as temp_conn: # quoted = temp_conn.sql(f"SELECT '{name}'::IDENTIFIER").fetchone() # if quoted: # return quoted[0] # else: # raise HTTPException(status_code=500, detail="Failed to quote identifier") # except duckdb.Error as e: # logger.error(f"Error quoting identifier '{name}': {e}") # raise HTTPException(status_code=400, detail=f"Invalid identifier '{name}': {e}") def safe_identifier(name: str) -> str: """Quotes an identifier safely for DuckDB SQL.""" if not name or not isinstance(name, str): raise HTTPException(status_code=400, detail=f"Invalid identifier provided: {name}") # Escape any double quotes within the identifier itself escaped_name = name.replace('"', '""') # Always enclose in double quotes for safety, especially with keywords or special chars return f'"{escaped_name}"' # def generate_column_sql(columns: List[ColumnDefinition]) -> str: # """Generates the column definition part of a CREATE TABLE statement.""" # defs = [] # for col in columns: # col_name_safe = safe_identifier(col.name) # allowed_types_prefix = ['INTEGER', 'VARCHAR', 'TEXT', 'BOOLEAN', 'FLOAT', 'DOUBLE', 'DATE', 'TIMESTAMP', 'BLOB', 'BIGINT', 'DECIMAL', 'LIST', 'STRUCT', 'MAP', 'UNION'] # type_upper = col.type.strip().upper() # is_allowed = False # for prefix in allowed_types_prefix: # if type_upper.startswith(prefix): # is_allowed = True # break # if not is_allowed: # raise HTTPException(status_code=400, detail=f"Unsupported or potentially invalid data type: {col.type}") # defs.append(f"{col_name_safe} {col.type}") # return ", ".join(defs) def generate_column_sql(columns: List[ColumnDefinition]) -> str: """Generates the column definition part of a CREATE TABLE statement.""" defs = [] for col in columns: col_name_safe = safe_identifier(col.name) # --- REMOVE OR COMMENT OUT THE STRICT VALIDATION --- # allowed_types_prefix = ['INTEGER', 'VARCHAR', 'TEXT', 'BOOLEAN', 'FLOAT', 'DOUBLE', 'DATE', 'TIMESTAMP', 'BLOB', 'BIGINT', 'DECIMAL', 'LIST', 'STRUCT', 'MAP', 'UNION'] # type_upper = col.type.strip().upper() # is_allowed = False # for prefix in allowed_types_prefix: # # Allow types like VARCHAR(255), DECIMAL(10,2), LIST, STRUCT etc. # if type_upper.startswith(prefix): # is_allowed = True # break # if not is_allowed: # raise HTTPException(status_code=400, detail=f"Unsupported or potentially invalid data type: {col.type}") # --- END REMOVAL --- # Trust DuckDB to validate the full type string including constraints defs.append(f"{col_name_safe} {col.type}") return ", ".join(defs) def result_to_dict(cursor_description, rows): """Converts cursor results (description + rows) to a list of dictionaries.""" if not cursor_description: # Handle cases like non-SELECT queries returning None description return [] column_names = [desc[0] for desc in cursor_description] return [dict(zip(column_names, row)) for row in rows] # --- NEW ROOT ENDPOINT --- @app.get("/", include_in_schema=False) # include_in_schema=False hides it from OpenAPI docs async def read_index_html(): """Serves the main index.html file.""" if not os.path.exists(HTML_FILE_PATH): logger.error(f"{HTML_FILE_PATH} not found!") raise HTTPException(status_code=404, detail="index.html not found") logger.info(f"Serving {HTML_FILE_PATH}") return FileResponse(HTML_FILE_PATH) # --- API Endpoints (keep or adapt existing, add /tables and /tables/{...}/schema if not present) --- @app.get("/tables", summary="List Tables", response_model=List[str]) async def list_tables(): """Lists all tables in the default schema.""" try: with get_db_context() as conn: tables = conn.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name").fetchall() return [table[0] for table in tables] except duckdb.Error as e: logger.error(f"Error listing tables: {e}") raise HTTPException(status_code=500, detail=f"Error listing tables: {e}") @app.get("/tables/{table_name}/schema", summary="Get Table Schema", response_model=TableSchemaResponse) async def get_table_schema( table_name: str = FastPath(..., description="Name of the table") ): """Gets the schema (column names and types) for a specific table.""" table_name_safe = safe_identifier(table_name) sql = f"PRAGMA table_info({table_name_safe});" try: with get_db_context() as conn: result = conn.execute(sql).fetchall() if not result: raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found or has no columns.") columns = [ColumnDefinition(name=row[1], type=row[2]) for row in result] return TableSchemaResponse(columns=columns) except duckdb.CatalogException as e: raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") except duckdb.Error as e: logger.error(f"Error getting schema for table '{table_name}': {e}") raise HTTPException(status_code=400, detail=f"Error getting table schema: {e}") @app.post("/query", summary="Execute Read-Only SQL Query") async def execute_query(query_request: SQLQueryRequest): """Executes a provided SQL query (read-only enforced).""" sql = query_request.sql.strip() forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'ATTACH', 'DETACH', 'COPY', 'EXPORT', 'IMPORT'] sql_upper = sql.upper() if any(keyword in sql_upper for keyword in forbidden_keywords): raise HTTPException(status_code=403, detail="Only SELECT queries are allowed.") if not sql_upper.startswith('SELECT') and not sql_upper.startswith('WITH') and not sql_upper.startswith('PRAGMA') and not sql_upper.startswith('SHOW'): # Allow PRAGMA and SHOW for exploration raise HTTPException(status_code=400, detail="Query must start with SELECT, WITH, PRAGMA, or SHOW.") try: logger.info(f"Executing user SQL: {sql}") with get_db_context() as conn: # Use sql() to get a relation, which gives description even for empty results rel = conn.sql(sql) description = rel.description result = rel.fetchall() data = result_to_dict(description, result) return data except duckdb.Error as e: logger.error(f"Error executing user query: {e}") raise HTTPException(status_code=400, detail=f"Error executing query: {e}") except Exception as e: logger.error(f"Unexpected error executing user query: {e}") raise HTTPException(status_code=500, detail="An unexpected error occurred during query execution.") @app.post("/tables/{table_name}", summary="Create Table", response_model=ApiResponse, status_code=201) async def create_table( table_name: str = FastPath(..., description="Name of the table to create"), schema: CreateTableRequest = ..., ): """Creates or replaces a table with the specified schema.""" table_name_safe = safe_identifier(table_name) if not schema.columns: raise HTTPException(status_code=400, detail="Table must have at least one column.") try: columns_sql = generate_column_sql(schema.columns) sql = f"CREATE OR REPLACE TABLE {table_name_safe} ({columns_sql});" logger.info(f"Executing SQL: {sql}") with get_db_context() as conn: conn.execute(sql) return {"message": f"Table '{table_name}' created or replaced successfully."} except HTTPException as e: raise e except duckdb.Error as e: logger.error(f"Error creating/replacing table '{table_name}': {e}") raise HTTPException(status_code=400, detail=f"Error creating/replacing table: {e}") except Exception as e: logger.error(f"Unexpected error creating/replacing table '{table_name}': {e}") raise HTTPException(status_code=500, detail="An unexpected error occurred.") @app.get("/tables/{table_name}", summary="Read Table Data") async def read_table( table_name: str = FastPath(..., description="Name of the table to read from"), limit: Optional[int] = 100, offset: Optional[int] = 0 ): """Reads and returns rows from a specified table. Supports limit and offset.""" table_name_safe = safe_identifier(table_name) sql = f"SELECT * FROM {table_name_safe}" params = [] if limit is not None and limit >= 0: sql += " LIMIT ?" params.append(limit) if offset is not None and offset >= 0: sql += " OFFSET ?" params.append(offset) sql += ";" try: logger.info(f"Executing SQL: {sql} with params: {params}") with get_db_context() as conn: rel = conn.sql(sql, params=params) description = rel.description result = rel.fetchall() data = result_to_dict(description, result) return data except duckdb.CatalogException as e: raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") except duckdb.Error as e: logger.error(f"Error reading table '{table_name}': {e}") raise HTTPException(status_code=400, detail=f"Error reading table: {e}") except Exception as e: logger.error(f"Unexpected error reading table '{table_name}': {e}") raise HTTPException(status_code=500, detail="An unexpected error occurred.") # ... (keep other existing endpoints like create_rows, update_rows, delete_rows, downloads, health check) ... # Make sure they use `with get_db_context() as conn:` instead of the `for conn in get_db():` loop # For example: @app.post("/tables/{table_name}/rows", summary="Create Rows", response_model=ApiResponse, status_code=201) async def create_rows( table_name: str = FastPath(..., description="Name of the table to insert into"), request: CreateRowRequest = ..., ): table_name_safe = safe_identifier(table_name) if not request.rows: raise HTTPException(status_code=400, detail="No rows provided to insert.") columns = list(request.rows[0].keys()) columns_safe = [safe_identifier(col) for col in columns] placeholders = ", ".join(["?"] * len(columns)) columns_sql = ", ".join(columns_safe) sql = f"INSERT INTO {table_name_safe} ({columns_sql}) VALUES ({placeholders});" params_list = [] for row_dict in request.rows: if list(row_dict.keys()) != columns: raise HTTPException(status_code=400, detail="All rows must have the same columns in the same order.") params_list.append(list(row_dict.values())) try: logger.info(f"Executing SQL: {sql} for {len(params_list)} rows") with get_db_context() as conn: # Use context manager conn.executemany(sql, params_list) return {"message": f"Successfully inserted {len(params_list)} rows into '{table_name}'."} except duckdb.CatalogException: raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found.") except duckdb.Error as e: logger.error(f"Error inserting rows into '{table_name}': {e}") raise HTTPException(status_code=400, detail=f"Error inserting rows: {e}") except Exception as e: logger.error(f"Unexpected error inserting rows into '{table_name}': {e}") raise HTTPException(status_code=500, detail="An unexpected error occurred.") # --- Apply the `with get_db_context() as conn:` pattern to update_rows, delete_rows, download_table_csv etc. --- # --- Download Endpoints --- @app.get("/download/table/{table_name}", summary="Download Table as CSV") async def download_table_csv( table_name: str = FastPath(..., description="Name of the table to download") ): table_name_safe = safe_identifier(table_name) sql = f"COPY (SELECT * FROM {table_name_safe}) TO STDOUT (FORMAT CSV, HEADER)" async def stream_csv_data(): try: with get_db_context() as conn: # Check if table exists before fetching conn.execute(f"SELECT 1 FROM {table_name_safe} LIMIT 0") # Use pandas for CSV conversion in-memory df = conn.execute(f"SELECT * FROM {table_name_safe}").df() all_data_io = io.StringIO() df.to_csv(all_data_io, index=False) all_data_io.seek(0) chunk_size = 8192 while True: chunk = all_data_io.read(chunk_size) if not chunk: break yield chunk.encode('utf-8') await asyncio.sleep(0) all_data_io.close() except duckdb.CatalogException: yield f"Error: Table '{table_name}' not found.".encode('utf-8') logger.error(f"Error downloading table '{table_name}': Table not found.") except duckdb.Error as e: yield f"Error: Could not export table '{table_name}'. {e}".encode('utf-8') logger.error(f"Error downloading table '{table_name}': {e}") except Exception as e: yield f"Error: An unexpected error occurred.".encode('utf-8') logger.error(f"Unexpected error downloading table '{table_name}': {e}") return StreamingResponse( stream_csv_data(), media_type="text/csv", headers={"Content-Disposition": f"attachment; filename={table_name}.csv"}, ) @app.get("/download/database", summary="Download Database File") async def download_database_file(): if not os.path.exists(DATABASE_PATH): raise HTTPException(status_code=404, detail="Database file not found.") logger.warning("Attempting to download database file. Ensure no active writes are occurring.") return FileResponse( path=DATABASE_PATH, filename=os.path.basename(DATABASE_PATH), media_type="application/vnd.duckdb.database" ) # --- Health Check --- @app.get("/health", summary="Health Check", response_model=ApiResponse) async def health_check(): try: with get_db_context() as conn: conn.execute("SELECT 1") return {"message": "API is healthy and database connection is successful."} except Exception as e: logger.error(f"Health check failed: {e}") raise HTTPException(status_code=503, detail=f"Health check failed: {e}")