import spaces import os import gradio as gr import torch import safetensors from huggingface_hub import hf_hub_download from diffusers.utils import load_image, check_min_version from controlnet_flux import FluxControlNetModel from transformer_flux import FluxTransformer2DModel from pipeline_flux_cnet import FluxControlNetInpaintingPipeline from PIL import Image, ImageDraw import numpy as np import subprocess from transformers import T5EncoderModel from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) HF_TOKEN = os.getenv("HF_TOKEN") # Ensure that the minimal version of diffusers is installed check_min_version("0.30.2") quant_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, ) text_encoder_2_4bit = T5EncoderModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", quantization_config=quant_config, torch_dtype=torch.bfloat16, token=HF_TOKEN ) # quant_config = DiffusersBitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_use_double_quant=True, # ) transformerx = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16, token=HF_TOKEN ) # text_encoder_8bit = T5EncoderModel.from_pretrained( # "black-forest-labs/FLUX.1-dev", # subfolder="text_encoder_2", # quantization_config=quant_config, # torch_dtype=torch.bfloat16, # use_safetensors=True, # token=HF_TOKEN # ) # Build pipeline controlnet = FluxControlNetModel.from_pretrained( "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", # subfolder="controlnet", torch_dtype=torch.bfloat16, token=HF_TOKEN ) pipe = FluxControlNetInpaintingPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", controlnet=controlnet, # text_encoder_2=text_encoder_8bit, transformer=transformerx, torch_dtype=torch.bfloat16, # device_map="balanced", token=HF_TOKEN ) # pipe.text_encoder_2 = text_encoder_2_4bit # pipe.transformer = transformer_4bit pipe.transformer.to(torch.bfloat16) pipe.controlnet.to(torch.bfloat16) pipe.to("cuda") pipe.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name="turbo") pipe.set_adapters(["turbo"], adapter_weights=[0.95]) pipe.fuse_lora(lora_scale=1) pipe.unload_lora_weights() # We can utilize the enable_group_offload method for Diffusers model implementations # pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) # For any other model implementations, the apply_group_offloading function can be used # pipe.push_to_hub("FLUX.1-Inpainting-8step_uncensored", private=True, token=HF_TOKEN) # pipe.enable_vae_tiling() # pipe.enable_model_cpu_offload() print(pipe.hf_device_map) def create_mask_from_editor(editor_value): """ Create a mask from the ImageEditor value. Args: editor_value: Dictionary from EditorValue with 'background', 'layers', and 'composite' Returns: PIL Image with white mask """ # The 'composite' key contains the final image with all layers applied composite_image = editor_value['composite'] # Convert to numpy array composite_array = np.array(composite_image) # Create mask where the composite image is white mask_array = np.all(composite_array == (255, 255, 255), axis=-1).astype(np.uint8) * 255 mask_image = Image.fromarray(mask_array) return mask_image def create_mask_on_image(image, xyxy): """ Create a white mask on the image given xyxy coordinates. Args: image: PIL Image xyxy: List of [x1, y1, x2, y2] coordinates Returns: PIL Image with white mask """ # Convert to numpy array img_array = np.array(image) # Create mask mask = Image.new('RGB', image.size, (0, 0, 0)) draw = ImageDraw.Draw(mask) # Draw white rectangle draw.rectangle(xyxy, fill=(255, 255, 255)) # Convert mask to array mask_array = np.array(mask) # Apply mask to image masked_array = np.where(mask_array == 255, 255, img_array) return Image.fromarray(mask_array), Image.fromarray(masked_array) def create_diptych_image(image): # Create a diptych image with original on left and black on right width, height = image.size diptych = Image.new('RGB', (width * 2, height), 'black') diptych.paste(image, (0, 0)) return diptych @spaces.GPU(duration=120) def inpaint_image(image, prompt, subject, editor_value): # Load image and mask size = (1536, 768) image = load_image(image).convert("RGB").resize((768, 768)) diptych_image = create_diptych_image(image) # mask = load_image(mask_path).convert("RGB").resize(size) # mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400]) mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768]) generator = torch.Generator(device="cuda").manual_seed(24) # Load and preprocess image # Calculate attention scale mask attn_scale_factor = 1.5 # Create a tensor of ones with same size as diptych image H, W = size[1]//16, size[0]//16 attn_scale_mask = torch.zeros(size[1], size[0]) attn_scale_mask[:, 768:] = 1.0 # height, width attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten() attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W) # Get inverted attention mask by subtracting from 1.0 transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2) cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask) cross_attn_region = cross_attn_region * attn_scale_factor cross_attn_region[cross_attn_region < 1.0] = 1.0 full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W) full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region # Convert to bfloat16 to match model dtype full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16) subject_name=subject target_text_prompt=prompt prompt_final=f'A two side-by-side image of {subject_name}. LEFT: a photo of {subject_name}; RIGHT: a photo of {subject_name} {target_text_prompt}.' # Convert attention mask to PIL image format # Take first head's mask after prompt tokens (shape is now H*W x H*W) attn_vis = full_attn_scale_mask[0, 0] attn_vis[attn_vis <= 1.0] = 0 attn_vis[attn_vis > 1.0] = 255 attn_vis = attn_vis.cpu().float().numpy().astype(np.uint8) # # Convert to PIL Image attn_vis_img = Image.fromarray(attn_vis) attn_vis_img.save('attention_mask_vis.png') with torch.inference_mode(): result = pipe( prompt=prompt_final, height=size[1], width=size[0], control_image=diptych_image, control_mask=mask, num_inference_steps=12, generator=generator, controlnet_conditioning_scale=0.7, guidance_scale=1, negative_prompt="", true_guidance_scale=1.0, attn_scale_mask=full_attn_scale_mask, ).images[0] return result, attn_vis_img # Create Gradio interface with structured layout with gr.Blocks() as iface: gr.Markdown("## FLUX Inpainting with Diptych Prompting") gr.Markdown("Upload an image, specify a prompt, and draw a mask on the image. The app will automatically generate the inpainted image.") with gr.Row(): with gr.Column(): with gr.Row(): with gr.Accordion(): input_image = gr.Image(type="filepath", label="Upload Image") with gr.Row(): prompt_preview = gr.Textbox(value="A two side-by-side image of 'subject_name'. LEFT: a photo of 'subject_name'; RIGHT: a photo of 'subject_name' 'target_text_prompt'", interactive=False) subject = gr.Textbox(lines=1, placeholder="Enter your subject", label="Subject") prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt") with gr.Column(): editor_value = gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", visible=False) inpainted_image = gr.Image(type="pil", label="Inpainted Image") attn_vis_img = gr.Image(type="pil", label="Attn Vis Image") with gr.Row(): inpaint_button = gr.Button("Inpaint") inpaint_button.click(fn=inpaint_image, inputs=[input_image, prompt, subject, editor_value], outputs=[inpainted_image, attn_vis_img]) # Launch the app iface.launch()