Spaces:
Runtime error
Runtime error
| # Code adapted from https://github.com/JunyaoHu/common_metrics_on_video_quality | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| def trans(x): | |
| # if greyscale images add channel | |
| if x.shape[-3] == 1: | |
| x = x.repeat(1, 1, 3, 1, 1) | |
| # permute BTCHW -> BCTHW | |
| x = x.permute(0, 2, 1, 3, 4) | |
| return x | |
| def calculate_fvd(videos1, videos2, device="cuda", method='styleganv'): | |
| if method == 'styleganv': | |
| from .fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained | |
| elif method == 'videogpt': | |
| from .fvd.videogpt.fvd import load_i3d_pretrained | |
| from .fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats | |
| from .fvd.videogpt.fvd import frechet_distance | |
| # videos [batch_size, timestamps, channel, h, w] | |
| assert videos1.shape == videos2.shape | |
| i3d = load_i3d_pretrained(device=device) | |
| fvd_results = [] | |
| # support grayscale input, if grayscale -> channel*3 | |
| # BTCHW -> BCTHW | |
| # videos -> [batch_size, channel, timestamps, h, w] | |
| videos1 = trans(videos1) | |
| videos2 = trans(videos2) | |
| # fvd_results = {} | |
| # for calculate FVD, each clip_timestamp must >= 10 | |
| for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): | |
| # print("clip_timestamp", clip_timestamp) | |
| # get a video clip | |
| # videos_clip [batch_size, channel, timestamps[:clip], h, w] | |
| videos_clip1 = videos1[:, :, : clip_timestamp] | |
| videos_clip2 = videos2[:, :, : clip_timestamp] | |
| # get FVD features | |
| feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) | |
| feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) | |
| # calculate FVD when timestamps[:clip] | |
| fvd_results.append(frechet_distance(feats1, feats2)) | |
| return fvd_results[-1] # only the last step | |
| # test code / using example | |
| def main(): | |
| NUMBER_OF_VIDEOS = 8 | |
| VIDEO_LENGTH = 50 | |
| CHANNEL = 3 | |
| SIZE = 64 | |
| videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) | |
| videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) | |
| device = torch.device("cuda") | |
| # device = torch.device("cpu") | |
| import json | |
| result = calculate_fvd(videos1, videos2, device, method='videogpt') | |
| print(json.dumps(result, indent=4)) | |
| result = calculate_fvd(videos1, videos2, device, method='styleganv') | |
| print(json.dumps(result, indent=4)) | |
| if __name__ == "__main__": | |
| main() | |