chetanganatra commited on
Commit
5f12973
·
verified ·
1 Parent(s): d5a989f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import Response
3
+ from pydantic import BaseModel, Field
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ import io
9
+ import base64
10
+ from PIL import Image
11
+ import uvicorn
12
+ import os
13
+
14
+ # Initialize FastAPI app
15
+ app = FastAPI(title="FLUX.1 Image Generation API", version="1.0.0")
16
+
17
+ # Configuration
18
+ dtype = torch.bfloat16
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ MAX_IMAGE_SIZE = 2048
22
+
23
+ # Load the model
24
+ print("Loading FLUX.1 model...")
25
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
26
+ print("Model loaded successfully!")
27
+
28
+ # Request models
29
+ class ImageGenerationRequest(BaseModel):
30
+ prompt: str = Field(..., description="Text prompt for image generation")
31
+ seed: int = Field(default=42, ge=0, le=MAX_SEED, description="Random seed for generation")
32
+ randomize_seed: bool = Field(default=False, description="Whether to randomize the seed")
33
+ width: int = Field(default=1024, ge=256, le=MAX_IMAGE_SIZE, description="Image width")
34
+ height: int = Field(default=1024, ge=256, le=MAX_IMAGE_SIZE, description="Image height")
35
+ num_inference_steps: int = Field(default=4, ge=1, le=50, description="Number of inference steps")
36
+ return_format: str = Field(default="base64", description="Return format: 'base64' or 'bytes'")
37
+
38
+ class ImageGenerationResponse(BaseModel):
39
+ image: str = Field(..., description="Generated image in base64 format")
40
+ seed: int = Field(..., description="Seed used for generation")
41
+ success: bool = Field(default=True, description="Whether generation was successful")
42
+
43
+ # Helper functions
44
+ def pil_to_base64(image: Image.Image) -> str:
45
+ """Convert PIL Image to base64 string"""
46
+ buffered = io.BytesIO()
47
+ image.save(buffered, format="PNG")
48
+ img_str = base64.b64encode(buffered.getvalue()).decode()
49
+ return img_str
50
+
51
+ def generate_image(
52
+ prompt: str,
53
+ seed: int = 42,
54
+ randomize_seed: bool = False,
55
+ width: int = 1024,
56
+ height: int = 1024,
57
+ num_inference_steps: int = 4
58
+ ):
59
+ """Generate image using FLUX.1 model"""
60
+ if randomize_seed:
61
+ seed = random.randint(0, MAX_SEED)
62
+
63
+ # Ensure width and height are multiples of 32
64
+ width = (width // 32) * 32
65
+ height = (height // 32) * 32
66
+
67
+ generator = torch.Generator().manual_seed(seed)
68
+
69
+ try:
70
+ image = pipe(
71
+ prompt=prompt,
72
+ width=width,
73
+ height=height,
74
+ num_inference_steps=num_inference_steps,
75
+ generator=generator,
76
+ guidance_scale=0.0
77
+ ).images[0]
78
+
79
+ return image, seed
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
82
+
83
+ # API endpoints
84
+ @app.get("/")
85
+ async def root():
86
+ """Root endpoint with API information"""
87
+ return {
88
+ "message": "FLUX.1 Image Generation API",
89
+ "version": "1.0.0",
90
+ "model": "black-forest-labs/FLUX.1-schnell",
91
+ "endpoints": {
92
+ "generate": "/generate",
93
+ "generate_image": "/generate/image",
94
+ "health": "/health"
95
+ }
96
+ }
97
+
98
+ @app.get("/health")
99
+ async def health_check():
100
+ """Health check endpoint"""
101
+ return {"status": "healthy", "device": device, "model_loaded": True}
102
+
103
+ @app.post("/generate", response_model=ImageGenerationResponse)
104
+ async def generate_image_endpoint(request: ImageGenerationRequest):
105
+ """Generate image and return as base64"""
106
+ try:
107
+ image, used_seed = generate_image(
108
+ prompt=request.prompt,
109
+ seed=request.seed,
110
+ randomize_seed=request.randomize_seed,
111
+ width=request.width,
112
+ height=request.height,
113
+ num_inference_steps=request.num_inference_steps
114
+ )
115
+
116
+ # Convert to base64
117
+ image_base64 = pil_to_base64(image)
118
+
119
+ return ImageGenerationResponse(
120
+ image=image_base64,
121
+ seed=used_seed,
122
+ success=True
123
+ )
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
127
+ @app.post("/generate/image")
128
+ async def generate_image_bytes(request: ImageGenerationRequest):
129
+ """Generate image and return as bytes"""
130
+ try:
131
+ image, used_seed = generate_image(
132
+ prompt=request.prompt,
133
+ seed=request.seed,
134
+ randomize_seed=request.randomize_seed,
135
+ width=request.width,
136
+ height=request.height,
137
+ num_inference_steps=request.num_inference_steps
138
+ )
139
+
140
+ # Convert to bytes
141
+ img_byte_arr = io.BytesIO()
142
+ image.save(img_byte_arr, format='PNG')
143
+ img_byte_arr.seek(0)
144
+
145
+ return Response(
146
+ content=img_byte_arr.getvalue(),
147
+ media_type="image/png",
148
+ headers={"X-Generated-Seed": str(used_seed)}
149
+ )
150
+ except Exception as e:
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
+ if __name__ == "__main__":
154
+ port = int(os.environ.get("PORT", 7860))
155
+ uvicorn.run(app, host="0.0.0.0", port=port)