caohy666 commited on
Commit
931adb9
·
1 Parent(s): 68fc027

<feat> support changing models.

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +29 -28
  3. app.sh +4 -0
README.md CHANGED
@@ -9,6 +9,7 @@ app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: huggingface space for DRA-Ctrl.
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: apache-2.0
11
  short_description: huggingface space for DRA-Ctrl.
12
+ entrypoint: app.sh
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,8 +1,4 @@
1
  import os
2
-
3
- if 'SPACES_APP' in os.environ:
4
- os.system("pip install flash-attn==2.7.3 --no-build-isolation")
5
-
6
  import sys
7
  import torch
8
  import diffusers
@@ -102,35 +98,40 @@ def init_basemodel():
102
  @spaces.GPU
103
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
104
  # set up the model
105
- global pipe, current_task
106
- if pipe is None or current_task != task:
107
- if current_task is not None:
108
- transformer.delete_adapters('default')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  for n, m in transformer.named_modules():
110
  if isinstance(m, peft.tuners.lora.layer.Linear):
111
- if hasattr(m, 'base_layer'):
112
- m.forward = m.base_layer.forward.__get__(m, type(m))
113
-
114
- current_task = task
115
-
116
- # insert LoRA
117
- lora_config = LoraConfig(
118
- r=16,
119
- lora_alpha=16,
120
- init_lora_weights="gaussian",
121
- target_modules=[
122
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
123
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
124
- 'ff.net.0.proj', 'ff.net.2',
125
- 'ff_context.net.0.proj', 'ff_context.net.2',
126
- 'norm1_context.linear', 'norm1.linear',
127
- 'norm.linear', 'proj_mlp', 'proj_out',
128
- ]
129
- )
130
- transformer.add_adapter(lora_config)
131
 
132
  # hack LoRA forward
133
  def create_hacked_forward(module):
 
 
134
  lora_forward = module.forward
135
  non_lora_forward = module.base_layer.forward
136
  img_sequence_length = int((512 / 8 / 2) ** 2)
 
1
  import os
 
 
 
 
2
  import sys
3
  import torch
4
  import diffusers
 
98
  @spaces.GPU
99
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
100
  # set up the model
101
+ global pipe, current_task, transformer
102
+ if current_task != task:
103
+ if current_task is None:
104
+ current_task = task
105
+
106
+ # insert LoRA
107
+ lora_config = LoraConfig(
108
+ r=16,
109
+ lora_alpha=16,
110
+ init_lora_weights="gaussian",
111
+ target_modules=[
112
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
113
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
114
+ 'ff.net.0.proj', 'ff.net.2',
115
+ 'ff_context.net.0.proj', 'ff_context.net.2',
116
+ 'norm1_context.linear', 'norm1.linear',
117
+ 'norm.linear', 'proj_mlp', 'proj_out',
118
+ ]
119
+ )
120
+ transformer.add_adapter(lora_config)
121
+ else:
122
+ def restore_forward(module):
123
+ def restored_forward(self, x, *args, **kwargs):
124
+ return module.original_forward(x, *args, **kwargs)
125
+ return restored_forward.__get__(module, type(module))
126
+
127
  for n, m in transformer.named_modules():
128
  if isinstance(m, peft.tuners.lora.layer.Linear):
129
+ m.forward = restore_forward(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # hack LoRA forward
132
  def create_hacked_forward(module):
133
+ if not hasattr(module, 'original_forward'):
134
+ module.original_forward = module.forward
135
  lora_forward = module.forward
136
  non_lora_forward = module.base_layer.forward
137
  img_sequence_length = int((512 / 8 / 2) ** 2)
app.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+ pip install torch==2.5.1
3
+ pip install git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc
4
+ python app.py