hindi-tokenizer-bpe / tokenizer.py
peeyushsinghal's picture
included hindi tokenizer files
10f8413 verified
import yaml
import regex as re
from tqdm import tqdm
import gc
import json
def load_config(config_file_path: str = "config.yml"):
with open(config_file_path, "r") as f:
config = yaml.safe_load(f)
return config
def get_input_text(config: dict) -> str:
with open(config["input_file_info"]["file_path"], 'r', encoding='utf-8') as _f:
hi_text = [line.strip() for line in _f.readlines()]
hi_text_abridged = hi_text[:int(config["input_file_info"]["input_file_limit"])]
hi_text_abridged = '\n'.join(hi_text_abridged)
if config["input_file_info"]["print_text"]:
print(" Sample text: ", hi_text_abridged[:10])
return hi_text_abridged
def get_stats(ids, counts= None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def stoi(text: str, config: dict) -> list:
# tokenize the text
if config["regex_string"] and len(config["regex_string"]) > 0:
print("Using regex string: ", config["regex_string"])
tokens = re.findall(config["regex_string"], text)
# Convert tokens to bytes and then to integers
return [b for token in tokens for b in token.encode('utf-8')]
else:
print("Using default tokenizer")
# Instead of splitting, we'll preserve spaces by encoding them directly
return [b for ch in text for b in ch.encode('utf-8')]
def encode(text, merges, config: dict):
"""
Encode text into tokens using the learned merges
"""
ids = stoi(text, config)
sorted_merges = sorted(merges.items(), key=lambda x: x[1])
for (p1, p2), idx in sorted_merges:
ids = merge(ids, (p1, p2), idx)
return ids
def decode(ids, merges, config: dict):
"""
Decode tokens back to text using the learned merges
"""
# Create reverse mapping from token to pair
reverse_merges = {idx: pair for pair, idx in merges.items()}
# Expand all tokens recursively
def expand_token(token):
if token < 256: # Base case: token is a byte
return bytes([token])
# Recursive case: expand the token into its constituent pair
pair = reverse_merges[token]
return expand_token(pair[0]) + expand_token(pair[1])
# Expand all tokens and concatenate
bytes_list = [expand_token(id) for id in ids]
bytes_data = b''.join(bytes_list)
# Convert bytes back to text
try:
return bytes_data.decode('utf-8')
except UnicodeDecodeError:
return "[DECODE_ERROR]"
class Tokenizer:
def __init__(self, merges = None, config: dict = None):
self.merges = merges or {}
self.config = config
def save(self, file_path):
# Convert tuple keys to strings for JSON serialization
serializable_merges = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(serializable_merges, f)
@classmethod
def load(cls, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
serialized_merges = json.load(f)
# Convert string keys back to tuples
merges = {tuple(map(int, k.split(','))): v
for k, v in serialized_merges.items()}
return cls(merges)
def encode(self, text):
return encode(text, self.merges, self.config)
def decode(self, ids):
return decode(ids, self.merges, self.config)
def train_tokenizer(config: dict) -> None:
# get input text
hi_text = get_input_text(config)
# convert string to tokens
tokens = stoi(hi_text, config)
initial_len = len(tokens)
print("Tokens length (initial): ", initial_len, " tokens unique: ", len(set(tokens)))
print("Example tokens: ", ord('क'), chr(2325), ord("।"), chr(2404))
print("Training tokenizer....")
num_merges = config["vocab_size"] - 256
original_token = tokens
merges ={}
pbar = tqdm(range(num_merges), desc="Training tokenizer")
output_file = config["output_file_info"]["file_path"]
for i in pbar:
# Get statistics of the tokens
stats = get_stats(tokens)
# Get the most frequent pair
pair = max (stats, key=stats.get)
# Get the index of the new token
idx = 256 + i
# Merge the pair
tokens = merge(tokens, pair, idx)
merges[pair] = idx
# Show progress
if (i + 1) % 100 == 0:
current_ratio = initial_len / len(tokens)
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
# Garbage collection periodically
if (i + 1) % 1000 == 0:
gc.collect()
# Save intermediate merges
if (i + 1) % 1000 == 0:
temp_tokenizer = Tokenizer(merges)
temp_tokenizer.save(f"{output_file}.checkpoint")
print("Training tokenizer completed")
final_tokenizer = Tokenizer(merges)
final_tokenizer.save(f"{output_file}")
print("\n=== Final Statistics ===")
print(f"Vocabulary size: {config['vocab_size']}")
print(f"Initial tokens: {initial_len:,}")
print(f"Final tokens: {len(tokens):,}")
print(f"Initial bytes: {initial_len * 4:,}")
print(f"Final bytes: {len(tokens) * 4:,}")
print(f"Token compression ratio: {initial_len / len(tokens):.2f}X")
print(f"Byte compression ratio: {initial_len * 4 / len(tokens) * 4:.2f}X")
print(f"Saved tokenizer to: {output_file}")
return merges
def load_tokenizer(config: dict) -> Tokenizer:
"load the tokenizer from the json file"
with open(config["output_file_info"]["file_path"], 'r', encoding='utf-8') as f:
serialized_merges = json.load(f)
merges = {tuple(map(int, k.split(','))): v
for k, v in serialized_merges.items()}
return Tokenizer(merges, config)
if __name__ == "__main__":
# TRAIN TOKENIZER
config = load_config()
merges = train_tokenizer(config)
print("Merges: ", merges)
# USE TOKENIZER
# tokenizer = load_tokenizer(config)
# test_text = config["test_text"]
# print("Test text: ", test_text)
# print("Encoded text: ", tokenizer.encode(test_text))
# decoded = tokenizer.decode(tokenizer.encode(test_text))
# print("Decoded text: ", decoded)
# print(f"Successful roundtrip: {test_text == decoded}")