Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info | |
| import braceexpand | |
| import random | |
| import sys | |
| def pytorch_worker_seed(): | |
| """get dataloader worker seed from pytorch""" | |
| worker_info = get_worker_info() | |
| if worker_info is not None: | |
| # favour the seed already created for pytorch dataloader workers if it exists | |
| return worker_info.seed | |
| # fallback to wds rank based seed | |
| return wds.utils.pytorch_worker_seed() | |
| class SharedEpoch: | |
| def __init__(self, epoch: int = 0): | |
| self.shared_epoch = Value('i', epoch) | |
| def set_value(self, epoch): | |
| self.shared_epoch.value = epoch | |
| def get_value(self): | |
| return self.shared_epoch.value | |
| class ResampledShards2(IterableDataset): | |
| """An iterable dataset yielding a list of urls.""" | |
| def __init__( | |
| self, | |
| urls, | |
| nshards=sys.maxsize, | |
| worker_seed=None, | |
| deterministic=False, | |
| epoch=-1, | |
| ): | |
| """Sample shards from the shard list with replacement. | |
| :param urls: a list of URLs as a Python list or brace notation string | |
| """ | |
| super().__init__() | |
| #urls = wds.shardlists.expand_urls(urls) | |
| if type(urls) != list: | |
| urls = list(braceexpand.braceexpand(urls)) | |
| self.urls = urls | |
| assert isinstance(self.urls[0], str) | |
| self.nshards = nshards | |
| self.rng = random.Random() | |
| self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed | |
| self.deterministic = deterministic | |
| self.epoch = epoch | |
| def __iter__(self): | |
| """Return an iterator over the shards.""" | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| if self.deterministic: | |
| # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed | |
| self.rng.seed(self.worker_seed() + epoch) | |
| for _ in range(self.nshards): | |
| yield dict(url=self.rng.choice(self.urls)) | |