MNIST Pix2Pix Model

This is a Pix2Pix model trained on MNIST for image-to-image translation.

πŸ“Œ Model Description

  • Model type: Conditional GAN (Pix2Pix)
  • Dataset: MNIST
    • Distorted version: Gist
  • Training framework: PyTorch
  • License: Apache 2.0
  • Files included:
    • latest_net_G.pth: Generator model weights
    • latest_net_D.pth: Discriminator model weights
  • Original Repository: junyanz/pytorch-CycleGAN-and-pix2pix

πŸš€ How to Use

Load the Model in PyTorch

import argparse
import os
import shutil

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
from torchvision.datasets import MNIST

from models.pix2pix_model import Pix2PixModel
from options.test_options import TestOptions

hf_model_path = hf_hub_download(repo_id="egpivo/mnist-pix2pix", filename="latest_net_G.pth")

expected_checkpoints_dir = os.path.join(os.path.dirname(hf_model_path), "mnist_pix2pix")
expected_model_path = os.path.join(expected_checkpoints_dir, "latest_net_G.pth")

os.makedirs(expected_checkpoints_dir, exist_ok=True)
shutil.copy(hf_model_path, expected_model_path)
print(f"Model copied to: {expected_model_path}")

opt = argparse.Namespace(
    dataroot="./dummy_data",  
    isTrain=False,  
    name="mnist_pix2pix",
    gpu_ids=[],
    checkpoints_dir=os.path.dirname(hf_model_path),  
    model="pix2pix",
    input_nc=3,
    output_nc=3,  
    ngf=64,
    ndf=64,
    netD="basic",
    netG="unet_256",
    n_layers_D=3,
    norm="batch",
    init_type="normal",
    init_gain=0.02,
    no_dropout=True,
    dataset_mode="aligned",
    direction="AtoB",
    serial_batches=True,
    num_threads=0,
    batch_size=1,
    load_size=256,
    crop_size=256,
    max_dataset_size=float("inf"),
    preprocess="resize_and_crop",
    no_flip=True,
    display_winsize=256,
    epoch="latest",
    load_iter=0,
    verbose=False,
    suffix="",
    use_wandb=False,
    wandb_project_name="",
    results_dir="./results",
    aspect_ratio=1.0,
    phase="test",
    eval=True,
    num_test=50,
)

model = Pix2PixModel(opt)
model.setup(opt)

model.netG.load_state_dict(torch.load(expected_model_path, map_location="cpu"))
model.netG.eval()

Example: Testing on a Distorted MNIST Image

mnist_dataset = MNIST(root="./data", train=False, download=True)
mnist_image, _ = mnist_dataset[0]  # Get first test digit

distorted_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomRotation(30),
    transforms.GaussianBlur(3),
    transforms.ToTensor()
])

distorted_image = distorted_transform(mnist_image)

# Convert grayscale to RGB (since Pix2Pix expects 3 channels)
input_tensor = distorted_image.unsqueeze(0).repeat(1, 3, 1, 1)  # Shape: [1, 3, 64, 64]

# Resize input to match model input size
input_tensor = F.interpolate(input_tensor, size=(128, 128), mode="bilinear", align_corners=False)

# Normalize to [-1, 1] as required by Pix2Pix
input_tensor = (input_tensor - 0.5) * 2

output = model.netG(input_tensor)

output_image = output.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
output_image = (output_image + 1) / 2  # Rescale to [0,1] range for display

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(distorted_image.squeeze(), cmap="gray")
plt.axis("off")
plt.title("Distorted Input")

plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.axis("off")
plt.title("Recovered Output")

plt.show()

print("Testing Completed")

image/png

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train egpivo/mnist-pix2pix