import torch from diffusers import DDPMScheduler import json from PIL import Image import numpy as np class LetterConditionedUnet(torch.nn.Module): def __init__(self, num_classes=26, class_emb_size=8): super().__init__() from diffusers import UNet2DModel self.class_emb = torch.nn.Embedding(num_classes, class_emb_size) self.model = UNet2DModel( sample_size=512, in_channels=1 + class_emb_size, out_channels=1, layers_per_block=2, block_out_channels=(64, 128, 256, 512, 512), down_block_types=( "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", ), up_block_types=( "AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", ), ) def forward(self, x, t, class_labels): bs, ch, w, h = x.shape class_cond = self.class_emb(class_labels) class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) net_input = torch.cat((x, class_cond), 1) return self.model(net_input, t).sample def generate_letter(letter, model_path="./"): """Genera una imagen de la letra especificada""" device = "cuda" if torch.cuda.is_available() else "cpu" # Cargar modelo model = LetterConditionedUnet() model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device)) model = model.to(device) model.eval() # Cargar scheduler with open(f"{model_path}/scheduler_config.json", 'r') as f: scheduler_config = json.load(f) scheduler = DDPMScheduler(**scheduler_config) # Preparar entrada letter_label = ord(letter.upper()) - 65 # Convertir letra a nĂºmero x = torch.randn(1, 1, 512, 512, device=device) labels = torch.tensor([letter_label], device=device) # Generar with torch.no_grad(): for t in scheduler.timesteps: residual = model(x, t, labels) x = scheduler.step(residual, t, x).prev_sample # Convertir a imagen image = x[0, 0].cpu().numpy() image = (image + 1) / 2 # Desnormalizar de [-1,1] a [0,1] image = (image * 255).astype(np.uint8) return Image.fromarray(image, mode='L') # Ejemplo de uso if __name__ == "__main__": letter_image = generate_letter('A') letter_image.save('generated_letter_A.png')