Intro

These are my efforts to train a real-world usable Cascaded Gaze image denoising network.

denoise_util.py includes all definitions required to use Cascaded Gaze networks with PyTorch.

Models

v1

  • ~ 132M params, trained on 256 * 256 RGB patches for intermediate jpg & webp compression artefact removal. It's been trained on about 700k samples (photographs only) at a precision of bf16. Also capable of removing ISO-like noise and gaussian noise.
  • I recommend inputing tensors of [B,3,256,256], with values of floats scaled to 0 - 1.

Loading v1

from denoise_util import CascadedGaze
from safetensors.torch import load_file

device = "cuda"

img_channel = 3
width = 60
enc_blks = [2, 2, 4, 6]
middle_blk_num = 12
dec_blks = [2, 2, 2, 2]
GCE_CONVS_nums = [3,3,2,2]

model = CascadedGaze(img_channel=img_channel,width=width, middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,GCE_CONVS_nums=GCE_CONVS_nums)

state_dict = load_file("models/v1.safetensors")
model.load_state_dict(state_dict)
model = model.to(device)
model.requires_grad_(False)
model.eval()

Usage

  • Using https://github.com/ProGamerGov/blended-tiling to handle converting images of arbitrary sizes into 256*256 tiles then back again.
  • You'll need to make ammendments to prevent the batches from being too large for your device.
  • presumes the model was already loaded with code above.
import torch
from PIL import Image
import torchvision
from blended_tiling import TilingModule

def toimg(tensor):
    tensor = torch.clamp(tensor, 0.0, 1.0)
    tensor = tensor * 255
    tensor = tensor.byte()
    return torchvision.transforms.functional.to_pil_image(tensor)

# nb: if rgba inputs are anticipated, this won't be sufficient.
pil_image = Image.open("input.jpg").convert("RGB")

tiling_module = TilingModule(
    tile_size=[256, 256],
    tile_overlap=[0.1, 0.1], # you can configure this to taste
    base_size=pil_image.size,
)

tensor = torchvision.transforms.functional.to_tensor(pil_image)
tensor = torch.unsqueeze(tensor,0)
tiles = tiling_module.split_into_tiles(tensor)
tiles = tiles.to(device)
result = model(tiles).cpu()
result = tiling_module.rebuild_with_masks(result).squeeze()

pil_result = toimg(result)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.