import spaces import gradio as gr import numpy as np from PIL import Image, ImageDraw import cv2 import gradio as gr import torch import torch.nn.functional as F from omegaconf import OmegaConf import numpy as np import os import re from PIL import Image, ImageDraw import cv2 # from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection import torch.nn as nn from inference.manganinjia_pipeline import MangaNinjiaPipeline from diffusers import ( ControlNetModel, DiffusionPipeline, DDIMScheduler, AutoencoderKL, ) from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl from src.models.unet_2d_condition import UNet2DConditionModel from src.models.refunet_2d_condition import RefUNet2DConditionModel from src.point_network import PointNet from src.annotator.lineart import BatchLineartDetector val_configs = OmegaConf.load('./configs/inference.yaml') # download the checkpoints from huggingface_hub import snapshot_download, hf_hub_download os.makedirs("checkpoints", exist_ok=True) # List of subdirectories to create inside "checkpoints" subfolders = [ "StableDiffusion", "models", "MangaNinjia" ] # Create each subdirectory for subfolder in subfolders: os.makedirs(os.path.join("checkpoints", subfolder), exist_ok=True) # List of subdirectories to create inside "models" models_subfolders = [ "clip-vit-large-patch14", "control_v11p_sd15_lineart", "Annotators" ] # Create each subdirectory for subfolder in models_subfolders: os.makedirs(os.path.join("checkpoints/models", subfolder), exist_ok=True) snapshot_download( repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir = "./checkpoints/StableDiffusion" ) snapshot_download( repo_id = "openai/clip-vit-large-patch14", local_dir = "./checkpoints/models/clip-vit-large-patch14" ) snapshot_download( repo_id = "lllyasviel/control_v11p_sd15_lineart", local_dir = "./checkpoints/models/control_v11p_sd15_lineart" ) hf_hub_download( repo_id = "lllyasviel/Annotators", filename = "sk_model.pth", local_dir = "./checkpoints/models/Annotators" ) snapshot_download( repo_id = "Johanan0528/MangaNinjia", local_dir = "./checkpoints/MangaNinjia" ) # === load the checkpoint === pretrained_model_name_or_path = val_configs.model_path.pretrained_model_name_or_path refnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path controlnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path controlnet_model_name_or_path = val_configs.model_path.controlnet_model_name annotator_ckpts_path = val_configs.model_path.annotator_ckpts_path output_root = val_configs.inference_config.output_path device = val_configs.inference_config.device preprocessor = BatchLineartDetector(annotator_ckpts_path) in_channels_reference_unet = 4 in_channels_denoising_unet = 4 in_channels_controlnet = 4 noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path,subfolder='scheduler') vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder='vae' ) denoising_unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path,subfolder="unet", in_channels=in_channels_denoising_unet, low_cpu_mem_usage=False, ignore_mismatched_sizes=True ) reference_unet = RefUNet2DConditionModel.from_pretrained( pretrained_model_name_or_path,subfolder="unet", in_channels=in_channels_reference_unet, low_cpu_mem_usage=False, ignore_mismatched_sizes=True ) refnet_tokenizer = CLIPTokenizer.from_pretrained(refnet_clip_vision_encoder_path) refnet_text_encoder = CLIPTextModel.from_pretrained(refnet_clip_vision_encoder_path) refnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(refnet_clip_vision_encoder_path) controlnet = ControlNetModel.from_pretrained( controlnet_model_name_or_path, in_channels=in_channels_controlnet, low_cpu_mem_usage=False, ignore_mismatched_sizes=True ) controlnet_tokenizer = CLIPTokenizer.from_pretrained(controlnet_clip_vision_encoder_path) controlnet_text_encoder = CLIPTextModel.from_pretrained(controlnet_clip_vision_encoder_path) controlnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(controlnet_clip_vision_encoder_path) point_net=PointNet() reference_control_writer = ReferenceAttentionControl( reference_unet, do_classifier_free_guidance=False, mode="write", fusion_blocks="full", ) reference_control_reader = ReferenceAttentionControl( denoising_unet, do_classifier_free_guidance=False, mode="read", fusion_blocks="full", ) controlnet.load_state_dict( torch.load(val_configs.model_path.manga_control_model_path, map_location="cpu"), strict=False, ) point_net.load_state_dict( torch.load(val_configs.model_path.point_net_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(val_configs.model_path.manga_reference_model_path, map_location="cpu"), strict=False, ) denoising_unet.load_state_dict( torch.load(val_configs.model_path.manga_main_model_path, map_location="cpu"), strict=False, ) pipe = MangaNinjiaPipeline( reference_unet=reference_unet, controlnet=controlnet, denoising_unet=denoising_unet, vae=vae, refnet_tokenizer=refnet_tokenizer, refnet_text_encoder=refnet_text_encoder, refnet_image_encoder=refnet_image_enc, controlnet_tokenizer=controlnet_tokenizer, controlnet_text_encoder=controlnet_text_encoder, controlnet_image_encoder=controlnet_image_enc, scheduler=noise_scheduler, point_net=point_net ) pipe = pipe.to(torch.device(device)) def string_to_np_array(coord_string): coord_string = coord_string.strip('[]') coords = re.findall(r'\d+', coord_string) coords = list(map(int, coords)) coord_array = np.array(coords).reshape(-1, 2) return coord_array def infer_single(is_lineart, ref_image, target_image, output_coords_ref, output_coords_base, seed = -1, num_inference_steps=20, guidance_scale_ref = 9, guidance_scale_point =15 ): """ mask: 0/1 1-channel np.array image: rgb np.array """ generator = torch.cuda.manual_seed(seed) matrix1 = np.zeros((512, 512), dtype=np.uint8) matrix2 = np.zeros((512, 512), dtype=np.uint8) output_coords_ref = string_to_np_array(output_coords_ref) output_coords_base = string_to_np_array(output_coords_base) for index, (coords_ref,coords_base) in enumerate(zip(output_coords_ref,output_coords_base)): y1, x1 = coords_ref y2, x2 = coords_base matrix1[y1, x1] = index + 1 matrix2[y2, x2] = index + 1 point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0) point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0) preprocessor.to(device,dtype=torch.float32) pipe_out = pipe( is_lineart, ref_image, target_image, target_image, denosing_steps=num_inference_steps, processing_res=512, match_input_res=True, batch_size=1, show_progress_bar=True, guidance_scale_ref=guidance_scale_ref, guidance_scale_point=guidance_scale_point, preprocessor=preprocessor, generator=generator, point_ref=point_ref, point_main=point_main, ) return pipe_out def inference_single_image(ref_image, tar_image, ddim_steps, scale_ref, scale_point, seed, output_coords1, output_coords2, is_lineart ): if seed == -1: seed = np.random.randint(10000) pipe_out = infer_single(is_lineart, ref_image, tar_image, output_coords_ref=output_coords1, output_coords_base=output_coords2,seed=seed ,num_inference_steps=ddim_steps, guidance_scale_ref = scale_ref, guidance_scale_point = scale_point ) return pipe_out clicked_points_img1 = [] clicked_points_img2 = [] current_img_idx = 0 max_clicks = 14 point_size = 8 colors = [(255, 0, 0), (0, 255, 0)] # Process images: resizing them to 512x512 def process_image(ref, base): ref_resized = cv2.resize(ref, (512, 512)) # Note OpenCV resize order is (width, height) base_resized = cv2.resize(base, (512, 512)) return ref_resized, base_resized # Convert string to numpy array of coordinates def string_to_np_array(coord_string): coord_string = coord_string.strip('[]') coords = re.findall(r'\d+', coord_string) coords = list(map(int, coords)) coord_array = np.array(coords).reshape(-1, 2) return coord_array # Function to handle click events def get_select_coords(img1, img2, evt: gr.SelectData): global clicked_points_img1, clicked_points_img2, current_img_idx click_coords = (evt.index[1], evt.index[0]) if current_img_idx == 0: clicked_points_img1.append(click_coords) if len(clicked_points_img1) > max_clicks: clicked_points_img1 = [] current_img = img1 clicked_points = clicked_points_img1 else: clicked_points_img2.append(click_coords) if len(clicked_points_img2) > max_clicks: clicked_points_img2 = [] current_img = img2 clicked_points = clicked_points_img2 current_img_idx = 1 - current_img_idx img_pil = Image.fromarray(current_img.astype('uint8')) draw = ImageDraw.Draw(img_pil) for idx, point in enumerate(clicked_points): x, y = point color = colors[current_img_idx] for dx in range(-point_size, point_size + 1): for dy in range(-point_size, point_size + 1): if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: draw.point((y+dy, x+dx), fill=color) img_out = np.array(img_pil) coord_array = np.array([(x, y) for x, y in clicked_points]) return img_out, coord_array # Function to clear the clicked points def undo_last_point(ref, base): global clicked_points_img1, clicked_points_img2, current_img_idx current_img_idx=1-current_img_idx if current_img_idx == 0 and clicked_points_img1: clicked_points_img1.pop() # Undo last point in ref elif current_img_idx == 1 and clicked_points_img2: clicked_points_img2.pop() # Undo last point in base # After removing the last point, redraw the image without it if current_img_idx == 0: current_img = ref current_img_other = base clicked_points = clicked_points_img1 clicked_points_other = clicked_points_img2 else: current_img = base current_img_other = ref clicked_points = clicked_points_img2 clicked_points_other = clicked_points_img1 # Redraw the image without the last point img_pil = Image.fromarray(current_img.astype('uint8')) draw = ImageDraw.Draw(img_pil) for idx, point in enumerate(clicked_points): x, y = point color = colors[current_img_idx] for dx in range(-point_size, point_size + 1): for dy in range(-point_size, point_size + 1): if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: draw.point((y+dy, x+dx), fill=color) img_out = np.array(img_pil) img_pil_other = Image.fromarray(current_img_other.astype('uint8'),) draw_other = ImageDraw.Draw(img_pil_other) for idx, point in enumerate(clicked_points_other): x, y = point color = colors[1-current_img_idx] for dx in range(-point_size, point_size + 1): for dy in range(-point_size, point_size + 1): if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: draw_other.point((y+dy, x+dx), fill=color) img_out_other = np.array(img_pil_other) coord_array = np.array([(x, y) for x, y in clicked_points]) # Return the updated image and coordinates as text updated_coords = str(coord_array.tolist()) # If current_img_idx is 0, it means we are working with ref, so return for ref if current_img_idx == 0: coord_array2 = np.array([(x, y) for x, y in clicked_points_img2]) updated_coords2 = str(coord_array2.tolist()) return img_out, updated_coords, img_out_other, updated_coords2 # for ref image else: coord_array1 = np.array([(x, y) for x, y in clicked_points_img1]) updated_coords1 = str(coord_array1.tolist()) return img_out_other, updated_coords1, img_out, updated_coords # for base image # Main function to run the image processing @spaces.GPU def run_local(ref, base, *args, progress=gr.Progress(track_tqdm=True)): image = Image.fromarray(base) ref_image = Image.fromarray(ref) pipe_out = inference_single_image(ref_image.copy(), image.copy(), *args) to_save_dict = pipe_out.to_save_dict to_save_dict['edit2'] = pipe_out.img_pil return [to_save_dict['edit2'], to_save_dict['edge2_black']] with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# MangaNinja: Line Art Colorization with Precise Reference Following") with gr.Row(): baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768) with gr.Accordion("Advanced Option", open=True): num_samples = 1 ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) scale_ref = gr.Slider(label="Guidance of ref", minimum=0, maximum=30.0, value=9, step=0.1) scale_point = gr.Slider(label="Guidance of points", minimum=0, maximum=30.0, value=15, step=0.1) is_lineart = gr.Checkbox(label="Input is lineart", value=False) seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1) gr.Markdown("### Tutorial") gr.Markdown("1. Upload the reference image and target image. Note that for the target image, there are two modes: you can upload an RGB image, and the model will automatically extract the line art; or you can directly upload the line art by checking the 'input is lineart' option.") gr.Markdown("2. Click 'Process Images' to resize the images to 512*512 resolution.") gr.Markdown("3. (Optional) **Starting from the reference image**, **alternately** click on the reference and target images in sequence to define matching points. Use 'Undo' to revert the last action.") gr.Markdown("4. Click 'Generate' to produce the result.") gr.Markdown("# Upload the reference image and target image") with gr.Row(): ref = gr.Image(label="Reference Image",) base = gr.Image(label="Target Image",) gr.Button("Process Images").click(process_image, inputs=[ref, base], outputs=[ref, base]) with gr.Row(): output_img1 = gr.Image(label="Reference Output") output_coords1 = gr.Textbox(lines=2, label="Clicked Coordinates Image 1 (npy format)") output_img2 = gr.Image(label="Base Output") output_coords2 = gr.Textbox(lines=2, label="Clicked Coordinates Image 2 (npy format)") # Image click select functions ref.select(get_select_coords, [ref, base], [output_img1, output_coords1]) base.select(get_select_coords, [ref, base], [output_img2, output_coords2]) # Undo button undo_button = gr.Button("Undo") undo_button.click(undo_last_point, inputs=[ref, base], outputs=[output_img1, output_coords1, output_img2, output_coords2]) run_local_button = gr.Button("Generate") with gr.Row(): gr.Examples( examples=[ ['test_cases/hz0.png', 'test_cases/manga_target_examples/target_1.jpg'], ['test_cases/more_cases/az0.png', 'test_cases/manga_target_examples/target_2.jpg'], ['test_cases/more_cases/hi0.png', 'test_cases/manga_target_examples/target_3.jpg'], ['test_cases/more_cases/kn0.jpg', 'test_cases/manga_target_examples/target_4.jpg'], ['test_cases/more_cases/rk0.jpg', 'test_cases/manga_target_examples/target_5.jpg'], ], inputs=[ref, base], cache_examples=False, examples_per_page=100 ) run_local_button.click(fn=run_local, inputs=[ref, base, ddim_steps, scale_ref, scale_point, seed, output_coords1, output_coords2, is_lineart ], outputs=[baseline_gallery] ) demo.launch(show_api=False, show_error=True)