Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,855 Bytes
8133633 9505fe5 8133633 9505fe5 8133633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from dataclasses import dataclass
import torch
import torch.nn as nn
import math
import importlib
import craftsman
import re
from typing import Optional
from craftsman.utils.base import BaseModule
from craftsman.models.denoisers.utils import *
@craftsman.register("pixart-denoiser")
class PixArtDinoDenoiser(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: Optional[str] = None
input_channels: int = 32
output_channels: int = 32
n_ctx: int = 512
width: int = 768
layers: int = 28
heads: int = 16
context_dim: int = 1024
n_views: int = 1
context_ln: bool = True
init_scale: float = 0.25
use_checkpoint: bool = False
drop_path: float = 0.
clip_weight: float = 1.0
dino_weight: float = 1.0
cfg: Config
def configure(self) -> None:
super().configure()
# timestep embedding
self.time_embed = TimestepEmbedder(self.cfg.width)
# x embedding
self.x_embed = nn.Linear(self.cfg.input_channels, self.cfg.width, bias=True)
# context embedding
if self.cfg.context_ln:
self.clip_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
self.dino_embed = nn.Sequential(
nn.LayerNorm(self.cfg.context_dim),
nn.Linear(self.cfg.context_dim, self.cfg.width),
)
else:
self.clip_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
self.dino_embed = nn.Linear(self.cfg.context_dim, self.cfg.width)
init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width)
drop_path = [x.item() for x in torch.linspace(0, self.cfg.drop_path, self.cfg.layers)]
self.blocks = nn.ModuleList([
DiTBlock(
width=self.cfg.width,
heads=self.cfg.heads,
init_scale=init_scale,
qkv_bias=self.cfg.drop_path,
use_flash=True,
drop_path=drop_path[i]
)
for i in range(self.cfg.layers)
])
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(self.cfg.width, 6 * self.cfg.width, bias=True)
)
# final layer
self.final_layer = T2IFinalLayer(self.cfg.width, self.cfg.output_channels)
self.identity_initialize()
if self.cfg.pretrained_model_name_or_path:
print(f"Loading pretrained model from {self.cfg.pretrained_model_name_or_path}")
ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu")['state_dict']
self.denoiser_ckpt = {}
for k, v in ckpt.items():
if k.startswith('denoiser_model.'):
self.denoiser_ckpt[k.replace('denoiser_model.', '')] = v
self.load_state_dict(self.denoiser_ckpt, strict=False)
def identity_initialize(self):
for block in self.blocks:
nn.init.constant_(block.attn.c_proj.weight, 0)
nn.init.constant_(block.attn.c_proj.bias, 0)
nn.init.constant_(block.cross_attn.c_proj.weight, 0)
nn.init.constant_(block.cross_attn.c_proj.bias, 0)
nn.init.constant_(block.mlp.c_proj.weight, 0)
nn.init.constant_(block.mlp.c_proj.bias, 0)
def forward(self,
model_input: torch.FloatTensor,
timestep: torch.LongTensor,
context: torch.FloatTensor):
r"""
Args:
model_input (torch.FloatTensor): [bs, n_data, c]
timestep (torch.LongTensor): [bs,]
context (torch.FloatTensor): [bs, context_tokens, c]
Returns:
sample (torch.FloatTensor): [bs, n_data, c]
"""
B, n_data, _ = model_input.shape
# 1. time
t_emb = self.time_embed(timestep)
# 2. conditions projector
context = context.view(B, self.cfg.n_views, -1, self.cfg.context_dim)
clip_feat, dino_feat = context.chunk(2, dim=2)
clip_cond = self.clip_embed(clip_feat.contiguous().view(B, -1, self.cfg.context_dim))
dino_cond = self.dino_embed(dino_feat.contiguous().view(B, -1, self.cfg.context_dim))
visual_cond = self.cfg.clip_weight * clip_cond + self.cfg.dino_weight * dino_cond
# 4. denoiser
latent = self.x_embed(model_input)
t0 = self.t_block(t_emb).unsqueeze(dim=1)
for block in self.blocks:
latent = auto_grad_checkpoint(block, latent, visual_cond, t0)
latent = self.final_layer(latent, t_emb)
return latent
|