import gradio as gr import numpy as np import random import os import tempfile from PIL import Image, ImageOps import pillow_heif # For HEIF/AVIF support import io # --- Constants --- MAX_SEED = np.iinfo(np.int32).max def load_client(): """Initialize the Inference Client""" # Register HEIF opener with PIL for AVIF/HEIF support pillow_heif.register_heif_opener() # Get token from environment variable hf_token = os.getenv("HF_TOKEN") if not hf_token: raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.") return hf_token def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None): """Send request to the API using HF Router for fal.ai provider""" import requests import json import base64 hf_token = load_client() if progress_callback: progress_callback(0.1, "Submitting request...") # Use the HF router to access fal.ai provider url = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev" headers = { "Authorization": f"Bearer {hf_token}", "X-HF-Bill-To": "huggingface", "Content-Type": "application/json" } # Convert image to base64 image_base64 = base64.b64encode(image_bytes).decode('utf-8') # Fixed payload structure - prompt should be at the top level payload = { "prompt": prompt, "inputs": image_base64, "seed": seed, "guidance_scale": guidance_scale, "num_inference_steps": steps } if progress_callback: progress_callback(0.3, "Processing request...") try: response = requests.post(url, headers=headers, json=payload, timeout=300) if response.status_code != 200: raise gr.Error(f"API request failed with status {response.status_code}: {response.text}") # Check if response is image bytes or JSON content_type = response.headers.get('content-type', '').lower() print(f"Response content type: {content_type}") print(f"Response length: {len(response.content)}") if 'image/' in content_type: # Direct image response if progress_callback: progress_callback(1.0, "Complete!") return response.content elif 'application/json' in content_type: # JSON response - might be queue status or result try: json_response = response.json() print(f"JSON response: {json_response}") # Check if it's a queue response if json_response.get("status") == "IN_QUEUE": if progress_callback: progress_callback(0.4, "Request queued, please wait...") raise gr.Error("Request is being processed. Please try again in a few moments.") # Handle immediate completion or result if 'images' in json_response and len(json_response['images']) > 0: image_info = json_response['images'][0] if isinstance(image_info, dict) and 'url' in image_info: # Download image from URL if progress_callback: progress_callback(0.9, "Downloading result...") img_response = requests.get(image_info['url']) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content else: raise gr.Error(f"Failed to download image: {img_response.status_code}") elif isinstance(image_info, str): # Base64 encoded image if progress_callback: progress_callback(1.0, "Complete!") return base64.b64decode(image_info) elif 'image' in json_response: # Single image field if progress_callback: progress_callback(1.0, "Complete!") return base64.b64decode(json_response['image']) else: raise gr.Error(f"Unexpected JSON response format: {json_response}") except json.JSONDecodeError as e: raise gr.Error(f"Failed to parse JSON response: {str(e)}") else: # Try to treat as image bytes if len(response.content) > 1000: # Likely an image if progress_callback: progress_callback(1.0, "Complete!") return response.content else: # Small response, probably an error try: error_text = response.content.decode('utf-8') raise gr.Error(f"Unexpected response: {error_text[:500]}") except: raise gr.Error(f"Unexpected response format. Content length: {len(response.content)}") except requests.exceptions.Timeout: raise gr.Error("Request timed out. Please try again.") except requests.exceptions.RequestException as e: raise gr.Error(f"Request failed: {str(e)}") except gr.Error: # Re-raise Gradio errors as-is raise except Exception as e: raise gr.Error(f"Unexpected error: {str(e)}") # --- Core Inference Function for ChatInterface --- def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): """ Performs image generation or editing based on user input from the chat interface. """ prompt = message["text"] files = message["files"] if not prompt and not files: raise gr.Error("Please provide a prompt and/or upload an image.") if randomize_seed: seed = random.randint(0, MAX_SEED) if files: print(f"Received image: {files[0]}") try: # Try to open and convert the image input_image = Image.open(files[0]) # Convert to RGB if needed (handles RGBA, P, etc.) if input_image.mode != "RGB": input_image = input_image.convert("RGB") # Auto-orient the image based on EXIF data input_image = ImageOps.exif_transpose(input_image) # Convert PIL image to bytes img_byte_arr = io.BytesIO() input_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) image_bytes = img_byte_arr.getvalue() except Exception as e: raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).") progress(0.1, desc="Processing image...") else: # For text-to-image, we need a placeholder image or handle differently # FLUX.1 Kontext is primarily an image-to-image model raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.") try: # Make API request result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress) # Try to convert response bytes to PIL Image try: image = Image.open(io.BytesIO(result_bytes)) except Exception as img_error: print(f"Failed to open image: {img_error}") print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}") # Try to decode as base64 if direct opening failed try: import base64 decoded_bytes = base64.b64decode(result_bytes) image = Image.open(io.BytesIO(decoded_bytes)) except: raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}") progress(1.0, desc="Complete!") return gr.Image(value=image) except gr.Error: # Re-raise gradio errors as-is raise except Exception as e: raise gr.Error(f"Failed to generate image: {str(e)}") # --- UI Definition using gr.ChatInterface --- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5) steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1) demo = gr.ChatInterface( fn=chat_fn, title="FLUX.1 Kontext [dev] - HF Inference Client", description="""
A simple chat UI for the FLUX.1 Kontext [dev] model using Hugging Face Inference Client approach.
Upload an image and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
This model specializes in understanding context and making precise edits to your images.
Find the model on Hugging Face.