# from moviepy.video.io.ImageSequenceClip import ImageSequenceClip # # from moviepy.audio.AudioClip import AudioArrayClip # from moviepy.audio.io.AudioFileClip import AudioFileClip from torch.utils.data import DataLoader from dataset import AudioVideoDataset, LatentDataset import torch as th import numpy as np import einops from moviepy.audio.io.AudioFileClip import AudioFileClip from moviepy.video.io.ImageSequenceClip import ImageSequenceClip from diffusers.models import AutoencoderKL from converter import denormalize, denormalize_spectrogram import soundfile as sf import os import json import torch from tqdm import tqdm ################################################################################# # Video Utils # ################################################################################# def preprocess_video(video): # video = 255*(video+1)/2.0 # [-1,1] -> [0,1] -> [0,255] # video = th.clamp(video, 0, 255).to(dtype=th.uint8, device="cuda") video = out2img(video) video = einops.rearrange(video, 't c h w -> t h w c').cpu().numpy() return video def preprocess_video_batch(videos): B = videos.shape[0] videos_prep = np.empty(B, dtype=np.ndarray) for b in range(B): videos_prep[b] = preprocess_video(videos[b]) videos_prep = np.stack(videos_prep, axis=0) return videos_prep def save_latents(video, audio, y, output_path, name_prefix, ext=".pt"): os.makedirs(output_path, exist_ok=True) th.save( { "video":video, "audio":audio, "y":y }, os.path.join(output_path, name_prefix + ext)) def save_multimodal(video, audio, output_path, name_prefix, video_fps=10, audio_fps=16000, audio_dir=None): if not audio_dir: audio_dir = output_path #prepare folders audio_dir = os.path.join(audio_dir, "audio") os.makedirs(audio_dir, exist_ok=True) audio_path = os.path.join(audio_dir, name_prefix + "_audio.wav") video_dir = os.path.join(output_path, "video") os.makedirs(video_dir, exist_ok=True) video_path = os.path.join(video_dir, name_prefix + "_video.mp4") #save audio sf.write(audio_path, audio, samplerate=audio_fps) #save video video = preprocess_video(video) imgs = [img for img in video] video_clip = ImageSequenceClip(imgs, fps=video_fps) audio_clip = AudioFileClip(audio_path) video_clip = video_clip.with_audio(audio_clip) video_clip.write_videofile(video_path, video_fps, audio=True, audio_fps=audio_fps) def get_dataloader(args, logger, sequence_length, train, latents=False): if latents: train_set = LatentDataset(args.data_path, train=train) else: train_set = AudioVideoDataset( args.data_path, train=train, sample_every_n_frames=1, resolution=args.image_size, sequence_length = sequence_length, audio_channels = 1, sample_rate=16000, min_length=1, ignore_cache=args.ignore_cache, labeled=args.num_classes > 0, target_video_fps=args.target_video_fps, ) loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True ) if logger is not None: logger.info(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})') else: print(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})') return loader @torch.no_grad() def encode_video(video, vae, use_sd_vae = False): b, t, c, h, w = video.shape video = einops.rearrange(video, "b t c h w-> (b t) c h w") if use_sd_vae: video = vae.encode(video).latent_dist.sample().mul_(0.18215) else: video = vae.encode(video)*vae.cfg.scaling_factor video = einops.rearrange(video, "(b t) c h w -> b t c h w", t=t) return video @torch.no_grad() def decode_video(video, vae): b = video.shape[0] video_decoded = [] video = einops.rearrange(video, "b t c h w -> (b t) c h w") #use minibatch to avoid memory error for i in range(0, video.shape[0], b): if isinstance(vae, AutoencoderKL): video_decoded.append(vae.decode(video[i:i+b] / 0.18215).sample.detach().cpu()) else: video_decoded.append(vae.decode(video[i:i+b] / vae.cfg.scaling_factor).detach().cpu()) video = torch.cat(video_decoded, dim=0) video = einops.rearrange(video, "(b t) c h w ->b t c h w",b=b) return video def generate_sample(vae, rectified_flow, forward_fn, video_length, video_latent_size, audio_latent_size, y, cfg_scale, device): with torch.no_grad(): v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) sample_fn = rectified_flow.sample( forward_fn, v_z, a_z, model_kwargs=model_kwargs, progress=True)() video = [] audio = [] for _ in tqdm(range(video_length), desc="Generating frames"): video_samples, audio_samples = next(sample_fn) video.append(video_samples) audio.append(audio_samples) video = torch.stack(video, dim=1) audio = torch.stack(audio, dim=1) video = decode_video(video, vae) audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") return video, audio def generate_sample_a2v(vae, rectified_flow, forward_fn, video_length, video_latent_size, audio, y, device, cfg_scale=1, scale=1): v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) sample_fn = rectified_flow.sample_a2v( forward_fn, v_z, audio, model_kwargs=model_kwargs, scale=scale, progress=True)() video = [] for i in tqdm(range(video_length), desc="Generating frames"): video_samples = next(sample_fn) video.append(video_samples) video = torch.stack(video, dim=1) video = decode_video(video, vae) audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") return video, audio def generate_sample_v2a(vae, rectified_flow, forward_fn, video_length, video, audio_latent_size, y, device, cfg_scale=1, scale=1): a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) sample_fn = rectified_flow.sample_v2a( forward_fn, video, a_z, model_kwargs=model_kwargs, scale=scale, progress=True)() audio = [] for i in tqdm(range(video_length), desc="Generating frames"): audio_samples = next(sample_fn) audio.append(audio_samples) audio = torch.stack(audio, dim=1) video = decode_video(video, vae) audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") return video, audio def dict_to_json(path, args): with open(path, 'w') as f: json.dump(args.__dict__, f, indent=2) def json_to_dict(path, args): with open(path, 'r') as f: args.__dict__ = json.load(f) return args def log_args(args, logger): text = "" for k, v in vars(args).items(): text += f'{k}={v}\n' logger.info(f"##### ARGS #####\n{text}") def out2img(samples): return th.clamp(127.5 * samples + 128.0, 0, 255).to( dtype=th.uint8 ).cuda() def get_gpu_usage(): device = th.device('cuda:0') free, total = th.cuda.mem_get_info(device) mem_used_MB = (total - free) / 1024 ** 2 return mem_used_MB def get_wavs(norm_spec, vocoder, audio_scale, device): norm_spec = norm_spec.squeeze(1) norm_spec = norm_spec / audio_scale post_norm_spec = denormalize(norm_spec).to(device) raw_chunk_spec = denormalize_spectrogram(post_norm_spec) wavs = vocoder.inference(raw_chunk_spec) return wavs