File size: 3,815 Bytes
d79115c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import json
from typing import List, Dict, Optional, Tuple, Union
from pathlib import Path
import regex as re
class OpenPeerTokenizer:
"""Simple tokenizer implementation for testing"""
def __init__(self, unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token="<|endoftext|>"):
self.unk_token = unk_token
self.bos_token = bos_token
self.eos_token = eos_token
self.pad_token = pad_token
self.eos_token_id = 0
# Get vocabulary
self.vocab = self._get_default_vocab()
self.vocab_size = len(self.vocab)
def _get_default_vocab(self) -> Dict[str, int]:
"""Get a basic default vocabulary"""
vocab = {}
# Add special tokens
vocab[self.unk_token] = 0
vocab[self.pad_token] = 1
vocab["<|mask|>"] = 2
# Add basic ASCII characters and common words
for i in range(32, 127):
vocab[chr(i)] = len(vocab)
# Add some common words
common_words = ["the", "be", "to", "of", "and", "a", "in", "that", "have"]
for word in common_words:
vocab[word] = len(vocab)
return vocab
def __call__(self, text: Union[str, List[str]], **kwargs) -> Dict[str, List[int]]:
"""Tokenize text"""
if isinstance(text, str):
# Split into words and characters
tokens = []
for word in text.split():
# Add word if in vocab, otherwise split into characters
if word in self.vocab:
tokens.append(self.vocab[word])
else:
for char in word:
tokens.append(self.vocab.get(char, self.vocab[self.unk_token]))
else:
tokens = []
for t in text:
word_tokens = []
for word in t.split():
if word in self.vocab:
word_tokens.append(self.vocab[word])
else:
for char in word:
word_tokens.append(self.vocab.get(char, self.vocab[self.unk_token]))
tokens.append(word_tokens)
if isinstance(text, str):
attention_mask = [1] * len(tokens)
return {"input_ids": tokens, "attention_mask": attention_mask}
else:
attention_masks = [[1] * len(t) for t in tokens]
return {"input_ids": tokens, "attention_mask": attention_masks}
def decode(self, token_ids: Union[List[int], List[List[int]]], skip_special_tokens: bool = True) -> str:
"""Decode token ids to text"""
# Create reverse vocab mapping
id_to_token = {v: k for k, v in self.vocab.items()}
if isinstance(token_ids[0], list):
# Batch decoding
texts = []
for ids in token_ids:
text = []
for id in ids:
token = id_to_token.get(id, self.unk_token)
if not skip_special_tokens or token not in [self.unk_token, self.pad_token, "<|mask|>"]:
text.append(token)
texts.append(" ".join(text))
return texts
else:
# Single sequence decoding
text = []
for id in token_ids:
token = id_to_token.get(id, self.unk_token)
if not skip_special_tokens or token not in [self.unk_token, self.pad_token, "<|mask|>"]:
text.append(token)
return " ".join(text) |