shibing624 commited on
Commit
196a035
·
1 Parent(s): 7af717c

Upload modeling_baichuan.py

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +13 -13
modeling_baichuan.py CHANGED
@@ -552,41 +552,41 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
552
  )
553
  return self
554
 
555
- def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
556
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
557
  max_input_tokens = self.config.model_max_length - max_new_tokens
558
  max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
559
  total_input, round_input = [], []
560
- for i, message in enumerate(messages[::-1]):
561
- content_tokens = tokenizer.encode(message['content'])
562
  if message['role'] == 'user':
563
- round_input = [self.generation_config.user_token_id] + content_tokens + round_input
 
 
 
 
564
  if total_input and len(total_input) + len(round_input) > max_input_tokens:
565
  break
566
  else:
567
- total_input = round_input + total_input
568
  if len(total_input) >= max_input_tokens:
569
  break
570
  else:
571
  round_input = []
572
  elif message['role'] == 'assistant':
573
- round_input = [
574
- self.generation_config.assistant_token_id
575
- ] + content_tokens + [
576
- self.generation_config.eos_token_id
577
- ] + round_input
578
  else:
579
  raise ValueError(f"message role not supported yet: {message['role']}")
580
  total_input = total_input[-max_input_tokens:] # truncate left
581
- total_input.append(self.generation_config.assistant_token_id)
582
  total_input = torch.LongTensor([total_input]).to(self.device)
583
  return total_input
584
 
585
  @torch.no_grad()
586
- def chat(self, tokenizer, messages: List[dict], stream=False,
587
  generation_config: Optional[GenerationConfig]=None):
588
  generation_config = generation_config or self.generation_config
589
- input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
590
  if stream:
591
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
592
  self.__class__.generate = NewGenerationMixin.generate
 
552
  )
553
  return self
554
 
555
+ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0, system_prompt=""):
556
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
557
  max_input_tokens = self.config.model_max_length - max_new_tokens
558
  max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
559
  total_input, round_input = [], []
560
+ for i, message in enumerate(messages):
 
561
  if message['role'] == 'user':
562
+ if i == 0:
563
+ content_tokens = tokenizer.encode(system_prompt + "USER: " + message['content'] + " ASSISTANT: ")
564
+ else:
565
+ content_tokens = tokenizer.encode("USER: " + message['content'] + " ASSISTANT: ")
566
+ round_input += content_tokens
567
  if total_input and len(total_input) + len(round_input) > max_input_tokens:
568
  break
569
  else:
570
+ total_input += round_input
571
  if len(total_input) >= max_input_tokens:
572
  break
573
  else:
574
  round_input = []
575
  elif message['role'] == 'assistant':
576
+ content_tokens = tokenizer.encode(message['content'])
577
+ round_input += content_tokens + [self.generation_config.eos_token_id]
 
 
 
578
  else:
579
  raise ValueError(f"message role not supported yet: {message['role']}")
580
  total_input = total_input[-max_input_tokens:] # truncate left
581
+ # total_input.append(self.generation_config.eos_token_id)
582
  total_input = torch.LongTensor([total_input]).to(self.device)
583
  return total_input
584
 
585
  @torch.no_grad()
586
+ def chat(self, tokenizer, messages: List[dict], stream=False, system_prompt="",
587
  generation_config: Optional[GenerationConfig]=None):
588
  generation_config = generation_config or self.generation_config
589
+ input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens, system_prompt)
590
  if stream:
591
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
592
  self.__class__.generate = NewGenerationMixin.generate