|
|
|
import torch |
|
from diffusers import DDPMScheduler |
|
import json |
|
from PIL import Image |
|
import numpy as np |
|
|
|
class LetterConditionedUnet(torch.nn.Module): |
|
def __init__(self, num_classes=26, class_emb_size=8): |
|
super().__init__() |
|
from diffusers import UNet2DModel |
|
|
|
self.class_emb = torch.nn.Embedding(num_classes, class_emb_size) |
|
|
|
self.model = UNet2DModel( |
|
sample_size=512, |
|
in_channels=1 + class_emb_size, |
|
out_channels=1, |
|
layers_per_block=2, |
|
block_out_channels=(64, 128, 256, 512, 512), |
|
down_block_types=( |
|
"DownBlock2D", |
|
"DownBlock2D", |
|
"AttnDownBlock2D", |
|
"AttnDownBlock2D", |
|
"AttnDownBlock2D", |
|
), |
|
up_block_types=( |
|
"AttnUpBlock2D", |
|
"AttnUpBlock2D", |
|
"AttnUpBlock2D", |
|
"UpBlock2D", |
|
"UpBlock2D", |
|
), |
|
) |
|
|
|
def forward(self, x, t, class_labels): |
|
bs, ch, w, h = x.shape |
|
class_cond = self.class_emb(class_labels) |
|
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) |
|
net_input = torch.cat((x, class_cond), 1) |
|
return self.model(net_input, t).sample |
|
|
|
def generate_letter(letter, model_path="./"): |
|
"""Genera una imagen de la letra especificada""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = LetterConditionedUnet() |
|
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device)) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with open(f"{model_path}/scheduler_config.json", 'r') as f: |
|
scheduler_config = json.load(f) |
|
scheduler = DDPMScheduler(**scheduler_config) |
|
|
|
|
|
letter_label = ord(letter.upper()) - 65 |
|
x = torch.randn(1, 1, 512, 512, device=device) |
|
labels = torch.tensor([letter_label], device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
for t in scheduler.timesteps: |
|
residual = model(x, t, labels) |
|
x = scheduler.step(residual, t, x).prev_sample |
|
|
|
|
|
image = x[0, 0].cpu().numpy() |
|
image = (image + 1) / 2 |
|
image = (image * 255).astype(np.uint8) |
|
|
|
return Image.fromarray(image, mode='L') |
|
|
|
|
|
if __name__ == "__main__": |
|
letter_image = generate_letter('A') |
|
letter_image.save('generated_letter_A.png') |
|
|