HD-Painter / lib /models /sd2_inp.py
Andranik Sargsyan
add demo code
bfd34e9
raw
history blame
No virus
1.66 kB
import safetensors
import safetensors.torch
import torch
from omegaconf import OmegaConf
from lib.smplfusion import DDIM, share, scheduler
from .common import *
MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-inpainting/512-inpainting-ema.safetensors'
DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
# pre-download
download_file(DOWNLOAD_URL, MODEL_PATH)
def load_model():
print ("Loading model: Stable-Inpainting 2.0")
download_file(DOWNLOAD_URL, MODEL_PATH)
state_dict = safetensors.torch.load_file(MODEL_PATH)
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
ddim = DDIM(config, vae, encoder, unet)
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
unet_state = extract(state_dict, 'model.diffusion_model')
encoder_state = extract(state_dict, 'cond_stage_model')
vae_state = extract(state_dict, 'first_stage_model')
unet.load_state_dict(unet_state)
encoder.load_state_dict(encoder_state)
vae.load_state_dict(vae_state)
unet = unet.requires_grad_(False)
encoder = encoder.requires_grad_(False)
vae = vae.requires_grad_(False)
ddim = DDIM(config, vae, encoder, unet)
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
print('Stable-Inpainting 2.0 loaded')
return ddim