Spaces:
Running
Running
import numpy as np | |
import torch | |
from tqdm import tqdm | |
def _extract_into_tensor(arr, timesteps, broadcast_shape): | |
""" | |
Extract values from a 1-D numpy array for a batch of indices. | |
:param arr: the 1-D numpy array. | |
:param timesteps: a tensor of indices into the array to extract. | |
:param broadcast_shape: a larger shape of K dimensions with the batch | |
dimension equal to the length of timesteps. | |
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. | |
""" | |
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() | |
while len(res.shape) < len(broadcast_shape): | |
res = res[..., None] | |
return res.expand(broadcast_shape) | |
class DiffSynthSampler: | |
def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False, | |
height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"): | |
if device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
self.device = device | |
self.height = height | |
self.train_width = train_width | |
self.max_batchsize = max_batchsize | |
self.max_width = max_width | |
self.channels = channels | |
self.num_timesteps = timesteps | |
self.timestep_map = list(range(self.num_timesteps)) | |
self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64) | |
self.respaced = False | |
self.define_beta_schedule() | |
self.CFG = 1.0 | |
self.mute = mute | |
self.noise_strategy = noise_strategy | |
def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None): | |
if reference_noise is None: | |
large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device) | |
else: | |
assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch" | |
large_noise_tensor = reference_noise | |
return large_noise_tensor[:batchsize, :, :, :width], None | |
def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None): | |
if self.noise_strategy == "repeat": | |
noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise) | |
return noise, concat_points | |
else: | |
noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise) | |
return noise, concat_points | |
def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None): | |
# 生成与训练数据长度相等的噪音 | |
if reference_noise is None: | |
train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device) | |
else: | |
assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch" | |
train_noise_tensor = reference_noise | |
release_width = int(self.train_width * 1.0 / 4) | |
first_part_width = self.train_width - release_width | |
first_part = train_noise_tensor[:batchsize, :, :, :first_part_width] | |
release_part = train_noise_tensor[:batchsize, :, :, -release_width:] | |
# 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分 | |
if width <= self.train_width: | |
_first_part_head_width = int((width - release_width) / 2) | |
_first_part_tail_width = width - release_width - _first_part_head_width | |
all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part] | |
# 沿第四维度拼接张量 | |
noise_tensor = torch.cat(all_parts, dim=3) | |
# 记录拼接点的位置 | |
concat_points = [0] | |
for part in all_parts[:-1]: | |
next_point = concat_points[-1] + part.size(3) | |
concat_points.append(next_point) | |
return noise_tensor, concat_points | |
# 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分 | |
else: | |
# 计算需要重复front_width的次数 | |
repeats = (width - release_width) // first_part_width | |
extra = (width - release_width) % first_part_width | |
_repeat_first_part_head_width = int(first_part_width / 2) | |
_repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width | |
repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)] | |
repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)] | |
# 计算起始索引 | |
_middle_part_start_index = (first_part_width - extra) // 2 | |
# 切片张量以获取中间部分 | |
middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra] | |
all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part] | |
# 沿第四维度拼接张量 | |
noise_tensor = torch.cat(all_parts, dim=3) | |
# 记录拼接点的位置 | |
concat_points = [0] | |
for part in all_parts[:-1]: | |
next_point = concat_points[-1] + part.size(3) | |
concat_points.append(next_point) | |
return noise_tensor, concat_points | |
def define_beta_schedule(self): | |
assert self.respaced == False, "This schedule has already been respaced!" | |
# define alphas | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) | |
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) | |
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) | |
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) | |
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) | |
self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas) | |
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) | |
def activate_classifier_free_guidance(self, CFG, unconditional_condition): | |
assert ( | |
not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available" | |
self.CFG = CFG | |
self.unconditional_condition = unconditional_condition | |
def respace(self, use_timesteps=None): | |
if not use_timesteps is None: | |
last_alpha_cumprod = 1.0 | |
new_betas = [] | |
self.timestep_map = [] | |
for i, _alpha_cumprod in enumerate(self.alphas_cumprod): | |
if i in use_timesteps: | |
new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod) | |
last_alpha_cumprod = _alpha_cumprod | |
self.timestep_map.append(i) | |
self.num_timesteps = len(use_timesteps) | |
self.betas = np.array(new_betas) | |
self.define_beta_schedule() | |
self.respaced = True | |
def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None): | |
assert shape[1] == self.channels, "shape[1] != self.channels" | |
assert shape[2] == self.height, "shape[2] != self.height" | |
noise = torch.empty(*shape, device=self.device) | |
# 第三种情况:两个端点都不是None,进行线性插值 | |
if first_endpoint is not None and second_endpoint is not None: | |
for i in range(shape[0]): | |
alpha = i / (shape[0] - 1) # 插值系数 | |
noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint | |
return noise # 返回插值后的结果,不需要进行后续的均值和方差调整 | |
else: | |
# 第一个端点不是None | |
if first_endpoint is not None: | |
noise[0] = first_endpoint | |
if shape[0] > 1: | |
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] | |
else: | |
noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] | |
if shape[0] > 1: | |
noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] | |
# 生成其他的噪声点 | |
for i in range(2, shape[0]): | |
noise[i] = 2 * noise[i - 1] - noise[i - 2] | |
# 当只有一个端点被指定时 | |
current_var = noise.var() | |
stddev_ratio = torch.sqrt(variance / current_var) | |
noise = noise * stddev_ratio | |
# 如果第一个端点被指定,进行平移调整 | |
if first_endpoint is not None: | |
shift = first_endpoint - noise[0] | |
noise += shift | |
return noise | |
def q_sample(self, x_start, t, noise=None): | |
""" | |
Diffuse the data for a given number of diffusion steps. | |
In other words, sample from q(x_t | x_0). | |
:param x_start: the initial data batch. | |
:param t: the number of diffusion steps (minus 1). Here, 0 means one step. | |
:param noise: if specified, the split-out normal noise. | |
:return: A noisy version of x_start. | |
""" | |
assert x_start.shape[1] == self.channels, "shape[1] != self.channels" | |
assert x_start.shape[2] == self.height, "shape[2] != self.height" | |
if noise is None: | |
# noise = torch.randn_like(x_start) | |
noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3]) | |
assert noise.shape == x_start.shape | |
return ( | |
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
* noise | |
) | |
def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0): | |
map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype) | |
mapped_t = map_tensor[t] | |
# Todo: add CFG | |
if self.CFG == 1.0: | |
pred_noise = model(x, mapped_t, condition) | |
else: | |
unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat( | |
*([x.shape[0]] + [1] * len(self.unconditional_condition.shape))) | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([mapped_t] * 2) | |
c_in = torch.cat([unconditional_condition, condition]) | |
noise_uncond, noise = model(x_in, t_in, c_in).chunk(2) | |
pred_noise = noise_uncond + self.CFG * (noise - noise_uncond) | |
# Todo: END | |
alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape) | |
alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) | |
pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t) | |
sigmas_t = ( | |
ddim_eta | |
* torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t)) | |
* torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev) | |
) | |
pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise | |
step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3]) | |
x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise | |
return x_prev | |
def p_sample(self, model, x, t, condition=None, sampler="ddim"): | |
if sampler == "ddim": | |
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0) | |
elif sampler == "ddpm": | |
return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0) | |
else: | |
raise NotImplementedError() | |
def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8): | |
release_length = int(self.train_width / 4) | |
assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)" | |
fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)] | |
# Todo: remove hard-coding | |
n_guidance_steps = int(n_masks * mask_flexivity) | |
n_free_steps = n_masks - n_guidance_steps | |
masks = [] | |
# Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img | |
for i in range(n_guidance_steps): | |
# mask = 1, freeze | |
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) | |
step_i_mask[:, :, :, -release_length:] = 1.0 | |
for fraction_index in range(len(fraction_lengths)): | |
_fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index]) | |
if fraction_index == 0: | |
step_i_mask[:, :, :, :_fraction_mask_length] = 1.0 | |
elif fraction_index == len(fraction_lengths) - 1: | |
if not _fraction_mask_length == 0: | |
step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0 | |
else: | |
fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2) | |
step_i_mask[:, :, :, | |
concat_points[fraction_index] + fraction_mask_start_position:concat_points[ | |
fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0 | |
masks.append(step_i_mask) | |
for i in range(n_free_steps): | |
step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) | |
step_i_mask[:, :, :, -release_length:] = 1.0 | |
masks.append(step_i_mask) | |
masks.reverse() | |
return masks | |
def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, | |
return_tensor=False, condition=None, guide_img=None, | |
mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8): | |
assert shape[1] == self.channels, "shape[1] != self.channels" | |
assert shape[2] == self.height, "shape[2] != self.height" | |
initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise) | |
assert initial_noise.shape == shape, "initial_noise.shape != shape" | |
start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!! | |
end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio) | |
timesteps = reversed(range(end_noise_level_index, start_noise_level_index)) | |
# configure initial img | |
assert (start_noise_level_ratio == 1.0) or ( | |
not guide_img is None), "A guide_img must be given to sample from a non-pure-noise." | |
if guide_img is None: | |
img = initial_noise | |
else: | |
guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img) | |
assert guide_img.shape == shape, "guide_img.shape != shape" | |
if start_noise_level_index > 0: | |
t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included | |
img = self.q_sample(guide_img, t, noise=initial_noise) | |
else: | |
print("Zero noise added to the guidance latent representation.") | |
img = guide_img | |
# get masks | |
n_masks = start_noise_level_index - end_noise_level_index | |
if use_dynamic_mask: | |
masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity) | |
else: | |
masks = [mask for _ in range(n_masks)] | |
imgs = [img] | |
current_mask = None | |
for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute): | |
# if i == 3: | |
# return [img], initial_noise # 第1排,第1列 | |
img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long), | |
condition=condition, | |
sampler=sampler) | |
# if i == 3: | |
# return [img], initial_noise # 第1排,第2列 | |
if inpaint: | |
if i > 0: | |
t = torch.full((shape[0],), int(i-1), device=self.device).long() | |
img_noise_t = self.q_sample(guide_img, t, noise=initial_noise) | |
# if i == 3: | |
# return [img_noise_t], initial_noise # 第2排,第2列 | |
current_mask = masks.pop() | |
img = current_mask * img_noise_t + (1 - current_mask) * img | |
# if i == 3: | |
# return [img], initial_noise # 第1.5排,最后1列 | |
else: | |
img = current_mask * guide_img + (1 - current_mask) * img | |
if return_tensor: | |
imgs.append(img) | |
else: | |
imgs.append(img.cpu().numpy()) | |
return imgs, initial_noise | |
def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None): | |
if not seed is None: | |
torch.manual_seed(seed) | |
return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, | |
return_tensor=return_tensor, condition=condition, sampler=sampler) | |
def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False, | |
condition=None, sampler="ddim", seed=None): | |
if not seed is None: | |
torch.manual_seed(seed) | |
linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint, | |
second_endpoint=second_endpoint) | |
return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0, | |
end_noise_level_ratio=0.0, | |
return_tensor=return_tensor, condition=condition, sampler=sampler) | |
def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None, | |
sampler="ddim", initial_noise=None, seed=None): | |
if not seed is None: | |
torch.manual_seed(seed) | |
assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]" | |
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0, | |
return_tensor=return_tensor, condition=condition, sampler=sampler, | |
guide_img=guide_img, initial_noise=initial_noise) | |
def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None, | |
sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None, | |
mask_flexivity=0.8): | |
if not seed is None: | |
torch.manual_seed(seed) | |
return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio, | |
return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask, | |
sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask, | |
mask_flexivity=mask_flexivity) |