import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import gradio as gr from PIL import Image import os # Check device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class ConditionalVAE(nn.Module): def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10): super(ConditionalVAE, self).__init__() # Encoder self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim) self.fc21 = nn.Linear(hidden_dim, latent_dim) self.fc22 = nn.Linear(hidden_dim, latent_dim) # Decoder self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim) self.fc4 = nn.Linear(hidden_dim, input_dim) self.latent_dim = latent_dim self.num_classes = num_classes def encode(self, x, y): inputs = torch.cat([x, y], 1) h1 = F.relu(self.fc1(inputs)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z, y): inputs = torch.cat([z, y], 1) h3 = F.relu(self.fc3(inputs)) return torch.sigmoid(self.fc4(h3)) def forward(self, x, y): mu, logvar = self.encode(x.view(-1, 784), y) z = self.reparameterize(mu, logvar) return self.decode(z, y), mu, logvar # Load model def load_model(): model = ConditionalVAE(input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10) model.load_state_dict(torch.load('mnist_cvae_model.pth', map_location=device)) model = model.to(device) model.eval() return model def generate_digits(model, digit, num_samples=5): model.eval() with torch.no_grad(): label = torch.zeros(num_samples, 10).to(device) label[:, digit] = 1 z = torch.randn(num_samples, model.latent_dim).to(device) generated = model.decode(z, label) generated = generated.view(num_samples, 28, 28) generated = generated.cpu().numpy() generated = (generated * 255).astype(np.uint8) return generated def generate_digit_images(digit): try: model = load_model() generated_images = generate_digits(model, int(digit), num_samples=5) pil_images = [] for img in generated_images: pil_img = Image.fromarray(img, mode='L') pil_img = pil_img.resize((112, 112), Image.NEAREST) pil_images.append(pil_img) return pil_images except Exception as e: print(f"Error: {e}") placeholder = Image.new('L', (112, 112), color=128) return [placeholder] * 5 def generate_and_display(digit): images = generate_digit_images(digit) return images[0], images[1], images[2], images[3], images[4] # Create Gradio interface with gr.Blocks(title="MNIST Digit Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔢 MNIST Handwritten Digit Generator") gr.Markdown("Select a digit (0-9) and generate 5 unique handwritten samples using a trained Conditional VAE model.") with gr.Row(): digit_input = gr.Slider( minimum=0, maximum=9, step=1, value=0, label="Select Digit to Generate" ) generate_btn = gr.Button("🎨 Generate 5 Digit Images", variant="primary", size="lg") gr.Markdown("## Generated Images") with gr.Row(): img1 = gr.Image(label="Sample 1", width=112, height=112) img2 = gr.Image(label="Sample 2", width=112, height=112) img3 = gr.Image(label="Sample 3", width=112, height=112) img4 = gr.Image(label="Sample 4", width=112, height=112) img5 = gr.Image(label="Sample 5", width=112, height=112) generate_btn.click( fn=generate_and_display, inputs=[digit_input], outputs=[img1, img2, img3, img4, img5] ) with gr.Accordion("📋 Model Information", open=False): gr.Markdown(""" ### Technical Details - **Architecture**: Conditional Variational Autoencoder (CVAE) - **Dataset**: MNIST (28×28 grayscale images) - **Training**: From scratch on Google Colab T4 GPU - **Latent Dimension**: 20 - **Training Epochs**: 15 - **Loss Function**: BCE + KL Divergence The model generates diverse samples by sampling from the learned latent space conditioned on digit labels. """) if __name__ == "__main__": demo.launch()