HD-Painter / lib /models /ds_inp.py
Andranik Sargsyan
add demo code
bfd34e9
raw
history blame
No virus
1.52 kB
import importlib
from omegaconf import OmegaConf
import torch
import safetensors
import safetensors.torch
from lib.smplfusion import DDIM, share, scheduler
from .common import *
MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
# pre-download
download_file(DOWNLOAD_URL, MODEL_PATH)
def load_model():
print ("Loading model: Dreamshaper Inpainting V8")
download_file(DOWNLOAD_URL, MODEL_PATH)
state_dict = safetensors.torch.load_file(MODEL_PATH)
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
unet_state = extract(state_dict, 'model.diffusion_model')
encoder_state = extract(state_dict, 'cond_stage_model')
vae_state = extract(state_dict, 'first_stage_model')
unet.load_state_dict(unet_state)
encoder.load_state_dict(encoder_state)
vae.load_state_dict(vae_state)
unet = unet.requires_grad_(False)
encoder = encoder.requires_grad_(False)
vae = vae.requires_grad_(False)
ddim = DDIM(config, vae, encoder, unet)
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
return ddim