|
import json |
|
import os |
|
from typing import List, Optional, Union |
|
|
|
from transformers import PreTrainedTokenizerFast |
|
|
|
|
|
class TessarTokenizer(PreTrainedTokenizerFast): |
|
""" |
|
Tessar Tokenizer implementation for Hugging Face Transformers |
|
""" |
|
|
|
model_input_names = ['input_ids', 'attention_mask'] |
|
|
|
def __init__( |
|
self, |
|
vocab_file=None, |
|
tokenizer_file=None, |
|
do_lower_case=True, |
|
unk_token="<unk>", |
|
sep_token="</s>", |
|
pad_token="<pad>", |
|
cls_token="<s>", |
|
mask_token="<mask>", |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
max_cell_length=15, |
|
**kwargs |
|
): |
|
""" |
|
Initialize the Tessar Tokenizer with specific token configurations |
|
|
|
Args: |
|
vocab_file (str, optional): Path to the vocabulary file |
|
tokenizer_file (str, optional): Path to the pre-trained tokenizer file |
|
do_lower_case (bool, optional): Whether to lowercase the input. Defaults to True. |
|
max_cell_length (int, optional): Maximum length for cell tokenization. Defaults to 15. |
|
""" |
|
|
|
special_tokens = { |
|
"unk_token": unk_token, |
|
"sep_token": sep_token, |
|
"pad_token": pad_token, |
|
"cls_token": cls_token, |
|
"mask_token": mask_token, |
|
"bos_token": bos_token, |
|
"eos_token": eos_token, |
|
} |
|
|
|
|
|
special_tokens = {k: v for k, v in special_tokens.items() if v is not None} |
|
|
|
|
|
super().__init__( |
|
vocab_file=vocab_file, |
|
tokenizer_file=tokenizer_file, |
|
do_lower_case=do_lower_case, |
|
**special_tokens, |
|
**kwargs |
|
) |
|
|
|
|
|
self.do_lower_case = do_lower_case |
|
self.max_cell_length = max_cell_length |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: |
|
""" |
|
Save the tokenizer vocabulary and special tokens file |
|
|
|
Args: |
|
save_directory (str): Directory to save the vocabulary |
|
filename_prefix (str, optional): Prefix for the saved files |
|
|
|
Returns: |
|
tuple: Paths to the saved files |
|
""" |
|
|
|
vocab_file = os.path.join( |
|
save_directory, |
|
f"{filename_prefix + '-' if filename_prefix else ''}vocab.json" |
|
) |
|
|
|
|
|
special_tokens_file = os.path.join( |
|
save_directory, |
|
f"{filename_prefix + '-' if filename_prefix else ''}special_tokens.json" |
|
) |
|
|
|
|
|
with open(vocab_file, 'w', encoding='utf-8') as f: |
|
json.dump(self.vocab, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
special_tokens_config = { |
|
"unk_token": self.unk_token, |
|
"sep_token": self.sep_token, |
|
"pad_token": self.pad_token, |
|
"cls_token": self.cls_token, |
|
"mask_token": self.mask_token, |
|
"bos_token": self.bos_token, |
|
"eos_token": self.eos_token, |
|
"do_lower_case": self.do_lower_case, |
|
"max_cell_length": self.max_cell_length |
|
} |
|
|
|
with open(special_tokens_file, 'w', encoding='utf-8') as f: |
|
json.dump(special_tokens_config, f, ensure_ascii=False, indent=2) |
|
|
|
return (vocab_file, special_tokens_file) |
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
""" |
|
Custom tokenization method |
|
|
|
Args: |
|
text (str): Input text to tokenize |
|
|
|
Returns: |
|
List[str]: List of tokens |
|
""" |
|
|
|
if self.do_lower_case: |
|
text = text.lower() |
|
|
|
|
|
tokens = super()._tokenize(text) |
|
|
|
|
|
tokens = tokens[:self.max_cell_length] |
|
|
|
return tokens |
|
|
|
def prepare_for_model( |
|
self, |
|
ids: List[int], |
|
pair_ids: Optional[List[int]] = None, |
|
**kwargs |
|
) -> dict: |
|
""" |
|
Prepare tokenized inputs for the model |
|
|
|
Args: |
|
ids (List[int]): List of input token ids |
|
pair_ids (Optional[List[int]], optional): List of pair token ids |
|
|
|
Returns: |
|
dict: Prepared model inputs |
|
""" |
|
|
|
|
|
return super().prepare_for_model(ids, pair_ids, **kwargs) |
|
|
|
|
|
def load_tessar_tokenizer(pretrained_model_name_or_path: str): |
|
""" |
|
Load a pretrained Tessar tokenizer |
|
|
|
Args: |
|
pretrained_model_name_or_path (str): Path to the pretrained model |
|
|
|
Returns: |
|
TessarTokenizer: Initialized tokenizer |
|
""" |
|
return TessarTokenizer.from_pretrained(pretrained_model_name_or_path) |