Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset | |
import random, os | |
import numpy as np | |
import torch | |
import string | |
class TokenClfDataset(Dataset): | |
def __init__( | |
self, | |
texts, | |
max_len=512, | |
tokenizer=None, | |
model_name="bert-base-multilingual-cased", | |
): | |
self.len = len(texts) | |
self.texts = texts | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.model_name = model_name | |
if "bert-base-multilingual-cased" in model_name: | |
self.cls_token = "[CLS]" | |
self.sep_token = "[SEP]" | |
self.unk_token = "[UNK]" | |
self.pad_token = "[PAD]" | |
self.mask_token = "[MASK]" | |
elif "xlm-roberta-large" in model_name: | |
self.bos_token = "<s>" | |
self.eos_token = "</s>" | |
self.sep_token = "</s>" | |
self.cls_token = "<s>" | |
self.unk_token = "<unk>" | |
self.pad_token = "<pad>" | |
self.mask_token = "<mask>" | |
else: | |
raise NotImplementedError() | |
def __getitem__(self, index): | |
text = self.texts[index] | |
tokenized_text = self.tokenizer.tokenize(text) | |
tokenized_text = ( | |
[self.cls_token] + tokenized_text + [self.sep_token] | |
) # add special tokens | |
if len(tokenized_text) > self.max_len: | |
tokenized_text = tokenized_text[: self.max_len] | |
else: | |
tokenized_text = tokenized_text + [ | |
self.pad_token for _ in range(self.max_len - len(tokenized_text)) | |
] | |
attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] | |
ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) | |
return { | |
"ids": torch.tensor(ids, dtype=torch.long), | |
"mask": torch.tensor(attn_mask, dtype=torch.long), | |
} | |
def __len__(self): | |
return self.len | |
def seed_everything(seed: int): | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def is_begin_of_new_word(token, model_name, force_tokens, token_map): | |
if "bert-base-multilingual-cased" in model_name: | |
if token.lstrip("##") in force_tokens or token.lstrip("##") in set(token_map.values()): | |
return True | |
return not token.startswith("##") | |
elif "xlm-roberta-large" in model_name: | |
if token in string.punctuation or token in force_tokens or token in set(token_map.values()): | |
return True | |
return token.startswith("β") | |
else: | |
raise NotImplementedError() | |
def replace_added_token(token, token_map): | |
for ori_token, new_token in token_map.items(): | |
token = token.replace(new_token, ori_token) | |
return token | |
def get_pure_token(token, model_name): | |
if "bert-base-multilingual-cased" in model_name: | |
return token.lstrip("##") | |
elif "xlm-roberta-large" in model_name: | |
return token.lstrip("β") | |
else: | |
raise NotImplementedError() |