import torch import einops from concept_attention.utils import linear_normalization from concept_attention.image_generator import FluxGenerator def generate_concept_basis_and_image_queries( prompt: str, concepts: list[str], layer_index: list[int] = [18], average_over_time: bool=True, model_name="flux-dev", num_steps=50, seed=42, average_after=0, target_space="output", generator=None, normalize_concepts=False, device="cuda", include_images_in_basis=False, offload=True, joint_attention_kwargs=None ): """ Given a prompt, generate the set basis of concept vectors for a particular layer in the model and the encoded image queries. """ assert target_space in ["output", "value", "cross_attention"], "Invalid target space" if generator is None: generator = FluxGenerator( model_name, device, offload=offload, ) image = generator.generate_image( width=1024, height=1024, num_steps=num_steps, guidance=0.0, seed=seed, prompt=prompt, concepts=concepts, joint_attention_kwargs=joint_attention_kwargs, ) concept_vectors = [] image_vectors = [] supplemental_vectors = [] for double_block in generator.model.double_blocks: if target_space == "output": image_vecs = torch.stack( double_block.image_output_vectors ).squeeze(1) concept_vecs = torch.stack( double_block.concept_output_vectors ).squeeze(1) image_supplemental_vecs = image_vecs # Clear out the layer double_block.clear_cached_vectors() elif target_space == "value": image_vecs = torch.stack( double_block.image_value_vectors ).squeeze(1) concept_vecs = torch.stack( double_block.concept_value_vectors ).squeeze(1) image_supplemental_vecs = image_vecs # Clear out the layer double_block.clear_cached_vectors() elif target_space == "cross_attention": image_vecs = torch.stack( double_block.image_query_vectors ).squeeze(1) concept_vecs = torch.stack( double_block.concept_key_vectors ).squeeze(1) image_supplemental_vecs = torch.stack( double_block.image_key_vectors ).squeeze(1) # Clear out the layer double_block.clear_cached_vectors() else: raise ValueError("Invalid target space") # Average over time if average_over_time: image_vecs = image_vecs[average_after:].mean(dim=0) concept_vecs = concept_vecs[average_after:].mean(dim=0) image_supplemental_vecs = image_supplemental_vecs[average_after:].mean(dim=0) # Add to list concept_vectors.append(concept_vecs) image_vectors.append(image_vecs) supplemental_vectors.append(image_supplemental_vecs) # Stack layers concept_vectors = torch.stack(concept_vectors) if include_images_in_basis: supplemental_vectors = torch.stack(supplemental_vectors) concept_vectors = torch.cat([concept_vectors, supplemental_vectors], dim=-2) image_vectors = torch.stack(image_vectors) if layer_index is not None: # Pull out the layer index concept_vectors = concept_vectors[layer_index] image_vectors = image_vectors[layer_index] # Apply linear normalization to concepts # NOTE: This is very important, as it makes up for not being able to do softmax if normalize_concepts: concept_vectors = linear_normalization(concept_vectors, dim=-2) return image, concept_vectors, image_vectors