Rasmus Lellep commited on
Commit
76b1ec5
·
1 Parent(s): 9e93eb6

add loader

Browse files
kuidastaltsutadalaamat/.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
kuidastaltsutadalaamat/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TartuNLP
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
kuidastaltsutadalaamat/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Kuidas taltsutada laamat
2
+ Implementation of LLM continued training and inference.
kuidastaltsutadalaamat/aux.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import numpy as np
4
+ import pickle
5
+ import re
6
+ import sys
7
+
8
+ from datetime import datetime
9
+
10
+
11
+ def log(msg, accelerator=None, all_threads=False):
12
+ if accelerator is not None and all_threads:
13
+ report_proc = f" ({accelerator.process_index+1}/{accelerator.num_processes})"
14
+ else:
15
+ report_proc = ""
16
+
17
+ if accelerator is None or accelerator.is_main_process or all_threads:
18
+ sys.stderr.write(str(datetime.now()) + report_proc + ": " + msg + '\n')
19
+
20
+
21
+ def _same_line_log(msg, len_to_del=0):
22
+ """if sys.stderr.isatty():
23
+ if len_to_del > 0:
24
+ sys.stderr.write("\b" * len_to_del)
25
+
26
+ new_len = len(msg)
27
+
28
+ sys.stderr.write(msg)
29
+ sys.stderr.flush()
30
+
31
+ return new_len
32
+ else:"""
33
+ log(msg)
34
+
35
+
36
+ def debug(msg):
37
+ pass
38
+ ### log("\n(DEBUG) " + msg)
39
+
40
+
41
+ def maybe_convert(value):
42
+ try:
43
+ return int(value)
44
+ except (ValueError, TypeError):
45
+ try:
46
+ return float(value)
47
+ except (ValueError, TypeError):
48
+ return value
49
+
50
+
51
+ def get_changed_config(conf, args):
52
+ arg_dict = args.to_dict()
53
+
54
+ for kwarg in arg_dict:
55
+ if hasattr(conf, kwarg) and arg_dict[kwarg] is not None:
56
+ setattr(conf, kwarg, maybe_convert(arg_dict[kwarg]))
57
+
58
+ return conf
59
+
60
+
61
+ class SameLineLogger:
62
+ def __init__(self, epoch_len, epoch_num, data_state):
63
+ self.epoch_len = epoch_len
64
+ self.epoch_num = epoch_num
65
+ self.start_global_step = epoch_len * data_state.epoch_idx + data_state.elem_idx
66
+
67
+ self.totalx = epoch_len * epoch_num
68
+
69
+ self.log_after = []
70
+ self.log_len = 0
71
+
72
+ self.start_time = datetime.now()
73
+
74
+ def line_start(self):
75
+ _same_line_log(str(datetime.now()) + ": training batches ")
76
+
77
+ def step(self, global_batch_idx, epoch_batch_idx, epoch_idx, loss, lr, grad):
78
+ passed_time = datetime.now() - self.start_time
79
+ time_per_batch = passed_time / (global_batch_idx - self.start_global_step)
80
+ prediction = time_per_batch * (self.totalx - global_batch_idx)
81
+
82
+ msg = f"{epoch_batch_idx} / {self.epoch_len}, epoch {epoch_idx + 1} / {self.epoch_num}, loss={loss}, avg {time_per_batch}/iter, {prediction} to finish, LR={lr:.2e}, grad={grad:.2e} "
83
+
84
+ new_len = _same_line_log(msg, self.log_len)
85
+
86
+ self.log_len = new_len
87
+
88
+ def line_break(self):
89
+ sys.stderr.write("\n")
90
+
91
+
92
+ class CmdlineArgs:
93
+ def __init__(self,
94
+ description,
95
+ pos_arg_list=None,
96
+ pos_arg_types=None,
97
+ kw_arg_dict=None,
98
+ input_args=None):
99
+
100
+ self.description = description
101
+
102
+ self.raw_pos_arg_list = pos_arg_list if pos_arg_list is not None else []
103
+ self.raw_pos_arg_types = pos_arg_types \
104
+ if pos_arg_types is not None \
105
+ else [None] * len(self.raw_pos_arg_list)
106
+
107
+ self.kw_arg_dict_with_defaults = kw_arg_dict if kw_arg_dict is not None else {}
108
+
109
+ kw_vals, cmdline_values = self._to_kwargs(sys.argv[1:] if input_args is None else input_args)
110
+
111
+ self._maybe_help(cmdline_values)
112
+
113
+ self._handle_positional_args(cmdline_values)
114
+
115
+ self._handle_keyword_args(kw_vals)
116
+
117
+ @staticmethod
118
+ def _to_kwargs(arg_list):
119
+ key_args = dict(raw_entry.lstrip("-").split("=") for raw_entry in arg_list if "=" in raw_entry)
120
+ filtered_arg_list = [arg for arg in arg_list if "=" not in arg]
121
+
122
+ return key_args, filtered_arg_list
123
+
124
+ def _handle_keyword_args(self, kw_vals):
125
+ for kw in self.kw_arg_dict_with_defaults:
126
+ if kw in kw_vals:
127
+ val = self._convert_kw(kw_vals, kw)
128
+ del kw_vals[kw]
129
+ else:
130
+ val = self.kw_arg_dict_with_defaults[kw]
131
+
132
+ setattr(self, kw, val)
133
+
134
+ if kw_vals:
135
+ extra_keys = ", ".join(kw_vals.keys())
136
+ msg = f"command-line keyword arguments '{extra_keys}' are not recognized."
137
+
138
+ self._help_message_and_die(extra=msg)
139
+
140
+ def _convert_kw(self, kw_vals, kw):
141
+ if self.kw_arg_dict_with_defaults[kw] is None:
142
+ return kw_vals[kw]
143
+ else:
144
+ this_typ = type(self.kw_arg_dict_with_defaults[kw])
145
+
146
+ try:
147
+ return this_typ(kw_vals[kw])
148
+ except ValueError:
149
+ self._help_message_and_die(extra=f"could not convert '{kw_vals[kw]}' to '{this_typ}'")
150
+
151
+ def _sanity_check_pos_args(self, cmdline_values):
152
+ cmdline_len = len(cmdline_values)
153
+
154
+ if cmdline_len < len(self.raw_pos_arg_list):
155
+ self._help_message_and_die(
156
+ extra=f"positional arguments missing: {', '.join(self.raw_pos_arg_list[cmdline_len:])}")
157
+
158
+ if cmdline_len > len(self.raw_pos_arg_list):
159
+ self._help_message_and_die(
160
+ extra=f"superfluous positional arguments: {', '.join(cmdline_values[len(self.raw_pos_arg_list):])}")
161
+
162
+ def _handle_positional_args(self, cmdline_values):
163
+ self._sanity_check_pos_args(cmdline_values)
164
+
165
+ for arg, val, typ in zip(self.raw_pos_arg_list, cmdline_values, self.raw_pos_arg_types):
166
+ try:
167
+ val = val if typ is None else typ(val)
168
+ except ValueError:
169
+ self._help_message_and_die(extra=f"could not convert '{val}' to '{typ}'")
170
+
171
+ setattr(self, arg, val)
172
+
173
+ def _maybe_help(self, cmdline_values):
174
+ if len(cmdline_values) == 1 and cmdline_values[0] in {"--help", "-h", "-?"}:
175
+ self._help_message_and_die()
176
+
177
+ def _help_message_and_die(self, extra=None):
178
+ sys.stderr.write("Help message: " + self.description + "\n")
179
+
180
+ if self.raw_pos_arg_list:
181
+ args_descr = ", ".join([f"'{arg}' ({typ.__name__ if typ is not None else 'any'})"
182
+ for arg, typ in zip(self.raw_pos_arg_list, self.raw_pos_arg_types)])
183
+
184
+ sys.stderr.write(f"Positional arguments: {args_descr}\n")
185
+
186
+ if self.kw_arg_dict_with_defaults:
187
+ kw_descr = ", ".join([f"'{kw}' (default: {val})"
188
+ for kw, val in self.kw_arg_dict_with_defaults.items()])
189
+
190
+ sys.stderr.write(f"Keyword arguments: {kw_descr}\n")
191
+
192
+ if extra is not None:
193
+ sys.stderr.write("Error: " + extra + "\n")
194
+
195
+ sys.stderr.write("\n")
196
+ sys.exit(-1)
197
+
198
+ def to_dict(self):
199
+ return {k: v for k, v in self.__dict__.items()
200
+ if k not in {'description', 'raw_pos_arg_list', 'raw_pos_arg_types', 'kw_arg_dict_with_defaults'}}
201
+
202
+ def __str__(self):
203
+ return str(self.to_dict())
204
+
205
+ def __repr__(self):
206
+ return self.__str__()
207
+
208
+ if __name__ == "__main__":
209
+ for dname in sys.argv[1:]:
210
+ d = np.load(dname + "/custom_checkpoint_1.pkl", allow_pickle=True)
211
+ p = pickle.loads(d['custom_checkpoint_1/data.pkl'])
212
+ print(dname, p)
kuidastaltsutadalaamat/data.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from .promptops import *
3
+
4
+ import json
5
+ import sys
6
+
7
+ from random import shuffle
8
+
9
+ from torch.utils.data import Dataset as TorchDataset, DataLoader
10
+
11
+ from .aux import log
12
+
13
+
14
+ def tokenize_str(tokenizer, entry, add_eos=True, max_len=3000, for_inf=False):
15
+ if for_inf:
16
+ tokens = tokenizer(
17
+ entry,
18
+ truncation=True,
19
+ max_length=max_len,
20
+ return_attention_mask=True,
21
+ return_tensors="pt"
22
+ )
23
+ else:
24
+ tokens = tokenizer(
25
+ entry,
26
+ truncation=True,
27
+ max_length=max_len,
28
+ return_attention_mask=True
29
+ )
30
+
31
+ if add_eos:
32
+ tokens['attention_mask'].append(1)
33
+ tokens['input_ids'].append(tokenizer.eos_token_id)
34
+
35
+ return tokens
36
+
37
+ """
38
+ Load texts into memory and allow to loop through it,
39
+ returning tokenized tensors.
40
+
41
+ Currently no support for text data that does not fit into memory,
42
+ need to add it. Or do HF datasets have something out of the box?
43
+ """
44
+ class LazyTokenizingDataset(TorchDataset):
45
+ def __init__(self, texts, tokenizer, max_length=512, prompt_format="raw"):
46
+ self.texts = texts
47
+ self.tokenizer = tokenizer
48
+ self.max_length = max_length
49
+ self.prompt_format = prompt_format
50
+
51
+ def __len__(self):
52
+ return len(self.texts)
53
+
54
+ def __getitem__(self, idx):
55
+ # Return plain Python lists; let the collator pad & build labels.
56
+ entry = self.texts[idx]
57
+
58
+ prompt = prep_prompt(entry, self.prompt_format)
59
+
60
+ return tokenize_str(self.tokenizer, prompt)
61
+
62
+
63
+ class LazyTokenizingInferenceDataset(TorchDataset):
64
+ def __init__(self, texts, tokenizer, prompt_format, max_length=512, debug=False):
65
+ self.texts = texts
66
+ self.tokenizer = tokenizer
67
+ self.max_length = max_length
68
+ self.prompt_format = prompt_format
69
+ self.debug = debug
70
+
71
+ def __len__(self):
72
+ return len(self.texts)
73
+
74
+ def __getitem__(self, idx):
75
+ entry = self.texts[idx]
76
+
77
+ prompt = prep_prompt(entry, self.prompt_format, inference=True)
78
+ result = tokenize_str(self.tokenizer, prompt, add_eos=False, for_inf=True)
79
+
80
+ if self.debug:
81
+ log(f"Input: {prompt}")
82
+ log(f"Tokenized: {result}")
83
+
84
+ return result
85
+
86
+
87
+ def read_input(path, formt):
88
+ if path is None:
89
+ log("Reading from STDIN")
90
+ fh = sys.stdin
91
+ else:
92
+ fh = open(path, 'r')
93
+
94
+ if formt == PF_RAW:
95
+ result = [fh.read()]
96
+ elif formt == PF_RAWLINES:
97
+ result = fh.readlines()
98
+ else:
99
+ result = json.load(fh)
100
+
101
+ return result
102
+
103
+
104
+ def get_data_loader(path, prompt_format, tokenizer, debug=False):
105
+ inputs = read_input(path, prompt_format)
106
+
107
+ dataset = LazyTokenizingInferenceDataset(inputs, tokenizer, prompt_format, debug=debug)
108
+
109
+ """
110
+ data_coll = DataCollatorForLanguageModeling(
111
+ tokenizer=tokenizer,
112
+ mlm=False,
113
+ pad_to_multiple_of=None, # helps performance; set None if you prefer exact lengths
114
+ )
115
+
116
+ data_loader = DataLoader(dataset, collate_fn=data_coll, batch_size=1)
117
+ """
118
+
119
+ return dataset
120
+
121
+
122
+
123
+ def load_training_data(path, tokenizer, cmd_args):
124
+ with open(path, "r") as f:
125
+ data = json.load(f)
126
+
127
+ train_set_iter = LazyTokenizingDataset(data, tokenizer, cmd_args.max_length, cmd_args.prompt_format)
128
+
129
+ return train_set_iter
130
+
131
+
132
+ if __name__ == '__main__':
133
+ all_data = []
134
+
135
+ for input_file in sys.argv[1:]:
136
+ with open(input_file, "r") as f:
137
+ this_data = json.load(f)
138
+ all_data += this_data
139
+
140
+ shuffle(all_data)
141
+
142
+ json.dump(all_data, sys.stdout)
kuidastaltsutadalaamat/inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from .promptops import *
4
+
5
+ from .aux import CmdlineArgs, log
6
+ from .data import get_data_loader
7
+ from .trainllm import env_stuff, load_model, load_tokenizer
8
+
9
+
10
+ import sys
11
+ import torch
12
+ import json
13
+ import torch.distributed as dist
14
+
15
+ from accelerate import Accelerator
16
+
17
+ from datetime import datetime
18
+
19
+ """
20
+ This currently assumes the batch size to be 1. With larger batches the padding tokens went
21
+ into the decoder. Right-padding as a solution?
22
+ """
23
+ def llm_generate(model, tokenizer, tok_batch, debug=False, max_len=2000):
24
+ tok_batch['input_ids'] = tok_batch['input_ids'].to(model.device)
25
+ tok_batch['attention_mask'] = tok_batch['attention_mask'].to(model.device)
26
+ start_time = datetime.now()
27
+
28
+ if debug:
29
+ log(f"Tokenized input: {tok_batch['input_ids']}")
30
+
31
+ raw_output_toks = model.generate(**tok_batch, tokenizer=tokenizer,
32
+ do_sample=False, num_beams=4, max_length=max_len, top_p=None, temperature=None,
33
+ eos_token_id=[tokenizer.eos_token_id,
34
+ tokenizer.convert_tokens_to_ids("<|reserved_special_token_14|>")])
35
+
36
+ #clean_output_toks = remove_prompt_from_output(tok_batch['attention_mask'], raw_output_toks, filler_id)
37
+ assert len(raw_output_toks) == 1, "Only batch size=1 supported %-("
38
+ gen_idx = len(tok_batch['attention_mask'][0])
39
+
40
+ if debug:
41
+ log(f"Full tokenized output: {raw_output_toks[0]}")
42
+ log(f"Full tokens: {tokenizer.convert_ids_to_tokens(raw_output_toks[0])}")
43
+ full_out = tokenizer.batch_decode([raw_output_toks[0]], skip_special_tokens=True)
44
+ log(f"Full text: {full_out[0]}")
45
+
46
+ clean_output_toks = raw_output_toks[0][gen_idx:]
47
+ clean_outputs = tokenizer.batch_decode([clean_output_toks], skip_special_tokens=True)
48
+
49
+ if debug:
50
+ log(f"Pruned tokenized output: {clean_output_toks}")
51
+ log(f"Pruned tokens: {tokenizer.convert_ids_to_tokens(clean_output_toks)}")
52
+ log(f"Cleaned output: {clean_outputs[0]}")
53
+
54
+ end_time = datetime.now()
55
+ log(f"This took: {end_time - start_time}")
56
+
57
+ return clean_outputs
58
+
59
+
60
+ def reassemble_multi(list_of_lists):
61
+ result = []
62
+
63
+ for gen_idx in range(len(list_of_lists[0])):
64
+ for i in range(len(list_of_lists)):
65
+ if gen_idx < len(list_of_lists[i]):
66
+ result.append(list_of_lists[i][gen_idx])
67
+
68
+ return result
69
+
70
+
71
+ def predict(model, tokenizer, data_loader, accel, multi=False, debug=False, max_len=2000):
72
+ outs_final = []
73
+
74
+ with torch.no_grad():
75
+ for idx, batch in enumerate(data_loader):
76
+ if idx % accel.num_processes == accel.process_index:
77
+ start_time = datetime.now()
78
+ outputs = llm_generate(model, tokenizer, batch, debug=debug, max_len=max_len)
79
+ end_time = datetime.now()
80
+ log(f"Generated for {idx} in proc {accel.process_index} in {end_time - start_time}")
81
+ outs_final += outputs
82
+
83
+ if multi:
84
+ accel.wait_for_everyone()
85
+
86
+ rank0_buffer = [None] * accel.num_processes if accel.is_main_process else None
87
+ dist.gather_object(outs_final, rank0_buffer, dst=0)
88
+ if accel.is_main_process:
89
+ outs_final = reassemble_multi(rank0_buffer)
90
+ else:
91
+ outs_final = None
92
+
93
+ return outs_final
94
+
95
+
96
+ def _cmdline_args():
97
+ inputs = sys.argv[1:]
98
+
99
+ description = """Predict output for an input via prompting"""
100
+
101
+ pos_args = ["mdl_id"]
102
+
103
+ #post-process the arguments
104
+ args = CmdlineArgs(description, pos_args, input_args=inputs,
105
+ kw_arg_dict={"debug": False,
106
+ "input_file": "none",
107
+ "output_file": "none",
108
+ "multiproc": False,
109
+ "max_len": 2000,
110
+ "prompt_format": PF_ALPACA})
111
+
112
+ if args.input_file == "none":
113
+ args.input_file = None
114
+ if args.output_file == "none":
115
+ args.output_file = None
116
+
117
+ log(f"Launched as {args}")
118
+
119
+ return args
120
+
121
+
122
+ def save_all(outputs, args, acc):
123
+ if acc.is_main_process:
124
+ if args.output_file is None:
125
+ log("Writing to STDOUT")
126
+ out_fh = sys.stdout
127
+ else:
128
+ out_fh = open(args.output_file, "w")
129
+
130
+ if args.prompt_format in {PF_RAW, PF_RAWLINES}:
131
+ for line in outputs:
132
+ out_fh.write(line + "\n")
133
+ else:
134
+ json.dump(outputs, out_fh)
135
+
136
+
137
+ def and_i_called_this_function_do_main_too():
138
+ args = _cmdline_args()
139
+
140
+ if args.multiproc:
141
+ env_stuff()
142
+
143
+ acc = Accelerator()
144
+ device = acc.device
145
+
146
+ log(f"Device: {device}.", accelerator=acc)
147
+
148
+ if not args.multiproc and not acc.is_main_process:
149
+ log("Not launched in multi-processing mode, exiting non-main process.")
150
+ sys.exit(0)
151
+
152
+ tokenizer = load_tokenizer(args.mdl_id, acc)
153
+
154
+ data_loader = get_data_loader(args.input_file, args.prompt_format, tokenizer, debug=args.debug)
155
+
156
+ model = load_model(args.mdl_id, device, acc, attention="eager")
157
+ model.eval()
158
+
159
+ log(f"Device: {model.device}.", accelerator=acc)
160
+
161
+ log("Model loaded, starting to generate")
162
+ outputs = predict(model, tokenizer, data_loader, acc, multi=args.multiproc, debug=args.debug, max_len=args.max_len)
163
+
164
+ save_all(outputs, args, acc)
165
+
166
+ log("Done")
167
+
168
+
169
+ if __name__ == "__main__":
170
+ and_i_called_this_function_do_main_too()
kuidastaltsutadalaamat/legacy/accel.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ from accelerate import Accelerator
6
+ from datetime import datetime
7
+ from transformers import get_scheduler
8
+
9
+ from aux import SameLineLogger, log
10
+ from data import DataState, BatchingIterator
11
+ from modelops import save_all_models, report_devices
12
+
13
+
14
+ def chain_params(coupling_specs):
15
+ for spec in coupling_specs:
16
+ yield from spec.model.parameters()
17
+
18
+
19
+ class TrainLossList:
20
+ def __init__(self):
21
+ self.data = []
22
+
23
+ def append(self, loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx):
24
+ self.data.append((loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx))
25
+
26
+ def state_dict(self):
27
+ return {'data': self.data}
28
+
29
+ def load_state_dict(self, state_dict):
30
+ self.data = state_dict['data']
31
+
32
+
33
+ class SwitchingAccelerator:
34
+ def __init__(self, train_set, train_kwargs, model, tokenizer, preinit_acc=None):
35
+ self.kwargs = train_kwargs
36
+ self.train_set_iter = BatchingIterator(train_set, self.kwargs.batch_size, tokenizer, train_kwargs.max_length)
37
+
38
+ self.model = model
39
+ self.tokenizer = tokenizer
40
+
41
+ self.train_loss_list = TrainLossList()
42
+ self.data_state = DataState(epoch_idx=0)
43
+
44
+ self._init_acc_and_stuff(preinit_acc)
45
+
46
+ self._init_time_keepers()
47
+
48
+ def _init_time_keepers(self):
49
+ if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
50
+ t = datetime.now()
51
+ self._tk_zero = t - t
52
+
53
+ self._tk_stats = {}
54
+ self._tk_time = {}
55
+
56
+ def _add_timekeeper(self, msg):
57
+ if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
58
+ self._tk_stats[msg] = []
59
+ self._tk_time[msg] = None
60
+
61
+ def _add_timekeepers(self, msgs):
62
+ for msg in msgs:
63
+ self._add_timekeeper(msg)
64
+
65
+ def _tk_start(self, msg):
66
+ if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
67
+ assert self._tk_time[msg] is None
68
+
69
+ self._tk_time[msg] = datetime.now()
70
+
71
+ def _tk_stop(self, msg):
72
+ if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
73
+ assert self._tk_time[msg] is not None
74
+
75
+ this_time = datetime.now() - self._tk_time[msg]
76
+ self._tk_time[msg] = None
77
+ self._tk_stats[msg].append(this_time)
78
+
79
+ log(f"{msg} took {this_time}, avg time: " +
80
+ f" {sum(self._tk_stats[msg], self._tk_zero) / len(self._tk_stats[msg])}" +
81
+ f" over {len(self._tk_stats[msg])} samples")
82
+
83
+ def __handle_accum(self):
84
+
85
+ assert self.kwargs.batch_size % (self.accelerator.num_processes * self.kwargs.nr_sents_per_gpu) == 0,\
86
+ "batch size must be divisible by number of processes and number of segments per GPU"
87
+
88
+ accum_steps = int((self.kwargs.batch_size / self.accelerator.num_processes) / self.kwargs.nr_sents_per_gpu)
89
+ self.accelerator.gradient_accumulation_steps = accum_steps
90
+
91
+ log(f"Nr sents/GPU: {self.kwargs.nr_sents_per_gpu}, accum steps: {accum_steps}, " +
92
+ f"nr. procs: {self.accelerator.num_processes}, batch size: {self.kwargs.batch_size}",
93
+ accelerator=self.accelerator)
94
+
95
+ def ___get_train_scalars(self):
96
+ epoch_len = len(self.train_set_iter)
97
+ train_len = epoch_len * self.kwargs.epochs
98
+
99
+ num_warmup = 0 #int(train_len * 0.01)
100
+
101
+ log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
102
+
103
+ return train_len, num_warmup
104
+
105
+ def __init_opt_lr_and_what_else(self):
106
+ train_len, num_warmup = self.___get_train_scalars()
107
+
108
+ opt = torch.optim.AdamW(self.model.parameters(), lr=self.kwargs.lr)
109
+
110
+ numtr = train_len * self.accelerator.num_processes
111
+ lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup, num_training_steps=numtr)
112
+
113
+ self.optimizer, self.lr_scheduler, self.model = self.accelerator.prepare(opt, lr_scheduler, self.model)
114
+
115
+ self.accelerator.register_for_checkpointing(self.data_state, self.train_loss_list)
116
+
117
+ def _init_acc_and_stuff(self, preinit_acc=None):
118
+ #self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
119
+
120
+ if preinit_acc is None:
121
+ self.accelerator = Accelerator()
122
+ else:
123
+ self.accelerator = preinit_acc
124
+
125
+ self.__handle_accum()
126
+
127
+ self.__init_opt_lr_and_what_else()
128
+
129
+ if self.kwargs.continue_training:
130
+ self.accelerator.load_state(self.kwargs.mdl_id)
131
+ log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
132
+
133
+ def train(self, dry_run=False):
134
+ try:
135
+ self._main_loop(dry_run)
136
+ except Exception as e:
137
+ #in multiprocess scenarios it is hard to read the stack trace, so just show one:
138
+ if self.accelerator.is_main_process:
139
+ raise e
140
+
141
+ self.accelerator.wait_for_everyone()
142
+
143
+ unwr_coupled_model = self.accelerator.unwrap_model(self.model)
144
+
145
+ return unwr_coupled_model
146
+
147
+ def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
148
+ from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
149
+ to_proc_idx = from_proc_idx + sub_batch_size
150
+
151
+ #log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
152
+
153
+ return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
154
+
155
+ def _get_split_batch_params(self):
156
+ batch_nr_snts = self.kwargs.batch_size
157
+
158
+ assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
159
+
160
+ proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
161
+
162
+ sub_batch_size = self.kwargs.nr_sents_per_gpu
163
+
164
+ nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
165
+
166
+ #log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
167
+ return sub_batch_size, nr_steps, proc_batch_nr_snts
168
+
169
+ def _report_mem_every_once_in_a_while(self, sub_batch_idx, epoch_batch_idx, batch_dim):
170
+ if sub_batch_idx == 0:
171
+ report_devices(f"training memory usage (batch size: {self.kwargs.batch_size} / {batch_dim[1]}",
172
+ self.accelerator, self.model)
173
+
174
+ def _main_loop(self, dry_run):
175
+ if self.accelerator.is_main_process:
176
+ logger = SameLineLogger(len(self.train_set_iter), self.kwargs.epochs, self.data_state)
177
+ logger.line_start()
178
+ else:
179
+ logger = None
180
+
181
+ self.model.train()
182
+ self.train_set_iter.thats_where(self.data_state)
183
+
184
+ tks = "full_batch", "prep_inputs", "forward", "backward", "upd_step"
185
+ tk_batch, tk_prep, tk_fw, tk_bk, tk_step = tks
186
+ self._add_timekeepers(tks)
187
+
188
+ with self.accelerator.accumulate(self.model):
189
+ for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
190
+ for batch, epoch_batch_idx in self.train_set_iter:
191
+ if dry_run:
192
+ log(f"Dry run, batch width: {batch['input_ids'].size()}")
193
+ else:
194
+ self._report_mem_every_once_in_a_while(0, epoch_batch_idx, batch['input_ids'].size())
195
+ sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params()
196
+
197
+ self._tk_start(tk_batch)
198
+
199
+ loss = None
200
+ for sub_batch_idx in range(nr_steps):
201
+ self._tk_start(tk_prep) ########
202
+ inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
203
+
204
+ inputs['labels'] = inputs['input_ids'].copy()
205
+ self._tk_stop(tk_prep) ########
206
+
207
+ self._tk_start(tk_fw) ########
208
+ outputs = self.model(**inputs)
209
+
210
+ loss = outputs.loss
211
+ self._tk_stop(tk_fw) ########
212
+
213
+ self.train_loss_list.append(loss.item(), sub_batch_idx, epoch_batch_idx, _epoch_idx)
214
+
215
+ self._tk_start(tk_bk) ########
216
+ self.accelerator.backward(loss)
217
+ self._tk_stop(tk_bk) ########
218
+
219
+ self._tk_start(tk_step) ########
220
+ self.optimizer.step()
221
+ self.lr_scheduler.step()
222
+ self.optimizer.zero_grad()
223
+ self._tk_stop(tk_step) ########
224
+
225
+ self._tk_stop(tk_batch)
226
+
227
+ #assert self.accelerator.sync_gradients, "It is not time to sync gradients yet."
228
+ self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
229
+
230
+ if self.accelerator.is_main_process:
231
+ logger.line_break()
232
+
233
+ def get_total_grad(self):
234
+ result = 0
235
+ grad_count = 0
236
+ all_count = 0
237
+
238
+ for p in self.model.parameters():
239
+ if p.grad is not None:
240
+ result += p.grad.abs().mean().item()
241
+ grad_count += 1
242
+ all_count += 1
243
+
244
+ return result/grad_count if grad_count != 0 else -1
245
+
246
+ def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
247
+ epoch_len = len(self.train_set_iter)
248
+ global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
249
+
250
+ is_end_of_epoch = (epoch_batch_idx == epoch_len)
251
+
252
+ if self.accelerator.is_main_process \
253
+ and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
254
+ grad = self.get_total_grad()
255
+
256
+ logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
257
+
258
+ #self.optimizer.zero_grad()
259
+
260
+ if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
261
+ self.accelerator.wait_for_everyone()
262
+
263
+ if self.accelerator.is_main_process:
264
+ logger.line_break()
265
+ log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
266
+
267
+ self._save_all(global_batch_idx, epoch_i)
268
+
269
+ logger.line_start()
270
+
271
+ def _save_all(self, global_batch_idx, epoch_i):
272
+ epoch_len = len(self.train_set_iter)
273
+
274
+ ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
275
+ (f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
276
+
277
+ this_location = os.path.join(self.kwargs.save_location, ckpt_name)
278
+ if os.path.exists(this_location):
279
+ raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
280
+
281
+ self.data_state.copy_from(self.train_set_iter.where_are_we(), epoch_idx=epoch_i)
282
+
283
+ model_to_save = self.accelerator.unwrap_model(self.model)
284
+
285
+ save_all_models(this_location, model_to_save, self.tokenizer, trainer=self.accelerator)
286
+
287
+ def test_this_damn_thing():
288
+ # testing
289
+ import torch
290
+ import json
291
+ from torch.optim import AdamW
292
+ from modelops import hf_tok
293
+ from transformers import AutoModelForCausalLM, AutoTokenizer
294
+
295
+ mdl_id = "models/llama3.2-1b"
296
+
297
+ tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
298
+ model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
299
+ with open("tmpx.json", "r") as f:
300
+ training_data_raw = json.load(f)
301
+
302
+ optimizer = AdamW(model.parameters(), lr=5e-6)
303
+
304
+ print("Initial 0:", optimizer.param_groups[0]['lr']) # Should be [5e-6]
305
+
306
+ scheduler = get_scheduler(
307
+ "linear",
308
+ optimizer=optimizer,
309
+ num_warmup_steps=0,
310
+ num_training_steps=2445
311
+ )
312
+
313
+ accel = Accelerator()
314
+
315
+ p_optimizer, p_lr_scheduler, p_model = accel.prepare(optimizer, scheduler, model)
316
+
317
+ print("Initial 1:", p_lr_scheduler.get_last_lr()) # Should be [5e-6]
318
+
319
+ """
320
+ for _ in range(2):
321
+ optimizer.step()
322
+ scheduler.step()
323
+ print("Step:", scheduler.get_last_lr())
324
+ """
325
+
326
+
327
+ if __name__ == "__main__":
328
+ test_this_damn_thing()
kuidastaltsutadalaamat/legacy/accel_backup.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ import os
3
+
4
+ import torch
5
+
6
+ from accelerate import Accelerator, DistributedDataParallelKwargs
7
+ from transformers import get_scheduler
8
+
9
+ from aux import SameLineLogger, log
10
+ from data import DataState
11
+ from langconv import is_dec_only_llm
12
+ from modelops import save_all_models, report_devices
13
+ from translate import encode
14
+
15
+
16
+ raise NotImplementedError("This is a backup package, do not run or import from it")
17
+
18
+
19
+ def chain_params(coupling_specs):
20
+ for spec in coupling_specs:
21
+ yield from spec.model.parameters()
22
+
23
+
24
+ class TrainLossList:
25
+ def __init__(self):
26
+ self.data = []
27
+
28
+ def append(self, loss_val, src_k, tgt_k):
29
+ self.data.append((loss_val, src_k, tgt_k))
30
+
31
+ def state_dict(self):
32
+ return {'data': self.data}
33
+
34
+ def load_state_dict(self, state_dict):
35
+ self.data = state_dict['data']
36
+
37
+
38
+
39
+ class SwitchingAccelerator:
40
+ def __init__(self, coupling_specs, train_set, train_kwargs):
41
+ self.coupling_specs = coupling_specs
42
+
43
+ self.train_set = train_set
44
+ self.kwargs = train_kwargs
45
+
46
+ self.is_generative = is_dec_only_llm(self.coupling_specs[0].tokenizer)
47
+
48
+ self.train_loss_list = TrainLossList()
49
+ self.data_state = DataState(epoch_idx=0)
50
+
51
+ self._init_acc_and_stuff()
52
+
53
+ def _init_acc_and_stuff(self):
54
+ #self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
55
+ #self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps)
56
+ self.accelerator = Accelerator()
57
+
58
+ epoch_len = len(self.train_set)
59
+ train_len = epoch_len * self.kwargs.epochs
60
+
61
+ num_warmup = int(train_len * 0.01)
62
+
63
+ log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
64
+
65
+ opt = torch.optim.AdamW(chain_params(self.coupling_specs), lr=self.kwargs.lr)
66
+ lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup,
67
+ num_training_steps=train_len * self.accelerator.num_processes)
68
+ models = [s.model for s in self.coupling_specs]
69
+
70
+ self.optimizer, self.lr_scheduler, *self.models = self.accelerator.prepare(opt, lr_scheduler, *models)
71
+
72
+ self.accelerator.register_for_checkpointing(self.lr_scheduler, self.data_state, self.train_loss_list)
73
+
74
+ if self.kwargs.continue_training:
75
+ self.accelerator.load_state(self.kwargs.mdl_id)
76
+ log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
77
+
78
+ def train(self):
79
+ try:
80
+ self._main_loop()
81
+ except Exception as e:
82
+ #in multi-process scenarios it is hard to read the stack trace, so just show one:
83
+ if self.accelerator.is_main_process:
84
+ raise e
85
+
86
+ self.accelerator.wait_for_everyone()
87
+
88
+ unwr_coupled_model = self.accelerator.unwrap_model(self.models[0])
89
+
90
+ return unwr_coupled_model, self.train_loss_list
91
+
92
+ def _split_batch_and_bin_idxs(self, batch_with_idxs):
93
+ if self.is_generative:
94
+ batch, _ = batch_with_idxs
95
+ src_k = 0
96
+ tgt_k = 0
97
+ else:
98
+ batch, src_k, tgt_k, _ = batch_with_idxs
99
+ return batch, src_k, tgt_k
100
+
101
+ def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
102
+ from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
103
+ to_proc_idx = from_proc_idx + sub_batch_size
104
+
105
+ #log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
106
+
107
+ return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
108
+
109
+ def _get_split_batch_params(self, batch):
110
+ batch_nr_snts = batch['input_ids'].size()[0]
111
+ snt_nr_words = batch['input_ids'].size()[1]
112
+
113
+ assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
114
+
115
+ proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
116
+
117
+ if self.kwargs.nr_snts_in_batch > 0:
118
+ sub_batch_size = self.kwargs.nr_snts_in_batch
119
+ else:
120
+ sub_batch_size = max(1, self.kwargs.nr_words_in_batch // snt_nr_words)
121
+ #log(f"DEBUG: #words/snt {snt_nr_words} X #snt in sub batch {sub_batch_size} = {snt_nr_words*sub_batch_size} ~ {self.kwargs.nr_words_in_batch}", accelerator=self.accelerator)
122
+
123
+ nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
124
+
125
+ #log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
126
+ return sub_batch_size, nr_steps, proc_batch_nr_snts
127
+
128
+ def _main_loop(self):
129
+ #countdown_till_do_it_once = 0
130
+
131
+ if self.accelerator.is_main_process:
132
+ logger = SameLineLogger(len(self.train_set), self.kwargs.epochs)
133
+ logger.line_start()
134
+ else:
135
+ logger = None
136
+
137
+ self.models[0].train()
138
+ self.train_set.thats_where(self.data_state)
139
+
140
+ for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
141
+ for batch_with_bin_idxs, epoch_batch_idx in self.train_set:
142
+ batch, src_k, tgt_k = self._split_batch_and_bin_idxs(batch_with_bin_idxs)
143
+ sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params(batch)
144
+
145
+ loss = None
146
+
147
+ for sub_batch_idx in range(nr_steps):
148
+ inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
149
+
150
+ if self.is_generative:
151
+ inputs['labels'] = inputs['input_ids']
152
+ outputs = self.models[0](**inputs)
153
+ else:
154
+ encoder_vecs = encode(self.models[src_k], inputs)
155
+ outputs = self.models[tgt_k](attention_mask=inputs['attention_mask'], labels=inputs['labels'], encoder_outputs=encoder_vecs)
156
+
157
+ loss = outputs.loss
158
+
159
+ #if countdown_till_do_it_once > 0:
160
+ # countdown_till_do_it_once -= 1
161
+ #elif countdown_till_do_it_once == 0:
162
+ if sub_batch_idx == 5:
163
+ batch_size = sum([inputs[k].size()[0] * inputs[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')])
164
+ report_devices(f"training memory usage (batch size: {batch_size}; inputs:" +
165
+ f"snts {inputs['input_ids'].size()[0]} X words {inputs['input_ids'].size()[1]})",
166
+ self.accelerator, self.models[0])
167
+ countdown_till_do_it_once = 0
168
+
169
+ self.train_loss_list.append(loss.item(), src_k, tgt_k)
170
+
171
+ self.accelerator.backward(loss)
172
+
173
+ for k in inputs:
174
+ inputs[k] = inputs[k].to('cpu')
175
+
176
+ self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
177
+
178
+ if self.accelerator.is_main_process:
179
+ logger.line_break()
180
+
181
+ def get_total_grad(self):
182
+ result = 0
183
+ grad_count = 0
184
+ all_count = 0
185
+
186
+ for p in self.models[0].parameters():
187
+ if p.grad is not None:
188
+ result += p.grad.abs().mean().item()
189
+ grad_count += 1
190
+ all_count += 1
191
+
192
+ return result/grad_count if grad_count > 0 else -1
193
+
194
+ def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
195
+ epoch_len = len(self.train_set)
196
+ global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
197
+
198
+ self.optimizer.step()
199
+ self.lr_scheduler.step()
200
+ self.accelerator.wait_for_everyone()
201
+
202
+ is_end_of_epoch = (epoch_batch_idx == epoch_len)
203
+
204
+ if self.accelerator.is_main_process and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
205
+ grad = self.get_total_grad()
206
+ logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
207
+
208
+ self.optimizer.zero_grad()
209
+
210
+ if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
211
+ self.accelerator.wait_for_everyone()
212
+
213
+ if self.accelerator.is_main_process:
214
+ logger.line_break()
215
+ log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
216
+
217
+ self._save_all(global_batch_idx, epoch_i)
218
+
219
+ logger.line_start()
220
+
221
+ def _save_all(self, global_batch_idx, epoch_i):
222
+ epoch_len = len(self.train_set)
223
+
224
+ ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
225
+ (f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
226
+
227
+ this_location = os.path.join(self.kwargs.save_location, ckpt_name)
228
+ if os.path.exists(this_location):
229
+ raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
230
+
231
+ self.data_state.copy_from(self.train_set.where_are_we(), epoch_idx=epoch_i)
232
+
233
+ model_to_save = self.accelerator.unwrap_model(self.models[0])
234
+
235
+ save_all_models(this_location, model_to_save, self.coupling_specs[0].tokenizer,
236
+ self.coupling_specs, trainer=self.accelerator)
237
+ """
kuidastaltsutadalaamat/legacy/benchmark.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ import json
6
+
7
+ from collections import defaultdict
8
+ from data import split_by_lang, make_path_compatible, get_tr_pairs
9
+ from inference import coupled_translate, load_and_init_module_config, neurotolge_in_batches
10
+ from evaluate import load as load_metric
11
+ from legacy.langconv import get_mdl_type, get_joshi_class
12
+ from accelerate import Accelerator
13
+
14
+ from aux import log
15
+
16
+
17
+ def get_hyp_cache_dir(model_location, create=False):
18
+ hyp_location = os.path.join(model_location, "hyp_cache")
19
+ if create:
20
+ os.makedirs(hyp_location, exist_ok=True)
21
+ return hyp_location
22
+
23
+
24
+ def get_hyp_cache_filename(model_location, benchmark_corpus, src_lang, tgt_lang):
25
+ hyp_location = get_hyp_cache_dir(model_location)
26
+
27
+ corpus_base = os.path.basename(benchmark_corpus)
28
+ basename = f"{corpus_base}-{src_lang}-to-{tgt_lang}"
29
+
30
+ hyp_file = os.path.join(hyp_location, f"{basename}.hyp")
31
+ src_file = os.path.join(hyp_location, f"{basename}.src")
32
+
33
+ return hyp_file, src_file
34
+
35
+
36
+ def get_benchmark_filename(model_location, benchmark_corpus):
37
+ corpus_base = os.path.basename(benchmark_corpus)
38
+ hyp_file = f"{corpus_base}-scores.json"
39
+ return os.path.join(model_location, hyp_file)
40
+
41
+
42
+ def load_hyps_from_file(filename):
43
+ with open(filename, "r", encoding="utf-8") as f:
44
+ return [line.strip() for line in f]
45
+
46
+
47
+ def save_hyps_to_file(hypos, filename):
48
+ if hypos is not None:
49
+ with open(filename, "w", encoding="utf-8") as f:
50
+ for hyp in hypos:
51
+ f.write(hyp + "\n")
52
+
53
+
54
+ def load_or_translate(mod_config, input_output_list, lp, model_location, benchmark_corpus):
55
+ src_lang, tgt_lang = lp.split("-")
56
+
57
+ inputs, _ = zip(*input_output_list)
58
+
59
+ cache_filename, src_filename = get_hyp_cache_filename(model_location, benchmark_corpus, src_lang, tgt_lang)
60
+
61
+ try:
62
+ hypos = load_hyps_from_file(cache_filename)
63
+ except FileNotFoundError:
64
+ if model_location == "models/neurotolge":
65
+ hypos = neurotolge_in_batches(inputs, src_lang, tgt_lang)
66
+ else:
67
+ hypos = coupled_translate(mod_config, inputs, src_lang, tgt_lang)
68
+
69
+ if hypos is not None:
70
+ save_hyps_to_file(hypos, cache_filename)
71
+ save_hyps_to_file(inputs, src_filename)
72
+
73
+ return zip(inputs, hypos)
74
+
75
+
76
+ def translate_all_hyps(lp_test_set_dict, module_conf, model_id, corpus_id, accelerator=None):
77
+ if accelerator is not None:
78
+ key_list = sorted(lp_test_set_dict.keys())
79
+ for idx, lp in enumerate(key_list):
80
+ if idx % accelerator.num_processes == accelerator.process_index:
81
+ log(f"Process {accelerator.process_index} translating {lp}")
82
+ load_or_translate(module_conf, lp_test_set_dict[lp], lp, model_id, corpus_id)
83
+ accelerator.wait_for_everyone()
84
+ else:
85
+ result = dict()
86
+ for i, lp in enumerate(lp_test_set_dict.keys()):
87
+ log(f"Translating {lp}, {i + 1}/{len(lp_test_set_dict)}")
88
+ result[lp] = load_or_translate(module_conf, lp_test_set_dict[lp], lp, model_id, corpus_id)
89
+ return result
90
+
91
+
92
+ def get_joshi_lp(from_lang, to_lang):
93
+ from_joshi = get_joshi_class(from_lang)
94
+ to_joshi = get_joshi_class(to_lang)
95
+
96
+ return f"{from_joshi}-{to_joshi}"
97
+
98
+
99
+ def get_all_scores(hyps_dict, lp_test_sets, metric_dict):
100
+ scores = dict()
101
+ avgs = defaultdict(list)
102
+
103
+ for lp in lp_test_sets:
104
+ from_lang, to_lang = lp.split("-")
105
+ jlp = get_joshi_lp(from_lang, to_lang)
106
+
107
+ _, outputs = zip(*lp_test_sets[lp])
108
+
109
+ preds = None if hyps_dict[lp] is None else [hyp for _, hyp in hyps_dict[lp]]
110
+
111
+ for metric_name in metric_dict:
112
+ metric_func = metric_dict[metric_name]
113
+
114
+ if preds is not None:
115
+ metric_value = metric_func.compute(predictions=preds, references=outputs)
116
+
117
+ scores[lp + "-" + metric_name] = metric_value['score']
118
+
119
+ avgs[jlp + "-" + metric_name].append(metric_value['score'])
120
+
121
+ for avg_k in avgs:
122
+ scores[avg_k] = sum(avgs[avg_k]) / len(avgs[avg_k])
123
+
124
+ return scores
125
+
126
+
127
+ def save_scores(scores, mdl_id, corpus):
128
+ filename = get_benchmark_filename(mdl_id, corpus)
129
+ with open(filename, "w") as ofh:
130
+ json.dump(scores, ofh, indent=2, sort_keys=True)
131
+
132
+
133
+ def benchmark_neurotolge(corpus):
134
+ log("Loading data")
135
+ lp_test_sets = split_by_lang(filename=corpus, model_type=None)
136
+
137
+ log("Starting benchmarking")
138
+ _ = get_hyp_cache_dir("models/neurotolge", create=True)
139
+
140
+ hyps_dict = translate_all_hyps(lp_test_sets, None, "models/neurotolge", corpus)
141
+
142
+ log("Loading metrics")
143
+ exp_id = "neurotõlge---" + make_path_compatible(corpus)
144
+ metric_dict = {
145
+ 'bleu': load_metric("sacrebleu", experiment_id=exp_id),
146
+ 'chrf': load_metric("chrf", experiment_id=exp_id) }
147
+
148
+ scores = get_all_scores(hyps_dict, lp_test_sets, metric_dict)
149
+
150
+ save_scores(scores, "models/neurotolge", corpus)
151
+
152
+
153
+ def benchmark_local_model(mdl_id, corpus):
154
+ accelerator = Accelerator()
155
+
156
+ main_model, module_config = load_and_init_module_config(mdl_id, accelerator)
157
+
158
+ log("Loading data", accelerator=accelerator)
159
+ lp_test_sets = split_by_lang(filename=corpus, model_type=get_mdl_type(main_model))
160
+
161
+ log("Loading metrics", accelerator=accelerator)
162
+ exp_id = make_path_compatible(mdl_id) + "---" + make_path_compatible(corpus)
163
+
164
+ metric_dict = {
165
+ 'bleu': load_metric("sacrebleu", experiment_id=exp_id),
166
+ 'chrf': load_metric("chrf", experiment_id=exp_id) }
167
+
168
+ log("Starting benchmarking", accelerator=accelerator)
169
+
170
+ if accelerator.is_main_process:
171
+ _ = get_hyp_cache_dir(mdl_id, create=True)
172
+
173
+ translate_all_hyps(lp_test_sets, module_config, mdl_id, corpus, accelerator)
174
+
175
+ if accelerator.is_main_process:
176
+ fin_hyps_dict = translate_all_hyps(lp_test_sets, module_config, mdl_id, corpus)
177
+
178
+ scores = get_all_scores(fin_hyps_dict, lp_test_sets, metric_dict)
179
+
180
+ save_scores(scores, mdl_id, corpus)
181
+
182
+
183
+ if __name__ == '__main__':
184
+ mdl_id_param = sys.argv[1]
185
+ corpus_param = sys.argv[2]
186
+
187
+ if mdl_id_param == "neurotolge":
188
+ benchmark_neurotolge(corpus_param)
189
+ else:
190
+ benchmark_local_model(mdl_id_param, corpus_param)
kuidastaltsutadalaamat/legacy/data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import json
4
+ import sys
5
+
6
+ from torch.utils.data import IterableDataset
7
+ from random import shuffle, randint
8
+ from legacy.tokops import tokenize_batch
9
+ from aux import log
10
+
11
+ def prep_llm_input(ljmftpl):
12
+ raise NotImplementedError
13
+
14
+ #{'task': 'translate' / 'approx-translate' / 'generate',
15
+ # 'src_segm': src_segm,
16
+ # 'tgt_segm': tgt_segm,
17
+ # 'src_lang': src_lang,
18
+ # 'tgt_lang': tgt_lang}
19
+
20
+ result = f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']}"
21
+
22
+ if ljmftpl['task'] in {'translate', 'approx-translate'}:
23
+ result += f"; {ljmftpl['task']} to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}"
24
+
25
+ return result
26
+
27
+ def make_path_compatible(filename):
28
+ return filename.replace("/", "_").replace(":", "-")
29
+
30
+ def do_list_in_batches(data, batch_size):
31
+ i = 0
32
+
33
+ while i < len(data):
34
+ yield data[i:i + batch_size]
35
+ i += batch_size
36
+
37
+
38
+ class DataState:
39
+ def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None):
40
+ self.elem_idx = elem_idx
41
+ self.shard_idx = shard_idx
42
+ self.epoch_idx = epoch_idx
43
+
44
+ def state_dict(self):
45
+ return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx}
46
+
47
+ def load_state_dict(self, state_dict):
48
+ self.elem_idx = state_dict['elem_idx']
49
+ self.shard_idx = state_dict['shard_idx']
50
+ self.epoch_idx = state_dict['epoch_idx']
51
+
52
+ def copy_from(self, src_ds, epoch_idx = None):
53
+ self.shard_idx = src_ds.shard_idx
54
+ self.elem_idx = src_ds.elem_idx
55
+
56
+ if epoch_idx is not None:
57
+ self.epoch_idx = epoch_idx
58
+
59
+ def __str__(self):
60
+ return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx)
61
+
62
+ def __repr__(self):
63
+ return self.__str__()
64
+
65
+
66
+ class BatchingIterator(IterableDataset):
67
+ def __init__(self, batched_data, batch_size, tokenizer, max_len=8000):
68
+ assert len(batched_data[0]) == batch_size, "loaded data batch size and specified batch size differ"
69
+
70
+ self.batched_data = batched_data
71
+
72
+ self.tokenizer = tokenizer
73
+ self.max_len = max_len
74
+
75
+ self.curr_elem_idx = 0
76
+
77
+ self.data_len = len(self.batched_data)
78
+
79
+ def __len__(self):
80
+ return self.data_len
81
+
82
+ def __iter__(self):
83
+ #self.curr_elem_idx = 0
84
+ return self
85
+
86
+ def where_are_we(self):
87
+ return DataState(shard_idx=0, elem_idx=self.curr_elem_idx)
88
+
89
+ def thats_where(self, data_state):
90
+ self.curr_elem_idx = data_state.elem_idx
91
+
92
+ def _tokenize(self, prepped_segm_list):
93
+ #self.tokenizer.pad_token = '<|reserved_special_token_0|>'
94
+ #tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len,
95
+ # truncation=True, add_special_tokens=True,
96
+ # padding=True)
97
+ tokenized_batch = tokenize_batch(self.tokenizer, prepped_segm_list, maxlen=self.max_len)
98
+ return tokenized_batch, self.curr_elem_idx + 1
99
+
100
+ def __next__(self):
101
+ if self.curr_elem_idx >= self.data_len:
102
+ self.curr_elem_idx = 0
103
+ raise StopIteration
104
+ else:
105
+ batch = self._tokenize(self.batched_data[self.curr_elem_idx])
106
+ self.curr_elem_idx += 1
107
+ return batch
108
+
109
+
110
+ def shuffle_data():
111
+ # open a list of tuples, save a list of batches of strings made of these tuples
112
+ input_file = sys.argv[1]
113
+ output_file = sys.argv[2]
114
+
115
+ try:
116
+ batch_size = int(sys.argv[3])
117
+ except IndexError:
118
+ batch_size = None
119
+
120
+ log("Reading data")
121
+ # read the tuples
122
+ with open(input_file, "r") as f:
123
+ #raw_data = json.load(f)
124
+ final_data = json.load(f)
125
+
126
+ log("Making strings")
127
+ # make strings out of tuples
128
+ unsorted_data_in_elems = [prep_llm_input(s) for s in raw_data]
129
+
130
+ if batch_size is None:
131
+ final_data = unsorted_data_in_elems
132
+ else:
133
+ # if last batch is undersized, get some random elements to compensate
134
+ while len(unsorted_data_in_elems) % batch_size != 0:
135
+ new_elem_idx = randint(0, len(unsorted_data_in_elems) - 1)
136
+ unsorted_data_in_elems.append(unsorted_data_in_elems[new_elem_idx])
137
+
138
+ log("Sorting and grouping")
139
+ # sort by length
140
+ sorted_data_in_elems = sorted(unsorted_data_in_elems, key=lambda x: len(x), reverse=True)
141
+
142
+ # group into batches
143
+ final_data = list(do_list_in_batches(sorted_data_in_elems, batch_size))
144
+
145
+ log("Shuffling")
146
+ # shuffle the batches / sentences
147
+ shuffle(final_data)
148
+
149
+ log("Saving")
150
+ # save the result
151
+ with open(output_file, "w") as f:
152
+ json.dump(final_data, f)
153
+
154
+ if __name__ == '__main__':
155
+ all_data = []
156
+
157
+ for input_file in sys.argv[1:]:
158
+ with open(input_file, "r") as f:
159
+ this_data = json.load(f)
160
+ all_data += this_data
161
+
162
+ shuffle(all_data)
163
+
164
+ json.dump(all_data, sys.stdout)
kuidastaltsutadalaamat/legacy/data_backup.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import json
4
+ #import os
5
+ import sys
6
+ import torch
7
+ #import re
8
+ import math
9
+
10
+ from torch.utils.data import IterableDataset
11
+ from collections import namedtuple, defaultdict
12
+ from random import randrange, shuffle, randint
13
+ #from pathlib import Path
14
+
15
+ #from aux import log
16
+ #from langconv import any_to_madlad, any_to_nllb, is_nllb, is_madlad, get_mdl_type, any_to_mdl_type, is_dec_only_llm, \
17
+ # base_to_nllb
18
+ #from tokops import tokenizeit
19
+
20
+ # TrPair = namedtuple('TrPair', ["src_lang", "tgt_lang", "input", "output"])
21
+
22
+ """
23
+ def prep_llm_input(ljmftpl):
24
+ #{'task': 'translate' / 'approx-translate' / 'generate',
25
+ # 'src_segm': src_segm,
26
+ # 'tgt_segm': tgt_segm,
27
+ # 'src_lang': src_lang,
28
+ # 'tgt_lang': tgt_lang}
29
+
30
+ # it's a tuple
31
+ if "src_segm" in ljmftpl and "task" in ljmftpl:
32
+ if ljmftpl['task'] in {'translate', 'approx-translate'}:
33
+ return (f"{ljmftpl['src_segm']}\n=====\n{ljmftpl['task']} from {ljmftpl['src_lang']}; " +
34
+ f"to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}")
35
+
36
+ elif ljmftpl['task'] == 'generate':
37
+ return f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']};"
38
+
39
+ # it's a string
40
+ else:
41
+ return ljmftpl
42
+
43
+
44
+ def make_path_compatible(filename):
45
+ return filename.replace("/", "_").replace(":", "-")
46
+
47
+ def do_list_in_batches(data, batch_size):
48
+ i = 0
49
+
50
+ while i < len(data):
51
+ yield data[i:i + batch_size]
52
+ i += batch_size
53
+ """
54
+ """
55
+ def do_bins_in_batches(bins, batch_size, sort_by_length):
56
+ result_list = []
57
+
58
+ for src_k in bins:
59
+ for tgt_k in bins[src_k]:
60
+ if src_k == 0 or tgt_k == 0:
61
+ result_list += [(e, src_k, tgt_k) for e in do_list_in_batches(bins[src_k][tgt_k], batch_size)]
62
+
63
+ shuffle(result_list)
64
+
65
+ return result_list
66
+
67
+
68
+ def _post_proc(text, lang):
69
+ if lang == 'liv' and "’" in text and "O’R" not in text:
70
+ return text.replace("’", "")
71
+ else:
72
+ return text
73
+
74
+
75
+ def clean_entry(entry, leave_out):
76
+ result = {k: _post_proc(entry[k], k) for k in entry if entry[k].strip() and k not in leave_out}
77
+ return result
78
+
79
+
80
+ def load_json_data(path, leave_out={}, skip_cats=True, load_mono=True):
81
+ with open(path, 'r') as f:
82
+ data = json.load(f)
83
+
84
+ if skip_cats:
85
+ # skip categories
86
+ resx = [clean_entry(entry, leave_out)
87
+ for cat in data for entry in cat['sentences']]
88
+ res = [e for e in resx if e]
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ # resx = {cat['source']: [clean_entry(entry, leave_out) for entry in cat['sentences']] for cat in data}
93
+ # res = {k: resx[k] for k in resx if resx[k]}
94
+
95
+ return res
96
+
97
+
98
+ def get_tr_pairs(raw_data=None, filename=None, leave_out=None, leave_only=None, model_type=None, exclude_set=None):
99
+ if filename is not None:
100
+ raw_data = load_json_data(filename)
101
+
102
+ if raw_data is None:
103
+ raise ValueError("Neither file nor data are provided")
104
+
105
+ i = 0
106
+ log("Loading data")
107
+ for tup in raw_data:
108
+ for l1 in tup:
109
+ for l2 in tup:
110
+ if l1 != l2 and not "dia" in l1 and not "dia" in l2:
111
+ if leave_out is None or f"{l1}-{l2}" not in leave_out:
112
+ if leave_only is None or f"{l1}-{l2}" in leave_only:
113
+ i += 1
114
+ if not i % 1000000:
115
+ log(f"Loaded {i/1000000}M pairs")
116
+ dia_key = f"{l2}-dia"
117
+
118
+ if exclude_set is None or (tup[l1] not in exclude_set[l1] and tup[l2] not in exclude_set[l2]):
119
+ input = tup[l1]
120
+ if dia_key in tup:
121
+ input = f"<{tup[dia_key]}> {input}"
122
+
123
+ conv_l1 = any_to_mdl_type(model_type, l1)
124
+ conv_l2 = any_to_mdl_type(model_type, l2)
125
+
126
+ if not snt_is_fishy(input, conv_l1) and not snt_is_fishy(tup[l2], conv_l2):
127
+ yield TrPair(conv_l1, conv_l2, input, tup[l2])
128
+
129
+ def split_by_lang(filename, model_type):
130
+ result = defaultdict(list)
131
+
132
+ # if filename is not None:
133
+ # tr_pairs = load_json_datax(filename)
134
+
135
+ tr_pairs = get_tr_pairs(filename=filename, model_type=model_type)
136
+
137
+ for tup in tr_pairs:
138
+ #for l1 in tup:
139
+ # for l2 in tup:
140
+ # if l1 != l2 and not "dia" in l1 and not "dia" in l2:
141
+ l1 = tup.src_lang
142
+ l2 = tup.tgt_lang
143
+ lp = f"{l1}-{l2}"
144
+ result[lp].append((tup.input, tup.output))
145
+
146
+ return result
147
+
148
+
149
+ def data_iter_for_tok_train(raw_data, langs_to_include):
150
+ for tup in raw_data:
151
+ for lang in tup:
152
+ if lang in langs_to_include:
153
+ yield tup[lang]
154
+
155
+
156
+ def lang_bin_mapping(coupling_specs):
157
+ lang_to_idx = dict()
158
+
159
+ for i, spec_pair in enumerate(coupling_specs):
160
+ for lang in spec_pair.lang_set:
161
+ if lang not in lang_to_idx:
162
+ lang_to_idx[lang] = {i}
163
+ else:
164
+ lang_to_idx[lang].add(i)
165
+
166
+ return lang_to_idx
167
+
168
+
169
+ def mix_and_sample_idxs_carefully(src_idxs, tgt_idxs):
170
+ idx_pairs = [(s, t) for s in src_idxs for t in tgt_idxs if not (s == 1 and t == 1)]
171
+
172
+ if len(idx_pairs) == 0:
173
+ result = (None, None)
174
+ else:
175
+ pair_idx = randrange(len(idx_pairs))
176
+ result = idx_pairs[pair_idx]
177
+
178
+ # debug(f"src lang: {tr_pair.src_lang}, tgt_lang: {tr_pair.tgt_lang}, idx list: {idx_pairs}, result: {result}")
179
+
180
+ return result
181
+
182
+
183
+ def inject_bin_indices(batch, src_k, tgt_k):
184
+ batch['input_ids'][0,0] += src_k << 30
185
+
186
+ batch['labels'][0,0] += tgt_k << 30
187
+
188
+ def get_data_cache_location(cache_meta_path, idx):
189
+ cache_folder, cache_file = os.path.split(cache_meta_path)
190
+
191
+ if cache_folder:
192
+ Path(cache_folder).mkdir(parents=True, exist_ok=True)
193
+
194
+ if cache_meta_path.endswith(".json"):
195
+ return cache_meta_path[:-5] + f"_{idx:04}.pt"
196
+ else:
197
+ raise ValueError(f"Expected a json file for the cache meta-location ({cache_meta_path})")
198
+
199
+
200
+ def make_gen_text(src_lang, tgt_lang, input_text, output_text=None, tok=None):
201
+ if input_text.startswith("<"):
202
+ posit = input_text.find(">") + 1
203
+ dialect = input_text[1:posit-1]
204
+ diatxt = f", variety: {dialect}"
205
+ txt = input_text[posit+1:]
206
+ else:
207
+ dialect = None
208
+ diatxt = ""
209
+ txt = input_text
210
+
211
+ return (f"Translate:\n== From: {src_lang}\n== To: {tgt_lang}{diatxt}\n== Input: {txt}\n== Output: " +
212
+ ("" if (output_text is None or tok is None) else f"{output_text}{tok.eos_token}"))
213
+
214
+
215
+ class MultilingualBatchingCachingDataset:
216
+ def _post_proc_bins(self, bins):
217
+ for src_k in bins:
218
+ for tgt_k in bins[src_k]:
219
+ while len(bins[src_k][tgt_k]) % self.args.batch_size != 0:
220
+ rnd_elem_idx = randrange(len(bins[src_k][tgt_k]))
221
+ rnd_elem = bins[src_k][tgt_k][rnd_elem_idx]
222
+ bins[src_k][tgt_k].append(rnd_elem)
223
+
224
+ if self.args.sort_by_len:
225
+ bins[src_k][tgt_k] = sorted(bins[src_k][tgt_k], key=lambda e: len(e.input))
226
+ else:
227
+ shuffle(bins[src_k][tgt_k])
228
+ return bins
229
+
230
+ def _get_idxs(self, tr_pair):
231
+ src_idxs = self._lang_to_idx[tr_pair.src_lang]
232
+ tgt_idxs = self._lang_to_idx[tr_pair.tgt_lang]
233
+
234
+ return mix_and_sample_idxs_carefully(src_idxs, tgt_idxs)
235
+
236
+ def _fill_bins(self):
237
+ bins = defaultdict(lambda: defaultdict(list))
238
+
239
+ for tr_pair in get_tr_pairs(filename=self.filename, model_type=self.model_type, exclude_set=self.exclude_set):
240
+ src_bin_idx, tgt_bin_idx = self._get_idxs(tr_pair)
241
+
242
+ if src_bin_idx is not None and tgt_bin_idx is not None:
243
+ bins[src_bin_idx][tgt_bin_idx].append(tr_pair)
244
+
245
+ return self._post_proc_bins(bins)
246
+
247
+ def report_update_stats(self, bins):
248
+ total = 0
249
+ totalx = 0
250
+ updates = 0
251
+ duds = 0
252
+
253
+ enc_count = 0
254
+ dec_count = 0
255
+
256
+ for src_k in bins:
257
+ for tgt_k in bins[src_k]:
258
+ l = len(bins[src_k][tgt_k])
259
+
260
+ total += l
261
+ if src_k == 0 or tgt_k == 0:
262
+ totalx += l
263
+ updates += l * (1 - (src_k + tgt_k) / 2)
264
+
265
+ enc_count += l * (1 - src_k)
266
+ dec_count += l * (1 - tgt_k)
267
+
268
+ if src_k == 1 and tgt_k == 1:
269
+ duds += 1
270
+ # log(str(self._lang_to_idx))
271
+
272
+ log(f"### Ratio of coupled model updates: {100 * updates / total:.2f}% ({100 * updates / totalx:.2f}%); " + \
273
+ f"frozen meaningless updates: {100 * duds / total:.2f}%; " + \
274
+ f"enc samples: {enc_count}, dec samples: {dec_count}")
275
+
276
+ def tokenize_input(self, cplspec, input_list, rawbatch):
277
+ src_tokenizer = cplspec.tokenizer
278
+ src_tokenizer.src_lang = rawbatch[0].src_lang
279
+ #prep_batch_grouped = src_tokenizer(text=input_list, return_tensors="pt",
280
+ # padding="longest", truncation=True, max_length=self.args.max_snt_len)
281
+ prep_batch_grouped = tokenizeit((src_tokenizer, cplspec.postokenizer), input_list, self.args.max_snt_len, False)
282
+
283
+ if is_nllb(src_tokenizer):
284
+ src_lang_list = [any_to_nllb(e.src_lang) for e in rawbatch]
285
+ src_lang_vec = src_tokenizer.convert_tokens_to_ids(src_lang_list)
286
+ prep_batch_grouped['input_ids'][:,0] = torch.tensor(src_lang_vec)
287
+
288
+ return prep_batch_grouped
289
+
290
+ def tokenize_output(self, tgttokenizer, tgtposttok, rawbatch):
291
+ outputs = [e.output for e in rawbatch]
292
+ tgttokenizer.tgt_lang = rawbatch[0].tgt_lang
293
+ #labels = tgttokenizer(text_target=outputs, return_tensors="pt",
294
+ # padding="longest", truncation=True, max_length=self.args.max_snt_len)
295
+ labels = tokenizeit((tgttokenizer, tgtposttok), outputs, self.args.max_snt_len, True)
296
+
297
+ if is_nllb(tgttokenizer):
298
+ tgt_lang_list = [any_to_nllb(e.tgt_lang) for e in rawbatch]
299
+ tgt_lang_vec = tgttokenizer.convert_tokens_to_ids(tgt_lang_list)
300
+ labels['input_ids'][:, 0] = torch.tensor(tgt_lang_vec)
301
+
302
+ return labels
303
+
304
+ def tokenize_gen_batch(self, raw_batch):
305
+ tokenizer = self.coupling_specs[0].tokenizer
306
+ tokenizer.pad_token = '<|reserved_special_token_0|>'
307
+ tokenizer.padding_side = 'left'
308
+
309
+ texts = [make_gen_text(e.src_lang, e.tgt_lang, e.input, e.output, tokenizer) for e in raw_batch]
310
+
311
+ #batch = tokenizer(texts, return_tensors="pt", max_length=512, truncation=True, add_special_tokens=True, padding=True)
312
+ batch = tokenizeit((tokenizer, self.coupling_specs[0].postokenizer), texts, self.args.max_snt_len, False)
313
+
314
+ return batch
315
+
316
+ def tokenize_and_pad(self, raw_batch, src_k, tgt_k):
317
+ tgt_tokenizer = self.coupling_specs[tgt_k].tokenizer
318
+ tgt_postok = self.coupling_specs[tgt_k].postokenizer
319
+
320
+ if is_madlad(tgt_tokenizer):
321
+ inputs = [f"{any_to_madlad(e.tgt_lang)} {e.input}" for e in raw_batch]
322
+ else:
323
+ inputs = [e.input for e in raw_batch]
324
+
325
+ prep_batch_grouped = self.tokenize_input(self.coupling_specs[src_k], inputs, raw_batch)
326
+ labels = self.tokenize_output(tgt_tokenizer, tgt_postok, raw_batch)
327
+ prep_batch_grouped['labels'] = labels['input_ids']
328
+
329
+ # inject_bin_indices(prep_batch_grouped, src_k, tgt_k)
330
+
331
+ #split_prep_batch = [{k: prep_batch_grouped[k][i] for k in prep_batch_grouped}
332
+ # for i, trp in enumerate(raw_batch)]
333
+
334
+ return prep_batch_grouped
335
+
336
+ def _bins_to_tokenized_batched_cached_data(self, bins, cache_path):
337
+ shard_i = 0
338
+ batch_i = 0
339
+ total_i = 0
340
+
341
+ metainfo = []
342
+ data = []
343
+
344
+ log("Tokenizing data")
345
+
346
+ for raw_batch, src_k, tgt_k in do_bins_in_batches(bins, self.args.batch_size, self.args.sort_by_len):
347
+ batch_i += 1
348
+ if not batch_i % 10000:
349
+ log(f"Tokenized {batch_i + shard_i * self.args.shard_size} batches (shard {shard_i})")
350
+
351
+ if is_dec_only_llm(self.coupling_specs[tgt_k].tokenizer):
352
+ prepared_batch = self.tokenize_gen_batch(raw_batch)
353
+ data.append((prepared_batch, total_i))
354
+ else:
355
+ prepared_batch = self.tokenize_and_pad(raw_batch, src_k, tgt_k)
356
+ data.append((prepared_batch, src_k, tgt_k, total_i))
357
+
358
+ if batch_i >= self.args.shard_size:
359
+ shard_i += 1
360
+ batch_i = 0
361
+ fn = self._save_cache_file(data, cache_path, shard_i)
362
+ metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
363
+
364
+ del data
365
+
366
+ data = []
367
+
368
+ total_i += 1
369
+
370
+ if len(data) > 0:
371
+ fn = self._save_cache_file(data, cache_path, shard_i + 1)
372
+ metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
373
+
374
+ with open(cache_path, 'w') as f:
375
+ json.dump(metainfo, f)
376
+
377
+ del data
378
+
379
+ @staticmethod
380
+ def _save_cache_file(data, cache_location, idx):
381
+ cache_location = get_data_cache_location(cache_location, idx)
382
+
383
+ if os.path.exists(cache_location):
384
+ raise Exception("Cache already exists")
385
+
386
+ torch.save(data, cache_location)
387
+ log(f"Saved data into cache (shard {idx})")
388
+
389
+ return cache_location
390
+
391
+ def set_model_type(self):
392
+ result = None
393
+
394
+ for spec_tuple in self.coupling_specs:
395
+ this_type = get_mdl_type(spec_tuple.tokenizer)
396
+ if result is None:
397
+ result = this_type
398
+ else:
399
+ assert result == this_type, "in this implementation model types (NLLB/MADLAD/...) must be the same for all included models"
400
+
401
+ return result
402
+
403
+
404
+ def __init__(self, tr_file, coupling_specs, args):
405
+ self.args = args
406
+ self.filename = tr_file
407
+ self.coupling_specs = coupling_specs
408
+
409
+ self.exclude_set = _dev_to_dict(args.exclude_set) if args.exclude_set is not None else None
410
+
411
+ self.model_type = self.set_model_type()
412
+
413
+ # init lang to idx
414
+ self._lang_to_idx = lang_bin_mapping(coupling_specs)
415
+
416
+ def load_and_cache_data(self, cache_path):
417
+ # collect data into bins and cache it
418
+ bins = self._fill_bins()
419
+
420
+ self.report_update_stats(bins)
421
+
422
+ self._bins_to_tokenized_batched_cached_data(bins, cache_path)
423
+ """
424
+
425
+ """
426
+ class DataState:
427
+ def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None):
428
+ self.elem_idx = elem_idx
429
+ self.shard_idx = shard_idx
430
+ self.epoch_idx = epoch_idx
431
+
432
+ def state_dict(self):
433
+ return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx}
434
+
435
+ def load_state_dict(self, state_dict):
436
+ self.elem_idx = state_dict['elem_idx']
437
+ self.shard_idx = state_dict['shard_idx']
438
+ self.epoch_idx = state_dict['epoch_idx']
439
+
440
+ def copy_from(self, src_ds, epoch_idx = None):
441
+ self.shard_idx = src_ds.shard_idx
442
+ self.elem_idx = src_ds.elem_idx
443
+
444
+ if epoch_idx is not None:
445
+ self.epoch_idx = epoch_idx
446
+
447
+ def __str__(self):
448
+ return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx)
449
+
450
+ def __repr__(self):
451
+ return self.__str__()
452
+
453
+
454
+ class BatchingIterator(IterableDataset):
455
+ def __init__(self, segment_list, batch_size, tokenizer, max_len=8000):
456
+ self.data = segment_list
457
+ shuffle(self.data)
458
+
459
+ self.batch_size = batch_size
460
+ self.tokenizer = tokenizer
461
+ self.max_len = max_len
462
+
463
+ self.curr_elem_idx = 0
464
+
465
+ self.data_len = math.ceil(len(self.data) / self.batch_size)
466
+
467
+ def __len__(self):
468
+ return self.data_len
469
+
470
+ def __iter__(self):
471
+ self.curr_elem_idx = 0
472
+ return self
473
+
474
+ def where_are_we(self):
475
+ return DataState(shard_idx=0, elem_idx=self.curr_elem_idx)
476
+
477
+ def thats_where(self, data_state):
478
+ self.curr_elem_idx = data_state.elem_idx
479
+
480
+ def _get_properly_sized_segment_list(self):
481
+ i = self.curr_elem_idx * self.batch_size
482
+
483
+ segment_list = self.data[i:i + self.batch_size]
484
+ if len(segment_list) < self.batch_size:
485
+ orig_len = len(segment_list)
486
+ while len(segment_list) < self.batch_size:
487
+ segment_list.append(segment_list[randint(0, orig_len - 1)])
488
+
489
+ return segment_list
490
+
491
+ def _tokenize(self, segment_list):
492
+ #{'task': 'translate',
493
+ # 'src_segm': src_segm,
494
+ # 'tgt_segm': tgt_segm,
495
+ # 'src_lang': src_lang,
496
+ # 'tgt_lang': tgt_lang}
497
+
498
+ prepped_segm_list = [prep_llm_input(s) for s in segment_list]
499
+
500
+ self.tokenizer.pad_token = '<|reserved_special_token_0|>'
501
+ tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len,
502
+ truncation=True, add_special_tokens=True,
503
+ padding=True)
504
+ return tokenized_batch, self.curr_elem_idx + 1
505
+
506
+ def __next__(self):
507
+ if self.curr_elem_idx >= self.data_len:
508
+ raise StopIteration
509
+ else:
510
+ segment_list = self._get_properly_sized_segment_list()
511
+
512
+ batch = self._tokenize(segment_list)
513
+ self.curr_elem_idx += 1
514
+ return batch
515
+ """
516
+ """
517
+ class MultilingualDatasetIterator(IterableDataset):
518
+ def _load_metafile(self, cache_metafile):
519
+ with open(cache_metafile, 'r') as f:
520
+ self.metainfo = json.load(f)
521
+ self.data_len = sum([e['shard_size'] for e in self.metainfo])
522
+
523
+ def _init_curr_shard(self):
524
+ cache_location = self.metainfo[self.curr_shard_idx]['shard_filename']
525
+
526
+ self.curr_shard_data = torch.load(cache_location, weights_only=False)
527
+
528
+ assert len(self.curr_shard_data) == self.metainfo[self.curr_shard_idx]['shard_size']
529
+
530
+ def __init__(self, filename):
531
+ self.curr_shard_idx = 0
532
+ self.curr_elem_idx = 0
533
+ self.prev_shard_sum_len = 0
534
+
535
+ if filename is not None:
536
+ self._load_metafile(filename)
537
+
538
+ def __iter__(self):
539
+ self._init_curr_shard()
540
+ return self
541
+
542
+ def where_are_we(self):
543
+ return DataState(shard_idx=self.curr_shard_idx, elem_idx=self.curr_elem_idx)
544
+
545
+ def thats_where(self, data_state):
546
+ self.curr_shard_idx = data_state.shard_idx
547
+ self.curr_elem_idx = data_state.elem_idx
548
+ self.prev_shard_sum_len = sum([e['shard_size'] for i, e in enumerate(self.metainfo) if i < self.curr_shard_idx])
549
+
550
+ def __next__(self):
551
+ try:
552
+ result_data = self.curr_shard_data[self.curr_elem_idx]
553
+
554
+ self.curr_elem_idx += 1
555
+ except IndexError:
556
+ self.prev_shard_sum_len += self.metainfo[self.curr_shard_idx]['shard_size']
557
+ self.curr_shard_idx += 1
558
+
559
+ if self.curr_shard_idx >= len(self.metainfo):
560
+ self.__init__(None)
561
+ raise StopIteration
562
+ else:
563
+ self._init_curr_shard()
564
+ self.curr_elem_idx = 0
565
+
566
+ result_data = self.curr_shard_data[self.curr_elem_idx]
567
+
568
+ self.curr_elem_idx += 1
569
+
570
+ index_in_epoch = self.prev_shard_sum_len + self.curr_elem_idx
571
+ return result_data, index_in_epoch
572
+
573
+ def __len__(self):
574
+ return self.data_len
575
+
576
+
577
+
578
+
579
+ def dump_to_stdout():
580
+ filename = sys.argv[1]
581
+
582
+ lc_src = defaultdict(int)
583
+
584
+ tot_len = 0
585
+ tot_count = 0
586
+
587
+ for tr_pair in get_tr_pairs(filename=filename):
588
+ print(tr_pair.src_lang + "\t" + tr_pair.input + "\t" + tr_pair.tgt_lang + "\t" + tr_pair.output)
589
+
590
+ tot_len += upd_lc(lc_src, tr_pair.src_lang, tr_pair.input)
591
+ tot_len += upd_lc(lc_src, tr_pair.tgt_lang, tr_pair.output)
592
+
593
+ tot_count += 2
594
+
595
+ totes = sum(lc_src.values())
596
+ for k in sorted(lc_src):
597
+ sys.stderr.write(f"{k}: {100*lc_src[k]/totes:.1f}%\n")
598
+ sys.stderr.write(f"Avg length: {tot_len/float(tot_count):.1f}\n")
599
+
600
+
601
+ def do_stats(filename):
602
+ stats = defaultdict(int)
603
+ raw_data = load_json_data(filename)
604
+
605
+ for data in raw_data:
606
+ langs = sorted([k for k in data.keys() if data[k].strip() != ""])
607
+ stats["-".join(langs)] += 1
608
+ for k in stats:
609
+ print(k, stats[k])
610
+
611
+
612
+ def lang_from_name(filename):
613
+ return filename.split(".")[-1]
614
+
615
+
616
+ def moses_to_json(file1, file2):
617
+ result = list()
618
+
619
+ l1 = lang_from_name(file1)
620
+ l2 = lang_from_name(file2)
621
+
622
+ with open(file1, "r") as h1, open(file2, "r") as h2:
623
+ for line1 in h1:
624
+ line2 = h2.readline()
625
+
626
+ result.append({l1: line1.strip(), l2: line2.strip()})
627
+
628
+ return result
629
+
630
+
631
+ def multi_moses_to_json(output_file, init_json, input_file_tuples):
632
+ try:
633
+ with open(init_json, "r") as h:
634
+ result = json.load(h)
635
+ except:
636
+ result = list()
637
+
638
+ for input_file_tuple in input_file_tuples:
639
+ this_result = moses_to_json(*input_file_tuple)
640
+ result.append({"source": f"{input_file_tuple[0]}-{input_file_tuple[1]}", "sentences": this_result})
641
+
642
+ with open(output_file, "w") as f:
643
+ json.dump(result, f, indent=2, sort_keys=True)
644
+
645
+
646
+ def group_tuples(input_tuples):
647
+ return [(input_tuples[2 * i], input_tuples[2 * i + 1]) for i in range(int(len(input_tuples) / 2))]
648
+
649
+
650
+ def combine_two_jsons(json_target, json_addition):
651
+ for k in json_addition:
652
+ if k in json_target:
653
+ json_target[k] += json_addition[k]
654
+ else:
655
+ json_target[k] = json_addition[k]
656
+
657
+
658
+ def combine_jsons(filelist):
659
+ result = dict()
660
+
661
+ for filename in filelist:
662
+ data = json.load(open(filename))
663
+
664
+ combine_two_jsons(result, data)
665
+
666
+ json.dumps(result)
667
+
668
+
669
+ def _dev_to_dict(filename):
670
+ result = defaultdict(lambda: defaultdict(int))
671
+
672
+ for dev_sample in load_json_data(filename):
673
+ for lang in dev_sample:
674
+ if not "dia" in lang:
675
+ result[lang][dev_sample[lang]] = 1
676
+
677
+ return result
678
+
679
+
680
+ def check_cross_pollination(small_path, large_path):
681
+ print("preparing dev set")
682
+ dct = _dev_to_dict(small_path)
683
+
684
+ print("reading train set")
685
+ for train_sample in load_json_data(large_path):
686
+ for lang in train_sample:
687
+ if not "dia" in lang and lang in dct:
688
+ snt = train_sample[lang]
689
+
690
+ if snt in dct[lang]:
691
+ dct[lang][snt] += 1
692
+
693
+ print("---------------------")
694
+ print("contamination report:")
695
+ print("---------------------")
696
+ for lang in dct:
697
+ total = 0
698
+ counts = 0
699
+ freqs = 0
700
+
701
+ for snt in dct[lang]:
702
+ total += 1
703
+ if dct[lang][snt] > 1:
704
+ counts += 1
705
+ freqs += (dct[lang][snt] - 1)
706
+
707
+ print(f"{lang}: contaminated: {counts} ({100*counts/float(total):.1f}%), total occurrence: {freqs}")
708
+
709
+
710
+ def char_class(c):
711
+ lc = c.lower()
712
+ if re.match("[a-z]", lc):
713
+ return "latn"
714
+ elif re.match("[а-я]", lc):
715
+ return "cyrl"
716
+ else:
717
+ return "other"
718
+
719
+
720
+ def snt_is_fishy(snt_raw, lang, detailed=False):
721
+ snt = re.sub(r'^<[^>]+> ', '', snt_raw)
722
+
723
+ snt_db = defaultdict(int)
724
+ for c in snt:
725
+ c_c = char_class(c)
726
+ snt_db[c_c] += 1
727
+
728
+ tot = snt_db['latn'] + snt_db['cyrl']
729
+
730
+ if tot > 0:
731
+ if snt_db['latn'] / tot > 0.7:
732
+ this_is = 'latn'
733
+ elif snt_db['cyrl'] / tot > 0.7:
734
+ this_is = 'cyrl'
735
+ else:
736
+ this_is = 'mix'
737
+
738
+ should_be = any_to_nllb(lang).split("_")[1].lower()
739
+
740
+ if should_be != this_is:
741
+ return (True, this_is, should_be) if detailed else True
742
+
743
+ return (False, None, None) if detailed else False
744
+
745
+
746
+ def script_stats():
747
+ db = defaultdict(lambda: defaultdict(int))
748
+
749
+ # corp = []
750
+
751
+ for raw_line in sys.stdin:
752
+ lang, snt_raw = raw_line.strip().split("\t")
753
+
754
+ is_fishy, this_is, should_be = snt_is_fishy(snt_raw, lang, detailed=True)
755
+ if is_fishy:
756
+ print(f"{lang}: should be {should_be}, is actually {this_is}:\n{snt_raw}")
757
+
758
+
759
+
760
+ def get_full_lang(lang, tupl):
761
+ dia_key = f"{lang}-dia"
762
+
763
+ if dia_key in tupl:
764
+ return f"{lang}, {tupl[dia_key]}"
765
+ else:
766
+ return lang
767
+
768
+
769
+ def convert_json_to_json(src_json, dest_json):
770
+ raw_data = load_json_data(src_json)
771
+
772
+ output_data = []
773
+
774
+ for tupl in raw_data:
775
+ for l1 in tupl:
776
+ for l2 in tupl:
777
+ if l1 != l2 and not "dia" in l1 and not "dia" in l2:
778
+ src_segm = tupl[l1]
779
+ tgt_segm = tupl[l2]
780
+
781
+ src_lang = get_full_lang(l1, tupl)
782
+ tgt_lang = get_full_lang(l2, tupl)
783
+
784
+ output_data.append({ 'task': 'translate',
785
+ 'src_segm': src_segm,
786
+ 'tgt_segm': tgt_segm,
787
+ 'src_lang': src_lang,
788
+ 'tgt_lang': tgt_lang})
789
+
790
+ with open(dest_json, "w") as f:
791
+ json.dump(output_data, f, indent=2)
792
+ """
793
+
794
+ if __name__ == "__main__":
795
+ # check_cross_pollination(sys.argv[1], sys.argv[2])
796
+ # multi_moses_to_json(sys.argv[1], sys.argv[2], group_tuples(sys.argv[3:]))
797
+ # combine_jsons(sys.argv[1:])
798
+ # do_stats("data/train.json")
799
+
800
+ # dump_to_stdout()
801
+ # script_stats()
802
+
803
+ # convert_json_to_json(sys.argv[1], sys.argv[2])
804
+ pass
kuidastaltsutadalaamat/legacy/diffmdl.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import torch
5
+
6
+ from datetime import datetime
7
+ from transformers import AutoModelForSeq2SeqLM
8
+
9
+
10
+ def get_mdl_param_dict(mdl):
11
+ return {n: p for n, p in mdl.named_parameters()}
12
+
13
+
14
+ def log(msg):
15
+ sys.stderr.write(str(datetime.now()) + ": " + msg + '\n')
16
+
17
+
18
+ def _avg_diff(pd1, pd2, skip_emb):
19
+ result = 0
20
+ count = 0
21
+
22
+ raw_count = 0
23
+
24
+ for k in pd1.keys():
25
+ # log(k)
26
+ if not (skip_emb and "shared" in k):
27
+ delta = pd1[k] - pd2[k]
28
+
29
+ raw_count += 1
30
+
31
+ if len(delta.shape) == 1:
32
+ thiscount = delta.shape[0]
33
+ elif len(delta.shape) == 2:
34
+ thiscount = delta.shape[0] * delta.shape[1]
35
+ else:
36
+ raise Exception("Unexpected shape")
37
+ count += thiscount
38
+ deltasum = torch.sum(delta)
39
+ #log(f"DETDIFF {k}: {deltasum/thiscount}")
40
+ result += deltasum
41
+ # print(f"Count {count}, raw count {raw_count}")
42
+
43
+ return result / count
44
+
45
+
46
+ def avg_mdl_diff(m1, m2, skip_emb=False):
47
+ pd1 = get_mdl_param_dict(m1)
48
+ pd2 = get_mdl_param_dict(m2)
49
+
50
+ assert (pd1.keys() == pd2.keys())
51
+
52
+ return _avg_diff(pd1, pd2, skip_emb)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ mdl1_id = sys.argv[1]
57
+ mdl2_id = sys.argv[2]
58
+
59
+ log(f"Load mdl 1: {mdl1_id}")
60
+ model1 = AutoModelForSeq2SeqLM.from_pretrained(mdl1_id)
61
+
62
+ log(f"Load mdl 2: {mdl2_id}")
63
+ model2 = AutoModelForSeq2SeqLM.from_pretrained(mdl2_id)
64
+
65
+ log(f"Full diff: {avg_mdl_diff(model1, model2)}")
66
+
67
+ #log(f"Encoder diff: {avg_mdl_diff(model1.get_encoder(), model2.get_encoder(), True)}")
68
+ #log(f"Decoder diff: {avg_mdl_diff(model1.get_decoder(), model2.get_decoder(), True)}")
69
+ #log(f"Embedding diff: {avg_mdl_diff(model1.get_input_embeddings(), model2.get_input_embeddings())}")
kuidastaltsutadalaamat/legacy/initmodel.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ from transformers import AutoConfig, AutoModelForSeq2SeqLM
6
+
7
+ from legacy.modelops import mdl_param_count
8
+ from legacy.tokops import train_or_extend_tokenizer_and_upd_model, save_postokens
9
+
10
+ from aux import get_changed_config, CmdlineArgs
11
+ from legacy.langconv import lang_set_maybe_smugri
12
+
13
+
14
+ def just_do_main_stuff_and_avoid_global_ctx_variables():
15
+ args = CmdlineArgs("Initialize a new HuggingFace model randomly, off of an existing configuration, with possible changes",
16
+ pos_arg_list=["mdl_id", "save_location"],
17
+ kw_arg_dict={ k: None for k in ["tok_train_file", "new_langs", "vocab_size", "merge_tokenizers", "merge_tok_mdl_id",
18
+ "tok_mdl_id", "activation_dropout", "activation_function", "d_model",
19
+ "decoder_attention_heads", "decoder_ffn_dim", "decoder_layerdrop", "decoder_layers",
20
+ "encoder_attention_heads", "encoder_ffn_dim", "encoder_layerdrop", "encoder_layers",
21
+ "num_hidden_layers"] })
22
+ if not args.tok_mdl_id:
23
+ args.tok_mdl_id = args.mdl_id
24
+
25
+ if args.new_langs:
26
+ args.new_langs = lang_set_maybe_smugri(args.new_langs)
27
+
28
+ if os.path.exists(args.save_location):
29
+ raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite")
30
+
31
+ config = get_changed_config(AutoConfig.from_pretrained(args.mdl_id), args)
32
+
33
+ model = AutoModelForSeq2SeqLM.from_config(config)
34
+
35
+ tokenizer, added = train_or_extend_tokenizer_and_upd_model(args, model)
36
+
37
+ tokenizer.save_pretrained(args.save_location)
38
+ save_postokens(added, args.save_location)
39
+ model.save_pretrained(args.save_location)
40
+
41
+ mdl_size, emb_size = mdl_param_count(model)
42
+ print(f"Created model with {mdl_size} parameters" +
43
+ ("" if emb_size < 0 else f" of which {emb_size} ({100 * emb_size / mdl_size:.2f}%) are embeddings"))
44
+
45
+ if __name__ == '__main__':
46
+ just_do_main_stuff_and_avoid_global_ctx_variables()
kuidastaltsutadalaamat/legacy/langconv.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert lang. codes between different schemas
3
+
4
+ NLLB uses codes like "eng_Latn": ISO-639-3 and script
5
+
6
+ MADLAD uses codes like "<2en>": ISO-639-1 for where it's available, ISO-639-3 elsewhere
7
+ (and some codes include the script, but we'll ignore them here)
8
+
9
+ Functions at the end of the file (any_to_nllb, any_to_madlad) should
10
+ cope with a lang code in any style ('en', 'eng', 'eng_Latn', '<2en>', '<2eng>', etc)
11
+ and convert them to corresponding representations (NLLB/MADLAD).
12
+ """
13
+ from collections import defaultdict
14
+
15
+ SMUGRI_LOW = "fkv,izh,kca,koi,kpv,krl,liv,lud,mdf,mhr,mns,mrj,myv,olo,sjd,sje,sju,sma,sme,smj,smn,sms,udm,vep,vot,vro"
16
+ SMUGRI_HIGH = "deu,eng,est,fin,hun,lvs,nor,rus,swe"
17
+ SMUGRI = SMUGRI_HIGH + "," + SMUGRI_LOW
18
+
19
+ import pycountry
20
+
21
+ # madlad all codes
22
+ MADLAD_CODES = ['<2meo>', '<2lo>', '<2Grek>', '<2ada>', '<2ps>', '<2arn>', '<2Armn>', '<2to>', '<2raj>', '<2bas>', '<2ny>', '<2>', '<2zza>', '<2Thai>', '<2kaa_Latn>', '<2yap>', '<2en_xx_simple>', '<2ta>', '<2bg_Latn>', '<2mkn>', '<2lhu>', '<2gu_Latn>', '<2nzi>', '<2uz>', '<2pis>', '<2cfm>', '<2min>', '<2fon>', '<2tn>', '<2msi>', '<2sw>', '<2Tfng>', '<2teo>', '<2taj>', '<2pap>', '<2sd>', '<2Jpan>', '<2tca>', '<2sr>', '<2an>', '<2fr>', '<2gor>', '<2az>', '<2qvi>', '<2pck>', '<2cak>', '<2ltg>', '<2sah>', '<2tly_IR>', '<2ts>', '<2yo>', '<2hne>', '<2bzj>', '<2tuc>', '<2sh>', '<2da>', '<2gui>', '<2translate>', '<2et>', '<2sja>', '<2nhe>', '<2scn>', '<2dje>', '<2pt>', '<2nog>', '<2fil>', '<2mai>', '<2lb>', '<2bm>', '<2Guru>', '<2gom>', '<2hr>', '<2kg>', '<2uk>', '<2rw>', '<2izz>', '<2Telu>', '<2wuu>', '<2Deva>', '<2or>', '<2is>', '<2om>', '<2iso>', '<2sn>', '<2kjh>', '<2tbz>', '<2suz>', '<2bjn>', '<2lv>', '<2mfe>', '<2tcy>', '<2tyz>', '<2ksw>', '<2nds_NL>', '<2ms>', '<2mam>', '<2ubu>', '<2hil>', '<2mh>', '<2gl>', '<2bew>', '<2ilo>', '<2kbd>', '<2toj>', '<2quf>', '<2jam>', '<2Beng>', '<2tyv>', '<2lmo>', '<2ace>', '<2cab>', '<2sq>', '<2ug>', '<2kac>', '<2ay>', '<2mag>', '<2Arab>', '<2mrj>', '<2cs>', '<2bci>', '<2doi>', '<2zu>', '<2ndc_ZW>', '<2smt>', '<2ho>', '<2ss>', '<2he>', '<2twu>', '<2kjg>', '<2pag>', '<2Latn>', '<2gym>', '<2sus>', '<2zh_Latn>', '<2mps>', '<2lg>', '<2ko>', '<2se>', '<2guc>', '<2mr>', '<2mwl>', '<2dwr>', '<2din>', '<2ffm>', '<2maz>', '<2nia>', '<2nl>', '<2Knda>', '<2jv>', '<2noa>', '<2udm>', '<2kr>', '<2de>', '<2ar>', '<2ZW>', '<2dln>', '<2mn>', '<2ml>', '<2crh>', '<2ha>', '<2ks>', '<2qvc>', '<2fur>', '<2myv>', '<2nv>', '<2ak>', '<2Gujr>', '<2cce>', '<2nso>', '<2sg>', '<2rmc>', '<2mas>', '<2mni>', '<2frp>', '<2my>', '<2xal>', '<2th>', '<2bik>', '<2bho>', '<2inb>', '<2Mlym>', '<2oj>', '<2back_translated>', '<2tet>', '<2gsw>', '<2ff>', '<2hy>', '<2otq>', '<2el>', '<2agr>', '<2br>', '<2alt>', '<2tzo>', '<2chm>', '<2transliterate>', '<2hu>', '<2btx>', '<2vi>', '<2iba>', '<2bg>', '<2gub>', '<2li>', '<2ace_Arab>', '<2qub>', '<2ktu>', '<2bru>', '<2bbc>', '<2ca>', '<2hvn>', '<2sat_Latn>', '<2ku>', '<2shn>', '<2djk>', '<2krc>', '<2io>', '<2ig>', '<2chk>', '<2sm>', '<2Mymr>', '<2Kore>', '<2ary>', '<2lu>', '<2fa>', '<2spp>', '<2af>', '<2ti>', '<2Tibt>', '<2emp>', '<2enq>', '<2kl>', '<2be>', '<2srn>', '<2ms_Arab_BN>', '<2kri>', '<2gd>', '<2mk>', '<2syr>', '<2kmz_Latn>', '<2CA>', '<2ium>', '<2abt>', '<2ngu>', '<2tab>', '<2it>', '<2ru>', '<2ann>', '<2msm>', '<2fo>', '<2ne>', '<2akb>', '<2kv>', '<2jac>', '<2ceb>', '<2ang>', '<2tdx>', '<2tr>', '<2kbp>', '<2mgh>', '<2az_RU>', '<2acf>', '<2tg>', '<2dov>', '<2pau>', '<2mg>', '<2fuv>', '<2nn>', '<2Hant>', '<2hui>', '<2ml_Latn>', '<2ja>', '<2lus>', '<2te>', '<2qu>', '<2rom>', '<2tsg>', '<2el_Latn>', '<2cr_Latn>', '<2ur>', '<2fi>', '<2shp>', '<2brx>', '<2laj>', '<2sda>', '<2lij>', '<2st>', '<2bn>', '<2zxx_xx_dtynoise>', '<2yua>', '<2no>', '<2fr_CA>', '<2miq>', '<2trp>', '<2es>', '<2ch>', '<2mass>', '<2os>', '<2bts>', '<2ady>', '<2lrc>', '<2seh>', '<2adh>', '<2new>', '<2mak>', '<2grc>', '<2nus>', '<2tzj>', '<2nut>', '<2gu>', '<2oc>', '<2ppk>', '<2Hans>', '<2tzh>', '<2si>', '<2wo>', '<2nyu>', '<2Hebr>', '<2mad>', '<2tll>', '<2kr_Arab>', '<2pon>', '<2mbt>', '<2kw>', '<2bjn_Arab>', '<2gn>', '<2eu>', '<2dz>', '<2kaa>', '<2crh_Latn>', '<2te_Latn>', '<2ky>', '<2kn_Latn>', '<2kum>', '<2fip>', '<2ksd>', '<2sk>', '<2NL>', '<2ctd_Latn>', '<2Khmr>', '<2gbm>', '<2Cans>', '<2haw>', '<2gag>', '<2Taml>', '<2cnh>', '<2bim>', '<2ms_Arab>', '<2Thaa>', '<2kha>', '<2tvl>', '<2Cyrl>', '<2chr>', '<2dtp>', '<2ba>', '<2nan_Latn_TW>', '<2ro>', '<2ctu>', '<2Ethi>', '<2zh>', '<2ln>', '<2ve>', '<2xh>', '<2skr>', '<2ber>', '<2niq>', '<2ibb>', '<2jvn>', '<2tks>', '<2av>', '<2ahk>', '<2tk>', '<2tt>', '<2ka>', '<2tsc>', '<2km>', '<2co>', '<2id>', '<2prs>', '<2rki>', '<2kmb>', '<2ks_Deva>', '<2ify>', '<2wal>', '<2arz>', '<2amu>', '<2rm>', '<2pa>', '<2RU>', '<2ce>', '<2hi>', '<2eo>', '<2taq>', '<2ga>', '<2qxr>', '<2la>', '<2bi>', '<2rwo>', '<2dyu>', '<2zh_Hant>', '<2mt>', '<2bqc>', '<2bn_Latn>', '<2zne>', '<2szl>', '<2lt>', '<2sl>', '<2hif>', '<2alz>', '<2ber_Latn>', '<2ckb>', '<2wa>', '<2Cher>', '<2msb>', '<2gom_Latn>', '<2ru_Latn>', '<2crs>', '<2kk>', '<2gvl>', '<2qvz>', '<2bar>', '<2qup>', '<2bgp>', '<2bo>', '<2su>', '<2tzm>', '<2IR>', '<2sv>', '<2srm>', '<2rn>', '<2bus>', '<2jiv>', '<2awa>', '<2gv>', '<2knj>', '<2as>', '<2quc>', '<2en>', '<2sa>', '<2bug>', '<2quy>', '<2hi_Latn>', '<2nds>', '<2kek>', '<2mrw>', '<2kos>', '<2cy>', '<2ta_Latn>', '<2kn>', '<2nr>', '<2ape>', '<2bs>', '<2iu>', '<2nnb>', '<2Geor>', '<2rcf>', '<2meu>', '<2cac>', '<2cuk>', '<2bua>', '<2vec>', '<2so>', '<2fj>', '<2gof>', '<2koi>', '<2cv>', '<2guh>', '<2war>', '<2pl>', '<2cbk>', '<2kj>', '<2dv>', '<2mdf>', '<2fy>', '<2am>', '<2sc>', '<2taq_Tfng>', '<2mi>', '<2zap>', '<2mqy>', '<2yi>', '<2kwi>', '<2hmn>', '<2tiv>', '<2sxn>', '<2hus>', '<2ban>', '<2nij>', '<2tlh>', '<2Orya>', '<2quh>', '<2ee>', '<2ht>', '<2bum>', '<2stq>']
23
+
24
+ # NLLB all codes
25
+ NLLB_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn']
26
+
27
+ MDL_NLLB = "MDL_NLLB"
28
+ MDL_MADLAD = "MDL_MADLAD"
29
+ MDL_NEUROTOLGE = "MDL_NEUROTÕLGE"
30
+ MDL_LLAMA = "MDL_LLAMA"
31
+
32
+ _iso3_to_script = dict([nllb_code.split("_") for nllb_code in NLLB_CODES])
33
+
34
+ iso3_to_nllb = { code: f"{code}_{_iso3_to_script[code]}" for code in _iso3_to_script }
35
+
36
+ iso3_to_nllb['lav'] = "lvs_Latn"
37
+ iso3_to_nllb['nor'] = "nob_Latn"
38
+ iso3_to_nllb['yid'] = "ydd_Hebr"
39
+
40
+ for lang in "fkv izh krl liv lud olo sje sju sma sme smj smn sms vep vot vro".split():
41
+ iso3_to_nllb[lang] = f"{lang}_Latn"
42
+
43
+ for lang in "kca koi kpv mdf mhr mns mrj myv sjd udm".split():
44
+ iso3_to_nllb[lang] = f"{lang}_Cyrl"
45
+
46
+
47
+ _rev_joshi = defaultdict(lambda: "?")
48
+
49
+ for k in "krl,sma,vep,smj,smn,lud,liv,izh,vot,kca,sms,sje,mns,fkv,sju,sjd".split(","):
50
+ _rev_joshi[k] = "0"
51
+ for k in "kpv,sme,mhr,udm,olo,myv,mdf,vro,mrj,koi".split(","):
52
+ _rev_joshi[k] = "1"
53
+ for k in SMUGRI_HIGH.split(","):
54
+ _rev_joshi[k] = "2+"
55
+
56
+
57
+ def guess_script(lang):
58
+ return "Unk"
59
+
60
+
61
+ def get_high_set():
62
+ return set(SMUGRI_HIGH.split(",")) - {"deu", "swe"}
63
+
64
+
65
+ def clean_lang(raw_lang):
66
+ if "<2" in raw_lang:
67
+ raw_lang = raw_lang[2:-1]
68
+
69
+ if "_" in raw_lang:
70
+ return raw_lang.split("_")[0]
71
+ else:
72
+ return raw_lang
73
+
74
+
75
+ def any_to_base(lang):
76
+ clang = clean_lang(lang)
77
+
78
+ res = pycountry.languages.get(alpha_2=clang)
79
+
80
+ if res is None:
81
+ return pycountry.languages.get(alpha_3=clang)
82
+ else:
83
+ return res
84
+
85
+
86
+ def base_to_nllb(lang_entry=None, lang_code=None):
87
+ if lang_code is None:
88
+ lang_code = lang_entry.alpha_3
89
+
90
+ try:
91
+ #script = iso3_to_script[lang_code]
92
+ return iso3_to_nllb[lang_code]
93
+ except KeyError:
94
+ script = guess_script(lang_code)
95
+ return f"{lang_code}_{script}"
96
+
97
+
98
+ def base_to_madlad(lang_entry=None, lang_code=None):
99
+ if lang_code is None:
100
+ if hasattr(lang_entry, 'alpha_2'):
101
+ lang_code = lang_entry.alpha_2
102
+ else:
103
+ lang_code = lang_entry.alpha_3
104
+
105
+ return f"<2{lang_code}>"
106
+
107
+
108
+ def any_to_something(lang, conv_func):
109
+ base = any_to_base(lang)
110
+
111
+ if base is None:
112
+ clang = clean_lang(lang)
113
+ return conv_func(None, clang)
114
+ else:
115
+ return conv_func(base)
116
+
117
+
118
+ def run_test(src_list, tgt_list, conv_func, msg_prefix, verbose=False):
119
+ ok_count = 0
120
+ err_count = 0
121
+ fail_count = 0
122
+
123
+ for raw_c in src_list:
124
+ try:
125
+ test = conv_func(raw_c)
126
+ if test in tgt_list:
127
+ ok_count += 1
128
+ else:
129
+ fail_count += 1
130
+ if verbose:
131
+ print("FAIL:", test)
132
+ except KeyError:
133
+ err_count += 1
134
+ if verbose:
135
+ print("ERR:", raw_c)
136
+
137
+ print(f"{msg_prefix}: {ok_count} good, {fail_count} fail, {err_count} err")
138
+
139
+
140
+ def any_to_madlad(lang):
141
+ return any_to_something(lang, base_to_madlad)
142
+
143
+
144
+ def any_to_nllb(lang):
145
+ return any_to_something(lang, base_to_nllb)
146
+
147
+
148
+ def any_to_neurotolge(lang):
149
+ l = any_to_base(lang).alpha_3
150
+
151
+ return l if l != 'lvs' else 'lv'
152
+
153
+
154
+ def any_to_mdl_type(mdl_type, lang):
155
+ if mdl_type == MDL_NLLB:
156
+ return any_to_nllb(lang)
157
+ elif mdl_type == MDL_MADLAD:
158
+ return any_to_madlad(lang)
159
+ elif mdl_type is None:
160
+ return lang
161
+ elif mdl_type == MDL_LLAMA:
162
+ return lang
163
+ else:
164
+ raise ValueError(f"Unknown mdl_type {mdl_type}")
165
+
166
+ def langs_to_madlad(lang_set):
167
+ return [any_to_madlad(l) for l in lang_set] if lang_set is not None else []
168
+
169
+
170
+ def langs_to_nllb(lang_set):
171
+ return [any_to_nllb(l) for l in lang_set] if lang_set is not None else []
172
+
173
+
174
+ if __name__ == "__main__":
175
+ run_test(NLLB_CODES, MADLAD_CODES, any_to_madlad, "NLLB to MADLAD")
176
+ run_test(NLLB_CODES, NLLB_CODES, any_to_nllb, "NLLB to NLLB")
177
+ run_test(MADLAD_CODES, NLLB_CODES, any_to_nllb, "MADLAD TO NLLB")
178
+ run_test(MADLAD_CODES, MADLAD_CODES, any_to_madlad, "MADLAD TO MADLAD")
179
+
180
+
181
+ def is_nllb(object):
182
+ """
183
+ Check if the object is an NLLB model or tokenizer
184
+ """
185
+ name = object.__class__.__name__.lower()
186
+ return "m2m100" in name or "nllb" in name
187
+
188
+
189
+ def is_madlad(object):
190
+ """
191
+ Check if the object is a MADLAD model or tokenizer
192
+ """
193
+ return "t5" in object.__class__.__name__.lower()
194
+
195
+
196
+ def is_dec_only_llm(obj):
197
+ lcname = obj.__class__.__name__.lower()
198
+ return any(k in lcname for k in ["pretrainedtokenizerfast", "llama", "gemma"])
199
+
200
+
201
+ def get_mdl_type(obj):
202
+ obj = obj.module if hasattr(obj, "module") else obj
203
+
204
+ if is_nllb(obj):
205
+ return MDL_NLLB
206
+ elif is_madlad(obj):
207
+ return MDL_MADLAD
208
+ elif is_dec_only_llm(obj):
209
+ return MDL_LLAMA
210
+ else:
211
+ raise ValueError(f"Object {str(obj)[:200]} is not supported")
212
+
213
+
214
+ def langs_to_mdl_type(mdl_type, lang_set):
215
+ if mdl_type == MDL_NLLB:
216
+ return langs_to_nllb(lang_set)
217
+ elif mdl_type == MDL_MADLAD:
218
+ return langs_to_madlad(lang_set)
219
+ elif mdl_type == MDL_LLAMA:
220
+ return lang_set
221
+ else:
222
+ raise ValueError(f"Model type {mdl_type} is not supported")
223
+
224
+
225
+ def get_joshi_class(lang_code):
226
+ norm_code = any_to_base(lang_code)
227
+
228
+ if norm_code is None:
229
+ return "?"
230
+ else:
231
+ norm_code = norm_code.alpha_3
232
+
233
+ return _rev_joshi[norm_code]
234
+
235
+ def lang_set_maybe_smugri(lang_def):
236
+ if lang_def == "smugri-low":
237
+ preresult = SMUGRI_LOW
238
+ elif lang_def == "smugri-high":
239
+ preresult = SMUGRI_HIGH
240
+ elif lang_def == "smugri":
241
+ preresult = SMUGRI
242
+ else:
243
+ preresult = lang_def
244
+
245
+ return set(preresult.split(","))
246
+
247
+
248
+ def smugri_back(lang_list):
249
+ sll = sorted(lang_list)
250
+
251
+ sll_str = ",".join(sll)
252
+
253
+ if sll_str == SMUGRI_LOW:
254
+ return "smugri-low"
255
+ elif sll_str == SMUGRI_HIGH:
256
+ return "smugri-high"
257
+ elif sll_str == SMUGRI:
258
+ return "smugri-full"
259
+ else:
260
+ return sll_str
kuidastaltsutadalaamat/legacy/localizemodel.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
6
+ from modelops import mdl_param_count, is_gen_ai, hf_tok
7
+ from tokops import train_or_extend_tokenizer_and_upd_model, save_postokens
8
+ from aux import CmdlineArgs, log
9
+ from legacy.langconv import lang_set_maybe_smugri
10
+
11
+
12
+ def i_dont_like_global_scope_variable_dangers():
13
+ args = CmdlineArgs("Localize an existing HuggingFace model, possibly expanding the tokenizer",
14
+ pos_arg_list=["mdl_id", "save_location"],
15
+ kw_arg_dict={"tok_train_file": None,
16
+ "tok_mdl_id": None,
17
+ "new_langs": None,
18
+ "merge_tokenizers": 0,
19
+ "merge_tok_mdl_id": None })
20
+ if not args.tok_mdl_id:
21
+ args.tok_mdl_id = args.mdl_id
22
+
23
+ if os.path.exists(args.save_location):
24
+ raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite")
25
+
26
+ if args.new_langs:
27
+ args.new_langs = lang_set_maybe_smugri(args.new_langs)
28
+
29
+ if is_gen_ai(args.mdl_id):
30
+ model = AutoModelForCausalLM.from_pretrained(args.mdl_id, token=hf_tok)
31
+ else:
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.mdl_id, token=hf_tok)
33
+
34
+ tokenizer, added = train_or_extend_tokenizer_and_upd_model(args, model)
35
+
36
+ mdl_size, emb_size = mdl_param_count(model)
37
+ log(f"Cached model with {mdl_size} parameters" +
38
+ ("" if emb_size < 0 else f" of which {emb_size} ({100 * emb_size / mdl_size:.2f}%) are embeddings"))
39
+
40
+ tokenizer.save_pretrained(args.save_location)
41
+ save_postokens(added, args.save_location)
42
+ model.save_pretrained(args.save_location)
43
+
44
+ if __name__ == '__main__':
45
+ i_dont_like_global_scope_variable_dangers()
kuidastaltsutadalaamat/legacy/modelops.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import namedtuple
3
+
4
+ import torch
5
+
6
+ from aux import log
7
+
8
+ CouplingSpecTuple = namedtuple("CouplingSpecPair", ["lang_set", "tokenizer", "postokenizer", "model_id", "model"])
9
+
10
+ hf_tok = None
11
+ with open("../hf_token", 'r') as fh:
12
+ hf_tok = fh.read().strip()
13
+
14
+
15
+
16
+ MODULE_CONFIG_FILE = "coupled_module_config.json"
17
+ DATA_STATE_FILE = "data_state.json"
18
+ LOSS_LIST_FILE = "loss_list.json"
19
+
20
+
21
+ def mdl_param_count(model):
22
+ result = 0
23
+ embedding_size = -1
24
+
25
+ for n, p in model.named_parameters():
26
+ this_count = 1
27
+
28
+ for s in p.shape:
29
+ this_count *= s
30
+
31
+ result += this_count
32
+
33
+ # if n == "model.shared.weight":
34
+
35
+ if "shared.weight" in n:
36
+ embedding_size = this_count
37
+
38
+ return result, embedding_size
39
+
40
+ """
41
+
42
+ def to_cpl_spec(langs, model, tokenizer, postokenizer, location):
43
+ mdl_type = get_mdl_type(tokenizer)
44
+ cpl_langs = set(langs_to_mdl_type(mdl_type, langs))
45
+
46
+ return [CouplingSpecTuple(cpl_langs, tokenizer, postokenizer, location, model)]
47
+
48
+
49
+ def _save_json_config(model_dir, filename, data):
50
+ with open(os.path.join(model_dir, filename), "w") as f:
51
+ json.dump(data, f, indent=2, sort_keys=True)
52
+ f.write("\n")
53
+
54
+
55
+ def _load_json_config(model_dir, filename):
56
+ try:
57
+ with open(os.path.join(model_dir, filename), "r") as f:
58
+ return json.load(f)
59
+ except FileNotFoundError:
60
+ return None
61
+
62
+ def save_module_config(model_dir, coupling_specs):
63
+ config = [{'lang_set': list(spec.lang_set), 'model_id': spec.model_id if i > 0 else model_dir} for i, spec in enumerate(coupling_specs)]
64
+ _save_json_config(model_dir, MODULE_CONFIG_FILE, config)
65
+
66
+
67
+ def load_module_config(model_dir):
68
+ result = _load_json_config(model_dir, MODULE_CONFIG_FILE)
69
+
70
+ return result if result is not None else [{"model_id": model_dir, "lang_set": {}}]
71
+ """
72
+
73
+ def save_all_models(location, model, tokenizer, cpl_specs=None, trainer=None):
74
+ if not os.path.exists(location):
75
+ os.makedirs(location)
76
+
77
+ if trainer is not None:
78
+ trainer.save_state(location)
79
+
80
+ model.config.save_pretrained(location)
81
+ model.generation_config.save_pretrained(location)
82
+
83
+ tokenizer.save_pretrained(location)
84
+ """
85
+ if cpl_specs is not None:
86
+ save_module_config(location, cpl_specs)
87
+ """
88
+
89
+
90
+ def report_devices(msg = "", accelerator = None, mdl = None):
91
+ if torch.cuda.is_available():
92
+ # Get the visible devices from CUDA
93
+ visible_devices = torch.cuda.device_count()
94
+
95
+ #log(f"Number of visible GPUs: {visible_devices}")
96
+ msg = f"{msg:30} {visible_devices} GPUs:"
97
+
98
+ # List the actual GPUs being used
99
+ gpu_names = [torch.cuda.get_device_name(i) for i in range(visible_devices)]
100
+ for i, name in enumerate(gpu_names):
101
+ mem_alloc = torch.cuda.memory_allocated(i) / 1024**2
102
+ mem_res = torch.cuda.memory_reserved(i) / 1024**2
103
+
104
+ if mem_alloc > 0.01 or mem_res > 0.01:
105
+ msg += f" {i}: alloc {mem_alloc:.2f} Mb / res {mem_res:.2f} Mb;"
106
+
107
+ log(msg, accelerator=accelerator)
108
+ elif accelerator is not None and accelerator.device.type == "mps":
109
+ mem_alloc = torch.mps.current_allocated_memory() / 1024**2
110
+ log(f"{msg:30} device being used: {accelerator.device}, mem alloc: {mem_alloc} Mb", accelerator=accelerator)
111
+ else:
112
+ log(f"No acceleration")
113
+
114
+ #if mdl is not None:
115
+ # log(f"Model device: {mdl.device}", accelerator=accelerator)
116
+
117
+
118
+ def is_gen_ai(mdl_id):
119
+ lc = mdl_id.lower()
120
+ return not ("madlad" in lc or "nllb" in lc or "m2m" in lc or "bart" in lc)
121
+
122
+
kuidastaltsutadalaamat/legacy/oldtrainllm.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import json
5
+ import torch
6
+ import sys
7
+
8
+ from accelerate import Accelerator
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from accel import SwitchingAccelerator
12
+ from modelops import hf_tok, save_all_models
13
+
14
+ from aux import log, CmdlineArgs
15
+ from data import do_list_in_batches
16
+
17
+
18
+ def _cmdline_args():
19
+ description = """Train or tune decoder models"""
20
+
21
+ result = CmdlineArgs(description,
22
+ pos_arg_list=["mdl_id", "save_location", "train_file"],
23
+ pos_arg_types=[str, str, str],
24
+ kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
25
+ "batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
26
+ "max_length": 3000 })
27
+
28
+ # if the directory args.save_location already exists, raise an exception:
29
+ if not result.continue_training and os.path.exists(result.save_location):
30
+ raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")
31
+
32
+ if result.nr_sents_per_gpu == 0:
33
+ result.nr_sents_per_gpu = result.batch_size
34
+
35
+ return result
36
+
37
+
38
+ def load_json_list(json_file):
39
+ with open(json_file, "r") as f:
40
+ data = json.load(f)
41
+ return data
42
+
43
+
44
+ def load_hf_model(mdl_id, accelerator=None):
45
+ if accelerator is None:
46
+ model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
47
+ else:
48
+ model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device)
49
+ return model
50
+
51
+
52
+ def load_hf_tokenizer(mdl_id):
53
+ tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
54
+ return tokenizer
55
+
56
+
57
+ def _no_globals_main():
58
+ accelerator = Accelerator()
59
+
60
+ try:
61
+ args = _cmdline_args()
62
+
63
+ log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}")
64
+ log("loading model", accelerator=accelerator)
65
+ mdl = load_hf_model(args.mdl_id)
66
+
67
+ log("loading tokenizer", accelerator=accelerator)
68
+ tok = load_hf_tokenizer(args.mdl_id)
69
+
70
+ log("loading data", accelerator=accelerator, all_threads=True)
71
+ train_set = load_json_list(args.train_file)
72
+
73
+ log("training", accelerator=accelerator)
74
+
75
+ acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator)
76
+ upd_model = acc_trainer.train()
77
+
78
+ log("saving", accelerator=accelerator)
79
+ save_all_models(args.save_location, upd_model, tok)
80
+ except Exception as e:
81
+ # in multiprocess scenarios it is hard to read the stack trace, so just show one:
82
+ if accelerator.is_main_process:
83
+ raise e
84
+
85
+
86
+ if __name__ == "__main__":
87
+ #sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split()
88
+ #sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split()
89
+
90
+ _no_globals_main()
kuidastaltsutadalaamat/legacy/parasynth.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import json
5
+
6
+ from collections import defaultdict
7
+
8
+ from benchmark import get_hyp_cache_dir, translate_all_hyps
9
+ from inference import load_and_init_module_config
10
+ from legacy.langconv import get_high_set, any_to_mdl_type, get_mdl_type
11
+ from accelerate import Accelerator
12
+ from aux import log
13
+
14
+
15
+ def load_raw_data(path):
16
+ with open(path, 'r') as f:
17
+ return json.load(f)
18
+
19
+
20
+ def save_raw_data(path, data):
21
+ with open(path, 'w') as f:
22
+ json.dump(data, f, indent=2)
23
+
24
+
25
+ def apply_func_to_hires_snts(snt_set, func):
26
+ high_set = get_high_set()
27
+
28
+ for tupl in snt_set:
29
+ langs = [k for k in tupl if not "-dia" in k and k in high_set]
30
+
31
+ if langs:
32
+ revlangs = high_set - set(langs)
33
+
34
+ for revlang in revlangs:
35
+ for lang in langs:
36
+ # translate sentences tupl[lang] from lang to revlang
37
+ # OR
38
+ # add the result as tupl[revlang]
39
+ func(tupl, lang, revlang)
40
+
41
+
42
+ def report_part_stats(part, part_index, num_parts):
43
+ hi_set = get_high_set()
44
+
45
+ num_snts = len(part['sentences'])
46
+ hires_langs = {k for k in part['sentences'][0] if "dia" not in k and k in hi_set}
47
+ num_hires_langs = len(hires_langs)
48
+ langs_to_do = hi_set - hires_langs
49
+ num_to_translate = num_hires_langs * len(langs_to_do)
50
+
51
+ log(f"Part {part_index + 1}/{num_parts}; {num_snts} sentences, num hires: {num_hires_langs}, to translate: {num_to_translate}")
52
+
53
+ return num_snts * num_hires_langs, num_snts * num_to_translate
54
+
55
+
56
+ def add_hires_synth_data(mdl_id, corpus_in, corpus_out, dry=False):
57
+ accelerator = Accelerator()
58
+
59
+ log("Loading data", accelerator)
60
+ data = load_raw_data(corpus_in)
61
+
62
+
63
+ log("Loading model", accelerator)
64
+ if dry:
65
+ main_model, module_config = None, None
66
+ mdl_type = None
67
+ else:
68
+ main_model, module_config = load_and_init_module_config(mdl_id, accelerator)
69
+ mdl_type = get_mdl_type(main_model)
70
+
71
+ if accelerator.is_main_process:
72
+ _ = get_hyp_cache_dir(mdl_id, create=True)
73
+ l = len(data)
74
+
75
+ tot_snt = 0
76
+ tot_tr = 0
77
+
78
+ for i, part in enumerate(data):
79
+ tr_dict = defaultdict(lambda: defaultdict(lambda: None))
80
+
81
+ num_snt, num_tr = report_part_stats(part, i, l)
82
+ tot_snt += num_snt
83
+ tot_tr += num_tr
84
+
85
+ if not dry:
86
+ def _transfer(tup, src, tgt):
87
+ srcm = any_to_mdl_type(mdl_type, src)
88
+ tgtm = any_to_mdl_type(mdl_type, tgt)
89
+
90
+ lp = f"{srcm}-{tgtm}"
91
+ inp_snt = tup[src]
92
+
93
+ # this "touches" the value: if it was not there, now it is None
94
+ # and if it was there, then we use it
95
+ if tr_dict[lp][inp_snt] is not None:
96
+ tup[tgt] = tr_dict[lp][inp_snt]
97
+
98
+ # collect sentences to translate
99
+ apply_func_to_hires_snts(part['sentences'], _transfer)
100
+
101
+ in_tr_dict_list = { lp: sorted(tr_dict[lp].items()) for lp in tr_dict }
102
+
103
+ log(f"Translating part {i+1}/{l}", accelerator)
104
+ #translate_cache_dict(tr_dict, mdl_id, module_config, corpus_in, accelerator)
105
+ translate_all_hyps(in_tr_dict_list, module_config, mdl_id, f"{corpus_in}-{i}", accelerator)
106
+
107
+ log(f"Collecting part {i+1}/{l}", accelerator)
108
+ out_tr_dict_list = translate_all_hyps(in_tr_dict_list, module_config, mdl_id, corpus_in)
109
+
110
+ for lp in out_tr_dict_list:
111
+ for inp, outp in out_tr_dict_list[lp]:
112
+ tr_dict[lp][inp] = outp
113
+
114
+ # put translations back into data structure
115
+ log(f"Integrating part {i+1}/{l}", accelerator)
116
+ apply_func_to_hires_snts(part['sentences'], _transfer)
117
+
118
+ log(f"Total sentences: {tot_snt}, total to generate: {tot_tr}", accelerator)
119
+ if not dry:
120
+ log("Saving data", accelerator)
121
+ save_raw_data(corpus_out, data)
122
+
123
+ if __name__ == '__main__':
124
+ try:
125
+ mdl_id_param = sys.argv[1]
126
+ corpus_param = sys.argv[2]
127
+ corpus_output_param = sys.argv[3]
128
+ except IndexError:
129
+ mdl_id_param = "models/nllb600m"
130
+ corpus_param = "data/flt.json"
131
+ corpus_output_param = "data/fltout.json"
132
+
133
+ try:
134
+ _ = sys.argv[4]
135
+ dry_run = True
136
+ except IndexError:
137
+ dry_run = False
138
+
139
+ add_hires_synth_data(mdl_id_param, corpus_param, corpus_output_param, dry_run)
kuidastaltsutadalaamat/legacy/pretok.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+
5
+ from data import MultilingualBatchingCachingDataset
6
+ from aux import log, CmdlineArgs
7
+ from legacy.langconv import lang_set_maybe_smugri
8
+ from modelops import to_cpl_spec
9
+ from tokops import load_tokenizer
10
+
11
+ """
12
+ def load_hf_tok(mdl_id, tok_id=None, verbose=False):
13
+ if tok_id is None:
14
+ tok_id = mdl_id
15
+
16
+ tokenizer = AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)
17
+
18
+ return tokenizer
19
+ """
20
+
21
+
22
+ def _cmdline_args():
23
+ description = """Pre-tokenize data and cache the results"""
24
+
25
+ pos_args = ["mdl_id", "train_file", "langs", "cache_path"]
26
+ pos_types = [str, str, lang_set_maybe_smugri, str]
27
+
28
+ kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "batch_size": 16, "shard_size": 100000,
29
+ "exclude_set": None, "max_snt_len": 1024, "sort_by_len": False }
30
+
31
+ #post-process the arguments
32
+ args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)
33
+
34
+ if args.anchor_langs is not None:
35
+ args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)
36
+
37
+ # if the directory args.save_location already exists, raise an exception:
38
+ if os.path.exists(args.cache_path):
39
+ raise Exception(f"Save location '{args.cache_path}' already exists, don't want to overwrite")
40
+
41
+ log(f"Launched as {args}")
42
+
43
+ return args
44
+
45
+
46
+ def oh_look_another_do_main_function():
47
+ args = _cmdline_args()
48
+
49
+ log("loading tokenizer")
50
+ main_tokenizer, main_postok = load_tokenizer(args.mdl_id) #load_hf_tok(args.mdl_id, verbose=True)
51
+
52
+ coupling_specs = to_cpl_spec(args.langs, None, main_tokenizer, main_postok, None)
53
+
54
+ if args.anchor_mdl_id is not None:
55
+ log("loading anchor model tokenizer")
56
+ anchor_tokenizer, anc_postok = load_tokenizer(args.anchor_mdl_id)
57
+
58
+ coupling_specs += to_cpl_spec(args.anchor_langs, None, anchor_tokenizer, anc_postok, None)
59
+
60
+ mbd = MultilingualBatchingCachingDataset(args.train_file, coupling_specs, args)
61
+ mbd.load_and_cache_data(args.cache_path)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ oh_look_another_do_main_function()
kuidastaltsutadalaamat/legacy/testmem.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch.optim
4
+ import sys
5
+ import subprocess
6
+ import random
7
+
8
+ from accelerate import Accelerator
9
+ from transformers import AutoModelForCausalLM, get_scheduler, AutoModelForSeq2SeqLM
10
+ from datasets import load_dataset
11
+
12
+ from aux import CmdlineArgs, log
13
+ from legacy.langconv import is_dec_only_llm
14
+ from modelops import report_devices, hf_tok
15
+ from tokops import load_tokenizer, tokenizeit
16
+
17
+
18
+ def run_test(mdl_id, batch_sizes, ctxlen, acc):
19
+ #state = AcceleratorState()
20
+ log(f"Num proc: {acc.num_processes}, proc ID: {acc.process_index}")
21
+
22
+ report_devices("Initial state:", accelerator=acc)
23
+
24
+ t, pt = load_tokenizer(mdl_id) # AutoTokenizer.from_mpretrained(mdl_id, token=hf_tok)
25
+ if is_dec_only_llm(t):
26
+ m = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
27
+ log("Decoder-only model")
28
+ else:
29
+ m = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
30
+ log("Encoder-decoder model")
31
+
32
+ opt = torch.optim.AdamW(m.parameters(), lr=1e-5)
33
+ lrs = get_scheduler("linear", optimizer=opt, num_warmup_steps=100, num_training_steps=1000)
34
+ opt, lrs, m = acc.prepare(opt, lrs, m)
35
+
36
+ report_devices("Models in VRAM:", accelerator=acc)
37
+ m.train()
38
+
39
+ ds = load_dataset("Helsinki-NLP/europarl", "en-et")
40
+ max_idx = len(ds['train'])
41
+
42
+ for batch_size in batch_sizes:
43
+ print("")
44
+
45
+ for _ in range(10):
46
+ inp_idx = random.randint(0, max_idx-batch_size)
47
+
48
+ raw_inp = [ds['train'][i]['translation']['et'] for i in range(inp_idx, inp_idx+batch_size)]
49
+
50
+ if is_dec_only_llm(t):
51
+ inp = tokenizeit((t, pt), raw_inp, ctxlen, is_target=False, is_llm=True)
52
+ else:
53
+ inp = tokenizeit((t, pt), raw_inp, ctxlen, is_target=False, is_llm=False)
54
+
55
+ inp['labels'] = inp['input_ids']
56
+ inp.to(m.device)
57
+
58
+ outputs = m(**inp)
59
+
60
+ loss = outputs.loss
61
+ report_devices(f"While training:", accelerator=acc)
62
+ log(f"Batches : {[inp[k].size() for k in 'input_ids labels attention_mask'.split(' ')]}")
63
+ log(f"Batch total: {sum([inp[k].size()[0] * inp[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')])}")
64
+
65
+ try:
66
+ if acc.is_main_process:
67
+ result = subprocess.run(['rocm-smi'], capture_output=True, text=True)
68
+ print(result.stdout)
69
+ except:
70
+ pass
71
+
72
+ acc.backward(loss)
73
+ acc.wait_for_everyone()
74
+
75
+ report_devices(f"Models gradients in VRAM, batch size {batch_size}:", accelerator=acc)
76
+
77
+ print(f"Testing {mdl_id} with batch size {batch_size}: success!")
78
+
79
+
80
+
81
+ if __name__ == "__main__":
82
+ if len(sys.argv) > 1:
83
+ args = CmdlineArgs("Test the VRAM usage by a model with different batch sizes, comma-separated",
84
+ pos_arg_list=["mdl_id", "batch_sizes"],
85
+ kw_arg_dict={"ctxlen": 2048})
86
+
87
+ clean_bs = [int(bs) for bs in args.batch_sizes.split(",")]
88
+ mdl_id = args.mdl_id
89
+ ctxlen = args.ctxlen
90
+ else:
91
+ mdl_id = "meta-llama/Llama-3.2-1B"
92
+ clean_bs = [16, 32, 64]
93
+ ctxlen = 2048
94
+
95
+ acc = Accelerator()
96
+ try:
97
+ run_test(mdl_id, clean_bs, ctxlen, acc)
98
+ except Exception as e:
99
+ if acc.is_main_process:
100
+ raise e
kuidastaltsutadalaamat/legacy/tokops.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sentencepiece as spm
4
+ import json
5
+
6
+ from transformers import AutoTokenizer
7
+ from transformers.models.nllb import NllbTokenizer
8
+ from transformers.models.t5 import T5Tokenizer
9
+ from collections import defaultdict
10
+
11
+ from aux import log
12
+ from legacy.langconv import langs_to_madlad, langs_to_nllb, is_nllb, is_madlad, is_dec_only_llm
13
+ from modelops import hf_tok
14
+
15
+
16
+ def test_tok(tok, snt, lang):
17
+ tok.src_lang = lang
18
+ out = tok(text = snt)
19
+ print(out['input_ids'])
20
+ print(tok.tokenize(snt))
21
+ print(tok.convert_ids_to_tokens(out['input_ids']))
22
+ print("-")
23
+
24
+
25
+ def get_stupid_correction(mdl_id):
26
+ l_mdl_id = mdl_id.lower()
27
+
28
+ if "m2m" in l_mdl_id:
29
+ correction = 108
30
+ elif "nllb" in l_mdl_id:
31
+ correction = 2
32
+ else:
33
+ correction = 0
34
+
35
+ return correction
36
+
37
+
38
+ def tsv_to_json_vocab(location):
39
+ new_location = location + ".json"
40
+
41
+ with open(location, "r") as f, open(new_location, "w") as w:
42
+ idx_dict = { "<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3 }
43
+
44
+ for line in f:
45
+ tok, _ = line.strip().split("\t")
46
+ if tok not in idx_dict:
47
+ idx_dict[tok] = len(idx_dict)
48
+
49
+ json.dump(idx_dict, w)
50
+
51
+ return new_location
52
+
53
+
54
+ def get_unk_toks(tokenizer, corpus, verbose=False):
55
+ unk_id = tokenizer.unk_token_id
56
+ unk_toks = defaultdict(int)
57
+
58
+ all_toks = set()
59
+
60
+ total_count = 0
61
+ unk_count = 0
62
+
63
+ with open(corpus, "r", encoding='utf-8') as f:
64
+ for snt in f:
65
+ toks = tokenizer.tokenize(snt.strip())
66
+ ids = tokenizer.convert_tokens_to_ids(toks)
67
+
68
+ for t, i in zip(toks, ids):
69
+ if i == unk_id:
70
+ unk_toks[t] += 1
71
+ unk_count += 1
72
+ total_count += 1
73
+
74
+ all_toks.add(t)
75
+
76
+ if verbose:
77
+ print(f"Tokenizer vocab size: {tokenizer.vocab_size}, nr of actually used tokens: {len(all_toks)}")
78
+ print(f"Corpus token count: {total_count}, UNK token percentage: {100*unk_count/total_count:.2f}%")
79
+
80
+ return list(unk_toks)
81
+
82
+
83
+ def get_top_toks(tokenizer, corpus, num_top_toks):
84
+ freq_count = defaultdict(int)
85
+
86
+ with open(corpus, "r", encoding='utf-8') as f:
87
+ for snt in f:
88
+ toks = tokenizer.tokenize(snt.strip())
89
+
90
+ for t in toks:
91
+ freq_count[t] += 1
92
+
93
+ sorted_freq_count = sorted(freq_count.keys(), key=lambda x: -freq_count[x])
94
+
95
+ return sorted_freq_count[:num_top_toks]
96
+
97
+
98
+ def extend_tok_langs(tokenizer, lang_set_raw):
99
+ if is_nllb(tokenizer):
100
+ lang_set = langs_to_nllb(lang_set_raw)
101
+ elif is_madlad(tokenizer):
102
+ lang_set = langs_to_madlad(lang_set_raw)
103
+ elif is_dec_only_llm(tokenizer):
104
+ return
105
+ else:
106
+ raise NotImplementedError
107
+
108
+ if 'additional_special_tokens' in tokenizer.special_tokens_map:
109
+ orig_langs = tokenizer.special_tokens_map['additional_special_tokens']
110
+ orig_lang_set = set(orig_langs)
111
+
112
+ addable_langs = list(set(lang_set) - orig_lang_set)
113
+ else:
114
+ orig_langs = []
115
+ addable_langs = lang_set
116
+
117
+ tokenizer.add_special_tokens({'additional_special_tokens': orig_langs + addable_langs})
118
+
119
+
120
+ def wrap_tok_in_correct_class(location, base_model_id, lang_set):
121
+ l_base_mdl_id = base_model_id.lower()
122
+
123
+ if "nllb" in l_base_mdl_id:
124
+ nllb_lang_set = langs_to_nllb(lang_set)
125
+ return NllbTokenizer(location + ".model", additional_special_tokens=nllb_lang_set)
126
+
127
+ elif "madlad" in l_base_mdl_id or "t5" in l_base_mdl_id:
128
+ madlad_lang_set = langs_to_madlad(lang_set)
129
+ return T5Tokenizer(location + ".model", additional_special_tokens=madlad_lang_set)
130
+ else:
131
+ raise ValueError("Incompatible model type for tokenizer")
132
+
133
+
134
+ def remove_tmp_spm_files(location):
135
+ for tmp_file in (".vocab", ".model"):
136
+ os.remove(location + tmp_file)
137
+
138
+
139
+ def learn_spm_tokenizer(corpus, save_location, base_model_id, vocab_size, lang_set=None):
140
+ tmp_location = os.path.join(save_location, "sentencepiece.bpe.tmp")
141
+ os.makedirs(save_location, exist_ok=True)
142
+
143
+ spm.SentencePieceTrainer.train(input=corpus, model_prefix=tmp_location, vocab_size=vocab_size)
144
+
145
+ tok = wrap_tok_in_correct_class(tmp_location, base_model_id, lang_set)
146
+
147
+ remove_tmp_spm_files(tmp_location)
148
+
149
+ return tok
150
+
151
+
152
+ def do_new_tok(tokargs):
153
+ correction = get_stupid_correction(tokargs.mdl_id)
154
+ voc_size = tokargs.vocab_size - correction
155
+ location = tokargs.save_location
156
+
157
+ return learn_spm_tokenizer(tokargs.tok_train_file, location, base_model_id=tokargs.tok_mdl_id,
158
+ vocab_size=voc_size, lang_set=tokargs.new_langs)
159
+
160
+
161
+ def remove_known_toks(toks, tokenizer):
162
+ return [t for t in toks if not t in tokenizer.get_vocab()]
163
+
164
+
165
+ def _handle_new_tokenizer(args):
166
+ assert args.new_langs is not None, "lang_set must be provided"
167
+ assert args.tok_train_file is not None, "tok_train_file must be provided"
168
+ args.vocab_size = int(args.vocab_size)
169
+
170
+ log("Training new tokenizer")
171
+ tokenizer = do_new_tok(args)
172
+
173
+ return tokenizer
174
+
175
+
176
+ def get_postoken_filename(save_location):
177
+ return os.path.join(save_location, "postokens.json")
178
+
179
+
180
+ def save_postokens(added_tokens, location):
181
+ if added_tokens is not None:
182
+ os.makedirs(location, exist_ok=True)
183
+ with open(get_postoken_filename(location), "w") as f:
184
+ json.dump(added_tokens, f)
185
+
186
+
187
+ def _handle_adding_tokens(tokenizer, toks_to_add, args):
188
+ if len(toks_to_add) == 0:
189
+ return None
190
+
191
+ log(f"Adding tokens: {toks_to_add}")
192
+
193
+ base_idx = len(tokenizer)
194
+
195
+ added_tok_dict = { t: (base_idx + i) for i, t in enumerate(toks_to_add) }
196
+ added_tok_rev_dict = { int(i): t for t, i in added_tok_dict.items() }
197
+
198
+ comb_dict = { 'tok2idx': added_tok_dict, 'idx2tok': added_tok_rev_dict }
199
+
200
+ save_postokens(comb_dict, args.save_location)
201
+
202
+ return comb_dict
203
+
204
+
205
+ def _handle_existing_tokenizer(args):
206
+ log("Reusing existing tokenizer")
207
+ tokenizer, added_tokens = load_tokenizer(args.tok_mdl_id)
208
+
209
+ if args.new_langs is not None:
210
+ log("Extending existing tokenizer with languages")
211
+ extend_tok_langs(tokenizer, args.new_langs)
212
+
213
+ if args.merge_tokenizers or args.merge_tok_mdl_id:
214
+ """
215
+ assert args.tok_train_file is not None, "For merging tokenizers a text file must be provided" \
216
+ + " to find the top N tokens to merge"
217
+ assert args.merge_tokenizers is not None and args.merge_tok_mdl_id is not None, \
218
+ "Both merge_tokenizers and merge_tok_mdl_id must be provided"
219
+ """
220
+ raise NotImplementedError("Merging is currently not supported")
221
+
222
+ added_tok_count = 0
223
+
224
+ if args.tok_train_file:
225
+ if args.merge_tokenizers:
226
+ """
227
+ merge_tok_max = int(args.merge_tokenizers)
228
+ log(f"Extending existing tokenizer ({args.merge_tok_mdl_id}) with up to {merge_tok_max} top tokens" +
229
+ f" from another tokenizer and corpus ({args.tok_train_file})")
230
+ new_tok = AutoTokenizer.from_pretrained(args.merge_tok_mdl_id, token=hf_tok)
231
+ toks_to_maybe_add = get_top_toks(new_tok, args.tok_train_file, merge_tok_max)
232
+ """
233
+ raise NotImplementedError("Merging is currently not supported")
234
+
235
+ else:
236
+ log(f"Extending existing tokenizer with UNK tokens from corpus ({args.tok_train_file})")
237
+ toks_to_maybe_add = get_unk_toks(tokenizer, args.tok_train_file, verbose=True)
238
+
239
+ toks_to_add = remove_known_toks(toks_to_maybe_add, tokenizer)
240
+ added_tok_count = len(toks_to_add)
241
+ added_tokens = _handle_adding_tokens(tokenizer, toks_to_add, args)
242
+
243
+ return tokenizer, added_tok_count, added_tokens
244
+
245
+
246
+ def train_or_extend_tokenizer_and_upd_model(args, model):
247
+ if hasattr(args, "vocab_size") and args.vocab_size:
248
+ # train a new sentence-piece tokenizer
249
+ tokenizer = _handle_new_tokenizer(args)
250
+ added_tok_count = 0
251
+ added_dict = None
252
+ else:
253
+ # save the pre-trained model's tokenizer, possibly adding new languages and tokens
254
+ tokenizer, added_tok_count, added_dict = _handle_existing_tokenizer(args)
255
+
256
+ upd_amt = get_stupid_correction(args.mdl_id)
257
+ new_len = len(tokenizer) + added_tok_count
258
+
259
+ model.resize_token_embeddings(new_len + upd_amt)
260
+
261
+ return tokenizer, added_dict
262
+
263
+
264
+ def load_tokenizer(tok_mdl_id):
265
+ orig_tokenizer = AutoTokenizer.from_pretrained(tok_mdl_id, token=hf_tok)
266
+
267
+ postoken_file = get_postoken_filename(tok_mdl_id)
268
+ if os.path.exists(postoken_file):
269
+ with open(postoken_file, "r") as f:
270
+ postokens = json.load(f)
271
+ else:
272
+ postokens = None
273
+
274
+ return orig_tokenizer, postokens
275
+
276
+
277
+ def tokenize_batch(tokenizer, sntlist, maxlen=8000):
278
+ #tokenizer.pad_token = '<|reserved_special_token_0|>'
279
+ tokenizer.pad_token = tokenizer.eos_token
280
+ output = tokenizer(sntlist, return_tensors="pt", max_length=maxlen, truncation=True, add_special_tokens=True,
281
+ padding=True)
282
+ output["labels"] = output["input_ids"].detach().clone()
283
+ return output
284
+
285
+ """
286
+
287
+ def detokenizeit(toktup, tok_ids):
288
+ #return toktup[0].decode(tok_ids, skip_special_tokens=True)
289
+
290
+ toks = []
291
+
292
+ for tok_id_tensor in tok_ids:
293
+ tok_id = tok_id_tensor.item()
294
+ try:
295
+ if tok_id not in toktup[0].added_tokens_decoder:
296
+ toks.append(toktup[0].convert_ids_to_tokens(tok_id))
297
+ except IndexError:
298
+ toks.append(toktup[1]['idx2tok'][str(tok_id)])
299
+
300
+ result = "".join(toks).replace("▁", " ")[1:]
301
+
302
+ return result, toks
303
+
304
+
305
+ def detokenizemany(toktup, tok_mtx):
306
+ result = [detokenizeit(toktup, tok_ids)[0] for tok_ids in tok_mtx]
307
+
308
+ return result
309
+
310
+
311
+
312
+
313
+
314
+ def run_tokenizer_testing():
315
+ args = CmdlineArgs("Test a tokenizer: tokenize & de-tokenize some text and check if these match",
316
+ pos_arg_list=["tok_mdl_id", "txt_file"])
317
+
318
+ #tokenizer = AutoTokenizer.fromm_pretrained(args.tok_mdl_id, token=hf_tok) if os.path.exists()
319
+ toktup = load_tokenizer(args.tok_mdl_id)
320
+
321
+ success = 0
322
+ failure = 0
323
+
324
+ with open(args.txt_file, "r", encoding="utf-8") as f:
325
+ snts = f.read().split("\n")
326
+
327
+ toks = tokenizeit(toktup, snts, 1024, False)
328
+
329
+ for i, snt in enumerate(snts):
330
+ tok_ids = toks['input_ids'][i]
331
+
332
+ #detoks = toktup[0].decode(tok_ids, skip_special_tokens=True)
333
+ detoks, tok_strs = detokenizeit(toktup, tok_ids)
334
+
335
+ if detoks != snt:
336
+ failure += 1
337
+ #log(f"Tokens: {toktup[0].convert_ids_to_tokens(tok_ids)}")
338
+ log(f"Tokens: {tok_strs}")
339
+ log(f"Test failed:\n{snt} !=\n{detoks}")
340
+ else:
341
+ success += 1
342
+ i += 1
343
+
344
+ log(f"Test result: {success} successful / {failure} failed")
345
+
346
+
347
+ if __name__ == "__main__":
348
+ sys.argv = ['', 'models/nllbxt', 'data/tok-test.txt']
349
+ run_tokenizer_testing()
350
+ """
kuidastaltsutadalaamat/legacy/trainmodel.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import torch
5
+
6
+ from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
7
+
8
+ from legacy.accel import SwitchingAccelerator
9
+ from accelerate import Accelerator
10
+ from data import MultilingualDatasetIterator
11
+ from aux import log, CmdlineArgs
12
+ from legacy.langconv import lang_set_maybe_smugri, is_dec_only_llm
13
+ from modelops import mdl_param_count, to_cpl_spec, hf_tok
14
+ from tokops import load_tokenizer
15
+
16
+
17
+ def freeze_model(model):
18
+ for n, p in model.named_parameters():
19
+ p.requires_grad = False
20
+
21
+
22
+ def load_hf_mdl_and_tok(mdl_id, tok_id=None, verbose=False):
23
+ if tok_id is None:
24
+ tok_id = mdl_id
25
+
26
+ tokenizer = load_tokenizer(tok_id) # AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)
27
+
28
+ if is_dec_only_llm(tokenizer[0]):
29
+ model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
30
+ else:
31
+ model = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
32
+
33
+ if verbose:
34
+ mdl_size, _ = mdl_param_count(model)
35
+ log(f"Loaded {mdl_id} with {mdl_size} params, voc size {model.config.vocab_size}")
36
+
37
+ return model, tokenizer
38
+
39
+
40
+ def _cmdline_args():
41
+ description = """Train or tune models"""
42
+
43
+ pos_args = ["mdl_id", "save_location", "train_pretok_file", "langs"]
44
+ pos_types = [str, str, str, lang_set_maybe_smugri]
45
+
46
+ kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "continue_training": False,
47
+ "save_steps": 100000, "lr": 1.5e-5, "nr_snts_in_batch": 0, "nr_words_in_batch": 0,
48
+ "log_steps": 100, "epochs": 4 }
49
+
50
+ #post-process the arguments
51
+ args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)
52
+
53
+ if args.anchor_langs is not None:
54
+ args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)
55
+
56
+ if (args.nr_snts_in_batch > 0) == (args.nr_words_in_batch > 0):
57
+ raise Exception(f"Specify the batch size either in words or in sentences.")
58
+
59
+ # if the directory args.save_location already exists, raise an exception:
60
+ if not args.continue_training and os.path.exists(args.save_location):
61
+ raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite.")
62
+
63
+ return args
64
+
65
+
66
+ def yes_i_called_this_function_do_main():
67
+ args = _cmdline_args()
68
+ tmp_acc = Accelerator()
69
+
70
+ log(f"Num proc: {tmp_acc.num_processes}, proc ID: {tmp_acc.process_index}")
71
+
72
+ log("loading coupled model and tokenizer", accelerator=tmp_acc)
73
+ main_model, main_tokenizer = load_hf_mdl_and_tok(args.mdl_id, verbose=True)
74
+
75
+ coupling_specs = to_cpl_spec(args.langs, main_model, main_tokenizer[0], main_tokenizer[1], args.save_location)
76
+
77
+ if args.anchor_mdl_id:
78
+ log("loading anchor model and tokenizer", accelerator=tmp_acc)
79
+ anchor_model, anchor_tokenizer = load_hf_mdl_and_tok(args.anchor_mdl_id, verbose=True)
80
+ freeze_model(anchor_model)
81
+
82
+ coupling_specs += to_cpl_spec(args.anchor_langs, anchor_model, anchor_tokenizer[0], anchor_tokenizer[1], args.anchor_mdl_id)
83
+
84
+ train_set = MultilingualDatasetIterator(args.train_pretok_file)
85
+
86
+ acc_trainer = SwitchingAccelerator(coupling_specs, train_set, args)
87
+
88
+ upd_model, loss_list = acc_trainer.train()
89
+
90
+ #save_all_models(args.save_location, upd_model, main_tokenizer, coupling_specs, loss_list, trainer=acc_trainer.accelerator)
91
+
92
+
93
+ if __name__ == "__main__":
94
+ #sys.argv = ". models/smol models/smol_next data/smugri4a-dev.json-tokcache/thiscache.json smugri log_steps=1 lr=1e-5".split()
95
+ #sys.argv = ". models/llama3.2-1b models/llama-tuned data/smugri4a-dev.json-tokcache/llama.json smugri".split()
96
+ yes_i_called_this_function_do_main()
kuidastaltsutadalaamat/legacy/translate_backup.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+
5
+ import sys
6
+ import requests
7
+ import re
8
+ import torch
9
+
10
+ from aux import CmdlineArgs, log
11
+ #from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
12
+ from trainllm import load_hf_tokenizer, load_hf_model
13
+ from data import do_list_in_batches
14
+ from modelops import hf_tok, is_gen_ai
15
+ from collections import defaultdict
16
+ from langconv import is_nllb, is_madlad, any_to_mdl_type, get_mdl_type, any_to_neurotolge, is_dec_only_llm
17
+ from tokops import load_tokenizer, tokenizeit, detokenizemany
18
+
19
+
20
+ def prepare_for_translation(provided_inputs, toktup, input_language, output_language=None, device=None):
21
+ if is_nllb(toktup[0]):
22
+ toktup[0].src_lang = input_language
23
+ inputs_to_process = provided_inputs
24
+ elif is_madlad(toktup[0]):
25
+ madlad_tgt_lang = output_language
26
+ inputs_to_process = [f"{madlad_tgt_lang} {inp}" for inp in provided_inputs]
27
+ else:
28
+ raise NotImplementedError("Model type not supported")
29
+
30
+ prepared_inputs = tokenizeit(toktup, inputs_to_process, 1024, False) #tokenizer(inputs_to_process, return_tensors="pt", padding=True, truncation=True, max_length=512)
31
+
32
+ if device is not None:
33
+ prepared_inputs.to(device)
34
+
35
+ frc_bos = toktup[0].get_lang_id(output_language) if output_language is not None else None
36
+
37
+ return prepared_inputs, frc_bos
38
+
39
+
40
+ def finalize_translation(outputs, toktup):
41
+ result = detokenizemany(toktup, outputs) # tokenizer.batch_decode(outputs, skip_special_tokens=True)
42
+
43
+ return result
44
+
45
+
46
+ def loadmodel(mdlname="facebook/m2m100_418M", accelerator=None):
47
+ cl = AutoModelForCausalLM if is_gen_ai(mdlname) else AutoModelForSeq2SeqLM
48
+
49
+ if accelerator is not None:
50
+ model = cl.from_pretrained(mdlname, token=hf_tok, torch_dtype=torch.bfloat16)
51
+ model = accelerator.prepare(model)
52
+ else:
53
+ model = cl.from_pretrained(mdlname, token=hf_tok, torch_dtype=torch.bfloat16, device_map="auto")
54
+
55
+ return model
56
+
57
+
58
+ def encode(model, input_batch):
59
+ model = model.module if hasattr(model, "module") else model
60
+
61
+ if is_nllb(model):
62
+ enc = model.model.encoder
63
+ elif is_madlad(model):
64
+ enc = model.base_model.encoder
65
+ else:
66
+ raise NotImplementedError(f"Model {model} is not supported yet.")
67
+
68
+ inputs_without_labels = { k: input_batch[k] for k in input_batch if k != "labels" }
69
+
70
+ return enc(**inputs_without_labels)
71
+
72
+
73
+ def coupled_encode(coupling_specs, lang_to_bin, input_lang, input_texts, debug=False):
74
+
75
+ mdl_type = get_mdl_type(coupling_specs[0].model)
76
+ conv_input_lang = any_to_mdl_type(mdl_type, input_lang)
77
+
78
+ this = coupling_specs[lang_to_bin[conv_input_lang]]
79
+
80
+ # 0. input text --> input token IDs
81
+ these_inputs, _ = prepare_for_translation(input_texts, (this.tokenizer, this.postokenizer), conv_input_lang, device=this.model.device)
82
+ attention_mask = these_inputs["attention_mask"]
83
+ if debug:
84
+ for iii in range(len(input_texts)):
85
+ toklist = []
86
+ for tok_idx in these_inputs['input_ids'][iii]:
87
+ try:
88
+ tok = this.tokenizer.convert_ids_to_tokens([tok_idx])[0]
89
+ except IndexError:
90
+ tok = this.postokenizer['idx2tok'][str(tok_idx.item())]
91
+ toklist.append(tok)
92
+ print(these_inputs['input_ids'][iii])
93
+ print(toklist)
94
+
95
+ # 1. input token IDs --> encoder vectors
96
+ #embeddings = this.model.model.encoder(**these_inputs)
97
+ return encode(this.model, these_inputs), attention_mask
98
+
99
+
100
+ def postproc_llm_output(raw_outputs, tok):
101
+ eos_id = tok.convert_tokens_to_ids(tok.eos_token)
102
+
103
+ for i, _ in enumerate(raw_outputs):
104
+ repl = None
105
+ for ii, t in enumerate(raw_outputs[i]):
106
+ if t.item() == eos_id:
107
+ repl = eos_id
108
+ if repl is not None:
109
+ raw_outputs[i][ii] = repl
110
+
111
+ return raw_outputs
112
+
113
+
114
+ def llm_generate(coupling_specs, input_language, output_language, input_texts, debug=False):
115
+ mdl_type = get_mdl_type(coupling_specs[0].model)
116
+ conv_input_lang = any_to_mdl_type(mdl_type, input_language)
117
+ conv_output_lang = any_to_mdl_type(mdl_type, output_language)
118
+
119
+ tokenizer = coupling_specs[0].tokenizer
120
+
121
+ prep_texts = [make_gen_text(conv_input_lang, conv_output_lang, input_txt, None) for input_txt in input_texts]
122
+
123
+ tokenized = tokenizeit((tokenizer, None), prep_texts, 1024, is_target=False, is_llm=True)
124
+
125
+ obj = coupling_specs[0].model
126
+ obj = obj.module if hasattr(obj, "module") else obj
127
+
128
+ tokenized['input_ids'] = tokenized['input_ids'].to(obj.device)
129
+ tokenized['attention_mask'] = tokenized['attention_mask'].to(obj.device)
130
+
131
+ raw_outputs = obj.generate(**tokenized, max_length)
132
+
133
+ # 3. output token IDs --> output text
134
+ pre_result = tokenizer.batch_decode(postproc_llm_output(raw_outputs, tokenizer), skip_special_tokens=True)
135
+
136
+ result = [raw_out[len(prep_texts[i]):].split("\n")[0] for i, raw_out in enumerate(pre_result)]
137
+ """
138
+ # for i, raw_out in enumerate(pre_result):
139
+ # print("====")
140
+ # print(i, raw_out)
141
+ # print("%%%%")
142
+ # print(raw_out[len(prep_texts[i])-3:])
143
+ # print("----")
144
+ """
145
+
146
+ return result
147
+
148
+ def coupled_generate(coupling_specs, lang_to_bin, output_lang, encoder_embeddings, att_mask, debug=False):
149
+ mdl_type = get_mdl_type(coupling_specs[0].model)
150
+ conv_output_lang = any_to_mdl_type(mdl_type, output_lang)
151
+
152
+ dec_idx = lang_to_bin[conv_output_lang]
153
+
154
+ tokenizer = coupling_specs[dec_idx].tokenizer
155
+
156
+ # 2. encoder vectors --> output token IDs
157
+ frc_bos = tokenizer.convert_tokens_to_ids(conv_output_lang)
158
+ obj = coupling_specs[dec_idx].model
159
+ obj = obj.module if hasattr(obj, "module") else obj
160
+
161
+ raw_outputs = obj.generate(forced_bos_token_id=frc_bos, encoder_outputs=encoder_embeddings, attention_mask=att_mask)
162
+ if debug:
163
+ for rwout in raw_outputs:
164
+ print(rwout)
165
+ print(tokenizer.convert_ids_to_tokens(rwout))
166
+
167
+ # 3. output token IDs --> output text
168
+ result = finalize_translation(raw_outputs, (tokenizer, coupling_specs[dec_idx].postokenizer))
169
+
170
+ return result
171
+
172
+
173
+ def make_uniq(lang_to_bin):
174
+ result = defaultdict(lambda: 0)
175
+
176
+ for lang in lang_to_bin:
177
+ bin_set = lang_to_bin[lang]
178
+ result[lang] = 0 if 0 in bin_set else list(bin_set)[0]
179
+
180
+ return result
181
+
182
+
183
+ def translate_with_neurotolge(translation_input: str, src_lang: str, tgt_lang: str) -> dict:
184
+ url = "https://api.tartunlp.ai/translation/v2"
185
+
186
+ payload = {
187
+ "text": translation_input,
188
+ "src": any_to_neurotolge(src_lang),
189
+ "tgt": any_to_neurotolge(tgt_lang),
190
+ "domain": "general",
191
+ "application": "benchmarking"
192
+ }
193
+
194
+ error = None
195
+
196
+ for i in range(5):
197
+ try:
198
+ response = requests.post(url, json=payload)
199
+ response.raise_for_status() # Raise an error for bad status codes
200
+ return response.json()['result']
201
+ except requests.exceptions.RequestException as e:
202
+ error = {"error": str(e)}
203
+
204
+ return error
205
+
206
+
207
+ def remove_dia(snt):
208
+ if ">" in snt:
209
+ return re.sub(r'^<[^>]+> ', '', snt)
210
+ else:
211
+ return snt
212
+
213
+
214
+ def neurotolge_in_batches(input_texts, src_lang, tgt_lang):
215
+ neurotolge_langs = {'eng', 'est', 'ger', 'lit', 'lav', 'lvs', 'fin', 'rus', 'ukr', 'kca', 'koi', 'kpv', 'krl', 'lud', 'mdf', 'mhr', 'mns', 'mrj', 'myv', 'olo', 'udm', 'vep', 'liv', 'vro', 'sma', 'sme', 'smn', 'sms', 'smj', 'nor', 'hun'}
216
+
217
+ if src_lang in neurotolge_langs and tgt_lang in neurotolge_langs:
218
+ all_outputs = list()
219
+
220
+ for inp_batch in do_list_in_batches(input_texts, 8):
221
+ inp_batch_no_dia = [remove_dia(s) for s in inp_batch]
222
+ these_outputs = translate_with_neurotolge(inp_batch_no_dia, src_lang, tgt_lang)
223
+ if len(these_outputs) != len(inp_batch_no_dia):
224
+ raise Exception(f"Something went wrong.: {src_lang}/{tgt_lang}/{these_outputs}")
225
+ all_outputs += these_outputs
226
+ log(f"Translated {len(all_outputs)}/{len(input_texts)} sentences")
227
+
228
+ return all_outputs
229
+ else:
230
+ return None
231
+
232
+
233
+ def coupled_translate(coupling_specs, input_texts, input_language, output_language, debug=False):
234
+ lang_to_bin = make_uniq(lang_bin_mapping(coupling_specs))
235
+
236
+ all_outputs = list()
237
+
238
+ for inp_batch in do_list_in_batches(input_texts, 32):
239
+ if is_dec_only_llm(coupling_specs[0].tokenizer):
240
+ these_outputs = llm_generate(coupling_specs, input_language, output_language, input_texts, debug=debug)
241
+ else:
242
+ encoder_embeddings, att_mask = coupled_encode(coupling_specs, lang_to_bin, input_language, inp_batch, debug=debug)
243
+ these_outputs = coupled_generate(coupling_specs, lang_to_bin, output_language, encoder_embeddings, att_mask, debug=debug)
244
+
245
+ all_outputs += these_outputs
246
+
247
+ return all_outputs
248
+
249
+
250
+ def load_and_init_module_config(model_id, accelerator=None):
251
+ config = load_module_config(model_id)
252
+
253
+ coupling_specs = list()
254
+
255
+ main_model = None
256
+
257
+ for i, entry in enumerate(config):
258
+ lang_set = entry["lang_set"]
259
+ model_id = entry["model_id"] if i > 0 else model_id
260
+
261
+ log(f"Loading model and tokenizer from '{model_id}'")
262
+ model = loadmodel(model_id, accelerator)
263
+ tokenizer, postok = load_tokenizer(model_id)
264
+
265
+ if i == 0:
266
+ main_model = model
267
+
268
+ #(langs, model, tokenizer, location):
269
+ coupling_specs += to_cpl_spec(lang_set, model, tokenizer, postok, model_id)
270
+
271
+ return main_model, coupling_specs
272
+
273
+
274
+ def _cmdline_args(inputs):
275
+ # description = ""Translate STDIN text with a translation model""
276
+
277
+ pos_args = ["mdl_id", "from_lang", "to_lang"]
278
+
279
+ #post-process the arguments
280
+ args = CmdlineArgs(description, pos_args, input_args=inputs, kw_arg_dict={"debug": False})
281
+
282
+ log(f"Launched as {args}")
283
+
284
+ return args
285
+
286
+
287
+ def and_i_called_this_function_do_main_too(iv):
288
+ args = _cmdline_args(iv)
289
+
290
+ inputs = [line.strip() for line in sys.stdin]
291
+ # inputs = ["See on ikka tore uudis.", "Ma ikka katsetaks ka täpitähtedega tõlkimist.", "Mis tähed on täpitähed?"]
292
+
293
+ log(f"Inputs: {inputs}")
294
+
295
+ main_model, module_config = load_and_init_module_config(args.mdl_id)
296
+ log("Model loaded, starting to translate")
297
+ outputs = coupled_translate(module_config, inputs, args.from_lang, args.to_lang, debug=args.debug)
298
+
299
+ print("\n".join(outputs))
300
+
301
+ log("Done...")
302
+
303
+
304
+ if __name__ == "__main__":
305
+ input_values = sys.argv[1:] if len(sys.argv) > 1 \
306
+ else ["models/nllb", "et", "en"]
307
+
308
+ and_i_called_this_function_do_main_too(input_values)
309
+ """
kuidastaltsutadalaamat/metrics.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from data import read_input
4
+ from aux import log
5
+
6
+ import sys
7
+ from collections import defaultdict
8
+ from evaluate import load as load_metric
9
+
10
+ SMUGRI_RES = {
11
+ 'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")),
12
+ 'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")),
13
+ 'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari,"
14
+ "Komi-Permyak,Inari Sami".split(",")),
15
+ 'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli,"
16
+ "Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty,"
17
+ "Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(","))
18
+ }
19
+
20
+
21
+ def _gen_lang(lang):
22
+ return lang.split(",")[0]
23
+
24
+
25
+ def _hi_or_lo_lang(lang):
26
+ gen_lang = _gen_lang(lang)
27
+
28
+ for k, v in SMUGRI_RES.items():
29
+ if gen_lang in v:
30
+ return k
31
+
32
+ log(f"Unrecognized language: {lang} / {gen_lang}")
33
+ return '?'
34
+
35
+
36
+ def _collect_lp_pairs(json_inputs, str_outputs):
37
+ sets_by_lp = defaultdict(list)
38
+
39
+ for i, o in zip(json_inputs, str_outputs):
40
+ ref = i["tgt_segm"]
41
+ hyp = o
42
+ det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"]
43
+ gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"])
44
+ hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"])
45
+
46
+ sets_by_lp[det_lp].append((hyp, ref))
47
+ sets_by_lp[gen_lp].append((hyp, ref))
48
+ sets_by_lp[hilo_lp].append((hyp, ref))
49
+
50
+ return sets_by_lp
51
+
52
+
53
+ def compute_metrics(json_inputs, str_outputs):
54
+ sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs)
55
+
56
+ metric = load_metric("chrf")
57
+
58
+ result = []
59
+
60
+ for lp in sets_by_lp:
61
+ preds, outputs = zip(*sets_by_lp[lp])
62
+ metric_value = metric.compute(predictions=preds, references=outputs)
63
+
64
+ result.append((lp, metric_value, len(preds)))
65
+
66
+ return result
67
+
68
+
69
+ def avoid_global_scope():
70
+ json_inputs = read_input(sys.argv[1], "json")
71
+ str_outputs = read_input(sys.argv[2], "json")
72
+
73
+ lp_metric_dict = compute_metrics(json_inputs, str_outputs)
74
+
75
+ for lp, metric, size in lp_metric_dict:
76
+ print(f"{lp}: {metric['score']:.2f} ({size})")
77
+
78
+ if __name__ == "__main__":
79
+ avoid_global_scope()
kuidastaltsutadalaamat/promptops.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # first, keyword identifiers for selecting prompt templates in scripts:
3
+
4
+ PF_RAW = "raw"
5
+ PF_RAWLINES = "rawlines"
6
+ PF_SMUGRI_MT = "smugri_mt"
7
+ PF_SMUGRI_LID = "smugri_lid"
8
+ PF_ALPACA = "alpaca"
9
+
10
+ # now the prompt templates themselves, SMUGRI LID / MT template:
11
+
12
+ SMUGRI_INF_PROMPT_LID = "<|reserved_special_token_12|>{src_segm}<|reserved_special_token_13|>"
13
+
14
+ _SMUGRI_INF_PROMPT_TMPMID = "<|reserved_special_token_14|>{task} to {tgt_lang}<|reserved_special_token_15|>"
15
+ SMUGRI_INF_PROMPT_MT = SMUGRI_INF_PROMPT_LID + "{src_lang}" + _SMUGRI_INF_PROMPT_TMPMID
16
+
17
+ _SMUGRI_TRAIN_PROMPT_PREF = SMUGRI_INF_PROMPT_LID + "{src_lang}"
18
+ _SMUGRI_TRAIN_PROMPT_MID = _SMUGRI_INF_PROMPT_TMPMID + "{tgt_segm}"
19
+ _SMUGRI_TRAIN_PROMPT_SUF = "<|reserved_special_token_16|><|end_of_text|>"
20
+
21
+ SMUGRI_PROMPT_TRAIN_PARA = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_MID + _SMUGRI_TRAIN_PROMPT_SUF
22
+ SMUGRI_PROMPT_TRAIN_MONO = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_SUF
23
+
24
+ # Alpaca instructions prompt template:
25
+
26
+ ALPACA_PROMPT_INF = ("Below is an instruction that describes a task, paired with an input that provides further context. "
27
+ "Write a response that appropriately completes the request.\n\n"
28
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n")
29
+
30
+ ALPACA_PROMPT_TRAIN = (ALPACA_PROMPT_INF + "{output}")
31
+
32
+
33
+ def prep_prompt(data, prompt_format, inference=False):
34
+ if prompt_format in {PF_RAW, PF_RAWLINES}:
35
+ # data is a string, return it
36
+ return data
37
+
38
+ elif prompt_format in {PF_SMUGRI_MT, PF_SMUGRI_LID}:
39
+ # data has src_segm, src_lang, tgt_lang, etc
40
+ return _prep_ljmf_entry(data, prompt_format, inference)
41
+
42
+ elif prompt_format == PF_ALPACA:
43
+ # data has instruction and input in it
44
+ return _prep_alpaca_entry(data, inference)
45
+
46
+ else:
47
+ raise NotImplementedError(f"Prompt format {prompt_format} is not implemented.")
48
+
49
+
50
+ def _prep_alpaca_entry(entry, inference=False):
51
+ fmt = ALPACA_PROMPT_INF if inference else ALPACA_PROMPT_TRAIN
52
+ prompt = fmt.format(**entry)
53
+ return prompt
54
+
55
+
56
+ def _prep_ljmf_entry(entry, fmt, inference=False):
57
+ if inference:
58
+ if fmt == PF_SMUGRI_MT:
59
+ prompt = SMUGRI_INF_PROMPT_MT.format(**entry)
60
+ elif fmt == PF_SMUGRI_LID:
61
+ prompt = SMUGRI_INF_PROMPT_LID.format(**entry)
62
+ else:
63
+ raise NotImplementedError(f"Prompt format {fmt} is not implemented.")
64
+ else:
65
+ if entry['task'] in {'translate', 'approx-translate'} and entry['tgt_segm'] and entry['tgt_lang']:
66
+ prompt = SMUGRI_PROMPT_TRAIN_PARA.format(**entry)
67
+ else:
68
+ prompt = SMUGRI_PROMPT_TRAIN_MONO.format(**entry)
69
+
70
+ return prompt
kuidastaltsutadalaamat/trainllm.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from .promptops import PF_SMUGRI_MT
4
+ from .aux import log, CmdlineArgs
5
+ from .data import load_training_data
6
+
7
+ import json
8
+ import os, socket, torch
9
+
10
+ from datetime import datetime
11
+
12
+ from accelerate import Accelerator
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ TrainingArguments,
17
+ Trainer,
18
+ DataCollatorForLanguageModeling,
19
+ logging,
20
+ TrainerCallback
21
+ )
22
+
23
+ """
24
+ 1/3 This simply reads in command-line arguments
25
+ """
26
+
27
+ def _cmdline_args():
28
+ description = """Train or tune decoder models"""
29
+
30
+ result = CmdlineArgs(description,
31
+ pos_arg_list=["mdl_id", "save_location", "train_file"],
32
+ pos_arg_types=[str, str, str],
33
+ kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
34
+ "batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
35
+ "max_length": 2000, "prompt_format": PF_SMUGRI_MT,
36
+ "deepspeed": "none"})
37
+
38
+ # if the directory args.save_location already exists, raise an exception:
39
+ if not result.continue_training and os.path.exists(result.save_location):
40
+ raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")
41
+
42
+ if result.nr_sents_per_gpu == 0:
43
+ result.nr_sents_per_gpu = result.batch_size
44
+
45
+ if result.deepspeed == "none":
46
+ result.deepspeed = None
47
+
48
+ return result
49
+
50
+ """
51
+ 2/3 This here is used in training in order to report timing and predictions
52
+ """
53
+
54
+ class StepTimerCallback(TrainerCallback):
55
+ def __init__(self):
56
+ self._step_start = None
57
+ self.lengths = []
58
+ self.abs_start = datetime.now()
59
+
60
+ self.actual_first_step = None
61
+
62
+ self.zero = self.abs_start - self.abs_start
63
+
64
+ def on_step_begin(self, args, state, control, **kwargs):
65
+ # called right before each training step
66
+ self._step_start = datetime.now()
67
+
68
+ def on_step_end(self, args, state, control, **kwargs):
69
+ if self.actual_first_step is None:
70
+ self.actual_first_step = state.global_step - 1
71
+
72
+ # called right after each training step
73
+ now = datetime.now()
74
+ elapsed = now - self._step_start
75
+ tot_elapsed = now - self.abs_start
76
+ self.lengths.append(elapsed)
77
+
78
+ avg = sum(self.lengths, start=self.zero) / len(self.lengths)
79
+
80
+ remaining = state.max_steps - self.actual_first_step - state.global_step
81
+ prediction = (tot_elapsed/(state.global_step - self.actual_first_step)) * remaining
82
+
83
+ # you can use logging.get_logger(...) instead of print
84
+ print(f"[step {state.global_step}/{state.max_steps}] took {elapsed}, avg {avg}; approx {prediction} remaining")
85
+
86
+ """
87
+ 3/3 Finally, the filling of TrainingArguments and the launching of Trainer:
88
+ """
89
+
90
+ def get_training_args(cmdline_args, acc):
91
+ world_size = acc.num_processes
92
+
93
+ assert cmdline_args.batch_size % (cmdline_args.nr_sents_per_gpu * world_size) == 0, \
94
+ "Batch size must be divisible by the number of GPUs and nr of sents per GPU"
95
+
96
+ accum_steps = cmdline_args.batch_size // (cmdline_args.nr_sents_per_gpu * world_size)
97
+
98
+ log(f"Nr of processes (GPUs): {world_size}, per-device batch: {cmdline_args.nr_sents_per_gpu}, accum. steps: {accum_steps}")
99
+
100
+ if cmdline_args.deepspeed is not None:
101
+ with open(cmdline_args.deepspeed, "r") as f:
102
+ dpspd = json.load(f)
103
+
104
+ #correct the dictionary with current values, so that we wouldn't need to update the JSON every time
105
+ dpspd['train_batch_size'] = cmdline_args.batch_size
106
+ dpspd['train_micro_batch_size_per_gpu'] = cmdline_args.nr_sents_per_gpu
107
+ dpspd['gradient_accumulation_steps'] = accum_steps
108
+
109
+ log(f"Using deepspeed with config {dpspd}")
110
+ else:
111
+ dpspd = None
112
+
113
+ tr_args = TrainingArguments(
114
+ output_dir=cmdline_args.save_location,
115
+ per_device_train_batch_size=cmdline_args.nr_sents_per_gpu,
116
+ gradient_accumulation_steps=accum_steps,
117
+ num_train_epochs=cmdline_args.epochs,
118
+ save_steps=cmdline_args.save_steps,
119
+ save_total_limit=10,
120
+ logging_steps=cmdline_args.log_steps,
121
+ deepspeed=dpspd,
122
+ learning_rate=cmdline_args.lr,
123
+ save_strategy="epoch",
124
+ disable_tqdm=True,
125
+ report_to="none",
126
+ # Optional but often helpful on LUMI/ROCm if you enable it in your args:
127
+ bf16=True,
128
+ ddp_find_unused_parameters=False,
129
+ #dataloader_num_workers=1,
130
+ #group_by_length=True,
131
+ log_level="debug",
132
+ #gradient_checkpointing=True,
133
+ #dataloader_persistent_workers=True
134
+ )
135
+
136
+ return tr_args
137
+
138
+
139
+ def load_model(mdl_id, device, accelerator=None, attention="flash_attention_2"):
140
+ log(f"Load model", accelerator=accelerator)
141
+ model = AutoModelForCausalLM.from_pretrained(mdl_id,
142
+ low_cpu_mem_usage=False,
143
+ torch_dtype=torch.bfloat16,
144
+ attn_implementation=attention)
145
+
146
+ model.config.use_cache = False
147
+ model = model.to(device)
148
+ log(f"Model loaded on device: {model.device}.", accelerator=accelerator)
149
+
150
+ return model
151
+
152
+
153
+ def load_tokenizer(mdl_id, accelerator=None):
154
+ log(f"Load tokenizer", accelerator=accelerator)
155
+ tokenizer = AutoTokenizer.from_pretrained(mdl_id)
156
+
157
+ # LLaMA 3.x: no pad token by default
158
+ if tokenizer.pad_token is None:
159
+ tokenizer.pad_token = "<|reserved_special_token_100|>"
160
+
161
+ return tokenizer
162
+
163
+
164
+ def simple_train():
165
+ cmd_args = _cmdline_args()
166
+ acc = Accelerator()
167
+ device = acc.device # it seems that the accelerator loses/changes this info later
168
+
169
+ training_args = get_training_args(cmd_args, acc)
170
+
171
+ tokenizer = load_tokenizer(cmd_args.mdl_id, acc)
172
+ model = load_model(cmd_args.mdl_id, device, acc)
173
+
174
+ if getattr(model.config, "pad_token_id", None) is None:
175
+ model.config.pad_token_id = tokenizer.pad_token_id
176
+
177
+ log(f"Load data", accelerator=acc)
178
+ tokenized_train_data = load_training_data(cmd_args.train_file, tokenizer, cmd_args)
179
+
180
+ data_collator = DataCollatorForLanguageModeling(
181
+ tokenizer=tokenizer,
182
+ mlm=False,
183
+ pad_to_multiple_of=8, # GPT says this helps performance
184
+ )
185
+
186
+ log(f"Preparing to train", accelerator=acc)
187
+
188
+ clbks = [StepTimerCallback] if acc.is_main_process else []
189
+
190
+ trainer = Trainer(
191
+ model=model,
192
+ args=training_args,
193
+ train_dataset=tokenized_train_data,
194
+ tokenizer=tokenizer,
195
+ data_collator=data_collator,
196
+ callbacks=clbks,
197
+ )
198
+
199
+ logging.set_verbosity_debug()
200
+
201
+ log(f"Starting training", accelerator=acc)
202
+ trainer.train(resume_from_checkpoint=cmd_args.continue_training)
203
+
204
+ log(f"Done, saving model", accelerator=acc)
205
+ trainer.save_model()
206
+
207
+
208
+ def env_stuff():
209
+ os.environ.setdefault("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "---"))
210
+ os.environ.setdefault("RANK", os.environ.get("SLURM_PROCID", "0"))
211
+ os.environ.setdefault("WORLD_SIZE", os.environ.get("SLURM_NTASKS", "1"))
212
+ os.environ.setdefault("MASTER_ADDR", os.environ.get("SLURM_LAUNCH_NODE_IPADDR", "127.0.0.1"))
213
+ os.environ.setdefault("MASTER_PORT", "29500") # pick an open port
214
+
215
+ # Optional: make sure each process selects its own GPU
216
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
217
+
218
+ try:
219
+ log(
220
+ f"host={socket.gethostname()} "
221
+ f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} "
222
+ f"LOCAL_RANK={os.environ['LOCAL_RANK']} "
223
+ f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} "
224
+ f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} "
225
+ f"cuda_count={torch.cuda.device_count()} curr_dev={torch.cuda.current_device()}"
226
+ )
227
+ except AssertionError:
228
+ log(
229
+ f"host={socket.gethostname()} "
230
+ f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} "
231
+ f"LOCAL_RANK={os.environ['LOCAL_RANK']} "
232
+ f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} "
233
+ f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} "
234
+ f"no cuda"
235
+ )
236
+
237
+ """
238
+ This replaces the trainer, in order to
239
+ print out the final batch when training,
240
+ and commit harakiri. So only for temporary
241
+ debugging-related usage
242
+ """
243
+ class LoggingKillingTrainer(Trainer):
244
+ def compute_loss(self, model, inputs, **kwargs):
245
+ log(f"Here is the batch for training: {inputs}")
246
+ raise NotImplementedError
247
+ #return super().compute_loss(model, inputs, **kwargs)
248
+
249
+ if __name__ == "__main__":
250
+ env_stuff()
251
+
252
+ simple_train()