jruaechalar commited on
Commit
577ffe1
verified
1 Parent(s): b371b91

Upload example.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example.py +80 -0
example.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from diffusers import DDPMScheduler
4
+ import json
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ class LetterConditionedUnet(torch.nn.Module):
9
+ def __init__(self, num_classes=26, class_emb_size=8):
10
+ super().__init__()
11
+ from diffusers import UNet2DModel
12
+
13
+ self.class_emb = torch.nn.Embedding(num_classes, class_emb_size)
14
+
15
+ self.model = UNet2DModel(
16
+ sample_size=512,
17
+ in_channels=1 + class_emb_size,
18
+ out_channels=1,
19
+ layers_per_block=2,
20
+ block_out_channels=(64, 128, 256, 512, 512),
21
+ down_block_types=(
22
+ "DownBlock2D",
23
+ "DownBlock2D",
24
+ "AttnDownBlock2D",
25
+ "AttnDownBlock2D",
26
+ "AttnDownBlock2D",
27
+ ),
28
+ up_block_types=(
29
+ "AttnUpBlock2D",
30
+ "AttnUpBlock2D",
31
+ "AttnUpBlock2D",
32
+ "UpBlock2D",
33
+ "UpBlock2D",
34
+ ),
35
+ )
36
+
37
+ def forward(self, x, t, class_labels):
38
+ bs, ch, w, h = x.shape
39
+ class_cond = self.class_emb(class_labels)
40
+ class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
41
+ net_input = torch.cat((x, class_cond), 1)
42
+ return self.model(net_input, t).sample
43
+
44
+ def generate_letter(letter, model_path="./"):
45
+ """Genera una imagen de la letra especificada"""
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ # Cargar modelo
49
+ model = LetterConditionedUnet()
50
+ model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
51
+ model = model.to(device)
52
+ model.eval()
53
+
54
+ # Cargar scheduler
55
+ with open(f"{model_path}/scheduler_config.json", 'r') as f:
56
+ scheduler_config = json.load(f)
57
+ scheduler = DDPMScheduler(**scheduler_config)
58
+
59
+ # Preparar entrada
60
+ letter_label = ord(letter.upper()) - 65 # Convertir letra a n煤mero
61
+ x = torch.randn(1, 1, 512, 512, device=device)
62
+ labels = torch.tensor([letter_label], device=device)
63
+
64
+ # Generar
65
+ with torch.no_grad():
66
+ for t in scheduler.timesteps:
67
+ residual = model(x, t, labels)
68
+ x = scheduler.step(residual, t, x).prev_sample
69
+
70
+ # Convertir a imagen
71
+ image = x[0, 0].cpu().numpy()
72
+ image = (image + 1) / 2 # Desnormalizar de [-1,1] a [0,1]
73
+ image = (image * 255).astype(np.uint8)
74
+
75
+ return Image.fromarray(image, mode='L')
76
+
77
+ # Ejemplo de uso
78
+ if __name__ == "__main__":
79
+ letter_image = generate_letter('A')
80
+ letter_image.save('generated_letter_A.png')