Update modeling_chatglm.py
Browse files- modeling_chatglm.py +23 -10
modeling_chatglm.py
CHANGED
|
@@ -416,7 +416,10 @@ class SelfAttention(torch.nn.Module):
|
|
| 416 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
| 417 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
| 418 |
if use_cache:
|
| 419 |
-
kv_cache
|
|
|
|
|
|
|
|
|
|
| 420 |
else:
|
| 421 |
kv_cache = None
|
| 422 |
|
|
@@ -627,12 +630,8 @@ class GLMTransformer(torch.nn.Module):
|
|
| 627 |
if not kv_caches:
|
| 628 |
kv_caches = [None for _ in range(self.num_layers)]
|
| 629 |
presents = () if use_cache else None
|
| 630 |
-
if self.
|
| 631 |
-
|
| 632 |
-
logger.warning_once(
|
| 633 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 634 |
-
)
|
| 635 |
-
use_cache = False
|
| 636 |
|
| 637 |
all_self_attentions = None
|
| 638 |
all_hidden_states = () if output_hidden_states else None
|
|
@@ -660,7 +659,15 @@ class GLMTransformer(torch.nn.Module):
|
|
| 660 |
)
|
| 661 |
hidden_states, kv_cache = layer_ret
|
| 662 |
if use_cache:
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
if output_hidden_states:
|
| 666 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
@@ -845,6 +852,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 845 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
| 846 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
| 847 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
|
| 849 |
if not return_dict:
|
| 850 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
@@ -1036,7 +1049,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1036 |
|
| 1037 |
@torch.inference_mode()
|
| 1038 |
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1039 |
-
max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.
|
| 1040 |
**kwargs):
|
| 1041 |
if history is None:
|
| 1042 |
history = []
|
|
@@ -1058,7 +1071,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1058 |
|
| 1059 |
@torch.inference_mode()
|
| 1060 |
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1061 |
-
past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.
|
| 1062 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1063 |
if history is None:
|
| 1064 |
history = []
|
|
|
|
| 416 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
| 417 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
| 418 |
if use_cache:
|
| 419 |
+
if kv_cache is None:
|
| 420 |
+
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
| 421 |
+
else:
|
| 422 |
+
kv_cache = (key_layer, value_layer)
|
| 423 |
else:
|
| 424 |
kv_cache = None
|
| 425 |
|
|
|
|
| 630 |
if not kv_caches:
|
| 631 |
kv_caches = [None for _ in range(self.num_layers)]
|
| 632 |
presents = () if use_cache else None
|
| 633 |
+
if self.training:
|
| 634 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
all_self_attentions = None
|
| 637 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
| 659 |
)
|
| 660 |
hidden_states, kv_cache = layer_ret
|
| 661 |
if use_cache:
|
| 662 |
+
# token by token decoding, use tuple format
|
| 663 |
+
if kv_caches[0] is not None:
|
| 664 |
+
presents = presents + (kv_cache,)
|
| 665 |
+
# prefilling in decoding, use tensor format to save cuda memory
|
| 666 |
+
else:
|
| 667 |
+
if len(presents) == 0:
|
| 668 |
+
presents = kv_cache
|
| 669 |
+
else:
|
| 670 |
+
presents = torch.cat((presents, kv_cache), dim=0)
|
| 671 |
|
| 672 |
if output_hidden_states:
|
| 673 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
| 852 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
| 853 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
| 854 |
)
|
| 855 |
+
if presents is not None and type(presents) is torch.Tensor:
|
| 856 |
+
presents = presents.split(1, dim=0)
|
| 857 |
+
presents = list(presents)
|
| 858 |
+
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
| 859 |
+
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
| 860 |
+
presents = tuple(presents)
|
| 861 |
|
| 862 |
if not return_dict:
|
| 863 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
|
| 1049 |
|
| 1050 |
@torch.inference_mode()
|
| 1051 |
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1052 |
+
max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None,
|
| 1053 |
**kwargs):
|
| 1054 |
if history is None:
|
| 1055 |
history = []
|
|
|
|
| 1071 |
|
| 1072 |
@torch.inference_mode()
|
| 1073 |
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1074 |
+
past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.7, temperature=0.95,
|
| 1075 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1076 |
if history is None:
|
| 1077 |
history = []
|