import gradio as gr import torch from PIL import Image import os from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from flux.transformer_flux import FluxTransformer2DModel from flux.pipeline_flux_chameleon import FluxPipeline import torch.nn as nn MODEL_ID = "Djrango/Qwen2vl-Flux" class Qwen2Connector(nn.Module): def __init__(self, input_dim=3584, output_dim=4096): super().__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): return self.linear(x) class FluxInterface: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): self.device = device self.dtype = torch.bfloat16 self.models = None self.MODEL_ID = "Djrango/Qwen2vl-Flux" def load_models(self): if self.models is not None: return # Load FLUX components tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer") text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder") text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2") tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2") # Load VAE and transformer from flux folder vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux") transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1) # Load Qwen2VL components from qwen2-vl folder qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl") # Load connector and t5 embedder from qwen2-vl folder connector = Qwen2Connector() connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt" connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location=self.device) connector.load_state_dict(connector_state) # Load T5 embedder self.t5_context_embedder = nn.Linear(4096, 3072) t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt" t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location=self.device) self.t5_context_embedder.load_state_dict(t5_embedder_state) # Move models to device and set dtype models = [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder] for model in models: model.to(self.device).to(self.dtype) model.eval() self.models = { 'tokenizer': tokenizer, 'text_encoder': text_encoder, 'text_encoder_two': text_encoder_two, 'tokenizer_two': tokenizer_two, 'vae': vae, 'transformer': transformer, 'scheduler': scheduler, 'qwen2vl': qwen2vl, 'connector': connector } # Initialize processor and pipeline self.qwen2vl_processor = AutoProcessor.from_pretrained( self.MODEL_ID, subfolder="qwen2-vl", min_pixels=256*28*28, max_pixels=256*28*28 ) self.pipeline = FluxPipeline( transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, ) def resize_image(self, img, max_pixels=1050000): if not isinstance(img, Image.Image): img = Image.fromarray(img) width, height = img.size num_pixels = width * height if num_pixels > max_pixels: scale = math.sqrt(max_pixels / num_pixels) new_width = int(width * scale) new_height = int(height * scale) new_width = new_width - (new_width % 8) new_height = new_height - (new_height % 8) img = img.resize((new_width, new_height), Image.LANCZOS) return img def process_image(self, image): message = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}, ] } ] text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) with torch.no_grad(): inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device) output_hidden_state, image_token_mask, image_grid_thw = self.models['qwen2vl'](**inputs) image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) image_hidden_state = self.models['connector'](image_hidden_state) return image_hidden_state, image_grid_thw def compute_t5_text_embeddings(self, prompt): """Compute T5 embeddings for text prompt""" if prompt == "": return None text_inputs = self.models['tokenizer_two']( prompt, padding="max_length", max_length=256, truncation=True, return_tensors="pt" ).to(self.device) prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0] prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) prompt_embeds = self.t5_context_embedder(prompt_embeds) return prompt_embeds def compute_text_embeddings(self, prompt=""): with torch.no_grad(): text_inputs = self.models['tokenizer']( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ).to(self.device) prompt_embeds = self.models['text_encoder']( text_inputs.input_ids, output_hidden_states=False ) pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype) return pooled_prompt_embeds def generate(self, input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None): try: if seed is not None: torch.manual_seed(seed) self.load_models() # Process input image input_image = self.resize_image(input_image) qwen2_hidden_state, image_grid_thw = self.process_image(input_image) pooled_prompt_embeds = self.compute_text_embeddings("") # Get T5 embeddings if prompt is provided t5_prompt_embeds = self.compute_t5_text_embeddings(prompt) # Generate images output_images = self.pipeline( prompt_embeds=qwen2_hidden_state.repeat(num_images, 1, 1), pooled_prompt_embeds=pooled_prompt_embeds, t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images return output_images except Exception as e: print(f"Error during generation: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") # Initialize the interface interface = FluxInterface() # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎨 Qwen2vl-Flux Image Variation Demo Upload an image and get AI-generated variations. You can optionally add a text prompt to guide the generation. """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Upload Image", type="pil", height=384, width=384, tool="select" ) prompt = gr.Textbox( label="Optional Text Prompt", placeholder="Enter text prompt here (optional)", lines=2 ) with gr.Group(): with gr.Row(equal_height=True): with gr.Column(scale=1): guidance = gr.Slider( minimum=1, maximum=10, value=3.5, step=0.5, label="Guidance Scale", info="Higher values follow prompt more closely" ) with gr.Column(scale=1): steps = gr.Slider( minimum=1, maximum=50, value=28, step=1, label="Steps", info="More steps = better quality but slower" ) with gr.Row(equal_height=True): with gr.Column(scale=1): num_images = gr.Slider( minimum=1, maximum=4, value=2, step=1, label="Number of Images", info="Generate multiple variations" ) with gr.Column(scale=1): seed = gr.Number( label="Random Seed", value=None, precision=0, info="Optional, for reproducibility" ) submit_btn = gr.Button( "Generate Variations", variant="primary", scale=1 ) with gr.Column(scale=1): output_gallery = gr.Gallery( label="Generated Variations", columns=2, rows=2, height=768, object_fit="contain", show_label=True ) gr.Markdown(""" ### Tips: - Upload any image to get started - Add a text prompt to guide the generation in a specific direction - Adjust guidance scale to control how closely the output follows the prompt - Increase steps for higher quality (but slower) generation - Use the same seed to reproduce results """) # Set up the generation function submit_btn.click( fn=interface.generate, inputs=[ input_image, prompt, guidance, steps, num_images, seed ], outputs=output_gallery ) # Launch the app if __name__ == "__main__": demo.launch()