#!/usr/bin/env python3 """ ChromaDB Auth Proxy (robust passthrough) - Bearer auth at the edge - Streams/buffers appropriately - Preserves Content-Type, avoids JSON re-serialization - Reuses a single AsyncClient (HTTP/2, pooled) - Filters hop-by-hop headers - Maps network errors to 502/504 """ import asyncio import logging import os import time from contextlib import asynccontextmanager from typing import AsyncGenerator, Dict import httpx from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer import uvicorn # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # ------------------------- # Configuration # ------------------------- CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8001")) PROXY_PORT = int(os.getenv("PROXY_PORT", "7860")) AUTH_TOKEN = os.getenv("CHROMA_AUTH_TOKEN", "test_token_123") # Timeout configuration (in seconds) TIMEOUT_CONNECT = 10.0 TIMEOUT_READ = 60.0 * 8 TIMEOUT_WRITE = 60.0 * 2 TIMEOUT_POOL = None # ------------------------- # Security # ------------------------- security = HTTPBearer() # ------------------------- # Lifespan management # ------------------------- @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Manage application lifespan - startup and shutdown""" logger.info("🚀 Starting ChromaDB Auth Proxy lifespan") yield logger.info("🛑 Shutting down ChromaDB Auth Proxy") await _client.aclose() app = FastAPI(title="ChromaDB Auth Proxy", lifespan=lifespan) @app.get("/") async def root(): return {"status": "ok", "service": "chromadb-auth-proxy"} @app.get("/health") async def health(): return {"status": "healthy", "service": "chromadb-auth-proxy"} async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.credentials != AUTH_TOKEN: raise HTTPException( status_code=401, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) return credentials # ------------------------- # HTTP client (shared) # ------------------------- # Increased timeouts for large operations (collection deletion with 200k+ docs) _client = httpx.AsyncClient( http2=True, timeout=httpx.Timeout( connect=TIMEOUT_CONNECT, read=TIMEOUT_READ, write=TIMEOUT_WRITE, pool=TIMEOUT_POOL, ), limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), ) # Hop-by-hop headers we should not forward HOP_BY_HOP = { "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailer", "transfer-encoding", "upgrade", } # Response headers to pass through (allow-list) PASS_HEADERS = { "content-type", "cache-control", "etag", "last-modified", "expires", "vary", "location", "content-disposition", "content-encoding", "x-chroma-trace-id", } def _filter_resp_headers(upstream: httpx.Response) -> Dict[str, str]: """Drop hop-by-hop and computed headers; keep useful ones.""" out: Dict[str, str] = {} for k, v in upstream.headers.items(): kl = k.lower() if kl in HOP_BY_HOP: continue if kl in PASS_HEADERS: out[k] = v return out @app.api_route( "/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] ) async def proxy_request(request: Request, path: str, _=Depends(verify_token)): start_time = time.time() target_url = f"http://{CHROMA_HOST}:{CHROMA_PORT}/{path}" # Special logging for DELETE operations if request.method == "DELETE": logger.warning( f"⚠️ DELETE operation - may take up to {int(TIMEOUT_READ)}s for large collections" ) # Query params params = dict(request.query_params) # Forward headers except host & auth fwd_headers = {} for k, v in request.headers.items(): kl = k.lower() if kl in ("host", "authorization"): continue fwd_headers[k] = v # Only read body for write-ish methods body = None body_size = 0 if request.method in {"POST", "PUT", "PATCH"}: body = await request.body() body_size = len(body) logger.info(f" Request body size: {body_size} bytes") try: upstream_start = time.time() async with _client.stream( method=request.method, url=target_url, params=params, headers=fwd_headers, content=body, ) as upstream: upstream_time = time.time() - upstream_start status = upstream.status_code resp_headers = _filter_resp_headers(upstream) logger.info( f" ✅ Upstream response: {status} (took {upstream_time:.2f}s)" ) # HEAD / 204: no body if request.method == "HEAD" or status == 204: total_time = time.time() - start_time logger.info( f" 📤 Returning HEAD/204 response (total: {total_time:.2f}s)" ) return Response(status_code=status, headers=resp_headers) ctype = upstream.headers.get("content-type", "") # If JSON, buffer minimally and pass through bytes unchanged if ctype.startswith("application/json"): json_start = time.time() data = await upstream.aread() json_time = time.time() - json_start total_time = time.time() - start_time logger.info( f" 📤 Returning JSON response: {len(data)} bytes (json: {json_time:.2f}s, total: {total_time:.2f}s)" ) return Response( content=data, status_code=status, headers=resp_headers, media_type=ctype, ) # Otherwise stream raw chunks async def _aiter(): chunk_count = 0 total_bytes = 0 async for chunk in upstream.aiter_raw(): if chunk: chunk_count += 1 total_bytes += len(chunk) yield chunk # be nice to the event loop await asyncio.sleep(0) logger.info(f" 📤 Streamed {chunk_count} chunks, {total_bytes} bytes") return StreamingResponse( _aiter(), status_code=status, headers=resp_headers, media_type=ctype or None, ) except httpx.ConnectTimeout: total_time = time.time() - start_time logger.error(f" ❌ Connect timeout after {total_time:.2f}s") raise HTTPException(status_code=504, detail="Chroma upstream connect timeout") except httpx.ReadTimeout: total_time = time.time() - start_time logger.error(f" ❌ Read timeout after {total_time:.2f}s") raise HTTPException(status_code=504, detail="Chroma upstream read timeout") except httpx.ConnectError as e: total_time = time.time() - start_time logger.error(f" ❌ Connect error after {total_time:.2f}s: {e}") raise HTTPException( status_code=502, detail=f"Chroma upstream connect error: {e}" ) except httpx.TransportError as e: total_time = time.time() - start_time logger.error(f" ❌ Transport error after {total_time:.2f}s: {e}") raise HTTPException( status_code=502, detail=f"Chroma upstream transport error: {e}" ) except Exception as e: total_time = time.time() - start_time logger.error(f" ❌ Unexpected error after {total_time:.2f}s: {e}") raise HTTPException(status_code=500, detail=f"Internal proxy error: {e}") if __name__ == "__main__": print("🚀 Starting ChromaDB Auth Proxy") print(f" Proxy URL: http://0.0.0.0:{PROXY_PORT}") print(f" ChromaDB URL: http://{CHROMA_HOST}:{CHROMA_PORT}") print( f" Timeouts: connect={int(TIMEOUT_CONNECT)}s, read={int(TIMEOUT_READ)}s, write={int(TIMEOUT_WRITE)}s" ) print(f" Logging: INFO level") logger.info("ChromaDB Auth Proxy starting up") uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT)