Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import argparse | |
| import torchvision | |
| from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, | |
| EulerDiscreteScheduler, DPMSolverMultistepScheduler, | |
| HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, | |
| DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) | |
| from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler | |
| from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder | |
| from omegaconf import OmegaConf | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| import os, sys | |
| sys.path.append(os.path.split(sys.path[0])[0]) | |
| from pipeline_latte import LattePipeline | |
| from models import get_models | |
| from utils import save_video_grid | |
| import imageio | |
| from torchvision.utils import save_image | |
| def main(args): | |
| # torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| transformer_model = get_models(args).to(device, dtype=torch.float16) | |
| if args.enable_vae_temporal_decoder: | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) | |
| else: | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) | |
| tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) | |
| # set eval mode | |
| transformer_model.eval() | |
| vae.eval() | |
| text_encoder.eval() | |
| if args.sample_method == 'DDIM': | |
| scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type, | |
| clip_sample=False) | |
| elif args.sample_method == 'EulerDiscrete': | |
| scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'DDPM': | |
| scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type, | |
| clip_sample=False) | |
| elif args.sample_method == 'DPMSolverMultistep': | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'DPMSolverSinglestep': | |
| scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'PNDM': | |
| scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'HeunDiscrete': | |
| scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'EulerAncestralDiscrete': | |
| scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'DEISMultistep': | |
| scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| elif args.sample_method == 'KDPM2AncestralDiscrete': | |
| scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule, | |
| variance_type=args.variance_type) | |
| videogen_pipeline = LattePipeline(vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| transformer=transformer_model).to(device) | |
| # videogen_pipeline.enable_xformers_memory_efficient_attention() | |
| if not os.path.exists(args.save_img_path): | |
| os.makedirs(args.save_img_path) | |
| # video_grids = [] | |
| for num_prompt, prompt in enumerate(args.text_prompt): | |
| print('Processing the ({}) prompt'.format(prompt)) | |
| videos = videogen_pipeline(prompt, | |
| video_length=args.video_length, | |
| height=args.image_size[0], | |
| width=args.image_size[1], | |
| num_inference_steps=args.num_sampling_steps, | |
| guidance_scale=args.guidance_scale, | |
| enable_temporal_attentions=args.enable_temporal_attentions, | |
| num_images_per_prompt=1, | |
| mask_feature=True, | |
| enable_vae_temporal_decoder=args.enable_vae_temporal_decoder | |
| ).video | |
| if videos.shape[1] == 1: | |
| try: | |
| save_image(videos[0][0], args.save_img_path + prompt.replace(' ', '_') + '.png') | |
| except: | |
| save_image(videos[0][0], args.save_img_path + str(num_prompt)+ '.png') | |
| print('Error when saving {}'.format(prompt)) | |
| else: | |
| try: | |
| imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 | |
| except: | |
| print('Error when saving {}'.format(prompt)) | |
| # save video grid | |
| # video_grids.append(videos) | |
| # video_grids = torch.cat(video_grids, dim=0) | |
| # video_grids = save_video_grid(video_grids) | |
| # # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) | |
| # imageio.mimwrite(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=8, quality=6) | |
| # print('save path {}'.format(args.save_img_path)) | |
| # save_videos_grid(video, f"./{prompt}.gif") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="./configs/wbv10m_train.yaml") | |
| args = parser.parse_args() | |
| main(OmegaConf.load(args.config)) | |