|
from typing import List |
|
import json |
|
from tokenizers import NormalizedString, PreTokenizedString |
|
import re |
|
from transformers import BertTokenizerFast |
|
from .splinter_json import splinter_data |
|
from tokenizers.pre_tokenizers import PreTokenizer, Sequence as PreTokenizerSequence |
|
from tokenizers.decoders import Decoder |
|
|
|
final_letters_map = { |
|
'ך': 'כ', |
|
'ם': 'מ', |
|
'ץ': 'צ', |
|
'ף': 'פ', |
|
'ן': 'נ', |
|
'כ': 'ך', |
|
'מ': 'ם', |
|
'צ': 'ץ', |
|
'פ': 'ף', |
|
'נ': 'ן' |
|
} |
|
|
|
def get_permutation(word, position, word_length): |
|
if position < 0: |
|
permutation = word[:word_length + position] + word[(word_length + position + 1):] |
|
else: |
|
permutation = word[:position] + word[(position + 1):] |
|
return permutation |
|
|
|
def replace_final_letters(text): |
|
if text == '': return text |
|
if text[-1] in final_letters_map: |
|
return replace_last_letter(text, final_letters_map[text[-1]]) |
|
return text |
|
|
|
def replace_last_letter(text, replacement): |
|
return text[:-1] + replacement |
|
|
|
def is_hebrew_letter(char): |
|
return '\u05D0' <= char <= '\u05EA' |
|
|
|
def is_word_contains_non_hebrew_letters(word) -> str: |
|
return re.search(r'[^\u05D0-\u05EA]', word) is not None |
|
|
|
class Splinter: |
|
def __init__(self, path, use_cache=True): |
|
if type(path) == str: |
|
with open(path, 'r', encoding='utf-8-sig') as r: |
|
parsed = json.loads(r.read()) |
|
else: |
|
parsed = path |
|
|
|
self.reductions_map = {int(key): value for key, value in parsed['reductions_map'].items()} |
|
self.new_unicode_chars_map = parsed['new_unicode_chars'] |
|
self.new_unicode_chars_inverted_map = {v:k for k,v in self.new_unicode_chars_map.items()} |
|
self.word_reductions_cache = dict() |
|
self.use_cache = use_cache |
|
|
|
def splinter_word(self, word: str): |
|
if self.use_cache: |
|
ret = self.word_reductions_cache.get(word, None) |
|
if ret: return self.word_reductions_cache[word] |
|
|
|
clean_word = replace_final_letters(word) |
|
|
|
if len(clean_word) > 15 or is_word_contains_non_hebrew_letters(clean_word): |
|
encoded_word = self.get_word_with_non_hebrew_chars_reduction(clean_word) |
|
else: |
|
word_reductions = self.get_word_reductions(clean_word) |
|
encoded_word = ''.join([self.new_unicode_chars_map[reduction] for reduction in word_reductions]) |
|
if self.use_cache: |
|
self.word_reductions_cache[word] = encoded_word |
|
return encoded_word |
|
|
|
def unsplinter_word(self, word: str): |
|
decoded_word = self.decode_word(word) |
|
return self.rebuild_reduced_word(decoded_word) |
|
|
|
def decode_word(self, word: str): |
|
return [self.new_unicode_chars_inverted_map.get(char, char) for char in word] |
|
|
|
|
|
def rebuild_reduced_word(self, decoded_word): |
|
original_word = "" |
|
for reduction in decoded_word: |
|
if ':' in reduction and len(reduction) > 1: |
|
position, letter = reduction.split(':') |
|
position = int(position) |
|
if position < 0: |
|
position = len(original_word) + position + 1 |
|
if len(original_word) == position - 1: |
|
original_word += reduction |
|
else: |
|
original_word = original_word[:position] + letter + original_word[position:] |
|
else: |
|
original_word += reduction |
|
|
|
original_word = replace_final_letters(original_word) |
|
return original_word |
|
|
|
|
|
def get_word_reductions(self, word): |
|
reduced_word = word |
|
reductions = [] |
|
while len(reduced_word) > 3: |
|
|
|
if len(reduced_word) not in self.reductions_map: |
|
reductions.extend(self.get_single_chars_reductions(reduced_word)) |
|
break |
|
reduction = self.get_reduction(reduced_word, 3, 3) |
|
if reduction is not None: |
|
position = int(reduction.split(':')[0]) |
|
reductions.append(reduction) |
|
reduced_word = get_permutation(reduced_word, position, len(reduced_word)) |
|
|
|
else: |
|
reductions.extend(self.get_single_chars_reductions(reduced_word)) |
|
break |
|
|
|
|
|
if len(reduced_word) < 4: |
|
reductions.extend(self.get_single_chars_reductions(reduced_word)) |
|
|
|
reductions.reverse() |
|
return reductions |
|
|
|
def get_reduction(self, word, depth, width): |
|
curr_step_reductions = [{"word": word, "reduction": None, "root_reduction": None, "score": 1}] |
|
word_length = len(word) |
|
i = 0 |
|
while i < depth and len(curr_step_reductions) > 0 and word_length > 3: |
|
next_step_reductions = list() |
|
for reduction in curr_step_reductions: |
|
possible_reductions = self.get_most_frequent_reduction_keys( |
|
reduction["word"], |
|
reduction["root_reduction"], |
|
reduction["score"], |
|
width, |
|
word_length |
|
) |
|
next_step_reductions += possible_reductions |
|
curr_step_reductions = list(next_step_reductions) |
|
i += 1 |
|
word_length -= 1 |
|
|
|
max_score_reduction = None |
|
if len(curr_step_reductions) > 0: |
|
max_score_reduction = max(curr_step_reductions, key=lambda x: x["score"])["root_reduction"] |
|
return max_score_reduction |
|
|
|
def get_most_frequent_reduction_keys(self, word, root_reduction, parent_score, number_of_reductions, word_length): |
|
possible_reductions = list() |
|
for reduction, score in self.reductions_map[len(word)].items(): |
|
position, letter = reduction.split(':') |
|
position = int(position) |
|
if word[position] == letter: |
|
permutation = get_permutation(word, position, word_length) |
|
possible_reductions.append({ |
|
"word": permutation, |
|
"reduction": reduction, |
|
"root_reduction": root_reduction if root_reduction is not None else reduction, |
|
"score": parent_score * score |
|
}) |
|
if len(possible_reductions) >= number_of_reductions: |
|
break |
|
return possible_reductions |
|
|
|
def get_word_with_non_hebrew_chars_reduction(self, word): |
|
return ''.join(self.new_unicode_chars_map[char] if is_hebrew_letter(char) else char for char in word) |
|
|
|
@staticmethod |
|
def get_single_chars_reductions(reduced_word): |
|
reductions = [] |
|
for char in reduced_word[::-1]: |
|
reductions.append(char) |
|
return reductions |
|
|
|
class SplinterPreTokenizer: |
|
def __init__(self, splinter: Splinter): |
|
super().__init__() |
|
self.splinter = splinter |
|
|
|
def splinter_split(self, i: int, str: NormalizedString): |
|
|
|
splintered_word = iter(self.splinter.splinter_word(str.normalized)) |
|
str.map(lambda _: next(splintered_word, ' ')) |
|
str.strip() |
|
return [str] |
|
|
|
def pre_tokenize(self, pretok: PreTokenizedString): |
|
pretok.split(self.splinter_split) |
|
|
|
class SplinterDecoder: |
|
def __init__(self, splinter: Splinter): |
|
self.splinter = splinter |
|
|
|
def decode_chain(self, tokens: List[str]) -> List[str]: |
|
|
|
combined_tokens = [] |
|
for token in tokens: |
|
if token.startswith('##') and combined_tokens: |
|
combined_tokens[-1] += token[2:] |
|
else: combined_tokens.append(token) |
|
|
|
return [f' {t}' for t in map(self.splinter.unsplinter_word, combined_tokens)] |
|
|
|
|
|
class SplinterBertTokenizerFast(BertTokenizerFast): |
|
def __init__(self, *args, use_cache=False, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.splinter = Splinter(splinter_data, use_cache=use_cache) |
|
self._tokenizer.pre_tokenizer = PreTokenizerSequence([ |
|
self._tokenizer.pre_tokenizer, |
|
PreTokenizer.custom(SplinterPreTokenizer(self.splinter)) |
|
]) |
|
self._tokenizer.decoder = Decoder.custom(SplinterDecoder(self.splinter)) |
|
|
|
def save_pretrained(self, *args, **kwargs): |
|
self._save_pretrained(*args, **kwargs) |
|
|
|
def _save_pretrained(self, *args, **kwargs): |
|
print('Cannot save SplinterBertTokenizerFast, please copy the files directly from the repository') |
|
|