MNIST Pix2Pix Model
This is a Pix2Pix model trained on MNIST for image-to-image translation.
π Model Description
- Model type: Conditional GAN (Pix2Pix)
- Dataset: MNIST
- Distorted version: Gist
- Training framework: PyTorch
- License: Apache 2.0
- Files included:
latest_net_G.pth
: Generator model weightslatest_net_D.pth
: Discriminator model weights
- Original Repository: junyanz/pytorch-CycleGAN-and-pix2pix
π How to Use
Load the Model in PyTorch
import argparse
import os
import shutil
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
from torchvision.datasets import MNIST
from models.pix2pix_model import Pix2PixModel
from options.test_options import TestOptions
hf_model_path = hf_hub_download(repo_id="egpivo/mnist-pix2pix", filename="latest_net_G.pth")
expected_checkpoints_dir = os.path.join(os.path.dirname(hf_model_path), "mnist_pix2pix")
expected_model_path = os.path.join(expected_checkpoints_dir, "latest_net_G.pth")
os.makedirs(expected_checkpoints_dir, exist_ok=True)
shutil.copy(hf_model_path, expected_model_path)
print(f"Model copied to: {expected_model_path}")
opt = argparse.Namespace(
dataroot="./dummy_data",
isTrain=False,
name="mnist_pix2pix",
gpu_ids=[],
checkpoints_dir=os.path.dirname(hf_model_path),
model="pix2pix",
input_nc=3,
output_nc=3,
ngf=64,
ndf=64,
netD="basic",
netG="unet_256",
n_layers_D=3,
norm="batch",
init_type="normal",
init_gain=0.02,
no_dropout=True,
dataset_mode="aligned",
direction="AtoB",
serial_batches=True,
num_threads=0,
batch_size=1,
load_size=256,
crop_size=256,
max_dataset_size=float("inf"),
preprocess="resize_and_crop",
no_flip=True,
display_winsize=256,
epoch="latest",
load_iter=0,
verbose=False,
suffix="",
use_wandb=False,
wandb_project_name="",
results_dir="./results",
aspect_ratio=1.0,
phase="test",
eval=True,
num_test=50,
)
model = Pix2PixModel(opt)
model.setup(opt)
model.netG.load_state_dict(torch.load(expected_model_path, map_location="cpu"))
model.netG.eval()
Example: Testing on a Distorted MNIST Image
mnist_dataset = MNIST(root="./data", train=False, download=True)
mnist_image, _ = mnist_dataset[0] # Get first test digit
distorted_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomRotation(30),
transforms.GaussianBlur(3),
transforms.ToTensor()
])
distorted_image = distorted_transform(mnist_image)
# Convert grayscale to RGB (since Pix2Pix expects 3 channels)
input_tensor = distorted_image.unsqueeze(0).repeat(1, 3, 1, 1) # Shape: [1, 3, 64, 64]
# Resize input to match model input size
input_tensor = F.interpolate(input_tensor, size=(128, 128), mode="bilinear", align_corners=False)
# Normalize to [-1, 1] as required by Pix2Pix
input_tensor = (input_tensor - 0.5) * 2
output = model.netG(input_tensor)
output_image = output.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
output_image = (output_image + 1) / 2 # Rescale to [0,1] range for display
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(distorted_image.squeeze(), cmap="gray")
plt.axis("off")
plt.title("Distorted Input")
plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.axis("off")
plt.title("Recovered Output")
plt.show()
print("Testing Completed")
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support
HF Inference deployability: The model authors have turned it off explicitly.