Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Bai-YT
		
	commited on
		
		
					Commit 
							
							·
						
						66982e9
	
0
								Parent(s):
							
							
Gradio App for ConsistencyTTA V1
Browse files- .gitignore +6 -0
- README.md +7 -0
- audioldm/hifigan/__init__.py +7 -0
- audioldm/hifigan/models.py +127 -0
- audioldm/hifigan/utilities.py +88 -0
- audioldm/latent_diffusion/attention.py +469 -0
- audioldm/latent_diffusion/util.py +293 -0
- audioldm/stft.py +257 -0
- audioldm/utils.py +177 -0
- audioldm/variational_autoencoder/__init__.py +1 -0
- audioldm/variational_autoencoder/autoencoder.py +131 -0
- audioldm/variational_autoencoder/distributions.py +102 -0
- audioldm/variational_autoencoder/modules.py +1067 -0
- consistencytta.py +200 -0
- consistencytta_clapft_ckpt/.DS_Store +0 -0
- diffusers/__init__.py +2 -0
- diffusers/models/__init__.py +23 -0
- diffusers/models/activations.py +12 -0
- diffusers/models/attention.py +523 -0
- diffusers/models/attention_processor.py +1646 -0
- diffusers/models/dual_transformer_2d.py +151 -0
- diffusers/models/embeddings.py +480 -0
- diffusers/models/loaders.py +1481 -0
- diffusers/models/modeling_utils.py +978 -0
- diffusers/models/prior_transformer.py +194 -0
- diffusers/models/resnet.py +839 -0
- diffusers/models/transformer_2d.py +333 -0
- diffusers/models/unet_2d.py +315 -0
- diffusers/models/unet_2d_blocks.py +0 -0
- diffusers/models/unet_2d_condition.py +907 -0
- diffusers/models/unet_2d_condition_guided.py +945 -0
- diffusers/scheduling_heun_discrete.py +387 -0
- diffusers/utils/configuration_utils.py +647 -0
- diffusers/utils/constants.py +34 -0
- diffusers/utils/deprecation_utils.py +49 -0
- diffusers/utils/hub_utils.py +357 -0
- diffusers/utils/import_utils.py +649 -0
- diffusers/utils/logging.py +342 -0
- diffusers/utils/outputs.py +108 -0
- diffusers/utils/scheduling_utils.py +176 -0
- diffusers/utils/torch_utils.py +83 -0
- run_gradio.py +87 -0
- tango_diffusion_light.json +46 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __pycache__
         | 
| 2 | 
            +
            */__pycache__
         | 
| 3 | 
            +
            flagged
         | 
| 4 | 
            +
            *.wav
         | 
| 5 | 
            +
            *.pt
         | 
| 6 | 
            +
            *.DS_Store
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Gradio App for ConsistencyTTA
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Required packages:
         | 
| 4 | 
            +
            `numpy scipy torch torchaudio einops soundfile librosa transformers gradio`
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            To run:
         | 
| 7 | 
            +
            `python run_gradio.py`
         | 
    	
        audioldm/hifigan/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .models import Generator
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class AttrDict(dict):
         | 
| 5 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 6 | 
            +
                    super(AttrDict, self).__init__(*args, **kwargs)
         | 
| 7 | 
            +
                    self.__dict__ = self
         | 
    	
        audioldm/hifigan/models.py
    ADDED
    
    | @@ -0,0 +1,127 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d
         | 
| 5 | 
            +
            from torch.nn.utils.parametrizations import weight_norm
         | 
| 6 | 
            +
            from torch.nn.utils.parametrize import remove_parametrizations
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 13 | 
            +
                classname = m.__class__.__name__
         | 
| 14 | 
            +
                if classname.find("Conv") != -1:
         | 
| 15 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 19 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class ResBlock(torch.nn.Module):
         | 
| 23 | 
            +
                def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
         | 
| 24 | 
            +
                    super(ResBlock, self).__init__()
         | 
| 25 | 
            +
                    self.h = h
         | 
| 26 | 
            +
                    self.convs1 = nn.ModuleList([
         | 
| 27 | 
            +
                        weight_norm(Conv1d(
         | 
| 28 | 
            +
                            channels, channels, kernel_size, 1, dilation=dilation[0],
         | 
| 29 | 
            +
                            padding=get_padding(kernel_size, dilation[0]),
         | 
| 30 | 
            +
                        )),
         | 
| 31 | 
            +
                        weight_norm(Conv1d(
         | 
| 32 | 
            +
                            channels, channels, kernel_size, 1, dilation=dilation[1],
         | 
| 33 | 
            +
                            padding=get_padding(kernel_size, dilation[1]),
         | 
| 34 | 
            +
                        )),
         | 
| 35 | 
            +
                        weight_norm(Conv1d(
         | 
| 36 | 
            +
                            channels, channels, kernel_size, 1, dilation=dilation[2],
         | 
| 37 | 
            +
                            padding=get_padding(kernel_size, dilation[2]),
         | 
| 38 | 
            +
                        )),
         | 
| 39 | 
            +
                    ])
         | 
| 40 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.convs2 = nn.ModuleList([
         | 
| 43 | 
            +
                        weight_norm(Conv1d(
         | 
| 44 | 
            +
                            channels, channels, kernel_size, 1, dilation=1,
         | 
| 45 | 
            +
                            padding=get_padding(kernel_size, 1),
         | 
| 46 | 
            +
                        )),
         | 
| 47 | 
            +
                        weight_norm(Conv1d(
         | 
| 48 | 
            +
                            channels, channels, kernel_size, 1, dilation=1,
         | 
| 49 | 
            +
                            padding=get_padding(kernel_size, 1),
         | 
| 50 | 
            +
                        )),
         | 
| 51 | 
            +
                        weight_norm(Conv1d(
         | 
| 52 | 
            +
                            channels, channels, kernel_size, 1, dilation=1,
         | 
| 53 | 
            +
                            padding=get_padding(kernel_size, 1),
         | 
| 54 | 
            +
                        )),
         | 
| 55 | 
            +
                    ])
         | 
| 56 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x):
         | 
| 59 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 60 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 61 | 
            +
                        xt = c1(xt)
         | 
| 62 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 63 | 
            +
                        xt = c2(xt)
         | 
| 64 | 
            +
                        x = xt + x
         | 
| 65 | 
            +
                    return x
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def remove_weight_norm(self):
         | 
| 68 | 
            +
                    for l in self.convs1:
         | 
| 69 | 
            +
                        remove_parametrizations(l, 'weight')
         | 
| 70 | 
            +
                    for l in self.convs2:
         | 
| 71 | 
            +
                        remove_parametrizations(l, 'weight')
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            class Generator(torch.nn.Module):
         | 
| 75 | 
            +
                def __init__(self, h):
         | 
| 76 | 
            +
                    super(Generator, self).__init__()
         | 
| 77 | 
            +
                    self.h = h
         | 
| 78 | 
            +
                    self.num_kernels = len(h.resblock_kernel_sizes)
         | 
| 79 | 
            +
                    self.num_upsamples = len(h.upsample_rates)
         | 
| 80 | 
            +
                    self.conv_pre = weight_norm(
         | 
| 81 | 
            +
                        Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    resblock = ResBlock
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 86 | 
            +
                    for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
         | 
| 87 | 
            +
                        self.ups.append(weight_norm(ConvTranspose1d(
         | 
| 88 | 
            +
                            h.upsample_initial_channel // (2**i),
         | 
| 89 | 
            +
                            h.upsample_initial_channel // (2 ** (i + 1)),
         | 
| 90 | 
            +
                            k, u, padding=(k - u) // 2,
         | 
| 91 | 
            +
                        )))
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 94 | 
            +
                    for i in range(len(self.ups)):
         | 
| 95 | 
            +
                        ch = h.upsample_initial_channel // (2 ** (i + 1))
         | 
| 96 | 
            +
                        for k, d in zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes):
         | 
| 97 | 
            +
                            self.resblocks.append(resblock(h, ch, k, d))
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
         | 
| 100 | 
            +
                    self.ups.apply(init_weights)
         | 
| 101 | 
            +
                    self.conv_post.apply(init_weights)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward(self, x):
         | 
| 104 | 
            +
                    x = self.conv_pre(x)
         | 
| 105 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 106 | 
            +
                        x = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 107 | 
            +
                        x = self.ups[i](x)
         | 
| 108 | 
            +
                        xs = None
         | 
| 109 | 
            +
                        for j in range(self.num_kernels):
         | 
| 110 | 
            +
                            if xs is None:
         | 
| 111 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 112 | 
            +
                            else:
         | 
| 113 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 114 | 
            +
                        x = xs / self.num_kernels
         | 
| 115 | 
            +
                    x = F.leaky_relu(x)
         | 
| 116 | 
            +
                    x = self.conv_post(x)
         | 
| 117 | 
            +
                    x = torch.tanh(x)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    return x
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def remove_weight_norm(self):
         | 
| 122 | 
            +
                    for l in self.ups:
         | 
| 123 | 
            +
                        remove_parametrizations(l, 'weight')
         | 
| 124 | 
            +
                    for l in self.resblocks:
         | 
| 125 | 
            +
                        l.remove_weight_norm()
         | 
| 126 | 
            +
                    remove_parametrizations(self.conv_pre, 'weight')
         | 
| 127 | 
            +
                    remove_parametrizations(self.conv_post, 'weight')
         | 
    	
        audioldm/hifigan/utilities.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import audioldm.hifigan as hifigan
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            HIFIGAN_16K_64 = {
         | 
| 7 | 
            +
                "resblock": "1",
         | 
| 8 | 
            +
                "num_gpus": 6,
         | 
| 9 | 
            +
                "batch_size": 16,
         | 
| 10 | 
            +
                "learning_rate": 0.0002,
         | 
| 11 | 
            +
                "adam_b1": 0.8,
         | 
| 12 | 
            +
                "adam_b2": 0.99,
         | 
| 13 | 
            +
                "lr_decay": 0.999,
         | 
| 14 | 
            +
                "seed": 1234,
         | 
| 15 | 
            +
                "upsample_rates": [5, 4, 2, 2, 2],
         | 
| 16 | 
            +
                "upsample_kernel_sizes": [16, 16, 8, 4, 4],
         | 
| 17 | 
            +
                "upsample_initial_channel": 1024,
         | 
| 18 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 19 | 
            +
                "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
         | 
| 20 | 
            +
                "segment_size": 8192,
         | 
| 21 | 
            +
                "num_mels": 64,
         | 
| 22 | 
            +
                "num_freq": 1025,
         | 
| 23 | 
            +
                "n_fft": 1024,
         | 
| 24 | 
            +
                "hop_size": 160,
         | 
| 25 | 
            +
                "win_size": 1024,
         | 
| 26 | 
            +
                "sampling_rate": 16000,
         | 
| 27 | 
            +
                "fmin": 0,
         | 
| 28 | 
            +
                "fmax": 8000,
         | 
| 29 | 
            +
                "fmax_for_loss": None,
         | 
| 30 | 
            +
                "num_workers": 4,
         | 
| 31 | 
            +
                "dist_config": {
         | 
| 32 | 
            +
                    "dist_backend": "nccl",
         | 
| 33 | 
            +
                    "dist_url": "tcp://localhost:54321",
         | 
| 34 | 
            +
                    "world_size": 1,
         | 
| 35 | 
            +
                },
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def get_available_checkpoint_keys(model, ckpt):
         | 
| 40 | 
            +
                print("==> Attemp to reload from %s" % ckpt)
         | 
| 41 | 
            +
                state_dict = torch.load(ckpt)["state_dict"]
         | 
| 42 | 
            +
                current_state_dict = model.state_dict()
         | 
| 43 | 
            +
                new_state_dict = {}
         | 
| 44 | 
            +
                for k in state_dict.keys():
         | 
| 45 | 
            +
                    if (
         | 
| 46 | 
            +
                        k in current_state_dict.keys()
         | 
| 47 | 
            +
                        and current_state_dict[k].size() == state_dict[k].size()
         | 
| 48 | 
            +
                    ):
         | 
| 49 | 
            +
                        new_state_dict[k] = state_dict[k]
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        print("==> WARNING: Skipping %s" % k)
         | 
| 52 | 
            +
                print(
         | 
| 53 | 
            +
                    "%s out of %s keys are matched"
         | 
| 54 | 
            +
                    % (len(new_state_dict.keys()), len(state_dict.keys()))
         | 
| 55 | 
            +
                )
         | 
| 56 | 
            +
                return new_state_dict
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def get_param_num(model):
         | 
| 60 | 
            +
                num_param = sum(param.numel() for param in model.parameters())
         | 
| 61 | 
            +
                return num_param
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_vocoder(config, device):
         | 
| 65 | 
            +
                config = hifigan.AttrDict(HIFIGAN_16K_64)
         | 
| 66 | 
            +
                vocoder = hifigan.Generator(config)
         | 
| 67 | 
            +
                vocoder.eval()
         | 
| 68 | 
            +
                vocoder.remove_weight_norm()
         | 
| 69 | 
            +
                vocoder.to(device)
         | 
| 70 | 
            +
                return vocoder
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def vocoder_infer(mels, vocoder, allow_grad=False, lengths=None):
         | 
| 74 | 
            +
                vocoder.eval()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if allow_grad:
         | 
| 77 | 
            +
                    wavs = vocoder(mels).squeeze(1).float()
         | 
| 78 | 
            +
                    wavs = wavs - (wavs.max() + wavs.min()) / 2
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    with torch.no_grad():
         | 
| 81 | 
            +
                        wavs = vocoder(mels).squeeze(1).float()
         | 
| 82 | 
            +
                        wavs = wavs - (wavs.max() + wavs.min()) / 2
         | 
| 83 | 
            +
                        wavs = (wavs.cpu().numpy() * 32768).astype("int16")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if lengths is not None:
         | 
| 86 | 
            +
                    wavs = wavs[:, :lengths]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return wavs
         | 
    	
        audioldm/latent_diffusion/attention.py
    ADDED
    
    | @@ -0,0 +1,469 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from inspect import isfunction
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from audioldm.latent_diffusion.util import checkpoint
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def exists(val):
         | 
| 12 | 
            +
                return val is not None
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def uniq(arr):
         | 
| 16 | 
            +
                return {el: True for el in arr}.keys()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def default(val, d):
         | 
| 20 | 
            +
                if exists(val):
         | 
| 21 | 
            +
                    return val
         | 
| 22 | 
            +
                return d() if isfunction(d) else d
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def max_neg_value(t):
         | 
| 26 | 
            +
                return -torch.finfo(t.dtype).max
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def init_(tensor):
         | 
| 30 | 
            +
                dim = tensor.shape[-1]
         | 
| 31 | 
            +
                std = 1 / math.sqrt(dim)
         | 
| 32 | 
            +
                tensor.uniform_(-std, std)
         | 
| 33 | 
            +
                return tensor
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # feedforward
         | 
| 37 | 
            +
            class GEGLU(nn.Module):
         | 
| 38 | 
            +
                def __init__(self, dim_in, dim_out):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, x):
         | 
| 43 | 
            +
                    x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 44 | 
            +
                    return x * F.gelu(gate)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class FeedForward(nn.Module):
         | 
| 48 | 
            +
                def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
         | 
| 49 | 
            +
                    super().__init__()
         | 
| 50 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 51 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 52 | 
            +
                    project_in = (
         | 
| 53 | 
            +
                        nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
         | 
| 54 | 
            +
                        if not glu
         | 
| 55 | 
            +
                        else GEGLU(dim, inner_dim)
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.net = nn.Sequential(
         | 
| 59 | 
            +
                        project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def forward(self, x):
         | 
| 63 | 
            +
                    return self.net(x)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def zero_module(module):
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                for p in module.parameters():
         | 
| 71 | 
            +
                    p.detach().zero_()
         | 
| 72 | 
            +
                return module
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def Normalize(in_channels):
         | 
| 76 | 
            +
                return torch.nn.GroupNorm(
         | 
| 77 | 
            +
                    num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class LinearAttention(nn.Module):
         | 
| 82 | 
            +
                def __init__(self, dim, heads=4, dim_head=32):
         | 
| 83 | 
            +
                    super().__init__()
         | 
| 84 | 
            +
                    self.heads = heads
         | 
| 85 | 
            +
                    hidden_dim = dim_head * heads
         | 
| 86 | 
            +
                    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
         | 
| 87 | 
            +
                    self.to_out = nn.Conv2d(hidden_dim, dim, 1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def forward(self, x):
         | 
| 90 | 
            +
                    b, c, h, w = x.shape
         | 
| 91 | 
            +
                    qkv = self.to_qkv(x)
         | 
| 92 | 
            +
                    q, k, v = rearrange(
         | 
| 93 | 
            +
                        qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
                    k = k.softmax(dim=-1)
         | 
| 96 | 
            +
                    context = torch.einsum("bhdn,bhen->bhde", k, v)
         | 
| 97 | 
            +
                    out = torch.einsum("bhde,bhdn->bhen", context, q)
         | 
| 98 | 
            +
                    out = rearrange(
         | 
| 99 | 
            +
                        out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    return self.to_out(out)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class SpatialSelfAttention(nn.Module):
         | 
| 105 | 
            +
                def __init__(self, in_channels):
         | 
| 106 | 
            +
                    super().__init__()
         | 
| 107 | 
            +
                    self.in_channels = in_channels
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 110 | 
            +
                    self.q = torch.nn.Conv2d(
         | 
| 111 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    self.k = torch.nn.Conv2d(
         | 
| 114 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    self.v = torch.nn.Conv2d(
         | 
| 117 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                    self.proj_out = torch.nn.Conv2d(
         | 
| 120 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def forward(self, x):
         | 
| 124 | 
            +
                    h_ = x
         | 
| 125 | 
            +
                    h_ = self.norm(h_)
         | 
| 126 | 
            +
                    q = self.q(h_)
         | 
| 127 | 
            +
                    k = self.k(h_)
         | 
| 128 | 
            +
                    v = self.v(h_)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # compute attention
         | 
| 131 | 
            +
                    b, c, h, w = q.shape
         | 
| 132 | 
            +
                    q = rearrange(q, "b c h w -> b (h w) c")
         | 
| 133 | 
            +
                    k = rearrange(k, "b c h w -> b c (h w)")
         | 
| 134 | 
            +
                    w_ = torch.einsum("bij,bjk->bik", q, k)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    w_ = w_ * (int(c) ** (-0.5))
         | 
| 137 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # attend to values
         | 
| 140 | 
            +
                    v = rearrange(v, "b c h w -> b c (h w)")
         | 
| 141 | 
            +
                    w_ = rearrange(w_, "b i j -> b j i")
         | 
| 142 | 
            +
                    h_ = torch.einsum("bij,bjk->bik", v, w_)
         | 
| 143 | 
            +
                    h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
         | 
| 144 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return x + h_
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            class CrossAttention(nn.Module):
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                ### Cross Attention Layer
         | 
| 152 | 
            +
                This falls-back to self-attention when conditional embeddings are not specified.
         | 
| 153 | 
            +
                """
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # use_flash_attention: bool = True
         | 
| 156 | 
            +
                use_flash_attention: bool = False
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def __init__(
         | 
| 159 | 
            +
                    self,
         | 
| 160 | 
            +
                    query_dim,
         | 
| 161 | 
            +
                    context_dim=None,
         | 
| 162 | 
            +
                    heads=8,
         | 
| 163 | 
            +
                    dim_head=64,
         | 
| 164 | 
            +
                    dropout=0.0,
         | 
| 165 | 
            +
                    is_inplace: bool = True,
         | 
| 166 | 
            +
                ):
         | 
| 167 | 
            +
                    # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
         | 
| 168 | 
            +
                    """
         | 
| 169 | 
            +
                    :param d_model: is the input embedding size
         | 
| 170 | 
            +
                    :param n_heads: is the number of attention heads
         | 
| 171 | 
            +
                    :param d_head: is the size of a attention head
         | 
| 172 | 
            +
                    :param d_cond: is the size of the conditional embeddings
         | 
| 173 | 
            +
                    :param is_inplace: specifies whether to perform the attention softmax computation inplace to
         | 
| 174 | 
            +
                        save memory
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                    super().__init__()
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    self.is_inplace = is_inplace
         | 
| 179 | 
            +
                    self.n_heads = heads
         | 
| 180 | 
            +
                    self.d_head = dim_head
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # Attention scaling factor
         | 
| 183 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # The normal self-attention layer
         | 
| 186 | 
            +
                    if context_dim is None:
         | 
| 187 | 
            +
                        context_dim = query_dim
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    # Query, key and value mappings
         | 
| 190 | 
            +
                    d_attn = dim_head * heads
         | 
| 191 | 
            +
                    self.to_q = nn.Linear(query_dim, d_attn, bias=False)
         | 
| 192 | 
            +
                    self.to_k = nn.Linear(context_dim, d_attn, bias=False)
         | 
| 193 | 
            +
                    self.to_v = nn.Linear(context_dim, d_attn, bias=False)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # Final linear layer
         | 
| 196 | 
            +
                    self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
         | 
| 199 | 
            +
                    # Flash attention is only used if it's installed
         | 
| 200 | 
            +
                    # and `CrossAttention.use_flash_attention` is set to `True`.
         | 
| 201 | 
            +
                    try:
         | 
| 202 | 
            +
                        # You can install flash attention by cloning their Github repo,
         | 
| 203 | 
            +
                        # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
         | 
| 204 | 
            +
                        # and then running `python setup.py install`
         | 
| 205 | 
            +
                        from flash_attn.flash_attention import FlashAttention
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        self.flash = FlashAttention()
         | 
| 208 | 
            +
                        # Set the scale for scaled dot-product attention.
         | 
| 209 | 
            +
                        self.flash.softmax_scale = self.scale
         | 
| 210 | 
            +
                    # Set to `None` if it's not installed
         | 
| 211 | 
            +
                    except ImportError:
         | 
| 212 | 
            +
                        self.flash = None
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
         | 
| 217 | 
            +
                    :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
         | 
| 218 | 
            +
                    """
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # If `cond` is `None` we perform self attention
         | 
| 221 | 
            +
                    has_cond = context is not None
         | 
| 222 | 
            +
                    if not has_cond:
         | 
| 223 | 
            +
                        context = x
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # Get query, key and value vectors
         | 
| 226 | 
            +
                    q = self.to_q(x)
         | 
| 227 | 
            +
                    k = self.to_k(context)
         | 
| 228 | 
            +
                    v = self.to_v(context)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Use flash attention if it's available and the head size is less than or equal to `128`
         | 
| 231 | 
            +
                    if (
         | 
| 232 | 
            +
                        CrossAttention.use_flash_attention
         | 
| 233 | 
            +
                        and self.flash is not None
         | 
| 234 | 
            +
                        and not has_cond
         | 
| 235 | 
            +
                        and self.d_head <= 128
         | 
| 236 | 
            +
                    ):
         | 
| 237 | 
            +
                        return self.flash_attention(q, k, v)
         | 
| 238 | 
            +
                    # Otherwise, fallback to normal attention
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        return self.normal_attention(q, k, v)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
                    #### Flash Attention
         | 
| 245 | 
            +
                    :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 246 | 
            +
                    :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 247 | 
            +
                    :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 248 | 
            +
                    """
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # Get batch size and number of elements along sequence axis (`width * height`)
         | 
| 251 | 
            +
                    batch_size, seq_len, _ = q.shape
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
         | 
| 254 | 
            +
                    # shape `[batch_size, seq_len, 3, n_heads * d_head]`
         | 
| 255 | 
            +
                    qkv = torch.stack((q, k, v), dim=2)
         | 
| 256 | 
            +
                    # Split the heads
         | 
| 257 | 
            +
                    qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
         | 
| 260 | 
            +
                    # fit this size.
         | 
| 261 | 
            +
                    if self.d_head <= 32:
         | 
| 262 | 
            +
                        pad = 32 - self.d_head
         | 
| 263 | 
            +
                    elif self.d_head <= 64:
         | 
| 264 | 
            +
                        pad = 64 - self.d_head
         | 
| 265 | 
            +
                    elif self.d_head <= 128:
         | 
| 266 | 
            +
                        pad = 128 - self.d_head
         | 
| 267 | 
            +
                    else:
         | 
| 268 | 
            +
                        raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # Pad the heads
         | 
| 271 | 
            +
                    if pad:
         | 
| 272 | 
            +
                        qkv = torch.cat(
         | 
| 273 | 
            +
                            (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
         | 
| 274 | 
            +
                        )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # Compute attention
         | 
| 277 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
         | 
| 278 | 
            +
                    # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
         | 
| 279 | 
            +
                    # TODO here I add the dtype changing
         | 
| 280 | 
            +
                    out, _ = self.flash(qkv.type(torch.float16))
         | 
| 281 | 
            +
                    # Truncate the extra head size
         | 
| 282 | 
            +
                    out = out[:, :, :, : self.d_head].float()
         | 
| 283 | 
            +
                    # Reshape to `[batch_size, seq_len, n_heads * d_head]`
         | 
| 284 | 
            +
                    out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Map to `[batch_size, height * width, d_model]` with a linear layer
         | 
| 287 | 
            +
                    return self.to_out(out)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
         | 
| 290 | 
            +
                    """
         | 
| 291 | 
            +
                    #### Normal Attention
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 294 | 
            +
                    :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 295 | 
            +
                    :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 296 | 
            +
                    """
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
         | 
| 299 | 
            +
                    q = q.view(*q.shape[:2], self.n_heads, -1)  # [bs, 64, 20, 32]
         | 
| 300 | 
            +
                    k = k.view(*k.shape[:2], self.n_heads, -1)  # [bs, 1, 20, 32]
         | 
| 301 | 
            +
                    v = v.view(*v.shape[:2], self.n_heads, -1)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
         | 
| 304 | 
            +
                    attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # Compute softmax
         | 
| 307 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
         | 
| 308 | 
            +
                    if self.is_inplace:
         | 
| 309 | 
            +
                        half = attn.shape[0] // 2
         | 
| 310 | 
            +
                        attn[half:] = attn[half:].softmax(dim=-1)
         | 
| 311 | 
            +
                        attn[:half] = attn[:half].softmax(dim=-1)
         | 
| 312 | 
            +
                    else:
         | 
| 313 | 
            +
                        attn = attn.softmax(dim=-1)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # Compute attention output
         | 
| 316 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
         | 
| 317 | 
            +
                    # attn: [bs, 20, 64, 1]
         | 
| 318 | 
            +
                    # v: [bs, 1, 20, 32]
         | 
| 319 | 
            +
                    out = torch.einsum("bhij,bjhd->bihd", attn, v)
         | 
| 320 | 
            +
                    # Reshape to `[batch_size, height * width, n_heads * d_head]`
         | 
| 321 | 
            +
                    out = out.reshape(*out.shape[:2], -1)
         | 
| 322 | 
            +
                    # Map to `[batch_size, height * width, d_model]` with a linear layer
         | 
| 323 | 
            +
                    return self.to_out(out)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            # class CrossAttention(nn.Module):
         | 
| 327 | 
            +
            # def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
         | 
| 328 | 
            +
            #     super().__init__()
         | 
| 329 | 
            +
            #     inner_dim = dim_head * heads
         | 
| 330 | 
            +
            #     context_dim = default(context_dim, query_dim)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
            #     self.scale = dim_head ** -0.5
         | 
| 333 | 
            +
            #     self.heads = heads
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            #     self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 336 | 
            +
            #     self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 337 | 
            +
            #     self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            #     self.to_out = nn.Sequential(
         | 
| 340 | 
            +
            #         nn.Linear(inner_dim, query_dim),
         | 
| 341 | 
            +
            #         nn.Dropout(dropout)
         | 
| 342 | 
            +
            #     )
         | 
| 343 | 
            +
             | 
| 344 | 
            +
            # def forward(self, x, context=None, mask=None):
         | 
| 345 | 
            +
            #     h = self.heads
         | 
| 346 | 
            +
             | 
| 347 | 
            +
            #     q = self.to_q(x)
         | 
| 348 | 
            +
            #     context = default(context, x)
         | 
| 349 | 
            +
            #     k = self.to_k(context)
         | 
| 350 | 
            +
            #     v = self.to_v(context)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
            #     q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
         | 
| 353 | 
            +
             | 
| 354 | 
            +
            #     sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            #     if exists(mask):
         | 
| 357 | 
            +
            #         mask = rearrange(mask, 'b ... -> b (...)')
         | 
| 358 | 
            +
            #         max_neg_value = -torch.finfo(sim.dtype).max
         | 
| 359 | 
            +
            #         mask = repeat(mask, 'b j -> (b h) () j', h=h)
         | 
| 360 | 
            +
            #         sim.masked_fill_(~mask, max_neg_value)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
            #     # attention, what we cannot get enough of
         | 
| 363 | 
            +
            #     attn = sim.softmax(dim=-1)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
            #     out = einsum('b i j, b j d -> b i d', attn, v)
         | 
| 366 | 
            +
            #     out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
         | 
| 367 | 
            +
            #     return self.to_out(out)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 371 | 
            +
                def __init__(
         | 
| 372 | 
            +
                    self,
         | 
| 373 | 
            +
                    dim,
         | 
| 374 | 
            +
                    n_heads,
         | 
| 375 | 
            +
                    d_head,
         | 
| 376 | 
            +
                    dropout=0.0,
         | 
| 377 | 
            +
                    context_dim=None,
         | 
| 378 | 
            +
                    gated_ff=True,
         | 
| 379 | 
            +
                    checkpoint=True,
         | 
| 380 | 
            +
                ):
         | 
| 381 | 
            +
                    super().__init__()
         | 
| 382 | 
            +
                    self.attn1 = CrossAttention(
         | 
| 383 | 
            +
                        query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
         | 
| 384 | 
            +
                    )  # is a self-attention
         | 
| 385 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 386 | 
            +
                    self.attn2 = CrossAttention(
         | 
| 387 | 
            +
                        query_dim=dim,
         | 
| 388 | 
            +
                        context_dim=context_dim,
         | 
| 389 | 
            +
                        heads=n_heads,
         | 
| 390 | 
            +
                        dim_head=d_head,
         | 
| 391 | 
            +
                        dropout=dropout,
         | 
| 392 | 
            +
                    )  # is self-attn if context is none
         | 
| 393 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 394 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 395 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 396 | 
            +
                    self.checkpoint = checkpoint
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                def forward(self, x, context=None):
         | 
| 399 | 
            +
                    if context is None:
         | 
| 400 | 
            +
                        return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
         | 
| 401 | 
            +
                    else:
         | 
| 402 | 
            +
                        return checkpoint(
         | 
| 403 | 
            +
                            self._forward, (x, context), self.parameters(), self.checkpoint
         | 
| 404 | 
            +
                        )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                def _forward(self, x, context=None):
         | 
| 407 | 
            +
                    x = self.attn1(self.norm1(x)) + x
         | 
| 408 | 
            +
                    x = self.attn2(self.norm2(x), context=context) + x
         | 
| 409 | 
            +
                    x = self.ff(self.norm3(x)) + x
         | 
| 410 | 
            +
                    return x
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            class SpatialTransformer(nn.Module):
         | 
| 414 | 
            +
                """
         | 
| 415 | 
            +
                Transformer block for image-like data.
         | 
| 416 | 
            +
                First, project the input (aka embedding)
         | 
| 417 | 
            +
                and reshape to b, t, d.
         | 
| 418 | 
            +
                Then apply standard transformer action.
         | 
| 419 | 
            +
                Finally, reshape to image
         | 
| 420 | 
            +
                """
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def __init__(
         | 
| 423 | 
            +
                    self,
         | 
| 424 | 
            +
                    in_channels,
         | 
| 425 | 
            +
                    n_heads,
         | 
| 426 | 
            +
                    d_head,
         | 
| 427 | 
            +
                    depth=1,
         | 
| 428 | 
            +
                    dropout=0.0,
         | 
| 429 | 
            +
                    context_dim=None,
         | 
| 430 | 
            +
                    no_context=False,
         | 
| 431 | 
            +
                ):
         | 
| 432 | 
            +
                    super().__init__()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    if no_context:
         | 
| 435 | 
            +
                        context_dim = None
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    self.in_channels = in_channels
         | 
| 438 | 
            +
                    inner_dim = n_heads * d_head
         | 
| 439 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    self.proj_in = nn.Conv2d(
         | 
| 442 | 
            +
                        in_channels, inner_dim, kernel_size=1, stride=1, padding=0
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 446 | 
            +
                        [
         | 
| 447 | 
            +
                            BasicTransformerBlock(
         | 
| 448 | 
            +
                                inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
         | 
| 449 | 
            +
                            )
         | 
| 450 | 
            +
                            for d in range(depth)
         | 
| 451 | 
            +
                        ]
         | 
| 452 | 
            +
                    )
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    self.proj_out = zero_module(
         | 
| 455 | 
            +
                        nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 456 | 
            +
                    )
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                def forward(self, x, context=None):
         | 
| 459 | 
            +
                    # note: if no context is given, cross-attention defaults to self-attention
         | 
| 460 | 
            +
                    b, c, h, w = x.shape
         | 
| 461 | 
            +
                    x_in = x
         | 
| 462 | 
            +
                    x = self.norm(x)
         | 
| 463 | 
            +
                    x = self.proj_in(x)
         | 
| 464 | 
            +
                    x = rearrange(x, "b c h w -> b (h w) c")
         | 
| 465 | 
            +
                    for block in self.transformer_blocks:
         | 
| 466 | 
            +
                        x = block(x, context=context)
         | 
| 467 | 
            +
                    x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
         | 
| 468 | 
            +
                    x = self.proj_out(x)
         | 
| 469 | 
            +
                    return x + x_in
         | 
    	
        audioldm/latent_diffusion/util.py
    ADDED
    
    | @@ -0,0 +1,293 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # adopted from
         | 
| 2 | 
            +
            # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
         | 
| 3 | 
            +
            # and
         | 
| 4 | 
            +
            # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
         | 
| 5 | 
            +
            # and
         | 
| 6 | 
            +
            # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            # thanks!
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            from einops import repeat
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from audioldm.utils import instantiate_from_config
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def make_beta_schedule(
         | 
| 20 | 
            +
                schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
         | 
| 21 | 
            +
            ):
         | 
| 22 | 
            +
                if schedule == "linear":
         | 
| 23 | 
            +
                    betas = (
         | 
| 24 | 
            +
                        torch.linspace(
         | 
| 25 | 
            +
                            linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
         | 
| 26 | 
            +
                        )
         | 
| 27 | 
            +
                        ** 2
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                elif schedule == "cosine":
         | 
| 31 | 
            +
                    timesteps = (
         | 
| 32 | 
            +
                        torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    alphas = timesteps / (1 + cosine_s) * np.pi / 2
         | 
| 35 | 
            +
                    alphas = torch.cos(alphas).pow(2)
         | 
| 36 | 
            +
                    alphas = alphas / alphas[0]
         | 
| 37 | 
            +
                    betas = 1 - alphas[1:] / alphas[:-1]
         | 
| 38 | 
            +
                    betas = np.clip(betas, a_min=0, a_max=0.999)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                elif schedule == "sqrt_linear":
         | 
| 41 | 
            +
                    betas = torch.linspace(
         | 
| 42 | 
            +
                        linear_start, linear_end, n_timestep, dtype=torch.float64
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                elif schedule == "sqrt":
         | 
| 45 | 
            +
                    betas = (
         | 
| 46 | 
            +
                        torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
         | 
| 47 | 
            +
                        ** 0.5
         | 
| 48 | 
            +
                    )
         | 
| 49 | 
            +
                else:
         | 
| 50 | 
            +
                    raise ValueError(f"schedule '{schedule}' unknown.")
         | 
| 51 | 
            +
                return betas.numpy()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def make_ddim_timesteps(
         | 
| 55 | 
            +
                ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
         | 
| 56 | 
            +
            ):
         | 
| 57 | 
            +
                if ddim_discr_method == "uniform":
         | 
| 58 | 
            +
                    c = num_ddpm_timesteps // num_ddim_timesteps
         | 
| 59 | 
            +
                    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
         | 
| 60 | 
            +
                elif ddim_discr_method == "quad":
         | 
| 61 | 
            +
                    ddim_timesteps = (
         | 
| 62 | 
            +
                        (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
         | 
| 63 | 
            +
                    ).astype(int)
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    raise NotImplementedError(
         | 
| 66 | 
            +
                        f'There is no ddim discretization method called "{ddim_discr_method}"'
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # assert ddim_timesteps.shape[0] == num_ddim_timesteps
         | 
| 70 | 
            +
                # add one to get the final alpha values right (the ones from first scale to data during sampling)
         | 
| 71 | 
            +
                steps_out = ddim_timesteps + 1
         | 
| 72 | 
            +
                if verbose:
         | 
| 73 | 
            +
                    print(f"Selected timesteps for ddim sampler: {steps_out}")
         | 
| 74 | 
            +
                return steps_out
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
         | 
| 78 | 
            +
                # select alphas for computing the variance schedule
         | 
| 79 | 
            +
                alphas = alphacums[ddim_timesteps]
         | 
| 80 | 
            +
                alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # according the the formula provided in https://arxiv.org/abs/2010.02502
         | 
| 83 | 
            +
                sigmas = eta * np.sqrt(
         | 
| 84 | 
            +
                    (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
         | 
| 85 | 
            +
                )
         | 
| 86 | 
            +
                if verbose:
         | 
| 87 | 
            +
                    print(
         | 
| 88 | 
            +
                        f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                    print(
         | 
| 91 | 
            +
                        f"For the chosen value of eta, which is {eta}, "
         | 
| 92 | 
            +
                        f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                return sigmas, alphas, alphas_prev
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function,
         | 
| 100 | 
            +
                which defines the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 101 | 
            +
                :param num_diffusion_timesteps: the number of betas to produce.
         | 
| 102 | 
            +
                :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
         | 
| 103 | 
            +
                                  produces the cumulative product of (1-beta) up to that
         | 
| 104 | 
            +
                                  part of the diffusion process.
         | 
| 105 | 
            +
                :param max_beta: the maximum beta to use; use values lower than 1 to
         | 
| 106 | 
            +
                                 prevent singularities.
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                betas = []
         | 
| 109 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 110 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 111 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 112 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 113 | 
            +
                return np.array(betas)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 117 | 
            +
                b, *_ = t.shape
         | 
| 118 | 
            +
                out = a.gather(-1, t).contiguous()
         | 
| 119 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def checkpoint(func, inputs, params, flag):
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                Evaluate a function without caching intermediate activations, allowing for
         | 
| 125 | 
            +
                reduced memory at the expense of extra compute in the backward pass.
         | 
| 126 | 
            +
                :param func: the function to evaluate.
         | 
| 127 | 
            +
                :param inputs: the argument sequence to pass to `func`.
         | 
| 128 | 
            +
                :param params: a sequence of parameters `func` depends on but does not
         | 
| 129 | 
            +
                               explicitly take as arguments.
         | 
| 130 | 
            +
                :param flag: if False, disable gradient checkpointing.
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                if flag:
         | 
| 133 | 
            +
                    args = tuple(inputs) + tuple(params)
         | 
| 134 | 
            +
                    return CheckpointFunction.apply(func, len(inputs), *args)
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                    return func(*inputs)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            class CheckpointFunction(torch.autograd.Function):
         | 
| 140 | 
            +
                @staticmethod
         | 
| 141 | 
            +
                def forward(ctx, run_function, length, *args):
         | 
| 142 | 
            +
                    ctx.run_function = run_function
         | 
| 143 | 
            +
                    ctx.input_tensors = list(args[:length])
         | 
| 144 | 
            +
                    ctx.input_params = list(args[length:])
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    with torch.no_grad():
         | 
| 147 | 
            +
                        output_tensors = ctx.run_function(*ctx.input_tensors)
         | 
| 148 | 
            +
                    return output_tensors
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                @staticmethod
         | 
| 151 | 
            +
                def backward(ctx, *output_grads):
         | 
| 152 | 
            +
                    ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
         | 
| 153 | 
            +
                    with torch.enable_grad():
         | 
| 154 | 
            +
                        # Fixes a bug where the first op in run_function modifies the
         | 
| 155 | 
            +
                        # Tensor storage in place, which is not allowed for detach()'d
         | 
| 156 | 
            +
                        # Tensors.
         | 
| 157 | 
            +
                        shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
         | 
| 158 | 
            +
                        output_tensors = ctx.run_function(*shallow_copies)
         | 
| 159 | 
            +
                    input_grads = torch.autograd.grad(
         | 
| 160 | 
            +
                        output_tensors,
         | 
| 161 | 
            +
                        ctx.input_tensors + ctx.input_params,
         | 
| 162 | 
            +
                        output_grads,
         | 
| 163 | 
            +
                        allow_unused=True,
         | 
| 164 | 
            +
                    )
         | 
| 165 | 
            +
                    del ctx.input_tensors
         | 
| 166 | 
            +
                    del ctx.input_params
         | 
| 167 | 
            +
                    del output_tensors
         | 
| 168 | 
            +
                    return (None, None) + input_grads
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
         | 
| 172 | 
            +
                """
         | 
| 173 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 174 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 175 | 
            +
                                  These may be fractional.
         | 
| 176 | 
            +
                :param dim: the dimension of the output.
         | 
| 177 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 178 | 
            +
                :return: an [N x dim] Tensor of positional embeddings.
         | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                if not repeat_only:
         | 
| 181 | 
            +
                    half = dim // 2
         | 
| 182 | 
            +
                    freqs = torch.exp(
         | 
| 183 | 
            +
                        -math.log(max_period)
         | 
| 184 | 
            +
                        * torch.arange(start=0, end=half, dtype=torch.float32)
         | 
| 185 | 
            +
                        / half
         | 
| 186 | 
            +
                    ).to(device=timesteps.device)
         | 
| 187 | 
            +
                    args = timesteps[:, None].float() * freqs[None]
         | 
| 188 | 
            +
                    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 189 | 
            +
                    if dim % 2:
         | 
| 190 | 
            +
                        embedding = torch.cat(
         | 
| 191 | 
            +
                            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
         | 
| 192 | 
            +
                        )
         | 
| 193 | 
            +
                else:
         | 
| 194 | 
            +
                    embedding = repeat(timesteps, "b -> b d", d=dim)
         | 
| 195 | 
            +
                return embedding
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            def zero_module(module):
         | 
| 199 | 
            +
                """
         | 
| 200 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                for p in module.parameters():
         | 
| 203 | 
            +
                    p.detach().zero_()
         | 
| 204 | 
            +
                return module
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def scale_module(module, scale):
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Scale the parameters of a module and return it.
         | 
| 210 | 
            +
                """
         | 
| 211 | 
            +
                for p in module.parameters():
         | 
| 212 | 
            +
                    p.detach().mul_(scale)
         | 
| 213 | 
            +
                return module
         | 
| 214 | 
            +
             | 
| 215 | 
            +
             | 
| 216 | 
            +
            def mean_flat(tensor):
         | 
| 217 | 
            +
                """
         | 
| 218 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 219 | 
            +
                """
         | 
| 220 | 
            +
                return tensor.mean(dim=list(range(1, len(tensor.shape))))
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            def normalization(channels):
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                Make a standard normalization layer.
         | 
| 226 | 
            +
                :param channels: number of input channels.
         | 
| 227 | 
            +
                :return: an nn.Module for normalization.
         | 
| 228 | 
            +
                """
         | 
| 229 | 
            +
                return GroupNorm32(32, channels)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
         | 
| 233 | 
            +
            class SiLU(nn.Module):
         | 
| 234 | 
            +
                def forward(self, x):
         | 
| 235 | 
            +
                    return x * torch.sigmoid(x)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            class GroupNorm32(nn.GroupNorm):
         | 
| 239 | 
            +
                def forward(self, x):
         | 
| 240 | 
            +
                    return super().forward(x.float()).type(x.dtype)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def conv_nd(dims, *args, **kwargs):
         | 
| 244 | 
            +
                """
         | 
| 245 | 
            +
                Create a 1D, 2D, or 3D convolution module.
         | 
| 246 | 
            +
                """
         | 
| 247 | 
            +
                if dims == 1:
         | 
| 248 | 
            +
                    return nn.Conv1d(*args, **kwargs)
         | 
| 249 | 
            +
                elif dims == 2:
         | 
| 250 | 
            +
                    return nn.Conv2d(*args, **kwargs)
         | 
| 251 | 
            +
                elif dims == 3:
         | 
| 252 | 
            +
                    return nn.Conv3d(*args, **kwargs)
         | 
| 253 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 254 | 
            +
             | 
| 255 | 
            +
             | 
| 256 | 
            +
            def linear(*args, **kwargs):
         | 
| 257 | 
            +
                """
         | 
| 258 | 
            +
                Create a linear module.
         | 
| 259 | 
            +
                """
         | 
| 260 | 
            +
                return nn.Linear(*args, **kwargs)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def avg_pool_nd(dims, *args, **kwargs):
         | 
| 264 | 
            +
                """
         | 
| 265 | 
            +
                Create a 1D, 2D, or 3D average pooling module.
         | 
| 266 | 
            +
                """
         | 
| 267 | 
            +
                if dims == 1:
         | 
| 268 | 
            +
                    return nn.AvgPool1d(*args, **kwargs)
         | 
| 269 | 
            +
                elif dims == 2:
         | 
| 270 | 
            +
                    return nn.AvgPool2d(*args, **kwargs)
         | 
| 271 | 
            +
                elif dims == 3:
         | 
| 272 | 
            +
                    return nn.AvgPool3d(*args, **kwargs)
         | 
| 273 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            class HybridConditioner(nn.Module):
         | 
| 277 | 
            +
                def __init__(self, c_concat_config, c_crossattn_config):
         | 
| 278 | 
            +
                    super().__init__()
         | 
| 279 | 
            +
                    self.concat_conditioner = instantiate_from_config(c_concat_config)
         | 
| 280 | 
            +
                    self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def forward(self, c_concat, c_crossattn):
         | 
| 283 | 
            +
                    c_concat = self.concat_conditioner(c_concat)
         | 
| 284 | 
            +
                    c_crossattn = self.crossattn_conditioner(c_crossattn)
         | 
| 285 | 
            +
                    return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            def noise_like(shape, device, repeat=False):
         | 
| 289 | 
            +
                repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
         | 
| 290 | 
            +
                    shape[0], *((1,) * (len(shape) - 1))
         | 
| 291 | 
            +
                )
         | 
| 292 | 
            +
                noise = lambda: torch.randn(shape, device=device)
         | 
| 293 | 
            +
                return repeat_noise() if repeat else noise()
         | 
    	
        audioldm/stft.py
    ADDED
    
    | @@ -0,0 +1,257 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from scipy.signal import get_window
         | 
| 5 | 
            +
            from librosa.util import pad_center, tiny, normalize, pad_center
         | 
| 6 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                Parameters
         | 
| 12 | 
            +
                ----------
         | 
| 13 | 
            +
                C: compression factor
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                return normalize_fun(torch.clamp(x, min=clip_val) * C)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def dynamic_range_decompression(x, C=1):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Parameters
         | 
| 21 | 
            +
                ----------
         | 
| 22 | 
            +
                C: compression factor used to compress
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                return torch.exp(x) / C
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def window_sumsquare(
         | 
| 28 | 
            +
                window,
         | 
| 29 | 
            +
                n_frames,
         | 
| 30 | 
            +
                hop_length,
         | 
| 31 | 
            +
                win_length,
         | 
| 32 | 
            +
                n_fft,
         | 
| 33 | 
            +
                dtype=np.float32,
         | 
| 34 | 
            +
                norm=None,
         | 
| 35 | 
            +
            ):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                # from librosa 0.6
         | 
| 38 | 
            +
                Compute the sum-square envelope of a window function at a given hop length.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                This is used to estimate modulation effects induced by windowing
         | 
| 41 | 
            +
                observations in short-time fourier transforms.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                Parameters
         | 
| 44 | 
            +
                ----------
         | 
| 45 | 
            +
                window : string, tuple, number, callable, or list-like
         | 
| 46 | 
            +
                    Window specification, as in `get_window`
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                n_frames : int > 0
         | 
| 49 | 
            +
                    The number of analysis frames
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                hop_length : int > 0
         | 
| 52 | 
            +
                    The number of samples to advance between frames
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                win_length : [optional]
         | 
| 55 | 
            +
                    The length of the window function.  By default, this matches `n_fft`.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                n_fft : int > 0
         | 
| 58 | 
            +
                    The length of each analysis frame.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                dtype : np.dtype
         | 
| 61 | 
            +
                    The data type of the output
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                Returns
         | 
| 64 | 
            +
                -------
         | 
| 65 | 
            +
                wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
         | 
| 66 | 
            +
                    The sum-squared envelope of the window function
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                if win_length is None:
         | 
| 69 | 
            +
                    win_length = n_fft
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                n = n_fft + hop_length * (n_frames - 1)
         | 
| 72 | 
            +
                x = np.zeros(n, dtype=dtype)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # Compute the squared window at the desired length
         | 
| 75 | 
            +
                win_sq = get_window(window, win_length, fftbins=True)
         | 
| 76 | 
            +
                win_sq = normalize(win_sq, norm=norm) ** 2
         | 
| 77 | 
            +
                win_sq = pad_center(win_sq, n_fft)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Fill the envelope
         | 
| 80 | 
            +
                for i in range(n_frames):
         | 
| 81 | 
            +
                    sample = i * hop_length
         | 
| 82 | 
            +
                    x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
         | 
| 83 | 
            +
                return x
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class STFT(torch.nn.Module):
         | 
| 87 | 
            +
                """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def __init__(self, filter_length, hop_length, win_length, window="hann"):
         | 
| 90 | 
            +
                    super(STFT, self).__init__()
         | 
| 91 | 
            +
                    self.filter_length = filter_length
         | 
| 92 | 
            +
                    self.hop_length = hop_length
         | 
| 93 | 
            +
                    self.win_length = win_length
         | 
| 94 | 
            +
                    self.window = window
         | 
| 95 | 
            +
                    self.forward_transform = None
         | 
| 96 | 
            +
                    scale = self.filter_length / self.hop_length
         | 
| 97 | 
            +
                    fourier_basis = np.fft.fft(np.eye(self.filter_length))
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    cutoff = int((self.filter_length / 2 + 1))
         | 
| 100 | 
            +
                    fourier_basis = np.vstack(
         | 
| 101 | 
            +
                        [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
         | 
| 105 | 
            +
                    inverse_basis = torch.FloatTensor(
         | 
| 106 | 
            +
                        np.linalg.pinv(scale * fourier_basis).T[:, None, :]
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    if window is not None:
         | 
| 110 | 
            +
                        assert filter_length >= win_length
         | 
| 111 | 
            +
                        # get window and zero center pad it to filter_length
         | 
| 112 | 
            +
                        fft_window = get_window(window, win_length, fftbins=True)
         | 
| 113 | 
            +
                        fft_window = pad_center(fft_window, size=filter_length)
         | 
| 114 | 
            +
                        fft_window = torch.from_numpy(fft_window).float()
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                        # window the bases
         | 
| 117 | 
            +
                        forward_basis *= fft_window
         | 
| 118 | 
            +
                        inverse_basis *= fft_window
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.register_buffer("forward_basis", forward_basis.float())
         | 
| 121 | 
            +
                    self.register_buffer("inverse_basis", inverse_basis.float())
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def transform(self, input_data):
         | 
| 124 | 
            +
                    device = self.forward_basis.device
         | 
| 125 | 
            +
                    input_data = input_data.to(device)
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    num_batches = input_data.size(0)
         | 
| 128 | 
            +
                    num_samples = input_data.size(1)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.num_samples = num_samples
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # similar to librosa, reflect-pad the input
         | 
| 133 | 
            +
                    input_data = input_data.view(num_batches, 1, num_samples)
         | 
| 134 | 
            +
                    input_data = F.pad(
         | 
| 135 | 
            +
                        input_data.unsqueeze(1),
         | 
| 136 | 
            +
                        (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
         | 
| 137 | 
            +
                        mode="reflect",
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
                    input_data = input_data.squeeze(1)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    forward_transform = F.conv1d(
         | 
| 142 | 
            +
                        input_data,
         | 
| 143 | 
            +
                        torch.autograd.Variable(self.forward_basis, requires_grad=False),
         | 
| 144 | 
            +
                        stride=self.hop_length,
         | 
| 145 | 
            +
                        padding=0,
         | 
| 146 | 
            +
                    )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    cutoff = int((self.filter_length / 2) + 1)
         | 
| 149 | 
            +
                    real_part = forward_transform[:, :cutoff, :]
         | 
| 150 | 
            +
                    imag_part = forward_transform[:, cutoff:, :]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    magnitude = torch.sqrt(real_part**2 + imag_part**2)
         | 
| 153 | 
            +
                    phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    return magnitude, phase
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def inverse(self, magnitude, phase):
         | 
| 158 | 
            +
                    device = self.forward_basis.device
         | 
| 159 | 
            +
                    magnitude, phase = magnitude.to(device), phase.to(device)
         | 
| 160 | 
            +
                    
         | 
| 161 | 
            +
                    recombine_magnitude_phase = torch.cat(
         | 
| 162 | 
            +
                        [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    inverse_transform = F.conv_transpose1d(
         | 
| 166 | 
            +
                        recombine_magnitude_phase,
         | 
| 167 | 
            +
                        torch.autograd.Variable(self.inverse_basis, requires_grad=False),
         | 
| 168 | 
            +
                        stride=self.hop_length,
         | 
| 169 | 
            +
                        padding=0,
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    if self.window is not None:
         | 
| 173 | 
            +
                        window_sum = window_sumsquare(
         | 
| 174 | 
            +
                            self.window,
         | 
| 175 | 
            +
                            magnitude.size(-1),
         | 
| 176 | 
            +
                            hop_length=self.hop_length,
         | 
| 177 | 
            +
                            win_length=self.win_length,
         | 
| 178 | 
            +
                            n_fft=self.filter_length,
         | 
| 179 | 
            +
                            dtype=np.float32,
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                        # remove modulation effects
         | 
| 182 | 
            +
                        approx_nonzero_indices = torch.from_numpy(
         | 
| 183 | 
            +
                            np.where(window_sum > tiny(window_sum))[0]
         | 
| 184 | 
            +
                        )
         | 
| 185 | 
            +
                        window_sum = torch.autograd.Variable(
         | 
| 186 | 
            +
                            torch.from_numpy(window_sum), requires_grad=False
         | 
| 187 | 
            +
                        )
         | 
| 188 | 
            +
                        window_sum = window_sum
         | 
| 189 | 
            +
                        inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
         | 
| 190 | 
            +
                            approx_nonzero_indices
         | 
| 191 | 
            +
                        ]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        # scale by hop ratio
         | 
| 194 | 
            +
                        inverse_transform *= float(self.filter_length) / self.hop_length
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
         | 
| 197 | 
            +
                    inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    return inverse_transform
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def forward(self, input_data):
         | 
| 202 | 
            +
                    self.magnitude, self.phase = self.transform(input_data)
         | 
| 203 | 
            +
                    reconstruction = self.inverse(self.magnitude, self.phase)
         | 
| 204 | 
            +
                    return reconstruction
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            class TacotronSTFT(torch.nn.Module):
         | 
| 208 | 
            +
                def __init__(
         | 
| 209 | 
            +
                    self,
         | 
| 210 | 
            +
                    filter_length,
         | 
| 211 | 
            +
                    hop_length,
         | 
| 212 | 
            +
                    win_length,
         | 
| 213 | 
            +
                    n_mel_channels,
         | 
| 214 | 
            +
                    sampling_rate,
         | 
| 215 | 
            +
                    mel_fmin,
         | 
| 216 | 
            +
                    mel_fmax,
         | 
| 217 | 
            +
                ):
         | 
| 218 | 
            +
                    super(TacotronSTFT, self).__init__()
         | 
| 219 | 
            +
                    self.n_mel_channels = n_mel_channels
         | 
| 220 | 
            +
                    self.sampling_rate = sampling_rate
         | 
| 221 | 
            +
                    self.stft_fn = STFT(filter_length, hop_length, win_length)
         | 
| 222 | 
            +
                    mel_basis = librosa_mel_fn(
         | 
| 223 | 
            +
                        sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    mel_basis = torch.from_numpy(mel_basis).float()
         | 
| 226 | 
            +
                    self.register_buffer("mel_basis", mel_basis)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def spectral_normalize(self, magnitudes, normalize_fun):
         | 
| 229 | 
            +
                    output = dynamic_range_compression(magnitudes, normalize_fun)
         | 
| 230 | 
            +
                    return output
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def spectral_de_normalize(self, magnitudes):
         | 
| 233 | 
            +
                    output = dynamic_range_decompression(magnitudes)
         | 
| 234 | 
            +
                    return output
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def mel_spectrogram(self, y, normalize_fun=torch.log):
         | 
| 237 | 
            +
                    """Computes mel-spectrograms from a batch of waves
         | 
| 238 | 
            +
                    PARAMS
         | 
| 239 | 
            +
                    ------
         | 
| 240 | 
            +
                    y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    RETURNS
         | 
| 243 | 
            +
                    -------
         | 
| 244 | 
            +
                    mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
         | 
| 245 | 
            +
                    """
         | 
| 246 | 
            +
                    assert torch.min(y.data) >= -1, torch.min(y.data)
         | 
| 247 | 
            +
                    assert torch.max(y.data) <= 1, torch.max(y.data)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    magnitudes, phases = self.stft_fn.transform(y)
         | 
| 250 | 
            +
                    magnitudes = magnitudes.data
         | 
| 251 | 
            +
                    mel_output = torch.matmul(self.mel_basis, magnitudes)
         | 
| 252 | 
            +
                    mel_output = self.spectral_normalize(mel_output, normalize_fun)
         | 
| 253 | 
            +
                    energy = torch.norm(magnitudes, dim=1)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    return mel_output, log_magnitudes, energy
         | 
    	
        audioldm/utils.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import importlib
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            CACHE_DIR = os.getenv(
         | 
| 6 | 
            +
                "AUDIOLDM_CACHE_DIR",
         | 
| 7 | 
            +
                os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def default_audioldm_config(model_name="audioldm-s-full"):    
         | 
| 11 | 
            +
                basic_config = {
         | 
| 12 | 
            +
                    "wave_file_save_path": "./output",
         | 
| 13 | 
            +
                    "id": {
         | 
| 14 | 
            +
                        "version": "v1",
         | 
| 15 | 
            +
                        "name": "default",
         | 
| 16 | 
            +
                        "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
         | 
| 17 | 
            +
                    },
         | 
| 18 | 
            +
                    "preprocessing": {
         | 
| 19 | 
            +
                        "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
         | 
| 20 | 
            +
                        "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
         | 
| 21 | 
            +
                        "mel": {
         | 
| 22 | 
            +
                            "n_mel_channels": 64,
         | 
| 23 | 
            +
                            "mel_fmin": 0,
         | 
| 24 | 
            +
                            "mel_fmax": 8000,
         | 
| 25 | 
            +
                            "freqm": 0,
         | 
| 26 | 
            +
                            "timem": 0,
         | 
| 27 | 
            +
                            "blur": False,
         | 
| 28 | 
            +
                            "mean": -4.63,
         | 
| 29 | 
            +
                            "std": 2.74,
         | 
| 30 | 
            +
                            "target_length": 1024,
         | 
| 31 | 
            +
                        },
         | 
| 32 | 
            +
                    },
         | 
| 33 | 
            +
                    "model": {
         | 
| 34 | 
            +
                        "device": "cuda",
         | 
| 35 | 
            +
                        "target": "audioldm.pipline.LatentDiffusion",
         | 
| 36 | 
            +
                        "params": {
         | 
| 37 | 
            +
                            "base_learning_rate": 5e-06,
         | 
| 38 | 
            +
                            "linear_start": 0.0015,
         | 
| 39 | 
            +
                            "linear_end": 0.0195,
         | 
| 40 | 
            +
                            "num_timesteps_cond": 1,
         | 
| 41 | 
            +
                            "log_every_t": 200,
         | 
| 42 | 
            +
                            "timesteps": 1000,
         | 
| 43 | 
            +
                            "first_stage_key": "fbank",
         | 
| 44 | 
            +
                            "cond_stage_key": "waveform",
         | 
| 45 | 
            +
                            "latent_t_size": 256,
         | 
| 46 | 
            +
                            "latent_f_size": 16,
         | 
| 47 | 
            +
                            "channels": 8,
         | 
| 48 | 
            +
                            "cond_stage_trainable": True,
         | 
| 49 | 
            +
                            "conditioning_key": "film",
         | 
| 50 | 
            +
                            "monitor": "val/loss_simple_ema",
         | 
| 51 | 
            +
                            "scale_by_std": True,
         | 
| 52 | 
            +
                            "unet_config": {
         | 
| 53 | 
            +
                                "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
         | 
| 54 | 
            +
                                "params": {
         | 
| 55 | 
            +
                                    "image_size": 64,
         | 
| 56 | 
            +
                                    "extra_film_condition_dim": 512,
         | 
| 57 | 
            +
                                    "extra_film_use_concat": True,
         | 
| 58 | 
            +
                                    "in_channels": 8,
         | 
| 59 | 
            +
                                    "out_channels": 8,
         | 
| 60 | 
            +
                                    "model_channels": 128,
         | 
| 61 | 
            +
                                    "attention_resolutions": [8, 4, 2],
         | 
| 62 | 
            +
                                    "num_res_blocks": 2,
         | 
| 63 | 
            +
                                    "channel_mult": [1, 2, 3, 5],
         | 
| 64 | 
            +
                                    "num_head_channels": 32,
         | 
| 65 | 
            +
                                    "use_spatial_transformer": True,
         | 
| 66 | 
            +
                                },
         | 
| 67 | 
            +
                            },
         | 
| 68 | 
            +
                            "first_stage_config": {
         | 
| 69 | 
            +
                                "base_learning_rate": 4.5e-05,
         | 
| 70 | 
            +
                                "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
         | 
| 71 | 
            +
                                "params": {
         | 
| 72 | 
            +
                                    "monitor": "val/rec_loss",
         | 
| 73 | 
            +
                                    "image_key": "fbank",
         | 
| 74 | 
            +
                                    "subband": 1,
         | 
| 75 | 
            +
                                    "embed_dim": 8,
         | 
| 76 | 
            +
                                    "time_shuffle": 1,
         | 
| 77 | 
            +
                                    "ddconfig": {
         | 
| 78 | 
            +
                                        "double_z": True,
         | 
| 79 | 
            +
                                        "z_channels": 8,
         | 
| 80 | 
            +
                                        "resolution": 256,
         | 
| 81 | 
            +
                                        "downsample_time": False,
         | 
| 82 | 
            +
                                        "in_channels": 1,
         | 
| 83 | 
            +
                                        "out_ch": 1,
         | 
| 84 | 
            +
                                        "ch": 128,
         | 
| 85 | 
            +
                                        "ch_mult": [1, 2, 4],
         | 
| 86 | 
            +
                                        "num_res_blocks": 2,
         | 
| 87 | 
            +
                                        "attn_resolutions": [],
         | 
| 88 | 
            +
                                        "dropout": 0.0,
         | 
| 89 | 
            +
                                    },
         | 
| 90 | 
            +
                                },
         | 
| 91 | 
            +
                            },
         | 
| 92 | 
            +
                            "cond_stage_config": {
         | 
| 93 | 
            +
                                "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
         | 
| 94 | 
            +
                                "params": {
         | 
| 95 | 
            +
                                    "key": "waveform",
         | 
| 96 | 
            +
                                    "sampling_rate": 16000,
         | 
| 97 | 
            +
                                    "embed_mode": "audio",
         | 
| 98 | 
            +
                                    "unconditional_prob": 0.1,
         | 
| 99 | 
            +
                                },
         | 
| 100 | 
            +
                            },
         | 
| 101 | 
            +
                        },
         | 
| 102 | 
            +
                    },
         | 
| 103 | 
            +
                }
         | 
| 104 | 
            +
                
         | 
| 105 | 
            +
                if("-l-" in model_name):
         | 
| 106 | 
            +
                    basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
         | 
| 107 | 
            +
                    basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
         | 
| 108 | 
            +
                elif("-m-" in model_name):
         | 
| 109 | 
            +
                    basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
         | 
| 110 | 
            +
                    basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                return basic_config
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def get_metadata():
         | 
| 116 | 
            +
                return {
         | 
| 117 | 
            +
                    "audioldm-s-full": {
         | 
| 118 | 
            +
                        "path": os.path.join(
         | 
| 119 | 
            +
                            CACHE_DIR,
         | 
| 120 | 
            +
                            "audioldm-s-full.ckpt",
         | 
| 121 | 
            +
                        ),
         | 
| 122 | 
            +
                        "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
         | 
| 123 | 
            +
                    },
         | 
| 124 | 
            +
                    "audioldm-l-full": {
         | 
| 125 | 
            +
                        "path": os.path.join(
         | 
| 126 | 
            +
                            CACHE_DIR,
         | 
| 127 | 
            +
                            "audioldm-l-full.ckpt",
         | 
| 128 | 
            +
                        ),
         | 
| 129 | 
            +
                        "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
         | 
| 130 | 
            +
                    },
         | 
| 131 | 
            +
                    "audioldm-s-full-v2": {
         | 
| 132 | 
            +
                        "path": os.path.join(
         | 
| 133 | 
            +
                            CACHE_DIR,
         | 
| 134 | 
            +
                            "audioldm-s-full-v2.ckpt",
         | 
| 135 | 
            +
                        ),
         | 
| 136 | 
            +
                        "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
         | 
| 137 | 
            +
                    },
         | 
| 138 | 
            +
                    "audioldm-m-text-ft": {
         | 
| 139 | 
            +
                        "path": os.path.join(
         | 
| 140 | 
            +
                            CACHE_DIR,
         | 
| 141 | 
            +
                            "audioldm-m-text-ft.ckpt",
         | 
| 142 | 
            +
                        ),
         | 
| 143 | 
            +
                        "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
         | 
| 144 | 
            +
                    },
         | 
| 145 | 
            +
                    "audioldm-s-text-ft": {
         | 
| 146 | 
            +
                        "path": os.path.join(
         | 
| 147 | 
            +
                            CACHE_DIR,
         | 
| 148 | 
            +
                            "audioldm-s-text-ft.ckpt",
         | 
| 149 | 
            +
                        ),
         | 
| 150 | 
            +
                        "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
         | 
| 151 | 
            +
                    },
         | 
| 152 | 
            +
                    "audioldm-m-full": {
         | 
| 153 | 
            +
                        "path": os.path.join(
         | 
| 154 | 
            +
                            CACHE_DIR,
         | 
| 155 | 
            +
                            "audioldm-m-full.ckpt",
         | 
| 156 | 
            +
                        ),
         | 
| 157 | 
            +
                        "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
         | 
| 158 | 
            +
                    },
         | 
| 159 | 
            +
                }
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def get_obj_from_str(string, reload=False):
         | 
| 163 | 
            +
                module, cls = string.rsplit(".", 1)
         | 
| 164 | 
            +
                if reload:
         | 
| 165 | 
            +
                    module_imp = importlib.import_module(module)
         | 
| 166 | 
            +
                    importlib.reload(module_imp)
         | 
| 167 | 
            +
                return getattr(importlib.import_module(module, package=None), cls)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def instantiate_from_config(config):
         | 
| 171 | 
            +
                if not "target" in config:
         | 
| 172 | 
            +
                    if config == "__is_first_stage__":
         | 
| 173 | 
            +
                        return None
         | 
| 174 | 
            +
                    elif config == "__is_unconditional__":
         | 
| 175 | 
            +
                        return None
         | 
| 176 | 
            +
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 177 | 
            +
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
    	
        audioldm/variational_autoencoder/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .autoencoder import AutoencoderKL
         | 
    	
        audioldm/variational_autoencoder/autoencoder.py
    ADDED
    
    | @@ -0,0 +1,131 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from audioldm.variational_autoencoder.modules import Encoder, Decoder
         | 
| 5 | 
            +
            from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
         | 
| 6 | 
            +
            from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class AutoencoderKL(nn.Module):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    ddconfig=None,
         | 
| 13 | 
            +
                    lossconfig=None,
         | 
| 14 | 
            +
                    image_key="fbank",
         | 
| 15 | 
            +
                    embed_dim=None,
         | 
| 16 | 
            +
                    time_shuffle=1,
         | 
| 17 | 
            +
                    subband=1,
         | 
| 18 | 
            +
                    ckpt_path=None,
         | 
| 19 | 
            +
                    reload_from_ckpt=None,
         | 
| 20 | 
            +
                    ignore_keys=[],
         | 
| 21 | 
            +
                    colorize_nlabels=None,
         | 
| 22 | 
            +
                    monitor=None,
         | 
| 23 | 
            +
                    base_learning_rate=1e-5,
         | 
| 24 | 
            +
                    scale_factor=1
         | 
| 25 | 
            +
                ):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    self.encoder = Encoder(**ddconfig)
         | 
| 29 | 
            +
                    self.decoder = Decoder(**ddconfig)
         | 
| 30 | 
            +
                    self.ema_decoder = None
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    self.subband = int(subband)
         | 
| 33 | 
            +
                    if self.subband > 1:
         | 
| 34 | 
            +
                        print("Use subband decomposition %s" % self.subband)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
         | 
| 37 | 
            +
                    self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
         | 
| 38 | 
            +
                    self.ema_post_quant_conv = None
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.vocoder = get_vocoder(None, "cpu")
         | 
| 41 | 
            +
                    self.embed_dim = embed_dim
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if monitor is not None:
         | 
| 44 | 
            +
                        self.monitor = monitor
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.time_shuffle = time_shuffle
         | 
| 47 | 
            +
                    self.reload_from_ckpt = reload_from_ckpt
         | 
| 48 | 
            +
                    self.reloaded = False
         | 
| 49 | 
            +
                    self.mean, self.std = None, None
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.scale_factor = scale_factor
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                @property
         | 
| 54 | 
            +
                def device(self):
         | 
| 55 | 
            +
                    return next(self.parameters()).device
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def freq_split_subband(self, fbank):
         | 
| 58 | 
            +
                    if self.subband == 1 or self.image_key != "stft":
         | 
| 59 | 
            +
                        return fbank
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    bs, ch, tstep, fbins = fbank.size()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    assert fbank.size(-1) % self.subband == 0
         | 
| 64 | 
            +
                    assert ch == 1
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    return (
         | 
| 67 | 
            +
                        fbank.squeeze(1)
         | 
| 68 | 
            +
                        .reshape(bs, tstep, self.subband, fbins // self.subband)
         | 
| 69 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def freq_merge_subband(self, subband_fbank):
         | 
| 73 | 
            +
                    if self.subband == 1 or self.image_key != "stft":
         | 
| 74 | 
            +
                        return subband_fbank
         | 
| 75 | 
            +
                    assert subband_fbank.size(1) == self.subband  # Channel dimension
         | 
| 76 | 
            +
                    bs, sub_ch, tstep, fbins = subband_fbank.size()
         | 
| 77 | 
            +
                    return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def encode(self, x):
         | 
| 80 | 
            +
                    x = self.freq_split_subband(x)
         | 
| 81 | 
            +
                    h = self.encoder(x)
         | 
| 82 | 
            +
                    moments = self.quant_conv(h)
         | 
| 83 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 84 | 
            +
                    return posterior
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                @torch.no_grad()
         | 
| 87 | 
            +
                def encode_first_stage(self, x):
         | 
| 88 | 
            +
                    return self.encode(x)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def decode(self, z, use_ema=False):
         | 
| 91 | 
            +
                    if use_ema and (not hasattr(self, 'ema_decoder') or self.ema_decoder is None):
         | 
| 92 | 
            +
                        print("VAE does not have EMA modules, but specified use_ema. "
         | 
| 93 | 
            +
                              "Using the none-EMA modules instead.")
         | 
| 94 | 
            +
                    if use_ema and hasattr(self, 'ema_decoder') and self.ema_decoder is not None:
         | 
| 95 | 
            +
                        z = self.ema_post_quant_conv(z)
         | 
| 96 | 
            +
                        dec = self.ema_decoder(z)
         | 
| 97 | 
            +
                    else:
         | 
| 98 | 
            +
                        z = self.post_quant_conv(z)
         | 
| 99 | 
            +
                        dec = self.decoder(z)
         | 
| 100 | 
            +
                    return self.freq_merge_subband(dec)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def decode_first_stage(self, z, allow_grad=False, use_ema=False):
         | 
| 103 | 
            +
                    with torch.set_grad_enabled(allow_grad):
         | 
| 104 | 
            +
                        z = z / self.scale_factor
         | 
| 105 | 
            +
                        return self.decode(z, use_ema)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def decode_to_waveform(self, dec, allow_grad=False):
         | 
| 108 | 
            +
                    dec = dec.squeeze(1).permute(0, 2, 1)
         | 
| 109 | 
            +
                    wav_reconstruction = vocoder_infer(dec, self.vocoder, allow_grad=allow_grad)
         | 
| 110 | 
            +
                    return wav_reconstruction
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, input, sample_posterior=True):
         | 
| 113 | 
            +
                    posterior = self.encode(input)
         | 
| 114 | 
            +
                    z = posterior.sample() if sample_posterior else posterior.mode()
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    if self.flag_first_run:
         | 
| 117 | 
            +
                        print("Latent size: ", z.size())
         | 
| 118 | 
            +
                        self.flag_first_run = False
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    return self.decode(z), posterior
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def get_first_stage_encoding(self, encoder_posterior):
         | 
| 123 | 
            +
                    if isinstance(encoder_posterior, DiagonalGaussianDistribution):
         | 
| 124 | 
            +
                        z = encoder_posterior.sample()
         | 
| 125 | 
            +
                    elif isinstance(encoder_posterior, torch.Tensor):
         | 
| 126 | 
            +
                        z = encoder_posterior
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        raise NotImplementedError(
         | 
| 129 | 
            +
                            f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
                    return self.scale_factor * z
         | 
    	
        audioldm/variational_autoencoder/distributions.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class AbstractDistribution:
         | 
| 6 | 
            +
                def sample(self):
         | 
| 7 | 
            +
                    raise NotImplementedError()
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def mode(self):
         | 
| 10 | 
            +
                    raise NotImplementedError()
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class DiracDistribution(AbstractDistribution):
         | 
| 14 | 
            +
                def __init__(self, value):
         | 
| 15 | 
            +
                    self.value = value
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def sample(self):
         | 
| 18 | 
            +
                    return self.value
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def mode(self):
         | 
| 21 | 
            +
                    return self.value
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class DiagonalGaussianDistribution(object):
         | 
| 25 | 
            +
                def __init__(self, parameters, deterministic=False):
         | 
| 26 | 
            +
                    self.parameters = parameters
         | 
| 27 | 
            +
                    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
         | 
| 28 | 
            +
                    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
         | 
| 29 | 
            +
                    self.deterministic = deterministic
         | 
| 30 | 
            +
                    self.std = torch.exp(0.5 * self.logvar)
         | 
| 31 | 
            +
                    self.var = torch.exp(self.logvar)
         | 
| 32 | 
            +
                    if self.deterministic:
         | 
| 33 | 
            +
                        self.var = self.std = torch.zeros_like(self.mean).to(
         | 
| 34 | 
            +
                            device=self.parameters.device
         | 
| 35 | 
            +
                        )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def sample(self):
         | 
| 38 | 
            +
                    x = self.mean + self.std * torch.randn(self.mean.shape).to(
         | 
| 39 | 
            +
                        device=self.parameters.device
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    return x
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def kl(self, other=None):
         | 
| 44 | 
            +
                    if self.deterministic:
         | 
| 45 | 
            +
                        return torch.Tensor([0.0])
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        if other is None:
         | 
| 48 | 
            +
                            return 0.5 * torch.mean(
         | 
| 49 | 
            +
                                torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
         | 
| 50 | 
            +
                                dim=[1, 2, 3],
         | 
| 51 | 
            +
                            )
         | 
| 52 | 
            +
                        else:
         | 
| 53 | 
            +
                            return 0.5 * torch.mean(
         | 
| 54 | 
            +
                                torch.pow(self.mean - other.mean, 2) / other.var
         | 
| 55 | 
            +
                                + self.var / other.var
         | 
| 56 | 
            +
                                - 1.0
         | 
| 57 | 
            +
                                - self.logvar
         | 
| 58 | 
            +
                                + other.logvar,
         | 
| 59 | 
            +
                                dim=[1, 2, 3],
         | 
| 60 | 
            +
                            )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def nll(self, sample, dims=[1, 2, 3]):
         | 
| 63 | 
            +
                    if self.deterministic:
         | 
| 64 | 
            +
                        return torch.Tensor([0.0])
         | 
| 65 | 
            +
                    logtwopi = np.log(2.0 * np.pi)
         | 
| 66 | 
            +
                    return 0.5 * torch.sum(
         | 
| 67 | 
            +
                        logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
         | 
| 68 | 
            +
                        dim=dims,
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def mode(self):
         | 
| 72 | 
            +
                    return self.mean
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def normal_kl(mean1, logvar1, mean2, logvar2):
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
         | 
| 78 | 
            +
                Compute the KL divergence between two gaussians.
         | 
| 79 | 
            +
                Shapes are automatically broadcasted, so batches can be compared to
         | 
| 80 | 
            +
                scalars, among other use cases.
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                tensor = None
         | 
| 83 | 
            +
                for obj in (mean1, logvar1, mean2, logvar2):
         | 
| 84 | 
            +
                    if isinstance(obj, torch.Tensor):
         | 
| 85 | 
            +
                        tensor = obj
         | 
| 86 | 
            +
                        break
         | 
| 87 | 
            +
                assert tensor is not None, "at least one argument must be a Tensor"
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                # Force variances to be Tensors. Broadcasting helps convert scalars to
         | 
| 90 | 
            +
                # Tensors, but it does not work for torch.exp().
         | 
| 91 | 
            +
                logvar1, logvar2 = [
         | 
| 92 | 
            +
                    x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
         | 
| 93 | 
            +
                    for x in (logvar1, logvar2)
         | 
| 94 | 
            +
                ]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                return 0.5 * (
         | 
| 97 | 
            +
                    -1.0
         | 
| 98 | 
            +
                    + logvar2
         | 
| 99 | 
            +
                    - logvar1
         | 
| 100 | 
            +
                    + torch.exp(logvar1 - logvar2)
         | 
| 101 | 
            +
                    + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
         | 
| 102 | 
            +
                )
         | 
    	
        audioldm/variational_autoencoder/modules.py
    ADDED
    
    | @@ -0,0 +1,1067 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # pytorch_diffusion + derived encoder decoder
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from audioldm.utils import instantiate_from_config
         | 
| 9 | 
            +
            from audioldm.latent_diffusion.attention import LinearAttention
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_timestep_embedding(timesteps, embedding_dim):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models:
         | 
| 15 | 
            +
                From Fairseq.
         | 
| 16 | 
            +
                Build sinusoidal embeddings.
         | 
| 17 | 
            +
                This matches the implementation in tensor2tensor, but differs slightly
         | 
| 18 | 
            +
                from the description in Section 3.5 of "Attention Is All You Need".
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                assert len(timesteps.shape) == 1
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                half_dim = embedding_dim // 2
         | 
| 23 | 
            +
                emb = math.log(10000) / (half_dim - 1)
         | 
| 24 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
         | 
| 25 | 
            +
                emb = emb.to(device=timesteps.device)
         | 
| 26 | 
            +
                emb = timesteps.float()[:, None] * emb[None, :]
         | 
| 27 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 28 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 29 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 30 | 
            +
                return emb
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def nonlinearity(x):
         | 
| 34 | 
            +
                # swish
         | 
| 35 | 
            +
                return x * torch.sigmoid(x)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def Normalize(in_channels, num_groups=32):
         | 
| 39 | 
            +
                return torch.nn.GroupNorm(
         | 
| 40 | 
            +
                    num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Upsample(nn.Module):
         | 
| 45 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.with_conv = with_conv
         | 
| 48 | 
            +
                    if self.with_conv:
         | 
| 49 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 50 | 
            +
                            in_channels, in_channels, kernel_size=3, stride=1, padding=1
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def forward(self, x):
         | 
| 54 | 
            +
                    x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 55 | 
            +
                    if self.with_conv:
         | 
| 56 | 
            +
                        x = self.conv(x)
         | 
| 57 | 
            +
                    return x
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class UpsampleTimeStride4(nn.Module):
         | 
| 61 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.with_conv = with_conv
         | 
| 64 | 
            +
                    if self.with_conv:
         | 
| 65 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 66 | 
            +
                            in_channels, in_channels, kernel_size=5, stride=1, padding=2
         | 
| 67 | 
            +
                        )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def forward(self, x):
         | 
| 70 | 
            +
                    x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
         | 
| 71 | 
            +
                    if self.with_conv:
         | 
| 72 | 
            +
                        x = self.conv(x)
         | 
| 73 | 
            +
                    return x
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            class Downsample(nn.Module):
         | 
| 77 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 78 | 
            +
                    super().__init__()
         | 
| 79 | 
            +
                    self.with_conv = with_conv
         | 
| 80 | 
            +
                    if self.with_conv:
         | 
| 81 | 
            +
                        # Do time downsampling here
         | 
| 82 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 83 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 84 | 
            +
                            in_channels, in_channels, kernel_size=3, stride=2, padding=0
         | 
| 85 | 
            +
                        )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(self, x):
         | 
| 88 | 
            +
                    if self.with_conv:
         | 
| 89 | 
            +
                        pad = (0, 1, 0, 1)
         | 
| 90 | 
            +
                        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
         | 
| 91 | 
            +
                        x = self.conv(x)
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
         | 
| 94 | 
            +
                    return x
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            class DownsampleTimeStride4(nn.Module):
         | 
| 98 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 99 | 
            +
                    super().__init__()
         | 
| 100 | 
            +
                    self.with_conv = with_conv
         | 
| 101 | 
            +
                    if self.with_conv:
         | 
| 102 | 
            +
                        # Do time downsampling here
         | 
| 103 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 104 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 105 | 
            +
                            in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
         | 
| 106 | 
            +
                        )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, x):
         | 
| 109 | 
            +
                    if self.with_conv:
         | 
| 110 | 
            +
                        pad = (0, 1, 0, 1)
         | 
| 111 | 
            +
                        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
         | 
| 112 | 
            +
                        x = self.conv(x)
         | 
| 113 | 
            +
                    else:
         | 
| 114 | 
            +
                        x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
         | 
| 115 | 
            +
                    return x
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class ResnetBlock(nn.Module):
         | 
| 119 | 
            +
                def __init__(
         | 
| 120 | 
            +
                    self,
         | 
| 121 | 
            +
                    *,
         | 
| 122 | 
            +
                    in_channels,
         | 
| 123 | 
            +
                    out_channels=None,
         | 
| 124 | 
            +
                    conv_shortcut=False,
         | 
| 125 | 
            +
                    dropout,
         | 
| 126 | 
            +
                    temb_channels=512,
         | 
| 127 | 
            +
                ):
         | 
| 128 | 
            +
                    super().__init__()
         | 
| 129 | 
            +
                    self.in_channels = in_channels
         | 
| 130 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 131 | 
            +
                    self.out_channels = out_channels
         | 
| 132 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    self.norm1 = Normalize(in_channels)
         | 
| 135 | 
            +
                    self.conv1 = torch.nn.Conv2d(
         | 
| 136 | 
            +
                        in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
                    if temb_channels > 0:
         | 
| 139 | 
            +
                        self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
         | 
| 140 | 
            +
                    self.norm2 = Normalize(out_channels)
         | 
| 141 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 142 | 
            +
                    self.conv2 = torch.nn.Conv2d(
         | 
| 143 | 
            +
                        out_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 146 | 
            +
                        if self.use_conv_shortcut:
         | 
| 147 | 
            +
                            self.conv_shortcut = torch.nn.Conv2d(
         | 
| 148 | 
            +
                                in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 149 | 
            +
                            )
         | 
| 150 | 
            +
                        else:
         | 
| 151 | 
            +
                            self.nin_shortcut = torch.nn.Conv2d(
         | 
| 152 | 
            +
                                in_channels, out_channels, kernel_size=1, stride=1, padding=0
         | 
| 153 | 
            +
                            )
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def forward(self, x, temb):
         | 
| 156 | 
            +
                    h = x
         | 
| 157 | 
            +
                    h = self.norm1(h)
         | 
| 158 | 
            +
                    h = nonlinearity(h)
         | 
| 159 | 
            +
                    h = self.conv1(h)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    if temb is not None:
         | 
| 162 | 
            +
                        h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    h = self.norm2(h)
         | 
| 165 | 
            +
                    h = nonlinearity(h)
         | 
| 166 | 
            +
                    h = self.dropout(h)
         | 
| 167 | 
            +
                    h = self.conv2(h)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 170 | 
            +
                        if self.use_conv_shortcut:
         | 
| 171 | 
            +
                            x = self.conv_shortcut(x)
         | 
| 172 | 
            +
                        else:
         | 
| 173 | 
            +
                            x = self.nin_shortcut(x)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    return x + h
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class LinAttnBlock(LinearAttention):
         | 
| 179 | 
            +
                """to match AttnBlock usage"""
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def __init__(self, in_channels):
         | 
| 182 | 
            +
                    super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            class AttnBlock(nn.Module):
         | 
| 186 | 
            +
                def __init__(self, in_channels):
         | 
| 187 | 
            +
                    super().__init__()
         | 
| 188 | 
            +
                    self.in_channels = in_channels
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 191 | 
            +
                    self.q = torch.nn.Conv2d(
         | 
| 192 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
                    self.k = torch.nn.Conv2d(
         | 
| 195 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 196 | 
            +
                    )
         | 
| 197 | 
            +
                    self.v = torch.nn.Conv2d(
         | 
| 198 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
                    self.proj_out = torch.nn.Conv2d(
         | 
| 201 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 202 | 
            +
                    )
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def forward(self, x):
         | 
| 205 | 
            +
                    h_ = x
         | 
| 206 | 
            +
                    h_ = self.norm(h_)
         | 
| 207 | 
            +
                    q = self.q(h_)
         | 
| 208 | 
            +
                    k = self.k(h_)
         | 
| 209 | 
            +
                    v = self.v(h_)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # compute attention
         | 
| 212 | 
            +
                    b, c, h, w = q.shape
         | 
| 213 | 
            +
                    q = q.reshape(b, c, h * w).contiguous()
         | 
| 214 | 
            +
                    q = q.permute(0, 2, 1).contiguous()  # b,hw,c
         | 
| 215 | 
            +
                    k = k.reshape(b, c, h * w).contiguous()  # b,c,hw
         | 
| 216 | 
            +
                    w_ = torch.bmm(q, k).contiguous()  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
         | 
| 217 | 
            +
                    w_ = w_ * (int(c) ** (-0.5))
         | 
| 218 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # attend to values
         | 
| 221 | 
            +
                    v = v.reshape(b, c, h * w).contiguous()
         | 
| 222 | 
            +
                    w_ = w_.permute(0, 2, 1).contiguous()  # b,hw,hw (first hw of k, second of q)
         | 
| 223 | 
            +
                    h_ = torch.bmm(
         | 
| 224 | 
            +
                        v, w_
         | 
| 225 | 
            +
                    ).contiguous()  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
         | 
| 226 | 
            +
                    h_ = h_.reshape(b, c, h, w).contiguous()
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    return x + h_
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            def make_attn(in_channels, attn_type="vanilla"):
         | 
| 234 | 
            +
                assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
         | 
| 235 | 
            +
                # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
         | 
| 236 | 
            +
                if attn_type == "vanilla":
         | 
| 237 | 
            +
                    return AttnBlock(in_channels)
         | 
| 238 | 
            +
                elif attn_type == "none":
         | 
| 239 | 
            +
                    return nn.Identity(in_channels)
         | 
| 240 | 
            +
                else:
         | 
| 241 | 
            +
                    return LinAttnBlock(in_channels)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            class Model(nn.Module):
         | 
| 245 | 
            +
                def __init__(
         | 
| 246 | 
            +
                    self,
         | 
| 247 | 
            +
                    *,
         | 
| 248 | 
            +
                    ch,
         | 
| 249 | 
            +
                    out_ch,
         | 
| 250 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 251 | 
            +
                    num_res_blocks,
         | 
| 252 | 
            +
                    attn_resolutions,
         | 
| 253 | 
            +
                    dropout=0.0,
         | 
| 254 | 
            +
                    resamp_with_conv=True,
         | 
| 255 | 
            +
                    in_channels,
         | 
| 256 | 
            +
                    resolution,
         | 
| 257 | 
            +
                    use_timestep=True,
         | 
| 258 | 
            +
                    use_linear_attn=False,
         | 
| 259 | 
            +
                    attn_type="vanilla",
         | 
| 260 | 
            +
                ):
         | 
| 261 | 
            +
                    super().__init__()
         | 
| 262 | 
            +
                    if use_linear_attn:
         | 
| 263 | 
            +
                        attn_type = "linear"
         | 
| 264 | 
            +
                    self.ch = ch
         | 
| 265 | 
            +
                    self.temb_ch = self.ch * 4
         | 
| 266 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 267 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 268 | 
            +
                    self.resolution = resolution
         | 
| 269 | 
            +
                    self.in_channels = in_channels
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    self.use_timestep = use_timestep
         | 
| 272 | 
            +
                    if self.use_timestep:
         | 
| 273 | 
            +
                        # timestep embedding
         | 
| 274 | 
            +
                        self.temb = nn.Module()
         | 
| 275 | 
            +
                        self.temb.dense = nn.ModuleList(
         | 
| 276 | 
            +
                            [
         | 
| 277 | 
            +
                                torch.nn.Linear(self.ch, self.temb_ch),
         | 
| 278 | 
            +
                                torch.nn.Linear(self.temb_ch, self.temb_ch),
         | 
| 279 | 
            +
                            ]
         | 
| 280 | 
            +
                        )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # downsampling
         | 
| 283 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 284 | 
            +
                        in_channels, self.ch, kernel_size=3, stride=1, padding=1
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    curr_res = resolution
         | 
| 288 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 289 | 
            +
                    self.down = nn.ModuleList()
         | 
| 290 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 291 | 
            +
                        block = nn.ModuleList()
         | 
| 292 | 
            +
                        attn = nn.ModuleList()
         | 
| 293 | 
            +
                        block_in = ch * in_ch_mult[i_level]
         | 
| 294 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 295 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 296 | 
            +
                            block.append(
         | 
| 297 | 
            +
                                ResnetBlock(
         | 
| 298 | 
            +
                                    in_channels=block_in,
         | 
| 299 | 
            +
                                    out_channels=block_out,
         | 
| 300 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 301 | 
            +
                                    dropout=dropout,
         | 
| 302 | 
            +
                                )
         | 
| 303 | 
            +
                            )
         | 
| 304 | 
            +
                            block_in = block_out
         | 
| 305 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 306 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 307 | 
            +
                        down = nn.Module()
         | 
| 308 | 
            +
                        down.block = block
         | 
| 309 | 
            +
                        down.attn = attn
         | 
| 310 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 311 | 
            +
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 312 | 
            +
                            curr_res = curr_res // 2
         | 
| 313 | 
            +
                        self.down.append(down)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # middle
         | 
| 316 | 
            +
                    self.mid = nn.Module()
         | 
| 317 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 318 | 
            +
                        in_channels=block_in,
         | 
| 319 | 
            +
                        out_channels=block_in,
         | 
| 320 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 321 | 
            +
                        dropout=dropout,
         | 
| 322 | 
            +
                    )
         | 
| 323 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 324 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 325 | 
            +
                        in_channels=block_in,
         | 
| 326 | 
            +
                        out_channels=block_in,
         | 
| 327 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 328 | 
            +
                        dropout=dropout,
         | 
| 329 | 
            +
                    )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # upsampling
         | 
| 332 | 
            +
                    self.up = nn.ModuleList()
         | 
| 333 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 334 | 
            +
                        block = nn.ModuleList()
         | 
| 335 | 
            +
                        attn = nn.ModuleList()
         | 
| 336 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 337 | 
            +
                        skip_in = ch * ch_mult[i_level]
         | 
| 338 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 339 | 
            +
                            if i_block == self.num_res_blocks:
         | 
| 340 | 
            +
                                skip_in = ch * in_ch_mult[i_level]
         | 
| 341 | 
            +
                            block.append(
         | 
| 342 | 
            +
                                ResnetBlock(
         | 
| 343 | 
            +
                                    in_channels=block_in + skip_in,
         | 
| 344 | 
            +
                                    out_channels=block_out,
         | 
| 345 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 346 | 
            +
                                    dropout=dropout,
         | 
| 347 | 
            +
                                )
         | 
| 348 | 
            +
                            )
         | 
| 349 | 
            +
                            block_in = block_out
         | 
| 350 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 351 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 352 | 
            +
                        up = nn.Module()
         | 
| 353 | 
            +
                        up.block = block
         | 
| 354 | 
            +
                        up.attn = attn
         | 
| 355 | 
            +
                        if i_level != 0:
         | 
| 356 | 
            +
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 357 | 
            +
                            curr_res = curr_res * 2
         | 
| 358 | 
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # end
         | 
| 361 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 362 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 363 | 
            +
                        block_in, out_ch, kernel_size=3, stride=1, padding=1
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def forward(self, x, t=None, context=None):
         | 
| 367 | 
            +
                    # assert x.shape[2] == x.shape[3] == self.resolution
         | 
| 368 | 
            +
                    if context is not None:
         | 
| 369 | 
            +
                        # assume aligned context, cat along channel axis
         | 
| 370 | 
            +
                        x = torch.cat((x, context), dim=1)
         | 
| 371 | 
            +
                    if self.use_timestep:
         | 
| 372 | 
            +
                        # timestep embedding
         | 
| 373 | 
            +
                        assert t is not None
         | 
| 374 | 
            +
                        temb = get_timestep_embedding(t, self.ch)
         | 
| 375 | 
            +
                        temb = self.temb.dense[0](temb)
         | 
| 376 | 
            +
                        temb = nonlinearity(temb)
         | 
| 377 | 
            +
                        temb = self.temb.dense[1](temb)
         | 
| 378 | 
            +
                    else:
         | 
| 379 | 
            +
                        temb = None
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    # downsampling
         | 
| 382 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 383 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 384 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 385 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 386 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 387 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 388 | 
            +
                            hs.append(h)
         | 
| 389 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 390 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    # middle
         | 
| 393 | 
            +
                    h = hs[-1]
         | 
| 394 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 395 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 396 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    # upsampling
         | 
| 399 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 400 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 401 | 
            +
                            h = self.up[i_level].block[i_block](
         | 
| 402 | 
            +
                                torch.cat([h, hs.pop()], dim=1), temb
         | 
| 403 | 
            +
                            )
         | 
| 404 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 405 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 406 | 
            +
                        if i_level != 0:
         | 
| 407 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    # end
         | 
| 410 | 
            +
                    h = self.norm_out(h)
         | 
| 411 | 
            +
                    h = nonlinearity(h)
         | 
| 412 | 
            +
                    h = self.conv_out(h)
         | 
| 413 | 
            +
                    return h
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                def get_last_layer(self):
         | 
| 416 | 
            +
                    return self.conv_out.weight
         | 
| 417 | 
            +
             | 
| 418 | 
            +
             | 
| 419 | 
            +
            class Encoder(nn.Module):
         | 
| 420 | 
            +
                def __init__(
         | 
| 421 | 
            +
                    self,
         | 
| 422 | 
            +
                    *,
         | 
| 423 | 
            +
                    ch,
         | 
| 424 | 
            +
                    out_ch,
         | 
| 425 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 426 | 
            +
                    num_res_blocks,
         | 
| 427 | 
            +
                    attn_resolutions,
         | 
| 428 | 
            +
                    dropout=0.0,
         | 
| 429 | 
            +
                    resamp_with_conv=True,
         | 
| 430 | 
            +
                    in_channels,
         | 
| 431 | 
            +
                    resolution,
         | 
| 432 | 
            +
                    z_channels,
         | 
| 433 | 
            +
                    double_z=True,
         | 
| 434 | 
            +
                    use_linear_attn=False,
         | 
| 435 | 
            +
                    attn_type="vanilla",
         | 
| 436 | 
            +
                    downsample_time_stride4_levels=[],
         | 
| 437 | 
            +
                    **ignore_kwargs,
         | 
| 438 | 
            +
                ):
         | 
| 439 | 
            +
                    super().__init__()
         | 
| 440 | 
            +
                    if use_linear_attn:
         | 
| 441 | 
            +
                        attn_type = "linear"
         | 
| 442 | 
            +
                    self.ch = ch
         | 
| 443 | 
            +
                    self.temb_ch = 0
         | 
| 444 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 445 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 446 | 
            +
                    self.resolution = resolution
         | 
| 447 | 
            +
                    self.in_channels = in_channels
         | 
| 448 | 
            +
                    self.downsample_time_stride4_levels = downsample_time_stride4_levels
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    if len(self.downsample_time_stride4_levels) > 0:
         | 
| 451 | 
            +
                        assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
         | 
| 452 | 
            +
                            "The level to perform downsample 4 operation need to be smaller than "
         | 
| 453 | 
            +
                            "the total resolution number %s" % str(self.num_resolutions)
         | 
| 454 | 
            +
                        )
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    # downsampling
         | 
| 457 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 458 | 
            +
                        in_channels, self.ch, kernel_size=3, stride=1, padding=1
         | 
| 459 | 
            +
                    )
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    curr_res = resolution
         | 
| 462 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 463 | 
            +
                    self.in_ch_mult = in_ch_mult
         | 
| 464 | 
            +
                    self.down = nn.ModuleList()
         | 
| 465 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 466 | 
            +
                        block = nn.ModuleList()
         | 
| 467 | 
            +
                        attn = nn.ModuleList()
         | 
| 468 | 
            +
                        block_in = ch * in_ch_mult[i_level]
         | 
| 469 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 470 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 471 | 
            +
                            block.append(
         | 
| 472 | 
            +
                                ResnetBlock(
         | 
| 473 | 
            +
                                    in_channels=block_in,
         | 
| 474 | 
            +
                                    out_channels=block_out,
         | 
| 475 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 476 | 
            +
                                    dropout=dropout,
         | 
| 477 | 
            +
                                )
         | 
| 478 | 
            +
                            )
         | 
| 479 | 
            +
                            block_in = block_out
         | 
| 480 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 481 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 482 | 
            +
                        down = nn.Module()
         | 
| 483 | 
            +
                        down.block = block
         | 
| 484 | 
            +
                        down.attn = attn
         | 
| 485 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 486 | 
            +
                            if i_level in self.downsample_time_stride4_levels:
         | 
| 487 | 
            +
                                down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
         | 
| 488 | 
            +
                            else:
         | 
| 489 | 
            +
                                down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 490 | 
            +
                            curr_res = curr_res // 2
         | 
| 491 | 
            +
                        self.down.append(down)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    # middle
         | 
| 494 | 
            +
                    self.mid = nn.Module()
         | 
| 495 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 496 | 
            +
                        in_channels=block_in,
         | 
| 497 | 
            +
                        out_channels=block_in,
         | 
| 498 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 499 | 
            +
                        dropout=dropout,
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 502 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 503 | 
            +
                        in_channels=block_in,
         | 
| 504 | 
            +
                        out_channels=block_in,
         | 
| 505 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 506 | 
            +
                        dropout=dropout,
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # end
         | 
| 510 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 511 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 512 | 
            +
                        block_in,
         | 
| 513 | 
            +
                        2 * z_channels if double_z else z_channels,
         | 
| 514 | 
            +
                        kernel_size=3,
         | 
| 515 | 
            +
                        stride=1,
         | 
| 516 | 
            +
                        padding=1,
         | 
| 517 | 
            +
                    )
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                def forward(self, x):
         | 
| 520 | 
            +
                    # timestep embedding
         | 
| 521 | 
            +
                    temb = None
         | 
| 522 | 
            +
                    # downsampling
         | 
| 523 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 524 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 525 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 526 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 527 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 528 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 529 | 
            +
                            hs.append(h)
         | 
| 530 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 531 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # middle
         | 
| 534 | 
            +
                    h = hs[-1]
         | 
| 535 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 536 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 537 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # end
         | 
| 540 | 
            +
                    h = self.norm_out(h)
         | 
| 541 | 
            +
                    h = nonlinearity(h)
         | 
| 542 | 
            +
                    h = self.conv_out(h)
         | 
| 543 | 
            +
                    return h
         | 
| 544 | 
            +
             | 
| 545 | 
            +
             | 
| 546 | 
            +
            class Decoder(nn.Module):
         | 
| 547 | 
            +
                def __init__(
         | 
| 548 | 
            +
                    self,
         | 
| 549 | 
            +
                    *,
         | 
| 550 | 
            +
                    ch,
         | 
| 551 | 
            +
                    out_ch,
         | 
| 552 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 553 | 
            +
                    num_res_blocks,
         | 
| 554 | 
            +
                    attn_resolutions,
         | 
| 555 | 
            +
                    dropout=0.0,
         | 
| 556 | 
            +
                    resamp_with_conv=True,
         | 
| 557 | 
            +
                    in_channels,
         | 
| 558 | 
            +
                    resolution,
         | 
| 559 | 
            +
                    z_channels,
         | 
| 560 | 
            +
                    give_pre_end=False,
         | 
| 561 | 
            +
                    tanh_out=False,
         | 
| 562 | 
            +
                    use_linear_attn=False,
         | 
| 563 | 
            +
                    downsample_time_stride4_levels=[],
         | 
| 564 | 
            +
                    attn_type="vanilla",
         | 
| 565 | 
            +
                    **ignorekwargs,
         | 
| 566 | 
            +
                ):
         | 
| 567 | 
            +
                    super().__init__()
         | 
| 568 | 
            +
                    if use_linear_attn:
         | 
| 569 | 
            +
                        attn_type = "linear"
         | 
| 570 | 
            +
                    self.ch = ch
         | 
| 571 | 
            +
                    self.temb_ch = 0
         | 
| 572 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 573 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 574 | 
            +
                    self.resolution = resolution
         | 
| 575 | 
            +
                    self.in_channels = in_channels
         | 
| 576 | 
            +
                    self.give_pre_end = give_pre_end
         | 
| 577 | 
            +
                    self.tanh_out = tanh_out
         | 
| 578 | 
            +
                    self.downsample_time_stride4_levels = downsample_time_stride4_levels
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    if len(self.downsample_time_stride4_levels) > 0:
         | 
| 581 | 
            +
                        assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
         | 
| 582 | 
            +
                            "The level to perform downsample 4 operation need to be smaller than "
         | 
| 583 | 
            +
                            "the total resolution number %s" % str(self.num_resolutions)
         | 
| 584 | 
            +
                        )
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    # compute in_ch_mult, block_in and curr_res at lowest res
         | 
| 587 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 588 | 
            +
                    block_in = ch * ch_mult[self.num_resolutions - 1]
         | 
| 589 | 
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 590 | 
            +
                    self.z_shape = (1, z_channels, curr_res, curr_res)
         | 
| 591 | 
            +
                    # print("Working with z of shape {} = {} dimensions.".format(
         | 
| 592 | 
            +
                    # self.z_shape, np.prod(self.z_shape)))
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                    # z to block_in
         | 
| 595 | 
            +
                    self.conv_in = torch.nn.Conv2d(
         | 
| 596 | 
            +
                        z_channels, block_in, kernel_size=3, stride=1, padding=1
         | 
| 597 | 
            +
                    )
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    # middle
         | 
| 600 | 
            +
                    self.mid = nn.Module()
         | 
| 601 | 
            +
                    self.mid.block_1 = ResnetBlock(
         | 
| 602 | 
            +
                        in_channels=block_in,
         | 
| 603 | 
            +
                        out_channels=block_in,
         | 
| 604 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 605 | 
            +
                        dropout=dropout,
         | 
| 606 | 
            +
                    )
         | 
| 607 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 608 | 
            +
                    self.mid.block_2 = ResnetBlock(
         | 
| 609 | 
            +
                        in_channels=block_in,
         | 
| 610 | 
            +
                        out_channels=block_in,
         | 
| 611 | 
            +
                        temb_channels=self.temb_ch,
         | 
| 612 | 
            +
                        dropout=dropout,
         | 
| 613 | 
            +
                    )
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    # upsampling
         | 
| 616 | 
            +
                    self.up = nn.ModuleList()
         | 
| 617 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 618 | 
            +
                        block = nn.ModuleList()
         | 
| 619 | 
            +
                        attn = nn.ModuleList()
         | 
| 620 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 621 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 622 | 
            +
                            block.append(
         | 
| 623 | 
            +
                                ResnetBlock(
         | 
| 624 | 
            +
                                    in_channels=block_in,
         | 
| 625 | 
            +
                                    out_channels=block_out,
         | 
| 626 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 627 | 
            +
                                    dropout=dropout,
         | 
| 628 | 
            +
                                )
         | 
| 629 | 
            +
                            )
         | 
| 630 | 
            +
                            block_in = block_out
         | 
| 631 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 632 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 633 | 
            +
                        up = nn.Module()
         | 
| 634 | 
            +
                        up.block = block
         | 
| 635 | 
            +
                        up.attn = attn
         | 
| 636 | 
            +
                        if i_level != 0:
         | 
| 637 | 
            +
                            if i_level - 1 in self.downsample_time_stride4_levels:
         | 
| 638 | 
            +
                                up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
         | 
| 639 | 
            +
                            else:
         | 
| 640 | 
            +
                                up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 641 | 
            +
                            curr_res = curr_res * 2
         | 
| 642 | 
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                    # end
         | 
| 645 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 646 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 647 | 
            +
                        block_in, out_ch, kernel_size=3, stride=1, padding=1
         | 
| 648 | 
            +
                    )
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                def forward(self, z):
         | 
| 651 | 
            +
                    # assert z.shape[1:] == self.z_shape[1:]
         | 
| 652 | 
            +
                    self.last_z_shape = z.shape
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    # timestep embedding
         | 
| 655 | 
            +
                    temb = None
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    # z to block_in
         | 
| 658 | 
            +
                    h = self.conv_in(z)
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    # middle
         | 
| 661 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 662 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 663 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                    # upsampling
         | 
| 666 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 667 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 668 | 
            +
                            h = self.up[i_level].block[i_block](h.float(), temb)
         | 
| 669 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 670 | 
            +
                                h = self.up[i_level].attn[i_block](h.float())
         | 
| 671 | 
            +
                        if i_level != 0:
         | 
| 672 | 
            +
                            h = self.up[i_level].upsample(h.float())
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    # end
         | 
| 675 | 
            +
                    if self.give_pre_end:
         | 
| 676 | 
            +
                        return h
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    h = self.norm_out(h)
         | 
| 679 | 
            +
                    h = nonlinearity(h)
         | 
| 680 | 
            +
                    h = self.conv_out(h)
         | 
| 681 | 
            +
                    if self.tanh_out:
         | 
| 682 | 
            +
                        h = torch.tanh(h)
         | 
| 683 | 
            +
                    return h
         | 
| 684 | 
            +
             | 
| 685 | 
            +
             | 
| 686 | 
            +
            class SimpleDecoder(nn.Module):
         | 
| 687 | 
            +
                def __init__(self, in_channels, out_channels, *args, **kwargs):
         | 
| 688 | 
            +
                    super().__init__()
         | 
| 689 | 
            +
                    self.model = nn.ModuleList(
         | 
| 690 | 
            +
                        [
         | 
| 691 | 
            +
                            nn.Conv2d(in_channels, in_channels, 1),
         | 
| 692 | 
            +
                            ResnetBlock(
         | 
| 693 | 
            +
                                in_channels=in_channels,
         | 
| 694 | 
            +
                                out_channels=2 * in_channels,
         | 
| 695 | 
            +
                                temb_channels=0,
         | 
| 696 | 
            +
                                dropout=0.0,
         | 
| 697 | 
            +
                            ),
         | 
| 698 | 
            +
                            ResnetBlock(
         | 
| 699 | 
            +
                                in_channels=2 * in_channels,
         | 
| 700 | 
            +
                                out_channels=4 * in_channels,
         | 
| 701 | 
            +
                                temb_channels=0,
         | 
| 702 | 
            +
                                dropout=0.0,
         | 
| 703 | 
            +
                            ),
         | 
| 704 | 
            +
                            ResnetBlock(
         | 
| 705 | 
            +
                                in_channels=4 * in_channels,
         | 
| 706 | 
            +
                                out_channels=2 * in_channels,
         | 
| 707 | 
            +
                                temb_channels=0,
         | 
| 708 | 
            +
                                dropout=0.0,
         | 
| 709 | 
            +
                            ),
         | 
| 710 | 
            +
                            nn.Conv2d(2 * in_channels, in_channels, 1),
         | 
| 711 | 
            +
                            Upsample(in_channels, with_conv=True),
         | 
| 712 | 
            +
                        ]
         | 
| 713 | 
            +
                    )
         | 
| 714 | 
            +
                    # end
         | 
| 715 | 
            +
                    self.norm_out = Normalize(in_channels)
         | 
| 716 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 717 | 
            +
                        in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 718 | 
            +
                    )
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                def forward(self, x):
         | 
| 721 | 
            +
                    for i, layer in enumerate(self.model):
         | 
| 722 | 
            +
                        if i in [1, 2, 3]:
         | 
| 723 | 
            +
                            x = layer(x, None)
         | 
| 724 | 
            +
                        else:
         | 
| 725 | 
            +
                            x = layer(x)
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    h = self.norm_out(x)
         | 
| 728 | 
            +
                    h = nonlinearity(h)
         | 
| 729 | 
            +
                    x = self.conv_out(h)
         | 
| 730 | 
            +
                    return x
         | 
| 731 | 
            +
             | 
| 732 | 
            +
             | 
| 733 | 
            +
            class UpsampleDecoder(nn.Module):
         | 
| 734 | 
            +
                def __init__(
         | 
| 735 | 
            +
                    self,
         | 
| 736 | 
            +
                    in_channels,
         | 
| 737 | 
            +
                    out_channels,
         | 
| 738 | 
            +
                    ch,
         | 
| 739 | 
            +
                    num_res_blocks,
         | 
| 740 | 
            +
                    resolution,
         | 
| 741 | 
            +
                    ch_mult=(2, 2),
         | 
| 742 | 
            +
                    dropout=0.0,
         | 
| 743 | 
            +
                ):
         | 
| 744 | 
            +
                    super().__init__()
         | 
| 745 | 
            +
                    # upsampling
         | 
| 746 | 
            +
                    self.temb_ch = 0
         | 
| 747 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 748 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 749 | 
            +
                    block_in = in_channels
         | 
| 750 | 
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 751 | 
            +
                    self.res_blocks = nn.ModuleList()
         | 
| 752 | 
            +
                    self.upsample_blocks = nn.ModuleList()
         | 
| 753 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 754 | 
            +
                        res_block = []
         | 
| 755 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 756 | 
            +
                        for _ in range(self.num_res_blocks + 1):
         | 
| 757 | 
            +
                            res_block.append(
         | 
| 758 | 
            +
                                ResnetBlock(
         | 
| 759 | 
            +
                                    in_channels=block_in,
         | 
| 760 | 
            +
                                    out_channels=block_out,
         | 
| 761 | 
            +
                                    temb_channels=self.temb_ch,
         | 
| 762 | 
            +
                                    dropout=dropout,
         | 
| 763 | 
            +
                                )
         | 
| 764 | 
            +
                            )
         | 
| 765 | 
            +
                            block_in = block_out
         | 
| 766 | 
            +
                        self.res_blocks.append(nn.ModuleList(res_block))
         | 
| 767 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 768 | 
            +
                            self.upsample_blocks.append(Upsample(block_in, True))
         | 
| 769 | 
            +
                            curr_res = curr_res * 2
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    # end
         | 
| 772 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 773 | 
            +
                    self.conv_out = torch.nn.Conv2d(
         | 
| 774 | 
            +
                        block_in, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 775 | 
            +
                    )
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                def forward(self, x):
         | 
| 778 | 
            +
                    # upsampling
         | 
| 779 | 
            +
                    h = x
         | 
| 780 | 
            +
                    for k, i_level in enumerate(range(self.num_resolutions)):
         | 
| 781 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 782 | 
            +
                            h = self.res_blocks[i_level][i_block](h, None)
         | 
| 783 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 784 | 
            +
                            h = self.upsample_blocks[k](h)
         | 
| 785 | 
            +
                    h = self.norm_out(h)
         | 
| 786 | 
            +
                    h = nonlinearity(h)
         | 
| 787 | 
            +
                    h = self.conv_out(h)
         | 
| 788 | 
            +
                    return h
         | 
| 789 | 
            +
             | 
| 790 | 
            +
             | 
| 791 | 
            +
            class LatentRescaler(nn.Module):
         | 
| 792 | 
            +
                def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
         | 
| 793 | 
            +
                    super().__init__()
         | 
| 794 | 
            +
                    # residual block, interpolate, residual block
         | 
| 795 | 
            +
                    self.factor = factor
         | 
| 796 | 
            +
                    self.conv_in = nn.Conv2d(
         | 
| 797 | 
            +
                        in_channels, mid_channels, kernel_size=3, stride=1, padding=1
         | 
| 798 | 
            +
                    )
         | 
| 799 | 
            +
                    self.res_block1 = nn.ModuleList(
         | 
| 800 | 
            +
                        [
         | 
| 801 | 
            +
                            ResnetBlock(
         | 
| 802 | 
            +
                                in_channels=mid_channels,
         | 
| 803 | 
            +
                                out_channels=mid_channels,
         | 
| 804 | 
            +
                                temb_channels=0,
         | 
| 805 | 
            +
                                dropout=0.0,
         | 
| 806 | 
            +
                            )
         | 
| 807 | 
            +
                            for _ in range(depth)
         | 
| 808 | 
            +
                        ]
         | 
| 809 | 
            +
                    )
         | 
| 810 | 
            +
                    self.attn = AttnBlock(mid_channels)
         | 
| 811 | 
            +
                    self.res_block2 = nn.ModuleList(
         | 
| 812 | 
            +
                        [
         | 
| 813 | 
            +
                            ResnetBlock(
         | 
| 814 | 
            +
                                in_channels=mid_channels,
         | 
| 815 | 
            +
                                out_channels=mid_channels,
         | 
| 816 | 
            +
                                temb_channels=0,
         | 
| 817 | 
            +
                                dropout=0.0,
         | 
| 818 | 
            +
                            )
         | 
| 819 | 
            +
                            for _ in range(depth)
         | 
| 820 | 
            +
                        ]
         | 
| 821 | 
            +
                    )
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                    self.conv_out = nn.Conv2d(
         | 
| 824 | 
            +
                        mid_channels,
         | 
| 825 | 
            +
                        out_channels,
         | 
| 826 | 
            +
                        kernel_size=1,
         | 
| 827 | 
            +
                    )
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                def forward(self, x):
         | 
| 830 | 
            +
                    x = self.conv_in(x)
         | 
| 831 | 
            +
                    for block in self.res_block1:
         | 
| 832 | 
            +
                        x = block(x, None)
         | 
| 833 | 
            +
                    x = torch.nn.functional.interpolate(
         | 
| 834 | 
            +
                        x,
         | 
| 835 | 
            +
                        size=(
         | 
| 836 | 
            +
                            int(round(x.shape[2] * self.factor)),
         | 
| 837 | 
            +
                            int(round(x.shape[3] * self.factor)),
         | 
| 838 | 
            +
                        ),
         | 
| 839 | 
            +
                    )
         | 
| 840 | 
            +
                    x = self.attn(x).contiguous()
         | 
| 841 | 
            +
                    for block in self.res_block2:
         | 
| 842 | 
            +
                        x = block(x, None)
         | 
| 843 | 
            +
                    x = self.conv_out(x)
         | 
| 844 | 
            +
                    return x
         | 
| 845 | 
            +
             | 
| 846 | 
            +
             | 
| 847 | 
            +
            class MergedRescaleEncoder(nn.Module):
         | 
| 848 | 
            +
                def __init__(
         | 
| 849 | 
            +
                    self,
         | 
| 850 | 
            +
                    in_channels,
         | 
| 851 | 
            +
                    ch,
         | 
| 852 | 
            +
                    resolution,
         | 
| 853 | 
            +
                    out_ch,
         | 
| 854 | 
            +
                    num_res_blocks,
         | 
| 855 | 
            +
                    attn_resolutions,
         | 
| 856 | 
            +
                    dropout=0.0,
         | 
| 857 | 
            +
                    resamp_with_conv=True,
         | 
| 858 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 859 | 
            +
                    rescale_factor=1.0,
         | 
| 860 | 
            +
                    rescale_module_depth=1,
         | 
| 861 | 
            +
                ):
         | 
| 862 | 
            +
                    super().__init__()
         | 
| 863 | 
            +
                    intermediate_chn = ch * ch_mult[-1]
         | 
| 864 | 
            +
                    self.encoder = Encoder(
         | 
| 865 | 
            +
                        in_channels=in_channels,
         | 
| 866 | 
            +
                        num_res_blocks=num_res_blocks,
         | 
| 867 | 
            +
                        ch=ch,
         | 
| 868 | 
            +
                        ch_mult=ch_mult,
         | 
| 869 | 
            +
                        z_channels=intermediate_chn,
         | 
| 870 | 
            +
                        double_z=False,
         | 
| 871 | 
            +
                        resolution=resolution,
         | 
| 872 | 
            +
                        attn_resolutions=attn_resolutions,
         | 
| 873 | 
            +
                        dropout=dropout,
         | 
| 874 | 
            +
                        resamp_with_conv=resamp_with_conv,
         | 
| 875 | 
            +
                        out_ch=None,
         | 
| 876 | 
            +
                    )
         | 
| 877 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 878 | 
            +
                        factor=rescale_factor,
         | 
| 879 | 
            +
                        in_channels=intermediate_chn,
         | 
| 880 | 
            +
                        mid_channels=intermediate_chn,
         | 
| 881 | 
            +
                        out_channels=out_ch,
         | 
| 882 | 
            +
                        depth=rescale_module_depth,
         | 
| 883 | 
            +
                    )
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                def forward(self, x):
         | 
| 886 | 
            +
                    x = self.encoder(x)
         | 
| 887 | 
            +
                    x = self.rescaler(x)
         | 
| 888 | 
            +
                    return x
         | 
| 889 | 
            +
             | 
| 890 | 
            +
             | 
| 891 | 
            +
            class MergedRescaleDecoder(nn.Module):
         | 
| 892 | 
            +
                def __init__(
         | 
| 893 | 
            +
                    self,
         | 
| 894 | 
            +
                    z_channels,
         | 
| 895 | 
            +
                    out_ch,
         | 
| 896 | 
            +
                    resolution,
         | 
| 897 | 
            +
                    num_res_blocks,
         | 
| 898 | 
            +
                    attn_resolutions,
         | 
| 899 | 
            +
                    ch,
         | 
| 900 | 
            +
                    ch_mult=(1, 2, 4, 8),
         | 
| 901 | 
            +
                    dropout=0.0,
         | 
| 902 | 
            +
                    resamp_with_conv=True,
         | 
| 903 | 
            +
                    rescale_factor=1.0,
         | 
| 904 | 
            +
                    rescale_module_depth=1,
         | 
| 905 | 
            +
                ):
         | 
| 906 | 
            +
                    super().__init__()
         | 
| 907 | 
            +
                    tmp_chn = z_channels * ch_mult[-1]
         | 
| 908 | 
            +
                    self.decoder = Decoder(
         | 
| 909 | 
            +
                        out_ch=out_ch,
         | 
| 910 | 
            +
                        z_channels=tmp_chn,
         | 
| 911 | 
            +
                        attn_resolutions=attn_resolutions,
         | 
| 912 | 
            +
                        dropout=dropout,
         | 
| 913 | 
            +
                        resamp_with_conv=resamp_with_conv,
         | 
| 914 | 
            +
                        in_channels=None,
         | 
| 915 | 
            +
                        num_res_blocks=num_res_blocks,
         | 
| 916 | 
            +
                        ch_mult=ch_mult,
         | 
| 917 | 
            +
                        resolution=resolution,
         | 
| 918 | 
            +
                        ch=ch,
         | 
| 919 | 
            +
                    )
         | 
| 920 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 921 | 
            +
                        factor=rescale_factor,
         | 
| 922 | 
            +
                        in_channels=z_channels,
         | 
| 923 | 
            +
                        mid_channels=tmp_chn,
         | 
| 924 | 
            +
                        out_channels=tmp_chn,
         | 
| 925 | 
            +
                        depth=rescale_module_depth,
         | 
| 926 | 
            +
                    )
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                def forward(self, x):
         | 
| 929 | 
            +
                    x = self.rescaler(x)
         | 
| 930 | 
            +
                    x = self.decoder(x)
         | 
| 931 | 
            +
                    return x
         | 
| 932 | 
            +
             | 
| 933 | 
            +
             | 
| 934 | 
            +
            class Upsampler(nn.Module):
         | 
| 935 | 
            +
                def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
         | 
| 936 | 
            +
                    super().__init__()
         | 
| 937 | 
            +
                    assert out_size >= in_size
         | 
| 938 | 
            +
                    num_blocks = int(np.log2(out_size // in_size)) + 1
         | 
| 939 | 
            +
                    factor_up = 1.0 + (out_size % in_size)
         | 
| 940 | 
            +
                    print(
         | 
| 941 | 
            +
                        f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
         | 
| 942 | 
            +
                    )
         | 
| 943 | 
            +
                    self.rescaler = LatentRescaler(
         | 
| 944 | 
            +
                        factor=factor_up,
         | 
| 945 | 
            +
                        in_channels=in_channels,
         | 
| 946 | 
            +
                        mid_channels=2 * in_channels,
         | 
| 947 | 
            +
                        out_channels=in_channels,
         | 
| 948 | 
            +
                    )
         | 
| 949 | 
            +
                    self.decoder = Decoder(
         | 
| 950 | 
            +
                        out_ch=out_channels,
         | 
| 951 | 
            +
                        resolution=out_size,
         | 
| 952 | 
            +
                        z_channels=in_channels,
         | 
| 953 | 
            +
                        num_res_blocks=2,
         | 
| 954 | 
            +
                        attn_resolutions=[],
         | 
| 955 | 
            +
                        in_channels=None,
         | 
| 956 | 
            +
                        ch=in_channels,
         | 
| 957 | 
            +
                        ch_mult=[ch_mult for _ in range(num_blocks)],
         | 
| 958 | 
            +
                    )
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                def forward(self, x):
         | 
| 961 | 
            +
                    x = self.rescaler(x)
         | 
| 962 | 
            +
                    x = self.decoder(x)
         | 
| 963 | 
            +
                    return x
         | 
| 964 | 
            +
             | 
| 965 | 
            +
             | 
| 966 | 
            +
            class Resize(nn.Module):
         | 
| 967 | 
            +
                def __init__(self, in_channels=None, learned=False, mode="bilinear"):
         | 
| 968 | 
            +
                    super().__init__()
         | 
| 969 | 
            +
                    self.with_conv = learned
         | 
| 970 | 
            +
                    self.mode = mode
         | 
| 971 | 
            +
                    if self.with_conv:
         | 
| 972 | 
            +
                        print(
         | 
| 973 | 
            +
                            f"Note: {self.__class__.__name} uses learned downsampling "
         | 
| 974 | 
            +
                            f"and will ignore the fixed {mode} mode"
         | 
| 975 | 
            +
                        )
         | 
| 976 | 
            +
                        raise NotImplementedError()
         | 
| 977 | 
            +
                        assert in_channels is not None
         | 
| 978 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 979 | 
            +
                        self.conv = torch.nn.Conv2d(
         | 
| 980 | 
            +
                            in_channels, in_channels, kernel_size=4, stride=2, padding=1
         | 
| 981 | 
            +
                        )
         | 
| 982 | 
            +
             | 
| 983 | 
            +
                def forward(self, x, scale_factor=1.0):
         | 
| 984 | 
            +
                    if scale_factor == 1.0:
         | 
| 985 | 
            +
                        return x
         | 
| 986 | 
            +
                    else:
         | 
| 987 | 
            +
                        x = torch.nn.functional.interpolate(
         | 
| 988 | 
            +
                            x, mode=self.mode, align_corners=False, scale_factor=scale_factor
         | 
| 989 | 
            +
                        )
         | 
| 990 | 
            +
                    return x
         | 
| 991 | 
            +
             | 
| 992 | 
            +
             | 
| 993 | 
            +
            class FirstStagePostProcessor(nn.Module):
         | 
| 994 | 
            +
                def __init__(
         | 
| 995 | 
            +
                    self,
         | 
| 996 | 
            +
                    ch_mult: list,
         | 
| 997 | 
            +
                    in_channels,
         | 
| 998 | 
            +
                    pretrained_model: nn.Module = None,
         | 
| 999 | 
            +
                    reshape=False,
         | 
| 1000 | 
            +
                    n_channels=None,
         | 
| 1001 | 
            +
                    dropout=0.0,
         | 
| 1002 | 
            +
                    pretrained_config=None,
         | 
| 1003 | 
            +
                ):
         | 
| 1004 | 
            +
                    super().__init__()
         | 
| 1005 | 
            +
                    if pretrained_config is None:
         | 
| 1006 | 
            +
                        assert (
         | 
| 1007 | 
            +
                            pretrained_model is not None
         | 
| 1008 | 
            +
                        ), 'Either "pretrained_model" or "pretrained_config" must not be None'
         | 
| 1009 | 
            +
                        self.pretrained_model = pretrained_model
         | 
| 1010 | 
            +
                    else:
         | 
| 1011 | 
            +
                        assert (
         | 
| 1012 | 
            +
                            pretrained_config is not None
         | 
| 1013 | 
            +
                        ), 'Either "pretrained_model" or "pretrained_config" must not be None'
         | 
| 1014 | 
            +
                        self.instantiate_pretrained(pretrained_config)
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    self.do_reshape = reshape
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                    if n_channels is None:
         | 
| 1019 | 
            +
                        n_channels = self.pretrained_model.encoder.ch
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                    self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
         | 
| 1022 | 
            +
                    self.proj = nn.Conv2d(
         | 
| 1023 | 
            +
                        in_channels, n_channels, kernel_size=3, stride=1, padding=1
         | 
| 1024 | 
            +
                    )
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                    blocks = []
         | 
| 1027 | 
            +
                    downs = []
         | 
| 1028 | 
            +
                    ch_in = n_channels
         | 
| 1029 | 
            +
                    for m in ch_mult:
         | 
| 1030 | 
            +
                        blocks.append(
         | 
| 1031 | 
            +
                            ResnetBlock(
         | 
| 1032 | 
            +
                                in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
         | 
| 1033 | 
            +
                            )
         | 
| 1034 | 
            +
                        )
         | 
| 1035 | 
            +
                        ch_in = m * n_channels
         | 
| 1036 | 
            +
                        downs.append(Downsample(ch_in, with_conv=False))
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                    self.model = nn.ModuleList(blocks)
         | 
| 1039 | 
            +
                    self.downsampler = nn.ModuleList(downs)
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                def instantiate_pretrained(self, config):
         | 
| 1042 | 
            +
                    model = instantiate_from_config(config)
         | 
| 1043 | 
            +
                    self.pretrained_model = model.eval()
         | 
| 1044 | 
            +
                    # self.pretrained_model.train = False
         | 
| 1045 | 
            +
                    for param in self.pretrained_model.parameters():
         | 
| 1046 | 
            +
                        param.requires_grad = False
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                @torch.no_grad()
         | 
| 1049 | 
            +
                def encode_with_pretrained(self, x):
         | 
| 1050 | 
            +
                    c = self.pretrained_model.encode(x)
         | 
| 1051 | 
            +
                    if isinstance(c, DiagonalGaussianDistribution):
         | 
| 1052 | 
            +
                        c = c.mode()
         | 
| 1053 | 
            +
                    return c
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                def forward(self, x):
         | 
| 1056 | 
            +
                    z_fs = self.encode_with_pretrained(x)
         | 
| 1057 | 
            +
                    z = self.proj_norm(z_fs)
         | 
| 1058 | 
            +
                    z = self.proj(z)
         | 
| 1059 | 
            +
                    z = nonlinearity(z)
         | 
| 1060 | 
            +
             | 
| 1061 | 
            +
                    for submodel, downmodel in zip(self.model, self.downsampler):
         | 
| 1062 | 
            +
                        z = submodel(z, temb=None)
         | 
| 1063 | 
            +
                        z = downmodel(z)
         | 
| 1064 | 
            +
             | 
| 1065 | 
            +
                    if self.do_reshape:
         | 
| 1066 | 
            +
                        z = rearrange(z, "b c h w -> b (h w) c")
         | 
| 1067 | 
            +
                    return z
         | 
    	
        consistencytta.py
    ADDED
    
    | @@ -0,0 +1,200 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn, Tensor
         | 
| 3 | 
            +
            from transformers import AutoTokenizer, T5EncoderModel
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 6 | 
            +
            from diffusers import UNet2DConditionGuidedModel, HeunDiscreteScheduler
         | 
| 7 | 
            +
            from audioldm.stft import TacotronSTFT
         | 
| 8 | 
            +
            from audioldm.variational_autoencoder import AutoencoderKL
         | 
| 9 | 
            +
            from audioldm.utils import default_audioldm_config
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ConsistencyTTA(nn.Module):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                def __init__(self):
         | 
| 15 | 
            +
                    super().__init__()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    # Initialize the consistency U-Net
         | 
| 18 | 
            +
                    unet_model_config_path='tango_diffusion_light.json'
         | 
| 19 | 
            +
                    unet_config = UNet2DConditionGuidedModel.load_config(unet_model_config_path)
         | 
| 20 | 
            +
                    self.unet = UNet2DConditionGuidedModel.from_config(unet_config, subfolder="unet")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    unet_weight_path = "consistencytta_clapft_ckpt/unet_state_dict.pt"
         | 
| 23 | 
            +
                    unet_weight_sd = torch.load(unet_weight_path, map_location='cpu')
         | 
| 24 | 
            +
                    self.unet.load_state_dict(unet_weight_sd)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # Initialize FLAN-T5 tokenizer and text encoder
         | 
| 27 | 
            +
                    text_encoder_name = 'google/flan-t5-large'
         | 
| 28 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
         | 
| 29 | 
            +
                    self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name)
         | 
| 30 | 
            +
                    self.text_encoder.eval(); self.text_encoder.requires_grad_(False)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    # Initialize the VAE
         | 
| 33 | 
            +
                    raw_vae_path = "consistencytta_clapft_ckpt/vae_state_dict.pt"
         | 
| 34 | 
            +
                    raw_vae_sd = torch.load(raw_vae_path, map_location="cpu")
         | 
| 35 | 
            +
                    vae_state_dict, scale_factor = raw_vae_sd["state_dict"], raw_vae_sd["scale_factor"]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    config = default_audioldm_config('audioldm-s-full')
         | 
| 38 | 
            +
                    vae_config = config["model"]["params"]["first_stage_config"]["params"]
         | 
| 39 | 
            +
                    vae_config["scale_factor"] = scale_factor
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.vae = AutoencoderKL(**vae_config)
         | 
| 42 | 
            +
                    self.vae.load_state_dict(vae_state_dict)
         | 
| 43 | 
            +
                    self.vae.eval(); self.vae.requires_grad_(False)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # Initialize the STFT
         | 
| 46 | 
            +
                    self.fn_STFT = TacotronSTFT(
         | 
| 47 | 
            +
                        config["preprocessing"]["stft"]["filter_length"],  # default 1024
         | 
| 48 | 
            +
                        config["preprocessing"]["stft"]["hop_length"],  # default 160
         | 
| 49 | 
            +
                        config["preprocessing"]["stft"]["win_length"],  # default 1024
         | 
| 50 | 
            +
                        config["preprocessing"]["mel"]["n_mel_channels"],  # default 64
         | 
| 51 | 
            +
                        config["preprocessing"]["audio"]["sampling_rate"],  # default 16000
         | 
| 52 | 
            +
                        config["preprocessing"]["mel"]["mel_fmin"],  # default 0
         | 
| 53 | 
            +
                        config["preprocessing"]["mel"]["mel_fmax"],  # default 8000
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    self.fn_STFT.eval(); self.fn_STFT.requires_grad_(False)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    self.scheduler = HeunDiscreteScheduler.from_pretrained(
         | 
| 58 | 
            +
                        pretrained_model_name_or_path='stabilityai/stable-diffusion-2-1', subfolder="scheduler"
         | 
| 59 | 
            +
                    )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
                def train(self, mode: bool = True):
         | 
| 63 | 
            +
                    self.unet.train(mode)
         | 
| 64 | 
            +
                    for model in [self.text_encoder, self.vae, self.fn_STFT]:
         | 
| 65 | 
            +
                        model.eval()
         | 
| 66 | 
            +
                    return self
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
                def eval(self):
         | 
| 70 | 
            +
                    return self.train(mode=False)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
                def check_eval_mode(self):
         | 
| 74 | 
            +
                    for model, name in zip(
         | 
| 75 | 
            +
                        [self.text_encoder, self.vae, self.fn_STFT, self.unet],
         | 
| 76 | 
            +
                        ['text_encoder', 'vae', 'fn_STFT', 'unet']
         | 
| 77 | 
            +
                    ):
         | 
| 78 | 
            +
                        assert model.training == False, f"The {name} is not in eval mode."
         | 
| 79 | 
            +
                        for param in model.parameters():
         | 
| 80 | 
            +
                            assert param.requires_grad == False, f"The {name} is not frozen."
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
                @torch.no_grad()
         | 
| 84 | 
            +
                def encode_text(self, prompt, max_length=None, padding=True):
         | 
| 85 | 
            +
                    device = self.text_encoder.device
         | 
| 86 | 
            +
                    if max_length is None:
         | 
| 87 | 
            +
                        max_length = self.tokenizer.model_max_length
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    batch = self.tokenizer(
         | 
| 90 | 
            +
                        prompt, max_length=max_length, padding=padding,
         | 
| 91 | 
            +
                        truncation=True, return_tensors="pt"
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    input_ids = batch.input_ids.to(device)
         | 
| 94 | 
            +
                    attention_mask = batch.attention_mask.to(device)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    prompt_embeds = self.text_encoder(
         | 
| 97 | 
            +
                        input_ids=input_ids, attention_mask=attention_mask
         | 
| 98 | 
            +
                    )[0]
         | 
| 99 | 
            +
                    bool_prompt_mask = (attention_mask == 1).to(device)  # Convert to boolean
         | 
| 100 | 
            +
                    return prompt_embeds, bool_prompt_mask
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
                @torch.no_grad()
         | 
| 104 | 
            +
                def encode_text_classifier_free(self, prompt: str, num_samples_per_prompt: int):
         | 
| 105 | 
            +
                    # get conditional embeddings
         | 
| 106 | 
            +
                    cond_prompt_embeds, cond_prompt_mask = self.encode_text(prompt)
         | 
| 107 | 
            +
                    cond_prompt_embeds = cond_prompt_embeds.repeat_interleave(
         | 
| 108 | 
            +
                        num_samples_per_prompt, 0
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                    cond_prompt_mask = cond_prompt_mask.repeat_interleave(
         | 
| 111 | 
            +
                        num_samples_per_prompt, 0
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # get unconditional embeddings for classifier free guidance
         | 
| 115 | 
            +
                    uncond_tokens = [""] * len(prompt)
         | 
| 116 | 
            +
                    negative_prompt_embeds, uncond_prompt_mask = self.encode_text(
         | 
| 117 | 
            +
                        uncond_tokens, max_length=cond_prompt_embeds.shape[1], padding="max_length"
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                    negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
         | 
| 120 | 
            +
                        num_samples_per_prompt, 0
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
                    uncond_prompt_mask = uncond_prompt_mask.repeat_interleave(
         | 
| 123 | 
            +
                        num_samples_per_prompt, 0
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    """ For classifier-free guidance, we need to do two forward passes.
         | 
| 127 | 
            +
                        We concatenate the unconditional and text embeddings into a single batch 
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    prompt_embeds = torch.cat([negative_prompt_embeds, cond_prompt_embeds])
         | 
| 130 | 
            +
                    prompt_mask = torch.cat([uncond_prompt_mask, cond_prompt_mask])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    return prompt_embeds, prompt_mask, cond_prompt_embeds, cond_prompt_mask
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
                def forward(
         | 
| 136 | 
            +
                    self, prompt: str, cfg_scale_input: float = 3., cfg_scale_post: float = 1.,
         | 
| 137 | 
            +
                    num_steps: int = 1, num_samples: int = 1, sr: int = 16000
         | 
| 138 | 
            +
                ):
         | 
| 139 | 
            +
                    self.check_eval_mode()
         | 
| 140 | 
            +
                    device = self.text_encoder.device
         | 
| 141 | 
            +
                    use_cf_guidance = cfg_scale_post > 1.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # Get prompt embeddings
         | 
| 144 | 
            +
                    prompt_embeds_cf, prompt_mask_cf, prompt_embeds, prompt_mask = \
         | 
| 145 | 
            +
                        self.encode_text_classifier_free(prompt, num_samples)
         | 
| 146 | 
            +
                    encoder_states, encoder_att_mask = \
         | 
| 147 | 
            +
                        (prompt_embeds_cf, prompt_mask_cf) if use_cf_guidance \
         | 
| 148 | 
            +
                            else (prompt_embeds, prompt_mask)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Prepare noise
         | 
| 151 | 
            +
                    num_channels_latents = self.unet.config.in_channels
         | 
| 152 | 
            +
                    latent_shape = (len(prompt) * num_samples, num_channels_latents, 256, 16)
         | 
| 153 | 
            +
                    noise = randn_tensor(
         | 
| 154 | 
            +
                        latent_shape, generator=None, device=device, dtype=prompt_embeds.dtype
         | 
| 155 | 
            +
                    )
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # Query the inference scheduler to obtain the time steps.
         | 
| 158 | 
            +
                    # The time steps spread between 0 and training time steps
         | 
| 159 | 
            +
                    self.scheduler.set_timesteps(18, device=device)  # Set this to training steps first
         | 
| 160 | 
            +
                    z_N = noise * self.scheduler.init_noise_sigma
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    def calc_zhat_0(z_n: Tensor, t: int):
         | 
| 163 | 
            +
                        """ Query the consistency model to get zhat_0, which is the denoised embedding.
         | 
| 164 | 
            +
                        Args:
         | 
| 165 | 
            +
                            z_n (Tensor):   The noisy embedding.
         | 
| 166 | 
            +
                            t (int):        The time step.
         | 
| 167 | 
            +
                        Returns:
         | 
| 168 | 
            +
                            Tensor:         The denoised embedding.
         | 
| 169 | 
            +
                        """
         | 
| 170 | 
            +
                        # expand the latents if we are doing classifier free guidance
         | 
| 171 | 
            +
                        z_n_input = torch.cat([z_n] * 2) if use_cf_guidance else z_n
         | 
| 172 | 
            +
                        # Scale model input as required for some schedules.
         | 
| 173 | 
            +
                        z_n_input = self.scheduler.scale_model_input(z_n_input, t)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                        # Get zhat_0 from the model
         | 
| 176 | 
            +
                        zhat_0 = self.unet(
         | 
| 177 | 
            +
                            z_n_input, t, guidance=cfg_scale_input,
         | 
| 178 | 
            +
                            encoder_hidden_states=encoder_states, encoder_attention_mask=encoder_att_mask
         | 
| 179 | 
            +
                        ).sample
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        # Perform external classifier-free guidance
         | 
| 182 | 
            +
                        if use_cf_guidance:
         | 
| 183 | 
            +
                            zhat_0_uncond, zhat_0_cond = zhat_0.chunk(2)
         | 
| 184 | 
            +
                            zhat_0 = (1 - cfg_scale_post) * zhat_0_uncond + cfg_scale_post * zhat_0_cond
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        return zhat_0
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Query the consistency model
         | 
| 189 | 
            +
                    zhat_0 = calc_zhat_0(z_N, self.scheduler.timesteps[0])
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # Iteratively query the consistency model if requested
         | 
| 192 | 
            +
                    self.scheduler.set_timesteps(num_steps, device=device)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    for t in self.scheduler.timesteps[1::2]:  # 2 is the order of the scheduler
         | 
| 195 | 
            +
                        zhat_n = self.scheduler.add_noise(zhat_0, torch.randn_like(zhat_0), t)
         | 
| 196 | 
            +
                        # Calculate new zhat_0
         | 
| 197 | 
            +
                        zhat_0 = calc_zhat_0(zhat_n, t)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    mel = self.vae.decode_first_stage(zhat_0.float())
         | 
| 200 | 
            +
                    return self.vae.decode_to_waveform(mel)[:, :int(sr * 9.5)]  # Truncate to 9.6 seconds
         | 
    	
        consistencytta_clapft_ckpt/.DS_Store
    ADDED
    
    | Binary file (6.15 kB). View file | 
|  | 
    	
        diffusers/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .scheduling_heun_discrete import HeunDiscreteScheduler
         | 
| 2 | 
            +
            from .models.unet_2d_condition_guided import UNet2DConditionGuidedModel
         | 
    	
        diffusers/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ..utils.import_utils import is_torch_available
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            if is_torch_available():
         | 
| 19 | 
            +
                from .modeling_utils import ModelMixin
         | 
| 20 | 
            +
                from .prior_transformer import PriorTransformer
         | 
| 21 | 
            +
                from .unet_2d import UNet2DModel
         | 
| 22 | 
            +
                from .unet_2d_condition import UNet2DConditionModel
         | 
| 23 | 
            +
                from .unet_2d_condition_guided import UNet2DConditionGuidedModel
         | 
    	
        diffusers/models/activations.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch import nn
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def get_activation(act_fn):
         | 
| 5 | 
            +
                if act_fn in ["swish", "silu"]:
         | 
| 6 | 
            +
                    return nn.SiLU()
         | 
| 7 | 
            +
                elif act_fn == "mish":
         | 
| 8 | 
            +
                    return nn.Mish()
         | 
| 9 | 
            +
                elif act_fn == "gelu":
         | 
| 10 | 
            +
                    return nn.GELU()
         | 
| 11 | 
            +
                else:
         | 
| 12 | 
            +
                    raise ValueError(f"Unsupported activation function: {act_fn}")
         | 
    	
        diffusers/models/attention.py
    ADDED
    
    | @@ -0,0 +1,523 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import math
         | 
| 15 | 
            +
            from typing import Any, Callable, Dict, Optional
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn.functional as F
         | 
| 19 | 
            +
            from torch import nn
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from ..utils.import_utils import is_xformers_available
         | 
| 22 | 
            +
            from .attention_processor import Attention
         | 
| 23 | 
            +
            from .embeddings import CombinedTimestepLabelEmbeddings
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            if is_xformers_available():
         | 
| 27 | 
            +
                import xformers
         | 
| 28 | 
            +
                import xformers.ops
         | 
| 29 | 
            +
            else:
         | 
| 30 | 
            +
                xformers = None
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
         | 
| 36 | 
            +
                to the N-d case.
         | 
| 37 | 
            +
                https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
         | 
| 38 | 
            +
                Uses three q, k, v linear layers to compute attention.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Parameters:
         | 
| 41 | 
            +
                    channels (`int`): The number of channels in the input and output.
         | 
| 42 | 
            +
                    num_head_channels (`int`, *optional*):
         | 
| 43 | 
            +
                        The number of channels in each head. If None, then `num_heads` = 1.
         | 
| 44 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
         | 
| 45 | 
            +
                    rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
         | 
| 46 | 
            +
                    eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(
         | 
| 52 | 
            +
                    self,
         | 
| 53 | 
            +
                    channels: int,
         | 
| 54 | 
            +
                    num_head_channels: Optional[int] = None,
         | 
| 55 | 
            +
                    norm_num_groups: int = 32,
         | 
| 56 | 
            +
                    rescale_output_factor: float = 1.0,
         | 
| 57 | 
            +
                    eps: float = 1e-5,
         | 
| 58 | 
            +
                ):
         | 
| 59 | 
            +
                    super().__init__()
         | 
| 60 | 
            +
                    self.channels = channels
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
         | 
| 63 | 
            +
                    self.num_head_size = num_head_channels
         | 
| 64 | 
            +
                    self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # define q,k,v as linear layers
         | 
| 67 | 
            +
                    self.query = nn.Linear(channels, channels)
         | 
| 68 | 
            +
                    self.key = nn.Linear(channels, channels)
         | 
| 69 | 
            +
                    self.value = nn.Linear(channels, channels)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.rescale_output_factor = rescale_output_factor
         | 
| 72 | 
            +
                    self.proj_attn = nn.Linear(channels, channels, bias=True)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self._use_memory_efficient_attention_xformers = False
         | 
| 75 | 
            +
                    self._attention_op = None
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def reshape_heads_to_batch_dim(self, tensor):
         | 
| 78 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 79 | 
            +
                    head_size = self.num_heads
         | 
| 80 | 
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         | 
| 81 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
         | 
| 82 | 
            +
                    return tensor
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def reshape_batch_dim_to_heads(self, tensor):
         | 
| 85 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 86 | 
            +
                    head_size = self.num_heads
         | 
| 87 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 88 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         | 
| 89 | 
            +
                    return tensor
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def set_use_memory_efficient_attention_xformers(
         | 
| 92 | 
            +
                    self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
         | 
| 93 | 
            +
                ):
         | 
| 94 | 
            +
                    if use_memory_efficient_attention_xformers:
         | 
| 95 | 
            +
                        if not is_xformers_available():
         | 
| 96 | 
            +
                            raise ModuleNotFoundError(
         | 
| 97 | 
            +
                                (
         | 
| 98 | 
            +
                                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
         | 
| 99 | 
            +
                                    " xformers"
         | 
| 100 | 
            +
                                ),
         | 
| 101 | 
            +
                                name="xformers",
         | 
| 102 | 
            +
                            )
         | 
| 103 | 
            +
                        elif not torch.cuda.is_available():
         | 
| 104 | 
            +
                            raise ValueError(
         | 
| 105 | 
            +
                                "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
         | 
| 106 | 
            +
                                " only available for GPU "
         | 
| 107 | 
            +
                            )
         | 
| 108 | 
            +
                        else:
         | 
| 109 | 
            +
                            try:
         | 
| 110 | 
            +
                                # Make sure we can run the memory efficient attention
         | 
| 111 | 
            +
                                _ = xformers.ops.memory_efficient_attention(
         | 
| 112 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 113 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 114 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 115 | 
            +
                                )
         | 
| 116 | 
            +
                            except Exception as e:
         | 
| 117 | 
            +
                                raise e
         | 
| 118 | 
            +
                    self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
         | 
| 119 | 
            +
                    self._attention_op = attention_op
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def forward(self, hidden_states):
         | 
| 122 | 
            +
                    residual = hidden_states
         | 
| 123 | 
            +
                    batch, channel, height, width = hidden_states.shape
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # norm
         | 
| 126 | 
            +
                    hidden_states = self.group_norm(hidden_states)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # proj to q, k, v
         | 
| 131 | 
            +
                    query_proj = self.query(hidden_states)
         | 
| 132 | 
            +
                    key_proj = self.key(hidden_states)
         | 
| 133 | 
            +
                    value_proj = self.value(hidden_states)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    scale = 1 / math.sqrt(self.channels / self.num_heads)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    query_proj = self.reshape_heads_to_batch_dim(query_proj)
         | 
| 138 | 
            +
                    key_proj = self.reshape_heads_to_batch_dim(key_proj)
         | 
| 139 | 
            +
                    value_proj = self.reshape_heads_to_batch_dim(value_proj)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    if self._use_memory_efficient_attention_xformers:
         | 
| 142 | 
            +
                        # Memory efficient attention
         | 
| 143 | 
            +
                        hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 144 | 
            +
                            query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
         | 
| 145 | 
            +
                        )
         | 
| 146 | 
            +
                        hidden_states = hidden_states.to(query_proj.dtype)
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        attention_scores = torch.baddbmm(
         | 
| 149 | 
            +
                            torch.empty(
         | 
| 150 | 
            +
                                query_proj.shape[0],
         | 
| 151 | 
            +
                                query_proj.shape[1],
         | 
| 152 | 
            +
                                key_proj.shape[1],
         | 
| 153 | 
            +
                                dtype=query_proj.dtype,
         | 
| 154 | 
            +
                                device=query_proj.device,
         | 
| 155 | 
            +
                            ),
         | 
| 156 | 
            +
                            query_proj,
         | 
| 157 | 
            +
                            key_proj.transpose(-1, -2),
         | 
| 158 | 
            +
                            beta=0,
         | 
| 159 | 
            +
                            alpha=scale,
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                        attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
         | 
| 162 | 
            +
                        hidden_states = torch.bmm(attention_probs, value_proj)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # reshape hidden_states
         | 
| 165 | 
            +
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # compute next hidden_states
         | 
| 168 | 
            +
                    hidden_states = self.proj_attn(hidden_states)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # res connect and rescale
         | 
| 173 | 
            +
                    hidden_states = (hidden_states + residual) / self.rescale_output_factor
         | 
| 174 | 
            +
                    return hidden_states
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 178 | 
            +
                r"""
         | 
| 179 | 
            +
                A basic Transformer block.
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                Parameters:
         | 
| 182 | 
            +
                    dim (`int`): The number of channels in the input and output.
         | 
| 183 | 
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 184 | 
            +
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 185 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 186 | 
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         | 
| 187 | 
            +
                    only_cross_attention (`bool`, *optional*):
         | 
| 188 | 
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         | 
| 189 | 
            +
                    double_self_attention (`bool`, *optional*):
         | 
| 190 | 
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         | 
| 191 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 192 | 
            +
                    num_embeds_ada_norm (:
         | 
| 193 | 
            +
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         | 
| 194 | 
            +
                    attention_bias (:
         | 
| 195 | 
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def __init__(
         | 
| 199 | 
            +
                    self,
         | 
| 200 | 
            +
                    dim: int,
         | 
| 201 | 
            +
                    num_attention_heads: int,
         | 
| 202 | 
            +
                    attention_head_dim: int,
         | 
| 203 | 
            +
                    dropout=0.0,
         | 
| 204 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 205 | 
            +
                    activation_fn: str = "geglu",
         | 
| 206 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 207 | 
            +
                    attention_bias: bool = False,
         | 
| 208 | 
            +
                    only_cross_attention: bool = False,
         | 
| 209 | 
            +
                    double_self_attention: bool = False,
         | 
| 210 | 
            +
                    upcast_attention: bool = False,
         | 
| 211 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 212 | 
            +
                    norm_type: str = "layer_norm",
         | 
| 213 | 
            +
                    final_dropout: bool = False,
         | 
| 214 | 
            +
                ):
         | 
| 215 | 
            +
                    super().__init__()
         | 
| 216 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
         | 
| 219 | 
            +
                    self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
         | 
| 222 | 
            +
                        raise ValueError(
         | 
| 223 | 
            +
                            f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
         | 
| 224 | 
            +
                            f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
         | 
| 225 | 
            +
                        )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # 1. Self-Attn
         | 
| 228 | 
            +
                    self.attn1 = Attention(
         | 
| 229 | 
            +
                        query_dim=dim,
         | 
| 230 | 
            +
                        heads=num_attention_heads,
         | 
| 231 | 
            +
                        dim_head=attention_head_dim,
         | 
| 232 | 
            +
                        dropout=dropout,
         | 
| 233 | 
            +
                        bias=attention_bias,
         | 
| 234 | 
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 235 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 236 | 
            +
                    )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    # 2. Cross-Attn
         | 
| 241 | 
            +
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 242 | 
            +
                        self.attn2 = Attention(
         | 
| 243 | 
            +
                            query_dim=dim,
         | 
| 244 | 
            +
                            cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         | 
| 245 | 
            +
                            heads=num_attention_heads,
         | 
| 246 | 
            +
                            dim_head=attention_head_dim,
         | 
| 247 | 
            +
                            dropout=dropout,
         | 
| 248 | 
            +
                            bias=attention_bias,
         | 
| 249 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 250 | 
            +
                        )  # is self-attn if encoder_hidden_states is none
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        self.attn2 = None
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    if self.use_ada_layer_norm:
         | 
| 255 | 
            +
                        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 256 | 
            +
                    elif self.use_ada_layer_norm_zero:
         | 
| 257 | 
            +
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         | 
| 258 | 
            +
                    else:
         | 
| 259 | 
            +
                        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 262 | 
            +
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         | 
| 263 | 
            +
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         | 
| 264 | 
            +
                        # the second cross attention block.
         | 
| 265 | 
            +
                        self.norm2 = (
         | 
| 266 | 
            +
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 267 | 
            +
                            if self.use_ada_layer_norm
         | 
| 268 | 
            +
                            else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 269 | 
            +
                        )
         | 
| 270 | 
            +
                    else:
         | 
| 271 | 
            +
                        self.norm2 = None
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    # 3. Feed-forward
         | 
| 274 | 
            +
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                def forward(
         | 
| 277 | 
            +
                    self,
         | 
| 278 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 279 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 280 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 281 | 
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 282 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 283 | 
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 284 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 285 | 
            +
                ):
         | 
| 286 | 
            +
                    if self.use_ada_layer_norm:
         | 
| 287 | 
            +
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
| 288 | 
            +
                    elif self.use_ada_layer_norm_zero:
         | 
| 289 | 
            +
                        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
         | 
| 290 | 
            +
                            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
         | 
| 291 | 
            +
                        )
         | 
| 292 | 
            +
                    else:
         | 
| 293 | 
            +
                        norm_hidden_states = self.norm1(hidden_states)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # 1. Self-Attention
         | 
| 296 | 
            +
                    cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
         | 
| 297 | 
            +
                    attn_output = self.attn1(
         | 
| 298 | 
            +
                        norm_hidden_states,
         | 
| 299 | 
            +
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
         | 
| 300 | 
            +
                        attention_mask=attention_mask,
         | 
| 301 | 
            +
                        **cross_attention_kwargs,
         | 
| 302 | 
            +
                    )
         | 
| 303 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 304 | 
            +
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 305 | 
            +
                    hidden_states = attn_output + hidden_states
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if self.attn2 is not None:
         | 
| 308 | 
            +
                        norm_hidden_states = (
         | 
| 309 | 
            +
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
         | 
| 310 | 
            +
                        )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                        # 2. Cross-Attention
         | 
| 313 | 
            +
                        attn_output = self.attn2(
         | 
| 314 | 
            +
                            norm_hidden_states,
         | 
| 315 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 316 | 
            +
                            attention_mask=encoder_attention_mask,
         | 
| 317 | 
            +
                            **cross_attention_kwargs,
         | 
| 318 | 
            +
                        )
         | 
| 319 | 
            +
                        hidden_states = attn_output + hidden_states
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    # 3. Feed-forward
         | 
| 322 | 
            +
                    norm_hidden_states = self.norm3(hidden_states)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 325 | 
            +
                        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    ff_output = self.ff(norm_hidden_states)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 330 | 
            +
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    hidden_states = ff_output + hidden_states
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    return hidden_states
         | 
| 335 | 
            +
             | 
| 336 | 
            +
             | 
| 337 | 
            +
            class FeedForward(nn.Module):
         | 
| 338 | 
            +
                r"""
         | 
| 339 | 
            +
                A feed-forward layer.
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                Parameters:
         | 
| 342 | 
            +
                    dim (`int`): The number of channels in the input.
         | 
| 343 | 
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 344 | 
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 345 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 346 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 347 | 
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 348 | 
            +
                """
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def __init__(
         | 
| 351 | 
            +
                    self,
         | 
| 352 | 
            +
                    dim: int,
         | 
| 353 | 
            +
                    dim_out: Optional[int] = None,
         | 
| 354 | 
            +
                    mult: int = 4,
         | 
| 355 | 
            +
                    dropout: float = 0.0,
         | 
| 356 | 
            +
                    activation_fn: str = "geglu",
         | 
| 357 | 
            +
                    final_dropout: bool = False,
         | 
| 358 | 
            +
                ):
         | 
| 359 | 
            +
                    super().__init__()
         | 
| 360 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 361 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    if activation_fn == "gelu":
         | 
| 364 | 
            +
                        act_fn = GELU(dim, inner_dim)
         | 
| 365 | 
            +
                    if activation_fn == "gelu-approximate":
         | 
| 366 | 
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         | 
| 367 | 
            +
                    elif activation_fn == "geglu":
         | 
| 368 | 
            +
                        act_fn = GEGLU(dim, inner_dim)
         | 
| 369 | 
            +
                    elif activation_fn == "geglu-approximate":
         | 
| 370 | 
            +
                        act_fn = ApproximateGELU(dim, inner_dim)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 373 | 
            +
                    # project in
         | 
| 374 | 
            +
                    self.net.append(act_fn)
         | 
| 375 | 
            +
                    # project dropout
         | 
| 376 | 
            +
                    self.net.append(nn.Dropout(dropout))
         | 
| 377 | 
            +
                    # project out
         | 
| 378 | 
            +
                    self.net.append(nn.Linear(inner_dim, dim_out))
         | 
| 379 | 
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 380 | 
            +
                    if final_dropout:
         | 
| 381 | 
            +
                        self.net.append(nn.Dropout(dropout))
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                def forward(self, hidden_states):
         | 
| 384 | 
            +
                    for module in self.net:
         | 
| 385 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 386 | 
            +
                    return hidden_states
         | 
| 387 | 
            +
             | 
| 388 | 
            +
             | 
| 389 | 
            +
            class GELU(nn.Module):
         | 
| 390 | 
            +
                r"""
         | 
| 391 | 
            +
                GELU activation function with tanh approximation support with `approximate="tanh"`.
         | 
| 392 | 
            +
                """
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
         | 
| 395 | 
            +
                    super().__init__()
         | 
| 396 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         | 
| 397 | 
            +
                    self.approximate = approximate
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                def gelu(self, gate):
         | 
| 400 | 
            +
                    if gate.device.type != "mps":
         | 
| 401 | 
            +
                        return F.gelu(gate, approximate=self.approximate)
         | 
| 402 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 403 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                def forward(self, hidden_states):
         | 
| 406 | 
            +
                    hidden_states = self.proj(hidden_states)
         | 
| 407 | 
            +
                    hidden_states = self.gelu(hidden_states)
         | 
| 408 | 
            +
                    return hidden_states
         | 
| 409 | 
            +
             | 
| 410 | 
            +
             | 
| 411 | 
            +
            class GEGLU(nn.Module):
         | 
| 412 | 
            +
                r"""
         | 
| 413 | 
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                Parameters:
         | 
| 416 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 417 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 418 | 
            +
                """
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 421 | 
            +
                    super().__init__()
         | 
| 422 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                def gelu(self, gate):
         | 
| 425 | 
            +
                    if gate.device.type != "mps":
         | 
| 426 | 
            +
                        return F.gelu(gate)
         | 
| 427 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 428 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                def forward(self, hidden_states):
         | 
| 431 | 
            +
                    hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
         | 
| 432 | 
            +
                    return hidden_states * self.gelu(gate)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
             | 
| 435 | 
            +
            class ApproximateGELU(nn.Module):
         | 
| 436 | 
            +
                """
         | 
| 437 | 
            +
                The approximate form of Gaussian Error Linear Unit (GELU)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                For more details, see section 2: https://arxiv.org/abs/1606.08415
         | 
| 440 | 
            +
                """
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 443 | 
            +
                    super().__init__()
         | 
| 444 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                def forward(self, x):
         | 
| 447 | 
            +
                    x = self.proj(x)
         | 
| 448 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
             | 
| 451 | 
            +
            class AdaLayerNorm(nn.Module):
         | 
| 452 | 
            +
                """
         | 
| 453 | 
            +
                Norm layer modified to incorporate timestep embeddings.
         | 
| 454 | 
            +
                """
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                def __init__(self, embedding_dim, num_embeddings):
         | 
| 457 | 
            +
                    super().__init__()
         | 
| 458 | 
            +
                    self.emb = nn.Embedding(num_embeddings, embedding_dim)
         | 
| 459 | 
            +
                    self.silu = nn.SiLU()
         | 
| 460 | 
            +
                    self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
         | 
| 461 | 
            +
                    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                def forward(self, x, timestep):
         | 
| 464 | 
            +
                    emb = self.linear(self.silu(self.emb(timestep)))
         | 
| 465 | 
            +
                    scale, shift = torch.chunk(emb, 2)
         | 
| 466 | 
            +
                    x = self.norm(x) * (1 + scale) + shift
         | 
| 467 | 
            +
                    return x
         | 
| 468 | 
            +
             | 
| 469 | 
            +
             | 
| 470 | 
            +
            class AdaLayerNormZero(nn.Module):
         | 
| 471 | 
            +
                """
         | 
| 472 | 
            +
                Norm layer adaptive layer norm zero (adaLN-Zero).
         | 
| 473 | 
            +
                """
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                def __init__(self, embedding_dim, num_embeddings):
         | 
| 476 | 
            +
                    super().__init__()
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    self.silu = nn.SiLU()
         | 
| 481 | 
            +
                    self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
         | 
| 482 | 
            +
                    self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                def forward(self, x, timestep, class_labels, hidden_dtype=None):
         | 
| 485 | 
            +
                    emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
         | 
| 486 | 
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
         | 
| 487 | 
            +
                    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
         | 
| 488 | 
            +
                    return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
         | 
| 489 | 
            +
             | 
| 490 | 
            +
             | 
| 491 | 
            +
            class AdaGroupNorm(nn.Module):
         | 
| 492 | 
            +
                """
         | 
| 493 | 
            +
                GroupNorm layer modified to incorporate timestep embeddings.
         | 
| 494 | 
            +
                """
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                def __init__(
         | 
| 497 | 
            +
                    self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
         | 
| 498 | 
            +
                ):
         | 
| 499 | 
            +
                    super().__init__()
         | 
| 500 | 
            +
                    self.num_groups = num_groups
         | 
| 501 | 
            +
                    self.eps = eps
         | 
| 502 | 
            +
                    self.act = None
         | 
| 503 | 
            +
                    if act_fn == "swish":
         | 
| 504 | 
            +
                        self.act = lambda x: F.silu(x)
         | 
| 505 | 
            +
                    elif act_fn == "mish":
         | 
| 506 | 
            +
                        self.act = nn.Mish()
         | 
| 507 | 
            +
                    elif act_fn == "silu":
         | 
| 508 | 
            +
                        self.act = nn.SiLU()
         | 
| 509 | 
            +
                    elif act_fn == "gelu":
         | 
| 510 | 
            +
                        self.act = nn.GELU()
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    self.linear = nn.Linear(embedding_dim, out_dim * 2)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                def forward(self, x, emb):
         | 
| 515 | 
            +
                    if self.act:
         | 
| 516 | 
            +
                        emb = self.act(emb)
         | 
| 517 | 
            +
                    emb = self.linear(emb)
         | 
| 518 | 
            +
                    emb = emb[:, :, None, None]
         | 
| 519 | 
            +
                    scale, shift = emb.chunk(2, dim=1)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    x = F.group_norm(x, self.num_groups, eps=self.eps)
         | 
| 522 | 
            +
                    x = x * (1 + scale) + shift
         | 
| 523 | 
            +
                    return x
         | 
    	
        diffusers/models/attention_processor.py
    ADDED
    
    | @@ -0,0 +1,1646 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from typing import Callable, Optional, Union
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
            import torch.nn.functional as F
         | 
| 18 | 
            +
            from torch import nn
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from ..utils.deprecation_utils import deprecate
         | 
| 21 | 
            +
            from ..utils.torch_utils import maybe_allow_in_graph
         | 
| 22 | 
            +
            from ..utils.import_utils import is_xformers_available
         | 
| 23 | 
            +
            from ..utils.logging import get_logger
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            logger = get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            if is_xformers_available():
         | 
| 29 | 
            +
                import xformers
         | 
| 30 | 
            +
                import xformers.ops
         | 
| 31 | 
            +
            else:
         | 
| 32 | 
            +
                xformers = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            @maybe_allow_in_graph
         | 
| 36 | 
            +
            class Attention(nn.Module):
         | 
| 37 | 
            +
                r"""
         | 
| 38 | 
            +
                A cross attention layer.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Parameters:
         | 
| 41 | 
            +
                    query_dim (`int`): The number of channels in the query.
         | 
| 42 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 43 | 
            +
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         | 
| 44 | 
            +
                    heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.
         | 
| 45 | 
            +
                    dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.
         | 
| 46 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 47 | 
            +
                    bias (`bool`, *optional*, defaults to False):
         | 
| 48 | 
            +
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(
         | 
| 52 | 
            +
                    self,
         | 
| 53 | 
            +
                    query_dim: int,
         | 
| 54 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 55 | 
            +
                    heads: int = 8,
         | 
| 56 | 
            +
                    dim_head: int = 64,
         | 
| 57 | 
            +
                    dropout: float = 0.0,
         | 
| 58 | 
            +
                    bias=False,
         | 
| 59 | 
            +
                    upcast_attention: bool = False,
         | 
| 60 | 
            +
                    upcast_softmax: bool = False,
         | 
| 61 | 
            +
                    cross_attention_norm: Optional[str] = None,
         | 
| 62 | 
            +
                    cross_attention_norm_num_groups: int = 32,
         | 
| 63 | 
            +
                    added_kv_proj_dim: Optional[int] = None,
         | 
| 64 | 
            +
                    norm_num_groups: Optional[int] = None,
         | 
| 65 | 
            +
                    spatial_norm_dim: Optional[int] = None,
         | 
| 66 | 
            +
                    out_bias: bool = True,
         | 
| 67 | 
            +
                    scale_qk: bool = True,
         | 
| 68 | 
            +
                    only_cross_attention: bool = False,
         | 
| 69 | 
            +
                    eps: float = 1e-5,
         | 
| 70 | 
            +
                    rescale_output_factor: float = 1.0,
         | 
| 71 | 
            +
                    residual_connection: bool = False,
         | 
| 72 | 
            +
                    _from_deprecated_attn_block=False,
         | 
| 73 | 
            +
                    processor: Optional["AttnProcessor"] = None,
         | 
| 74 | 
            +
                ):
         | 
| 75 | 
            +
                    super().__init__()
         | 
| 76 | 
            +
                    inner_dim = dim_head * heads
         | 
| 77 | 
            +
                    cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 78 | 
            +
                    self.upcast_attention = upcast_attention
         | 
| 79 | 
            +
                    self.upcast_softmax = upcast_softmax
         | 
| 80 | 
            +
                    self.rescale_output_factor = rescale_output_factor
         | 
| 81 | 
            +
                    self.residual_connection = residual_connection
         | 
| 82 | 
            +
                    self.dropout = dropout
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # we make use of this private variable to know whether this class is loaded
         | 
| 85 | 
            +
                    # with an deprecated state dict so that we can convert it on the fly
         | 
| 86 | 
            +
                    self._from_deprecated_attn_block = _from_deprecated_attn_block
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.scale_qk = scale_qk
         | 
| 89 | 
            +
                    self.scale = dim_head**-0.5 if self.scale_qk else 1.0
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.heads = heads
         | 
| 92 | 
            +
                    # for slice_size > 0 the attention score computation
         | 
| 93 | 
            +
                    # is split across the batch axis to save memory
         | 
| 94 | 
            +
                    # You can set slice_size with `set_attention_slice`
         | 
| 95 | 
            +
                    self.sliceable_head_dim = heads
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         | 
| 98 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if self.added_kv_proj_dim is None and self.only_cross_attention:
         | 
| 101 | 
            +
                        raise ValueError(
         | 
| 102 | 
            +
                            "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
         | 
| 103 | 
            +
                        )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if norm_num_groups is not None:
         | 
| 106 | 
            +
                        self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        self.group_norm = None
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if spatial_norm_dim is not None:
         | 
| 111 | 
            +
                        self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
         | 
| 112 | 
            +
                    else:
         | 
| 113 | 
            +
                        self.spatial_norm = None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if cross_attention_norm is None:
         | 
| 116 | 
            +
                        self.norm_cross = None
         | 
| 117 | 
            +
                    elif cross_attention_norm == "layer_norm":
         | 
| 118 | 
            +
                        self.norm_cross = nn.LayerNorm(cross_attention_dim)
         | 
| 119 | 
            +
                    elif cross_attention_norm == "group_norm":
         | 
| 120 | 
            +
                        if self.added_kv_proj_dim is not None:
         | 
| 121 | 
            +
                            # The given `encoder_hidden_states` are initially of shape
         | 
| 122 | 
            +
                            # (batch_size, seq_len, added_kv_proj_dim) before being projected
         | 
| 123 | 
            +
                            # to (batch_size, seq_len, cross_attention_dim). The norm is applied
         | 
| 124 | 
            +
                            # before the projection, so we need to use `added_kv_proj_dim` as
         | 
| 125 | 
            +
                            # the number of channels for the group norm.
         | 
| 126 | 
            +
                            norm_cross_num_channels = added_kv_proj_dim
         | 
| 127 | 
            +
                        else:
         | 
| 128 | 
            +
                            norm_cross_num_channels = cross_attention_dim
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        self.norm_cross = nn.GroupNorm(
         | 
| 131 | 
            +
                            num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
                    else:
         | 
| 134 | 
            +
                        raise ValueError(
         | 
| 135 | 
            +
                            f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if not self.only_cross_attention:
         | 
| 141 | 
            +
                        # only relevant for the `AddedKVProcessor` classes
         | 
| 142 | 
            +
                        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         | 
| 143 | 
            +
                        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         | 
| 144 | 
            +
                    else:
         | 
| 145 | 
            +
                        self.to_k = None
         | 
| 146 | 
            +
                        self.to_v = None
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if self.added_kv_proj_dim is not None:
         | 
| 149 | 
            +
                        self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
         | 
| 150 | 
            +
                        self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 153 | 
            +
                    self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
         | 
| 154 | 
            +
                    self.to_out.append(nn.Dropout(dropout))
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # set attention processor
         | 
| 157 | 
            +
                    # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         | 
| 158 | 
            +
                    # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         | 
| 159 | 
            +
                    # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         | 
| 160 | 
            +
                    if processor is None:
         | 
| 161 | 
            +
                        processor = (
         | 
| 162 | 
            +
                            AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
         | 
| 163 | 
            +
                        )
         | 
| 164 | 
            +
                    self.set_processor(processor)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def set_use_memory_efficient_attention_xformers(
         | 
| 167 | 
            +
                    self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
         | 
| 168 | 
            +
                ):
         | 
| 169 | 
            +
                    is_lora = hasattr(self, "processor") and isinstance(
         | 
| 170 | 
            +
                        self.processor,
         | 
| 171 | 
            +
                        (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
                    is_custom_diffusion = hasattr(self, "processor") and isinstance(
         | 
| 174 | 
            +
                        self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
         | 
| 175 | 
            +
                    )
         | 
| 176 | 
            +
                    is_added_kv_processor = hasattr(self, "processor") and isinstance(
         | 
| 177 | 
            +
                        self.processor,
         | 
| 178 | 
            +
                        (
         | 
| 179 | 
            +
                            AttnAddedKVProcessor,
         | 
| 180 | 
            +
                            AttnAddedKVProcessor2_0,
         | 
| 181 | 
            +
                            SlicedAttnAddedKVProcessor,
         | 
| 182 | 
            +
                            XFormersAttnAddedKVProcessor,
         | 
| 183 | 
            +
                            LoRAAttnAddedKVProcessor,
         | 
| 184 | 
            +
                        ),
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if use_memory_efficient_attention_xformers:
         | 
| 188 | 
            +
                        if is_added_kv_processor and (is_lora or is_custom_diffusion):
         | 
| 189 | 
            +
                            raise NotImplementedError(
         | 
| 190 | 
            +
                                f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
         | 
| 191 | 
            +
                            )
         | 
| 192 | 
            +
                        if not is_xformers_available():
         | 
| 193 | 
            +
                            raise ModuleNotFoundError(
         | 
| 194 | 
            +
                                (
         | 
| 195 | 
            +
                                    "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
         | 
| 196 | 
            +
                                    " xformers"
         | 
| 197 | 
            +
                                ),
         | 
| 198 | 
            +
                                name="xformers",
         | 
| 199 | 
            +
                            )
         | 
| 200 | 
            +
                        elif not torch.cuda.is_available():
         | 
| 201 | 
            +
                            raise ValueError(
         | 
| 202 | 
            +
                                "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
         | 
| 203 | 
            +
                                " only available for GPU "
         | 
| 204 | 
            +
                            )
         | 
| 205 | 
            +
                        else:
         | 
| 206 | 
            +
                            try:
         | 
| 207 | 
            +
                                # Make sure we can run the memory efficient attention
         | 
| 208 | 
            +
                                _ = xformers.ops.memory_efficient_attention(
         | 
| 209 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 210 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 211 | 
            +
                                    torch.randn((1, 2, 40), device="cuda"),
         | 
| 212 | 
            +
                                )
         | 
| 213 | 
            +
                            except Exception as e:
         | 
| 214 | 
            +
                                raise e
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                        if is_lora:
         | 
| 217 | 
            +
                            # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
         | 
| 218 | 
            +
                            # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
         | 
| 219 | 
            +
                            processor = LoRAXFormersAttnProcessor(
         | 
| 220 | 
            +
                                hidden_size=self.processor.hidden_size,
         | 
| 221 | 
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         | 
| 222 | 
            +
                                rank=self.processor.rank,
         | 
| 223 | 
            +
                                attention_op=attention_op,
         | 
| 224 | 
            +
                            )
         | 
| 225 | 
            +
                            processor.load_state_dict(self.processor.state_dict())
         | 
| 226 | 
            +
                            processor.to(self.processor.to_q_lora.up.weight.device)
         | 
| 227 | 
            +
                        elif is_custom_diffusion:
         | 
| 228 | 
            +
                            processor = CustomDiffusionXFormersAttnProcessor(
         | 
| 229 | 
            +
                                train_kv=self.processor.train_kv,
         | 
| 230 | 
            +
                                train_q_out=self.processor.train_q_out,
         | 
| 231 | 
            +
                                hidden_size=self.processor.hidden_size,
         | 
| 232 | 
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         | 
| 233 | 
            +
                                attention_op=attention_op,
         | 
| 234 | 
            +
                            )
         | 
| 235 | 
            +
                            processor.load_state_dict(self.processor.state_dict())
         | 
| 236 | 
            +
                            if hasattr(self.processor, "to_k_custom_diffusion"):
         | 
| 237 | 
            +
                                processor.to(self.processor.to_k_custom_diffusion.weight.device)
         | 
| 238 | 
            +
                        elif is_added_kv_processor:
         | 
| 239 | 
            +
                            # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
         | 
| 240 | 
            +
                            # which uses this type of cross attention ONLY because the attention mask of format
         | 
| 241 | 
            +
                            # [0, ..., -10.000, ..., 0, ...,] is not supported
         | 
| 242 | 
            +
                            # throw warning
         | 
| 243 | 
            +
                            logger.info(
         | 
| 244 | 
            +
                                "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
         | 
| 245 | 
            +
                            )
         | 
| 246 | 
            +
                            processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
         | 
| 247 | 
            +
                        else:
         | 
| 248 | 
            +
                            processor = XFormersAttnProcessor(attention_op=attention_op)
         | 
| 249 | 
            +
                    else:
         | 
| 250 | 
            +
                        if is_lora:
         | 
| 251 | 
            +
                            attn_processor_class = (
         | 
| 252 | 
            +
                                LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         | 
| 253 | 
            +
                            )
         | 
| 254 | 
            +
                            processor = attn_processor_class(
         | 
| 255 | 
            +
                                hidden_size=self.processor.hidden_size,
         | 
| 256 | 
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         | 
| 257 | 
            +
                                rank=self.processor.rank,
         | 
| 258 | 
            +
                            )
         | 
| 259 | 
            +
                            processor.load_state_dict(self.processor.state_dict())
         | 
| 260 | 
            +
                            processor.to(self.processor.to_q_lora.up.weight.device)
         | 
| 261 | 
            +
                        elif is_custom_diffusion:
         | 
| 262 | 
            +
                            processor = CustomDiffusionAttnProcessor(
         | 
| 263 | 
            +
                                train_kv=self.processor.train_kv,
         | 
| 264 | 
            +
                                train_q_out=self.processor.train_q_out,
         | 
| 265 | 
            +
                                hidden_size=self.processor.hidden_size,
         | 
| 266 | 
            +
                                cross_attention_dim=self.processor.cross_attention_dim,
         | 
| 267 | 
            +
                            )
         | 
| 268 | 
            +
                            processor.load_state_dict(self.processor.state_dict())
         | 
| 269 | 
            +
                            if hasattr(self.processor, "to_k_custom_diffusion"):
         | 
| 270 | 
            +
                                processor.to(self.processor.to_k_custom_diffusion.weight.device)
         | 
| 271 | 
            +
                        else:
         | 
| 272 | 
            +
                            # set attention processor
         | 
| 273 | 
            +
                            # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         | 
| 274 | 
            +
                            # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         | 
| 275 | 
            +
                            # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         | 
| 276 | 
            +
                            processor = (
         | 
| 277 | 
            +
                                AttnProcessor2_0()
         | 
| 278 | 
            +
                                if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
         | 
| 279 | 
            +
                                else AttnProcessor()
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    self.set_processor(processor)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 285 | 
            +
                    if slice_size is not None and slice_size > self.sliceable_head_dim:
         | 
| 286 | 
            +
                        raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if slice_size is not None and self.added_kv_proj_dim is not None:
         | 
| 289 | 
            +
                        processor = SlicedAttnAddedKVProcessor(slice_size)
         | 
| 290 | 
            +
                    elif slice_size is not None:
         | 
| 291 | 
            +
                        processor = SlicedAttnProcessor(slice_size)
         | 
| 292 | 
            +
                    elif self.added_kv_proj_dim is not None:
         | 
| 293 | 
            +
                        processor = AttnAddedKVProcessor()
         | 
| 294 | 
            +
                    else:
         | 
| 295 | 
            +
                        # set attention processor
         | 
| 296 | 
            +
                        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         | 
| 297 | 
            +
                        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         | 
| 298 | 
            +
                        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         | 
| 299 | 
            +
                        processor = (
         | 
| 300 | 
            +
                            AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
         | 
| 301 | 
            +
                        )
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    self.set_processor(processor)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def set_processor(self, processor: "AttnProcessor"):
         | 
| 306 | 
            +
                    # if current processor is in `self._modules` and if passed `processor` is not, we need to
         | 
| 307 | 
            +
                    # pop `processor` from `self._modules`
         | 
| 308 | 
            +
                    if (
         | 
| 309 | 
            +
                        hasattr(self, "processor")
         | 
| 310 | 
            +
                        and isinstance(self.processor, torch.nn.Module)
         | 
| 311 | 
            +
                        and not isinstance(processor, torch.nn.Module)
         | 
| 312 | 
            +
                    ):
         | 
| 313 | 
            +
                        logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
         | 
| 314 | 
            +
                        self._modules.pop("processor")
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    self.processor = processor
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
         | 
| 319 | 
            +
                    # The `Attention` class can call different attention processors / attention functions
         | 
| 320 | 
            +
                    # here we simply pass along all tensors to the selected processor class
         | 
| 321 | 
            +
                    # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         | 
| 322 | 
            +
                    return self.processor(
         | 
| 323 | 
            +
                        self,
         | 
| 324 | 
            +
                        hidden_states,
         | 
| 325 | 
            +
                        encoder_hidden_states=encoder_hidden_states,
         | 
| 326 | 
            +
                        attention_mask=attention_mask,
         | 
| 327 | 
            +
                        **cross_attention_kwargs,
         | 
| 328 | 
            +
                    )
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def batch_to_head_dim(self, tensor):
         | 
| 331 | 
            +
                    head_size = self.heads
         | 
| 332 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 333 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 334 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         | 
| 335 | 
            +
                    return tensor
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                def head_to_batch_dim(self, tensor, out_dim=3):
         | 
| 338 | 
            +
                    head_size = self.heads
         | 
| 339 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 340 | 
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         | 
| 341 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    if out_dim == 3:
         | 
| 344 | 
            +
                        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    return tensor
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def get_attention_scores(self, query, key, attention_mask=None):
         | 
| 349 | 
            +
                    dtype = query.dtype
         | 
| 350 | 
            +
                    if self.upcast_attention:
         | 
| 351 | 
            +
                        query = query.float()
         | 
| 352 | 
            +
                        key = key.float()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if attention_mask is None:
         | 
| 355 | 
            +
                        baddbmm_input = torch.empty(
         | 
| 356 | 
            +
                            query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
         | 
| 357 | 
            +
                        )
         | 
| 358 | 
            +
                        beta = 0
         | 
| 359 | 
            +
                    else:
         | 
| 360 | 
            +
                        baddbmm_input = attention_mask
         | 
| 361 | 
            +
                        beta = 1
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    attention_scores = torch.baddbmm(
         | 
| 364 | 
            +
                        baddbmm_input,
         | 
| 365 | 
            +
                        query,
         | 
| 366 | 
            +
                        key.transpose(-1, -2),
         | 
| 367 | 
            +
                        beta=beta,
         | 
| 368 | 
            +
                        alpha=self.scale,
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
                    del baddbmm_input
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    if self.upcast_softmax:
         | 
| 373 | 
            +
                        attention_scores = attention_scores.float()
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         | 
| 376 | 
            +
                    del attention_scores
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    attention_probs = attention_probs.to(dtype)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    return attention_probs
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
         | 
| 383 | 
            +
                    if batch_size is None:
         | 
| 384 | 
            +
                        deprecate(
         | 
| 385 | 
            +
                            "batch_size=None",
         | 
| 386 | 
            +
                            "0.0.15",
         | 
| 387 | 
            +
                            (
         | 
| 388 | 
            +
                                "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
         | 
| 389 | 
            +
                                " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
         | 
| 390 | 
            +
                                " `prepare_attention_mask` when preparing the attention_mask."
         | 
| 391 | 
            +
                            ),
         | 
| 392 | 
            +
                        )
         | 
| 393 | 
            +
                        batch_size = 1
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    head_size = self.heads
         | 
| 396 | 
            +
                    if attention_mask is None:
         | 
| 397 | 
            +
                        return attention_mask
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    current_length: int = attention_mask.shape[-1]
         | 
| 400 | 
            +
                    if current_length != target_length:
         | 
| 401 | 
            +
                        if attention_mask.device.type == "mps":
         | 
| 402 | 
            +
                            # HACK: MPS: Does not support padding by greater than dimension of input tensor.
         | 
| 403 | 
            +
                            # Instead, we can manually construct the padding tensor.
         | 
| 404 | 
            +
                            padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
         | 
| 405 | 
            +
                            padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
         | 
| 406 | 
            +
                            attention_mask = torch.cat([attention_mask, padding], dim=2)
         | 
| 407 | 
            +
                        else:
         | 
| 408 | 
            +
                            # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
         | 
| 409 | 
            +
                            #       we want to instead pad by (0, remaining_length), where remaining_length is:
         | 
| 410 | 
            +
                            #       remaining_length: int = target_length - current_length
         | 
| 411 | 
            +
                            # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
         | 
| 412 | 
            +
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    if out_dim == 3:
         | 
| 415 | 
            +
                        if attention_mask.shape[0] < batch_size * head_size:
         | 
| 416 | 
            +
                            attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
         | 
| 417 | 
            +
                    elif out_dim == 4:
         | 
| 418 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 419 | 
            +
                        attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    return attention_mask
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                def norm_encoder_hidden_states(self, encoder_hidden_states):
         | 
| 424 | 
            +
                    assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    if isinstance(self.norm_cross, nn.LayerNorm):
         | 
| 427 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 428 | 
            +
                    elif isinstance(self.norm_cross, nn.GroupNorm):
         | 
| 429 | 
            +
                        # Group norm norms along the channels dimension and expects
         | 
| 430 | 
            +
                        # input to be in the shape of (N, C, *). In this case, we want
         | 
| 431 | 
            +
                        # to norm along the hidden dimension, so we need to move
         | 
| 432 | 
            +
                        # (batch_size, sequence_length, hidden_size) ->
         | 
| 433 | 
            +
                        # (batch_size, hidden_size, sequence_length)
         | 
| 434 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 435 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 436 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 437 | 
            +
                    else:
         | 
| 438 | 
            +
                        assert False
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    return encoder_hidden_states
         | 
| 441 | 
            +
             | 
| 442 | 
            +
             | 
| 443 | 
            +
            class AttnProcessor:
         | 
| 444 | 
            +
                r"""
         | 
| 445 | 
            +
                Default processor for performing attention-related computations.
         | 
| 446 | 
            +
                """
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                def __call__(
         | 
| 449 | 
            +
                    self,
         | 
| 450 | 
            +
                    attn: Attention,
         | 
| 451 | 
            +
                    hidden_states,
         | 
| 452 | 
            +
                    encoder_hidden_states=None,
         | 
| 453 | 
            +
                    attention_mask=None,
         | 
| 454 | 
            +
                    temb=None,
         | 
| 455 | 
            +
                ):
         | 
| 456 | 
            +
                    residual = hidden_states
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 459 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    if input_ndim == 4:
         | 
| 464 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 465 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 468 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 469 | 
            +
                    )
         | 
| 470 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    if attn.group_norm is not None:
         | 
| 473 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    if encoder_hidden_states is None:
         | 
| 478 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 479 | 
            +
                    elif attn.norm_cross:
         | 
| 480 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 483 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 486 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 487 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 490 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 491 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    # linear proj
         | 
| 494 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 495 | 
            +
                    # dropout
         | 
| 496 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    if input_ndim == 4:
         | 
| 499 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    if attn.residual_connection:
         | 
| 502 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    return hidden_states
         | 
| 507 | 
            +
             | 
| 508 | 
            +
             | 
| 509 | 
            +
            class LoRALinearLayer(nn.Module):
         | 
| 510 | 
            +
                def __init__(self, in_features, out_features, rank=4, network_alpha=None):
         | 
| 511 | 
            +
                    super().__init__()
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    if rank > min(in_features, out_features):
         | 
| 514 | 
            +
                        raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    self.down = nn.Linear(in_features, rank, bias=False)
         | 
| 517 | 
            +
                    self.up = nn.Linear(rank, out_features, bias=False)
         | 
| 518 | 
            +
                    # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
         | 
| 519 | 
            +
                    # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
         | 
| 520 | 
            +
                    self.network_alpha = network_alpha
         | 
| 521 | 
            +
                    self.rank = rank
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    nn.init.normal_(self.down.weight, std=1 / rank)
         | 
| 524 | 
            +
                    nn.init.zeros_(self.up.weight)
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                def forward(self, hidden_states):
         | 
| 527 | 
            +
                    orig_dtype = hidden_states.dtype
         | 
| 528 | 
            +
                    dtype = self.down.weight.dtype
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    down_hidden_states = self.down(hidden_states.to(dtype))
         | 
| 531 | 
            +
                    up_hidden_states = self.up(down_hidden_states)
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    if self.network_alpha is not None:
         | 
| 534 | 
            +
                        up_hidden_states *= self.network_alpha / self.rank
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    return up_hidden_states.to(orig_dtype)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
            class LoRAAttnProcessor(nn.Module):
         | 
| 540 | 
            +
                r"""
         | 
| 541 | 
            +
                Processor for implementing the LoRA attention mechanism.
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                Args:
         | 
| 544 | 
            +
                    hidden_size (`int`, *optional*):
         | 
| 545 | 
            +
                        The hidden size of the attention layer.
         | 
| 546 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 547 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 548 | 
            +
                    rank (`int`, defaults to 4):
         | 
| 549 | 
            +
                        The dimension of the LoRA update matrices.
         | 
| 550 | 
            +
                    network_alpha (`int`, *optional*):
         | 
| 551 | 
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         | 
| 552 | 
            +
                """
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         | 
| 555 | 
            +
                    super().__init__()
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    self.hidden_size = hidden_size
         | 
| 558 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 559 | 
            +
                    self.rank = rank
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 562 | 
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 563 | 
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 564 | 
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                def __call__(
         | 
| 567 | 
            +
                    self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
         | 
| 568 | 
            +
                ):
         | 
| 569 | 
            +
                    residual = hidden_states
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 572 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                    if input_ndim == 4:
         | 
| 577 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 578 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 581 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 582 | 
            +
                    )
         | 
| 583 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    if attn.group_norm is not None:
         | 
| 586 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         | 
| 589 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    if encoder_hidden_states is None:
         | 
| 592 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 593 | 
            +
                    elif attn.norm_cross:
         | 
| 594 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         | 
| 597 | 
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 600 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 603 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 604 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                    # linear proj
         | 
| 607 | 
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         | 
| 608 | 
            +
                    # dropout
         | 
| 609 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    if input_ndim == 4:
         | 
| 612 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    if attn.residual_connection:
         | 
| 615 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    return hidden_states
         | 
| 620 | 
            +
             | 
| 621 | 
            +
             | 
| 622 | 
            +
            class CustomDiffusionAttnProcessor(nn.Module):
         | 
| 623 | 
            +
                r"""
         | 
| 624 | 
            +
                Processor for implementing attention for the Custom Diffusion method.
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                Args:
         | 
| 627 | 
            +
                    train_kv (`bool`, defaults to `True`):
         | 
| 628 | 
            +
                        Whether to newly train the key and value matrices corresponding to the text features.
         | 
| 629 | 
            +
                    train_q_out (`bool`, defaults to `True`):
         | 
| 630 | 
            +
                        Whether to newly train query matrices corresponding to the latent image features.
         | 
| 631 | 
            +
                    hidden_size (`int`, *optional*, defaults to `None`):
         | 
| 632 | 
            +
                        The hidden size of the attention layer.
         | 
| 633 | 
            +
                    cross_attention_dim (`int`, *optional*, defaults to `None`):
         | 
| 634 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 635 | 
            +
                    out_bias (`bool`, defaults to `True`):
         | 
| 636 | 
            +
                        Whether to include the bias parameter in `train_q_out`.
         | 
| 637 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0):
         | 
| 638 | 
            +
                        The dropout probability to use.
         | 
| 639 | 
            +
                """
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                def __init__(
         | 
| 642 | 
            +
                    self,
         | 
| 643 | 
            +
                    train_kv=True,
         | 
| 644 | 
            +
                    train_q_out=True,
         | 
| 645 | 
            +
                    hidden_size=None,
         | 
| 646 | 
            +
                    cross_attention_dim=None,
         | 
| 647 | 
            +
                    out_bias=True,
         | 
| 648 | 
            +
                    dropout=0.0,
         | 
| 649 | 
            +
                ):
         | 
| 650 | 
            +
                    super().__init__()
         | 
| 651 | 
            +
                    self.train_kv = train_kv
         | 
| 652 | 
            +
                    self.train_q_out = train_q_out
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    self.hidden_size = hidden_size
         | 
| 655 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    # `_custom_diffusion` id for easy serialization and loading.
         | 
| 658 | 
            +
                    if self.train_kv:
         | 
| 659 | 
            +
                        self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 660 | 
            +
                        self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 661 | 
            +
                    if self.train_q_out:
         | 
| 662 | 
            +
                        self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
         | 
| 663 | 
            +
                        self.to_out_custom_diffusion = nn.ModuleList([])
         | 
| 664 | 
            +
                        self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
         | 
| 665 | 
            +
                        self.to_out_custom_diffusion.append(nn.Dropout(dropout))
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 668 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 669 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 670 | 
            +
                    if self.train_q_out:
         | 
| 671 | 
            +
                        query = self.to_q_custom_diffusion(hidden_states)
         | 
| 672 | 
            +
                    else:
         | 
| 673 | 
            +
                        query = attn.to_q(hidden_states)
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    if encoder_hidden_states is None:
         | 
| 676 | 
            +
                        crossattn = False
         | 
| 677 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 678 | 
            +
                    else:
         | 
| 679 | 
            +
                        crossattn = True
         | 
| 680 | 
            +
                        if attn.norm_cross:
         | 
| 681 | 
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    if self.train_kv:
         | 
| 684 | 
            +
                        key = self.to_k_custom_diffusion(encoder_hidden_states)
         | 
| 685 | 
            +
                        value = self.to_v_custom_diffusion(encoder_hidden_states)
         | 
| 686 | 
            +
                    else:
         | 
| 687 | 
            +
                        key = attn.to_k(encoder_hidden_states)
         | 
| 688 | 
            +
                        value = attn.to_v(encoder_hidden_states)
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                    if crossattn:
         | 
| 691 | 
            +
                        detach = torch.ones_like(key)
         | 
| 692 | 
            +
                        detach[:, :1, :] = detach[:, :1, :] * 0.0
         | 
| 693 | 
            +
                        key = detach * key + (1 - detach) * key.detach()
         | 
| 694 | 
            +
                        value = detach * value + (1 - detach) * value.detach()
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 697 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 698 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 701 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 702 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    if self.train_q_out:
         | 
| 705 | 
            +
                        # linear proj
         | 
| 706 | 
            +
                        hidden_states = self.to_out_custom_diffusion[0](hidden_states)
         | 
| 707 | 
            +
                        # dropout
         | 
| 708 | 
            +
                        hidden_states = self.to_out_custom_diffusion[1](hidden_states)
         | 
| 709 | 
            +
                    else:
         | 
| 710 | 
            +
                        # linear proj
         | 
| 711 | 
            +
                        hidden_states = attn.to_out[0](hidden_states)
         | 
| 712 | 
            +
                        # dropout
         | 
| 713 | 
            +
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    return hidden_states
         | 
| 716 | 
            +
             | 
| 717 | 
            +
             | 
| 718 | 
            +
            class AttnAddedKVProcessor:
         | 
| 719 | 
            +
                r"""
         | 
| 720 | 
            +
                Processor for performing attention-related computations with extra learnable key and value matrices for the text
         | 
| 721 | 
            +
                encoder.
         | 
| 722 | 
            +
                """
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 725 | 
            +
                    residual = hidden_states
         | 
| 726 | 
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         | 
| 727 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    if encoder_hidden_states is None:
         | 
| 732 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 733 | 
            +
                    elif attn.norm_cross:
         | 
| 734 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 739 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         | 
| 742 | 
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         | 
| 743 | 
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         | 
| 744 | 
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                    if not attn.only_cross_attention:
         | 
| 747 | 
            +
                        key = attn.to_k(hidden_states)
         | 
| 748 | 
            +
                        value = attn.to_v(hidden_states)
         | 
| 749 | 
            +
                        key = attn.head_to_batch_dim(key)
         | 
| 750 | 
            +
                        value = attn.head_to_batch_dim(value)
         | 
| 751 | 
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         | 
| 752 | 
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         | 
| 753 | 
            +
                    else:
         | 
| 754 | 
            +
                        key = encoder_hidden_states_key_proj
         | 
| 755 | 
            +
                        value = encoder_hidden_states_value_proj
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 758 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 759 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                    # linear proj
         | 
| 762 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 763 | 
            +
                    # dropout
         | 
| 764 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         | 
| 767 | 
            +
                    hidden_states = hidden_states + residual
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                    return hidden_states
         | 
| 770 | 
            +
             | 
| 771 | 
            +
             | 
| 772 | 
            +
            class AttnAddedKVProcessor2_0:
         | 
| 773 | 
            +
                r"""
         | 
| 774 | 
            +
                Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
         | 
| 775 | 
            +
                learnable key and value matrices for the text encoder.
         | 
| 776 | 
            +
                """
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                def __init__(self):
         | 
| 779 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 780 | 
            +
                        raise ImportError(
         | 
| 781 | 
            +
                            "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
         | 
| 782 | 
            +
                        )
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 785 | 
            +
                    residual = hidden_states
         | 
| 786 | 
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         | 
| 787 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                    if encoder_hidden_states is None:
         | 
| 792 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 793 | 
            +
                    elif attn.norm_cross:
         | 
| 794 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 799 | 
            +
                    query = attn.head_to_batch_dim(query, out_dim=4)
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         | 
| 802 | 
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         | 
| 803 | 
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
         | 
| 804 | 
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    if not attn.only_cross_attention:
         | 
| 807 | 
            +
                        key = attn.to_k(hidden_states)
         | 
| 808 | 
            +
                        value = attn.to_v(hidden_states)
         | 
| 809 | 
            +
                        key = attn.head_to_batch_dim(key, out_dim=4)
         | 
| 810 | 
            +
                        value = attn.head_to_batch_dim(value, out_dim=4)
         | 
| 811 | 
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
         | 
| 812 | 
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
         | 
| 813 | 
            +
                    else:
         | 
| 814 | 
            +
                        key = encoder_hidden_states_key_proj
         | 
| 815 | 
            +
                        value = encoder_hidden_states_value_proj
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 818 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 819 | 
            +
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 820 | 
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 821 | 
            +
                    )
         | 
| 822 | 
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                    # linear proj
         | 
| 825 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 826 | 
            +
                    # dropout
         | 
| 827 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         | 
| 830 | 
            +
                    hidden_states = hidden_states + residual
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                    return hidden_states
         | 
| 833 | 
            +
             | 
| 834 | 
            +
             | 
| 835 | 
            +
            class LoRAAttnAddedKVProcessor(nn.Module):
         | 
| 836 | 
            +
                r"""
         | 
| 837 | 
            +
                Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
         | 
| 838 | 
            +
                encoder.
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                Args:
         | 
| 841 | 
            +
                    hidden_size (`int`, *optional*):
         | 
| 842 | 
            +
                        The hidden size of the attention layer.
         | 
| 843 | 
            +
                    cross_attention_dim (`int`, *optional*, defaults to `None`):
         | 
| 844 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 845 | 
            +
                    rank (`int`, defaults to 4):
         | 
| 846 | 
            +
                        The dimension of the LoRA update matrices.
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                """
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         | 
| 851 | 
            +
                    super().__init__()
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                    self.hidden_size = hidden_size
         | 
| 854 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 855 | 
            +
                    self.rank = rank
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 858 | 
            +
                    self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 859 | 
            +
                    self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 860 | 
            +
                    self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 861 | 
            +
                    self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 862 | 
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
         | 
| 865 | 
            +
                    residual = hidden_states
         | 
| 866 | 
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         | 
| 867 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 870 | 
            +
             | 
| 871 | 
            +
                    if encoder_hidden_states is None:
         | 
| 872 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 873 | 
            +
                    elif attn.norm_cross:
         | 
| 874 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 877 | 
            +
             | 
| 878 | 
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         | 
| 879 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
         | 
| 882 | 
            +
                        encoder_hidden_states
         | 
| 883 | 
            +
                    )
         | 
| 884 | 
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
         | 
| 885 | 
            +
                        encoder_hidden_states
         | 
| 886 | 
            +
                    )
         | 
| 887 | 
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         | 
| 888 | 
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    if not attn.only_cross_attention:
         | 
| 891 | 
            +
                        key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
         | 
| 892 | 
            +
                        value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
         | 
| 893 | 
            +
                        key = attn.head_to_batch_dim(key)
         | 
| 894 | 
            +
                        value = attn.head_to_batch_dim(value)
         | 
| 895 | 
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         | 
| 896 | 
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         | 
| 897 | 
            +
                    else:
         | 
| 898 | 
            +
                        key = encoder_hidden_states_key_proj
         | 
| 899 | 
            +
                        value = encoder_hidden_states_value_proj
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 902 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 903 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 904 | 
            +
             | 
| 905 | 
            +
                    # linear proj
         | 
| 906 | 
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         | 
| 907 | 
            +
                    # dropout
         | 
| 908 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         | 
| 911 | 
            +
                    hidden_states = hidden_states + residual
         | 
| 912 | 
            +
             | 
| 913 | 
            +
                    return hidden_states
         | 
| 914 | 
            +
             | 
| 915 | 
            +
             | 
| 916 | 
            +
            class XFormersAttnAddedKVProcessor:
         | 
| 917 | 
            +
                r"""
         | 
| 918 | 
            +
                Processor for implementing memory efficient attention using xFormers.
         | 
| 919 | 
            +
             | 
| 920 | 
            +
                Args:
         | 
| 921 | 
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         | 
| 922 | 
            +
                        The base
         | 
| 923 | 
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         | 
| 924 | 
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         | 
| 925 | 
            +
                        operator.
         | 
| 926 | 
            +
                """
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                def __init__(self, attention_op: Optional[Callable] = None):
         | 
| 929 | 
            +
                    self.attention_op = attention_op
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 932 | 
            +
                    residual = hidden_states
         | 
| 933 | 
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         | 
| 934 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    if encoder_hidden_states is None:
         | 
| 939 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 940 | 
            +
                    elif attn.norm_cross:
         | 
| 941 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 942 | 
            +
             | 
| 943 | 
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 946 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         | 
| 949 | 
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         | 
| 950 | 
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         | 
| 951 | 
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                    if not attn.only_cross_attention:
         | 
| 954 | 
            +
                        key = attn.to_k(hidden_states)
         | 
| 955 | 
            +
                        value = attn.to_v(hidden_states)
         | 
| 956 | 
            +
                        key = attn.head_to_batch_dim(key)
         | 
| 957 | 
            +
                        value = attn.head_to_batch_dim(value)
         | 
| 958 | 
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         | 
| 959 | 
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         | 
| 960 | 
            +
                    else:
         | 
| 961 | 
            +
                        key = encoder_hidden_states_key_proj
         | 
| 962 | 
            +
                        value = encoder_hidden_states_value_proj
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 965 | 
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         | 
| 966 | 
            +
                    )
         | 
| 967 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 968 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                    # linear proj
         | 
| 971 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 972 | 
            +
                    # dropout
         | 
| 973 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         | 
| 976 | 
            +
                    hidden_states = hidden_states + residual
         | 
| 977 | 
            +
             | 
| 978 | 
            +
                    return hidden_states
         | 
| 979 | 
            +
             | 
| 980 | 
            +
             | 
| 981 | 
            +
            class XFormersAttnProcessor:
         | 
| 982 | 
            +
                r"""
         | 
| 983 | 
            +
                Processor for implementing memory efficient attention using xFormers.
         | 
| 984 | 
            +
             | 
| 985 | 
            +
                Args:
         | 
| 986 | 
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         | 
| 987 | 
            +
                        The base
         | 
| 988 | 
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         | 
| 989 | 
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         | 
| 990 | 
            +
                        operator.
         | 
| 991 | 
            +
                """
         | 
| 992 | 
            +
             | 
| 993 | 
            +
                def __init__(self, attention_op: Optional[Callable] = None):
         | 
| 994 | 
            +
                    self.attention_op = attention_op
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                def __call__(
         | 
| 997 | 
            +
                    self,
         | 
| 998 | 
            +
                    attn: Attention,
         | 
| 999 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 1000 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 1001 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 1002 | 
            +
                    temb: Optional[torch.FloatTensor] = None,
         | 
| 1003 | 
            +
                ):
         | 
| 1004 | 
            +
                    residual = hidden_states
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 1007 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 1008 | 
            +
             | 
| 1009 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1010 | 
            +
             | 
| 1011 | 
            +
                    if input_ndim == 4:
         | 
| 1012 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1013 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                    batch_size, key_tokens, _ = (
         | 
| 1016 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1017 | 
            +
                    )
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
         | 
| 1020 | 
            +
                    if attention_mask is not None:
         | 
| 1021 | 
            +
                        # expand our mask's singleton query_tokens dimension:
         | 
| 1022 | 
            +
                        #   [batch*heads,            1, key_tokens] ->
         | 
| 1023 | 
            +
                        #   [batch*heads, query_tokens, key_tokens]
         | 
| 1024 | 
            +
                        # so that it can be added as a bias onto the attention scores that xformers computes:
         | 
| 1025 | 
            +
                        #   [batch*heads, query_tokens, key_tokens]
         | 
| 1026 | 
            +
                        # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
         | 
| 1027 | 
            +
                        _, query_tokens, _ = hidden_states.shape
         | 
| 1028 | 
            +
                        attention_mask = attention_mask.expand(-1, query_tokens, -1)
         | 
| 1029 | 
            +
             | 
| 1030 | 
            +
                    if attn.group_norm is not None:
         | 
| 1031 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1032 | 
            +
             | 
| 1033 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1036 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1037 | 
            +
                    elif attn.norm_cross:
         | 
| 1038 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1039 | 
            +
             | 
| 1040 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 1041 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 1042 | 
            +
             | 
| 1043 | 
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         | 
| 1044 | 
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         | 
| 1045 | 
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 1048 | 
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         | 
| 1049 | 
            +
                    )
         | 
| 1050 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 1051 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1052 | 
            +
             | 
| 1053 | 
            +
                    # linear proj
         | 
| 1054 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1055 | 
            +
                    # dropout
         | 
| 1056 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1057 | 
            +
             | 
| 1058 | 
            +
                    if input_ndim == 4:
         | 
| 1059 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 1060 | 
            +
             | 
| 1061 | 
            +
                    if attn.residual_connection:
         | 
| 1062 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1063 | 
            +
             | 
| 1064 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1065 | 
            +
             | 
| 1066 | 
            +
                    return hidden_states
         | 
| 1067 | 
            +
             | 
| 1068 | 
            +
             | 
| 1069 | 
            +
            class AttnProcessor2_0:
         | 
| 1070 | 
            +
                r"""
         | 
| 1071 | 
            +
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         | 
| 1072 | 
            +
                """
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                def __init__(self):
         | 
| 1075 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 1076 | 
            +
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 1077 | 
            +
             | 
| 1078 | 
            +
                def __call__(
         | 
| 1079 | 
            +
                    self,
         | 
| 1080 | 
            +
                    attn: Attention,
         | 
| 1081 | 
            +
                    hidden_states,
         | 
| 1082 | 
            +
                    encoder_hidden_states=None,
         | 
| 1083 | 
            +
                    attention_mask=None,
         | 
| 1084 | 
            +
                    temb=None,
         | 
| 1085 | 
            +
                ):
         | 
| 1086 | 
            +
                    residual = hidden_states
         | 
| 1087 | 
            +
             | 
| 1088 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 1089 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 1090 | 
            +
             | 
| 1091 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                    if input_ndim == 4:
         | 
| 1094 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1095 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 1096 | 
            +
             | 
| 1097 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1098 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1099 | 
            +
                    )
         | 
| 1100 | 
            +
                    inner_dim = hidden_states.shape[-1]
         | 
| 1101 | 
            +
             | 
| 1102 | 
            +
                    if attention_mask is not None:
         | 
| 1103 | 
            +
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1104 | 
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 1105 | 
            +
                        # (batch, heads, source_length, target_length)
         | 
| 1106 | 
            +
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         | 
| 1107 | 
            +
             | 
| 1108 | 
            +
                    if attn.group_norm is not None:
         | 
| 1109 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1110 | 
            +
             | 
| 1111 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1114 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1115 | 
            +
                    elif attn.norm_cross:
         | 
| 1116 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 1119 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 1122 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1123 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1124 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1125 | 
            +
             | 
| 1126 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 1127 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 1128 | 
            +
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 1129 | 
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 1130 | 
            +
                    )
         | 
| 1131 | 
            +
             | 
| 1132 | 
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 1133 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 1134 | 
            +
             | 
| 1135 | 
            +
                    # linear proj
         | 
| 1136 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1137 | 
            +
                    # dropout
         | 
| 1138 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1139 | 
            +
             | 
| 1140 | 
            +
                    if input_ndim == 4:
         | 
| 1141 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                    if attn.residual_connection:
         | 
| 1144 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1145 | 
            +
             | 
| 1146 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1147 | 
            +
             | 
| 1148 | 
            +
                    return hidden_states
         | 
| 1149 | 
            +
             | 
| 1150 | 
            +
             | 
| 1151 | 
            +
            class LoRAXFormersAttnProcessor(nn.Module):
         | 
| 1152 | 
            +
                r"""
         | 
| 1153 | 
            +
                Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
         | 
| 1154 | 
            +
             | 
| 1155 | 
            +
                Args:
         | 
| 1156 | 
            +
                    hidden_size (`int`, *optional*):
         | 
| 1157 | 
            +
                        The hidden size of the attention layer.
         | 
| 1158 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 1159 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 1160 | 
            +
                    rank (`int`, defaults to 4):
         | 
| 1161 | 
            +
                        The dimension of the LoRA update matrices.
         | 
| 1162 | 
            +
                    attention_op (`Callable`, *optional*, defaults to `None`):
         | 
| 1163 | 
            +
                        The base
         | 
| 1164 | 
            +
                        [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
         | 
| 1165 | 
            +
                        use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
         | 
| 1166 | 
            +
                        operator.
         | 
| 1167 | 
            +
                    network_alpha (`int`, *optional*):
         | 
| 1168 | 
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         | 
| 1169 | 
            +
             | 
| 1170 | 
            +
                """
         | 
| 1171 | 
            +
             | 
| 1172 | 
            +
                def __init__(
         | 
| 1173 | 
            +
                    self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
         | 
| 1174 | 
            +
                ):
         | 
| 1175 | 
            +
                    super().__init__()
         | 
| 1176 | 
            +
             | 
| 1177 | 
            +
                    self.hidden_size = hidden_size
         | 
| 1178 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 1179 | 
            +
                    self.rank = rank
         | 
| 1180 | 
            +
                    self.attention_op = attention_op
         | 
| 1181 | 
            +
             | 
| 1182 | 
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 1183 | 
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 1184 | 
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 1185 | 
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                def __call__(
         | 
| 1188 | 
            +
                    self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
         | 
| 1189 | 
            +
                ):
         | 
| 1190 | 
            +
                    residual = hidden_states
         | 
| 1191 | 
            +
             | 
| 1192 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 1193 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 1194 | 
            +
             | 
| 1195 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1196 | 
            +
             | 
| 1197 | 
            +
                    if input_ndim == 4:
         | 
| 1198 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1199 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 1200 | 
            +
             | 
| 1201 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1202 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1203 | 
            +
                    )
         | 
| 1204 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1205 | 
            +
             | 
| 1206 | 
            +
                    if attn.group_norm is not None:
         | 
| 1207 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1208 | 
            +
             | 
| 1209 | 
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         | 
| 1210 | 
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         | 
| 1211 | 
            +
             | 
| 1212 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1213 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1214 | 
            +
                    elif attn.norm_cross:
         | 
| 1215 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1216 | 
            +
             | 
| 1217 | 
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         | 
| 1218 | 
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         | 
| 1219 | 
            +
             | 
| 1220 | 
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         | 
| 1221 | 
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         | 
| 1222 | 
            +
             | 
| 1223 | 
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 1224 | 
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         | 
| 1225 | 
            +
                    )
         | 
| 1226 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1227 | 
            +
             | 
| 1228 | 
            +
                    # linear proj
         | 
| 1229 | 
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         | 
| 1230 | 
            +
                    # dropout
         | 
| 1231 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1232 | 
            +
             | 
| 1233 | 
            +
                    if input_ndim == 4:
         | 
| 1234 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 1235 | 
            +
             | 
| 1236 | 
            +
                    if attn.residual_connection:
         | 
| 1237 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1238 | 
            +
             | 
| 1239 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1240 | 
            +
             | 
| 1241 | 
            +
                    return hidden_states
         | 
| 1242 | 
            +
             | 
| 1243 | 
            +
             | 
| 1244 | 
            +
            class LoRAAttnProcessor2_0(nn.Module):
         | 
| 1245 | 
            +
                r"""
         | 
| 1246 | 
            +
                Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
         | 
| 1247 | 
            +
                attention.
         | 
| 1248 | 
            +
             | 
| 1249 | 
            +
                Args:
         | 
| 1250 | 
            +
                    hidden_size (`int`):
         | 
| 1251 | 
            +
                        The hidden size of the attention layer.
         | 
| 1252 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 1253 | 
            +
                        The number of channels in the `encoder_hidden_states`.
         | 
| 1254 | 
            +
                    rank (`int`, defaults to 4):
         | 
| 1255 | 
            +
                        The dimension of the LoRA update matrices.
         | 
| 1256 | 
            +
                    network_alpha (`int`, *optional*):
         | 
| 1257 | 
            +
                        Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
         | 
| 1258 | 
            +
                """
         | 
| 1259 | 
            +
             | 
| 1260 | 
            +
                def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
         | 
| 1261 | 
            +
                    super().__init__()
         | 
| 1262 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 1263 | 
            +
                        raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 1264 | 
            +
             | 
| 1265 | 
            +
                    self.hidden_size = hidden_size
         | 
| 1266 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 1267 | 
            +
                    self.rank = rank
         | 
| 1268 | 
            +
             | 
| 1269 | 
            +
                    self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 1270 | 
            +
                    self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 1271 | 
            +
                    self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
         | 
| 1272 | 
            +
                    self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
         | 
| 1273 | 
            +
             | 
| 1274 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
         | 
| 1275 | 
            +
                    residual = hidden_states
         | 
| 1276 | 
            +
             | 
| 1277 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1278 | 
            +
             | 
| 1279 | 
            +
                    if input_ndim == 4:
         | 
| 1280 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1281 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 1282 | 
            +
             | 
| 1283 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1284 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1285 | 
            +
                    )
         | 
| 1286 | 
            +
                    inner_dim = hidden_states.shape[-1]
         | 
| 1287 | 
            +
             | 
| 1288 | 
            +
                    if attention_mask is not None:
         | 
| 1289 | 
            +
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1290 | 
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 1291 | 
            +
                        # (batch, heads, source_length, target_length)
         | 
| 1292 | 
            +
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
         | 
| 1293 | 
            +
             | 
| 1294 | 
            +
                    if attn.group_norm is not None:
         | 
| 1295 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1296 | 
            +
             | 
| 1297 | 
            +
                    query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
         | 
| 1298 | 
            +
             | 
| 1299 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1300 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1301 | 
            +
                    elif attn.norm_cross:
         | 
| 1302 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1303 | 
            +
             | 
| 1304 | 
            +
                    key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
         | 
| 1305 | 
            +
                    value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
         | 
| 1306 | 
            +
             | 
| 1307 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 1308 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1309 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1310 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1311 | 
            +
             | 
| 1312 | 
            +
                    # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 1313 | 
            +
                    hidden_states = F.scaled_dot_product_attention(
         | 
| 1314 | 
            +
                        query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 1315 | 
            +
                    )
         | 
| 1316 | 
            +
                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 1317 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 1318 | 
            +
             | 
| 1319 | 
            +
                    # linear proj
         | 
| 1320 | 
            +
                    hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
         | 
| 1321 | 
            +
                    # dropout
         | 
| 1322 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1323 | 
            +
             | 
| 1324 | 
            +
                    if input_ndim == 4:
         | 
| 1325 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 1326 | 
            +
             | 
| 1327 | 
            +
                    if attn.residual_connection:
         | 
| 1328 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1329 | 
            +
             | 
| 1330 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1331 | 
            +
             | 
| 1332 | 
            +
                    return hidden_states
         | 
| 1333 | 
            +
             | 
| 1334 | 
            +
             | 
| 1335 | 
            +
            class CustomDiffusionXFormersAttnProcessor(nn.Module):
         | 
| 1336 | 
            +
                r"""
         | 
| 1337 | 
            +
                Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
         | 
| 1338 | 
            +
             | 
| 1339 | 
            +
                Args:
         | 
| 1340 | 
            +
                train_kv (`bool`, defaults to `True`):
         | 
| 1341 | 
            +
                    Whether to newly train the key and value matrices corresponding to the text features.
         | 
| 1342 | 
            +
                train_q_out (`bool`, defaults to `True`):
         | 
| 1343 | 
            +
                    Whether to newly train query matrices corresponding to the latent image features.
         | 
| 1344 | 
            +
                hidden_size (`int`, *optional*, defaults to `None`):
         | 
| 1345 | 
            +
                    The hidden size of the attention layer.
         | 
| 1346 | 
            +
                cross_attention_dim (`int`, *optional*, defaults to `None`):
         | 
| 1347 | 
            +
                    The number of channels in the `encoder_hidden_states`.
         | 
| 1348 | 
            +
                out_bias (`bool`, defaults to `True`):
         | 
| 1349 | 
            +
                    Whether to include the bias parameter in `train_q_out`.
         | 
| 1350 | 
            +
                dropout (`float`, *optional*, defaults to 0.0):
         | 
| 1351 | 
            +
                    The dropout probability to use.
         | 
| 1352 | 
            +
                attention_op (`Callable`, *optional*, defaults to `None`):
         | 
| 1353 | 
            +
                    The base
         | 
| 1354 | 
            +
                    [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
         | 
| 1355 | 
            +
                    as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
         | 
| 1356 | 
            +
                """
         | 
| 1357 | 
            +
             | 
| 1358 | 
            +
                def __init__(
         | 
| 1359 | 
            +
                    self,
         | 
| 1360 | 
            +
                    train_kv=True,
         | 
| 1361 | 
            +
                    train_q_out=False,
         | 
| 1362 | 
            +
                    hidden_size=None,
         | 
| 1363 | 
            +
                    cross_attention_dim=None,
         | 
| 1364 | 
            +
                    out_bias=True,
         | 
| 1365 | 
            +
                    dropout=0.0,
         | 
| 1366 | 
            +
                    attention_op: Optional[Callable] = None,
         | 
| 1367 | 
            +
                ):
         | 
| 1368 | 
            +
                    super().__init__()
         | 
| 1369 | 
            +
                    self.train_kv = train_kv
         | 
| 1370 | 
            +
                    self.train_q_out = train_q_out
         | 
| 1371 | 
            +
             | 
| 1372 | 
            +
                    self.hidden_size = hidden_size
         | 
| 1373 | 
            +
                    self.cross_attention_dim = cross_attention_dim
         | 
| 1374 | 
            +
                    self.attention_op = attention_op
         | 
| 1375 | 
            +
             | 
| 1376 | 
            +
                    # `_custom_diffusion` id for easy serialization and loading.
         | 
| 1377 | 
            +
                    if self.train_kv:
         | 
| 1378 | 
            +
                        self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 1379 | 
            +
                        self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
         | 
| 1380 | 
            +
                    if self.train_q_out:
         | 
| 1381 | 
            +
                        self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
         | 
| 1382 | 
            +
                        self.to_out_custom_diffusion = nn.ModuleList([])
         | 
| 1383 | 
            +
                        self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
         | 
| 1384 | 
            +
                        self.to_out_custom_diffusion.append(nn.Dropout(dropout))
         | 
| 1385 | 
            +
             | 
| 1386 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 1387 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1388 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1389 | 
            +
                    )
         | 
| 1390 | 
            +
             | 
| 1391 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1392 | 
            +
             | 
| 1393 | 
            +
                    if self.train_q_out:
         | 
| 1394 | 
            +
                        query = self.to_q_custom_diffusion(hidden_states)
         | 
| 1395 | 
            +
                    else:
         | 
| 1396 | 
            +
                        query = attn.to_q(hidden_states)
         | 
| 1397 | 
            +
             | 
| 1398 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1399 | 
            +
                        crossattn = False
         | 
| 1400 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1401 | 
            +
                    else:
         | 
| 1402 | 
            +
                        crossattn = True
         | 
| 1403 | 
            +
                        if attn.norm_cross:
         | 
| 1404 | 
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1405 | 
            +
             | 
| 1406 | 
            +
                    if self.train_kv:
         | 
| 1407 | 
            +
                        key = self.to_k_custom_diffusion(encoder_hidden_states)
         | 
| 1408 | 
            +
                        value = self.to_v_custom_diffusion(encoder_hidden_states)
         | 
| 1409 | 
            +
                    else:
         | 
| 1410 | 
            +
                        key = attn.to_k(encoder_hidden_states)
         | 
| 1411 | 
            +
                        value = attn.to_v(encoder_hidden_states)
         | 
| 1412 | 
            +
             | 
| 1413 | 
            +
                    if crossattn:
         | 
| 1414 | 
            +
                        detach = torch.ones_like(key)
         | 
| 1415 | 
            +
                        detach[:, :1, :] = detach[:, :1, :] * 0.0
         | 
| 1416 | 
            +
                        key = detach * key + (1 - detach) * key.detach()
         | 
| 1417 | 
            +
                        value = detach * value + (1 - detach) * value.detach()
         | 
| 1418 | 
            +
             | 
| 1419 | 
            +
                    query = attn.head_to_batch_dim(query).contiguous()
         | 
| 1420 | 
            +
                    key = attn.head_to_batch_dim(key).contiguous()
         | 
| 1421 | 
            +
                    value = attn.head_to_batch_dim(value).contiguous()
         | 
| 1422 | 
            +
             | 
| 1423 | 
            +
                    hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 1424 | 
            +
                        query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
         | 
| 1425 | 
            +
                    )
         | 
| 1426 | 
            +
                    hidden_states = hidden_states.to(query.dtype)
         | 
| 1427 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1428 | 
            +
             | 
| 1429 | 
            +
                    if self.train_q_out:
         | 
| 1430 | 
            +
                        # linear proj
         | 
| 1431 | 
            +
                        hidden_states = self.to_out_custom_diffusion[0](hidden_states)
         | 
| 1432 | 
            +
                        # dropout
         | 
| 1433 | 
            +
                        hidden_states = self.to_out_custom_diffusion[1](hidden_states)
         | 
| 1434 | 
            +
                    else:
         | 
| 1435 | 
            +
                        # linear proj
         | 
| 1436 | 
            +
                        hidden_states = attn.to_out[0](hidden_states)
         | 
| 1437 | 
            +
                        # dropout
         | 
| 1438 | 
            +
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 1439 | 
            +
                    return hidden_states
         | 
| 1440 | 
            +
             | 
| 1441 | 
            +
             | 
| 1442 | 
            +
            class SlicedAttnProcessor:
         | 
| 1443 | 
            +
                r"""
         | 
| 1444 | 
            +
                Processor for implementing sliced attention.
         | 
| 1445 | 
            +
             | 
| 1446 | 
            +
                Args:
         | 
| 1447 | 
            +
                    slice_size (`int`, *optional*):
         | 
| 1448 | 
            +
                        The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
         | 
| 1449 | 
            +
                        `attention_head_dim` must be a multiple of the `slice_size`.
         | 
| 1450 | 
            +
                """
         | 
| 1451 | 
            +
             | 
| 1452 | 
            +
                def __init__(self, slice_size):
         | 
| 1453 | 
            +
                    self.slice_size = slice_size
         | 
| 1454 | 
            +
             | 
| 1455 | 
            +
                def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 1456 | 
            +
                    residual = hidden_states
         | 
| 1457 | 
            +
             | 
| 1458 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1459 | 
            +
             | 
| 1460 | 
            +
                    if input_ndim == 4:
         | 
| 1461 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1462 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 1463 | 
            +
             | 
| 1464 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1465 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 1466 | 
            +
                    )
         | 
| 1467 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1468 | 
            +
             | 
| 1469 | 
            +
                    if attn.group_norm is not None:
         | 
| 1470 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1471 | 
            +
             | 
| 1472 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 1473 | 
            +
                    dim = query.shape[-1]
         | 
| 1474 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 1475 | 
            +
             | 
| 1476 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1477 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1478 | 
            +
                    elif attn.norm_cross:
         | 
| 1479 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1480 | 
            +
             | 
| 1481 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 1482 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 1483 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 1484 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 1485 | 
            +
             | 
| 1486 | 
            +
                    batch_size_attention, query_tokens, _ = query.shape
         | 
| 1487 | 
            +
                    hidden_states = torch.zeros(
         | 
| 1488 | 
            +
                        (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         | 
| 1489 | 
            +
                    )
         | 
| 1490 | 
            +
             | 
| 1491 | 
            +
                    for i in range(batch_size_attention // self.slice_size):
         | 
| 1492 | 
            +
                        start_idx = i * self.slice_size
         | 
| 1493 | 
            +
                        end_idx = (i + 1) * self.slice_size
         | 
| 1494 | 
            +
             | 
| 1495 | 
            +
                        query_slice = query[start_idx:end_idx]
         | 
| 1496 | 
            +
                        key_slice = key[start_idx:end_idx]
         | 
| 1497 | 
            +
                        attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         | 
| 1498 | 
            +
             | 
| 1499 | 
            +
                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 1500 | 
            +
             | 
| 1501 | 
            +
                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         | 
| 1502 | 
            +
             | 
| 1503 | 
            +
                        hidden_states[start_idx:end_idx] = attn_slice
         | 
| 1504 | 
            +
             | 
| 1505 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1506 | 
            +
             | 
| 1507 | 
            +
                    # linear proj
         | 
| 1508 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1509 | 
            +
                    # dropout
         | 
| 1510 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1511 | 
            +
             | 
| 1512 | 
            +
                    if input_ndim == 4:
         | 
| 1513 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 1514 | 
            +
             | 
| 1515 | 
            +
                    if attn.residual_connection:
         | 
| 1516 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1517 | 
            +
             | 
| 1518 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1519 | 
            +
             | 
| 1520 | 
            +
                    return hidden_states
         | 
| 1521 | 
            +
             | 
| 1522 | 
            +
             | 
| 1523 | 
            +
            class SlicedAttnAddedKVProcessor:
         | 
| 1524 | 
            +
                r"""
         | 
| 1525 | 
            +
                Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
         | 
| 1526 | 
            +
             | 
| 1527 | 
            +
                Args:
         | 
| 1528 | 
            +
                    slice_size (`int`, *optional*):
         | 
| 1529 | 
            +
                        The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
         | 
| 1530 | 
            +
                        `attention_head_dim` must be a multiple of the `slice_size`.
         | 
| 1531 | 
            +
                """
         | 
| 1532 | 
            +
             | 
| 1533 | 
            +
                def __init__(self, slice_size):
         | 
| 1534 | 
            +
                    self.slice_size = slice_size
         | 
| 1535 | 
            +
             | 
| 1536 | 
            +
                def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
         | 
| 1537 | 
            +
                    residual = hidden_states
         | 
| 1538 | 
            +
             | 
| 1539 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 1540 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 1541 | 
            +
             | 
| 1542 | 
            +
                    hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
         | 
| 1543 | 
            +
             | 
| 1544 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 1545 | 
            +
             | 
| 1546 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 1547 | 
            +
             | 
| 1548 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1549 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1550 | 
            +
                    elif attn.norm_cross:
         | 
| 1551 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 1552 | 
            +
             | 
| 1553 | 
            +
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 1554 | 
            +
             | 
| 1555 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 1556 | 
            +
                    dim = query.shape[-1]
         | 
| 1557 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 1558 | 
            +
             | 
| 1559 | 
            +
                    encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
         | 
| 1560 | 
            +
                    encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
         | 
| 1561 | 
            +
             | 
| 1562 | 
            +
                    encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
         | 
| 1563 | 
            +
                    encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
         | 
| 1564 | 
            +
             | 
| 1565 | 
            +
                    if not attn.only_cross_attention:
         | 
| 1566 | 
            +
                        key = attn.to_k(hidden_states)
         | 
| 1567 | 
            +
                        value = attn.to_v(hidden_states)
         | 
| 1568 | 
            +
                        key = attn.head_to_batch_dim(key)
         | 
| 1569 | 
            +
                        value = attn.head_to_batch_dim(value)
         | 
| 1570 | 
            +
                        key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
         | 
| 1571 | 
            +
                        value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
         | 
| 1572 | 
            +
                    else:
         | 
| 1573 | 
            +
                        key = encoder_hidden_states_key_proj
         | 
| 1574 | 
            +
                        value = encoder_hidden_states_value_proj
         | 
| 1575 | 
            +
             | 
| 1576 | 
            +
                    batch_size_attention, query_tokens, _ = query.shape
         | 
| 1577 | 
            +
                    hidden_states = torch.zeros(
         | 
| 1578 | 
            +
                        (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         | 
| 1579 | 
            +
                    )
         | 
| 1580 | 
            +
             | 
| 1581 | 
            +
                    for i in range(batch_size_attention // self.slice_size):
         | 
| 1582 | 
            +
                        start_idx = i * self.slice_size
         | 
| 1583 | 
            +
                        end_idx = (i + 1) * self.slice_size
         | 
| 1584 | 
            +
             | 
| 1585 | 
            +
                        query_slice = query[start_idx:end_idx]
         | 
| 1586 | 
            +
                        key_slice = key[start_idx:end_idx]
         | 
| 1587 | 
            +
                        attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         | 
| 1588 | 
            +
             | 
| 1589 | 
            +
                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 1590 | 
            +
             | 
| 1591 | 
            +
                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         | 
| 1592 | 
            +
             | 
| 1593 | 
            +
                        hidden_states[start_idx:end_idx] = attn_slice
         | 
| 1594 | 
            +
             | 
| 1595 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1596 | 
            +
             | 
| 1597 | 
            +
                    # linear proj
         | 
| 1598 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1599 | 
            +
                    # dropout
         | 
| 1600 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1601 | 
            +
             | 
| 1602 | 
            +
                    hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
         | 
| 1603 | 
            +
                    hidden_states = hidden_states + residual
         | 
| 1604 | 
            +
             | 
| 1605 | 
            +
                    return hidden_states
         | 
| 1606 | 
            +
             | 
| 1607 | 
            +
             | 
| 1608 | 
            +
            AttentionProcessor = Union[
         | 
| 1609 | 
            +
                AttnProcessor,
         | 
| 1610 | 
            +
                AttnProcessor2_0,
         | 
| 1611 | 
            +
                XFormersAttnProcessor,
         | 
| 1612 | 
            +
                SlicedAttnProcessor,
         | 
| 1613 | 
            +
                AttnAddedKVProcessor,
         | 
| 1614 | 
            +
                SlicedAttnAddedKVProcessor,
         | 
| 1615 | 
            +
                AttnAddedKVProcessor2_0,
         | 
| 1616 | 
            +
                XFormersAttnAddedKVProcessor,
         | 
| 1617 | 
            +
                LoRAAttnProcessor,
         | 
| 1618 | 
            +
                LoRAXFormersAttnProcessor,
         | 
| 1619 | 
            +
                LoRAAttnProcessor2_0,
         | 
| 1620 | 
            +
                LoRAAttnAddedKVProcessor,
         | 
| 1621 | 
            +
                CustomDiffusionAttnProcessor,
         | 
| 1622 | 
            +
                CustomDiffusionXFormersAttnProcessor,
         | 
| 1623 | 
            +
            ]
         | 
| 1624 | 
            +
             | 
| 1625 | 
            +
             | 
| 1626 | 
            +
            class SpatialNorm(nn.Module):
         | 
| 1627 | 
            +
                """
         | 
| 1628 | 
            +
                Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
         | 
| 1629 | 
            +
                """
         | 
| 1630 | 
            +
             | 
| 1631 | 
            +
                def __init__(
         | 
| 1632 | 
            +
                    self,
         | 
| 1633 | 
            +
                    f_channels,
         | 
| 1634 | 
            +
                    zq_channels,
         | 
| 1635 | 
            +
                ):
         | 
| 1636 | 
            +
                    super().__init__()
         | 
| 1637 | 
            +
                    self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
         | 
| 1638 | 
            +
                    self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
         | 
| 1639 | 
            +
                    self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
         | 
| 1640 | 
            +
             | 
| 1641 | 
            +
                def forward(self, f, zq):
         | 
| 1642 | 
            +
                    f_size = f.shape[-2:]
         | 
| 1643 | 
            +
                    zq = F.interpolate(zq, size=f_size, mode="nearest")
         | 
| 1644 | 
            +
                    norm_f = self.norm_layer(f)
         | 
| 1645 | 
            +
                    new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
         | 
| 1646 | 
            +
                    return new_f
         | 
    	
        diffusers/models/dual_transformer_2d.py
    ADDED
    
    | @@ -0,0 +1,151 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from typing import Optional
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from torch import nn
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class DualTransformer2DModel(nn.Module):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Parameters:
         | 
| 26 | 
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         | 
| 27 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         | 
| 28 | 
            +
                    in_channels (`int`, *optional*):
         | 
| 29 | 
            +
                        Pass if the input is continuous. The number of channels in the input and output.
         | 
| 30 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         | 
| 31 | 
            +
                    dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
         | 
| 32 | 
            +
                    cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
         | 
| 33 | 
            +
                    sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
         | 
| 34 | 
            +
                        Note that this is fixed at training time as it is used for learning a number of position embeddings. See
         | 
| 35 | 
            +
                        `ImagePositionalEmbeddings`.
         | 
| 36 | 
            +
                    num_vector_embeds (`int`, *optional*):
         | 
| 37 | 
            +
                        Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
         | 
| 38 | 
            +
                        Includes the class for the masked latent pixel.
         | 
| 39 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 40 | 
            +
                    num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
         | 
| 41 | 
            +
                        The number of diffusion steps used during training. Note that this is fixed at training time as it is used
         | 
| 42 | 
            +
                        to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
         | 
| 43 | 
            +
                        up to but not more than steps than `num_embeds_ada_norm`.
         | 
| 44 | 
            +
                    attention_bias (`bool`, *optional*):
         | 
| 45 | 
            +
                        Configure if the TransformerBlocks' attention should contain a bias parameter.
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                    num_attention_heads: int = 16,
         | 
| 51 | 
            +
                    attention_head_dim: int = 88,
         | 
| 52 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 53 | 
            +
                    num_layers: int = 1,
         | 
| 54 | 
            +
                    dropout: float = 0.0,
         | 
| 55 | 
            +
                    norm_num_groups: int = 32,
         | 
| 56 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 57 | 
            +
                    attention_bias: bool = False,
         | 
| 58 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 59 | 
            +
                    num_vector_embeds: Optional[int] = None,
         | 
| 60 | 
            +
                    activation_fn: str = "geglu",
         | 
| 61 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 62 | 
            +
                ):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    self.transformers = nn.ModuleList(
         | 
| 65 | 
            +
                        [
         | 
| 66 | 
            +
                            Transformer2DModel(
         | 
| 67 | 
            +
                                num_attention_heads=num_attention_heads,
         | 
| 68 | 
            +
                                attention_head_dim=attention_head_dim,
         | 
| 69 | 
            +
                                in_channels=in_channels,
         | 
| 70 | 
            +
                                num_layers=num_layers,
         | 
| 71 | 
            +
                                dropout=dropout,
         | 
| 72 | 
            +
                                norm_num_groups=norm_num_groups,
         | 
| 73 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 74 | 
            +
                                attention_bias=attention_bias,
         | 
| 75 | 
            +
                                sample_size=sample_size,
         | 
| 76 | 
            +
                                num_vector_embeds=num_vector_embeds,
         | 
| 77 | 
            +
                                activation_fn=activation_fn,
         | 
| 78 | 
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         | 
| 79 | 
            +
                            )
         | 
| 80 | 
            +
                            for _ in range(2)
         | 
| 81 | 
            +
                        ]
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # Variables that can be set by a pipeline:
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # The ratio of transformer1 to transformer2's output states to be combined during inference
         | 
| 87 | 
            +
                    self.mix_ratio = 0.5
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    # The shape of `encoder_hidden_states` is expected to be
         | 
| 90 | 
            +
                    # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
         | 
| 91 | 
            +
                    self.condition_lengths = [77, 257]
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # Which transformer to use to encode which condition.
         | 
| 94 | 
            +
                    # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
         | 
| 95 | 
            +
                    self.transformer_index_for_condition = [1, 0]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def forward(
         | 
| 98 | 
            +
                    self,
         | 
| 99 | 
            +
                    hidden_states,
         | 
| 100 | 
            +
                    encoder_hidden_states,
         | 
| 101 | 
            +
                    timestep=None,
         | 
| 102 | 
            +
                    attention_mask=None,
         | 
| 103 | 
            +
                    cross_attention_kwargs=None,
         | 
| 104 | 
            +
                    return_dict: bool = True,
         | 
| 105 | 
            +
                ):
         | 
| 106 | 
            +
                    """
         | 
| 107 | 
            +
                    Args:
         | 
| 108 | 
            +
                        hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
         | 
| 109 | 
            +
                            When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
         | 
| 110 | 
            +
                            hidden_states
         | 
| 111 | 
            +
                        encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
         | 
| 112 | 
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         | 
| 113 | 
            +
                            self-attention.
         | 
| 114 | 
            +
                        timestep ( `torch.long`, *optional*):
         | 
| 115 | 
            +
                            Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
         | 
| 116 | 
            +
                        attention_mask (`torch.FloatTensor`, *optional*):
         | 
| 117 | 
            +
                            Optional attention mask to be applied in Attention
         | 
| 118 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 119 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    Returns:
         | 
| 122 | 
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
         | 
| 123 | 
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
         | 
| 124 | 
            +
                        returning a tuple, the first element is the sample tensor.
         | 
| 125 | 
            +
                    """
         | 
| 126 | 
            +
                    input_states = hidden_states
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    encoded_states = []
         | 
| 129 | 
            +
                    tokens_start = 0
         | 
| 130 | 
            +
                    # attention_mask is not used yet
         | 
| 131 | 
            +
                    for i in range(2):
         | 
| 132 | 
            +
                        # for each of the two transformers, pass the corresponding condition tokens
         | 
| 133 | 
            +
                        condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
         | 
| 134 | 
            +
                        transformer_index = self.transformer_index_for_condition[i]
         | 
| 135 | 
            +
                        encoded_state = self.transformers[transformer_index](
         | 
| 136 | 
            +
                            input_states,
         | 
| 137 | 
            +
                            encoder_hidden_states=condition_state,
         | 
| 138 | 
            +
                            timestep=timestep,
         | 
| 139 | 
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         | 
| 140 | 
            +
                            return_dict=False,
         | 
| 141 | 
            +
                        )[0]
         | 
| 142 | 
            +
                        encoded_states.append(encoded_state - input_states)
         | 
| 143 | 
            +
                        tokens_start += self.condition_lengths[i]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
         | 
| 146 | 
            +
                    output_states = output_states + input_states
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if not return_dict:
         | 
| 149 | 
            +
                        return (output_states,)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    return Transformer2DModelOutput(sample=output_states)
         | 
    	
        diffusers/models/embeddings.py
    ADDED
    
    | @@ -0,0 +1,480 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import math
         | 
| 16 | 
            +
            from typing import Optional
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import numpy as np
         | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            from torch import nn
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .activations import get_activation
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def get_timestep_embedding(
         | 
| 26 | 
            +
                timesteps: torch.Tensor,
         | 
| 27 | 
            +
                embedding_dim: int,
         | 
| 28 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 29 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 30 | 
            +
                scale: float = 1,
         | 
| 31 | 
            +
                max_period: int = 10000,
         | 
| 32 | 
            +
            ):
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 37 | 
            +
                                  These may be fractional.
         | 
| 38 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 39 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                half_dim = embedding_dim // 2
         | 
| 44 | 
            +
                exponent = -math.log(max_period) * torch.arange(
         | 
| 45 | 
            +
                    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                emb = torch.exp(exponent)
         | 
| 50 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                # scale embeddings
         | 
| 53 | 
            +
                emb = scale * emb
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # concat sine and cosine embeddings
         | 
| 56 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # flip sine and cosine embeddings
         | 
| 59 | 
            +
                if flip_sin_to_cos:
         | 
| 60 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # zero pad
         | 
| 63 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 64 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 65 | 
            +
                return emb
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
         | 
| 71 | 
            +
                [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                grid_h = np.arange(grid_size, dtype=np.float32)
         | 
| 74 | 
            +
                grid_w = np.arange(grid_size, dtype=np.float32)
         | 
| 75 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 76 | 
            +
                grid = np.stack(grid, axis=0)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                grid = grid.reshape([2, 1, grid_size, grid_size])
         | 
| 79 | 
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 80 | 
            +
                if cls_token and extra_tokens > 0:
         | 
| 81 | 
            +
                    pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
         | 
| 82 | 
            +
                return pos_embed
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 86 | 
            +
                if embed_dim % 2 != 0:
         | 
| 87 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 90 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
         | 
| 91 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
         | 
| 94 | 
            +
                return emb
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                if embed_dim % 2 != 0:
         | 
| 102 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         | 
| 105 | 
            +
                omega /= embed_dim / 2.0
         | 
| 106 | 
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                pos = pos.reshape(-1)  # (M,)
         | 
| 109 | 
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                emb_sin = np.sin(out)  # (M, D/2)
         | 
| 112 | 
            +
                emb_cos = np.cos(out)  # (M, D/2)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         | 
| 115 | 
            +
                return emb
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 119 | 
            +
                """2D Image to Patch Embedding"""
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def __init__(
         | 
| 122 | 
            +
                    self,
         | 
| 123 | 
            +
                    height=224,
         | 
| 124 | 
            +
                    width=224,
         | 
| 125 | 
            +
                    patch_size=16,
         | 
| 126 | 
            +
                    in_channels=3,
         | 
| 127 | 
            +
                    embed_dim=768,
         | 
| 128 | 
            +
                    layer_norm=False,
         | 
| 129 | 
            +
                    flatten=True,
         | 
| 130 | 
            +
                    bias=True,
         | 
| 131 | 
            +
                ):
         | 
| 132 | 
            +
                    super().__init__()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    num_patches = (height // patch_size) * (width // patch_size)
         | 
| 135 | 
            +
                    self.flatten = flatten
         | 
| 136 | 
            +
                    self.layer_norm = layer_norm
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.proj = nn.Conv2d(
         | 
| 139 | 
            +
                        in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    if layer_norm:
         | 
| 142 | 
            +
                        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        self.norm = None
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
         | 
| 147 | 
            +
                    self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def forward(self, latent):
         | 
| 150 | 
            +
                    latent = self.proj(latent)
         | 
| 151 | 
            +
                    if self.flatten:
         | 
| 152 | 
            +
                        latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC
         | 
| 153 | 
            +
                    if self.layer_norm:
         | 
| 154 | 
            +
                        latent = self.norm(latent)
         | 
| 155 | 
            +
                    return latent + self.pos_embed
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 159 | 
            +
                def __init__(
         | 
| 160 | 
            +
                    self,
         | 
| 161 | 
            +
                    in_channels: int,
         | 
| 162 | 
            +
                    time_embed_dim: int,
         | 
| 163 | 
            +
                    act_fn: str = "silu",
         | 
| 164 | 
            +
                    out_dim: int = None,
         | 
| 165 | 
            +
                    post_act_fn: Optional[str] = None,
         | 
| 166 | 
            +
                    cond_proj_dim=None,
         | 
| 167 | 
            +
                ):
         | 
| 168 | 
            +
                    super().__init__()
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    if cond_proj_dim is not None:
         | 
| 173 | 
            +
                        self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
         | 
| 174 | 
            +
                    else:
         | 
| 175 | 
            +
                        self.cond_proj = None
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    self.act = get_activation(act_fn)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    if out_dim is not None:
         | 
| 180 | 
            +
                        time_embed_dim_out = out_dim
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        time_embed_dim_out = time_embed_dim
         | 
| 183 | 
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    if post_act_fn is None:
         | 
| 186 | 
            +
                        self.post_act = None
         | 
| 187 | 
            +
                    else:
         | 
| 188 | 
            +
                        self.post_act = get_activation(post_act_fn)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def forward(self, sample, condition=None):
         | 
| 191 | 
            +
                    if condition is not None:
         | 
| 192 | 
            +
                        sample = sample + self.cond_proj(condition)
         | 
| 193 | 
            +
                    sample = self.linear_1(sample)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if self.act is not None:
         | 
| 196 | 
            +
                        sample = self.act(sample)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    sample = self.linear_2(sample)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    if self.post_act is not None:
         | 
| 201 | 
            +
                        sample = self.post_act(sample)
         | 
| 202 | 
            +
                    return sample
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            class Timesteps(nn.Module):
         | 
| 206 | 
            +
                def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
         | 
| 207 | 
            +
                    super().__init__()
         | 
| 208 | 
            +
                    self.num_channels = num_channels
         | 
| 209 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 210 | 
            +
                    self.downscale_freq_shift = downscale_freq_shift
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def forward(self, timesteps):
         | 
| 213 | 
            +
                    t_emb = get_timestep_embedding(
         | 
| 214 | 
            +
                        timesteps,
         | 
| 215 | 
            +
                        self.num_channels,
         | 
| 216 | 
            +
                        flip_sin_to_cos=self.flip_sin_to_cos,
         | 
| 217 | 
            +
                        downscale_freq_shift=self.downscale_freq_shift,
         | 
| 218 | 
            +
                    )
         | 
| 219 | 
            +
                    return t_emb
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            class GaussianFourierProjection(nn.Module):
         | 
| 223 | 
            +
                """Gaussian Fourier embeddings for noise levels."""
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def __init__(
         | 
| 226 | 
            +
                    self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
         | 
| 227 | 
            +
                ):
         | 
| 228 | 
            +
                    super().__init__()
         | 
| 229 | 
            +
                    self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         | 
| 230 | 
            +
                    self.log = log
         | 
| 231 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if set_W_to_weight:
         | 
| 234 | 
            +
                        # to delete later
         | 
| 235 | 
            +
                        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                        self.weight = self.W
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def forward(self, x):
         | 
| 240 | 
            +
                    if self.log:
         | 
| 241 | 
            +
                        x = torch.log(x)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    if self.flip_sin_to_cos:
         | 
| 246 | 
            +
                        out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
         | 
| 247 | 
            +
                    else:
         | 
| 248 | 
            +
                        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
         | 
| 249 | 
            +
                    return out
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            class ImagePositionalEmbeddings(nn.Module):
         | 
| 253 | 
            +
                """
         | 
| 254 | 
            +
                Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
         | 
| 255 | 
            +
                height and width of the latent space.
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                For VQ-diffusion:
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                Output vector embeddings are used as input for the transformer.
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                Args:
         | 
| 266 | 
            +
                    num_embed (`int`):
         | 
| 267 | 
            +
                        Number of embeddings for the latent pixels embeddings.
         | 
| 268 | 
            +
                    height (`int`):
         | 
| 269 | 
            +
                        Height of the latent image i.e. the number of height embeddings.
         | 
| 270 | 
            +
                    width (`int`):
         | 
| 271 | 
            +
                        Width of the latent image i.e. the number of width embeddings.
         | 
| 272 | 
            +
                    embed_dim (`int`):
         | 
| 273 | 
            +
                        Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
         | 
| 274 | 
            +
                """
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                def __init__(
         | 
| 277 | 
            +
                    self,
         | 
| 278 | 
            +
                    num_embed: int,
         | 
| 279 | 
            +
                    height: int,
         | 
| 280 | 
            +
                    width: int,
         | 
| 281 | 
            +
                    embed_dim: int,
         | 
| 282 | 
            +
                ):
         | 
| 283 | 
            +
                    super().__init__()
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    self.height = height
         | 
| 286 | 
            +
                    self.width = width
         | 
| 287 | 
            +
                    self.num_embed = num_embed
         | 
| 288 | 
            +
                    self.embed_dim = embed_dim
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    self.emb = nn.Embedding(self.num_embed, embed_dim)
         | 
| 291 | 
            +
                    self.height_emb = nn.Embedding(self.height, embed_dim)
         | 
| 292 | 
            +
                    self.width_emb = nn.Embedding(self.width, embed_dim)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def forward(self, index):
         | 
| 295 | 
            +
                    emb = self.emb(index)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # 1 x H x D -> 1 x H x 1 x D
         | 
| 300 | 
            +
                    height_emb = height_emb.unsqueeze(2)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    # 1 x W x D -> 1 x 1 x W x D
         | 
| 305 | 
            +
                    width_emb = width_emb.unsqueeze(1)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    pos_emb = height_emb + width_emb
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # 1 x H x W x D -> 1 x L xD
         | 
| 310 | 
            +
                    pos_emb = pos_emb.view(1, self.height * self.width, -1)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    emb = emb + pos_emb[:, : emb.shape[1], :]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    return emb
         | 
| 315 | 
            +
             | 
| 316 | 
            +
             | 
| 317 | 
            +
            class LabelEmbedding(nn.Module):
         | 
| 318 | 
            +
                """
         | 
| 319 | 
            +
                Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                Args:
         | 
| 322 | 
            +
                    num_classes (`int`): The number of classes.
         | 
| 323 | 
            +
                    hidden_size (`int`): The size of the vector embeddings.
         | 
| 324 | 
            +
                    dropout_prob (`float`): The probability of dropping a label.
         | 
| 325 | 
            +
                """
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def __init__(self, num_classes, hidden_size, dropout_prob):
         | 
| 328 | 
            +
                    super().__init__()
         | 
| 329 | 
            +
                    use_cfg_embedding = dropout_prob > 0
         | 
| 330 | 
            +
                    self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
         | 
| 331 | 
            +
                    self.num_classes = num_classes
         | 
| 332 | 
            +
                    self.dropout_prob = dropout_prob
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                def token_drop(self, labels, force_drop_ids=None):
         | 
| 335 | 
            +
                    """
         | 
| 336 | 
            +
                    Drops labels to enable classifier-free guidance.
         | 
| 337 | 
            +
                    """
         | 
| 338 | 
            +
                    if force_drop_ids is None:
         | 
| 339 | 
            +
                        drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        drop_ids = torch.tensor(force_drop_ids == 1)
         | 
| 342 | 
            +
                    labels = torch.where(drop_ids, self.num_classes, labels)
         | 
| 343 | 
            +
                    return labels
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def forward(self, labels: torch.LongTensor, force_drop_ids=None):
         | 
| 346 | 
            +
                    use_dropout = self.dropout_prob > 0
         | 
| 347 | 
            +
                    if (self.training and use_dropout) or (force_drop_ids is not None):
         | 
| 348 | 
            +
                        labels = self.token_drop(labels, force_drop_ids)
         | 
| 349 | 
            +
                    embeddings = self.embedding_table(labels)
         | 
| 350 | 
            +
                    return embeddings
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
            +
            class TextImageProjection(nn.Module):
         | 
| 354 | 
            +
                def __init__(
         | 
| 355 | 
            +
                    self,
         | 
| 356 | 
            +
                    text_embed_dim: int = 1024,
         | 
| 357 | 
            +
                    image_embed_dim: int = 768,
         | 
| 358 | 
            +
                    cross_attention_dim: int = 768,
         | 
| 359 | 
            +
                    num_image_text_embeds: int = 10,
         | 
| 360 | 
            +
                ):
         | 
| 361 | 
            +
                    super().__init__()
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    self.num_image_text_embeds = num_image_text_embeds
         | 
| 364 | 
            +
                    self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
         | 
| 365 | 
            +
                    self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
         | 
| 368 | 
            +
                    batch_size = text_embeds.shape[0]
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # image
         | 
| 371 | 
            +
                    image_text_embeds = self.image_embeds(image_embeds)
         | 
| 372 | 
            +
                    image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    # text
         | 
| 375 | 
            +
                    text_embeds = self.text_proj(text_embeds)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return torch.cat([image_text_embeds, text_embeds], dim=1)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            class CombinedTimestepLabelEmbeddings(nn.Module):
         | 
| 381 | 
            +
                def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
         | 
| 382 | 
            +
                    super().__init__()
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
         | 
| 385 | 
            +
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         | 
| 386 | 
            +
                    self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                def forward(self, timestep, class_labels, hidden_dtype=None):
         | 
| 389 | 
            +
                    timesteps_proj = self.time_proj(timestep)
         | 
| 390 | 
            +
                    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    class_labels = self.class_embedder(class_labels)  # (N, D)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    conditioning = timesteps_emb + class_labels  # (N, D)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    return conditioning
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            class TextTimeEmbedding(nn.Module):
         | 
| 400 | 
            +
                def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
         | 
| 401 | 
            +
                    super().__init__()
         | 
| 402 | 
            +
                    self.norm1 = nn.LayerNorm(encoder_dim)
         | 
| 403 | 
            +
                    self.pool = AttentionPooling(num_heads, encoder_dim)
         | 
| 404 | 
            +
                    self.proj = nn.Linear(encoder_dim, time_embed_dim)
         | 
| 405 | 
            +
                    self.norm2 = nn.LayerNorm(time_embed_dim)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                def forward(self, hidden_states):
         | 
| 408 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 409 | 
            +
                    hidden_states = self.pool(hidden_states)
         | 
| 410 | 
            +
                    hidden_states = self.proj(hidden_states)
         | 
| 411 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 412 | 
            +
                    return hidden_states
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            class TextImageTimeEmbedding(nn.Module):
         | 
| 416 | 
            +
                def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
         | 
| 417 | 
            +
                    super().__init__()
         | 
| 418 | 
            +
                    self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
         | 
| 419 | 
            +
                    self.text_norm = nn.LayerNorm(time_embed_dim)
         | 
| 420 | 
            +
                    self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
         | 
| 423 | 
            +
                    # text
         | 
| 424 | 
            +
                    time_text_embeds = self.text_proj(text_embeds)
         | 
| 425 | 
            +
                    time_text_embeds = self.text_norm(time_text_embeds)
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    # image
         | 
| 428 | 
            +
                    time_image_embeds = self.image_proj(image_embeds)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    return time_image_embeds + time_text_embeds
         | 
| 431 | 
            +
             | 
| 432 | 
            +
             | 
| 433 | 
            +
            class AttentionPooling(nn.Module):
         | 
| 434 | 
            +
                # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                def __init__(self, num_heads, embed_dim, dtype=None):
         | 
| 437 | 
            +
                    super().__init__()
         | 
| 438 | 
            +
                    self.dtype = dtype
         | 
| 439 | 
            +
                    self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
         | 
| 440 | 
            +
                    self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         | 
| 441 | 
            +
                    self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         | 
| 442 | 
            +
                    self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
         | 
| 443 | 
            +
                    self.num_heads = num_heads
         | 
| 444 | 
            +
                    self.dim_per_head = embed_dim // self.num_heads
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                def forward(self, x):
         | 
| 447 | 
            +
                    bs, length, width = x.size()
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    def shape(x):
         | 
| 450 | 
            +
                        # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
         | 
| 451 | 
            +
                        x = x.view(bs, -1, self.num_heads, self.dim_per_head)
         | 
| 452 | 
            +
                        # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
         | 
| 453 | 
            +
                        x = x.transpose(1, 2)
         | 
| 454 | 
            +
                        # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
         | 
| 455 | 
            +
                        x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
         | 
| 456 | 
            +
                        # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
         | 
| 457 | 
            +
                        x = x.transpose(1, 2)
         | 
| 458 | 
            +
                        return x
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
         | 
| 461 | 
            +
                    x = torch.cat([class_token, x], dim=1)  # (bs, length+1, width)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # (bs*n_heads, class_token_length, dim_per_head)
         | 
| 464 | 
            +
                    q = shape(self.q_proj(class_token))
         | 
| 465 | 
            +
                    # (bs*n_heads, length+class_token_length, dim_per_head)
         | 
| 466 | 
            +
                    k = shape(self.k_proj(x))
         | 
| 467 | 
            +
                    v = shape(self.v_proj(x))
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    # (bs*n_heads, class_token_length, length+class_token_length):
         | 
| 470 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
         | 
| 471 | 
            +
                    weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)  # More stable with f16 than dividing afterwards
         | 
| 472 | 
            +
                    weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    # (bs*n_heads, dim_per_head, class_token_length)
         | 
| 475 | 
            +
                    a = torch.einsum("bts,bcs->bct", weight, v)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # (bs, length+1, width)
         | 
| 478 | 
            +
                    a = a.reshape(bs, -1, 1).transpose(1, 2)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    return a[:, 0, :]  # cls_token
         | 
    	
        diffusers/models/loaders.py
    ADDED
    
    | @@ -0,0 +1,1481 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import os
         | 
| 16 | 
            +
            import warnings
         | 
| 17 | 
            +
            from collections import defaultdict
         | 
| 18 | 
            +
            from pathlib import Path
         | 
| 19 | 
            +
            from typing import Callable, Dict, List, Optional, Union
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import torch.nn.functional as F
         | 
| 23 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from .attention_processor import (
         | 
| 26 | 
            +
                AttnAddedKVProcessor,
         | 
| 27 | 
            +
                AttnAddedKVProcessor2_0,
         | 
| 28 | 
            +
                CustomDiffusionAttnProcessor,
         | 
| 29 | 
            +
                CustomDiffusionXFormersAttnProcessor,
         | 
| 30 | 
            +
                LoRAAttnAddedKVProcessor,
         | 
| 31 | 
            +
                LoRAAttnProcessor,
         | 
| 32 | 
            +
                LoRAAttnProcessor2_0,
         | 
| 33 | 
            +
                LoRAXFormersAttnProcessor,
         | 
| 34 | 
            +
                SlicedAttnAddedKVProcessor,
         | 
| 35 | 
            +
                XFormersAttnProcessor,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
            from ..utils.constants import DIFFUSERS_CACHE, TEXT_ENCODER_ATTN_MODULE
         | 
| 38 | 
            +
            from ..utils.hub_utils import HF_HUB_OFFLINE, _get_model_file
         | 
| 39 | 
            +
            from ..utils.deprecation_utils import deprecate
         | 
| 40 | 
            +
            from ..utils.import_utils import is_safetensors_available, is_transformers_available
         | 
| 41 | 
            +
            from ..utils.logging import get_logger
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            if is_safetensors_available():
         | 
| 44 | 
            +
                import safetensors
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            if is_transformers_available():
         | 
| 47 | 
            +
                from transformers import PreTrainedModel, PreTrainedTokenizer
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            logger = get_logger(__name__)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            TEXT_ENCODER_NAME = "text_encoder"
         | 
| 53 | 
            +
            UNET_NAME = "unet"
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
         | 
| 56 | 
            +
            LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            TEXT_INVERSION_NAME = "learned_embeds.bin"
         | 
| 59 | 
            +
            TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
         | 
| 62 | 
            +
            CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class AttnProcsLayers(torch.nn.Module):
         | 
| 66 | 
            +
                def __init__(self, state_dict: Dict[str, torch.Tensor]):
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    self.layers = torch.nn.ModuleList(state_dict.values())
         | 
| 69 | 
            +
                    self.mapping = dict(enumerate(state_dict.keys()))
         | 
| 70 | 
            +
                    self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # .processor for unet, .self_attn for text encoder
         | 
| 73 | 
            +
                    self.split_keys = [".processor", ".self_attn"]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # we add a hook to state_dict() and load_state_dict() so that the
         | 
| 76 | 
            +
                    # naming fits with `unet.attn_processors`
         | 
| 77 | 
            +
                    def map_to(module, state_dict, *args, **kwargs):
         | 
| 78 | 
            +
                        new_state_dict = {}
         | 
| 79 | 
            +
                        for key, value in state_dict.items():
         | 
| 80 | 
            +
                            num = int(key.split(".")[1])  # 0 is always "layers"
         | 
| 81 | 
            +
                            new_key = key.replace(f"layers.{num}", module.mapping[num])
         | 
| 82 | 
            +
                            new_state_dict[new_key] = value
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        return new_state_dict
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    def remap_key(key, state_dict):
         | 
| 87 | 
            +
                        for k in self.split_keys:
         | 
| 88 | 
            +
                            if k in key:
         | 
| 89 | 
            +
                                return key.split(k)[0] + k
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        raise ValueError(
         | 
| 92 | 
            +
                            f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
         | 
| 93 | 
            +
                        )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    def map_from(module, state_dict, *args, **kwargs):
         | 
| 96 | 
            +
                        all_keys = list(state_dict.keys())
         | 
| 97 | 
            +
                        for key in all_keys:
         | 
| 98 | 
            +
                            replace_key = remap_key(key, state_dict)
         | 
| 99 | 
            +
                            new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
         | 
| 100 | 
            +
                            state_dict[new_key] = state_dict[key]
         | 
| 101 | 
            +
                            del state_dict[key]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    self._register_state_dict_hook(map_to)
         | 
| 104 | 
            +
                    self._register_load_state_dict_pre_hook(map_from, with_module=True)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            class UNet2DConditionLoadersMixin:
         | 
| 108 | 
            +
                text_encoder_name = TEXT_ENCODER_NAME
         | 
| 109 | 
            +
                unet_name = UNET_NAME
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
         | 
| 112 | 
            +
                    r"""
         | 
| 113 | 
            +
                    Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
         | 
| 114 | 
            +
                    defined in
         | 
| 115 | 
            +
                    [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
         | 
| 116 | 
            +
                    and be a `torch.nn.Module` class.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    Parameters:
         | 
| 119 | 
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         | 
| 120 | 
            +
                            Can be either:
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                                - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         | 
| 123 | 
            +
                                  the Hub.
         | 
| 124 | 
            +
                                - A path to a directory (for example `./my_model_directory`) containing the model weights saved
         | 
| 125 | 
            +
                                  with [`ModelMixin.save_pretrained`].
         | 
| 126 | 
            +
                                - A [torch state
         | 
| 127 | 
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 130 | 
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         | 
| 131 | 
            +
                            is not used.
         | 
| 132 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 133 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 134 | 
            +
                            cached versions if they exist.
         | 
| 135 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 136 | 
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         | 
| 137 | 
            +
                            incompletely downloaded files are deleted.
         | 
| 138 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 139 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         | 
| 140 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 141 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 142 | 
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         | 
| 143 | 
            +
                            won't be downloaded from the Hub.
         | 
| 144 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 145 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         | 
| 146 | 
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         | 
| 147 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 148 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         | 
| 149 | 
            +
                            allowed by Git.
         | 
| 150 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 151 | 
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         | 
| 152 | 
            +
                        mirror (`str`, *optional*):
         | 
| 153 | 
            +
                            Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
         | 
| 154 | 
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         | 
| 155 | 
            +
                            information.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    """
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 160 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 161 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 162 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 163 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 164 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 165 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 166 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 167 | 
            +
                    weight_name = kwargs.pop("weight_name", None)
         | 
| 168 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         | 
| 169 | 
            +
                    # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
         | 
| 170 | 
            +
                    # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
         | 
| 171 | 
            +
                    network_alpha = kwargs.pop("network_alpha", None)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    if use_safetensors and not is_safetensors_available():
         | 
| 174 | 
            +
                        raise ValueError(
         | 
| 175 | 
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    allow_pickle = False
         | 
| 179 | 
            +
                    if use_safetensors is None:
         | 
| 180 | 
            +
                        use_safetensors = is_safetensors_available()
         | 
| 181 | 
            +
                        allow_pickle = True
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    user_agent = {
         | 
| 184 | 
            +
                        "file_type": "attn_procs_weights",
         | 
| 185 | 
            +
                        "framework": "pytorch",
         | 
| 186 | 
            +
                    }
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    model_file = None
         | 
| 189 | 
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         | 
| 190 | 
            +
                        # Let's first try to load .safetensors weights
         | 
| 191 | 
            +
                        if (use_safetensors and weight_name is None) or (
         | 
| 192 | 
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         | 
| 193 | 
            +
                        ):
         | 
| 194 | 
            +
                            try:
         | 
| 195 | 
            +
                                model_file = _get_model_file(
         | 
| 196 | 
            +
                                    pretrained_model_name_or_path_or_dict,
         | 
| 197 | 
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         | 
| 198 | 
            +
                                    cache_dir=cache_dir,
         | 
| 199 | 
            +
                                    force_download=force_download,
         | 
| 200 | 
            +
                                    resume_download=resume_download,
         | 
| 201 | 
            +
                                    proxies=proxies,
         | 
| 202 | 
            +
                                    local_files_only=local_files_only,
         | 
| 203 | 
            +
                                    use_auth_token=use_auth_token,
         | 
| 204 | 
            +
                                    revision=revision,
         | 
| 205 | 
            +
                                    subfolder=subfolder,
         | 
| 206 | 
            +
                                    user_agent=user_agent,
         | 
| 207 | 
            +
                                )
         | 
| 208 | 
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         | 
| 209 | 
            +
                            except IOError as e:
         | 
| 210 | 
            +
                                if not allow_pickle:
         | 
| 211 | 
            +
                                    raise e
         | 
| 212 | 
            +
                                # try loading non-safetensors weights
         | 
| 213 | 
            +
                                pass
         | 
| 214 | 
            +
                        if model_file is None:
         | 
| 215 | 
            +
                            model_file = _get_model_file(
         | 
| 216 | 
            +
                                pretrained_model_name_or_path_or_dict,
         | 
| 217 | 
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         | 
| 218 | 
            +
                                cache_dir=cache_dir,
         | 
| 219 | 
            +
                                force_download=force_download,
         | 
| 220 | 
            +
                                resume_download=resume_download,
         | 
| 221 | 
            +
                                proxies=proxies,
         | 
| 222 | 
            +
                                local_files_only=local_files_only,
         | 
| 223 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 224 | 
            +
                                revision=revision,
         | 
| 225 | 
            +
                                subfolder=subfolder,
         | 
| 226 | 
            +
                                user_agent=user_agent,
         | 
| 227 | 
            +
                            )
         | 
| 228 | 
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # fill attn processors
         | 
| 233 | 
            +
                    attn_processors = {}
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    is_lora = all("lora" in k for k in state_dict.keys())
         | 
| 236 | 
            +
                    is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if is_lora:
         | 
| 239 | 
            +
                        is_new_lora_format = all(
         | 
| 240 | 
            +
                            key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
         | 
| 241 | 
            +
                        )
         | 
| 242 | 
            +
                        if is_new_lora_format:
         | 
| 243 | 
            +
                            # Strip the `"unet"` prefix.
         | 
| 244 | 
            +
                            is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
         | 
| 245 | 
            +
                            if is_text_encoder_present:
         | 
| 246 | 
            +
                                warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
         | 
| 247 | 
            +
                                warnings.warn(warn_message)
         | 
| 248 | 
            +
                            unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
         | 
| 249 | 
            +
                            state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        lora_grouped_dict = defaultdict(dict)
         | 
| 252 | 
            +
                        for key, value in state_dict.items():
         | 
| 253 | 
            +
                            attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         | 
| 254 | 
            +
                            lora_grouped_dict[attn_processor_key][sub_key] = value
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                        for key, value_dict in lora_grouped_dict.items():
         | 
| 257 | 
            +
                            rank = value_dict["to_k_lora.down.weight"].shape[0]
         | 
| 258 | 
            +
                            hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                            attn_processor = self
         | 
| 261 | 
            +
                            for sub_key in key.split("."):
         | 
| 262 | 
            +
                                attn_processor = getattr(attn_processor, sub_key)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                            if isinstance(
         | 
| 265 | 
            +
                                attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
         | 
| 266 | 
            +
                            ):
         | 
| 267 | 
            +
                                cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
         | 
| 268 | 
            +
                                attn_processor_class = LoRAAttnAddedKVProcessor
         | 
| 269 | 
            +
                            else:
         | 
| 270 | 
            +
                                cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
         | 
| 271 | 
            +
                                if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
         | 
| 272 | 
            +
                                    attn_processor_class = LoRAXFormersAttnProcessor
         | 
| 273 | 
            +
                                else:
         | 
| 274 | 
            +
                                    attn_processor_class = (
         | 
| 275 | 
            +
                                        LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         | 
| 276 | 
            +
                                    )
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                            attn_processors[key] = attn_processor_class(
         | 
| 279 | 
            +
                                hidden_size=hidden_size,
         | 
| 280 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 281 | 
            +
                                rank=rank,
         | 
| 282 | 
            +
                                network_alpha=network_alpha,
         | 
| 283 | 
            +
                            )
         | 
| 284 | 
            +
                            attn_processors[key].load_state_dict(value_dict)
         | 
| 285 | 
            +
                    elif is_custom_diffusion:
         | 
| 286 | 
            +
                        custom_diffusion_grouped_dict = defaultdict(dict)
         | 
| 287 | 
            +
                        for key, value in state_dict.items():
         | 
| 288 | 
            +
                            if len(value) == 0:
         | 
| 289 | 
            +
                                custom_diffusion_grouped_dict[key] = {}
         | 
| 290 | 
            +
                            else:
         | 
| 291 | 
            +
                                if "to_out" in key:
         | 
| 292 | 
            +
                                    attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         | 
| 293 | 
            +
                                else:
         | 
| 294 | 
            +
                                    attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
         | 
| 295 | 
            +
                                custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                        for key, value_dict in custom_diffusion_grouped_dict.items():
         | 
| 298 | 
            +
                            if len(value_dict) == 0:
         | 
| 299 | 
            +
                                attn_processors[key] = CustomDiffusionAttnProcessor(
         | 
| 300 | 
            +
                                    train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
         | 
| 301 | 
            +
                                )
         | 
| 302 | 
            +
                            else:
         | 
| 303 | 
            +
                                cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
         | 
| 304 | 
            +
                                hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
         | 
| 305 | 
            +
                                train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
         | 
| 306 | 
            +
                                attn_processors[key] = CustomDiffusionAttnProcessor(
         | 
| 307 | 
            +
                                    train_kv=True,
         | 
| 308 | 
            +
                                    train_q_out=train_q_out,
         | 
| 309 | 
            +
                                    hidden_size=hidden_size,
         | 
| 310 | 
            +
                                    cross_attention_dim=cross_attention_dim,
         | 
| 311 | 
            +
                                )
         | 
| 312 | 
            +
                                attn_processors[key].load_state_dict(value_dict)
         | 
| 313 | 
            +
                    else:
         | 
| 314 | 
            +
                        raise ValueError(
         | 
| 315 | 
            +
                            f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
         | 
| 316 | 
            +
                        )
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    # set correct dtype & device
         | 
| 319 | 
            +
                    attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    # set layers
         | 
| 322 | 
            +
                    self.set_attn_processor(attn_processors)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def save_attn_procs(
         | 
| 325 | 
            +
                    self,
         | 
| 326 | 
            +
                    save_directory: Union[str, os.PathLike],
         | 
| 327 | 
            +
                    is_main_process: bool = True,
         | 
| 328 | 
            +
                    weight_name: str = None,
         | 
| 329 | 
            +
                    save_function: Callable = None,
         | 
| 330 | 
            +
                    safe_serialization: bool = False,
         | 
| 331 | 
            +
                    **kwargs,
         | 
| 332 | 
            +
                ):
         | 
| 333 | 
            +
                    r"""
         | 
| 334 | 
            +
                    Save an attention processor to a directory so that it can be reloaded using the
         | 
| 335 | 
            +
                    [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    Arguments:
         | 
| 338 | 
            +
                        save_directory (`str` or `os.PathLike`):
         | 
| 339 | 
            +
                            Directory to save an attention processor to. Will be created if it doesn't exist.
         | 
| 340 | 
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         | 
| 341 | 
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         | 
| 342 | 
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         | 
| 343 | 
            +
                            process to avoid race conditions.
         | 
| 344 | 
            +
                        save_function (`Callable`):
         | 
| 345 | 
            +
                            The function to use to save the state dictionary. Useful during distributed training when you need to
         | 
| 346 | 
            +
                            replace `torch.save` with another method. Can be configured with the environment variable
         | 
| 347 | 
            +
                            `DIFFUSERS_SAVE_MODE`.
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    """
         | 
| 350 | 
            +
                    weight_name = weight_name or deprecate(
         | 
| 351 | 
            +
                        "weights_name",
         | 
| 352 | 
            +
                        "0.20.0",
         | 
| 353 | 
            +
                        "`weights_name` is deprecated, please use `weight_name` instead.",
         | 
| 354 | 
            +
                        take_from=kwargs,
         | 
| 355 | 
            +
                    )
         | 
| 356 | 
            +
                    if os.path.isfile(save_directory):
         | 
| 357 | 
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         | 
| 358 | 
            +
                        return
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    if save_function is None:
         | 
| 361 | 
            +
                        if safe_serialization:
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                            def save_function(weights, filename):
         | 
| 364 | 
            +
                                return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                        else:
         | 
| 367 | 
            +
                            save_function = torch.save
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    os.makedirs(save_directory, exist_ok=True)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    is_custom_diffusion = any(
         | 
| 372 | 
            +
                        isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
         | 
| 373 | 
            +
                        for (_, x) in self.attn_processors.items()
         | 
| 374 | 
            +
                    )
         | 
| 375 | 
            +
                    if is_custom_diffusion:
         | 
| 376 | 
            +
                        model_to_save = AttnProcsLayers(
         | 
| 377 | 
            +
                            {
         | 
| 378 | 
            +
                                y: x
         | 
| 379 | 
            +
                                for (y, x) in self.attn_processors.items()
         | 
| 380 | 
            +
                                if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
         | 
| 381 | 
            +
                            }
         | 
| 382 | 
            +
                        )
         | 
| 383 | 
            +
                        state_dict = model_to_save.state_dict()
         | 
| 384 | 
            +
                        for name, attn in self.attn_processors.items():
         | 
| 385 | 
            +
                            if len(attn.state_dict()) == 0:
         | 
| 386 | 
            +
                                state_dict[name] = {}
         | 
| 387 | 
            +
                    else:
         | 
| 388 | 
            +
                        model_to_save = AttnProcsLayers(self.attn_processors)
         | 
| 389 | 
            +
                        state_dict = model_to_save.state_dict()
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    if weight_name is None:
         | 
| 392 | 
            +
                        if safe_serialization:
         | 
| 393 | 
            +
                            weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
         | 
| 394 | 
            +
                        else:
         | 
| 395 | 
            +
                            weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    # Save the model
         | 
| 398 | 
            +
                    save_function(state_dict, os.path.join(save_directory, weight_name))
         | 
| 399 | 
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
         | 
| 400 | 
            +
             | 
| 401 | 
            +
             | 
| 402 | 
            +
            class TextualInversionLoaderMixin:
         | 
| 403 | 
            +
                r"""
         | 
| 404 | 
            +
                Load textual inversion tokens and embeddings to the tokenizer and text encoder.
         | 
| 405 | 
            +
                """
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
         | 
| 408 | 
            +
                    r"""
         | 
| 409 | 
            +
                    Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
         | 
| 410 | 
            +
                    be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
         | 
| 411 | 
            +
                    inversion token or if the textual inversion token is a single vector, the input prompt is returned.
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    Parameters:
         | 
| 414 | 
            +
                        prompt (`str` or list of `str`):
         | 
| 415 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 416 | 
            +
                        tokenizer (`PreTrainedTokenizer`):
         | 
| 417 | 
            +
                            The tokenizer responsible for encoding the prompt into input tokens.
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    Returns:
         | 
| 420 | 
            +
                        `str` or list of `str`: The converted prompt
         | 
| 421 | 
            +
                    """
         | 
| 422 | 
            +
                    if not isinstance(prompt, List):
         | 
| 423 | 
            +
                        prompts = [prompt]
         | 
| 424 | 
            +
                    else:
         | 
| 425 | 
            +
                        prompts = prompt
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    if not isinstance(prompt, List):
         | 
| 430 | 
            +
                        return prompts[0]
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    return prompts
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
         | 
| 435 | 
            +
                    r"""
         | 
| 436 | 
            +
                    Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
         | 
| 437 | 
            +
                    to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
         | 
| 438 | 
            +
                    is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
         | 
| 439 | 
            +
                    inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    Parameters:
         | 
| 442 | 
            +
                        prompt (`str`):
         | 
| 443 | 
            +
                            The prompt to guide the image generation.
         | 
| 444 | 
            +
                        tokenizer (`PreTrainedTokenizer`):
         | 
| 445 | 
            +
                            The tokenizer responsible for encoding the prompt into input tokens.
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    Returns:
         | 
| 448 | 
            +
                        `str`: The converted prompt
         | 
| 449 | 
            +
                    """
         | 
| 450 | 
            +
                    tokens = tokenizer.tokenize(prompt)
         | 
| 451 | 
            +
                    unique_tokens = set(tokens)
         | 
| 452 | 
            +
                    for token in unique_tokens:
         | 
| 453 | 
            +
                        if token in tokenizer.added_tokens_encoder:
         | 
| 454 | 
            +
                            replacement = token
         | 
| 455 | 
            +
                            i = 1
         | 
| 456 | 
            +
                            while f"{token}_{i}" in tokenizer.added_tokens_encoder:
         | 
| 457 | 
            +
                                replacement += f" {token}_{i}"
         | 
| 458 | 
            +
                                i += 1
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                            prompt = prompt.replace(token, replacement)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    return prompt
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                def load_textual_inversion(
         | 
| 465 | 
            +
                    self,
         | 
| 466 | 
            +
                    pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
         | 
| 467 | 
            +
                    token: Optional[Union[str, List[str]]] = None,
         | 
| 468 | 
            +
                    **kwargs,
         | 
| 469 | 
            +
                ):
         | 
| 470 | 
            +
                    r"""
         | 
| 471 | 
            +
                    Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
         | 
| 472 | 
            +
                    Automatic1111 formats are supported).
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    Parameters:
         | 
| 475 | 
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
         | 
| 476 | 
            +
                            Can be either one of the following or a list of them:
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                                - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
         | 
| 479 | 
            +
                                  pretrained model hosted on the Hub.
         | 
| 480 | 
            +
                                - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
         | 
| 481 | 
            +
                                  inversion weights.
         | 
| 482 | 
            +
                                - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
         | 
| 483 | 
            +
                                - A [torch state
         | 
| 484 | 
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                        token (`str` or `List[str]`, *optional*):
         | 
| 487 | 
            +
                            Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
         | 
| 488 | 
            +
                            list, then `token` must also be a list of equal length.
         | 
| 489 | 
            +
                        weight_name (`str`, *optional*):
         | 
| 490 | 
            +
                            Name of a custom weight file. This should be used when:
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                                - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
         | 
| 493 | 
            +
                                  name such as `text_inv.bin`.
         | 
| 494 | 
            +
                                - The saved textual inversion file is in the Automatic1111 format.
         | 
| 495 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 496 | 
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         | 
| 497 | 
            +
                            is not used.
         | 
| 498 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 499 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 500 | 
            +
                            cached versions if they exist.
         | 
| 501 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 502 | 
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         | 
| 503 | 
            +
                            incompletely downloaded files are deleted.
         | 
| 504 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 505 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         | 
| 506 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 507 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 508 | 
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         | 
| 509 | 
            +
                            won't be downloaded from the Hub.
         | 
| 510 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 511 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         | 
| 512 | 
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         | 
| 513 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 514 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         | 
| 515 | 
            +
                            allowed by Git.
         | 
| 516 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 517 | 
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         | 
| 518 | 
            +
                        mirror (`str`, *optional*):
         | 
| 519 | 
            +
                            Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
         | 
| 520 | 
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         | 
| 521 | 
            +
                            information.
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    Example:
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    To load a textual inversion embedding vector in 🤗 Diffusers format:
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    ```py
         | 
| 528 | 
            +
                    from diffusers import StableDiffusionPipeline
         | 
| 529 | 
            +
                    import torch
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    model_id = "runwayml/stable-diffusion-v1-5"
         | 
| 532 | 
            +
                    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    pipe.load_textual_inversion("sd-concepts-library/cat-toy")
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    prompt = "A <cat-toy> backpack"
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    image = pipe(prompt, num_inference_steps=50).images[0]
         | 
| 539 | 
            +
                    image.save("cat-backpack.png")
         | 
| 540 | 
            +
                    ```
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
         | 
| 543 | 
            +
                    (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
         | 
| 544 | 
            +
                    locally:
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    ```py
         | 
| 547 | 
            +
                    from diffusers import StableDiffusionPipeline
         | 
| 548 | 
            +
                    import torch
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    model_id = "runwayml/stable-diffusion-v1-5"
         | 
| 551 | 
            +
                    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    image = pipe(prompt, num_inference_steps=50).images[0]
         | 
| 558 | 
            +
                    image.save("character.png")
         | 
| 559 | 
            +
                    ```
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    """
         | 
| 562 | 
            +
                    if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
         | 
| 563 | 
            +
                        raise ValueError(
         | 
| 564 | 
            +
                            f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
         | 
| 565 | 
            +
                            f" `{self.load_textual_inversion.__name__}`"
         | 
| 566 | 
            +
                        )
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                    if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
         | 
| 569 | 
            +
                        raise ValueError(
         | 
| 570 | 
            +
                            f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
         | 
| 571 | 
            +
                            f" `{self.load_textual_inversion.__name__}`"
         | 
| 572 | 
            +
                        )
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 575 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 576 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 577 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 578 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 579 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 580 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 581 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 582 | 
            +
                    weight_name = kwargs.pop("weight_name", None)
         | 
| 583 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    if use_safetensors and not is_safetensors_available():
         | 
| 586 | 
            +
                        raise ValueError(
         | 
| 587 | 
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
         | 
| 588 | 
            +
                        )
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    allow_pickle = False
         | 
| 591 | 
            +
                    if use_safetensors is None:
         | 
| 592 | 
            +
                        use_safetensors = is_safetensors_available()
         | 
| 593 | 
            +
                        allow_pickle = True
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    user_agent = {
         | 
| 596 | 
            +
                        "file_type": "text_inversion",
         | 
| 597 | 
            +
                        "framework": "pytorch",
         | 
| 598 | 
            +
                    }
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    if not isinstance(pretrained_model_name_or_path, list):
         | 
| 601 | 
            +
                        pretrained_model_name_or_paths = [pretrained_model_name_or_path]
         | 
| 602 | 
            +
                    else:
         | 
| 603 | 
            +
                        pretrained_model_name_or_paths = pretrained_model_name_or_path
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    if isinstance(token, str):
         | 
| 606 | 
            +
                        tokens = [token]
         | 
| 607 | 
            +
                    elif token is None:
         | 
| 608 | 
            +
                        tokens = [None] * len(pretrained_model_name_or_paths)
         | 
| 609 | 
            +
                    else:
         | 
| 610 | 
            +
                        tokens = token
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    if len(pretrained_model_name_or_paths) != len(tokens):
         | 
| 613 | 
            +
                        raise ValueError(
         | 
| 614 | 
            +
                            f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
         | 
| 615 | 
            +
                            f"Make sure both lists have the same length."
         | 
| 616 | 
            +
                        )
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    valid_tokens = [t for t in tokens if t is not None]
         | 
| 619 | 
            +
                    if len(set(valid_tokens)) < len(valid_tokens):
         | 
| 620 | 
            +
                        raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    token_ids_and_embeddings = []
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                    for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
         | 
| 625 | 
            +
                        if not isinstance(pretrained_model_name_or_path, dict):
         | 
| 626 | 
            +
                            # 1. Load textual inversion file
         | 
| 627 | 
            +
                            model_file = None
         | 
| 628 | 
            +
                            # Let's first try to load .safetensors weights
         | 
| 629 | 
            +
                            if (use_safetensors and weight_name is None) or (
         | 
| 630 | 
            +
                                weight_name is not None and weight_name.endswith(".safetensors")
         | 
| 631 | 
            +
                            ):
         | 
| 632 | 
            +
                                try:
         | 
| 633 | 
            +
                                    model_file = _get_model_file(
         | 
| 634 | 
            +
                                        pretrained_model_name_or_path,
         | 
| 635 | 
            +
                                        weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
         | 
| 636 | 
            +
                                        cache_dir=cache_dir,
         | 
| 637 | 
            +
                                        force_download=force_download,
         | 
| 638 | 
            +
                                        resume_download=resume_download,
         | 
| 639 | 
            +
                                        proxies=proxies,
         | 
| 640 | 
            +
                                        local_files_only=local_files_only,
         | 
| 641 | 
            +
                                        use_auth_token=use_auth_token,
         | 
| 642 | 
            +
                                        revision=revision,
         | 
| 643 | 
            +
                                        subfolder=subfolder,
         | 
| 644 | 
            +
                                        user_agent=user_agent,
         | 
| 645 | 
            +
                                    )
         | 
| 646 | 
            +
                                    state_dict = safetensors.torch.load_file(model_file, device="cpu")
         | 
| 647 | 
            +
                                except Exception as e:
         | 
| 648 | 
            +
                                    if not allow_pickle:
         | 
| 649 | 
            +
                                        raise e
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                                    model_file = None
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                            if model_file is None:
         | 
| 654 | 
            +
                                model_file = _get_model_file(
         | 
| 655 | 
            +
                                    pretrained_model_name_or_path,
         | 
| 656 | 
            +
                                    weights_name=weight_name or TEXT_INVERSION_NAME,
         | 
| 657 | 
            +
                                    cache_dir=cache_dir,
         | 
| 658 | 
            +
                                    force_download=force_download,
         | 
| 659 | 
            +
                                    resume_download=resume_download,
         | 
| 660 | 
            +
                                    proxies=proxies,
         | 
| 661 | 
            +
                                    local_files_only=local_files_only,
         | 
| 662 | 
            +
                                    use_auth_token=use_auth_token,
         | 
| 663 | 
            +
                                    revision=revision,
         | 
| 664 | 
            +
                                    subfolder=subfolder,
         | 
| 665 | 
            +
                                    user_agent=user_agent,
         | 
| 666 | 
            +
                                )
         | 
| 667 | 
            +
                                state_dict = torch.load(model_file, map_location="cpu")
         | 
| 668 | 
            +
                        else:
         | 
| 669 | 
            +
                            state_dict = pretrained_model_name_or_path
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                        # 2. Load token and embedding correcly from file
         | 
| 672 | 
            +
                        loaded_token = None
         | 
| 673 | 
            +
                        if isinstance(state_dict, torch.Tensor):
         | 
| 674 | 
            +
                            if token is None:
         | 
| 675 | 
            +
                                raise ValueError(
         | 
| 676 | 
            +
                                    "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
         | 
| 677 | 
            +
                                )
         | 
| 678 | 
            +
                            embedding = state_dict
         | 
| 679 | 
            +
                        elif len(state_dict) == 1:
         | 
| 680 | 
            +
                            # diffusers
         | 
| 681 | 
            +
                            loaded_token, embedding = next(iter(state_dict.items()))
         | 
| 682 | 
            +
                        elif "string_to_param" in state_dict:
         | 
| 683 | 
            +
                            # A1111
         | 
| 684 | 
            +
                            loaded_token = state_dict["name"]
         | 
| 685 | 
            +
                            embedding = state_dict["string_to_param"]["*"]
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                        if token is not None and loaded_token != token:
         | 
| 688 | 
            +
                            logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
         | 
| 689 | 
            +
                        else:
         | 
| 690 | 
            +
                            token = loaded_token
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                        embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                        # 3. Make sure we don't mess up the tokenizer or text encoder
         | 
| 695 | 
            +
                        vocab = self.tokenizer.get_vocab()
         | 
| 696 | 
            +
                        if token in vocab:
         | 
| 697 | 
            +
                            raise ValueError(
         | 
| 698 | 
            +
                                f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
         | 
| 699 | 
            +
                            )
         | 
| 700 | 
            +
                        elif f"{token}_1" in vocab:
         | 
| 701 | 
            +
                            multi_vector_tokens = [token]
         | 
| 702 | 
            +
                            i = 1
         | 
| 703 | 
            +
                            while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
         | 
| 704 | 
            +
                                multi_vector_tokens.append(f"{token}_{i}")
         | 
| 705 | 
            +
                                i += 1
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                            raise ValueError(
         | 
| 708 | 
            +
                                f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
         | 
| 709 | 
            +
                            )
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                        is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                        if is_multi_vector:
         | 
| 714 | 
            +
                            tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
         | 
| 715 | 
            +
                            embeddings = [e for e in embedding]  # noqa: C416
         | 
| 716 | 
            +
                        else:
         | 
| 717 | 
            +
                            tokens = [token]
         | 
| 718 | 
            +
                            embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                        # add tokens and get ids
         | 
| 721 | 
            +
                        self.tokenizer.add_tokens(tokens)
         | 
| 722 | 
            +
                        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
         | 
| 723 | 
            +
                        token_ids_and_embeddings += zip(token_ids, embeddings)
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                        logger.info(f"Loaded textual inversion embedding for {token}.")
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    # resize token embeddings and set all new embeddings
         | 
| 728 | 
            +
                    self.text_encoder.resize_token_embeddings(len(self.tokenizer))
         | 
| 729 | 
            +
                    for token_id, embedding in token_ids_and_embeddings:
         | 
| 730 | 
            +
                        self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
         | 
| 731 | 
            +
             | 
| 732 | 
            +
             | 
| 733 | 
            +
            class LoraLoaderMixin:
         | 
| 734 | 
            +
                r"""
         | 
| 735 | 
            +
                Load LoRA layers into [`UNet2DConditionModel`] and
         | 
| 736 | 
            +
                [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         | 
| 737 | 
            +
                """
         | 
| 738 | 
            +
                text_encoder_name = TEXT_ENCODER_NAME
         | 
| 739 | 
            +
                unet_name = UNET_NAME
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
         | 
| 742 | 
            +
                    r"""
         | 
| 743 | 
            +
                    Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
         | 
| 744 | 
            +
                    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                    Parameters:
         | 
| 747 | 
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         | 
| 748 | 
            +
                            Can be either:
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                                - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         | 
| 751 | 
            +
                                  the Hub.
         | 
| 752 | 
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
         | 
| 753 | 
            +
                                  with [`ModelMixin.save_pretrained`].
         | 
| 754 | 
            +
                                - A [torch state
         | 
| 755 | 
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 758 | 
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         | 
| 759 | 
            +
                            is not used.
         | 
| 760 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 761 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 762 | 
            +
                            cached versions if they exist.
         | 
| 763 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 764 | 
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         | 
| 765 | 
            +
                            incompletely downloaded files are deleted.
         | 
| 766 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 767 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         | 
| 768 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 769 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 770 | 
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         | 
| 771 | 
            +
                            won't be downloaded from the Hub.
         | 
| 772 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 773 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         | 
| 774 | 
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         | 
| 775 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 776 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         | 
| 777 | 
            +
                            allowed by Git.
         | 
| 778 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 779 | 
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         | 
| 780 | 
            +
                        mirror (`str`, *optional*):
         | 
| 781 | 
            +
                            Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
         | 
| 782 | 
            +
                            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
         | 
| 783 | 
            +
                            information.
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                    """
         | 
| 786 | 
            +
                    # Load the main state dict first which has the LoRA layers for either of
         | 
| 787 | 
            +
                    # UNet and text encoder or both.
         | 
| 788 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 789 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 790 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 791 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 792 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 793 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 794 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 795 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 796 | 
            +
                    weight_name = kwargs.pop("weight_name", None)
         | 
| 797 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                    # set lora scale to a reasonable default
         | 
| 800 | 
            +
                    self._lora_scale = 1.0
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    if use_safetensors and not is_safetensors_available():
         | 
| 803 | 
            +
                        raise ValueError(
         | 
| 804 | 
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
         | 
| 805 | 
            +
                        )
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                    allow_pickle = False
         | 
| 808 | 
            +
                    if use_safetensors is None:
         | 
| 809 | 
            +
                        use_safetensors = is_safetensors_available()
         | 
| 810 | 
            +
                        allow_pickle = True
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                    user_agent = {
         | 
| 813 | 
            +
                        "file_type": "attn_procs_weights",
         | 
| 814 | 
            +
                        "framework": "pytorch",
         | 
| 815 | 
            +
                    }
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    model_file = None
         | 
| 818 | 
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         | 
| 819 | 
            +
                        # Let's first try to load .safetensors weights
         | 
| 820 | 
            +
                        if (use_safetensors and weight_name is None) or (
         | 
| 821 | 
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         | 
| 822 | 
            +
                        ):
         | 
| 823 | 
            +
                            try:
         | 
| 824 | 
            +
                                model_file = _get_model_file(
         | 
| 825 | 
            +
                                    pretrained_model_name_or_path_or_dict,
         | 
| 826 | 
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         | 
| 827 | 
            +
                                    cache_dir=cache_dir,
         | 
| 828 | 
            +
                                    force_download=force_download,
         | 
| 829 | 
            +
                                    resume_download=resume_download,
         | 
| 830 | 
            +
                                    proxies=proxies,
         | 
| 831 | 
            +
                                    local_files_only=local_files_only,
         | 
| 832 | 
            +
                                    use_auth_token=use_auth_token,
         | 
| 833 | 
            +
                                    revision=revision,
         | 
| 834 | 
            +
                                    subfolder=subfolder,
         | 
| 835 | 
            +
                                    user_agent=user_agent,
         | 
| 836 | 
            +
                                )
         | 
| 837 | 
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         | 
| 838 | 
            +
                            except IOError as e:
         | 
| 839 | 
            +
                                if not allow_pickle:
         | 
| 840 | 
            +
                                    raise e
         | 
| 841 | 
            +
                                # try loading non-safetensors weights
         | 
| 842 | 
            +
                                pass
         | 
| 843 | 
            +
                        if model_file is None:
         | 
| 844 | 
            +
                            model_file = _get_model_file(
         | 
| 845 | 
            +
                                pretrained_model_name_or_path_or_dict,
         | 
| 846 | 
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         | 
| 847 | 
            +
                                cache_dir=cache_dir,
         | 
| 848 | 
            +
                                force_download=force_download,
         | 
| 849 | 
            +
                                resume_download=resume_download,
         | 
| 850 | 
            +
                                proxies=proxies,
         | 
| 851 | 
            +
                                local_files_only=local_files_only,
         | 
| 852 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 853 | 
            +
                                revision=revision,
         | 
| 854 | 
            +
                                subfolder=subfolder,
         | 
| 855 | 
            +
                                user_agent=user_agent,
         | 
| 856 | 
            +
                            )
         | 
| 857 | 
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         | 
| 858 | 
            +
                    else:
         | 
| 859 | 
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                    # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
         | 
| 862 | 
            +
                    network_alpha = None
         | 
| 863 | 
            +
                    if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
         | 
| 864 | 
            +
                        state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                    # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
         | 
| 867 | 
            +
                    # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
         | 
| 868 | 
            +
                    # their prefixes.
         | 
| 869 | 
            +
                    keys = list(state_dict.keys())
         | 
| 870 | 
            +
                    if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
         | 
| 871 | 
            +
                        # Load the layers corresponding to UNet.
         | 
| 872 | 
            +
                        unet_keys = [k for k in keys if k.startswith(self.unet_name)]
         | 
| 873 | 
            +
                        logger.info(f"Loading {self.unet_name}.")
         | 
| 874 | 
            +
                        unet_lora_state_dict = {
         | 
| 875 | 
            +
                            k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
         | 
| 876 | 
            +
                        }
         | 
| 877 | 
            +
                        self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                        # Load the layers corresponding to text encoder and make necessary adjustments.
         | 
| 880 | 
            +
                        text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
         | 
| 881 | 
            +
                        text_encoder_lora_state_dict = {
         | 
| 882 | 
            +
                            k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
         | 
| 883 | 
            +
                        }
         | 
| 884 | 
            +
                        if len(text_encoder_lora_state_dict) > 0:
         | 
| 885 | 
            +
                            logger.info(f"Loading {self.text_encoder_name}.")
         | 
| 886 | 
            +
                            attn_procs_text_encoder = self._load_text_encoder_attn_procs(
         | 
| 887 | 
            +
                                text_encoder_lora_state_dict, network_alpha=network_alpha
         | 
| 888 | 
            +
                            )
         | 
| 889 | 
            +
                            self._modify_text_encoder(attn_procs_text_encoder)
         | 
| 890 | 
            +
             | 
| 891 | 
            +
                            # save lora attn procs of text encoder so that it can be easily retrieved
         | 
| 892 | 
            +
                            self._text_encoder_lora_attn_procs = attn_procs_text_encoder
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                    # Otherwise, we're dealing with the old format. This means the `state_dict` should only
         | 
| 895 | 
            +
                    # contain the module names of the `unet` as its keys WITHOUT any prefix.
         | 
| 896 | 
            +
                    elif not all(
         | 
| 897 | 
            +
                        key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
         | 
| 898 | 
            +
                    ):
         | 
| 899 | 
            +
                        self.unet.load_attn_procs(state_dict)
         | 
| 900 | 
            +
                        warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
         | 
| 901 | 
            +
                        warnings.warn(warn_message)
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                @property
         | 
| 904 | 
            +
                def lora_scale(self) -> float:
         | 
| 905 | 
            +
                    # property function that returns the lora scale which can be set at run time by the pipeline.
         | 
| 906 | 
            +
                    # if _lora_scale has not been set, return 1
         | 
| 907 | 
            +
                    return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                @property
         | 
| 910 | 
            +
                def text_encoder_lora_attn_procs(self):
         | 
| 911 | 
            +
                    if hasattr(self, "_text_encoder_lora_attn_procs"):
         | 
| 912 | 
            +
                        return self._text_encoder_lora_attn_procs
         | 
| 913 | 
            +
                    return
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                def _remove_text_encoder_monkey_patch(self):
         | 
| 916 | 
            +
                    # Loop over the CLIPAttention module of text_encoder
         | 
| 917 | 
            +
                    for name, attn_module in self.text_encoder.named_modules():
         | 
| 918 | 
            +
                        if name.endswith(TEXT_ENCODER_ATTN_MODULE):
         | 
| 919 | 
            +
                            # Loop over the LoRA layers
         | 
| 920 | 
            +
                            for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
         | 
| 921 | 
            +
                                # Retrieve the q/k/v/out projection of CLIPAttention
         | 
| 922 | 
            +
                                module = attn_module.get_submodule(text_encoder_attr)
         | 
| 923 | 
            +
                                if hasattr(module, "old_forward"):
         | 
| 924 | 
            +
                                    # restore original `forward` to remove monkey-patch
         | 
| 925 | 
            +
                                    module.forward = module.old_forward
         | 
| 926 | 
            +
                                    delattr(module, "old_forward")
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
         | 
| 929 | 
            +
                    r"""
         | 
| 930 | 
            +
                    Monkey-patches the forward passes of attention modules of the text encoder.
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                    Parameters:
         | 
| 933 | 
            +
                        attn_processors: Dict[str, `LoRAAttnProcessor`]:
         | 
| 934 | 
            +
                            A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
         | 
| 935 | 
            +
                    """
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                    # First, remove any monkey-patch that might have been applied before
         | 
| 938 | 
            +
                    self._remove_text_encoder_monkey_patch()
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                    # Loop over the CLIPAttention module of text_encoder
         | 
| 941 | 
            +
                    for name, attn_module in self.text_encoder.named_modules():
         | 
| 942 | 
            +
                        if name.endswith(TEXT_ENCODER_ATTN_MODULE):
         | 
| 943 | 
            +
                            # Loop over the LoRA layers
         | 
| 944 | 
            +
                            for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
         | 
| 945 | 
            +
                                # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
         | 
| 946 | 
            +
                                module = attn_module.get_submodule(text_encoder_attr)
         | 
| 947 | 
            +
                                lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                                # save old_forward to module that can be used to remove monkey-patch
         | 
| 950 | 
            +
                                old_forward = module.old_forward = module.forward
         | 
| 951 | 
            +
             | 
| 952 | 
            +
                                # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
         | 
| 953 | 
            +
                                # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
         | 
| 954 | 
            +
                                def make_new_forward(old_forward, lora_layer):
         | 
| 955 | 
            +
                                    def new_forward(x):
         | 
| 956 | 
            +
                                        result = old_forward(x) + self.lora_scale * lora_layer(x)
         | 
| 957 | 
            +
                                        return result
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                                    return new_forward
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                                # Monkey-patch.
         | 
| 962 | 
            +
                                module.forward = make_new_forward(old_forward, lora_layer)
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                @property
         | 
| 965 | 
            +
                def _lora_attn_processor_attr_to_text_encoder_attr(self):
         | 
| 966 | 
            +
                    return {
         | 
| 967 | 
            +
                        "to_q_lora": "q_proj",
         | 
| 968 | 
            +
                        "to_k_lora": "k_proj",
         | 
| 969 | 
            +
                        "to_v_lora": "v_proj",
         | 
| 970 | 
            +
                        "to_out_lora": "out_proj",
         | 
| 971 | 
            +
                    }
         | 
| 972 | 
            +
             | 
| 973 | 
            +
                def _load_text_encoder_attn_procs(
         | 
| 974 | 
            +
                    self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
         | 
| 975 | 
            +
                ):
         | 
| 976 | 
            +
                    r"""
         | 
| 977 | 
            +
                    Load pretrained attention processor layers for
         | 
| 978 | 
            +
                    [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
         | 
| 979 | 
            +
             | 
| 980 | 
            +
                    <Tip warning={true}>
         | 
| 981 | 
            +
             | 
| 982 | 
            +
                    This function is experimental and might change in the future.
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                    </Tip>
         | 
| 985 | 
            +
             | 
| 986 | 
            +
                    Parameters:
         | 
| 987 | 
            +
                        pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
         | 
| 988 | 
            +
                            Can be either:
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                                - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
         | 
| 991 | 
            +
                                  Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
         | 
| 992 | 
            +
                                - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
         | 
| 993 | 
            +
                                  `./my_model_directory/`.
         | 
| 994 | 
            +
                                - A [torch state
         | 
| 995 | 
            +
                                  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 998 | 
            +
                            Path to a directory in which a downloaded pretrained model configuration should be cached if the
         | 
| 999 | 
            +
                            standard cache should not be used.
         | 
| 1000 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 1001 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 1002 | 
            +
                            cached versions if they exist.
         | 
| 1003 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 1004 | 
            +
                            Whether or not to delete incompletely received files. Will attempt to resume the download if such a
         | 
| 1005 | 
            +
                            file exists.
         | 
| 1006 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 1007 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
         | 
| 1008 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 1009 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 1010 | 
            +
                            Whether or not to only look at local files (i.e., do not try to download the model).
         | 
| 1011 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 1012 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
         | 
| 1013 | 
            +
                            when running `diffusers-cli login` (stored in `~/.huggingface`).
         | 
| 1014 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 1015 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
         | 
| 1016 | 
            +
                            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
         | 
| 1017 | 
            +
                            identifier allowed by git.
         | 
| 1018 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 1019 | 
            +
                            In case the relevant files are located inside a subfolder of the model repo (either remote in
         | 
| 1020 | 
            +
                            huggingface.co or downloaded locally), you can specify the folder name here.
         | 
| 1021 | 
            +
                        mirror (`str`, *optional*):
         | 
| 1022 | 
            +
                            Mirror source to accelerate downloads in China. If you are from China and have an accessibility
         | 
| 1023 | 
            +
                            problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
         | 
| 1024 | 
            +
                            Please refer to the mirror site for more information.
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                    Returns:
         | 
| 1027 | 
            +
                        `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
         | 
| 1028 | 
            +
                        [`LoRAAttnProcessor`].
         | 
| 1029 | 
            +
             | 
| 1030 | 
            +
                    <Tip>
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                    It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
         | 
| 1033 | 
            +
                    models](https://huggingface.co/docs/hub/models-gated#gated-models).
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                    </Tip>
         | 
| 1036 | 
            +
                    """
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 1039 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 1040 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 1041 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 1042 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 1043 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 1044 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 1045 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 1046 | 
            +
                    weight_name = kwargs.pop("weight_name", None)
         | 
| 1047 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         | 
| 1048 | 
            +
                    network_alpha = kwargs.pop("network_alpha", None)
         | 
| 1049 | 
            +
             | 
| 1050 | 
            +
                    if use_safetensors and not is_safetensors_available():
         | 
| 1051 | 
            +
                        raise ValueError(
         | 
| 1052 | 
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
         | 
| 1053 | 
            +
                        )
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                    allow_pickle = False
         | 
| 1056 | 
            +
                    if use_safetensors is None:
         | 
| 1057 | 
            +
                        use_safetensors = is_safetensors_available()
         | 
| 1058 | 
            +
                        allow_pickle = True
         | 
| 1059 | 
            +
             | 
| 1060 | 
            +
                    user_agent = {
         | 
| 1061 | 
            +
                        "file_type": "attn_procs_weights",
         | 
| 1062 | 
            +
                        "framework": "pytorch",
         | 
| 1063 | 
            +
                    }
         | 
| 1064 | 
            +
             | 
| 1065 | 
            +
                    model_file = None
         | 
| 1066 | 
            +
                    if not isinstance(pretrained_model_name_or_path_or_dict, dict):
         | 
| 1067 | 
            +
                        # Let's first try to load .safetensors weights
         | 
| 1068 | 
            +
                        if (use_safetensors and weight_name is None) or (
         | 
| 1069 | 
            +
                            weight_name is not None and weight_name.endswith(".safetensors")
         | 
| 1070 | 
            +
                        ):
         | 
| 1071 | 
            +
                            try:
         | 
| 1072 | 
            +
                                model_file = _get_model_file(
         | 
| 1073 | 
            +
                                    pretrained_model_name_or_path_or_dict,
         | 
| 1074 | 
            +
                                    weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
         | 
| 1075 | 
            +
                                    cache_dir=cache_dir,
         | 
| 1076 | 
            +
                                    force_download=force_download,
         | 
| 1077 | 
            +
                                    resume_download=resume_download,
         | 
| 1078 | 
            +
                                    proxies=proxies,
         | 
| 1079 | 
            +
                                    local_files_only=local_files_only,
         | 
| 1080 | 
            +
                                    use_auth_token=use_auth_token,
         | 
| 1081 | 
            +
                                    revision=revision,
         | 
| 1082 | 
            +
                                    subfolder=subfolder,
         | 
| 1083 | 
            +
                                    user_agent=user_agent,
         | 
| 1084 | 
            +
                                )
         | 
| 1085 | 
            +
                                state_dict = safetensors.torch.load_file(model_file, device="cpu")
         | 
| 1086 | 
            +
                            except IOError as e:
         | 
| 1087 | 
            +
                                if not allow_pickle:
         | 
| 1088 | 
            +
                                    raise e
         | 
| 1089 | 
            +
                                # try loading non-safetensors weights
         | 
| 1090 | 
            +
                                pass
         | 
| 1091 | 
            +
                        if model_file is None:
         | 
| 1092 | 
            +
                            model_file = _get_model_file(
         | 
| 1093 | 
            +
                                pretrained_model_name_or_path_or_dict,
         | 
| 1094 | 
            +
                                weights_name=weight_name or LORA_WEIGHT_NAME,
         | 
| 1095 | 
            +
                                cache_dir=cache_dir,
         | 
| 1096 | 
            +
                                force_download=force_download,
         | 
| 1097 | 
            +
                                resume_download=resume_download,
         | 
| 1098 | 
            +
                                proxies=proxies,
         | 
| 1099 | 
            +
                                local_files_only=local_files_only,
         | 
| 1100 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 1101 | 
            +
                                revision=revision,
         | 
| 1102 | 
            +
                                subfolder=subfolder,
         | 
| 1103 | 
            +
                                user_agent=user_agent,
         | 
| 1104 | 
            +
                            )
         | 
| 1105 | 
            +
                            state_dict = torch.load(model_file, map_location="cpu")
         | 
| 1106 | 
            +
                    else:
         | 
| 1107 | 
            +
                        state_dict = pretrained_model_name_or_path_or_dict
         | 
| 1108 | 
            +
             | 
| 1109 | 
            +
                    # fill attn processors
         | 
| 1110 | 
            +
                    attn_processors = {}
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                    is_lora = all("lora" in k for k in state_dict.keys())
         | 
| 1113 | 
            +
             | 
| 1114 | 
            +
                    if is_lora:
         | 
| 1115 | 
            +
                        lora_grouped_dict = defaultdict(dict)
         | 
| 1116 | 
            +
                        for key, value in state_dict.items():
         | 
| 1117 | 
            +
                            attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
         | 
| 1118 | 
            +
                            lora_grouped_dict[attn_processor_key][sub_key] = value
         | 
| 1119 | 
            +
             | 
| 1120 | 
            +
                        for key, value_dict in lora_grouped_dict.items():
         | 
| 1121 | 
            +
                            rank = value_dict["to_k_lora.down.weight"].shape[0]
         | 
| 1122 | 
            +
                            cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
         | 
| 1123 | 
            +
                            hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
         | 
| 1124 | 
            +
             | 
| 1125 | 
            +
                            attn_processor_class = (
         | 
| 1126 | 
            +
                                LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
         | 
| 1127 | 
            +
                            )
         | 
| 1128 | 
            +
                            attn_processors[key] = attn_processor_class(
         | 
| 1129 | 
            +
                                hidden_size=hidden_size,
         | 
| 1130 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 1131 | 
            +
                                rank=rank,
         | 
| 1132 | 
            +
                                network_alpha=network_alpha,
         | 
| 1133 | 
            +
                            )
         | 
| 1134 | 
            +
                            attn_processors[key].load_state_dict(value_dict)
         | 
| 1135 | 
            +
             | 
| 1136 | 
            +
                    else:
         | 
| 1137 | 
            +
                        raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
         | 
| 1138 | 
            +
             | 
| 1139 | 
            +
                    # set correct dtype & device
         | 
| 1140 | 
            +
                    attn_processors = {
         | 
| 1141 | 
            +
                        k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
         | 
| 1142 | 
            +
                    }
         | 
| 1143 | 
            +
                    return attn_processors
         | 
| 1144 | 
            +
             | 
| 1145 | 
            +
                @classmethod
         | 
| 1146 | 
            +
                def save_lora_weights(
         | 
| 1147 | 
            +
                    self,
         | 
| 1148 | 
            +
                    save_directory: Union[str, os.PathLike],
         | 
| 1149 | 
            +
                    unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
         | 
| 1150 | 
            +
                    text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
         | 
| 1151 | 
            +
                    is_main_process: bool = True,
         | 
| 1152 | 
            +
                    weight_name: str = None,
         | 
| 1153 | 
            +
                    save_function: Callable = None,
         | 
| 1154 | 
            +
                    safe_serialization: bool = False,
         | 
| 1155 | 
            +
                ):
         | 
| 1156 | 
            +
                    r"""
         | 
| 1157 | 
            +
                    Save the LoRA parameters corresponding to the UNet and text encoder.
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                    Arguments:
         | 
| 1160 | 
            +
                        save_directory (`str` or `os.PathLike`):
         | 
| 1161 | 
            +
                            Directory to save LoRA parameters to. Will be created if it doesn't exist.
         | 
| 1162 | 
            +
                        unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
         | 
| 1163 | 
            +
                            State dict of the LoRA layers corresponding to the UNet.
         | 
| 1164 | 
            +
                        text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
         | 
| 1165 | 
            +
                            State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
         | 
| 1166 | 
            +
                            encoder LoRA state dict because it comes 🤗 Transformers.
         | 
| 1167 | 
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         | 
| 1168 | 
            +
                            Whether the process calling this is the main process or not. Useful during distributed training and you
         | 
| 1169 | 
            +
                            need to call this function on all processes. In this case, set `is_main_process=True` only on the main
         | 
| 1170 | 
            +
                            process to avoid race conditions.
         | 
| 1171 | 
            +
                        save_function (`Callable`):
         | 
| 1172 | 
            +
                            The function to use to save the state dictionary. Useful during distributed training when you need to
         | 
| 1173 | 
            +
                            replace `torch.save` with another method. Can be configured with the environment variable
         | 
| 1174 | 
            +
                            `DIFFUSERS_SAVE_MODE`.
         | 
| 1175 | 
            +
                    """
         | 
| 1176 | 
            +
                    if os.path.isfile(save_directory):
         | 
| 1177 | 
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         | 
| 1178 | 
            +
                        return
         | 
| 1179 | 
            +
             | 
| 1180 | 
            +
                    if save_function is None:
         | 
| 1181 | 
            +
                        if safe_serialization:
         | 
| 1182 | 
            +
             | 
| 1183 | 
            +
                            def save_function(weights, filename):
         | 
| 1184 | 
            +
                                return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
         | 
| 1185 | 
            +
             | 
| 1186 | 
            +
                        else:
         | 
| 1187 | 
            +
                            save_function = torch.save
         | 
| 1188 | 
            +
             | 
| 1189 | 
            +
                    os.makedirs(save_directory, exist_ok=True)
         | 
| 1190 | 
            +
             | 
| 1191 | 
            +
                    # Create a flat dictionary.
         | 
| 1192 | 
            +
                    state_dict = {}
         | 
| 1193 | 
            +
                    if unet_lora_layers is not None:
         | 
| 1194 | 
            +
                        weights = (
         | 
| 1195 | 
            +
                            unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
         | 
| 1196 | 
            +
                        )
         | 
| 1197 | 
            +
             | 
| 1198 | 
            +
                        unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
         | 
| 1199 | 
            +
                        state_dict.update(unet_lora_state_dict)
         | 
| 1200 | 
            +
             | 
| 1201 | 
            +
                    if text_encoder_lora_layers is not None:
         | 
| 1202 | 
            +
                        weights = (
         | 
| 1203 | 
            +
                            text_encoder_lora_layers.state_dict()
         | 
| 1204 | 
            +
                            if isinstance(text_encoder_lora_layers, torch.nn.Module)
         | 
| 1205 | 
            +
                            else text_encoder_lora_layers
         | 
| 1206 | 
            +
                        )
         | 
| 1207 | 
            +
             | 
| 1208 | 
            +
                        text_encoder_lora_state_dict = {
         | 
| 1209 | 
            +
                            f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
         | 
| 1210 | 
            +
                        }
         | 
| 1211 | 
            +
                        state_dict.update(text_encoder_lora_state_dict)
         | 
| 1212 | 
            +
             | 
| 1213 | 
            +
                    # Save the model
         | 
| 1214 | 
            +
                    if weight_name is None:
         | 
| 1215 | 
            +
                        if safe_serialization:
         | 
| 1216 | 
            +
                            weight_name = LORA_WEIGHT_NAME_SAFE
         | 
| 1217 | 
            +
                        else:
         | 
| 1218 | 
            +
                            weight_name = LORA_WEIGHT_NAME
         | 
| 1219 | 
            +
             | 
| 1220 | 
            +
                    save_function(state_dict, os.path.join(save_directory, weight_name))
         | 
| 1221 | 
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
         | 
| 1222 | 
            +
             | 
| 1223 | 
            +
                def _convert_kohya_lora_to_diffusers(self, state_dict):
         | 
| 1224 | 
            +
                    unet_state_dict = {}
         | 
| 1225 | 
            +
                    te_state_dict = {}
         | 
| 1226 | 
            +
                    network_alpha = None
         | 
| 1227 | 
            +
             | 
| 1228 | 
            +
                    for key, value in state_dict.items():
         | 
| 1229 | 
            +
                        if "lora_down" in key:
         | 
| 1230 | 
            +
                            lora_name = key.split(".")[0]
         | 
| 1231 | 
            +
                            lora_name_up = lora_name + ".lora_up.weight"
         | 
| 1232 | 
            +
                            lora_name_alpha = lora_name + ".alpha"
         | 
| 1233 | 
            +
                            if lora_name_alpha in state_dict:
         | 
| 1234 | 
            +
                                alpha = state_dict[lora_name_alpha].item()
         | 
| 1235 | 
            +
                                if network_alpha is None:
         | 
| 1236 | 
            +
                                    network_alpha = alpha
         | 
| 1237 | 
            +
                                elif network_alpha != alpha:
         | 
| 1238 | 
            +
                                    raise ValueError("Network alpha is not consistent")
         | 
| 1239 | 
            +
             | 
| 1240 | 
            +
                            if lora_name.startswith("lora_unet_"):
         | 
| 1241 | 
            +
                                diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
         | 
| 1242 | 
            +
                                diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
         | 
| 1243 | 
            +
                                diffusers_name = diffusers_name.replace("mid.block", "mid_block")
         | 
| 1244 | 
            +
                                diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
         | 
| 1245 | 
            +
                                diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
         | 
| 1246 | 
            +
                                diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
         | 
| 1247 | 
            +
                                diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
         | 
| 1248 | 
            +
                                diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
         | 
| 1249 | 
            +
                                diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
         | 
| 1250 | 
            +
                                if "transformer_blocks" in diffusers_name:
         | 
| 1251 | 
            +
                                    if "attn1" in diffusers_name or "attn2" in diffusers_name:
         | 
| 1252 | 
            +
                                        diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
         | 
| 1253 | 
            +
                                        diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
         | 
| 1254 | 
            +
                                        unet_state_dict[diffusers_name] = value
         | 
| 1255 | 
            +
                                        unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
         | 
| 1256 | 
            +
                            elif lora_name.startswith("lora_te_"):
         | 
| 1257 | 
            +
                                diffusers_name = key.replace("lora_te_", "").replace("_", ".")
         | 
| 1258 | 
            +
                                diffusers_name = diffusers_name.replace("text.model", "text_model")
         | 
| 1259 | 
            +
                                diffusers_name = diffusers_name.replace("self.attn", "self_attn")
         | 
| 1260 | 
            +
                                diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
         | 
| 1261 | 
            +
                                diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
         | 
| 1262 | 
            +
                                diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
         | 
| 1263 | 
            +
                                diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
         | 
| 1264 | 
            +
                                if "self_attn" in diffusers_name:
         | 
| 1265 | 
            +
                                    te_state_dict[diffusers_name] = value
         | 
| 1266 | 
            +
                                    te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
         | 
| 1267 | 
            +
             | 
| 1268 | 
            +
                    unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
         | 
| 1269 | 
            +
                    te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
         | 
| 1270 | 
            +
                    new_state_dict = {**unet_state_dict, **te_state_dict}
         | 
| 1271 | 
            +
                    return new_state_dict, network_alpha
         | 
| 1272 | 
            +
             | 
| 1273 | 
            +
             | 
| 1274 | 
            +
            class FromCkptMixin:
         | 
| 1275 | 
            +
                """
         | 
| 1276 | 
            +
                Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
         | 
| 1277 | 
            +
                """
         | 
| 1278 | 
            +
             | 
| 1279 | 
            +
                @classmethod
         | 
| 1280 | 
            +
                def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
         | 
| 1281 | 
            +
                    r"""
         | 
| 1282 | 
            +
                    Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
         | 
| 1283 | 
            +
                    is set in evaluation mode (`model.eval()`) by default.
         | 
| 1284 | 
            +
             | 
| 1285 | 
            +
                    Parameters:
         | 
| 1286 | 
            +
                        pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
         | 
| 1287 | 
            +
                            Can be either:
         | 
| 1288 | 
            +
                                - A link to the `.ckpt` file (for example
         | 
| 1289 | 
            +
                                  `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
         | 
| 1290 | 
            +
                                - A path to a *file* containing all pipeline weights.
         | 
| 1291 | 
            +
                        torch_dtype (`str` or `torch.dtype`, *optional*):
         | 
| 1292 | 
            +
                            Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
         | 
| 1293 | 
            +
                            dtype is automatically derived from the model's weights.
         | 
| 1294 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 1295 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 1296 | 
            +
                            cached versions if they exist.
         | 
| 1297 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 1298 | 
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         | 
| 1299 | 
            +
                            is not used.
         | 
| 1300 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 1301 | 
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         | 
| 1302 | 
            +
                            incompletely downloaded files are deleted.
         | 
| 1303 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 1304 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         | 
| 1305 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 1306 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 1307 | 
            +
                            Whether to only load local model weights and configuration files or not. If set to True, the model
         | 
| 1308 | 
            +
                            won't be downloaded from the Hub.
         | 
| 1309 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 1310 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         | 
| 1311 | 
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         | 
| 1312 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 1313 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         | 
| 1314 | 
            +
                            allowed by Git.
         | 
| 1315 | 
            +
                        use_safetensors (`bool`, *optional*, defaults to `None`):
         | 
| 1316 | 
            +
                            If set to `None`, the safetensors weights are downloaded if they're available **and** if the
         | 
| 1317 | 
            +
                            safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
         | 
| 1318 | 
            +
                            weights. If set to `False`, safetensors weights are not loaded.
         | 
| 1319 | 
            +
                        extract_ema (`bool`, *optional*, defaults to `False`):
         | 
| 1320 | 
            +
                            Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
         | 
| 1321 | 
            +
                            higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
         | 
| 1322 | 
            +
                        upcast_attention (`bool`, *optional*, defaults to `None`):
         | 
| 1323 | 
            +
                            Whether the attention computation should always be upcasted.
         | 
| 1324 | 
            +
                        image_size (`int`, *optional*, defaults to 512):
         | 
| 1325 | 
            +
                            The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
         | 
| 1326 | 
            +
                            Diffusion v2 base model. Use 768 for Stable Diffusion v2.
         | 
| 1327 | 
            +
                        prediction_type (`str`, *optional*):
         | 
| 1328 | 
            +
                            The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
         | 
| 1329 | 
            +
                            the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
         | 
| 1330 | 
            +
                        num_in_channels (`int`, *optional*, defaults to `None`):
         | 
| 1331 | 
            +
                            The number of input channels. If `None`, it will be automatically inferred.
         | 
| 1332 | 
            +
                        scheduler_type (`str`, *optional*, defaults to `"pndm"`):
         | 
| 1333 | 
            +
                            Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
         | 
| 1334 | 
            +
                            "ddim"]`.
         | 
| 1335 | 
            +
                        load_safety_checker (`bool`, *optional*, defaults to `True`):
         | 
| 1336 | 
            +
                            Whether to load the safety checker or not.
         | 
| 1337 | 
            +
                        text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
         | 
| 1338 | 
            +
                            An instance of
         | 
| 1339 | 
            +
                            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
         | 
| 1340 | 
            +
                            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
         | 
| 1341 | 
            +
                            variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
         | 
| 1342 | 
            +
                            needed.
         | 
| 1343 | 
            +
                        tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
         | 
| 1344 | 
            +
                            An instance of
         | 
| 1345 | 
            +
                            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
         | 
| 1346 | 
            +
                            to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
         | 
| 1347 | 
            +
                            itself, if needed.
         | 
| 1348 | 
            +
                        kwargs (remaining dictionary of keyword arguments, *optional*):
         | 
| 1349 | 
            +
                            Can be used to overwrite load and saveable variables (for example the pipeline components of the
         | 
| 1350 | 
            +
                            specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
         | 
| 1351 | 
            +
                            method. See example below for more information.
         | 
| 1352 | 
            +
             | 
| 1353 | 
            +
                    Examples:
         | 
| 1354 | 
            +
             | 
| 1355 | 
            +
                    ```py
         | 
| 1356 | 
            +
                    >>> from diffusers import StableDiffusionPipeline
         | 
| 1357 | 
            +
             | 
| 1358 | 
            +
                    >>> # Download pipeline from huggingface.co and cache.
         | 
| 1359 | 
            +
                    >>> pipeline = StableDiffusionPipeline.from_ckpt(
         | 
| 1360 | 
            +
                    ...     "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
         | 
| 1361 | 
            +
                    ... )
         | 
| 1362 | 
            +
             | 
| 1363 | 
            +
                    >>> # Download pipeline from local file
         | 
| 1364 | 
            +
                    >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
         | 
| 1365 | 
            +
                    >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
         | 
| 1366 | 
            +
             | 
| 1367 | 
            +
                    >>> # Enable float16 and move to GPU
         | 
| 1368 | 
            +
                    >>> pipeline = StableDiffusionPipeline.from_ckpt(
         | 
| 1369 | 
            +
                    ...     "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
         | 
| 1370 | 
            +
                    ...     torch_dtype=torch.float16,
         | 
| 1371 | 
            +
                    ... )
         | 
| 1372 | 
            +
                    >>> pipeline.to("cuda")
         | 
| 1373 | 
            +
                    ```
         | 
| 1374 | 
            +
                    """
         | 
| 1375 | 
            +
                    # import here to avoid circular dependency
         | 
| 1376 | 
            +
                    from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
         | 
| 1377 | 
            +
             | 
| 1378 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 1379 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 1380 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 1381 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 1382 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 1383 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 1384 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 1385 | 
            +
                    extract_ema = kwargs.pop("extract_ema", False)
         | 
| 1386 | 
            +
                    image_size = kwargs.pop("image_size", 512)
         | 
| 1387 | 
            +
                    scheduler_type = kwargs.pop("scheduler_type", "pndm")
         | 
| 1388 | 
            +
                    num_in_channels = kwargs.pop("num_in_channels", None)
         | 
| 1389 | 
            +
                    upcast_attention = kwargs.pop("upcast_attention", None)
         | 
| 1390 | 
            +
                    load_safety_checker = kwargs.pop("load_safety_checker", True)
         | 
| 1391 | 
            +
                    prediction_type = kwargs.pop("prediction_type", None)
         | 
| 1392 | 
            +
                    text_encoder = kwargs.pop("text_encoder", None)
         | 
| 1393 | 
            +
                    tokenizer = kwargs.pop("tokenizer", None)
         | 
| 1394 | 
            +
             | 
| 1395 | 
            +
                    torch_dtype = kwargs.pop("torch_dtype", None)
         | 
| 1396 | 
            +
             | 
| 1397 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
         | 
| 1398 | 
            +
             | 
| 1399 | 
            +
                    pipeline_name = cls.__name__
         | 
| 1400 | 
            +
                    file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
         | 
| 1401 | 
            +
                    from_safetensors = file_extension == "safetensors"
         | 
| 1402 | 
            +
             | 
| 1403 | 
            +
                    if from_safetensors and use_safetensors is False:
         | 
| 1404 | 
            +
                        raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
         | 
| 1405 | 
            +
             | 
| 1406 | 
            +
                    # TODO: For now we only support stable diffusion
         | 
| 1407 | 
            +
                    stable_unclip = None
         | 
| 1408 | 
            +
                    model_type = None
         | 
| 1409 | 
            +
                    controlnet = False
         | 
| 1410 | 
            +
             | 
| 1411 | 
            +
                    if pipeline_name == "StableDiffusionControlNetPipeline":
         | 
| 1412 | 
            +
                        # Model type will be inferred from the checkpoint.
         | 
| 1413 | 
            +
                        controlnet = True
         | 
| 1414 | 
            +
                    elif "StableDiffusion" in pipeline_name:
         | 
| 1415 | 
            +
                        # Model type will be inferred from the checkpoint.
         | 
| 1416 | 
            +
                        pass
         | 
| 1417 | 
            +
                    elif pipeline_name == "StableUnCLIPPipeline":
         | 
| 1418 | 
            +
                        model_type = "FrozenOpenCLIPEmbedder"
         | 
| 1419 | 
            +
                        stable_unclip = "txt2img"
         | 
| 1420 | 
            +
                    elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
         | 
| 1421 | 
            +
                        model_type = "FrozenOpenCLIPEmbedder"
         | 
| 1422 | 
            +
                        stable_unclip = "img2img"
         | 
| 1423 | 
            +
                    elif pipeline_name == "PaintByExamplePipeline":
         | 
| 1424 | 
            +
                        model_type = "PaintByExample"
         | 
| 1425 | 
            +
                    elif pipeline_name == "LDMTextToImagePipeline":
         | 
| 1426 | 
            +
                        model_type = "LDMTextToImage"
         | 
| 1427 | 
            +
                    else:
         | 
| 1428 | 
            +
                        raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
         | 
| 1429 | 
            +
             | 
| 1430 | 
            +
                    # remove huggingface url
         | 
| 1431 | 
            +
                    for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
         | 
| 1432 | 
            +
                        if pretrained_model_link_or_path.startswith(prefix):
         | 
| 1433 | 
            +
                            pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
         | 
| 1434 | 
            +
             | 
| 1435 | 
            +
                    # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
         | 
| 1436 | 
            +
                    ckpt_path = Path(pretrained_model_link_or_path)
         | 
| 1437 | 
            +
                    if not ckpt_path.is_file():
         | 
| 1438 | 
            +
                        # get repo_id and (potentially nested) file path of ckpt in repo
         | 
| 1439 | 
            +
                        repo_id = "/".join(ckpt_path.parts[:2])
         | 
| 1440 | 
            +
                        file_path = "/".join(ckpt_path.parts[2:])
         | 
| 1441 | 
            +
             | 
| 1442 | 
            +
                        if file_path.startswith("blob/"):
         | 
| 1443 | 
            +
                            file_path = file_path[len("blob/") :]
         | 
| 1444 | 
            +
             | 
| 1445 | 
            +
                        if file_path.startswith("main/"):
         | 
| 1446 | 
            +
                            file_path = file_path[len("main/") :]
         | 
| 1447 | 
            +
             | 
| 1448 | 
            +
                        pretrained_model_link_or_path = hf_hub_download(
         | 
| 1449 | 
            +
                            repo_id,
         | 
| 1450 | 
            +
                            filename=file_path,
         | 
| 1451 | 
            +
                            cache_dir=cache_dir,
         | 
| 1452 | 
            +
                            resume_download=resume_download,
         | 
| 1453 | 
            +
                            proxies=proxies,
         | 
| 1454 | 
            +
                            local_files_only=local_files_only,
         | 
| 1455 | 
            +
                            use_auth_token=use_auth_token,
         | 
| 1456 | 
            +
                            revision=revision,
         | 
| 1457 | 
            +
                            force_download=force_download,
         | 
| 1458 | 
            +
                        )
         | 
| 1459 | 
            +
             | 
| 1460 | 
            +
                    pipe = download_from_original_stable_diffusion_ckpt(
         | 
| 1461 | 
            +
                        pretrained_model_link_or_path,
         | 
| 1462 | 
            +
                        pipeline_class=cls,
         | 
| 1463 | 
            +
                        model_type=model_type,
         | 
| 1464 | 
            +
                        stable_unclip=stable_unclip,
         | 
| 1465 | 
            +
                        controlnet=controlnet,
         | 
| 1466 | 
            +
                        from_safetensors=from_safetensors,
         | 
| 1467 | 
            +
                        extract_ema=extract_ema,
         | 
| 1468 | 
            +
                        image_size=image_size,
         | 
| 1469 | 
            +
                        scheduler_type=scheduler_type,
         | 
| 1470 | 
            +
                        num_in_channels=num_in_channels,
         | 
| 1471 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 1472 | 
            +
                        load_safety_checker=load_safety_checker,
         | 
| 1473 | 
            +
                        prediction_type=prediction_type,
         | 
| 1474 | 
            +
                        text_encoder=text_encoder,
         | 
| 1475 | 
            +
                        tokenizer=tokenizer,
         | 
| 1476 | 
            +
                    )
         | 
| 1477 | 
            +
             | 
| 1478 | 
            +
                    if torch_dtype is not None:
         | 
| 1479 | 
            +
                        pipe.to(torch_dtype=torch_dtype)
         | 
| 1480 | 
            +
             | 
| 1481 | 
            +
                    return pipe
         | 
    	
        diffusers/models/modeling_utils.py
    ADDED
    
    | @@ -0,0 +1,978 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 The HuggingFace Inc. team.
         | 
| 3 | 
            +
            # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import inspect
         | 
| 18 | 
            +
            import itertools
         | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
            import re
         | 
| 21 | 
            +
            from functools import partial
         | 
| 22 | 
            +
            from typing import Any, Callable, List, Optional, Tuple, Union
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            import torch
         | 
| 25 | 
            +
            from torch import Tensor, device, nn
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from ..utils.constants import (
         | 
| 28 | 
            +
                CONFIG_NAME,
         | 
| 29 | 
            +
                DIFFUSERS_CACHE,
         | 
| 30 | 
            +
                FLAX_WEIGHTS_NAME,
         | 
| 31 | 
            +
                SAFETENSORS_WEIGHTS_NAME,
         | 
| 32 | 
            +
                WEIGHTS_NAME
         | 
| 33 | 
            +
            )
         | 
| 34 | 
            +
            from ..utils.hub_utils import (
         | 
| 35 | 
            +
                HF_HUB_OFFLINE,
         | 
| 36 | 
            +
                _add_variant,
         | 
| 37 | 
            +
                _get_model_file
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            from ..utils.deprecation_utils import deprecate
         | 
| 40 | 
            +
            from ..utils.import_utils import (
         | 
| 41 | 
            +
                is_accelerate_available,
         | 
| 42 | 
            +
                is_safetensors_available,
         | 
| 43 | 
            +
                is_torch_version
         | 
| 44 | 
            +
            )
         | 
| 45 | 
            +
            from ..utils.logging import get_logger
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            logger = get_logger(__name__)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            if is_torch_version(">=", "1.9.0"):
         | 
| 51 | 
            +
                _LOW_CPU_MEM_USAGE_DEFAULT = True
         | 
| 52 | 
            +
            else:
         | 
| 53 | 
            +
                _LOW_CPU_MEM_USAGE_DEFAULT = False
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            if is_accelerate_available():
         | 
| 57 | 
            +
                import accelerate
         | 
| 58 | 
            +
                from accelerate.utils import set_module_tensor_to_device
         | 
| 59 | 
            +
                from accelerate.utils.versions import is_torch_version
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            if is_safetensors_available():
         | 
| 62 | 
            +
                import safetensors
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def get_parameter_device(parameter: torch.nn.Module):
         | 
| 66 | 
            +
                try:
         | 
| 67 | 
            +
                    parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
         | 
| 68 | 
            +
                    return next(parameters_and_buffers).device
         | 
| 69 | 
            +
                except StopIteration:
         | 
| 70 | 
            +
                    # For torch.nn.DataParallel compatibility in PyTorch 1.5
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
         | 
| 73 | 
            +
                        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
         | 
| 74 | 
            +
                        return tuples
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
         | 
| 77 | 
            +
                    first_tuple = next(gen)
         | 
| 78 | 
            +
                    return first_tuple[1].device
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def get_parameter_dtype(parameter: torch.nn.Module):
         | 
| 82 | 
            +
                try:
         | 
| 83 | 
            +
                    params = tuple(parameter.parameters())
         | 
| 84 | 
            +
                    if len(params) > 0:
         | 
| 85 | 
            +
                        return params[0].dtype
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    buffers = tuple(parameter.buffers())
         | 
| 88 | 
            +
                    if len(buffers) > 0:
         | 
| 89 | 
            +
                        return buffers[0].dtype
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                except StopIteration:
         | 
| 92 | 
            +
                    # For torch.nn.DataParallel compatibility in PyTorch 1.5
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
         | 
| 95 | 
            +
                        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
         | 
| 96 | 
            +
                        return tuples
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
         | 
| 99 | 
            +
                    first_tuple = next(gen)
         | 
| 100 | 
            +
                    return first_tuple[1].dtype
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                Reads a checkpoint file, returning properly formatted errors if they arise.
         | 
| 106 | 
            +
                """
         | 
| 107 | 
            +
                try:
         | 
| 108 | 
            +
                    if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
         | 
| 109 | 
            +
                        return torch.load(checkpoint_file, map_location="cpu")
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        return safetensors.torch.load_file(checkpoint_file, device="cpu")
         | 
| 112 | 
            +
                except Exception as e:
         | 
| 113 | 
            +
                    try:
         | 
| 114 | 
            +
                        with open(checkpoint_file) as f:
         | 
| 115 | 
            +
                            if f.read().startswith("version"):
         | 
| 116 | 
            +
                                raise OSError(
         | 
| 117 | 
            +
                                    "You seem to have cloned a repository without having git-lfs installed. Please install "
         | 
| 118 | 
            +
                                    "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
         | 
| 119 | 
            +
                                    "you cloned."
         | 
| 120 | 
            +
                                )
         | 
| 121 | 
            +
                            else:
         | 
| 122 | 
            +
                                raise ValueError(
         | 
| 123 | 
            +
                                    f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
         | 
| 124 | 
            +
                                    "model. Make sure you have saved the model properly."
         | 
| 125 | 
            +
                                ) from e
         | 
| 126 | 
            +
                    except (UnicodeDecodeError, ValueError):
         | 
| 127 | 
            +
                        raise OSError(
         | 
| 128 | 
            +
                            f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
         | 
| 129 | 
            +
                            f"at '{checkpoint_file}'. "
         | 
| 130 | 
            +
                            "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def _load_state_dict_into_model(model_to_load, state_dict):
         | 
| 135 | 
            +
                # Convert old format to new format if needed from a PyTorch state_dict
         | 
| 136 | 
            +
                # copy state_dict so _load_from_state_dict can modify it
         | 
| 137 | 
            +
                state_dict = state_dict.copy()
         | 
| 138 | 
            +
                error_msgs = []
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
         | 
| 141 | 
            +
                # so we need to apply the function recursively.
         | 
| 142 | 
            +
                def load(module: torch.nn.Module, prefix=""):
         | 
| 143 | 
            +
                    args = (state_dict, prefix, {}, True, [], [], error_msgs)
         | 
| 144 | 
            +
                    module._load_from_state_dict(*args)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    for name, child in module._modules.items():
         | 
| 147 | 
            +
                        if child is not None:
         | 
| 148 | 
            +
                            load(child, prefix + name + ".")
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                load(model_to_load)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                return error_msgs
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            class ModelMixin(torch.nn.Module):
         | 
| 156 | 
            +
                r"""
         | 
| 157 | 
            +
                Base class for all models.
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
         | 
| 160 | 
            +
                and saving models.
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
         | 
| 163 | 
            +
                      [`~models.ModelMixin.save_pretrained`].
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                config_name = CONFIG_NAME
         | 
| 166 | 
            +
                _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
         | 
| 167 | 
            +
                _supports_gradient_checkpointing = False
         | 
| 168 | 
            +
                _keys_to_ignore_on_load_unexpected = None
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def __init__(self):
         | 
| 171 | 
            +
                    super().__init__()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def __getattr__(self, name: str) -> Any:
         | 
| 174 | 
            +
                    """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
         | 
| 175 | 
            +
                    config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
         | 
| 176 | 
            +
                    __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
         | 
| 177 | 
            +
                    https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         | 
| 178 | 
            +
                    """
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
         | 
| 181 | 
            +
                    is_attribute = name in self.__dict__
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    if is_in_config and not is_attribute:
         | 
| 184 | 
            +
                        deprecation_message = (
         | 
| 185 | 
            +
                            f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. "
         | 
| 186 | 
            +
                            f"Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
         | 
| 187 | 
            +
                        )
         | 
| 188 | 
            +
                        deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
         | 
| 189 | 
            +
                        return self._internal_dict[name]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         | 
| 192 | 
            +
                    return super().__getattr__(name)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                @property
         | 
| 195 | 
            +
                def is_gradient_checkpointing(self) -> bool:
         | 
| 196 | 
            +
                    """
         | 
| 197 | 
            +
                    Whether gradient checkpointing is activated for this model or not.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
         | 
| 200 | 
            +
                    activations".
         | 
| 201 | 
            +
                    """
         | 
| 202 | 
            +
                    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def enable_gradient_checkpointing(self):
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
                    Activates gradient checkpointing for the current model.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
         | 
| 209 | 
            +
                    activations".
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    if not self._supports_gradient_checkpointing:
         | 
| 212 | 
            +
                        raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
         | 
| 213 | 
            +
                    self.apply(partial(self._set_gradient_checkpointing, value=True))
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def disable_gradient_checkpointing(self):
         | 
| 216 | 
            +
                    """
         | 
| 217 | 
            +
                    Deactivates gradient checkpointing for the current model.
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
         | 
| 220 | 
            +
                    activations".
         | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
                    if self._supports_gradient_checkpointing:
         | 
| 223 | 
            +
                        self.apply(partial(self._set_gradient_checkpointing, value=False))
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def set_use_memory_efficient_attention_xformers(
         | 
| 226 | 
            +
                    self, valid: bool, attention_op: Optional[Callable] = None
         | 
| 227 | 
            +
                ) -> None:
         | 
| 228 | 
            +
                    # Recursively walk through all the children.
         | 
| 229 | 
            +
                    # Any children which exposes the set_use_memory_efficient_attention_xformers method
         | 
| 230 | 
            +
                    # gets the message
         | 
| 231 | 
            +
                    def fn_recursive_set_mem_eff(module: torch.nn.Module):
         | 
| 232 | 
            +
                        if hasattr(module, "set_use_memory_efficient_attention_xformers"):
         | 
| 233 | 
            +
                            module.set_use_memory_efficient_attention_xformers(valid, attention_op)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                        for child in module.children():
         | 
| 236 | 
            +
                            fn_recursive_set_mem_eff(child)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    for module in self.children():
         | 
| 239 | 
            +
                        if isinstance(module, torch.nn.Module):
         | 
| 240 | 
            +
                            fn_recursive_set_mem_eff(module)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
         | 
| 243 | 
            +
                    r"""
         | 
| 244 | 
            +
                    Enable memory efficient attention as implemented in xformers.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
         | 
| 247 | 
            +
                    time. Speed up at training time is not guaranteed.
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
         | 
| 250 | 
            +
                    is used.
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    Parameters:
         | 
| 253 | 
            +
                        attention_op (`Callable`, *optional*):
         | 
| 254 | 
            +
                            Override the default `None` operator for use as `op` argument to the
         | 
| 255 | 
            +
                            [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
         | 
| 256 | 
            +
                            function of xFormers.
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    Examples:
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    ```py
         | 
| 261 | 
            +
                    >>> import torch
         | 
| 262 | 
            +
                    >>> from diffusers import UNet2DConditionModel
         | 
| 263 | 
            +
                    >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    >>> model = UNet2DConditionModel.from_pretrained(
         | 
| 266 | 
            +
                    ...     "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
         | 
| 267 | 
            +
                    ... )
         | 
| 268 | 
            +
                    >>> model = model.to("cuda")
         | 
| 269 | 
            +
                    >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
         | 
| 270 | 
            +
                    ```
         | 
| 271 | 
            +
                    """
         | 
| 272 | 
            +
                    self.set_use_memory_efficient_attention_xformers(True, attention_op)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def disable_xformers_memory_efficient_attention(self):
         | 
| 275 | 
            +
                    r"""
         | 
| 276 | 
            +
                    Disable memory efficient attention as implemented in xformers.
         | 
| 277 | 
            +
                    """
         | 
| 278 | 
            +
                    self.set_use_memory_efficient_attention_xformers(False)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def save_pretrained(
         | 
| 281 | 
            +
                    self,
         | 
| 282 | 
            +
                    save_directory: Union[str, os.PathLike],
         | 
| 283 | 
            +
                    is_main_process: bool = True,
         | 
| 284 | 
            +
                    save_function: Callable = None,
         | 
| 285 | 
            +
                    safe_serialization: bool = False,
         | 
| 286 | 
            +
                    variant: Optional[str] = None,
         | 
| 287 | 
            +
                ):
         | 
| 288 | 
            +
                    """
         | 
| 289 | 
            +
                    Save a model and its configuration file to a directory, so that it can be re-loaded using the
         | 
| 290 | 
            +
                    `[`~models.ModelMixin.from_pretrained`]` class method.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    Arguments:
         | 
| 293 | 
            +
                        save_directory (`str` or `os.PathLike`):
         | 
| 294 | 
            +
                            Directory to which to save. Will be created if it doesn't exist.
         | 
| 295 | 
            +
                        is_main_process (`bool`, *optional*, defaults to `True`):
         | 
| 296 | 
            +
                            Whether the process calling this is the main process or not. Useful when in distributed training like
         | 
| 297 | 
            +
                            TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
         | 
| 298 | 
            +
                            the main process to avoid race conditions.
         | 
| 299 | 
            +
                        save_function (`Callable`):
         | 
| 300 | 
            +
                            The function to use to save the state dictionary. Useful on distributed training like TPUs when one
         | 
| 301 | 
            +
                            need to replace `torch.save` by another method. Can be configured with the environment variable
         | 
| 302 | 
            +
                            `DIFFUSERS_SAVE_MODE`.
         | 
| 303 | 
            +
                        safe_serialization (`bool`, *optional*, defaults to `False`):
         | 
| 304 | 
            +
                            Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
         | 
| 305 | 
            +
                        variant (`str`, *optional*):
         | 
| 306 | 
            +
                            If specified, weights are saved in the format pytorch_model.<variant>.bin.
         | 
| 307 | 
            +
                    """
         | 
| 308 | 
            +
                    if safe_serialization and not is_safetensors_available():
         | 
| 309 | 
            +
                        raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if os.path.isfile(save_directory):
         | 
| 312 | 
            +
                        logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
         | 
| 313 | 
            +
                        return
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    os.makedirs(save_directory, exist_ok=True)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    model_to_save = self
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    # Attach architecture to the config
         | 
| 320 | 
            +
                    # Save the config
         | 
| 321 | 
            +
                    if is_main_process:
         | 
| 322 | 
            +
                        model_to_save.save_config(save_directory)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    # Save the model
         | 
| 325 | 
            +
                    state_dict = model_to_save.state_dict()
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
         | 
| 328 | 
            +
                    weights_name = _add_variant(weights_name, variant)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # Save the model
         | 
| 331 | 
            +
                    if safe_serialization:
         | 
| 332 | 
            +
                        safetensors.torch.save_file(
         | 
| 333 | 
            +
                            state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
         | 
| 334 | 
            +
                        )
         | 
| 335 | 
            +
                    else:
         | 
| 336 | 
            +
                        torch.save(state_dict, os.path.join(save_directory, weights_name))
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                @classmethod
         | 
| 341 | 
            +
                def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
         | 
| 342 | 
            +
                    r"""
         | 
| 343 | 
            +
                    Instantiate a pretrained pytorch model from a pre-trained model configuration.
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
         | 
| 346 | 
            +
                    the model, you should first set it back in training mode with `model.train()`.
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
         | 
| 349 | 
            +
                    pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
         | 
| 350 | 
            +
                    task.
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
         | 
| 353 | 
            +
                    weights are discarded.
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    Parameters:
         | 
| 356 | 
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
         | 
| 357 | 
            +
                            Can be either:
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                                - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
         | 
| 360 | 
            +
                                  Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
         | 
| 361 | 
            +
                                - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
         | 
| 362 | 
            +
                                  `./my_model_directory/`.
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 365 | 
            +
                            Path to a directory in which a downloaded pretrained model configuration should be cached if the
         | 
| 366 | 
            +
                            standard cache should not be used.
         | 
| 367 | 
            +
                        torch_dtype (`str` or `torch.dtype`, *optional*):
         | 
| 368 | 
            +
                            Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
         | 
| 369 | 
            +
                            will be automatically derived from the model's weights.
         | 
| 370 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 371 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 372 | 
            +
                            cached versions if they exist.
         | 
| 373 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 374 | 
            +
                            Whether or not to delete incompletely received files. Will attempt to resume the download if such a
         | 
| 375 | 
            +
                            file exists.
         | 
| 376 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 377 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
         | 
| 378 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 379 | 
            +
                        output_loading_info(`bool`, *optional*, defaults to `False`):
         | 
| 380 | 
            +
                            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
         | 
| 381 | 
            +
                        local_files_only(`bool`, *optional*, defaults to `False`):
         | 
| 382 | 
            +
                            Whether or not to only look at local files (i.e., do not try to download the model).
         | 
| 383 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 384 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
         | 
| 385 | 
            +
                            when running `diffusers-cli login` (stored in `~/.huggingface`).
         | 
| 386 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 387 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
         | 
| 388 | 
            +
                            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
         | 
| 389 | 
            +
                            identifier allowed by git.
         | 
| 390 | 
            +
                        from_flax (`bool`, *optional*, defaults to `False`):
         | 
| 391 | 
            +
                            Load the model weights from a Flax checkpoint save file.
         | 
| 392 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 393 | 
            +
                            In case the relevant files are located inside a subfolder of the model repo (either remote in
         | 
| 394 | 
            +
                            huggingface.co or downloaded locally), you can specify the folder name here.
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                        mirror (`str`, *optional*):
         | 
| 397 | 
            +
                            Mirror source to accelerate downloads in China. If you are from China and have an accessibility
         | 
| 398 | 
            +
                            problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
         | 
| 399 | 
            +
                            Please refer to the mirror site for more information.
         | 
| 400 | 
            +
                        device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
         | 
| 401 | 
            +
                            A map that specifies where each submodule should go. It doesn't need to be refined to each
         | 
| 402 | 
            +
                            parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
         | 
| 403 | 
            +
                            same device.
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                            To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
         | 
| 406 | 
            +
                            more information about each option see [designing a device
         | 
| 407 | 
            +
                            map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
         | 
| 408 | 
            +
                        max_memory (`Dict`, *optional*):
         | 
| 409 | 
            +
                            A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
         | 
| 410 | 
            +
                            GPU and the available CPU RAM if unset.
         | 
| 411 | 
            +
                        offload_folder (`str` or `os.PathLike`, *optional*):
         | 
| 412 | 
            +
                            If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
         | 
| 413 | 
            +
                        offload_state_dict (`bool`, *optional*):
         | 
| 414 | 
            +
                            If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
         | 
| 415 | 
            +
                            RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
         | 
| 416 | 
            +
                            `True` when there is some disk offload.
         | 
| 417 | 
            +
                        low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
         | 
| 418 | 
            +
                            Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
         | 
| 419 | 
            +
                            also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
         | 
| 420 | 
            +
                            model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
         | 
| 421 | 
            +
                            setting this argument to `True` will raise an error.
         | 
| 422 | 
            +
                        variant (`str`, *optional*):
         | 
| 423 | 
            +
                            If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
         | 
| 424 | 
            +
                            ignored when using `from_flax`.
         | 
| 425 | 
            +
                        use_safetensors (`bool`, *optional*, defaults to `None`):
         | 
| 426 | 
            +
                            If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
         | 
| 427 | 
            +
                            `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
         | 
| 428 | 
            +
                            `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    <Tip>
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                     It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
         | 
| 433 | 
            +
                     models](https://huggingface.co/docs/hub/models-gated#gated-models).
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    </Tip>
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    <Tip>
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
         | 
| 440 | 
            +
                    this method in a firewalled environment.
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    </Tip>
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    """
         | 
| 445 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 446 | 
            +
                    ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
         | 
| 447 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 448 | 
            +
                    from_flax = kwargs.pop("from_flax", False)
         | 
| 449 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 450 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 451 | 
            +
                    output_loading_info = kwargs.pop("output_loading_info", False)
         | 
| 452 | 
            +
                    local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
         | 
| 453 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 454 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 455 | 
            +
                    torch_dtype = kwargs.pop("torch_dtype", None)
         | 
| 456 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 457 | 
            +
                    device_map = kwargs.pop("device_map", None)
         | 
| 458 | 
            +
                    max_memory = kwargs.pop("max_memory", None)
         | 
| 459 | 
            +
                    offload_folder = kwargs.pop("offload_folder", None)
         | 
| 460 | 
            +
                    offload_state_dict = kwargs.pop("offload_state_dict", False)
         | 
| 461 | 
            +
                    low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
         | 
| 462 | 
            +
                    variant = kwargs.pop("variant", None)
         | 
| 463 | 
            +
                    use_safetensors = kwargs.pop("use_safetensors", None)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    if use_safetensors and not is_safetensors_available():
         | 
| 466 | 
            +
                        raise ValueError(
         | 
| 467 | 
            +
                            "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
         | 
| 468 | 
            +
                        )
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    allow_pickle = False
         | 
| 471 | 
            +
                    if use_safetensors is None:
         | 
| 472 | 
            +
                        use_safetensors = is_safetensors_available()
         | 
| 473 | 
            +
                        allow_pickle = True
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    if low_cpu_mem_usage and not is_accelerate_available():
         | 
| 476 | 
            +
                        low_cpu_mem_usage = False
         | 
| 477 | 
            +
                        logger.warning(
         | 
| 478 | 
            +
                            "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
         | 
| 479 | 
            +
                            " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
         | 
| 480 | 
            +
                            " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
         | 
| 481 | 
            +
                            " install accelerate\n```\n."
         | 
| 482 | 
            +
                        )
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    if device_map is not None and not is_accelerate_available():
         | 
| 485 | 
            +
                        raise NotImplementedError(
         | 
| 486 | 
            +
                            "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
         | 
| 487 | 
            +
                            " `device_map=None`. You can install accelerate with `pip install accelerate`."
         | 
| 488 | 
            +
                        )
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    # Check if we can handle device_map and dispatching the weights
         | 
| 491 | 
            +
                    if device_map is not None and not is_torch_version(">=", "1.9.0"):
         | 
| 492 | 
            +
                        raise NotImplementedError(
         | 
| 493 | 
            +
                            "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
         | 
| 494 | 
            +
                            " `device_map=None`."
         | 
| 495 | 
            +
                        )
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
         | 
| 498 | 
            +
                        raise NotImplementedError(
         | 
| 499 | 
            +
                            "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
         | 
| 500 | 
            +
                            " `low_cpu_mem_usage=False`."
         | 
| 501 | 
            +
                        )
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    if low_cpu_mem_usage is False and device_map is not None:
         | 
| 504 | 
            +
                        raise ValueError(
         | 
| 505 | 
            +
                            f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
         | 
| 506 | 
            +
                            " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
         | 
| 507 | 
            +
                        )
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # Load config if we don't provide a configuration
         | 
| 510 | 
            +
                    config_path = pretrained_model_name_or_path
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    user_agent = {
         | 
| 513 | 
            +
                        "diffusers": __version__,
         | 
| 514 | 
            +
                        "file_type": "model",
         | 
| 515 | 
            +
                        "framework": "pytorch",
         | 
| 516 | 
            +
                    }
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    # load config
         | 
| 519 | 
            +
                    config, unused_kwargs, commit_hash = cls.load_config(
         | 
| 520 | 
            +
                        config_path,
         | 
| 521 | 
            +
                        cache_dir=cache_dir,
         | 
| 522 | 
            +
                        return_unused_kwargs=True,
         | 
| 523 | 
            +
                        return_commit_hash=True,
         | 
| 524 | 
            +
                        force_download=force_download,
         | 
| 525 | 
            +
                        resume_download=resume_download,
         | 
| 526 | 
            +
                        proxies=proxies,
         | 
| 527 | 
            +
                        local_files_only=local_files_only,
         | 
| 528 | 
            +
                        use_auth_token=use_auth_token,
         | 
| 529 | 
            +
                        revision=revision,
         | 
| 530 | 
            +
                        subfolder=subfolder,
         | 
| 531 | 
            +
                        device_map=device_map,
         | 
| 532 | 
            +
                        max_memory=max_memory,
         | 
| 533 | 
            +
                        offload_folder=offload_folder,
         | 
| 534 | 
            +
                        offload_state_dict=offload_state_dict,
         | 
| 535 | 
            +
                        user_agent=user_agent,
         | 
| 536 | 
            +
                        **kwargs,
         | 
| 537 | 
            +
                    )
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # load model
         | 
| 540 | 
            +
                    model_file = None
         | 
| 541 | 
            +
                    if from_flax:
         | 
| 542 | 
            +
                        model_file = _get_model_file(
         | 
| 543 | 
            +
                            pretrained_model_name_or_path,
         | 
| 544 | 
            +
                            weights_name=FLAX_WEIGHTS_NAME,
         | 
| 545 | 
            +
                            cache_dir=cache_dir,
         | 
| 546 | 
            +
                            force_download=force_download,
         | 
| 547 | 
            +
                            resume_download=resume_download,
         | 
| 548 | 
            +
                            proxies=proxies,
         | 
| 549 | 
            +
                            local_files_only=local_files_only,
         | 
| 550 | 
            +
                            use_auth_token=use_auth_token,
         | 
| 551 | 
            +
                            revision=revision,
         | 
| 552 | 
            +
                            subfolder=subfolder,
         | 
| 553 | 
            +
                            user_agent=user_agent,
         | 
| 554 | 
            +
                            commit_hash=commit_hash,
         | 
| 555 | 
            +
                        )
         | 
| 556 | 
            +
                        model = cls.from_config(config, **unused_kwargs)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                        # Convert the weights
         | 
| 559 | 
            +
                        from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                        model = load_flax_checkpoint_in_pytorch_model(model, model_file)
         | 
| 562 | 
            +
                    else:
         | 
| 563 | 
            +
                        if use_safetensors:
         | 
| 564 | 
            +
                            try:
         | 
| 565 | 
            +
                                model_file = _get_model_file(
         | 
| 566 | 
            +
                                    pretrained_model_name_or_path,
         | 
| 567 | 
            +
                                    weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
         | 
| 568 | 
            +
                                    cache_dir=cache_dir,
         | 
| 569 | 
            +
                                    force_download=force_download,
         | 
| 570 | 
            +
                                    resume_download=resume_download,
         | 
| 571 | 
            +
                                    proxies=proxies,
         | 
| 572 | 
            +
                                    local_files_only=local_files_only,
         | 
| 573 | 
            +
                                    use_auth_token=use_auth_token,
         | 
| 574 | 
            +
                                    revision=revision,
         | 
| 575 | 
            +
                                    subfolder=subfolder,
         | 
| 576 | 
            +
                                    user_agent=user_agent,
         | 
| 577 | 
            +
                                    commit_hash=commit_hash,
         | 
| 578 | 
            +
                                )
         | 
| 579 | 
            +
                            except IOError as e:
         | 
| 580 | 
            +
                                if not allow_pickle:
         | 
| 581 | 
            +
                                    raise e
         | 
| 582 | 
            +
                                pass
         | 
| 583 | 
            +
                        if model_file is None:
         | 
| 584 | 
            +
                            model_file = _get_model_file(
         | 
| 585 | 
            +
                                pretrained_model_name_or_path,
         | 
| 586 | 
            +
                                weights_name=_add_variant(WEIGHTS_NAME, variant),
         | 
| 587 | 
            +
                                cache_dir=cache_dir,
         | 
| 588 | 
            +
                                force_download=force_download,
         | 
| 589 | 
            +
                                resume_download=resume_download,
         | 
| 590 | 
            +
                                proxies=proxies,
         | 
| 591 | 
            +
                                local_files_only=local_files_only,
         | 
| 592 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 593 | 
            +
                                revision=revision,
         | 
| 594 | 
            +
                                subfolder=subfolder,
         | 
| 595 | 
            +
                                user_agent=user_agent,
         | 
| 596 | 
            +
                                commit_hash=commit_hash,
         | 
| 597 | 
            +
                            )
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                        if low_cpu_mem_usage:
         | 
| 600 | 
            +
                            # Instantiate model with empty weights
         | 
| 601 | 
            +
                            with accelerate.init_empty_weights():
         | 
| 602 | 
            +
                                model = cls.from_config(config, **unused_kwargs)
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                            # if device_map is None, load the state dict and move the params from meta device to the cpu
         | 
| 605 | 
            +
                            if device_map is None:
         | 
| 606 | 
            +
                                param_device = "cpu"
         | 
| 607 | 
            +
                                state_dict = load_state_dict(model_file, variant=variant)
         | 
| 608 | 
            +
                                model._convert_deprecated_attention_blocks(state_dict)
         | 
| 609 | 
            +
                                # move the params from meta device to cpu
         | 
| 610 | 
            +
                                missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
         | 
| 611 | 
            +
                                if len(missing_keys) > 0:
         | 
| 612 | 
            +
                                    raise ValueError(
         | 
| 613 | 
            +
                                        f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
         | 
| 614 | 
            +
                                        f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
         | 
| 615 | 
            +
                                        " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
         | 
| 616 | 
            +
                                        " those weights or else make sure your checkpoint file is correct."
         | 
| 617 | 
            +
                                    )
         | 
| 618 | 
            +
                                unexpected_keys = []
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                                empty_state_dict = model.state_dict()
         | 
| 621 | 
            +
                                for param_name, param in state_dict.items():
         | 
| 622 | 
            +
                                    accepts_dtype = "dtype" in set(
         | 
| 623 | 
            +
                                        inspect.signature(set_module_tensor_to_device).parameters.keys()
         | 
| 624 | 
            +
                                    )
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                                    if param_name not in empty_state_dict:
         | 
| 627 | 
            +
                                        unexpected_keys.append(param_name)
         | 
| 628 | 
            +
                                        continue
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                                    if empty_state_dict[param_name].shape != param.shape:
         | 
| 631 | 
            +
                                        raise ValueError(
         | 
| 632 | 
            +
                                            f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
         | 
| 633 | 
            +
                                        )
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                                    if accepts_dtype:
         | 
| 636 | 
            +
                                        set_module_tensor_to_device(
         | 
| 637 | 
            +
                                            model, param_name, param_device, value=param, dtype=torch_dtype
         | 
| 638 | 
            +
                                        )
         | 
| 639 | 
            +
                                    else:
         | 
| 640 | 
            +
                                        set_module_tensor_to_device(model, param_name, param_device, value=param)
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                                if cls._keys_to_ignore_on_load_unexpected is not None:
         | 
| 643 | 
            +
                                    for pat in cls._keys_to_ignore_on_load_unexpected:
         | 
| 644 | 
            +
                                        unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                                if len(unexpected_keys) > 0:
         | 
| 647 | 
            +
                                    logger.warn(
         | 
| 648 | 
            +
                                        f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
         | 
| 649 | 
            +
                                    )
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                            else:  # else let accelerate handle loading and dispatching.
         | 
| 652 | 
            +
                                # Load weights and dispatch according to the device_map
         | 
| 653 | 
            +
                                # by default the device_map is None and the weights are loaded on the CPU
         | 
| 654 | 
            +
                                try:
         | 
| 655 | 
            +
                                    accelerate.load_checkpoint_and_dispatch(
         | 
| 656 | 
            +
                                        model,
         | 
| 657 | 
            +
                                        model_file,
         | 
| 658 | 
            +
                                        device_map,
         | 
| 659 | 
            +
                                        max_memory=max_memory,
         | 
| 660 | 
            +
                                        offload_folder=offload_folder,
         | 
| 661 | 
            +
                                        offload_state_dict=offload_state_dict,
         | 
| 662 | 
            +
                                        dtype=torch_dtype,
         | 
| 663 | 
            +
                                    )
         | 
| 664 | 
            +
                                except AttributeError as e:
         | 
| 665 | 
            +
                                    # When using accelerate loading, we do not have the ability to load the state
         | 
| 666 | 
            +
                                    # dict and rename the weight names manually. Additionally, accelerate skips
         | 
| 667 | 
            +
                                    # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
         | 
| 668 | 
            +
                                    # (which look like they should be private variables?), so we can't use the standard hooks
         | 
| 669 | 
            +
                                    # to rename parameters on load. We need to mimic the original weight names so the correct
         | 
| 670 | 
            +
                                    # attributes are available. After we have loaded the weights, we convert the deprecated
         | 
| 671 | 
            +
                                    # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
         | 
| 672 | 
            +
                                    # the weights so we don't have to do this again.
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                                    if "'Attention' object has no attribute" in str(e):
         | 
| 675 | 
            +
                                        logger.warn(
         | 
| 676 | 
            +
                                            f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
         | 
| 677 | 
            +
                                            " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
         | 
| 678 | 
            +
                                            " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
         | 
| 679 | 
            +
                                            " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
         | 
| 680 | 
            +
                                            " please also re-upload it or open a PR on the original repository."
         | 
| 681 | 
            +
                                        )
         | 
| 682 | 
            +
                                        model._temp_convert_self_to_deprecated_attention_blocks()
         | 
| 683 | 
            +
                                        accelerate.load_checkpoint_and_dispatch(
         | 
| 684 | 
            +
                                            model,
         | 
| 685 | 
            +
                                            model_file,
         | 
| 686 | 
            +
                                            device_map,
         | 
| 687 | 
            +
                                            max_memory=max_memory,
         | 
| 688 | 
            +
                                            offload_folder=offload_folder,
         | 
| 689 | 
            +
                                            offload_state_dict=offload_state_dict,
         | 
| 690 | 
            +
                                            dtype=torch_dtype,
         | 
| 691 | 
            +
                                        )
         | 
| 692 | 
            +
                                        model._undo_temp_convert_self_to_deprecated_attention_blocks()
         | 
| 693 | 
            +
                                    else:
         | 
| 694 | 
            +
                                        raise e
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                            loading_info = {
         | 
| 697 | 
            +
                                "missing_keys": [],
         | 
| 698 | 
            +
                                "unexpected_keys": [],
         | 
| 699 | 
            +
                                "mismatched_keys": [],
         | 
| 700 | 
            +
                                "error_msgs": [],
         | 
| 701 | 
            +
                            }
         | 
| 702 | 
            +
                        else:
         | 
| 703 | 
            +
                            model = cls.from_config(config, **unused_kwargs)
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                            state_dict = load_state_dict(model_file, variant=variant)
         | 
| 706 | 
            +
                            model._convert_deprecated_attention_blocks(state_dict)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                            model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
         | 
| 709 | 
            +
                                model,
         | 
| 710 | 
            +
                                state_dict,
         | 
| 711 | 
            +
                                model_file,
         | 
| 712 | 
            +
                                pretrained_model_name_or_path,
         | 
| 713 | 
            +
                                ignore_mismatched_sizes=ignore_mismatched_sizes,
         | 
| 714 | 
            +
                            )
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                            loading_info = {
         | 
| 717 | 
            +
                                "missing_keys": missing_keys,
         | 
| 718 | 
            +
                                "unexpected_keys": unexpected_keys,
         | 
| 719 | 
            +
                                "mismatched_keys": mismatched_keys,
         | 
| 720 | 
            +
                                "error_msgs": error_msgs,
         | 
| 721 | 
            +
                            }
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
         | 
| 724 | 
            +
                        raise ValueError(
         | 
| 725 | 
            +
                            f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
         | 
| 726 | 
            +
                        )
         | 
| 727 | 
            +
                    elif torch_dtype is not None:
         | 
| 728 | 
            +
                        model = model.to(torch_dtype)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    model.register_to_config(_name_or_path=pretrained_model_name_or_path)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                    # Set model in evaluation mode to deactivate DropOut modules by default
         | 
| 733 | 
            +
                    model.eval()
         | 
| 734 | 
            +
                    if output_loading_info:
         | 
| 735 | 
            +
                        return model, loading_info
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                    return model
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                @classmethod
         | 
| 740 | 
            +
                def _load_pretrained_model(
         | 
| 741 | 
            +
                    cls,
         | 
| 742 | 
            +
                    model,
         | 
| 743 | 
            +
                    state_dict,
         | 
| 744 | 
            +
                    resolved_archive_file,
         | 
| 745 | 
            +
                    pretrained_model_name_or_path,
         | 
| 746 | 
            +
                    ignore_mismatched_sizes=False,
         | 
| 747 | 
            +
                ):
         | 
| 748 | 
            +
                    # Retrieve missing & unexpected_keys
         | 
| 749 | 
            +
                    model_state_dict = model.state_dict()
         | 
| 750 | 
            +
                    loaded_keys = list(state_dict.keys())
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    expected_keys = list(model_state_dict.keys())
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                    original_loaded_keys = loaded_keys
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    missing_keys = list(set(expected_keys) - set(loaded_keys))
         | 
| 757 | 
            +
                    unexpected_keys = list(set(loaded_keys) - set(expected_keys))
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    # Make sure we are able to load base models as well as derived models (with heads)
         | 
| 760 | 
            +
                    model_to_load = model
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                    def _find_mismatched_keys(
         | 
| 763 | 
            +
                        state_dict,
         | 
| 764 | 
            +
                        model_state_dict,
         | 
| 765 | 
            +
                        loaded_keys,
         | 
| 766 | 
            +
                        ignore_mismatched_sizes,
         | 
| 767 | 
            +
                    ):
         | 
| 768 | 
            +
                        mismatched_keys = []
         | 
| 769 | 
            +
                        if ignore_mismatched_sizes:
         | 
| 770 | 
            +
                            for checkpoint_key in loaded_keys:
         | 
| 771 | 
            +
                                model_key = checkpoint_key
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                                if (
         | 
| 774 | 
            +
                                    model_key in model_state_dict
         | 
| 775 | 
            +
                                    and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
         | 
| 776 | 
            +
                                ):
         | 
| 777 | 
            +
                                    mismatched_keys.append(
         | 
| 778 | 
            +
                                        (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
         | 
| 779 | 
            +
                                    )
         | 
| 780 | 
            +
                                    del state_dict[checkpoint_key]
         | 
| 781 | 
            +
                        return mismatched_keys
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                    if state_dict is not None:
         | 
| 784 | 
            +
                        # Whole checkpoint
         | 
| 785 | 
            +
                        mismatched_keys = _find_mismatched_keys(
         | 
| 786 | 
            +
                            state_dict,
         | 
| 787 | 
            +
                            model_state_dict,
         | 
| 788 | 
            +
                            original_loaded_keys,
         | 
| 789 | 
            +
                            ignore_mismatched_sizes,
         | 
| 790 | 
            +
                        )
         | 
| 791 | 
            +
                        error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                    if len(error_msgs) > 0:
         | 
| 794 | 
            +
                        error_msg = "\n\t".join(error_msgs)
         | 
| 795 | 
            +
                        if "size mismatch" in error_msg:
         | 
| 796 | 
            +
                            error_msg += (
         | 
| 797 | 
            +
                                "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
         | 
| 798 | 
            +
                            )
         | 
| 799 | 
            +
                        raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                    if len(unexpected_keys) > 0:
         | 
| 802 | 
            +
                        logger.warning(
         | 
| 803 | 
            +
                            f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
         | 
| 804 | 
            +
                            f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
         | 
| 805 | 
            +
                            f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
         | 
| 806 | 
            +
                            " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
         | 
| 807 | 
            +
                            " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
         | 
| 808 | 
            +
                            f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
         | 
| 809 | 
            +
                            " identical (initializing a BertForSequenceClassification model from a"
         | 
| 810 | 
            +
                            " BertForSequenceClassification model)."
         | 
| 811 | 
            +
                        )
         | 
| 812 | 
            +
                    else:
         | 
| 813 | 
            +
                        logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
         | 
| 814 | 
            +
                    if len(missing_keys) > 0:
         | 
| 815 | 
            +
                        logger.warning(
         | 
| 816 | 
            +
                            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
         | 
| 817 | 
            +
                            f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
         | 
| 818 | 
            +
                            " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
         | 
| 819 | 
            +
                        )
         | 
| 820 | 
            +
                    elif len(mismatched_keys) == 0:
         | 
| 821 | 
            +
                        logger.info(
         | 
| 822 | 
            +
                            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
         | 
| 823 | 
            +
                            f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
         | 
| 824 | 
            +
                            f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
         | 
| 825 | 
            +
                            " without further training."
         | 
| 826 | 
            +
                        )
         | 
| 827 | 
            +
                    if len(mismatched_keys) > 0:
         | 
| 828 | 
            +
                        mismatched_warning = "\n".join(
         | 
| 829 | 
            +
                            [
         | 
| 830 | 
            +
                                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
         | 
| 831 | 
            +
                                for key, shape1, shape2 in mismatched_keys
         | 
| 832 | 
            +
                            ]
         | 
| 833 | 
            +
                        )
         | 
| 834 | 
            +
                        logger.warning(
         | 
| 835 | 
            +
                            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
         | 
| 836 | 
            +
                            f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
         | 
| 837 | 
            +
                            f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
         | 
| 838 | 
            +
                            " able to use it for predictions and inference."
         | 
| 839 | 
            +
                        )
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                    return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                @property
         | 
| 844 | 
            +
                def device(self) -> device:
         | 
| 845 | 
            +
                    """
         | 
| 846 | 
            +
                    `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
         | 
| 847 | 
            +
                    device).
         | 
| 848 | 
            +
                    """
         | 
| 849 | 
            +
                    return get_parameter_device(self)
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                @property
         | 
| 852 | 
            +
                def dtype(self) -> torch.dtype:
         | 
| 853 | 
            +
                    """
         | 
| 854 | 
            +
                    `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
         | 
| 855 | 
            +
                    """
         | 
| 856 | 
            +
                    return get_parameter_dtype(self)
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
         | 
| 859 | 
            +
                    """
         | 
| 860 | 
            +
                    Get number of (optionally, trainable or non-embeddings) parameters in the module.
         | 
| 861 | 
            +
             | 
| 862 | 
            +
                    Args:
         | 
| 863 | 
            +
                        only_trainable (`bool`, *optional*, defaults to `False`):
         | 
| 864 | 
            +
                            Whether or not to return only the number of trainable parameters
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                        exclude_embeddings (`bool`, *optional*, defaults to `False`):
         | 
| 867 | 
            +
                            Whether or not to return only the number of non-embeddings parameters
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                    Returns:
         | 
| 870 | 
            +
                        `int`: The number of parameters.
         | 
| 871 | 
            +
                    """
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                    if exclude_embeddings:
         | 
| 874 | 
            +
                        embedding_param_names = [
         | 
| 875 | 
            +
                            f"{name}.weight"
         | 
| 876 | 
            +
                            for name, module_type in self.named_modules()
         | 
| 877 | 
            +
                            if isinstance(module_type, torch.nn.Embedding)
         | 
| 878 | 
            +
                        ]
         | 
| 879 | 
            +
                        non_embedding_parameters = [
         | 
| 880 | 
            +
                            parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
         | 
| 881 | 
            +
                        ]
         | 
| 882 | 
            +
                        return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
         | 
| 883 | 
            +
                    else:
         | 
| 884 | 
            +
                        return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                def _convert_deprecated_attention_blocks(self, state_dict):
         | 
| 887 | 
            +
                    deprecated_attention_block_paths = []
         | 
| 888 | 
            +
             | 
| 889 | 
            +
                    def recursive_find_attn_block(name, module):
         | 
| 890 | 
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         | 
| 891 | 
            +
                            deprecated_attention_block_paths.append(name)
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                        for sub_name, sub_module in module.named_children():
         | 
| 894 | 
            +
                            sub_name = sub_name if name == "" else f"{name}.{sub_name}"
         | 
| 895 | 
            +
                            recursive_find_attn_block(sub_name, sub_module)
         | 
| 896 | 
            +
             | 
| 897 | 
            +
                    recursive_find_attn_block("", self)
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                    # NOTE: we have to check if the deprecated parameters are in the state dict
         | 
| 900 | 
            +
                    # because it is possible we are loading from a state dict that was already
         | 
| 901 | 
            +
                    # converted
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                    for path in deprecated_attention_block_paths:
         | 
| 904 | 
            +
                        # group_norm path stays the same
         | 
| 905 | 
            +
             | 
| 906 | 
            +
                        # query -> to_q
         | 
| 907 | 
            +
                        if f"{path}.query.weight" in state_dict:
         | 
| 908 | 
            +
                            state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
         | 
| 909 | 
            +
                        if f"{path}.query.bias" in state_dict:
         | 
| 910 | 
            +
                            state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                        # key -> to_k
         | 
| 913 | 
            +
                        if f"{path}.key.weight" in state_dict:
         | 
| 914 | 
            +
                            state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
         | 
| 915 | 
            +
                        if f"{path}.key.bias" in state_dict:
         | 
| 916 | 
            +
                            state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                        # value -> to_v
         | 
| 919 | 
            +
                        if f"{path}.value.weight" in state_dict:
         | 
| 920 | 
            +
                            state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
         | 
| 921 | 
            +
                        if f"{path}.value.bias" in state_dict:
         | 
| 922 | 
            +
                            state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
         | 
| 923 | 
            +
             | 
| 924 | 
            +
                        # proj_attn -> to_out.0
         | 
| 925 | 
            +
                        if f"{path}.proj_attn.weight" in state_dict:
         | 
| 926 | 
            +
                            state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
         | 
| 927 | 
            +
                        if f"{path}.proj_attn.bias" in state_dict:
         | 
| 928 | 
            +
                            state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
         | 
| 929 | 
            +
             | 
| 930 | 
            +
                def _temp_convert_self_to_deprecated_attention_blocks(self):
         | 
| 931 | 
            +
                    deprecated_attention_block_modules = []
         | 
| 932 | 
            +
             | 
| 933 | 
            +
                    def recursive_find_attn_block(module):
         | 
| 934 | 
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         | 
| 935 | 
            +
                            deprecated_attention_block_modules.append(module)
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                        for sub_module in module.children():
         | 
| 938 | 
            +
                            recursive_find_attn_block(sub_module)
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                    recursive_find_attn_block(self)
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                    for module in deprecated_attention_block_modules:
         | 
| 943 | 
            +
                        module.query = module.to_q
         | 
| 944 | 
            +
                        module.key = module.to_k
         | 
| 945 | 
            +
                        module.value = module.to_v
         | 
| 946 | 
            +
                        module.proj_attn = module.to_out[0]
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                        # We don't _have_ to delete the old attributes, but it's helpful to ensure
         | 
| 949 | 
            +
                        # that _all_ the weights are loaded into the new attributes and we're not
         | 
| 950 | 
            +
                        # making an incorrect assumption that this model should be converted when
         | 
| 951 | 
            +
                        # it really shouldn't be.
         | 
| 952 | 
            +
                        del module.to_q
         | 
| 953 | 
            +
                        del module.to_k
         | 
| 954 | 
            +
                        del module.to_v
         | 
| 955 | 
            +
                        del module.to_out
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
         | 
| 958 | 
            +
                    deprecated_attention_block_modules = []
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                    def recursive_find_attn_block(module):
         | 
| 961 | 
            +
                        if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
         | 
| 962 | 
            +
                            deprecated_attention_block_modules.append(module)
         | 
| 963 | 
            +
             | 
| 964 | 
            +
                        for sub_module in module.children():
         | 
| 965 | 
            +
                            recursive_find_attn_block(sub_module)
         | 
| 966 | 
            +
             | 
| 967 | 
            +
                    recursive_find_attn_block(self)
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                    for module in deprecated_attention_block_modules:
         | 
| 970 | 
            +
                        module.to_q = module.query
         | 
| 971 | 
            +
                        module.to_k = module.key
         | 
| 972 | 
            +
                        module.to_v = module.value
         | 
| 973 | 
            +
                        module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                        del module.query
         | 
| 976 | 
            +
                        del module.key
         | 
| 977 | 
            +
                        del module.value
         | 
| 978 | 
            +
                        del module.proj_attn
         | 
    	
        diffusers/models/prior_transformer.py
    ADDED
    
    | @@ -0,0 +1,194 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 9 | 
            +
            from ..utils.outputs import BaseOutput
         | 
| 10 | 
            +
            from .attention import BasicTransformerBlock
         | 
| 11 | 
            +
            from .embeddings import TimestepEmbedding, Timesteps
         | 
| 12 | 
            +
            from .modeling_utils import ModelMixin
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            @dataclass
         | 
| 16 | 
            +
            class PriorTransformerOutput(BaseOutput):
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                    predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         | 
| 20 | 
            +
                        The predicted CLIP image embedding conditioned on the CLIP text embedding input.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                predicted_image_embedding: torch.FloatTensor
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class PriorTransformer(ModelMixin, ConfigMixin):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
         | 
| 29 | 
            +
                transformer predicts the image embeddings through a denoising diffusion process.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
         | 
| 32 | 
            +
                implements for all the models (such as downloading or saving, etc.)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                For more details, see the original paper: https://arxiv.org/abs/2204.06125
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Parameters:
         | 
| 37 | 
            +
                    num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
         | 
| 38 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
         | 
| 39 | 
            +
                    num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
         | 
| 40 | 
            +
                    embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
         | 
| 41 | 
            +
                        image embeddings and text embeddings are both the same dimension.
         | 
| 42 | 
            +
                    num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
         | 
| 43 | 
            +
                        length of the prompt after it has been tokenized.
         | 
| 44 | 
            +
                    additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
         | 
| 45 | 
            +
                        projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
         | 
| 46 | 
            +
                        additional_embeddings`.
         | 
| 47 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @register_to_config
         | 
| 52 | 
            +
                def __init__(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    num_attention_heads: int = 32,
         | 
| 55 | 
            +
                    attention_head_dim: int = 64,
         | 
| 56 | 
            +
                    num_layers: int = 20,
         | 
| 57 | 
            +
                    embedding_dim: int = 768,
         | 
| 58 | 
            +
                    num_embeddings=77,
         | 
| 59 | 
            +
                    additional_embeddings=4,
         | 
| 60 | 
            +
                    dropout: float = 0.0,
         | 
| 61 | 
            +
                ):
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 64 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 65 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 66 | 
            +
                    self.additional_embeddings = additional_embeddings
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.time_proj = Timesteps(inner_dim, True, 0)
         | 
| 69 | 
            +
                    self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.proj_in = nn.Linear(embedding_dim, inner_dim)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
         | 
| 74 | 
            +
                    self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 81 | 
            +
                        [
         | 
| 82 | 
            +
                            BasicTransformerBlock(
         | 
| 83 | 
            +
                                inner_dim,
         | 
| 84 | 
            +
                                num_attention_heads,
         | 
| 85 | 
            +
                                attention_head_dim,
         | 
| 86 | 
            +
                                dropout=dropout,
         | 
| 87 | 
            +
                                activation_fn="gelu",
         | 
| 88 | 
            +
                                attention_bias=True,
         | 
| 89 | 
            +
                            )
         | 
| 90 | 
            +
                            for d in range(num_layers)
         | 
| 91 | 
            +
                        ]
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.norm_out = nn.LayerNorm(inner_dim)
         | 
| 95 | 
            +
                    self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    causal_attention_mask = torch.full(
         | 
| 98 | 
            +
                        [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    causal_attention_mask.triu_(1)
         | 
| 101 | 
            +
                    causal_attention_mask = causal_attention_mask[None, ...]
         | 
| 102 | 
            +
                    self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
         | 
| 105 | 
            +
                    self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def forward(
         | 
| 108 | 
            +
                    self,
         | 
| 109 | 
            +
                    hidden_states,
         | 
| 110 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 111 | 
            +
                    proj_embedding: torch.FloatTensor,
         | 
| 112 | 
            +
                    encoder_hidden_states: torch.FloatTensor,
         | 
| 113 | 
            +
                    attention_mask: Optional[torch.BoolTensor] = None,
         | 
| 114 | 
            +
                    return_dict: bool = True,
         | 
| 115 | 
            +
                ):
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    Args:
         | 
| 118 | 
            +
                        hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         | 
| 119 | 
            +
                            x_t, the currently predicted image embeddings.
         | 
| 120 | 
            +
                        timestep (`torch.long`):
         | 
| 121 | 
            +
                            Current denoising step.
         | 
| 122 | 
            +
                        proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
         | 
| 123 | 
            +
                            Projected embedding vector the denoising process is conditioned on.
         | 
| 124 | 
            +
                        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
         | 
| 125 | 
            +
                            Hidden states of the text embeddings the denoising process is conditioned on.
         | 
| 126 | 
            +
                        attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
         | 
| 127 | 
            +
                            Text mask for the text embeddings.
         | 
| 128 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 129 | 
            +
                            Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
         | 
| 130 | 
            +
                            tuple.
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    Returns:
         | 
| 133 | 
            +
                        [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
         | 
| 134 | 
            +
                        [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
         | 
| 135 | 
            +
                        returning a tuple, the first element is the sample tensor.
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    batch_size = hidden_states.shape[0]
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    timesteps = timestep
         | 
| 140 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 141 | 
            +
                        timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
         | 
| 142 | 
            +
                    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
         | 
| 143 | 
            +
                        timesteps = timesteps[None].to(hidden_states.device)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 146 | 
            +
                    timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    timesteps_projected = self.time_proj(timesteps)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 151 | 
            +
                    # but time_embedding might be fp16, so we need to cast here.
         | 
| 152 | 
            +
                    timesteps_projected = timesteps_projected.to(dtype=self.dtype)
         | 
| 153 | 
            +
                    time_embeddings = self.time_embedding(timesteps_projected)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    proj_embeddings = self.embedding_proj(proj_embedding)
         | 
| 156 | 
            +
                    encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
         | 
| 157 | 
            +
                    hidden_states = self.proj_in(hidden_states)
         | 
| 158 | 
            +
                    prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
         | 
| 159 | 
            +
                    positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    hidden_states = torch.cat(
         | 
| 162 | 
            +
                        [
         | 
| 163 | 
            +
                            encoder_hidden_states,
         | 
| 164 | 
            +
                            proj_embeddings[:, None, :],
         | 
| 165 | 
            +
                            time_embeddings[:, None, :],
         | 
| 166 | 
            +
                            hidden_states[:, None, :],
         | 
| 167 | 
            +
                            prd_embedding,
         | 
| 168 | 
            +
                        ],
         | 
| 169 | 
            +
                        dim=1,
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    hidden_states = hidden_states + positional_embeddings
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    if attention_mask is not None:
         | 
| 175 | 
            +
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         | 
| 176 | 
            +
                        attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
         | 
| 177 | 
            +
                        attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
         | 
| 178 | 
            +
                        attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    for block in self.transformer_blocks:
         | 
| 181 | 
            +
                        hidden_states = block(hidden_states, attention_mask=attention_mask)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    hidden_states = self.norm_out(hidden_states)
         | 
| 184 | 
            +
                    hidden_states = hidden_states[:, -1]
         | 
| 185 | 
            +
                    predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if not return_dict:
         | 
| 188 | 
            +
                        return (predicted_image_embedding,)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def post_process_latents(self, prior_latents):
         | 
| 193 | 
            +
                    prior_latents = (prior_latents * self.clip_std) + self.clip_mean
         | 
| 194 | 
            +
                    return prior_latents
         | 
    	
        diffusers/models/resnet.py
    ADDED
    
    | @@ -0,0 +1,839 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from functools import partial
         | 
| 17 | 
            +
            from typing import Optional
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            import torch.nn as nn
         | 
| 21 | 
            +
            import torch.nn.functional as F
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from .attention import AdaGroupNorm
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Upsample1D(nn.Module):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                An upsampling layer with an optional convolution.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Parameters:
         | 
| 31 | 
            +
                        channels: channels in the inputs and outputs.
         | 
| 32 | 
            +
                        use_conv: a bool determining if a convolution is applied.
         | 
| 33 | 
            +
                        use_conv_transpose:
         | 
| 34 | 
            +
                        out_channels:
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         | 
| 38 | 
            +
                    super().__init__()
         | 
| 39 | 
            +
                    self.channels = channels
         | 
| 40 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 41 | 
            +
                    self.use_conv = use_conv
         | 
| 42 | 
            +
                    self.use_conv_transpose = use_conv_transpose
         | 
| 43 | 
            +
                    self.name = name
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    self.conv = None
         | 
| 46 | 
            +
                    if use_conv_transpose:
         | 
| 47 | 
            +
                        self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
         | 
| 48 | 
            +
                    elif use_conv:
         | 
| 49 | 
            +
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def forward(self, x):
         | 
| 52 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 53 | 
            +
                    if self.use_conv_transpose:
         | 
| 54 | 
            +
                        return self.conv(x)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    x = F.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if self.use_conv:
         | 
| 59 | 
            +
                        x = self.conv(x)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    return x
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class Downsample1D(nn.Module):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                A downsampling layer with an optional convolution.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                Parameters:
         | 
| 69 | 
            +
                    channels: channels in the inputs and outputs.
         | 
| 70 | 
            +
                    use_conv: a bool determining if a convolution is applied.
         | 
| 71 | 
            +
                    out_channels:
         | 
| 72 | 
            +
                    padding:
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
         | 
| 76 | 
            +
                    super().__init__()
         | 
| 77 | 
            +
                    self.channels = channels
         | 
| 78 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 79 | 
            +
                    self.use_conv = use_conv
         | 
| 80 | 
            +
                    self.padding = padding
         | 
| 81 | 
            +
                    stride = 2
         | 
| 82 | 
            +
                    self.name = name
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if use_conv:
         | 
| 85 | 
            +
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
         | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        assert self.channels == self.out_channels
         | 
| 88 | 
            +
                        self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(self, x):
         | 
| 91 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 92 | 
            +
                    return self.conv(x)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            class Upsample2D(nn.Module):
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                An upsampling layer with an optional convolution.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                Parameters:
         | 
| 100 | 
            +
                    channels: channels in the inputs and outputs.
         | 
| 101 | 
            +
                    use_conv: a bool determining if a convolution is applied.
         | 
| 102 | 
            +
                    use_conv_transpose:
         | 
| 103 | 
            +
                    out_channels:
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         | 
| 107 | 
            +
                    super().__init__()
         | 
| 108 | 
            +
                    self.channels = channels
         | 
| 109 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 110 | 
            +
                    self.use_conv = use_conv
         | 
| 111 | 
            +
                    self.use_conv_transpose = use_conv_transpose
         | 
| 112 | 
            +
                    self.name = name
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    conv = None
         | 
| 115 | 
            +
                    if use_conv_transpose:
         | 
| 116 | 
            +
                        conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
         | 
| 117 | 
            +
                    elif use_conv:
         | 
| 118 | 
            +
                        conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         | 
| 121 | 
            +
                    if name == "conv":
         | 
| 122 | 
            +
                        self.conv = conv
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        self.Conv2d_0 = conv
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def forward(self, hidden_states, output_size=None):
         | 
| 127 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    if self.use_conv_transpose:
         | 
| 130 | 
            +
                        return self.conv(hidden_states)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         | 
| 133 | 
            +
                    # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
         | 
| 134 | 
            +
                    # https://github.com/pytorch/pytorch/issues/86679
         | 
| 135 | 
            +
                    dtype = hidden_states.dtype
         | 
| 136 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 137 | 
            +
                        hidden_states = hidden_states.to(torch.float32)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         | 
| 140 | 
            +
                    if hidden_states.shape[0] >= 64:
         | 
| 141 | 
            +
                        hidden_states = hidden_states.contiguous()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    # if `output_size` is passed we force the interpolation output
         | 
| 144 | 
            +
                    # size and do not make use of `scale_factor=2`
         | 
| 145 | 
            +
                    if output_size is None:
         | 
| 146 | 
            +
                        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # If the input is bfloat16, we cast back to bfloat16
         | 
| 151 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 152 | 
            +
                        hidden_states = hidden_states.to(dtype)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         | 
| 155 | 
            +
                    if self.use_conv:
         | 
| 156 | 
            +
                        if self.name == "conv":
         | 
| 157 | 
            +
                            hidden_states = self.conv(hidden_states)
         | 
| 158 | 
            +
                        else:
         | 
| 159 | 
            +
                            hidden_states = self.Conv2d_0(hidden_states)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return hidden_states
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            class Downsample2D(nn.Module):
         | 
| 165 | 
            +
                """
         | 
| 166 | 
            +
                A downsampling layer with an optional convolution.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                Parameters:
         | 
| 169 | 
            +
                    channels: channels in the inputs and outputs.
         | 
| 170 | 
            +
                    use_conv: a bool determining if a convolution is applied.
         | 
| 171 | 
            +
                    out_channels:
         | 
| 172 | 
            +
                    padding:
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
         | 
| 176 | 
            +
                    super().__init__()
         | 
| 177 | 
            +
                    self.channels = channels
         | 
| 178 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 179 | 
            +
                    self.use_conv = use_conv
         | 
| 180 | 
            +
                    self.padding = padding
         | 
| 181 | 
            +
                    stride = 2
         | 
| 182 | 
            +
                    self.name = name
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if use_conv:
         | 
| 185 | 
            +
                        conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
         | 
| 186 | 
            +
                    else:
         | 
| 187 | 
            +
                        assert self.channels == self.out_channels
         | 
| 188 | 
            +
                        conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
         | 
| 191 | 
            +
                    if name == "conv":
         | 
| 192 | 
            +
                        self.Conv2d_0 = conv
         | 
| 193 | 
            +
                        self.conv = conv
         | 
| 194 | 
            +
                    elif name == "Conv2d_0":
         | 
| 195 | 
            +
                        self.conv = conv
         | 
| 196 | 
            +
                    else:
         | 
| 197 | 
            +
                        self.conv = conv
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def forward(self, hidden_states):
         | 
| 200 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 201 | 
            +
                    if self.use_conv and self.padding == 0:
         | 
| 202 | 
            +
                        pad = (0, 1, 0, 1)
         | 
| 203 | 
            +
                        hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 206 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    return hidden_states
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            class FirUpsample2D(nn.Module):
         | 
| 212 | 
            +
                def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
         | 
| 213 | 
            +
                    super().__init__()
         | 
| 214 | 
            +
                    out_channels = out_channels if out_channels else channels
         | 
| 215 | 
            +
                    if use_conv:
         | 
| 216 | 
            +
                        self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 217 | 
            +
                    self.use_conv = use_conv
         | 
| 218 | 
            +
                    self.fir_kernel = fir_kernel
         | 
| 219 | 
            +
                    self.out_channels = out_channels
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
         | 
| 222 | 
            +
                    """Fused `upsample_2d()` followed by `Conv2d()`.
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
         | 
| 225 | 
            +
                    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
         | 
| 226 | 
            +
                    arbitrary order.
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    Args:
         | 
| 229 | 
            +
                        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 230 | 
            +
                        weight: Weight tensor of the shape `[filterH, filterW, inChannels,
         | 
| 231 | 
            +
                            outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         | 
| 232 | 
            +
                        kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 233 | 
            +
                            (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
         | 
| 234 | 
            +
                        factor: Integer upsampling factor (default: 2).
         | 
| 235 | 
            +
                        gain: Scaling factor for signal magnitude (default: 1.0).
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    Returns:
         | 
| 238 | 
            +
                        output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
         | 
| 239 | 
            +
                        datatype as `hidden_states`.
         | 
| 240 | 
            +
                    """
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    assert isinstance(factor, int) and factor >= 1
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    # Setup filter kernel.
         | 
| 245 | 
            +
                    if kernel is None:
         | 
| 246 | 
            +
                        kernel = [1] * factor
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # setup kernel
         | 
| 249 | 
            +
                    kernel = torch.tensor(kernel, dtype=torch.float32)
         | 
| 250 | 
            +
                    if kernel.ndim == 1:
         | 
| 251 | 
            +
                        kernel = torch.outer(kernel, kernel)
         | 
| 252 | 
            +
                    kernel /= torch.sum(kernel)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    kernel = kernel * (gain * (factor**2))
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    if self.use_conv:
         | 
| 257 | 
            +
                        convH = weight.shape[2]
         | 
| 258 | 
            +
                        convW = weight.shape[3]
         | 
| 259 | 
            +
                        inC = weight.shape[1]
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        pad_value = (kernel.shape[0] - factor) - (convW - 1)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                        stride = (factor, factor)
         | 
| 264 | 
            +
                        # Determine data dimensions.
         | 
| 265 | 
            +
                        output_shape = (
         | 
| 266 | 
            +
                            (hidden_states.shape[2] - 1) * factor + convH,
         | 
| 267 | 
            +
                            (hidden_states.shape[3] - 1) * factor + convW,
         | 
| 268 | 
            +
                        )
         | 
| 269 | 
            +
                        output_padding = (
         | 
| 270 | 
            +
                            output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
         | 
| 271 | 
            +
                            output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
                        assert output_padding[0] >= 0 and output_padding[1] >= 0
         | 
| 274 | 
            +
                        num_groups = hidden_states.shape[1] // inC
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                        # Transpose weights.
         | 
| 277 | 
            +
                        weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
         | 
| 278 | 
            +
                        weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
         | 
| 279 | 
            +
                        weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        inverse_conv = F.conv_transpose2d(
         | 
| 282 | 
            +
                            hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
         | 
| 283 | 
            +
                        )
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                        output = upfirdn2d_native(
         | 
| 286 | 
            +
                            inverse_conv,
         | 
| 287 | 
            +
                            torch.tensor(kernel, device=inverse_conv.device),
         | 
| 288 | 
            +
                            pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        pad_value = kernel.shape[0] - factor
         | 
| 292 | 
            +
                        output = upfirdn2d_native(
         | 
| 293 | 
            +
                            hidden_states,
         | 
| 294 | 
            +
                            torch.tensor(kernel, device=hidden_states.device),
         | 
| 295 | 
            +
                            up=factor,
         | 
| 296 | 
            +
                            pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
         | 
| 297 | 
            +
                        )
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    return output
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                def forward(self, hidden_states):
         | 
| 302 | 
            +
                    if self.use_conv:
         | 
| 303 | 
            +
                        height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
         | 
| 304 | 
            +
                        height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    return height
         | 
| 309 | 
            +
             | 
| 310 | 
            +
             | 
| 311 | 
            +
            class FirDownsample2D(nn.Module):
         | 
| 312 | 
            +
                def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
         | 
| 313 | 
            +
                    super().__init__()
         | 
| 314 | 
            +
                    out_channels = out_channels if out_channels else channels
         | 
| 315 | 
            +
                    if use_conv:
         | 
| 316 | 
            +
                        self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 317 | 
            +
                    self.fir_kernel = fir_kernel
         | 
| 318 | 
            +
                    self.use_conv = use_conv
         | 
| 319 | 
            +
                    self.out_channels = out_channels
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
         | 
| 322 | 
            +
                    """Fused `Conv2d()` followed by `downsample_2d()`.
         | 
| 323 | 
            +
                    Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
         | 
| 324 | 
            +
                    efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
         | 
| 325 | 
            +
                    arbitrary order.
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    Args:
         | 
| 328 | 
            +
                        hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 329 | 
            +
                        weight:
         | 
| 330 | 
            +
                            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
         | 
| 331 | 
            +
                            performed by `inChannels = x.shape[0] // numGroups`.
         | 
| 332 | 
            +
                        kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
         | 
| 333 | 
            +
                        factor`, which corresponds to average pooling.
         | 
| 334 | 
            +
                        factor: Integer downsampling factor (default: 2).
         | 
| 335 | 
            +
                        gain: Scaling factor for signal magnitude (default: 1.0).
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    Returns:
         | 
| 338 | 
            +
                        output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
         | 
| 339 | 
            +
                        same datatype as `x`.
         | 
| 340 | 
            +
                    """
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    assert isinstance(factor, int) and factor >= 1
         | 
| 343 | 
            +
                    if kernel is None:
         | 
| 344 | 
            +
                        kernel = [1] * factor
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    # setup kernel
         | 
| 347 | 
            +
                    kernel = torch.tensor(kernel, dtype=torch.float32)
         | 
| 348 | 
            +
                    if kernel.ndim == 1:
         | 
| 349 | 
            +
                        kernel = torch.outer(kernel, kernel)
         | 
| 350 | 
            +
                    kernel /= torch.sum(kernel)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    kernel = kernel * gain
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if self.use_conv:
         | 
| 355 | 
            +
                        _, _, convH, convW = weight.shape
         | 
| 356 | 
            +
                        pad_value = (kernel.shape[0] - factor) + (convW - 1)
         | 
| 357 | 
            +
                        stride_value = [factor, factor]
         | 
| 358 | 
            +
                        upfirdn_input = upfirdn2d_native(
         | 
| 359 | 
            +
                            hidden_states,
         | 
| 360 | 
            +
                            torch.tensor(kernel, device=hidden_states.device),
         | 
| 361 | 
            +
                            pad=((pad_value + 1) // 2, pad_value // 2),
         | 
| 362 | 
            +
                        )
         | 
| 363 | 
            +
                        output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
         | 
| 364 | 
            +
                    else:
         | 
| 365 | 
            +
                        pad_value = kernel.shape[0] - factor
         | 
| 366 | 
            +
                        output = upfirdn2d_native(
         | 
| 367 | 
            +
                            hidden_states,
         | 
| 368 | 
            +
                            torch.tensor(kernel, device=hidden_states.device),
         | 
| 369 | 
            +
                            down=factor,
         | 
| 370 | 
            +
                            pad=((pad_value + 1) // 2, pad_value // 2),
         | 
| 371 | 
            +
                        )
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    return output
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                def forward(self, hidden_states):
         | 
| 376 | 
            +
                    if self.use_conv:
         | 
| 377 | 
            +
                        downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
         | 
| 378 | 
            +
                        hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
         | 
| 379 | 
            +
                    else:
         | 
| 380 | 
            +
                        hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    return hidden_states
         | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 385 | 
            +
            # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
         | 
| 386 | 
            +
            class KDownsample2D(nn.Module):
         | 
| 387 | 
            +
                def __init__(self, pad_mode="reflect"):
         | 
| 388 | 
            +
                    super().__init__()
         | 
| 389 | 
            +
                    self.pad_mode = pad_mode
         | 
| 390 | 
            +
                    kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
         | 
| 391 | 
            +
                    self.pad = kernel_1d.shape[1] // 2 - 1
         | 
| 392 | 
            +
                    self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                def forward(self, x):
         | 
| 395 | 
            +
                    x = F.pad(x, (self.pad,) * 4, self.pad_mode)
         | 
| 396 | 
            +
                    weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
         | 
| 397 | 
            +
                    indices = torch.arange(x.shape[1], device=x.device)
         | 
| 398 | 
            +
                    weight[indices, indices] = self.kernel.to(weight)
         | 
| 399 | 
            +
                    return F.conv2d(x, weight, stride=2)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
             | 
| 402 | 
            +
            class KUpsample2D(nn.Module):
         | 
| 403 | 
            +
                def __init__(self, pad_mode="reflect"):
         | 
| 404 | 
            +
                    super().__init__()
         | 
| 405 | 
            +
                    self.pad_mode = pad_mode
         | 
| 406 | 
            +
                    kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
         | 
| 407 | 
            +
                    self.pad = kernel_1d.shape[1] // 2 - 1
         | 
| 408 | 
            +
                    self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                def forward(self, x):
         | 
| 411 | 
            +
                    x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
         | 
| 412 | 
            +
                    weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
         | 
| 413 | 
            +
                    indices = torch.arange(x.shape[1], device=x.device)
         | 
| 414 | 
            +
                    weight[indices, indices] = self.kernel.to(weight)
         | 
| 415 | 
            +
                    return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
         | 
| 416 | 
            +
             | 
| 417 | 
            +
             | 
| 418 | 
            +
            class ResnetBlock2D(nn.Module):
         | 
| 419 | 
            +
                r"""
         | 
| 420 | 
            +
                A Resnet block.
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                Parameters:
         | 
| 423 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 424 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 425 | 
            +
                        The number of output channels for the first conv2d layer. If None, same as `in_channels`.
         | 
| 426 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 427 | 
            +
                    temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
         | 
| 428 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 429 | 
            +
                    groups_out (`int`, *optional*, default to None):
         | 
| 430 | 
            +
                        The number of groups to use for the second normalization layer. if set to None, same as `groups`.
         | 
| 431 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 432 | 
            +
                    non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
         | 
| 433 | 
            +
                    time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
         | 
| 434 | 
            +
                        By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
         | 
| 435 | 
            +
                        "ada_group" for a stronger conditioning with scale and shift.
         | 
| 436 | 
            +
                    kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
         | 
| 437 | 
            +
                        [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
         | 
| 438 | 
            +
                    output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
         | 
| 439 | 
            +
                    use_in_shortcut (`bool`, *optional*, default to `True`):
         | 
| 440 | 
            +
                        If `True`, add a 1x1 nn.conv2d layer for skip-connection.
         | 
| 441 | 
            +
                    up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
         | 
| 442 | 
            +
                    down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
         | 
| 443 | 
            +
                    conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
         | 
| 444 | 
            +
                        `conv_shortcut` output.
         | 
| 445 | 
            +
                    conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
         | 
| 446 | 
            +
                        If None, same as `out_channels`.
         | 
| 447 | 
            +
                """
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                def __init__(
         | 
| 450 | 
            +
                    self,
         | 
| 451 | 
            +
                    *,
         | 
| 452 | 
            +
                    in_channels,
         | 
| 453 | 
            +
                    out_channels=None,
         | 
| 454 | 
            +
                    conv_shortcut=False,
         | 
| 455 | 
            +
                    dropout=0.0,
         | 
| 456 | 
            +
                    temb_channels=512,
         | 
| 457 | 
            +
                    groups=32,
         | 
| 458 | 
            +
                    groups_out=None,
         | 
| 459 | 
            +
                    pre_norm=True,
         | 
| 460 | 
            +
                    eps=1e-6,
         | 
| 461 | 
            +
                    non_linearity="swish",
         | 
| 462 | 
            +
                    time_embedding_norm="default",  # default, scale_shift, ada_group
         | 
| 463 | 
            +
                    kernel=None,
         | 
| 464 | 
            +
                    output_scale_factor=1.0,
         | 
| 465 | 
            +
                    use_in_shortcut=None,
         | 
| 466 | 
            +
                    up=False,
         | 
| 467 | 
            +
                    down=False,
         | 
| 468 | 
            +
                    conv_shortcut_bias: bool = True,
         | 
| 469 | 
            +
                    conv_2d_out_channels: Optional[int] = None,
         | 
| 470 | 
            +
                ):
         | 
| 471 | 
            +
                    super().__init__()
         | 
| 472 | 
            +
                    self.pre_norm = pre_norm
         | 
| 473 | 
            +
                    self.pre_norm = True
         | 
| 474 | 
            +
                    self.in_channels = in_channels
         | 
| 475 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 476 | 
            +
                    self.out_channels = out_channels
         | 
| 477 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 478 | 
            +
                    self.up = up
         | 
| 479 | 
            +
                    self.down = down
         | 
| 480 | 
            +
                    self.output_scale_factor = output_scale_factor
         | 
| 481 | 
            +
                    self.time_embedding_norm = time_embedding_norm
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    if groups_out is None:
         | 
| 484 | 
            +
                        groups_out = groups
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 487 | 
            +
                        self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
         | 
| 488 | 
            +
                    else:
         | 
| 489 | 
            +
                        self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    if temb_channels is not None:
         | 
| 494 | 
            +
                        if self.time_embedding_norm == "default":
         | 
| 495 | 
            +
                            self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
         | 
| 496 | 
            +
                        elif self.time_embedding_norm == "scale_shift":
         | 
| 497 | 
            +
                            self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
         | 
| 498 | 
            +
                        elif self.time_embedding_norm == "ada_group":
         | 
| 499 | 
            +
                            self.time_emb_proj = None
         | 
| 500 | 
            +
                        else:
         | 
| 501 | 
            +
                            raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
         | 
| 502 | 
            +
                    else:
         | 
| 503 | 
            +
                        self.time_emb_proj = None
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 506 | 
            +
                        self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
         | 
| 507 | 
            +
                    else:
         | 
| 508 | 
            +
                        self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 511 | 
            +
                    conv_2d_out_channels = conv_2d_out_channels or out_channels
         | 
| 512 | 
            +
                    self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    if non_linearity == "swish":
         | 
| 515 | 
            +
                        self.nonlinearity = lambda x: F.silu(x)
         | 
| 516 | 
            +
                    elif non_linearity == "mish":
         | 
| 517 | 
            +
                        self.nonlinearity = nn.Mish()
         | 
| 518 | 
            +
                    elif non_linearity == "silu":
         | 
| 519 | 
            +
                        self.nonlinearity = nn.SiLU()
         | 
| 520 | 
            +
                    elif non_linearity == "gelu":
         | 
| 521 | 
            +
                        self.nonlinearity = nn.GELU()
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    self.upsample = self.downsample = None
         | 
| 524 | 
            +
                    if self.up:
         | 
| 525 | 
            +
                        if kernel == "fir":
         | 
| 526 | 
            +
                            fir_kernel = (1, 3, 3, 1)
         | 
| 527 | 
            +
                            self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
         | 
| 528 | 
            +
                        elif kernel == "sde_vp":
         | 
| 529 | 
            +
                            self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
         | 
| 530 | 
            +
                        else:
         | 
| 531 | 
            +
                            self.upsample = Upsample2D(in_channels, use_conv=False)
         | 
| 532 | 
            +
                    elif self.down:
         | 
| 533 | 
            +
                        if kernel == "fir":
         | 
| 534 | 
            +
                            fir_kernel = (1, 3, 3, 1)
         | 
| 535 | 
            +
                            self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
         | 
| 536 | 
            +
                        elif kernel == "sde_vp":
         | 
| 537 | 
            +
                            self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
         | 
| 538 | 
            +
                        else:
         | 
| 539 | 
            +
                            self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    self.conv_shortcut = None
         | 
| 544 | 
            +
                    if self.use_in_shortcut:
         | 
| 545 | 
            +
                        self.conv_shortcut = torch.nn.Conv2d(
         | 
| 546 | 
            +
                            in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
         | 
| 547 | 
            +
                        )
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                def forward(self, input_tensor, temb):
         | 
| 550 | 
            +
                    hidden_states = input_tensor
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 553 | 
            +
                        hidden_states = self.norm1(hidden_states, temb)
         | 
| 554 | 
            +
                    else:
         | 
| 555 | 
            +
                        hidden_states = self.norm1(hidden_states)
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    if self.upsample is not None:
         | 
| 560 | 
            +
                        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         | 
| 561 | 
            +
                        if hidden_states.shape[0] >= 64:
         | 
| 562 | 
            +
                            input_tensor = input_tensor.contiguous()
         | 
| 563 | 
            +
                            hidden_states = hidden_states.contiguous()
         | 
| 564 | 
            +
                        input_tensor = self.upsample(input_tensor)
         | 
| 565 | 
            +
                        hidden_states = self.upsample(hidden_states)
         | 
| 566 | 
            +
                    elif self.downsample is not None:
         | 
| 567 | 
            +
                        input_tensor = self.downsample(input_tensor)
         | 
| 568 | 
            +
                        hidden_states = self.downsample(hidden_states)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    if self.time_emb_proj is not None:
         | 
| 573 | 
            +
                        temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                    if temb is not None and self.time_embedding_norm == "default":
         | 
| 576 | 
            +
                        hidden_states = hidden_states + temb
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    if self.time_embedding_norm == "ada_group":
         | 
| 579 | 
            +
                        hidden_states = self.norm2(hidden_states, temb)
         | 
| 580 | 
            +
                    else:
         | 
| 581 | 
            +
                        hidden_states = self.norm2(hidden_states)
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    if temb is not None and self.time_embedding_norm == "scale_shift":
         | 
| 584 | 
            +
                        scale, shift = torch.chunk(temb, 2, dim=1)
         | 
| 585 | 
            +
                        hidden_states = hidden_states * (1 + scale) + shift
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 590 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    if self.conv_shortcut is not None:
         | 
| 593 | 
            +
                        input_tensor = self.conv_shortcut(input_tensor)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                    return output_tensor
         | 
| 598 | 
            +
             | 
| 599 | 
            +
             | 
| 600 | 
            +
            class Mish(torch.nn.Module):
         | 
| 601 | 
            +
                def forward(self, hidden_states):
         | 
| 602 | 
            +
                    return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
         | 
| 603 | 
            +
             | 
| 604 | 
            +
             | 
| 605 | 
            +
            # unet_rl.py
         | 
| 606 | 
            +
            def rearrange_dims(tensor):
         | 
| 607 | 
            +
                if len(tensor.shape) == 2:
         | 
| 608 | 
            +
                    return tensor[:, :, None]
         | 
| 609 | 
            +
                if len(tensor.shape) == 3:
         | 
| 610 | 
            +
                    return tensor[:, :, None, :]
         | 
| 611 | 
            +
                elif len(tensor.shape) == 4:
         | 
| 612 | 
            +
                    return tensor[:, :, 0, :]
         | 
| 613 | 
            +
                else:
         | 
| 614 | 
            +
                    raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
         | 
| 615 | 
            +
             | 
| 616 | 
            +
             | 
| 617 | 
            +
            class Conv1dBlock(nn.Module):
         | 
| 618 | 
            +
                """
         | 
| 619 | 
            +
                Conv1d --> GroupNorm --> Mish
         | 
| 620 | 
            +
                """
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
         | 
| 623 | 
            +
                    super().__init__()
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
         | 
| 626 | 
            +
                    self.group_norm = nn.GroupNorm(n_groups, out_channels)
         | 
| 627 | 
            +
                    self.mish = nn.Mish()
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                def forward(self, x):
         | 
| 630 | 
            +
                    x = self.conv1d(x)
         | 
| 631 | 
            +
                    x = rearrange_dims(x)
         | 
| 632 | 
            +
                    x = self.group_norm(x)
         | 
| 633 | 
            +
                    x = rearrange_dims(x)
         | 
| 634 | 
            +
                    x = self.mish(x)
         | 
| 635 | 
            +
                    return x
         | 
| 636 | 
            +
             | 
| 637 | 
            +
             | 
| 638 | 
            +
            # unet_rl.py
         | 
| 639 | 
            +
            class ResidualTemporalBlock1D(nn.Module):
         | 
| 640 | 
            +
                def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
         | 
| 641 | 
            +
                    super().__init__()
         | 
| 642 | 
            +
                    self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
         | 
| 643 | 
            +
                    self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                    self.time_emb_act = nn.Mish()
         | 
| 646 | 
            +
                    self.time_emb = nn.Linear(embed_dim, out_channels)
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    self.residual_conv = (
         | 
| 649 | 
            +
                        nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
         | 
| 650 | 
            +
                    )
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                def forward(self, x, t):
         | 
| 653 | 
            +
                    """
         | 
| 654 | 
            +
                    Args:
         | 
| 655 | 
            +
                        x : [ batch_size x inp_channels x horizon ]
         | 
| 656 | 
            +
                        t : [ batch_size x embed_dim ]
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    returns:
         | 
| 659 | 
            +
                        out : [ batch_size x out_channels x horizon ]
         | 
| 660 | 
            +
                    """
         | 
| 661 | 
            +
                    t = self.time_emb_act(t)
         | 
| 662 | 
            +
                    t = self.time_emb(t)
         | 
| 663 | 
            +
                    out = self.conv_in(x) + rearrange_dims(t)
         | 
| 664 | 
            +
                    out = self.conv_out(out)
         | 
| 665 | 
            +
                    return out + self.residual_conv(x)
         | 
| 666 | 
            +
             | 
| 667 | 
            +
             | 
| 668 | 
            +
            def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
         | 
| 669 | 
            +
                r"""Upsample2D a batch of 2D images with the given filter.
         | 
| 670 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
         | 
| 671 | 
            +
                filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
         | 
| 672 | 
            +
                `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
         | 
| 673 | 
            +
                a: multiple of the upsampling factor.
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                Args:
         | 
| 676 | 
            +
                    hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 677 | 
            +
                    kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 678 | 
            +
                      (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
         | 
| 679 | 
            +
                    factor: Integer upsampling factor (default: 2).
         | 
| 680 | 
            +
                    gain: Scaling factor for signal magnitude (default: 1.0).
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                Returns:
         | 
| 683 | 
            +
                    output: Tensor of the shape `[N, C, H * factor, W * factor]`
         | 
| 684 | 
            +
                """
         | 
| 685 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 686 | 
            +
                if kernel is None:
         | 
| 687 | 
            +
                    kernel = [1] * factor
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                kernel = torch.tensor(kernel, dtype=torch.float32)
         | 
| 690 | 
            +
                if kernel.ndim == 1:
         | 
| 691 | 
            +
                    kernel = torch.outer(kernel, kernel)
         | 
| 692 | 
            +
                kernel /= torch.sum(kernel)
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                kernel = kernel * (gain * (factor**2))
         | 
| 695 | 
            +
                pad_value = kernel.shape[0] - factor
         | 
| 696 | 
            +
                output = upfirdn2d_native(
         | 
| 697 | 
            +
                    hidden_states,
         | 
| 698 | 
            +
                    kernel.to(device=hidden_states.device),
         | 
| 699 | 
            +
                    up=factor,
         | 
| 700 | 
            +
                    pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
         | 
| 701 | 
            +
                )
         | 
| 702 | 
            +
                return output
         | 
| 703 | 
            +
             | 
| 704 | 
            +
             | 
| 705 | 
            +
            def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
         | 
| 706 | 
            +
                r"""Downsample2D a batch of 2D images with the given filter.
         | 
| 707 | 
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
         | 
| 708 | 
            +
                given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
         | 
| 709 | 
            +
                specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
         | 
| 710 | 
            +
                shape is a multiple of the downsampling factor.
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                Args:
         | 
| 713 | 
            +
                    hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         | 
| 714 | 
            +
                    kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
         | 
| 715 | 
            +
                      (separable). The default is `[1] * factor`, which corresponds to average pooling.
         | 
| 716 | 
            +
                    factor: Integer downsampling factor (default: 2).
         | 
| 717 | 
            +
                    gain: Scaling factor for signal magnitude (default: 1.0).
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                Returns:
         | 
| 720 | 
            +
                    output: Tensor of the shape `[N, C, H // factor, W // factor]`
         | 
| 721 | 
            +
                """
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                assert isinstance(factor, int) and factor >= 1
         | 
| 724 | 
            +
                if kernel is None:
         | 
| 725 | 
            +
                    kernel = [1] * factor
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                kernel = torch.tensor(kernel, dtype=torch.float32)
         | 
| 728 | 
            +
                if kernel.ndim == 1:
         | 
| 729 | 
            +
                    kernel = torch.outer(kernel, kernel)
         | 
| 730 | 
            +
                kernel /= torch.sum(kernel)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                kernel = kernel * gain
         | 
| 733 | 
            +
                pad_value = kernel.shape[0] - factor
         | 
| 734 | 
            +
                output = upfirdn2d_native(
         | 
| 735 | 
            +
                    hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
         | 
| 736 | 
            +
                )
         | 
| 737 | 
            +
                return output
         | 
| 738 | 
            +
             | 
| 739 | 
            +
             | 
| 740 | 
            +
            def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
         | 
| 741 | 
            +
                up_x = up_y = up
         | 
| 742 | 
            +
                down_x = down_y = down
         | 
| 743 | 
            +
                pad_x0 = pad_y0 = pad[0]
         | 
| 744 | 
            +
                pad_x1 = pad_y1 = pad[1]
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                _, channel, in_h, in_w = tensor.shape
         | 
| 747 | 
            +
                tensor = tensor.reshape(-1, in_h, in_w, 1)
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                _, in_h, in_w, minor = tensor.shape
         | 
| 750 | 
            +
                kernel_h, kernel_w = kernel.shape
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                out = tensor.view(-1, in_h, 1, in_w, 1, minor)
         | 
| 753 | 
            +
                out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
         | 
| 754 | 
            +
                out = out.view(-1, in_h * up_y, in_w * up_x, minor)
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
         | 
| 757 | 
            +
                out = out.to(tensor.device)  # Move back to mps if necessary
         | 
| 758 | 
            +
                out = out[
         | 
| 759 | 
            +
                    :,
         | 
| 760 | 
            +
                    max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
         | 
| 761 | 
            +
                    max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
         | 
| 762 | 
            +
                    :,
         | 
| 763 | 
            +
                ]
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                out = out.permute(0, 3, 1, 2)
         | 
| 766 | 
            +
                out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
         | 
| 767 | 
            +
                w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
         | 
| 768 | 
            +
                out = F.conv2d(out, w)
         | 
| 769 | 
            +
                out = out.reshape(
         | 
| 770 | 
            +
                    -1,
         | 
| 771 | 
            +
                    minor,
         | 
| 772 | 
            +
                    in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
         | 
| 773 | 
            +
                    in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
         | 
| 774 | 
            +
                )
         | 
| 775 | 
            +
                out = out.permute(0, 2, 3, 1)
         | 
| 776 | 
            +
                out = out[:, ::down_y, ::down_x, :]
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
         | 
| 779 | 
            +
                out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                return out.view(-1, channel, out_h, out_w)
         | 
| 782 | 
            +
             | 
| 783 | 
            +
             | 
| 784 | 
            +
            class TemporalConvLayer(nn.Module):
         | 
| 785 | 
            +
                """
         | 
| 786 | 
            +
                Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
         | 
| 787 | 
            +
                https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
         | 
| 788 | 
            +
                """
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                def __init__(self, in_dim, out_dim=None, dropout=0.0):
         | 
| 791 | 
            +
                    super().__init__()
         | 
| 792 | 
            +
                    out_dim = out_dim or in_dim
         | 
| 793 | 
            +
                    self.in_dim = in_dim
         | 
| 794 | 
            +
                    self.out_dim = out_dim
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                    # conv layers
         | 
| 797 | 
            +
                    self.conv1 = nn.Sequential(
         | 
| 798 | 
            +
                        nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
         | 
| 799 | 
            +
                    )
         | 
| 800 | 
            +
                    self.conv2 = nn.Sequential(
         | 
| 801 | 
            +
                        nn.GroupNorm(32, out_dim),
         | 
| 802 | 
            +
                        nn.SiLU(),
         | 
| 803 | 
            +
                        nn.Dropout(dropout),
         | 
| 804 | 
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         | 
| 805 | 
            +
                    )
         | 
| 806 | 
            +
                    self.conv3 = nn.Sequential(
         | 
| 807 | 
            +
                        nn.GroupNorm(32, out_dim),
         | 
| 808 | 
            +
                        nn.SiLU(),
         | 
| 809 | 
            +
                        nn.Dropout(dropout),
         | 
| 810 | 
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         | 
| 811 | 
            +
                    )
         | 
| 812 | 
            +
                    self.conv4 = nn.Sequential(
         | 
| 813 | 
            +
                        nn.GroupNorm(32, out_dim),
         | 
| 814 | 
            +
                        nn.SiLU(),
         | 
| 815 | 
            +
                        nn.Dropout(dropout),
         | 
| 816 | 
            +
                        nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
         | 
| 817 | 
            +
                    )
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                    # zero out the last layer params,so the conv block is identity
         | 
| 820 | 
            +
                    nn.init.zeros_(self.conv4[-1].weight)
         | 
| 821 | 
            +
                    nn.init.zeros_(self.conv4[-1].bias)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                def forward(self, hidden_states, num_frames=1):
         | 
| 824 | 
            +
                    hidden_states = (
         | 
| 825 | 
            +
                        hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
         | 
| 826 | 
            +
                    )
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                    identity = hidden_states
         | 
| 829 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 830 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 831 | 
            +
                    hidden_states = self.conv3(hidden_states)
         | 
| 832 | 
            +
                    hidden_states = self.conv4(hidden_states)
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                    hidden_states = identity + hidden_states
         | 
| 835 | 
            +
             | 
| 836 | 
            +
                    hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
         | 
| 837 | 
            +
                        (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
         | 
| 838 | 
            +
                    )
         | 
| 839 | 
            +
                    return hidden_states
         | 
    	
        diffusers/models/transformer_2d.py
    ADDED
    
    | @@ -0,0 +1,333 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from dataclasses import dataclass
         | 
| 16 | 
            +
            from typing import Any, Dict, Optional
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn.functional as F
         | 
| 20 | 
            +
            from torch import nn
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 23 | 
            +
            from ..utils.outputs import BaseOutput
         | 
| 24 | 
            +
            from ..utils.deprecation_utils import deprecate
         | 
| 25 | 
            +
            from ..models.embeddings import ImagePositionalEmbeddings
         | 
| 26 | 
            +
            from .attention import BasicTransformerBlock
         | 
| 27 | 
            +
            from .embeddings import PatchEmbed
         | 
| 28 | 
            +
            from .modeling_utils import ModelMixin
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            @dataclass
         | 
| 32 | 
            +
            class Transformer2DModelOutput(BaseOutput):
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                Args:
         | 
| 35 | 
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or 
         | 
| 36 | 
            +
                        `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
         | 
| 37 | 
            +
                        Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
         | 
| 38 | 
            +
                        for the unnoised latent pixels.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                sample: torch.FloatTensor
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Transformer2DModel(ModelMixin, ConfigMixin):
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
         | 
| 47 | 
            +
                embeddings) inputs.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
         | 
| 50 | 
            +
                transformer action. Finally, reshape to image.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
         | 
| 53 | 
            +
                embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
         | 
| 54 | 
            +
                classes of unnoised image.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
         | 
| 57 | 
            +
                image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Parameters:
         | 
| 60 | 
            +
                    num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
         | 
| 61 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
         | 
| 62 | 
            +
                    in_channels (`int`, *optional*):
         | 
| 63 | 
            +
                        Pass if the input is continuous. The number of channels in the input and output.
         | 
| 64 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
         | 
| 65 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 66 | 
            +
                    cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
         | 
| 67 | 
            +
                    sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
         | 
| 68 | 
            +
                        Note that this is fixed at training time as it is used for learning a number of position embeddings.
         | 
| 69 | 
            +
                        See `ImagePositionalEmbeddings`.
         | 
| 70 | 
            +
                    num_vector_embeds (`int`, *optional*):
         | 
| 71 | 
            +
                        Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
         | 
| 72 | 
            +
                        Includes the class for the masked latent pixel.
         | 
| 73 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 74 | 
            +
                    num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
         | 
| 75 | 
            +
                        The number of diffusion steps used during training. Note that this is fixed at training time as it is
         | 
| 76 | 
            +
                        used to learn a number of embeddings that are added to the hidden states. During inference, you can
         | 
| 77 | 
            +
                        denoise for up to but not more than steps than `num_embeds_ada_norm`.
         | 
| 78 | 
            +
                    attention_bias (`bool`, *optional*):
         | 
| 79 | 
            +
                        Configure if the TransformerBlocks' attention should contain a bias parameter.
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                @register_to_config
         | 
| 83 | 
            +
                def __init__(
         | 
| 84 | 
            +
                    self,
         | 
| 85 | 
            +
                    num_attention_heads: int = 16,
         | 
| 86 | 
            +
                    attention_head_dim: int = 88,
         | 
| 87 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 88 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 89 | 
            +
                    num_layers: int = 1,
         | 
| 90 | 
            +
                    dropout: float = 0.0,
         | 
| 91 | 
            +
                    norm_num_groups: int = 32,
         | 
| 92 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 93 | 
            +
                    attention_bias: bool = False,
         | 
| 94 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 95 | 
            +
                    num_vector_embeds: Optional[int] = None,
         | 
| 96 | 
            +
                    patch_size: Optional[int] = None,
         | 
| 97 | 
            +
                    activation_fn: str = "geglu",
         | 
| 98 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 99 | 
            +
                    use_linear_projection: bool = False,
         | 
| 100 | 
            +
                    only_cross_attention: bool = False,
         | 
| 101 | 
            +
                    upcast_attention: bool = False,
         | 
| 102 | 
            +
                    norm_type: str = "layer_norm",
         | 
| 103 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 104 | 
            +
                ):
         | 
| 105 | 
            +
                    super().__init__()
         | 
| 106 | 
            +
                    self.use_linear_projection = use_linear_projection
         | 
| 107 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 108 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 109 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # 1. Transformer2DModel can process both standard continuous images of
         | 
| 112 | 
            +
                    #   shape `(batch_size, num_channels, width, height)` as well as
         | 
| 113 | 
            +
                    #   quantized image embeddings of shape `(batch_size, num_image_vectors)`
         | 
| 114 | 
            +
                    # Define whether input is continuous or discrete depending on configuration
         | 
| 115 | 
            +
                    self.is_input_continuous = (in_channels is not None) and (patch_size is None)
         | 
| 116 | 
            +
                    self.is_input_vectorized = num_vector_embeds is not None
         | 
| 117 | 
            +
                    self.is_input_patches = in_channels is not None and patch_size is not None
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
         | 
| 120 | 
            +
                        deprecation_message = (
         | 
| 121 | 
            +
                            f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
         | 
| 122 | 
            +
                            " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
         | 
| 123 | 
            +
                            " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
         | 
| 124 | 
            +
                            " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
         | 
| 125 | 
            +
                            " would be very nice if you could open a Pull request for the `transformer/config.json` file"
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
                        deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 128 | 
            +
                        norm_type = "ada_norm"
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    if self.is_input_continuous and self.is_input_vectorized:
         | 
| 131 | 
            +
                        raise ValueError(
         | 
| 132 | 
            +
                            f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
         | 
| 133 | 
            +
                            " sure that either `in_channels` or `num_vector_embeds` is None."
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                    elif self.is_input_vectorized and self.is_input_patches:
         | 
| 136 | 
            +
                        raise ValueError(
         | 
| 137 | 
            +
                            f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
         | 
| 138 | 
            +
                            " sure that either `num_vector_embeds` or `num_patches` is None."
         | 
| 139 | 
            +
                        )
         | 
| 140 | 
            +
                    elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
         | 
| 141 | 
            +
                        raise ValueError(
         | 
| 142 | 
            +
                            f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
         | 
| 143 | 
            +
                            f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
         | 
| 144 | 
            +
                        )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # 2. Define input layers
         | 
| 147 | 
            +
                    if self.is_input_continuous:
         | 
| 148 | 
            +
                        self.in_channels = in_channels
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 151 | 
            +
                        if use_linear_projection:
         | 
| 152 | 
            +
                            self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 153 | 
            +
                        else:
         | 
| 154 | 
            +
                            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
         | 
| 155 | 
            +
                    elif self.is_input_vectorized:
         | 
| 156 | 
            +
                        assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
         | 
| 157 | 
            +
                        assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        self.height = sample_size
         | 
| 160 | 
            +
                        self.width = sample_size
         | 
| 161 | 
            +
                        self.num_vector_embeds = num_vector_embeds
         | 
| 162 | 
            +
                        self.num_latent_pixels = self.height * self.width
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        self.latent_image_embedding = ImagePositionalEmbeddings(
         | 
| 165 | 
            +
                            num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
         | 
| 166 | 
            +
                        )
         | 
| 167 | 
            +
                    elif self.is_input_patches:
         | 
| 168 | 
            +
                        assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        self.height = sample_size
         | 
| 171 | 
            +
                        self.width = sample_size
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        self.patch_size = patch_size
         | 
| 174 | 
            +
                        self.pos_embed = PatchEmbed(
         | 
| 175 | 
            +
                            height=sample_size,
         | 
| 176 | 
            +
                            width=sample_size,
         | 
| 177 | 
            +
                            patch_size=patch_size,
         | 
| 178 | 
            +
                            in_channels=in_channels,
         | 
| 179 | 
            +
                            embed_dim=inner_dim,
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # 3. Define transformers blocks
         | 
| 183 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 184 | 
            +
                        [
         | 
| 185 | 
            +
                            BasicTransformerBlock(
         | 
| 186 | 
            +
                                inner_dim,
         | 
| 187 | 
            +
                                num_attention_heads,
         | 
| 188 | 
            +
                                attention_head_dim,
         | 
| 189 | 
            +
                                dropout=dropout,
         | 
| 190 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 191 | 
            +
                                activation_fn=activation_fn,
         | 
| 192 | 
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         | 
| 193 | 
            +
                                attention_bias=attention_bias,
         | 
| 194 | 
            +
                                only_cross_attention=only_cross_attention,
         | 
| 195 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 196 | 
            +
                                norm_type=norm_type,
         | 
| 197 | 
            +
                                norm_elementwise_affine=norm_elementwise_affine,
         | 
| 198 | 
            +
                            )
         | 
| 199 | 
            +
                            for d in range(num_layers)
         | 
| 200 | 
            +
                        ]
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    # 4. Define output layers
         | 
| 204 | 
            +
                    self.out_channels = in_channels if out_channels is None else out_channels
         | 
| 205 | 
            +
                    if self.is_input_continuous:
         | 
| 206 | 
            +
                        # TODO: should use out_channels for continuous projections
         | 
| 207 | 
            +
                        if use_linear_projection:
         | 
| 208 | 
            +
                            self.proj_out = nn.Linear(inner_dim, in_channels)
         | 
| 209 | 
            +
                        else:
         | 
| 210 | 
            +
                            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 211 | 
            +
                    elif self.is_input_vectorized:
         | 
| 212 | 
            +
                        self.norm_out = nn.LayerNorm(inner_dim)
         | 
| 213 | 
            +
                        self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
         | 
| 214 | 
            +
                    elif self.is_input_patches:
         | 
| 215 | 
            +
                        self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
         | 
| 216 | 
            +
                        self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
         | 
| 217 | 
            +
                        self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def forward(
         | 
| 220 | 
            +
                    self,
         | 
| 221 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 222 | 
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 223 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 224 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 225 | 
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 226 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 227 | 
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 228 | 
            +
                    return_dict: bool = True,
         | 
| 229 | 
            +
                ):
         | 
| 230 | 
            +
                    """
         | 
| 231 | 
            +
                    Args:
         | 
| 232 | 
            +
                        hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
         | 
| 233 | 
            +
                            When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
         | 
| 234 | 
            +
                            hidden_states
         | 
| 235 | 
            +
                        encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
         | 
| 236 | 
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         | 
| 237 | 
            +
                            self-attention.
         | 
| 238 | 
            +
                        timestep ( `torch.LongTensor`, *optional*):
         | 
| 239 | 
            +
                            Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
         | 
| 240 | 
            +
                        class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
         | 
| 241 | 
            +
                            Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class 
         | 
| 242 | 
            +
                            labels conditioning.
         | 
| 243 | 
            +
                        attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ).
         | 
| 244 | 
            +
                            Bias to add to attention scores.
         | 
| 245 | 
            +
                        encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ).
         | 
| 246 | 
            +
                            Bias to add to cross-attention scores.
         | 
| 247 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 248 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    Returns:
         | 
| 251 | 
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
         | 
| 252 | 
            +
                        [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. 
         | 
| 253 | 
            +
                        When returning a tuple, the first element is the sample tensor.
         | 
| 254 | 
            +
                    """
         | 
| 255 | 
            +
                    # 1. Input
         | 
| 256 | 
            +
                    if self.is_input_continuous:
         | 
| 257 | 
            +
                        batch, _, height, width = hidden_states.shape
         | 
| 258 | 
            +
                        residual = hidden_states
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                        hidden_states = self.norm(hidden_states)
         | 
| 261 | 
            +
                        if not self.use_linear_projection:
         | 
| 262 | 
            +
                            hidden_states = self.proj_in(hidden_states)
         | 
| 263 | 
            +
                            inner_dim = hidden_states.shape[1]
         | 
| 264 | 
            +
                            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
         | 
| 265 | 
            +
                        else:
         | 
| 266 | 
            +
                            inner_dim = hidden_states.shape[1]
         | 
| 267 | 
            +
                            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
         | 
| 268 | 
            +
                            hidden_states = self.proj_in(hidden_states)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    elif self.is_input_vectorized:
         | 
| 271 | 
            +
                        hidden_states = self.latent_image_embedding(hidden_states)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    elif self.is_input_patches:
         | 
| 274 | 
            +
                        hidden_states = self.pos_embed(hidden_states)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # 2. Blocks
         | 
| 277 | 
            +
                    for block in self.transformer_blocks:
         | 
| 278 | 
            +
                        hidden_states = block(
         | 
| 279 | 
            +
                            hidden_states,
         | 
| 280 | 
            +
                            attention_mask=attention_mask,
         | 
| 281 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 282 | 
            +
                            encoder_attention_mask=encoder_attention_mask,
         | 
| 283 | 
            +
                            timestep=timestep,
         | 
| 284 | 
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         | 
| 285 | 
            +
                            class_labels=class_labels,
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # 3. Output
         | 
| 289 | 
            +
                    if self.is_input_continuous:
         | 
| 290 | 
            +
                        if not self.use_linear_projection:
         | 
| 291 | 
            +
                            hidden_states = hidden_states.reshape(
         | 
| 292 | 
            +
                                batch, height, width, inner_dim
         | 
| 293 | 
            +
                            ).permute(0, 3, 1, 2).contiguous()
         | 
| 294 | 
            +
                            hidden_states = self.proj_out(hidden_states)
         | 
| 295 | 
            +
                        else:
         | 
| 296 | 
            +
                            hidden_states = self.proj_out(hidden_states)
         | 
| 297 | 
            +
                            hidden_states = hidden_states.reshape(
         | 
| 298 | 
            +
                                batch, height, width, inner_dim
         | 
| 299 | 
            +
                            ).permute(0, 3, 1, 2).contiguous()
         | 
| 300 | 
            +
                        output = hidden_states + residual
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    elif self.is_input_vectorized:
         | 
| 303 | 
            +
                        hidden_states = self.norm_out(hidden_states)
         | 
| 304 | 
            +
                        logits = self.out(hidden_states)
         | 
| 305 | 
            +
                        # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
         | 
| 306 | 
            +
                        logits = logits.permute(0, 2, 1)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        # log(p(x_0))
         | 
| 309 | 
            +
                        output = F.log_softmax(logits.double(), dim=1).float()
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    elif self.is_input_patches:
         | 
| 312 | 
            +
                        # TODO: cleanup!
         | 
| 313 | 
            +
                        conditioning = self.transformer_blocks[0].norm1.emb(
         | 
| 314 | 
            +
                            timestep, class_labels, hidden_dtype=hidden_states.dtype
         | 
| 315 | 
            +
                        )
         | 
| 316 | 
            +
                        shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
         | 
| 317 | 
            +
                        hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
         | 
| 318 | 
            +
                        hidden_states = self.proj_out_2(hidden_states)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                        # unpatchify
         | 
| 321 | 
            +
                        height = width = int(hidden_states.shape[1] ** 0.5)
         | 
| 322 | 
            +
                        hidden_states = hidden_states.reshape(
         | 
| 323 | 
            +
                            shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
         | 
| 324 | 
            +
                        )
         | 
| 325 | 
            +
                        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
         | 
| 326 | 
            +
                        output = hidden_states.reshape(
         | 
| 327 | 
            +
                            shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
         | 
| 328 | 
            +
                        )
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    if not return_dict:
         | 
| 331 | 
            +
                        return (output,)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    return Transformer2DModelOutput(sample=output)
         | 
    	
        diffusers/models/unet_2d.py
    ADDED
    
    | @@ -0,0 +1,315 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            from dataclasses import dataclass
         | 
| 15 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn as nn
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from ..utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 21 | 
            +
            from ..utils.outputs import BaseOutput
         | 
| 22 | 
            +
            from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
         | 
| 23 | 
            +
            from .modeling_utils import ModelMixin
         | 
| 24 | 
            +
            from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            @dataclass
         | 
| 28 | 
            +
            class UNet2DOutput(BaseOutput):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Args:
         | 
| 31 | 
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         | 
| 32 | 
            +
                        Hidden states output. Output of last layer of model.
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                sample: torch.FloatTensor
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class UNet2DModel(ModelMixin, ConfigMixin):
         | 
| 39 | 
            +
                r"""
         | 
| 40 | 
            +
                UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
         | 
| 43 | 
            +
                implements for all the model (such as downloading or saving, etc.)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                Parameters:
         | 
| 46 | 
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         | 
| 47 | 
            +
                        Height and width of input/output sample.
         | 
| 48 | 
            +
                    in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
         | 
| 49 | 
            +
                    out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
         | 
| 50 | 
            +
                    center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
         | 
| 51 | 
            +
                    time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
         | 
| 52 | 
            +
                    freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
         | 
| 53 | 
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to :
         | 
| 54 | 
            +
                        obj:`True`): Whether to flip sin to cos for fourier time embedding.
         | 
| 55 | 
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to :
         | 
| 56 | 
            +
                        obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
         | 
| 57 | 
            +
                        types.
         | 
| 58 | 
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
         | 
| 59 | 
            +
                        The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
         | 
| 60 | 
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to :
         | 
| 61 | 
            +
                        obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
         | 
| 62 | 
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to :
         | 
| 63 | 
            +
                        obj:`(224, 448, 672, 896)`): Tuple of block output channels.
         | 
| 64 | 
            +
                    layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
         | 
| 65 | 
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
         | 
| 66 | 
            +
                    downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
         | 
| 67 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         | 
| 68 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
         | 
| 69 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
         | 
| 70 | 
            +
                    norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
         | 
| 71 | 
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
         | 
| 72 | 
            +
                        for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
         | 
| 73 | 
            +
                    class_embed_type (`str`, *optional*, defaults to None):
         | 
| 74 | 
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
         | 
| 75 | 
            +
                        `"timestep"`, or `"identity"`.
         | 
| 76 | 
            +
                    num_class_embeds (`int`, *optional*, defaults to None):
         | 
| 77 | 
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
         | 
| 78 | 
            +
                        class conditioning with `class_embed_type` equal to `None`.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                @register_to_config
         | 
| 82 | 
            +
                def __init__(
         | 
| 83 | 
            +
                    self,
         | 
| 84 | 
            +
                    sample_size: Optional[Union[int, Tuple[int, int]]] = None,
         | 
| 85 | 
            +
                    in_channels: int = 3,
         | 
| 86 | 
            +
                    out_channels: int = 3,
         | 
| 87 | 
            +
                    center_input_sample: bool = False,
         | 
| 88 | 
            +
                    time_embedding_type: str = "positional",
         | 
| 89 | 
            +
                    freq_shift: int = 0,
         | 
| 90 | 
            +
                    flip_sin_to_cos: bool = True,
         | 
| 91 | 
            +
                    down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
         | 
| 92 | 
            +
                    up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
         | 
| 93 | 
            +
                    block_out_channels: Tuple[int] = (224, 448, 672, 896),
         | 
| 94 | 
            +
                    layers_per_block: int = 2,
         | 
| 95 | 
            +
                    mid_block_scale_factor: float = 1,
         | 
| 96 | 
            +
                    downsample_padding: int = 1,
         | 
| 97 | 
            +
                    act_fn: str = "silu",
         | 
| 98 | 
            +
                    attention_head_dim: Optional[int] = 8,
         | 
| 99 | 
            +
                    norm_num_groups: int = 32,
         | 
| 100 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 101 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 102 | 
            +
                    add_attention: bool = True,
         | 
| 103 | 
            +
                    class_embed_type: Optional[str] = None,
         | 
| 104 | 
            +
                    num_class_embeds: Optional[int] = None,
         | 
| 105 | 
            +
                ):
         | 
| 106 | 
            +
                    super().__init__()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.sample_size = sample_size
         | 
| 109 | 
            +
                    time_embed_dim = block_out_channels[0] * 4
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # Check inputs
         | 
| 112 | 
            +
                    if len(down_block_types) != len(up_block_types):
         | 
| 113 | 
            +
                        raise ValueError(
         | 
| 114 | 
            +
                            f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         | 
| 115 | 
            +
                        )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if len(block_out_channels) != len(down_block_types):
         | 
| 118 | 
            +
                        raise ValueError(
         | 
| 119 | 
            +
                            f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # input
         | 
| 123 | 
            +
                    self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # time
         | 
| 126 | 
            +
                    if time_embedding_type == "fourier":
         | 
| 127 | 
            +
                        self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
         | 
| 128 | 
            +
                        timestep_input_dim = 2 * block_out_channels[0]
         | 
| 129 | 
            +
                    elif time_embedding_type == "positional":
         | 
| 130 | 
            +
                        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         | 
| 131 | 
            +
                        timestep_input_dim = block_out_channels[0]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # class embedding
         | 
| 136 | 
            +
                    if class_embed_type is None and num_class_embeds is not None:
         | 
| 137 | 
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
         | 
| 138 | 
            +
                    elif class_embed_type == "timestep":
         | 
| 139 | 
            +
                        self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
         | 
| 140 | 
            +
                    elif class_embed_type == "identity":
         | 
| 141 | 
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        self.class_embedding = None
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 146 | 
            +
                    self.mid_block = None
         | 
| 147 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    # down
         | 
| 150 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 151 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 152 | 
            +
                        input_channel = output_channel
         | 
| 153 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 154 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        down_block = get_down_block(
         | 
| 157 | 
            +
                            down_block_type,
         | 
| 158 | 
            +
                            num_layers=layers_per_block,
         | 
| 159 | 
            +
                            in_channels=input_channel,
         | 
| 160 | 
            +
                            out_channels=output_channel,
         | 
| 161 | 
            +
                            temb_channels=time_embed_dim,
         | 
| 162 | 
            +
                            add_downsample=not is_final_block,
         | 
| 163 | 
            +
                            resnet_eps=norm_eps,
         | 
| 164 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 165 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 166 | 
            +
                            attn_num_head_channels=attention_head_dim,
         | 
| 167 | 
            +
                            downsample_padding=downsample_padding,
         | 
| 168 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # mid
         | 
| 173 | 
            +
                    self.mid_block = UNetMidBlock2D(
         | 
| 174 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 175 | 
            +
                        temb_channels=time_embed_dim,
         | 
| 176 | 
            +
                        resnet_eps=norm_eps,
         | 
| 177 | 
            +
                        resnet_act_fn=act_fn,
         | 
| 178 | 
            +
                        output_scale_factor=mid_block_scale_factor,
         | 
| 179 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 180 | 
            +
                        attn_num_head_channels=attention_head_dim,
         | 
| 181 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 182 | 
            +
                        add_attention=add_attention,
         | 
| 183 | 
            +
                    )
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # up
         | 
| 186 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 187 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 188 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 189 | 
            +
                        prev_output_channel = output_channel
         | 
| 190 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 191 | 
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                        up_block = get_up_block(
         | 
| 196 | 
            +
                            up_block_type,
         | 
| 197 | 
            +
                            num_layers=layers_per_block + 1,
         | 
| 198 | 
            +
                            in_channels=input_channel,
         | 
| 199 | 
            +
                            out_channels=output_channel,
         | 
| 200 | 
            +
                            prev_output_channel=prev_output_channel,
         | 
| 201 | 
            +
                            temb_channels=time_embed_dim,
         | 
| 202 | 
            +
                            add_upsample=not is_final_block,
         | 
| 203 | 
            +
                            resnet_eps=norm_eps,
         | 
| 204 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 205 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 206 | 
            +
                            attn_num_head_channels=attention_head_dim,
         | 
| 207 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 210 | 
            +
                        prev_output_channel = output_channel
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    # out
         | 
| 213 | 
            +
                    num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
         | 
| 214 | 
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
         | 
| 215 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 216 | 
            +
                    self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                def forward(
         | 
| 219 | 
            +
                    self,
         | 
| 220 | 
            +
                    sample: torch.FloatTensor,
         | 
| 221 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 222 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 223 | 
            +
                    return_dict: bool = True,
         | 
| 224 | 
            +
                ) -> Union[UNet2DOutput, Tuple]:
         | 
| 225 | 
            +
                    r"""
         | 
| 226 | 
            +
                    Args:
         | 
| 227 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 228 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
         | 
| 229 | 
            +
                        class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
         | 
| 230 | 
            +
                            Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
         | 
| 231 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 232 | 
            +
                            Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    Returns:
         | 
| 235 | 
            +
                        [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
         | 
| 236 | 
            +
                        otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
         | 
| 237 | 
            +
                    """
         | 
| 238 | 
            +
                    # 0. center input if necessary
         | 
| 239 | 
            +
                    if self.config.center_input_sample:
         | 
| 240 | 
            +
                        sample = 2 * sample - 1.0
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # 1. time
         | 
| 243 | 
            +
                    timesteps = timestep
         | 
| 244 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 245 | 
            +
                        timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
         | 
| 246 | 
            +
                    elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
         | 
| 247 | 
            +
                        timesteps = timesteps[None].to(sample.device)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 250 | 
            +
                    timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    t_emb = self.time_proj(timesteps)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 255 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 256 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 257 | 
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         | 
| 258 | 
            +
                    emb = self.time_embedding(t_emb)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    if self.class_embedding is not None:
         | 
| 261 | 
            +
                        if class_labels is None:
         | 
| 262 | 
            +
                            raise ValueError("class_labels should be provided when doing class conditioning")
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        if self.config.class_embed_type == "timestep":
         | 
| 265 | 
            +
                            class_labels = self.time_proj(class_labels)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
         | 
| 268 | 
            +
                        emb = emb + class_emb
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # 2. pre-process
         | 
| 271 | 
            +
                    skip_sample = sample
         | 
| 272 | 
            +
                    sample = self.conv_in(sample)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    # 3. down
         | 
| 275 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 276 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 277 | 
            +
                        if hasattr(downsample_block, "skip_conv"):
         | 
| 278 | 
            +
                            sample, res_samples, skip_sample = downsample_block(
         | 
| 279 | 
            +
                                hidden_states=sample, temb=emb, skip_sample=skip_sample
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
                        else:
         | 
| 282 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                        down_block_res_samples += res_samples
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # 4. mid
         | 
| 287 | 
            +
                    sample = self.mid_block(sample, emb)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    # 5. up
         | 
| 290 | 
            +
                    skip_sample = None
         | 
| 291 | 
            +
                    for upsample_block in self.up_blocks:
         | 
| 292 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 293 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                        if hasattr(upsample_block, "skip_conv"):
         | 
| 296 | 
            +
                            sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
         | 
| 297 | 
            +
                        else:
         | 
| 298 | 
            +
                            sample = upsample_block(sample, res_samples, emb)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # 6. post-process
         | 
| 301 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 302 | 
            +
                    sample = self.conv_act(sample)
         | 
| 303 | 
            +
                    sample = self.conv_out(sample)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    if skip_sample is not None:
         | 
| 306 | 
            +
                        sample += skip_sample
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    if self.config.time_embedding_type == "fourier":
         | 
| 309 | 
            +
                        timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
         | 
| 310 | 
            +
                        sample = sample / timesteps
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    if not return_dict:
         | 
| 313 | 
            +
                        return (sample,)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    return UNet2DOutput(sample=sample)
         | 
    	
        diffusers/models/unet_2d_blocks.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        diffusers/models/unet_2d_condition.py
    ADDED
    
    | @@ -0,0 +1,907 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from dataclasses import dataclass
         | 
| 16 | 
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            import torch.utils.checkpoint
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 23 | 
            +
            from ..utils.outputs import BaseOutput
         | 
| 24 | 
            +
            from .loaders import UNet2DConditionLoadersMixin
         | 
| 25 | 
            +
            from .activations import get_activation
         | 
| 26 | 
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         | 
| 27 | 
            +
            from .embeddings import (
         | 
| 28 | 
            +
                GaussianFourierProjection,
         | 
| 29 | 
            +
                TextImageProjection,
         | 
| 30 | 
            +
                TextImageTimeEmbedding,
         | 
| 31 | 
            +
                TextTimeEmbedding,
         | 
| 32 | 
            +
                TimestepEmbedding,
         | 
| 33 | 
            +
                Timesteps,
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            from .modeling_utils import ModelMixin
         | 
| 36 | 
            +
            from .unet_2d_blocks import (
         | 
| 37 | 
            +
                CrossAttnDownBlock2D,
         | 
| 38 | 
            +
                CrossAttnUpBlock2D,
         | 
| 39 | 
            +
                DownBlock2D,
         | 
| 40 | 
            +
                UNetMidBlock2DCrossAttn,
         | 
| 41 | 
            +
                UNetMidBlock2DSimpleCrossAttn,
         | 
| 42 | 
            +
                UpBlock2D,
         | 
| 43 | 
            +
                get_down_block,
         | 
| 44 | 
            +
                get_up_block,
         | 
| 45 | 
            +
            )
         | 
| 46 | 
            +
            from ..utils.logging import get_logger
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            logger = get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @dataclass
         | 
| 52 | 
            +
            class UNet2DConditionOutput(BaseOutput):
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                Args:
         | 
| 55 | 
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
         | 
| 56 | 
            +
                        Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                sample: torch.FloatTensor
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
         | 
| 63 | 
            +
                r"""
         | 
| 64 | 
            +
                UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
         | 
| 65 | 
            +
                and returns sample shaped output.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
         | 
| 68 | 
            +
                implements for all the models (such as downloading or saving, etc.)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                Parameters:
         | 
| 71 | 
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         | 
| 72 | 
            +
                        Height and width of input/output sample.
         | 
| 73 | 
            +
                    in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
         | 
| 74 | 
            +
                    out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
         | 
| 75 | 
            +
                    center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
         | 
| 76 | 
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
         | 
| 77 | 
            +
                        Whether to flip the sin to cos in the time embedding.
         | 
| 78 | 
            +
                    freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
         | 
| 79 | 
            +
                    down_block_types (`Tuple[str]`, *optional*,
         | 
| 80 | 
            +
                        defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
         | 
| 81 | 
            +
                        The tuple of downsample blocks to use.
         | 
| 82 | 
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
         | 
| 83 | 
            +
                        The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
         | 
| 84 | 
            +
                        mid block layer if `None`.
         | 
| 85 | 
            +
                    up_block_types (`Tuple[str]`, *optional*,
         | 
| 86 | 
            +
                        defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
         | 
| 87 | 
            +
                        The tuple of upsample blocks to use.
         | 
| 88 | 
            +
                    only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
         | 
| 89 | 
            +
                        Whether to include self-attention in the basic transformer blocks, see
         | 
| 90 | 
            +
                        [`~models.attention.BasicTransformerBlock`].
         | 
| 91 | 
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         | 
| 92 | 
            +
                        The tuple of output channels for each block.
         | 
| 93 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
         | 
| 94 | 
            +
                    downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
         | 
| 95 | 
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
         | 
| 96 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         | 
| 97 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
         | 
| 98 | 
            +
                        If `None`, it will skip the normalization and activation layers in post-processing
         | 
| 99 | 
            +
                    norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
         | 
| 100 | 
            +
                    cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
         | 
| 101 | 
            +
                        The dimension of the cross attention features.
         | 
| 102 | 
            +
                    encoder_hid_dim (`int`, *optional*, defaults to None):
         | 
| 103 | 
            +
                        If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
         | 
| 104 | 
            +
                        dimension to `cross_attention_dim`.
         | 
| 105 | 
            +
                    encoder_hid_dim_type (`str`, *optional*, defaults to None):
         | 
| 106 | 
            +
                        If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
         | 
| 107 | 
            +
                        embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
         | 
| 108 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
         | 
| 109 | 
            +
                    num_attention_heads (`int`, *optional*):
         | 
| 110 | 
            +
                        The number of attention heads. If not defined, defaults to `attention_head_dim`
         | 
| 111 | 
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
         | 
| 112 | 
            +
                        for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
         | 
| 113 | 
            +
                    class_embed_type (`str`, *optional*, defaults to None):
         | 
| 114 | 
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
         | 
| 115 | 
            +
                        `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
         | 
| 116 | 
            +
                    addition_embed_type (`str`, *optional*, defaults to None):
         | 
| 117 | 
            +
                        Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
         | 
| 118 | 
            +
                        "text". "text" will use the `TextTimeEmbedding` layer.
         | 
| 119 | 
            +
                    num_class_embeds (`int`, *optional*, defaults to None):
         | 
| 120 | 
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
         | 
| 121 | 
            +
                        class conditioning with `class_embed_type` equal to `None`.
         | 
| 122 | 
            +
                    time_embedding_type (`str`, *optional*, default to `positional`):
         | 
| 123 | 
            +
                        The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
         | 
| 124 | 
            +
                    time_embedding_dim (`int`, *optional*, default to `None`):
         | 
| 125 | 
            +
                        An optional override for the dimension of the projected time embedding.
         | 
| 126 | 
            +
                    time_embedding_act_fn (`str`, *optional*, default to `None`):
         | 
| 127 | 
            +
                        Optional activation function to use on the time embeddings only one time before they as passed to the rest
         | 
| 128 | 
            +
                        of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
         | 
| 129 | 
            +
                    timestep_post_act (`str, *optional*, default to `None`):
         | 
| 130 | 
            +
                        The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
         | 
| 131 | 
            +
                    time_cond_proj_dim (`int`, *optional*, default to `None`):
         | 
| 132 | 
            +
                        The dimension of `cond_proj` layer in timestep embedding.
         | 
| 133 | 
            +
                    conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
         | 
| 134 | 
            +
                    conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
         | 
| 135 | 
            +
                    projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
         | 
| 136 | 
            +
                        using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
         | 
| 137 | 
            +
                    class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
         | 
| 138 | 
            +
                        embeddings with the class embeddings.
         | 
| 139 | 
            +
                    mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
         | 
| 140 | 
            +
                        Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
         | 
| 141 | 
            +
                        `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
         | 
| 142 | 
            +
                        `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
         | 
| 143 | 
            +
                        default to `False`.
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                @register_to_config
         | 
| 149 | 
            +
                def __init__(
         | 
| 150 | 
            +
                    self,
         | 
| 151 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 152 | 
            +
                    in_channels: int = 4,
         | 
| 153 | 
            +
                    out_channels: int = 4,
         | 
| 154 | 
            +
                    center_input_sample: bool = False,
         | 
| 155 | 
            +
                    flip_sin_to_cos: bool = True,
         | 
| 156 | 
            +
                    freq_shift: int = 0,
         | 
| 157 | 
            +
                    down_block_types: Tuple[str] = (
         | 
| 158 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 159 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 160 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 161 | 
            +
                        "DownBlock2D",
         | 
| 162 | 
            +
                    ),
         | 
| 163 | 
            +
                    mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
         | 
| 164 | 
            +
                    up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
         | 
| 165 | 
            +
                    only_cross_attention: Union[bool, Tuple[bool]] = False,
         | 
| 166 | 
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         | 
| 167 | 
            +
                    layers_per_block: Union[int, Tuple[int]] = 2,
         | 
| 168 | 
            +
                    downsample_padding: int = 1,
         | 
| 169 | 
            +
                    mid_block_scale_factor: float = 1,
         | 
| 170 | 
            +
                    act_fn: str = "silu",
         | 
| 171 | 
            +
                    norm_num_groups: Optional[int] = 32,
         | 
| 172 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 173 | 
            +
                    cross_attention_dim: Union[int, Tuple[int]] = 1280,
         | 
| 174 | 
            +
                    encoder_hid_dim: Optional[int] = None,
         | 
| 175 | 
            +
                    encoder_hid_dim_type: Optional[str] = None,
         | 
| 176 | 
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         | 
| 177 | 
            +
                    num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
         | 
| 178 | 
            +
                    dual_cross_attention: bool = False,
         | 
| 179 | 
            +
                    use_linear_projection: bool = False,
         | 
| 180 | 
            +
                    class_embed_type: Optional[str] = None,
         | 
| 181 | 
            +
                    addition_embed_type: Optional[str] = None,
         | 
| 182 | 
            +
                    num_class_embeds: Optional[int] = None,
         | 
| 183 | 
            +
                    upcast_attention: bool = False,
         | 
| 184 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 185 | 
            +
                    resnet_skip_time_act: bool = False,
         | 
| 186 | 
            +
                    resnet_out_scale_factor: int = 1.0,
         | 
| 187 | 
            +
                    time_embedding_type: str = "positional",
         | 
| 188 | 
            +
                    time_embedding_dim: Optional[int] = None,
         | 
| 189 | 
            +
                    time_embedding_act_fn: Optional[str] = None,
         | 
| 190 | 
            +
                    timestep_post_act: Optional[str] = None,
         | 
| 191 | 
            +
                    time_cond_proj_dim: Optional[int] = None,
         | 
| 192 | 
            +
                    conv_in_kernel: int = 3,
         | 
| 193 | 
            +
                    conv_out_kernel: int = 3,
         | 
| 194 | 
            +
                    projection_class_embeddings_input_dim: Optional[int] = None,
         | 
| 195 | 
            +
                    class_embeddings_concat: bool = False,
         | 
| 196 | 
            +
                    mid_block_only_cross_attention: Optional[bool] = None,
         | 
| 197 | 
            +
                    cross_attention_norm: Optional[str] = None,
         | 
| 198 | 
            +
                    addition_embed_type_num_heads=64,
         | 
| 199 | 
            +
                ):
         | 
| 200 | 
            +
                    super().__init__()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    self.sample_size = sample_size
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         | 
| 205 | 
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         | 
| 206 | 
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         | 
| 207 | 
            +
                    # when this library was created. The incorrect naming was only discovered much later in
         | 
| 208 | 
            +
                    # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         | 
| 209 | 
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
         | 
| 210 | 
            +
                    # which is why we correct for the naming here.
         | 
| 211 | 
            +
                    num_attention_heads = num_attention_heads or attention_head_dim
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # Check inputs
         | 
| 214 | 
            +
                    if len(down_block_types) != len(up_block_types):
         | 
| 215 | 
            +
                        raise ValueError(
         | 
| 216 | 
            +
                            "Must provide the same number of `down_block_types` as `up_block_types`. "
         | 
| 217 | 
            +
                            f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         | 
| 218 | 
            +
                        )
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    if len(block_out_channels) != len(down_block_types):
         | 
| 221 | 
            +
                        raise ValueError(
         | 
| 222 | 
            +
                            "Must provide the same number of `block_out_channels` as `down_block_types`. "
         | 
| 223 | 
            +
                            f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
         | 
| 227 | 
            +
                        raise ValueError(
         | 
| 228 | 
            +
                            "Must provide the same number of `only_cross_attention` as `down_block_types`. "
         | 
| 229 | 
            +
                            f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
         | 
| 230 | 
            +
                        )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
         | 
| 233 | 
            +
                        raise ValueError(
         | 
| 234 | 
            +
                            "Must provide the same number of `num_attention_heads` as `down_block_types`. "
         | 
| 235 | 
            +
                            f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
         | 
| 236 | 
            +
                        )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
         | 
| 239 | 
            +
                        raise ValueError(
         | 
| 240 | 
            +
                            "Must provide the same number of `attention_head_dim` as `down_block_types`. "
         | 
| 241 | 
            +
                            f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
         | 
| 242 | 
            +
                        )
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
         | 
| 245 | 
            +
                        raise ValueError(
         | 
| 246 | 
            +
                            "Must provide the same number of `cross_attention_dim` as `down_block_types`. "
         | 
| 247 | 
            +
                            f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
         | 
| 251 | 
            +
                        raise ValueError(
         | 
| 252 | 
            +
                            f"Must provide the same number of `layers_per_block` as `down_block_types`. "
         | 
| 253 | 
            +
                            f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
         | 
| 254 | 
            +
                        )
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # input
         | 
| 257 | 
            +
                    conv_in_padding = (conv_in_kernel - 1) // 2
         | 
| 258 | 
            +
                    self.conv_in = nn.Conv2d(
         | 
| 259 | 
            +
                        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    # time
         | 
| 263 | 
            +
                    if time_embedding_type == "fourier":
         | 
| 264 | 
            +
                        time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
         | 
| 265 | 
            +
                        if time_embed_dim % 2 != 0:
         | 
| 266 | 
            +
                            raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
         | 
| 267 | 
            +
                        self.time_proj = GaussianFourierProjection(
         | 
| 268 | 
            +
                            time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
         | 
| 269 | 
            +
                        )
         | 
| 270 | 
            +
                        timestep_input_dim = time_embed_dim
         | 
| 271 | 
            +
                    elif time_embedding_type == "positional":
         | 
| 272 | 
            +
                        time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         | 
| 275 | 
            +
                        timestep_input_dim = block_out_channels[0]
         | 
| 276 | 
            +
                    else:
         | 
| 277 | 
            +
                        raise ValueError(
         | 
| 278 | 
            +
                            f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
         | 
| 279 | 
            +
                        )
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    self.time_embedding = TimestepEmbedding(
         | 
| 282 | 
            +
                        timestep_input_dim,
         | 
| 283 | 
            +
                        time_embed_dim,
         | 
| 284 | 
            +
                        act_fn=act_fn,
         | 
| 285 | 
            +
                        post_act_fn=timestep_post_act,
         | 
| 286 | 
            +
                        cond_proj_dim=time_cond_proj_dim,
         | 
| 287 | 
            +
                    )
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    if encoder_hid_dim_type is None and encoder_hid_dim is not None:
         | 
| 290 | 
            +
                        encoder_hid_dim_type = "text_proj"
         | 
| 291 | 
            +
                        self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
         | 
| 292 | 
            +
                        logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    if encoder_hid_dim is None and encoder_hid_dim_type is not None:
         | 
| 295 | 
            +
                        raise ValueError(
         | 
| 296 | 
            +
                            f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
         | 
| 297 | 
            +
                        )
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    if encoder_hid_dim_type == "text_proj":
         | 
| 300 | 
            +
                        self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
         | 
| 301 | 
            +
                    elif encoder_hid_dim_type == "text_image_proj":
         | 
| 302 | 
            +
                        # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
         | 
| 303 | 
            +
                        # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently
         | 
| 304 | 
            +
                        # only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
         | 
| 305 | 
            +
                        self.encoder_hid_proj = TextImageProjection(
         | 
| 306 | 
            +
                            text_embed_dim=encoder_hid_dim,
         | 
| 307 | 
            +
                            image_embed_dim=cross_attention_dim,
         | 
| 308 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 309 | 
            +
                        )
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    elif encoder_hid_dim_type is not None:
         | 
| 312 | 
            +
                        raise ValueError(
         | 
| 313 | 
            +
                            f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
         | 
| 314 | 
            +
                        )
         | 
| 315 | 
            +
                    else:
         | 
| 316 | 
            +
                        self.encoder_hid_proj = None
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    # class embedding
         | 
| 319 | 
            +
                    if class_embed_type is None and num_class_embeds is not None:
         | 
| 320 | 
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
         | 
| 321 | 
            +
                    elif class_embed_type == "timestep":
         | 
| 322 | 
            +
                        self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
         | 
| 323 | 
            +
                    elif class_embed_type == "identity":
         | 
| 324 | 
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         | 
| 325 | 
            +
                    elif class_embed_type == "projection":
         | 
| 326 | 
            +
                        if projection_class_embeddings_input_dim is None:
         | 
| 327 | 
            +
                            raise ValueError(
         | 
| 328 | 
            +
                                "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
         | 
| 329 | 
            +
                            )
         | 
| 330 | 
            +
                        # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
         | 
| 331 | 
            +
                        # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
         | 
| 332 | 
            +
                        # 2. it projects from an arbitrary input dimension.
         | 
| 333 | 
            +
                        #
         | 
| 334 | 
            +
                        # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
         | 
| 335 | 
            +
                        # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
         | 
| 336 | 
            +
                        # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
         | 
| 337 | 
            +
                        self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
         | 
| 338 | 
            +
                    elif class_embed_type == "simple_projection":
         | 
| 339 | 
            +
                        if projection_class_embeddings_input_dim is None:
         | 
| 340 | 
            +
                            raise ValueError(
         | 
| 341 | 
            +
                                "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
         | 
| 342 | 
            +
                            )
         | 
| 343 | 
            +
                        self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        self.class_embedding = None
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if addition_embed_type == "text":
         | 
| 348 | 
            +
                        if encoder_hid_dim is not None:
         | 
| 349 | 
            +
                            text_time_embedding_from_dim = encoder_hid_dim
         | 
| 350 | 
            +
                        else:
         | 
| 351 | 
            +
                            text_time_embedding_from_dim = cross_attention_dim
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        self.add_embedding = TextTimeEmbedding(
         | 
| 354 | 
            +
                            text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
         | 
| 355 | 
            +
                        )
         | 
| 356 | 
            +
                    elif addition_embed_type == "text_image":
         | 
| 357 | 
            +
                        # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__
         | 
| 358 | 
            +
                        # too much, they are set to `cross_attention_dim` here as this is exactly the required dimension for the
         | 
| 359 | 
            +
                        # currently only use case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
         | 
| 360 | 
            +
                        self.add_embedding = TextImageTimeEmbedding(
         | 
| 361 | 
            +
                            text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
         | 
| 362 | 
            +
                        )
         | 
| 363 | 
            +
                    elif addition_embed_type is not None:
         | 
| 364 | 
            +
                        raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if time_embedding_act_fn is None:
         | 
| 367 | 
            +
                        self.time_embed_act = None
         | 
| 368 | 
            +
                    else:
         | 
| 369 | 
            +
                        self.time_embed_act = get_activation(time_embedding_act_fn)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 372 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    if isinstance(only_cross_attention, bool):
         | 
| 375 | 
            +
                        if mid_block_only_cross_attention is None:
         | 
| 376 | 
            +
                            mid_block_only_cross_attention = only_cross_attention
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                        only_cross_attention = [only_cross_attention] * len(down_block_types)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    if mid_block_only_cross_attention is None:
         | 
| 381 | 
            +
                        mid_block_only_cross_attention = False
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    if isinstance(num_attention_heads, int):
         | 
| 384 | 
            +
                        num_attention_heads = (num_attention_heads,) * len(down_block_types)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    if isinstance(attention_head_dim, int):
         | 
| 387 | 
            +
                        attention_head_dim = (attention_head_dim,) * len(down_block_types)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    if isinstance(cross_attention_dim, int):
         | 
| 390 | 
            +
                        cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    if isinstance(layers_per_block, int):
         | 
| 393 | 
            +
                        layers_per_block = [layers_per_block] * len(down_block_types)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    if class_embeddings_concat:
         | 
| 396 | 
            +
                        # The time embeddings are concatenated with the class embeddings. The dimension of the
         | 
| 397 | 
            +
                        # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
         | 
| 398 | 
            +
                        # regular time embeddings
         | 
| 399 | 
            +
                        blocks_time_embed_dim = time_embed_dim * 2
         | 
| 400 | 
            +
                    else:
         | 
| 401 | 
            +
                        blocks_time_embed_dim = time_embed_dim
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    # down
         | 
| 404 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 405 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 406 | 
            +
                        input_channel = output_channel
         | 
| 407 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 408 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                        down_block = get_down_block(
         | 
| 411 | 
            +
                            down_block_type,
         | 
| 412 | 
            +
                            num_layers=layers_per_block[i],
         | 
| 413 | 
            +
                            in_channels=input_channel,
         | 
| 414 | 
            +
                            out_channels=output_channel,
         | 
| 415 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 416 | 
            +
                            add_downsample=not is_final_block,
         | 
| 417 | 
            +
                            resnet_eps=norm_eps,
         | 
| 418 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 419 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 420 | 
            +
                            cross_attention_dim=cross_attention_dim[i],
         | 
| 421 | 
            +
                            num_attention_heads=num_attention_heads[i],
         | 
| 422 | 
            +
                            downsample_padding=downsample_padding,
         | 
| 423 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 424 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 425 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 426 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 427 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 428 | 
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         | 
| 429 | 
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         | 
| 430 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 431 | 
            +
                            attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         | 
| 432 | 
            +
                        )
         | 
| 433 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    # mid
         | 
| 436 | 
            +
                    if mid_block_type == "UNetMidBlock2DCrossAttn":
         | 
| 437 | 
            +
                        self.mid_block = UNetMidBlock2DCrossAttn(
         | 
| 438 | 
            +
                            in_channels=block_out_channels[-1],
         | 
| 439 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 440 | 
            +
                            resnet_eps=norm_eps,
         | 
| 441 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 442 | 
            +
                            output_scale_factor=mid_block_scale_factor,
         | 
| 443 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 444 | 
            +
                            cross_attention_dim=cross_attention_dim[-1],
         | 
| 445 | 
            +
                            num_attention_heads=num_attention_heads[-1],
         | 
| 446 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 447 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 448 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 449 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 450 | 
            +
                        )
         | 
| 451 | 
            +
                    elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
         | 
| 452 | 
            +
                        self.mid_block = UNetMidBlock2DSimpleCrossAttn(
         | 
| 453 | 
            +
                            in_channels=block_out_channels[-1],
         | 
| 454 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 455 | 
            +
                            resnet_eps=norm_eps,
         | 
| 456 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 457 | 
            +
                            output_scale_factor=mid_block_scale_factor,
         | 
| 458 | 
            +
                            cross_attention_dim=cross_attention_dim[-1],
         | 
| 459 | 
            +
                            attention_head_dim=attention_head_dim[-1],
         | 
| 460 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 461 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 462 | 
            +
                            skip_time_act=resnet_skip_time_act,
         | 
| 463 | 
            +
                            only_cross_attention=mid_block_only_cross_attention,
         | 
| 464 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 465 | 
            +
                        )
         | 
| 466 | 
            +
                    elif mid_block_type is None:
         | 
| 467 | 
            +
                        self.mid_block = None
         | 
| 468 | 
            +
                    else:
         | 
| 469 | 
            +
                        raise ValueError(f"unknown mid_block_type : {mid_block_type}")
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    # count how many layers upsample the images
         | 
| 472 | 
            +
                    self.num_upsamplers = 0
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    # up
         | 
| 475 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 476 | 
            +
                    reversed_num_attention_heads = list(reversed(num_attention_heads))
         | 
| 477 | 
            +
                    reversed_layers_per_block = list(reversed(layers_per_block))
         | 
| 478 | 
            +
                    reversed_cross_attention_dim = list(reversed(cross_attention_dim))
         | 
| 479 | 
            +
                    only_cross_attention = list(reversed(only_cross_attention))
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 482 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 483 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        prev_output_channel = output_channel
         | 
| 486 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 487 | 
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                        # add upsample block for all BUT final layer
         | 
| 490 | 
            +
                        if not is_final_block:
         | 
| 491 | 
            +
                            add_upsample = True
         | 
| 492 | 
            +
                            self.num_upsamplers += 1
         | 
| 493 | 
            +
                        else:
         | 
| 494 | 
            +
                            add_upsample = False
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                        up_block = get_up_block(
         | 
| 497 | 
            +
                            up_block_type,
         | 
| 498 | 
            +
                            num_layers=reversed_layers_per_block[i] + 1,
         | 
| 499 | 
            +
                            in_channels=input_channel,
         | 
| 500 | 
            +
                            out_channels=output_channel,
         | 
| 501 | 
            +
                            prev_output_channel=prev_output_channel,
         | 
| 502 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 503 | 
            +
                            add_upsample=add_upsample,
         | 
| 504 | 
            +
                            resnet_eps=norm_eps,
         | 
| 505 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 506 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 507 | 
            +
                            cross_attention_dim=reversed_cross_attention_dim[i],
         | 
| 508 | 
            +
                            num_attention_heads=reversed_num_attention_heads[i],
         | 
| 509 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 510 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 511 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 512 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 513 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 514 | 
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         | 
| 515 | 
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         | 
| 516 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 517 | 
            +
                            attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         | 
| 518 | 
            +
                        )
         | 
| 519 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 520 | 
            +
                        prev_output_channel = output_channel
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    # out
         | 
| 523 | 
            +
                    if norm_num_groups is not None:
         | 
| 524 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 525 | 
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
         | 
| 526 | 
            +
                        )
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                        self.conv_act = get_activation(act_fn)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    else:
         | 
| 531 | 
            +
                        self.conv_norm_out = None
         | 
| 532 | 
            +
                        self.conv_act = None
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    conv_out_padding = (conv_out_kernel - 1) // 2
         | 
| 535 | 
            +
                    self.conv_out = nn.Conv2d(
         | 
| 536 | 
            +
                        block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
         | 
| 537 | 
            +
                    )
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                @property
         | 
| 540 | 
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         | 
| 541 | 
            +
                    r"""
         | 
| 542 | 
            +
                    Returns:
         | 
| 543 | 
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         | 
| 544 | 
            +
                        indexed by its weight name.
         | 
| 545 | 
            +
                    """
         | 
| 546 | 
            +
                    # set recursively
         | 
| 547 | 
            +
                    processors = {}
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         | 
| 550 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 551 | 
            +
                            processors[f"{name}.processor"] = module.processor
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 554 | 
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                        return processors
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    for name, module in self.named_children():
         | 
| 559 | 
            +
                        fn_recursive_add_processors(name, module, processors)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    return processors
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         | 
| 564 | 
            +
                    r"""
         | 
| 565 | 
            +
                    Parameters:
         | 
| 566 | 
            +
                        `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
         | 
| 567 | 
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         | 
| 568 | 
            +
                            of **all** `Attention` layers.
         | 
| 569 | 
            +
                        In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. 
         | 
| 570 | 
            +
                        This is strongly recommended when setting trainable attention processors.
         | 
| 571 | 
            +
                    """
         | 
| 572 | 
            +
                    count = len(self.attn_processors.keys())
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 575 | 
            +
                        raise ValueError(
         | 
| 576 | 
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         | 
| 577 | 
            +
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 578 | 
            +
                        )
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 581 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 582 | 
            +
                            if not isinstance(processor, dict):
         | 
| 583 | 
            +
                                module.set_processor(processor)
         | 
| 584 | 
            +
                            else:
         | 
| 585 | 
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 588 | 
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    for name, module in self.named_children():
         | 
| 591 | 
            +
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                def set_default_attn_processor(self):
         | 
| 594 | 
            +
                    """
         | 
| 595 | 
            +
                    Disables custom attention processors and sets the default attention implementation.
         | 
| 596 | 
            +
                    """
         | 
| 597 | 
            +
                    self.set_attn_processor(AttnProcessor())
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 600 | 
            +
                    r"""
         | 
| 601 | 
            +
                    Enable sliced attention computation.
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    When this option is enabled, the attention module will split the input tensor in slices, to compute attention
         | 
| 604 | 
            +
                    in several steps. This is useful to save some memory in exchange for a small speed decrease.
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                    Args:
         | 
| 607 | 
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         | 
| 608 | 
            +
                            When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
         | 
| 609 | 
            +
                            `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
         | 
| 610 | 
            +
                            provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
         | 
| 611 | 
            +
                            `num_attention_heads` must be a multiple of `slice_size`.
         | 
| 612 | 
            +
                    """
         | 
| 613 | 
            +
                    sliceable_head_dims = []
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
         | 
| 616 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 617 | 
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                        for child in module.children():
         | 
| 620 | 
            +
                            fn_recursive_retrieve_sliceable_dims(child)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    # retrieve number of attention layers
         | 
| 623 | 
            +
                    for module in self.children():
         | 
| 624 | 
            +
                        fn_recursive_retrieve_sliceable_dims(module)
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    num_sliceable_layers = len(sliceable_head_dims)
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    if slice_size == "auto":
         | 
| 629 | 
            +
                        # half the attention head size is usually a good trade-off between
         | 
| 630 | 
            +
                        # speed and memory
         | 
| 631 | 
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         | 
| 632 | 
            +
                    elif slice_size == "max":
         | 
| 633 | 
            +
                        # make smallest slice possible
         | 
| 634 | 
            +
                        slice_size = num_sliceable_layers * [1]
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                    slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    if len(slice_size) != len(sliceable_head_dims):
         | 
| 639 | 
            +
                        raise ValueError(
         | 
| 640 | 
            +
                            f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
         | 
| 641 | 
            +
                            f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         | 
| 642 | 
            +
                        )
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                    for i in range(len(slice_size)):
         | 
| 645 | 
            +
                        size = slice_size[i]
         | 
| 646 | 
            +
                        dim = sliceable_head_dims[i]
         | 
| 647 | 
            +
                        if size is not None and size > dim:
         | 
| 648 | 
            +
                            raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    # Recursively walk through all the children.
         | 
| 651 | 
            +
                    # Any children which exposes the set_attention_slice method
         | 
| 652 | 
            +
                    # gets the message
         | 
| 653 | 
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         | 
| 654 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 655 | 
            +
                            module.set_attention_slice(slice_size.pop())
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                        for child in module.children():
         | 
| 658 | 
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    reversed_slice_size = list(reversed(slice_size))
         | 
| 661 | 
            +
                    for module in self.children():
         | 
| 662 | 
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 665 | 
            +
                    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
         | 
| 666 | 
            +
                        module.gradient_checkpointing = value
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                def forward(
         | 
| 669 | 
            +
                    self,
         | 
| 670 | 
            +
                    sample: torch.FloatTensor,
         | 
| 671 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 672 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 673 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 674 | 
            +
                    timestep_cond: Optional[torch.Tensor] = None,
         | 
| 675 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 676 | 
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 677 | 
            +
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         | 
| 678 | 
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         | 
| 679 | 
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         | 
| 680 | 
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 681 | 
            +
                    return_dict: bool = True,
         | 
| 682 | 
            +
                    **kwargs
         | 
| 683 | 
            +
                ) -> Union[UNet2DConditionOutput, Tuple]:
         | 
| 684 | 
            +
                    r"""
         | 
| 685 | 
            +
                    Args:
         | 
| 686 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 687 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 688 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 689 | 
            +
                        encoder_attention_mask (`torch.Tensor`):
         | 
| 690 | 
            +
                            (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
         | 
| 691 | 
            +
                            discard. Mask will be converted into a bias, which adds large negative values to attention scores
         | 
| 692 | 
            +
                            corresponding to "discard" tokens.
         | 
| 693 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 694 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         | 
| 695 | 
            +
                        cross_attention_kwargs (`dict`, *optional*):
         | 
| 696 | 
            +
                            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
         | 
| 697 | 
            +
                            `self.processor` in [diffusers.cross_attention]
         | 
| 698 | 
            +
                            (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
         | 
| 699 | 
            +
                        added_cond_kwargs (`dict`, *optional*):
         | 
| 700 | 
            +
                            A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
         | 
| 701 | 
            +
                            embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
         | 
| 702 | 
            +
                            `addition_embed_type` for more information.
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    Returns:
         | 
| 705 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
         | 
| 706 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
         | 
| 707 | 
            +
                        returning a tuple, the first element is the sample tensor.
         | 
| 708 | 
            +
                    """
         | 
| 709 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 710 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
         | 
| 711 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 712 | 
            +
                    # on the fly if necessary.
         | 
| 713 | 
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 716 | 
            +
                    forward_upsample_size = False
         | 
| 717 | 
            +
                    upsample_size = None
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 720 | 
            +
                        logger.info("Forward upsample size to force interpolation output size.")
         | 
| 721 | 
            +
                        forward_upsample_size = True
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
         | 
| 724 | 
            +
                    # expects mask of shape:
         | 
| 725 | 
            +
                    #   [batch, key_tokens]
         | 
| 726 | 
            +
                    # adds singleton query_tokens dimension:
         | 
| 727 | 
            +
                    #   [batch,                    1, key_tokens]
         | 
| 728 | 
            +
                    # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         | 
| 729 | 
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         | 
| 730 | 
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         | 
| 731 | 
            +
                    if attention_mask is not None:
         | 
| 732 | 
            +
                        # assume that mask is expressed as:
         | 
| 733 | 
            +
                        #   (1 = keep,      0 = discard)
         | 
| 734 | 
            +
                        # convert mask into a bias that can be added to attention scores:
         | 
| 735 | 
            +
                        #       (keep = +0,     discard = -10000.0)
         | 
| 736 | 
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         | 
| 737 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 740 | 
            +
                    if encoder_attention_mask is not None:
         | 
| 741 | 
            +
                        encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
         | 
| 742 | 
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                    # 0. center input if necessary
         | 
| 745 | 
            +
                    if self.config.center_input_sample:
         | 
| 746 | 
            +
                        sample = 2 * sample - 1.0
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                    # 1. time
         | 
| 749 | 
            +
                    timesteps = timestep
         | 
| 750 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 751 | 
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         | 
| 752 | 
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 753 | 
            +
                        is_mps = sample.device.type == "mps"
         | 
| 754 | 
            +
                        if isinstance(timestep, float):
         | 
| 755 | 
            +
                            dtype = torch.float32 if is_mps else torch.float64
         | 
| 756 | 
            +
                        else:
         | 
| 757 | 
            +
                            dtype = torch.int32 if is_mps else torch.int64
         | 
| 758 | 
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         | 
| 759 | 
            +
                    elif len(timesteps.shape) == 0:
         | 
| 760 | 
            +
                        timesteps = timesteps[None].to(sample.device)
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 763 | 
            +
                    timesteps = timesteps.expand(sample.shape[0])
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                    t_emb = self.time_proj(timesteps)
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    # `Timesteps` does not contain any weights and will always return f32 tensors
         | 
| 768 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 769 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 770 | 
            +
                    t_emb = t_emb.to(dtype=sample.dtype)
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    emb = self.time_embedding(t_emb, timestep_cond)
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                    if self.class_embedding is not None:
         | 
| 775 | 
            +
                        if class_labels is None:
         | 
| 776 | 
            +
                            raise ValueError("class_labels should be provided when num_class_embeds > 0")
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                        if self.config.class_embed_type == "timestep":
         | 
| 779 | 
            +
                            class_labels = self.time_proj(class_labels)
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                            # `Timesteps` does not contain any weights and will always return f32 tensors
         | 
| 782 | 
            +
                            # there might be better ways to encapsulate this.
         | 
| 783 | 
            +
                            class_labels = class_labels.to(dtype=sample.dtype)
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                        if self.config.class_embeddings_concat:
         | 
| 788 | 
            +
                            emb = torch.cat([emb, class_emb], dim=-1)
         | 
| 789 | 
            +
                        else:
         | 
| 790 | 
            +
                            emb = emb + class_emb
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    if self.config.addition_embed_type == "text":
         | 
| 793 | 
            +
                        aug_emb = self.add_embedding(encoder_hidden_states)
         | 
| 794 | 
            +
                        emb = emb + aug_emb
         | 
| 795 | 
            +
                    elif self.config.addition_embed_type == "text_image":
         | 
| 796 | 
            +
                        # Kadinsky 2.1 - style
         | 
| 797 | 
            +
                        if "image_embeds" not in added_cond_kwargs:
         | 
| 798 | 
            +
                            raise ValueError(
         | 
| 799 | 
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
         | 
| 800 | 
            +
                                "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
         | 
| 801 | 
            +
                            )
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                        image_embs = added_cond_kwargs.get("image_embeds")
         | 
| 804 | 
            +
                        text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                        aug_emb = self.add_embedding(text_embs, image_embs)
         | 
| 807 | 
            +
                        emb = emb + aug_emb
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                    if self.time_embed_act is not None:
         | 
| 810 | 
            +
                        emb = self.time_embed_act(emb)
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                    if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
         | 
| 813 | 
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
         | 
| 814 | 
            +
                    elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
         | 
| 815 | 
            +
                        # Kadinsky 2.1 - style
         | 
| 816 | 
            +
                        if "image_embeds" not in added_cond_kwargs:
         | 
| 817 | 
            +
                            raise ValueError(
         | 
| 818 | 
            +
                                f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
         | 
| 819 | 
            +
                                "which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
         | 
| 820 | 
            +
                            )
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                        image_embeds = added_cond_kwargs.get("image_embeds")
         | 
| 823 | 
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                    # 2. pre-process
         | 
| 826 | 
            +
                    sample = self.conv_in(sample)
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                    # 3. down
         | 
| 829 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 830 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 831 | 
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         | 
| 832 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 833 | 
            +
                                hidden_states=sample,
         | 
| 834 | 
            +
                                temb=emb,
         | 
| 835 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 836 | 
            +
                                attention_mask=attention_mask,
         | 
| 837 | 
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 838 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 839 | 
            +
                            )
         | 
| 840 | 
            +
                        else:
         | 
| 841 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                        down_block_res_samples += res_samples
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    if down_block_additional_residuals is not None:
         | 
| 846 | 
            +
                        new_down_block_res_samples = ()
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                        for down_block_res_sample, down_block_additional_residual in zip(
         | 
| 849 | 
            +
                            down_block_res_samples, down_block_additional_residuals
         | 
| 850 | 
            +
                        ):
         | 
| 851 | 
            +
                            down_block_res_sample = down_block_res_sample + down_block_additional_residual
         | 
| 852 | 
            +
                            new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                        down_block_res_samples = new_down_block_res_samples
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                    # 4. mid
         | 
| 857 | 
            +
                    if self.mid_block is not None:
         | 
| 858 | 
            +
                        sample = self.mid_block(
         | 
| 859 | 
            +
                            sample,
         | 
| 860 | 
            +
                            emb,
         | 
| 861 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 862 | 
            +
                            attention_mask=attention_mask,
         | 
| 863 | 
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         | 
| 864 | 
            +
                            encoder_attention_mask=encoder_attention_mask,
         | 
| 865 | 
            +
                        )
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                    if mid_block_additional_residual is not None:
         | 
| 868 | 
            +
                        sample = sample + mid_block_additional_residual
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                    # 5. up
         | 
| 871 | 
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         | 
| 872 | 
            +
                        is_final_block = i == len(self.up_blocks) - 1
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 875 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         | 
| 876 | 
            +
             | 
| 877 | 
            +
                        # if we have not reached the final block and need to forward the
         | 
| 878 | 
            +
                        # upsample size, we do it here
         | 
| 879 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 880 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         | 
| 883 | 
            +
                            sample = upsample_block(
         | 
| 884 | 
            +
                                hidden_states=sample,
         | 
| 885 | 
            +
                                temb=emb,
         | 
| 886 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 887 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 888 | 
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 889 | 
            +
                                upsample_size=upsample_size,
         | 
| 890 | 
            +
                                attention_mask=attention_mask,
         | 
| 891 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 892 | 
            +
                            )
         | 
| 893 | 
            +
                        else:
         | 
| 894 | 
            +
                            sample = upsample_block(
         | 
| 895 | 
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         | 
| 896 | 
            +
                            )
         | 
| 897 | 
            +
             | 
| 898 | 
            +
                    # 6. post-process
         | 
| 899 | 
            +
                    if self.conv_norm_out:
         | 
| 900 | 
            +
                        sample = self.conv_norm_out(sample)
         | 
| 901 | 
            +
                        sample = self.conv_act(sample)
         | 
| 902 | 
            +
                    sample = self.conv_out(sample)
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                    if not return_dict:
         | 
| 905 | 
            +
                        return (sample,)
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                    return UNet2DConditionOutput(sample=sample)
         | 
    	
        diffusers/models/unet_2d_condition_guided.py
    ADDED
    
    | @@ -0,0 +1,945 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from dataclasses import dataclass
         | 
| 16 | 
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            import torch.utils.checkpoint
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from ..utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 23 | 
            +
            from ..utils import logging
         | 
| 24 | 
            +
            from .loaders import UNet2DConditionLoadersMixin
         | 
| 25 | 
            +
            from .activations import get_activation
         | 
| 26 | 
            +
            from .attention_processor import AttentionProcessor, AttnProcessor
         | 
| 27 | 
            +
            from .embeddings import (
         | 
| 28 | 
            +
                GaussianFourierProjection,
         | 
| 29 | 
            +
                TextImageProjection,
         | 
| 30 | 
            +
                TextImageTimeEmbedding,
         | 
| 31 | 
            +
                TextTimeEmbedding,
         | 
| 32 | 
            +
                TimestepEmbedding,
         | 
| 33 | 
            +
                Timesteps,
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            from .modeling_utils import ModelMixin
         | 
| 36 | 
            +
            from .unet_2d_blocks import (
         | 
| 37 | 
            +
                CrossAttnDownBlock2D,
         | 
| 38 | 
            +
                CrossAttnUpBlock2D,
         | 
| 39 | 
            +
                DownBlock2D,
         | 
| 40 | 
            +
                UNetMidBlock2DCrossAttn,
         | 
| 41 | 
            +
                UNetMidBlock2DSimpleCrossAttn,
         | 
| 42 | 
            +
                UpBlock2D,
         | 
| 43 | 
            +
                get_down_block,
         | 
| 44 | 
            +
                get_up_block,
         | 
| 45 | 
            +
            )
         | 
| 46 | 
            +
            from .unet_2d_condition import UNet2DConditionOutput
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class UNet2DConditionGuidedModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
         | 
| 52 | 
            +
                r"""
         | 
| 53 | 
            +
                UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample,
         | 
| 54 | 
            +
                conditional state, and a timestep and returns sample shaped output.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic
         | 
| 57 | 
            +
                methods the library implements for all the models (such as downloading or saving, etc.)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Parameters:
         | 
| 60 | 
            +
                    sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
         | 
| 61 | 
            +
                        Height and width of input/output sample.
         | 
| 62 | 
            +
                    in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
         | 
| 63 | 
            +
                    out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
         | 
| 64 | 
            +
                    center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
         | 
| 65 | 
            +
                    flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
         | 
| 66 | 
            +
                        Whether to flip the sin to cos in the time embedding.
         | 
| 67 | 
            +
                    freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
         | 
| 68 | 
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to 
         | 
| 69 | 
            +
                        `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
         | 
| 70 | 
            +
                        The tuple of downsample blocks to use.
         | 
| 71 | 
            +
                    mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
         | 
| 72 | 
            +
                        The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`,
         | 
| 73 | 
            +
                        will skip the mid block layer if `None`.
         | 
| 74 | 
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to 
         | 
| 75 | 
            +
                        `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
         | 
| 76 | 
            +
                        The tuple of upsample blocks to use.
         | 
| 77 | 
            +
                    only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
         | 
| 78 | 
            +
                        Whether to include self-attention in the basic transformer blocks, see
         | 
| 79 | 
            +
                        [`~models.attention.BasicTransformerBlock`].
         | 
| 80 | 
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
         | 
| 81 | 
            +
                        The tuple of output channels for each block.
         | 
| 82 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
         | 
| 83 | 
            +
                    downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
         | 
| 84 | 
            +
                    mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
         | 
| 85 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         | 
| 86 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
         | 
| 87 | 
            +
                        If `None`, it will skip the normalization and activation layers in post-processing
         | 
| 88 | 
            +
                    norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
         | 
| 89 | 
            +
                    cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
         | 
| 90 | 
            +
                        The dimension of the cross attention features.
         | 
| 91 | 
            +
                    encoder_hid_dim (`int`, *optional*, defaults to None):
         | 
| 92 | 
            +
                        If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
         | 
| 93 | 
            +
                        dimension to `cross_attention_dim`.
         | 
| 94 | 
            +
                    encoder_hid_dim_type (`str`, *optional*, defaults to None):
         | 
| 95 | 
            +
                        If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
         | 
| 96 | 
            +
                        embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
         | 
| 97 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
         | 
| 98 | 
            +
                    num_attention_heads (`int`, *optional*):
         | 
| 99 | 
            +
                        The number of attention heads. If not defined, defaults to `attention_head_dim`
         | 
| 100 | 
            +
                    resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
         | 
| 101 | 
            +
                        for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
         | 
| 102 | 
            +
                    class_embed_type (`str`, *optional*, defaults to None):
         | 
| 103 | 
            +
                        The type of class embedding to use which is ultimately summed with the time embeddings.
         | 
| 104 | 
            +
                        Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
         | 
| 105 | 
            +
                    addition_embed_type (`str`, *optional*, defaults to None):
         | 
| 106 | 
            +
                        Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
         | 
| 107 | 
            +
                        "text". "text" will use the `TextTimeEmbedding` layer.
         | 
| 108 | 
            +
                    num_class_embeds (`int`, *optional*, defaults to None):
         | 
| 109 | 
            +
                        Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
         | 
| 110 | 
            +
                        class conditioning with `class_embed_type` equal to `None`.
         | 
| 111 | 
            +
                    time_embedding_type (`str`, *optional*, default to `positional`):
         | 
| 112 | 
            +
                        The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
         | 
| 113 | 
            +
                    time_embedding_dim (`int`, *optional*, default to `None`):
         | 
| 114 | 
            +
                        An optional override for the dimension of the projected time embedding.
         | 
| 115 | 
            +
                    time_embedding_act_fn (`str`, *optional*, default to `None`):
         | 
| 116 | 
            +
                        Optional activation function to use on the time embeddings only one time before they as passed
         | 
| 117 | 
            +
                        to the rest of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
         | 
| 118 | 
            +
                    timestep_post_act (`str, *optional*, default to `None`):
         | 
| 119 | 
            +
                        The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
         | 
| 120 | 
            +
                    time_cond_proj_dim (`int`, *optional*, default to `None`):
         | 
| 121 | 
            +
                        The dimension of `cond_proj` layer in timestep embedding.
         | 
| 122 | 
            +
                    conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
         | 
| 123 | 
            +
                    conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
         | 
| 124 | 
            +
                    projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
         | 
| 125 | 
            +
                        using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
         | 
| 126 | 
            +
                    class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
         | 
| 127 | 
            +
                        embeddings with the class embeddings.
         | 
| 128 | 
            +
                    mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
         | 
| 129 | 
            +
                        Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
         | 
| 130 | 
            +
                        `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
         | 
| 131 | 
            +
                        `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`.
         | 
| 132 | 
            +
                        Else, it will default to `False`.
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                @register_to_config
         | 
| 138 | 
            +
                def __init__(
         | 
| 139 | 
            +
                    self,
         | 
| 140 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 141 | 
            +
                    in_channels: int = 4,
         | 
| 142 | 
            +
                    out_channels: int = 4,
         | 
| 143 | 
            +
                    center_input_sample: bool = False,
         | 
| 144 | 
            +
                    flip_sin_to_cos: bool = True,
         | 
| 145 | 
            +
                    freq_shift: int = 0,
         | 
| 146 | 
            +
                    down_block_types: Tuple[str] = (
         | 
| 147 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 148 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 149 | 
            +
                        "CrossAttnDownBlock2D",
         | 
| 150 | 
            +
                        "DownBlock2D",
         | 
| 151 | 
            +
                    ),
         | 
| 152 | 
            +
                    mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
         | 
| 153 | 
            +
                    up_block_types: Tuple[str] = (
         | 
| 154 | 
            +
                        "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
         | 
| 155 | 
            +
                    ),
         | 
| 156 | 
            +
                    only_cross_attention: Union[bool, Tuple[bool]] = False,
         | 
| 157 | 
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         | 
| 158 | 
            +
                    layers_per_block: Union[int, Tuple[int]] = 2,
         | 
| 159 | 
            +
                    downsample_padding: int = 1,
         | 
| 160 | 
            +
                    mid_block_scale_factor: float = 1,
         | 
| 161 | 
            +
                    act_fn: str = "silu",
         | 
| 162 | 
            +
                    norm_num_groups: Optional[int] = 32,
         | 
| 163 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 164 | 
            +
                    cross_attention_dim: Union[int, Tuple[int]] = 1280,
         | 
| 165 | 
            +
                    encoder_hid_dim: Optional[int] = None,
         | 
| 166 | 
            +
                    encoder_hid_dim_type: Optional[str] = None,
         | 
| 167 | 
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         | 
| 168 | 
            +
                    num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
         | 
| 169 | 
            +
                    dual_cross_attention: bool = False,
         | 
| 170 | 
            +
                    use_linear_projection: bool = False,
         | 
| 171 | 
            +
                    class_embed_type: Optional[str] = None,
         | 
| 172 | 
            +
                    addition_embed_type: Optional[str] = None,
         | 
| 173 | 
            +
                    num_class_embeds: Optional[int] = None,
         | 
| 174 | 
            +
                    upcast_attention: bool = False,
         | 
| 175 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 176 | 
            +
                    resnet_skip_time_act: bool = False,
         | 
| 177 | 
            +
                    resnet_out_scale_factor: int = 1.0,
         | 
| 178 | 
            +
                    time_embedding_type: str = "positional",
         | 
| 179 | 
            +
                    time_embedding_dim: Optional[int] = None,
         | 
| 180 | 
            +
                    time_embedding_act_fn: Optional[str] = None,
         | 
| 181 | 
            +
                    timestep_post_act: Optional[str] = None,
         | 
| 182 | 
            +
                    time_cond_proj_dim: Optional[int] = None,
         | 
| 183 | 
            +
                    guidance_embedding_type: str = "fourier",
         | 
| 184 | 
            +
                    guidance_embedding_dim: Optional[int] = None,
         | 
| 185 | 
            +
                    guidance_post_act: Optional[str] = None,
         | 
| 186 | 
            +
                    guidance_cond_proj_dim: Optional[int] = None,
         | 
| 187 | 
            +
                    conv_in_kernel: int = 3,
         | 
| 188 | 
            +
                    conv_out_kernel: int = 3,
         | 
| 189 | 
            +
                    projection_class_embeddings_input_dim: Optional[int] = None,
         | 
| 190 | 
            +
                    class_embeddings_concat: bool = False,
         | 
| 191 | 
            +
                    mid_block_only_cross_attention: Optional[bool] = None,
         | 
| 192 | 
            +
                    cross_attention_norm: Optional[str] = None,
         | 
| 193 | 
            +
                    addition_embed_type_num_heads=64,
         | 
| 194 | 
            +
                ):
         | 
| 195 | 
            +
                    super().__init__()
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.sample_size = sample_size
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    # If `num_attention_heads` is not defined (which is the case for most models)
         | 
| 200 | 
            +
                    # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
         | 
| 201 | 
            +
                    # The reason for this behavior is to correct for incorrectly named variables that were introduced
         | 
| 202 | 
            +
                    # when this library was created. The incorrect naming was only discovered much later in 
         | 
| 203 | 
            +
                    # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
         | 
| 204 | 
            +
                    # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too
         | 
| 205 | 
            +
                    # backwards breaking which is why we correct for the naming here.
         | 
| 206 | 
            +
                    num_attention_heads = num_attention_heads or attention_head_dim
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # Check inputs
         | 
| 209 | 
            +
                    if len(down_block_types) != len(up_block_types):
         | 
| 210 | 
            +
                        raise ValueError(
         | 
| 211 | 
            +
                            "Must provide the same number of `down_block_types` as `up_block_types`. "
         | 
| 212 | 
            +
                            f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
         | 
| 213 | 
            +
                        )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if len(block_out_channels) != len(down_block_types):
         | 
| 216 | 
            +
                        raise ValueError(
         | 
| 217 | 
            +
                            "Must provide the same number of `block_out_channels` as `down_block_types`. "
         | 
| 218 | 
            +
                            f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
         | 
| 222 | 
            +
                        raise ValueError(
         | 
| 223 | 
            +
                            "Must provide the same number of `only_cross_attention` as `down_block_types`. "
         | 
| 224 | 
            +
                            f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
         | 
| 225 | 
            +
                        )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
         | 
| 228 | 
            +
                        raise ValueError(
         | 
| 229 | 
            +
                            "Must provide the same number of `num_attention_heads` as `down_block_types`. "
         | 
| 230 | 
            +
                            f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
         | 
| 234 | 
            +
                        raise ValueError(
         | 
| 235 | 
            +
                            "Must provide the same number of `attention_head_dim` as `down_block_types`. "
         | 
| 236 | 
            +
                            f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
         | 
| 240 | 
            +
                        raise ValueError(
         | 
| 241 | 
            +
                            "Must provide the same number of `cross_attention_dim` as `down_block_types`. "
         | 
| 242 | 
            +
                            f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
         | 
| 243 | 
            +
                        )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
         | 
| 246 | 
            +
                        raise ValueError(
         | 
| 247 | 
            +
                            f"Must provide the same number of `layers_per_block` as `down_block_types`. "
         | 
| 248 | 
            +
                            f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
         | 
| 249 | 
            +
                        )
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # input
         | 
| 252 | 
            +
                    conv_in_padding = (conv_in_kernel - 1) // 2
         | 
| 253 | 
            +
                    self.conv_in = nn.Conv2d(
         | 
| 254 | 
            +
                        in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
         | 
| 255 | 
            +
                    )
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # time and guidance embeddings
         | 
| 258 | 
            +
                    embedding_types = {'time': time_embedding_type, 'guidance': guidance_embedding_type}
         | 
| 259 | 
            +
                    embedding_dims = {'time': time_embedding_dim, 'guidance': guidance_embedding_dim}
         | 
| 260 | 
            +
                    embed_dims, embed_input_dims, embed_projs = {}, {}, {}
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    for key in ['time', 'guidance']:
         | 
| 263 | 
            +
                        logger.info(f"Using {embedding_types[key]} embedding for {key}.")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                        if embedding_types[key] == "fourier":
         | 
| 266 | 
            +
                            embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
         | 
| 267 | 
            +
                            embed_input_dims[key] = embed_dims[key]
         | 
| 268 | 
            +
                            if embed_dims[key] % 2 != 0:
         | 
| 269 | 
            +
                                raise ValueError(
         | 
| 270 | 
            +
                                    f"`{key}_embed_dim` should be divisible by 2, but is {embed_dims[key]}."
         | 
| 271 | 
            +
                                )
         | 
| 272 | 
            +
                            embed_projs[key] = GaussianFourierProjection(
         | 
| 273 | 
            +
                                embed_dims[key] // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
         | 
| 274 | 
            +
                            )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                        elif embedding_types[key] == "positional":
         | 
| 277 | 
            +
                            embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
         | 
| 278 | 
            +
                            embed_input_dims[key] = block_out_channels[0]
         | 
| 279 | 
            +
                            embed_projs[key] = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        else:
         | 
| 282 | 
            +
                            raise ValueError(
         | 
| 283 | 
            +
                                f"{embedding_types[key]} does not exist for {key} embedding. "
         | 
| 284 | 
            +
                                f"Please make sure to use one of `fourier` or `positional`."
         | 
| 285 | 
            +
                            )
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    self.time_proj, self.guidance_proj = embed_projs['time'], embed_projs['guidance']
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    self.time_embedding = TimestepEmbedding(
         | 
| 290 | 
            +
                        embed_input_dims['time'],
         | 
| 291 | 
            +
                        embed_dims['time'],
         | 
| 292 | 
            +
                        act_fn=act_fn,
         | 
| 293 | 
            +
                        post_act_fn=timestep_post_act,
         | 
| 294 | 
            +
                        cond_proj_dim=time_cond_proj_dim,
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
                    self.guidance_embedding = TimestepEmbedding(
         | 
| 297 | 
            +
                        embed_input_dims['guidance'],
         | 
| 298 | 
            +
                        embed_dims['guidance'],
         | 
| 299 | 
            +
                        act_fn=act_fn,
         | 
| 300 | 
            +
                        post_act_fn=guidance_post_act,
         | 
| 301 | 
            +
                        cond_proj_dim=guidance_cond_proj_dim,
         | 
| 302 | 
            +
                    )
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    if encoder_hid_dim_type is None and encoder_hid_dim is not None:
         | 
| 305 | 
            +
                        encoder_hid_dim_type = "text_proj"
         | 
| 306 | 
            +
                        self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
         | 
| 307 | 
            +
                        logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    if encoder_hid_dim is None and encoder_hid_dim_type is not None:
         | 
| 310 | 
            +
                        raise ValueError(
         | 
| 311 | 
            +
                            "`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` "
         | 
| 312 | 
            +
                            f"is set to {encoder_hid_dim_type}."
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    if encoder_hid_dim_type == "text_proj":
         | 
| 316 | 
            +
                        self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
         | 
| 317 | 
            +
                    elif encoder_hid_dim_type == "text_image_proj":
         | 
| 318 | 
            +
                        # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much,
         | 
| 319 | 
            +
                        # they are set to `cross_attention_dim` here as this is exactly the required dimension for the
         | 
| 320 | 
            +
                        # currently only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
         | 
| 321 | 
            +
                        self.encoder_hid_proj = TextImageProjection(
         | 
| 322 | 
            +
                            text_embed_dim=encoder_hid_dim,
         | 
| 323 | 
            +
                            image_embed_dim=cross_attention_dim,
         | 
| 324 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    elif encoder_hid_dim_type is not None:
         | 
| 328 | 
            +
                        raise ValueError(
         | 
| 329 | 
            +
                            f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
         | 
| 330 | 
            +
                        )
         | 
| 331 | 
            +
                    else:
         | 
| 332 | 
            +
                        self.encoder_hid_proj = None
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    # class embedding
         | 
| 335 | 
            +
                    # print(f"class_embed_type: {class_embed_type}, num_class_embeds: {num_class_embeds}")
         | 
| 336 | 
            +
                    if class_embed_type is None and num_class_embeds is not None:
         | 
| 337 | 
            +
                        self.class_embedding = nn.Embedding(num_class_embeds, embedding_dims['time'])
         | 
| 338 | 
            +
                    elif class_embed_type == "timestep":
         | 
| 339 | 
            +
                        self.class_embedding = TimestepEmbedding(
         | 
| 340 | 
            +
                            embed_input_dims['time'], embed_dims['time'], act_fn=act_fn
         | 
| 341 | 
            +
                        )
         | 
| 342 | 
            +
                    elif class_embed_type == "identity":
         | 
| 343 | 
            +
                        self.class_embedding = nn.Identity(embed_dims['time'], embed_dims['time'])
         | 
| 344 | 
            +
                    elif class_embed_type == "projection":
         | 
| 345 | 
            +
                        if projection_class_embeddings_input_dim is None:
         | 
| 346 | 
            +
                            raise ValueError(
         | 
| 347 | 
            +
                                "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
         | 
| 348 | 
            +
                            )
         | 
| 349 | 
            +
                        # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
         | 
| 350 | 
            +
                        # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
         | 
| 351 | 
            +
                        # 2. it projects from an arbitrary input dimension.
         | 
| 352 | 
            +
                        #
         | 
| 353 | 
            +
                        # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
         | 
| 354 | 
            +
                        # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal
         | 
| 355 | 
            +
                        # embeddings. As a result, `TimestepEmbedding` can be passed arbitrary vectors.
         | 
| 356 | 
            +
                        self.class_embedding = TimestepEmbedding(
         | 
| 357 | 
            +
                            projection_class_embeddings_input_dim, embed_dims['time']
         | 
| 358 | 
            +
                        )
         | 
| 359 | 
            +
                    elif class_embed_type == "simple_projection":
         | 
| 360 | 
            +
                        if projection_class_embeddings_input_dim is None:
         | 
| 361 | 
            +
                            raise ValueError(
         | 
| 362 | 
            +
                                "`class_embed_type`: 'simple_projection' requires "
         | 
| 363 | 
            +
                                "`projection_class_embeddings_input_dim` be set"
         | 
| 364 | 
            +
                            )
         | 
| 365 | 
            +
                        self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, embed_dims['time'])
         | 
| 366 | 
            +
                    else:
         | 
| 367 | 
            +
                        self.class_embedding = None
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    # Addition embedding
         | 
| 370 | 
            +
                    if addition_embed_type == "text":
         | 
| 371 | 
            +
                        if encoder_hid_dim is not None:
         | 
| 372 | 
            +
                            text_time_embedding_from_dim = encoder_hid_dim
         | 
| 373 | 
            +
                        else:
         | 
| 374 | 
            +
                            text_time_embedding_from_dim = cross_attention_dim
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                        self.add_embedding = TextTimeEmbedding(
         | 
| 377 | 
            +
                            text_time_embedding_from_dim, embed_dims['time'], num_heads=addition_embed_type_num_heads
         | 
| 378 | 
            +
                        )
         | 
| 379 | 
            +
                    elif addition_embed_type == "text_image":
         | 
| 380 | 
            +
                        # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
         | 
| 381 | 
            +
                        # To not clutter the __init__ too much, they are set to `cross_attention_dim`
         | 
| 382 | 
            +
                        # here as this is exactly the required dimension for the currently only use case
         | 
| 383 | 
            +
                        # when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
         | 
| 384 | 
            +
                        self.add_embedding = TextImageTimeEmbedding(
         | 
| 385 | 
            +
                            text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim,
         | 
| 386 | 
            +
                            time_embed_dim=embed_dims['time']
         | 
| 387 | 
            +
                        )
         | 
| 388 | 
            +
                    elif addition_embed_type is not None:
         | 
| 389 | 
            +
                        raise ValueError(
         | 
| 390 | 
            +
                            f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
         | 
| 391 | 
            +
                        )
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    # Embedding activation function
         | 
| 394 | 
            +
                    if time_embedding_act_fn is None:
         | 
| 395 | 
            +
                        self.time_embed_act = None
         | 
| 396 | 
            +
                    else:
         | 
| 397 | 
            +
                        self.time_embed_act = get_activation(time_embedding_act_fn)
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 400 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    if isinstance(only_cross_attention, bool):
         | 
| 403 | 
            +
                        if mid_block_only_cross_attention is None:
         | 
| 404 | 
            +
                            mid_block_only_cross_attention = only_cross_attention
         | 
| 405 | 
            +
                        only_cross_attention = [only_cross_attention] * len(down_block_types)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    if mid_block_only_cross_attention is None:
         | 
| 408 | 
            +
                        mid_block_only_cross_attention = False
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    if isinstance(num_attention_heads, int):
         | 
| 411 | 
            +
                        num_attention_heads = (num_attention_heads,) * len(down_block_types)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    if isinstance(attention_head_dim, int):
         | 
| 414 | 
            +
                        attention_head_dim = (attention_head_dim,) * len(down_block_types)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    if isinstance(cross_attention_dim, int):
         | 
| 417 | 
            +
                        cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    if isinstance(layers_per_block, int):
         | 
| 420 | 
            +
                        layers_per_block = [layers_per_block] * len(down_block_types)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    if class_embeddings_concat:
         | 
| 423 | 
            +
                        # The time embeddings are concatenated with the class embeddings. The dimension of the
         | 
| 424 | 
            +
                        # time embeddings passed to the down, middle, and up blocks is twice the dimension of
         | 
| 425 | 
            +
                        # the regular time embeddings
         | 
| 426 | 
            +
                        # Now we have time emb, guidance emb, and class emb
         | 
| 427 | 
            +
                        blocks_time_embed_dim = embed_dims['time'] * 3
         | 
| 428 | 
            +
                    else:
         | 
| 429 | 
            +
                        blocks_time_embed_dim = embed_dims['time']
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    # down
         | 
| 432 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 433 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 434 | 
            +
                        input_channel = output_channel
         | 
| 435 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 436 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                        down_block = get_down_block(
         | 
| 439 | 
            +
                            down_block_type,
         | 
| 440 | 
            +
                            num_layers=layers_per_block[i],
         | 
| 441 | 
            +
                            in_channels=input_channel,
         | 
| 442 | 
            +
                            out_channels=output_channel,
         | 
| 443 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 444 | 
            +
                            add_downsample=not is_final_block,
         | 
| 445 | 
            +
                            resnet_eps=norm_eps,
         | 
| 446 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 447 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 448 | 
            +
                            cross_attention_dim=cross_attention_dim[i],
         | 
| 449 | 
            +
                            num_attention_heads=num_attention_heads[i],
         | 
| 450 | 
            +
                            downsample_padding=downsample_padding,
         | 
| 451 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 452 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 453 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 454 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 455 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 456 | 
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         | 
| 457 | 
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         | 
| 458 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 459 | 
            +
                            attention_head_dim=\
         | 
| 460 | 
            +
                                attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         | 
| 461 | 
            +
                        )
         | 
| 462 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    # mid
         | 
| 465 | 
            +
                    if mid_block_type == "UNetMidBlock2DCrossAttn":
         | 
| 466 | 
            +
                        self.mid_block = UNetMidBlock2DCrossAttn(
         | 
| 467 | 
            +
                            in_channels=block_out_channels[-1],
         | 
| 468 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 469 | 
            +
                            resnet_eps=norm_eps,
         | 
| 470 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 471 | 
            +
                            output_scale_factor=mid_block_scale_factor,
         | 
| 472 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 473 | 
            +
                            cross_attention_dim=cross_attention_dim[-1],
         | 
| 474 | 
            +
                            num_attention_heads=num_attention_heads[-1],
         | 
| 475 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 476 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 477 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 478 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 479 | 
            +
                        )
         | 
| 480 | 
            +
                    elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
         | 
| 481 | 
            +
                        self.mid_block = UNetMidBlock2DSimpleCrossAttn(
         | 
| 482 | 
            +
                            in_channels=block_out_channels[-1],
         | 
| 483 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 484 | 
            +
                            resnet_eps=norm_eps,
         | 
| 485 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 486 | 
            +
                            output_scale_factor=mid_block_scale_factor,
         | 
| 487 | 
            +
                            cross_attention_dim=cross_attention_dim[-1],
         | 
| 488 | 
            +
                            attention_head_dim=attention_head_dim[-1],
         | 
| 489 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 490 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 491 | 
            +
                            skip_time_act=resnet_skip_time_act,
         | 
| 492 | 
            +
                            only_cross_attention=mid_block_only_cross_attention,
         | 
| 493 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 494 | 
            +
                        )
         | 
| 495 | 
            +
                    elif mid_block_type is None:
         | 
| 496 | 
            +
                        self.mid_block = None
         | 
| 497 | 
            +
                    else:
         | 
| 498 | 
            +
                        raise ValueError(f"unknown mid_block_type : {mid_block_type}")
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    # count how many layers upsample the images
         | 
| 501 | 
            +
                    self.num_upsamplers = 0
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    # up
         | 
| 504 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 505 | 
            +
                    reversed_num_attention_heads = list(reversed(num_attention_heads))
         | 
| 506 | 
            +
                    reversed_layers_per_block = list(reversed(layers_per_block))
         | 
| 507 | 
            +
                    reversed_cross_attention_dim = list(reversed(cross_attention_dim))
         | 
| 508 | 
            +
                    only_cross_attention = list(reversed(only_cross_attention))
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 511 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 512 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                        prev_output_channel = output_channel
         | 
| 515 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 516 | 
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                        # add upsample block for all BUT final layer
         | 
| 519 | 
            +
                        if not is_final_block:
         | 
| 520 | 
            +
                            add_upsample = True
         | 
| 521 | 
            +
                            self.num_upsamplers += 1
         | 
| 522 | 
            +
                        else:
         | 
| 523 | 
            +
                            add_upsample = False
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                        up_block = get_up_block(
         | 
| 526 | 
            +
                            up_block_type,
         | 
| 527 | 
            +
                            num_layers=reversed_layers_per_block[i] + 1,
         | 
| 528 | 
            +
                            in_channels=input_channel,
         | 
| 529 | 
            +
                            out_channels=output_channel,
         | 
| 530 | 
            +
                            prev_output_channel=prev_output_channel,
         | 
| 531 | 
            +
                            temb_channels=blocks_time_embed_dim,
         | 
| 532 | 
            +
                            add_upsample=add_upsample,
         | 
| 533 | 
            +
                            resnet_eps=norm_eps,
         | 
| 534 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 535 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 536 | 
            +
                            cross_attention_dim=reversed_cross_attention_dim[i],
         | 
| 537 | 
            +
                            num_attention_heads=reversed_num_attention_heads[i],
         | 
| 538 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 539 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 540 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 541 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 542 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 543 | 
            +
                            resnet_skip_time_act=resnet_skip_time_act,
         | 
| 544 | 
            +
                            resnet_out_scale_factor=resnet_out_scale_factor,
         | 
| 545 | 
            +
                            cross_attention_norm=cross_attention_norm,
         | 
| 546 | 
            +
                            attention_head_dim=\
         | 
| 547 | 
            +
                                attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
         | 
| 548 | 
            +
                        )
         | 
| 549 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 550 | 
            +
                        prev_output_channel = output_channel
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    # out
         | 
| 553 | 
            +
                    if norm_num_groups is not None:
         | 
| 554 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 555 | 
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
         | 
| 556 | 
            +
                        )
         | 
| 557 | 
            +
                        self.conv_act = get_activation(act_fn)
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    else:
         | 
| 560 | 
            +
                        self.conv_norm_out = None
         | 
| 561 | 
            +
                        self.conv_act = None
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    conv_out_padding = (conv_out_kernel - 1) // 2
         | 
| 564 | 
            +
                    self.conv_out = nn.Conv2d(
         | 
| 565 | 
            +
                        block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
         | 
| 566 | 
            +
                    )
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                @property
         | 
| 569 | 
            +
                def attn_processors(self) -> Dict[str, AttentionProcessor]:
         | 
| 570 | 
            +
                    r"""
         | 
| 571 | 
            +
                    Returns:
         | 
| 572 | 
            +
                        `dict` of attention processors: A dictionary containing all attention processors used in
         | 
| 573 | 
            +
                        the model with indexed by its weight name.
         | 
| 574 | 
            +
                    """
         | 
| 575 | 
            +
                    # set recursively
         | 
| 576 | 
            +
                    processors = {}
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    def fn_recursive_add_processors(
         | 
| 579 | 
            +
                        name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]
         | 
| 580 | 
            +
                    ):
         | 
| 581 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 582 | 
            +
                            processors[f"{name}.processor"] = module.processor
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 585 | 
            +
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                        return processors
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    for name, module in self.named_children():
         | 
| 590 | 
            +
                        fn_recursive_add_processors(name, module, processors)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    return processors
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         | 
| 595 | 
            +
                    r"""
         | 
| 596 | 
            +
                    Parameters:
         | 
| 597 | 
            +
                        `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
         | 
| 598 | 
            +
                            The instantiated processor class or a dictionary of processor classes that will be set as the
         | 
| 599 | 
            +
                            processor of **all** `Attention` layers.
         | 
| 600 | 
            +
                        In case `processor` is a dict, the key needs to define the path to the corresponding cross 
         | 
| 601 | 
            +
                        attention processor. This is strongly recommended when setting trainable attention processors.
         | 
| 602 | 
            +
                    """
         | 
| 603 | 
            +
                    count = len(self.attn_processors.keys())
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 606 | 
            +
                        raise ValueError(
         | 
| 607 | 
            +
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match"
         | 
| 608 | 
            +
                            f" the number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 609 | 
            +
                        )
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 612 | 
            +
                        if hasattr(module, "set_processor"):
         | 
| 613 | 
            +
                            if not isinstance(processor, dict):
         | 
| 614 | 
            +
                                module.set_processor(processor)
         | 
| 615 | 
            +
                            else:
         | 
| 616 | 
            +
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                        for sub_name, child in module.named_children():
         | 
| 619 | 
            +
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    for name, module in self.named_children():
         | 
| 622 | 
            +
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                def set_default_attn_processor(self):
         | 
| 625 | 
            +
                    """
         | 
| 626 | 
            +
                    Disables custom attention processors and sets the default attention implementation.
         | 
| 627 | 
            +
                    """
         | 
| 628 | 
            +
                    self.set_attn_processor(AttnProcessor())
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 631 | 
            +
                    r"""
         | 
| 632 | 
            +
                    Enable sliced attention computation.
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    When this option is enabled, the attention module will split the input tensor in slices, to compute
         | 
| 635 | 
            +
                    attention in several steps. This is useful to save some memory in exchange for a small speed decrease.
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    Args:
         | 
| 638 | 
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         | 
| 639 | 
            +
                            When `"auto"`, halves the input to the attention heads, so attention will be computed in two
         | 
| 640 | 
            +
                            steps. "max"`, maximum amount of memory will be saved by running only one slice at a time.
         | 
| 641 | 
            +
                            If a number is provided, uses as many slices as `num_attention_heads // slice_size`.
         | 
| 642 | 
            +
                            In this case, `num_attention_heads` must be a multiple of `slice_size`.
         | 
| 643 | 
            +
                    """
         | 
| 644 | 
            +
                    sliceable_head_dims = []
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
         | 
| 647 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 648 | 
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                        for child in module.children():
         | 
| 651 | 
            +
                            fn_recursive_retrieve_sliceable_dims(child)
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    # retrieve number of attention layers
         | 
| 654 | 
            +
                    for module in self.children():
         | 
| 655 | 
            +
                        fn_recursive_retrieve_sliceable_dims(module)
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    num_sliceable_layers = len(sliceable_head_dims)
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    if slice_size == "auto":
         | 
| 660 | 
            +
                        # half the attention head size is usually a good trade-off between
         | 
| 661 | 
            +
                        # speed and memory
         | 
| 662 | 
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         | 
| 663 | 
            +
                    elif slice_size == "max":
         | 
| 664 | 
            +
                        # make smallest slice possible
         | 
| 665 | 
            +
                        slice_size = num_sliceable_layers * [1]
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    if len(slice_size) != len(sliceable_head_dims):
         | 
| 670 | 
            +
                        raise ValueError(
         | 
| 671 | 
            +
                            f"You have provided {len(slice_size)}, but {self.config} has "
         | 
| 672 | 
            +
                            f"{len(sliceable_head_dims)} different attention layers. "
         | 
| 673 | 
            +
                            f"Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         | 
| 674 | 
            +
                        )
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    for i in range(len(slice_size)):
         | 
| 677 | 
            +
                        size = slice_size[i]
         | 
| 678 | 
            +
                        dim = sliceable_head_dims[i]
         | 
| 679 | 
            +
                        if size is not None and size > dim:
         | 
| 680 | 
            +
                            raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    # Recursively walk through all the children.
         | 
| 683 | 
            +
                    # Any children which exposes the set_attention_slice method
         | 
| 684 | 
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         | 
| 685 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 686 | 
            +
                            module.set_attention_slice(slice_size.pop())
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                        for child in module.children():
         | 
| 689 | 
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                    reversed_slice_size = list(reversed(slice_size))
         | 
| 692 | 
            +
                    for module in self.children():
         | 
| 693 | 
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 696 | 
            +
                    if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
         | 
| 697 | 
            +
                        module.gradient_checkpointing = value
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                def _prepare_tensor(self, value, device):
         | 
| 700 | 
            +
                    if not torch.is_tensor(value):
         | 
| 701 | 
            +
                        # Requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         | 
| 702 | 
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 703 | 
            +
                        if isinstance(value, float):
         | 
| 704 | 
            +
                            dtype = torch.float32 if device.type == "mps" else torch.float64
         | 
| 705 | 
            +
                        else:
         | 
| 706 | 
            +
                            dtype = torch.int32 if device.type == "mps" else torch.int64
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                        return torch.tensor([value], dtype=dtype, device=device)
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                    elif len(value.shape) == 0:
         | 
| 711 | 
            +
                        return value[None].to(device)
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                    else:
         | 
| 714 | 
            +
                        return value
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                def forward(
         | 
| 717 | 
            +
                    self,
         | 
| 718 | 
            +
                    sample: torch.FloatTensor,
         | 
| 719 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 720 | 
            +
                    guidance: Union[torch.Tensor, float, int],
         | 
| 721 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 722 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 723 | 
            +
                    timestep_cond: Optional[torch.Tensor] = None,
         | 
| 724 | 
            +
                    guidance_cond: Optional[torch.Tensor] = None,
         | 
| 725 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 726 | 
            +
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 727 | 
            +
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         | 
| 728 | 
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         | 
| 729 | 
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         | 
| 730 | 
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 731 | 
            +
                    return_dict: bool = True,
         | 
| 732 | 
            +
                    **kwargs
         | 
| 733 | 
            +
                ) -> Union[UNet2DConditionOutput, Tuple]:
         | 
| 734 | 
            +
                    r"""
         | 
| 735 | 
            +
                    Args:
         | 
| 736 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 737 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 738 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): 
         | 
| 739 | 
            +
                            (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 740 | 
            +
                        encoder_attention_mask (`torch.Tensor`):
         | 
| 741 | 
            +
                            (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep,
         | 
| 742 | 
            +
                            False = discard. Mask will be converted into a bias, which adds large negative values to
         | 
| 743 | 
            +
                            attention scores corresponding to "discard" tokens.
         | 
| 744 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 745 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`]
         | 
| 746 | 
            +
                            instead of a plain tuple.
         | 
| 747 | 
            +
                        cross_attention_kwargs (`dict`, *optional*):
         | 
| 748 | 
            +
                            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined
         | 
| 749 | 
            +
                            under `self.processor` in [diffusers.cross_attention]
         | 
| 750 | 
            +
                            (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
         | 
| 751 | 
            +
                        added_cond_kwargs (`dict`, *optional*):
         | 
| 752 | 
            +
                            A kwargs dictionary that if specified includes additonal conditions that can be used for
         | 
| 753 | 
            +
                            additonal time embeddings or encoder hidden states projections. See the configurations
         | 
| 754 | 
            +
                            `encoder_hid_dim_type` and `addition_embed_type` for more information.
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    Returns:
         | 
| 757 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
         | 
| 758 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
         | 
| 759 | 
            +
                            When returning a tuple, the first element is the sample tensor.
         | 
| 760 | 
            +
                    """
         | 
| 761 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 762 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
         | 
| 763 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 764 | 
            +
                    # on the fly if necessary.
         | 
| 765 | 
            +
                    default_overall_up_factor = 2 ** self.num_upsamplers
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 768 | 
            +
                    forward_upsample_size = False
         | 
| 769 | 
            +
                    upsample_size = None
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 772 | 
            +
                        logger.info("Forward upsample size to force interpolation output size.")
         | 
| 773 | 
            +
                        forward_upsample_size = True
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                    # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
         | 
| 776 | 
            +
                    # expects mask of shape:
         | 
| 777 | 
            +
                    #   [batch, key_tokens]
         | 
| 778 | 
            +
                    # adds singleton query_tokens dimension:
         | 
| 779 | 
            +
                    #   [batch,                    1, key_tokens]
         | 
| 780 | 
            +
                    # this helps to broadcast it as a bias over attention scores,
         | 
| 781 | 
            +
                    # which will be in one of the following shapes:
         | 
| 782 | 
            +
                    #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         | 
| 783 | 
            +
                    #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         | 
| 784 | 
            +
                    if attention_mask is not None:
         | 
| 785 | 
            +
                        # assume that mask is expressed as:
         | 
| 786 | 
            +
                        #   (1 = keep,      0 = discard)
         | 
| 787 | 
            +
                        # convert mask into a bias that can be added to attention scores:
         | 
| 788 | 
            +
                        #       (keep = +0,     discard = -10000.0)
         | 
| 789 | 
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         | 
| 790 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 793 | 
            +
                    if encoder_attention_mask is not None:
         | 
| 794 | 
            +
                        encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * (-10000.0)
         | 
| 795 | 
            +
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                    # 0. center input if necessary
         | 
| 798 | 
            +
                    if self.config.center_input_sample:
         | 
| 799 | 
            +
                        sample = 2 * sample - 1.0
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                    # 1. time and guidance
         | 
| 802 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 803 | 
            +
                    timestep = self._prepare_tensor(timestep, sample.device).expand(sample.shape[0])
         | 
| 804 | 
            +
                    # Project to get embedding
         | 
| 805 | 
            +
                    # `Timestep` does not contain any weights and will always return fp32 tensors
         | 
| 806 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 807 | 
            +
                    t_emb = self.time_proj(timestep).to(dtype=sample.dtype)
         | 
| 808 | 
            +
                    t_emb = self.time_embedding(t_emb, timestep_cond)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                    guidance = self._prepare_tensor(guidance, sample.device).expand(sample.shape[0])
         | 
| 811 | 
            +
                    g_emb = self.guidance_proj(guidance).to(dtype=sample.dtype)
         | 
| 812 | 
            +
                    g_emb = self.guidance_embedding(g_emb, guidance_cond)
         | 
| 813 | 
            +
             | 
| 814 | 
            +
                    # 1.5. prepare other embeddings
         | 
| 815 | 
            +
                    if self.class_embedding is None:
         | 
| 816 | 
            +
                        emb = t_emb + g_emb
         | 
| 817 | 
            +
                    else:
         | 
| 818 | 
            +
                        if class_labels is None:
         | 
| 819 | 
            +
                            raise ValueError("class_labels should be provided when num_class_embeds > 0")
         | 
| 820 | 
            +
                        if self.config.class_embed_type == "timestep":
         | 
| 821 | 
            +
                            class_labels = self.time_proj(class_labels).to(dtype=sample.dtype)
         | 
| 822 | 
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                        if self.config.class_embeddings_concat:
         | 
| 825 | 
            +
                            emb = torch.cat([t_emb, g_emb, class_emb], dim=-1)
         | 
| 826 | 
            +
                        else:
         | 
| 827 | 
            +
                            emb = t_emb + g_emb + class_emb
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                    if self.config.addition_embed_type == "text":
         | 
| 830 | 
            +
                        aug_emb = self.add_embedding(encoder_hidden_states)
         | 
| 831 | 
            +
                        emb = emb + aug_emb
         | 
| 832 | 
            +
                    elif self.config.addition_embed_type == "text_image":
         | 
| 833 | 
            +
                        # Kadinsky 2.1 - style
         | 
| 834 | 
            +
                        if "image_embeds" not in added_cond_kwargs:
         | 
| 835 | 
            +
                            raise ValueError(
         | 
| 836 | 
            +
                                f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
         | 
| 837 | 
            +
                                "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
         | 
| 838 | 
            +
                            )
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                        image_embs = added_cond_kwargs.get("image_embeds")
         | 
| 841 | 
            +
                        text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                        aug_emb = self.add_embedding(text_embs, image_embs)
         | 
| 844 | 
            +
                        emb = emb + aug_emb
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    if self.time_embed_act is not None:
         | 
| 847 | 
            +
                        emb = self.time_embed_act(emb)
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                    if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
         | 
| 850 | 
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
         | 
| 851 | 
            +
                    elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
         | 
| 852 | 
            +
                        # Kadinsky 2.1 - style
         | 
| 853 | 
            +
                        if "image_embeds" not in added_cond_kwargs:
         | 
| 854 | 
            +
                            raise ValueError(
         | 
| 855 | 
            +
                                f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
         | 
| 856 | 
            +
                                "which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
         | 
| 857 | 
            +
                            )
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                        image_embeds = added_cond_kwargs.get("image_embeds")
         | 
| 860 | 
            +
                        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
         | 
| 861 | 
            +
             | 
| 862 | 
            +
                    # 2. pre-process
         | 
| 863 | 
            +
                    sample = self.conv_in(sample)
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                    # 3. down
         | 
| 866 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 867 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 868 | 
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         | 
| 869 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 870 | 
            +
                                hidden_states=sample,
         | 
| 871 | 
            +
                                temb=emb,
         | 
| 872 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 873 | 
            +
                                attention_mask=attention_mask,
         | 
| 874 | 
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 875 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 876 | 
            +
                            )
         | 
| 877 | 
            +
                        else:
         | 
| 878 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                        down_block_res_samples += res_samples
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                    if down_block_additional_residuals is not None:
         | 
| 883 | 
            +
                        new_down_block_res_samples = ()
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                        for down_block_res_sample, down_block_additional_residual in zip(
         | 
| 886 | 
            +
                            down_block_res_samples, down_block_additional_residuals
         | 
| 887 | 
            +
                        ):
         | 
| 888 | 
            +
                            down_block_res_sample = down_block_res_sample + down_block_additional_residual
         | 
| 889 | 
            +
                            new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
         | 
| 890 | 
            +
             | 
| 891 | 
            +
                        down_block_res_samples = new_down_block_res_samples
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                    # 4. mid
         | 
| 894 | 
            +
                    if self.mid_block is not None:
         | 
| 895 | 
            +
                        sample = self.mid_block(
         | 
| 896 | 
            +
                            sample,
         | 
| 897 | 
            +
                            emb,
         | 
| 898 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 899 | 
            +
                            attention_mask=attention_mask,
         | 
| 900 | 
            +
                            cross_attention_kwargs=cross_attention_kwargs,
         | 
| 901 | 
            +
                            encoder_attention_mask=encoder_attention_mask,
         | 
| 902 | 
            +
                        )
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                    if mid_block_additional_residual is not None:
         | 
| 905 | 
            +
                        sample = sample + mid_block_additional_residual
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                    # 5. up
         | 
| 908 | 
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         | 
| 909 | 
            +
                        is_final_block = i == len(self.up_blocks) - 1
         | 
| 910 | 
            +
             | 
| 911 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 912 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                        # if we have not reached the final block and need to forward the
         | 
| 915 | 
            +
                        # upsample size, we do it here
         | 
| 916 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 917 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         | 
| 920 | 
            +
                            sample = upsample_block(
         | 
| 921 | 
            +
                                hidden_states=sample,
         | 
| 922 | 
            +
                                temb=emb,
         | 
| 923 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 924 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 925 | 
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 926 | 
            +
                                upsample_size=upsample_size,
         | 
| 927 | 
            +
                                attention_mask=attention_mask,
         | 
| 928 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 929 | 
            +
                            )
         | 
| 930 | 
            +
                        else:
         | 
| 931 | 
            +
                            sample = upsample_block(
         | 
| 932 | 
            +
                                hidden_states=sample, temb=emb,
         | 
| 933 | 
            +
                                res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         | 
| 934 | 
            +
                            )
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                    # 6. post-process
         | 
| 937 | 
            +
                    if self.conv_norm_out:
         | 
| 938 | 
            +
                        sample = self.conv_norm_out(sample)
         | 
| 939 | 
            +
                        sample = self.conv_act(sample)
         | 
| 940 | 
            +
                    sample = self.conv_out(sample)
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                    if not return_dict:
         | 
| 943 | 
            +
                        return (sample,)
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                    return UNet2DConditionOutput(sample=sample)
         | 
    	
        diffusers/scheduling_heun_discrete.py
    ADDED
    
    | @@ -0,0 +1,387 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ### This file has been modified for the purposes of the ConsistencyTTA generation. ###
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import math
         | 
| 18 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import numpy as np
         | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from .utils.configuration_utils import ConfigMixin, register_to_config
         | 
| 24 | 
            +
            from .utils.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         | 
| 28 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines 
         | 
| 31 | 
            +
                    the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Contains a function alpha_bar that takes an argument t and transforms it to the 
         | 
| 34 | 
            +
                    cumulative product of (1-beta) up to that part of the diffusion process.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
                Args:
         | 
| 38 | 
            +
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         | 
| 39 | 
            +
                    max_beta (`float`): 
         | 
| 40 | 
            +
                        the maximum beta to use; use values lower than 1 to prevent singularities.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Returns:
         | 
| 43 | 
            +
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def alpha_bar(time_step):
         | 
| 47 | 
            +
                    return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                betas = []
         | 
| 50 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 51 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 52 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 53 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 54 | 
            +
                return torch.tensor(betas, dtype=torch.float32)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. 
         | 
| 60 | 
            +
                Based on the original k-diffusion implementation by Katherine Crowson:
         | 
| 61 | 
            +
                https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/
         | 
| 62 | 
            +
                k_diffusion/sampling.py#L90
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                [`~ConfigMixin`] takes care of storing all config attributes that are passed
         | 
| 65 | 
            +
                in the scheduler's `__init__` function, such as `num_train_timesteps`. 
         | 
| 66 | 
            +
                They can be accessed via `scheduler.config.num_train_timesteps`.
         | 
| 67 | 
            +
                [`SchedulerMixin`] provides general loading and saving functionality via the
         | 
| 68 | 
            +
                [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                Args:
         | 
| 71 | 
            +
                    num_train_timesteps (`int`):
         | 
| 72 | 
            +
                        number of diffusion steps used to train the model. 
         | 
| 73 | 
            +
                    beta_start (`float`):
         | 
| 74 | 
            +
                        the starting `beta` value of inference. 
         | 
| 75 | 
            +
                    beta_end (`float`):
         | 
| 76 | 
            +
                        the final `beta` value. 
         | 
| 77 | 
            +
                    beta_schedule (`str`):
         | 
| 78 | 
            +
                        the beta schedule, a mapping from a beta range to a sequence of betas for stepping
         | 
| 79 | 
            +
                        the model. Choose from `linear` or `scaled_linear`.
         | 
| 80 | 
            +
                    trained_betas (`np.ndarray`, optional):
         | 
| 81 | 
            +
                        option to pass an array of betas directly to the constructor to bypass 
         | 
| 82 | 
            +
                        `beta_start`, `beta_end` etc.
         | 
| 83 | 
            +
                        options to clip the variance used when adding noise to the denoised sample.
         | 
| 84 | 
            +
                        Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, 
         | 
| 85 | 
            +
                        `fixed_large_log`, `learned` or `learned_range`.
         | 
| 86 | 
            +
                    prediction_type (`str`, default `epsilon`, optional):
         | 
| 87 | 
            +
                        prediction type of the scheduler function, one of 
         | 
| 88 | 
            +
                        `epsilon` (predicting the noise of the diffusion process), 
         | 
| 89 | 
            +
                        `sample` (directly predicting the noisy sample`), or 
         | 
| 90 | 
            +
                        `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                _compatibles = [e.name for e in KarrasDiffusionSchedulers]
         | 
| 94 | 
            +
                order = 2
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @register_to_config
         | 
| 97 | 
            +
                def __init__(
         | 
| 98 | 
            +
                    self,
         | 
| 99 | 
            +
                    num_train_timesteps: int = 1000,
         | 
| 100 | 
            +
                    beta_start: float = 0.00085,  # sensible defaults
         | 
| 101 | 
            +
                    beta_end: float = 0.012,
         | 
| 102 | 
            +
                    beta_schedule: str = "linear",
         | 
| 103 | 
            +
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         | 
| 104 | 
            +
                    prediction_type: str = "epsilon",
         | 
| 105 | 
            +
                    use_karras_sigmas: Optional[bool] = False,
         | 
| 106 | 
            +
                ):
         | 
| 107 | 
            +
                    if trained_betas is not None:
         | 
| 108 | 
            +
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         | 
| 109 | 
            +
                    elif beta_schedule == "linear":
         | 
| 110 | 
            +
                        self.betas = torch.linspace(
         | 
| 111 | 
            +
                            beta_start, beta_end, num_train_timesteps, dtype=torch.float32
         | 
| 112 | 
            +
                        )
         | 
| 113 | 
            +
                    elif beta_schedule == "scaled_linear":
         | 
| 114 | 
            +
                        # this schedule is very specific to the latent diffusion model.
         | 
| 115 | 
            +
                        self.betas = (
         | 
| 116 | 
            +
                            torch.linspace(
         | 
| 117 | 
            +
                                beta_start ** 0.5, beta_end ** 0.5,
         | 
| 118 | 
            +
                                num_train_timesteps, dtype=torch.float32
         | 
| 119 | 
            +
                            ) ** 2
         | 
| 120 | 
            +
                        )
         | 
| 121 | 
            +
                    elif beta_schedule == "squaredcos_cap_v2":
         | 
| 122 | 
            +
                        # Glide cosine schedule
         | 
| 123 | 
            +
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        raise NotImplementedError(
         | 
| 126 | 
            +
                            f"{beta_schedule} does is not implemented for {self.__class__}"
         | 
| 127 | 
            +
                        )
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.alphas = 1.0 - self.betas
         | 
| 130 | 
            +
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # set all values
         | 
| 133 | 
            +
                    self.use_karras_sigmas = use_karras_sigmas
         | 
| 134 | 
            +
                    self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def index_for_timestep(self, timestep):
         | 
| 137 | 
            +
                    """Get the first / last index at which self.timesteps == timestep
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    assert len(timestep.shape) < 2
         | 
| 140 | 
            +
                    avail_timesteps = self.timesteps.reshape(1, -1).to(timestep.device)
         | 
| 141 | 
            +
                    mask = (avail_timesteps == timestep.reshape(-1, 1))
         | 
| 142 | 
            +
                    assert (mask.sum(dim=1) != 0).all(), f"timestep: {timestep.tolist()}"
         | 
| 143 | 
            +
                    mask = mask.cpu() * torch.arange(mask.shape[1]).reshape(1, -1)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.state_in_first_order:
         | 
| 146 | 
            +
                        return mask.argmax(dim=1).numpy()
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        return mask.argmax(dim=1).numpy() - 1
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def scale_model_input(
         | 
| 151 | 
            +
                    self,
         | 
| 152 | 
            +
                    sample: torch.FloatTensor,
         | 
| 153 | 
            +
                    timestep: Union[float, torch.FloatTensor],
         | 
| 154 | 
            +
                ) -> torch.FloatTensor:
         | 
| 155 | 
            +
                    """
         | 
| 156 | 
            +
                    Ensures interchangeability with schedulers that need to scale the 
         | 
| 157 | 
            +
                    denoising model input depending on the current timestep.
         | 
| 158 | 
            +
                    Args:
         | 
| 159 | 
            +
                        sample (`torch.FloatTensor`): input sample 
         | 
| 160 | 
            +
                        timestep (`int`, optional): current timestep
         | 
| 161 | 
            +
                    Returns:
         | 
| 162 | 
            +
                        `torch.FloatTensor`: scaled input sample
         | 
| 163 | 
            +
                    """
         | 
| 164 | 
            +
                    if not torch.is_tensor(timestep):
         | 
| 165 | 
            +
                        timestep = torch.tensor(timestep)
         | 
| 166 | 
            +
                    timestep = timestep.to(sample.device).reshape(-1)
         | 
| 167 | 
            +
                    step_index = self.index_for_timestep(timestep)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    sigma = self.sigmas[step_index].reshape(-1, 1, 1, 1).to(sample.device)
         | 
| 170 | 
            +
                    sample = sample / ((sigma ** 2 + 1) ** 0.5)  # sample *= sqrt_alpha_prod
         | 
| 171 | 
            +
                    return sample
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def set_timesteps(
         | 
| 174 | 
            +
                    self,
         | 
| 175 | 
            +
                    num_inference_steps: int,
         | 
| 176 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 177 | 
            +
                    num_train_timesteps: Optional[int] = None,
         | 
| 178 | 
            +
                ):
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    Sets the timesteps used for the diffusion chain. 
         | 
| 181 | 
            +
                    Supporting function to be run before inference.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    Args:
         | 
| 184 | 
            +
                        num_inference_steps (`int`):
         | 
| 185 | 
            +
                            the number of diffusion steps used when generating samples
         | 
| 186 | 
            +
                            with a pre-trained model.
         | 
| 187 | 
            +
                        device (`str` or `torch.device`, optional):
         | 
| 188 | 
            +
                            the device to which the timesteps should be moved to.
         | 
| 189 | 
            +
                            If `None`, the timesteps are not moved.
         | 
| 190 | 
            +
                    """
         | 
| 191 | 
            +
                    self.num_inference_steps = num_inference_steps
         | 
| 192 | 
            +
                    num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    timesteps = np.linspace(
         | 
| 195 | 
            +
                        0, num_train_timesteps - 1, num_inference_steps, dtype=float
         | 
| 196 | 
            +
                    )[::-1].copy()
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # sigma^2 = beta / alpha
         | 
| 199 | 
            +
                    sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
         | 
| 200 | 
            +
                    log_sigmas = np.log(sigmas)
         | 
| 201 | 
            +
                    sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    if self.use_karras_sigmas:
         | 
| 204 | 
            +
                        sigmas = self._convert_to_karras(
         | 
| 205 | 
            +
                            in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                        timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
         | 
| 210 | 
            +
                    sigmas = torch.from_numpy(sigmas).to(device=device)
         | 
| 211 | 
            +
                    self.sigmas = torch.cat(
         | 
| 212 | 
            +
                        [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
         | 
| 213 | 
            +
                    )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # standard deviation of the initial noise distribution
         | 
| 216 | 
            +
                    self.init_noise_sigma = self.sigmas.max()
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    timesteps = torch.from_numpy(timesteps)
         | 
| 219 | 
            +
                    timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
         | 
| 220 | 
            +
                    if 'mps' in str(device):
         | 
| 221 | 
            +
                        timesteps = timesteps.float()
         | 
| 222 | 
            +
                    self.timesteps = timesteps.to(device)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # empty dt and derivative
         | 
| 225 | 
            +
                    self.prev_derivative = None
         | 
| 226 | 
            +
                    self.dt = None
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def _sigma_to_t(self, sigma, log_sigmas):
         | 
| 229 | 
            +
                    # get log sigma
         | 
| 230 | 
            +
                    log_sigma = np.log(sigma)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # get distribution
         | 
| 233 | 
            +
                    dists = log_sigma - log_sigmas[:, np.newaxis]
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    # get sigmas range
         | 
| 236 | 
            +
                    low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(
         | 
| 237 | 
            +
                        max=log_sigmas.shape[0] - 2
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
                    high_idx = low_idx + 1
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    low = log_sigmas[low_idx]
         | 
| 242 | 
            +
                    high = log_sigmas[high_idx]
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    # interpolate sigmas
         | 
| 245 | 
            +
                    w = (low - log_sigma) / (low - high)
         | 
| 246 | 
            +
                    w = np.clip(w, 0, 1)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # transform interpolation to time range
         | 
| 249 | 
            +
                    t = (1 - w) * low_idx + w * high_idx
         | 
| 250 | 
            +
                    t = t.reshape(sigma.shape)
         | 
| 251 | 
            +
                    return t
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def _convert_to_karras(
         | 
| 254 | 
            +
                    self, in_sigmas: torch.FloatTensor, num_inference_steps
         | 
| 255 | 
            +
                ) -> torch.FloatTensor:
         | 
| 256 | 
            +
                    """Constructs the noise schedule of Karras et al. (2022)."""
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    sigma_min: float = in_sigmas[-1].item()
         | 
| 259 | 
            +
                    sigma_max: float = in_sigmas[0].item()
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    rho = 7.0  # 7.0 is the value used in the paper
         | 
| 262 | 
            +
                    ramp = np.linspace(0, 1, num_inference_steps)
         | 
| 263 | 
            +
                    min_inv_rho = sigma_min ** (1 / rho)
         | 
| 264 | 
            +
                    max_inv_rho = sigma_max ** (1 / rho)
         | 
| 265 | 
            +
                    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
         | 
| 266 | 
            +
                    return sigmas
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                @property
         | 
| 269 | 
            +
                def state_in_first_order(self):
         | 
| 270 | 
            +
                    return self.dt is None
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def step(
         | 
| 273 | 
            +
                    self,
         | 
| 274 | 
            +
                    model_output: Union[torch.FloatTensor, np.ndarray],
         | 
| 275 | 
            +
                    timestep: Union[float, torch.FloatTensor],
         | 
| 276 | 
            +
                    sample: Union[torch.FloatTensor, np.ndarray],
         | 
| 277 | 
            +
                    return_dict: bool = True,
         | 
| 278 | 
            +
                ) -> Union[SchedulerOutput, Tuple]:
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    Predict the sample at the previous timestep by reversing the SDE. 
         | 
| 281 | 
            +
                    Core function to propagate the diffusion process from the learned 
         | 
| 282 | 
            +
                    model outputs (most often the predicted noise).
         | 
| 283 | 
            +
                    Args:
         | 
| 284 | 
            +
                        model_output (`torch.FloatTensor` or `np.ndarray`): 
         | 
| 285 | 
            +
                            direct output from learned diffusion model. 
         | 
| 286 | 
            +
                        timestep (`int`): 
         | 
| 287 | 
            +
                            current discrete timestep in the diffusion chain. 
         | 
| 288 | 
            +
                        sample (`torch.FloatTensor` or `np.ndarray`):
         | 
| 289 | 
            +
                            current instance of sample being created by diffusion process.
         | 
| 290 | 
            +
                        return_dict (`bool`): 
         | 
| 291 | 
            +
                            option for returning tuple rather than SchedulerOutput class
         | 
| 292 | 
            +
                    Returns:
         | 
| 293 | 
            +
                        [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
         | 
| 294 | 
            +
                        [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` 
         | 
| 295 | 
            +
                            is True, otherwise a `tuple`. When returning a tuple,
         | 
| 296 | 
            +
                            the first element is the sample tensor.
         | 
| 297 | 
            +
                    """
         | 
| 298 | 
            +
                    if not torch.is_tensor(timestep):
         | 
| 299 | 
            +
                        timestep = torch.tensor(timestep)
         | 
| 300 | 
            +
                    timestep = timestep.reshape(-1).to(sample.device)
         | 
| 301 | 
            +
                    step_index = self.index_for_timestep(timestep)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    if self.state_in_first_order:
         | 
| 304 | 
            +
                        sigma = self.sigmas[step_index]
         | 
| 305 | 
            +
                        sigma_next = self.sigmas[step_index + 1]
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        # 2nd order / Heun's method
         | 
| 308 | 
            +
                        sigma = self.sigmas[step_index - 1]
         | 
| 309 | 
            +
                        sigma_next = self.sigmas[step_index]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    sigma = sigma.reshape(-1, 1, 1, 1).to(sample.device)
         | 
| 312 | 
            +
                    sigma_next = sigma_next.reshape(-1, 1, 1, 1).to(sample.device)
         | 
| 313 | 
            +
                    sigma_input = sigma if self.state_in_first_order else sigma_next
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
         | 
| 316 | 
            +
                    if self.config.prediction_type == "epsilon":
         | 
| 317 | 
            +
                        pred_original_sample = sample - sigma_input * model_output
         | 
| 318 | 
            +
                    elif self.config.prediction_type == "v_prediction":
         | 
| 319 | 
            +
                        alpha_prod = 1 / (sigma_input ** 2 + 1)
         | 
| 320 | 
            +
                        pred_original_sample = (
         | 
| 321 | 
            +
                            sample * alpha_prod - model_output * (sigma_input * alpha_prod ** .5)
         | 
| 322 | 
            +
                        )
         | 
| 323 | 
            +
                    elif self.config.prediction_type == "sample":
         | 
| 324 | 
            +
                        raise NotImplementedError("prediction_type not implemented yet: sample")
         | 
| 325 | 
            +
                    else:
         | 
| 326 | 
            +
                        raise ValueError(
         | 
| 327 | 
            +
                            f"prediction_type given as {self.config.prediction_type} "
         | 
| 328 | 
            +
                            "must be one of `epsilon`, or `v_prediction`"
         | 
| 329 | 
            +
                        )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    if self.state_in_first_order:
         | 
| 332 | 
            +
                        # 2. Convert to an ODE derivative for 1st order
         | 
| 333 | 
            +
                        derivative = (sample - pred_original_sample) / sigma
         | 
| 334 | 
            +
                        # 3. delta timestep
         | 
| 335 | 
            +
                        dt = sigma_next - sigma
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                        # store for 2nd order step
         | 
| 338 | 
            +
                        self.prev_derivative = derivative
         | 
| 339 | 
            +
                        self.dt = dt
         | 
| 340 | 
            +
                        self.sample = sample
         | 
| 341 | 
            +
                    else:
         | 
| 342 | 
            +
                        # 2. 2nd order / Heun's method
         | 
| 343 | 
            +
                        derivative = (sample - pred_original_sample) / sigma_next
         | 
| 344 | 
            +
                        derivative = (self.prev_derivative + derivative) / 2
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                        # 3. take prev timestep & sample
         | 
| 347 | 
            +
                        dt = self.dt
         | 
| 348 | 
            +
                        sample = self.sample
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                        # free dt and derivative
         | 
| 351 | 
            +
                        # Note, this puts the scheduler in "first order mode"
         | 
| 352 | 
            +
                        self.prev_derivative = None
         | 
| 353 | 
            +
                        self.dt = None
         | 
| 354 | 
            +
                        self.sample = None
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    prev_sample = sample + derivative * dt
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if not return_dict:
         | 
| 359 | 
            +
                        return (prev_sample,)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    return SchedulerOutput(prev_sample=prev_sample)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def add_noise(
         | 
| 364 | 
            +
                    self,
         | 
| 365 | 
            +
                    original_samples: torch.FloatTensor,
         | 
| 366 | 
            +
                    noise: torch.FloatTensor,
         | 
| 367 | 
            +
                    timesteps: torch.FloatTensor,
         | 
| 368 | 
            +
                ) -> torch.FloatTensor:
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # Make sure sigmas and timesteps have the same device and dtype as original_samples
         | 
| 371 | 
            +
                    self.sigmas = self.sigmas.to(
         | 
| 372 | 
            +
                        device=original_samples.device, dtype=original_samples.dtype
         | 
| 373 | 
            +
                    )
         | 
| 374 | 
            +
                    self.timesteps = self.timesteps.to(original_samples.device)
         | 
| 375 | 
            +
                    timesteps = timesteps.to(original_samples.device)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    step_indices = self.index_for_timestep(timesteps)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    sigma = self.sigmas[step_indices].flatten()
         | 
| 380 | 
            +
                    while len(sigma.shape) < len(original_samples.shape):
         | 
| 381 | 
            +
                        sigma = sigma.unsqueeze(-1)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    noisy_samples = original_samples + noise * sigma
         | 
| 384 | 
            +
                    return noisy_samples
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                def __len__(self):
         | 
| 387 | 
            +
                    return self.config.num_train_timesteps
         | 
    	
        diffusers/utils/configuration_utils.py
    ADDED
    
    | @@ -0,0 +1,647 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 The HuggingFace Inc. team.
         | 
| 3 | 
            +
            # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
            """ ConfigMixin base class and utilities."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import dataclasses
         | 
| 19 | 
            +
            import functools
         | 
| 20 | 
            +
            import importlib
         | 
| 21 | 
            +
            import inspect
         | 
| 22 | 
            +
            import json
         | 
| 23 | 
            +
            import os
         | 
| 24 | 
            +
            import re
         | 
| 25 | 
            +
            from collections import OrderedDict
         | 
| 26 | 
            +
            from pathlib import PosixPath
         | 
| 27 | 
            +
            from typing import Any, Dict, Tuple, Union
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            import numpy as np
         | 
| 30 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 31 | 
            +
            from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
         | 
| 32 | 
            +
            from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
         | 
| 33 | 
            +
            from requests import HTTPError
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            from .import_utils import DummyObject
         | 
| 36 | 
            +
            from .deprecation_utils import deprecate
         | 
| 37 | 
            +
            from .hub_utils import extract_commit_hash, http_user_agent
         | 
| 38 | 
            +
            from .logging import get_logger
         | 
| 39 | 
            +
            from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            logger = get_logger(__name__)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            _re_configuration_file = re.compile(r"config\.(.*)\.json")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class FrozenDict(OrderedDict):
         | 
| 48 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 49 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    for key, value in self.items():
         | 
| 52 | 
            +
                        setattr(self, key, value)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.__frozen = True
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def __delitem__(self, *args, **kwargs):
         | 
| 57 | 
            +
                    raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def setdefault(self, *args, **kwargs):
         | 
| 60 | 
            +
                    raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def pop(self, *args, **kwargs):
         | 
| 63 | 
            +
                    raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def update(self, *args, **kwargs):
         | 
| 66 | 
            +
                    raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __setattr__(self, name, value):
         | 
| 69 | 
            +
                    if hasattr(self, "__frozen") and self.__frozen:
         | 
| 70 | 
            +
                        raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
         | 
| 71 | 
            +
                    super().__setattr__(name, value)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __setitem__(self, name, value):
         | 
| 74 | 
            +
                    if hasattr(self, "__frozen") and self.__frozen:
         | 
| 75 | 
            +
                        raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
         | 
| 76 | 
            +
                    super().__setitem__(name, value)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class ConfigMixin:
         | 
| 80 | 
            +
                r"""
         | 
| 81 | 
            +
                Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
         | 
| 82 | 
            +
                provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
         | 
| 83 | 
            +
                saving classes that inherit from [`ConfigMixin`].
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                Class attributes:
         | 
| 86 | 
            +
                    - **config_name** (`str`) -- A filename under which the config should stored when calling
         | 
| 87 | 
            +
                      [`~ConfigMixin.save_config`] (should be overridden by parent class).
         | 
| 88 | 
            +
                    - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
         | 
| 89 | 
            +
                      overridden by subclass).
         | 
| 90 | 
            +
                    - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
         | 
| 91 | 
            +
                    - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
         | 
| 92 | 
            +
                      should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
         | 
| 93 | 
            +
                      subclass).
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                config_name = None
         | 
| 96 | 
            +
                ignore_for_config = []
         | 
| 97 | 
            +
                has_compatibles = False
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                _deprecated_kwargs = []
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def register_to_config(self, **kwargs):
         | 
| 102 | 
            +
                    if self.config_name is None:
         | 
| 103 | 
            +
                        raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
         | 
| 104 | 
            +
                    # Special case for `kwargs` used in deprecation warning added to schedulers
         | 
| 105 | 
            +
                    # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
         | 
| 106 | 
            +
                    # or solve in a more general way.
         | 
| 107 | 
            +
                    kwargs.pop("kwargs", None)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    if not hasattr(self, "_internal_dict"):
         | 
| 110 | 
            +
                        internal_dict = kwargs
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        previous_dict = dict(self._internal_dict)
         | 
| 113 | 
            +
                        internal_dict = {**self._internal_dict, **kwargs}
         | 
| 114 | 
            +
                        logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self._internal_dict = FrozenDict(internal_dict)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def __getattr__(self, name: str) -> Any:
         | 
| 119 | 
            +
                    """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
         | 
| 120 | 
            +
                    config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
         | 
| 123 | 
            +
                    https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
         | 
| 127 | 
            +
                    is_attribute = name in self.__dict__
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    if is_in_config and not is_attribute:
         | 
| 130 | 
            +
                        deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
         | 
| 131 | 
            +
                        deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 132 | 
            +
                        return self._internal_dict[name]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
         | 
| 139 | 
            +
                    [`~ConfigMixin.from_config`] class method.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    Args:
         | 
| 142 | 
            +
                        save_directory (`str` or `os.PathLike`):
         | 
| 143 | 
            +
                            Directory where the configuration JSON file is saved (will be created if it does not exist).
         | 
| 144 | 
            +
                    """
         | 
| 145 | 
            +
                    if os.path.isfile(save_directory):
         | 
| 146 | 
            +
                        raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    os.makedirs(save_directory, exist_ok=True)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # If we save using the predefined names, we can load using `from_config`
         | 
| 151 | 
            +
                    output_config_file = os.path.join(save_directory, self.config_name)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    self.to_json_file(output_config_file)
         | 
| 154 | 
            +
                    logger.info(f"Configuration saved in {output_config_file}")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                @classmethod
         | 
| 157 | 
            +
                def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
         | 
| 158 | 
            +
                    r"""
         | 
| 159 | 
            +
                    Instantiate a Python class from a config dictionary.
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    Parameters:
         | 
| 162 | 
            +
                        config (`Dict[str, Any]`):
         | 
| 163 | 
            +
                            A config dictionary from which the Python class is instantiated. Make sure to only load configuration
         | 
| 164 | 
            +
                            files of compatible classes.
         | 
| 165 | 
            +
                        return_unused_kwargs (`bool`, *optional*, defaults to `False`):
         | 
| 166 | 
            +
                            Whether kwargs that are not consumed by the Python class should be returned or not.
         | 
| 167 | 
            +
                        kwargs (remaining dictionary of keyword arguments, *optional*):
         | 
| 168 | 
            +
                            Can be used to update the configuration object (after it is loaded) and initiate the Python class.
         | 
| 169 | 
            +
                            `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
         | 
| 170 | 
            +
                            overwrite the same named arguments in `config`.
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    Returns:
         | 
| 173 | 
            +
                        [`ModelMixin`] or [`SchedulerMixin`]:
         | 
| 174 | 
            +
                            A model or scheduler object instantiated from a config dictionary.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    Examples:
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    ```python
         | 
| 179 | 
            +
                    >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    >>> # Download scheduler from huggingface.co and cache.
         | 
| 182 | 
            +
                    >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    >>> # Instantiate DDIM scheduler class with same config as DDPM
         | 
| 185 | 
            +
                    >>> scheduler = DDIMScheduler.from_config(scheduler.config)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    >>> # Instantiate PNDM scheduler class with same config as DDPM
         | 
| 188 | 
            +
                    >>> scheduler = PNDMScheduler.from_config(scheduler.config)
         | 
| 189 | 
            +
                    ```
         | 
| 190 | 
            +
                    """
         | 
| 191 | 
            +
                    # <===== TO BE REMOVED WITH DEPRECATION
         | 
| 192 | 
            +
                    # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
         | 
| 193 | 
            +
                    if "pretrained_model_name_or_path" in kwargs:
         | 
| 194 | 
            +
                        config = kwargs.pop("pretrained_model_name_or_path")
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    if config is None:
         | 
| 197 | 
            +
                        raise ValueError("Please make sure to provide a config as the first positional argument.")
         | 
| 198 | 
            +
                    # ======>
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    if not isinstance(config, dict):
         | 
| 201 | 
            +
                        deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
         | 
| 202 | 
            +
                        if "Scheduler" in cls.__name__:
         | 
| 203 | 
            +
                            deprecation_message += (
         | 
| 204 | 
            +
                                f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
         | 
| 205 | 
            +
                                " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
         | 
| 206 | 
            +
                                " be removed in v1.0.0."
         | 
| 207 | 
            +
                            )
         | 
| 208 | 
            +
                        elif "Model" in cls.__name__:
         | 
| 209 | 
            +
                            deprecation_message += (
         | 
| 210 | 
            +
                                f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
         | 
| 211 | 
            +
                                f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
         | 
| 212 | 
            +
                                " instead. This functionality will be removed in v1.0.0."
         | 
| 213 | 
            +
                            )
         | 
| 214 | 
            +
                        deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 215 | 
            +
                        config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # Allow dtype to be specified on initialization
         | 
| 220 | 
            +
                    if "dtype" in unused_kwargs:
         | 
| 221 | 
            +
                        init_dict["dtype"] = unused_kwargs.pop("dtype")
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # add possible deprecated kwargs
         | 
| 224 | 
            +
                    for deprecated_kwarg in cls._deprecated_kwargs:
         | 
| 225 | 
            +
                        if deprecated_kwarg in unused_kwargs:
         | 
| 226 | 
            +
                            init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # Return model and optionally state and/or unused_kwargs
         | 
| 229 | 
            +
                    model = cls(**init_dict)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    # make sure to also save config parameters that might be used for compatible classes
         | 
| 232 | 
            +
                    model.register_to_config(**hidden_dict)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # add hidden kwargs of compatible classes to unused_kwargs
         | 
| 235 | 
            +
                    unused_kwargs = {**unused_kwargs, **hidden_dict}
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if return_unused_kwargs:
         | 
| 238 | 
            +
                        return (model, unused_kwargs)
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        return model
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                @classmethod
         | 
| 243 | 
            +
                def get_config_dict(cls, *args, **kwargs):
         | 
| 244 | 
            +
                    deprecation_message = (
         | 
| 245 | 
            +
                        f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
         | 
| 246 | 
            +
                        " removed in version v1.0.0"
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
                    deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 249 | 
            +
                    return cls.load_config(*args, **kwargs)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                @classmethod
         | 
| 252 | 
            +
                def load_config(
         | 
| 253 | 
            +
                    cls,
         | 
| 254 | 
            +
                    pretrained_model_name_or_path: Union[str, os.PathLike],
         | 
| 255 | 
            +
                    return_unused_kwargs=False,
         | 
| 256 | 
            +
                    return_commit_hash=False,
         | 
| 257 | 
            +
                    **kwargs,
         | 
| 258 | 
            +
                ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
         | 
| 259 | 
            +
                    r"""
         | 
| 260 | 
            +
                    Load a model or scheduler configuration.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    Parameters:
         | 
| 263 | 
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
         | 
| 264 | 
            +
                            Can be either:
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                                - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
         | 
| 267 | 
            +
                                  the Hub.
         | 
| 268 | 
            +
                                - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
         | 
| 269 | 
            +
                                  [`~ConfigMixin.save_config`].
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 272 | 
            +
                            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
         | 
| 273 | 
            +
                            is not used.
         | 
| 274 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 275 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 276 | 
            +
                            cached versions if they exist.
         | 
| 277 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 278 | 
            +
                            Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
         | 
| 279 | 
            +
                            incompletely downloaded files are deleted.
         | 
| 280 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 281 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
         | 
| 282 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 283 | 
            +
                        output_loading_info(`bool`, *optional*, defaults to `False`):
         | 
| 284 | 
            +
                            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
         | 
| 285 | 
            +
                        local_files_only (`bool`, *optional*, defaults to `False`):
         | 
| 286 | 
            +
                            Whether to only load local model weights and configuration files or not. If set to `True`, the model
         | 
| 287 | 
            +
                            won't be downloaded from the Hub.
         | 
| 288 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 289 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
         | 
| 290 | 
            +
                            `diffusers-cli login` (stored in `~/.huggingface`) is used.
         | 
| 291 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 292 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
         | 
| 293 | 
            +
                            allowed by Git.
         | 
| 294 | 
            +
                        subfolder (`str`, *optional*, defaults to `""`):
         | 
| 295 | 
            +
                            The subfolder location of a model file within a larger model repository on the Hub or locally.
         | 
| 296 | 
            +
                        return_unused_kwargs (`bool`, *optional*, defaults to `False):
         | 
| 297 | 
            +
                            Whether unused keyword arguments of the config are returned.
         | 
| 298 | 
            +
                        return_commit_hash (`bool`, *optional*, defaults to `False):
         | 
| 299 | 
            +
                            Whether the `commit_hash` of the loaded configuration are returned.
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    Returns:
         | 
| 302 | 
            +
                        `dict`:
         | 
| 303 | 
            +
                            A dictionary of all the parameters stored in a JSON configuration file.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    """
         | 
| 306 | 
            +
                    cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
         | 
| 307 | 
            +
                    force_download = kwargs.pop("force_download", False)
         | 
| 308 | 
            +
                    resume_download = kwargs.pop("resume_download", False)
         | 
| 309 | 
            +
                    proxies = kwargs.pop("proxies", None)
         | 
| 310 | 
            +
                    use_auth_token = kwargs.pop("use_auth_token", None)
         | 
| 311 | 
            +
                    local_files_only = kwargs.pop("local_files_only", False)
         | 
| 312 | 
            +
                    revision = kwargs.pop("revision", None)
         | 
| 313 | 
            +
                    _ = kwargs.pop("mirror", None)
         | 
| 314 | 
            +
                    subfolder = kwargs.pop("subfolder", None)
         | 
| 315 | 
            +
                    user_agent = kwargs.pop("user_agent", {})
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    user_agent = {**user_agent, "file_type": "config"}
         | 
| 318 | 
            +
                    user_agent = http_user_agent(user_agent)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    if cls.config_name is None:
         | 
| 323 | 
            +
                        raise ValueError(
         | 
| 324 | 
            +
                            "`self.config_name` is not defined. Note that one should not load a config from "
         | 
| 325 | 
            +
                            "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
         | 
| 326 | 
            +
                        )
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    if os.path.isfile(pretrained_model_name_or_path):
         | 
| 329 | 
            +
                        config_file = pretrained_model_name_or_path
         | 
| 330 | 
            +
                    elif os.path.isdir(pretrained_model_name_or_path):
         | 
| 331 | 
            +
                        if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
         | 
| 332 | 
            +
                            # Load from a PyTorch checkpoint
         | 
| 333 | 
            +
                            config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
         | 
| 334 | 
            +
                        elif subfolder is not None and os.path.isfile(
         | 
| 335 | 
            +
                            os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
         | 
| 336 | 
            +
                        ):
         | 
| 337 | 
            +
                            config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
         | 
| 338 | 
            +
                        else:
         | 
| 339 | 
            +
                            raise EnvironmentError(
         | 
| 340 | 
            +
                                f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
         | 
| 341 | 
            +
                            )
         | 
| 342 | 
            +
                    else:
         | 
| 343 | 
            +
                        try:
         | 
| 344 | 
            +
                            # Load from URL or cache if already cached
         | 
| 345 | 
            +
                            config_file = hf_hub_download(
         | 
| 346 | 
            +
                                pretrained_model_name_or_path,
         | 
| 347 | 
            +
                                filename=cls.config_name,
         | 
| 348 | 
            +
                                cache_dir=cache_dir,
         | 
| 349 | 
            +
                                force_download=force_download,
         | 
| 350 | 
            +
                                proxies=proxies,
         | 
| 351 | 
            +
                                resume_download=resume_download,
         | 
| 352 | 
            +
                                local_files_only=local_files_only,
         | 
| 353 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 354 | 
            +
                                user_agent=user_agent,
         | 
| 355 | 
            +
                                subfolder=subfolder,
         | 
| 356 | 
            +
                                revision=revision,
         | 
| 357 | 
            +
                            )
         | 
| 358 | 
            +
                        except RepositoryNotFoundError:
         | 
| 359 | 
            +
                            raise EnvironmentError(
         | 
| 360 | 
            +
                                f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
         | 
| 361 | 
            +
                                " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
         | 
| 362 | 
            +
                                " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
         | 
| 363 | 
            +
                                " login`."
         | 
| 364 | 
            +
                            )
         | 
| 365 | 
            +
                        except RevisionNotFoundError:
         | 
| 366 | 
            +
                            raise EnvironmentError(
         | 
| 367 | 
            +
                                f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
         | 
| 368 | 
            +
                                " this model name. Check the model page at"
         | 
| 369 | 
            +
                                f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
         | 
| 370 | 
            +
                            )
         | 
| 371 | 
            +
                        except EntryNotFoundError:
         | 
| 372 | 
            +
                            raise EnvironmentError(
         | 
| 373 | 
            +
                                f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
         | 
| 374 | 
            +
                            )
         | 
| 375 | 
            +
                        except HTTPError as err:
         | 
| 376 | 
            +
                            raise EnvironmentError(
         | 
| 377 | 
            +
                                "There was a specific connection error when trying to load"
         | 
| 378 | 
            +
                                f" {pretrained_model_name_or_path}:\n{err}"
         | 
| 379 | 
            +
                            )
         | 
| 380 | 
            +
                        except ValueError:
         | 
| 381 | 
            +
                            raise EnvironmentError(
         | 
| 382 | 
            +
                                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
         | 
| 383 | 
            +
                                f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
         | 
| 384 | 
            +
                                f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
         | 
| 385 | 
            +
                                " run the library in offline mode at"
         | 
| 386 | 
            +
                                " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
         | 
| 387 | 
            +
                            )
         | 
| 388 | 
            +
                        except EnvironmentError:
         | 
| 389 | 
            +
                            raise EnvironmentError(
         | 
| 390 | 
            +
                                f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
         | 
| 391 | 
            +
                                "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
         | 
| 392 | 
            +
                                f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
         | 
| 393 | 
            +
                                f"containing a {cls.config_name} file"
         | 
| 394 | 
            +
                            )
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    try:
         | 
| 397 | 
            +
                        # Load config dict
         | 
| 398 | 
            +
                        config_dict = cls._dict_from_json_file(config_file)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        commit_hash = extract_commit_hash(config_file)
         | 
| 401 | 
            +
                    except (json.JSONDecodeError, UnicodeDecodeError):
         | 
| 402 | 
            +
                        raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    if not (return_unused_kwargs or return_commit_hash):
         | 
| 405 | 
            +
                        return config_dict
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    outputs = (config_dict,)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    if return_unused_kwargs:
         | 
| 410 | 
            +
                        outputs += (kwargs,)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    if return_commit_hash:
         | 
| 413 | 
            +
                        outputs += (commit_hash,)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    return outputs
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                @staticmethod
         | 
| 418 | 
            +
                def _get_init_keys(cls):
         | 
| 419 | 
            +
                    return set(dict(inspect.signature(cls.__init__).parameters).keys())
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                @classmethod
         | 
| 422 | 
            +
                def extract_init_dict(cls, config_dict, **kwargs):
         | 
| 423 | 
            +
                    # 0. Copy origin config dict
         | 
| 424 | 
            +
                    original_dict = dict(config_dict.items())
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # 1. Retrieve expected config attributes from __init__ signature
         | 
| 427 | 
            +
                    expected_keys = cls._get_init_keys(cls)
         | 
| 428 | 
            +
                    expected_keys.remove("self")
         | 
| 429 | 
            +
                    # remove general kwargs if present in dict
         | 
| 430 | 
            +
                    if "kwargs" in expected_keys:
         | 
| 431 | 
            +
                        expected_keys.remove("kwargs")
         | 
| 432 | 
            +
                    # remove flax internal keys
         | 
| 433 | 
            +
                    if hasattr(cls, "_flax_internal_args"):
         | 
| 434 | 
            +
                        for arg in cls._flax_internal_args:
         | 
| 435 | 
            +
                            expected_keys.remove(arg)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    # 2. Remove attributes that cannot be expected from expected config attributes
         | 
| 438 | 
            +
                    # remove keys to be ignored
         | 
| 439 | 
            +
                    if len(cls.ignore_for_config) > 0:
         | 
| 440 | 
            +
                        expected_keys = expected_keys - set(cls.ignore_for_config)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # load diffusers library to import compatible and original scheduler
         | 
| 443 | 
            +
                    diffusers_library = importlib.import_module(__name__.split(".")[0])
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    if cls.has_compatibles:
         | 
| 446 | 
            +
                        compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
         | 
| 447 | 
            +
                    else:
         | 
| 448 | 
            +
                        compatible_classes = []
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    expected_keys_comp_cls = set()
         | 
| 451 | 
            +
                    for c in compatible_classes:
         | 
| 452 | 
            +
                        expected_keys_c = cls._get_init_keys(c)
         | 
| 453 | 
            +
                        expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
         | 
| 454 | 
            +
                    expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
         | 
| 455 | 
            +
                    config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    # remove attributes from orig class that cannot be expected
         | 
| 458 | 
            +
                    orig_cls_name = config_dict.pop("_class_name", cls.__name__)
         | 
| 459 | 
            +
                    if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
         | 
| 460 | 
            +
                        orig_cls = getattr(diffusers_library, orig_cls_name)
         | 
| 461 | 
            +
                        unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
         | 
| 462 | 
            +
                        config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    # remove private attributes
         | 
| 465 | 
            +
                    config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
         | 
| 468 | 
            +
                    init_dict = {}
         | 
| 469 | 
            +
                    for key in expected_keys:
         | 
| 470 | 
            +
                        # if config param is passed to kwarg and is present in config dict
         | 
| 471 | 
            +
                        # it should overwrite existing config dict key
         | 
| 472 | 
            +
                        if key in kwargs and key in config_dict:
         | 
| 473 | 
            +
                            config_dict[key] = kwargs.pop(key)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                        if key in kwargs:
         | 
| 476 | 
            +
                            # overwrite key
         | 
| 477 | 
            +
                            init_dict[key] = kwargs.pop(key)
         | 
| 478 | 
            +
                        elif key in config_dict:
         | 
| 479 | 
            +
                            # use value from config dict
         | 
| 480 | 
            +
                            init_dict[key] = config_dict.pop(key)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    # 4. Give nice warning if unexpected values have been passed
         | 
| 483 | 
            +
                    if len(config_dict) > 0:
         | 
| 484 | 
            +
                        logger.warning(
         | 
| 485 | 
            +
                            f"The config attributes {config_dict} were passed to {cls.__name__}, "
         | 
| 486 | 
            +
                            "but are not expected and will be ignored. Please verify your "
         | 
| 487 | 
            +
                            f"{cls.config_name} configuration file."
         | 
| 488 | 
            +
                        )
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
         | 
| 491 | 
            +
                    passed_keys = set(init_dict.keys())
         | 
| 492 | 
            +
                    if len(expected_keys - passed_keys) > 0:
         | 
| 493 | 
            +
                        logger.info(
         | 
| 494 | 
            +
                            f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
         | 
| 495 | 
            +
                        )
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    # 6. Define unused keyword arguments
         | 
| 498 | 
            +
                    unused_kwargs = {**config_dict, **kwargs}
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    # 7. Define "hidden" config parameters that were saved for compatible classes
         | 
| 501 | 
            +
                    hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    return init_dict, unused_kwargs, hidden_config_dict
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                @classmethod
         | 
| 506 | 
            +
                def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
         | 
| 507 | 
            +
                    with open(json_file, "r", encoding="utf-8") as reader:
         | 
| 508 | 
            +
                        text = reader.read()
         | 
| 509 | 
            +
                    return json.loads(text)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                def __repr__(self):
         | 
| 512 | 
            +
                    return f"{self.__class__.__name__} {self.to_json_string()}"
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                @property
         | 
| 515 | 
            +
                def config(self) -> Dict[str, Any]:
         | 
| 516 | 
            +
                    """
         | 
| 517 | 
            +
                    Returns the config of the class as a frozen dictionary
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    Returns:
         | 
| 520 | 
            +
                        `Dict[str, Any]`: Config of the class.
         | 
| 521 | 
            +
                    """
         | 
| 522 | 
            +
                    return self._internal_dict
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                def to_json_string(self) -> str:
         | 
| 525 | 
            +
                    """
         | 
| 526 | 
            +
                    Serializes the configuration instance to a JSON string.
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    Returns:
         | 
| 529 | 
            +
                        `str`:
         | 
| 530 | 
            +
                            String containing all the attributes that make up the configuration instance in JSON format.
         | 
| 531 | 
            +
                    """
         | 
| 532 | 
            +
                    config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
         | 
| 533 | 
            +
                    config_dict["_class_name"] = self.__class__.__name__
         | 
| 534 | 
            +
                    config_dict["_diffusers_version"] = __version__
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    def to_json_saveable(value):
         | 
| 537 | 
            +
                        if isinstance(value, np.ndarray):
         | 
| 538 | 
            +
                            value = value.tolist()
         | 
| 539 | 
            +
                        elif isinstance(value, PosixPath):
         | 
| 540 | 
            +
                            value = str(value)
         | 
| 541 | 
            +
                        return value
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
         | 
| 544 | 
            +
                    # Don't save "_ignore_files"
         | 
| 545 | 
            +
                    config_dict.pop("_ignore_files", None)
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                    return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                def to_json_file(self, json_file_path: Union[str, os.PathLike]):
         | 
| 550 | 
            +
                    """
         | 
| 551 | 
            +
                    Save the configuration instance's parameters to a JSON file.
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    Args:
         | 
| 554 | 
            +
                        json_file_path (`str` or `os.PathLike`):
         | 
| 555 | 
            +
                            Path to the JSON file to save a configuration instance's parameters.
         | 
| 556 | 
            +
                    """
         | 
| 557 | 
            +
                    with open(json_file_path, "w", encoding="utf-8") as writer:
         | 
| 558 | 
            +
                        writer.write(self.to_json_string())
         | 
| 559 | 
            +
             | 
| 560 | 
            +
             | 
| 561 | 
            +
            def register_to_config(init):
         | 
| 562 | 
            +
                r"""
         | 
| 563 | 
            +
                Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
         | 
| 564 | 
            +
                automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
         | 
| 565 | 
            +
                shouldn't be registered in the config, use the `ignore_for_config` class variable
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
         | 
| 568 | 
            +
                """
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                @functools.wraps(init)
         | 
| 571 | 
            +
                def inner_init(self, *args, **kwargs):
         | 
| 572 | 
            +
                    # Ignore private kwargs in the init.
         | 
| 573 | 
            +
                    init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
         | 
| 574 | 
            +
                    config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
         | 
| 575 | 
            +
                    if not isinstance(self, ConfigMixin):
         | 
| 576 | 
            +
                        raise RuntimeError(
         | 
| 577 | 
            +
                            f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
         | 
| 578 | 
            +
                            "not inherit from `ConfigMixin`."
         | 
| 579 | 
            +
                        )
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    ignore = getattr(self, "ignore_for_config", [])
         | 
| 582 | 
            +
                    # Get positional arguments aligned with kwargs
         | 
| 583 | 
            +
                    new_kwargs = {}
         | 
| 584 | 
            +
                    signature = inspect.signature(init)
         | 
| 585 | 
            +
                    parameters = {
         | 
| 586 | 
            +
                        name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
         | 
| 587 | 
            +
                    }
         | 
| 588 | 
            +
                    for arg, name in zip(args, parameters.keys()):
         | 
| 589 | 
            +
                        new_kwargs[name] = arg
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    # Then add all kwargs
         | 
| 592 | 
            +
                    new_kwargs.update(
         | 
| 593 | 
            +
                        {
         | 
| 594 | 
            +
                            k: init_kwargs.get(k, default)
         | 
| 595 | 
            +
                            for k, default in parameters.items()
         | 
| 596 | 
            +
                            if k not in ignore and k not in new_kwargs
         | 
| 597 | 
            +
                        }
         | 
| 598 | 
            +
                    )
         | 
| 599 | 
            +
                    new_kwargs = {**config_init_kwargs, **new_kwargs}
         | 
| 600 | 
            +
                    getattr(self, "register_to_config")(**new_kwargs)
         | 
| 601 | 
            +
                    init(self, *args, **init_kwargs)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                return inner_init
         | 
| 604 | 
            +
             | 
| 605 | 
            +
             | 
| 606 | 
            +
            def flax_register_to_config(cls):
         | 
| 607 | 
            +
                original_init = cls.__init__
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                @functools.wraps(original_init)
         | 
| 610 | 
            +
                def init(self, *args, **kwargs):
         | 
| 611 | 
            +
                    if not isinstance(self, ConfigMixin):
         | 
| 612 | 
            +
                        raise RuntimeError(
         | 
| 613 | 
            +
                            f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
         | 
| 614 | 
            +
                            "not inherit from `ConfigMixin`."
         | 
| 615 | 
            +
                        )
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                    # Ignore private kwargs in the init. Retrieve all passed attributes
         | 
| 618 | 
            +
                    init_kwargs = dict(kwargs.items())
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    # Retrieve default values
         | 
| 621 | 
            +
                    fields = dataclasses.fields(self)
         | 
| 622 | 
            +
                    default_kwargs = {}
         | 
| 623 | 
            +
                    for field in fields:
         | 
| 624 | 
            +
                        # ignore flax specific attributes
         | 
| 625 | 
            +
                        if field.name in self._flax_internal_args:
         | 
| 626 | 
            +
                            continue
         | 
| 627 | 
            +
                        if type(field.default) == dataclasses._MISSING_TYPE:
         | 
| 628 | 
            +
                            default_kwargs[field.name] = None
         | 
| 629 | 
            +
                        else:
         | 
| 630 | 
            +
                            default_kwargs[field.name] = getattr(self, field.name)
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    # Make sure init_kwargs override default kwargs
         | 
| 633 | 
            +
                    new_kwargs = {**default_kwargs, **init_kwargs}
         | 
| 634 | 
            +
                    # dtype should be part of `init_kwargs`, but not `new_kwargs`
         | 
| 635 | 
            +
                    if "dtype" in new_kwargs:
         | 
| 636 | 
            +
                        new_kwargs.pop("dtype")
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    # Get positional arguments aligned with kwargs
         | 
| 639 | 
            +
                    for i, arg in enumerate(args):
         | 
| 640 | 
            +
                        name = fields[i].name
         | 
| 641 | 
            +
                        new_kwargs[name] = arg
         | 
| 642 | 
            +
             | 
| 643 | 
            +
                    getattr(self, "register_to_config")(**new_kwargs)
         | 
| 644 | 
            +
                    original_init(self, *args, **kwargs)
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                cls.__init__ = init
         | 
| 647 | 
            +
                return cls
         | 
    	
        diffusers/utils/constants.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import os
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            default_cache_path = HUGGINGFACE_HUB_CACHE
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            CONFIG_NAME = "config.json"
         | 
| 24 | 
            +
            WEIGHTS_NAME = "diffusion_pytorch_model.bin"
         | 
| 25 | 
            +
            FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
         | 
| 26 | 
            +
            ONNX_WEIGHTS_NAME = "model.onnx"
         | 
| 27 | 
            +
            SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
         | 
| 28 | 
            +
            ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
         | 
| 29 | 
            +
            HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
         | 
| 30 | 
            +
            DIFFUSERS_CACHE = default_cache_path
         | 
| 31 | 
            +
            DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
         | 
| 32 | 
            +
            HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
         | 
| 33 | 
            +
            DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
         | 
| 34 | 
            +
            TEXT_ENCODER_ATTN_MODULE = ".self_attn"
         | 
    	
        diffusers/utils/deprecation_utils.py
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import inspect
         | 
| 2 | 
            +
            import warnings
         | 
| 3 | 
            +
            from typing import Any, Dict, Optional, Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from packaging import version
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
         | 
| 9 | 
            +
                from .. import __version__
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                deprecated_kwargs = take_from
         | 
| 12 | 
            +
                values = ()
         | 
| 13 | 
            +
                if not isinstance(args[0], tuple):
         | 
| 14 | 
            +
                    args = (args,)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                for attribute, version_name, message in args:
         | 
| 17 | 
            +
                    if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
         | 
| 18 | 
            +
                        raise ValueError(
         | 
| 19 | 
            +
                            f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
         | 
| 20 | 
            +
                            f" version {__version__} is >= {version_name}"
         | 
| 21 | 
            +
                        )
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    warning = None
         | 
| 24 | 
            +
                    if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
         | 
| 25 | 
            +
                        values += (deprecated_kwargs.pop(attribute),)
         | 
| 26 | 
            +
                        warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
         | 
| 27 | 
            +
                    elif hasattr(deprecated_kwargs, attribute):
         | 
| 28 | 
            +
                        values += (getattr(deprecated_kwargs, attribute),)
         | 
| 29 | 
            +
                        warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
         | 
| 30 | 
            +
                    elif deprecated_kwargs is None:
         | 
| 31 | 
            +
                        warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if warning is not None:
         | 
| 34 | 
            +
                        warning = warning + " " if standard_warn else ""
         | 
| 35 | 
            +
                        warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
         | 
| 38 | 
            +
                    call_frame = inspect.getouterframes(inspect.currentframe())[1]
         | 
| 39 | 
            +
                    filename = call_frame.filename
         | 
| 40 | 
            +
                    line_number = call_frame.lineno
         | 
| 41 | 
            +
                    function = call_frame.function
         | 
| 42 | 
            +
                    key, value = next(iter(deprecated_kwargs.items()))
         | 
| 43 | 
            +
                    raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if len(values) == 0:
         | 
| 46 | 
            +
                    return
         | 
| 47 | 
            +
                elif len(values) == 1:
         | 
| 48 | 
            +
                    return values[0]
         | 
| 49 | 
            +
                return values
         | 
    	
        diffusers/utils/hub_utils.py
    ADDED
    
    | @@ -0,0 +1,357 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 The HuggingFace Inc. team.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            import os
         | 
| 18 | 
            +
            import re
         | 
| 19 | 
            +
            import sys
         | 
| 20 | 
            +
            import traceback
         | 
| 21 | 
            +
            import warnings
         | 
| 22 | 
            +
            from pathlib import Path
         | 
| 23 | 
            +
            from typing import Dict, Optional, Union
         | 
| 24 | 
            +
            from uuid import uuid4
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
         | 
| 27 | 
            +
            from huggingface_hub.file_download import REGEX_COMMIT_HASH
         | 
| 28 | 
            +
            from huggingface_hub.utils import (
         | 
| 29 | 
            +
                EntryNotFoundError,
         | 
| 30 | 
            +
                RepositoryNotFoundError,
         | 
| 31 | 
            +
                RevisionNotFoundError,
         | 
| 32 | 
            +
                is_jinja_available,
         | 
| 33 | 
            +
            )
         | 
| 34 | 
            +
            from packaging import version
         | 
| 35 | 
            +
            from requests import HTTPError
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            from .constants import (
         | 
| 38 | 
            +
                DEPRECATED_REVISION_ARGS,
         | 
| 39 | 
            +
                DIFFUSERS_CACHE,
         | 
| 40 | 
            +
                HUGGINGFACE_CO_RESOLVE_ENDPOINT,
         | 
| 41 | 
            +
                SAFETENSORS_WEIGHTS_NAME,
         | 
| 42 | 
            +
                WEIGHTS_NAME,
         | 
| 43 | 
            +
            )
         | 
| 44 | 
            +
            from .import_utils import (
         | 
| 45 | 
            +
                ENV_VARS_TRUE_VALUES,
         | 
| 46 | 
            +
                _flax_version,
         | 
| 47 | 
            +
                _jax_version,
         | 
| 48 | 
            +
                _onnxruntime_version,
         | 
| 49 | 
            +
                _torch_version,
         | 
| 50 | 
            +
                is_flax_available,
         | 
| 51 | 
            +
                is_onnx_available,
         | 
| 52 | 
            +
                is_torch_available,
         | 
| 53 | 
            +
            )
         | 
| 54 | 
            +
            from .logging import get_logger
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            logger = get_logger(__name__)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
         | 
| 61 | 
            +
            SESSION_ID = uuid4().hex
         | 
| 62 | 
            +
            HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
         | 
| 63 | 
            +
            DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
         | 
| 64 | 
            +
            HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                Formats a user-agent string with basic info about a request.
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                ua = f"diffusers; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
         | 
| 72 | 
            +
                if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
         | 
| 73 | 
            +
                    return ua + "; telemetry/off"
         | 
| 74 | 
            +
                if is_torch_available():
         | 
| 75 | 
            +
                    ua += f"; torch/{_torch_version}"
         | 
| 76 | 
            +
                if is_flax_available():
         | 
| 77 | 
            +
                    ua += f"; jax/{_jax_version}"
         | 
| 78 | 
            +
                    ua += f"; flax/{_flax_version}"
         | 
| 79 | 
            +
                if is_onnx_available():
         | 
| 80 | 
            +
                    ua += f"; onnxruntime/{_onnxruntime_version}"
         | 
| 81 | 
            +
                # CI will set this value to True
         | 
| 82 | 
            +
                if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
         | 
| 83 | 
            +
                    ua += "; is_ci/true"
         | 
| 84 | 
            +
                if isinstance(user_agent, dict):
         | 
| 85 | 
            +
                    ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
         | 
| 86 | 
            +
                elif isinstance(user_agent, str):
         | 
| 87 | 
            +
                    ua += "; " + user_agent
         | 
| 88 | 
            +
                return ua
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
         | 
| 92 | 
            +
                if token is None:
         | 
| 93 | 
            +
                    token = HfFolder.get_token()
         | 
| 94 | 
            +
                if organization is None:
         | 
| 95 | 
            +
                    username = whoami(token)["name"]
         | 
| 96 | 
            +
                    return f"{username}/{model_id}"
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    return f"{organization}/{model_id}"
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def create_model_card(args, model_name):
         | 
| 102 | 
            +
                if not is_jinja_available():
         | 
| 103 | 
            +
                    raise ValueError(
         | 
| 104 | 
            +
                        "Modelcard rendering is based on Jinja templates."
         | 
| 105 | 
            +
                        " Please make sure to have `jinja` installed before using `create_model_card`."
         | 
| 106 | 
            +
                        " To install it, please run `pip install Jinja2`."
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
         | 
| 110 | 
            +
                    return
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                hub_token = args.hub_token if hasattr(args, "hub_token") else None
         | 
| 113 | 
            +
                repo_name = get_full_repo_name(model_name, token=hub_token)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                model_card = ModelCard.from_template(
         | 
| 116 | 
            +
                    card_data=ModelCardData(  # Card metadata object that will be converted to YAML block
         | 
| 117 | 
            +
                        language="en",
         | 
| 118 | 
            +
                        license="apache-2.0",
         | 
| 119 | 
            +
                        library_name="diffusers",
         | 
| 120 | 
            +
                        tags=[],
         | 
| 121 | 
            +
                        datasets=args.dataset_name,
         | 
| 122 | 
            +
                        metrics=[],
         | 
| 123 | 
            +
                    ),
         | 
| 124 | 
            +
                    template_path=MODEL_CARD_TEMPLATE_PATH,
         | 
| 125 | 
            +
                    model_name=model_name,
         | 
| 126 | 
            +
                    repo_name=repo_name,
         | 
| 127 | 
            +
                    dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
         | 
| 128 | 
            +
                    learning_rate=args.learning_rate,
         | 
| 129 | 
            +
                    train_batch_size=args.train_batch_size,
         | 
| 130 | 
            +
                    eval_batch_size=args.eval_batch_size,
         | 
| 131 | 
            +
                    gradient_accumulation_steps=(
         | 
| 132 | 
            +
                        args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
         | 
| 133 | 
            +
                    ),
         | 
| 134 | 
            +
                    adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
         | 
| 135 | 
            +
                    adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
         | 
| 136 | 
            +
                    adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
         | 
| 137 | 
            +
                    adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
         | 
| 138 | 
            +
                    lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
         | 
| 139 | 
            +
                    lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
         | 
| 140 | 
            +
                    ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
         | 
| 141 | 
            +
                    ema_power=args.ema_power if hasattr(args, "ema_power") else None,
         | 
| 142 | 
            +
                    ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
         | 
| 143 | 
            +
                    mixed_precision=args.mixed_precision,
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                card_path = os.path.join(args.output_dir, "README.md")
         | 
| 147 | 
            +
                model_card.save(card_path)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
                Extracts the commit hash from a resolved filename toward a cache file.
         | 
| 153 | 
            +
                """
         | 
| 154 | 
            +
                if resolved_file is None or commit_hash is not None:
         | 
| 155 | 
            +
                    return commit_hash
         | 
| 156 | 
            +
                resolved_file = str(Path(resolved_file).as_posix())
         | 
| 157 | 
            +
                search = re.search(r"snapshots/([^/]+)/", resolved_file)
         | 
| 158 | 
            +
                if search is None:
         | 
| 159 | 
            +
                    return None
         | 
| 160 | 
            +
                commit_hash = search.groups()[0]
         | 
| 161 | 
            +
                return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            # Old default cache path, potentially to be migrated.
         | 
| 165 | 
            +
            # This logic was more or less taken from `transformers`, with the following differences:
         | 
| 166 | 
            +
            # - Diffusers doesn't use custom environment variables to specify the cache path.
         | 
| 167 | 
            +
            # - There is no need to migrate the cache format, just move the files to the new location.
         | 
| 168 | 
            +
            hf_cache_home = os.path.expanduser(
         | 
| 169 | 
            +
                os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
         | 
| 170 | 
            +
            )
         | 
| 171 | 
            +
            old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
            def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
         | 
| 175 | 
            +
                if new_cache_dir is None:
         | 
| 176 | 
            +
                    new_cache_dir = DIFFUSERS_CACHE
         | 
| 177 | 
            +
                if old_cache_dir is None:
         | 
| 178 | 
            +
                    old_cache_dir = old_diffusers_cache
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                old_cache_dir = Path(old_cache_dir).expanduser()
         | 
| 181 | 
            +
                new_cache_dir = Path(new_cache_dir).expanduser()
         | 
| 182 | 
            +
                for old_blob_path in old_cache_dir.glob("**/blobs/*"):
         | 
| 183 | 
            +
                    if old_blob_path.is_file() and not old_blob_path.is_symlink():
         | 
| 184 | 
            +
                        new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
         | 
| 185 | 
            +
                        new_blob_path.parent.mkdir(parents=True, exist_ok=True)
         | 
| 186 | 
            +
                        os.replace(old_blob_path, new_blob_path)
         | 
| 187 | 
            +
                        try:
         | 
| 188 | 
            +
                            os.symlink(new_blob_path, old_blob_path)
         | 
| 189 | 
            +
                        except OSError:
         | 
| 190 | 
            +
                            logger.warning(
         | 
| 191 | 
            +
                                "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
         | 
| 192 | 
            +
                            )
         | 
| 193 | 
            +
                # At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
         | 
| 197 | 
            +
            if not os.path.isfile(cache_version_file):
         | 
| 198 | 
            +
                cache_version = 0
         | 
| 199 | 
            +
            else:
         | 
| 200 | 
            +
                with open(cache_version_file) as f:
         | 
| 201 | 
            +
                    cache_version = int(f.read())
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            if cache_version < 1:
         | 
| 204 | 
            +
                old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
         | 
| 205 | 
            +
                if old_cache_is_not_empty:
         | 
| 206 | 
            +
                    logger.warning(
         | 
| 207 | 
            +
                        "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
         | 
| 208 | 
            +
                        "existing cached models. This is a one-time operation, you can interrupt it or run it "
         | 
| 209 | 
            +
                        "later by calling `diffusers.utils.hub_utils.move_cache()`."
         | 
| 210 | 
            +
                    )
         | 
| 211 | 
            +
                    try:
         | 
| 212 | 
            +
                        move_cache()
         | 
| 213 | 
            +
                    except Exception as e:
         | 
| 214 | 
            +
                        trace = "\n".join(traceback.format_tb(e.__traceback__))
         | 
| 215 | 
            +
                        logger.error(
         | 
| 216 | 
            +
                            f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
         | 
| 217 | 
            +
                            "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
         | 
| 218 | 
            +
                            "message and we will do our best to help."
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            if cache_version < 1:
         | 
| 222 | 
            +
                try:
         | 
| 223 | 
            +
                    os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
         | 
| 224 | 
            +
                    with open(cache_version_file, "w") as f:
         | 
| 225 | 
            +
                        f.write("1")
         | 
| 226 | 
            +
                except Exception:
         | 
| 227 | 
            +
                    logger.warning(
         | 
| 228 | 
            +
                        f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
         | 
| 229 | 
            +
                        "the directory exists and can be written to."
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
         | 
| 234 | 
            +
                if variant is not None:
         | 
| 235 | 
            +
                    splits = weights_name.split(".")
         | 
| 236 | 
            +
                    splits = splits[:-1] + [variant] + splits[-1:]
         | 
| 237 | 
            +
                    weights_name = ".".join(splits)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                return weights_name
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def _get_model_file(
         | 
| 243 | 
            +
                pretrained_model_name_or_path,
         | 
| 244 | 
            +
                *,
         | 
| 245 | 
            +
                weights_name,
         | 
| 246 | 
            +
                subfolder,
         | 
| 247 | 
            +
                cache_dir,
         | 
| 248 | 
            +
                force_download,
         | 
| 249 | 
            +
                proxies,
         | 
| 250 | 
            +
                resume_download,
         | 
| 251 | 
            +
                local_files_only,
         | 
| 252 | 
            +
                use_auth_token,
         | 
| 253 | 
            +
                user_agent,
         | 
| 254 | 
            +
                revision,
         | 
| 255 | 
            +
                commit_hash=None,
         | 
| 256 | 
            +
            ):
         | 
| 257 | 
            +
                pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         | 
| 258 | 
            +
                if os.path.isfile(pretrained_model_name_or_path):
         | 
| 259 | 
            +
                    return pretrained_model_name_or_path
         | 
| 260 | 
            +
                elif os.path.isdir(pretrained_model_name_or_path):
         | 
| 261 | 
            +
                    if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
         | 
| 262 | 
            +
                        # Load from a PyTorch checkpoint
         | 
| 263 | 
            +
                        model_file = os.path.join(pretrained_model_name_or_path, weights_name)
         | 
| 264 | 
            +
                        return model_file
         | 
| 265 | 
            +
                    elif subfolder is not None and os.path.isfile(
         | 
| 266 | 
            +
                        os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
         | 
| 267 | 
            +
                    ):
         | 
| 268 | 
            +
                        model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
         | 
| 269 | 
            +
                        return model_file
         | 
| 270 | 
            +
                    else:
         | 
| 271 | 
            +
                        raise EnvironmentError(
         | 
| 272 | 
            +
                            f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
         | 
| 273 | 
            +
                        )
         | 
| 274 | 
            +
                else:
         | 
| 275 | 
            +
                    # 1. First check if deprecated way of loading from branches is used
         | 
| 276 | 
            +
                    if (
         | 
| 277 | 
            +
                        revision in DEPRECATED_REVISION_ARGS
         | 
| 278 | 
            +
                        and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
         | 
| 279 | 
            +
                        and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
         | 
| 280 | 
            +
                    ):
         | 
| 281 | 
            +
                        try:
         | 
| 282 | 
            +
                            model_file = hf_hub_download(
         | 
| 283 | 
            +
                                pretrained_model_name_or_path,
         | 
| 284 | 
            +
                                filename=_add_variant(weights_name, revision),
         | 
| 285 | 
            +
                                cache_dir=cache_dir,
         | 
| 286 | 
            +
                                force_download=force_download,
         | 
| 287 | 
            +
                                proxies=proxies,
         | 
| 288 | 
            +
                                resume_download=resume_download,
         | 
| 289 | 
            +
                                local_files_only=local_files_only,
         | 
| 290 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 291 | 
            +
                                user_agent=user_agent,
         | 
| 292 | 
            +
                                subfolder=subfolder,
         | 
| 293 | 
            +
                                revision=revision or commit_hash,
         | 
| 294 | 
            +
                            )
         | 
| 295 | 
            +
                            warnings.warn(
         | 
| 296 | 
            +
                                f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
         | 
| 297 | 
            +
                                FutureWarning,
         | 
| 298 | 
            +
                            )
         | 
| 299 | 
            +
                            return model_file
         | 
| 300 | 
            +
                        except:  # noqa: E722
         | 
| 301 | 
            +
                            warnings.warn(
         | 
| 302 | 
            +
                                f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
         | 
| 303 | 
            +
                                FutureWarning,
         | 
| 304 | 
            +
                            )
         | 
| 305 | 
            +
                    try:
         | 
| 306 | 
            +
                        # 2. Load model file as usual
         | 
| 307 | 
            +
                        model_file = hf_hub_download(
         | 
| 308 | 
            +
                            pretrained_model_name_or_path,
         | 
| 309 | 
            +
                            filename=weights_name,
         | 
| 310 | 
            +
                            cache_dir=cache_dir,
         | 
| 311 | 
            +
                            force_download=force_download,
         | 
| 312 | 
            +
                            proxies=proxies,
         | 
| 313 | 
            +
                            resume_download=resume_download,
         | 
| 314 | 
            +
                            local_files_only=local_files_only,
         | 
| 315 | 
            +
                            use_auth_token=use_auth_token,
         | 
| 316 | 
            +
                            user_agent=user_agent,
         | 
| 317 | 
            +
                            subfolder=subfolder,
         | 
| 318 | 
            +
                            revision=revision or commit_hash,
         | 
| 319 | 
            +
                        )
         | 
| 320 | 
            +
                        return model_file
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    except RepositoryNotFoundError:
         | 
| 323 | 
            +
                        raise EnvironmentError(
         | 
| 324 | 
            +
                            f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
         | 
| 325 | 
            +
                            "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
         | 
| 326 | 
            +
                            "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
         | 
| 327 | 
            +
                            "login`."
         | 
| 328 | 
            +
                        )
         | 
| 329 | 
            +
                    except RevisionNotFoundError:
         | 
| 330 | 
            +
                        raise EnvironmentError(
         | 
| 331 | 
            +
                            f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
         | 
| 332 | 
            +
                            "this model name. Check the model page at "
         | 
| 333 | 
            +
                            f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
         | 
| 334 | 
            +
                        )
         | 
| 335 | 
            +
                    except EntryNotFoundError:
         | 
| 336 | 
            +
                        raise EnvironmentError(
         | 
| 337 | 
            +
                            f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
         | 
| 338 | 
            +
                        )
         | 
| 339 | 
            +
                    except HTTPError as err:
         | 
| 340 | 
            +
                        raise EnvironmentError(
         | 
| 341 | 
            +
                            f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
         | 
| 342 | 
            +
                        )
         | 
| 343 | 
            +
                    except ValueError:
         | 
| 344 | 
            +
                        raise EnvironmentError(
         | 
| 345 | 
            +
                            f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
         | 
| 346 | 
            +
                            f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
         | 
| 347 | 
            +
                            f" directory containing a file named {weights_name} or"
         | 
| 348 | 
            +
                            " \nCheckout your internet connection or see how to run the library in"
         | 
| 349 | 
            +
                            " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
         | 
| 350 | 
            +
                        )
         | 
| 351 | 
            +
                    except EnvironmentError:
         | 
| 352 | 
            +
                        raise EnvironmentError(
         | 
| 353 | 
            +
                            f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
         | 
| 354 | 
            +
                            "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
         | 
| 355 | 
            +
                            f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
         | 
| 356 | 
            +
                            f"containing a file named {weights_name}"
         | 
| 357 | 
            +
                        )
         | 
    	
        diffusers/utils/import_utils.py
    ADDED
    
    | @@ -0,0 +1,649 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            """
         | 
| 15 | 
            +
            Import utilities: Utilities related to imports and our lazy inits.
         | 
| 16 | 
            +
            """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import importlib.util
         | 
| 19 | 
            +
            import operator as op
         | 
| 20 | 
            +
            import os
         | 
| 21 | 
            +
            import sys
         | 
| 22 | 
            +
            from collections import OrderedDict
         | 
| 23 | 
            +
            from typing import Union
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from packaging import version
         | 
| 26 | 
            +
            from packaging.version import Version, parse
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from . import logging
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            # The package importlib_metadata is in a different place, depending on the python version.
         | 
| 32 | 
            +
            if sys.version_info < (3, 8):
         | 
| 33 | 
            +
                import importlib_metadata
         | 
| 34 | 
            +
            else:
         | 
| 35 | 
            +
                import importlib.metadata as importlib_metadata
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
         | 
| 41 | 
            +
            ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            USE_TF = os.environ.get("USE_TF", "AUTO").upper()
         | 
| 44 | 
            +
            USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
         | 
| 45 | 
            +
            USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
         | 
| 46 | 
            +
            USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            STR_OPERATION_TO_FUNC = {
         | 
| 49 | 
            +
                ">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt
         | 
| 50 | 
            +
            }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            _torch_version = "N/A"
         | 
| 53 | 
            +
            if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
         | 
| 54 | 
            +
                _torch_available = importlib.util.find_spec("torch") is not None
         | 
| 55 | 
            +
                if _torch_available:
         | 
| 56 | 
            +
                    try:
         | 
| 57 | 
            +
                        _torch_version = importlib_metadata.version("torch")
         | 
| 58 | 
            +
                        logger.info(f"PyTorch version {_torch_version} available.")
         | 
| 59 | 
            +
                    except importlib_metadata.PackageNotFoundError:
         | 
| 60 | 
            +
                        _torch_available = False
         | 
| 61 | 
            +
            else:
         | 
| 62 | 
            +
                logger.info("Disabling PyTorch because USE_TORCH is set")
         | 
| 63 | 
            +
                _torch_available = False
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            _tf_version = "N/A"
         | 
| 67 | 
            +
            if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
         | 
| 68 | 
            +
                _tf_available = importlib.util.find_spec("tensorflow") is not None
         | 
| 69 | 
            +
                if _tf_available:
         | 
| 70 | 
            +
                    candidates = (
         | 
| 71 | 
            +
                        "tensorflow",
         | 
| 72 | 
            +
                        "tensorflow-cpu",
         | 
| 73 | 
            +
                        "tensorflow-gpu",
         | 
| 74 | 
            +
                        "tf-nightly",
         | 
| 75 | 
            +
                        "tf-nightly-cpu",
         | 
| 76 | 
            +
                        "tf-nightly-gpu",
         | 
| 77 | 
            +
                        "intel-tensorflow",
         | 
| 78 | 
            +
                        "intel-tensorflow-avx512",
         | 
| 79 | 
            +
                        "tensorflow-rocm",
         | 
| 80 | 
            +
                        "tensorflow-macos",
         | 
| 81 | 
            +
                        "tensorflow-aarch64",
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                    _tf_version = None
         | 
| 84 | 
            +
                    # For the metadata, we have to look for both tensorflow and tensorflow-cpu
         | 
| 85 | 
            +
                    for pkg in candidates:
         | 
| 86 | 
            +
                        try:
         | 
| 87 | 
            +
                            _tf_version = importlib_metadata.version(pkg)
         | 
| 88 | 
            +
                            break
         | 
| 89 | 
            +
                        except importlib_metadata.PackageNotFoundError:
         | 
| 90 | 
            +
                            pass
         | 
| 91 | 
            +
                    _tf_available = _tf_version is not None
         | 
| 92 | 
            +
                if _tf_available:
         | 
| 93 | 
            +
                    if version.parse(_tf_version) < version.parse("2"):
         | 
| 94 | 
            +
                        logger.info(f"TensorFlow found but with version {_tf_version}. "
         | 
| 95 | 
            +
                                    "Diffusers requires version 2 minimum.")
         | 
| 96 | 
            +
                        _tf_available = False
         | 
| 97 | 
            +
                    else:
         | 
| 98 | 
            +
                        logger.info(f"TensorFlow version {_tf_version} available.")
         | 
| 99 | 
            +
            else:
         | 
| 100 | 
            +
                logger.info("Disabling Tensorflow because USE_TORCH is set")
         | 
| 101 | 
            +
                _tf_available = False
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            _jax_version = "N/A"
         | 
| 104 | 
            +
            _flax_version = "N/A"
         | 
| 105 | 
            +
            if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
         | 
| 106 | 
            +
                _flax_available = importlib.util.find_spec("jax") is not None \
         | 
| 107 | 
            +
                    and importlib.util.find_spec("flax") is not None
         | 
| 108 | 
            +
                if _flax_available:
         | 
| 109 | 
            +
                    try:
         | 
| 110 | 
            +
                        _jax_version = importlib_metadata.version("jax")
         | 
| 111 | 
            +
                        _flax_version = importlib_metadata.version("flax")
         | 
| 112 | 
            +
                        logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
         | 
| 113 | 
            +
                    except importlib_metadata.PackageNotFoundError:
         | 
| 114 | 
            +
                        _flax_available = False
         | 
| 115 | 
            +
            else:
         | 
| 116 | 
            +
                _flax_available = False
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
         | 
| 119 | 
            +
                _safetensors_available = importlib.util.find_spec("safetensors") is not None
         | 
| 120 | 
            +
                if _safetensors_available:
         | 
| 121 | 
            +
                    try:
         | 
| 122 | 
            +
                        _safetensors_version = importlib_metadata.version("safetensors")
         | 
| 123 | 
            +
                        logger.info(f"Safetensors version {_safetensors_version} available.")
         | 
| 124 | 
            +
                    except importlib_metadata.PackageNotFoundError:
         | 
| 125 | 
            +
                        _safetensors_available = False
         | 
| 126 | 
            +
            else:
         | 
| 127 | 
            +
                logger.info("Disabling Safetensors because USE_TF is set")
         | 
| 128 | 
            +
                _safetensors_available = False
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            _transformers_available = importlib.util.find_spec("transformers") is not None
         | 
| 131 | 
            +
            try:
         | 
| 132 | 
            +
                _transformers_version = importlib_metadata.version("transformers")
         | 
| 133 | 
            +
                logger.debug(f"Successfully imported transformers version {_transformers_version}")
         | 
| 134 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 135 | 
            +
                _transformers_available = False
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            _inflect_available = importlib.util.find_spec("inflect") is not None
         | 
| 139 | 
            +
            try:
         | 
| 140 | 
            +
                _inflect_version = importlib_metadata.version("inflect")
         | 
| 141 | 
            +
                logger.debug(f"Successfully imported inflect version {_inflect_version}")
         | 
| 142 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 143 | 
            +
                _inflect_available = False
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            _unidecode_available = importlib.util.find_spec("unidecode") is not None
         | 
| 147 | 
            +
            try:
         | 
| 148 | 
            +
                _unidecode_version = importlib_metadata.version("unidecode")
         | 
| 149 | 
            +
                logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
         | 
| 150 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 151 | 
            +
                _unidecode_available = False
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            _onnxruntime_version = "N/A"
         | 
| 155 | 
            +
            _onnx_available = importlib.util.find_spec("onnxruntime") is not None
         | 
| 156 | 
            +
            if _onnx_available:
         | 
| 157 | 
            +
                candidates = (
         | 
| 158 | 
            +
                    "onnxruntime",
         | 
| 159 | 
            +
                    "onnxruntime-gpu",
         | 
| 160 | 
            +
                    "ort_nightly_gpu",
         | 
| 161 | 
            +
                    "onnxruntime-directml",
         | 
| 162 | 
            +
                    "onnxruntime-openvino",
         | 
| 163 | 
            +
                    "ort_nightly_directml",
         | 
| 164 | 
            +
                    "onnxruntime-rocm",
         | 
| 165 | 
            +
                    "onnxruntime-training",
         | 
| 166 | 
            +
                )
         | 
| 167 | 
            +
                _onnxruntime_version = None
         | 
| 168 | 
            +
                # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
         | 
| 169 | 
            +
                for pkg in candidates:
         | 
| 170 | 
            +
                    try:
         | 
| 171 | 
            +
                        _onnxruntime_version = importlib_metadata.version(pkg)
         | 
| 172 | 
            +
                        break
         | 
| 173 | 
            +
                    except importlib_metadata.PackageNotFoundError:
         | 
| 174 | 
            +
                        pass
         | 
| 175 | 
            +
                _onnx_available = _onnxruntime_version is not None
         | 
| 176 | 
            +
                if _onnx_available:
         | 
| 177 | 
            +
                    logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
         | 
| 180 | 
            +
            # _opencv_available = importlib.util.find_spec("opencv-python") is not None
         | 
| 181 | 
            +
            try:
         | 
| 182 | 
            +
                candidates = (
         | 
| 183 | 
            +
                    "opencv-python",
         | 
| 184 | 
            +
                    "opencv-contrib-python",
         | 
| 185 | 
            +
                    "opencv-python-headless",
         | 
| 186 | 
            +
                    "opencv-contrib-python-headless",
         | 
| 187 | 
            +
                )
         | 
| 188 | 
            +
                _opencv_version = None
         | 
| 189 | 
            +
                for pkg in candidates:
         | 
| 190 | 
            +
                    try:
         | 
| 191 | 
            +
                        _opencv_version = importlib_metadata.version(pkg)
         | 
| 192 | 
            +
                        break
         | 
| 193 | 
            +
                    except importlib_metadata.PackageNotFoundError:
         | 
| 194 | 
            +
                        pass
         | 
| 195 | 
            +
                _opencv_available = _opencv_version is not None
         | 
| 196 | 
            +
                if _opencv_available:
         | 
| 197 | 
            +
                    logger.debug(f"Successfully imported cv2 version {_opencv_version}")
         | 
| 198 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 199 | 
            +
                _opencv_available = False
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            _scipy_available = importlib.util.find_spec("scipy") is not None
         | 
| 202 | 
            +
            try:
         | 
| 203 | 
            +
                _scipy_version = importlib_metadata.version("scipy")
         | 
| 204 | 
            +
                logger.debug(f"Successfully imported scipy version {_scipy_version}")
         | 
| 205 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 206 | 
            +
                _scipy_available = False
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            _librosa_available = importlib.util.find_spec("librosa") is not None
         | 
| 209 | 
            +
            try:
         | 
| 210 | 
            +
                _librosa_version = importlib_metadata.version("librosa")
         | 
| 211 | 
            +
                logger.debug(f"Successfully imported librosa version {_librosa_version}")
         | 
| 212 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 213 | 
            +
                _librosa_available = False
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            _accelerate_available = importlib.util.find_spec("accelerate") is not None
         | 
| 216 | 
            +
            try:
         | 
| 217 | 
            +
                _accelerate_version = importlib_metadata.version("accelerate")
         | 
| 218 | 
            +
                logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
         | 
| 219 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 220 | 
            +
                _accelerate_available = False
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            _xformers_available = importlib.util.find_spec("xformers") is not None
         | 
| 223 | 
            +
            try:
         | 
| 224 | 
            +
                _xformers_version = importlib_metadata.version("xformers")
         | 
| 225 | 
            +
                if _torch_available:
         | 
| 226 | 
            +
                    import torch
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    if version.Version(torch.__version__) < version.Version("1.12"):
         | 
| 229 | 
            +
                        raise ValueError("PyTorch should be >= 1.12")
         | 
| 230 | 
            +
                logger.debug(f"Successfully imported xformers version {_xformers_version}")
         | 
| 231 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 232 | 
            +
                _xformers_available = False
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
         | 
| 235 | 
            +
            try:
         | 
| 236 | 
            +
                _k_diffusion_version = importlib_metadata.version("k_diffusion")
         | 
| 237 | 
            +
                logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
         | 
| 238 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 239 | 
            +
                _k_diffusion_available = False
         | 
| 240 | 
            +
             | 
| 241 | 
            +
            _note_seq_available = importlib.util.find_spec("note_seq") is not None
         | 
| 242 | 
            +
            try:
         | 
| 243 | 
            +
                _note_seq_version = importlib_metadata.version("note_seq")
         | 
| 244 | 
            +
                logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
         | 
| 245 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 246 | 
            +
                _note_seq_available = False
         | 
| 247 | 
            +
             | 
| 248 | 
            +
            _wandb_available = importlib.util.find_spec("wandb") is not None
         | 
| 249 | 
            +
            try:
         | 
| 250 | 
            +
                _wandb_version = importlib_metadata.version("wandb")
         | 
| 251 | 
            +
                logger.debug(f"Successfully imported wandb version {_wandb_version }")
         | 
| 252 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 253 | 
            +
                _wandb_available = False
         | 
| 254 | 
            +
             | 
| 255 | 
            +
            _omegaconf_available = importlib.util.find_spec("omegaconf") is not None
         | 
| 256 | 
            +
            try:
         | 
| 257 | 
            +
                _omegaconf_version = importlib_metadata.version("omegaconf")
         | 
| 258 | 
            +
                logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}")
         | 
| 259 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 260 | 
            +
                _omegaconf_available = False
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            _tensorboard_available = importlib.util.find_spec("tensorboard")
         | 
| 263 | 
            +
            try:
         | 
| 264 | 
            +
                _tensorboard_version = importlib_metadata.version("tensorboard")
         | 
| 265 | 
            +
                logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
         | 
| 266 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 267 | 
            +
                _tensorboard_available = False
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
            _compel_available = importlib.util.find_spec("compel")
         | 
| 271 | 
            +
            try:
         | 
| 272 | 
            +
                _compel_version = importlib_metadata.version("compel")
         | 
| 273 | 
            +
                logger.debug(f"Successfully imported compel version {_compel_version}")
         | 
| 274 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 275 | 
            +
                _compel_available = False
         | 
| 276 | 
            +
             | 
| 277 | 
            +
             | 
| 278 | 
            +
            _ftfy_available = importlib.util.find_spec("ftfy") is not None
         | 
| 279 | 
            +
            try:
         | 
| 280 | 
            +
                _ftfy_version = importlib_metadata.version("ftfy")
         | 
| 281 | 
            +
                logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
         | 
| 282 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 283 | 
            +
                _ftfy_available = False
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            _bs4_available = importlib.util.find_spec("bs4") is not None
         | 
| 287 | 
            +
            try:
         | 
| 288 | 
            +
                # importlib metadata under different name
         | 
| 289 | 
            +
                _bs4_version = importlib_metadata.version("beautifulsoup4")
         | 
| 290 | 
            +
                logger.debug(f"Successfully imported ftfy version {_bs4_version}")
         | 
| 291 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 292 | 
            +
                _bs4_available = False
         | 
| 293 | 
            +
             | 
| 294 | 
            +
            _torchsde_available = importlib.util.find_spec("torchsde") is not None
         | 
| 295 | 
            +
            try:
         | 
| 296 | 
            +
                _torchsde_version = importlib_metadata.version("torchsde")
         | 
| 297 | 
            +
                logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
         | 
| 298 | 
            +
            except importlib_metadata.PackageNotFoundError:
         | 
| 299 | 
            +
                _torchsde_available = False
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            def is_torch_available():
         | 
| 303 | 
            +
                return _torch_available
         | 
| 304 | 
            +
             | 
| 305 | 
            +
             | 
| 306 | 
            +
            def is_safetensors_available():
         | 
| 307 | 
            +
                return _safetensors_available
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def is_tf_available():
         | 
| 311 | 
            +
                return _tf_available
         | 
| 312 | 
            +
             | 
| 313 | 
            +
             | 
| 314 | 
            +
            def is_flax_available():
         | 
| 315 | 
            +
                return _flax_available
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            def is_transformers_available():
         | 
| 319 | 
            +
                return _transformers_available
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            def is_inflect_available():
         | 
| 323 | 
            +
                return _inflect_available
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            def is_unidecode_available():
         | 
| 327 | 
            +
                return _unidecode_available
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            def is_onnx_available():
         | 
| 331 | 
            +
                return _onnx_available
         | 
| 332 | 
            +
             | 
| 333 | 
            +
             | 
| 334 | 
            +
            def is_opencv_available():
         | 
| 335 | 
            +
                return _opencv_available
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            def is_scipy_available():
         | 
| 339 | 
            +
                return _scipy_available
         | 
| 340 | 
            +
             | 
| 341 | 
            +
             | 
| 342 | 
            +
            def is_librosa_available():
         | 
| 343 | 
            +
                return _librosa_available
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            def is_xformers_available():
         | 
| 347 | 
            +
                return _xformers_available
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            def is_accelerate_available():
         | 
| 351 | 
            +
                return _accelerate_available
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            def is_k_diffusion_available():
         | 
| 355 | 
            +
                return _k_diffusion_available
         | 
| 356 | 
            +
             | 
| 357 | 
            +
             | 
| 358 | 
            +
            def is_note_seq_available():
         | 
| 359 | 
            +
                return _note_seq_available
         | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
            def is_wandb_available():
         | 
| 363 | 
            +
                return _wandb_available
         | 
| 364 | 
            +
             | 
| 365 | 
            +
             | 
| 366 | 
            +
            def is_omegaconf_available():
         | 
| 367 | 
            +
                return _omegaconf_available
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            def is_tensorboard_available():
         | 
| 371 | 
            +
                return _tensorboard_available
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
            def is_compel_available():
         | 
| 375 | 
            +
                return _compel_available
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            def is_ftfy_available():
         | 
| 379 | 
            +
                return _ftfy_available
         | 
| 380 | 
            +
             | 
| 381 | 
            +
             | 
| 382 | 
            +
            def is_bs4_available():
         | 
| 383 | 
            +
                return _bs4_available
         | 
| 384 | 
            +
             | 
| 385 | 
            +
             | 
| 386 | 
            +
            def is_torchsde_available():
         | 
| 387 | 
            +
                return _torchsde_available
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            # docstyle-ignore
         | 
| 391 | 
            +
            FLAX_IMPORT_ERROR = """
         | 
| 392 | 
            +
            {0} requires the FLAX library but it was not found in your environment. 
         | 
| 393 | 
            +
            Checkout the instructions on the installation page: https://github.com/google/flax 
         | 
| 394 | 
            +
            and follow the ones that match your environment.
         | 
| 395 | 
            +
            """
         | 
| 396 | 
            +
             | 
| 397 | 
            +
            # docstyle-ignore
         | 
| 398 | 
            +
            INFLECT_IMPORT_ERROR = """
         | 
| 399 | 
            +
            {0} requires the inflect library but it was not found in your environment. 
         | 
| 400 | 
            +
            You can install it with pip: `pip install inflect`
         | 
| 401 | 
            +
            """
         | 
| 402 | 
            +
             | 
| 403 | 
            +
            # docstyle-ignore
         | 
| 404 | 
            +
            PYTORCH_IMPORT_ERROR = """
         | 
| 405 | 
            +
            {0} requires the PyTorch library but it was not found in your environment. 
         | 
| 406 | 
            +
            Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ 
         | 
| 407 | 
            +
            and follow the ones that match your environment.
         | 
| 408 | 
            +
            """
         | 
| 409 | 
            +
             | 
| 410 | 
            +
            # docstyle-ignore
         | 
| 411 | 
            +
            ONNX_IMPORT_ERROR = """
         | 
| 412 | 
            +
            {0} requires the onnxruntime library but it was not found in your environment. 
         | 
| 413 | 
            +
            You can install it with pip: `pip install onnxruntime`
         | 
| 414 | 
            +
            """
         | 
| 415 | 
            +
             | 
| 416 | 
            +
            # docstyle-ignore
         | 
| 417 | 
            +
            OPENCV_IMPORT_ERROR = """
         | 
| 418 | 
            +
            {0} requires the OpenCV library but it was not found in your environment. 
         | 
| 419 | 
            +
            You can install it with pip: `pip install opencv-python`
         | 
| 420 | 
            +
            """
         | 
| 421 | 
            +
             | 
| 422 | 
            +
            # docstyle-ignore
         | 
| 423 | 
            +
            SCIPY_IMPORT_ERROR = """
         | 
| 424 | 
            +
            {0} requires the scipy library but it was not found in your environment. 
         | 
| 425 | 
            +
            You can install it with pip: `pip install scipy`
         | 
| 426 | 
            +
            """
         | 
| 427 | 
            +
             | 
| 428 | 
            +
            # docstyle-ignore
         | 
| 429 | 
            +
            LIBROSA_IMPORT_ERROR = """
         | 
| 430 | 
            +
            {0} requires the librosa library but it was not found in your environment. 
         | 
| 431 | 
            +
            Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html 
         | 
| 432 | 
            +
            and follow the ones that match your environment.
         | 
| 433 | 
            +
            """
         | 
| 434 | 
            +
             | 
| 435 | 
            +
            # docstyle-ignore
         | 
| 436 | 
            +
            TRANSFORMERS_IMPORT_ERROR = """
         | 
| 437 | 
            +
            {0} requires the transformers library but it was not found in your environment. 
         | 
| 438 | 
            +
            You can install it with pip: `pip install transformers`
         | 
| 439 | 
            +
            """
         | 
| 440 | 
            +
             | 
| 441 | 
            +
            # docstyle-ignore
         | 
| 442 | 
            +
            UNIDECODE_IMPORT_ERROR = """
         | 
| 443 | 
            +
            {0} requires the unidecode library but it was not found in your environment. 
         | 
| 444 | 
            +
            You can install it with pip: `pip install Unidecode`
         | 
| 445 | 
            +
            """
         | 
| 446 | 
            +
             | 
| 447 | 
            +
            # docstyle-ignore
         | 
| 448 | 
            +
            K_DIFFUSION_IMPORT_ERROR = """
         | 
| 449 | 
            +
            {0} requires the k-diffusion library but it was not found in your environment. 
         | 
| 450 | 
            +
            You can install it with pip: `pip install k-diffusion`
         | 
| 451 | 
            +
            """
         | 
| 452 | 
            +
             | 
| 453 | 
            +
            # docstyle-ignore
         | 
| 454 | 
            +
            NOTE_SEQ_IMPORT_ERROR = """
         | 
| 455 | 
            +
            {0} requires the note-seq library but it was not found in your environment. 
         | 
| 456 | 
            +
            You can install it with pip: `pip install note-seq`
         | 
| 457 | 
            +
            """
         | 
| 458 | 
            +
             | 
| 459 | 
            +
            # docstyle-ignore
         | 
| 460 | 
            +
            WANDB_IMPORT_ERROR = """
         | 
| 461 | 
            +
            {0} requires the wandb library but it was not found in your environment. 
         | 
| 462 | 
            +
            You can install it with pip: `pip install wandb`
         | 
| 463 | 
            +
            """
         | 
| 464 | 
            +
             | 
| 465 | 
            +
            # docstyle-ignore
         | 
| 466 | 
            +
            OMEGACONF_IMPORT_ERROR = """
         | 
| 467 | 
            +
            {0} requires the omegaconf library but it was not found in your environment. 
         | 
| 468 | 
            +
            You can install it with pip: `pip install omegaconf`
         | 
| 469 | 
            +
            """
         | 
| 470 | 
            +
             | 
| 471 | 
            +
            # docstyle-ignore
         | 
| 472 | 
            +
            TENSORBOARD_IMPORT_ERROR = """
         | 
| 473 | 
            +
            {0} requires the tensorboard library but it was not found in your environment. 
         | 
| 474 | 
            +
            You can install it with pip: `pip install tensorboard`
         | 
| 475 | 
            +
            """
         | 
| 476 | 
            +
             | 
| 477 | 
            +
             | 
| 478 | 
            +
            # docstyle-ignore
         | 
| 479 | 
            +
            COMPEL_IMPORT_ERROR = """
         | 
| 480 | 
            +
            {0} requires the compel library but it was not found in your environment. 
         | 
| 481 | 
            +
            You can install it with pip: `pip install compel`
         | 
| 482 | 
            +
            """
         | 
| 483 | 
            +
             | 
| 484 | 
            +
            # docstyle-ignore
         | 
| 485 | 
            +
            BS4_IMPORT_ERROR = """
         | 
| 486 | 
            +
            {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
         | 
| 487 | 
            +
            `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
         | 
| 488 | 
            +
            """
         | 
| 489 | 
            +
             | 
| 490 | 
            +
            # docstyle-ignore
         | 
| 491 | 
            +
            FTFY_IMPORT_ERROR = """
         | 
| 492 | 
            +
            {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
         | 
| 493 | 
            +
            installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
         | 
| 494 | 
            +
            that match your environment. Please note that you may need to restart your runtime after installation.
         | 
| 495 | 
            +
            """
         | 
| 496 | 
            +
             | 
| 497 | 
            +
            # docstyle-ignore
         | 
| 498 | 
            +
            TORCHSDE_IMPORT_ERROR = """
         | 
| 499 | 
            +
            {0} requires the torchsde library but it was not found in your environment. 
         | 
| 500 | 
            +
            You can install it with pip: `pip install torchsde`
         | 
| 501 | 
            +
            """
         | 
| 502 | 
            +
             | 
| 503 | 
            +
             | 
| 504 | 
            +
            BACKENDS_MAPPING = OrderedDict(
         | 
| 505 | 
            +
                [
         | 
| 506 | 
            +
                    ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
         | 
| 507 | 
            +
                    ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
         | 
| 508 | 
            +
                    ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
         | 
| 509 | 
            +
                    ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
         | 
| 510 | 
            +
                    ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
         | 
| 511 | 
            +
                    ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
         | 
| 512 | 
            +
                    ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
         | 
| 513 | 
            +
                    ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
         | 
| 514 | 
            +
                    ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
         | 
| 515 | 
            +
                    ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
         | 
| 516 | 
            +
                    ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
         | 
| 517 | 
            +
                    ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
         | 
| 518 | 
            +
                    ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
         | 
| 519 | 
            +
                    ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
         | 
| 520 | 
            +
                    ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
         | 
| 521 | 
            +
                    ("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
         | 
| 522 | 
            +
                    ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
         | 
| 523 | 
            +
                    ("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)),
         | 
| 524 | 
            +
                ]
         | 
| 525 | 
            +
            )
         | 
| 526 | 
            +
             | 
| 527 | 
            +
             | 
| 528 | 
            +
            def requires_backends(obj, backends):
         | 
| 529 | 
            +
                if not isinstance(backends, (list, tuple)):
         | 
| 530 | 
            +
                    backends = [backends]
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
         | 
| 533 | 
            +
                checks = (BACKENDS_MAPPING[backend] for backend in backends)
         | 
| 534 | 
            +
                failed = [msg.format(name) for available, msg in checks if not available()]
         | 
| 535 | 
            +
                if failed:
         | 
| 536 | 
            +
                    raise ImportError("".join(failed))
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                if name in [
         | 
| 539 | 
            +
                    "VersatileDiffusionTextToImagePipeline",
         | 
| 540 | 
            +
                    "VersatileDiffusionPipeline",
         | 
| 541 | 
            +
                    "VersatileDiffusionDualGuidedPipeline",
         | 
| 542 | 
            +
                    "StableDiffusionImageVariationPipeline",
         | 
| 543 | 
            +
                    "UnCLIPPipeline",
         | 
| 544 | 
            +
                ] and is_transformers_version("<", "4.25.0"):
         | 
| 545 | 
            +
                    raise ImportError(
         | 
| 546 | 
            +
                        f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
         | 
| 547 | 
            +
                        " --upgrade transformers \n```"
         | 
| 548 | 
            +
                    )
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
         | 
| 551 | 
            +
                    "<", "4.26.0"
         | 
| 552 | 
            +
                ):
         | 
| 553 | 
            +
                    raise ImportError(
         | 
| 554 | 
            +
                        f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
         | 
| 555 | 
            +
                        " --upgrade transformers \n```"
         | 
| 556 | 
            +
                    )
         | 
| 557 | 
            +
             | 
| 558 | 
            +
             | 
| 559 | 
            +
            class DummyObject(type):
         | 
| 560 | 
            +
                """
         | 
| 561 | 
            +
                Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
         | 
| 562 | 
            +
                `requires_backend` each time a user tries to access any method of that class.
         | 
| 563 | 
            +
                """
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                def __getattr__(cls, key):
         | 
| 566 | 
            +
                    if key.startswith("_"):
         | 
| 567 | 
            +
                        return super().__getattr__(cls, key)
         | 
| 568 | 
            +
                    requires_backends(cls, cls._backends)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
             | 
| 571 | 
            +
            # This function was copied from: 
         | 
| 572 | 
            +
            # https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
         | 
| 573 | 
            +
            def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
         | 
| 574 | 
            +
                """
         | 
| 575 | 
            +
                Args:
         | 
| 576 | 
            +
                Compares a library version to some requirement using a given operation.
         | 
| 577 | 
            +
                    library_or_version (`str` or `packaging.version.Version`):
         | 
| 578 | 
            +
                        A library name or a version to check.
         | 
| 579 | 
            +
                    operation (`str`):
         | 
| 580 | 
            +
                        A string representation of an operator, such as `">"` or `"<="`.
         | 
| 581 | 
            +
                    requirement_version (`str`):
         | 
| 582 | 
            +
                        The version to compare the library version against
         | 
| 583 | 
            +
                """
         | 
| 584 | 
            +
                if operation not in STR_OPERATION_TO_FUNC.keys():
         | 
| 585 | 
            +
                    raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
         | 
| 586 | 
            +
                operation = STR_OPERATION_TO_FUNC[operation]
         | 
| 587 | 
            +
                if isinstance(library_or_version, str):
         | 
| 588 | 
            +
                    library_or_version = parse(importlib_metadata.version(library_or_version))
         | 
| 589 | 
            +
                return operation(library_or_version, parse(requirement_version))
         | 
| 590 | 
            +
             | 
| 591 | 
            +
             | 
| 592 | 
            +
            # This function was copied from: 
         | 
| 593 | 
            +
            # https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
         | 
| 594 | 
            +
            def is_torch_version(operation: str, version: str):
         | 
| 595 | 
            +
                """
         | 
| 596 | 
            +
                Args:
         | 
| 597 | 
            +
                Compares the current PyTorch version to a given reference with an operation.
         | 
| 598 | 
            +
                    operation (`str`):
         | 
| 599 | 
            +
                        A string representation of an operator, such as `">"` or `"<="`
         | 
| 600 | 
            +
                    version (`str`):
         | 
| 601 | 
            +
                        A string version of PyTorch
         | 
| 602 | 
            +
                """
         | 
| 603 | 
            +
                return compare_versions(parse(_torch_version), operation, version)
         | 
| 604 | 
            +
             | 
| 605 | 
            +
             | 
| 606 | 
            +
            def is_transformers_version(operation: str, version: str):
         | 
| 607 | 
            +
                """
         | 
| 608 | 
            +
                Args:
         | 
| 609 | 
            +
                Compares the current Transformers version to a given reference with an operation.
         | 
| 610 | 
            +
                    operation (`str`):
         | 
| 611 | 
            +
                        A string representation of an operator, such as `">"` or `"<="`
         | 
| 612 | 
            +
                    version (`str`):
         | 
| 613 | 
            +
                        A version string
         | 
| 614 | 
            +
                """
         | 
| 615 | 
            +
                if not _transformers_available:
         | 
| 616 | 
            +
                    return False
         | 
| 617 | 
            +
                return compare_versions(parse(_transformers_version), operation, version)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
             | 
| 620 | 
            +
            def is_accelerate_version(operation: str, version: str):
         | 
| 621 | 
            +
                """
         | 
| 622 | 
            +
                Args:
         | 
| 623 | 
            +
                Compares the current Accelerate version to a given reference with an operation.
         | 
| 624 | 
            +
                    operation (`str`):
         | 
| 625 | 
            +
                        A string representation of an operator, such as `">"` or `"<="`
         | 
| 626 | 
            +
                    version (`str`):
         | 
| 627 | 
            +
                        A version string
         | 
| 628 | 
            +
                """
         | 
| 629 | 
            +
                if not _accelerate_available:
         | 
| 630 | 
            +
                    return False
         | 
| 631 | 
            +
                return compare_versions(parse(_accelerate_version), operation, version)
         | 
| 632 | 
            +
             | 
| 633 | 
            +
             | 
| 634 | 
            +
            def is_k_diffusion_version(operation: str, version: str):
         | 
| 635 | 
            +
                """
         | 
| 636 | 
            +
                Args:
         | 
| 637 | 
            +
                Compares the current k-diffusion version to a given reference with an operation.
         | 
| 638 | 
            +
                    operation (`str`):
         | 
| 639 | 
            +
                        A string representation of an operator, such as `">"` or `"<="`
         | 
| 640 | 
            +
                    version (`str`):
         | 
| 641 | 
            +
                        A version string
         | 
| 642 | 
            +
                """
         | 
| 643 | 
            +
                if not _k_diffusion_available:
         | 
| 644 | 
            +
                    return False
         | 
| 645 | 
            +
                return compare_versions(parse(_k_diffusion_version), operation, version)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
             | 
| 648 | 
            +
            class OptionalDependencyNotAvailable(BaseException):
         | 
| 649 | 
            +
                """An error indicating that an optional dependency of Diffusers was not found in the environment."""
         | 
    	
        diffusers/utils/logging.py
    ADDED
    
    | @@ -0,0 +1,342 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 Optuna, Hugging Face
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """ Logging utilities."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import logging
         | 
| 18 | 
            +
            import os
         | 
| 19 | 
            +
            import sys
         | 
| 20 | 
            +
            import threading
         | 
| 21 | 
            +
            from logging import (
         | 
| 22 | 
            +
                CRITICAL,  # NOQA
         | 
| 23 | 
            +
                DEBUG,  # NOQA
         | 
| 24 | 
            +
                ERROR,  # NOQA
         | 
| 25 | 
            +
                FATAL,  # NOQA
         | 
| 26 | 
            +
                INFO,  # NOQA
         | 
| 27 | 
            +
                NOTSET,  # NOQA
         | 
| 28 | 
            +
                WARN,  # NOQA
         | 
| 29 | 
            +
                WARNING,  # NOQA
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            from typing import Optional
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            from tqdm import auto as tqdm_lib
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            _lock = threading.Lock()
         | 
| 37 | 
            +
            _default_handler: Optional[logging.Handler] = None
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            log_levels = {
         | 
| 40 | 
            +
                "debug": logging.DEBUG,
         | 
| 41 | 
            +
                "info": logging.INFO,
         | 
| 42 | 
            +
                "warning": logging.WARNING,
         | 
| 43 | 
            +
                "error": logging.ERROR,
         | 
| 44 | 
            +
                "critical": logging.CRITICAL,
         | 
| 45 | 
            +
            }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            _default_log_level = logging.WARNING
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            _tqdm_active = True
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def _get_default_logging_level():
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
         | 
| 55 | 
            +
                not - fall back to `_default_log_level`
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
         | 
| 58 | 
            +
                if env_level_str:
         | 
| 59 | 
            +
                    if env_level_str in log_levels:
         | 
| 60 | 
            +
                        return log_levels[env_level_str]
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        logging.getLogger().warning(
         | 
| 63 | 
            +
                            f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
         | 
| 64 | 
            +
                            f"has to be one of: { ', '.join(log_levels.keys()) }"
         | 
| 65 | 
            +
                        )
         | 
| 66 | 
            +
                return _default_log_level
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def _get_library_name() -> str:
         | 
| 70 | 
            +
                return __name__.split(".")[0]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def _get_library_root_logger() -> logging.Logger:
         | 
| 74 | 
            +
                return logging.getLogger(_get_library_name())
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def _configure_library_root_logger() -> None:
         | 
| 78 | 
            +
                global _default_handler
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                with _lock:
         | 
| 81 | 
            +
                    if _default_handler:
         | 
| 82 | 
            +
                        # This library has already configured the library root logger.
         | 
| 83 | 
            +
                        return
         | 
| 84 | 
            +
                    _default_handler = logging.StreamHandler()  # Set sys.stderr as stream.
         | 
| 85 | 
            +
                    _default_handler.flush = sys.stderr.flush
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    # Apply our default configuration to the library root logger.
         | 
| 88 | 
            +
                    library_root_logger = _get_library_root_logger()
         | 
| 89 | 
            +
                    library_root_logger.addHandler(_default_handler)
         | 
| 90 | 
            +
                    library_root_logger.setLevel(_get_default_logging_level())
         | 
| 91 | 
            +
                    library_root_logger.propagate = False
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def _reset_library_root_logger() -> None:
         | 
| 95 | 
            +
                global _default_handler
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                with _lock:
         | 
| 98 | 
            +
                    if not _default_handler:
         | 
| 99 | 
            +
                        return
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    library_root_logger = _get_library_root_logger()
         | 
| 102 | 
            +
                    library_root_logger.removeHandler(_default_handler)
         | 
| 103 | 
            +
                    library_root_logger.setLevel(logging.NOTSET)
         | 
| 104 | 
            +
                    _default_handler = None
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def get_log_levels_dict():
         | 
| 108 | 
            +
                return log_levels
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def get_logger(name: Optional[str] = None) -> logging.Logger:
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                Return a logger with the specified name.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                if name is None:
         | 
| 119 | 
            +
                    name = _get_library_name()
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                _configure_library_root_logger()
         | 
| 122 | 
            +
                return logging.getLogger(name)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def get_verbosity() -> int:
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                Return the current level for the 🤗 Diffusers' root logger as an int.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                Returns:
         | 
| 130 | 
            +
                    `int`: The logging level.
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                <Tip>
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                🤗 Diffusers has following logging levels:
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
         | 
| 137 | 
            +
                - 40: `diffusers.logging.ERROR`
         | 
| 138 | 
            +
                - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
         | 
| 139 | 
            +
                - 20: `diffusers.logging.INFO`
         | 
| 140 | 
            +
                - 10: `diffusers.logging.DEBUG`
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                </Tip>"""
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                _configure_library_root_logger()
         | 
| 145 | 
            +
                return _get_library_root_logger().getEffectiveLevel()
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def set_verbosity(verbosity: int) -> None:
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Set the verbosity level for the 🤗 Diffusers' root logger.
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                Args:
         | 
| 153 | 
            +
                    verbosity (`int`):
         | 
| 154 | 
            +
                        Logging level, e.g., one of:
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                        - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
         | 
| 157 | 
            +
                        - `diffusers.logging.ERROR`
         | 
| 158 | 
            +
                        - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
         | 
| 159 | 
            +
                        - `diffusers.logging.INFO`
         | 
| 160 | 
            +
                        - `diffusers.logging.DEBUG`
         | 
| 161 | 
            +
                """
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                _configure_library_root_logger()
         | 
| 164 | 
            +
                _get_library_root_logger().setLevel(verbosity)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def set_verbosity_info():
         | 
| 168 | 
            +
                """Set the verbosity to the `INFO` level."""
         | 
| 169 | 
            +
                return set_verbosity(INFO)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def set_verbosity_warning():
         | 
| 173 | 
            +
                """Set the verbosity to the `WARNING` level."""
         | 
| 174 | 
            +
                return set_verbosity(WARNING)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def set_verbosity_debug():
         | 
| 178 | 
            +
                """Set the verbosity to the `DEBUG` level."""
         | 
| 179 | 
            +
                return set_verbosity(DEBUG)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def set_verbosity_error():
         | 
| 183 | 
            +
                """Set the verbosity to the `ERROR` level."""
         | 
| 184 | 
            +
                return set_verbosity(ERROR)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def disable_default_handler() -> None:
         | 
| 188 | 
            +
                """Disable the default handler of the HuggingFace Diffusers' root logger."""
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                _configure_library_root_logger()
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                assert _default_handler is not None
         | 
| 193 | 
            +
                _get_library_root_logger().removeHandler(_default_handler)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def enable_default_handler() -> None:
         | 
| 197 | 
            +
                """Enable the default handler of the HuggingFace Diffusers' root logger."""
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                _configure_library_root_logger()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                assert _default_handler is not None
         | 
| 202 | 
            +
                _get_library_root_logger().addHandler(_default_handler)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            def add_handler(handler: logging.Handler) -> None:
         | 
| 206 | 
            +
                """adds a handler to the HuggingFace Diffusers' root logger."""
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                _configure_library_root_logger()
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                assert handler is not None
         | 
| 211 | 
            +
                _get_library_root_logger().addHandler(handler)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            def remove_handler(handler: logging.Handler) -> None:
         | 
| 215 | 
            +
                """removes given handler from the HuggingFace Diffusers' root logger."""
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                _configure_library_root_logger()
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                assert handler is not None and handler not in _get_library_root_logger().handlers
         | 
| 220 | 
            +
                _get_library_root_logger().removeHandler(handler)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            def disable_propagation() -> None:
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                Disable propagation of the library log outputs. Note that log propagation is disabled by default.
         | 
| 226 | 
            +
                """
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                _configure_library_root_logger()
         | 
| 229 | 
            +
                _get_library_root_logger().propagate = False
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            def enable_propagation() -> None:
         | 
| 233 | 
            +
                """
         | 
| 234 | 
            +
                Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
         | 
| 235 | 
            +
                double logging if the root logger has been configured.
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                _configure_library_root_logger()
         | 
| 239 | 
            +
                _get_library_root_logger().propagate = True
         | 
| 240 | 
            +
             | 
| 241 | 
            +
             | 
| 242 | 
            +
            def enable_explicit_format() -> None:
         | 
| 243 | 
            +
                """
         | 
| 244 | 
            +
                Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
         | 
| 245 | 
            +
                ```
         | 
| 246 | 
            +
                    [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
         | 
| 247 | 
            +
                ```
         | 
| 248 | 
            +
                All handlers currently bound to the root logger are affected by this method.
         | 
| 249 | 
            +
                """
         | 
| 250 | 
            +
                handlers = _get_library_root_logger().handlers
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                for handler in handlers:
         | 
| 253 | 
            +
                    formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
         | 
| 254 | 
            +
                    handler.setFormatter(formatter)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def reset_format() -> None:
         | 
| 258 | 
            +
                """
         | 
| 259 | 
            +
                Resets the formatting for HuggingFace Diffusers' loggers.
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                All handlers currently bound to the root logger are affected by this method.
         | 
| 262 | 
            +
                """
         | 
| 263 | 
            +
                handlers = _get_library_root_logger().handlers
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                for handler in handlers:
         | 
| 266 | 
            +
                    handler.setFormatter(None)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            def warning_advice(self, *args, **kwargs):
         | 
| 270 | 
            +
                """
         | 
| 271 | 
            +
                This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
         | 
| 272 | 
            +
                warning will not be printed
         | 
| 273 | 
            +
                """
         | 
| 274 | 
            +
                no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
         | 
| 275 | 
            +
                if no_advisory_warnings:
         | 
| 276 | 
            +
                    return
         | 
| 277 | 
            +
                self.warning(*args, **kwargs)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            logging.Logger.warning_advice = warning_advice
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            class EmptyTqdm:
         | 
| 284 | 
            +
                """Dummy tqdm which doesn't do anything."""
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def __init__(self, *args, **kwargs):  # pylint: disable=unused-argument
         | 
| 287 | 
            +
                    self._iterator = args[0] if args else None
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def __iter__(self):
         | 
| 290 | 
            +
                    return iter(self._iterator)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def __getattr__(self, _):
         | 
| 293 | 
            +
                    """Return empty function."""
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    def empty_fn(*args, **kwargs):  # pylint: disable=unused-argument
         | 
| 296 | 
            +
                        return
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    return empty_fn
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def __enter__(self):
         | 
| 301 | 
            +
                    return self
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def __exit__(self, type_, value, traceback):
         | 
| 304 | 
            +
                    return
         | 
| 305 | 
            +
             | 
| 306 | 
            +
             | 
| 307 | 
            +
            class _tqdm_cls:
         | 
| 308 | 
            +
                def __call__(self, *args, **kwargs):
         | 
| 309 | 
            +
                    if _tqdm_active:
         | 
| 310 | 
            +
                        return tqdm_lib.tqdm(*args, **kwargs)
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        return EmptyTqdm(*args, **kwargs)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                def set_lock(self, *args, **kwargs):
         | 
| 315 | 
            +
                    self._lock = None
         | 
| 316 | 
            +
                    if _tqdm_active:
         | 
| 317 | 
            +
                        return tqdm_lib.tqdm.set_lock(*args, **kwargs)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                def get_lock(self):
         | 
| 320 | 
            +
                    if _tqdm_active:
         | 
| 321 | 
            +
                        return tqdm_lib.tqdm.get_lock()
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            tqdm = _tqdm_cls()
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            def is_progress_bar_enabled() -> bool:
         | 
| 328 | 
            +
                """Return a boolean indicating whether tqdm progress bars are enabled."""
         | 
| 329 | 
            +
                global _tqdm_active
         | 
| 330 | 
            +
                return bool(_tqdm_active)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
             | 
| 333 | 
            +
            def enable_progress_bar():
         | 
| 334 | 
            +
                """Enable tqdm progress bar."""
         | 
| 335 | 
            +
                global _tqdm_active
         | 
| 336 | 
            +
                _tqdm_active = True
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
            def disable_progress_bar():
         | 
| 340 | 
            +
                """Disable tqdm progress bar."""
         | 
| 341 | 
            +
                global _tqdm_active
         | 
| 342 | 
            +
                _tqdm_active = False
         | 
    	
        diffusers/utils/outputs.py
    ADDED
    
    | @@ -0,0 +1,108 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            """
         | 
| 15 | 
            +
            Generic utilities
         | 
| 16 | 
            +
            """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from collections import OrderedDict
         | 
| 19 | 
            +
            from dataclasses import fields
         | 
| 20 | 
            +
            from typing import Any, Tuple
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from .import_utils import is_torch_available
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def is_tensor(x):
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Tests if `x` is a `torch.Tensor` or `np.ndarray`.
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                if is_torch_available():
         | 
| 32 | 
            +
                    import torch
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if isinstance(x, torch.Tensor):
         | 
| 35 | 
            +
                        return True
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                return isinstance(x, np.ndarray)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class BaseOutput(OrderedDict):
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
         | 
| 43 | 
            +
                tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
         | 
| 44 | 
            +
                python dictionary.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                <Tip warning={true}>
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
         | 
| 49 | 
            +
                before.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                </Tip>
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __post_init__(self):
         | 
| 55 | 
            +
                    class_fields = fields(self)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # Safety and consistency checks
         | 
| 58 | 
            +
                    if not len(class_fields):
         | 
| 59 | 
            +
                        raise ValueError(f"{self.__class__.__name__} has no fields.")
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    first_field = getattr(self, class_fields[0].name)
         | 
| 62 | 
            +
                    other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if other_fields_are_none and isinstance(first_field, dict):
         | 
| 65 | 
            +
                        for key, value in first_field.items():
         | 
| 66 | 
            +
                            self[key] = value
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        for field in class_fields:
         | 
| 69 | 
            +
                            v = getattr(self, field.name)
         | 
| 70 | 
            +
                            if v is not None:
         | 
| 71 | 
            +
                                self[field.name] = v
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __delitem__(self, *args, **kwargs):
         | 
| 74 | 
            +
                    raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def setdefault(self, *args, **kwargs):
         | 
| 77 | 
            +
                    raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def pop(self, *args, **kwargs):
         | 
| 80 | 
            +
                    raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def update(self, *args, **kwargs):
         | 
| 83 | 
            +
                    raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __getitem__(self, k):
         | 
| 86 | 
            +
                    if isinstance(k, str):
         | 
| 87 | 
            +
                        inner_dict = dict(self.items())
         | 
| 88 | 
            +
                        return inner_dict[k]
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        return self.to_tuple()[k]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def __setattr__(self, name, value):
         | 
| 93 | 
            +
                    if name in self.keys() and value is not None:
         | 
| 94 | 
            +
                        # Don't call self.__setitem__ to avoid recursion errors
         | 
| 95 | 
            +
                        super().__setitem__(name, value)
         | 
| 96 | 
            +
                    super().__setattr__(name, value)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def __setitem__(self, key, value):
         | 
| 99 | 
            +
                    # Will raise a KeyException if needed
         | 
| 100 | 
            +
                    super().__setitem__(key, value)
         | 
| 101 | 
            +
                    # Don't call self.__setattr__ to avoid recursion errors
         | 
| 102 | 
            +
                    super().__setattr__(key, value)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def to_tuple(self) -> Tuple[Any]:
         | 
| 105 | 
            +
                    """
         | 
| 106 | 
            +
                    Convert self to a tuple containing all the attributes/keys that are not `None`.
         | 
| 107 | 
            +
                    """
         | 
| 108 | 
            +
                    return tuple(self[k] for k in self.keys())
         | 
    	
        diffusers/utils/scheduling_utils.py
    ADDED
    
    | @@ -0,0 +1,176 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import importlib
         | 
| 15 | 
            +
            import os
         | 
| 16 | 
            +
            from dataclasses import dataclass
         | 
| 17 | 
            +
            from enum import Enum
         | 
| 18 | 
            +
            from typing import Any, Dict, Optional, Union
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .outputs import BaseOutput
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            SCHEDULER_CONFIG_NAME = "scheduler_config.json"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            # NOTE: We make this type an enum because it simplifies usage in docs and prevents
         | 
| 29 | 
            +
            # circular imports when used for `_compatibles` within the schedulers module.
         | 
| 30 | 
            +
            # When it's used as a type in pipelines, it really is a Union because the actual
         | 
| 31 | 
            +
            # scheduler instance is passed in.
         | 
| 32 | 
            +
            class KarrasDiffusionSchedulers(Enum):
         | 
| 33 | 
            +
                DDIMScheduler = 1
         | 
| 34 | 
            +
                DDPMScheduler = 2
         | 
| 35 | 
            +
                PNDMScheduler = 3
         | 
| 36 | 
            +
                LMSDiscreteScheduler = 4
         | 
| 37 | 
            +
                EulerDiscreteScheduler = 5
         | 
| 38 | 
            +
                HeunDiscreteScheduler = 6
         | 
| 39 | 
            +
                EulerAncestralDiscreteScheduler = 7
         | 
| 40 | 
            +
                DPMSolverMultistepScheduler = 8
         | 
| 41 | 
            +
                DPMSolverSinglestepScheduler = 9
         | 
| 42 | 
            +
                KDPM2DiscreteScheduler = 10
         | 
| 43 | 
            +
                KDPM2AncestralDiscreteScheduler = 11
         | 
| 44 | 
            +
                DEISMultistepScheduler = 12
         | 
| 45 | 
            +
                UniPCMultistepScheduler = 13
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            @dataclass
         | 
| 49 | 
            +
            class SchedulerOutput(BaseOutput):
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                Base class for the scheduler's step function output.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                Args:
         | 
| 54 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 55 | 
            +
                        Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 56 | 
            +
                        denoising loop.
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            class SchedulerMixin:
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                Mixin containing common functions for the schedulers.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                Class attributes:
         | 
| 67 | 
            +
                    - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
         | 
| 68 | 
            +
                      `from_config` can be used from a class different than the one used to save the config (should be overridden
         | 
| 69 | 
            +
                      by parent class).
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                config_name = SCHEDULER_CONFIG_NAME
         | 
| 73 | 
            +
                _compatibles = []
         | 
| 74 | 
            +
                has_compatibles = True
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                @classmethod
         | 
| 77 | 
            +
                def from_pretrained(
         | 
| 78 | 
            +
                    cls,
         | 
| 79 | 
            +
                    pretrained_model_name_or_path: Dict[str, Any] = None,
         | 
| 80 | 
            +
                    subfolder: Optional[str] = None,
         | 
| 81 | 
            +
                    return_unused_kwargs=False,
         | 
| 82 | 
            +
                    **kwargs,
         | 
| 83 | 
            +
                ):
         | 
| 84 | 
            +
                    r"""
         | 
| 85 | 
            +
                    Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    Parameters:
         | 
| 88 | 
            +
                        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
         | 
| 89 | 
            +
                            Can be either:
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                                - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
         | 
| 92 | 
            +
                                  organization name, like `google/ddpm-celebahq-256`.
         | 
| 93 | 
            +
                                - A path to a *directory* containing the schedluer configurations saved using
         | 
| 94 | 
            +
                                  [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
         | 
| 95 | 
            +
                        subfolder (`str`, *optional*):
         | 
| 96 | 
            +
                            In case the relevant files are located inside a subfolder of the model repo (either remote in
         | 
| 97 | 
            +
                            huggingface.co or downloaded locally), you can specify the folder name here.
         | 
| 98 | 
            +
                        return_unused_kwargs (`bool`, *optional*, defaults to `False`):
         | 
| 99 | 
            +
                            Whether kwargs that are not consumed by the Python class should be returned or not.
         | 
| 100 | 
            +
                        cache_dir (`Union[str, os.PathLike]`, *optional*):
         | 
| 101 | 
            +
                            Path to a directory in which a downloaded pretrained model configuration should be cached if the
         | 
| 102 | 
            +
                            standard cache should not be used.
         | 
| 103 | 
            +
                        force_download (`bool`, *optional*, defaults to `False`):
         | 
| 104 | 
            +
                            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
         | 
| 105 | 
            +
                            cached versions if they exist.
         | 
| 106 | 
            +
                        resume_download (`bool`, *optional*, defaults to `False`):
         | 
| 107 | 
            +
                            Whether or not to delete incompletely received files. Will attempt to resume the download if such a
         | 
| 108 | 
            +
                            file exists.
         | 
| 109 | 
            +
                        proxies (`Dict[str, str]`, *optional*):
         | 
| 110 | 
            +
                            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
         | 
| 111 | 
            +
                            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
         | 
| 112 | 
            +
                        output_loading_info(`bool`, *optional*, defaults to `False`):
         | 
| 113 | 
            +
                            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
         | 
| 114 | 
            +
                        local_files_only(`bool`, *optional*, defaults to `False`):
         | 
| 115 | 
            +
                            Whether or not to only look at local files (i.e., do not try to download the model).
         | 
| 116 | 
            +
                        use_auth_token (`str` or *bool*, *optional*):
         | 
| 117 | 
            +
                            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
         | 
| 118 | 
            +
                            when running `transformers-cli login` (stored in `~/.huggingface`).
         | 
| 119 | 
            +
                        revision (`str`, *optional*, defaults to `"main"`):
         | 
| 120 | 
            +
                            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
         | 
| 121 | 
            +
                            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
         | 
| 122 | 
            +
                            identifier allowed by git.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    <Tip>
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                     It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
         | 
| 127 | 
            +
                     models](https://huggingface.co/docs/hub/models-gated#gated-models).
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    </Tip>
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    <Tip>
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
         | 
| 134 | 
            +
                    use this method in a firewalled environment.
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    </Tip>
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    config, kwargs, commit_hash = cls.load_config(
         | 
| 140 | 
            +
                        pretrained_model_name_or_path=pretrained_model_name_or_path,
         | 
| 141 | 
            +
                        subfolder=subfolder,
         | 
| 142 | 
            +
                        return_unused_kwargs=True,
         | 
| 143 | 
            +
                        return_commit_hash=True,
         | 
| 144 | 
            +
                        **kwargs,
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
                    return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
         | 
| 149 | 
            +
                    """
         | 
| 150 | 
            +
                    Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
         | 
| 151 | 
            +
                    [`~SchedulerMixin.from_pretrained`] class method.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    Args:
         | 
| 154 | 
            +
                        save_directory (`str` or `os.PathLike`):
         | 
| 155 | 
            +
                            Directory where the configuration JSON file will be saved (will be created if it does not exist).
         | 
| 156 | 
            +
                    """
         | 
| 157 | 
            +
                    self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                @property
         | 
| 160 | 
            +
                def compatibles(self):
         | 
| 161 | 
            +
                    """
         | 
| 162 | 
            +
                    Returns all schedulers that are compatible with this scheduler
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    Returns:
         | 
| 165 | 
            +
                        `List[SchedulerMixin]`: List of compatible schedulers
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    return self._get_compatibles()
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                @classmethod
         | 
| 170 | 
            +
                def _get_compatibles(cls):
         | 
| 171 | 
            +
                    compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
         | 
| 172 | 
            +
                    diffusers_library = importlib.import_module(__name__.split(".")[0])
         | 
| 173 | 
            +
                    compatible_classes = [
         | 
| 174 | 
            +
                        getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
         | 
| 175 | 
            +
                    ]
         | 
| 176 | 
            +
                    return compatible_classes
         | 
    	
        diffusers/utils/torch_utils.py
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            """
         | 
| 15 | 
            +
            PyTorch utilities: Utilities related to PyTorch
         | 
| 16 | 
            +
            """
         | 
| 17 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from . import logging
         | 
| 20 | 
            +
            from .import_utils import is_torch_available, is_torch_version
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            if is_torch_available():
         | 
| 23 | 
            +
                import torch
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            try:
         | 
| 28 | 
            +
                from torch._dynamo import allow_in_graph as maybe_allow_in_graph
         | 
| 29 | 
            +
            except (ImportError, ModuleNotFoundError):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def maybe_allow_in_graph(cls):
         | 
| 32 | 
            +
                    return cls
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def randn_tensor(
         | 
| 36 | 
            +
                shape: Union[Tuple, List],
         | 
| 37 | 
            +
                generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
         | 
| 38 | 
            +
                device: Optional["torch.device"] = None,
         | 
| 39 | 
            +
                dtype: Optional["torch.dtype"] = None,
         | 
| 40 | 
            +
                layout: Optional["torch.layout"] = None,
         | 
| 41 | 
            +
            ):
         | 
| 42 | 
            +
                """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
         | 
| 43 | 
            +
                passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
         | 
| 44 | 
            +
                will always be created on CPU.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                # device on which tensor is created defaults to device
         | 
| 47 | 
            +
                rand_device = device
         | 
| 48 | 
            +
                batch_size = shape[0]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                layout = layout or torch.strided
         | 
| 51 | 
            +
                device = device or torch.device("cpu")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                if generator is not None:
         | 
| 54 | 
            +
                    gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
         | 
| 55 | 
            +
                    if gen_device_type != device.type and gen_device_type == "cpu":
         | 
| 56 | 
            +
                        rand_device = "cpu"
         | 
| 57 | 
            +
                        if device != "mps":
         | 
| 58 | 
            +
                            logger.info(
         | 
| 59 | 
            +
                                f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
         | 
| 60 | 
            +
                                f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
         | 
| 61 | 
            +
                                f" slighly speed up this function by passing a generator that was created on the {device} device."
         | 
| 62 | 
            +
                            )
         | 
| 63 | 
            +
                    elif gen_device_type != device.type and gen_device_type == "cuda":
         | 
| 64 | 
            +
                        raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                if isinstance(generator, list):
         | 
| 67 | 
            +
                    shape = (1,) + shape[1:]
         | 
| 68 | 
            +
                    latents = [
         | 
| 69 | 
            +
                        torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
         | 
| 70 | 
            +
                        for i in range(batch_size)
         | 
| 71 | 
            +
                    ]
         | 
| 72 | 
            +
                    latents = torch.cat(latents, dim=0).to(device)
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                return latents
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def is_compiled_module(module):
         | 
| 80 | 
            +
                """Check whether the module was compiled with torch.compile()"""
         | 
| 81 | 
            +
                if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
         | 
| 82 | 
            +
                    return False
         | 
| 83 | 
            +
                return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
         | 
    	
        run_gradio.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import soundfile as sf
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import random, os
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from consistencytta import ConsistencyTTA
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def seed_all(seed):
         | 
| 11 | 
            +
                """ Seed all random number generators. """
         | 
| 12 | 
            +
                seed = int(seed)
         | 
| 13 | 
            +
                random.seed(seed)
         | 
| 14 | 
            +
                np.random.seed(seed)
         | 
| 15 | 
            +
                torch.manual_seed(seed)
         | 
| 16 | 
            +
                torch.cuda.manual_seed(seed)
         | 
| 17 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 18 | 
            +
                torch.cuda.random.manual_seed(seed)
         | 
| 19 | 
            +
                os.environ['PYTHONHASHSEED'] = str(seed)
         | 
| 20 | 
            +
                torch.backends.cudnn.benchmark = False
         | 
| 21 | 
            +
                torch.backends.cudnn.deterministic = True
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            device = torch.device(
         | 
| 25 | 
            +
                "cuda:0" if torch.cuda.is_available() else
         | 
| 26 | 
            +
                "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            sr = 16000
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Build ConsistencyTTA model
         | 
| 31 | 
            +
            consistencytta = ConsistencyTTA().to(device)
         | 
| 32 | 
            +
            consistencytta.eval()
         | 
| 33 | 
            +
            consistencytta.requires_grad_(False)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
         | 
| 37 | 
            +
                """ Generate audio from a given prompt.
         | 
| 38 | 
            +
                Args:
         | 
| 39 | 
            +
                    prompt (str): Text prompt to generate audio from.
         | 
| 40 | 
            +
                    seed (str, optional): Random seed. Defaults to '', which means no seed.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                if seed != '':
         | 
| 43 | 
            +
                    try:
         | 
| 44 | 
            +
                        seed_all(int(seed))
         | 
| 45 | 
            +
                    except:
         | 
| 46 | 
            +
                        pass
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                with torch.no_grad():
         | 
| 49 | 
            +
                    with torch.autocast(
         | 
| 50 | 
            +
                        device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()
         | 
| 51 | 
            +
                    ):
         | 
| 52 | 
            +
                        wav = consistencytta(
         | 
| 53 | 
            +
                            [prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                    sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                return "output.wav"
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            # Generate test audio
         | 
| 61 | 
            +
            print("Generating test audio...")
         | 
| 62 | 
            +
            generate("A dog barks as a train passes by.", seed=1)
         | 
| 63 | 
            +
            print("Test audio generated successfully! Starting Gradio interface...")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Launch Gradio interface
         | 
| 66 | 
            +
            iface = gr.Interface(
         | 
| 67 | 
            +
                fn=generate,
         | 
| 68 | 
            +
                inputs=[
         | 
| 69 | 
            +
                    gr.Textbox(
         | 
| 70 | 
            +
                        label="Text", value="Several people cheer and scream and speak as water flows hard."
         | 
| 71 | 
            +
                    ),
         | 
| 72 | 
            +
                    gr.Textbox(label="Random Seed (Optional)", value=''),
         | 
| 73 | 
            +
                    gr.Slider(
         | 
| 74 | 
            +
                        minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
         | 
| 75 | 
            +
                    )],
         | 
| 76 | 
            +
                outputs="audio",
         | 
| 77 | 
            +
                title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
         | 
| 78 | 
            +
                      "Generation with Consistency Distillation",
         | 
| 79 | 
            +
                description="This is the official demo page for <a href='https://consistency-tta.github." \
         | 
| 80 | 
            +
                            "io' target=“blank”>ConsistencyTTA</a>, a model that accelerates " \
         | 
| 81 | 
            +
                            "diffusion-based text-to-audio generation hundreds of times with consistency " \
         | 
| 82 | 
            +
                            "models. <br> Here, the audio is generated within a single non-autoregressive " \
         | 
| 83 | 
            +
                            "forward pass from the  CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
         | 
| 84 | 
            +
                            "the training dataset does not include speech, the model is not expected to " \
         | 
| 85 | 
            +
                            "generate coherent speech. <br> Have fun!"
         | 
| 86 | 
            +
            )
         | 
| 87 | 
            +
            iface.launch(share=True)
         | 
    	
        tango_diffusion_light.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "UNet2DConditionModel",
         | 
| 3 | 
            +
              "_diffusers_version": "0.10.0.dev0",
         | 
| 4 | 
            +
              "act_fn": "silu",
         | 
| 5 | 
            +
              "attention_head_dim": [
         | 
| 6 | 
            +
                5,
         | 
| 7 | 
            +
                10,
         | 
| 8 | 
            +
                20,
         | 
| 9 | 
            +
                20
         | 
| 10 | 
            +
              ],
         | 
| 11 | 
            +
              "block_out_channels": [
         | 
| 12 | 
            +
                256,
         | 
| 13 | 
            +
                512,
         | 
| 14 | 
            +
                1024,
         | 
| 15 | 
            +
                1024
         | 
| 16 | 
            +
              ],
         | 
| 17 | 
            +
              "center_input_sample": false,
         | 
| 18 | 
            +
              "cross_attention_dim": 1024,
         | 
| 19 | 
            +
              "down_block_types": [
         | 
| 20 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 21 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 22 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 23 | 
            +
                "DownBlock2D"
         | 
| 24 | 
            +
              ],
         | 
| 25 | 
            +
              "downsample_padding": 1,
         | 
| 26 | 
            +
              "dual_cross_attention": false,
         | 
| 27 | 
            +
              "flip_sin_to_cos": true,
         | 
| 28 | 
            +
              "freq_shift": 0,
         | 
| 29 | 
            +
              "in_channels": 8,
         | 
| 30 | 
            +
              "layers_per_block": 2,
         | 
| 31 | 
            +
              "mid_block_scale_factor": 1,
         | 
| 32 | 
            +
              "norm_eps": 1e-05,
         | 
| 33 | 
            +
              "norm_num_groups": 32,
         | 
| 34 | 
            +
              "num_class_embeds": null,
         | 
| 35 | 
            +
              "only_cross_attention": false,
         | 
| 36 | 
            +
              "out_channels": 8,
         | 
| 37 | 
            +
              "sample_size": [32, 2],
         | 
| 38 | 
            +
              "up_block_types": [
         | 
| 39 | 
            +
                "UpBlock2D",
         | 
| 40 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 41 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 42 | 
            +
                "CrossAttnUpBlock2D"
         | 
| 43 | 
            +
              ],
         | 
| 44 | 
            +
              "use_linear_projection": true,
         | 
| 45 | 
            +
              "upcast_attention": true
         | 
| 46 | 
            +
            }
         | 
