generadorLetras / example.py
jruaechalar's picture
Upload example.py with huggingface_hub
577ffe1 verified
import torch
from diffusers import DDPMScheduler
import json
from PIL import Image
import numpy as np
class LetterConditionedUnet(torch.nn.Module):
def __init__(self, num_classes=26, class_emb_size=8):
super().__init__()
from diffusers import UNet2DModel
self.class_emb = torch.nn.Embedding(num_classes, class_emb_size)
self.model = UNet2DModel(
sample_size=512,
in_channels=1 + class_emb_size,
out_channels=1,
layers_per_block=2,
block_out_channels=(64, 128, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
def forward(self, x, t, class_labels):
bs, ch, w, h = x.shape
class_cond = self.class_emb(class_labels)
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
net_input = torch.cat((x, class_cond), 1)
return self.model(net_input, t).sample
def generate_letter(letter, model_path="./"):
"""Genera una imagen de la letra especificada"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Cargar modelo
model = LetterConditionedUnet()
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
model = model.to(device)
model.eval()
# Cargar scheduler
with open(f"{model_path}/scheduler_config.json", 'r') as f:
scheduler_config = json.load(f)
scheduler = DDPMScheduler(**scheduler_config)
# Preparar entrada
letter_label = ord(letter.upper()) - 65 # Convertir letra a número
x = torch.randn(1, 1, 512, 512, device=device)
labels = torch.tensor([letter_label], device=device)
# Generar
with torch.no_grad():
for t in scheduler.timesteps:
residual = model(x, t, labels)
x = scheduler.step(residual, t, x).prev_sample
# Convertir a imagen
image = x[0, 0].cpu().numpy()
image = (image + 1) / 2 # Desnormalizar de [-1,1] a [0,1]
image = (image * 255).astype(np.uint8)
return Image.fromarray(image, mode='L')
# Ejemplo de uso
if __name__ == "__main__":
letter_image = generate_letter('A')
letter_image.save('generated_letter_A.png')