Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import os | |
import random | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection | |
import torch.nn as nn | |
from inference.manganinjia_pipeline import MangaNinjiaPipeline | |
from diffusers import ( | |
ControlNetModel, | |
DiffusionPipeline, | |
DDIMScheduler, | |
AutoencoderKL, | |
) | |
from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl | |
from src.models.unet_2d_condition import UNet2DConditionModel | |
from src.models.refunet_2d_condition import RefUNet2DConditionModel | |
from src.point_network import PointNet | |
from src.annotator.lineart import BatchLineartDetector | |
if "__main__" == __name__: | |
logging.basicConfig(level=logging.INFO) | |
# -------------------- Arguments -------------------- | |
parser = argparse.ArgumentParser( | |
description="Run single-image MangaNinjia" | |
) | |
parser.add_argument( | |
"--output_dir", type=str, required=True, help="Output directory." | |
) | |
# inference setting | |
parser.add_argument( | |
"--denoise_steps", | |
type=int, | |
default=50, # quantitative evaluation uses 50 steps | |
help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", | |
) | |
# resolution setting | |
parser.add_argument("--seed", type=int, default=None, help="Random seed.") | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
type=str, | |
default=None, | |
required=True, | |
help="Path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--image_encoder_path", | |
type=str, | |
default=None, | |
required=True, | |
help="Path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--controlnet_model_name_or_path", type=str, required=True, help="Path to original controlnet." | |
) | |
parser.add_argument( | |
"--annotator_ckpts_path", type=str, required=True, help="Path to depth inpainting model." | |
) | |
parser.add_argument( | |
"--manga_reference_unet_path", type=str, required=True, help="Path to depth inpainting model." | |
) | |
parser.add_argument( | |
"--manga_main_model_path", type=str, required=True, help="Path to depth inpainting model." | |
) | |
parser.add_argument( | |
"--manga_controlnet_model_path", type=str, required=True, help="Path to depth inpainting model." | |
) | |
parser.add_argument( | |
"--point_net_path", type=str, required=True, help="Path to depth inpainting model." | |
) | |
parser.add_argument( | |
"--input_reference_paths", | |
nargs='+', | |
default=None, | |
help="input_image_paths", | |
) | |
parser.add_argument( | |
"--input_lineart_paths", | |
nargs='+', | |
default=None, | |
help="lineart_paths", | |
) | |
parser.add_argument( | |
"--point_ref_paths", | |
type=str, | |
default=None, | |
nargs="+", | |
) | |
parser.add_argument( | |
"--point_lineart_paths", | |
type=str, | |
default=None, | |
nargs="+", | |
) | |
parser.add_argument( | |
"--is_lineart", | |
action="store_true", | |
default=False | |
) | |
parser.add_argument( | |
"--guidance_scale_ref", | |
type=float, | |
default=1e-4, | |
help="guidance scale for reference image", | |
) | |
parser.add_argument( | |
"--guidance_scale_point", | |
type=float, | |
default=1e-4, | |
help="guidance scale for points", | |
) | |
args = parser.parse_args() | |
output_dir = args.output_dir | |
denoise_steps = args.denoise_steps | |
seed = args.seed | |
is_lineart = args.is_lineart | |
os.makedirs(output_dir, exist_ok=True) | |
logging.info(f"output dir = {output_dir}") | |
if args.input_reference_paths is not None: | |
assert len(args.input_reference_paths) == len(args.input_lineart_paths) | |
input_reference_paths = args.input_reference_paths | |
input_lineart_paths = args.input_lineart_paths | |
if args.point_ref_paths is not None: | |
point_ref_paths = args.point_ref_paths | |
point_lineart_paths = args.point_lineart_paths | |
assert len(point_ref_paths) == len(point_lineart_paths) | |
print(f"arguments: {args}") | |
if seed is None: | |
import time | |
seed = int(time.time()) | |
generator = torch.cuda.manual_seed(seed) | |
# -------------------- Device -------------------- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
logging.warning("CUDA is not available. Running on CPU will be slow.") | |
logging.info(f"device = {device}") | |
# -------------------- Model -------------------- | |
preprocessor = BatchLineartDetector(args.annotator_ckpts_path) | |
preprocessor.to(device,dtype=torch.float32) | |
in_channels_reference_unet = 4 | |
in_channels_denoising_unet = 4 | |
in_channels_controlnet = 4 | |
noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path,subfolder='scheduler') | |
vae = AutoencoderKL.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder='vae' | |
) | |
denoising_unet = UNet2DConditionModel.from_pretrained( | |
args.pretrained_model_name_or_path,subfolder="unet", | |
in_channels=in_channels_denoising_unet, | |
low_cpu_mem_usage=False, | |
ignore_mismatched_sizes=True | |
) | |
reference_unet = RefUNet2DConditionModel.from_pretrained( | |
args.pretrained_model_name_or_path,subfolder="unet", | |
in_channels=in_channels_reference_unet, | |
low_cpu_mem_usage=False, | |
ignore_mismatched_sizes=True | |
) | |
refnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path) | |
refnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path) | |
refnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) | |
controlnet = ControlNetModel.from_pretrained( | |
args.controlnet_model_name_or_path, | |
in_channels=in_channels_controlnet, | |
low_cpu_mem_usage=False, | |
ignore_mismatched_sizes=True | |
) | |
controlnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path) | |
controlnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path) | |
controlnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) | |
point_net=PointNet() | |
controlnet.load_state_dict( | |
torch.load(args.manga_controlnet_model_path, map_location="cpu"), | |
strict=False, | |
) | |
point_net.load_state_dict( | |
torch.load(args.point_net_path, map_location="cpu"), | |
strict=False, | |
) | |
reference_unet.load_state_dict( | |
torch.load(args.manga_reference_unet_path, map_location="cpu"), | |
strict=False, | |
) | |
denoising_unet.load_state_dict( | |
torch.load(args.manga_main_model_path, map_location="cpu"), | |
strict=False, | |
) | |
pipe = MangaNinjiaPipeline( | |
reference_unet=reference_unet, | |
controlnet=controlnet, | |
denoising_unet=denoising_unet, | |
vae=vae, | |
refnet_tokenizer=refnet_tokenizer, | |
refnet_text_encoder=refnet_text_encoder, | |
refnet_image_encoder=refnet_image_encoder, | |
controlnet_tokenizer=controlnet_tokenizer, | |
controlnet_text_encoder=controlnet_text_encoder, | |
controlnet_image_encoder=controlnet_image_encoder, | |
scheduler=noise_scheduler, | |
point_net=point_net | |
) | |
pipe = pipe.to(torch.device(device)) | |
# -------------------- Inference and saving -------------------- | |
with torch.no_grad(): | |
for i in range(len(input_reference_paths)): | |
input_reference_path = input_reference_paths[i] | |
input_lineart_path = input_lineart_paths[i] | |
# save path | |
rgb_name_base = os.path.splitext(os.path.basename(input_reference_path))[0] | |
pred_name_base = rgb_name_base + "_colorized" | |
lineart_name_base = rgb_name_base + "_lineart" | |
colored_save_path = os.path.join( | |
output_dir, f"{pred_name_base}.png" | |
) | |
lineart_save_path = os.path.join( | |
output_dir, f"{lineart_name_base}.png" | |
) | |
if point_ref_paths is not None: | |
point_ref_path = point_ref_paths[i] | |
point_lineart_path = point_lineart_paths[i] | |
point_ref = torch.from_numpy(np.load(point_ref_path)).unsqueeze(0).unsqueeze(0) | |
point_main = torch.from_numpy(np.load(point_lineart_path)).unsqueeze(0).unsqueeze(0) | |
else: | |
matrix1 = np.zeros((512, 512), dtype=np.uint8) | |
matrix2 = np.zeros((512, 512), dtype=np.uint8) | |
point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0) | |
point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0) | |
ref_image = Image.open(input_reference_path) | |
ref_image = ref_image.resize((512, 512)) | |
target_image = Image.open(input_lineart_path) | |
target_image = target_image.resize((512, 512)) | |
pipe_out = pipe( | |
is_lineart, | |
ref_image, | |
target_image, | |
target_image, | |
denosing_steps=denoise_steps, | |
processing_res=512, | |
match_input_res=True, | |
batch_size=1, | |
show_progress_bar=True, | |
guidance_scale_ref=args.guidance_scale_ref, | |
guidance_scale_point=args.guidance_scale_point, | |
preprocessor=preprocessor, | |
generator=generator, | |
point_ref=point_ref, | |
point_main=point_main, | |
) | |
if os.path.exists(colored_save_path): | |
logging.warning(f"Existing file: '{colored_save_path}' will be overwritten") | |
image = pipe_out.img_pil | |
lineart = pipe_out.to_save_dict['edge2_black'] | |
image.save(colored_save_path) | |
lineart.save(lineart_save_path) |