import regex as re |
from tqdm import tqdm |
def sanskrit_token_preprocessor(text): |
pat = re.compile(r""" |
\ ?[\u0900-\u097F]+ # Match Sanskit words with an optional space before |
|\ ?\d+ # Matches one or more numerical digits. |
|\ ?[^\s\u0900-\u097F\d]+ # Matches any character that is not a space, snaskrit or digit with an optional space before |
|\s+(?!\S) # trailing whitespace characters (spaces, tabs, newlines) |
|\s+ # Match whitespace (spaces, tabs, newlines) |
|\s?[\r\n] |
""", re.VERBOSE) |
return re.findall(pat, text) |
def bpeAlgo(tokensList, total_runs, newTokenStartValue): |
paired_tokens_vocab = {} |
newTokensList = list(tokensList) |
for i in tqdm(range(0, total_runs), desc="Learning BPE"): |
pair_stats = {} |
for tokens in newTokensList: |
pair_stats = get_pair_stats(tokens, pair_stats) |
if not pair_stats: |
print(f"Not enought unique pairs to run BPE's {total_runs} runs") |
raise ValueError(f"Not enough pairs to create a vocab of size {total_runs+256}") |
top_pair = max(pair_stats, key=pair_stats.get) |
newTokenVal = len(paired_tokens_vocab)+newTokenStartValue |
newTokensList = [merge(tokens, top_pair, newTokenVal) for tokens in newTokensList] |
paired_tokens_vocab[top_pair] = newTokenVal |
return paired_tokens_vocab, [tok for toks in newTokensList for tok in toks] |
def get_pair_stats(toks, pair_stats): |
for pair in zip(toks, toks[1:]): |
pair_stats[pair] = pair_stats.get(pair, 0) + 1 |
return pair_stats |
def merge(toks, pair, newTok): |
newToks = [] |
i = 0 |
while i < len(toks): |
if i < len(toks) - 1 and (toks[i], toks[i+1]) == pair: |
newToks.append(newTok) |
i += 2 |
else: |
newToks.append(toks[i]) |
i += 1 |
return newToks |
def create_vocab(paired_tokens_vocab): |
vocab = {i: bytes([i]) for i in range(256)} |
for (p0, p1), i in paired_tokens_vocab.items(): |
vocab[i] = vocab[p0] + vocab[p1] |
return vocab |
def save_paired_tokens_vocab(filepath, paired_tokens_vocab): |
with open(filepath, 'w') as f: |
f.write('Token version 1\n') |
for k,v in paired_tokens_vocab.items(): |
f.write(f"{k}:{v}\n") |
def save_vocab(filepath, vocab): |
with open(filepath, 'w') as f: |
f.write('Token version 1\n') |
for k,v in vocab.items(): |
f.write(f"{k}:{v}\n") |
def load_paired_tokens_vocab(filepath): |
paired_tokens = {} |
with open(filepath, 'r') as f: |
for line in f: |
[k, v] = line.split(":") |
paired_tokens[k] = v |
return paired_tokens |
def load_vocab(filepath): |
vocab = {} |
with open(filepath, 'r') as f: |
for line in f: |
[k, v] = line.split("<::>") |
vocab[int(k)] = eval(v) |
return vocab |
def decode(intTokens, vocab): |
vocalTokenVals = b"".join(vocab[vocabKey] for vocabKey in intTokens) |
return vocalTokenVals.decode("utf-8") |
def encode(text, paired_tokens_vocab): |
tokens = list(text.encode('utf-8')) |
while len(tokens) > 1: |
pair_stats = get_pair_stats(tokens, {}) |
pair = min(pair_stats, key=lambda k : paired_tokens_vocab.get(k, float("inf"))) |
if pair not in paired_tokens_vocab: |
return tokens |
newTok = paired_tokens_vocab[pair] |
tokens = merge(tokens, pair, newTok) |
return tokens |