File size: 9,993 Bytes
1bafe30
 
 
9231de3
d6ceac3
d1b130d
 
d6ceac3
a3c7c9b
 
1bafe30
920a718
1bafe30
 
d6ceac3
a3c7c9b
d6ceac3
 
 
 
f5f7379
 
 
 
a3c7c9b
 
 
d1b130d
d6ceac3
a3c7c9b
17cc4e0
a3c7c9b
17cc4e0
fcf74fc
 
 
a3c7c9b
d6ceac3
17cc4e0
a3c7c9b
 
 
 
fc5bd53
a3c7c9b
 
 
 
 
 
fc5bd53
d6ceac3
a3c7c9b
 
d6ceac3
a3c7c9b
 
 
 
 
 
 
 
 
 
 
 
 
d6ceac3
a3c7c9b
d6ceac3
a3c7c9b
 
 
 
 
 
 
 
 
 
 
 
 
fcf74fc
a3c7c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09f3aa3
 
 
 
a3c7c9b
 
 
 
 
 
 
 
 
 
 
 
 
5c6ea42
 
a3c7c9b
 
 
 
 
09f3aa3
a3c7c9b
 
 
 
 
 
 
1bafe30
920a718
d1b130d
1bafe30
 
 
 
 
 
 
 
 
 
 
 
 
 
943caab
 
d1b130d
 
 
 
 
 
 
e1f8042
d1b130d
 
f5f7379
e1f8042
f5f7379
943caab
 
d1b130d
 
1bafe30
d6ceac3
 
 
f5f7379
 
d6ceac3
 
f5f7379
90342ab
c847b55
d6ceac3
c847b55
90342ab
d6ceac3
c847b55
 
 
d6ceac3
 
c847b55
 
d6ceac3
f5f7379
d1b130d
f5f7379
 
c847b55
 
 
f5f7379
 
1bafe30
 
 
 
 
 
 
 
 
 
a3c7c9b
1bafe30
a3c7c9b
1bafe30
d6ceac3
1bafe30
d6ceac3
1bafe30
 
a3c7c9b
 
1bafe30
d1b130d
1bafe30
 
d6ceac3
9231de3
1bafe30
 
 
 
 
 
 
 
 
 
 
d1b130d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import numpy as np
import random
import os
import tempfile
from PIL import Image, ImageOps
import pillow_heif  # For HEIF/AVIF support
import io
import fal_client
import base64

# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max

def load_client():
    """Initialize the FAL Client through HF"""
    # Register HEIF opener with PIL for AVIF/HEIF support
    pillow_heif.register_heif_opener()
    
    # Get token from environment variable
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
    
    # Set the HF token for fal_client to use HF routing
    os.environ["FAL_KEY"] = hf_token
    return True

def query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=None):
    """Send request using fal_client"""
    
    load_client()
    
    if progress_callback:
        progress_callback(0.1, "Submitting request...")
    
    # Convert image bytes to base64
    image_base64 = base64.b64encode(image_bytes).decode('utf-8')
    
    # Create a temporary file for the image
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
        temp_file.write(image_bytes)
        temp_file_path = temp_file.name
    
    def on_queue_update(update):
        if isinstance(update, fal_client.InProgress):
            for log in update.logs:
                print(f"FAL Log: {log['message']}")
                if progress_callback:
                    progress_callback(0.5, f"Processing: {log['message'][:50]}...")
    
    try:
        if progress_callback:
            progress_callback(0.3, "Connecting to FAL API...")
        
        # Use fal_client.subscribe following the pattern you provided
        result = fal_client.subscribe(
            "fal-ai/flux-kontext/dev",
            arguments={
                "prompt": prompt,
                "image_url": f"data:image/png;base64,{image_base64}",
                "seed": seed,
                "guidance_scale": guidance_scale,
                "num_inference_steps": steps,
            },
            with_logs=True,
            on_queue_update=on_queue_update,
        )
        
        print(f"FAL Result: {result}")
        
        if progress_callback:
            progress_callback(0.9, "Processing result...")
        
        # Handle the result
        if isinstance(result, dict):
            if 'images' in result and len(result['images']) > 0:
                # Get the first image
                image_info = result['images'][0]
                if isinstance(image_info, dict) and 'url' in image_info:
                    # Download image from URL
                    import requests
                    img_response = requests.get(image_info['url'])
                    if img_response.status_code == 200:
                        if progress_callback:
                            progress_callback(1.0, "Complete!")
                        return img_response.content
                    else:
                        raise gr.Error(f"Failed to download result image: {img_response.status_code}")
                elif isinstance(image_info, str):
                    # Direct URL
                    import requests
                    img_response = requests.get(image_info)
                    if img_response.status_code == 200:
                        if progress_callback:
                            progress_callback(1.0, "Complete!")
                        return img_response.content
            elif 'image' in result:
                # Single image field
                if isinstance(result['image'], dict) and 'url' in result['image']:
                    import requests
                    img_response = requests.get(result['image']['url'])
                    if img_response.status_code == 200:
                        if progress_callback:
                            progress_callback(1.0, "Complete!")
                        return img_response.content
                elif isinstance(result['image'], str):
                    # Could be URL or base64
                    if result['image'].startswith('http'):
                        import requests
                        img_response = requests.get(result['image'])
                        if img_response.status_code == 200:
                            if progress_callback:
                                progress_callback(1.0, "Complete!")
                            return img_response.content
                    else:
                        # Assume base64
                        try:
                            if progress_callback:
                                progress_callback(1.0, "Complete!")
                            return base64.b64decode(result['image'])
                        except:
                            pass
            elif 'url' in result:
                # Direct URL in result
                import requests
                img_response = requests.get(result['url'])
                if img_response.status_code == 200:
                    if progress_callback:
                        progress_callback(1.0, "Complete!")
                    return img_response.content
        
        # If we get here, the result format is unexpected
        raise gr.Error(f"Unexpected result format from FAL API: {result}")
        
    except Exception as e:
        raise gr.Error(f"FAL API error: {str(e)}")
    finally:
        # Clean up temporary file
        try:
            os.unlink(temp_file_path)
        except:
            pass

