caohy666 commited on
Commit
734f79a
·
1 Parent(s): f4b19f4

<feat> unify hacked_lora_forward

Browse files
Files changed (1) hide show
  1. app.py +60 -67
app.py CHANGED
@@ -99,79 +99,72 @@ def init_basemodel():
99
  image_processor=image_processor,
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  @spaces.GPU
104
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
105
  # set up the model
106
  global pipe, current_task, transformer
107
  if current_task != task:
108
- if current_task is None:
109
- # insert LoRA
110
- lora_config = LoraConfig(
111
- r=16,
112
- lora_alpha=16,
113
- init_lora_weights="gaussian",
114
- target_modules=[
115
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
116
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
117
- 'ff.net.0.proj', 'ff.net.2',
118
- 'ff_context.net.0.proj', 'ff_context.net.2',
119
- 'norm1_context.linear', 'norm1.linear',
120
- 'norm.linear', 'proj_mlp', 'proj_out',
121
- ]
122
- )
123
- transformer.add_adapter(lora_config)
124
- else:
125
- def restore_forward(module):
126
- def restored_forward(self, x, *args, **kwargs):
127
- return module.original_forward(x, *args, **kwargs)
128
- return restored_forward.__get__(module, type(module))
129
-
130
- for n, m in transformer.named_modules():
131
- if isinstance(m, peft.tuners.lora.layer.Linear):
132
- m.forward = restore_forward(m)
133
-
134
- current_task = task
135
-
136
- # hack LoRA forward
137
- def create_hacked_forward(module):
138
- if not hasattr(module, 'original_forward'):
139
- module.original_forward = module.forward
140
- lora_forward = module.forward
141
- non_lora_forward = module.base_layer.forward
142
- img_sequence_length = int((512 / 8 / 2) ** 2)
143
- encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
144
- num_imgs = 4
145
- num_generated_imgs = 3
146
- num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
147
-
148
- def hacked_lora_forward(self, x, *args, **kwargs):
149
- if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
150
- return torch.cat((
151
- lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
152
- non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
153
- ), dim=1)
154
- elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
155
- return lora_forward(x, *args, **kwargs)
156
- elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
157
- return torch.cat((
158
- lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
159
- non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
160
- lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
161
- ), dim=1)
162
- elif x.shape[1] == 3072:
163
- return non_lora_forward(x, *args, **kwargs)
164
- else:
165
- raise ValueError(
166
- f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
167
- )
168
-
169
- return hacked_lora_forward.__get__(module, type(module))
170
-
171
- for n, m in transformer.named_modules():
172
- if isinstance(m, peft.tuners.lora.layer.Linear):
173
- m.forward = create_hacked_forward(m)
174
-
175
  # load LoRA weights
176
  model_root = hf_hub_download(
177
  repo_id="Kunbyte/DRA-Ctrl",
 
99
  image_processor=image_processor,
100
  )
101
 
102
+ # insert LoRA
103
+ lora_config = LoraConfig(
104
+ r=16,
105
+ lora_alpha=16,
106
+ init_lora_weights="gaussian",
107
+ target_modules=[
108
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
109
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
110
+ 'ff.net.0.proj', 'ff.net.2',
111
+ 'ff_context.net.0.proj', 'ff_context.net.2',
112
+ 'norm1_context.linear', 'norm1.linear',
113
+ 'norm.linear', 'proj_mlp', 'proj_out',
114
+ ]
115
+ )
116
+ transformer.add_adapter(lora_config)
117
+
118
+ # hack LoRA forward
119
+ def create_hacked_forward(module):
120
+ if not hasattr(module, 'original_forward'):
121
+ module.original_forward = module.forward
122
+ lora_forward = module.forward
123
+ non_lora_forward = module.base_layer.forward
124
+ img_sequence_length = int((512 / 8 / 2) ** 2)
125
+ encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
126
+ num_imgs = 4
127
+ num_generated_imgs = 3
128
+
129
+ def hacked_lora_forward(self, x, *args, **kwargs):
130
+ if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
131
+ return torch.cat((
132
+ lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
133
+ non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
134
+ ), dim=1)
135
+ elif x.shape[1] == encoder_sequence_length * 2 or x.shape[1] == encoder_sequence_length:
136
+ return lora_forward(x, *args, **kwargs)
137
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length:
138
+ return torch.cat((
139
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
140
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-encoder_sequence_length], *args, **kwargs),
141
+ lora_forward(x[:, -encoder_sequence_length:], *args, **kwargs)
142
+ ), dim=1)
143
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * 2:
144
+ return torch.cat((
145
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
146
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-2*encoder_sequence_length], *args, **kwargs),
147
+ lora_forward(x[:, -2*encoder_sequence_length:], *args, **kwargs)
148
+ ), dim=1)
149
+ elif x.shape[1] == 3072:
150
+ return non_lora_forward(x, *args, **kwargs)
151
+ else:
152
+ raise ValueError(
153
+ f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
154
+ )
155
+
156
+ return hacked_lora_forward.__get__(module, type(module))
157
+
158
+ for n, m in transformer.named_modules():
159
+ if isinstance(m, peft.tuners.lora.layer.Linear):
160
+ m.forward = create_hacked_forward(m)
161
+
162
 
163
  @spaces.GPU
164
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
165
  # set up the model
166
  global pipe, current_task, transformer
167
  if current_task != task:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # load LoRA weights
169
  model_root = hf_hub_download(
170
  repo_id="Kunbyte/DRA-Ctrl",