|
import regex as re |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sanskrit_token_preprocessor(text): |
|
|
|
pat = re.compile(r""" |
|
\ ?[\u0900-\u097F]+ # Match Sanskit words with an optional space before |
|
|\ ?\d+ # Matches one or more numerical digits. |
|
|\ ?[^\s\u0900-\u097F\d]+ # Matches any character that is not a space, snaskrit or digit with an optional space before |
|
|\s+(?!\S) # trailing whitespace characters (spaces, tabs, newlines) |
|
|\s+ # Match whitespace (spaces, tabs, newlines) |
|
|\s?[\r\n] |
|
""", re.VERBOSE) |
|
return re.findall(pat, text) |
|
|
|
def bpeAlgo(tokensList, total_runs, newTokenStartValue): |
|
|
|
paired_tokens_vocab = {} |
|
newTokensList = list(tokensList) |
|
|
|
|
|
for i in tqdm(range(0, total_runs), desc="Learning BPE"): |
|
pair_stats = {} |
|
for tokens in newTokensList: |
|
pair_stats = get_pair_stats(tokens, pair_stats) |
|
|
|
|
|
if not pair_stats: |
|
print(f"Not enought unique pairs to run BPE's {total_runs} runs") |
|
raise ValueError(f"Not enough pairs to create a vocab of size {total_runs+256}") |
|
|
|
top_pair = max(pair_stats, key=pair_stats.get) |
|
newTokenVal = len(paired_tokens_vocab)+newTokenStartValue |
|
|
|
newTokensList = [merge(tokens, top_pair, newTokenVal) for tokens in newTokensList] |
|
|
|
|
|
paired_tokens_vocab[top_pair] = newTokenVal |
|
|
|
return paired_tokens_vocab, [tok for toks in newTokensList for tok in toks] |
|
|
|
|
|
def get_pair_stats(toks, pair_stats): |
|
for pair in zip(toks, toks[1:]): |
|
pair_stats[pair] = pair_stats.get(pair, 0) + 1 |
|
|
|
return pair_stats |
|
|
|
|
|
def merge(toks, pair, newTok): |
|
newToks = [] |
|
i = 0 |
|
while i < len(toks): |
|
if i < len(toks) - 1 and (toks[i], toks[i+1]) == pair: |
|
newToks.append(newTok) |
|
i += 2 |
|
else: |
|
newToks.append(toks[i]) |
|
i += 1 |
|
return newToks |
|
|
|
|
|
def create_vocab(paired_tokens_vocab): |
|
vocab = {i: bytes([i]) for i in range(256)} |
|
|
|
for (p0, p1), i in paired_tokens_vocab.items(): |
|
vocab[i] = vocab[p0] + vocab[p1] |
|
|
|
return vocab |
|
|
|
def save_paired_tokens_vocab(filepath, paired_tokens_vocab): |
|
with open(filepath, 'w') as f: |
|
f.write('Token version 1\n') |
|
for k,v in paired_tokens_vocab.items(): |
|
f.write(f"{k}:{v}\n") |
|
|
|
def save_vocab(filepath, vocab): |
|
with open(filepath, 'w') as f: |
|
f.write('Token version 1\n') |
|
for k,v in vocab.items(): |
|
f.write(f"{k}:{v}\n") |
|
|
|
def load_paired_tokens_vocab(filepath): |
|
paired_tokens = {} |
|
with open(filepath, 'r') as f: |
|
for line in f: |
|
[k, v] = line.split(":") |
|
paired_tokens[k] = v |
|
return paired_tokens |
|
|
|
def load_vocab(filepath): |
|
vocab = {} |
|
with open(filepath, 'r') as f: |
|
for line in f: |
|
[k, v] = line.split("<::>") |
|
vocab[int(k)] = eval(v) |
|
return vocab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode(intTokens, vocab): |
|
vocalTokenVals = b"".join(vocab[vocabKey] for vocabKey in intTokens) |
|
return vocalTokenVals.decode("utf-8") |
|
|
|
def encode(text, paired_tokens_vocab): |
|
tokens = list(text.encode('utf-8')) |
|
|
|
while len(tokens) > 1: |
|
pair_stats = get_pair_stats(tokens, {}) |
|
|
|
pair = min(pair_stats, key=lambda k : paired_tokens_vocab.get(k, float("inf"))) |
|
if pair not in paired_tokens_vocab: |
|
return tokens |
|
newTok = paired_tokens_vocab[pair] |
|
|
|
tokens = merge(tokens, pair, newTok) |
|
|
|
return tokens |