Spaces:
Running
Running
import os | |
import httpx | |
import json | |
from fastapi import FastAPI, Request, HTTPException, Response, Depends | |
from fastapi.security import APIKeyHeader | |
from fastapi.responses import StreamingResponse, JSONResponse | |
import logging | |
from contextlib import asynccontextmanager | |
import typing | |
import itertools # For key rotation | |
import asyncio # For potential sleep during retry | |
# --- Configuration --- | |
# --- Client Authentication (Proxy Access) --- | |
# Load Allowed Client API Keys (for clients talking to this proxy) | |
ALLOWED_API_KEYS_STR = os.getenv("ALLOWED_API_KEYS") | |
if not ALLOWED_API_KEYS_STR: | |
raise ValueError("REQUIRED: ALLOWED_API_KEYS environment variable (comma-separated keys for clients) not set.") | |
ALLOWED_KEYS = set(key.strip() for key in ALLOWED_API_KEYS_STR.split(',') if key.strip()) | |
if not ALLOWED_KEYS: | |
raise ValueError("ALLOWED_API_KEYS must contain at least one non-empty key.") | |
logging.info(f"Loaded {len(ALLOWED_KEYS)} allowed client API keys.") | |
# --- Upstream API Configuration --- | |
# URL to fetch upstream API keys from (one key per line) | |
UPSTREAM_KEYS_URL = os.getenv("UPSTREAM_KEYS_URL") | |
# Optional: A single fallback/default upstream key (used if URL fetch fails or isn't provided) | |
# Or required if the upstream target needs a key in a different way sometimes. | |
# Let's make it optional now. | |
DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
# Upstream API Base URL | |
OPENAI_API_BASE = os.getenv("OPENAI_API_BASE", "https://freeaichatplayground.com/api/v1") | |
OPENAI_CHAT_ENDPOINT = f"{OPENAI_API_BASE.rstrip('/')}/chat/completions" | |
if not UPSTREAM_KEYS_URL and not DEFAULT_OPENAI_API_KEY: | |
raise ValueError("REQUIRED: Either UPSTREAM_KEYS_URL or OPENAI_API_KEY environment variable must be set for upstream authentication.") | |
# --- Logging --- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- Authentication Dependency (Client -> Proxy) --- | |
api_key_header_auth = APIKeyHeader(name="Authorization", auto_error=False) | |
async def verify_api_key(api_key_header: typing.Optional[str] = Depends(api_key_header_auth)): | |
"""Dependency to verify the client's API key provided to this proxy.""" | |
if not api_key_header: | |
logger.warning("Missing Authorization header from client") | |
raise HTTPException(status_code=401, detail="Missing Authorization header") | |
parts = api_key_header.split() | |
if len(parts) != 2 or parts[0].lower() != "bearer": | |
logger.warning(f"Invalid Authorization header format from client.") | |
raise HTTPException(status_code=401, detail="Invalid Authorization header format. Use 'Bearer YOUR_KEY'.") | |
client_api_key = parts[1] | |
if client_api_key not in ALLOWED_KEYS: | |
truncated_key = client_api_key[:4] + "..." + client_api_key[-4:] if len(client_api_key) > 8 else client_api_key | |
logger.warning(f"Invalid Client API Key received: {truncated_key}") | |
raise HTTPException(status_code=403, detail="Invalid API Key provided") | |
logger.info(f"Client authenticated successfully (Key ending: ...{client_api_key[-4:]})") | |
return client_api_key | |
# --- Key Fetching and Rotation Logic --- | |
async def fetch_upstream_keys(url: str) -> list[str]: | |
"""Fetches keys from the given URL, one key per line.""" | |
keys = [] | |
try: | |
async with httpx.AsyncClient(timeout=15.0) as client: # Use a temporary client | |
logger.info(f"Fetching upstream API keys from: {url}") | |
response = await client.get(url) | |
response.raise_for_status() # Raise exception for 4xx/5xx status codes | |
content = response.text | |
keys = [line.strip() for line in content.splitlines() if line.strip()] | |
logger.info(f"Successfully fetched {len(keys)} upstream API keys.") | |
if not keys: | |
logger.warning(f"No valid keys found at {url}. The response was empty or contained only whitespace.") | |
return keys | |
except httpx.RequestError as e: | |
logger.error(f"Error fetching upstream keys from {url}: {e}") | |
return [] # Return empty list on fetch error | |
except httpx.HTTPStatusError as e: | |
logger.error(f"Error fetching upstream keys from {url}: Status {e.response.status_code}") | |
logger.error(f"Response body: {e.response.text}") | |
return [] # Return empty list on bad status | |
# --- HTTP Client and Key Iterator Management (Lifespan) --- | |
async def lifespan(app: FastAPI): | |
# --- Initialize Upstream Key Iterator --- | |
upstream_keys = [] | |
if UPSTREAM_KEYS_URL: | |
upstream_keys = await fetch_upstream_keys(UPSTREAM_KEYS_URL) | |
if not upstream_keys: | |
logger.warning("No upstream keys fetched from URL or URL not provided.") | |
if DEFAULT_OPENAI_API_KEY: | |
logger.info("Using fallback OPENAI_API_KEY for upstream authentication.") | |
upstream_keys = [DEFAULT_OPENAI_API_KEY] | |
else: | |
# Critical failure - no keys available | |
logger.critical("FATAL: No upstream API keys available (URL fetch failed/empty and no fallback OPENAI_API_KEY). Exiting.") | |
# In a real scenario, you might want a more graceful shutdown or retry mechanism | |
# For simplicity here, we'll let it proceed but log critically. The requests will likely fail later. | |
# Or raise an exception here to prevent startup: | |
raise RuntimeError("Failed to load any upstream API keys. Cannot start service.") | |
# Store keys and create the cycling iterator in app.state | |
app.state.upstream_api_keys = upstream_keys | |
app.state.key_iterator = itertools.cycle(upstream_keys) | |
logger.info(f"Initialized key rotation with {len(upstream_keys)} keys.") | |
# --- Initialize HTTPX Client --- | |
logger.info("Initializing main HTTPX client...") | |
timeout = httpx.Timeout(5.0, read=180.0, write=5.0, connect=5.0) | |
client = httpx.AsyncClient(timeout=timeout) # No base_url needed if using full URLs | |
app.state.http_client = client # Store client in app.state | |
logger.info("HTTPX client initialized.") | |
yield # Application runs here | |
# --- Cleanup --- | |
logger.info("Closing HTTPX client...") | |
await app.state.http_client.aclose() | |
logger.info("HTTPX client closed.") | |
app.state.upstream_api_keys = [] # Clear keys | |
app.state.key_iterator = None | |
logger.info("Upstream keys cleared.") | |
# --- FastAPI App --- | |
app = FastAPI(lifespan=lifespan) | |
# --- Streaming Helper --- | |
async def yield_openai_chunks(response: httpx.Response): | |
"""Asynchronously yields chunks from the upstream response stream.""" | |
# (Content remains the same as before) | |
logger.info("Starting to stream chunks from upstream...") | |
try: | |
buffer = '' | |
async for chunk in response.aiter_bytes(): | |
buffer += chunk.decode() | |
while True: | |
index = buffer.find("\n") | |
if index != -1: | |
content = buffer[:index] | |
buffer = buffer[index + 1:] | |
if content.startswith("0:"): | |
content = content[2:][1:-1].replace("\\n","\n").replace('\\"','"').replace("\\\\","\\") | |
data = {"id":"123456-456789-123456","object":"chat.completion.chunk","choices":[{"delta":{"content":content},"index":0,"finish_reason":None}]} | |
yield "data: " + json.dumps(data) + "\n\n" | |
else: | |
break | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
logger.error(f"Error during streaming upstream response: {e}") | |
finally: | |
await response.aclose() | |
logger.info("Upstream streaming response closed.") | |
# --- Proxy Endpoint --- | |
async def proxy_openai_chat(request: Request, _client_key: str = Depends(verify_api_key)): # Use Depends for auth | |
""" | |
Proxies requests to the configured Chat Completions endpoint AFTER verifying client API key. | |
Uses rotated keys for upstream authentication. | |
""" | |
client: httpx.AsyncClient = request.app.state.http_client | |
key_iterator = request.app.state.key_iterator | |
if not client or not key_iterator: | |
logger.error("HTTPX client or Key Iterator not available (app state issue).") | |
raise HTTPException(status_code=503, detail="Service temporarily unavailable") | |
# --- Get Next Upstream API Key --- | |
try: | |
current_upstream_key = next(key_iterator) | |
# Log rotation (optional, consider security of logging key info) | |
# logger.info(f"Using upstream key ending: ...{current_upstream_key[-4:]}") | |
except StopIteration: | |
# This should not happen if lifespan logic is correct and keys were loaded | |
logger.error("Upstream key iterator exhausted unexpectedly.") | |
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") | |
except Exception as e: | |
logger.error(f"Unexpected error getting next key: {e}") | |
raise HTTPException(status_code=500, detail="Internal Server Error: Key rotation failed") | |
# --- Get Request Data --- | |
try: | |
request_body = await request.body() | |
payload = json.loads(request_body) | |
config = {} | |
if payload.get("temperature") is not None: | |
config["temperature"] = payload.get("temperature") | |
del payload["temperature"] | |
if payload.get("max_tokens") is not None: | |
config["maxTokens"] = payload.get("max_tokens") | |
del payload["max_tokens"] | |
if payload.get("top_p") is not None: | |
config["topP"] = payload.get("top_p") | |
del payload["top_p"] | |
if len(config) > 0: | |
payload["config"] = config | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid JSON body") | |
is_streaming = payload.get("stream", False) | |
# --- Prepare Upstream Request --- | |
upstream_headers = { | |
"Content-Type": request.headers.get("Content-Type", "application/json"), | |
"Accept": request.headers.get("Accept", "application/json"), | |
} | |
# --- Upstream Authentication (Using Rotated Key) --- | |
# Decide based on the target API (e.g., freeaichatplayground vs standard OpenAI) | |
if "freeaichatplayground.com" in OPENAI_API_BASE: | |
logger.debug("Using payload apiKey for upstream authentication (freeaichatplayground specific).") | |
payload["apiKey"] = current_upstream_key # Inject ROTATED key into payload | |
else: | |
# Default to standard Bearer token authentication for upstream | |
logger.debug("Using Authorization header for upstream authentication.") | |
upstream_headers["Authorization"] = f"Bearer {current_upstream_key}" # Use ROTATED key | |
if is_streaming and "text/event-stream" not in upstream_headers["Accept"]: | |
logger.info("Adding 'Accept: text/event-stream' for streaming request") | |
upstream_headers["Accept"] = "text/event-stream, application/json" | |
logger.info(f"Forwarding request to {OPENAI_CHAT_ENDPOINT} (Streaming: {is_streaming})") | |
# --- Make Request to Upstream --- | |
response = None # Define response here to ensure it's available in finally block | |
try: | |
req = client.build_request( | |
"POST", | |
OPENAI_CHAT_ENDPOINT, # Use the full URL | |
json=payload, | |
headers=upstream_headers, | |
) | |
response = await client.send(req, stream=True) | |
# Check for immediate errors *before* processing body/stream | |
if response.status_code >= 400: | |
error_body = await response.aread() # Read error fully | |
await response.aclose() | |
logger.error(f"Upstream API returned error: {response.status_code} Key ending: ...{current_upstream_key[-4:]} Body: {error_body.decode()}") | |
try: detail = json.loads(error_body) | |
except json.JSONDecodeError: detail = error_body.decode() | |
raise HTTPException(status_code=response.status_code, detail=detail) | |
# --- Handle Streaming Response --- | |
if is_streaming: | |
logger.info(f"Received OK streaming response from upstream (Status: {response.status_code}). Piping to client.") | |
return StreamingResponse( | |
yield_openai_chunks(response), # Generator handles closing response | |
status_code=response.status_code, | |
media_type=response.headers.get("content-type", "text/event-stream"), | |
) | |
# --- Handle Non-Streaming Response --- | |
else: | |
logger.info(f"Received OK non-streaming response from upstream (Status: {response.status_code}). Reading full body.") | |
response_body = await response.aread() | |
await response.aclose() # Ensure closed | |
content_type = response.headers.get("content-type", "application/json") | |
return Response( # Return raw response, FastAPI handles JSON content type | |
content=response_body, | |
status_code=response.status_code, | |
media_type=content_type, | |
) | |
except httpx.TimeoutException as e: | |
logger.error(f"Request to upstream timed out: {e}") | |
if response: await response.aclose() | |
raise HTTPException(status_code=504, detail="Request to upstream API timed out.") | |
except httpx.RequestError as e: | |
logger.error(f"Error requesting upstream API: {e}") | |
if response: await response.aclose() | |
raise HTTPException(status_code=502, detail=f"Error contacting upstream API: {e}") | |
except HTTPException as e: | |
# Re-raise FastAPI HTTPExceptions (like the 4xx check above) | |
if response and not response.is_closed: await response.aclose() | |
raise e | |
except Exception as e: | |
logger.exception("An unexpected error occurred during response processing.") | |
if response and not response.is_closed: await response.aclose() | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
# --- Health Check Endpoint --- | |
async def health_check(): | |
"""Simple health check endpoint.""" | |
# Could add checks here, e.g., if keys were loaded | |
key_count = len(app.state.upstream_api_keys) if hasattr(app.state, 'upstream_api_keys') else 0 | |
return {"status": "ok", "upstream_keys_loaded": key_count > 0, "key_count": key_count} | |
# --- Main Execution Guard --- | |
if __name__ == "__main__": | |
import uvicorn | |
# Startup checks are implicitly handled by config loading at the top | |
print("--- Starting FastAPI OpenAI Proxy with Custom Auth & Key Rotation ---") | |
print(f"Proxying requests to: {OPENAI_CHAT_ENDPOINT}") | |
if UPSTREAM_KEYS_URL: | |
print(f"Fetching upstream keys from: {UPSTREAM_KEYS_URL}") | |
elif DEFAULT_OPENAI_API_KEY: | |
print("Using single OPENAI_API_KEY for upstream.") | |
else: | |
print("ERROR: No upstream key source configured!") # Should have failed earlier | |
print(f"Clients must provide a valid API key in 'Authorization: Bearer <key>' header.") | |
print(f"Number of allowed client keys configured: {len(ALLOWED_KEYS)}") | |
print("---") | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |