Spaces:
Sleeping
Sleeping
import yaml | |
import regex as re | |
from tqdm import tqdm | |
import gc | |
import json | |
def load_config(config_file_path: str = "config.yml"): | |
with open(config_file_path, "r") as f: | |
config = yaml.safe_load(f) | |
return config | |
def get_input_text(config: dict) -> str: | |
with open(config["input_file_info"]["file_path"], 'r', encoding='utf-8') as _f: | |
hi_text = [line.strip() for line in _f.readlines()] | |
hi_text_abridged = hi_text[:int(config["input_file_info"]["input_file_limit"])] | |
hi_text_abridged = '\n'.join(hi_text_abridged) | |
if config["input_file_info"]["print_text"]: | |
print(" Sample text: ", hi_text_abridged[:10]) | |
return hi_text_abridged | |
def get_stats(ids, counts= None): | |
counts = {} if counts is None else counts | |
for pair in zip(ids, ids[1:]): | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def merge(ids, pair, idx): | |
newids = [] | |
i = 0 | |
while i < len(ids): | |
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
newids.append(idx) | |
i += 2 | |
else: | |
newids.append(ids[i]) | |
i += 1 | |
return newids | |
def stoi(text: str, config: dict) -> list: | |
# tokenize the text | |
if config["regex_string"] and len(config["regex_string"]) > 0: | |
print("Using regex string: ", config["regex_string"]) | |
tokens = re.findall(config["regex_string"], text) | |
# Convert tokens to bytes and then to integers | |
return [b for token in tokens for b in token.encode('utf-8')] | |
else: | |
print("Using default tokenizer") | |
# Instead of splitting, we'll preserve spaces by encoding them directly | |
return [b for ch in text for b in ch.encode('utf-8')] | |
def encode(text, merges, config: dict): | |
""" | |
Encode text into tokens using the learned merges | |
""" | |
ids = stoi(text, config) | |
sorted_merges = sorted(merges.items(), key=lambda x: x[1]) | |
for (p1, p2), idx in sorted_merges: | |
ids = merge(ids, (p1, p2), idx) | |
return ids | |
def decode(ids, merges, config: dict): | |
""" | |
Decode tokens back to text using the learned merges | |
""" | |
# Create reverse mapping from token to pair | |
reverse_merges = {idx: pair for pair, idx in merges.items()} | |
# Expand all tokens recursively | |
def expand_token(token): | |
if token < 256: # Base case: token is a byte | |
return bytes([token]) | |
# Recursive case: expand the token into its constituent pair | |
pair = reverse_merges[token] | |
return expand_token(pair[0]) + expand_token(pair[1]) | |
# Expand all tokens and concatenate | |
bytes_list = [expand_token(id) for id in ids] | |
bytes_data = b''.join(bytes_list) | |
# Convert bytes back to text | |
try: | |
return bytes_data.decode('utf-8') | |
except UnicodeDecodeError: | |
return "[DECODE_ERROR]" | |
class Tokenizer: | |
def __init__(self, merges = None, config: dict = None): | |
self.merges = merges or {} | |
self.config = config | |
def save(self, file_path): | |
# Convert tuple keys to strings for JSON serialization | |
serializable_merges = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(serializable_merges, f) | |
def load(cls, file_path): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
serialized_merges = json.load(f) | |
# Convert string keys back to tuples | |
merges = {tuple(map(int, k.split(','))): v | |
for k, v in serialized_merges.items()} | |
return cls(merges) | |
def encode(self, text): | |
return encode(text, self.merges, self.config) | |
def decode(self, ids): | |
return decode(ids, self.merges, self.config) | |
def train_tokenizer(config: dict) -> None: | |
# get input text | |
hi_text = get_input_text(config) | |
# convert string to tokens | |
tokens = stoi(hi_text, config) | |
initial_len = len(tokens) | |
print("Tokens length (initial): ", initial_len, " tokens unique: ", len(set(tokens))) | |
print("Example tokens: ", ord('क'), chr(2325), ord("।"), chr(2404)) | |
print("Training tokenizer....") | |
num_merges = config["vocab_size"] - 256 | |
original_token = tokens | |
merges ={} | |
pbar = tqdm(range(num_merges), desc="Training tokenizer") | |
output_file = config["output_file_info"]["file_path"] | |
for i in pbar: | |
# Get statistics of the tokens | |
stats = get_stats(tokens) | |
# Get the most frequent pair | |
pair = max (stats, key=stats.get) | |
# Get the index of the new token | |
idx = 256 + i | |
# Merge the pair | |
tokens = merge(tokens, pair, idx) | |
merges[pair] = idx | |
# Show progress | |
if (i + 1) % 100 == 0: | |
current_ratio = initial_len / len(tokens) | |
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X") | |
# Garbage collection periodically | |
if (i + 1) % 1000 == 0: | |
gc.collect() | |
# Save intermediate merges | |
if (i + 1) % 1000 == 0: | |
temp_tokenizer = Tokenizer(merges) | |
temp_tokenizer.save(f"{output_file}.checkpoint") | |
print("Training tokenizer completed") | |
final_tokenizer = Tokenizer(merges) | |
final_tokenizer.save(f"{output_file}") | |
print("\n=== Final Statistics ===") | |
print(f"Vocabulary size: {config['vocab_size']}") | |
print(f"Initial tokens: {initial_len:,}") | |
print(f"Final tokens: {len(tokens):,}") | |
print(f"Initial bytes: {initial_len * 4:,}") | |
print(f"Final bytes: {len(tokens) * 4:,}") | |
print(f"Token compression ratio: {initial_len / len(tokens):.2f}X") | |
print(f"Byte compression ratio: {initial_len * 4 / len(tokens) * 4:.2f}X") | |
print(f"Saved tokenizer to: {output_file}") | |
return merges | |
def load_tokenizer(config: dict) -> Tokenizer: | |
"load the tokenizer from the json file" | |
with open(config["output_file_info"]["file_path"], 'r', encoding='utf-8') as f: | |
serialized_merges = json.load(f) | |
merges = {tuple(map(int, k.split(','))): v | |
for k, v in serialized_merges.items()} | |
return Tokenizer(merges, config) | |
if __name__ == "__main__": | |
# TRAIN TOKENIZER | |
config = load_config() | |
merges = train_tokenizer(config) | |
print("Merges: ", merges) | |
# USE TOKENIZER | |
# tokenizer = load_tokenizer(config) | |
# test_text = config["test_text"] | |
# print("Test text: ", test_text) | |
# print("Encoded text: ", tokenizer.encode(test_text)) | |
# decoded = tokenizer.decode(tokenizer.encode(test_text)) | |
# print("Decoded text: ", decoded) | |
# print(f"Successful roundtrip: {test_text == decoded}") | |