File size: 1,416 Bytes
759dfe0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from pathlib import Path
import click
import torch
from torch.utils.data import DataLoader, Dataset
def load_to_cpu(x):
return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
class LatentEmbedDataset(Dataset):
def __init__(self, file_paths, repeat=1):
self.items = [
(Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
for p in file_paths
if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
]
self.items = self.items * repeat
print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
latent_path, embed_path = self.items[idx]
return load_to_cpu(latent_path), load_to_cpu(embed_path)
@click.command()
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
def process_videos(directory):
dir_path = Path(directory)
mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
assert mp4_files, f"No mp4 files found"
dataset = LatentEmbedDataset(mp4_files)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for latents, embeds in dataloader:
print([(k, v.shape) for k, v in latents.items()])
if __name__ == "__main__":
process_videos()
|