--- license: mit --- Custom hand-made 3-scale VQVAE trained on private dataset that consists of about 4k images pixelart images. Source code for model can be found [here](https://github.com/Kemsekov/kemsekov_torch/tree/main/vqvae). It acrhived 0.987 r2 metric on image reconstruction in 500 epoch on 256x256 images crops. Because I used crops, this model works fine with larger and smaller images as well. Model have codebook: * 512 bottom * 512 mid * 256 top This provides enough space for model to achieve good metrics. Here is code example how to use it. ```py import random import PIL.Image from matplotlib import pyplot as plt import torch import torchvision.transforms as T sample = PIL.Image.open("sample_images/cat.png") # you sample image sample = T.ToTensor()(sample)[None,:] # add batch dimension sample = T.Resize((512,512))(sample) # optional, this vqvae works fine with any input image size vqvae=torch.jit.load("model.pt") # rec is reconstruction # z is list of latent space tensors # z_q is quantized list of latent space tensors # ind is list of encoded indices of quantized elements in latent space rec, z, z_q,ind = vqvae.eval().cpu()(sample) rec_ind = vqvae.decode_from_ind(ind) rec=rec.sigmoid() rec_ind=rec_ind.sigmoid() print("Original image shape",list(sample.shape[1:])) print("ind shapes",[list(v.shape[1:]) for v in ind]) plt.figure(figsize=(18,6)) plt.subplot(1,3,1) plt.imshow(T.ToPILImage()(sample[0]).resize((256,256))) plt.title("original") plt.axis('off') plt.subplot(1,3,2) plt.imshow(T.ToPILImage()(rec[0]).resize((256,256))) plt.title("reconstruction") plt.axis('off') plt.subplot(1,3,3) plt.imshow(T.ToPILImage()(rec_ind[0]).resize((256,256))) plt.title("reconstruction from ind") plt.axis('off') plt.show() plt.figure(figsize=(18,6)) plt.subplot(1,3,1) plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256))) plt.title("ind0") plt.axis('off') plt.subplot(1,3,2) plt.imshow(T.ToPILImage()(ind[1]/512).resize((256,256))) plt.title("ind1") plt.axis('off') plt.subplot(1,3,3) plt.imshow(T.ToPILImage()(ind[2]/256).resize((256,256))) plt.title("ind2") plt.axis('off') plt.show() print("latent space render") for z_ in z: dims = len(z_[0]) dims_sqrt = int(dims**0.5) plt.figure(figsize=(10,10)) plt.axis('off') for i in range(dims_sqrt): for j in range(dims_sqrt): slice_ind = i*dims_sqrt+j slice_ind_end = slice_ind+1 plt.subplot(dims_sqrt,dims_sqrt,slice_ind+1) plt.imshow(T.ToPILImage()(z_[0][slice_ind:slice_ind_end])) plt.axis('off') plt.show() ``` ``` Original image shape [3, 512, 512] ind shapes [[128, 128], [64, 64], [32, 32]] ``` ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/d3PSfPu9tkKZkdMv8UJSV.png) ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/pDOPnZtAh05UXfkFaklkq.png) And it have following latent space Bottom ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/RkRVxY6uly59c8yumMTpv.png) Mid ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/CwR8o--prVLmR6TdL4Jt7.png) Top ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/uF95lUigW-NOYIV2EhD8h.png) As you can see, it properly handles different image aspects at different scales