helblazer811 commited on
Commit
3a5de53
·
1 Parent(s): 5f5f2bb

Added cross attention to the UI.

Browse files
Files changed (2) hide show
  1. app.py +198 -83
  2. concept_attention/concept_attention_pipeline.py +105 -76
app.py CHANGED
@@ -2,6 +2,8 @@ import spaces
2
  import gradio as gr
3
  from PIL import Image
4
  import math
 
 
5
 
6
  from concept_attention import ConceptAttentionFluxPipeline
7
 
@@ -14,31 +16,53 @@ EXAMPLES = [
14
  "tree, dog, grass, background", # words
15
  42, # seed
16
  ],
17
- [
18
- "A dragon", # prompt
19
- "dragon, sky, rock, cloud", # words
20
- 42, # seed
21
- ],
22
- [
23
- "A hot air balloon", # prompt
24
- "balloon, sky, water, tree", # words
25
- 42, # seed
26
- ]
27
  ]
28
 
29
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @spaces.GPU(duration=60)
32
- def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_index):
33
- print("Processing inputs")
34
- assert layer_start_index is not None
35
- assert timestep_start_index is not None
 
 
 
36
 
37
  prompt = prompt.strip()
38
- if not word_list.strip():
39
- gr.exceptions.InputError("words", "Please enter comma-separated words")
40
 
41
- concepts = [w.strip() for w in word_list.split(",")]
 
 
 
 
42
 
43
  if len(concepts) == 0:
44
  raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
@@ -59,101 +83,192 @@ def process_inputs(prompt, word_list, seed, layer_start_index, timestep_start_in
59
  )
60
 
61
  output_image = pipeline_output.image
62
- concept_heatmaps = pipeline_output.concept_heatmaps
63
- concept_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in concept_heatmaps]
64
 
65
- heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
66
- all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
 
67
 
68
- num_rows = math.ceil(len(all_images_and_labels) / COLUMNS)
 
 
69
 
70
- print(num_rows)
 
 
71
 
72
- return all_images_and_labels, num_rows
 
 
73
 
