Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5f0abcc
1
Parent(s):
4b30dce
Working app locally.
Browse files- app.py +31 -9
- concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc +0 -0
- concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc +0 -0
- concept_attention/concept_attention_pipeline.py +2 -2
- concept_attention/flux/src/flux/util.py +2 -1
- concept_attention/image_generator.py +3 -2
app.py
CHANGED
@@ -4,6 +4,9 @@ import io
|
|
4 |
import spaces
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
|
|
|
|
|
|
7 |
|
8 |
from concept_attention import ConceptAttentionFluxPipeline
|
9 |
|
@@ -17,15 +20,28 @@ concept_attention_default_args = {
|
|
17 |
}
|
18 |
IMG_SIZE = 250
|
19 |
|
|
|
|
|
|
|
20 |
EXAMPLES = [
|
21 |
[
|
22 |
-
"A
|
23 |
-
"
|
24 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
42, # seed
|
26 |
],
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
]
|
30 |
|
31 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
@@ -40,9 +56,15 @@ def process_inputs(prompt, input_image, word_list, seed):
|
|
40 |
concepts = [w.strip() for w in word_list.split(",")]
|
41 |
|
42 |
if input_image is not None:
|
43 |
-
input_image
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
pipeline_output = pipeline.encode_image(
|
48 |
image=input_image,
|
@@ -128,7 +150,7 @@ with gr.Blocks(
|
|
128 |
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
-
demo.launch()
|
132 |
# share=True,
|
133 |
# server_name="0.0.0.0",
|
134 |
# inbrowser=True,
|
|
|
4 |
import spaces
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
+
import requests
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
|
11 |
from concept_attention import ConceptAttentionFluxPipeline
|
12 |
|
|
|
20 |
}
|
21 |
IMG_SIZE = 250
|
22 |
|
23 |
+
def download_image(url):
|
24 |
+
return Image.open(io.BytesIO(requests.get(url).content))
|
25 |
+
|
26 |
EXAMPLES = [
|
27 |
[
|
28 |
+
"A dog by a tree", # prompt
|
29 |
+
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"),
|
30 |
+
"tree, dog, grass, background", # words
|
31 |
+
42, # seed
|
32 |
+
],
|
33 |
+
[
|
34 |
+
"A dragon", # prompt
|
35 |
+
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"),
|
36 |
+
"dragon, sky, rock, cloud", # words
|
37 |
42, # seed
|
38 |
],
|
39 |
+
[
|
40 |
+
"A hot air balloon", # prompt
|
41 |
+
download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"),
|
42 |
+
"balloon, sky, water, tree", # words
|
43 |
+
42, # seed
|
44 |
+
]
|
45 |
]
|
46 |
|
47 |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
|
|
|
56 |
concepts = [w.strip() for w in word_list.split(",")]
|
57 |
|
58 |
if input_image is not None:
|
59 |
+
if isinstance(input_image, np.ndarray):
|
60 |
+
input_image = Image.fromarray(input_image)
|
61 |
+
input_image = input_image.convert("RGB")
|
62 |
+
input_image = input_image.resize((1024, 1024))
|
63 |
+
elif isinstance(input_image, PIL.Image.Image):
|
64 |
+
input_image = input_image.convert("RGB")
|
65 |
+
input_image = input_image.resize((1024, 1024))
|
66 |
+
|
67 |
+
print(input_image.size)
|
68 |
|
69 |
pipeline_output = pipeline.encode_image(
|
70 |
image=input_image,
|
|
|
150 |
gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
|
151 |
|
152 |
if __name__ == "__main__":
|
153 |
+
demo.launch(max_threads=1)
|
154 |
# share=True,
|
155 |
# server_name="0.0.0.0",
|
156 |
# inbrowser=True,
|
concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc
CHANGED
Binary files a/concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc and b/concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc differ
|
|
concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc
CHANGED
Binary files a/concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc and b/concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc differ
|
|
concept_attention/concept_attention_pipeline.py
CHANGED
@@ -28,7 +28,7 @@ class ConceptAttentionFluxPipeline():
|
|
28 |
device="cuda:0"
|
29 |
):
|
30 |
self.model_name = model_name
|
31 |
-
self.offload_model =
|
32 |
# Load the generator
|
33 |
self.flux_generator = FluxGenerator(
|
34 |
model_name=model_name,
|
@@ -139,7 +139,7 @@ class ConceptAttentionFluxPipeline():
|
|
139 |
height=height,
|
140 |
width=width
|
141 |
)
|
142 |
-
concept_heatmaps = concept_heatmaps.detach().cpu().numpy()
|
143 |
|
144 |
# Convert the torch heatmaps to PIL images.
|
145 |
if return_pil_heatmaps:
|
|
|
28 |
device="cuda:0"
|
29 |
):
|
30 |
self.model_name = model_name
|
31 |
+
self.offload_model = offload_model
|
32 |
# Load the generator
|
33 |
self.flux_generator = FluxGenerator(
|
34 |
model_name=model_name,
|
|
|
139 |
height=height,
|
140 |
width=width
|
141 |
)
|
142 |
+
concept_heatmaps = concept_heatmaps.detach().cpu().numpy().squeeze()
|
143 |
|
144 |
# Convert the torch heatmaps to PIL images.
|
145 |
if return_pil_heatmaps:
|
concept_attention/flux/src/flux/util.py
CHANGED
@@ -136,6 +136,7 @@ class T5Embedder(nn.Module):
|
|
136 |
self.hf_module = hf_module
|
137 |
self.tokenizer = tokenizer
|
138 |
|
|
|
139 |
def forward(self, text: list[str]) -> torch.Tensor:
|
140 |
batch_encoding = self.tokenizer(
|
141 |
text,
|
@@ -181,7 +182,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
|
|
181 |
tokenizer,
|
182 |
max_length=max_length,
|
183 |
output_key="last_hidden_state"
|
184 |
-
).to(device)
|
185 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
186 |
# Load the safe tensors model
|
187 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|
|
|
136 |
self.hf_module = hf_module
|
137 |
self.tokenizer = tokenizer
|
138 |
|
139 |
+
@torch.no_grad()
|
140 |
def forward(self, text: list[str]) -> torch.Tensor:
|
141 |
batch_encoding = self.tokenizer(
|
142 |
text,
|
|
|
182 |
tokenizer,
|
183 |
max_length=max_length,
|
184 |
output_key="last_hidden_state"
|
185 |
+
).to(device).to(torch.bfloat16)
|
186 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
187 |
# Load the safe tensors model
|
188 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|
concept_attention/image_generator.py
CHANGED
@@ -58,8 +58,9 @@ def get_models(
|
|
58 |
clip = load_clip(device)
|
59 |
model = load_flow_model(name, device="cpu" if offload else device, attention_block_class=attention_block_class, dit_class=dit_class)
|
60 |
ae = load_ae(name, device="cpu" if offload else device)
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
class FluxGenerator():
|
65 |
|
|
|
58 |
clip = load_clip(device)
|
59 |
model = load_flow_model(name, device="cpu" if offload else device, attention_block_class=attention_block_class, dit_class=dit_class)
|
60 |
ae = load_ae(name, device="cpu" if offload else device)
|
61 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
62 |
+
|
63 |
+
return model, ae, t5, clip, nsfw_classifier
|
64 |
|
65 |
class FluxGenerator():
|
66 |
|