HD-Painter / lib /models /ds_inp.py
Andranik Sargsyan
enable fp16, move SR to cuda:1
da1e12f
raw
history blame contribute delete
No virus
1.64 kB
import importlib
from omegaconf import OmegaConf
import torch
import safetensors
import safetensors.torch
from lib.smplfusion import DDIM, share, scheduler
from .common import *
MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
# pre-download
download_file(DOWNLOAD_URL, MODEL_PATH)
def load_model(dtype=torch.float16):
print ("Loading model: Dreamshaper Inpainting V8")
download_file(DOWNLOAD_URL, MODEL_PATH)
state_dict = safetensors.torch.load_file(MODEL_PATH)
config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
unet_state = extract(state_dict, 'model.diffusion_model')
encoder_state = extract(state_dict, 'cond_stage_model')
vae_state = extract(state_dict, 'first_stage_model')
unet.load_state_dict(unet_state)
encoder.load_state_dict(encoder_state)
vae.load_state_dict(vae_state)
if dtype == torch.float16:
unet.convert_to_fp16()
vae.to(dtype)
encoder.to(dtype)
unet = unet.requires_grad_(False)
encoder = encoder.requires_grad_(False)
vae = vae.requires_grad_(False)
ddim = DDIM(config, vae, encoder, unet)
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
return ddim