Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast, BitsAndBytesConfig | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from flux.transformer_flux import FluxTransformer2DModel | |
| from flux.pipeline_flux_chameleon import FluxPipeline | |
| from flux.pipeline_flux_img2img import FluxImg2ImgPipeline | |
| from flux.pipeline_flux_inpaint import FluxInpaintPipeline | |
| from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel | |
| from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline | |
| from flux.controlnet_flux import FluxMultiControlNetModel | |
| from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline | |
| from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import math | |
| nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
| def get_model_path(model_name): | |
| """Get the full path for a model based on the checkpoints directory.""" | |
| base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override | |
| return os.path.join(base_dir, model_name) | |
| # Model paths configuration | |
| MODEL_PATHS = { | |
| 'flux': get_model_path('flux'), | |
| 'qwen2vl': get_model_path('qwen2-vl'), | |
| 'controlnet': get_model_path('controlnet'), | |
| 'depth_anything': { | |
| 'path': get_model_path('depth-anything-v2'), | |
| 'weights': 'depth_anything_v2_vitl.pth' | |
| }, | |
| 'anyline': { | |
| 'path': get_model_path('anyline'), | |
| 'weights': 'MTEED.pth' | |
| }, | |
| 'sam2': { | |
| 'path': get_model_path('segment-anything-2'), | |
| 'weights': 'sam2_hiera_large.pt', | |
| 'config': 'sam2_hiera_l.yaml' | |
| } | |
| } | |
| ASPECT_RATIOS = { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| "2.4:1": (1536, 640), | |
| "3:4": (896, 1152), | |
| "4:3": (1152, 896), | |
| } | |
| class Qwen2Connector(nn.Module): | |
| def __init__(self, input_dim=3584, output_dim=4096): | |
| super().__init__() | |
| self.linear = nn.Linear(input_dim, output_dim) | |
| def forward(self, x): | |
| return self.linear(x) | |
| class FluxModel: | |
| def __init__(self, is_turbo=False, device="cuda", required_features=None, is_quantization=True): | |
| """ | |
| Initialize FluxModel with specified features | |
| Args: | |
| is_turbo: Enable turbo mode for faster inference | |
| device: Device to run the model on | |
| required_features: List of required features ['controlnet', 'depth', 'line', 'sam'] | |
| """ | |
| self.device = torch.device(device) | |
| self.qkwargs = {"quantization_config": nf4_config} if is_quantization else {} | |
| self.dtype = torch.bfloat16 | |
| if required_features is None: | |
| required_features = [] | |
| self._line_detector_imported = False | |
| self._depth_model_imported = False | |
| self._sam_imported = False | |
| self._turbo_imported = False | |
| # Initialize base models (always required) | |
| self._init_base_models() | |
| # Initialize optional models based on requirements | |
| if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']): | |
| self._init_controlnet() | |
| if 'depth' in required_features: | |
| self._init_depth_model() | |
| if 'line' in required_features: | |
| self._init_line_detector() | |
| if 'sam' in required_features: | |
| self._init_sam() | |
| if is_turbo: | |
| self._enable_turbo() | |
| def _init_base_models(self): | |
| """Initialize the core models that are always needed""" | |
| # Qwen2VL and connector initialization | |
| self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained( | |
| MODEL_PATHS['qwen2vl'], | |
| torch_dtype=self.dtype, | |
| **self.qkwargs | |
| ) | |
| self.qwen2vl.requires_grad_(False).to(self.device) | |
| self.connector = Qwen2Connector(input_dim=3584, output_dim=4096) | |
| connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt") | |
| if os.path.exists(connector_path): | |
| connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True) | |
| connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()} | |
| self.connector.load_state_dict(connector_state_dict) | |
| self.connector.to(self.dtype).to(self.device) | |
| # Text encoders initialization | |
| self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder") | |
| self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2", **self.qkwargs) | |
| self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2") | |
| self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device) | |
| #self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device) | |
| self.text_encoder_two.requires_grad_(False).to(self.device) | |
| # T5 context embedder | |
| self.t5_context_embedder = nn.Linear(4096, 3072) | |
| t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt") | |
| t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True) | |
| self.t5_context_embedder.load_state_dict(t5_embedder_state_dict) | |
| self.t5_context_embedder.to(self.dtype).to(self.device) | |
| # Basic components | |
| self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1) | |
| self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae") | |
| self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer", **self.qkwargs) | |
| self.vae.requires_grad_(False).to(self.dtype).to(self.device) | |
| #self.transformer.requires_grad_(False).to(self.dtype).to(self.device) | |
| self.transformer.requires_grad_(False).to(self.device) | |
| def _init_controlnet(self): | |
| """Initialize ControlNet model""" | |
| self.controlnet_union = FluxControlNetModel.from_pretrained( | |
| MODEL_PATHS['controlnet'], | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| self.controlnet_union.requires_grad_(False).to(self.device) | |
| self.controlnet = FluxMultiControlNetModel([self.controlnet_union]) | |
| def _init_depth_model(self): | |
| """Initialize Depth Anything V2 model""" | |
| if not self._depth_model_imported: | |
| from depth_anything_v2.dpt import DepthAnythingV2 | |
| self._depth_model_imported = True | |
| self.depth_model = DepthAnythingV2( | |
| encoder='vitl', | |
| features=256, | |
| out_channels=[256, 512, 1024, 1024] | |
| ) | |
| depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'], | |
| MODEL_PATHS['depth_anything']['weights']) | |
| self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device)) | |
| self.depth_model.requires_grad_(False).to(self.device) | |
| def _init_line_detector(self): | |
| """Initialize line detection model""" | |
| if not self._line_detector_imported: | |
| from controlnet_aux import AnylineDetector | |
| self._line_detector_imported = True | |
| self.anyline = AnylineDetector.from_pretrained( | |
| MODEL_PATHS['anyline']['path'], | |
| filename=MODEL_PATHS['anyline']['weights'] | |
| ) | |
| self.anyline.to(self.device) | |
| def _init_sam(self): | |
| """Initialize SAM2 model""" | |
| if not self._sam_imported: | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| self._sam_imported = True | |
| sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'], | |
| MODEL_PATHS['sam2']['weights']) | |
| model_cfg = os.path.join(MODEL_PATHS['sam2']['path'], | |
| MODEL_PATHS['sam2']['config']) | |
| self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device) | |
| self.sam2_predictor = SAM2ImagePredictor(self.sam2_model) | |
| def _enable_turbo(self): | |
| """Enable turbo mode for faster inference""" | |
| if not self._turbo_imported: | |
| from optimum.quanto import freeze, qfloat8, quantize | |
| self._turbo_imported = True | |
| quantize( | |
| self.transformer, | |
| weights=qfloat8, | |
| exclude=[ | |
| "*.norm", "*.norm1", "*.norm2", "*.norm2_context", | |
| "proj_out", "x_embedder", "norm_out", "context_embedder", | |
| ], | |
| ) | |
| freeze(self.transformer) | |
| def generate_mask(self, image, input_points, input_labels): | |
| """ | |
| 使用SAM2生成分割mask | |
| Args: | |
| image: PIL Image或numpy数组 | |
| input_points: numpy数组,形状为(N, 2),包含点的坐标 | |
| input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点 | |
| Returns: | |
| PIL Image: 最高分数的mask | |
| """ | |
| try: | |
| # 确保图像是numpy数组 | |
| if isinstance(image, Image.Image): | |
| image_array = np.array(image) | |
| else: | |
| image_array = image | |
| # 设置图像 | |
| self.sam2_predictor.set_image(image_array) | |
| # 进行预测 | |
| with torch.inference_mode(): | |
| masks, scores, logits = self.sam2_predictor.predict( | |
| point_coords=input_points, | |
| point_labels=input_labels, | |
| multimask_output=True, | |
| ) | |
| # 返回得分最高的mask | |
| best_mask_idx = scores.argmax() | |
| mask = masks[best_mask_idx] | |
| mask_image = Image.fromarray((mask * 255).astype(np.uint8)) | |
| return mask_image | |
| except Exception as e: | |
| print(f"Mask generation failed: {str(e)}") | |
| raise | |
| def recover_2d_shape(self, image_hidden_state, grid_thw): | |
| batch_size, num_tokens, hidden_dim = image_hidden_state.shape | |
| _, h, w = grid_thw | |
| h_out = h // 2 | |
| w_out = w // 2 | |
| # 重塑为 (batch_size, height, width, hidden_dim) | |
| reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim) | |
| return reshaped | |
| def generate_attention_matrix(self, center_x, center_y, radius, image_shape): | |
| height, width = image_shape | |
| y, x = np.ogrid[:height, :width] | |
| center_y, center_x = center_y * height, center_x * width | |
| distances = np.sqrt((x - center_x)**2 + (y - center_y)**2) | |
| attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1) | |
| return attention | |
| def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius): | |
| qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0])) | |
| attention_matrix = self.generate_attention_matrix( | |
| center_x, center_y, radius, | |
| (qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2)) | |
| ) | |
| attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1) | |
| qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device) | |
| return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3)) | |
| def compute_text_embeddings(self, prompt): | |
| with torch.no_grad(): | |
| text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False) | |
| pooled_prompt_embeds = prompt_embeds.pooler_output | |
| return pooled_prompt_embeds.to(self.dtype) | |
| def compute_t5_text_embeddings( | |
| self, | |
| max_sequence_length=256, | |
| prompt=None, | |
| num_images_per_prompt=1, | |
| device=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| text_inputs = self.tokenizer_two( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0] | |
| dtype = self.text_encoder_two.dtype | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def process_image(self, image): | |
| message = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "Describe this image."}, | |
| ] | |
| } | |
| ] | |
| text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
| with torch.no_grad(): | |
| inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device) | |
| output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs) | |
| image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) | |
| return image_hidden_state, image_grid_thw | |
| def resize_image(self, img, max_pixels=1050000): | |
| # 确保输入是 PIL Image | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(img) | |
| width, height = img.size | |
| num_pixels = width * height | |
| if num_pixels > max_pixels: | |
| scale = math.sqrt(max_pixels / num_pixels) | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # 调整宽度和高度,使其能被8整除 | |
| new_width = new_width - (new_width % 8) | |
| new_height = new_height - (new_height % 8) | |
| img = img.resize((new_width, new_height), Image.LANCZOS) | |
| else: | |
| # 如果图片不需要缩小,仍然需要确保尺寸能被8整除 | |
| new_width = width - (width % 8) | |
| new_height = height - (height % 8) | |
| if new_width != width or new_height != height: | |
| img = img.resize((new_width, new_height), Image.LANCZOS) | |
| return img | |
| def generate_depth_map(self, image): | |
| """Generate depth map using Depth Anything V2""" | |
| # Convert PIL to numpy array | |
| image_np = np.array(image) | |
| # Convert RGB to BGR for cv2 | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # Generate depth map | |
| with torch.no_grad(): | |
| depth = self.depth_model.infer_image(image_bgr) | |
| # Normalize depth to 0-1 range | |
| depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) | |
| # Convert to RGB image | |
| depth_rgb = (depth_norm * 255).astype(np.uint8) | |
| depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(depth_rgb) | |
| def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28, | |
| aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation", | |
| denoise_strength=0.8, mask_image=None, imageCount=2, | |
| line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2): | |
| batch_size = imageCount | |
| if aspect_ratio not in ASPECT_RATIOS: | |
| raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}") | |
| width, height = ASPECT_RATIOS[aspect_ratio] | |
| pooled_prompt_embeds = self.compute_text_embeddings(prompt="") | |
| t5_prompt_embeds = None | |
| if prompt != "": | |
| self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28) | |
| t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device).to(self.dtype) | |
| t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds) | |
| else: | |
| self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28) | |
| qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a) | |
| # 只有当所有注意力参数都被提供时,才应用注意力机制 | |
| if mode == "variation": | |
| if center_x is not None and center_y is not None and radius is not None: | |
| qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius) | |
| qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) | |
| if mode == "img2img" or mode == "inpaint": | |
| if input_image_b: | |
| qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b) | |
| if center_x is not None and center_y is not None and radius is not None: | |
| qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius) | |
| qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b) | |
| else: | |
| qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) | |
| qwen2_hidden_state_b = None | |
| if mode == "controlnet" or mode == "controlnet-inpaint": | |
| qwen2_hidden_state_b = None | |
| if input_image_b: | |
| qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b) | |
| if center_x is not None and center_y is not None and radius is not None: | |
| qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius) | |
| qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b) | |
| qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) | |
| ############################# | |
| # IMAGE GENERATION | |
| ############################# | |
| if mode == "variation": | |
| # Initialize different pipelines | |
| pipeline = FluxPipeline( | |
| transformer=self.transformer, | |
| scheduler=self.noise_scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| ) | |
| gen_images = pipeline( | |
| prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| ############################# | |
| # IMAGE-TO-IMAGE | |
| ############################# | |
| elif mode == "img2img": | |
| input_image_a = self.resize_image(input_image_a) | |
| width, height = input_image_a.size | |
| img2img_pipeline = FluxImg2ImgPipeline( | |
| transformer=self.transformer, | |
| scheduler=self.noise_scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| ) | |
| gen_images = img2img_pipeline( | |
| image=input_image_a, | |
| strength=denoise_strength, | |
| prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| ############################# | |
| # INPAINTING | |
| ############################# | |
| elif mode == "inpaint": | |
| if mask_image is None: | |
| raise ValueError("Mask image is required for inpainting mode") | |
| input_image_a = self.resize_image(input_image_a) | |
| mask_image = self.resize_image(mask_image) | |
| width, height = input_image_a.size | |
| inpaint_pipeline = FluxInpaintPipeline( | |
| transformer=self.transformer, | |
| scheduler=self.noise_scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| ) | |
| gen_images = inpaint_pipeline( | |
| image=input_image_a, | |
| mask_image=mask_image, | |
| strength=denoise_strength, | |
| prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| ############################# | |
| # CONTROLNET | |
| ############################# | |
| elif mode == "controlnet": | |
| input_image_a = self.resize_image(input_image_a) | |
| width, height = input_image_a.size | |
| controlnet_pipeline = FluxControlNetImg2ImgPipeline( | |
| transformer=self.transformer, | |
| scheduler=self.noise_scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| controlnet=self.controlnet, | |
| ) | |
| # 准备控制图像和模式列表 | |
| control_images = [] | |
| control_modes = [] | |
| conditioning_scales = [] | |
| # 根据用户选择添加控制模式 | |
| if depth_mode: | |
| control_image_depth = self.generate_depth_map(input_image_a) | |
| control_images.append(control_image_depth) | |
| control_modes.append(2) # depth mode | |
| conditioning_scales.append(depth_strength) | |
| if line_mode: | |
| control_image_canny = self.anyline(input_image_a, detect_resolution=1280) | |
| control_images.append(control_image_canny) | |
| control_modes.append(0) # line mode | |
| conditioning_scales.append(line_strength) | |
| # 如果没有启用任何模式,默认使用line+depth模式 | |
| if not line_mode and not depth_mode: | |
| control_image_depth = self.generate_depth_map(input_image_a) | |
| control_image_canny = self.anyline(input_image_a, detect_resolution=1280) | |
| control_images = [control_image_depth, control_image_canny] | |
| control_modes = [2, 0] | |
| conditioning_scales = [0.2, 0.4] | |
| if qwen2_hidden_state_b is not None: | |
| qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :] | |
| qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :] | |
| gen_images = controlnet_pipeline( | |
| image=input_image_a, | |
| strength=denoise_strength, | |
| control_image=control_images, | |
| control_mode=control_modes, | |
| controlnet_conditioning_scale=conditioning_scales, | |
| prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, | |
| prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| ############################# | |
| # CONTROLNET INPAINT | |
| ############################# | |
| elif mode == "controlnet-inpaint": | |
| input_image_a = self.resize_image(input_image_a) | |
| mask_image = self.resize_image(mask_image) | |
| width, height = input_image_a.size | |
| controlnet_pipeline = FluxControlNetInpaintPipeline( | |
| transformer=self.transformer, | |
| scheduler=self.noise_scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| controlnet=self.controlnet, | |
| ) | |
| # 准备控制图像和模式列表 | |
| control_images = [] | |
| control_modes = [] | |
| conditioning_scales = [] | |
| # 根据用户选择添加控制模式 | |
| if depth_mode: | |
| control_image_depth = self.generate_depth_map(input_image_a) | |
| control_images.append(control_image_depth) | |
| control_modes.append(2) # depth mode | |
| conditioning_scales.append(depth_strength) | |
| if line_mode: | |
| control_image_canny = self.anyline(input_image_a, detect_resolution=1280) | |
| control_images.append(control_image_canny) | |
| control_modes.append(0) # line mode | |
| conditioning_scales.append(line_strength) | |
| # 如果没有启用任何模式,默认使用line+depth模式 | |
| if not line_mode and not depth_mode: | |
| control_image_depth = self.generate_depth_map(input_image_a) | |
| control_image_canny = self.anyline(input_image_a, detect_resolution=1280) | |
| control_images = [control_image_depth, control_image_canny] | |
| control_modes = [2, 0] | |
| conditioning_scales = [0.2, 0.4] | |
| if qwen2_hidden_state_b is not None: | |
| qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :] | |
| qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :] | |
| gen_images = controlnet_pipeline( | |
| image=input_image_a, | |
| mask_image=mask_image, | |
| control_image=control_images, | |
| control_mode=control_modes, | |
| controlnet_conditioning_scale=conditioning_scales, | |
| strength=denoise_strength, | |
| prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, | |
| prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1), | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=height, | |
| width=width, | |
| ).images | |
| else: | |
| raise ValueError(f"Invalid mode: {mode}") | |
| return gen_images | |