Spaces:
Sleeping
Sleeping
Update gradio_utils.py
Browse files- gradio_utils.py +0 -5
gradio_utils.py
CHANGED
@@ -42,7 +42,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|
42 |
temb=None):
|
43 |
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
|
44 |
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
|
45 |
-
# 生成一个0到1之间的随机数
|
46 |
global total_count,attn_count,cur_step,mask256,mask1024,mask4096
|
47 |
global sa16, sa32, sa64
|
48 |
global write
|
@@ -50,7 +49,6 @@ class SpatialAttnProcessor2_0(torch.nn.Module):
|
|
50 |
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
|
51 |
else:
|
52 |
encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
|
53 |
-
# 判断随机数是否大于0.5
|
54 |
if cur_step <5:
|
55 |
hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
|
56 |
else: # 256 1024 4096
|
@@ -260,7 +258,6 @@ def cal_attn_indice_xl_effcient_memory(total_length,id_length,sa32,sa64,height,w
|
|
260 |
nums_4096 = (height // 16) * (width // 16)
|
261 |
bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
|
262 |
bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
|
263 |
-
# 用nonzero()函数获取所有为True的值的索引
|
264 |
indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
|
265 |
indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
|
266 |
|
@@ -431,7 +428,6 @@ def is_torch2_available():
|
|
431 |
return hasattr(F, "scaled_dot_product_attention")
|
432 |
|
433 |
|
434 |
-
# 将列表转换为字典的函数
|
435 |
def character_to_dict(general_prompt):
|
436 |
character_dict = {}
|
437 |
generate_prompt_arr = general_prompt.splitlines()
|
@@ -439,7 +435,6 @@ def character_to_dict(general_prompt):
|
|
439 |
invert_character_index_dict = {}
|
440 |
character_list = []
|
441 |
for ind,string in enumerate(generate_prompt_arr):
|
442 |
-
# 分割字符串寻找key和value
|
443 |
start = string.find('[')
|
444 |
end = string.find(']')
|
445 |
if start != -1 and end != -1:
|
|
|
42 |
temb=None):
|
43 |
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
|
44 |
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
|
|
|
45 |
global total_count,attn_count,cur_step,mask256,mask1024,mask4096
|
46 |
global sa16, sa32, sa64
|
47 |
global write
|
|
|
49 |
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
|
50 |
else:
|
51 |
encoder_hidden_states = torch.cat(self.id_bank[cur_step][0],hidden_states[:1],self.id_bank[cur_step][1],hidden_states[1:])
|
|
|
52 |
if cur_step <5:
|
53 |
hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
|
54 |
else: # 256 1024 4096
|
|
|
258 |
nums_4096 = (height // 16) * (width // 16)
|
259 |
bool_matrix1024 = torch.rand((total_length,nums_1024),device = device,dtype = dtype) < sa32
|
260 |
bool_matrix4096 = torch.rand((total_length,nums_4096),device = device,dtype = dtype) < sa64
|
|
|
261 |
indices1024 = [torch.nonzero(bool_matrix1024[i], as_tuple=True)[0] for i in range(total_length)]
|
262 |
indices4096 = [torch.nonzero(bool_matrix4096[i], as_tuple=True)[0] for i in range(total_length)]
|
263 |
|
|
|
428 |
return hasattr(F, "scaled_dot_product_attention")
|
429 |
|
430 |
|
|
|
431 |
def character_to_dict(general_prompt):
|
432 |
character_dict = {}
|
433 |
generate_prompt_arr = general_prompt.splitlines()
|
|
|
435 |
invert_character_index_dict = {}
|
436 |
character_list = []
|
437 |
for ind,string in enumerate(generate_prompt_arr):
|
|
|
438 |
start = string.find('[')
|
439 |
end = string.find(']')
|
440 |
if start != -1 and end != -1:
|