Image colorization

A UNet architecture, utilizing transfer learning by using a pretrained ResNet-34 as an encoder.

Try the model on Google Colab or Huggingface space.

The model takes a 1x224x224 L tensor as input and outputs 2x224x224 ab channels. The decoder has been trained from scratch. The encoder (ResNet-34) was initially frozen for the decoder to adapt to the task, then it was progressively unfrozen layer by layer. Initial layers were not unfrozen, only deeper layers were fine-tuned. Read various research papers. It took 20+ hours of training on Google Colab and Kaggle T4 GPUs to train the model.

There are no dedicated datasets for image colorisation, hence I curated my own dataset and used it to train the model. The COCO 2017 dataset was filtered to remove grayscale images, heavily filtered images, and other artifacts not suitable for training a natural colorization model. Also the images were center-cropped and resized to 224x224. The dataset can be found here. This repository contains the model weights and the UNet architecture to load the weights into.

Usage

Download the architecture file and model weights

hf_hub_download(
    repo_id="ayushshah/imagecolorization",
    filename="model.py",
    local_dir=".",
    local_dir_use_symlinks=False
)

weights_path = hf_hub_download(
    repo_id=REPO_ID,
    filename="model.safetensors"
)

Make sure the input image(s) are of the size 224x224. Convert them to LAB color space. You can use kornia. Isolate the L channel and make sure it is in the range [0, 1]. L channel is originally in the range [0, 100].

from model import UNet
from safetensors.torch import load_file

model = UNet().to(DEVICE)
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()

with torch.no_grad():
    ab_pred = model(L_normalized)

The outputs are in the range [-1, 1]. You can convert the ab channels to their original range using a linear scaling function. Afterwards you can concatenate the original L and the ab channels to get the LAB image.

ab = (ab+1) * 255.0 / 2 - 128.0
ab = torch.clamp(ab, -128, 127)
lab = torch.cat((L, ab), dim=1)

References

Downloads last month
2
Safetensors
Model size
24.5M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train ayushshah/imagecolorization

Space using ayushshah/imagecolorization 1