Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import os | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from torch.utils.data import Dataset as TorchDataset | |
| from datasets.encode_openx_dataset import DATA_FREQ_TABLE | |
| from genie.config import GenieConfig | |
| from genie.st_mask_git import cosine_schedule | |
| SVD_SCALE = 0.18215 | |
| def normalize_actions(actions): | |
| """ | |
| compute mean and std of actions. Normalize actions is done inside the network. | |
| """ | |
| mean = np.mean(actions, axis=0).tolist() | |
| std = np.std(actions, axis=0).tolist() | |
| return actions, [mean, std] | |
| class RawFeatureDataset(TorchDataset): | |
| """ Loads raw float32 tokens as memmap-backed array """ | |
| def __init__( | |
| self, | |
| data_dir, | |
| window_size, | |
| stride=1, | |
| filter_interrupts=True, | |
| filter_overlaps=False, | |
| use_actions=False, | |
| max_traj_num=1000000, | |
| compute_stride_from_freq_table=True, | |
| natural_hz=2, | |
| datio_noise_ratio=0.0, | |
| use_raw_image_as_latent=False, | |
| domain=None, | |
| ): | |
| """ | |
| Args: | |
| data_dir: directory with the same format as `data/train_v0` and `data/val_v0`. | |
| Notably, has `video.bin` and `metadata.json` | |
| window_size: number of frames per "video" sequence | |
| stride: frame skip | |
| filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips. | |
| If filter_interrupts is True, will filter out these sequences using the segment ids. | |
| filter_overlaps: If False (default), one frame will appear in multiple examples; | |
| e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15. | |
| If True, will filter out examples so that each frame appears at most once in the dataset. | |
| use_actions: If True, will load the actions from the `actions` folder for the models | |
| """ | |
| data_dir = Path(data_dir) | |
| with open(data_dir / "metadata.json") as f: | |
| self.metadata = json.load(f) | |
| # TODO: assert not quantized in metadata | |
| shape = (self.metadata["num_images"], self.metadata.get("latent_channels", 4), self.metadata["h"], self.metadata["w"]) # | |
| print("token shape:", shape) | |
| self.use_raw_image_as_latent = use_raw_image_as_latent | |
| if use_raw_image_as_latent: | |
| shape = (shape[0], 3, shape[2], shape[3]) | |
| # resize to 32x32 | |
| video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin" | |
| for name in ["video", "segment_ids", "actions"]] | |
| token_dtype = np.dtype(self.metadata.get("token_dtype", "float16")) | |
| self.data = np.memmap(video_tokens_path, mode="r", shape=shape, dtype=token_dtype) | |
| print("data nan:", torch.isnan(torch.from_numpy(self.data[:100].copy())).sum()) | |
| # import IPython; IPython.embed() | |
| if use_raw_image_as_latent: | |
| # debug for robomimic dataset | |
| # 256->64x64 | |
| self.metadata["h"] = 32 | |
| self.metadata["w"] = 32 | |
| self.metadata["latent_channels"] = 3 | |
| self.window_size, self.stride = window_size, stride | |
| self.datio_noise_ratio = datio_noise_ratio | |
| if domain is not None: # TODO: remove | |
| self.name = domain | |
| else: | |
| self.name = self.metadata["name"] | |
| self.name = self.name.replace("_noquant", "") | |
| self.stride = stride | |
| if compute_stride_from_freq_table: | |
| self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1) | |
| self.n_action = self.metadata.get("action_dim", 1) * (self.stride) | |
| if use_actions: | |
| actions = [] | |
| # hack here for the separations in the 1x datasets | |
| for action_file in sorted((data_dir / "actions").iterdir()): | |
| actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1)) | |
| self.actions = np.concatenate(actions, axis=-1) | |
| self.actions, self.action_stat = normalize_actions(self.actions) | |
| if os.path.isfile(segment_ids_path): | |
| self.segment_ids = np.memmap( | |
| segment_ids_path, | |
| dtype=np.int32, | |
| mode="r", | |
| shape=(self.metadata["num_images"],) | |
| ) | |
| else: | |
| self.segment_ids = None | |
| if filter_interrupts: | |
| raise NotImplementedError("Cannot filter interrupted sequences without segment ids.") | |
| # Number of frames between the first and last frames of a video sequence (excluding one endpoint frame) | |
| self.video_len = (self.window_size - 1) * self.stride | |
| self.valid_start_inds = [] | |
| for start_ind in range(len(self.data) - self.video_len - self.stride): | |
| # Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short) | |
| # if the first and last frames have different segment ids. | |
| if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]): | |
| self.valid_start_inds.append(start_ind) | |
| if len(self.valid_start_inds) >= max_traj_num: | |
| break | |
| if filter_overlaps: | |
| # Instead of using a sliding window, use each frame at most once | |
| filtered_start_inds = [] | |
| for start_ind in self.valid_start_inds: | |
| overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)} | |
| # all sequences from `overlapping_start_inds` will also contain `start_ind`, | |
| # so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used | |
| for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]: | |
| # Bound could be improved | |
| if existing_start_ind in overlapping_start_inds: | |
| break | |
| else: | |
| filtered_start_inds.append(start_ind) | |
| self.valid_start_inds = filtered_start_inds | |
| num_videos = len(np.unique(self.segment_ids)) | |
| print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {num_videos=}") | |
| def __len__(self): | |
| return len(self.valid_start_inds) | |
| def __getitem__(self, idx): | |
| """ | |
| Returns a flattened sequence of tokens representing `self.window_size` frames, | |
| spaced `self.stride` apart. | |
| """ | |
| start_ind = self.valid_start_inds[idx] | |
| x = self.data[start_ind : start_ind + self.video_len + 1 : self.stride].copy() | |
| x = torch.FloatTensor(x).float() | |
| if self.use_raw_image_as_latent: | |
| x = torch.nn.functional.interpolate(x, size=(self.metadata["h"], self.metadata["w"])) | |
| # normalize | |
| x = x / 255 - 0.5 | |
| else: | |
| x = x * SVD_SCALE | |
| x = rearrange(x, "t c h w -> (t h w) c") | |
| # divide it when decoding | |
| # reconstructions since the input ids and the labels are the same | |
| attention_mask = torch.ones_like(x) | |
| data_dict = { | |
| "input_ids": x, | |
| "labels": x, | |
| "attention_mask": attention_mask, | |
| "h": self.metadata["h"], | |
| "w": self.metadata["w"], | |
| "c": self.metadata["latent_channels"], | |
| } | |
| if hasattr(self, "actions"): | |
| # we want to have all actions within the stride to predict the next frame at the end of the stride | |
| # we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride] | |
| data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1) | |
| data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32)) | |
| data_dict["domain"] = self.name.replace("_noquant", "") | |
| return data_dict | |
| def get_maskgit_collator_feature(config: GenieConfig): | |
| # mask_token_id = config.image_vocab_size | |
| def collate_fn(features) -> dict[str, torch.Tensor]: | |
| # during training, map (z_0, z_1', z_2') -> (null, z_1, z_2) | |
| # (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1 | |
| h = features[0]["h"] | |
| w = features[0]["w"] | |
| input_ids = torch.stack([ex["input_ids"] for ex in features]) | |
| device = input_ids.device | |
| x_THWC = rearrange(input_ids, "b (t h w) c -> b t h w c", b=len(features), t=config.T, h=h, w=w) | |
| labels = x_THWC.clone() | |
| first_masked_frame = config.T | |
| mask = torch.zeros(1).long() | |
| mask_token_indicator = torch.zeros((len(features), config.T, h, w)).long() | |
| if config.dataloader_apply_mask: | |
| if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference | |
| # Leave frames [0, first_masked_frame) unmasked. | |
| first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1) | |
| else: # Typical MLM masking | |
| first_masked_frame = 1 | |
| c = 0 | |
| while mask.max() == 0: # We could get unlucky and mask no tokens? | |
| # per-minibatch, per-frame masking probability (could try variable masking rate from MUSE) | |
| rand = torch.rand(len(features), config.T - first_masked_frame, 1, 1) | |
| # add a minimum mask ratio | |
| rand_mask = rand * (1 - config.dataloader_mask_ratio_min) + config.dataloader_mask_ratio_min | |
| mask_prob_T = cosine_schedule(rand_mask) | |
| r = torch.rand_like(x_THWC[:, first_masked_frame:, ..., 0], dtype=torch.float) | |
| mask = r < mask_prob_T | |
| c += 1 | |
| if c > 1: | |
| print(f"Generated mask {c} > 1 times.") | |
| mask_token_indicator = torch.cat([ | |
| torch.zeros((len(features), first_masked_frame, h, w), dtype=mask.dtype), mask], dim=1) | |
| data_dict = { | |
| "input_ids": rearrange(x_THWC, "b t h w c -> b (t h w) c"), | |
| "labels": rearrange(labels, "b t h w c-> b (t h w) c"), | |
| "masked_tokens_indicator": mask_token_indicator, | |
| } | |
| if "action_ids" in features[0]: | |
| data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features]) | |
| data_dict['domain'] = [ex["domain"] for ex in features] | |
| data_dict['h'] = [ex["h"] for ex in features] | |
| data_dict['w'] = [ex["w"] for ex in features] | |
| return data_dict | |
| return collate_fn | |