import os from glob import glob import cv2 import numpy as np from PIL import Image import torch from torchvision import transforms import gradio as gr import spaces from models.GCoNet import GCoNet import zipfile device = ['cpu', 'cuda'][0] class ImagePreprocessor(): def __init__(self) -> None: self.transform_image = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def proc(self, image): image = self.transform_image(image) return image def save_tensor_img(path, tenor_im): im = tenor_im.cpu().clone() im = im.squeeze(0) tensor2pil = transforms.ToPILImage() im = tensor2pil(im) im.save(path) model = GCoNet(bb_pretrained=False).to(device) state_dict = './ultimate_duts_cocoseg (The best one).pth' if os.path.exists(state_dict): gconet_dict = torch.load(state_dict, map_location=device) model.load_state_dict(gconet_dict) model.eval() @spaces.GPU def pred_maps(images): assert (images is not None), 'AssertionError: images cannot be None.' # For tab_batch save_paths = [] save_dir = 'preds-GCoNet_plus' if not os.path.exists(save_dir): os.makedirs(save_dir) image_array_lst = [] for idx_image, image_src in enumerate(images): save_paths.append(os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))) if isinstance(image_src, str): image = np.array(Image.open(image_src)) else: image = image_src image_array_lst.append(image) images = image_array_lst image_shapes = [image.shape[:2] for image in images] images = [Image.fromarray(image) for image in images] images_proc = [] image_preprocessor = ImagePreprocessor() for image in images: images_proc.append(image_preprocessor.proc(image)) images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc]) with torch.no_grad(): scaled_preds_tensor = model(images_proc.to(device))[-1] preds = [] for image_shape, pred_tensor, save_path in zip(image_shapes, scaled_preds_tensor, save_paths): if device == 'cuda': pred_tensor = pred_tensor.cpu() pred_tensor = torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze() save_tensor_img(save_path, pred_tensor) zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir)) with zipfile.ZipFile(zip_file_path, 'w') as zipf: for file in save_paths: zipf.write(file, os.path.basename(file)) return save_paths, zip_file_path tab_batch = gr.Interface( fn=pred_maps, inputs=gr.File(label="Upload multiple images in a group", type="filepath", file_count="multiple"), outputs=[gr.Gallery(label="GCoNet+'s predictions"), gr.File(label="Download predicted maps.")], api_name="batch", description='Upload pictures, most of which contain salient objects of the same class. Our demo will give you the binary maps of these co-salient objects :)', ) demo = gr.TabbedInterface( [tab_batch], ['batch'], title="Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`", ) demo.launch(debug=True)