from typing import List import json from tokenizers import NormalizedString, PreTokenizedString import re from transformers import BertTokenizerFast from .splinter_json import splinter_data from tokenizers.pre_tokenizers import PreTokenizer, Sequence as PreTokenizerSequence from tokenizers.decoders import Decoder final_letters_map = { 'ך': 'כ', 'ם': 'מ', 'ץ': 'צ', 'ף': 'פ', 'ן': 'נ', 'כ': 'ך', 'מ': 'ם', 'צ': 'ץ', 'פ': 'ף', 'נ': 'ן' } def get_permutation(word, position, word_length): if position < 0: permutation = word[:word_length + position] + word[(word_length + position + 1):] else: permutation = word[:position] + word[(position + 1):] return permutation def replace_final_letters(text): if text == '': return text if text[-1] in final_letters_map: return replace_last_letter(text, final_letters_map[text[-1]]) return text def replace_last_letter(text, replacement): return text[:-1] + replacement def is_hebrew_letter(char): return '\u05D0' <= char <= '\u05EA' def is_word_contains_non_hebrew_letters(word) -> str: return re.search(r'[^\u05D0-\u05EA]', word) is not None class Splinter: def __init__(self, path, use_cache=True): if type(path) == str: with open(path, 'r', encoding='utf-8-sig') as r: parsed = json.loads(r.read()) else: parsed = path self.reductions_map = {int(key): value for key, value in parsed['reductions_map'].items()} self.new_unicode_chars_map = parsed['new_unicode_chars'] self.new_unicode_chars_inverted_map = {v:k for k,v in self.new_unicode_chars_map.items()} self.word_reductions_cache = dict() self.use_cache = use_cache def splinter_word(self, word: str): if self.use_cache: ret = self.word_reductions_cache.get(word, None) if ret: return self.word_reductions_cache[word] clean_word = replace_final_letters(word) # if a word contains non-Hebrew characters, convert only the Hebrew. if len(clean_word) > 15 or is_word_contains_non_hebrew_letters(clean_word): encoded_word = self.get_word_with_non_hebrew_chars_reduction(clean_word) else: word_reductions = self.get_word_reductions(clean_word) encoded_word = ''.join([self.new_unicode_chars_map[reduction] for reduction in word_reductions]) if self.use_cache: self.word_reductions_cache[word] = encoded_word return encoded_word def unsplinter_word(self, word: str): decoded_word = self.decode_word(word) return self.rebuild_reduced_word(decoded_word) def decode_word(self, word: str): return [self.new_unicode_chars_inverted_map.get(char, char) for char in word] def rebuild_reduced_word(self, decoded_word): original_word = "" for reduction in decoded_word: if ':' in reduction and len(reduction) > 1: position, letter = reduction.split(':') position = int(position) if position < 0: position = len(original_word) + position + 1 if len(original_word) == position - 1: original_word += reduction else: original_word = original_word[:position] + letter + original_word[position:] else: original_word += reduction original_word = replace_final_letters(original_word) return original_word def get_word_reductions(self, word): reduced_word = word reductions = [] while len(reduced_word) > 3: # if this word length has no known reductions - return what's left of the word as is if len(reduced_word) not in self.reductions_map: reductions.extend(self.get_single_chars_reductions(reduced_word)) break reduction = self.get_reduction(reduced_word, 3, 3) if reduction is not None: position = int(reduction.split(':')[0]) reductions.append(reduction) reduced_word = get_permutation(reduced_word, position, len(reduced_word)) # if we couldn't find a reduction - return what's left of the word as is else: reductions.extend(self.get_single_chars_reductions(reduced_word)) break # if we found all reductions and left only with the suspected root - keep it as is if len(reduced_word) < 4: reductions.extend(self.get_single_chars_reductions(reduced_word)) reductions.reverse() return reductions def get_reduction(self, word, depth, width): curr_step_reductions = [{"word": word, "reduction": None, "root_reduction": None, "score": 1}] word_length = len(word) i = 0 while i < depth and len(curr_step_reductions) > 0 and word_length > 3: next_step_reductions = list() for reduction in curr_step_reductions: possible_reductions = self.get_most_frequent_reduction_keys( reduction["word"], reduction["root_reduction"], reduction["score"], width, word_length ) next_step_reductions += possible_reductions curr_step_reductions = list(next_step_reductions) i += 1 word_length -= 1 max_score_reduction = None if len(curr_step_reductions) > 0: max_score_reduction = max(curr_step_reductions, key=lambda x: x["score"])["root_reduction"] return max_score_reduction def get_most_frequent_reduction_keys(self, word, root_reduction, parent_score, number_of_reductions, word_length): possible_reductions = list() for reduction, score in self.reductions_map[len(word)].items(): position, letter = reduction.split(':') position = int(position) if word[position] == letter: permutation = get_permutation(word, position, word_length) possible_reductions.append({ "word": permutation, "reduction": reduction, "root_reduction": root_reduction if root_reduction is not None else reduction, "score": parent_score * score }) if len(possible_reductions) >= number_of_reductions: break return possible_reductions def get_word_with_non_hebrew_chars_reduction(self, word): return ''.join(self.new_unicode_chars_map[char] if is_hebrew_letter(char) else char for char in word) @staticmethod def get_single_chars_reductions(reduced_word): reductions = [] for char in reduced_word[::-1]: reductions.append(char) return reductions class SplinterPreTokenizer: def __init__(self, splinter: Splinter): super().__init__() self.splinter = splinter def splinter_split(self, i: int, str: NormalizedString): # create the split splintered_word = iter(self.splinter.splinter_word(str.normalized)) str.map(lambda _: next(splintered_word, ' ')) str.strip() return [str] def pre_tokenize(self, pretok: PreTokenizedString): pretok.split(self.splinter_split) class SplinterDecoder: def __init__(self, splinter: Splinter): self.splinter = splinter def decode_chain(self, tokens: List[str]) -> List[str]: # combine the wordpieces combined_tokens = [] for token in tokens: if token.startswith('##') and combined_tokens: combined_tokens[-1] += token[2:] else: combined_tokens.append(token) return [f' {t}' for t in map(self.splinter.unsplinter_word, combined_tokens)] class SplinterBertTokenizerFast(BertTokenizerFast): def __init__(self, *args, use_cache=False, **kwargs): super().__init__(*args, **kwargs) self.splinter = Splinter(splinter_data, use_cache=use_cache) self._tokenizer.pre_tokenizer = PreTokenizerSequence([ self._tokenizer.pre_tokenizer, PreTokenizer.custom(SplinterPreTokenizer(self.splinter)) ]) self._tokenizer.decoder = Decoder.custom(SplinterDecoder(self.splinter)) def save_pretrained(self, *args, **kwargs): self._save_pretrained(*args, **kwargs) def _save_pretrained(self, *args, **kwargs): print('Cannot save SplinterBertTokenizerFast, please copy the files directly from the repository')