--- license: mit tags: - Denoise --- # Intro These are my efforts to train a real-world usable [Cascaded Gaze](https://github.com/Ascend-Research/CascadedGaze) image denoising network. denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch. # Models **v1** - ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise. - I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1. **Loading v1** ``` python from denoise_util import CascadedGaze from safetensors.torch import load_file device = "cuda" img_channel = 3 width = 60 enc_blks = [2, 2, 4, 6] middle_blk_num = 12 dec_blks = [2, 2, 2, 2] GCE_CONVS_nums = [3,3,2,2] model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num, enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums) state_dict = load_file("models/v1.safetensors") model.load_state_dict(state_dict) model = model.to(device) model.requires_grad_(False) model.eval() ``` **Usage** - Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again. - You'll need to make ammendments to prevent the batches from being too large for your device. - presumes the model was already loaded with code above. ```python import torch from PIL import Image import torchvision from blended_tiling import TilingModule def toimg(tensor): tensor = torch.clamp(tensor, 0.0, 1.0) tensor = tensor * 255 tensor = tensor.byte() return torchvision.transforms.functional.to_pil_image(tensor) # nb: if rgba inputs are anticipated, this won't be sufficient. pil_image = Image.open("input.jpg").convert("RGB") tiling_module = TilingModule( tile_size=[256, 256], tile_overlap=[0.1, 0.1], # you can configure this to taste base_size=pil_image.size, ) tensor = torchvision.transforms.functional.to_tensor(pil_image) tensor = torch.unsqueeze(tensor,0) tiles = tiling_module.split_into_tiles(tensor) tiles = tiles.to(device) result = model(tiles).cpu() result = tiling_module.rebuild_with_masks(result).squeeze() pil_result = toimg(result) ```