peeyushsinghal commited on
Commit
10f8413
·
verified ·
1 Parent(s): 0a0cd27

included hindi tokenizer files

Browse files
Files changed (4) hide show
  1. app.py +197 -0
  2. config_app.yml +4 -0
  3. requirements.txt +86 -0
  4. tokenizer.py +212 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tokenizer import Tokenizer, load_config
3
+ import json
4
+ import html
5
+
6
+ # Load the tokenizer
7
+ config = load_config("config_app.yml")
8
+ tokenizer = Tokenizer.load(config["tokenizer_file_path"])
9
+ tokenizer.config = config
10
+
11
+ def highlight_tokens(text: str, encoded_tokens: list) -> str:
12
+ """
13
+ Create HTML with highlighted tokens in the text.
14
+
15
+ Args:
16
+ text (str): The original input text to be tokenized.
17
+ encoded_tokens (list): A list of encoded token IDs.
18
+
19
+ Returns:
20
+ str: HTML string with highlighted tokens and tooltips showing token IDs.
21
+ """
22
+ decoded_tokens = []
23
+ current_pos = 0
24
+ html_text = ""
25
+
26
+ # Decode each token and create spans with different colors
27
+ for i, token in enumerate(encoded_tokens):
28
+ token_bytes = tokenizer.decode([token])
29
+ decoded_tokens.append(token_bytes)
30
+
31
+ # Find the token in the original text
32
+ token_pos = text.find(token_bytes, current_pos)
33
+ if token_pos != -1:
34
+ # Add any skipped text
35
+ if token_pos > current_pos:
36
+ html_text += html.escape(text[current_pos:token_pos])
37
+
38
+ # Add the highlighted token with improved tooltip
39
+ color = f"hsl({(i * 60) % 360}, 80%, 85%)"
40
+ html_text += f'''
41
+ <span
42
+ style="background-color: {color};
43
+ border-radius: 3px;
44
+ padding: 0 3px;
45
+ margin: 0 1px;
46
+ position: relative;
47
+ cursor: help;"
48
+ onmouseover="this.querySelector('.tooltip').style.display='block'"
49
+ onmouseout="this.querySelector('.tooltip').style.display='none'">
50
+ {html.escape(token_bytes)}
51
+ <span class="tooltip"
52
+ style="display: none;
53
+ position: absolute;
54
+ bottom: 100%;
55
+ left: 50%;
56
+ transform: translateX(-50%);
57
+ background-color: #333;
58
+ color: white;
59
+ padding: 4px 8px;
60
+ border-radius: 4px;
61
+ font-size: 12px;
62
+ white-space: nowrap;
63
+ z-index: 1000;">
64
+ Token ID: {token}
65
+ </span>
66
+ </span>'''
67
+ current_pos = token_pos + len(token_bytes)
68
+
69
+ # Add any remaining text
70
+ if current_pos < len(text):
71
+ html_text += html.escape(text[current_pos:])
72
+
73
+ return html_text
74
+
75
+ def process_text(text: str) -> tuple:
76
+ """
77
+ Process input text through the tokenizer and return results.
78
+
79
+ Args:
80
+ text (str): The input text to be processed.
81
+
82
+ Returns:
83
+ tuple: A tuple containing:
84
+ - HTML string of highlighted tokens.
85
+ - HTML string of token statistics.
86
+ - String of token IDs.
87
+ """
88
+ try:
89
+ # Encode the text
90
+ encoded = tokenizer.encode(text)
91
+
92
+ # Decode back to text
93
+ decoded = tokenizer.decode(encoded)
94
+
95
+ # Create token visualization
96
+ highlighted_text = highlight_tokens(text, encoded)
97
+
98
+ # Token statistics
99
+ stats = {
100
+ "Total Tokens": len(encoded),
101
+ "Unique Tokens": len(set(encoded)),
102
+ "Characters": len(text),
103
+ "Bytes": len(text.encode('utf-8')),
104
+ "Compression Ratio": f"{len(text.encode('utf-8')) / (len(encoded) * 4):.2f}x"
105
+ }
106
+
107
+ # Format statistics
108
+ stats_html = "<div style='margin-top: 20px;'>"
109
+ for key, value in stats.items():
110
+ stats_html += f"<div style='margin: 5px 0;'><b>{key}:</b> {value}</div>"
111
+ stats_html += "</div>"
112
+
113
+ return (
114
+ gr.HTML(highlighted_text),
115
+ gr.HTML(stats_html),
116
+ f"Token IDs: {encoded}"
117
+ )
118
+ except Exception as e:
119
+ return (
120
+ gr.HTML(f"<span style='color: red'>Error: {str(e)}</span>"),
121
+ "",
122
+ ""
123
+ )
124
+
125
+ # Define example inputs
126
+ examples = [
127
+ ["यहां वर्तमान में 20 हजार पुस्तकें थी जो अभी रैन बसेरा परिसर के कक्ष में रखी हुई है।"],
128
+ ["भारत एक विशाल देश है।"],
129
+ ["मैं हिंदी में बात कर रहा हूं।"],
130
+ ["नमस्ते, आप कैसे हैं?"],
131
+ ["दिल्ली भारत की राजधानी है।"]
132
+ ]
133
+
134
+ # Custom CSS
135
+ custom_css = """
136
+ .container {
137
+ max-width: 800px;
138
+ margin: auto;
139
+ padding: 20px;
140
+ }
141
+ .token-viz {
142
+ font-family: monospace;
143
+ line-height: 1.6;
144
+ padding: 15px;
145
+ border: 1px solid #ddd;
146
+ border-radius: 5px;
147
+ background: white;
148
+ margin: 10px 0;
149
+ position: relative;
150
+ }
151
+ .stats {
152
+ background: #f7f7f7;
153
+ padding: 15px;
154
+ border-radius: 5px;
155
+ margin: 10px 0;
156
+ }
157
+ .token-ids {
158
+ font-family: monospace;
159
+ padding: 15px;
160
+ background: #f0f0f0;
161
+ border-radius: 5px;
162
+ overflow-wrap: break-word;
163
+ }
164
+ .tooltip {
165
+ pointer-events: none;
166
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2);
167
+ }
168
+ """
169
+
170
+ # Create the Gradio interface
171
+ iface = gr.Interface(
172
+ fn=process_text,
173
+ inputs=[
174
+ gr.Textbox(
175
+ label="Input Text",
176
+ placeholder="Enter Hindi text here...",
177
+ lines=3
178
+ )
179
+ ],
180
+ outputs=[
181
+ gr.HTML(label="Tokenized Text", elem_classes="token-viz"),
182
+ gr.HTML(label="Statistics", elem_classes="stats"),
183
+ gr.Textbox(label="Token IDs", elem_classes="token-ids")
184
+ ],
185
+ title="Hindi BPE Tokenizer Visualization",
186
+ description="""
187
+ This demo shows how the Hindi BPE tokenizer processes text. Each token is highlighted with a different color.
188
+ Hover over the highlighted tokens to see their token IDs.
189
+ """,
190
+ examples=examples,
191
+ theme=gr.themes.Soft(),
192
+ css=custom_css,
193
+ allow_flagging="never"
194
+ )
195
+
196
+ if __name__ == "__main__":
197
+ iface.launch(share=True)
config_app.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ regex_string: r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
3
+
4
+ tokenizer_file_path: "model/hi_tokenizer_regex.json"
requirements.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ appnope==0.1.4
5
+ asttokens==3.0.0
6
+ certifi==2024.12.14
7
+ charset-normalizer==3.4.1
8
+ click==8.1.8
9
+ comm==0.2.2
10
+ contourpy==1.3.0
11
+ cycler==0.12.1
12
+ debugpy==1.8.11
13
+ decorator==5.1.1
14
+ exceptiongroup==1.2.2
15
+ executing==2.1.0
16
+ fastapi==0.115.6
17
+ ffmpy==0.5.0
18
+ filelock==3.16.1
19
+ fonttools==4.55.3
20
+ fsspec==2024.12.0
21
+ gradio==4.44.1
22
+ gradio_client==1.3.0
23
+ h11==0.14.0
24
+ httpcore==1.0.7
25
+ httpx==0.28.1
26
+ huggingface-hub==0.27.1
27
+ idna==3.10
28
+ importlib_metadata==8.5.0
29
+ importlib_resources==6.5.2
30
+ ipykernel==6.29.5
31
+ ipython==8.18.1
32
+ jedi==0.19.2
33
+ Jinja2==3.1.5
34
+ jupyter_client==8.6.3
35
+ jupyter_core==5.7.2
36
+ kiwisolver==1.4.7
37
+ markdown-it-py==3.0.0
38
+ MarkupSafe==2.1.5
39
+ matplotlib==3.9.4
40
+ matplotlib-inline==0.1.7
41
+ mdurl==0.1.2
42
+ nest-asyncio==1.6.0
43
+ numpy==2.0.2
44
+ orjson==3.10.14
45
+ packaging==24.2
46
+ pandas==2.2.3
47
+ parso==0.8.4
48
+ pexpect==4.9.0
49
+ pillow==10.4.0
50
+ platformdirs==4.3.6
51
+ prompt_toolkit==3.0.48
52
+ psutil==6.1.1
53
+ ptyprocess==0.7.0
54
+ pure_eval==0.2.3
55
+ pydantic==2.10.5
56
+ pydantic_core==2.27.2
57
+ pydub==0.25.1
58
+ Pygments==2.19.1
59
+ pyparsing==3.2.1
60
+ python-dateutil==2.9.0.post0
61
+ python-multipart==0.0.20
62
+ pytz==2024.2
63
+ PyYAML==6.0.2
64
+ pyzmq==26.2.0
65
+ regex==2024.11.6
66
+ requests==2.32.3
67
+ rich==13.9.4
68
+ ruff==0.9.0
69
+ semantic-version==2.10.0
70
+ shellingham==1.5.4
71
+ six==1.17.0
72
+ sniffio==1.3.1
73
+ stack-data==0.6.3
74
+ starlette==0.41.3
75
+ tomlkit==0.12.0
76
+ tornado==6.4.2
77
+ tqdm==4.67.1
78
+ traitlets==5.14.3
79
+ typer==0.15.1
80
+ typing_extensions==4.12.2
81
+ tzdata==2024.2
82
+ urllib3==2.3.0
83
+ uvicorn==0.34.0
84
+ wcwidth==0.2.13
85
+ websockets==12.0
86
+ zipp==3.21.0
tokenizer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import regex as re
3
+ from tqdm import tqdm
4
+ import gc
5
+ import json
6
+
7
+
8
+ def load_config(config_file_path: str = "config.yml"):
9
+ with open(config_file_path, "r") as f:
10
+ config = yaml.safe_load(f)
11
+ return config
12
+
13
+ def get_input_text(config: dict) -> str:
14
+ with open(config["input_file_info"]["file_path"], 'r', encoding='utf-8') as _f:
15
+ hi_text = [line.strip() for line in _f.readlines()]
16
+
17
+ hi_text_abridged = hi_text[:int(config["input_file_info"]["input_file_limit"])]
18
+ hi_text_abridged = '\n'.join(hi_text_abridged)
19
+
20
+ if config["input_file_info"]["print_text"]:
21
+ print(" Sample text: ", hi_text_abridged[:10])
22
+
23
+ return hi_text_abridged
24
+
25
+ def get_stats(ids, counts= None):
26
+ counts = {} if counts is None else counts
27
+ for pair in zip(ids, ids[1:]):
28
+ counts[pair] = counts.get(pair, 0) + 1
29
+ return counts
30
+
31
+ def merge(ids, pair, idx):
32
+ newids = []
33
+ i = 0
34
+ while i < len(ids):
35
+ if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
36
+ newids.append(idx)
37
+ i += 2
38
+ else:
39
+ newids.append(ids[i])
40
+ i += 1
41
+ return newids
42
+
43
+ def stoi(text: str, config: dict) -> list:
44
+ # tokenize the text
45
+ if config["regex_string"] and len(config["regex_string"]) > 0:
46
+ print("Using regex string: ", config["regex_string"])
47
+ tokens = re.findall(config["regex_string"], text)
48
+ # Convert tokens to bytes and then to integers
49
+ return [b for token in tokens for b in token.encode('utf-8')]
50
+ else:
51
+ print("Using default tokenizer")
52
+ # Instead of splitting, we'll preserve spaces by encoding them directly
53
+ return [b for ch in text for b in ch.encode('utf-8')]
54
+
55
+
56
+ def encode(text, merges, config: dict):
57
+ """
58
+ Encode text into tokens using the learned merges
59
+ """
60
+ ids = stoi(text, config)
61
+
62
+ sorted_merges = sorted(merges.items(), key=lambda x: x[1])
63
+ for (p1, p2), idx in sorted_merges:
64
+ ids = merge(ids, (p1, p2), idx)
65
+
66
+ return ids
67
+
68
+ def decode(ids, merges, config: dict):
69
+ """
70
+ Decode tokens back to text using the learned merges
71
+ """
72
+ # Create reverse mapping from token to pair
73
+ reverse_merges = {idx: pair for pair, idx in merges.items()}
74
+
75
+ # Expand all tokens recursively
76
+ def expand_token(token):
77
+ if token < 256: # Base case: token is a byte
78
+ return bytes([token])
79
+
80
+ # Recursive case: expand the token into its constituent pair
81
+ pair = reverse_merges[token]
82
+ return expand_token(pair[0]) + expand_token(pair[1])
83
+
84
+ # Expand all tokens and concatenate
85
+ bytes_list = [expand_token(id) for id in ids]
86
+ bytes_data = b''.join(bytes_list)
87
+
88
+ # Convert bytes back to text
89
+ try:
90
+ return bytes_data.decode('utf-8')
91
+ except UnicodeDecodeError:
92
+ return "[DECODE_ERROR]"
93
+
94
+ class Tokenizer:
95
+ def __init__(self, merges = None, config: dict = None):
96
+ self.merges = merges or {}
97
+ self.config = config
98
+
99
+ def save(self, file_path):
100
+ # Convert tuple keys to strings for JSON serialization
101
+ serializable_merges = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
102
+ with open(file_path, 'w', encoding='utf-8') as f:
103
+ json.dump(serializable_merges, f)
104
+
105
+ @classmethod
106
+ def load(cls, file_path):
107
+ with open(file_path, 'r', encoding='utf-8') as f:
108
+ serialized_merges = json.load(f)
109
+ # Convert string keys back to tuples
110
+ merges = {tuple(map(int, k.split(','))): v
111
+ for k, v in serialized_merges.items()}
112
+
113
+ return cls(merges)
114
+
115
+ def encode(self, text):
116
+ return encode(text, self.merges, self.config)
117
+
118
+ def decode(self, ids):
119
+ return decode(ids, self.merges, self.config)
120
+
121
+ def train_tokenizer(config: dict) -> None:
122
+ # get input text
123
+ hi_text = get_input_text(config)
124
+
125
+ # convert string to tokens
126
+ tokens = stoi(hi_text, config)
127
+ initial_len = len(tokens)
128
+ print("Tokens length (initial): ", initial_len, " tokens unique: ", len(set(tokens)))
129
+ print("Example tokens: ", ord('क'), chr(2325), ord("।"), chr(2404))
130
+
131
+ print("Training tokenizer....")
132
+ num_merges = config["vocab_size"] - 256
133
+ original_token = tokens
134
+
135
+ merges ={}
136
+ pbar = tqdm(range(num_merges), desc="Training tokenizer")
137
+ output_file = config["output_file_info"]["file_path"]
138
+
139
+ for i in pbar:
140
+ # Get statistics of the tokens
141
+ stats = get_stats(tokens)
142
+ # Get the most frequent pair
143
+ pair = max (stats, key=stats.get)
144
+ # Get the index of the new token
145
+ idx = 256 + i
146
+
147
+ # Merge the pair
148
+ tokens = merge(tokens, pair, idx)
149
+ merges[pair] = idx
150
+
151
+
152
+ # Show progress
153
+ if (i + 1) % 100 == 0:
154
+ current_ratio = initial_len / len(tokens)
155
+ pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")
156
+
157
+ # Garbage collection periodically
158
+ if (i + 1) % 1000 == 0:
159
+ gc.collect()
160
+
161
+ # Save intermediate merges
162
+ if (i + 1) % 1000 == 0:
163
+ temp_tokenizer = Tokenizer(merges)
164
+ temp_tokenizer.save(f"{output_file}.checkpoint")
165
+
166
+ print("Training tokenizer completed")
167
+ final_tokenizer = Tokenizer(merges)
168
+ final_tokenizer.save(f"{output_file}")
169
+
170
+ print("\n=== Final Statistics ===")
171
+ print(f"Vocabulary size: {config['vocab_size']}")
172
+ print(f"Initial tokens: {initial_len:,}")
173
+ print(f"Final tokens: {len(tokens):,}")
174
+ print(f"Initial bytes: {initial_len * 4:,}")
175
+ print(f"Final bytes: {len(tokens) * 4:,}")
176
+ print(f"Token compression ratio: {initial_len / len(tokens):.2f}X")
177
+ print(f"Byte compression ratio: {initial_len * 4 / len(tokens) * 4:.2f}X")
178
+ print(f"Saved tokenizer to: {output_file}")
179
+
180
+ return merges
181
+
182
+ def load_tokenizer(config: dict) -> Tokenizer:
183
+ "load the tokenizer from the json file"
184
+ with open(config["output_file_info"]["file_path"], 'r', encoding='utf-8') as f:
185
+ serialized_merges = json.load(f)
186
+
187
+ merges = {tuple(map(int, k.split(','))): v
188
+ for k, v in serialized_merges.items()}
189
+
190
+ return Tokenizer(merges, config)
191
+
192
+ if __name__ == "__main__":
193
+
194
+ # TRAIN TOKENIZER
195
+ config = load_config()
196
+ merges = train_tokenizer(config)
197
+ print("Merges: ", merges)
198
+
199
+ # USE TOKENIZER
200
+ # tokenizer = load_tokenizer(config)
201
+ # test_text = config["test_text"]
202
+
203
+ # print("Test text: ", test_text)
204
+ # print("Encoded text: ", tokenizer.encode(test_text))
205
+ # decoded = tokenizer.decode(tokenizer.encode(test_text))
206
+ # print("Decoded text: ", decoded)
207
+
208
+ # print(f"Successful roundtrip: {test_text == decoded}")
209
+
210
+
211
+
212
+