# --- Core Inference Function for ChatInterface ---
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
    """
    Performs image generation or editing based on user input from the chat interface.
    """
    prompt = message["text"]
    files = message["files"]

    if not prompt and not files:
        raise gr.Error("Please provide a prompt and/or upload an image.")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    if files:
        print(f"Received image: {files[0]}")
        try:
            # Try to open and convert the image
            input_image = Image.open(files[0])
            # Convert to RGB if needed (handles RGBA, P, etc.)
            if input_image.mode != "RGB":
                input_image = input_image.convert("RGB")
            # Auto-orient the image based on EXIF data
            input_image = ImageOps.exif_transpose(input_image)
            
            # Convert PIL image to bytes
            img_byte_arr = io.BytesIO()
            input_image.save(img_byte_arr, format='PNG')
            img_byte_arr.seek(0)
            image_bytes = img_byte_arr.getvalue()
            
        except Exception as e:
            raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
            
        progress(0.1, desc="Processing image...")
    else:
        # For text-to-image, we need a placeholder image or handle differently
        # FLUX.1 Kontext is primarily an image-to-image model
        raise gr.Error("This model (FLUX.1 Kontext) requires an input image. Please upload an image to edit.")

    try:
        # Make API request
        result_bytes = query_api(image_bytes, prompt, seed, guidance_scale, steps, progress_callback=progress)
        
        # Try to convert response bytes to PIL Image
        try:
            image = Image.open(io.BytesIO(result_bytes))
        except Exception as img_error:
            print(f"Failed to open image: {img_error}")
            print(f"Image bytes type: {type(result_bytes)}, length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
            
            # Try to decode as base64 if direct opening failed
            try:
                import base64
                decoded_bytes = base64.b64decode(result_bytes)
                image = Image.open(io.BytesIO(decoded_bytes))
            except:
                raise gr.Error(f"Could not process API response as image. Response length: {len(result_bytes) if hasattr(result_bytes, '__len__') else 'unknown'}")
        
        progress(1.0, desc="Complete!")
        return gr.Image(value=image)
        
    except gr.Error:
        # Re-raise gradio errors as-is
        raise
    except Exception as e:
        raise gr.Error(f"Failed to generate image: {str(e)}")

# --- UI Definition using gr.ChatInterface ---

seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)

demo = gr.ChatInterface(
    fn=chat_fn,
    title="FLUX.1 Kontext [dev] - FAL Client",
    description="""<p style='text-align: center;'>
    A simple chat UI for the <b>FLUX.1 Kontext [dev]</b> model using FAL AI client through Hugging Face.
    <br>
    <b>Upload an image</b> and type your editing instructions (e.g., "Turn the cat into a tiger", "Add a hat").
    <br>
    This model specializes in understanding context and making precise edits to your images.
    <br>
    Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
    <br>
    <b>Note:</b> Uses HF_TOKEN environment variable through HF inference providers.
    </p>""",
    multimodal=True,
    textbox=gr.MultimodalTextbox(
        file_types=["image"],
        placeholder="Upload an image and type your editing instructions...",
        render=False
    ),
    additional_inputs=[
        seed_slider,
        randomize_checkbox,
        guidance_slider,
        steps_slider
    ],
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()