Spaces:
Sleeping
Sleeping
included hindi tokenizer files
Browse files- app.py +197 -0
- config_app.yml +4 -0
- requirements.txt +86 -0
- tokenizer.py +212 -0
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tokenizer import Tokenizer, load_config
|
3 |
+
import json
|
4 |
+
import html
|
5 |
+
|
6 |
+
# Load the tokenizer
|
7 |
+
config = load_config("config_app.yml")
|
8 |
+
tokenizer = Tokenizer.load(config["tokenizer_file_path"])
|
9 |
+
tokenizer.config = config
|
10 |
+
|
11 |
+
def highlight_tokens(text: str, encoded_tokens: list) -> str:
|
12 |
+
"""
|
13 |
+
Create HTML with highlighted tokens in the text.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
text (str): The original input text to be tokenized.
|
17 |
+
encoded_tokens (list): A list of encoded token IDs.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: HTML string with highlighted tokens and tooltips showing token IDs.
|
21 |
+
"""
|
22 |
+
decoded_tokens = []
|
23 |
+
current_pos = 0
|
24 |
+
html_text = ""
|
25 |
+
|
26 |
+
# Decode each token and create spans with different colors
|
27 |
+
for i, token in enumerate(encoded_tokens):
|
28 |
+
token_bytes = tokenizer.decode([token])
|
29 |
+
decoded_tokens.append(token_bytes)
|
30 |
+
|
31 |
+
# Find the token in the original text
|
32 |
+
token_pos = text.find(token_bytes, current_pos)
|
33 |
+
if token_pos != -1:
|
34 |
+
# Add any skipped text
|
35 |
+
if token_pos > current_pos:
|
36 |
+
html_text += html.escape(text[current_pos:token_pos])
|
37 |
+
|
38 |
+
# Add the highlighted token with improved tooltip
|
39 |
+
color = f"hsl({(i * 60) % 360}, 80%, 85%)"
|
40 |
+
html_text += f'''
|
41 |
+
<span
|
42 |
+
style="background-color: {color};
|
43 |
+
border-radius: 3px;
|
44 |
+
padding: 0 3px;
|
45 |
+
margin: 0 1px;
|
46 |
+
position: relative;
|
47 |
+
cursor: help;"
|
48 |
+
onmouseover="this.querySelector('.tooltip').style.display='block'"
|
49 |
+
onmouseout="this.querySelector('.tooltip').style.display='none'">
|
50 |
+
{html.escape(token_bytes)}
|
51 |
+
<span class="tooltip"
|
52 |
+
style="display: none;
|
53 |
+
position: absolute;
|
54 |
+
bottom: 100%;
|
55 |
+
left: 50%;
|
56 |
+
transform: translateX(-50%);
|
57 |
+
background-color: #333;
|
58 |
+
color: white;
|
59 |
+
padding: 4px 8px;
|
60 |
+
border-radius: 4px;
|
61 |
+
font-size: 12px;
|
62 |
+
white-space: nowrap;
|
63 |
+
z-index: 1000;">
|
64 |
+
Token ID: {token}
|
65 |
+
</span>
|
66 |
+
</span>'''
|
67 |
+
current_pos = token_pos + len(token_bytes)
|
68 |
+
|
69 |
+
# Add any remaining text
|
70 |
+
if current_pos < len(text):
|
71 |
+
html_text += html.escape(text[current_pos:])
|
72 |
+
|
73 |
+
return html_text
|
74 |
+
|
75 |
+
def process_text(text: str) -> tuple:
|
76 |
+
"""
|
77 |
+
Process input text through the tokenizer and return results.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
text (str): The input text to be processed.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
tuple: A tuple containing:
|
84 |
+
- HTML string of highlighted tokens.
|
85 |
+
- HTML string of token statistics.
|
86 |
+
- String of token IDs.
|
87 |
+
"""
|
88 |
+
try:
|
89 |
+
# Encode the text
|
90 |
+
encoded = tokenizer.encode(text)
|
91 |
+
|
92 |
+
# Decode back to text
|
93 |
+
decoded = tokenizer.decode(encoded)
|
94 |
+
|
95 |
+
# Create token visualization
|
96 |
+
highlighted_text = highlight_tokens(text, encoded)
|
97 |
+
|
98 |
+
# Token statistics
|
99 |
+
stats = {
|
100 |
+
"Total Tokens": len(encoded),
|
101 |
+
"Unique Tokens": len(set(encoded)),
|
102 |
+
"Characters": len(text),
|
103 |
+
"Bytes": len(text.encode('utf-8')),
|
104 |
+
"Compression Ratio": f"{len(text.encode('utf-8')) / (len(encoded) * 4):.2f}x"
|
105 |
+
}
|
106 |
+
|
107 |
+
# Format statistics
|
108 |
+
stats_html = "<div style='margin-top: 20px;'>"
|
109 |
+
for key, value in stats.items():
|
110 |
+
stats_html += f"<div style='margin: 5px 0;'><b>{key}:</b> {value}</div>"
|
111 |
+
stats_html += "</div>"
|
112 |
+
|
113 |
+
return (
|
114 |
+
gr.HTML(highlighted_text),
|
115 |
+
gr.HTML(stats_html),
|
116 |
+
f"Token IDs: {encoded}"
|
117 |
+
)
|
118 |
+
except Exception as e:
|
119 |
+
return (
|
120 |
+
gr.HTML(f"<span style='color: red'>Error: {str(e)}</span>"),
|
121 |
+
"",
|
122 |
+
""
|
123 |
+
)
|
124 |
+
|
125 |
+
# Define example inputs
|
126 |
+
examples = [
|
127 |
+
["यहां वर्तमान में 20 हजार पुस्तकें थी जो अभी रैन बसेरा परिसर के कक्ष में रखी हुई है।"],
|
128 |
+
["भारत एक विशाल देश है।"],
|
129 |
+
["मैं हिंदी में बात कर रहा हूं।"],
|
130 |
+
["नमस्ते, आप कैसे हैं?"],
|
131 |
+
["दिल्ली भारत की राजधानी है।"]
|
132 |
+
]
|
133 |
+
|
134 |
+
# Custom CSS
|
135 |
+
custom_css = """
|
136 |
+
.container {
|
137 |
+
max-width: 800px;
|
138 |
+
margin: auto;
|
139 |
+
padding: 20px;
|
140 |
+
}
|
141 |
+
.token-viz {
|
142 |
+
font-family: monospace;
|
143 |
+
line-height: 1.6;
|
144 |
+
padding: 15px;
|
145 |
+
border: 1px solid #ddd;
|
146 |
+
border-radius: 5px;
|
147 |
+
background: white;
|
148 |
+
margin: 10px 0;
|
149 |
+
position: relative;
|
150 |
+
}
|
151 |
+
.stats {
|
152 |
+
background: #f7f7f7;
|
153 |
+
padding: 15px;
|
154 |
+
border-radius: 5px;
|
155 |
+
margin: 10px 0;
|
156 |
+
}
|
157 |
+
.token-ids {
|
158 |
+
font-family: monospace;
|
159 |
+
padding: 15px;
|
160 |
+
background: #f0f0f0;
|
161 |
+
border-radius: 5px;
|
162 |
+
overflow-wrap: break-word;
|
163 |
+
}
|
164 |
+
.tooltip {
|
165 |
+
pointer-events: none;
|
166 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.2);
|
167 |
+
}
|
168 |
+
"""
|
169 |
+
|
170 |
+
# Create the Gradio interface
|
171 |
+
iface = gr.Interface(
|
172 |
+
fn=process_text,
|
173 |
+
inputs=[
|
174 |
+
gr.Textbox(
|
175 |
+
label="Input Text",
|
176 |
+
placeholder="Enter Hindi text here...",
|
177 |
+
lines=3
|
178 |
+
)
|
179 |
+
],
|
180 |
+
outputs=[
|
181 |
+
gr.HTML(label="Tokenized Text", elem_classes="token-viz"),
|
182 |
+
gr.HTML(label="Statistics", elem_classes="stats"),
|
183 |
+
gr.Textbox(label="Token IDs", elem_classes="token-ids")
|
184 |
+
],
|
185 |
+
title="Hindi BPE Tokenizer Visualization",
|
186 |
+
description="""
|
187 |
+
This demo shows how the Hindi BPE tokenizer processes text. Each token is highlighted with a different color.
|
188 |
+
Hover over the highlighted tokens to see their token IDs.
|
189 |
+
""",
|
190 |
+
examples=examples,
|
191 |
+
theme=gr.themes.Soft(),
|
192 |
+
css=custom_css,
|
193 |
+
allow_flagging="never"
|
194 |
+
)
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
iface.launch(share=True)
|
config_app.yml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
regex_string: r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
3 |
+
|
4 |
+
tokenizer_file_path: "model/hi_tokenizer_regex.json"
|
requirements.txt
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.8.0
|
4 |
+
appnope==0.1.4
|
5 |
+
asttokens==3.0.0
|
6 |
+
certifi==2024.12.14
|
7 |
+
charset-normalizer==3.4.1
|
8 |
+
click==8.1.8
|
9 |
+
comm==0.2.2
|
10 |
+
contourpy==1.3.0
|
11 |
+
cycler==0.12.1
|
12 |
+
debugpy==1.8.11
|
13 |
+
decorator==5.1.1
|
14 |
+
exceptiongroup==1.2.2
|
15 |
+
executing==2.1.0
|
16 |
+
fastapi==0.115.6
|
17 |
+
ffmpy==0.5.0
|
18 |
+
filelock==3.16.1
|
19 |
+
fonttools==4.55.3
|
20 |
+
fsspec==2024.12.0
|
21 |
+
gradio==4.44.1
|
22 |
+
gradio_client==1.3.0
|
23 |
+
h11==0.14.0
|
24 |
+
httpcore==1.0.7
|
25 |
+
httpx==0.28.1
|
26 |
+
huggingface-hub==0.27.1
|
27 |
+
idna==3.10
|
28 |
+
importlib_metadata==8.5.0
|
29 |
+
importlib_resources==6.5.2
|
30 |
+
ipykernel==6.29.5
|
31 |
+
ipython==8.18.1
|
32 |
+
jedi==0.19.2
|
33 |
+
Jinja2==3.1.5
|
34 |
+
jupyter_client==8.6.3
|
35 |
+
jupyter_core==5.7.2
|
36 |
+
kiwisolver==1.4.7
|
37 |
+
markdown-it-py==3.0.0
|
38 |
+
MarkupSafe==2.1.5
|
39 |
+
matplotlib==3.9.4
|
40 |
+
matplotlib-inline==0.1.7
|
41 |
+
mdurl==0.1.2
|
42 |
+
nest-asyncio==1.6.0
|
43 |
+
numpy==2.0.2
|
44 |
+
orjson==3.10.14
|
45 |
+
packaging==24.2
|
46 |
+
pandas==2.2.3
|
47 |
+
parso==0.8.4
|
48 |
+
pexpect==4.9.0
|
49 |
+
pillow==10.4.0
|
50 |
+
platformdirs==4.3.6
|
51 |
+
prompt_toolkit==3.0.48
|
52 |
+
psutil==6.1.1
|
53 |
+
ptyprocess==0.7.0
|
54 |
+
pure_eval==0.2.3
|
55 |
+
pydantic==2.10.5
|
56 |
+
pydantic_core==2.27.2
|
57 |
+
pydub==0.25.1
|
58 |
+
Pygments==2.19.1
|
59 |
+
pyparsing==3.2.1
|
60 |
+
python-dateutil==2.9.0.post0
|
61 |
+
python-multipart==0.0.20
|
62 |
+
pytz==2024.2
|
63 |
+
PyYAML==6.0.2
|
64 |
+
pyzmq==26.2.0
|
65 |
+
regex==2024.11.6
|
66 |
+
requests==2.32.3
|
67 |
+
rich==13.9.4
|
68 |
+
ruff==0.9.0
|
69 |
+
semantic-version==2.10.0
|
70 |
+
shellingham==1.5.4
|
71 |
+
six==1.17.0
|
72 |
+
sniffio==1.3.1
|
73 |
+
stack-data==0.6.3
|
74 |
+
starlette==0.41.3
|
75 |
+
tomlkit==0.12.0
|
76 |
+
tornado==6.4.2
|
77 |
+
tqdm==4.67.1
|
78 |
+
traitlets==5.14.3
|
79 |
+
typer==0.15.1
|
80 |
+
typing_extensions==4.12.2
|
81 |
+
tzdata==2024.2
|
82 |
+
urllib3==2.3.0
|
83 |
+
uvicorn==0.34.0
|
84 |
+
wcwidth==0.2.13
|
85 |
+
websockets==12.0
|
86 |
+
zipp==3.21.0
|
tokenizer.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import regex as re
|
3 |
+
from tqdm import tqdm
|
4 |
+
import gc
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
def load_config(config_file_path: str = "config.yml"):
|
9 |
+
with open(config_file_path, "r") as f:
|
10 |
+
config = yaml.safe_load(f)
|
11 |
+
return config
|
12 |
+
|
13 |
+
def get_input_text(config: dict) -> str:
|
14 |
+
with open(config["input_file_info"]["file_path"], 'r', encoding='utf-8') as _f:
|
15 |
+
hi_text = [line.strip() for line in _f.readlines()]
|
16 |
+
|
17 |
+
hi_text_abridged = hi_text[:int(config["input_file_info"]["input_file_limit"])]
|
18 |
+
hi_text_abridged = '\n'.join(hi_text_abridged)
|
19 |
+
|
20 |
+
if config["input_file_info"]["print_text"]:
|
21 |
+
print(" Sample text: ", hi_text_abridged[:10])
|
22 |
+
|
23 |
+
return hi_text_abridged
|
24 |
+
|
25 |
+
def get_stats(ids, counts= None):
|
26 |
+
counts = {} if counts is None else counts
|
27 |
+
for pair in zip(ids, ids[1:]):
|
28 |
+
counts[pair] = counts.get(pair, 0) + 1
|
29 |
+
return counts
|
30 |
+
|
31 |
+
def merge(ids, pair, idx):
|
32 |
+
newids = []
|
33 |
+
i = 0
|
34 |
+
while i < len(ids):
|
35 |
+
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
|
36 |
+
newids.append(idx)
|
37 |
+
i += 2
|
38 |
+
else:
|
39 |
+
newids.append(ids[i])
|
40 |
+
i += 1
|
41 |
+
return newids
|
42 |
+
|
43 |
+
def stoi(text: str, config: dict) -> list:
|
44 |
+
# tokenize the text
|
45 |
+
if config["regex_string"] and len(config["regex_string"]) > 0:
|
46 |
+
print("Using regex string: ", config["regex_string"])
|
47 |
+
tokens = re.findall(config["regex_string"], text)
|
48 |
+
# Convert tokens to bytes and then to integers
|
49 |
+
return [b for token in tokens for b in token.encode('utf-8')]
|
50 |
+
else:
|
51 |
+
print("Using default tokenizer")
|
52 |
+
# Instead of splitting, we'll preserve spaces by encoding them directly
|
53 |
+
return [b for ch in text for b in ch.encode('utf-8')]
|
54 |
+
|
55 |
+
|
56 |
+
def encode(text, merges, config: dict):
|
57 |
+
"""
|
58 |
+
Encode text into tokens using the learned merges
|
59 |
+
"""
|
60 |
+
ids = stoi(text, config)
|
61 |
+
|
62 |
+
sorted_merges = sorted(merges.items(), key=lambda x: x[1])
|
63 |
+
for (p1, p2), idx in sorted_merges:
|
64 |
+
ids = merge(ids, (p1, p2), idx)
|
65 |
+
|
66 |
+
return ids
|
67 |
+
|
68 |
+
def decode(ids, merges, config: dict):
|
69 |
+
"""
|
70 |
+
Decode tokens back to text using the learned merges
|
71 |
+
"""
|
72 |
+
# Create reverse mapping from token to pair
|
73 |
+
reverse_merges = {idx: pair for pair, idx in merges.items()}
|
74 |
+
|
75 |
+
# Expand all tokens recursively
|
76 |
+
def expand_token(token):
|
77 |
+
if token < 256: # Base case: token is a byte
|
78 |
+
return bytes([token])
|
79 |
+
|
80 |
+
# Recursive case: expand the token into its constituent pair
|
81 |
+
pair = reverse_merges[token]
|
82 |
+
return expand_token(pair[0]) + expand_token(pair[1])
|
83 |
+
|
84 |
+
# Expand all tokens and concatenate
|
85 |
+
bytes_list = [expand_token(id) for id in ids]
|
86 |
+
bytes_data = b''.join(bytes_list)
|
87 |
+
|
88 |
+
# Convert bytes back to text
|
89 |
+
try:
|
90 |
+
return bytes_data.decode('utf-8')
|
91 |
+
except UnicodeDecodeError:
|
92 |
+
return "[DECODE_ERROR]"
|
93 |
+
|
94 |
+
class Tokenizer:
|
95 |
+
def __init__(self, merges = None, config: dict = None):
|
96 |
+
self.merges = merges or {}
|
97 |
+
self.config = config
|
98 |
+
|
99 |
+
def save(self, file_path):
|
100 |
+
# Convert tuple keys to strings for JSON serialization
|
101 |
+
serializable_merges = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
|
102 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
103 |
+
json.dump(serializable_merges, f)
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def load(cls, file_path):
|
107 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
108 |
+
serialized_merges = json.load(f)
|
109 |
+
# Convert string keys back to tuples
|
110 |
+
merges = {tuple(map(int, k.split(','))): v
|
111 |
+
for k, v in serialized_merges.items()}
|
112 |
+
|
113 |
+
return cls(merges)
|
114 |
+
|
115 |
+
def encode(self, text):
|
116 |
+
return encode(text, self.merges, self.config)
|
117 |
+
|
118 |
+
def decode(self, ids):
|
119 |
+
return decode(ids, self.merges, self.config)
|
120 |
+
|
121 |
+
def train_tokenizer(config: dict) -> None:
|
122 |
+
# get input text
|
123 |
+
hi_text = get_input_text(config)
|
124 |
+
|
125 |
+
# convert string to tokens
|
126 |
+
tokens = stoi(hi_text, config)
|
127 |
+
initial_len = len(tokens)
|
128 |
+
print("Tokens length (initial): ", initial_len, " tokens unique: ", len(set(tokens)))
|
129 |
+
print("Example tokens: ", ord('क'), chr(2325), ord("।"), chr(2404))
|
130 |
+
|
131 |
+
print("Training tokenizer....")
|
132 |
+
num_merges = config["vocab_size"] - 256
|
133 |
+
original_token = tokens
|
134 |
+
|
135 |
+
merges ={}
|
136 |
+
pbar = tqdm(range(num_merges), desc="Training tokenizer")
|
137 |
+
output_file = config["output_file_info"]["file_path"]
|
138 |
+
|
139 |
+
for i in pbar:
|
140 |
+
# Get statistics of the tokens
|
141 |
+
stats = get_stats(tokens)
|
142 |
+
# Get the most frequent pair
|
143 |
+
pair = max (stats, key=stats.get)
|
144 |
+
# Get the index of the new token
|
145 |
+
idx = 256 + i
|
146 |
+
|
147 |
+
# Merge the pair
|
148 |
+
tokens = merge(tokens, pair, idx)
|
149 |
+
merges[pair] = idx
|
150 |
+
|
151 |
+
|
152 |
+
# Show progress
|
153 |
+
if (i + 1) % 100 == 0:
|
154 |
+
current_ratio = initial_len / len(tokens)
|
155 |
+
pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
|
156 |
+
|
157 |
+
# Garbage collection periodically
|
158 |
+
if (i + 1) % 1000 == 0:
|
159 |
+
gc.collect()
|
160 |
+
|
161 |
+
# Save intermediate merges
|
162 |
+
if (i + 1) % 1000 == 0:
|
163 |
+
temp_tokenizer = Tokenizer(merges)
|
164 |
+
temp_tokenizer.save(f"{output_file}.checkpoint")
|
165 |
+
|
166 |
+
print("Training tokenizer completed")
|
167 |
+
final_tokenizer = Tokenizer(merges)
|
168 |
+
final_tokenizer.save(f"{output_file}")
|
169 |
+
|
170 |
+
print("\n=== Final Statistics ===")
|
171 |
+
print(f"Vocabulary size: {config['vocab_size']}")
|
172 |
+
print(f"Initial tokens: {initial_len:,}")
|
173 |
+
print(f"Final tokens: {len(tokens):,}")
|
174 |
+
print(f"Initial bytes: {initial_len * 4:,}")
|
175 |
+
print(f"Final bytes: {len(tokens) * 4:,}")
|
176 |
+
print(f"Token compression ratio: {initial_len / len(tokens):.2f}X")
|
177 |
+
print(f"Byte compression ratio: {initial_len * 4 / len(tokens) * 4:.2f}X")
|
178 |
+
print(f"Saved tokenizer to: {output_file}")
|
179 |
+
|
180 |
+
return merges
|
181 |
+
|
182 |
+
def load_tokenizer(config: dict) -> Tokenizer:
|
183 |
+
"load the tokenizer from the json file"
|
184 |
+
with open(config["output_file_info"]["file_path"], 'r', encoding='utf-8') as f:
|
185 |
+
serialized_merges = json.load(f)
|
186 |
+
|
187 |
+
merges = {tuple(map(int, k.split(','))): v
|
188 |
+
for k, v in serialized_merges.items()}
|
189 |
+
|
190 |
+
return Tokenizer(merges, config)
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
|
194 |
+
# TRAIN TOKENIZER
|
195 |
+
config = load_config()
|
196 |
+
merges = train_tokenizer(config)
|
197 |
+
print("Merges: ", merges)
|
198 |
+
|
199 |
+
# USE TOKENIZER
|
200 |
+
# tokenizer = load_tokenizer(config)
|
201 |
+
# test_text = config["test_text"]
|
202 |
+
|
203 |
+
# print("Test text: ", test_text)
|
204 |
+
# print("Encoded text: ", tokenizer.encode(test_text))
|
205 |
+
# decoded = tokenizer.decode(tokenizer.encode(test_text))
|
206 |
+
# print("Decoded text: ", decoded)
|
207 |
+
|
208 |
+
# print(f"Successful roundtrip: {test_text == decoded}")
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|