huzefa11 commited on
Commit
b1a89db
·
verified ·
1 Parent(s): 568e86f

Update gradio_utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils.py +0 -5
gradio_utils.py CHANGED
@@ -42,7 +42,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
42
  temb=None):
43
  # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
44
  # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
45
- # 生成一个0到1之间的随机数
46
  global total_count,attn_count,cur_step,mask256,mask1024,mask4096
47
  global sa16, sa32, sa64
48
  global write
@@ -50,7 +49,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
50
  self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
51
  else:
52
  encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
53
- # 判断随机数是否大于0.5
54
  if cur_step <5:
55
  hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
56
  else: # 256 1024 4096
@@ -260,7 +258,6 @@ def cal_attn_indice_xl_effcient_memory(total_length,id_length,sa32,sa64,height,w
260
  nums_4096 = (height // 16) * (width // 16)
261
  bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
262
  bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
263
- # 用nonzero()函数获取所有为True的值的索引
264
  indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
265
  indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
266
 
@@ -431,7 +428,6 @@ def is_torch2_available():
431
  return hasattr(F, "scaled_dot_product_attention")
432
 
433
 
434
- # 将列表转换为字典的函数
435
  def character_to_dict(general_prompt):
436
  character_dict = {}
437
  generate_prompt_arr = general_prompt.splitlines()
@@ -439,7 +435,6 @@ def character_to_dict(general_prompt):
439
  invert_character_index_dict = {}
440
  character_list = []
441
  for ind,string in enumerate(generate_prompt_arr):
442
- # 分割字符串寻找key和value
443
  start = string.find('[')
444
  end = string.find(']')
445
  if start != -1 and end != -1:
 
42
  temb=None):
43
  # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
44
  # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
 
45
  global total_count,attn_count,cur_step,mask256,mask1024,mask4096
46
  global sa16, sa32, sa64
47
  global write
 
49
  self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
50
  else:
51
  encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
 
52
  if cur_step <5:
53
  hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
54
  else: # 256 1024 4096
 
258
  nums_4096 = (height // 16) * (width // 16)
259
  bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
260
  bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
 
261
  indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
262
  indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
263
 
 
428
  return hasattr(F, "scaled_dot_product_attention")
429
 
430
 
 
431
  def character_to_dict(general_prompt):
432
  character_dict = {}
433
  generate_prompt_arr = general_prompt.splitlines()
 
435
  invert_character_index_dict = {}
436
  character_list = []
437
  for ind,string in enumerate(generate_prompt_arr):
 
438
  start = string.find('[')
439
  end = string.find(']')
440
  if start != -1 and end != -1: