caohy666 commited on
Commit
f4b19f4
·
1 Parent(s): cb59ffa

<fix> remove pipe_lock

Browse files
Files changed (1) hide show
  1. app.py +190 -193
app.py CHANGED
@@ -47,8 +47,6 @@ there's no need to manually input edge maps, depth maps, or other condition imag
47
  The corresponding condition images will be automatically extracted.
48
  """
49
 
50
- pipe_lock = threading.Lock()
51
-
52
 
53
  def init_basemodel():
54
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor, pipe, current_task
@@ -105,201 +103,200 @@ def init_basemodel():
105
  @spaces.GPU
106
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
107
  # set up the model
108
- with pipe_lock:
109
- global pipe, current_task, transformer
110
- if current_task != task:
111
- if current_task is None:
112
- # insert LoRA
113
- lora_config = LoraConfig(
114
- r=16,
115
- lora_alpha=16,
116
- init_lora_weights="gaussian",
117
- target_modules=[
118
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
119
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
120
- 'ff.net.0.proj', 'ff.net.2',
121
- 'ff_context.net.0.proj', 'ff_context.net.2',
122
- 'norm1_context.linear', 'norm1.linear',
123
- 'norm.linear', 'proj_mlp', 'proj_out',
124
- ]
125
- )
126
- transformer.add_adapter(lora_config)
127
- else:
128
- def restore_forward(module):
129
- def restored_forward(self, x, *args, **kwargs):
130
- return module.original_forward(x, *args, **kwargs)
131
- return restored_forward.__get__(module, type(module))
132
-
133
- for n, m in transformer.named_modules():
134
- if isinstance(m, peft.tuners.lora.layer.Linear):
135
- m.forward = restore_forward(m)
136
-
137
- current_task = task
138
-
139
- # hack LoRA forward
140
- def create_hacked_forward(module):
141
- if not hasattr(module, 'original_forward'):
142
- module.original_forward = module.forward
143
- lora_forward = module.forward
144
- non_lora_forward = module.base_layer.forward
145
- img_sequence_length = int((512 / 8 / 2) ** 2)
146
- encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
147
- num_imgs = 4
148
- num_generated_imgs = 3
149
- num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
150
-
151
- def hacked_lora_forward(self, x, *args, **kwargs):
152
- if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
153
- return torch.cat((
154
- lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
155
- non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
156
- ), dim=1)
157
- elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
158
- return lora_forward(x, *args, **kwargs)
159
- elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
160
- return torch.cat((
161
- lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
162
- non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
163
- lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
164
- ), dim=1)
165
- elif x.shape[1] == 3072:
166
- return non_lora_forward(x, *args, **kwargs)
167
- else:
168
- raise ValueError(
169
- f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
170
- )
171
-
172
- return hacked_lora_forward.__get__(module, type(module))
173
-
174
  for n, m in transformer.named_modules():
175
  if isinstance(m, peft.tuners.lora.layer.Linear):
176
- m.forward = create_hacked_forward(m)
177
-
178
- # load LoRA weights
179
- model_root = hf_hub_download(
180
- repo_id="Kunbyte/DRA-Ctrl",
181
- filename=f"{task}.safetensors",
182
- resume_download=True)
183
-
184
- try:
185
- with safe_open(model_root, framework="pt") as f:
186
- lora_weights = {}
187
- for k in f.keys():
188
- param = f.get_tensor(k)
189
- if k.endswith(".weight"):
190
- k = k.replace('.weight', '.default.weight')
191
- lora_weights[k] = param
192
- transformer.load_state_dict(lora_weights, strict=False)
193
- except Exception as e:
194
- raise ValueError(f'{e}')
195
-
196
- transformer.requires_grad_(False)
197
-
198
- # start generation
199
- c_txt = None if condition_image_prompt == "" else condition_image_prompt
200
- c_img = condition_image.resize((512, 512))
201
- t_txt = target_prompt
202
-
203
- if task not in ['subject_driven', 'style_transfer']:
204
- if task == "canny":
205
- def get_canny_edge(img):
206
- img_np = np.array(img)
207
- img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
208
- edges = cv2.Canny(img_gray, 100, 200)
209
- edges_tmp = Image.fromarray(edges).convert("RGB")
210
- edges[edges == 0] = 128
211
- return Image.fromarray(edges).convert("RGB")
212
- c_img = get_canny_edge(c_img)
213
- elif task == "coloring":
214
- c_img = (
215
- c_img.resize((512, 512))
216
- .convert("L")
217
- .convert("RGB")
218
- )
219
- elif task == "deblurring":
220
- blur_radius = 10
221
- c_img = (
222
- c_img.convert("RGB")
223
- .filter(ImageFilter.GaussianBlur(blur_radius))
224
- .resize((512, 512))
225
- .convert("RGB")
226
- )
227
- elif task == "depth":
228
- def get_depth_map(img):
229
- from transformers import pipeline
230
-
231
- depth_pipe = pipeline(
232
- task="depth-estimation",
233
- model="LiheYoung/depth-anything-small-hf",
234
- device="cpu",
235
  )
236
- return depth_pipe(img)["depth"].convert("RGB").resize((512, 512))
237
- c_img = get_depth_map(c_img)
238
- k = (255 - 128) / 255
239
- b = 128
240
- c_img = c_img.point(lambda x: k * x + b)
241
- elif task == "depth_pred":
242
- c_img = c_img
243
- elif task == "fill":
244
- c_img = c_img.resize((512, 512)).convert("RGB")
245
- x1, x2 = fill_x1, fill_x2
246
- y1, y2 = fill_y1, fill_y2
247
- mask = Image.new("L", (512, 512), 0)
248
- draw = ImageDraw.Draw(mask)
249
- draw.rectangle((x1, y1, x2, y2), fill=255)
250
- if inpainting:
251
- mask = Image.eval(mask, lambda a: 255 - a)
252
- c_img = Image.composite(
253
- c_img,
254
- Image.new("RGB", (512, 512), (255, 255, 255)),
255
- mask
256
- )
257
- c_img = Image.composite(
258
- c_img,
259
- Image.new("RGB", (512, 512), (128, 128, 128)),
260
- mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  )
262
- elif task == "sr":
263
- c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB")
264
- c_img = c_img.resize((512, 512))
265
-
266
- gen_img = pipe(
267
- image=c_img,
268
- prompt=[t_txt.strip()],
269
- prompt_condition=[c_txt.strip()] if c_txt is not None else None,
270
- prompt_2=[t_txt],
271
- height=512,
272
- width=512,
273
- num_frames=5,
274
- num_inference_steps=num_steps,
275
- guidance_scale=6.0,
276
- num_videos_per_prompt=1,
277
- generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed),
278
- output_type='pt',
279
- image_embed_interleave=4,
280
- frame_gap=48,
281
- mixup=True,
282
- mixup_num_imgs=2,
283
- enhance_tp=task in ['subject_driven'],
284
- ).frames
285
-
286
- output_images = []
287
- for i in range(10):
288
- out = gen_img[:, i:i+1, :, :, :]
289
- out = out.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy()
290
- out = np.transpose(out, (1, 2, 0))
291
- out = (out * 255).astype(np.uint8)
292
- out = Image.fromarray(out)
293
- output_images.append(out)
294
-
295
- # video = [np.array(img.convert('RGB')) for img in output_images[1:] + [output_images[0]]]
296
- # video = np.stack(video, axis=0)
297
-
298
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
299
- video_path = f.name
300
- imageio.mimsave(video_path, output_images[1:]+[output_images[0]], fps=5)
301
-
302
- return output_images[0], video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  def get_samples():
305
  sample_list = [
 
47
  The corresponding condition images will be automatically extracted.
48
  """
