Safetensors
Hebrew
bert
dictabert-splinter / tokenization_splinter.py
Shaltiel's picture
Upload tokenization_splinter.py
a5aac9e verified
from typing import List
import json
from tokenizers import NormalizedString, PreTokenizedString
import re
from transformers import BertTokenizerFast
from .splinter_json import splinter_data
from tokenizers.pre_tokenizers import PreTokenizer, Sequence as PreTokenizerSequence
from tokenizers.decoders import Decoder
final_letters_map = {
'ך': 'כ',
'ם': 'מ',
'ץ': 'צ',
'ף': 'פ',
'ן': 'נ',
'כ': 'ך',
'מ': 'ם',
'צ': 'ץ',
'פ': 'ף',
'נ': 'ן'
}
def get_permutation(word, position, word_length):
if position < 0:
permutation = word[:word_length + position] + word[(word_length + position + 1):]
else:
permutation = word[:position] + word[(position + 1):]
return permutation
def replace_final_letters(text):
if text == '': return text
if text[-1] in final_letters_map:
return replace_last_letter(text, final_letters_map[text[-1]])
return text
def replace_last_letter(text, replacement):
return text[:-1] + replacement
def is_hebrew_letter(char):
return '\u05D0' <= char <= '\u05EA'
def is_word_contains_non_hebrew_letters(word) -> str:
return re.search(r'[^\u05D0-\u05EA]', word) is not None
class Splinter:
def __init__(self, path, use_cache=True):
if type(path) == str:
with open(path, 'r', encoding='utf-8-sig') as r:
parsed = json.loads(r.read())
else:
parsed = path
self.reductions_map = {int(key): value for key, value in parsed['reductions_map'].items()}
self.new_unicode_chars_map = parsed['new_unicode_chars']
self.new_unicode_chars_inverted_map = {v:k for k,v in self.new_unicode_chars_map.items()}
self.word_reductions_cache = dict()
self.use_cache = use_cache
def splinter_word(self, word: str):
if self.use_cache:
ret = self.word_reductions_cache.get(word, None)
if ret: return self.word_reductions_cache[word]
clean_word = replace_final_letters(word)
# if a word contains non-Hebrew characters, convert only the Hebrew.
if len(clean_word) > 15 or is_word_contains_non_hebrew_letters(clean_word):
encoded_word = self.get_word_with_non_hebrew_chars_reduction(clean_word)
else:
word_reductions = self.get_word_reductions(clean_word)
encoded_word = ''.join([self.new_unicode_chars_map[reduction] for reduction in word_reductions])
if self.use_cache:
self.word_reductions_cache[word] = encoded_word
return encoded_word
def unsplinter_word(self, word: str):
decoded_word = self.decode_word(word)
return self.rebuild_reduced_word(decoded_word)
def decode_word(self, word: str):
return [self.new_unicode_chars_inverted_map.get(char, char) for char in word]
def rebuild_reduced_word(self, decoded_word):
original_word = ""
for reduction in decoded_word:
if ':' in reduction and len(reduction) > 1:
position, letter = reduction.split(':')
position = int(position)
if position < 0:
position = len(original_word) + position + 1
if len(original_word) == position - 1:
original_word += reduction
else:
original_word = original_word[:position] + letter + original_word[position:]
else:
original_word += reduction
original_word = replace_final_letters(original_word)
return original_word
def get_word_reductions(self, word):
reduced_word = word
reductions = []
while len(reduced_word) > 3:
# if this word length has no known reductions - return what's left of the word as is
if len(reduced_word) not in self.reductions_map:
reductions.extend(self.get_single_chars_reductions(reduced_word))
break
reduction = self.get_reduction(reduced_word, 3, 3)
if reduction is not None:
position = int(reduction.split(':')[0])
reductions.append(reduction)
reduced_word = get_permutation(reduced_word, position, len(reduced_word))
# if we couldn't find a reduction - return what's left of the word as is
else:
reductions.extend(self.get_single_chars_reductions(reduced_word))
break
# if we found all reductions and left only with the suspected root - keep it as is
if len(reduced_word) < 4:
reductions.extend(self.get_single_chars_reductions(reduced_word))
reductions.reverse()
return reductions
def get_reduction(self, word, depth, width):
curr_step_reductions = [{"word": word, "reduction": None, "root_reduction": None, "score": 1}]
word_length = len(word)
i = 0
while i < depth and len(curr_step_reductions) > 0 and word_length > 3:
next_step_reductions = list()
for reduction in curr_step_reductions:
possible_reductions = self.get_most_frequent_reduction_keys(
reduction["word"],
reduction["root_reduction"],
reduction["score"],
width,
word_length
)
next_step_reductions += possible_reductions
curr_step_reductions = list(next_step_reductions)
i += 1
word_length -= 1
max_score_reduction = None
if len(curr_step_reductions) > 0:
max_score_reduction = max(curr_step_reductions, key=lambda x: x["score"])["root_reduction"]
return max_score_reduction
def get_most_frequent_reduction_keys(self, word, root_reduction, parent_score, number_of_reductions, word_length):
possible_reductions = list()
for reduction, score in self.reductions_map[len(word)].items():
position, letter = reduction.split(':')
position = int(position)
if word[position] == letter:
permutation = get_permutation(word, position, word_length)
possible_reductions.append({
"word": permutation,
"reduction": reduction,
"root_reduction": root_reduction if root_reduction is not None else reduction,
"score": parent_score * score
})
if len(possible_reductions) >= number_of_reductions:
break
return possible_reductions
def get_word_with_non_hebrew_chars_reduction(self, word):
return ''.join(self.new_unicode_chars_map[char] if is_hebrew_letter(char) else char for char in word)
@staticmethod
def get_single_chars_reductions(reduced_word):
reductions = []
for char in reduced_word[::-1]:
reductions.append(char)
return reductions
class SplinterPreTokenizer:
def __init__(self, splinter: Splinter):
super().__init__()
self.splinter = splinter
def splinter_split(self, i: int, str: NormalizedString):
# create the split
splintered_word = iter(self.splinter.splinter_word(str.normalized))
str.map(lambda _: next(splintered_word, ' '))
str.strip()
return [str]
def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self.splinter_split)
class SplinterDecoder:
def __init__(self, splinter: Splinter):
self.splinter = splinter
def decode_chain(self, tokens: List[str]) -> List[str]:
# combine the wordpieces
combined_tokens = []
for token in tokens:
if token.startswith('##') and combined_tokens:
combined_tokens[-1] += token[2:]
else: combined_tokens.append(token)
return [f' {t}' for t in map(self.splinter.unsplinter_word, combined_tokens)]
class SplinterBertTokenizerFast(BertTokenizerFast):
def __init__(self, *args, use_cache=False, **kwargs):
super().__init__(*args, **kwargs)
self.splinter = Splinter(splinter_data, use_cache=use_cache)
self._tokenizer.pre_tokenizer = PreTokenizerSequence([
self._tokenizer.pre_tokenizer,
PreTokenizer.custom(SplinterPreTokenizer(self.splinter))
])
self._tokenizer.decoder = Decoder.custom(SplinterDecoder(self.splinter))
def save_pretrained(self, *args, **kwargs):
self._save_pretrained(*args, **kwargs)
def _save_pretrained(self, *args, **kwargs):
print('Cannot save SplinterBertTokenizerFast, please copy the files directly from the repository')