Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| import scipy.io.wavfile as wav | |
| import ffmpeg | |
| import cv2 | |
| from PIL import Image | |
| import decord | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Compose, GaussianBlur, Grayscale, Resize | |
| import torchaudio | |
| decord.bridge.set_bridge('torch') | |
| torchaudio.set_audio_backend("sox_io") | |
| class AudioEncoder(nn.Module): | |
| """ | |
| A PyTorch Module to encode audio data into a fixed-size vector | |
| (also known as an "embedding"). This can be useful for various machine | |
| learning tasks such as classification, similarity matching, etc. | |
| """ | |
| def __init__(self, path): | |
| """ | |
| Initialize the AudioEncoder object. | |
| Args: | |
| path (str): The file path where the pre-trained model is stored. | |
| """ | |
| super().__init__() | |
| self.model = torch.jit.load(path) | |
| self.register_buffer('hidden', torch.zeros(2, 1, 256)) | |
| def forward(self, audio): | |
| """ | |
| The forward method is where the actual encoding happens. Given an | |
| audio sample, this function returns its corresponding embedding. | |
| Args: | |
| audio (Tensor): A PyTorch tensor containing the audio data. | |
| Returns: | |
| Tensor: The embedding of the given audio. | |
| """ | |
| self.reset() | |
| x = create_windowed_sequence(audio, 3200, cutting_stride=640, pad_samples=3200-640, cut_dim=1) | |
| embs = [] | |
| for i in range(x.shape[1]): | |
| emb, _, self.hidden = self.model(x[:, i], torch.LongTensor([3200]), init_state=self.hidden) | |
| embs.append(emb) | |
| return torch.vstack(embs) | |
| def reset(self): | |
| """ | |
| Resets the hidden states in the model. Call this function | |
| before processing a new audio sample to ensure that there is | |
| no state carried over from the previous sample. | |
| """ | |
| self.hidden = torch.zeros(2, 1, 256).to(self.hidden.device) | |
| def get_audio_emb(audio_path, checkpoint, device): | |
| """ | |
| This function takes the path of an audio file, loads it into a | |
| PyTorch tensor, and returns its embedding. | |
| Args: | |
| audio_path (str): The file path of the audio to be loaded. | |
| checkpoint (str): The file path of the pre-trained model. | |
| device (str): The computing device ('cpu' or 'cuda'). | |
| Returns: | |
| Tensor, Tensor: The original audio as a tensor and its corresponding embedding. | |
| """ | |
| audio, audio_rate = torchaudio.load(audio_path, channels_first=False) | |
| assert audio_rate == 16000, 'Only 16 kHZ audio is supported.' | |
| audio = audio[None, None, :, 0].to(device) | |
| audio_encoder = AudioEncoder(checkpoint).to(device) | |
| emb = audio_encoder(audio) | |
| return audio, emb | |
| def get_id_frame(path, random=False, resize=128): | |
| """ | |
| Retrieves a frame from either a video or image file. This frame can | |
| serve as an identifier or reference for the video or image. | |
| Args: | |
| path (str): File path to the video or image. | |
| random (bool): Whether to randomly select a frame from the video. | |
| resize (int): The dimensions to which the frame should be resized. | |
| Returns: | |
| Tensor: The image frame as a tensor. | |
| """ | |
| if path.endswith('.mp4'): | |
| vr = decord.VideoReader(path) | |
| if random: | |
| idx = [np.random.randint(len(vr))] | |
| else: | |
| idx = [0] | |
| frame = vr.get_batch(idx).permute(0, 3, 1, 2) | |
| else: | |
| frame = load_image_to_torch(path).unsqueeze(0) | |
| frame = (frame / 255) * 2 - 1 | |
| frame = Resize((resize, resize), antialias=True)(frame).float() | |
| return frame | |
| def get_motion_transforms(args): | |
| """ | |
| Applies a series of transformations like Gaussian blur and grayscale | |
| conversion based on the provided arguments. This is commonly used for | |
| data augmentation or preprocessing. | |
| Args: | |
| args (Namespace): Arguments containing options for motion transformations. | |
| Returns: | |
| Compose: A composed function of transforms. | |
| """ | |
| motion_transforms = [] | |
| if args.motion_blur: | |
| motion_transforms.append(GaussianBlur(5, sigma=2.0)) | |
| if args.grayscale_motion: | |
| motion_transforms.append(Grayscale(1)) | |
| return Compose(motion_transforms) | |
| def save_audio(path, audio, audio_rate=16000): | |
| """ | |
| Saves the audio data as a WAV file. | |
| Args: | |
| path (str): The file path where the audio will be saved. | |
| audio (Tensor or np.array): The audio data. | |
| audio_rate (int): The sampling rate of the audio, defaults to 16000Hz. | |
| """ | |
| if torch.is_tensor(audio): | |
| aud = audio.squeeze().detach().cpu().numpy() | |
| else: | |
| aud = audio.copy() # Make a copy so that we don't alter the object | |
| aud = ((2 ** 15) * aud).astype(np.int16) | |
| wav.write(path, audio_rate, aud) | |
| def save_video(path, video, fps=25, scale=2, audio=None, audio_rate=16000, overlay_pts=None, ffmpeg_experimental=False): | |
| """ | |
| Saves the video data as an MP4 file. Optionally includes audio and overlay points. | |
| Args: | |
| path (str): The file path where the video will be saved. | |
| video (Tensor or np.array): The video data. | |
| fps (int): Frames per second of the video. | |
| scale (int): Scaling factor for the video dimensions. | |
| audio (Tensor or np.array, optional): Audio data. | |
| audio_rate (int, optional): The sampling rate for the audio. | |
| overlay_pts (list of points, optional): Points to overlay on the video frames. | |
| ffmpeg_experimental (bool): Whether to use experimental ffmpeg options. | |
| Returns: | |
| bool: Success status. | |
| """ | |
| if not os.path.exists(os.path.dirname(path)): | |
| os.makedirs(os.path.dirname(path)) | |
| success = True | |
| out_size = (scale * video.shape[-1], scale * video.shape[-2]) | |
| video_path = get_temp_path(os.path.split(path)[0], ext=".mp4") | |
| if torch.is_tensor(video): | |
| vid = video.squeeze().detach().cpu().numpy() | |
| else: | |
| vid = video.copy() # Make a copy so that we don't alter the object | |
| if np.min(vid) < 0: | |
| vid = 127 * vid + 127 | |
| elif np.max(vid) <= 1: | |
| vid = 255 * vid | |
| is_color = True | |
| if vid.ndim == 3: | |
| is_color = False | |
| writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), float(fps), out_size, isColor=is_color) | |
| for i, frame in enumerate(vid): | |
| if is_color: | |
| frame = cv2.cvtColor(np.rollaxis(frame, 0, 3), cv2.COLOR_RGB2BGR) | |
| if scale != 1: | |
| frame = cv2.resize(frame, out_size) | |
| write_frame = frame.astype('uint8') | |
| if overlay_pts is not None: | |
| for pt in overlay_pts[i]: | |
| cv2.circle(write_frame, (int(scale * pt[0]), int(scale * pt[1])), 2, (0, 0, 0), -1) | |
| writer.write(write_frame) | |
| writer.release() | |
| inputs = [ffmpeg.input(video_path)['v']] | |
| if audio is not None: # Save the audio file | |
| audio_path = swp_extension(video_path, ".wav") | |
| save_audio(audio_path, audio, audio_rate) | |
| inputs += [ffmpeg.input(audio_path)['a']] | |
| try: | |
| if ffmpeg_experimental: | |
| out = ffmpeg.output(*inputs, path, strict='-2', loglevel="panic", vcodec='h264').overwrite_output() | |
| else: | |
| out = ffmpeg.output(*inputs, path, loglevel="panic", vcodec='h264').overwrite_output() | |
| out.run(quiet=True) | |
| except: | |
| success = False | |
| if audio is not None and os.path.isfile(audio_path): | |
| os.remove(audio_path) | |
| if os.path.isfile(video_path): | |
| os.remove(video_path) | |
| return success | |
| def load_image_to_torch(dir): | |
| """ | |
| Load an image from disk and convert it to a PyTorch tensor. | |
| Args: | |
| dir (str): The directory path to the image file. | |
| Returns: | |
| torch.Tensor: A tensor representation of the image. | |
| """ | |
| img = Image.open(dir).convert('RGB') | |
| img = np.array(img) | |
| return torch.from_numpy(img).permute(2, 0, 1) | |
| def get_temp_path(tmp_dir, mode="", ext=""): | |
| """ | |
| Generate a temporary file path for storing data. | |
| Args: | |
| tmp_dir (str): The directory where the temporary file will be created. | |
| mode (str, optional): A string to append to the file name. | |
| ext (str, optional): The file extension. | |
| Returns: | |
| str: The full path to the temporary file. | |
| """ | |
| file_path = next(tempfile._get_candidate_names()) + mode + ext | |
| if not os.path.exists(tmp_dir): | |
| os.makedirs(tmp_dir) | |
| file_path = os.path.join(tmp_dir, file_path) | |
| return file_path | |
| def swp_extension(file, ext): | |
| """ | |
| Swap the extension of a given file name. | |
| Args: | |
| file (str): The original file name. | |
| ext (str): The new extension. | |
| Returns: | |
| str: The file name with the new extension. | |
| """ | |
| return os.path.splitext(file)[0] + ext | |
| def pad_both_ends(tensor, left, right, dim=0): | |
| """ | |
| Pad a tensor on both ends along a specific dimension. | |
| Args: | |
| tensor (torch.Tensor): The tensor to be padded. | |
| left (int): The padding size for the left side. | |
| right (int): The padding size for the right side. | |
| dim (int, optional): The dimension along which to pad. | |
| Returns: | |
| torch.Tensor: The padded tensor. | |
| """ | |
| no_dims = len(tensor.size()) | |
| if dim == -1: | |
| dim = no_dims - 1 | |
| padding = [0] * 2 * no_dims | |
| padding[2 * (no_dims - dim - 1)] = left | |
| padding[2 * (no_dims - dim - 1) + 1] = right | |
| return F.pad(tensor, padding, "constant", 0) | |
| def cut_n_stack(seq, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0): | |
| """ | |
| Divide a sequence tensor into smaller snips and stack them. | |
| Args: | |
| seq (torch.Tensor): The original sequence tensor. | |
| snip_length (int): The length of each snip. | |
| cut_dim (int, optional): The dimension along which to cut. | |
| cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length. | |
| pad_samples (int, optional): Number of samples to pad at both ends. | |
| Returns: | |
| torch.Tensor: A tensor containing the stacked snips. | |
| """ | |
| if cutting_stride is None: | |
| cutting_stride = snip_length | |
| pad_left = pad_samples // 2 | |
| pad_right = pad_samples - pad_samples // 2 | |
| seq = pad_both_ends(seq, pad_left, pad_right, dim=cut_dim) | |
| stacked = seq.narrow(cut_dim, 0, snip_length).unsqueeze(0) | |
| iterations = (seq.size()[cut_dim] - snip_length) // cutting_stride + 1 | |
| for i in range(1, iterations): | |
| stacked = torch.cat((stacked, seq.narrow(cut_dim, i * cutting_stride, snip_length).unsqueeze(0))) | |
| return stacked | |
| def create_windowed_sequence(seqs, snip_length, cut_dim=0, cutting_stride=None, pad_samples=0): | |
| """ | |
| Create a windowed sequence from a list of sequences. | |
| Args: | |
| seqs (list of torch.Tensor): List of sequence tensors. | |
| snip_length (int): The length of each snip. | |
| cut_dim (int, optional): The dimension along which to cut. | |
| cutting_stride (int, optional): The stride length for cutting. Defaults to snip_length. | |
| pad_samples (int, optional): Number of samples to pad at both ends. | |
| Returns: | |
| torch.Tensor: A tensor containing the windowed sequences. | |
| """ | |
| windowed_seqs = [] | |
| for seq in seqs: | |
| windowed_seqs.append(cut_n_stack(seq, snip_length, cut_dim, cutting_stride, pad_samples).unsqueeze(0)) | |
| return torch.cat(windowed_seqs) |