duzx16
commited on
Commit
·
591fa87
1
Parent(s):
85ba2d2
Add system prompt
Browse files- modeling_chatglm.py +14 -17
- tokenization_chatglm.py +7 -2
modeling_chatglm.py
CHANGED
|
@@ -1001,19 +1001,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1001 |
response = response.replace("[[训练时间]]", "2023年")
|
| 1002 |
return response
|
| 1003 |
|
| 1004 |
-
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
| 1005 |
-
inputs = tokenizer.build_chat_input(query, history=history)
|
| 1006 |
-
inputs = inputs.to(self.device)
|
| 1007 |
-
return inputs
|
| 1008 |
-
|
| 1009 |
-
def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
| 1010 |
-
inputs = tokenizer.build_chat_input(query)
|
| 1011 |
inputs = inputs.to(self.device)
|
| 1012 |
return inputs
|
| 1013 |
|
| 1014 |
@torch.inference_mode()
|
| 1015 |
-
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
| 1016 |
-
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
|
|
|
| 1017 |
if history is None:
|
| 1018 |
history = []
|
| 1019 |
if logits_processor is None:
|
|
@@ -1021,7 +1017,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1021 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1022 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1023 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1024 |
-
inputs = self.build_inputs(tokenizer, query, history=history)
|
| 1025 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
|
| 1026 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
| 1027 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
|
@@ -1031,21 +1027,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1031 |
return response, history
|
| 1032 |
|
| 1033 |
@torch.inference_mode()
|
| 1034 |
-
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None,
|
| 1035 |
-
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
| 1036 |
-
return_past_key_values=False, **kwargs):
|
| 1037 |
if history is None:
|
| 1038 |
history = []
|
| 1039 |
if logits_processor is None:
|
| 1040 |
logits_processor = LogitsProcessorList()
|
| 1041 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1042 |
-
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")
|
|
|
|
| 1043 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1044 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1045 |
-
if past_key_values is None
|
| 1046 |
-
inputs = self.build_inputs(tokenizer, query, history=history)
|
| 1047 |
else:
|
| 1048 |
-
inputs = self.
|
| 1049 |
if past_key_values is not None:
|
| 1050 |
past_length = past_key_values[0][0].shape[0]
|
| 1051 |
if self.transformer.pre_seq_len is not None:
|
|
|
|
| 1001 |
response = response.replace("[[训练时间]]", "2023年")
|
| 1002 |
return response
|
| 1003 |
|
| 1004 |
+
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None):
|
| 1005 |
+
inputs = tokenizer.build_chat_input(query, history=history, system=system)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
inputs = inputs.to(self.device)
|
| 1007 |
return inputs
|
| 1008 |
|
| 1009 |
@torch.inference_mode()
|
| 1010 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
|
| 1011 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1012 |
+
**kwargs):
|
| 1013 |
if history is None:
|
| 1014 |
history = []
|
| 1015 |
if logits_processor is None:
|
|
|
|
| 1017 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1018 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1019 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1020 |
+
inputs = self.build_inputs(tokenizer, query, history=history, system=system)
|
| 1021 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>")]
|
| 1022 |
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
| 1023 |
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
|
|
|
| 1027 |
return response, history
|
| 1028 |
|
| 1029 |
@torch.inference_mode()
|
| 1030 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system: str = None,
|
| 1031 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
| 1032 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1033 |
if history is None:
|
| 1034 |
history = []
|
| 1035 |
if logits_processor is None:
|
| 1036 |
logits_processor = LogitsProcessorList()
|
| 1037 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1038 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
| 1039 |
+
tokenizer.get_command("<|observation|>")]
|
| 1040 |
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1041 |
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1042 |
+
if past_key_values is None:
|
| 1043 |
+
inputs = self.build_inputs(tokenizer, query, history=history, system=system)
|
| 1044 |
else:
|
| 1045 |
+
inputs = self.build_inputs(tokenizer, query)
|
| 1046 |
if past_key_values is not None:
|
| 1047 |
past_length = past_key_values[0][0].shape[0]
|
| 1048 |
if self.transformer.pre_seq_len is not None:
|
tokenization_chatglm.py
CHANGED
|
@@ -67,7 +67,9 @@ class SPTokenizer:
|
|
| 67 |
|
| 68 |
def convert_id_to_token(self, index):
|
| 69 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 70 |
-
if index in self.index_special_tokens
|
|
|
|
|
|
|
| 71 |
return ""
|
| 72 |
return self.sp_model.IdToPiece(index)
|
| 73 |
|
|
@@ -171,10 +173,13 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 171 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
| 172 |
return prefix_tokens
|
| 173 |
|
| 174 |
-
def build_chat_input(self, query, history=None):
|
| 175 |
if history is None:
|
| 176 |
history = []
|
| 177 |
input_ids = []
|
|
|
|
|
|
|
|
|
|
| 178 |
for i, (old_query, old_response) in enumerate(history):
|
| 179 |
input_ids.extend(
|
| 180 |
[self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
|
|
|
|
| 67 |
|
| 68 |
def convert_id_to_token(self, index):
|
| 69 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 70 |
+
if index in self.index_special_tokens:
|
| 71 |
+
return self.index_special_tokens[index]
|
| 72 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
| 73 |
return ""
|
| 74 |
return self.sp_model.IdToPiece(index)
|
| 75 |
|
|
|
|
| 173 |
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
| 174 |
return prefix_tokens
|
| 175 |
|
| 176 |
+
def build_chat_input(self, query, history=None, system=None):
|
| 177 |
if history is None:
|
| 178 |
history = []
|
| 179 |
input_ids = []
|
| 180 |
+
if system is not None:
|
| 181 |
+
input_ids.extend(
|
| 182 |
+
[self.get_command("<|system|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(system))
|
| 183 |
for i, (old_query, old_response) in enumerate(history):
|
| 184 |
input_ids.extend(
|
| 185 |
[self.get_command("<|user|>")] + self.tokenizer.encode("\n") + self.tokenizer.encode(old_query))
|