Deci
/

Text Generation
Transformers
Safetensors
deci
custom_code
File size: 2,437 Bytes
c141a5b
e084f01
 
c141a5b
 
 
 
 
 
 
 
e084f01
 
 
 
 
 
 
 
 
 
 
c141a5b
e084f01
c141a5b
e084f01
 
 
 
 
 
 
 
 
 
 
c141a5b
e084f01
 
 
 
c141a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4af0b5
c141a5b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.tokenization_utils import AddedToken

_codegen_revision = dict(pretrained_model_name_or_path="Salesforce/codegen25-7b-multi",
                         revision="d4dc9dd90e8b23d5411e6d970e3a11e88dc5c2bc")

CodeGen25Tokenizer = get_class_from_dynamic_module(
    "tokenization_codegen25.CodeGen25Tokenizer", **_codegen_revision)

tiktoken_tokenizer = get_class_from_dynamic_module(
    "tokenization_codegen25.tiktoken_tokenizer", **_codegen_revision)


class DeciCoderTokenizer(CodeGen25Tokenizer):
    def __init__(
            self,
            pad_token=None,
            eos_token="<|endoftext|>",
            add_eos_token=False,
            add_special_tokens=True,
            **kwargs,
    ):
        self._tiktoken_kwargs = dict(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
        self.add_eos_token = add_eos_token
        self.encoder = tiktoken_tokenizer(**self._tiktoken_kwargs)
        pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
        eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        super().__init__(
            pad_token=pad_token_added,
            eos_token=eos_token_added,
            add_eos_token=add_eos_token,
            add_special_tokens=add_special_tokens,
            **kwargs,
        )

    def _convert_id_to_token(self, index):
        """ bug fix in CodeGen25Tokenizer """
        try:
            return super()._convert_id_to_token(index)
        except:
            return None

    def __getstate__(self):
        """ make the object picklable """
        return {**self.__dict__, "encoder": None}

    def __setstate__(self, state):
        """ initialize tiktoken encoder after unpickling """
        state["encoder"] = tiktoken_tokenizer(**state["_tiktoken_kwargs"])
        self.__dict__ = state

    def save_pretrained(self, *args, **kwargs):
        """
        add_special_tokens is not JSON serializable, which crashes save_pretrained().
        Removing it from the tokenizer_config.json does not affect from_pretrained().
        """
        add_special_tokens = self.add_special_tokens
        self.add_special_tokens = True
        super().save_pretrained(*args, **kwargs)
        self.add_special_tokens = add_special_tokens