# pipeline.py # subclass SD pipeline to replace CLIP-L with T5 import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionPipeline from diffusers.utils import logging from transformers import CLIPTextModelWithProjection from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextConfig logger = logging.get_logger(__name__) T5_NAME="mcmonkey/google_t5-v1_1-xxl_encoderonly" class StableDiffusionT5Pipeline(StableDiffusionPipeline): ################################################################################### # Create a minimal CLIPTextModelWithProjection with minimal layers and vocab size, # and the projection layer we need # We only actually care about the .text_projection # but this is the only (easy) way to save using pipe.save_pretrained() def create_clipholder(self): config = CLIPTextConfig( vocab_size=1, # minimal vocab size hidden_size=4096, # input hidden size projection_dim=768, # output dimension to projection num_hidden_layers=0, # no transformer layers num_attention_heads=1, intermediate_size=4, # minimal intermediate size ) model = CLIPTextModelWithProjection(config) # This should automatically have generated the following: #model.text_projection = nn.Linear(4096, 768) return model ################################################### # override this so we can auto-init text_encoder # These are the original values? #_optional_components = ["safety_checker", "feature_extractor", "image_encoder", "text_encoder"] # t5_projection not really optional, but needed it here to stop internal whining _optional_components = StableDiffusionPipeline._optional_components + ["text_encoder", "t5_projection"] def __init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker=None, feature_extractor=None, image_encoder=None, requires_safety_checker=True, t5_projection=None, ): self.tokenizer = ( tokenizer if tokenizer is not None else T5Tokenizer.from_pretrained(T5_NAME,torch_dtype=unet.dtype) ) if text_encoder is None: self.text_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype) else: self.text_encoder = text_encoder super().__init__( vae=vae, tokenizer=self.tokenizer, text_encoder=self.text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, requires_safety_checker=requires_safety_checker, ) if t5_projection is None: print("WARNING: no CLIPTextModelWithProjection found. This may indicate an error") answer=input("Should I auto-generate one? type 'Yes' to proceed") if answer != "Yes": exit(1) self.t5_projection = self.create_clipholder().to(vae.device, dtype=vae.dtype) else: if isinstance(t5_projection, CLIPTextModelWithProjection): self.t5_projection = t5_projection else: raise TypeError("Error: expected t5_projection to be type CLIPTextModelWithProjection") checkval = getattr(self.t5_projection.config, "scaling_factor", None) if not checkval: #0.013 # This is my kinda calculated factor, for norms ~ 1.0 #scaling_factor = 0.13025 # This would be the vae scaling factor #scaling_factor = 0.035 # This is a commonly used factor for T5 # buuut... to make output stdD similar to CLIP, scaling factor = 1.8 # (See check-cache-stdd-t5.py) scaling_factor = 1.8 print("INFO: Pipeline setting empty t5 scaling factor to", scaling_factor) self.t5_projection.config.scaling_factor = scaling_factor # Ensure everything is properly registered for to("cuda") # and also for saving the model self.register_modules(t5_projection=self.t5_projection) # returns the raw t5 4096dim embedding, not the one scaled to 768 def encode_prompt_t5( self, prompt, negative_prompt, #can be None device, ): def _tok(text): out = self.tokenizer( text, return_tensors="pt", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, ) return out.input_ids.to(device=device, dtype=torch.long), out.attention_mask.to(device) pos_ids, pos_mask = _tok(prompt) pos_hidden = self.text_encoder(pos_ids, attention_mask=pos_mask).last_hidden_state neg_prompt = negative_prompt if negative_prompt is not None else "" neg_ids, neg_mask = _tok(neg_prompt) neg_hidden = self.text_encoder(neg_ids, attention_mask=neg_mask).last_hidden_state return pos_hidden, neg_hidden def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, **kwargs, ): scaling_factor = self.t5_projection.config.scaling_factor pos_hidden, neg_hidden = self.encode_prompt_t5(prompt, negative_prompt, device) pos_embeds = self.t5_projection.text_projection(pos_hidden) pos_embeds = pos_embeds * scaling_factor if do_classifier_free_guidance: neg_embeds = self.t5_projection.text_projection(neg_hidden) neg_embeds = neg_embeds * scaling_factor pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0) neg_embeds = neg_embeds.repeat_interleave(num_images_per_prompt, dim=0) return [neg_embeds, pos_embeds] else: pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0) return pos_embeds