davidlvxin
commited on
Commit
•
a93f22e
1
Parent(s):
acb9849
Update modeling_chatglm.py
Browse files- modeling_chatglm.py +23 -10
modeling_chatglm.py
CHANGED
@@ -416,7 +416,10 @@ class SelfAttention(torch.nn.Module):
|
|
416 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
417 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
418 |
if use_cache:
|
419 |
-
kv_cache
|
|
|
|
|
|
|
420 |
else:
|
421 |
kv_cache = None
|
422 |
|
@@ -627,12 +630,8 @@ class GLMTransformer(torch.nn.Module):
|
|
627 |
if not kv_caches:
|
628 |
kv_caches = [None for _ in range(self.num_layers)]
|
629 |
presents = () if use_cache else None
|
630 |
-
if self.
|
631 |
-
|
632 |
-
logger.warning_once(
|
633 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
634 |
-
)
|
635 |
-
use_cache = False
|
636 |
|
637 |
all_self_attentions = None
|
638 |
all_hidden_states = () if output_hidden_states else None
|
@@ -660,7 +659,15 @@ class GLMTransformer(torch.nn.Module):
|
|
660 |
)
|
661 |
hidden_states, kv_cache = layer_ret
|
662 |
if use_cache:
|
663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
|
665 |
if output_hidden_states:
|
666 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
@@ -845,6 +852,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
845 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
846 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
847 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
848 |
|
849 |
if not return_dict:
|
850 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
@@ -1036,7 +1049,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1036 |
|
1037 |
@torch.inference_mode()
|
1038 |
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1039 |
-
max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.
|
1040 |
**kwargs):
|
1041 |
if history is None:
|
1042 |
history = []
|
@@ -1058,7 +1071,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1058 |
|
1059 |
@torch.inference_mode()
|
1060 |
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1061 |
-
past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.
|
1062 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
1063 |
if history is None:
|
1064 |
history = []
|
|
|
416 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
417 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
418 |
if use_cache:
|
419 |
+
if kv_cache is None:
|
420 |
+
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
421 |
+
else:
|
422 |
+
kv_cache = (key_layer, value_layer)
|
423 |
else:
|
424 |
kv_cache = None
|
425 |
|
|
|
630 |
if not kv_caches:
|
631 |
kv_caches = [None for _ in range(self.num_layers)]
|
632 |
presents = () if use_cache else None
|
633 |
+
if self.training:
|
634 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
635 |
|
636 |
all_self_attentions = None
|
637 |
all_hidden_states = () if output_hidden_states else None
|
|
|
659 |
)
|
660 |
hidden_states, kv_cache = layer_ret
|
661 |
if use_cache:
|
662 |
+
# token by token decoding, use tuple format
|
663 |
+
if kv_caches[0] is not None:
|
664 |
+
presents = presents + (kv_cache,)
|
665 |
+
# prefilling in decoding, use tensor format to save cuda memory
|
666 |
+
else:
|
667 |
+
if len(presents) == 0:
|
668 |
+
presents = kv_cache
|
669 |
+
else:
|
670 |
+
presents = torch.cat((presents, kv_cache), dim=0)
|
671 |
|
672 |
if output_hidden_states:
|
673 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
852 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
853 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
854 |
)
|
855 |
+
if presents is not None and type(presents) is torch.Tensor:
|
856 |
+
presents = presents.split(1, dim=0)
|
857 |
+
presents = list(presents)
|
858 |
+
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
859 |
+
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
860 |
+
presents = tuple(presents)
|
861 |
|
862 |
if not return_dict:
|
863 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
1049 |
|
1050 |
@torch.inference_mode()
|
1051 |
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1052 |
+
max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None,
|
1053 |
**kwargs):
|
1054 |
if history is None:
|
1055 |
history = []
|
|
|
1071 |
|
1072 |
@torch.inference_mode()
|
1073 |
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1074 |
+
past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.7, temperature=0.95,
|
1075 |
logits_processor=None, return_past_key_values=False, **kwargs):
|
1076 |
if history is None:
|
1077 |
history = []
|