helblazer811 commited on
Commit
5f0abcc
·
1 Parent(s): 4b30dce

Working app locally.

Browse files
app.py CHANGED
@@ -4,6 +4,9 @@ import io
4
  import spaces
5
  import gradio as gr
6
  from PIL import Image
 
 
 
7
 
8
  from concept_attention import ConceptAttentionFluxPipeline
9
 
@@ -17,15 +20,28 @@ concept_attention_default_args = {
17
  }
18
  IMG_SIZE = 250
19
 
 
 
 
20
  EXAMPLES = [
21
  [
22
- "A fluffy cat sitting on a windowsill", # prompt
23
- "cat.jpg", # image
24
- "fur, whiskers, eyes", # words
 
 
 
 
 
 
25
  42, # seed
26
  ],
27
- # ["Mountain landscape with lake", "cat.jpg", "sky, trees, water", 123],
28
- # ["Portrait of a young woman", "monkey.png", "face, hair, eyes", 456],
 
 
 
 
29
  ]
30
 
31
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
@@ -40,9 +56,15 @@ def process_inputs(prompt, input_image, word_list, seed):
40
  concepts = [w.strip() for w in word_list.split(",")]
41
 
42
  if input_image is not None:
43
- input_image = Image.fromarray(input_image)
44
- input_image = input_image.convert("RGB")
45
- input_image = input_image.resize((1024, 1024))
 
 
 
 
 
 
46
 
47
  pipeline_output = pipeline.encode_image(
48
  image=input_image,
@@ -128,7 +150,7 @@ with gr.Blocks(
128
  gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
129
 
130
  if __name__ == "__main__":
131
- demo.launch()
132
  # share=True,
133
  # server_name="0.0.0.0",
134
  # inbrowser=True,
 
4
  import spaces
5
  import gradio as gr
6
  from PIL import Image
7
+ import requests
8
+ import numpy as np
9
+ import PIL
10
 
11
  from concept_attention import ConceptAttentionFluxPipeline
12
 
 
20
  }
21
  IMG_SIZE = 250
22
 
23
+ def download_image(url):
24
+ return Image.open(io.BytesIO(requests.get(url).content))
25
+
26
  EXAMPLES = [
27
  [
28
+ "A dog by a tree", # prompt
29
+ download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"),
30
+ "tree, dog, grass, background", # words
31
+ 42, # seed
32
+ ],
33
+ [
34
+ "A dragon", # prompt
35
+ download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"),
36
+ "dragon, sky, rock, cloud", # words
37
  42, # seed
38
  ],
39
+ [
40
+ "A hot air balloon", # prompt
41
+ download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"),
42
+ "balloon, sky, water, tree", # words
43
+ 42, # seed
44
+ ]
45
  ]
46
 
47
  pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
 
56
  concepts = [w.strip() for w in word_list.split(",")]
57
 
58
  if input_image is not None:
59
+ if isinstance(input_image, np.ndarray):
60
+ input_image = Image.fromarray(input_image)
61
+ input_image = input_image.convert("RGB")
62
+ input_image = input_image.resize((1024, 1024))
63
+ elif isinstance(input_image, PIL.Image.Image):
64
+ input_image = input_image.convert("RGB")
65
+ input_image = input_image.resize((1024, 1024))
66
+
67
+ print(input_image.size)
68
 
69
  pipeline_output = pipeline.encode_image(
70
  image=input_image,
 
150
  gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
151
 
152
  if __name__ == "__main__":
153
+ demo.launch(max_threads=1)
154
  # share=True,
155
  # server_name="0.0.0.0",
156
  # inbrowser=True,
concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc CHANGED
Binary files a/concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc and b/concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc differ
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc CHANGED
Binary files a/concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc and b/concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc differ
 
concept_attention/concept_attention_pipeline.py CHANGED
@@ -28,7 +28,7 @@ class ConceptAttentionFluxPipeline():
28
  device="cuda:0"
29
  ):
30
  self.model_name = model_name
31
- self.offload_model = False
32
  # Load the generator
33
  self.flux_generator = FluxGenerator(
34
  model_name=model_name,
@@ -139,7 +139,7 @@ class ConceptAttentionFluxPipeline():
139
  height=height,
140
  width=width
141
  )
142
- concept_heatmaps = concept_heatmaps.detach().cpu().numpy()
143
 
144
  # Convert the torch heatmaps to PIL images.
145
  if return_pil_heatmaps:
 
28
  device="cuda:0"
29
  ):
30
  self.model_name = model_name
31
+ self.offload_model = offload_model
32
  # Load the generator
33
  self.flux_generator = FluxGenerator(
34
  model_name=model_name,
 
139
  height=height,
140
  width=width
141
  )
142
+ concept_heatmaps = concept_heatmaps.detach().cpu().numpy().squeeze()
143
 
144
  # Convert the torch heatmaps to PIL images.
145
  if return_pil_heatmaps:
concept_attention/flux/src/flux/util.py CHANGED
@@ -136,6 +136,7 @@ class T5Embedder(nn.Module):
136
  self.hf_module = hf_module
137
  self.tokenizer = tokenizer
138
 
 
139
  def forward(self, text: list[str]) -> torch.Tensor:
140
  batch_encoding = self.tokenizer(
141
  text,
@@ -181,7 +182,7 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
181
  tokenizer,
182
  max_length=max_length,
183
  output_key="last_hidden_state"
184
- ).to(device)
185
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
186
  # Load the safe tensors model
187
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
 
136
  self.hf_module = hf_module
137
  self.tokenizer = tokenizer
138
 
139
+ @torch.no_grad()
140
  def forward(self, text: list[str]) -> torch.Tensor:
141
  batch_encoding = self.tokenizer(
142
  text,
 
182
  tokenizer,
183
  max_length=max_length,
184
  output_key="last_hidden_state"
185
+ ).to(device).to(torch.bfloat16)
186
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
187
  # Load the safe tensors model
188
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
concept_attention/image_generator.py CHANGED
@@ -58,8 +58,9 @@ def get_models(
58
  clip = load_clip(device)
59
  model = load_flow_model(name, device="cpu" if offload else device, attention_block_class=attention_block_class, dit_class=dit_class)
60
  ae = load_ae(name, device="cpu" if offload else device)
61
- # nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
62
- return model, ae, t5, clip, None
 
63
 
64
  class FluxGenerator():
65
 
 
58
  clip = load_clip(device)
59
  model = load_flow_model(name, device="cpu" if offload else device, attention_block_class=attention_block_class, dit_class=dit_class)
60
  ae = load_ae(name, device="cpu" if offload else device)
61
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
62
+
63
+ return model, ae, t5, clip, nsfw_classifier
64
 
65
  class FluxGenerator():
66