akhaliq's picture
akhaliq HF Staff
Update app.py
618f8cb verified
import gradio as gr
import numpy as np
import random
import os
import tempfile
from PIL import Image, ImageOps
import pillow_heif # For HEIF/AVIF support
import io
# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max
def load_client():
"""Initialize the Inference Client"""
# Register HEIF opener with PIL for AVIF/HEIF support
pillow_heif.register_heif_opener()
# Get token from environment variable
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
return hf_token
def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
"""Send request to the API using HF Router for fal.ai provider"""
import requests
import json
import base64
hf_token = load_client()
if progress_callback:
progress_callback(0.1, "Submitting request...")
# Use the HF router to access fal.ai provider
url = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev"
headers = {
"Authorization": f"Bearer {hf_token}",
"X-HF-Bill-To": "huggingface",
"Content-Type": "application/json"
}
# Convert image to base64
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Fixed payload structure - prompt should be at the top level
payload = {
"prompt": prompt,
"inputs": image_base64,
"seed": seed,
"guidance_scale": guidance_scale,
"num_inference_steps": steps
}
if progress_callback:
progress_callback(0.3, "Processing request...")
try:
response = requests.post(url, headers=headers, json=payload, timeout=300)
if response.status_code != 200:
raise gr.Error(f"API request failed with status {response.status_code}: {response.text}")
# Check if response is image bytes or JSON
content_type = response.headers.get('content-type', '').lower()
print(f"Response content type: {content_type}")
print(f"Response length: {len(response.content)}")
if 'image/' in content_type:
# Direct image response
if progress_callback:
progress_callback(1.0, "Complete!")
return response.content
elif 'application/json' in content_type:
# JSON response - might be queue status or result
try:
json_response = response.json()
print(f"JSON response: {json_response}")
# Check if it's a queue response
if json_response.get("status") == "IN_QUEUE":
if progress_callback:
progress_callback(0.4, "Request queued, please wait...")
raise gr.Error("Request is being processed. Please try again in a few moments.")
# Handle immediate completion or result
if 'images' in json_response and len(json_response['images']) > 0:
image_info = json_response['images'][0]
if isinstance(image_info, dict) and 'url' in image_info:
# Download image from URL
if progress_callback:
progress_callback(0.9, "Downloading result...")
img_response = requests.get(image_info['url'])
if img_response.status_code == 200:
if progress_callback:
progress_callback(1.0, "Complete!")
return img_response.content
else:
raise gr.Error(f"Failed to download image: {img_response.status_code}")
elif isinstance(image_info, str):
# Base64 encoded image
if progress_callback:
progress_callback(1.0, "Complete!")
return base64.b64decode(image_info)
elif 'image' in json_response:
# Single image field
if progress_callback:
progress_callback(1.0, "Complete!")
return base64.b64decode(json_response['image'])
else:
raise gr.Error(f"Unexpected JSON response format: {json_response}")
except json.JSONDecodeError as e:
raise gr.Error(f"Failed to parse JSON response: {str(e)}")
else:
# Try to treat as image bytes
if len(response.content) > 1000: # Likely an image
if progress_callback:
progress_callback(1.0, "Complete!")
return response.content
else:
# Small response, probably an error
try:
error_text = response.content.decode('utf-8')
raise gr.Error(f"Unexpected response: {error_text[:500]}")
except:
raise gr.Error(f"Unexpected response format. Content length: {len(response.content)}")
except requests.exceptions.Timeout:
raise gr.Error("Request timed out. Please try again.")
except requests.exceptions.RequestException as e:
raise gr.Error(f"Request failed: {str(e)}")
except gr.Error:
# Re-raise Gradio errors as-is
raise
except Exception as e:
raise gr.Error(f"Unexpected error: {str(e)}")
# --- Core Inference Function for ChatInterface ---
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
"""
Performs image generation or editing based on user input from the chat interface.
"""
prompt = message["text"]
files = message["files"]
if not prompt and not files:
raise gr.Error("Please provide a prompt and/or upload an image.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if files:
print(f"Received image: {files[0]}")
try:
# Try to open and convert the image
input_image = Image.open(files[0])
# Convert to RGB if needed (handles RGBA, P, etc.)
if input_image.mode != "RGB":
input_image = input_image.convert("RGB")
# Auto-orient the image based on EXIF data
input_image = ImageOps.exif_transpose(input_image)
# Convert PIL image to bytes
img_byte_arr = io.BytesIO()
input_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
image_bytes = img_byte_arr.getvalue()
except Exception as e:
raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
progress(0.1, desc="Processing image...")
else:
# For text-to-image, we need a placeholder image or handle differently
# FLUX.1 Kontext is primarily an image-to-image model
raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")
try:
# Make API request
result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
# Try to convert response bytes to PIL Image
try:
image = Image.open(io.BytesIO(result_bytes))
except Exception as img_error:
print(f"Failed to open image: {img_error}")
print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
# Try to decode as base64 if direct opening failed
try:
import base64
decoded_bytes = base64.b64decode(result_bytes)
image = Image.open(io.BytesIO(decoded_bytes))
except:
raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
progress(1.0, desc="Complete!")
return gr.Image(value=image)
except gr.Error:
# Re-raise gradio errors as-is
raise
except Exception as e:
raise gr.Error(f"Failed to generate image: {str(e)}")
# --- UI Definition using gr.ChatInterface ---
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
demo = gr.ChatInterface(
fn=chat_fn,
title="FLUX.1 Kontext [dev] - HF Inference Client",
description="""<p style='text-align: center;'>
A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using Hugging Face Inference Client approach.
<br>
<b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
<br>
This model specializes in understanding context and making precise edits to your images.
<br>
Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
</p>""",
multimodal=True,
textbox=gr.MultimodalTextbox(
file_types=["image"],
placeholder="Upload an image and type your editing instructions...",
render=False
),
additional_inputs=[
seed_slider,
randomize_checkbox,
guidance_slider,
steps_slider
],
theme="soft"
)
if __name__ == "__main__":
demo.launch()