Spaces:
Build error
Build error
| import torch | |
| import random | |
| import numbers | |
| from torchvision.transforms import RandomCrop, RandomResizedCrop | |
| def _is_tensor_video_clip(clip): | |
| if not torch.is_tensor(clip): | |
| raise TypeError("clip should be Tensor. Got %s" % type(clip)) | |
| if not clip.ndimension() == 4: | |
| raise ValueError("clip should be 4D. Got %dD" % clip.dim()) | |
| return True | |
| def to_tensor(clip): | |
| """ | |
| Convert tensor data type from uint8 to float, divide value by 255.0 and | |
| permute the dimensions of clip tensor | |
| Args: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
| Return: | |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
| """ | |
| _is_tensor_video_clip(clip) | |
| if not clip.dtype == torch.uint8: | |
| raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) | |
| # return clip.float().permute(3, 0, 1, 2) / 255.0 | |
| return clip.float() / 255.0 | |
| def resize(clip, target_size, interpolation_mode): | |
| if len(target_size) != 2: | |
| raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") | |
| return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) | |
| class ToTensorVideo: | |
| """ | |
| Convert tensor data type from uint8 to float, divide value by 255.0 and | |
| permute the dimensions of clip tensor | |
| """ | |
| def __init__(self): | |
| pass | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) | |
| Return: | |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) | |
| """ | |
| return to_tensor(clip) | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ | |
| class ResizeVideo: | |
| ''' | |
| Resize to the specified size | |
| ''' | |
| def __init__( | |
| self, | |
| size, | |
| interpolation_mode="bilinear", | |
| ): | |
| if isinstance(size, tuple): | |
| if len(size) != 2: | |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation_mode = interpolation_mode | |
| def __call__(self, clip): | |
| """ | |
| Args: | |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) | |
| Returns: | |
| torch.tensor: scale resized video clip. | |
| size is (T, C, h, w) | |
| """ | |
| clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) | |
| return clip_resize | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" | |
| class TemporalRandomCrop(object): | |
| """Temporally crop the given frame indices at a random location. | |
| Args: | |
| size (int): Desired length of frames will be seen in the model. | |
| """ | |
| def __init__(self, size): | |
| self.size = size | |
| def __call__(self, total_frames): | |
| rand_end = max(0, total_frames - self.size - 1) | |
| begin_index = random.randint(0, rand_end) | |
| end_index = min(begin_index + self.size, total_frames) | |
| return begin_index, end_index | |