import json import torchvision.transforms as transforms from torch.utils.data.dataset import Dataset #from torchvision.io import read_image from PIL import Image import os import torch import torchvision.transforms.functional as F def tokenize_captions( caption, tokenizer): captions = [caption] inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) # tokenizer(prompt, padding='max_length', # max_length=self.tokenizer.model_max_length, return_tensors='pt') return inputs.input_ids class SquarePad: def __call__(self, image ): w, h = image.size max_wh = max(w, h) hp = int((max_wh - w) / 2) vp = int((max_wh - h) / 2) padding = (hp, vp, hp, vp) return F.pad(image, padding, (255,255,255), 'constant') class NormalSegDataset(Dataset): def __init__(self,args, path,tokenizer,cfg_prob ): self.image_transforms = transforms.Compose( [ # transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), # SquarePad(), # transforms.Pad( (200,100,200,300),fill=(255,255,255),padding_mode='constant'), # transforms.RandomRotation(degrees=30, fill=(255, 255, 255)) , transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ), transforms.ToTensor(), ] ) self.additional_image_transforms = transforms.Compose( [transforms.Normalize([0.5], [0.5]),] ) meta_path = os.path.join(path, 'meta_train_seg.json') with open(meta_path, 'r') as f: self.meta = json.load(f) self.keys = self.meta['keys'] self.meta = self.meta['data'] self.tokenizer = tokenizer self.cfg_prob = cfg_prob def __len__(self): return len(self.keys) def __getitem__(self, index): meta_data = self.meta[self.keys[index]] rgb_path = meta_data['rgb'] normal_path = meta_data['normal'] seg_path = meta_data['seg'] text_prompt = meta_data['caption'][0] rand = torch.rand(1).item() if rand < self.cfg_prob: text_prompt = "" image = Image.open(rgb_path).convert("RGB") state = torch.get_rng_state() image = self.image_transforms(image) rand = torch.rand(1).item() if rand < self.cfg_prob: # get a white image # print("white image") normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255)) # gray_image = Image.new('L', (image.shape[1], image.shape[2]), (255)) seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0)) else: normal_image = Image.open(normal_path).convert("RGB") seg_image = Image.open(seg_path).convert("L") torch.set_rng_state(state) normal_image = self.image_transforms(normal_image) torch.set_rng_state(state) seg_image = self.image_transforms(seg_image) conditioning_image = torch.cat([normal_image, seg_image], dim=0) image = self.additional_image_transforms(image) prompt = text_prompt prompt = tokenize_captions(prompt, self.tokenizer) return image, conditioning_image, prompt, text_prompt