DiffuSynthV0.2 / model /DiffSynthSampler.py
WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
raw
history blame
21 kB
import numpy as np
import torch
from tqdm import tqdm
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
class DiffSynthSampler:
def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False,
height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"):
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.height = height
self.train_width = train_width
self.max_batchsize = max_batchsize
self.max_width = max_width
self.channels = channels
self.num_timesteps = timesteps
self.timestep_map = list(range(self.num_timesteps))
self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64)
self.respaced = False
self.define_beta_schedule()
self.CFG = 1.0
self.mute = mute
self.noise_strategy = noise_strategy
def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None):
if reference_noise is None:
large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device)
else:
assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch"
large_noise_tensor = reference_noise
return large_noise_tensor[:batchsize, :, :, :width], None
def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None):
if self.noise_strategy == "repeat":
noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise)
return noise, concat_points
else:
noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise)
return noise, concat_points
def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None):
# 生成与训练数据长度相等的噪音
if reference_noise is None:
train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device)
else:
assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch"
train_noise_tensor = reference_noise
release_width = int(self.train_width * 1.0 / 4)
first_part_width = self.train_width - release_width
first_part = train_noise_tensor[:batchsize, :, :, :first_part_width]
release_part = train_noise_tensor[:batchsize, :, :, -release_width:]
# 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分
if width <= self.train_width:
_first_part_head_width = int((width - release_width) / 2)
_first_part_tail_width = width - release_width - _first_part_head_width
all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part]
# 沿第四维度拼接张量
noise_tensor = torch.cat(all_parts, dim=3)
# 记录拼接点的位置
concat_points = [0]
for part in all_parts[:-1]:
next_point = concat_points[-1] + part.size(3)
concat_points.append(next_point)
return noise_tensor, concat_points
# 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分
else:
# 计算需要重复front_width的次数
repeats = (width - release_width) // first_part_width
extra = (width - release_width) % first_part_width
_repeat_first_part_head_width = int(first_part_width / 2)
_repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width
repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)]
repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)]
# 计算起始索引
_middle_part_start_index = (first_part_width - extra) // 2
# 切片张量以获取中间部分
middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra]
all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part]
# 沿第四维度拼接张量
noise_tensor = torch.cat(all_parts, dim=3)
# 记录拼接点的位置
concat_points = [0]
for part in all_parts[:-1]:
next_point = concat_points[-1] + part.size(3)
concat_points.append(next_point)
return noise_tensor, concat_points
def define_beta_schedule(self):
assert self.respaced == False, "This schedule has already been respaced!"
# define alphas
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod))
def activate_classifier_free_guidance(self, CFG, unconditional_condition):
assert (
not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available"
self.CFG = CFG
self.unconditional_condition = unconditional_condition
def respace(self, use_timesteps=None):
if not use_timesteps is None:
last_alpha_cumprod = 1.0
new_betas = []
self.timestep_map = []
for i, _alpha_cumprod in enumerate(self.alphas_cumprod):
if i in use_timesteps:
new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = _alpha_cumprod
self.timestep_map.append(i)
self.num_timesteps = len(use_timesteps)
self.betas = np.array(new_betas)
self.define_beta_schedule()
self.respaced = True
def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None):
assert shape[1] == self.channels, "shape[1] != self.channels"
assert shape[2] == self.height, "shape[2] != self.height"
noise = torch.empty(*shape, device=self.device)
# 第三种情况:两个端点都不是None,进行线性插值
if first_endpoint is not None and second_endpoint is not None:
for i in range(shape[0]):
alpha = i / (shape[0] - 1) # 插值系数
noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint
return noise # 返回插值后的结果,不需要进行后续的均值和方差调整
else:
# 第一个端点不是None
if first_endpoint is not None:
noise[0] = first_endpoint
if shape[0] > 1:
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
else:
noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
if shape[0] > 1:
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0]
# 生成其他的噪声点
for i in range(2, shape[0]):
noise[i] = 2 * noise[i - 1] - noise[i - 2]
# 当只有一个端点被指定时
current_var = noise.var()
stddev_ratio = torch.sqrt(variance / current_var)
noise = noise * stddev_ratio
# 如果第一个端点被指定,进行平移调整
if first_endpoint is not None:
shift = first_endpoint - noise[0]
noise += shift
return noise
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
assert x_start.shape[1] == self.channels, "shape[1] != self.channels"
assert x_start.shape[2] == self.height, "shape[2] != self.height"
if noise is None:
# noise = torch.randn_like(x_start)
noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3])
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
@torch.no_grad()
def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0):
map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype)
mapped_t = map_tensor[t]
# Todo: add CFG
if self.CFG == 1.0:
pred_noise = model(x, mapped_t, condition)
else:
unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat(
*([x.shape[0]] + [1] * len(self.unconditional_condition.shape)))
x_in = torch.cat([x] * 2)
t_in = torch.cat([mapped_t] * 2)
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = model(x_in, t_in, c_in).chunk(2)
pred_noise = noise_uncond + self.CFG * (noise - noise_uncond)
# Todo: END
alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
sigmas_t = (
ddim_eta
* torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t))
* torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev)
)
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise
step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3])
x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise
return x_prev
def p_sample(self, model, x, t, condition=None, sampler="ddim"):
if sampler == "ddim":
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0)
elif sampler == "ddpm":
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0)
else:
raise NotImplementedError()
def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8):
release_length = int(self.train_width / 4)
assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)"
fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)]
# Todo: remove hard-coding
n_guidance_steps = int(n_masks * mask_flexivity)
n_free_steps = n_masks - n_guidance_steps
masks = []
# Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img
for i in range(n_guidance_steps):
# mask = 1, freeze
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
step_i_mask[:, :, :, -release_length:] = 1.0
for fraction_index in range(len(fraction_lengths)):
_fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index])
if fraction_index == 0:
step_i_mask[:, :, :, :_fraction_mask_length] = 1.0
elif fraction_index == len(fraction_lengths) - 1:
if not _fraction_mask_length == 0:
step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0
else:
fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2)
step_i_mask[:, :, :,
concat_points[fraction_index] + fraction_mask_start_position:concat_points[
fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0
masks.append(step_i_mask)
for i in range(n_free_steps):
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device)
step_i_mask[:, :, :, -release_length:] = 1.0
masks.append(step_i_mask)
masks.reverse()
return masks
@torch.no_grad()
def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
return_tensor=False, condition=None, guide_img=None,
mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8):
assert shape[1] == self.channels, "shape[1] != self.channels"
assert shape[2] == self.height, "shape[2] != self.height"
initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise)
assert initial_noise.shape == shape, "initial_noise.shape != shape"
start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!!
end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio)
timesteps = reversed(range(end_noise_level_index, start_noise_level_index))
# configure initial img
assert (start_noise_level_ratio == 1.0) or (
not guide_img is None), "A guide_img must be given to sample from a non-pure-noise."
if guide_img is None:
img = initial_noise
else:
guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img)
assert guide_img.shape == shape, "guide_img.shape != shape"
if start_noise_level_index > 0:
t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included
img = self.q_sample(guide_img, t, noise=initial_noise)
else:
print("Zero noise added to the guidance latent representation.")
img = guide_img
# get masks
n_masks = start_noise_level_index - end_noise_level_index
if use_dynamic_mask:
masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity)
else:
masks = [mask for _ in range(n_masks)]
imgs = [img]
current_mask = None
for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute):
# if i == 3:
# return [img], initial_noise # 第1排,第1列
img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long),
condition=condition,
sampler=sampler)
# if i == 3:
# return [img], initial_noise # 第1排,第2列
if inpaint:
if i > 0:
t = torch.full((shape[0],), int(i-1), device=self.device).long()
img_noise_t = self.q_sample(guide_img, t, noise=initial_noise)
# if i == 3:
# return [img_noise_t], initial_noise # 第2排,第2列
current_mask = masks.pop()
img = current_mask * img_noise_t + (1 - current_mask) * img
# if i == 3:
# return [img], initial_noise # 第1.5排,最后1列
else:
img = current_mask * guide_img + (1 - current_mask) * img
if return_tensor:
imgs.append(img)
else:
imgs.append(img.cpu().numpy())
return imgs, initial_noise
def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None):
if not seed is None:
torch.manual_seed(seed)
return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0,
return_tensor=return_tensor, condition=condition, sampler=sampler)
def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False,
condition=None, sampler="ddim", seed=None):
if not seed is None:
torch.manual_seed(seed)
linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint,
second_endpoint=second_endpoint)
return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0,
end_noise_level_ratio=0.0,
return_tensor=return_tensor, condition=condition, sampler=sampler)
def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None,
sampler="ddim", initial_noise=None, seed=None):
if not seed is None:
torch.manual_seed(seed)
assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]"
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0,
return_tensor=return_tensor, condition=condition, sampler=sampler,
guide_img=guide_img, initial_noise=initial_noise)
def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None,
sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None,
mask_flexivity=0.8):
if not seed is None:
torch.manual_seed(seed)
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio,
return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask,
sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask,
mask_flexivity=mask_flexivity)