Commit
·
196a035
1
Parent(s):
7af717c
Upload modeling_baichuan.py
Browse files- modeling_baichuan.py +13 -13
modeling_baichuan.py
CHANGED
@@ -552,41 +552,41 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
552 |
)
|
553 |
return self
|
554 |
|
555 |
-
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
556 |
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
557 |
max_input_tokens = self.config.model_max_length - max_new_tokens
|
558 |
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
|
559 |
total_input, round_input = [], []
|
560 |
-
for i, message in enumerate(messages
|
561 |
-
content_tokens = tokenizer.encode(message['content'])
|
562 |
if message['role'] == 'user':
|
563 |
-
|
|
|
|
|
|
|
|
|
564 |
if total_input and len(total_input) + len(round_input) > max_input_tokens:
|
565 |
break
|
566 |
else:
|
567 |
-
total_input
|
568 |
if len(total_input) >= max_input_tokens:
|
569 |
break
|
570 |
else:
|
571 |
round_input = []
|
572 |
elif message['role'] == 'assistant':
|
573 |
-
|
574 |
-
|
575 |
-
] + content_tokens + [
|
576 |
-
self.generation_config.eos_token_id
|
577 |
-
] + round_input
|
578 |
else:
|
579 |
raise ValueError(f"message role not supported yet: {message['role']}")
|
580 |
total_input = total_input[-max_input_tokens:] # truncate left
|
581 |
-
total_input.append(self.generation_config.
|
582 |
total_input = torch.LongTensor([total_input]).to(self.device)
|
583 |
return total_input
|
584 |
|
585 |
@torch.no_grad()
|
586 |
-
def chat(self, tokenizer, messages: List[dict], stream=False,
|
587 |
generation_config: Optional[GenerationConfig]=None):
|
588 |
generation_config = generation_config or self.generation_config
|
589 |
-
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
|
590 |
if stream:
|
591 |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
592 |
self.__class__.generate = NewGenerationMixin.generate
|
|
|
552 |
)
|
553 |
return self
|
554 |
|
555 |
+
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0, system_prompt=""):
|
556 |
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
557 |
max_input_tokens = self.config.model_max_length - max_new_tokens
|
558 |
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
|
559 |
total_input, round_input = [], []
|
560 |
+
for i, message in enumerate(messages):
|
|
|
561 |
if message['role'] == 'user':
|
562 |
+
if i == 0:
|
563 |
+
content_tokens = tokenizer.encode(system_prompt + "USER: " + message['content'] + " ASSISTANT: ")
|
564 |
+
else:
|
565 |
+
content_tokens = tokenizer.encode("USER: " + message['content'] + " ASSISTANT: ")
|
566 |
+
round_input += content_tokens
|
567 |
if total_input and len(total_input) + len(round_input) > max_input_tokens:
|
568 |
break
|
569 |
else:
|
570 |
+
total_input += round_input
|
571 |
if len(total_input) >= max_input_tokens:
|
572 |
break
|
573 |
else:
|
574 |
round_input = []
|
575 |
elif message['role'] == 'assistant':
|
576 |
+
content_tokens = tokenizer.encode(message['content'])
|
577 |
+
round_input += content_tokens + [self.generation_config.eos_token_id]
|
|
|
|
|
|
|
578 |
else:
|
579 |
raise ValueError(f"message role not supported yet: {message['role']}")
|
580 |
total_input = total_input[-max_input_tokens:] # truncate left
|
581 |
+
# total_input.append(self.generation_config.eos_token_id)
|
582 |
total_input = torch.LongTensor([total_input]).to(self.device)
|
583 |
return total_input
|
584 |
|
585 |
@torch.no_grad()
|
586 |
+
def chat(self, tokenizer, messages: List[dict], stream=False, system_prompt="",
|
587 |
generation_config: Optional[GenerationConfig]=None):
|
588 |
generation_config = generation_config or self.generation_config
|
589 |
+
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens, system_prompt)
|
590 |
if stream:
|
591 |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
592 |
self.__class__.generate = NewGenerationMixin.generate
|