Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import warnings | |
| import torch | |
| def _is_tensor_video_clip(clip): | |
| if not torch.is_tensor(clip): | |
| raise TypeError("clip should be Tensor. Got %s" % type(clip)) | |
| if not clip.ndimension() == 4: | |
| raise ValueError("clip should be 4D. Got %dD" % clip.dim()) | |
| return True | |
| def crop(clip, i, j, h, w): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | |
| """ | |
| if len(clip.size()) != 4: | |
| raise ValueError("clip should be a 4D tensor") | |
| return clip[..., i : i + h, j : j + w] | |
| def resize(clip, target_size, interpolation_mode): | |
| if len(target_size) != 2: | |
| raise ValueError( | |
| f"target size should be tuple (height, width), instead got {target_size}" | |
| ) | |
| return torch.nn.functional.interpolate( | |
| clip, size=target_size, mode=interpolation_mode, align_corners=False | |
| ) | |
| def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): | |
| """ | |
| Do spatial cropping and resizing to the video clip | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | |
| i (int): i in (i,j) i.e coordinates of the upper left corner. | |
| j (int): j in (i,j) i.e coordinates of the upper left corner. | |
| h (int): Height of the cropped region. | |
| w (int): Width of the cropped region. | |
| size (tuple(int, int)): height and width of resized clip | |
| Returns: | |
| clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| clip = crop(clip, i, j, h, w) | |
| clip = resize(clip, size, interpolation_mode) | |
| return clip | |
| def center_crop(clip, crop_size): | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| h, w = clip.size(-2), clip.size(-1) | |
| th, tw = crop_size | |
| if h < th or w < tw: | |
| raise ValueError("height and width must be no smaller than crop_size") | |
| i = int(round((h - th) / 2.0)) | |
| j = int(round((w - tw) / 2.0)) | |
| return crop(clip, i, j, th, tw) | |
| def to_tensor(clip): | |
| """ | |
| Convert tensor data type from uint8 to float, divide value by 255.0 and | |
| permute the dimensions of clip tensor | |
| Args: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) | |
| Return: | |
| clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) | |
| """ | |
| _is_tensor_video_clip(clip) | |
| if not clip.dtype == torch.uint8: | |
| raise TypeError( | |
| "clip tensor should have data type uint8. Got %s" % str(clip.dtype) | |
| ) | |
| return clip.float().permute(3, 0, 1, 2) / 255.0 | |
| def normalize(clip, mean, std, inplace=False): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) | |
| mean (tuple): pixel RGB mean. Size is (3) | |
| std (tuple): pixel standard deviation. Size is (3) | |
| Returns: | |
| normalized clip (torch.tensor): Size is (C, T, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| if not inplace: | |
| clip = clip.clone() | |
| mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) | |
| std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) | |
| clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) | |
| return clip | |
| def hflip(clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) | |
| Returns: | |
| flipped clip (torch.tensor): Size is (C, T, H, W) | |
| """ | |
| if not _is_tensor_video_clip(clip): | |
| raise ValueError("clip should be a 4D torch.tensor") | |
| return clip.flip(-1) | |