Intro

These are my efforts to train a real-world usable Cascaded Gaze image denoising network.

denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch.

Models

v1

  • ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
  • I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.

Loading v1

from denoise_util import CascadedGaze
from safetensors.torch import load_file

device = "cuda"

img_channel = 3
width = 60
enc_blks = [2, 2, 4, 6]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]
GCE_CONVS_nums = [3,3,2,2]

model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums)

state_dict = load_file("models/v1.safetensors")
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()

jpg+webp denoising mini sample sample: 4x zoom demonstration of this model applied to an image of a toy car found online.

  • only ~18M parameters, trained on 256 * 256 BGR patches for jpg & webp compression artefact removal only. PSNR loss was used.
  • can handle artefacts that have been up or down scaled.
#Loading as above but with some settings changed:
enc_blks = [2, 2, 3]
middle_blk_num = 6
dec_blks = [2, 2, 2]
GCE_CONVS_nums = [3,3,2]

Usage

  • Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
  • You'll need to make ammendments to prevent the batches from being too large for your device.
  • presumes the model was already loaded with code above.
  • loading/saving images omitted, but you could use PIL or cv2, etc. note the BGR vs RGB cardinality of the models.
import torch
from PIL import Image
import torchvision
from blended_tiling import TilingModule

#load an image however you want

tiling_module = TilingModule(
    tile_size=[256, 256],
    tile_overlap=[0.1, 0.1], # you can configure this to taste
    base_size=pil_image.size, #nb: see .shape if you load with cv2
)

tensor = torchvision.transforms.functional.to_tensor(pil_image) #also compatible with cv2
tensor = torch.unsqueeze(tensor,0)
tiles = tiling_module.split_into_tiles(tensor)
tiles = tiles.to(device)
result = model(tiles).cpu() #you'll likely want to handle re-batching of tiles to fit vram
result = tiling_module.rebuild_with_masks(result).squeeze().clamp(0, 1)

#save an image however you want
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support