import gradio as gr import numpy as np import random import os import tempfile from PIL import Image, ImageOps import pillow_heif # For HEIF/AVIF support import io import fal_client import base64 # --- Constants --- MAX_SEED = np.iinfo(np.int32).max def load_client(): """Initialize the FAL Client through HF""" # Register HEIF opener with PIL for AVIF/HEIF support pillow_heif.register_heif_opener() # Get token from environment variable hf_token = os.getenv("HF_TOKEN") if not hf_token: raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.") # Set the HF token for fal_client to use HF routing os.environ["FAL_KEY"] = hf_token return True def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None): """Send request using fal_client""" load_client() if progress_callback: progress_callback(0.1, "Submitting request...") # Convert image bytes to base64 image_base64 = base64.b64encode(image_bytes).decode('utf-8') # Create a temporary file for the image with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: temp_file.write(image_bytes) temp_file_path = temp_file.name def on_queue_update(update): if isinstance(update, fal_client.InProgress): for log in update.logs: print(f"FAL Log: {log['message']}") if progress_callback: progress_callback(0.5, f"Processing: {log['message'][:50]}...") try: if progress_callback: progress_callback(0.3, "Connecting to FAL API...") # Use fal_client.subscribe following the pattern you provided result = fal_client.subscribe( "fal-ai/flux-kontext/dev", arguments={ "prompt": prompt, "image_url": f"data:image/png;base64,{image_base64}", "seed": seed, "guidance_scale": guidance_scale, "num_inference_steps": steps, }, with_logs=True, on_queue_update=on_queue_update, ) print(f"FAL Result: {result}") if progress_callback: progress_callback(0.9, "Processing result...") # Handle the result if isinstance(result, dict): if 'images' in result and len(result['images']) > 0: # Get the first image image_info = result['images'][0] if isinstance(image_info, dict) and 'url' in image_info: # Download image from URL import requests img_response = requests.get(image_info['url']) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content else: raise gr.Error(f"Failed to download result image: {img_response.status_code}") elif isinstance(image_info, str): # Direct URL import requests img_response = requests.get(image_info) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content elif 'image' in result: # Single image field if isinstance(result['image'], dict) and 'url' in result['image']: import requests img_response = requests.get(result['image']['url']) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content elif isinstance(result['image'], str): # Could be URL or base64 if result['image'].startswith('http'): import requests img_response = requests.get(result['image']) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content else: # Assume base64 try: if progress_callback: progress_callback(1.0, "Complete!") return base64.b64decode(result['image']) except: pass elif 'url' in result: # Direct URL in result import requests img_response = requests.get(result['url']) if img_response.status_code == 200: if progress_callback: progress_callback(1.0, "Complete!") return img_response.content # If we get here, the result format is unexpected raise gr.Error(f"Unexpected result format from FAL API: {result}") except Exception as e: raise gr.Error(f"FAL API error: {str(e)}") finally: # Clean up temporary file try: os.unlink(temp_file_path) except: pass # --- Core Inference Function for ChatInterface --- def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): """ Performs image generation or editing based on user input from the chat interface. """ prompt = message["text"] files = message["files"] if not prompt and not files: raise gr.Error("Please provide a prompt and/or upload an image.") if randomize_seed: seed = random.randint(0, MAX_SEED) if files: print(f"Received image: {files[0]}") try: # Try to open and convert the image input_image = Image.open(files[0]) # Convert to RGB if needed (handles RGBA, P, etc.) if input_image.mode != "RGB": input_image = input_image.convert("RGB") # Auto-orient the image based on EXIF data input_image = ImageOps.exif_transpose(input_image) # Convert PIL image to bytes img_byte_arr = io.BytesIO() input_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) image_bytes = img_byte_arr.getvalue() except Exception as e: raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).") progress(0.1, desc="Processing image...") else: # For text-to-image, we need a placeholder image or handle differently # FLUX.1 Kontext is primarily an image-to-image model raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.") try: # Make API request result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress) # Try to convert response bytes to PIL Image try: image = Image.open(io.BytesIO(result_bytes)) except Exception as img_error: print(f"Failed to open image: {img_error}") print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}") # Try to decode as base64 if direct opening failed try: import base64 decoded_bytes = base64.b64decode(result_bytes) image = Image.open(io.BytesIO(decoded_bytes)) except: raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}") progress(1.0, desc="Complete!") return gr.Image(value=image) except gr.Error: # Re-raise gradio errors as-is raise except Exception as e: raise gr.Error(f"Failed to generate image: {str(e)}") # --- UI Definition using gr.ChatInterface --- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5) steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1) demo = gr.ChatInterface( fn=chat_fn, title="FLUX.1 Kontext [dev] - FAL Client", description="""
A simple chat UI for the FLUX.1 Kontext [dev] model using FAL AI client through Hugging Face.
Upload an image and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
This model specializes in understanding context and making precise edits to your images.
Find the model on Hugging Face.
Note: Uses HF_TOKEN environment variable through HF inference providers.