Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e6d4da8
1
Parent(s):
6391136
revision to the UI.
Browse files
app.py
CHANGED
@@ -10,14 +10,14 @@ import PIL
|
|
10 |
|
11 |
from concept_attention import ConceptAttentionFluxPipeline
|
12 |
|
13 |
-
concept_attention_default_args = {
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
}
|
21 |
IMG_SIZE = 250
|
22 |
|
23 |
def download_image(url):
|
@@ -47,7 +47,7 @@ EXAMPLES = [
|
|
47 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
48 |
|
49 |
@spaces.GPU(duration=60)
|
50 |
-
def process_inputs(prompt, input_image, word_list, seed):
|
51 |
print("Processing inputs")
|
52 |
prompt = prompt.strip()
|
53 |
if not word_list.strip():
|
@@ -64,8 +64,6 @@ def process_inputs(prompt, input_image, word_list, seed):
|
|
64 |
input_image = input_image.convert("RGB")
|
65 |
input_image = input_image.resize((1024, 1024))
|
66 |
|
67 |
-
print(input_image.size)
|
68 |
-
|
69 |
pipeline_output = pipeline.encode_image(
|
70 |
image=input_image,
|
71 |
concepts=concepts,
|
@@ -73,8 +71,10 @@ def process_inputs(prompt, input_image, word_list, seed):
|
|
73 |
width=1024,
|
74 |
height=1024,
|
75 |
seed=seed,
|
76 |
-
num_samples=
|
|
|
77 |
)
|
|
|
78 |
else:
|
79 |
pipeline_output = pipeline.generate_image(
|
80 |
prompt=prompt,
|
@@ -82,8 +82,9 @@ def process_inputs(prompt, input_image, word_list, seed):
|
|
82 |
width=1024,
|
83 |
height=1024,
|
84 |
seed=seed,
|
85 |
-
timesteps=
|
86 |
-
num_inference_steps=
|
|
|
87 |
)
|
88 |
|
89 |
output_image = pipeline_output.image
|
@@ -105,32 +106,47 @@ def process_inputs(prompt, input_image, word_list, seed):
|
|
105 |
html_elements.append(html)
|
106 |
|
107 |
combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
|
108 |
-
return output_image, combined_html
|
109 |
|
110 |
|
111 |
with gr.Blocks(
|
112 |
css="""
|
113 |
.container { max-width: 1200px; margin: 0 auto; padding: 20px; }
|
114 |
.title { text-align: center; margin-bottom: 10px; }
|
115 |
-
.authors { text-align: center; margin-bottom:
|
116 |
-
.affiliations { text-align: center; color: #666; margin-bottom:
|
117 |
.content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
|
118 |
-
.section {
|
|
|
|
|
119 |
"""
|
120 |
) as demo:
|
121 |
with gr.Column(elem_classes="container"):
|
122 |
gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
|
123 |
-
gr.Markdown("
|
124 |
-
gr.Markdown("¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
with gr.Row(elem_classes="content"):
|
127 |
with gr.Column(elem_classes="section"):
|
128 |
gr.Markdown("### Input")
|
129 |
prompt = gr.Textbox(label="Enter your prompt")
|
130 |
-
words = gr.Textbox(label="Enter
|
131 |
-
|
132 |
-
gr.
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
with gr.Column(elem_classes="section"):
|
136 |
gr.Markdown("### Output")
|
@@ -144,8 +160,13 @@ with gr.Blocks(
|
|
144 |
|
145 |
submit_btn.click(
|
146 |
fn=process_inputs,
|
147 |
-
inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display]
|
148 |
)
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
|
151 |
|
|
|
10 |
|
11 |
from concept_attention import ConceptAttentionFluxPipeline
|
12 |
|
13 |
+
# concept_attention_default_args = {
|
14 |
+
# "model_name": "flux-schnell",
|
15 |
+
# "device": "cuda",
|
16 |
+
# "layer_indices": list(range(10, 19)),
|
17 |
+
# "timesteps": list(range(2, 4)),
|
18 |
+
# "num_samples": 4,
|
19 |
+
# "num_inference_steps": 4
|
20 |
+
# }
|
21 |
IMG_SIZE = 250
|
22 |
|
23 |
def download_image(url):
|
|
|
47 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
48 |
|
49 |
@spaces.GPU(duration=60)
|
50 |
+
def process_inputs(prompt, input_image, word_list, seed, num_samples, layer_start_index, timestep_start_index):
|
51 |
print("Processing inputs")
|
52 |
prompt = prompt.strip()
|
53 |
if not word_list.strip():
|
|
|
64 |
input_image = input_image.convert("RGB")
|
65 |
input_image = input_image.resize((1024, 1024))
|
66 |
|
|
|
|
|
67 |
pipeline_output = pipeline.encode_image(
|
68 |
image=input_image,
|
69 |
concepts=concepts,
|
|
|
71 |
width=1024,
|
72 |
height=1024,
|
73 |
seed=seed,
|
74 |
+
num_samples=num_samples,
|
75 |
+
layer_indices=list(range(layer_start_index, 19)),
|
76 |
)
|
77 |
+
|
78 |
else:
|
79 |
pipeline_output = pipeline.generate_image(
|
80 |
prompt=prompt,
|
|
|
82 |
width=1024,
|
83 |
height=1024,
|
84 |
seed=seed,
|
85 |
+
timesteps=list(range(timestep_start_index, 4)),
|
86 |
+
num_inference_steps=4,
|
87 |
+
layer_indices=list(range(layer_start_index, 19)),
|
88 |
)
|
89 |
|
90 |
output_image = pipeline_output.image
|
|
|
106 |
html_elements.append(html)
|
107 |
|
108 |
combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
|
109 |
+
return output_image, combined_html, None # None fills input_image with None
|
110 |
|
111 |
|
112 |
with gr.Blocks(
|
113 |
css="""
|
114 |
.container { max-width: 1200px; margin: 0 auto; padding: 20px; }
|
115 |
.title { text-align: center; margin-bottom: 10px; }
|
116 |
+
.authors { text-align: center; margin-bottom: 10px; }
|
117 |
+
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
|
118 |
.content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
|
119 |
+
.section { }
|
120 |
+
.input-image { width: 100%; height: 200px; }
|
121 |
+
.abstract { text-align: center; margin-bottom: 40px; }
|
122 |
"""
|
123 |
) as demo:
|
124 |
with gr.Column(elem_classes="container"):
|
125 |
gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
|
126 |
+
gr.Markdown("### Alec Helbling¹, Tuna Meral², Ben Hoover¹³, Pinar Yanardag², Duen Horng (Polo) Chau¹", elem_classes="authors")
|
127 |
+
gr.Markdown("### ¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
|
128 |
+
gr.Markdown(
|
129 |
+
"""
|
130 |
+
We introduce ConceptAttention, an approach to interpreting the intermediate representations of diffusion transformers.
|
131 |
+
The user just gives a list of textual concepts and ConceptAttention will produce a set of saliency maps depicting
|
132 |
+
the location and intensity of these concepts in generated images. Check out our paper: [here](https://arxiv.org/abs/2502.04320).
|
133 |
+
""",
|
134 |
+
elem_classes="abstract"
|
135 |
+
)
|
136 |
|
137 |
with gr.Row(elem_classes="content"):
|
138 |
with gr.Column(elem_classes="section"):
|
139 |
gr.Markdown("### Input")
|
140 |
prompt = gr.Textbox(label="Enter your prompt")
|
141 |
+
words = gr.Textbox(label="Enter a list of concepts (comma-separated)")
|
142 |
+
# gr.HTML("<div style='text-align: center;'> <h3> Or </h3> </div>")
|
143 |
+
image_input = gr.Image(type="numpy", label="Upload image (optional)", elem_classes="input-image")
|
144 |
+
# Set up advanced options
|
145 |
+
with gr.Accordion("Advanced Settings", open=False):
|
146 |
+
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
|
147 |
+
num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4)
|
148 |
+
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
|
149 |
+
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
|
150 |
|
151 |
with gr.Column(elem_classes="section"):
|
152 |
gr.Markdown("### Output")
|
|
|
160 |
|
161 |
submit_btn.click(
|
162 |
fn=process_inputs,
|
163 |
+
inputs=[prompt, image_input, words, seed, num_samples, layer_start_index, timestep_start_index], outputs=[output_image, saliency_display, image_input]
|
164 |
)
|
165 |
+
# .then(
|
166 |
+
# fn=lambda component: gr.update(value=None),
|
167 |
+
# inputs=[image_input],
|
168 |
+
# outputs=[]
|
169 |
+
# )
|
170 |
|
171 |
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
|
172 |
|
concept_attention/concept_attention_pipeline.py
CHANGED
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|
5 |
import PIL
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
|
9 |
from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
|
10 |
from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
|
@@ -49,6 +50,7 @@ class ConceptAttentionFluxPipeline():
|
|
49 |
generator=self.flux_generator
|
50 |
)
|
51 |
|
|
|
52 |
def generate_image(
|
53 |
self,
|
54 |
prompt: str,
|
|
|
5 |
import PIL
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
import torch
|
9 |
|
10 |
from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
|
11 |
from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
|
|
|
50 |
generator=self.flux_generator
|
51 |
)
|
52 |
|
53 |
+
@torch.no_grad()
|
54 |
def generate_image(
|
55 |
self,
|
56 |
prompt: str,
|
concept_attention/flux/src/flux/util.py
CHANGED
@@ -169,7 +169,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
|
|
169 |
state_dict.update(load_sft(safe_tensor_1, device=str(device)))
|
170 |
state_dict.update(load_sft(safe_tensor_2, device=str(device)))
|
171 |
# Load the state dict
|
172 |
-
t5_encoder = T5EncoderModel(config=model_config).to(torch.bfloat16)
|
173 |
t5_encoder.load_state_dict(state_dict, strict=False)
|
174 |
|
175 |
# Load the tokenizer
|
@@ -182,7 +182,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
|
|
182 |
tokenizer,
|
183 |
max_length=max_length,
|
184 |
output_key="last_hidden_state"
|
185 |
-
)
|
186 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
187 |
# Load the safe tensors model
|
188 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|
|
|
169 |
state_dict.update(load_sft(safe_tensor_1, device=str(device)))
|
170 |
state_dict.update(load_sft(safe_tensor_2, device=str(device)))
|
171 |
# Load the state dict
|
172 |
+
t5_encoder = T5EncoderModel(config=model_config).to(torch.bfloat16).to(device)
|
173 |
t5_encoder.load_state_dict(state_dict, strict=False)
|
174 |
|
175 |
# Load the tokenizer
|
|
|
182 |
tokenizer,
|
183 |
max_length=max_length,
|
184 |
output_key="last_hidden_state"
|
185 |
+
)
|
186 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
187 |
# Load the safe tensors model
|
188 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|