from fastapi import FastAPI, HTTPException from fastapi.responses import Response from pydantic import BaseModel, Field import numpy as np import random import torch from diffusers import DiffusionPipeline import io import base64 from PIL import Image import uvicorn import os # Initialize FastAPI app app = FastAPI(title="FLUX.1 Image Generation API", version="1.0.0") # Configuration dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 # Load the model print("Loading FLUX.1 model...") pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device) print("Model loaded successfully!") # Request models class ImageGenerationRequest(BaseModel): prompt: str = Field(..., description="Text prompt for image generation") seed: int = Field(default=42, ge=0, le=MAX_SEED, description="Random seed for generation") randomize_seed: bool = Field(default=False, description="Whether to randomize the seed") width: int = Field(default=1024, ge=256, le=MAX_IMAGE_SIZE, description="Image width") height: int = Field(default=1024, ge=256, le=MAX_IMAGE_SIZE, description="Image height") num_inference_steps: int = Field(default=4, ge=1, le=50, description="Number of inference steps") return_format: str = Field(default="base64", description="Return format: 'base64' or 'bytes'") class ImageGenerationResponse(BaseModel): image: str = Field(..., description="Generated image in base64 format") seed: int = Field(..., description="Seed used for generation") success: bool = Field(default=True, description="Whether generation was successful") # Helper functions def pil_to_base64(image: Image.Image) -> str: """Convert PIL Image to base64 string""" buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return img_str def generate_image( prompt: str, seed: int = 42, randomize_seed: bool = False, width: int = 1024, height: int = 1024, num_inference_steps: int = 4 ): """Generate image using FLUX.1 model""" if randomize_seed: seed = random.randint(0, MAX_SEED) # Ensure width and height are multiples of 32 width = (width // 32) * 32 height = (height // 32) * 32 generator = torch.Generator().manual_seed(seed) try: image = pipe( prompt=prompt, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0 ).images[0] return image, seed except Exception as e: raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}") # API endpoints @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "FLUX.1 Image Generation API", "version": "1.0.0", "model": "black-forest-labs/FLUX.1-schnell", "endpoints": { "generate": "/generate", "generate_image": "/generate/image", "health": "/health" } } @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "device": device, "model_loaded": True} @app.post("/generate", response_model=ImageGenerationResponse) async def generate_image_endpoint(request: ImageGenerationRequest): """Generate image and return as base64""" try: image, used_seed = generate_image( prompt=request.prompt, seed=request.seed, randomize_seed=request.randomize_seed, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps ) # Convert to base64 image_base64 = pil_to_base64(image) return ImageGenerationResponse( image=image_base64, seed=used_seed, success=True ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate/image") async def generate_image_bytes(request: ImageGenerationRequest): """Generate image and return as bytes""" try: image, used_seed = generate_image( prompt=request.prompt, seed=request.seed, randomize_seed=request.randomize_seed, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps ) # Convert to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return Response( content=img_byte_arr.getvalue(), media_type="image/png", headers={"X-Generated-Seed": str(used_seed)} ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)