|
import random |
|
import argparse |
|
import cv2 |
|
from tqdm import tqdm |
|
import numpy as np |
|
import numpy.typing as npt |
|
import torch |
|
from decord import VideoReader, cpu |
|
from torch.nn import functional as F |
|
from pytorchvideo.transforms import ShortSideScale |
|
from torchvision.transforms import Lambda, Compose |
|
from torchvision.transforms._transforms_video import CenterCropVideo |
|
import sys |
|
from torch.utils.data import Dataset, DataLoader, Subset |
|
import os |
|
|
|
sys.path.append(".") |
|
from opensora.models.ae.videobase import CausalVAEModel |
|
import torch.nn as nn |
|
|
|
|
|
def array_to_video( |
|
image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4" |
|
) -> None: |
|
height, width, channels = image_array[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) |
|
|
|
for image in image_array: |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
video_writer.write(image_rgb) |
|
|
|
video_writer.release() |
|
|
|
|
|
def custom_to_video( |
|
x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4" |
|
) -> None: |
|
x = x.detach().cpu() |
|
x = torch.clamp(x, -1, 1) |
|
x = (x + 1) / 2 |
|
x = x.permute(1, 2, 3, 0).float().numpy() |
|
x = (255 * x).astype(np.uint8) |
|
array_to_video(x, fps=fps, output_file=output_file) |
|
return |
|
|
|
|
|
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: |
|
decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8) |
|
total_frames = len(decord_vr) |
|
sample_frames_len = sample_rate * num_frames |
|
|
|
if total_frames > sample_frames_len: |
|
s = 0 |
|
e = s + sample_frames_len |
|
num_frames = num_frames |
|
else: |
|
s = 0 |
|
e = total_frames |
|
num_frames = int(total_frames / sample_frames_len * num_frames) |
|
print( |
|
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", |
|
video_path, |
|
total_frames, |
|
) |
|
|
|
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) |
|
video_data = decord_vr.get_batch(frame_id_list).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(3, 0, 1, 2) |
|
return video_data |
|
|
|
|
|
class RealVideoDataset(Dataset): |
|
def __init__( |
|
self, |
|
real_video_dir, |
|
num_frames, |
|
sample_rate=1, |
|
crop_size=None, |
|
resolution=128, |
|
) -> None: |
|
super().__init__() |
|
self.real_video_files = self._combine_without_prefix(real_video_dir) |
|
self.num_frames = num_frames |
|
self.sample_rate = sample_rate |
|
self.crop_size = crop_size |
|
self.short_size = resolution |
|
|
|
def __len__(self): |
|
return len(self.real_video_files) |
|
|
|
def __getitem__(self, index): |
|
if index >= len(self): |
|
raise IndexError |
|
real_video_file = self.real_video_files[index] |
|
real_video_tensor = self._load_video(real_video_file) |
|
video_name = os.path.basename(real_video_file) |
|
return {'video': real_video_tensor, 'file_name': video_name } |
|
|
|
def _load_video(self, video_path): |
|
num_frames = self.num_frames |
|
sample_rate = self.sample_rate |
|
decord_vr = VideoReader(video_path, ctx=cpu(0)) |
|
total_frames = len(decord_vr) |
|
sample_frames_len = sample_rate * num_frames |
|
|
|
if total_frames > sample_frames_len: |
|
s = 0 |
|
e = s + sample_frames_len |
|
num_frames = num_frames |
|
else: |
|
s = 0 |
|
e = total_frames |
|
num_frames = int(total_frames / sample_frames_len * num_frames) |
|
print( |
|
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", |
|
video_path, |
|
total_frames, |
|
) |
|
|
|
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) |
|
video_data = decord_vr.get_batch(frame_id_list).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(3, 0, 1, 2) |
|
return _preprocess( |
|
video_data, short_size=self.short_size, crop_size=self.crop_size |
|
) |
|
|
|
def _combine_without_prefix(self, folder_path, prefix="."): |
|
folder = [] |
|
for name in os.listdir(folder_path): |
|
if name[0] == prefix: |
|
continue |
|
folder.append(os.path.join(folder_path, name)) |
|
folder.sort() |
|
return folder |
|
|
|
def resize(x, resolution): |
|
height, width = x.shape[-2:] |
|
aspect_ratio = width / height |
|
if width <= height: |
|
new_width = resolution |
|
new_height = int(resolution / aspect_ratio) |
|
else: |
|
new_height = resolution |
|
new_width = int(resolution * aspect_ratio) |
|
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) |
|
return resized_x |
|
|
|
def _preprocess(video_data, short_size=128, crop_size=None): |
|
transform = Compose( |
|
[ |
|
Lambda(lambda x: ((x / 255.0) * 2 - 1)), |
|
Lambda(lambda x: resize(x, short_size)), |
|
( |
|
CenterCropVideo(crop_size=crop_size) |
|
if crop_size is not None |
|
else Lambda(lambda x: x) |
|
), |
|
] |
|
) |
|
video_outputs = transform(video_data) |
|
video_outputs = _format_video_shape(video_outputs) |
|
return video_outputs |
|
|
|
|
|
def _format_video_shape(video, time_compress=4, spatial_compress=8): |
|
time = video.shape[1] |
|
height = video.shape[2] |
|
width = video.shape[3] |
|
new_time = ( |
|
(time - (time - 1) % time_compress) |
|
if (time - 1) % time_compress != 0 |
|
else time |
|
) |
|
new_height = ( |
|
(height - (height) % spatial_compress) |
|
if height % spatial_compress != 0 |
|
else height |
|
) |
|
new_width = ( |
|
(width - (width) % spatial_compress) if width % spatial_compress != 0 else width |
|
) |
|
return video[:, :new_time, :new_height, :new_width] |
|
|
|
|
|
@torch.no_grad() |
|
def main(args: argparse.Namespace): |
|
real_video_dir = args.real_video_dir |
|
generated_video_dir = args.generated_video_dir |
|
ckpt = args.ckpt |
|
sample_rate = args.sample_rate |
|
resolution = args.resolution |
|
crop_size = args.crop_size |
|
num_frames = args.num_frames |
|
sample_rate = args.sample_rate |
|
device = args.device |
|
sample_fps = args.sample_fps |
|
batch_size = args.batch_size |
|
num_workers = args.num_workers |
|
subset_size = args.subset_size |
|
|
|
if not os.path.exists(args.generated_video_dir): |
|
os.makedirs(args.generated_video_dir, exist_ok=True) |
|
|
|
data_type = torch.bfloat16 |
|
|
|
|
|
device = args.device |
|
vqvae = CausalVAEModel.from_pretrained(args.ckpt) |
|
vqvae = vqvae.to(device).to(data_type) |
|
if args.enable_tiling: |
|
vqvae.enable_tiling() |
|
vqvae.tile_overlap_factor = args.tile_overlap_factor |
|
|
|
|
|
|
|
dataset = RealVideoDataset( |
|
real_video_dir=real_video_dir, |
|
num_frames=num_frames, |
|
sample_rate=sample_rate, |
|
crop_size=crop_size, |
|
resolution=resolution, |
|
) |
|
|
|
if subset_size: |
|
indices = range(subset_size) |
|
dataset = Subset(dataset, indices=indices) |
|
|
|
dataloader = DataLoader( |
|
dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers |
|
) |
|
|
|
|
|
|
|
for batch in tqdm(dataloader): |
|
x, file_names = batch['video'], batch['file_name'] |
|
x = x.to(device=device, dtype=data_type) |
|
latents = vqvae.encode(x).sample().to(data_type) |
|
video_recon = vqvae.decode(latents) |
|
for idx, video in enumerate(video_recon): |
|
output_path = os.path.join(generated_video_dir, file_names[idx]) |
|
if args.output_origin: |
|
os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True) |
|
origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx]) |
|
custom_to_video( |
|
x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path |
|
) |
|
custom_to_video( |
|
video, fps=sample_fps / sample_rate, output_file=output_path |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--real_video_dir", type=str, default="") |
|
parser.add_argument("--generated_video_dir", type=str, default="") |
|
parser.add_argument("--ckpt", type=str, default="") |
|
parser.add_argument("--sample_fps", type=int, default=30) |
|
parser.add_argument("--resolution", type=int, default=336) |
|
parser.add_argument("--crop_size", type=int, default=None) |
|
parser.add_argument("--num_frames", type=int, default=17) |
|
parser.add_argument("--sample_rate", type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--num_workers", type=int, default=8) |
|
parser.add_argument("--subset_size", type=int, default=None) |
|
parser.add_argument("--tile_overlap_factor", type=float, default=0.25) |
|
parser.add_argument('--enable_tiling', action='store_true') |
|
parser.add_argument('--output_origin', action='store_true') |
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|