49
 
 
 
50
 
51
  def init_basemodel():
52
  global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor, pipe, current_task
 
103
  @spaces.GPU
104
  def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2):
105
  # set up the model
106
+ global pipe, current_task, transformer
107
+ if current_task != task:
108
+ if current_task is None:
109
+ # insert LoRA
110
+ lora_config = LoraConfig(
111
+ r=16,
112
+ lora_alpha=16,
113
+ init_lora_weights="gaussian",
114
+ target_modules=[
115
+ 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
116
+ 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
117
+ 'ff.net.0.proj', 'ff.net.2',
118
+ 'ff_context.net.0.proj', 'ff_context.net.2',
119
+ 'norm1_context.linear', 'norm1.linear',
120
+ 'norm.linear', 'proj_mlp', 'proj_out',
121
+ ]
122
+ )
123
+ transformer.add_adapter(lora_config)
124
+ else:
125
+ def restore_forward(module):
126
+ def restored_forward(self, x, *args, **kwargs):
127
+ return module.original_forward(x, *args, **kwargs)
128
+ return restored_forward.__get__(module, type(module))
129
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  for n, m in transformer.named_modules():
131
  if isinstance(m, peft.tuners.lora.layer.Linear):
132
+ m.forward = restore_forward(m)
133
+
134
+ current_task = task
135
+
136
+ # hack LoRA forward
137
+ def create_hacked_forward(module):
138
+ if not hasattr(module, 'original_forward'):
139
+ module.original_forward = module.forward
140
+ lora_forward = module.forward
141
+ non_lora_forward = module.base_layer.forward
142
+ img_sequence_length = int((512 / 8 / 2) ** 2)
143
+ encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt
144
+ num_imgs = 4
145
+ num_generated_imgs = 3
146
+ num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1
147
+
148
+ def hacked_lora_forward(self, x, *args, **kwargs):
149
+ if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2:
150
+ return torch.cat((
151
+ lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs),
152
+ non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs)
153
+ ), dim=1)
154
+ elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length:
155
+ return lora_forward(x, *args, **kwargs)
156
+ elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences:
157
+ return torch.cat((
158
+ lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs),
159
+ non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs),
160
+ lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs)
161
+ ), dim=1)
162
+ elif x.shape[1] == 3072:
163
+ return non_lora_forward(x, *args, **kwargs)
164
+ else:
165
+ raise ValueError(
166
+ f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
+
169
+ return hacked_lora_forward.__get__(module, type(module))
170
+
171
+ for n, m in transformer.named_modules():
172
+ if isinstance(m, peft.tuners.lora.layer.Linear):
173
+ m.forward = create_hacked_forward(m)
174
+
175
+ # load LoRA weights
176
+ model_root = hf_hub_download(
177
+ repo_id="Kunbyte/DRA-Ctrl",
178
+ filename=f"{task}.safetensors",
179
+ resume_download=True)
180
+
181
+ try:
182
+ with safe_open(model_root, framework="pt") as f:
183
+ lora_weights = {}
184
+ for k in f.keys():
185
+ param = f.get_tensor(k)
186
+ if k.endswith(".weight"):
187
+ k = k.replace('.weight', '.default.weight')
188
+ lora_weights[k] = param
189
+ transformer.load_state_dict(lora_weights, strict=False)
190
+ except Exception as e:
191
+ raise ValueError(f'{e}')
192
+
193
+ transformer.requires_grad_(False)
194
+
195
+ # start generation
196
+ c_txt = None if condition_image_prompt == "" else condition_image_prompt
197
+ c_img = condition_image.resize((512, 512))
198
+ t_txt = target_prompt
199
+
200
+ if task not in ['subject_driven', 'style_transfer']:
201
+ if task == "canny":
202
+ def get_canny_edge(img):
203
+ img_np = np.array(img)
204
+ img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
205
+ edges = cv2.Canny(img_gray, 100, 200)
206
+ edges_tmp = Image.fromarray(edges).convert("RGB")
207
+ edges[edges == 0] = 128
208
+ return Image.fromarray(edges).convert("RGB")
209
+ c_img = get_canny_edge(c_img)
210
+ elif task == "coloring":
211
+ c_img = (
212
+ c_img.resize((512, 512))
213
+ .convert("L")
214
+ .convert("RGB")
215
+ )
216
+ elif task == "deblurring":
217
+ blur_radius = 10
218
+ c_img = (
219
+ c_img.convert("RGB")
220
+ .filter(ImageFilter.GaussianBlur(blur_radius))
221
+ .resize((512, 512))
222
+ .convert("RGB")
223
+ )
224
+ elif task == "depth":
225
+ def get_depth_map(img):
226
+ from transformers import pipeline
227
+
228
+ depth_pipe = pipeline(
229
+ task="depth-estimation",
230
+ model="LiheYoung/depth-anything-small-hf",
231
+ device="cpu",
232
  )
233
+ return depth_pipe(img)["depth"].convert("RGB").resize((512, 512))
234
+ c_img = get_depth_map(c_img)
235
+ k = (255 - 128) / 255
236
+ b = 128
237
+ c_img = c_img.point(lambda x: k * x + b)
238
+ elif task == "depth_pred":
239
+ c_img = c_img
240
+ elif task == "fill":
241
+ c_img = c_img.resize((512, 512)).convert("RGB")
242
+ x1, x2 = fill_x1, fill_x2
243
+ y1, y2 = fill_y1, fill_y2
244
+ mask = Image.new("L", (512, 512), 0)
245
+ draw = ImageDraw.Draw(mask)
246
+ draw.rectangle((x1, y1, x2, y2), fill=255)
247
+ if inpainting:
248
+ mask = Image.eval(mask, lambda a: 255 - a)
249
+ c_img = Image.composite(
250
+ c_img,
251
+ Image.new("RGB", (512, 512), (255, 255, 255)),
252
+ mask
253
+ )
254
+ c_img = Image.composite(
255
+ c_img,
256
+ Image.new("RGB", (512, 512), (128, 128, 128)),
257
+ mask
258
+ )
259
+ elif task == "sr":
260
+ c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB")
261
+ c_img = c_img.resize((512, 512))
262
+
263
+ gen_img = pipe(
264
+ image=c_img,
265
+ prompt=[t_txt.strip()],
266
+ prompt_condition=[c_txt.strip()] if c_txt is not None else None,
267
+ prompt_2=[t_txt],
268
+ height=512,
269
+ width=512,
270
+ num_frames=5,
271
+ num_inference_steps=num_steps,
272
+ guidance_scale=6.0,
273
+ num_videos_per_prompt=1,
274
+ generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed),
275
+ output_type='pt',
276
+ image_embed_interleave=4,
277
+ frame_gap=48,
278
+ mixup=True,
279
+ mixup_num_imgs=2,
280
+ enhance_tp=task in ['subject_driven'],
281
+ ).frames
282
+
283
+ output_images = []
284
+ for i in range(10):
285
+ out = gen_img[:, i:i+1, :, :, :]
286
+ out = out.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy()
287
+ out = np.transpose(out, (1, 2, 0))
288
+ out = (out * 255).astype(np.uint8)
289
+ out = Image.fromarray(out)
290
+ output_images.append(out)
291
+
292
+ # video = [np.array(img.convert('RGB')) for img in output_images[1:] + [output_images[0]]]
293
+ # video = np.stack(video, axis=0)
294
+
295
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
296
+ video_path = f.name
297
+ imageio.mimsave(video_path, output_images[1:]+[output_images[0]], fps=5)
298
+
299
+ return output_images[0], video_path
300
 
301
  def get_samples():
302
  sample_list = [