Spaces:
Runtime error
Runtime error
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import json | |
| import random | |
| import cv2 | |
| def canny_processor(image, low_threshold=100, high_threshold=200): | |
| image = np.array(image) | |
| image = cv2.Canny(image, low_threshold, high_threshold) | |
| image = image[:, :, None] | |
| image = np.concatenate([image, image, image], axis=2) | |
| canny_image = Image.fromarray(image) | |
| return canny_image | |
| def c_crop(image): | |
| width, height = image.size | |
| new_size = min(width, height) | |
| left = (width - new_size) / 2 | |
| top = (height - new_size) / 2 | |
| right = (width + new_size) / 2 | |
| bottom = (height + new_size) / 2 | |
| return image.crop((left, top, right, bottom)) | |
| class CustomImageDataset(Dataset): | |
| def __init__(self, img_dir, img_size=512): | |
| self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] | |
| self.images.sort() | |
| self.img_size = img_size | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| try: | |
| img = Image.open(self.images[idx]) | |
| img = c_crop(img) | |
| img = img.resize((self.img_size, self.img_size)) | |
| hint = canny_processor(img) | |
| img = torch.from_numpy((np.array(img) / 127.5) - 1) | |
| img = img.permute(2, 0, 1) | |
| hint = torch.from_numpy((np.array(hint) / 127.5) - 1) | |
| hint = hint.permute(2, 0, 1) | |
| json_path = self.images[idx].split('.')[0] + '.json' | |
| prompt = json.load(open(json_path))['caption'] | |
| return img, hint, prompt | |
| except Exception as e: | |
| print(e) | |
| return self.__getitem__(random.randint(0, len(self.images) - 1)) | |
| def loader(train_batch_size, num_workers, **args): | |
| dataset = CustomImageDataset(**args) | |
| return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers) | |