import numpy as np import torch from tqdm import tqdm def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) class DiffSynthSampler: def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, device=None, mute=False, height=128, max_batchsize=16, max_width=256, channels=4, train_width=64, noise_strategy="repeat"): if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.height = height self.train_width = train_width self.max_batchsize = max_batchsize self.max_width = max_width self.channels = channels self.num_timesteps = timesteps self.timestep_map = list(range(self.num_timesteps)) self.betas = np.array(np.linspace(beta_start, beta_end, self.num_timesteps), dtype=np.float64) self.respaced = False self.define_beta_schedule() self.CFG = 1.0 self.mute = mute self.noise_strategy = noise_strategy def get_deterministic_noise_tensor_non_repeat(self, batchsize, width, reference_noise=None): if reference_noise is None: large_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.max_width), device=self.device) else: assert reference_noise.shape == (batchsize, self.channels, self.height, self.max_width), "reference_noise shape mismatch" large_noise_tensor = reference_noise return large_noise_tensor[:batchsize, :, :, :width], None def get_deterministic_noise_tensor(self, batchsize, width, reference_noise=None): if self.noise_strategy == "repeat": noise, concat_points = self.get_deterministic_noise_tensor_repeat(batchsize, width, reference_noise=reference_noise) return noise, concat_points else: noise, concat_points = self.get_deterministic_noise_tensor_non_repeat(batchsize, width, reference_noise=reference_noise) return noise, concat_points def get_deterministic_noise_tensor_repeat(self, batchsize, width, reference_noise=None): # 生成与训练数据长度相等的噪音 if reference_noise is None: train_noise_tensor = torch.randn((self.max_batchsize, self.channels, self.height, self.train_width), device=self.device) else: assert reference_noise.shape == (batchsize, self.channels, self.height, self.train_width), "reference_noise shape mismatch" train_noise_tensor = reference_noise release_width = int(self.train_width * 1.0 / 4) first_part_width = self.train_width - release_width first_part = train_noise_tensor[:batchsize, :, :, :first_part_width] release_part = train_noise_tensor[:batchsize, :, :, -release_width:] # 如果所需 length 小于等于 origin length,去掉 first_part 的中间部分 if width <= self.train_width: _first_part_head_width = int((width - release_width) / 2) _first_part_tail_width = width - release_width - _first_part_head_width all_parts = [first_part[:, :, :, :_first_part_head_width], first_part[:, :, :, -_first_part_tail_width:], release_part] # 沿第四维度拼接张量 noise_tensor = torch.cat(all_parts, dim=3) # 记录拼接点的位置 concat_points = [0] for part in all_parts[:-1]: next_point = concat_points[-1] + part.size(3) concat_points.append(next_point) return noise_tensor, concat_points # 如果所需 length 大于 origin length,不断地从中间插入 first_part 的中间部分 else: # 计算需要重复front_width的次数 repeats = (width - release_width) // first_part_width extra = (width - release_width) % first_part_width _repeat_first_part_head_width = int(first_part_width / 2) _repeat_first_part_tail_width = first_part_width - _repeat_first_part_head_width repeated_first_head_parts = [first_part[:, :, :, :_repeat_first_part_head_width] for _ in range(repeats)] repeated_first_tail_parts = [first_part[:, :, :, -_repeat_first_part_tail_width:] for _ in range(repeats)] # 计算起始索引 _middle_part_start_index = (first_part_width - extra) // 2 # 切片张量以获取中间部分 middle_part = first_part[:, :, :, _middle_part_start_index: _middle_part_start_index + extra] all_parts = repeated_first_head_parts + [middle_part] + repeated_first_tail_parts + [release_part] # 沿第四维度拼接张量 noise_tensor = torch.cat(all_parts, dim=3) # 记录拼接点的位置 concat_points = [0] for part in all_parts[:-1]: next_point = concat_points[-1] + part.size(3) concat_points.append(next_point) return noise_tensor, concat_points def define_beta_schedule(self): assert self.respaced == False, "This schedule has already been respaced!" # define alphas self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recip_alphas = np.sqrt(1.0 / self.alphas) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) def activate_classifier_free_guidance(self, CFG, unconditional_condition): assert ( not unconditional_condition is None) or CFG == 1.0, "For CFG != 1.0, unconditional_condition must be available" self.CFG = CFG self.unconditional_condition = unconditional_condition def respace(self, use_timesteps=None): if not use_timesteps is None: last_alpha_cumprod = 1.0 new_betas = [] self.timestep_map = [] for i, _alpha_cumprod in enumerate(self.alphas_cumprod): if i in use_timesteps: new_betas.append(1 - _alpha_cumprod / last_alpha_cumprod) last_alpha_cumprod = _alpha_cumprod self.timestep_map.append(i) self.num_timesteps = len(use_timesteps) self.betas = np.array(new_betas) self.define_beta_schedule() self.respaced = True def generate_linear_noise(self, shape, variance=1.0, first_endpoint=None, second_endpoint=None): assert shape[1] == self.channels, "shape[1] != self.channels" assert shape[2] == self.height, "shape[2] != self.height" noise = torch.empty(*shape, device=self.device) # 第三种情况:两个端点都不是None,进行线性插值 if first_endpoint is not None and second_endpoint is not None: for i in range(shape[0]): alpha = i / (shape[0] - 1) # 插值系数 noise[i] = alpha * second_endpoint + (1 - alpha) * first_endpoint return noise # 返回插值后的结果,不需要进行后续的均值和方差调整 else: # 第一个端点不是None if first_endpoint is not None: noise[0] = first_endpoint if shape[0] > 1: noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] else: noise[0], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] if shape[0] > 1: noise[1], _ = self.get_deterministic_noise_tensor(1, shape[3])[0] # 生成其他的噪声点 for i in range(2, shape[0]): noise[i] = 2 * noise[i - 1] - noise[i - 2] # 当只有一个端点被指定时 current_var = noise.var() stddev_ratio = torch.sqrt(variance / current_var) noise = noise * stddev_ratio # 如果第一个端点被指定,进行平移调整 if first_endpoint is not None: shift = first_endpoint - noise[0] noise += shift return noise def q_sample(self, x_start, t, noise=None): """ Diffuse the data for a given number of diffusion steps. In other words, sample from q(x_t | x_0). :param x_start: the initial data batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ assert x_start.shape[1] == self.channels, "shape[1] != self.channels" assert x_start.shape[2] == self.height, "shape[2] != self.height" if noise is None: # noise = torch.randn_like(x_start) noise, _ = self.get_deterministic_noise_tensor(x_start.shape[0], x_start.shape[3]) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) @torch.no_grad() def ddim_sample(self, model, x, t, condition=None, ddim_eta=0.0): map_tensor = torch.tensor(self.timestep_map, device=t.device, dtype=t.dtype) mapped_t = map_tensor[t] # Todo: add CFG if self.CFG == 1.0: pred_noise = model(x, mapped_t, condition) else: unconditional_condition = self.unconditional_condition.unsqueeze(0).repeat( *([x.shape[0]] + [1] * len(self.unconditional_condition.shape))) x_in = torch.cat([x] * 2) t_in = torch.cat([mapped_t] * 2) c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = model(x_in, t_in, c_in).chunk(2) pred_noise = noise_uncond + self.CFG * (noise - noise_uncond) # Todo: END alpha_cumprod_t = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_cumprod_t_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) pred_x0 = (x - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t) sigmas_t = ( ddim_eta * torch.sqrt((1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t)) * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_prev) ) pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t ** 2) * pred_noise step_noise, _ = self.get_deterministic_noise_tensor(x.shape[0], x.shape[3]) x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * step_noise return x_prev def p_sample(self, model, x, t, condition=None, sampler="ddim"): if sampler == "ddim": return self.ddim_sample(model, x, t, condition=condition, ddim_eta=0.0) elif sampler == "ddpm": return self.ddim_sample(model, x, t, condition=condition, ddim_eta=1.0) else: raise NotImplementedError() def get_dynamic_masks(self, n_masks, shape, concat_points, mask_flexivity=0.8): release_length = int(self.train_width / 4) assert shape[3] == (concat_points[-1] + release_length), "shape[3] != (concat_points[-1] + release_length)" fraction_lengths = [concat_points[i + 1] - concat_points[i] for i in range(len(concat_points) - 1)] # Todo: remove hard-coding n_guidance_steps = int(n_masks * mask_flexivity) n_free_steps = n_masks - n_guidance_steps masks = [] # Todo: 在一半的 steps 内收缩 mask。也就是说,在后程对 release 以外的区域不做inpaint,而是 img2img for i in range(n_guidance_steps): # mask = 1, freeze step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) step_i_mask[:, :, :, -release_length:] = 1.0 for fraction_index in range(len(fraction_lengths)): _fraction_mask_length = int((n_guidance_steps - 1 - i) / (n_guidance_steps - 1) * fraction_lengths[fraction_index]) if fraction_index == 0: step_i_mask[:, :, :, :_fraction_mask_length] = 1.0 elif fraction_index == len(fraction_lengths) - 1: if not _fraction_mask_length == 0: step_i_mask[:, :, :, -_fraction_mask_length - release_length:] = 1.0 else: fraction_mask_start_position = int((fraction_lengths[fraction_index] - _fraction_mask_length) / 2) step_i_mask[:, :, :, concat_points[fraction_index] + fraction_mask_start_position:concat_points[ fraction_index] + fraction_mask_start_position + _fraction_mask_length] = 1.0 masks.append(step_i_mask) for i in range(n_free_steps): step_i_mask = torch.zeros((shape[0], 1, shape[2], shape[3]), dtype=torch.float32).to(self.device) step_i_mask[:, :, :, -release_length:] = 1.0 masks.append(step_i_mask) masks.reverse() return masks @torch.no_grad() def p_sample_loop(self, model, shape, initial_noise=None, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, return_tensor=False, condition=None, guide_img=None, mask=None, sampler="ddim", inpaint=False, use_dynamic_mask=False, mask_flexivity=0.8): assert shape[1] == self.channels, "shape[1] != self.channels" assert shape[2] == self.height, "shape[2] != self.height" initial_noise, _ = self.get_deterministic_noise_tensor(shape[0], shape[3], reference_noise=initial_noise) assert initial_noise.shape == shape, "initial_noise.shape != shape" start_noise_level_index = int(self.num_timesteps * start_noise_level_ratio) # not included!!! end_noise_level_index = int(self.num_timesteps * end_noise_level_ratio) timesteps = reversed(range(end_noise_level_index, start_noise_level_index)) # configure initial img assert (start_noise_level_ratio == 1.0) or ( not guide_img is None), "A guide_img must be given to sample from a non-pure-noise." if guide_img is None: img = initial_noise else: guide_img, concat_points = self.get_deterministic_noise_tensor_repeat(shape[0], shape[3], reference_noise=guide_img) assert guide_img.shape == shape, "guide_img.shape != shape" if start_noise_level_index > 0: t = torch.full((shape[0],), start_noise_level_index-1, device=self.device).long() # -1 for start_noise_level_index not included img = self.q_sample(guide_img, t, noise=initial_noise) else: print("Zero noise added to the guidance latent representation.") img = guide_img # get masks n_masks = start_noise_level_index - end_noise_level_index if use_dynamic_mask: masks = self.get_dynamic_masks(n_masks, shape, concat_points, mask_flexivity) else: masks = [mask for _ in range(n_masks)] imgs = [img] current_mask = None for i in tqdm(timesteps, total=start_noise_level_index - end_noise_level_index, disable=self.mute): # if i == 3: # return [img], initial_noise # 第1排,第1列 img = self.p_sample(model, img, torch.full((shape[0],), i, device=self.device, dtype=torch.long), condition=condition, sampler=sampler) # if i == 3: # return [img], initial_noise # 第1排,第2列 if inpaint: if i > 0: t = torch.full((shape[0],), int(i-1), device=self.device).long() img_noise_t = self.q_sample(guide_img, t, noise=initial_noise) # if i == 3: # return [img_noise_t], initial_noise # 第2排,第2列 current_mask = masks.pop() img = current_mask * img_noise_t + (1 - current_mask) * img # if i == 3: # return [img], initial_noise # 第1.5排,最后1列 else: img = current_mask * guide_img + (1 - current_mask) * img if return_tensor: imgs.append(img) else: imgs.append(img.cpu().numpy()) return imgs, initial_noise def sample(self, model, shape, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None): if not seed is None: torch.manual_seed(seed) return self.p_sample_loop(model, shape, initial_noise=initial_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, return_tensor=return_tensor, condition=condition, sampler=sampler) def interpolate(self, model, shape, variance, first_endpoint=None, second_endpoint=None, return_tensor=False, condition=None, sampler="ddim", seed=None): if not seed is None: torch.manual_seed(seed) linear_noise = self.generate_linear_noise(shape, variance, first_endpoint=first_endpoint, second_endpoint=second_endpoint) return self.p_sample_loop(model, shape, initial_noise=linear_noise, start_noise_level_ratio=1.0, end_noise_level_ratio=0.0, return_tensor=return_tensor, condition=condition, sampler=sampler) def img_guided_sample(self, model, shape, noising_strength, guide_img, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, seed=None): if not seed is None: torch.manual_seed(seed) assert guide_img.shape[-1] == shape[-1], "guide_img.shape[:-1] != shape[:-1]" return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=0.0, return_tensor=return_tensor, condition=condition, sampler=sampler, guide_img=guide_img, initial_noise=initial_noise) def inpaint_sample(self, model, shape, noising_strength, guide_img, mask, return_tensor=False, condition=None, sampler="ddim", initial_noise=None, use_dynamic_mask=False, end_noise_level_ratio=0.0, seed=None, mask_flexivity=0.8): if not seed is None: torch.manual_seed(seed) return self.p_sample_loop(model, shape, start_noise_level_ratio=noising_strength, end_noise_level_ratio=end_noise_level_ratio, return_tensor=return_tensor, condition=condition, guide_img=guide_img, mask=mask, sampler=sampler, inpaint=True, initial_noise=initial_noise, use_dynamic_mask=use_dynamic_mask, mask_flexivity=mask_flexivity)