74
  with gr.Blocks(
75
  css="""
76
- .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
77
- .title { text-align: center; margin-bottom: 10px; }
78
- .authors { text-align: center; margin-bottom: 10px; }
79
- .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
80
- .abstract { text-align: center; margin-bottom: 40px; }
81
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ) as demo:
83
  with gr.Column(elem_classes="container"):
84
- gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
85
- gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
86
- gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
87
- gr.Markdown(
88
- """
89
- We introduce ConceptAttention, an approach to interpreting the intermediate representations of diffusion transformers.
90
- The user just gives a list of textual concepts and ConceptAttention will produce a set of saliency maps depicting
91
- the location and intensity of these concepts in generated images. Check out our paper: [here](https://arxiv.org/abs/2502.04320).
92
- """,
93
- elem_classes="abstract"
94
- )
95
 
96
- with gr.Row(scale=1):
97
- prompt = gr.Textbox(
98
- label="Enter your prompt",
99
- placeholder="Enter your prompt",
100
- value=EXAMPLES[0][0],
101
- scale=4,
102
- # show_label=True,
103
- container=False
104
- # height="80px"
105
- )
106
- words = gr.Textbox(
107
- label="Enter a list of concepts (comma-separated)",
108
- placeholder="Enter a list of concepts (comma-separated)",
109
- value=EXAMPLES[0][1],
110
- scale=4,
111
- # show_label=True,
112
- container=False
113
- # height="80px"
114
- )
115
- submit_btn = gr.Button(
116
- "Run",
117
- min_width="100px",
118
- scale=1
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  num_rows_state = gr.State(value=1) # Initial number of rows
122
 
123
  # generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
124
- gallery = gr.Gallery(
125
- label="Generated images",
126
- show_label=True,
127
- # elem_id="gallery",
128
- columns=COLUMNS,
129
- rows=1,
130
- # object_fit="contain",
131
- height="auto",
132
- elem_classes="gallery"
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  with gr.Accordion("Advanced Settings", open=False):
135
  seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
136
  layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
137
  timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
138
 
139
-
140
  submit_btn.click(
141
  fn=process_inputs,
142
- inputs=[prompt, words, seed, layer_start_index, timestep_start_index],
143
- outputs=[gallery, num_rows_state]
144
  )
145
 
146
- gr.Examples(examples=EXAMPLES, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery, num_rows_state], fn=process_inputs, cache_examples=False)
147
-
148
  # num_rows_state.change(
149
  # fn=lambda rows: gr.Gallery.update(rows=int(rows)),
150
  # inputs=[num_rows_state],
151
  # outputs=[gallery]
152
  # )
153
 
154
- # Automatically process the first example on launch
155
- demo.load(process_inputs, inputs=[prompt, words, seed, layer_start_index, timestep_start_index], outputs=[gallery, num_rows_state])
156
 
 
 
 
 
 
 
157
 
158
  if __name__ == "__main__":
159
  demo.launch(max_threads=1)
 
2
  import gradio as gr
3
  from PIL import Image
4
  import math
5
+ import io
6
+ import base64
7
 
8
  from concept_attention import ConceptAttentionFluxPipeline
9
 
 
16
  "tree, dog, grass, background", # words
17
  42, # seed
18
  ],
19
+ # [
20
+ # "A dragon", # prompt
21
+ # "dragon, sky, rock, cloud", # words
22
+ # 42, # seed
23
+ # ],
24
+ # [
25
+ # "A hot air balloon", # prompt
26
+ # "balloon, sky, water, tree", # words
27
+ # 42, # seed
28
+ # ]
29
  ]
30
 
31
+ def update_default_concepts(prompt):
32
+ default_concepts = {
33
+ "A dog by a tree": ["dog", "grass", "tree", "background"],
34
+ "A dragon": ["dragon", "sky", "rock", "cloud"],
35
+ "A hot air balloon": ["balloon", "sky", "water", "tree"]
36
+ }
37
+
38
+ return gr.update(value=default_concepts.get(prompt, []))
39
+
40
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda", offload_model=True)
41
+
42
+ def convert_pil_to_bytes(img):
43
+ img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
44
+ buffered = io.BytesIO()
45
+ img.save(buffered, format="PNG")
46
+ img_str = base64.b64encode(buffered.getvalue()).decode()
47
+
48
+ return img_str
49
 
50
  @spaces.GPU(duration=60)
51
+ def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
52
+ # print("Processing inputs")
53
+ # assert layer_start_index is not None
54
+ # assert timestep_start_index is not None
55
+
56
+ if not prompt.strip():
57
+ raise gr.exceptions.InputError("prompt", "Please enter a prompt")
58
 
59
  prompt = prompt.strip()
 
 
60
 
61
+ print(concepts)
62
+ # if not word_list.strip():
63
+ # gr.exceptions.InputError("words", "Please enter comma-separated words")
64
+
65
+ # concepts = [w.strip() for w in word_list.split(",")]
66
 
67
  if len(concepts) == 0:
68
  raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
 
83
  )
84
 
85
  output_image = pipeline_output.image
 
 
86
 
87
+ output_space_heatmaps = pipeline_output.concept_heatmaps
88
+ output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
89
+ output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
90
 
91
+ cross_attention_heatmaps = pipeline_output.cross_attention_maps
92
+ cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
93
+ cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
94
 
95
+ # heatmaps_and_labels = [(concept_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
96
+ # all_images_and_labels = [(output_image, "Generated Image")] + heatmaps_and_labels
97
+ # num_rows = math.ceil(len(all_images_and_labels) / COLUMNS)
98
 
99
+ return output_image, \
100
+ gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
101
+ gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
102
 
103
  with gr.Blocks(
104
  css="""
105
+ .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
106
+ .title { text-align: center; margin-bottom: 10px; }
107
+ .authors { text-align: center; margin-bottom: 10px; }
108
+ .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
109
+ .abstract { text-align: center; margin-bottom: 40px; }
110
+ .generated-image {
111
+ display: flex;
112
+ align-items: center;
113
+ justify-content: center;
114
+ height: 100%; /* Ensures full height */
115
+ }
116
+ .input {
117
+ height: 47px;
118
+ }
119
+ .input-column {
120
+ flex-direction: column;
121
+ gap: 0px;
122
+ }
123
+ .input-column-label {}
124
+ .gallery {
125
+ # scrollbar-width: thin;
126
+ # scrollbar-color: #27272A;
127
+ }
128
+
129
+ .run-button-column {
130
+ width: 100px !important;
131
+ }
132
+ """
133
+ # ,
134
+ # elem_classes="container"
135
  ) as demo:
136
  with gr.Column(elem_classes="container"):
137
+ gr.Markdown("# ConceptAttention: Visualize Any Concepts in Your Generated Images", elem_classes="title")
138
+ # gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
139
+ # gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
140
+ gr.Markdown("## Interpret generative models with precise, high-quality heatmaps. Check out our paper [here](https://arxiv.org/abs/2502.04320).", elem_classes="abstract")
 
 
 
 
 
 
 
141
 
142
+ with gr.Row(scale=1, equal_height=True):
143
+ with gr.Column(scale=3, elem_classes="input-column"):
144
+ gr.HTML(
145
+ "Write a Prompt",
146
+ elem_classes="input-column-label"
147
+ )
148
+ prompt = gr.Dropdown(
149
+ ["A dog by a tree", "A dragon", "A hot air balloon"],
150
+ # label="Prompt",
151
+ container=False,
152
+ # scale=3,
153
+ allow_custom_value=True,
154
+ elem_classes="input"
155
+ )
156
+
157
+ with gr.Column(scale=7, elem_classes="input-column"):
158
+ gr.HTML(
159
+ "Select or Write Concepts",
160
+ elem_classes="input-column-label"
161
+ )
162
+ concepts = gr.Dropdown(
163
+ ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
164
+ value=["dog", "grass", "tree", "background"],
165
+ multiselect=True,
166
+ label="Concepts",
167
+ container=False,
168
+ allow_custom_value=True,
169
+ # scale=4,
170
+ elem_classes="input",
171
+ max_choices=5
172
+ )
173
+
174
+ with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
175
+ gr.HTML(
176
+ "​",
177
+ elem_classes="input-column-label"
178
+ )
179
+ submit_btn = gr.Button(
180
+ "Run",
181
+ # scale=1,
182
+ elem_classes="input"
183
+ )
184
+ # prompt = gr.Textbox(
185
+ # label="Enter your prompt",
186
+ # placeholder="Enter your prompt",
187
+ # value=EXAMPLES[0][0],
188
+ # scale=4,
189
+ # # show_label=True,
190
+ # container=False
191
+ # # height="80px"
192
+ # )
193
+ # words = gr.Textbox(
194
+ # label="Enter a list of concepts (comma-separated)",
195
+ # placeholder="Enter a list of concepts (comma-separated)",
196
+ # value=EXAMPLES[0][1],
197
+ # scale=4,
198
+ # # show_label=True,
199
+ # container=False
200
+ # # height="80px"
201
+ # )
202
 
203
  num_rows_state = gr.State(value=1) # Initial number of rows
204
 
205
  # generated_image = gr.Image(label="Generated Image", elem_classes="input-image")
206
+ # gallery = gr.Gallery(
207
+ # label="Generated images",
208
+ # show_label=True,
209
+ # # elem_id="gallery",
210
+ # columns=COLUMNS,
211
+ # rows=1,
212
+ # # object_fit="contain",
213
+ # height="auto",
214
+ # elem_classes="gallery"
215
+ # )
216
+
217
+ with gr.Row(elem_classes="gallery", scale=8):
218
+
219
+ with gr.Column(scale=1):
220
+ generated_image = gr.Image(
221
+ elem_classes="generated-image",
222
+ show_label=False
223
+ )
224
+
225
+ with gr.Column(scale=4):
226
+ concept_attention_gallery = gr.Gallery(
227
+ label="Concept Attention (Ours)",
228
+ show_label=True,
229
+ # columns=3,
230
+ rows=1,
231
+ object_fit="contain",
232
+ height="200px",
233
+ elem_classes="gallery"
234
+ )
235
+
236
+ cross_attention_gallery = gr.Gallery(
237
+ label="Cross Attention",
238
+ show_label=True,
239
+ # columns=3,
240
+ rows=1,
241
+ object_fit="contain",
242
+ height="200px",
243
+ elem_classes="gallery"
244
+ )
245
+
246
  with gr.Accordion("Advanced Settings", open=False):
247
  seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
248
  layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
249
  timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
250
 
 
251
  submit_btn.click(
252
  fn=process_inputs,
253
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
254
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
255
  )
256
 
257
+ # gr.Examples(examples=EXAMPLES, inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], outputs=[gallery, num_rows_state], fn=process_inputs, cache_examples=False)
 
258
  # num_rows_state.change(
259
  # fn=lambda rows: gr.Gallery.update(rows=int(rows)),
260
  # inputs=[num_rows_state],
261
  # outputs=[gallery]
262
  # )
263
 
264
+ prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
 
265
 
266
+ # Automatically process the first example on launch
267
+ demo.load(
268
+ process_inputs,
269
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
270
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
271
+ )
272
 
273
  if __name__ == "__main__":
274
  demo.launch(max_threads=1)
concept_attention/concept_attention_pipeline.py CHANGED
@@ -6,6 +6,7 @@ import PIL
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
  import torch
 
9
 
10
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
11
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
@@ -15,6 +16,7 @@ from concept_attention.image_generator import FluxGenerator
15
  class ConceptAttentionPipelineOutput():
16
  image: PIL.Image.Image | np.ndarray
17
  concept_heatmaps: list[PIL.Image.Image]
 
18
 
19
  class ConceptAttentionFluxPipeline():
20
  """
@@ -36,19 +38,6 @@ class ConceptAttentionFluxPipeline():
36
  offload=offload_model,
37
  device=device
38
  )
39
- # Make a Raw Cross Attention Segmentation Model and Raw Output space segmentation model
40
- self.cross_attention_segmentation_model = RawCrossAttentionSegmentationModel(
41
- generator=self.flux_generator
42
- )
43
- self.output_space_segmentation_model = RawOutputSpaceSegmentationModel(
44
- generator=self.flux_generator
45
- )
46
- self.raw_output_space_generator = RawOutputSpaceBaseline(
47
- generator=self.flux_generator
48
- )
49
- self.raw_cross_attention_generator = RawCrossAttentionBaseline(
50
- generator=self.flux_generator
51
- )
52
 
53
  @torch.no_grad()
54
  def generate_image(
@@ -77,20 +66,50 @@ class ConceptAttentionFluxPipeline():
77
  if timesteps is None:
78
  timesteps = list(range(num_inference_steps))
79
  # Run the raw output space object
80
- concept_heatmaps, image = self.raw_output_space_generator(
81
- prompt,
82
- concepts,
83
- seed=seed,
84
- num_steps=num_inference_steps,
85
- timesteps=timesteps,
86
- layers=layer_indices,
87
- softmax=softmax,
88
- height=width,
89
  width=width,
 
 
 
 
 
90
  guidance=guidance,
91
  )
92
- # Convert to numpy
93
- concept_heatmaps = concept_heatmaps.detach().cpu().numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Convert the torch heatmaps to PIL images.
95
  if return_pil_heatmaps:
96
  # Convert to a matplotlib color scheme
@@ -103,63 +122,73 @@ class ConceptAttentionFluxPipeline():
103
 
104
  concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
105
 
 
 
 
 
 
 
 
 
 
106
  return ConceptAttentionPipelineOutput(
107
  image=image,
108
- concept_heatmaps=concept_heatmaps
 
109
  )
110
 
111
- def encode_image(
112
- self,
113
- image: PIL.Image.Image,
114
- concepts: list[str],
115
- prompt: str = "", # Optional
116
- width: int = 1024,
117
- height: int = 1024,
118
- return_cross_attention = False,
119
- layer_indices = list(range(15, 19)),
120
- num_samples: int = 1,
121
- device: str = "cuda:0",
122
- return_pil_heatmaps: bool = True,
123
- seed: int = 0,
124
- cmap="plasma"
125
- ) -> ConceptAttentionPipelineOutput:
126
- """
127
- Encode an image with flux, given a list of concepts.
128
- """
129
- assert return_cross_attention is False, "Not supported yet"
130
- assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
131
- assert height == width, "Height and width must be the same for now"
132
- # Run the raw output space object
133
- concept_heatmaps, _ = self.output_space_segmentation_model.segment_individual_image(
134
- image=image,
135
- concepts=concepts,
136
- caption=prompt,
137
- device=device,
138
- softmax=True,
139
- layers=layer_indices,
140
- num_samples=num_samples,
141
- height=height,
142
- width=width
143
- )
144
- concept_heatmaps = concept_heatmaps.detach().cpu().numpy().squeeze()
145
 
146
- # Convert the torch heatmaps to PIL images.
147
- if return_pil_heatmaps:
148
- min_val = concept_heatmaps.min()
149
- max_val = concept_heatmaps.max()
150
- # Convert to a matplotlib color scheme
151
- colored_heatmaps = []
152
- for concept_heatmap in concept_heatmaps:
153
- # concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
154
- concept_heatmap = (concept_heatmap - min_val) / (max_val - min_val)
155
- colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
156
- rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
157
- colored_heatmaps.append(rgb_image)
158
 
159
- concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
160
 
161
- return ConceptAttentionPipelineOutput(
162
- image=image,
163
- concept_heatmaps=concept_heatmaps
164
- )
165
 
 
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
  import torch
9
+ import einops
10
 
11
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
12
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
 
16
  class ConceptAttentionPipelineOutput():
17
  image: PIL.Image.Image | np.ndarray
18
  concept_heatmaps: list[PIL.Image.Image]
19
+ cross_attention_maps: list[PIL.Image.Image]
20
 
21
  class ConceptAttentionFluxPipeline():
22
  """
 
38
  offload=offload_model,
39
  device=device
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  @torch.no_grad()
43
  def generate_image(
 
66
  if timesteps is None:
67
  timesteps = list(range(num_inference_steps))
68
  # Run the raw output space object
69
+ image, cross_attention_maps, concept_heatmaps = self.flux_generator.generate_image(
 
 
 
 
 
 
 
 
70
  width=width,
71
+ height=height,
72
+ prompt=prompt,
73
+ num_steps=num_inference_steps,
74
+ concepts=concepts,
75
+ seed=seed,
76
  guidance=guidance,
77
  )
78
+ # Concept heamaps extraction
79
+ if softmax:
80
+ concept_heatmaps = torch.nn.functional.softmax(concept_heatmaps, dim=-2)
81
+
82
+ concept_heatmaps = concept_heatmaps[:, layer_indices]
83
+ concept_heatmaps = einops.reduce(
84
+ concept_heatmaps,
85
+ "time layers batch concepts patches -> batch concepts patches",
86
+ reduction="mean"
87
+ )
88
+ concept_heatmaps = einops.rearrange(
89
+ concept_heatmaps,
90
+ "batch concepts (h w) -> batch concepts h w",
91
+ h=64,
92
+ w=64
93
+ )
94
+ # Cross attention maps
95
+ if softmax:
96
+ cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
97
+
98
+ cross_attention_maps = cross_attention_maps[:, layer_indices]
99
+ cross_attention_maps = einops.reduce(
100
+ cross_attention_maps,
101
+ "time layers batch concepts patches -> batch concepts patches",
102
+ reduction="mean"
103
+ )
104
+ cross_attention_maps = einops.rearrange(
105
+ cross_attention_maps,
106
+ "batch concepts (h w) -> batch concepts h w",
107
+ h=64,
108
+ w=64
109
+ )
110
+
111
+ concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
112
+ cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
113
  # Convert the torch heatmaps to PIL images.
114
  if return_pil_heatmaps:
115
  # Convert to a matplotlib color scheme
 
122
 
123
  concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
124
 
125
+ colored_cross_attention_maps = []
126
+ for cross_attention_map in cross_attention_maps:
127
+ cross_attention_map = (cross_attention_map - cross_attention_map.min()) / (cross_attention_map.max() - cross_attention_map.min())
128
+ colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
129
+ rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
130
+ colored_cross_attention_maps.append(rgb_image)
131
+
132
+ cross_attention_maps = [PIL.Image.fromarray(cross_attention_map) for cross_attention_map in colored_cross_attention_maps]
133
+
134
  return ConceptAttentionPipelineOutput(
135
  image=image,
136
+ concept_heatmaps=concept_heatmaps,
137
+ cross_attention_maps=cross_attention_maps
138
  )
139
 
140
+ # def encode_image(
141
+ # self,
142
+ # image: PIL.Image.Image,
143
+ # concepts: list[str],
144
+ # prompt: str = "", # Optional
145
+ # width: int = 1024,
146
+ # height: int = 1024,
147
+ # return_cross_attention = False,
148
+ # layer_indices = list(range(15, 19)),
149
+ # num_samples: int = 1,
150
+ # device: str = "cuda:0",
151
+ # return_pil_heatmaps: bool = True,
152
+ # seed: int = 0,
153
+ # cmap="plasma"
154
+ # ) -> ConceptAttentionPipelineOutput:
155
+ # """
156
+ # Encode an image with flux, given a list of concepts.
157
+ # """
158
+ # assert return_cross_attention is False, "Not supported yet"
159
+ # assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
160
+ # assert height == width, "Height and width must be the same for now"
161
+ # # Run the raw output space object
162
+ # concept_heatmaps, _ = self.output_space_segmentation_model.segment_individual_image(
163
+ # image=image,
164
+ # concepts=concepts,
165
+ # caption=prompt,
166
+ # device=device,
167
+ # softmax=True,
168
+ # layers=layer_indices,
169
+ # num_samples=num_samples,
170
+ # height=height,
171
+ # width=width
172
+ # )
173
+ # concept_heatmaps = concept_heatmaps.detach().cpu().numpy().squeeze()
174
 
175
+ # # Convert the torch heatmaps to PIL images.
176
+ # if return_pil_heatmaps:
177
+ # min_val = concept_heatmaps.min()
178
+ # max_val = concept_heatmaps.max()
179
+ # # Convert to a matplotlib color scheme
180
+ # colored_heatmaps = []
181
+ # for concept_heatmap in concept_heatmaps:
182
+ # # concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
183
+ # concept_heatmap = (concept_heatmap - min_val) / (max_val - min_val)
184
+ # colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
185
+ # rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
186
+ # colored_heatmaps.append(rgb_image)
187
 
188
+ # concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
189
 
190
+ # return ConceptAttentionPipelineOutput(
191
+ # image=image,
192
+ # concept_heatmaps=concept_heatmaps
193
+ # )
194