helblazer811 commited on
Commit
e6d4da8
·
1 Parent(s): 6391136

revision to the UI.

Browse files
app.py CHANGED
@@ -10,14 +10,14 @@ import PIL
10
 
11
  from concept_attention import ConceptAttentionFluxPipeline
12
 
13
- concept_attention_default_args = {
14
- "model_name": "flux-schnell",
15
- "device": "cuda",
16
- "layer_indices": list(range(10, 19)),
17
- "timesteps": list(range(4)),
18
- "num_samples": 4,
19
- "num_inference_steps": 4
20
- }
21
  IMG_SIZE = 250
22
 
23
  def download_image(url):
@@ -47,7 +47,7 @@ EXAMPLES = [
47
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
48
 
49
  @spaces.GPU(duration=60)
50
- def process_inputs(prompt, input_image, word_list, seed):
51
  print("Processing inputs")
52
  prompt = prompt.strip()
53
  if not word_list.strip():
@@ -64,8 +64,6 @@ def process_inputs(prompt, input_image, word_list, seed):
64
  input_image = input_image.convert("RGB")
65
  input_image = input_image.resize((1024, 1024))
66
 
67
- print(input_image.size)
68
-
69
  pipeline_output = pipeline.encode_image(
70
  image=input_image,
71
  concepts=concepts,
@@ -73,8 +71,10 @@ def process_inputs(prompt, input_image, word_list, seed):
73
  width=1024,
74
  height=1024,
75
  seed=seed,
76
- num_samples=concept_attention_default_args["num_samples"]
 
77
  )
 
78
  else:
79
  pipeline_output = pipeline.generate_image(
80
  prompt=prompt,
@@ -82,8 +82,9 @@ def process_inputs(prompt, input_image, word_list, seed):
82
  width=1024,
83
  height=1024,
84
  seed=seed,
85
- timesteps=concept_attention_default_args["timesteps"],
86
- num_inference_steps=concept_attention_default_args["num_inference_steps"],
 
87
  )
88
 
89
  output_image = pipeline_output.image
@@ -105,32 +106,47 @@ def process_inputs(prompt, input_image, word_list, seed):
105
  html_elements.append(html)
106
 
107
  combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
108
- return output_image, combined_html
109
 
110
 
111
  with gr.Blocks(
112
  css="""
113
  .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
114
  .title { text-align: center; margin-bottom: 10px; }
115
- .authors { text-align: center; margin-bottom: 20px; }
116
- .affiliations { text-align: center; color: #666; margin-bottom: 40px; }
117
  .content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
118
- .section { border: 2px solid #ddd; border-radius: 10px; padding: 20px; }
 
 
119
  """
120
  ) as demo:
121
  with gr.Column(elem_classes="container"):
122
  gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
123
- gr.Markdown("**Alec Helbling**¹, **Tuna Meral**², **Ben Hoover**¹³, **Pinar Yanardag**², **Duen Horng (Polo) Chau**¹", elem_classes="authors")
124
- gr.Markdown("¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
 
 
 
 
 
 
 
 
125
 
126
  with gr.Row(elem_classes="content"):
127
  with gr.Column(elem_classes="section"):
128
  gr.Markdown("### Input")
129
  prompt = gr.Textbox(label="Enter your prompt")
130
- words = gr.Textbox(label="Enter words (comma-separated)")
131
- seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
132
- gr.HTML("<div style='text-align: center;'> <h1> Or </h1> </div>")
133
- image_input = gr.Image(type="numpy", label="Upload image (optional)")
 
 
 
 
 
134
 
135
  with gr.Column(elem_classes="section"):
136
  gr.Markdown("### Output")
@@ -144,8 +160,13 @@ with gr.Blocks(
144
 
145
  submit_btn.click(
146
  fn=process_inputs,
147
- inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display]
148
  )
 
 
 
 
 
149
 
150
  gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
151
 
 
10
 
11
  from concept_attention import ConceptAttentionFluxPipeline
12
 
13
+ # concept_attention_default_args = {
14
+ # "model_name": "flux-schnell",
15
+ # "device": "cuda",
16
+ # "layer_indices": list(range(10, 19)),
17
+ # "timesteps": list(range(2, 4)),
18
+ # "num_samples": 4,
19
+ # "num_inference_steps": 4
20
+ # }
21
  IMG_SIZE = 250
22
 
23
  def download_image(url):
 
47
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
48
 
49
  @spaces.GPU(duration=60)
50
+ def process_inputs(prompt, input_image, word_list, seed, num_samples, layer_start_index, timestep_start_index):
51
  print("Processing inputs")
52
  prompt = prompt.strip()
53
  if not word_list.strip():
 
64
  input_image = input_image.convert("RGB")
65
  input_image = input_image.resize((1024, 1024))
66
 
 
 
67
  pipeline_output = pipeline.encode_image(
68
  image=input_image,
69
  concepts=concepts,
 
71
  width=1024,
72
  height=1024,
73
  seed=seed,
74
+ num_samples=num_samples,
75
+ layer_indices=list(range(layer_start_index, 19)),
76
  )
77
+
78
  else:
79
  pipeline_output = pipeline.generate_image(
80
  prompt=prompt,
 
82
  width=1024,
83
  height=1024,
84
  seed=seed,
85
+ timesteps=list(range(timestep_start_index, 4)),
86
+ num_inference_steps=4,
87
+ layer_indices=list(range(layer_start_index, 19)),
88
  )
89
 
90
  output_image = pipeline_output.image
 
106
  html_elements.append(html)
107
 
108
  combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
109
+ return output_image, combined_html, None # None fills input_image with None
110
 
111
 
112
  with gr.Blocks(
113
  css="""
114
  .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
115
  .title { text-align: center; margin-bottom: 10px; }
116
+ .authors { text-align: center; margin-bottom: 10px; }
117
+ .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
118
  .content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
119
+ .section { }
120
+ .input-image { width: 100%; height: 200px; }
121
+ .abstract { text-align: center; margin-bottom: 40px; }
122
  """
123
  ) as demo:
124
  with gr.Column(elem_classes="container"):
125
  gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
126
+ gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
127
+ gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
128
+ gr.Markdown(
129
+ """
130
+ We introduce ConceptAttention, an approach to interpreting the intermediate representations of diffusion transformers.
131
+ The user just gives a list of textual concepts and ConceptAttention will produce a set of saliency maps depicting
132
+ the location and intensity of these concepts in generated images. Check out our paper: [here](https://arxiv.org/abs/2502.04320).
133
+ """,
134
+ elem_classes="abstract"
135
+ )
136
 
137
  with gr.Row(elem_classes="content"):
138
  with gr.Column(elem_classes="section"):
139
  gr.Markdown("### Input")
140
  prompt = gr.Textbox(label="Enter your prompt")
141
+ words = gr.Textbox(label="Enter a list of concepts (comma-separated)")
142
+ # gr.HTML("<div style='text-align: center;'> <h3> Or </h3> </div>")
143
+ image_input = gr.Image(type="numpy", label="Upload image (optional)", elem_classes="input-image")
144
+ # Set up advanced options
145
+ with gr.Accordion("Advanced Settings", open=False):
146
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
147
+ num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
148
+ layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
149
+ timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
150
 
151
  with gr.Column(elem_classes="section"):
152
  gr.Markdown("### Output")
 
160
 
161
  submit_btn.click(
162
  fn=process_inputs,
163
+ inputs=[prompt, image_input, words, seed, num_samples, layer_start_index, timestep_start_index], outputs=[output_image, saliency_display, image_input]
164
  )
165
+ # .then(
166
+ # fn=lambda component: gr.update(value=None),
167
+ # inputs=[image_input],
168
+ # outputs=[]
169
+ # )
170
 
171
  gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
172
 
concept_attention/concept_attention_pipeline.py CHANGED
@@ -5,6 +5,7 @@ from dataclasses import dataclass
5
  import PIL
6
  import numpy as np
7
  import matplotlib.pyplot as plt
 
8
 
9
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
10
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
@@ -49,6 +50,7 @@ class ConceptAttentionFluxPipeline():
49
  generator=self.flux_generator
50
  )
51
 
 
52
  def generate_image(
53
  self,
54
  prompt: str,
 
5
  import PIL
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
+ import torch
9
 
10
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
11
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
 
50
  generator=self.flux_generator
51
  )
52
 
53
+ @torch.no_grad()
54
  def generate_image(
55
  self,
56
  prompt: str,
concept_attention/flux/src/flux/util.py CHANGED
@@ -169,7 +169,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
169
  state_dict.update(load_sft(safe_tensor_1, device=str(device)))
170
  state_dict.update(load_sft(safe_tensor_2, device=str(device)))
171
  # Load the state dict
172
- t5_encoder = T5EncoderModel(config=model_config).to(torch.bfloat16)
173
  t5_encoder.load_state_dict(state_dict, strict=False)
174
 
175
  # Load the tokenizer
@@ -182,7 +182,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
182
  tokenizer,
183
  max_length=max_length,
184
  output_key="last_hidden_state"
185
- ).to(device).to(torch.bfloat16)
186
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
187
  # Load the safe tensors model
188
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
 
169
  state_dict.update(load_sft(safe_tensor_1, device=str(device)))
170
  state_dict.update(load_sft(safe_tensor_2, device=str(device)))
171
  # Load the state dict
172
+ t5_encoder = T5EncoderModel(config=model_config).to(torch.bfloat16).to(device)
173
  t5_encoder.load_state_dict(state_dict, strict=False)
174
 
175
  # Load the tokenizer
 
182
  tokenizer,
183
  max_length=max_length,
184
  output_key="last_hidden_state"
185
+ )
186
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
187
  # Load the safe tensors model
188
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)