Spaces:
Sleeping
Sleeping
Rasmus Lellep
commited on
Commit
·
76b1ec5
1
Parent(s):
9e93eb6
add loader
Browse files- kuidastaltsutadalaamat/.gitignore +162 -0
- kuidastaltsutadalaamat/LICENSE +21 -0
- kuidastaltsutadalaamat/README.md +2 -0
- kuidastaltsutadalaamat/aux.py +212 -0
- kuidastaltsutadalaamat/data.py +142 -0
- kuidastaltsutadalaamat/inference.py +170 -0
- kuidastaltsutadalaamat/legacy/accel.py +328 -0
- kuidastaltsutadalaamat/legacy/accel_backup.py +237 -0
- kuidastaltsutadalaamat/legacy/benchmark.py +190 -0
- kuidastaltsutadalaamat/legacy/data.py +164 -0
- kuidastaltsutadalaamat/legacy/data_backup.py +804 -0
- kuidastaltsutadalaamat/legacy/diffmdl.py +69 -0
- kuidastaltsutadalaamat/legacy/initmodel.py +46 -0
- kuidastaltsutadalaamat/legacy/langconv.py +260 -0
- kuidastaltsutadalaamat/legacy/localizemodel.py +45 -0
- kuidastaltsutadalaamat/legacy/modelops.py +122 -0
- kuidastaltsutadalaamat/legacy/oldtrainllm.py +90 -0
- kuidastaltsutadalaamat/legacy/parasynth.py +139 -0
- kuidastaltsutadalaamat/legacy/pretok.py +65 -0
- kuidastaltsutadalaamat/legacy/testmem.py +100 -0
- kuidastaltsutadalaamat/legacy/tokops.py +350 -0
- kuidastaltsutadalaamat/legacy/trainmodel.py +96 -0
- kuidastaltsutadalaamat/legacy/translate_backup.py +309 -0
- kuidastaltsutadalaamat/metrics.py +79 -0
- kuidastaltsutadalaamat/promptops.py +70 -0
- kuidastaltsutadalaamat/trainllm.py +252 -0
kuidastaltsutadalaamat/.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
kuidastaltsutadalaamat/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 TartuNLP
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
kuidastaltsutadalaamat/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Kuidas taltsutada laamat
|
2 |
+
Implementation of LLM continued training and inference.
|
kuidastaltsutadalaamat/aux.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
|
11 |
+
def log(msg, accelerator=None, all_threads=False):
|
12 |
+
if accelerator is not None and all_threads:
|
13 |
+
report_proc = f" ({accelerator.process_index+1}/{accelerator.num_processes})"
|
14 |
+
else:
|
15 |
+
report_proc = ""
|
16 |
+
|
17 |
+
if accelerator is None or accelerator.is_main_process or all_threads:
|
18 |
+
sys.stderr.write(str(datetime.now()) + report_proc + ": " + msg + '\n')
|
19 |
+
|
20 |
+
|
21 |
+
def _same_line_log(msg, len_to_del=0):
|
22 |
+
"""if sys.stderr.isatty():
|
23 |
+
if len_to_del > 0:
|
24 |
+
sys.stderr.write("\b" * len_to_del)
|
25 |
+
|
26 |
+
new_len = len(msg)
|
27 |
+
|
28 |
+
sys.stderr.write(msg)
|
29 |
+
sys.stderr.flush()
|
30 |
+
|
31 |
+
return new_len
|
32 |
+
else:"""
|
33 |
+
log(msg)
|
34 |
+
|
35 |
+
|
36 |
+
def debug(msg):
|
37 |
+
pass
|
38 |
+
### log("\n(DEBUG) " + msg)
|
39 |
+
|
40 |
+
|
41 |
+
def maybe_convert(value):
|
42 |
+
try:
|
43 |
+
return int(value)
|
44 |
+
except (ValueError, TypeError):
|
45 |
+
try:
|
46 |
+
return float(value)
|
47 |
+
except (ValueError, TypeError):
|
48 |
+
return value
|
49 |
+
|
50 |
+
|
51 |
+
def get_changed_config(conf, args):
|
52 |
+
arg_dict = args.to_dict()
|
53 |
+
|
54 |
+
for kwarg in arg_dict:
|
55 |
+
if hasattr(conf, kwarg) and arg_dict[kwarg] is not None:
|
56 |
+
setattr(conf, kwarg, maybe_convert(arg_dict[kwarg]))
|
57 |
+
|
58 |
+
return conf
|
59 |
+
|
60 |
+
|
61 |
+
class SameLineLogger:
|
62 |
+
def __init__(self, epoch_len, epoch_num, data_state):
|
63 |
+
self.epoch_len = epoch_len
|
64 |
+
self.epoch_num = epoch_num
|
65 |
+
self.start_global_step = epoch_len * data_state.epoch_idx + data_state.elem_idx
|
66 |
+
|
67 |
+
self.totalx = epoch_len * epoch_num
|
68 |
+
|
69 |
+
self.log_after = []
|
70 |
+
self.log_len = 0
|
71 |
+
|
72 |
+
self.start_time = datetime.now()
|
73 |
+
|
74 |
+
def line_start(self):
|
75 |
+
_same_line_log(str(datetime.now()) + ": training batches ")
|
76 |
+
|
77 |
+
def step(self, global_batch_idx, epoch_batch_idx, epoch_idx, loss, lr, grad):
|
78 |
+
passed_time = datetime.now() - self.start_time
|
79 |
+
time_per_batch = passed_time / (global_batch_idx - self.start_global_step)
|
80 |
+
prediction = time_per_batch * (self.totalx - global_batch_idx)
|
81 |
+
|
82 |
+
msg = f"{epoch_batch_idx} / {self.epoch_len}, epoch {epoch_idx + 1} / {self.epoch_num}, loss={loss}, avg {time_per_batch}/iter, {prediction} to finish, LR={lr:.2e}, grad={grad:.2e} "
|
83 |
+
|
84 |
+
new_len = _same_line_log(msg, self.log_len)
|
85 |
+
|
86 |
+
self.log_len = new_len
|
87 |
+
|
88 |
+
def line_break(self):
|
89 |
+
sys.stderr.write("\n")
|
90 |
+
|
91 |
+
|
92 |
+
class CmdlineArgs:
|
93 |
+
def __init__(self,
|
94 |
+
description,
|
95 |
+
pos_arg_list=None,
|
96 |
+
pos_arg_types=None,
|
97 |
+
kw_arg_dict=None,
|
98 |
+
input_args=None):
|
99 |
+
|
100 |
+
self.description = description
|
101 |
+
|
102 |
+
self.raw_pos_arg_list = pos_arg_list if pos_arg_list is not None else []
|
103 |
+
self.raw_pos_arg_types = pos_arg_types \
|
104 |
+
if pos_arg_types is not None \
|
105 |
+
else [None] * len(self.raw_pos_arg_list)
|
106 |
+
|
107 |
+
self.kw_arg_dict_with_defaults = kw_arg_dict if kw_arg_dict is not None else {}
|
108 |
+
|
109 |
+
kw_vals, cmdline_values = self._to_kwargs(sys.argv[1:] if input_args is None else input_args)
|
110 |
+
|
111 |
+
self._maybe_help(cmdline_values)
|
112 |
+
|
113 |
+
self._handle_positional_args(cmdline_values)
|
114 |
+
|
115 |
+
self._handle_keyword_args(kw_vals)
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def _to_kwargs(arg_list):
|
119 |
+
key_args = dict(raw_entry.lstrip("-").split("=") for raw_entry in arg_list if "=" in raw_entry)
|
120 |
+
filtered_arg_list = [arg for arg in arg_list if "=" not in arg]
|
121 |
+
|
122 |
+
return key_args, filtered_arg_list
|
123 |
+
|
124 |
+
def _handle_keyword_args(self, kw_vals):
|
125 |
+
for kw in self.kw_arg_dict_with_defaults:
|
126 |
+
if kw in kw_vals:
|
127 |
+
val = self._convert_kw(kw_vals, kw)
|
128 |
+
del kw_vals[kw]
|
129 |
+
else:
|
130 |
+
val = self.kw_arg_dict_with_defaults[kw]
|
131 |
+
|
132 |
+
setattr(self, kw, val)
|
133 |
+
|
134 |
+
if kw_vals:
|
135 |
+
extra_keys = ", ".join(kw_vals.keys())
|
136 |
+
msg = f"command-line keyword arguments '{extra_keys}' are not recognized."
|
137 |
+
|
138 |
+
self._help_message_and_die(extra=msg)
|
139 |
+
|
140 |
+
def _convert_kw(self, kw_vals, kw):
|
141 |
+
if self.kw_arg_dict_with_defaults[kw] is None:
|
142 |
+
return kw_vals[kw]
|
143 |
+
else:
|
144 |
+
this_typ = type(self.kw_arg_dict_with_defaults[kw])
|
145 |
+
|
146 |
+
try:
|
147 |
+
return this_typ(kw_vals[kw])
|
148 |
+
except ValueError:
|
149 |
+
self._help_message_and_die(extra=f"could not convert '{kw_vals[kw]}' to '{this_typ}'")
|
150 |
+
|
151 |
+
def _sanity_check_pos_args(self, cmdline_values):
|
152 |
+
cmdline_len = len(cmdline_values)
|
153 |
+
|
154 |
+
if cmdline_len < len(self.raw_pos_arg_list):
|
155 |
+
self._help_message_and_die(
|
156 |
+
extra=f"positional arguments missing: {', '.join(self.raw_pos_arg_list[cmdline_len:])}")
|
157 |
+
|
158 |
+
if cmdline_len > len(self.raw_pos_arg_list):
|
159 |
+
self._help_message_and_die(
|
160 |
+
extra=f"superfluous positional arguments: {', '.join(cmdline_values[len(self.raw_pos_arg_list):])}")
|
161 |
+
|
162 |
+
def _handle_positional_args(self, cmdline_values):
|
163 |
+
self._sanity_check_pos_args(cmdline_values)
|
164 |
+
|
165 |
+
for arg, val, typ in zip(self.raw_pos_arg_list, cmdline_values, self.raw_pos_arg_types):
|
166 |
+
try:
|
167 |
+
val = val if typ is None else typ(val)
|
168 |
+
except ValueError:
|
169 |
+
self._help_message_and_die(extra=f"could not convert '{val}' to '{typ}'")
|
170 |
+
|
171 |
+
setattr(self, arg, val)
|
172 |
+
|
173 |
+
def _maybe_help(self, cmdline_values):
|
174 |
+
if len(cmdline_values) == 1 and cmdline_values[0] in {"--help", "-h", "-?"}:
|
175 |
+
self._help_message_and_die()
|
176 |
+
|
177 |
+
def _help_message_and_die(self, extra=None):
|
178 |
+
sys.stderr.write("Help message: " + self.description + "\n")
|
179 |
+
|
180 |
+
if self.raw_pos_arg_list:
|
181 |
+
args_descr = ", ".join([f"'{arg}' ({typ.__name__ if typ is not None else 'any'})"
|
182 |
+
for arg, typ in zip(self.raw_pos_arg_list, self.raw_pos_arg_types)])
|
183 |
+
|
184 |
+
sys.stderr.write(f"Positional arguments: {args_descr}\n")
|
185 |
+
|
186 |
+
if self.kw_arg_dict_with_defaults:
|
187 |
+
kw_descr = ", ".join([f"'{kw}' (default: {val})"
|
188 |
+
for kw, val in self.kw_arg_dict_with_defaults.items()])
|
189 |
+
|
190 |
+
sys.stderr.write(f"Keyword arguments: {kw_descr}\n")
|
191 |
+
|
192 |
+
if extra is not None:
|
193 |
+
sys.stderr.write("Error: " + extra + "\n")
|
194 |
+
|
195 |
+
sys.stderr.write("\n")
|
196 |
+
sys.exit(-1)
|
197 |
+
|
198 |
+
def to_dict(self):
|
199 |
+
return {k: v for k, v in self.__dict__.items()
|
200 |
+
if k not in {'description', 'raw_pos_arg_list', 'raw_pos_arg_types', 'kw_arg_dict_with_defaults'}}
|
201 |
+
|
202 |
+
def __str__(self):
|
203 |
+
return str(self.to_dict())
|
204 |
+
|
205 |
+
def __repr__(self):
|
206 |
+
return self.__str__()
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
for dname in sys.argv[1:]:
|
210 |
+
d = np.load(dname + "/custom_checkpoint_1.pkl", allow_pickle=True)
|
211 |
+
p = pickle.loads(d['custom_checkpoint_1/data.pkl'])
|
212 |
+
print(dname, p)
|
kuidastaltsutadalaamat/data.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from .promptops import *
|
3 |
+
|
4 |
+
import json
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from random import shuffle
|
8 |
+
|
9 |
+
from torch.utils.data import Dataset as TorchDataset, DataLoader
|
10 |
+
|
11 |
+
from .aux import log
|
12 |
+
|
13 |
+
|
14 |
+
def tokenize_str(tokenizer, entry, add_eos=True, max_len=3000, for_inf=False):
|
15 |
+
if for_inf:
|
16 |
+
tokens = tokenizer(
|
17 |
+
entry,
|
18 |
+
truncation=True,
|
19 |
+
max_length=max_len,
|
20 |
+
return_attention_mask=True,
|
21 |
+
return_tensors="pt"
|
22 |
+
)
|
23 |
+
else:
|
24 |
+
tokens = tokenizer(
|
25 |
+
entry,
|
26 |
+
truncation=True,
|
27 |
+
max_length=max_len,
|
28 |
+
return_attention_mask=True
|
29 |
+
)
|
30 |
+
|
31 |
+
if add_eos:
|
32 |
+
tokens['attention_mask'].append(1)
|
33 |
+
tokens['input_ids'].append(tokenizer.eos_token_id)
|
34 |
+
|
35 |
+
return tokens
|
36 |
+
|
37 |
+
"""
|
38 |
+
Load texts into memory and allow to loop through it,
|
39 |
+
returning tokenized tensors.
|
40 |
+
|
41 |
+
Currently no support for text data that does not fit into memory,
|
42 |
+
need to add it. Or do HF datasets have something out of the box?
|
43 |
+
"""
|
44 |
+
class LazyTokenizingDataset(TorchDataset):
|
45 |
+
def __init__(self, texts, tokenizer, max_length=512, prompt_format="raw"):
|
46 |
+
self.texts = texts
|
47 |
+
self.tokenizer = tokenizer
|
48 |
+
self.max_length = max_length
|
49 |
+
self.prompt_format = prompt_format
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.texts)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
# Return plain Python lists; let the collator pad & build labels.
|
56 |
+
entry = self.texts[idx]
|
57 |
+
|
58 |
+
prompt = prep_prompt(entry, self.prompt_format)
|
59 |
+
|
60 |
+
return tokenize_str(self.tokenizer, prompt)
|
61 |
+
|
62 |
+
|
63 |
+
class LazyTokenizingInferenceDataset(TorchDataset):
|
64 |
+
def __init__(self, texts, tokenizer, prompt_format, max_length=512, debug=False):
|
65 |
+
self.texts = texts
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
self.max_length = max_length
|
68 |
+
self.prompt_format = prompt_format
|
69 |
+
self.debug = debug
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return len(self.texts)
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
entry = self.texts[idx]
|
76 |
+
|
77 |
+
prompt = prep_prompt(entry, self.prompt_format, inference=True)
|
78 |
+
result = tokenize_str(self.tokenizer, prompt, add_eos=False, for_inf=True)
|
79 |
+
|
80 |
+
if self.debug:
|
81 |
+
log(f"Input: {prompt}")
|
82 |
+
log(f"Tokenized: {result}")
|
83 |
+
|
84 |
+
return result
|
85 |
+
|
86 |
+
|
87 |
+
def read_input(path, formt):
|
88 |
+
if path is None:
|
89 |
+
log("Reading from STDIN")
|
90 |
+
fh = sys.stdin
|
91 |
+
else:
|
92 |
+
fh = open(path, 'r')
|
93 |
+
|
94 |
+
if formt == PF_RAW:
|
95 |
+
result = [fh.read()]
|
96 |
+
elif formt == PF_RAWLINES:
|
97 |
+
result = fh.readlines()
|
98 |
+
else:
|
99 |
+
result = json.load(fh)
|
100 |
+
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
def get_data_loader(path, prompt_format, tokenizer, debug=False):
|
105 |
+
inputs = read_input(path, prompt_format)
|
106 |
+
|
107 |
+
dataset = LazyTokenizingInferenceDataset(inputs, tokenizer, prompt_format, debug=debug)
|
108 |
+
|
109 |
+
"""
|
110 |
+
data_coll = DataCollatorForLanguageModeling(
|
111 |
+
tokenizer=tokenizer,
|
112 |
+
mlm=False,
|
113 |
+
pad_to_multiple_of=None, # helps performance; set None if you prefer exact lengths
|
114 |
+
)
|
115 |
+
|
116 |
+
data_loader = DataLoader(dataset, collate_fn=data_coll, batch_size=1)
|
117 |
+
"""
|
118 |
+
|
119 |
+
return dataset
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def load_training_data(path, tokenizer, cmd_args):
|
124 |
+
with open(path, "r") as f:
|
125 |
+
data = json.load(f)
|
126 |
+
|
127 |
+
train_set_iter = LazyTokenizingDataset(data, tokenizer, cmd_args.max_length, cmd_args.prompt_format)
|
128 |
+
|
129 |
+
return train_set_iter
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
all_data = []
|
134 |
+
|
135 |
+
for input_file in sys.argv[1:]:
|
136 |
+
with open(input_file, "r") as f:
|
137 |
+
this_data = json.load(f)
|
138 |
+
all_data += this_data
|
139 |
+
|
140 |
+
shuffle(all_data)
|
141 |
+
|
142 |
+
json.dump(all_data, sys.stdout)
|
kuidastaltsutadalaamat/inference.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from .promptops import *
|
4 |
+
|
5 |
+
from .aux import CmdlineArgs, log
|
6 |
+
from .data import get_data_loader
|
7 |
+
from .trainllm import env_stuff, load_model, load_tokenizer
|
8 |
+
|
9 |
+
|
10 |
+
import sys
|
11 |
+
import torch
|
12 |
+
import json
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
from accelerate import Accelerator
|
16 |
+
|
17 |
+
from datetime import datetime
|
18 |
+
|
19 |
+
"""
|
20 |
+
This currently assumes the batch size to be 1. With larger batches the padding tokens went
|
21 |
+
into the decoder. Right-padding as a solution?
|
22 |
+
"""
|
23 |
+
def llm_generate(model, tokenizer, tok_batch, debug=False, max_len=2000):
|
24 |
+
tok_batch['input_ids'] = tok_batch['input_ids'].to(model.device)
|
25 |
+
tok_batch['attention_mask'] = tok_batch['attention_mask'].to(model.device)
|
26 |
+
start_time = datetime.now()
|
27 |
+
|
28 |
+
if debug:
|
29 |
+
log(f"Tokenized input: {tok_batch['input_ids']}")
|
30 |
+
|
31 |
+
raw_output_toks = model.generate(**tok_batch, tokenizer=tokenizer,
|
32 |
+
do_sample=False, num_beams=4, max_length=max_len, top_p=None, temperature=None,
|
33 |
+
eos_token_id=[tokenizer.eos_token_id,
|
34 |
+
tokenizer.convert_tokens_to_ids("<|reserved_special_token_14|>")])
|
35 |
+
|
36 |
+
#clean_output_toks = remove_prompt_from_output(tok_batch['attention_mask'], raw_output_toks, filler_id)
|
37 |
+
assert len(raw_output_toks) == 1, "Only batch size=1 supported %-("
|
38 |
+
gen_idx = len(tok_batch['attention_mask'][0])
|
39 |
+
|
40 |
+
if debug:
|
41 |
+
log(f"Full tokenized output: {raw_output_toks[0]}")
|
42 |
+
log(f"Full tokens: {tokenizer.convert_ids_to_tokens(raw_output_toks[0])}")
|
43 |
+
full_out = tokenizer.batch_decode([raw_output_toks[0]], skip_special_tokens=True)
|
44 |
+
log(f"Full text: {full_out[0]}")
|
45 |
+
|
46 |
+
clean_output_toks = raw_output_toks[0][gen_idx:]
|
47 |
+
clean_outputs = tokenizer.batch_decode([clean_output_toks], skip_special_tokens=True)
|
48 |
+
|
49 |
+
if debug:
|
50 |
+
log(f"Pruned tokenized output: {clean_output_toks}")
|
51 |
+
log(f"Pruned tokens: {tokenizer.convert_ids_to_tokens(clean_output_toks)}")
|
52 |
+
log(f"Cleaned output: {clean_outputs[0]}")
|
53 |
+
|
54 |
+
end_time = datetime.now()
|
55 |
+
log(f"This took: {end_time - start_time}")
|
56 |
+
|
57 |
+
return clean_outputs
|
58 |
+
|
59 |
+
|
60 |
+
def reassemble_multi(list_of_lists):
|
61 |
+
result = []
|
62 |
+
|
63 |
+
for gen_idx in range(len(list_of_lists[0])):
|
64 |
+
for i in range(len(list_of_lists)):
|
65 |
+
if gen_idx < len(list_of_lists[i]):
|
66 |
+
result.append(list_of_lists[i][gen_idx])
|
67 |
+
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def predict(model, tokenizer, data_loader, accel, multi=False, debug=False, max_len=2000):
|
72 |
+
outs_final = []
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
for idx, batch in enumerate(data_loader):
|
76 |
+
if idx % accel.num_processes == accel.process_index:
|
77 |
+
start_time = datetime.now()
|
78 |
+
outputs = llm_generate(model, tokenizer, batch, debug=debug, max_len=max_len)
|
79 |
+
end_time = datetime.now()
|
80 |
+
log(f"Generated for {idx} in proc {accel.process_index} in {end_time - start_time}")
|
81 |
+
outs_final += outputs
|
82 |
+
|
83 |
+
if multi:
|
84 |
+
accel.wait_for_everyone()
|
85 |
+
|
86 |
+
rank0_buffer = [None] * accel.num_processes if accel.is_main_process else None
|
87 |
+
dist.gather_object(outs_final, rank0_buffer, dst=0)
|
88 |
+
if accel.is_main_process:
|
89 |
+
outs_final = reassemble_multi(rank0_buffer)
|
90 |
+
else:
|
91 |
+
outs_final = None
|
92 |
+
|
93 |
+
return outs_final
|
94 |
+
|
95 |
+
|
96 |
+
def _cmdline_args():
|
97 |
+
inputs = sys.argv[1:]
|
98 |
+
|
99 |
+
description = """Predict output for an input via prompting"""
|
100 |
+
|
101 |
+
pos_args = ["mdl_id"]
|
102 |
+
|
103 |
+
#post-process the arguments
|
104 |
+
args = CmdlineArgs(description, pos_args, input_args=inputs,
|
105 |
+
kw_arg_dict={"debug": False,
|
106 |
+
"input_file": "none",
|
107 |
+
"output_file": "none",
|
108 |
+
"multiproc": False,
|
109 |
+
"max_len": 2000,
|
110 |
+
"prompt_format": PF_ALPACA})
|
111 |
+
|
112 |
+
if args.input_file == "none":
|
113 |
+
args.input_file = None
|
114 |
+
if args.output_file == "none":
|
115 |
+
args.output_file = None
|
116 |
+
|
117 |
+
log(f"Launched as {args}")
|
118 |
+
|
119 |
+
return args
|
120 |
+
|
121 |
+
|
122 |
+
def save_all(outputs, args, acc):
|
123 |
+
if acc.is_main_process:
|
124 |
+
if args.output_file is None:
|
125 |
+
log("Writing to STDOUT")
|
126 |
+
out_fh = sys.stdout
|
127 |
+
else:
|
128 |
+
out_fh = open(args.output_file, "w")
|
129 |
+
|
130 |
+
if args.prompt_format in {PF_RAW, PF_RAWLINES}:
|
131 |
+
for line in outputs:
|
132 |
+
out_fh.write(line + "\n")
|
133 |
+
else:
|
134 |
+
json.dump(outputs, out_fh)
|
135 |
+
|
136 |
+
|
137 |
+
def and_i_called_this_function_do_main_too():
|
138 |
+
args = _cmdline_args()
|
139 |
+
|
140 |
+
if args.multiproc:
|
141 |
+
env_stuff()
|
142 |
+
|
143 |
+
acc = Accelerator()
|
144 |
+
device = acc.device
|
145 |
+
|
146 |
+
log(f"Device: {device}.", accelerator=acc)
|
147 |
+
|
148 |
+
if not args.multiproc and not acc.is_main_process:
|
149 |
+
log("Not launched in multi-processing mode, exiting non-main process.")
|
150 |
+
sys.exit(0)
|
151 |
+
|
152 |
+
tokenizer = load_tokenizer(args.mdl_id, acc)
|
153 |
+
|
154 |
+
data_loader = get_data_loader(args.input_file, args.prompt_format, tokenizer, debug=args.debug)
|
155 |
+
|
156 |
+
model = load_model(args.mdl_id, device, acc, attention="eager")
|
157 |
+
model.eval()
|
158 |
+
|
159 |
+
log(f"Device: {model.device}.", accelerator=acc)
|
160 |
+
|
161 |
+
log("Model loaded, starting to generate")
|
162 |
+
outputs = predict(model, tokenizer, data_loader, acc, multi=args.multiproc, debug=args.debug, max_len=args.max_len)
|
163 |
+
|
164 |
+
save_all(outputs, args, acc)
|
165 |
+
|
166 |
+
log("Done")
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
and_i_called_this_function_do_main_too()
|
kuidastaltsutadalaamat/legacy/accel.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from accelerate import Accelerator
|
6 |
+
from datetime import datetime
|
7 |
+
from transformers import get_scheduler
|
8 |
+
|
9 |
+
from aux import SameLineLogger, log
|
10 |
+
from data import DataState, BatchingIterator
|
11 |
+
from modelops import save_all_models, report_devices
|
12 |
+
|
13 |
+
|
14 |
+
def chain_params(coupling_specs):
|
15 |
+
for spec in coupling_specs:
|
16 |
+
yield from spec.model.parameters()
|
17 |
+
|
18 |
+
|
19 |
+
class TrainLossList:
|
20 |
+
def __init__(self):
|
21 |
+
self.data = []
|
22 |
+
|
23 |
+
def append(self, loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx):
|
24 |
+
self.data.append((loss_val, sub_batch_idx, epoch_batch_idx, _epoch_idx))
|
25 |
+
|
26 |
+
def state_dict(self):
|
27 |
+
return {'data': self.data}
|
28 |
+
|
29 |
+
def load_state_dict(self, state_dict):
|
30 |
+
self.data = state_dict['data']
|
31 |
+
|
32 |
+
|
33 |
+
class SwitchingAccelerator:
|
34 |
+
def __init__(self, train_set, train_kwargs, model, tokenizer, preinit_acc=None):
|
35 |
+
self.kwargs = train_kwargs
|
36 |
+
self.train_set_iter = BatchingIterator(train_set, self.kwargs.batch_size, tokenizer, train_kwargs.max_length)
|
37 |
+
|
38 |
+
self.model = model
|
39 |
+
self.tokenizer = tokenizer
|
40 |
+
|
41 |
+
self.train_loss_list = TrainLossList()
|
42 |
+
self.data_state = DataState(epoch_idx=0)
|
43 |
+
|
44 |
+
self._init_acc_and_stuff(preinit_acc)
|
45 |
+
|
46 |
+
self._init_time_keepers()
|
47 |
+
|
48 |
+
def _init_time_keepers(self):
|
49 |
+
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
|
50 |
+
t = datetime.now()
|
51 |
+
self._tk_zero = t - t
|
52 |
+
|
53 |
+
self._tk_stats = {}
|
54 |
+
self._tk_time = {}
|
55 |
+
|
56 |
+
def _add_timekeeper(self, msg):
|
57 |
+
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
|
58 |
+
self._tk_stats[msg] = []
|
59 |
+
self._tk_time[msg] = None
|
60 |
+
|
61 |
+
def _add_timekeepers(self, msgs):
|
62 |
+
for msg in msgs:
|
63 |
+
self._add_timekeeper(msg)
|
64 |
+
|
65 |
+
def _tk_start(self, msg):
|
66 |
+
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
|
67 |
+
assert self._tk_time[msg] is None
|
68 |
+
|
69 |
+
self._tk_time[msg] = datetime.now()
|
70 |
+
|
71 |
+
def _tk_stop(self, msg):
|
72 |
+
if self.kwargs.log_steps < 0 and self.accelerator.is_main_process:
|
73 |
+
assert self._tk_time[msg] is not None
|
74 |
+
|
75 |
+
this_time = datetime.now() - self._tk_time[msg]
|
76 |
+
self._tk_time[msg] = None
|
77 |
+
self._tk_stats[msg].append(this_time)
|
78 |
+
|
79 |
+
log(f"{msg} took {this_time}, avg time: " +
|
80 |
+
f" {sum(self._tk_stats[msg], self._tk_zero) / len(self._tk_stats[msg])}" +
|
81 |
+
f" over {len(self._tk_stats[msg])} samples")
|
82 |
+
|
83 |
+
def __handle_accum(self):
|
84 |
+
|
85 |
+
assert self.kwargs.batch_size % (self.accelerator.num_processes * self.kwargs.nr_sents_per_gpu) == 0,\
|
86 |
+
"batch size must be divisible by number of processes and number of segments per GPU"
|
87 |
+
|
88 |
+
accum_steps = int((self.kwargs.batch_size / self.accelerator.num_processes) / self.kwargs.nr_sents_per_gpu)
|
89 |
+
self.accelerator.gradient_accumulation_steps = accum_steps
|
90 |
+
|
91 |
+
log(f"Nr sents/GPU: {self.kwargs.nr_sents_per_gpu}, accum steps: {accum_steps}, " +
|
92 |
+
f"nr. procs: {self.accelerator.num_processes}, batch size: {self.kwargs.batch_size}",
|
93 |
+
accelerator=self.accelerator)
|
94 |
+
|
95 |
+
def ___get_train_scalars(self):
|
96 |
+
epoch_len = len(self.train_set_iter)
|
97 |
+
train_len = epoch_len * self.kwargs.epochs
|
98 |
+
|
99 |
+
num_warmup = 0 #int(train_len * 0.01)
|
100 |
+
|
101 |
+
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
|
102 |
+
|
103 |
+
return train_len, num_warmup
|
104 |
+
|
105 |
+
def __init_opt_lr_and_what_else(self):
|
106 |
+
train_len, num_warmup = self.___get_train_scalars()
|
107 |
+
|
108 |
+
opt = torch.optim.AdamW(self.model.parameters(), lr=self.kwargs.lr)
|
109 |
+
|
110 |
+
numtr = train_len * self.accelerator.num_processes
|
111 |
+
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup, num_training_steps=numtr)
|
112 |
+
|
113 |
+
self.optimizer, self.lr_scheduler, self.model = self.accelerator.prepare(opt, lr_scheduler, self.model)
|
114 |
+
|
115 |
+
self.accelerator.register_for_checkpointing(self.data_state, self.train_loss_list)
|
116 |
+
|
117 |
+
def _init_acc_and_stuff(self, preinit_acc=None):
|
118 |
+
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
119 |
+
|
120 |
+
if preinit_acc is None:
|
121 |
+
self.accelerator = Accelerator()
|
122 |
+
else:
|
123 |
+
self.accelerator = preinit_acc
|
124 |
+
|
125 |
+
self.__handle_accum()
|
126 |
+
|
127 |
+
self.__init_opt_lr_and_what_else()
|
128 |
+
|
129 |
+
if self.kwargs.continue_training:
|
130 |
+
self.accelerator.load_state(self.kwargs.mdl_id)
|
131 |
+
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
|
132 |
+
|
133 |
+
def train(self, dry_run=False):
|
134 |
+
try:
|
135 |
+
self._main_loop(dry_run)
|
136 |
+
except Exception as e:
|
137 |
+
#in multiprocess scenarios it is hard to read the stack trace, so just show one:
|
138 |
+
if self.accelerator.is_main_process:
|
139 |
+
raise e
|
140 |
+
|
141 |
+
self.accelerator.wait_for_everyone()
|
142 |
+
|
143 |
+
unwr_coupled_model = self.accelerator.unwrap_model(self.model)
|
144 |
+
|
145 |
+
return unwr_coupled_model
|
146 |
+
|
147 |
+
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
|
148 |
+
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
|
149 |
+
to_proc_idx = from_proc_idx + sub_batch_size
|
150 |
+
|
151 |
+
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
|
152 |
+
|
153 |
+
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
|
154 |
+
|
155 |
+
def _get_split_batch_params(self):
|
156 |
+
batch_nr_snts = self.kwargs.batch_size
|
157 |
+
|
158 |
+
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
|
159 |
+
|
160 |
+
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
|
161 |
+
|
162 |
+
sub_batch_size = self.kwargs.nr_sents_per_gpu
|
163 |
+
|
164 |
+
nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
|
165 |
+
|
166 |
+
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
|
167 |
+
return sub_batch_size, nr_steps, proc_batch_nr_snts
|
168 |
+
|
169 |
+
def _report_mem_every_once_in_a_while(self, sub_batch_idx, epoch_batch_idx, batch_dim):
|
170 |
+
if sub_batch_idx == 0:
|
171 |
+
report_devices(f"training memory usage (batch size: {self.kwargs.batch_size} / {batch_dim[1]}",
|
172 |
+
self.accelerator, self.model)
|
173 |
+
|
174 |
+
def _main_loop(self, dry_run):
|
175 |
+
if self.accelerator.is_main_process:
|
176 |
+
logger = SameLineLogger(len(self.train_set_iter), self.kwargs.epochs, self.data_state)
|
177 |
+
logger.line_start()
|
178 |
+
else:
|
179 |
+
logger = None
|
180 |
+
|
181 |
+
self.model.train()
|
182 |
+
self.train_set_iter.thats_where(self.data_state)
|
183 |
+
|
184 |
+
tks = "full_batch", "prep_inputs", "forward", "backward", "upd_step"
|
185 |
+
tk_batch, tk_prep, tk_fw, tk_bk, tk_step = tks
|
186 |
+
self._add_timekeepers(tks)
|
187 |
+
|
188 |
+
with self.accelerator.accumulate(self.model):
|
189 |
+
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
|
190 |
+
for batch, epoch_batch_idx in self.train_set_iter:
|
191 |
+
if dry_run:
|
192 |
+
log(f"Dry run, batch width: {batch['input_ids'].size()}")
|
193 |
+
else:
|
194 |
+
self._report_mem_every_once_in_a_while(0, epoch_batch_idx, batch['input_ids'].size())
|
195 |
+
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params()
|
196 |
+
|
197 |
+
self._tk_start(tk_batch)
|
198 |
+
|
199 |
+
loss = None
|
200 |
+
for sub_batch_idx in range(nr_steps):
|
201 |
+
self._tk_start(tk_prep) ########
|
202 |
+
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
|
203 |
+
|
204 |
+
inputs['labels'] = inputs['input_ids'].copy()
|
205 |
+
self._tk_stop(tk_prep) ########
|
206 |
+
|
207 |
+
self._tk_start(tk_fw) ########
|
208 |
+
outputs = self.model(**inputs)
|
209 |
+
|
210 |
+
loss = outputs.loss
|
211 |
+
self._tk_stop(tk_fw) ########
|
212 |
+
|
213 |
+
self.train_loss_list.append(loss.item(), sub_batch_idx, epoch_batch_idx, _epoch_idx)
|
214 |
+
|
215 |
+
self._tk_start(tk_bk) ########
|
216 |
+
self.accelerator.backward(loss)
|
217 |
+
self._tk_stop(tk_bk) ########
|
218 |
+
|
219 |
+
self._tk_start(tk_step) ########
|
220 |
+
self.optimizer.step()
|
221 |
+
self.lr_scheduler.step()
|
222 |
+
self.optimizer.zero_grad()
|
223 |
+
self._tk_stop(tk_step) ########
|
224 |
+
|
225 |
+
self._tk_stop(tk_batch)
|
226 |
+
|
227 |
+
#assert self.accelerator.sync_gradients, "It is not time to sync gradients yet."
|
228 |
+
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
|
229 |
+
|
230 |
+
if self.accelerator.is_main_process:
|
231 |
+
logger.line_break()
|
232 |
+
|
233 |
+
def get_total_grad(self):
|
234 |
+
result = 0
|
235 |
+
grad_count = 0
|
236 |
+
all_count = 0
|
237 |
+
|
238 |
+
for p in self.model.parameters():
|
239 |
+
if p.grad is not None:
|
240 |
+
result += p.grad.abs().mean().item()
|
241 |
+
grad_count += 1
|
242 |
+
all_count += 1
|
243 |
+
|
244 |
+
return result/grad_count if grad_count != 0 else -1
|
245 |
+
|
246 |
+
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
|
247 |
+
epoch_len = len(self.train_set_iter)
|
248 |
+
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
|
249 |
+
|
250 |
+
is_end_of_epoch = (epoch_batch_idx == epoch_len)
|
251 |
+
|
252 |
+
if self.accelerator.is_main_process \
|
253 |
+
and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
|
254 |
+
grad = self.get_total_grad()
|
255 |
+
|
256 |
+
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
|
257 |
+
|
258 |
+
#self.optimizer.zero_grad()
|
259 |
+
|
260 |
+
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
|
261 |
+
self.accelerator.wait_for_everyone()
|
262 |
+
|
263 |
+
if self.accelerator.is_main_process:
|
264 |
+
logger.line_break()
|
265 |
+
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
|
266 |
+
|
267 |
+
self._save_all(global_batch_idx, epoch_i)
|
268 |
+
|
269 |
+
logger.line_start()
|
270 |
+
|
271 |
+
def _save_all(self, global_batch_idx, epoch_i):
|
272 |
+
epoch_len = len(self.train_set_iter)
|
273 |
+
|
274 |
+
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
|
275 |
+
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
|
276 |
+
|
277 |
+
this_location = os.path.join(self.kwargs.save_location, ckpt_name)
|
278 |
+
if os.path.exists(this_location):
|
279 |
+
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
|
280 |
+
|
281 |
+
self.data_state.copy_from(self.train_set_iter.where_are_we(), epoch_idx=epoch_i)
|
282 |
+
|
283 |
+
model_to_save = self.accelerator.unwrap_model(self.model)
|
284 |
+
|
285 |
+
save_all_models(this_location, model_to_save, self.tokenizer, trainer=self.accelerator)
|
286 |
+
|
287 |
+
def test_this_damn_thing():
|
288 |
+
# testing
|
289 |
+
import torch
|
290 |
+
import json
|
291 |
+
from torch.optim import AdamW
|
292 |
+
from modelops import hf_tok
|
293 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
294 |
+
|
295 |
+
mdl_id = "models/llama3.2-1b"
|
296 |
+
|
297 |
+
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
|
298 |
+
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
299 |
+
with open("tmpx.json", "r") as f:
|
300 |
+
training_data_raw = json.load(f)
|
301 |
+
|
302 |
+
optimizer = AdamW(model.parameters(), lr=5e-6)
|
303 |
+
|
304 |
+
print("Initial 0:", optimizer.param_groups[0]['lr']) # Should be [5e-6]
|
305 |
+
|
306 |
+
scheduler = get_scheduler(
|
307 |
+
"linear",
|
308 |
+
optimizer=optimizer,
|
309 |
+
num_warmup_steps=0,
|
310 |
+
num_training_steps=2445
|
311 |
+
)
|
312 |
+
|
313 |
+
accel = Accelerator()
|
314 |
+
|
315 |
+
p_optimizer, p_lr_scheduler, p_model = accel.prepare(optimizer, scheduler, model)
|
316 |
+
|
317 |
+
print("Initial 1:", p_lr_scheduler.get_last_lr()) # Should be [5e-6]
|
318 |
+
|
319 |
+
"""
|
320 |
+
for _ in range(2):
|
321 |
+
optimizer.step()
|
322 |
+
scheduler.step()
|
323 |
+
print("Step:", scheduler.get_last_lr())
|
324 |
+
"""
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
test_this_damn_thing()
|
kuidastaltsutadalaamat/legacy/accel_backup.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
7 |
+
from transformers import get_scheduler
|
8 |
+
|
9 |
+
from aux import SameLineLogger, log
|
10 |
+
from data import DataState
|
11 |
+
from langconv import is_dec_only_llm
|
12 |
+
from modelops import save_all_models, report_devices
|
13 |
+
from translate import encode
|
14 |
+
|
15 |
+
|
16 |
+
raise NotImplementedError("This is a backup package, do not run or import from it")
|
17 |
+
|
18 |
+
|
19 |
+
def chain_params(coupling_specs):
|
20 |
+
for spec in coupling_specs:
|
21 |
+
yield from spec.model.parameters()
|
22 |
+
|
23 |
+
|
24 |
+
class TrainLossList:
|
25 |
+
def __init__(self):
|
26 |
+
self.data = []
|
27 |
+
|
28 |
+
def append(self, loss_val, src_k, tgt_k):
|
29 |
+
self.data.append((loss_val, src_k, tgt_k))
|
30 |
+
|
31 |
+
def state_dict(self):
|
32 |
+
return {'data': self.data}
|
33 |
+
|
34 |
+
def load_state_dict(self, state_dict):
|
35 |
+
self.data = state_dict['data']
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class SwitchingAccelerator:
|
40 |
+
def __init__(self, coupling_specs, train_set, train_kwargs):
|
41 |
+
self.coupling_specs = coupling_specs
|
42 |
+
|
43 |
+
self.train_set = train_set
|
44 |
+
self.kwargs = train_kwargs
|
45 |
+
|
46 |
+
self.is_generative = is_dec_only_llm(self.coupling_specs[0].tokenizer)
|
47 |
+
|
48 |
+
self.train_loss_list = TrainLossList()
|
49 |
+
self.data_state = DataState(epoch_idx=0)
|
50 |
+
|
51 |
+
self._init_acc_and_stuff()
|
52 |
+
|
53 |
+
def _init_acc_and_stuff(self):
|
54 |
+
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
55 |
+
#self.accelerator = Accelerator(gradient_accumulation_steps=self.kwargs.accum_steps)
|
56 |
+
self.accelerator = Accelerator()
|
57 |
+
|
58 |
+
epoch_len = len(self.train_set)
|
59 |
+
train_len = epoch_len * self.kwargs.epochs
|
60 |
+
|
61 |
+
num_warmup = int(train_len * 0.01)
|
62 |
+
|
63 |
+
log(f"Warmup steps: {num_warmup}, epoch len: {epoch_len}, train len: {train_len}", accelerator=self.accelerator)
|
64 |
+
|
65 |
+
opt = torch.optim.AdamW(chain_params(self.coupling_specs), lr=self.kwargs.lr)
|
66 |
+
lr_scheduler = get_scheduler("linear", optimizer=opt, num_warmup_steps=num_warmup,
|
67 |
+
num_training_steps=train_len * self.accelerator.num_processes)
|
68 |
+
models = [s.model for s in self.coupling_specs]
|
69 |
+
|
70 |
+
self.optimizer, self.lr_scheduler, *self.models = self.accelerator.prepare(opt, lr_scheduler, *models)
|
71 |
+
|
72 |
+
self.accelerator.register_for_checkpointing(self.lr_scheduler, self.data_state, self.train_loss_list)
|
73 |
+
|
74 |
+
if self.kwargs.continue_training:
|
75 |
+
self.accelerator.load_state(self.kwargs.mdl_id)
|
76 |
+
log(f"Reloaded data state: {self.data_state}", accelerator=self.accelerator)
|
77 |
+
|
78 |
+
def train(self):
|
79 |
+
try:
|
80 |
+
self._main_loop()
|
81 |
+
except Exception as e:
|
82 |
+
#in multi-process scenarios it is hard to read the stack trace, so just show one:
|
83 |
+
if self.accelerator.is_main_process:
|
84 |
+
raise e
|
85 |
+
|
86 |
+
self.accelerator.wait_for_everyone()
|
87 |
+
|
88 |
+
unwr_coupled_model = self.accelerator.unwrap_model(self.models[0])
|
89 |
+
|
90 |
+
return unwr_coupled_model, self.train_loss_list
|
91 |
+
|
92 |
+
def _split_batch_and_bin_idxs(self, batch_with_idxs):
|
93 |
+
if self.is_generative:
|
94 |
+
batch, _ = batch_with_idxs
|
95 |
+
src_k = 0
|
96 |
+
tgt_k = 0
|
97 |
+
else:
|
98 |
+
batch, src_k, tgt_k, _ = batch_with_idxs
|
99 |
+
return batch, src_k, tgt_k
|
100 |
+
|
101 |
+
def _prepare_inputs(self, batch, sub_batch_idx, sub_batch_size, proc_batch_size):
|
102 |
+
from_proc_idx = proc_batch_size * self.accelerator.process_index + sub_batch_size * sub_batch_idx
|
103 |
+
to_proc_idx = from_proc_idx + sub_batch_size
|
104 |
+
|
105 |
+
#log(f"----> DEBUG for sub_b idx {sub_batch_idx}, proc {self.accelerator.process_index}: {from_proc_idx}:{to_proc_idx}")
|
106 |
+
|
107 |
+
return {k: batch[k][from_proc_idx:to_proc_idx].to(self.accelerator.device) for k in batch}
|
108 |
+
|
109 |
+
def _get_split_batch_params(self, batch):
|
110 |
+
batch_nr_snts = batch['input_ids'].size()[0]
|
111 |
+
snt_nr_words = batch['input_ids'].size()[1]
|
112 |
+
|
113 |
+
assert batch_nr_snts % self.accelerator.num_processes == 0, "Batch size must be divisible by number of processes."
|
114 |
+
|
115 |
+
proc_batch_nr_snts = batch_nr_snts // self.accelerator.num_processes
|
116 |
+
|
117 |
+
if self.kwargs.nr_snts_in_batch > 0:
|
118 |
+
sub_batch_size = self.kwargs.nr_snts_in_batch
|
119 |
+
else:
|
120 |
+
sub_batch_size = max(1, self.kwargs.nr_words_in_batch // snt_nr_words)
|
121 |
+
#log(f"DEBUG: #words/snt {snt_nr_words} X #snt in sub batch {sub_batch_size} = {snt_nr_words*sub_batch_size} ~ {self.kwargs.nr_words_in_batch}", accelerator=self.accelerator)
|
122 |
+
|
123 |
+
nr_steps = -(proc_batch_nr_snts // -sub_batch_size)
|
124 |
+
|
125 |
+
#log(f"--> DEBUG: sub_batch {sub_batch_size} X steps {nr_steps} ~ {proc_batch_nr_snts} ({batch_nr_snts} / {self.accelerator.num_processes})", accelerator=self.accelerator)
|
126 |
+
return sub_batch_size, nr_steps, proc_batch_nr_snts
|
127 |
+
|
128 |
+
def _main_loop(self):
|
129 |
+
#countdown_till_do_it_once = 0
|
130 |
+
|
131 |
+
if self.accelerator.is_main_process:
|
132 |
+
logger = SameLineLogger(len(self.train_set), self.kwargs.epochs)
|
133 |
+
logger.line_start()
|
134 |
+
else:
|
135 |
+
logger = None
|
136 |
+
|
137 |
+
self.models[0].train()
|
138 |
+
self.train_set.thats_where(self.data_state)
|
139 |
+
|
140 |
+
for _epoch_idx in range(self.data_state.epoch_idx, self.kwargs.epochs):
|
141 |
+
for batch_with_bin_idxs, epoch_batch_idx in self.train_set:
|
142 |
+
batch, src_k, tgt_k = self._split_batch_and_bin_idxs(batch_with_bin_idxs)
|
143 |
+
sub_batch_size, nr_steps, proc_batch_size = self._get_split_batch_params(batch)
|
144 |
+
|
145 |
+
loss = None
|
146 |
+
|
147 |
+
for sub_batch_idx in range(nr_steps):
|
148 |
+
inputs = self._prepare_inputs(batch, sub_batch_idx, sub_batch_size, proc_batch_size)
|
149 |
+
|
150 |
+
if self.is_generative:
|
151 |
+
inputs['labels'] = inputs['input_ids']
|
152 |
+
outputs = self.models[0](**inputs)
|
153 |
+
else:
|
154 |
+
encoder_vecs = encode(self.models[src_k], inputs)
|
155 |
+
outputs = self.models[tgt_k](attention_mask=inputs['attention_mask'], labels=inputs['labels'], encoder_outputs=encoder_vecs)
|
156 |
+
|
157 |
+
loss = outputs.loss
|
158 |
+
|
159 |
+
#if countdown_till_do_it_once > 0:
|
160 |
+
# countdown_till_do_it_once -= 1
|
161 |
+
#elif countdown_till_do_it_once == 0:
|
162 |
+
if sub_batch_idx == 5:
|
163 |
+
batch_size = sum([inputs[k].size()[0] * inputs[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')])
|
164 |
+
report_devices(f"training memory usage (batch size: {batch_size}; inputs:" +
|
165 |
+
f"snts {inputs['input_ids'].size()[0]} X words {inputs['input_ids'].size()[1]})",
|
166 |
+
self.accelerator, self.models[0])
|
167 |
+
countdown_till_do_it_once = 0
|
168 |
+
|
169 |
+
self.train_loss_list.append(loss.item(), src_k, tgt_k)
|
170 |
+
|
171 |
+
self.accelerator.backward(loss)
|
172 |
+
|
173 |
+
for k in inputs:
|
174 |
+
inputs[k] = inputs[k].to('cpu')
|
175 |
+
|
176 |
+
self._step_and_perhaps_save(logger, epoch_batch_idx, _epoch_idx, float(loss.item()))
|
177 |
+
|
178 |
+
if self.accelerator.is_main_process:
|
179 |
+
logger.line_break()
|
180 |
+
|
181 |
+
def get_total_grad(self):
|
182 |
+
result = 0
|
183 |
+
grad_count = 0
|
184 |
+
all_count = 0
|
185 |
+
|
186 |
+
for p in self.models[0].parameters():
|
187 |
+
if p.grad is not None:
|
188 |
+
result += p.grad.abs().mean().item()
|
189 |
+
grad_count += 1
|
190 |
+
all_count += 1
|
191 |
+
|
192 |
+
return result/grad_count if grad_count > 0 else -1
|
193 |
+
|
194 |
+
def _step_and_perhaps_save(self, logger, epoch_batch_idx, epoch_i, loss):
|
195 |
+
epoch_len = len(self.train_set)
|
196 |
+
global_batch_idx = epoch_batch_idx + epoch_i * epoch_len
|
197 |
+
|
198 |
+
self.optimizer.step()
|
199 |
+
self.lr_scheduler.step()
|
200 |
+
self.accelerator.wait_for_everyone()
|
201 |
+
|
202 |
+
is_end_of_epoch = (epoch_batch_idx == epoch_len)
|
203 |
+
|
204 |
+
if self.accelerator.is_main_process and (epoch_batch_idx % self.kwargs.log_steps == 0 or is_end_of_epoch):
|
205 |
+
grad = self.get_total_grad()
|
206 |
+
logger.step(global_batch_idx, epoch_batch_idx, epoch_i, loss, self.lr_scheduler.get_last_lr()[0], grad)
|
207 |
+
|
208 |
+
self.optimizer.zero_grad()
|
209 |
+
|
210 |
+
if (global_batch_idx % self.kwargs.save_steps == 0) or is_end_of_epoch:
|
211 |
+
self.accelerator.wait_for_everyone()
|
212 |
+
|
213 |
+
if self.accelerator.is_main_process:
|
214 |
+
logger.line_break()
|
215 |
+
log(f"Saving at {epoch_batch_idx} steps, epoch {epoch_i + 1} ({global_batch_idx} global steps)", accelerator=self.accelerator)
|
216 |
+
|
217 |
+
self._save_all(global_batch_idx, epoch_i)
|
218 |
+
|
219 |
+
logger.line_start()
|
220 |
+
|
221 |
+
def _save_all(self, global_batch_idx, epoch_i):
|
222 |
+
epoch_len = len(self.train_set)
|
223 |
+
|
224 |
+
ckpt_name = (f"checkpoint-e{epoch_i + 1:02}-" +
|
225 |
+
(f"b{global_batch_idx:07}" if (global_batch_idx % epoch_len) else f"full"))
|
226 |
+
|
227 |
+
this_location = os.path.join(self.kwargs.save_location, ckpt_name)
|
228 |
+
if os.path.exists(this_location):
|
229 |
+
raise FileExistsError(f"Cannot overwrite existing checkpoint {this_location}!")
|
230 |
+
|
231 |
+
self.data_state.copy_from(self.train_set.where_are_we(), epoch_idx=epoch_i)
|
232 |
+
|
233 |
+
model_to_save = self.accelerator.unwrap_model(self.models[0])
|
234 |
+
|
235 |
+
save_all_models(this_location, model_to_save, self.coupling_specs[0].tokenizer,
|
236 |
+
self.coupling_specs, trainer=self.accelerator)
|
237 |
+
"""
|
kuidastaltsutadalaamat/legacy/benchmark.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
from data import split_by_lang, make_path_compatible, get_tr_pairs
|
9 |
+
from inference import coupled_translate, load_and_init_module_config, neurotolge_in_batches
|
10 |
+
from evaluate import load as load_metric
|
11 |
+
from legacy.langconv import get_mdl_type, get_joshi_class
|
12 |
+
from accelerate import Accelerator
|
13 |
+
|
14 |
+
from aux import log
|
15 |
+
|
16 |
+
|
17 |
+
def get_hyp_cache_dir(model_location, create=False):
|
18 |
+
hyp_location = os.path.join(model_location, "hyp_cache")
|
19 |
+
if create:
|
20 |
+
os.makedirs(hyp_location, exist_ok=True)
|
21 |
+
return hyp_location
|
22 |
+
|
23 |
+
|
24 |
+
def get_hyp_cache_filename(model_location, benchmark_corpus, src_lang, tgt_lang):
|
25 |
+
hyp_location = get_hyp_cache_dir(model_location)
|
26 |
+
|
27 |
+
corpus_base = os.path.basename(benchmark_corpus)
|
28 |
+
basename = f"{corpus_base}-{src_lang}-to-{tgt_lang}"
|
29 |
+
|
30 |
+
hyp_file = os.path.join(hyp_location, f"{basename}.hyp")
|
31 |
+
src_file = os.path.join(hyp_location, f"{basename}.src")
|
32 |
+
|
33 |
+
return hyp_file, src_file
|
34 |
+
|
35 |
+
|
36 |
+
def get_benchmark_filename(model_location, benchmark_corpus):
|
37 |
+
corpus_base = os.path.basename(benchmark_corpus)
|
38 |
+
hyp_file = f"{corpus_base}-scores.json"
|
39 |
+
return os.path.join(model_location, hyp_file)
|
40 |
+
|
41 |
+
|
42 |
+
def load_hyps_from_file(filename):
|
43 |
+
with open(filename, "r", encoding="utf-8") as f:
|
44 |
+
return [line.strip() for line in f]
|
45 |
+
|
46 |
+
|
47 |
+
def save_hyps_to_file(hypos, filename):
|
48 |
+
if hypos is not None:
|
49 |
+
with open(filename, "w", encoding="utf-8") as f:
|
50 |
+
for hyp in hypos:
|
51 |
+
f.write(hyp + "\n")
|
52 |
+
|
53 |
+
|
54 |
+
def load_or_translate(mod_config, input_output_list, lp, model_location, benchmark_corpus):
|
55 |
+
src_lang, tgt_lang = lp.split("-")
|
56 |
+
|
57 |
+
inputs, _ = zip(*input_output_list)
|
58 |
+
|
59 |
+
cache_filename, src_filename = get_hyp_cache_filename(model_location, benchmark_corpus, src_lang, tgt_lang)
|
60 |
+
|
61 |
+
try:
|
62 |
+
hypos = load_hyps_from_file(cache_filename)
|
63 |
+
except FileNotFoundError:
|
64 |
+
if model_location == "models/neurotolge":
|
65 |
+
hypos = neurotolge_in_batches(inputs, src_lang, tgt_lang)
|
66 |
+
else:
|
67 |
+
hypos = coupled_translate(mod_config, inputs, src_lang, tgt_lang)
|
68 |
+
|
69 |
+
if hypos is not None:
|
70 |
+
save_hyps_to_file(hypos, cache_filename)
|
71 |
+
save_hyps_to_file(inputs, src_filename)
|
72 |
+
|
73 |
+
return zip(inputs, hypos)
|
74 |
+
|
75 |
+
|
76 |
+
def translate_all_hyps(lp_test_set_dict, module_conf, model_id, corpus_id, accelerator=None):
|
77 |
+
if accelerator is not None:
|
78 |
+
key_list = sorted(lp_test_set_dict.keys())
|
79 |
+
for idx, lp in enumerate(key_list):
|
80 |
+
if idx % accelerator.num_processes == accelerator.process_index:
|
81 |
+
log(f"Process {accelerator.process_index} translating {lp}")
|
82 |
+
load_or_translate(module_conf, lp_test_set_dict[lp], lp, model_id, corpus_id)
|
83 |
+
accelerator.wait_for_everyone()
|
84 |
+
else:
|
85 |
+
result = dict()
|
86 |
+
for i, lp in enumerate(lp_test_set_dict.keys()):
|
87 |
+
log(f"Translating {lp}, {i + 1}/{len(lp_test_set_dict)}")
|
88 |
+
result[lp] = load_or_translate(module_conf, lp_test_set_dict[lp], lp, model_id, corpus_id)
|
89 |
+
return result
|
90 |
+
|
91 |
+
|
92 |
+
def get_joshi_lp(from_lang, to_lang):
|
93 |
+
from_joshi = get_joshi_class(from_lang)
|
94 |
+
to_joshi = get_joshi_class(to_lang)
|
95 |
+
|
96 |
+
return f"{from_joshi}-{to_joshi}"
|
97 |
+
|
98 |
+
|
99 |
+
def get_all_scores(hyps_dict, lp_test_sets, metric_dict):
|
100 |
+
scores = dict()
|
101 |
+
avgs = defaultdict(list)
|
102 |
+
|
103 |
+
for lp in lp_test_sets:
|
104 |
+
from_lang, to_lang = lp.split("-")
|
105 |
+
jlp = get_joshi_lp(from_lang, to_lang)
|
106 |
+
|
107 |
+
_, outputs = zip(*lp_test_sets[lp])
|
108 |
+
|
109 |
+
preds = None if hyps_dict[lp] is None else [hyp for _, hyp in hyps_dict[lp]]
|
110 |
+
|
111 |
+
for metric_name in metric_dict:
|
112 |
+
metric_func = metric_dict[metric_name]
|
113 |
+
|
114 |
+
if preds is not None:
|
115 |
+
metric_value = metric_func.compute(predictions=preds, references=outputs)
|
116 |
+
|
117 |
+
scores[lp + "-" + metric_name] = metric_value['score']
|
118 |
+
|
119 |
+
avgs[jlp + "-" + metric_name].append(metric_value['score'])
|
120 |
+
|
121 |
+
for avg_k in avgs:
|
122 |
+
scores[avg_k] = sum(avgs[avg_k]) / len(avgs[avg_k])
|
123 |
+
|
124 |
+
return scores
|
125 |
+
|
126 |
+
|
127 |
+
def save_scores(scores, mdl_id, corpus):
|
128 |
+
filename = get_benchmark_filename(mdl_id, corpus)
|
129 |
+
with open(filename, "w") as ofh:
|
130 |
+
json.dump(scores, ofh, indent=2, sort_keys=True)
|
131 |
+
|
132 |
+
|
133 |
+
def benchmark_neurotolge(corpus):
|
134 |
+
log("Loading data")
|
135 |
+
lp_test_sets = split_by_lang(filename=corpus, model_type=None)
|
136 |
+
|
137 |
+
log("Starting benchmarking")
|
138 |
+
_ = get_hyp_cache_dir("models/neurotolge", create=True)
|
139 |
+
|
140 |
+
hyps_dict = translate_all_hyps(lp_test_sets, None, "models/neurotolge", corpus)
|
141 |
+
|
142 |
+
log("Loading metrics")
|
143 |
+
exp_id = "neurotõlge---" + make_path_compatible(corpus)
|
144 |
+
metric_dict = {
|
145 |
+
'bleu': load_metric("sacrebleu", experiment_id=exp_id),
|
146 |
+
'chrf': load_metric("chrf", experiment_id=exp_id) }
|
147 |
+
|
148 |
+
scores = get_all_scores(hyps_dict, lp_test_sets, metric_dict)
|
149 |
+
|
150 |
+
save_scores(scores, "models/neurotolge", corpus)
|
151 |
+
|
152 |
+
|
153 |
+
def benchmark_local_model(mdl_id, corpus):
|
154 |
+
accelerator = Accelerator()
|
155 |
+
|
156 |
+
main_model, module_config = load_and_init_module_config(mdl_id, accelerator)
|
157 |
+
|
158 |
+
log("Loading data", accelerator=accelerator)
|
159 |
+
lp_test_sets = split_by_lang(filename=corpus, model_type=get_mdl_type(main_model))
|
160 |
+
|
161 |
+
log("Loading metrics", accelerator=accelerator)
|
162 |
+
exp_id = make_path_compatible(mdl_id) + "---" + make_path_compatible(corpus)
|
163 |
+
|
164 |
+
metric_dict = {
|
165 |
+
'bleu': load_metric("sacrebleu", experiment_id=exp_id),
|
166 |
+
'chrf': load_metric("chrf", experiment_id=exp_id) }
|
167 |
+
|
168 |
+
log("Starting benchmarking", accelerator=accelerator)
|
169 |
+
|
170 |
+
if accelerator.is_main_process:
|
171 |
+
_ = get_hyp_cache_dir(mdl_id, create=True)
|
172 |
+
|
173 |
+
translate_all_hyps(lp_test_sets, module_config, mdl_id, corpus, accelerator)
|
174 |
+
|
175 |
+
if accelerator.is_main_process:
|
176 |
+
fin_hyps_dict = translate_all_hyps(lp_test_sets, module_config, mdl_id, corpus)
|
177 |
+
|
178 |
+
scores = get_all_scores(fin_hyps_dict, lp_test_sets, metric_dict)
|
179 |
+
|
180 |
+
save_scores(scores, mdl_id, corpus)
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
mdl_id_param = sys.argv[1]
|
185 |
+
corpus_param = sys.argv[2]
|
186 |
+
|
187 |
+
if mdl_id_param == "neurotolge":
|
188 |
+
benchmark_neurotolge(corpus_param)
|
189 |
+
else:
|
190 |
+
benchmark_local_model(mdl_id_param, corpus_param)
|
kuidastaltsutadalaamat/legacy/data.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import json
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from torch.utils.data import IterableDataset
|
7 |
+
from random import shuffle, randint
|
8 |
+
from legacy.tokops import tokenize_batch
|
9 |
+
from aux import log
|
10 |
+
|
11 |
+
def prep_llm_input(ljmftpl):
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
#{'task': 'translate' / 'approx-translate' / 'generate',
|
15 |
+
# 'src_segm': src_segm,
|
16 |
+
# 'tgt_segm': tgt_segm,
|
17 |
+
# 'src_lang': src_lang,
|
18 |
+
# 'tgt_lang': tgt_lang}
|
19 |
+
|
20 |
+
result = f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']}"
|
21 |
+
|
22 |
+
if ljmftpl['task'] in {'translate', 'approx-translate'}:
|
23 |
+
result += f"; {ljmftpl['task']} to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}"
|
24 |
+
|
25 |
+
return result
|
26 |
+
|
27 |
+
def make_path_compatible(filename):
|
28 |
+
return filename.replace("/", "_").replace(":", "-")
|
29 |
+
|
30 |
+
def do_list_in_batches(data, batch_size):
|
31 |
+
i = 0
|
32 |
+
|
33 |
+
while i < len(data):
|
34 |
+
yield data[i:i + batch_size]
|
35 |
+
i += batch_size
|
36 |
+
|
37 |
+
|
38 |
+
class DataState:
|
39 |
+
def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None):
|
40 |
+
self.elem_idx = elem_idx
|
41 |
+
self.shard_idx = shard_idx
|
42 |
+
self.epoch_idx = epoch_idx
|
43 |
+
|
44 |
+
def state_dict(self):
|
45 |
+
return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx}
|
46 |
+
|
47 |
+
def load_state_dict(self, state_dict):
|
48 |
+
self.elem_idx = state_dict['elem_idx']
|
49 |
+
self.shard_idx = state_dict['shard_idx']
|
50 |
+
self.epoch_idx = state_dict['epoch_idx']
|
51 |
+
|
52 |
+
def copy_from(self, src_ds, epoch_idx = None):
|
53 |
+
self.shard_idx = src_ds.shard_idx
|
54 |
+
self.elem_idx = src_ds.elem_idx
|
55 |
+
|
56 |
+
if epoch_idx is not None:
|
57 |
+
self.epoch_idx = epoch_idx
|
58 |
+
|
59 |
+
def __str__(self):
|
60 |
+
return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx)
|
61 |
+
|
62 |
+
def __repr__(self):
|
63 |
+
return self.__str__()
|
64 |
+
|
65 |
+
|
66 |
+
class BatchingIterator(IterableDataset):
|
67 |
+
def __init__(self, batched_data, batch_size, tokenizer, max_len=8000):
|
68 |
+
assert len(batched_data[0]) == batch_size, "loaded data batch size and specified batch size differ"
|
69 |
+
|
70 |
+
self.batched_data = batched_data
|
71 |
+
|
72 |
+
self.tokenizer = tokenizer
|
73 |
+
self.max_len = max_len
|
74 |
+
|
75 |
+
self.curr_elem_idx = 0
|
76 |
+
|
77 |
+
self.data_len = len(self.batched_data)
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return self.data_len
|
81 |
+
|
82 |
+
def __iter__(self):
|
83 |
+
#self.curr_elem_idx = 0
|
84 |
+
return self
|
85 |
+
|
86 |
+
def where_are_we(self):
|
87 |
+
return DataState(shard_idx=0, elem_idx=self.curr_elem_idx)
|
88 |
+
|
89 |
+
def thats_where(self, data_state):
|
90 |
+
self.curr_elem_idx = data_state.elem_idx
|
91 |
+
|
92 |
+
def _tokenize(self, prepped_segm_list):
|
93 |
+
#self.tokenizer.pad_token = '<|reserved_special_token_0|>'
|
94 |
+
#tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len,
|
95 |
+
# truncation=True, add_special_tokens=True,
|
96 |
+
# padding=True)
|
97 |
+
tokenized_batch = tokenize_batch(self.tokenizer, prepped_segm_list, maxlen=self.max_len)
|
98 |
+
return tokenized_batch, self.curr_elem_idx + 1
|
99 |
+
|
100 |
+
def __next__(self):
|
101 |
+
if self.curr_elem_idx >= self.data_len:
|
102 |
+
self.curr_elem_idx = 0
|
103 |
+
raise StopIteration
|
104 |
+
else:
|
105 |
+
batch = self._tokenize(self.batched_data[self.curr_elem_idx])
|
106 |
+
self.curr_elem_idx += 1
|
107 |
+
return batch
|
108 |
+
|
109 |
+
|
110 |
+
def shuffle_data():
|
111 |
+
# open a list of tuples, save a list of batches of strings made of these tuples
|
112 |
+
input_file = sys.argv[1]
|
113 |
+
output_file = sys.argv[2]
|
114 |
+
|
115 |
+
try:
|
116 |
+
batch_size = int(sys.argv[3])
|
117 |
+
except IndexError:
|
118 |
+
batch_size = None
|
119 |
+
|
120 |
+
log("Reading data")
|
121 |
+
# read the tuples
|
122 |
+
with open(input_file, "r") as f:
|
123 |
+
#raw_data = json.load(f)
|
124 |
+
final_data = json.load(f)
|
125 |
+
|
126 |
+
log("Making strings")
|
127 |
+
# make strings out of tuples
|
128 |
+
unsorted_data_in_elems = [prep_llm_input(s) for s in raw_data]
|
129 |
+
|
130 |
+
if batch_size is None:
|
131 |
+
final_data = unsorted_data_in_elems
|
132 |
+
else:
|
133 |
+
# if last batch is undersized, get some random elements to compensate
|
134 |
+
while len(unsorted_data_in_elems) % batch_size != 0:
|
135 |
+
new_elem_idx = randint(0, len(unsorted_data_in_elems) - 1)
|
136 |
+
unsorted_data_in_elems.append(unsorted_data_in_elems[new_elem_idx])
|
137 |
+
|
138 |
+
log("Sorting and grouping")
|
139 |
+
# sort by length
|
140 |
+
sorted_data_in_elems = sorted(unsorted_data_in_elems, key=lambda x: len(x), reverse=True)
|
141 |
+
|
142 |
+
# group into batches
|
143 |
+
final_data = list(do_list_in_batches(sorted_data_in_elems, batch_size))
|
144 |
+
|
145 |
+
log("Shuffling")
|
146 |
+
# shuffle the batches / sentences
|
147 |
+
shuffle(final_data)
|
148 |
+
|
149 |
+
log("Saving")
|
150 |
+
# save the result
|
151 |
+
with open(output_file, "w") as f:
|
152 |
+
json.dump(final_data, f)
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
all_data = []
|
156 |
+
|
157 |
+
for input_file in sys.argv[1:]:
|
158 |
+
with open(input_file, "r") as f:
|
159 |
+
this_data = json.load(f)
|
160 |
+
all_data += this_data
|
161 |
+
|
162 |
+
shuffle(all_data)
|
163 |
+
|
164 |
+
json.dump(all_data, sys.stdout)
|
kuidastaltsutadalaamat/legacy/data_backup.py
ADDED
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import json
|
4 |
+
#import os
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
#import re
|
8 |
+
import math
|
9 |
+
|
10 |
+
from torch.utils.data import IterableDataset
|
11 |
+
from collections import namedtuple, defaultdict
|
12 |
+
from random import randrange, shuffle, randint
|
13 |
+
#from pathlib import Path
|
14 |
+
|
15 |
+
#from aux import log
|
16 |
+
#from langconv import any_to_madlad, any_to_nllb, is_nllb, is_madlad, get_mdl_type, any_to_mdl_type, is_dec_only_llm, \
|
17 |
+
# base_to_nllb
|
18 |
+
#from tokops import tokenizeit
|
19 |
+
|
20 |
+
# TrPair = namedtuple('TrPair', ["src_lang", "tgt_lang", "input", "output"])
|
21 |
+
|
22 |
+
"""
|
23 |
+
def prep_llm_input(ljmftpl):
|
24 |
+
#{'task': 'translate' / 'approx-translate' / 'generate',
|
25 |
+
# 'src_segm': src_segm,
|
26 |
+
# 'tgt_segm': tgt_segm,
|
27 |
+
# 'src_lang': src_lang,
|
28 |
+
# 'tgt_lang': tgt_lang}
|
29 |
+
|
30 |
+
# it's a tuple
|
31 |
+
if "src_segm" in ljmftpl and "task" in ljmftpl:
|
32 |
+
if ljmftpl['task'] in {'translate', 'approx-translate'}:
|
33 |
+
return (f"{ljmftpl['src_segm']}\n=====\n{ljmftpl['task']} from {ljmftpl['src_lang']}; " +
|
34 |
+
f"to {ljmftpl['tgt_lang']}:\n{ljmftpl['tgt_segm']}")
|
35 |
+
|
36 |
+
elif ljmftpl['task'] == 'generate':
|
37 |
+
return f"{ljmftpl['src_segm']}\n=====\nis in {ljmftpl['src_lang']};"
|
38 |
+
|
39 |
+
# it's a string
|
40 |
+
else:
|
41 |
+
return ljmftpl
|
42 |
+
|
43 |
+
|
44 |
+
def make_path_compatible(filename):
|
45 |
+
return filename.replace("/", "_").replace(":", "-")
|
46 |
+
|
47 |
+
def do_list_in_batches(data, batch_size):
|
48 |
+
i = 0
|
49 |
+
|
50 |
+
while i < len(data):
|
51 |
+
yield data[i:i + batch_size]
|
52 |
+
i += batch_size
|
53 |
+
"""
|
54 |
+
"""
|
55 |
+
def do_bins_in_batches(bins, batch_size, sort_by_length):
|
56 |
+
result_list = []
|
57 |
+
|
58 |
+
for src_k in bins:
|
59 |
+
for tgt_k in bins[src_k]:
|
60 |
+
if src_k == 0 or tgt_k == 0:
|
61 |
+
result_list += [(e, src_k, tgt_k) for e in do_list_in_batches(bins[src_k][tgt_k], batch_size)]
|
62 |
+
|
63 |
+
shuffle(result_list)
|
64 |
+
|
65 |
+
return result_list
|
66 |
+
|
67 |
+
|
68 |
+
def _post_proc(text, lang):
|
69 |
+
if lang == 'liv' and "’" in text and "O’R" not in text:
|
70 |
+
return text.replace("’", "")
|
71 |
+
else:
|
72 |
+
return text
|
73 |
+
|
74 |
+
|
75 |
+
def clean_entry(entry, leave_out):
|
76 |
+
result = {k: _post_proc(entry[k], k) for k in entry if entry[k].strip() and k not in leave_out}
|
77 |
+
return result
|
78 |
+
|
79 |
+
|
80 |
+
def load_json_data(path, leave_out={}, skip_cats=True, load_mono=True):
|
81 |
+
with open(path, 'r') as f:
|
82 |
+
data = json.load(f)
|
83 |
+
|
84 |
+
if skip_cats:
|
85 |
+
# skip categories
|
86 |
+
resx = [clean_entry(entry, leave_out)
|
87 |
+
for cat in data for entry in cat['sentences']]
|
88 |
+
res = [e for e in resx if e]
|
89 |
+
else:
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
# resx = {cat['source']: [clean_entry(entry, leave_out) for entry in cat['sentences']] for cat in data}
|
93 |
+
# res = {k: resx[k] for k in resx if resx[k]}
|
94 |
+
|
95 |
+
return res
|
96 |
+
|
97 |
+
|
98 |
+
def get_tr_pairs(raw_data=None, filename=None, leave_out=None, leave_only=None, model_type=None, exclude_set=None):
|
99 |
+
if filename is not None:
|
100 |
+
raw_data = load_json_data(filename)
|
101 |
+
|
102 |
+
if raw_data is None:
|
103 |
+
raise ValueError("Neither file nor data are provided")
|
104 |
+
|
105 |
+
i = 0
|
106 |
+
log("Loading data")
|
107 |
+
for tup in raw_data:
|
108 |
+
for l1 in tup:
|
109 |
+
for l2 in tup:
|
110 |
+
if l1 != l2 and not "dia" in l1 and not "dia" in l2:
|
111 |
+
if leave_out is None or f"{l1}-{l2}" not in leave_out:
|
112 |
+
if leave_only is None or f"{l1}-{l2}" in leave_only:
|
113 |
+
i += 1
|
114 |
+
if not i % 1000000:
|
115 |
+
log(f"Loaded {i/1000000}M pairs")
|
116 |
+
dia_key = f"{l2}-dia"
|
117 |
+
|
118 |
+
if exclude_set is None or (tup[l1] not in exclude_set[l1] and tup[l2] not in exclude_set[l2]):
|
119 |
+
input = tup[l1]
|
120 |
+
if dia_key in tup:
|
121 |
+
input = f"<{tup[dia_key]}> {input}"
|
122 |
+
|
123 |
+
conv_l1 = any_to_mdl_type(model_type, l1)
|
124 |
+
conv_l2 = any_to_mdl_type(model_type, l2)
|
125 |
+
|
126 |
+
if not snt_is_fishy(input, conv_l1) and not snt_is_fishy(tup[l2], conv_l2):
|
127 |
+
yield TrPair(conv_l1, conv_l2, input, tup[l2])
|
128 |
+
|
129 |
+
def split_by_lang(filename, model_type):
|
130 |
+
result = defaultdict(list)
|
131 |
+
|
132 |
+
# if filename is not None:
|
133 |
+
# tr_pairs = load_json_datax(filename)
|
134 |
+
|
135 |
+
tr_pairs = get_tr_pairs(filename=filename, model_type=model_type)
|
136 |
+
|
137 |
+
for tup in tr_pairs:
|
138 |
+
#for l1 in tup:
|
139 |
+
# for l2 in tup:
|
140 |
+
# if l1 != l2 and not "dia" in l1 and not "dia" in l2:
|
141 |
+
l1 = tup.src_lang
|
142 |
+
l2 = tup.tgt_lang
|
143 |
+
lp = f"{l1}-{l2}"
|
144 |
+
result[lp].append((tup.input, tup.output))
|
145 |
+
|
146 |
+
return result
|
147 |
+
|
148 |
+
|
149 |
+
def data_iter_for_tok_train(raw_data, langs_to_include):
|
150 |
+
for tup in raw_data:
|
151 |
+
for lang in tup:
|
152 |
+
if lang in langs_to_include:
|
153 |
+
yield tup[lang]
|
154 |
+
|
155 |
+
|
156 |
+
def lang_bin_mapping(coupling_specs):
|
157 |
+
lang_to_idx = dict()
|
158 |
+
|
159 |
+
for i, spec_pair in enumerate(coupling_specs):
|
160 |
+
for lang in spec_pair.lang_set:
|
161 |
+
if lang not in lang_to_idx:
|
162 |
+
lang_to_idx[lang] = {i}
|
163 |
+
else:
|
164 |
+
lang_to_idx[lang].add(i)
|
165 |
+
|
166 |
+
return lang_to_idx
|
167 |
+
|
168 |
+
|
169 |
+
def mix_and_sample_idxs_carefully(src_idxs, tgt_idxs):
|
170 |
+
idx_pairs = [(s, t) for s in src_idxs for t in tgt_idxs if not (s == 1 and t == 1)]
|
171 |
+
|
172 |
+
if len(idx_pairs) == 0:
|
173 |
+
result = (None, None)
|
174 |
+
else:
|
175 |
+
pair_idx = randrange(len(idx_pairs))
|
176 |
+
result = idx_pairs[pair_idx]
|
177 |
+
|
178 |
+
# debug(f"src lang: {tr_pair.src_lang}, tgt_lang: {tr_pair.tgt_lang}, idx list: {idx_pairs}, result: {result}")
|
179 |
+
|
180 |
+
return result
|
181 |
+
|
182 |
+
|
183 |
+
def inject_bin_indices(batch, src_k, tgt_k):
|
184 |
+
batch['input_ids'][0,0] += src_k << 30
|
185 |
+
|
186 |
+
batch['labels'][0,0] += tgt_k << 30
|
187 |
+
|
188 |
+
def get_data_cache_location(cache_meta_path, idx):
|
189 |
+
cache_folder, cache_file = os.path.split(cache_meta_path)
|
190 |
+
|
191 |
+
if cache_folder:
|
192 |
+
Path(cache_folder).mkdir(parents=True, exist_ok=True)
|
193 |
+
|
194 |
+
if cache_meta_path.endswith(".json"):
|
195 |
+
return cache_meta_path[:-5] + f"_{idx:04}.pt"
|
196 |
+
else:
|
197 |
+
raise ValueError(f"Expected a json file for the cache meta-location ({cache_meta_path})")
|
198 |
+
|
199 |
+
|
200 |
+
def make_gen_text(src_lang, tgt_lang, input_text, output_text=None, tok=None):
|
201 |
+
if input_text.startswith("<"):
|
202 |
+
posit = input_text.find(">") + 1
|
203 |
+
dialect = input_text[1:posit-1]
|
204 |
+
diatxt = f", variety: {dialect}"
|
205 |
+
txt = input_text[posit+1:]
|
206 |
+
else:
|
207 |
+
dialect = None
|
208 |
+
diatxt = ""
|
209 |
+
txt = input_text
|
210 |
+
|
211 |
+
return (f"Translate:\n== From: {src_lang}\n== To: {tgt_lang}{diatxt}\n== Input: {txt}\n== Output: " +
|
212 |
+
("" if (output_text is None or tok is None) else f"{output_text}{tok.eos_token}"))
|
213 |
+
|
214 |
+
|
215 |
+
class MultilingualBatchingCachingDataset:
|
216 |
+
def _post_proc_bins(self, bins):
|
217 |
+
for src_k in bins:
|
218 |
+
for tgt_k in bins[src_k]:
|
219 |
+
while len(bins[src_k][tgt_k]) % self.args.batch_size != 0:
|
220 |
+
rnd_elem_idx = randrange(len(bins[src_k][tgt_k]))
|
221 |
+
rnd_elem = bins[src_k][tgt_k][rnd_elem_idx]
|
222 |
+
bins[src_k][tgt_k].append(rnd_elem)
|
223 |
+
|
224 |
+
if self.args.sort_by_len:
|
225 |
+
bins[src_k][tgt_k] = sorted(bins[src_k][tgt_k], key=lambda e: len(e.input))
|
226 |
+
else:
|
227 |
+
shuffle(bins[src_k][tgt_k])
|
228 |
+
return bins
|
229 |
+
|
230 |
+
def _get_idxs(self, tr_pair):
|
231 |
+
src_idxs = self._lang_to_idx[tr_pair.src_lang]
|
232 |
+
tgt_idxs = self._lang_to_idx[tr_pair.tgt_lang]
|
233 |
+
|
234 |
+
return mix_and_sample_idxs_carefully(src_idxs, tgt_idxs)
|
235 |
+
|
236 |
+
def _fill_bins(self):
|
237 |
+
bins = defaultdict(lambda: defaultdict(list))
|
238 |
+
|
239 |
+
for tr_pair in get_tr_pairs(filename=self.filename, model_type=self.model_type, exclude_set=self.exclude_set):
|
240 |
+
src_bin_idx, tgt_bin_idx = self._get_idxs(tr_pair)
|
241 |
+
|
242 |
+
if src_bin_idx is not None and tgt_bin_idx is not None:
|
243 |
+
bins[src_bin_idx][tgt_bin_idx].append(tr_pair)
|
244 |
+
|
245 |
+
return self._post_proc_bins(bins)
|
246 |
+
|
247 |
+
def report_update_stats(self, bins):
|
248 |
+
total = 0
|
249 |
+
totalx = 0
|
250 |
+
updates = 0
|
251 |
+
duds = 0
|
252 |
+
|
253 |
+
enc_count = 0
|
254 |
+
dec_count = 0
|
255 |
+
|
256 |
+
for src_k in bins:
|
257 |
+
for tgt_k in bins[src_k]:
|
258 |
+
l = len(bins[src_k][tgt_k])
|
259 |
+
|
260 |
+
total += l
|
261 |
+
if src_k == 0 or tgt_k == 0:
|
262 |
+
totalx += l
|
263 |
+
updates += l * (1 - (src_k + tgt_k) / 2)
|
264 |
+
|
265 |
+
enc_count += l * (1 - src_k)
|
266 |
+
dec_count += l * (1 - tgt_k)
|
267 |
+
|
268 |
+
if src_k == 1 and tgt_k == 1:
|
269 |
+
duds += 1
|
270 |
+
# log(str(self._lang_to_idx))
|
271 |
+
|
272 |
+
log(f"### Ratio of coupled model updates: {100 * updates / total:.2f}% ({100 * updates / totalx:.2f}%); " + \
|
273 |
+
f"frozen meaningless updates: {100 * duds / total:.2f}%; " + \
|
274 |
+
f"enc samples: {enc_count}, dec samples: {dec_count}")
|
275 |
+
|
276 |
+
def tokenize_input(self, cplspec, input_list, rawbatch):
|
277 |
+
src_tokenizer = cplspec.tokenizer
|
278 |
+
src_tokenizer.src_lang = rawbatch[0].src_lang
|
279 |
+
#prep_batch_grouped = src_tokenizer(text=input_list, return_tensors="pt",
|
280 |
+
# padding="longest", truncation=True, max_length=self.args.max_snt_len)
|
281 |
+
prep_batch_grouped = tokenizeit((src_tokenizer, cplspec.postokenizer), input_list, self.args.max_snt_len, False)
|
282 |
+
|
283 |
+
if is_nllb(src_tokenizer):
|
284 |
+
src_lang_list = [any_to_nllb(e.src_lang) for e in rawbatch]
|
285 |
+
src_lang_vec = src_tokenizer.convert_tokens_to_ids(src_lang_list)
|
286 |
+
prep_batch_grouped['input_ids'][:,0] = torch.tensor(src_lang_vec)
|
287 |
+
|
288 |
+
return prep_batch_grouped
|
289 |
+
|
290 |
+
def tokenize_output(self, tgttokenizer, tgtposttok, rawbatch):
|
291 |
+
outputs = [e.output for e in rawbatch]
|
292 |
+
tgttokenizer.tgt_lang = rawbatch[0].tgt_lang
|
293 |
+
#labels = tgttokenizer(text_target=outputs, return_tensors="pt",
|
294 |
+
# padding="longest", truncation=True, max_length=self.args.max_snt_len)
|
295 |
+
labels = tokenizeit((tgttokenizer, tgtposttok), outputs, self.args.max_snt_len, True)
|
296 |
+
|
297 |
+
if is_nllb(tgttokenizer):
|
298 |
+
tgt_lang_list = [any_to_nllb(e.tgt_lang) for e in rawbatch]
|
299 |
+
tgt_lang_vec = tgttokenizer.convert_tokens_to_ids(tgt_lang_list)
|
300 |
+
labels['input_ids'][:, 0] = torch.tensor(tgt_lang_vec)
|
301 |
+
|
302 |
+
return labels
|
303 |
+
|
304 |
+
def tokenize_gen_batch(self, raw_batch):
|
305 |
+
tokenizer = self.coupling_specs[0].tokenizer
|
306 |
+
tokenizer.pad_token = '<|reserved_special_token_0|>'
|
307 |
+
tokenizer.padding_side = 'left'
|
308 |
+
|
309 |
+
texts = [make_gen_text(e.src_lang, e.tgt_lang, e.input, e.output, tokenizer) for e in raw_batch]
|
310 |
+
|
311 |
+
#batch = tokenizer(texts, return_tensors="pt", max_length=512, truncation=True, add_special_tokens=True, padding=True)
|
312 |
+
batch = tokenizeit((tokenizer, self.coupling_specs[0].postokenizer), texts, self.args.max_snt_len, False)
|
313 |
+
|
314 |
+
return batch
|
315 |
+
|
316 |
+
def tokenize_and_pad(self, raw_batch, src_k, tgt_k):
|
317 |
+
tgt_tokenizer = self.coupling_specs[tgt_k].tokenizer
|
318 |
+
tgt_postok = self.coupling_specs[tgt_k].postokenizer
|
319 |
+
|
320 |
+
if is_madlad(tgt_tokenizer):
|
321 |
+
inputs = [f"{any_to_madlad(e.tgt_lang)} {e.input}" for e in raw_batch]
|
322 |
+
else:
|
323 |
+
inputs = [e.input for e in raw_batch]
|
324 |
+
|
325 |
+
prep_batch_grouped = self.tokenize_input(self.coupling_specs[src_k], inputs, raw_batch)
|
326 |
+
labels = self.tokenize_output(tgt_tokenizer, tgt_postok, raw_batch)
|
327 |
+
prep_batch_grouped['labels'] = labels['input_ids']
|
328 |
+
|
329 |
+
# inject_bin_indices(prep_batch_grouped, src_k, tgt_k)
|
330 |
+
|
331 |
+
#split_prep_batch = [{k: prep_batch_grouped[k][i] for k in prep_batch_grouped}
|
332 |
+
# for i, trp in enumerate(raw_batch)]
|
333 |
+
|
334 |
+
return prep_batch_grouped
|
335 |
+
|
336 |
+
def _bins_to_tokenized_batched_cached_data(self, bins, cache_path):
|
337 |
+
shard_i = 0
|
338 |
+
batch_i = 0
|
339 |
+
total_i = 0
|
340 |
+
|
341 |
+
metainfo = []
|
342 |
+
data = []
|
343 |
+
|
344 |
+
log("Tokenizing data")
|
345 |
+
|
346 |
+
for raw_batch, src_k, tgt_k in do_bins_in_batches(bins, self.args.batch_size, self.args.sort_by_len):
|
347 |
+
batch_i += 1
|
348 |
+
if not batch_i % 10000:
|
349 |
+
log(f"Tokenized {batch_i + shard_i * self.args.shard_size} batches (shard {shard_i})")
|
350 |
+
|
351 |
+
if is_dec_only_llm(self.coupling_specs[tgt_k].tokenizer):
|
352 |
+
prepared_batch = self.tokenize_gen_batch(raw_batch)
|
353 |
+
data.append((prepared_batch, total_i))
|
354 |
+
else:
|
355 |
+
prepared_batch = self.tokenize_and_pad(raw_batch, src_k, tgt_k)
|
356 |
+
data.append((prepared_batch, src_k, tgt_k, total_i))
|
357 |
+
|
358 |
+
if batch_i >= self.args.shard_size:
|
359 |
+
shard_i += 1
|
360 |
+
batch_i = 0
|
361 |
+
fn = self._save_cache_file(data, cache_path, shard_i)
|
362 |
+
metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
|
363 |
+
|
364 |
+
del data
|
365 |
+
|
366 |
+
data = []
|
367 |
+
|
368 |
+
total_i += 1
|
369 |
+
|
370 |
+
if len(data) > 0:
|
371 |
+
fn = self._save_cache_file(data, cache_path, shard_i + 1)
|
372 |
+
metainfo.append({'shard_filename': fn, 'shard_size': len(data)})
|
373 |
+
|
374 |
+
with open(cache_path, 'w') as f:
|
375 |
+
json.dump(metainfo, f)
|
376 |
+
|
377 |
+
del data
|
378 |
+
|
379 |
+
@staticmethod
|
380 |
+
def _save_cache_file(data, cache_location, idx):
|
381 |
+
cache_location = get_data_cache_location(cache_location, idx)
|
382 |
+
|
383 |
+
if os.path.exists(cache_location):
|
384 |
+
raise Exception("Cache already exists")
|
385 |
+
|
386 |
+
torch.save(data, cache_location)
|
387 |
+
log(f"Saved data into cache (shard {idx})")
|
388 |
+
|
389 |
+
return cache_location
|
390 |
+
|
391 |
+
def set_model_type(self):
|
392 |
+
result = None
|
393 |
+
|
394 |
+
for spec_tuple in self.coupling_specs:
|
395 |
+
this_type = get_mdl_type(spec_tuple.tokenizer)
|
396 |
+
if result is None:
|
397 |
+
result = this_type
|
398 |
+
else:
|
399 |
+
assert result == this_type, "in this implementation model types (NLLB/MADLAD/...) must be the same for all included models"
|
400 |
+
|
401 |
+
return result
|
402 |
+
|
403 |
+
|
404 |
+
def __init__(self, tr_file, coupling_specs, args):
|
405 |
+
self.args = args
|
406 |
+
self.filename = tr_file
|
407 |
+
self.coupling_specs = coupling_specs
|
408 |
+
|
409 |
+
self.exclude_set = _dev_to_dict(args.exclude_set) if args.exclude_set is not None else None
|
410 |
+
|
411 |
+
self.model_type = self.set_model_type()
|
412 |
+
|
413 |
+
# init lang to idx
|
414 |
+
self._lang_to_idx = lang_bin_mapping(coupling_specs)
|
415 |
+
|
416 |
+
def load_and_cache_data(self, cache_path):
|
417 |
+
# collect data into bins and cache it
|
418 |
+
bins = self._fill_bins()
|
419 |
+
|
420 |
+
self.report_update_stats(bins)
|
421 |
+
|
422 |
+
self._bins_to_tokenized_batched_cached_data(bins, cache_path)
|
423 |
+
"""
|
424 |
+
|
425 |
+
"""
|
426 |
+
class DataState:
|
427 |
+
def __init__(self, elem_idx = 0, shard_idx = 0, epoch_idx = None):
|
428 |
+
self.elem_idx = elem_idx
|
429 |
+
self.shard_idx = shard_idx
|
430 |
+
self.epoch_idx = epoch_idx
|
431 |
+
|
432 |
+
def state_dict(self):
|
433 |
+
return {'elem_idx': self.elem_idx, 'shard_idx': self.shard_idx, 'epoch_idx': self.epoch_idx}
|
434 |
+
|
435 |
+
def load_state_dict(self, state_dict):
|
436 |
+
self.elem_idx = state_dict['elem_idx']
|
437 |
+
self.shard_idx = state_dict['shard_idx']
|
438 |
+
self.epoch_idx = state_dict['epoch_idx']
|
439 |
+
|
440 |
+
def copy_from(self, src_ds, epoch_idx = None):
|
441 |
+
self.shard_idx = src_ds.shard_idx
|
442 |
+
self.elem_idx = src_ds.elem_idx
|
443 |
+
|
444 |
+
if epoch_idx is not None:
|
445 |
+
self.epoch_idx = epoch_idx
|
446 |
+
|
447 |
+
def __str__(self):
|
448 |
+
return 'DataState(elem_idx={}, shard_idx={}, epoch_idx={})'.format(self.elem_idx, self.shard_idx, self.epoch_idx)
|
449 |
+
|
450 |
+
def __repr__(self):
|
451 |
+
return self.__str__()
|
452 |
+
|
453 |
+
|
454 |
+
class BatchingIterator(IterableDataset):
|
455 |
+
def __init__(self, segment_list, batch_size, tokenizer, max_len=8000):
|
456 |
+
self.data = segment_list
|
457 |
+
shuffle(self.data)
|
458 |
+
|
459 |
+
self.batch_size = batch_size
|
460 |
+
self.tokenizer = tokenizer
|
461 |
+
self.max_len = max_len
|
462 |
+
|
463 |
+
self.curr_elem_idx = 0
|
464 |
+
|
465 |
+
self.data_len = math.ceil(len(self.data) / self.batch_size)
|
466 |
+
|
467 |
+
def __len__(self):
|
468 |
+
return self.data_len
|
469 |
+
|
470 |
+
def __iter__(self):
|
471 |
+
self.curr_elem_idx = 0
|
472 |
+
return self
|
473 |
+
|
474 |
+
def where_are_we(self):
|
475 |
+
return DataState(shard_idx=0, elem_idx=self.curr_elem_idx)
|
476 |
+
|
477 |
+
def thats_where(self, data_state):
|
478 |
+
self.curr_elem_idx = data_state.elem_idx
|
479 |
+
|
480 |
+
def _get_properly_sized_segment_list(self):
|
481 |
+
i = self.curr_elem_idx * self.batch_size
|
482 |
+
|
483 |
+
segment_list = self.data[i:i + self.batch_size]
|
484 |
+
if len(segment_list) < self.batch_size:
|
485 |
+
orig_len = len(segment_list)
|
486 |
+
while len(segment_list) < self.batch_size:
|
487 |
+
segment_list.append(segment_list[randint(0, orig_len - 1)])
|
488 |
+
|
489 |
+
return segment_list
|
490 |
+
|
491 |
+
def _tokenize(self, segment_list):
|
492 |
+
#{'task': 'translate',
|
493 |
+
# 'src_segm': src_segm,
|
494 |
+
# 'tgt_segm': tgt_segm,
|
495 |
+
# 'src_lang': src_lang,
|
496 |
+
# 'tgt_lang': tgt_lang}
|
497 |
+
|
498 |
+
prepped_segm_list = [prep_llm_input(s) for s in segment_list]
|
499 |
+
|
500 |
+
self.tokenizer.pad_token = '<|reserved_special_token_0|>'
|
501 |
+
tokenized_batch = self.tokenizer(prepped_segm_list, return_tensors="pt", max_length=self.max_len,
|
502 |
+
truncation=True, add_special_tokens=True,
|
503 |
+
padding=True)
|
504 |
+
return tokenized_batch, self.curr_elem_idx + 1
|
505 |
+
|
506 |
+
def __next__(self):
|
507 |
+
if self.curr_elem_idx >= self.data_len:
|
508 |
+
raise StopIteration
|
509 |
+
else:
|
510 |
+
segment_list = self._get_properly_sized_segment_list()
|
511 |
+
|
512 |
+
batch = self._tokenize(segment_list)
|
513 |
+
self.curr_elem_idx += 1
|
514 |
+
return batch
|
515 |
+
"""
|
516 |
+
"""
|
517 |
+
class MultilingualDatasetIterator(IterableDataset):
|
518 |
+
def _load_metafile(self, cache_metafile):
|
519 |
+
with open(cache_metafile, 'r') as f:
|
520 |
+
self.metainfo = json.load(f)
|
521 |
+
self.data_len = sum([e['shard_size'] for e in self.metainfo])
|
522 |
+
|
523 |
+
def _init_curr_shard(self):
|
524 |
+
cache_location = self.metainfo[self.curr_shard_idx]['shard_filename']
|
525 |
+
|
526 |
+
self.curr_shard_data = torch.load(cache_location, weights_only=False)
|
527 |
+
|
528 |
+
assert len(self.curr_shard_data) == self.metainfo[self.curr_shard_idx]['shard_size']
|
529 |
+
|
530 |
+
def __init__(self, filename):
|
531 |
+
self.curr_shard_idx = 0
|
532 |
+
self.curr_elem_idx = 0
|
533 |
+
self.prev_shard_sum_len = 0
|
534 |
+
|
535 |
+
if filename is not None:
|
536 |
+
self._load_metafile(filename)
|
537 |
+
|
538 |
+
def __iter__(self):
|
539 |
+
self._init_curr_shard()
|
540 |
+
return self
|
541 |
+
|
542 |
+
def where_are_we(self):
|
543 |
+
return DataState(shard_idx=self.curr_shard_idx, elem_idx=self.curr_elem_idx)
|
544 |
+
|
545 |
+
def thats_where(self, data_state):
|
546 |
+
self.curr_shard_idx = data_state.shard_idx
|
547 |
+
self.curr_elem_idx = data_state.elem_idx
|
548 |
+
self.prev_shard_sum_len = sum([e['shard_size'] for i, e in enumerate(self.metainfo) if i < self.curr_shard_idx])
|
549 |
+
|
550 |
+
def __next__(self):
|
551 |
+
try:
|
552 |
+
result_data = self.curr_shard_data[self.curr_elem_idx]
|
553 |
+
|
554 |
+
self.curr_elem_idx += 1
|
555 |
+
except IndexError:
|
556 |
+
self.prev_shard_sum_len += self.metainfo[self.curr_shard_idx]['shard_size']
|
557 |
+
self.curr_shard_idx += 1
|
558 |
+
|
559 |
+
if self.curr_shard_idx >= len(self.metainfo):
|
560 |
+
self.__init__(None)
|
561 |
+
raise StopIteration
|
562 |
+
else:
|
563 |
+
self._init_curr_shard()
|
564 |
+
self.curr_elem_idx = 0
|
565 |
+
|
566 |
+
result_data = self.curr_shard_data[self.curr_elem_idx]
|
567 |
+
|
568 |
+
self.curr_elem_idx += 1
|
569 |
+
|
570 |
+
index_in_epoch = self.prev_shard_sum_len + self.curr_elem_idx
|
571 |
+
return result_data, index_in_epoch
|
572 |
+
|
573 |
+
def __len__(self):
|
574 |
+
return self.data_len
|
575 |
+
|
576 |
+
|
577 |
+
|
578 |
+
|
579 |
+
def dump_to_stdout():
|
580 |
+
filename = sys.argv[1]
|
581 |
+
|
582 |
+
lc_src = defaultdict(int)
|
583 |
+
|
584 |
+
tot_len = 0
|
585 |
+
tot_count = 0
|
586 |
+
|
587 |
+
for tr_pair in get_tr_pairs(filename=filename):
|
588 |
+
print(tr_pair.src_lang + "\t" + tr_pair.input + "\t" + tr_pair.tgt_lang + "\t" + tr_pair.output)
|
589 |
+
|
590 |
+
tot_len += upd_lc(lc_src, tr_pair.src_lang, tr_pair.input)
|
591 |
+
tot_len += upd_lc(lc_src, tr_pair.tgt_lang, tr_pair.output)
|
592 |
+
|
593 |
+
tot_count += 2
|
594 |
+
|
595 |
+
totes = sum(lc_src.values())
|
596 |
+
for k in sorted(lc_src):
|
597 |
+
sys.stderr.write(f"{k}: {100*lc_src[k]/totes:.1f}%\n")
|
598 |
+
sys.stderr.write(f"Avg length: {tot_len/float(tot_count):.1f}\n")
|
599 |
+
|
600 |
+
|
601 |
+
def do_stats(filename):
|
602 |
+
stats = defaultdict(int)
|
603 |
+
raw_data = load_json_data(filename)
|
604 |
+
|
605 |
+
for data in raw_data:
|
606 |
+
langs = sorted([k for k in data.keys() if data[k].strip() != ""])
|
607 |
+
stats["-".join(langs)] += 1
|
608 |
+
for k in stats:
|
609 |
+
print(k, stats[k])
|
610 |
+
|
611 |
+
|
612 |
+
def lang_from_name(filename):
|
613 |
+
return filename.split(".")[-1]
|
614 |
+
|
615 |
+
|
616 |
+
def moses_to_json(file1, file2):
|
617 |
+
result = list()
|
618 |
+
|
619 |
+
l1 = lang_from_name(file1)
|
620 |
+
l2 = lang_from_name(file2)
|
621 |
+
|
622 |
+
with open(file1, "r") as h1, open(file2, "r") as h2:
|
623 |
+
for line1 in h1:
|
624 |
+
line2 = h2.readline()
|
625 |
+
|
626 |
+
result.append({l1: line1.strip(), l2: line2.strip()})
|
627 |
+
|
628 |
+
return result
|
629 |
+
|
630 |
+
|
631 |
+
def multi_moses_to_json(output_file, init_json, input_file_tuples):
|
632 |
+
try:
|
633 |
+
with open(init_json, "r") as h:
|
634 |
+
result = json.load(h)
|
635 |
+
except:
|
636 |
+
result = list()
|
637 |
+
|
638 |
+
for input_file_tuple in input_file_tuples:
|
639 |
+
this_result = moses_to_json(*input_file_tuple)
|
640 |
+
result.append({"source": f"{input_file_tuple[0]}-{input_file_tuple[1]}", "sentences": this_result})
|
641 |
+
|
642 |
+
with open(output_file, "w") as f:
|
643 |
+
json.dump(result, f, indent=2, sort_keys=True)
|
644 |
+
|
645 |
+
|
646 |
+
def group_tuples(input_tuples):
|
647 |
+
return [(input_tuples[2 * i], input_tuples[2 * i + 1]) for i in range(int(len(input_tuples) / 2))]
|
648 |
+
|
649 |
+
|
650 |
+
def combine_two_jsons(json_target, json_addition):
|
651 |
+
for k in json_addition:
|
652 |
+
if k in json_target:
|
653 |
+
json_target[k] += json_addition[k]
|
654 |
+
else:
|
655 |
+
json_target[k] = json_addition[k]
|
656 |
+
|
657 |
+
|
658 |
+
def combine_jsons(filelist):
|
659 |
+
result = dict()
|
660 |
+
|
661 |
+
for filename in filelist:
|
662 |
+
data = json.load(open(filename))
|
663 |
+
|
664 |
+
combine_two_jsons(result, data)
|
665 |
+
|
666 |
+
json.dumps(result)
|
667 |
+
|
668 |
+
|
669 |
+
def _dev_to_dict(filename):
|
670 |
+
result = defaultdict(lambda: defaultdict(int))
|
671 |
+
|
672 |
+
for dev_sample in load_json_data(filename):
|
673 |
+
for lang in dev_sample:
|
674 |
+
if not "dia" in lang:
|
675 |
+
result[lang][dev_sample[lang]] = 1
|
676 |
+
|
677 |
+
return result
|
678 |
+
|
679 |
+
|
680 |
+
def check_cross_pollination(small_path, large_path):
|
681 |
+
print("preparing dev set")
|
682 |
+
dct = _dev_to_dict(small_path)
|
683 |
+
|
684 |
+
print("reading train set")
|
685 |
+
for train_sample in load_json_data(large_path):
|
686 |
+
for lang in train_sample:
|
687 |
+
if not "dia" in lang and lang in dct:
|
688 |
+
snt = train_sample[lang]
|
689 |
+
|
690 |
+
if snt in dct[lang]:
|
691 |
+
dct[lang][snt] += 1
|
692 |
+
|
693 |
+
print("---------------------")
|
694 |
+
print("contamination report:")
|
695 |
+
print("---------------------")
|
696 |
+
for lang in dct:
|
697 |
+
total = 0
|
698 |
+
counts = 0
|
699 |
+
freqs = 0
|
700 |
+
|
701 |
+
for snt in dct[lang]:
|
702 |
+
total += 1
|
703 |
+
if dct[lang][snt] > 1:
|
704 |
+
counts += 1
|
705 |
+
freqs += (dct[lang][snt] - 1)
|
706 |
+
|
707 |
+
print(f"{lang}: contaminated: {counts} ({100*counts/float(total):.1f}%), total occurrence: {freqs}")
|
708 |
+
|
709 |
+
|
710 |
+
def char_class(c):
|
711 |
+
lc = c.lower()
|
712 |
+
if re.match("[a-z]", lc):
|
713 |
+
return "latn"
|
714 |
+
elif re.match("[а-я]", lc):
|
715 |
+
return "cyrl"
|
716 |
+
else:
|
717 |
+
return "other"
|
718 |
+
|
719 |
+
|
720 |
+
def snt_is_fishy(snt_raw, lang, detailed=False):
|
721 |
+
snt = re.sub(r'^<[^>]+> ', '', snt_raw)
|
722 |
+
|
723 |
+
snt_db = defaultdict(int)
|
724 |
+
for c in snt:
|
725 |
+
c_c = char_class(c)
|
726 |
+
snt_db[c_c] += 1
|
727 |
+
|
728 |
+
tot = snt_db['latn'] + snt_db['cyrl']
|
729 |
+
|
730 |
+
if tot > 0:
|
731 |
+
if snt_db['latn'] / tot > 0.7:
|
732 |
+
this_is = 'latn'
|
733 |
+
elif snt_db['cyrl'] / tot > 0.7:
|
734 |
+
this_is = 'cyrl'
|
735 |
+
else:
|
736 |
+
this_is = 'mix'
|
737 |
+
|
738 |
+
should_be = any_to_nllb(lang).split("_")[1].lower()
|
739 |
+
|
740 |
+
if should_be != this_is:
|
741 |
+
return (True, this_is, should_be) if detailed else True
|
742 |
+
|
743 |
+
return (False, None, None) if detailed else False
|
744 |
+
|
745 |
+
|
746 |
+
def script_stats():
|
747 |
+
db = defaultdict(lambda: defaultdict(int))
|
748 |
+
|
749 |
+
# corp = []
|
750 |
+
|
751 |
+
for raw_line in sys.stdin:
|
752 |
+
lang, snt_raw = raw_line.strip().split("\t")
|
753 |
+
|
754 |
+
is_fishy, this_is, should_be = snt_is_fishy(snt_raw, lang, detailed=True)
|
755 |
+
if is_fishy:
|
756 |
+
print(f"{lang}: should be {should_be}, is actually {this_is}:\n{snt_raw}")
|
757 |
+
|
758 |
+
|
759 |
+
|
760 |
+
def get_full_lang(lang, tupl):
|
761 |
+
dia_key = f"{lang}-dia"
|
762 |
+
|
763 |
+
if dia_key in tupl:
|
764 |
+
return f"{lang}, {tupl[dia_key]}"
|
765 |
+
else:
|
766 |
+
return lang
|
767 |
+
|
768 |
+
|
769 |
+
def convert_json_to_json(src_json, dest_json):
|
770 |
+
raw_data = load_json_data(src_json)
|
771 |
+
|
772 |
+
output_data = []
|
773 |
+
|
774 |
+
for tupl in raw_data:
|
775 |
+
for l1 in tupl:
|
776 |
+
for l2 in tupl:
|
777 |
+
if l1 != l2 and not "dia" in l1 and not "dia" in l2:
|
778 |
+
src_segm = tupl[l1]
|
779 |
+
tgt_segm = tupl[l2]
|
780 |
+
|
781 |
+
src_lang = get_full_lang(l1, tupl)
|
782 |
+
tgt_lang = get_full_lang(l2, tupl)
|
783 |
+
|
784 |
+
output_data.append({ 'task': 'translate',
|
785 |
+
'src_segm': src_segm,
|
786 |
+
'tgt_segm': tgt_segm,
|
787 |
+
'src_lang': src_lang,
|
788 |
+
'tgt_lang': tgt_lang})
|
789 |
+
|
790 |
+
with open(dest_json, "w") as f:
|
791 |
+
json.dump(output_data, f, indent=2)
|
792 |
+
"""
|
793 |
+
|
794 |
+
if __name__ == "__main__":
|
795 |
+
# check_cross_pollination(sys.argv[1], sys.argv[2])
|
796 |
+
# multi_moses_to_json(sys.argv[1], sys.argv[2], group_tuples(sys.argv[3:]))
|
797 |
+
# combine_jsons(sys.argv[1:])
|
798 |
+
# do_stats("data/train.json")
|
799 |
+
|
800 |
+
# dump_to_stdout()
|
801 |
+
# script_stats()
|
802 |
+
|
803 |
+
# convert_json_to_json(sys.argv[1], sys.argv[2])
|
804 |
+
pass
|
kuidastaltsutadalaamat/legacy/diffmdl.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from datetime import datetime
|
7 |
+
from transformers import AutoModelForSeq2SeqLM
|
8 |
+
|
9 |
+
|
10 |
+
def get_mdl_param_dict(mdl):
|
11 |
+
return {n: p for n, p in mdl.named_parameters()}
|
12 |
+
|
13 |
+
|
14 |
+
def log(msg):
|
15 |
+
sys.stderr.write(str(datetime.now()) + ": " + msg + '\n')
|
16 |
+
|
17 |
+
|
18 |
+
def _avg_diff(pd1, pd2, skip_emb):
|
19 |
+
result = 0
|
20 |
+
count = 0
|
21 |
+
|
22 |
+
raw_count = 0
|
23 |
+
|
24 |
+
for k in pd1.keys():
|
25 |
+
# log(k)
|
26 |
+
if not (skip_emb and "shared" in k):
|
27 |
+
delta = pd1[k] - pd2[k]
|
28 |
+
|
29 |
+
raw_count += 1
|
30 |
+
|
31 |
+
if len(delta.shape) == 1:
|
32 |
+
thiscount = delta.shape[0]
|
33 |
+
elif len(delta.shape) == 2:
|
34 |
+
thiscount = delta.shape[0] * delta.shape[1]
|
35 |
+
else:
|
36 |
+
raise Exception("Unexpected shape")
|
37 |
+
count += thiscount
|
38 |
+
deltasum = torch.sum(delta)
|
39 |
+
#log(f"DETDIFF {k}: {deltasum/thiscount}")
|
40 |
+
result += deltasum
|
41 |
+
# print(f"Count {count}, raw count {raw_count}")
|
42 |
+
|
43 |
+
return result / count
|
44 |
+
|
45 |
+
|
46 |
+
def avg_mdl_diff(m1, m2, skip_emb=False):
|
47 |
+
pd1 = get_mdl_param_dict(m1)
|
48 |
+
pd2 = get_mdl_param_dict(m2)
|
49 |
+
|
50 |
+
assert (pd1.keys() == pd2.keys())
|
51 |
+
|
52 |
+
return _avg_diff(pd1, pd2, skip_emb)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
mdl1_id = sys.argv[1]
|
57 |
+
mdl2_id = sys.argv[2]
|
58 |
+
|
59 |
+
log(f"Load mdl 1: {mdl1_id}")
|
60 |
+
model1 = AutoModelForSeq2SeqLM.from_pretrained(mdl1_id)
|
61 |
+
|
62 |
+
log(f"Load mdl 2: {mdl2_id}")
|
63 |
+
model2 = AutoModelForSeq2SeqLM.from_pretrained(mdl2_id)
|
64 |
+
|
65 |
+
log(f"Full diff: {avg_mdl_diff(model1, model2)}")
|
66 |
+
|
67 |
+
#log(f"Encoder diff: {avg_mdl_diff(model1.get_encoder(), model2.get_encoder(), True)}")
|
68 |
+
#log(f"Decoder diff: {avg_mdl_diff(model1.get_decoder(), model2.get_decoder(), True)}")
|
69 |
+
#log(f"Embedding diff: {avg_mdl_diff(model1.get_input_embeddings(), model2.get_input_embeddings())}")
|
kuidastaltsutadalaamat/legacy/initmodel.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
6 |
+
|
7 |
+
from legacy.modelops import mdl_param_count
|
8 |
+
from legacy.tokops import train_or_extend_tokenizer_and_upd_model, save_postokens
|
9 |
+
|
10 |
+
from aux import get_changed_config, CmdlineArgs
|
11 |
+
from legacy.langconv import lang_set_maybe_smugri
|
12 |
+
|
13 |
+
|
14 |
+
def just_do_main_stuff_and_avoid_global_ctx_variables():
|
15 |
+
args = CmdlineArgs("Initialize a new HuggingFace model randomly, off of an existing configuration, with possible changes",
|
16 |
+
pos_arg_list=["mdl_id", "save_location"],
|
17 |
+
kw_arg_dict={ k: None for k in ["tok_train_file", "new_langs", "vocab_size", "merge_tokenizers", "merge_tok_mdl_id",
|
18 |
+
"tok_mdl_id", "activation_dropout", "activation_function", "d_model",
|
19 |
+
"decoder_attention_heads", "decoder_ffn_dim", "decoder_layerdrop", "decoder_layers",
|
20 |
+
"encoder_attention_heads", "encoder_ffn_dim", "encoder_layerdrop", "encoder_layers",
|
21 |
+
"num_hidden_layers"] })
|
22 |
+
if not args.tok_mdl_id:
|
23 |
+
args.tok_mdl_id = args.mdl_id
|
24 |
+
|
25 |
+
if args.new_langs:
|
26 |
+
args.new_langs = lang_set_maybe_smugri(args.new_langs)
|
27 |
+
|
28 |
+
if os.path.exists(args.save_location):
|
29 |
+
raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite")
|
30 |
+
|
31 |
+
config = get_changed_config(AutoConfig.from_pretrained(args.mdl_id), args)
|
32 |
+
|
33 |
+
model = AutoModelForSeq2SeqLM.from_config(config)
|
34 |
+
|
35 |
+
tokenizer, added = train_or_extend_tokenizer_and_upd_model(args, model)
|
36 |
+
|
37 |
+
tokenizer.save_pretrained(args.save_location)
|
38 |
+
save_postokens(added, args.save_location)
|
39 |
+
model.save_pretrained(args.save_location)
|
40 |
+
|
41 |
+
mdl_size, emb_size = mdl_param_count(model)
|
42 |
+
print(f"Created model with {mdl_size} parameters" +
|
43 |
+
("" if emb_size < 0 else f" of which {emb_size} ({100 * emb_size / mdl_size:.2f}%) are embeddings"))
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
just_do_main_stuff_and_avoid_global_ctx_variables()
|
kuidastaltsutadalaamat/legacy/langconv.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Convert lang. codes between different schemas
|
3 |
+
|
4 |
+
NLLB uses codes like "eng_Latn": ISO-639-3 and script
|
5 |
+
|
6 |
+
MADLAD uses codes like "<2en>": ISO-639-1 for where it's available, ISO-639-3 elsewhere
|
7 |
+
(and some codes include the script, but we'll ignore them here)
|
8 |
+
|
9 |
+
Functions at the end of the file (any_to_nllb, any_to_madlad) should
|
10 |
+
cope with a lang code in any style ('en', 'eng', 'eng_Latn', '<2en>', '<2eng>', etc)
|
11 |
+
and convert them to corresponding representations (NLLB/MADLAD).
|
12 |
+
"""
|
13 |
+
from collections import defaultdict
|
14 |
+
|
15 |
+
SMUGRI_LOW = "fkv,izh,kca,koi,kpv,krl,liv,lud,mdf,mhr,mns,mrj,myv,olo,sjd,sje,sju,sma,sme,smj,smn,sms,udm,vep,vot,vro"
|
16 |
+
SMUGRI_HIGH = "deu,eng,est,fin,hun,lvs,nor,rus,swe"
|
17 |
+
SMUGRI = SMUGRI_HIGH + "," + SMUGRI_LOW
|
18 |
+
|
19 |
+
import pycountry
|
20 |
+
|
21 |
+
# madlad all codes
|
22 |
+
MADLAD_CODES = ['<2meo>', '<2lo>', '<2Grek>', '<2ada>', '<2ps>', '<2arn>', '<2Armn>', '<2to>', '<2raj>', '<2bas>', '<2ny>', '<2>', '<2zza>', '<2Thai>', '<2kaa_Latn>', '<2yap>', '<2en_xx_simple>', '<2ta>', '<2bg_Latn>', '<2mkn>', '<2lhu>', '<2gu_Latn>', '<2nzi>', '<2uz>', '<2pis>', '<2cfm>', '<2min>', '<2fon>', '<2tn>', '<2msi>', '<2sw>', '<2Tfng>', '<2teo>', '<2taj>', '<2pap>', '<2sd>', '<2Jpan>', '<2tca>', '<2sr>', '<2an>', '<2fr>', '<2gor>', '<2az>', '<2qvi>', '<2pck>', '<2cak>', '<2ltg>', '<2sah>', '<2tly_IR>', '<2ts>', '<2yo>', '<2hne>', '<2bzj>', '<2tuc>', '<2sh>', '<2da>', '<2gui>', '<2translate>', '<2et>', '<2sja>', '<2nhe>', '<2scn>', '<2dje>', '<2pt>', '<2nog>', '<2fil>', '<2mai>', '<2lb>', '<2bm>', '<2Guru>', '<2gom>', '<2hr>', '<2kg>', '<2uk>', '<2rw>', '<2izz>', '<2Telu>', '<2wuu>', '<2Deva>', '<2or>', '<2is>', '<2om>', '<2iso>', '<2sn>', '<2kjh>', '<2tbz>', '<2suz>', '<2bjn>', '<2lv>', '<2mfe>', '<2tcy>', '<2tyz>', '<2ksw>', '<2nds_NL>', '<2ms>', '<2mam>', '<2ubu>', '<2hil>', '<2mh>', '<2gl>', '<2bew>', '<2ilo>', '<2kbd>', '<2toj>', '<2quf>', '<2jam>', '<2Beng>', '<2tyv>', '<2lmo>', '<2ace>', '<2cab>', '<2sq>', '<2ug>', '<2kac>', '<2ay>', '<2mag>', '<2Arab>', '<2mrj>', '<2cs>', '<2bci>', '<2doi>', '<2zu>', '<2ndc_ZW>', '<2smt>', '<2ho>', '<2ss>', '<2he>', '<2twu>', '<2kjg>', '<2pag>', '<2Latn>', '<2gym>', '<2sus>', '<2zh_Latn>', '<2mps>', '<2lg>', '<2ko>', '<2se>', '<2guc>', '<2mr>', '<2mwl>', '<2dwr>', '<2din>', '<2ffm>', '<2maz>', '<2nia>', '<2nl>', '<2Knda>', '<2jv>', '<2noa>', '<2udm>', '<2kr>', '<2de>', '<2ar>', '<2ZW>', '<2dln>', '<2mn>', '<2ml>', '<2crh>', '<2ha>', '<2ks>', '<2qvc>', '<2fur>', '<2myv>', '<2nv>', '<2ak>', '<2Gujr>', '<2cce>', '<2nso>', '<2sg>', '<2rmc>', '<2mas>', '<2mni>', '<2frp>', '<2my>', '<2xal>', '<2th>', '<2bik>', '<2bho>', '<2inb>', '<2Mlym>', '<2oj>', '<2back_translated>', '<2tet>', '<2gsw>', '<2ff>', '<2hy>', '<2otq>', '<2el>', '<2agr>', '<2br>', '<2alt>', '<2tzo>', '<2chm>', '<2transliterate>', '<2hu>', '<2btx>', '<2vi>', '<2iba>', '<2bg>', '<2gub>', '<2li>', '<2ace_Arab>', '<2qub>', '<2ktu>', '<2bru>', '<2bbc>', '<2ca>', '<2hvn>', '<2sat_Latn>', '<2ku>', '<2shn>', '<2djk>', '<2krc>', '<2io>', '<2ig>', '<2chk>', '<2sm>', '<2Mymr>', '<2Kore>', '<2ary>', '<2lu>', '<2fa>', '<2spp>', '<2af>', '<2ti>', '<2Tibt>', '<2emp>', '<2enq>', '<2kl>', '<2be>', '<2srn>', '<2ms_Arab_BN>', '<2kri>', '<2gd>', '<2mk>', '<2syr>', '<2kmz_Latn>', '<2CA>', '<2ium>', '<2abt>', '<2ngu>', '<2tab>', '<2it>', '<2ru>', '<2ann>', '<2msm>', '<2fo>', '<2ne>', '<2akb>', '<2kv>', '<2jac>', '<2ceb>', '<2ang>', '<2tdx>', '<2tr>', '<2kbp>', '<2mgh>', '<2az_RU>', '<2acf>', '<2tg>', '<2dov>', '<2pau>', '<2mg>', '<2fuv>', '<2nn>', '<2Hant>', '<2hui>', '<2ml_Latn>', '<2ja>', '<2lus>', '<2te>', '<2qu>', '<2rom>', '<2tsg>', '<2el_Latn>', '<2cr_Latn>', '<2ur>', '<2fi>', '<2shp>', '<2brx>', '<2laj>', '<2sda>', '<2lij>', '<2st>', '<2bn>', '<2zxx_xx_dtynoise>', '<2yua>', '<2no>', '<2fr_CA>', '<2miq>', '<2trp>', '<2es>', '<2ch>', '<2mass>', '<2os>', '<2bts>', '<2ady>', '<2lrc>', '<2seh>', '<2adh>', '<2new>', '<2mak>', '<2grc>', '<2nus>', '<2tzj>', '<2nut>', '<2gu>', '<2oc>', '<2ppk>', '<2Hans>', '<2tzh>', '<2si>', '<2wo>', '<2nyu>', '<2Hebr>', '<2mad>', '<2tll>', '<2kr_Arab>', '<2pon>', '<2mbt>', '<2kw>', '<2bjn_Arab>', '<2gn>', '<2eu>', '<2dz>', '<2kaa>', '<2crh_Latn>', '<2te_Latn>', '<2ky>', '<2kn_Latn>', '<2kum>', '<2fip>', '<2ksd>', '<2sk>', '<2NL>', '<2ctd_Latn>', '<2Khmr>', '<2gbm>', '<2Cans>', '<2haw>', '<2gag>', '<2Taml>', '<2cnh>', '<2bim>', '<2ms_Arab>', '<2Thaa>', '<2kha>', '<2tvl>', '<2Cyrl>', '<2chr>', '<2dtp>', '<2ba>', '<2nan_Latn_TW>', '<2ro>', '<2ctu>', '<2Ethi>', '<2zh>', '<2ln>', '<2ve>', '<2xh>', '<2skr>', '<2ber>', '<2niq>', '<2ibb>', '<2jvn>', '<2tks>', '<2av>', '<2ahk>', '<2tk>', '<2tt>', '<2ka>', '<2tsc>', '<2km>', '<2co>', '<2id>', '<2prs>', '<2rki>', '<2kmb>', '<2ks_Deva>', '<2ify>', '<2wal>', '<2arz>', '<2amu>', '<2rm>', '<2pa>', '<2RU>', '<2ce>', '<2hi>', '<2eo>', '<2taq>', '<2ga>', '<2qxr>', '<2la>', '<2bi>', '<2rwo>', '<2dyu>', '<2zh_Hant>', '<2mt>', '<2bqc>', '<2bn_Latn>', '<2zne>', '<2szl>', '<2lt>', '<2sl>', '<2hif>', '<2alz>', '<2ber_Latn>', '<2ckb>', '<2wa>', '<2Cher>', '<2msb>', '<2gom_Latn>', '<2ru_Latn>', '<2crs>', '<2kk>', '<2gvl>', '<2qvz>', '<2bar>', '<2qup>', '<2bgp>', '<2bo>', '<2su>', '<2tzm>', '<2IR>', '<2sv>', '<2srm>', '<2rn>', '<2bus>', '<2jiv>', '<2awa>', '<2gv>', '<2knj>', '<2as>', '<2quc>', '<2en>', '<2sa>', '<2bug>', '<2quy>', '<2hi_Latn>', '<2nds>', '<2kek>', '<2mrw>', '<2kos>', '<2cy>', '<2ta_Latn>', '<2kn>', '<2nr>', '<2ape>', '<2bs>', '<2iu>', '<2nnb>', '<2Geor>', '<2rcf>', '<2meu>', '<2cac>', '<2cuk>', '<2bua>', '<2vec>', '<2so>', '<2fj>', '<2gof>', '<2koi>', '<2cv>', '<2guh>', '<2war>', '<2pl>', '<2cbk>', '<2kj>', '<2dv>', '<2mdf>', '<2fy>', '<2am>', '<2sc>', '<2taq_Tfng>', '<2mi>', '<2zap>', '<2mqy>', '<2yi>', '<2kwi>', '<2hmn>', '<2tiv>', '<2sxn>', '<2hus>', '<2ban>', '<2nij>', '<2tlh>', '<2Orya>', '<2quh>', '<2ee>', '<2ht>', '<2bum>', '<2stq>']
|
23 |
+
|
24 |
+
# NLLB all codes
|
25 |
+
NLLB_CODES = ['ace_Arab', 'ace_Latn', 'acm_Arab', 'acq_Arab', 'aeb_Arab', 'afr_Latn', 'ajp_Arab', 'aka_Latn', 'amh_Ethi', 'apc_Arab', 'arb_Arab', 'ars_Arab', 'ary_Arab', 'arz_Arab', 'asm_Beng', 'ast_Latn', 'awa_Deva', 'ayr_Latn', 'azb_Arab', 'azj_Latn', 'bak_Cyrl', 'bam_Latn', 'ban_Latn', 'bel_Cyrl', 'bem_Latn', 'ben_Beng', 'bho_Deva', 'bjn_Arab', 'bjn_Latn', 'bod_Tibt', 'bos_Latn', 'bug_Latn', 'bul_Cyrl', 'cat_Latn', 'ceb_Latn', 'ces_Latn', 'cjk_Latn', 'ckb_Arab', 'crh_Latn', 'cym_Latn', 'dan_Latn', 'deu_Latn', 'dik_Latn', 'dyu_Latn', 'dzo_Tibt', 'ell_Grek', 'eng_Latn', 'epo_Latn', 'est_Latn', 'eus_Latn', 'ewe_Latn', 'fao_Latn', 'pes_Arab', 'fij_Latn', 'fin_Latn', 'fon_Latn', 'fra_Latn', 'fur_Latn', 'fuv_Latn', 'gla_Latn', 'gle_Latn', 'glg_Latn', 'grn_Latn', 'guj_Gujr', 'hat_Latn', 'hau_Latn', 'heb_Hebr', 'hin_Deva', 'hne_Deva', 'hrv_Latn', 'hun_Latn', 'hye_Armn', 'ibo_Latn', 'ilo_Latn', 'ind_Latn', 'isl_Latn', 'ita_Latn', 'jav_Latn', 'jpn_Jpan', 'kab_Latn', 'kac_Latn', 'kam_Latn', 'kan_Knda', 'kas_Arab', 'kas_Deva', 'kat_Geor', 'knc_Arab', 'knc_Latn', 'kaz_Cyrl', 'kbp_Latn', 'kea_Latn', 'khm_Khmr', 'kik_Latn', 'kin_Latn', 'kir_Cyrl', 'kmb_Latn', 'kon_Latn', 'kor_Hang', 'kmr_Latn', 'lao_Laoo', 'lvs_Latn', 'lij_Latn', 'lim_Latn', 'lin_Latn', 'lit_Latn', 'lmo_Latn', 'ltg_Latn', 'ltz_Latn', 'lua_Latn', 'lug_Latn', 'luo_Latn', 'lus_Latn', 'mag_Deva', 'mai_Deva', 'mal_Mlym', 'mar_Deva', 'min_Latn', 'mkd_Cyrl', 'plt_Latn', 'mlt_Latn', 'mni_Beng', 'khk_Cyrl', 'mos_Latn', 'mri_Latn', 'zsm_Latn', 'mya_Mymr', 'nld_Latn', 'nno_Latn', 'nob_Latn', 'npi_Deva', 'nso_Latn', 'nus_Latn', 'nya_Latn', 'oci_Latn', 'gaz_Latn', 'ory_Orya', 'pag_Latn', 'pan_Guru', 'pap_Latn', 'pol_Latn', 'por_Latn', 'prs_Arab', 'pbt_Arab', 'quy_Latn', 'ron_Latn', 'run_Latn', 'rus_Cyrl', 'sag_Latn', 'san_Deva', 'sat_Beng', 'scn_Latn', 'shn_Mymr', 'sin_Sinh', 'slk_Latn', 'slv_Latn', 'smo_Latn', 'sna_Latn', 'snd_Arab', 'som_Latn', 'sot_Latn', 'spa_Latn', 'als_Latn', 'srd_Latn', 'srp_Cyrl', 'ssw_Latn', 'sun_Latn', 'swe_Latn', 'swh_Latn', 'szl_Latn', 'tam_Taml', 'tat_Cyrl', 'tel_Telu', 'tgk_Cyrl', 'tgl_Latn', 'tha_Thai', 'tir_Ethi', 'taq_Latn', 'taq_Tfng', 'tpi_Latn', 'tsn_Latn', 'tso_Latn', 'tuk_Latn', 'tum_Latn', 'tur_Latn', 'twi_Latn', 'tzm_Tfng', 'uig_Arab', 'ukr_Cyrl', 'umb_Latn', 'urd_Arab', 'uzn_Latn', 'vec_Latn', 'vie_Latn', 'war_Latn', 'wol_Latn', 'xho_Latn', 'ydd_Hebr', 'yor_Latn', 'yue_Hant', 'zho_Hans', 'zho_Hant', 'zul_Latn']
|
26 |
+
|
27 |
+
MDL_NLLB = "MDL_NLLB"
|
28 |
+
MDL_MADLAD = "MDL_MADLAD"
|
29 |
+
MDL_NEUROTOLGE = "MDL_NEUROTÕLGE"
|
30 |
+
MDL_LLAMA = "MDL_LLAMA"
|
31 |
+
|
32 |
+
_iso3_to_script = dict([nllb_code.split("_") for nllb_code in NLLB_CODES])
|
33 |
+
|
34 |
+
iso3_to_nllb = { code: f"{code}_{_iso3_to_script[code]}" for code in _iso3_to_script }
|
35 |
+
|
36 |
+
iso3_to_nllb['lav'] = "lvs_Latn"
|
37 |
+
iso3_to_nllb['nor'] = "nob_Latn"
|
38 |
+
iso3_to_nllb['yid'] = "ydd_Hebr"
|
39 |
+
|
40 |
+
for lang in "fkv izh krl liv lud olo sje sju sma sme smj smn sms vep vot vro".split():
|
41 |
+
iso3_to_nllb[lang] = f"{lang}_Latn"
|
42 |
+
|
43 |
+
for lang in "kca koi kpv mdf mhr mns mrj myv sjd udm".split():
|
44 |
+
iso3_to_nllb[lang] = f"{lang}_Cyrl"
|
45 |
+
|
46 |
+
|
47 |
+
_rev_joshi = defaultdict(lambda: "?")
|
48 |
+
|
49 |
+
for k in "krl,sma,vep,smj,smn,lud,liv,izh,vot,kca,sms,sje,mns,fkv,sju,sjd".split(","):
|
50 |
+
_rev_joshi[k] = "0"
|
51 |
+
for k in "kpv,sme,mhr,udm,olo,myv,mdf,vro,mrj,koi".split(","):
|
52 |
+
_rev_joshi[k] = "1"
|
53 |
+
for k in SMUGRI_HIGH.split(","):
|
54 |
+
_rev_joshi[k] = "2+"
|
55 |
+
|
56 |
+
|
57 |
+
def guess_script(lang):
|
58 |
+
return "Unk"
|
59 |
+
|
60 |
+
|
61 |
+
def get_high_set():
|
62 |
+
return set(SMUGRI_HIGH.split(",")) - {"deu", "swe"}
|
63 |
+
|
64 |
+
|
65 |
+
def clean_lang(raw_lang):
|
66 |
+
if "<2" in raw_lang:
|
67 |
+
raw_lang = raw_lang[2:-1]
|
68 |
+
|
69 |
+
if "_" in raw_lang:
|
70 |
+
return raw_lang.split("_")[0]
|
71 |
+
else:
|
72 |
+
return raw_lang
|
73 |
+
|
74 |
+
|
75 |
+
def any_to_base(lang):
|
76 |
+
clang = clean_lang(lang)
|
77 |
+
|
78 |
+
res = pycountry.languages.get(alpha_2=clang)
|
79 |
+
|
80 |
+
if res is None:
|
81 |
+
return pycountry.languages.get(alpha_3=clang)
|
82 |
+
else:
|
83 |
+
return res
|
84 |
+
|
85 |
+
|
86 |
+
def base_to_nllb(lang_entry=None, lang_code=None):
|
87 |
+
if lang_code is None:
|
88 |
+
lang_code = lang_entry.alpha_3
|
89 |
+
|
90 |
+
try:
|
91 |
+
#script = iso3_to_script[lang_code]
|
92 |
+
return iso3_to_nllb[lang_code]
|
93 |
+
except KeyError:
|
94 |
+
script = guess_script(lang_code)
|
95 |
+
return f"{lang_code}_{script}"
|
96 |
+
|
97 |
+
|
98 |
+
def base_to_madlad(lang_entry=None, lang_code=None):
|
99 |
+
if lang_code is None:
|
100 |
+
if hasattr(lang_entry, 'alpha_2'):
|
101 |
+
lang_code = lang_entry.alpha_2
|
102 |
+
else:
|
103 |
+
lang_code = lang_entry.alpha_3
|
104 |
+
|
105 |
+
return f"<2{lang_code}>"
|
106 |
+
|
107 |
+
|
108 |
+
def any_to_something(lang, conv_func):
|
109 |
+
base = any_to_base(lang)
|
110 |
+
|
111 |
+
if base is None:
|
112 |
+
clang = clean_lang(lang)
|
113 |
+
return conv_func(None, clang)
|
114 |
+
else:
|
115 |
+
return conv_func(base)
|
116 |
+
|
117 |
+
|
118 |
+
def run_test(src_list, tgt_list, conv_func, msg_prefix, verbose=False):
|
119 |
+
ok_count = 0
|
120 |
+
err_count = 0
|
121 |
+
fail_count = 0
|
122 |
+
|
123 |
+
for raw_c in src_list:
|
124 |
+
try:
|
125 |
+
test = conv_func(raw_c)
|
126 |
+
if test in tgt_list:
|
127 |
+
ok_count += 1
|
128 |
+
else:
|
129 |
+
fail_count += 1
|
130 |
+
if verbose:
|
131 |
+
print("FAIL:", test)
|
132 |
+
except KeyError:
|
133 |
+
err_count += 1
|
134 |
+
if verbose:
|
135 |
+
print("ERR:", raw_c)
|
136 |
+
|
137 |
+
print(f"{msg_prefix}: {ok_count} good, {fail_count} fail, {err_count} err")
|
138 |
+
|
139 |
+
|
140 |
+
def any_to_madlad(lang):
|
141 |
+
return any_to_something(lang, base_to_madlad)
|
142 |
+
|
143 |
+
|
144 |
+
def any_to_nllb(lang):
|
145 |
+
return any_to_something(lang, base_to_nllb)
|
146 |
+
|
147 |
+
|
148 |
+
def any_to_neurotolge(lang):
|
149 |
+
l = any_to_base(lang).alpha_3
|
150 |
+
|
151 |
+
return l if l != 'lvs' else 'lv'
|
152 |
+
|
153 |
+
|
154 |
+
def any_to_mdl_type(mdl_type, lang):
|
155 |
+
if mdl_type == MDL_NLLB:
|
156 |
+
return any_to_nllb(lang)
|
157 |
+
elif mdl_type == MDL_MADLAD:
|
158 |
+
return any_to_madlad(lang)
|
159 |
+
elif mdl_type is None:
|
160 |
+
return lang
|
161 |
+
elif mdl_type == MDL_LLAMA:
|
162 |
+
return lang
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Unknown mdl_type {mdl_type}")
|
165 |
+
|
166 |
+
def langs_to_madlad(lang_set):
|
167 |
+
return [any_to_madlad(l) for l in lang_set] if lang_set is not None else []
|
168 |
+
|
169 |
+
|
170 |
+
def langs_to_nllb(lang_set):
|
171 |
+
return [any_to_nllb(l) for l in lang_set] if lang_set is not None else []
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
run_test(NLLB_CODES, MADLAD_CODES, any_to_madlad, "NLLB to MADLAD")
|
176 |
+
run_test(NLLB_CODES, NLLB_CODES, any_to_nllb, "NLLB to NLLB")
|
177 |
+
run_test(MADLAD_CODES, NLLB_CODES, any_to_nllb, "MADLAD TO NLLB")
|
178 |
+
run_test(MADLAD_CODES, MADLAD_CODES, any_to_madlad, "MADLAD TO MADLAD")
|
179 |
+
|
180 |
+
|
181 |
+
def is_nllb(object):
|
182 |
+
"""
|
183 |
+
Check if the object is an NLLB model or tokenizer
|
184 |
+
"""
|
185 |
+
name = object.__class__.__name__.lower()
|
186 |
+
return "m2m100" in name or "nllb" in name
|
187 |
+
|
188 |
+
|
189 |
+
def is_madlad(object):
|
190 |
+
"""
|
191 |
+
Check if the object is a MADLAD model or tokenizer
|
192 |
+
"""
|
193 |
+
return "t5" in object.__class__.__name__.lower()
|
194 |
+
|
195 |
+
|
196 |
+
def is_dec_only_llm(obj):
|
197 |
+
lcname = obj.__class__.__name__.lower()
|
198 |
+
return any(k in lcname for k in ["pretrainedtokenizerfast", "llama", "gemma"])
|
199 |
+
|
200 |
+
|
201 |
+
def get_mdl_type(obj):
|
202 |
+
obj = obj.module if hasattr(obj, "module") else obj
|
203 |
+
|
204 |
+
if is_nllb(obj):
|
205 |
+
return MDL_NLLB
|
206 |
+
elif is_madlad(obj):
|
207 |
+
return MDL_MADLAD
|
208 |
+
elif is_dec_only_llm(obj):
|
209 |
+
return MDL_LLAMA
|
210 |
+
else:
|
211 |
+
raise ValueError(f"Object {str(obj)[:200]} is not supported")
|
212 |
+
|
213 |
+
|
214 |
+
def langs_to_mdl_type(mdl_type, lang_set):
|
215 |
+
if mdl_type == MDL_NLLB:
|
216 |
+
return langs_to_nllb(lang_set)
|
217 |
+
elif mdl_type == MDL_MADLAD:
|
218 |
+
return langs_to_madlad(lang_set)
|
219 |
+
elif mdl_type == MDL_LLAMA:
|
220 |
+
return lang_set
|
221 |
+
else:
|
222 |
+
raise ValueError(f"Model type {mdl_type} is not supported")
|
223 |
+
|
224 |
+
|
225 |
+
def get_joshi_class(lang_code):
|
226 |
+
norm_code = any_to_base(lang_code)
|
227 |
+
|
228 |
+
if norm_code is None:
|
229 |
+
return "?"
|
230 |
+
else:
|
231 |
+
norm_code = norm_code.alpha_3
|
232 |
+
|
233 |
+
return _rev_joshi[norm_code]
|
234 |
+
|
235 |
+
def lang_set_maybe_smugri(lang_def):
|
236 |
+
if lang_def == "smugri-low":
|
237 |
+
preresult = SMUGRI_LOW
|
238 |
+
elif lang_def == "smugri-high":
|
239 |
+
preresult = SMUGRI_HIGH
|
240 |
+
elif lang_def == "smugri":
|
241 |
+
preresult = SMUGRI
|
242 |
+
else:
|
243 |
+
preresult = lang_def
|
244 |
+
|
245 |
+
return set(preresult.split(","))
|
246 |
+
|
247 |
+
|
248 |
+
def smugri_back(lang_list):
|
249 |
+
sll = sorted(lang_list)
|
250 |
+
|
251 |
+
sll_str = ",".join(sll)
|
252 |
+
|
253 |
+
if sll_str == SMUGRI_LOW:
|
254 |
+
return "smugri-low"
|
255 |
+
elif sll_str == SMUGRI_HIGH:
|
256 |
+
return "smugri-high"
|
257 |
+
elif sll_str == SMUGRI:
|
258 |
+
return "smugri-full"
|
259 |
+
else:
|
260 |
+
return sll_str
|
kuidastaltsutadalaamat/legacy/localizemodel.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
6 |
+
from modelops import mdl_param_count, is_gen_ai, hf_tok
|
7 |
+
from tokops import train_or_extend_tokenizer_and_upd_model, save_postokens
|
8 |
+
from aux import CmdlineArgs, log
|
9 |
+
from legacy.langconv import lang_set_maybe_smugri
|
10 |
+
|
11 |
+
|
12 |
+
def i_dont_like_global_scope_variable_dangers():
|
13 |
+
args = CmdlineArgs("Localize an existing HuggingFace model, possibly expanding the tokenizer",
|
14 |
+
pos_arg_list=["mdl_id", "save_location"],
|
15 |
+
kw_arg_dict={"tok_train_file": None,
|
16 |
+
"tok_mdl_id": None,
|
17 |
+
"new_langs": None,
|
18 |
+
"merge_tokenizers": 0,
|
19 |
+
"merge_tok_mdl_id": None })
|
20 |
+
if not args.tok_mdl_id:
|
21 |
+
args.tok_mdl_id = args.mdl_id
|
22 |
+
|
23 |
+
if os.path.exists(args.save_location):
|
24 |
+
raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite")
|
25 |
+
|
26 |
+
if args.new_langs:
|
27 |
+
args.new_langs = lang_set_maybe_smugri(args.new_langs)
|
28 |
+
|
29 |
+
if is_gen_ai(args.mdl_id):
|
30 |
+
model = AutoModelForCausalLM.from_pretrained(args.mdl_id, token=hf_tok)
|
31 |
+
else:
|
32 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.mdl_id, token=hf_tok)
|
33 |
+
|
34 |
+
tokenizer, added = train_or_extend_tokenizer_and_upd_model(args, model)
|
35 |
+
|
36 |
+
mdl_size, emb_size = mdl_param_count(model)
|
37 |
+
log(f"Cached model with {mdl_size} parameters" +
|
38 |
+
("" if emb_size < 0 else f" of which {emb_size} ({100 * emb_size / mdl_size:.2f}%) are embeddings"))
|
39 |
+
|
40 |
+
tokenizer.save_pretrained(args.save_location)
|
41 |
+
save_postokens(added, args.save_location)
|
42 |
+
model.save_pretrained(args.save_location)
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
i_dont_like_global_scope_variable_dangers()
|
kuidastaltsutadalaamat/legacy/modelops.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import namedtuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from aux import log
|
7 |
+
|
8 |
+
CouplingSpecTuple = namedtuple("CouplingSpecPair", ["lang_set", "tokenizer", "postokenizer", "model_id", "model"])
|
9 |
+
|
10 |
+
hf_tok = None
|
11 |
+
with open("../hf_token", 'r') as fh:
|
12 |
+
hf_tok = fh.read().strip()
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
MODULE_CONFIG_FILE = "coupled_module_config.json"
|
17 |
+
DATA_STATE_FILE = "data_state.json"
|
18 |
+
LOSS_LIST_FILE = "loss_list.json"
|
19 |
+
|
20 |
+
|
21 |
+
def mdl_param_count(model):
|
22 |
+
result = 0
|
23 |
+
embedding_size = -1
|
24 |
+
|
25 |
+
for n, p in model.named_parameters():
|
26 |
+
this_count = 1
|
27 |
+
|
28 |
+
for s in p.shape:
|
29 |
+
this_count *= s
|
30 |
+
|
31 |
+
result += this_count
|
32 |
+
|
33 |
+
# if n == "model.shared.weight":
|
34 |
+
|
35 |
+
if "shared.weight" in n:
|
36 |
+
embedding_size = this_count
|
37 |
+
|
38 |
+
return result, embedding_size
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
def to_cpl_spec(langs, model, tokenizer, postokenizer, location):
|
43 |
+
mdl_type = get_mdl_type(tokenizer)
|
44 |
+
cpl_langs = set(langs_to_mdl_type(mdl_type, langs))
|
45 |
+
|
46 |
+
return [CouplingSpecTuple(cpl_langs, tokenizer, postokenizer, location, model)]
|
47 |
+
|
48 |
+
|
49 |
+
def _save_json_config(model_dir, filename, data):
|
50 |
+
with open(os.path.join(model_dir, filename), "w") as f:
|
51 |
+
json.dump(data, f, indent=2, sort_keys=True)
|
52 |
+
f.write("\n")
|
53 |
+
|
54 |
+
|
55 |
+
def _load_json_config(model_dir, filename):
|
56 |
+
try:
|
57 |
+
with open(os.path.join(model_dir, filename), "r") as f:
|
58 |
+
return json.load(f)
|
59 |
+
except FileNotFoundError:
|
60 |
+
return None
|
61 |
+
|
62 |
+
def save_module_config(model_dir, coupling_specs):
|
63 |
+
config = [{'lang_set': list(spec.lang_set), 'model_id': spec.model_id if i > 0 else model_dir} for i, spec in enumerate(coupling_specs)]
|
64 |
+
_save_json_config(model_dir, MODULE_CONFIG_FILE, config)
|
65 |
+
|
66 |
+
|
67 |
+
def load_module_config(model_dir):
|
68 |
+
result = _load_json_config(model_dir, MODULE_CONFIG_FILE)
|
69 |
+
|
70 |
+
return result if result is not None else [{"model_id": model_dir, "lang_set": {}}]
|
71 |
+
"""
|
72 |
+
|
73 |
+
def save_all_models(location, model, tokenizer, cpl_specs=None, trainer=None):
|
74 |
+
if not os.path.exists(location):
|
75 |
+
os.makedirs(location)
|
76 |
+
|
77 |
+
if trainer is not None:
|
78 |
+
trainer.save_state(location)
|
79 |
+
|
80 |
+
model.config.save_pretrained(location)
|
81 |
+
model.generation_config.save_pretrained(location)
|
82 |
+
|
83 |
+
tokenizer.save_pretrained(location)
|
84 |
+
"""
|
85 |
+
if cpl_specs is not None:
|
86 |
+
save_module_config(location, cpl_specs)
|
87 |
+
"""
|
88 |
+
|
89 |
+
|
90 |
+
def report_devices(msg = "", accelerator = None, mdl = None):
|
91 |
+
if torch.cuda.is_available():
|
92 |
+
# Get the visible devices from CUDA
|
93 |
+
visible_devices = torch.cuda.device_count()
|
94 |
+
|
95 |
+
#log(f"Number of visible GPUs: {visible_devices}")
|
96 |
+
msg = f"{msg:30} {visible_devices} GPUs:"
|
97 |
+
|
98 |
+
# List the actual GPUs being used
|
99 |
+
gpu_names = [torch.cuda.get_device_name(i) for i in range(visible_devices)]
|
100 |
+
for i, name in enumerate(gpu_names):
|
101 |
+
mem_alloc = torch.cuda.memory_allocated(i) / 1024**2
|
102 |
+
mem_res = torch.cuda.memory_reserved(i) / 1024**2
|
103 |
+
|
104 |
+
if mem_alloc > 0.01 or mem_res > 0.01:
|
105 |
+
msg += f" {i}: alloc {mem_alloc:.2f} Mb / res {mem_res:.2f} Mb;"
|
106 |
+
|
107 |
+
log(msg, accelerator=accelerator)
|
108 |
+
elif accelerator is not None and accelerator.device.type == "mps":
|
109 |
+
mem_alloc = torch.mps.current_allocated_memory() / 1024**2
|
110 |
+
log(f"{msg:30} device being used: {accelerator.device}, mem alloc: {mem_alloc} Mb", accelerator=accelerator)
|
111 |
+
else:
|
112 |
+
log(f"No acceleration")
|
113 |
+
|
114 |
+
#if mdl is not None:
|
115 |
+
# log(f"Model device: {mdl.device}", accelerator=accelerator)
|
116 |
+
|
117 |
+
|
118 |
+
def is_gen_ai(mdl_id):
|
119 |
+
lc = mdl_id.lower()
|
120 |
+
return not ("madlad" in lc or "nllb" in lc or "m2m" in lc or "bart" in lc)
|
121 |
+
|
122 |
+
|
kuidastaltsutadalaamat/legacy/oldtrainllm.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import sys
|
7 |
+
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
|
11 |
+
from accel import SwitchingAccelerator
|
12 |
+
from modelops import hf_tok, save_all_models
|
13 |
+
|
14 |
+
from aux import log, CmdlineArgs
|
15 |
+
from data import do_list_in_batches
|
16 |
+
|
17 |
+
|
18 |
+
def _cmdline_args():
|
19 |
+
description = """Train or tune decoder models"""
|
20 |
+
|
21 |
+
result = CmdlineArgs(description,
|
22 |
+
pos_arg_list=["mdl_id", "save_location", "train_file"],
|
23 |
+
pos_arg_types=[str, str, str],
|
24 |
+
kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
|
25 |
+
"batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
|
26 |
+
"max_length": 3000 })
|
27 |
+
|
28 |
+
# if the directory args.save_location already exists, raise an exception:
|
29 |
+
if not result.continue_training and os.path.exists(result.save_location):
|
30 |
+
raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")
|
31 |
+
|
32 |
+
if result.nr_sents_per_gpu == 0:
|
33 |
+
result.nr_sents_per_gpu = result.batch_size
|
34 |
+
|
35 |
+
return result
|
36 |
+
|
37 |
+
|
38 |
+
def load_json_list(json_file):
|
39 |
+
with open(json_file, "r") as f:
|
40 |
+
data = json.load(f)
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def load_hf_model(mdl_id, accelerator=None):
|
45 |
+
if accelerator is None:
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
47 |
+
else:
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device)
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def load_hf_tokenizer(mdl_id):
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok)
|
54 |
+
return tokenizer
|
55 |
+
|
56 |
+
|
57 |
+
def _no_globals_main():
|
58 |
+
accelerator = Accelerator()
|
59 |
+
|
60 |
+
try:
|
61 |
+
args = _cmdline_args()
|
62 |
+
|
63 |
+
log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}")
|
64 |
+
log("loading model", accelerator=accelerator)
|
65 |
+
mdl = load_hf_model(args.mdl_id)
|
66 |
+
|
67 |
+
log("loading tokenizer", accelerator=accelerator)
|
68 |
+
tok = load_hf_tokenizer(args.mdl_id)
|
69 |
+
|
70 |
+
log("loading data", accelerator=accelerator, all_threads=True)
|
71 |
+
train_set = load_json_list(args.train_file)
|
72 |
+
|
73 |
+
log("training", accelerator=accelerator)
|
74 |
+
|
75 |
+
acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator)
|
76 |
+
upd_model = acc_trainer.train()
|
77 |
+
|
78 |
+
log("saving", accelerator=accelerator)
|
79 |
+
save_all_models(args.save_location, upd_model, tok)
|
80 |
+
except Exception as e:
|
81 |
+
# in multiprocess scenarios it is hard to read the stack trace, so just show one:
|
82 |
+
if accelerator.is_main_process:
|
83 |
+
raise e
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
#sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split()
|
88 |
+
#sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split()
|
89 |
+
|
90 |
+
_no_globals_main()
|
kuidastaltsutadalaamat/legacy/parasynth.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
from benchmark import get_hyp_cache_dir, translate_all_hyps
|
9 |
+
from inference import load_and_init_module_config
|
10 |
+
from legacy.langconv import get_high_set, any_to_mdl_type, get_mdl_type
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from aux import log
|
13 |
+
|
14 |
+
|
15 |
+
def load_raw_data(path):
|
16 |
+
with open(path, 'r') as f:
|
17 |
+
return json.load(f)
|
18 |
+
|
19 |
+
|
20 |
+
def save_raw_data(path, data):
|
21 |
+
with open(path, 'w') as f:
|
22 |
+
json.dump(data, f, indent=2)
|
23 |
+
|
24 |
+
|
25 |
+
def apply_func_to_hires_snts(snt_set, func):
|
26 |
+
high_set = get_high_set()
|
27 |
+
|
28 |
+
for tupl in snt_set:
|
29 |
+
langs = [k for k in tupl if not "-dia" in k and k in high_set]
|
30 |
+
|
31 |
+
if langs:
|
32 |
+
revlangs = high_set - set(langs)
|
33 |
+
|
34 |
+
for revlang in revlangs:
|
35 |
+
for lang in langs:
|
36 |
+
# translate sentences tupl[lang] from lang to revlang
|
37 |
+
# OR
|
38 |
+
# add the result as tupl[revlang]
|
39 |
+
func(tupl, lang, revlang)
|
40 |
+
|
41 |
+
|
42 |
+
def report_part_stats(part, part_index, num_parts):
|
43 |
+
hi_set = get_high_set()
|
44 |
+
|
45 |
+
num_snts = len(part['sentences'])
|
46 |
+
hires_langs = {k for k in part['sentences'][0] if "dia" not in k and k in hi_set}
|
47 |
+
num_hires_langs = len(hires_langs)
|
48 |
+
langs_to_do = hi_set - hires_langs
|
49 |
+
num_to_translate = num_hires_langs * len(langs_to_do)
|
50 |
+
|
51 |
+
log(f"Part {part_index + 1}/{num_parts}; {num_snts} sentences, num hires: {num_hires_langs}, to translate: {num_to_translate}")
|
52 |
+
|
53 |
+
return num_snts * num_hires_langs, num_snts * num_to_translate
|
54 |
+
|
55 |
+
|
56 |
+
def add_hires_synth_data(mdl_id, corpus_in, corpus_out, dry=False):
|
57 |
+
accelerator = Accelerator()
|
58 |
+
|
59 |
+
log("Loading data", accelerator)
|
60 |
+
data = load_raw_data(corpus_in)
|
61 |
+
|
62 |
+
|
63 |
+
log("Loading model", accelerator)
|
64 |
+
if dry:
|
65 |
+
main_model, module_config = None, None
|
66 |
+
mdl_type = None
|
67 |
+
else:
|
68 |
+
main_model, module_config = load_and_init_module_config(mdl_id, accelerator)
|
69 |
+
mdl_type = get_mdl_type(main_model)
|
70 |
+
|
71 |
+
if accelerator.is_main_process:
|
72 |
+
_ = get_hyp_cache_dir(mdl_id, create=True)
|
73 |
+
l = len(data)
|
74 |
+
|
75 |
+
tot_snt = 0
|
76 |
+
tot_tr = 0
|
77 |
+
|
78 |
+
for i, part in enumerate(data):
|
79 |
+
tr_dict = defaultdict(lambda: defaultdict(lambda: None))
|
80 |
+
|
81 |
+
num_snt, num_tr = report_part_stats(part, i, l)
|
82 |
+
tot_snt += num_snt
|
83 |
+
tot_tr += num_tr
|
84 |
+
|
85 |
+
if not dry:
|
86 |
+
def _transfer(tup, src, tgt):
|
87 |
+
srcm = any_to_mdl_type(mdl_type, src)
|
88 |
+
tgtm = any_to_mdl_type(mdl_type, tgt)
|
89 |
+
|
90 |
+
lp = f"{srcm}-{tgtm}"
|
91 |
+
inp_snt = tup[src]
|
92 |
+
|
93 |
+
# this "touches" the value: if it was not there, now it is None
|
94 |
+
# and if it was there, then we use it
|
95 |
+
if tr_dict[lp][inp_snt] is not None:
|
96 |
+
tup[tgt] = tr_dict[lp][inp_snt]
|
97 |
+
|
98 |
+
# collect sentences to translate
|
99 |
+
apply_func_to_hires_snts(part['sentences'], _transfer)
|
100 |
+
|
101 |
+
in_tr_dict_list = { lp: sorted(tr_dict[lp].items()) for lp in tr_dict }
|
102 |
+
|
103 |
+
log(f"Translating part {i+1}/{l}", accelerator)
|
104 |
+
#translate_cache_dict(tr_dict, mdl_id, module_config, corpus_in, accelerator)
|
105 |
+
translate_all_hyps(in_tr_dict_list, module_config, mdl_id, f"{corpus_in}-{i}", accelerator)
|
106 |
+
|
107 |
+
log(f"Collecting part {i+1}/{l}", accelerator)
|
108 |
+
out_tr_dict_list = translate_all_hyps(in_tr_dict_list, module_config, mdl_id, corpus_in)
|
109 |
+
|
110 |
+
for lp in out_tr_dict_list:
|
111 |
+
for inp, outp in out_tr_dict_list[lp]:
|
112 |
+
tr_dict[lp][inp] = outp
|
113 |
+
|
114 |
+
# put translations back into data structure
|
115 |
+
log(f"Integrating part {i+1}/{l}", accelerator)
|
116 |
+
apply_func_to_hires_snts(part['sentences'], _transfer)
|
117 |
+
|
118 |
+
log(f"Total sentences: {tot_snt}, total to generate: {tot_tr}", accelerator)
|
119 |
+
if not dry:
|
120 |
+
log("Saving data", accelerator)
|
121 |
+
save_raw_data(corpus_out, data)
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
try:
|
125 |
+
mdl_id_param = sys.argv[1]
|
126 |
+
corpus_param = sys.argv[2]
|
127 |
+
corpus_output_param = sys.argv[3]
|
128 |
+
except IndexError:
|
129 |
+
mdl_id_param = "models/nllb600m"
|
130 |
+
corpus_param = "data/flt.json"
|
131 |
+
corpus_output_param = "data/fltout.json"
|
132 |
+
|
133 |
+
try:
|
134 |
+
_ = sys.argv[4]
|
135 |
+
dry_run = True
|
136 |
+
except IndexError:
|
137 |
+
dry_run = False
|
138 |
+
|
139 |
+
add_hires_synth_data(mdl_id_param, corpus_param, corpus_output_param, dry_run)
|
kuidastaltsutadalaamat/legacy/pretok.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
from data import MultilingualBatchingCachingDataset
|
6 |
+
from aux import log, CmdlineArgs
|
7 |
+
from legacy.langconv import lang_set_maybe_smugri
|
8 |
+
from modelops import to_cpl_spec
|
9 |
+
from tokops import load_tokenizer
|
10 |
+
|
11 |
+
"""
|
12 |
+
def load_hf_tok(mdl_id, tok_id=None, verbose=False):
|
13 |
+
if tok_id is None:
|
14 |
+
tok_id = mdl_id
|
15 |
+
|
16 |
+
tokenizer = AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)
|
17 |
+
|
18 |
+
return tokenizer
|
19 |
+
"""
|
20 |
+
|
21 |
+
|
22 |
+
def _cmdline_args():
|
23 |
+
description = """Pre-tokenize data and cache the results"""
|
24 |
+
|
25 |
+
pos_args = ["mdl_id", "train_file", "langs", "cache_path"]
|
26 |
+
pos_types = [str, str, lang_set_maybe_smugri, str]
|
27 |
+
|
28 |
+
kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "batch_size": 16, "shard_size": 100000,
|
29 |
+
"exclude_set": None, "max_snt_len": 1024, "sort_by_len": False }
|
30 |
+
|
31 |
+
#post-process the arguments
|
32 |
+
args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)
|
33 |
+
|
34 |
+
if args.anchor_langs is not None:
|
35 |
+
args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)
|
36 |
+
|
37 |
+
# if the directory args.save_location already exists, raise an exception:
|
38 |
+
if os.path.exists(args.cache_path):
|
39 |
+
raise Exception(f"Save location '{args.cache_path}' already exists, don't want to overwrite")
|
40 |
+
|
41 |
+
log(f"Launched as {args}")
|
42 |
+
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
def oh_look_another_do_main_function():
|
47 |
+
args = _cmdline_args()
|
48 |
+
|
49 |
+
log("loading tokenizer")
|
50 |
+
main_tokenizer, main_postok = load_tokenizer(args.mdl_id) #load_hf_tok(args.mdl_id, verbose=True)
|
51 |
+
|
52 |
+
coupling_specs = to_cpl_spec(args.langs, None, main_tokenizer, main_postok, None)
|
53 |
+
|
54 |
+
if args.anchor_mdl_id is not None:
|
55 |
+
log("loading anchor model tokenizer")
|
56 |
+
anchor_tokenizer, anc_postok = load_tokenizer(args.anchor_mdl_id)
|
57 |
+
|
58 |
+
coupling_specs += to_cpl_spec(args.anchor_langs, None, anchor_tokenizer, anc_postok, None)
|
59 |
+
|
60 |
+
mbd = MultilingualBatchingCachingDataset(args.train_file, coupling_specs, args)
|
61 |
+
mbd.load_and_cache_data(args.cache_path)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
oh_look_another_do_main_function()
|
kuidastaltsutadalaamat/legacy/testmem.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import torch.optim
|
4 |
+
import sys
|
5 |
+
import subprocess
|
6 |
+
import random
|
7 |
+
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from transformers import AutoModelForCausalLM, get_scheduler, AutoModelForSeq2SeqLM
|
10 |
+
from datasets import load_dataset
|
11 |
+
|
12 |
+
from aux import CmdlineArgs, log
|
13 |
+
from legacy.langconv import is_dec_only_llm
|
14 |
+
from modelops import report_devices, hf_tok
|
15 |
+
from tokops import load_tokenizer, tokenizeit
|
16 |
+
|
17 |
+
|
18 |
+
def run_test(mdl_id, batch_sizes, ctxlen, acc):
|
19 |
+
#state = AcceleratorState()
|
20 |
+
log(f"Num proc: {acc.num_processes}, proc ID: {acc.process_index}")
|
21 |
+
|
22 |
+
report_devices("Initial state:", accelerator=acc)
|
23 |
+
|
24 |
+
t, pt = load_tokenizer(mdl_id) # AutoTokenizer.from_mpretrained(mdl_id, token=hf_tok)
|
25 |
+
if is_dec_only_llm(t):
|
26 |
+
m = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
27 |
+
log("Decoder-only model")
|
28 |
+
else:
|
29 |
+
m = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
30 |
+
log("Encoder-decoder model")
|
31 |
+
|
32 |
+
opt = torch.optim.AdamW(m.parameters(), lr=1e-5)
|
33 |
+
lrs = get_scheduler("linear", optimizer=opt, num_warmup_steps=100, num_training_steps=1000)
|
34 |
+
opt, lrs, m = acc.prepare(opt, lrs, m)
|
35 |
+
|
36 |
+
report_devices("Models in VRAM:", accelerator=acc)
|
37 |
+
m.train()
|
38 |
+
|
39 |
+
ds = load_dataset("Helsinki-NLP/europarl", "en-et")
|
40 |
+
max_idx = len(ds['train'])
|
41 |
+
|
42 |
+
for batch_size in batch_sizes:
|
43 |
+
print("")
|
44 |
+
|
45 |
+
for _ in range(10):
|
46 |
+
inp_idx = random.randint(0, max_idx-batch_size)
|
47 |
+
|
48 |
+
raw_inp = [ds['train'][i]['translation']['et'] for i in range(inp_idx, inp_idx+batch_size)]
|
49 |
+
|
50 |
+
if is_dec_only_llm(t):
|
51 |
+
inp = tokenizeit((t, pt), raw_inp, ctxlen, is_target=False, is_llm=True)
|
52 |
+
else:
|
53 |
+
inp = tokenizeit((t, pt), raw_inp, ctxlen, is_target=False, is_llm=False)
|
54 |
+
|
55 |
+
inp['labels'] = inp['input_ids']
|
56 |
+
inp.to(m.device)
|
57 |
+
|
58 |
+
outputs = m(**inp)
|
59 |
+
|
60 |
+
loss = outputs.loss
|
61 |
+
report_devices(f"While training:", accelerator=acc)
|
62 |
+
log(f"Batches : {[inp[k].size() for k in 'input_ids labels attention_mask'.split(' ')]}")
|
63 |
+
log(f"Batch total: {sum([inp[k].size()[0] * inp[k].size()[1] for k in 'input_ids labels attention_mask'.split(' ')])}")
|
64 |
+
|
65 |
+
try:
|
66 |
+
if acc.is_main_process:
|
67 |
+
result = subprocess.run(['rocm-smi'], capture_output=True, text=True)
|
68 |
+
print(result.stdout)
|
69 |
+
except:
|
70 |
+
pass
|
71 |
+
|
72 |
+
acc.backward(loss)
|
73 |
+
acc.wait_for_everyone()
|
74 |
+
|
75 |
+
report_devices(f"Models gradients in VRAM, batch size {batch_size}:", accelerator=acc)
|
76 |
+
|
77 |
+
print(f"Testing {mdl_id} with batch size {batch_size}: success!")
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
if len(sys.argv) > 1:
|
83 |
+
args = CmdlineArgs("Test the VRAM usage by a model with different batch sizes, comma-separated",
|
84 |
+
pos_arg_list=["mdl_id", "batch_sizes"],
|
85 |
+
kw_arg_dict={"ctxlen": 2048})
|
86 |
+
|
87 |
+
clean_bs = [int(bs) for bs in args.batch_sizes.split(",")]
|
88 |
+
mdl_id = args.mdl_id
|
89 |
+
ctxlen = args.ctxlen
|
90 |
+
else:
|
91 |
+
mdl_id = "meta-llama/Llama-3.2-1B"
|
92 |
+
clean_bs = [16, 32, 64]
|
93 |
+
ctxlen = 2048
|
94 |
+
|
95 |
+
acc = Accelerator()
|
96 |
+
try:
|
97 |
+
run_test(mdl_id, clean_bs, ctxlen, acc)
|
98 |
+
except Exception as e:
|
99 |
+
if acc.is_main_process:
|
100 |
+
raise e
|
kuidastaltsutadalaamat/legacy/tokops.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import sentencepiece as spm
|
4 |
+
import json
|
5 |
+
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from transformers.models.nllb import NllbTokenizer
|
8 |
+
from transformers.models.t5 import T5Tokenizer
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
from aux import log
|
12 |
+
from legacy.langconv import langs_to_madlad, langs_to_nllb, is_nllb, is_madlad, is_dec_only_llm
|
13 |
+
from modelops import hf_tok
|
14 |
+
|
15 |
+
|
16 |
+
def test_tok(tok, snt, lang):
|
17 |
+
tok.src_lang = lang
|
18 |
+
out = tok(text = snt)
|
19 |
+
print(out['input_ids'])
|
20 |
+
print(tok.tokenize(snt))
|
21 |
+
print(tok.convert_ids_to_tokens(out['input_ids']))
|
22 |
+
print("-")
|
23 |
+
|
24 |
+
|
25 |
+
def get_stupid_correction(mdl_id):
|
26 |
+
l_mdl_id = mdl_id.lower()
|
27 |
+
|
28 |
+
if "m2m" in l_mdl_id:
|
29 |
+
correction = 108
|
30 |
+
elif "nllb" in l_mdl_id:
|
31 |
+
correction = 2
|
32 |
+
else:
|
33 |
+
correction = 0
|
34 |
+
|
35 |
+
return correction
|
36 |
+
|
37 |
+
|
38 |
+
def tsv_to_json_vocab(location):
|
39 |
+
new_location = location + ".json"
|
40 |
+
|
41 |
+
with open(location, "r") as f, open(new_location, "w") as w:
|
42 |
+
idx_dict = { "<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3 }
|
43 |
+
|
44 |
+
for line in f:
|
45 |
+
tok, _ = line.strip().split("\t")
|
46 |
+
if tok not in idx_dict:
|
47 |
+
idx_dict[tok] = len(idx_dict)
|
48 |
+
|
49 |
+
json.dump(idx_dict, w)
|
50 |
+
|
51 |
+
return new_location
|
52 |
+
|
53 |
+
|
54 |
+
def get_unk_toks(tokenizer, corpus, verbose=False):
|
55 |
+
unk_id = tokenizer.unk_token_id
|
56 |
+
unk_toks = defaultdict(int)
|
57 |
+
|
58 |
+
all_toks = set()
|
59 |
+
|
60 |
+
total_count = 0
|
61 |
+
unk_count = 0
|
62 |
+
|
63 |
+
with open(corpus, "r", encoding='utf-8') as f:
|
64 |
+
for snt in f:
|
65 |
+
toks = tokenizer.tokenize(snt.strip())
|
66 |
+
ids = tokenizer.convert_tokens_to_ids(toks)
|
67 |
+
|
68 |
+
for t, i in zip(toks, ids):
|
69 |
+
if i == unk_id:
|
70 |
+
unk_toks[t] += 1
|
71 |
+
unk_count += 1
|
72 |
+
total_count += 1
|
73 |
+
|
74 |
+
all_toks.add(t)
|
75 |
+
|
76 |
+
if verbose:
|
77 |
+
print(f"Tokenizer vocab size: {tokenizer.vocab_size}, nr of actually used tokens: {len(all_toks)}")
|
78 |
+
print(f"Corpus token count: {total_count}, UNK token percentage: {100*unk_count/total_count:.2f}%")
|
79 |
+
|
80 |
+
return list(unk_toks)
|
81 |
+
|
82 |
+
|
83 |
+
def get_top_toks(tokenizer, corpus, num_top_toks):
|
84 |
+
freq_count = defaultdict(int)
|
85 |
+
|
86 |
+
with open(corpus, "r", encoding='utf-8') as f:
|
87 |
+
for snt in f:
|
88 |
+
toks = tokenizer.tokenize(snt.strip())
|
89 |
+
|
90 |
+
for t in toks:
|
91 |
+
freq_count[t] += 1
|
92 |
+
|
93 |
+
sorted_freq_count = sorted(freq_count.keys(), key=lambda x: -freq_count[x])
|
94 |
+
|
95 |
+
return sorted_freq_count[:num_top_toks]
|
96 |
+
|
97 |
+
|
98 |
+
def extend_tok_langs(tokenizer, lang_set_raw):
|
99 |
+
if is_nllb(tokenizer):
|
100 |
+
lang_set = langs_to_nllb(lang_set_raw)
|
101 |
+
elif is_madlad(tokenizer):
|
102 |
+
lang_set = langs_to_madlad(lang_set_raw)
|
103 |
+
elif is_dec_only_llm(tokenizer):
|
104 |
+
return
|
105 |
+
else:
|
106 |
+
raise NotImplementedError
|
107 |
+
|
108 |
+
if 'additional_special_tokens' in tokenizer.special_tokens_map:
|
109 |
+
orig_langs = tokenizer.special_tokens_map['additional_special_tokens']
|
110 |
+
orig_lang_set = set(orig_langs)
|
111 |
+
|
112 |
+
addable_langs = list(set(lang_set) - orig_lang_set)
|
113 |
+
else:
|
114 |
+
orig_langs = []
|
115 |
+
addable_langs = lang_set
|
116 |
+
|
117 |
+
tokenizer.add_special_tokens({'additional_special_tokens': orig_langs + addable_langs})
|
118 |
+
|
119 |
+
|
120 |
+
def wrap_tok_in_correct_class(location, base_model_id, lang_set):
|
121 |
+
l_base_mdl_id = base_model_id.lower()
|
122 |
+
|
123 |
+
if "nllb" in l_base_mdl_id:
|
124 |
+
nllb_lang_set = langs_to_nllb(lang_set)
|
125 |
+
return NllbTokenizer(location + ".model", additional_special_tokens=nllb_lang_set)
|
126 |
+
|
127 |
+
elif "madlad" in l_base_mdl_id or "t5" in l_base_mdl_id:
|
128 |
+
madlad_lang_set = langs_to_madlad(lang_set)
|
129 |
+
return T5Tokenizer(location + ".model", additional_special_tokens=madlad_lang_set)
|
130 |
+
else:
|
131 |
+
raise ValueError("Incompatible model type for tokenizer")
|
132 |
+
|
133 |
+
|
134 |
+
def remove_tmp_spm_files(location):
|
135 |
+
for tmp_file in (".vocab", ".model"):
|
136 |
+
os.remove(location + tmp_file)
|
137 |
+
|
138 |
+
|
139 |
+
def learn_spm_tokenizer(corpus, save_location, base_model_id, vocab_size, lang_set=None):
|
140 |
+
tmp_location = os.path.join(save_location, "sentencepiece.bpe.tmp")
|
141 |
+
os.makedirs(save_location, exist_ok=True)
|
142 |
+
|
143 |
+
spm.SentencePieceTrainer.train(input=corpus, model_prefix=tmp_location, vocab_size=vocab_size)
|
144 |
+
|
145 |
+
tok = wrap_tok_in_correct_class(tmp_location, base_model_id, lang_set)
|
146 |
+
|
147 |
+
remove_tmp_spm_files(tmp_location)
|
148 |
+
|
149 |
+
return tok
|
150 |
+
|
151 |
+
|
152 |
+
def do_new_tok(tokargs):
|
153 |
+
correction = get_stupid_correction(tokargs.mdl_id)
|
154 |
+
voc_size = tokargs.vocab_size - correction
|
155 |
+
location = tokargs.save_location
|
156 |
+
|
157 |
+
return learn_spm_tokenizer(tokargs.tok_train_file, location, base_model_id=tokargs.tok_mdl_id,
|
158 |
+
vocab_size=voc_size, lang_set=tokargs.new_langs)
|
159 |
+
|
160 |
+
|
161 |
+
def remove_known_toks(toks, tokenizer):
|
162 |
+
return [t for t in toks if not t in tokenizer.get_vocab()]
|
163 |
+
|
164 |
+
|
165 |
+
def _handle_new_tokenizer(args):
|
166 |
+
assert args.new_langs is not None, "lang_set must be provided"
|
167 |
+
assert args.tok_train_file is not None, "tok_train_file must be provided"
|
168 |
+
args.vocab_size = int(args.vocab_size)
|
169 |
+
|
170 |
+
log("Training new tokenizer")
|
171 |
+
tokenizer = do_new_tok(args)
|
172 |
+
|
173 |
+
return tokenizer
|
174 |
+
|
175 |
+
|
176 |
+
def get_postoken_filename(save_location):
|
177 |
+
return os.path.join(save_location, "postokens.json")
|
178 |
+
|
179 |
+
|
180 |
+
def save_postokens(added_tokens, location):
|
181 |
+
if added_tokens is not None:
|
182 |
+
os.makedirs(location, exist_ok=True)
|
183 |
+
with open(get_postoken_filename(location), "w") as f:
|
184 |
+
json.dump(added_tokens, f)
|
185 |
+
|
186 |
+
|
187 |
+
def _handle_adding_tokens(tokenizer, toks_to_add, args):
|
188 |
+
if len(toks_to_add) == 0:
|
189 |
+
return None
|
190 |
+
|
191 |
+
log(f"Adding tokens: {toks_to_add}")
|
192 |
+
|
193 |
+
base_idx = len(tokenizer)
|
194 |
+
|
195 |
+
added_tok_dict = { t: (base_idx + i) for i, t in enumerate(toks_to_add) }
|
196 |
+
added_tok_rev_dict = { int(i): t for t, i in added_tok_dict.items() }
|
197 |
+
|
198 |
+
comb_dict = { 'tok2idx': added_tok_dict, 'idx2tok': added_tok_rev_dict }
|
199 |
+
|
200 |
+
save_postokens(comb_dict, args.save_location)
|
201 |
+
|
202 |
+
return comb_dict
|
203 |
+
|
204 |
+
|
205 |
+
def _handle_existing_tokenizer(args):
|
206 |
+
log("Reusing existing tokenizer")
|
207 |
+
tokenizer, added_tokens = load_tokenizer(args.tok_mdl_id)
|
208 |
+
|
209 |
+
if args.new_langs is not None:
|
210 |
+
log("Extending existing tokenizer with languages")
|
211 |
+
extend_tok_langs(tokenizer, args.new_langs)
|
212 |
+
|
213 |
+
if args.merge_tokenizers or args.merge_tok_mdl_id:
|
214 |
+
"""
|
215 |
+
assert args.tok_train_file is not None, "For merging tokenizers a text file must be provided" \
|
216 |
+
+ " to find the top N tokens to merge"
|
217 |
+
assert args.merge_tokenizers is not None and args.merge_tok_mdl_id is not None, \
|
218 |
+
"Both merge_tokenizers and merge_tok_mdl_id must be provided"
|
219 |
+
"""
|
220 |
+
raise NotImplementedError("Merging is currently not supported")
|
221 |
+
|
222 |
+
added_tok_count = 0
|
223 |
+
|
224 |
+
if args.tok_train_file:
|
225 |
+
if args.merge_tokenizers:
|
226 |
+
"""
|
227 |
+
merge_tok_max = int(args.merge_tokenizers)
|
228 |
+
log(f"Extending existing tokenizer ({args.merge_tok_mdl_id}) with up to {merge_tok_max} top tokens" +
|
229 |
+
f" from another tokenizer and corpus ({args.tok_train_file})")
|
230 |
+
new_tok = AutoTokenizer.from_pretrained(args.merge_tok_mdl_id, token=hf_tok)
|
231 |
+
toks_to_maybe_add = get_top_toks(new_tok, args.tok_train_file, merge_tok_max)
|
232 |
+
"""
|
233 |
+
raise NotImplementedError("Merging is currently not supported")
|
234 |
+
|
235 |
+
else:
|
236 |
+
log(f"Extending existing tokenizer with UNK tokens from corpus ({args.tok_train_file})")
|
237 |
+
toks_to_maybe_add = get_unk_toks(tokenizer, args.tok_train_file, verbose=True)
|
238 |
+
|
239 |
+
toks_to_add = remove_known_toks(toks_to_maybe_add, tokenizer)
|
240 |
+
added_tok_count = len(toks_to_add)
|
241 |
+
added_tokens = _handle_adding_tokens(tokenizer, toks_to_add, args)
|
242 |
+
|
243 |
+
return tokenizer, added_tok_count, added_tokens
|
244 |
+
|
245 |
+
|
246 |
+
def train_or_extend_tokenizer_and_upd_model(args, model):
|
247 |
+
if hasattr(args, "vocab_size") and args.vocab_size:
|
248 |
+
# train a new sentence-piece tokenizer
|
249 |
+
tokenizer = _handle_new_tokenizer(args)
|
250 |
+
added_tok_count = 0
|
251 |
+
added_dict = None
|
252 |
+
else:
|
253 |
+
# save the pre-trained model's tokenizer, possibly adding new languages and tokens
|
254 |
+
tokenizer, added_tok_count, added_dict = _handle_existing_tokenizer(args)
|
255 |
+
|
256 |
+
upd_amt = get_stupid_correction(args.mdl_id)
|
257 |
+
new_len = len(tokenizer) + added_tok_count
|
258 |
+
|
259 |
+
model.resize_token_embeddings(new_len + upd_amt)
|
260 |
+
|
261 |
+
return tokenizer, added_dict
|
262 |
+
|
263 |
+
|
264 |
+
def load_tokenizer(tok_mdl_id):
|
265 |
+
orig_tokenizer = AutoTokenizer.from_pretrained(tok_mdl_id, token=hf_tok)
|
266 |
+
|
267 |
+
postoken_file = get_postoken_filename(tok_mdl_id)
|
268 |
+
if os.path.exists(postoken_file):
|
269 |
+
with open(postoken_file, "r") as f:
|
270 |
+
postokens = json.load(f)
|
271 |
+
else:
|
272 |
+
postokens = None
|
273 |
+
|
274 |
+
return orig_tokenizer, postokens
|
275 |
+
|
276 |
+
|
277 |
+
def tokenize_batch(tokenizer, sntlist, maxlen=8000):
|
278 |
+
#tokenizer.pad_token = '<|reserved_special_token_0|>'
|
279 |
+
tokenizer.pad_token = tokenizer.eos_token
|
280 |
+
output = tokenizer(sntlist, return_tensors="pt", max_length=maxlen, truncation=True, add_special_tokens=True,
|
281 |
+
padding=True)
|
282 |
+
output["labels"] = output["input_ids"].detach().clone()
|
283 |
+
return output
|
284 |
+
|
285 |
+
"""
|
286 |
+
|
287 |
+
def detokenizeit(toktup, tok_ids):
|
288 |
+
#return toktup[0].decode(tok_ids, skip_special_tokens=True)
|
289 |
+
|
290 |
+
toks = []
|
291 |
+
|
292 |
+
for tok_id_tensor in tok_ids:
|
293 |
+
tok_id = tok_id_tensor.item()
|
294 |
+
try:
|
295 |
+
if tok_id not in toktup[0].added_tokens_decoder:
|
296 |
+
toks.append(toktup[0].convert_ids_to_tokens(tok_id))
|
297 |
+
except IndexError:
|
298 |
+
toks.append(toktup[1]['idx2tok'][str(tok_id)])
|
299 |
+
|
300 |
+
result = "".join(toks).replace("▁", " ")[1:]
|
301 |
+
|
302 |
+
return result, toks
|
303 |
+
|
304 |
+
|
305 |
+
def detokenizemany(toktup, tok_mtx):
|
306 |
+
result = [detokenizeit(toktup, tok_ids)[0] for tok_ids in tok_mtx]
|
307 |
+
|
308 |
+
return result
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
def run_tokenizer_testing():
|
315 |
+
args = CmdlineArgs("Test a tokenizer: tokenize & de-tokenize some text and check if these match",
|
316 |
+
pos_arg_list=["tok_mdl_id", "txt_file"])
|
317 |
+
|
318 |
+
#tokenizer = AutoTokenizer.fromm_pretrained(args.tok_mdl_id, token=hf_tok) if os.path.exists()
|
319 |
+
toktup = load_tokenizer(args.tok_mdl_id)
|
320 |
+
|
321 |
+
success = 0
|
322 |
+
failure = 0
|
323 |
+
|
324 |
+
with open(args.txt_file, "r", encoding="utf-8") as f:
|
325 |
+
snts = f.read().split("\n")
|
326 |
+
|
327 |
+
toks = tokenizeit(toktup, snts, 1024, False)
|
328 |
+
|
329 |
+
for i, snt in enumerate(snts):
|
330 |
+
tok_ids = toks['input_ids'][i]
|
331 |
+
|
332 |
+
#detoks = toktup[0].decode(tok_ids, skip_special_tokens=True)
|
333 |
+
detoks, tok_strs = detokenizeit(toktup, tok_ids)
|
334 |
+
|
335 |
+
if detoks != snt:
|
336 |
+
failure += 1
|
337 |
+
#log(f"Tokens: {toktup[0].convert_ids_to_tokens(tok_ids)}")
|
338 |
+
log(f"Tokens: {tok_strs}")
|
339 |
+
log(f"Test failed:\n{snt} !=\n{detoks}")
|
340 |
+
else:
|
341 |
+
success += 1
|
342 |
+
i += 1
|
343 |
+
|
344 |
+
log(f"Test result: {success} successful / {failure} failed")
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
sys.argv = ['', 'models/nllbxt', 'data/tok-test.txt']
|
349 |
+
run_tokenizer_testing()
|
350 |
+
"""
|
kuidastaltsutadalaamat/legacy/trainmodel.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
7 |
+
|
8 |
+
from legacy.accel import SwitchingAccelerator
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from data import MultilingualDatasetIterator
|
11 |
+
from aux import log, CmdlineArgs
|
12 |
+
from legacy.langconv import lang_set_maybe_smugri, is_dec_only_llm
|
13 |
+
from modelops import mdl_param_count, to_cpl_spec, hf_tok
|
14 |
+
from tokops import load_tokenizer
|
15 |
+
|
16 |
+
|
17 |
+
def freeze_model(model):
|
18 |
+
for n, p in model.named_parameters():
|
19 |
+
p.requires_grad = False
|
20 |
+
|
21 |
+
|
22 |
+
def load_hf_mdl_and_tok(mdl_id, tok_id=None, verbose=False):
|
23 |
+
if tok_id is None:
|
24 |
+
tok_id = mdl_id
|
25 |
+
|
26 |
+
tokenizer = load_tokenizer(tok_id) # AutoTokenizer.fromm_pretrained(tok_id, token=hf_tok)
|
27 |
+
|
28 |
+
if is_dec_only_llm(tokenizer[0]):
|
29 |
+
model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
30 |
+
else:
|
31 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16)
|
32 |
+
|
33 |
+
if verbose:
|
34 |
+
mdl_size, _ = mdl_param_count(model)
|
35 |
+
log(f"Loaded {mdl_id} with {mdl_size} params, voc size {model.config.vocab_size}")
|
36 |
+
|
37 |
+
return model, tokenizer
|
38 |
+
|
39 |
+
|
40 |
+
def _cmdline_args():
|
41 |
+
description = """Train or tune models"""
|
42 |
+
|
43 |
+
pos_args = ["mdl_id", "save_location", "train_pretok_file", "langs"]
|
44 |
+
pos_types = [str, str, str, lang_set_maybe_smugri]
|
45 |
+
|
46 |
+
kw_args = { "anchor_mdl_id": None, "anchor_langs": None, "continue_training": False,
|
47 |
+
"save_steps": 100000, "lr": 1.5e-5, "nr_snts_in_batch": 0, "nr_words_in_batch": 0,
|
48 |
+
"log_steps": 100, "epochs": 4 }
|
49 |
+
|
50 |
+
#post-process the arguments
|
51 |
+
args = CmdlineArgs(description, pos_arg_list=pos_args, pos_arg_types=pos_types, kw_arg_dict=kw_args)
|
52 |
+
|
53 |
+
if args.anchor_langs is not None:
|
54 |
+
args.anchor_langs = lang_set_maybe_smugri(args.anchor_langs)
|
55 |
+
|
56 |
+
if (args.nr_snts_in_batch > 0) == (args.nr_words_in_batch > 0):
|
57 |
+
raise Exception(f"Specify the batch size either in words or in sentences.")
|
58 |
+
|
59 |
+
# if the directory args.save_location already exists, raise an exception:
|
60 |
+
if not args.continue_training and os.path.exists(args.save_location):
|
61 |
+
raise Exception(f"Save location '{args.save_location}' already exists, don't want to overwrite.")
|
62 |
+
|
63 |
+
return args
|
64 |
+
|
65 |
+
|
66 |
+
def yes_i_called_this_function_do_main():
|
67 |
+
args = _cmdline_args()
|
68 |
+
tmp_acc = Accelerator()
|
69 |
+
|
70 |
+
log(f"Num proc: {tmp_acc.num_processes}, proc ID: {tmp_acc.process_index}")
|
71 |
+
|
72 |
+
log("loading coupled model and tokenizer", accelerator=tmp_acc)
|
73 |
+
main_model, main_tokenizer = load_hf_mdl_and_tok(args.mdl_id, verbose=True)
|
74 |
+
|
75 |
+
coupling_specs = to_cpl_spec(args.langs, main_model, main_tokenizer[0], main_tokenizer[1], args.save_location)
|
76 |
+
|
77 |
+
if args.anchor_mdl_id:
|
78 |
+
log("loading anchor model and tokenizer", accelerator=tmp_acc)
|
79 |
+
anchor_model, anchor_tokenizer = load_hf_mdl_and_tok(args.anchor_mdl_id, verbose=True)
|
80 |
+
freeze_model(anchor_model)
|
81 |
+
|
82 |
+
coupling_specs += to_cpl_spec(args.anchor_langs, anchor_model, anchor_tokenizer[0], anchor_tokenizer[1], args.anchor_mdl_id)
|
83 |
+
|
84 |
+
train_set = MultilingualDatasetIterator(args.train_pretok_file)
|
85 |
+
|
86 |
+
acc_trainer = SwitchingAccelerator(coupling_specs, train_set, args)
|
87 |
+
|
88 |
+
upd_model, loss_list = acc_trainer.train()
|
89 |
+
|
90 |
+
#save_all_models(args.save_location, upd_model, main_tokenizer, coupling_specs, loss_list, trainer=acc_trainer.accelerator)
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
#sys.argv = ". models/smol models/smol_next data/smugri4a-dev.json-tokcache/thiscache.json smugri log_steps=1 lr=1e-5".split()
|
95 |
+
#sys.argv = ". models/llama3.2-1b models/llama-tuned data/smugri4a-dev.json-tokcache/llama.json smugri".split()
|
96 |
+
yes_i_called_this_function_do_main()
|
kuidastaltsutadalaamat/legacy/translate_backup.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
import sys
|
6 |
+
import requests
|
7 |
+
import re
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from aux import CmdlineArgs, log
|
11 |
+
#from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
|
12 |
+
from trainllm import load_hf_tokenizer, load_hf_model
|
13 |
+
from data import do_list_in_batches
|
14 |
+
from modelops import hf_tok, is_gen_ai
|
15 |
+
from collections import defaultdict
|
16 |
+
from langconv import is_nllb, is_madlad, any_to_mdl_type, get_mdl_type, any_to_neurotolge, is_dec_only_llm
|
17 |
+
from tokops import load_tokenizer, tokenizeit, detokenizemany
|
18 |
+
|
19 |
+
|
20 |
+
def prepare_for_translation(provided_inputs, toktup, input_language, output_language=None, device=None):
|
21 |
+
if is_nllb(toktup[0]):
|
22 |
+
toktup[0].src_lang = input_language
|
23 |
+
inputs_to_process = provided_inputs
|
24 |
+
elif is_madlad(toktup[0]):
|
25 |
+
madlad_tgt_lang = output_language
|
26 |
+
inputs_to_process = [f"{madlad_tgt_lang} {inp}" for inp in provided_inputs]
|
27 |
+
else:
|
28 |
+
raise NotImplementedError("Model type not supported")
|
29 |
+
|
30 |
+
prepared_inputs = tokenizeit(toktup, inputs_to_process, 1024, False) #tokenizer(inputs_to_process, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
31 |
+
|
32 |
+
if device is not None:
|
33 |
+
prepared_inputs.to(device)
|
34 |
+
|
35 |
+
frc_bos = toktup[0].get_lang_id(output_language) if output_language is not None else None
|
36 |
+
|
37 |
+
return prepared_inputs, frc_bos
|
38 |
+
|
39 |
+
|
40 |
+
def finalize_translation(outputs, toktup):
|
41 |
+
result = detokenizemany(toktup, outputs) # tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
42 |
+
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
def loadmodel(mdlname="facebook/m2m100_418M", accelerator=None):
|
47 |
+
cl = AutoModelForCausalLM if is_gen_ai(mdlname) else AutoModelForSeq2SeqLM
|
48 |
+
|
49 |
+
if accelerator is not None:
|
50 |
+
model = cl.from_pretrained(mdlname, token=hf_tok, torch_dtype=torch.bfloat16)
|
51 |
+
model = accelerator.prepare(model)
|
52 |
+
else:
|
53 |
+
model = cl.from_pretrained(mdlname, token=hf_tok, torch_dtype=torch.bfloat16, device_map="auto")
|
54 |
+
|
55 |
+
return model
|
56 |
+
|
57 |
+
|
58 |
+
def encode(model, input_batch):
|
59 |
+
model = model.module if hasattr(model, "module") else model
|
60 |
+
|
61 |
+
if is_nllb(model):
|
62 |
+
enc = model.model.encoder
|
63 |
+
elif is_madlad(model):
|
64 |
+
enc = model.base_model.encoder
|
65 |
+
else:
|
66 |
+
raise NotImplementedError(f"Model {model} is not supported yet.")
|
67 |
+
|
68 |
+
inputs_without_labels = { k: input_batch[k] for k in input_batch if k != "labels" }
|
69 |
+
|
70 |
+
return enc(**inputs_without_labels)
|
71 |
+
|
72 |
+
|
73 |
+
def coupled_encode(coupling_specs, lang_to_bin, input_lang, input_texts, debug=False):
|
74 |
+
|
75 |
+
mdl_type = get_mdl_type(coupling_specs[0].model)
|
76 |
+
conv_input_lang = any_to_mdl_type(mdl_type, input_lang)
|
77 |
+
|
78 |
+
this = coupling_specs[lang_to_bin[conv_input_lang]]
|
79 |
+
|
80 |
+
# 0. input text --> input token IDs
|
81 |
+
these_inputs, _ = prepare_for_translation(input_texts, (this.tokenizer, this.postokenizer), conv_input_lang, device=this.model.device)
|
82 |
+
attention_mask = these_inputs["attention_mask"]
|
83 |
+
if debug:
|
84 |
+
for iii in range(len(input_texts)):
|
85 |
+
toklist = []
|
86 |
+
for tok_idx in these_inputs['input_ids'][iii]:
|
87 |
+
try:
|
88 |
+
tok = this.tokenizer.convert_ids_to_tokens([tok_idx])[0]
|
89 |
+
except IndexError:
|
90 |
+
tok = this.postokenizer['idx2tok'][str(tok_idx.item())]
|
91 |
+
toklist.append(tok)
|
92 |
+
print(these_inputs['input_ids'][iii])
|
93 |
+
print(toklist)
|
94 |
+
|
95 |
+
# 1. input token IDs --> encoder vectors
|
96 |
+
#embeddings = this.model.model.encoder(**these_inputs)
|
97 |
+
return encode(this.model, these_inputs), attention_mask
|
98 |
+
|
99 |
+
|
100 |
+
def postproc_llm_output(raw_outputs, tok):
|
101 |
+
eos_id = tok.convert_tokens_to_ids(tok.eos_token)
|
102 |
+
|
103 |
+
for i, _ in enumerate(raw_outputs):
|
104 |
+
repl = None
|
105 |
+
for ii, t in enumerate(raw_outputs[i]):
|
106 |
+
if t.item() == eos_id:
|
107 |
+
repl = eos_id
|
108 |
+
if repl is not None:
|
109 |
+
raw_outputs[i][ii] = repl
|
110 |
+
|
111 |
+
return raw_outputs
|
112 |
+
|
113 |
+
|
114 |
+
def llm_generate(coupling_specs, input_language, output_language, input_texts, debug=False):
|
115 |
+
mdl_type = get_mdl_type(coupling_specs[0].model)
|
116 |
+
conv_input_lang = any_to_mdl_type(mdl_type, input_language)
|
117 |
+
conv_output_lang = any_to_mdl_type(mdl_type, output_language)
|
118 |
+
|
119 |
+
tokenizer = coupling_specs[0].tokenizer
|
120 |
+
|
121 |
+
prep_texts = [make_gen_text(conv_input_lang, conv_output_lang, input_txt, None) for input_txt in input_texts]
|
122 |
+
|
123 |
+
tokenized = tokenizeit((tokenizer, None), prep_texts, 1024, is_target=False, is_llm=True)
|
124 |
+
|
125 |
+
obj = coupling_specs[0].model
|
126 |
+
obj = obj.module if hasattr(obj, "module") else obj
|
127 |
+
|
128 |
+
tokenized['input_ids'] = tokenized['input_ids'].to(obj.device)
|
129 |
+
tokenized['attention_mask'] = tokenized['attention_mask'].to(obj.device)
|
130 |
+
|
131 |
+
raw_outputs = obj.generate(**tokenized, max_length)
|
132 |
+
|
133 |
+
# 3. output token IDs --> output text
|
134 |
+
pre_result = tokenizer.batch_decode(postproc_llm_output(raw_outputs, tokenizer), skip_special_tokens=True)
|
135 |
+
|
136 |
+
result = [raw_out[len(prep_texts[i]):].split("\n")[0] for i, raw_out in enumerate(pre_result)]
|
137 |
+
"""
|
138 |
+
# for i, raw_out in enumerate(pre_result):
|
139 |
+
# print("====")
|
140 |
+
# print(i, raw_out)
|
141 |
+
# print("%%%%")
|
142 |
+
# print(raw_out[len(prep_texts[i])-3:])
|
143 |
+
# print("----")
|
144 |
+
"""
|
145 |
+
|
146 |
+
return result
|
147 |
+
|
148 |
+
def coupled_generate(coupling_specs, lang_to_bin, output_lang, encoder_embeddings, att_mask, debug=False):
|
149 |
+
mdl_type = get_mdl_type(coupling_specs[0].model)
|
150 |
+
conv_output_lang = any_to_mdl_type(mdl_type, output_lang)
|
151 |
+
|
152 |
+
dec_idx = lang_to_bin[conv_output_lang]
|
153 |
+
|
154 |
+
tokenizer = coupling_specs[dec_idx].tokenizer
|
155 |
+
|
156 |
+
# 2. encoder vectors --> output token IDs
|
157 |
+
frc_bos = tokenizer.convert_tokens_to_ids(conv_output_lang)
|
158 |
+
obj = coupling_specs[dec_idx].model
|
159 |
+
obj = obj.module if hasattr(obj, "module") else obj
|
160 |
+
|
161 |
+
raw_outputs = obj.generate(forced_bos_token_id=frc_bos, encoder_outputs=encoder_embeddings, attention_mask=att_mask)
|
162 |
+
if debug:
|
163 |
+
for rwout in raw_outputs:
|
164 |
+
print(rwout)
|
165 |
+
print(tokenizer.convert_ids_to_tokens(rwout))
|
166 |
+
|
167 |
+
# 3. output token IDs --> output text
|
168 |
+
result = finalize_translation(raw_outputs, (tokenizer, coupling_specs[dec_idx].postokenizer))
|
169 |
+
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
def make_uniq(lang_to_bin):
|
174 |
+
result = defaultdict(lambda: 0)
|
175 |
+
|
176 |
+
for lang in lang_to_bin:
|
177 |
+
bin_set = lang_to_bin[lang]
|
178 |
+
result[lang] = 0 if 0 in bin_set else list(bin_set)[0]
|
179 |
+
|
180 |
+
return result
|
181 |
+
|
182 |
+
|
183 |
+
def translate_with_neurotolge(translation_input: str, src_lang: str, tgt_lang: str) -> dict:
|
184 |
+
url = "https://api.tartunlp.ai/translation/v2"
|
185 |
+
|
186 |
+
payload = {
|
187 |
+
"text": translation_input,
|
188 |
+
"src": any_to_neurotolge(src_lang),
|
189 |
+
"tgt": any_to_neurotolge(tgt_lang),
|
190 |
+
"domain": "general",
|
191 |
+
"application": "benchmarking"
|
192 |
+
}
|
193 |
+
|
194 |
+
error = None
|
195 |
+
|
196 |
+
for i in range(5):
|
197 |
+
try:
|
198 |
+
response = requests.post(url, json=payload)
|
199 |
+
response.raise_for_status() # Raise an error for bad status codes
|
200 |
+
return response.json()['result']
|
201 |
+
except requests.exceptions.RequestException as e:
|
202 |
+
error = {"error": str(e)}
|
203 |
+
|
204 |
+
return error
|
205 |
+
|
206 |
+
|
207 |
+
def remove_dia(snt):
|
208 |
+
if ">" in snt:
|
209 |
+
return re.sub(r'^<[^>]+> ', '', snt)
|
210 |
+
else:
|
211 |
+
return snt
|
212 |
+
|
213 |
+
|
214 |
+
def neurotolge_in_batches(input_texts, src_lang, tgt_lang):
|
215 |
+
neurotolge_langs = {'eng', 'est', 'ger', 'lit', 'lav', 'lvs', 'fin', 'rus', 'ukr', 'kca', 'koi', 'kpv', 'krl', 'lud', 'mdf', 'mhr', 'mns', 'mrj', 'myv', 'olo', 'udm', 'vep', 'liv', 'vro', 'sma', 'sme', 'smn', 'sms', 'smj', 'nor', 'hun'}
|
216 |
+
|
217 |
+
if src_lang in neurotolge_langs and tgt_lang in neurotolge_langs:
|
218 |
+
all_outputs = list()
|
219 |
+
|
220 |
+
for inp_batch in do_list_in_batches(input_texts, 8):
|
221 |
+
inp_batch_no_dia = [remove_dia(s) for s in inp_batch]
|
222 |
+
these_outputs = translate_with_neurotolge(inp_batch_no_dia, src_lang, tgt_lang)
|
223 |
+
if len(these_outputs) != len(inp_batch_no_dia):
|
224 |
+
raise Exception(f"Something went wrong.: {src_lang}/{tgt_lang}/{these_outputs}")
|
225 |
+
all_outputs += these_outputs
|
226 |
+
log(f"Translated {len(all_outputs)}/{len(input_texts)} sentences")
|
227 |
+
|
228 |
+
return all_outputs
|
229 |
+
else:
|
230 |
+
return None
|
231 |
+
|
232 |
+
|
233 |
+
def coupled_translate(coupling_specs, input_texts, input_language, output_language, debug=False):
|
234 |
+
lang_to_bin = make_uniq(lang_bin_mapping(coupling_specs))
|
235 |
+
|
236 |
+
all_outputs = list()
|
237 |
+
|
238 |
+
for inp_batch in do_list_in_batches(input_texts, 32):
|
239 |
+
if is_dec_only_llm(coupling_specs[0].tokenizer):
|
240 |
+
these_outputs = llm_generate(coupling_specs, input_language, output_language, input_texts, debug=debug)
|
241 |
+
else:
|
242 |
+
encoder_embeddings, att_mask = coupled_encode(coupling_specs, lang_to_bin, input_language, inp_batch, debug=debug)
|
243 |
+
these_outputs = coupled_generate(coupling_specs, lang_to_bin, output_language, encoder_embeddings, att_mask, debug=debug)
|
244 |
+
|
245 |
+
all_outputs += these_outputs
|
246 |
+
|
247 |
+
return all_outputs
|
248 |
+
|
249 |
+
|
250 |
+
def load_and_init_module_config(model_id, accelerator=None):
|
251 |
+
config = load_module_config(model_id)
|
252 |
+
|
253 |
+
coupling_specs = list()
|
254 |
+
|
255 |
+
main_model = None
|
256 |
+
|
257 |
+
for i, entry in enumerate(config):
|
258 |
+
lang_set = entry["lang_set"]
|
259 |
+
model_id = entry["model_id"] if i > 0 else model_id
|
260 |
+
|
261 |
+
log(f"Loading model and tokenizer from '{model_id}'")
|
262 |
+
model = loadmodel(model_id, accelerator)
|
263 |
+
tokenizer, postok = load_tokenizer(model_id)
|
264 |
+
|
265 |
+
if i == 0:
|
266 |
+
main_model = model
|
267 |
+
|
268 |
+
#(langs, model, tokenizer, location):
|
269 |
+
coupling_specs += to_cpl_spec(lang_set, model, tokenizer, postok, model_id)
|
270 |
+
|
271 |
+
return main_model, coupling_specs
|
272 |
+
|
273 |
+
|
274 |
+
def _cmdline_args(inputs):
|
275 |
+
# description = ""Translate STDIN text with a translation model""
|
276 |
+
|
277 |
+
pos_args = ["mdl_id", "from_lang", "to_lang"]
|
278 |
+
|
279 |
+
#post-process the arguments
|
280 |
+
args = CmdlineArgs(description, pos_args, input_args=inputs, kw_arg_dict={"debug": False})
|
281 |
+
|
282 |
+
log(f"Launched as {args}")
|
283 |
+
|
284 |
+
return args
|
285 |
+
|
286 |
+
|
287 |
+
def and_i_called_this_function_do_main_too(iv):
|
288 |
+
args = _cmdline_args(iv)
|
289 |
+
|
290 |
+
inputs = [line.strip() for line in sys.stdin]
|
291 |
+
# inputs = ["See on ikka tore uudis.", "Ma ikka katsetaks ka täpitähtedega tõlkimist.", "Mis tähed on täpitähed?"]
|
292 |
+
|
293 |
+
log(f"Inputs: {inputs}")
|
294 |
+
|
295 |
+
main_model, module_config = load_and_init_module_config(args.mdl_id)
|
296 |
+
log("Model loaded, starting to translate")
|
297 |
+
outputs = coupled_translate(module_config, inputs, args.from_lang, args.to_lang, debug=args.debug)
|
298 |
+
|
299 |
+
print("\n".join(outputs))
|
300 |
+
|
301 |
+
log("Done...")
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
input_values = sys.argv[1:] if len(sys.argv) > 1 \
|
306 |
+
else ["models/nllb", "et", "en"]
|
307 |
+
|
308 |
+
and_i_called_this_function_do_main_too(input_values)
|
309 |
+
"""
|
kuidastaltsutadalaamat/metrics.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from data import read_input
|
4 |
+
from aux import log
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from collections import defaultdict
|
8 |
+
from evaluate import load as load_metric
|
9 |
+
|
10 |
+
SMUGRI_RES = {
|
11 |
+
'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")),
|
12 |
+
'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")),
|
13 |
+
'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari,"
|
14 |
+
"Komi-Permyak,Inari Sami".split(",")),
|
15 |
+
'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli,"
|
16 |
+
"Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty,"
|
17 |
+
"Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(","))
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def _gen_lang(lang):
|
22 |
+
return lang.split(",")[0]
|
23 |
+
|
24 |
+
|
25 |
+
def _hi_or_lo_lang(lang):
|
26 |
+
gen_lang = _gen_lang(lang)
|
27 |
+
|
28 |
+
for k, v in SMUGRI_RES.items():
|
29 |
+
if gen_lang in v:
|
30 |
+
return k
|
31 |
+
|
32 |
+
log(f"Unrecognized language: {lang} / {gen_lang}")
|
33 |
+
return '?'
|
34 |
+
|
35 |
+
|
36 |
+
def _collect_lp_pairs(json_inputs, str_outputs):
|
37 |
+
sets_by_lp = defaultdict(list)
|
38 |
+
|
39 |
+
for i, o in zip(json_inputs, str_outputs):
|
40 |
+
ref = i["tgt_segm"]
|
41 |
+
hyp = o
|
42 |
+
det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"]
|
43 |
+
gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"])
|
44 |
+
hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"])
|
45 |
+
|
46 |
+
sets_by_lp[det_lp].append((hyp, ref))
|
47 |
+
sets_by_lp[gen_lp].append((hyp, ref))
|
48 |
+
sets_by_lp[hilo_lp].append((hyp, ref))
|
49 |
+
|
50 |
+
return sets_by_lp
|
51 |
+
|
52 |
+
|
53 |
+
def compute_metrics(json_inputs, str_outputs):
|
54 |
+
sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs)
|
55 |
+
|
56 |
+
metric = load_metric("chrf")
|
57 |
+
|
58 |
+
result = []
|
59 |
+
|
60 |
+
for lp in sets_by_lp:
|
61 |
+
preds, outputs = zip(*sets_by_lp[lp])
|
62 |
+
metric_value = metric.compute(predictions=preds, references=outputs)
|
63 |
+
|
64 |
+
result.append((lp, metric_value, len(preds)))
|
65 |
+
|
66 |
+
return result
|
67 |
+
|
68 |
+
|
69 |
+
def avoid_global_scope():
|
70 |
+
json_inputs = read_input(sys.argv[1], "json")
|
71 |
+
str_outputs = read_input(sys.argv[2], "json")
|
72 |
+
|
73 |
+
lp_metric_dict = compute_metrics(json_inputs, str_outputs)
|
74 |
+
|
75 |
+
for lp, metric, size in lp_metric_dict:
|
76 |
+
print(f"{lp}: {metric['score']:.2f} ({size})")
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
avoid_global_scope()
|
kuidastaltsutadalaamat/promptops.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# first, keyword identifiers for selecting prompt templates in scripts:
|
3 |
+
|
4 |
+
PF_RAW = "raw"
|
5 |
+
PF_RAWLINES = "rawlines"
|
6 |
+
PF_SMUGRI_MT = "smugri_mt"
|
7 |
+
PF_SMUGRI_LID = "smugri_lid"
|
8 |
+
PF_ALPACA = "alpaca"
|
9 |
+
|
10 |
+
# now the prompt templates themselves, SMUGRI LID / MT template:
|
11 |
+
|
12 |
+
SMUGRI_INF_PROMPT_LID = "<|reserved_special_token_12|>{src_segm}<|reserved_special_token_13|>"
|
13 |
+
|
14 |
+
_SMUGRI_INF_PROMPT_TMPMID = "<|reserved_special_token_14|>{task} to {tgt_lang}<|reserved_special_token_15|>"
|
15 |
+
SMUGRI_INF_PROMPT_MT = SMUGRI_INF_PROMPT_LID + "{src_lang}" + _SMUGRI_INF_PROMPT_TMPMID
|
16 |
+
|
17 |
+
_SMUGRI_TRAIN_PROMPT_PREF = SMUGRI_INF_PROMPT_LID + "{src_lang}"
|
18 |
+
_SMUGRI_TRAIN_PROMPT_MID = _SMUGRI_INF_PROMPT_TMPMID + "{tgt_segm}"
|
19 |
+
_SMUGRI_TRAIN_PROMPT_SUF = "<|reserved_special_token_16|><|end_of_text|>"
|
20 |
+
|
21 |
+
SMUGRI_PROMPT_TRAIN_PARA = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_MID + _SMUGRI_TRAIN_PROMPT_SUF
|
22 |
+
SMUGRI_PROMPT_TRAIN_MONO = _SMUGRI_TRAIN_PROMPT_PREF + _SMUGRI_TRAIN_PROMPT_SUF
|
23 |
+
|
24 |
+
# Alpaca instructions prompt template:
|
25 |
+
|
26 |
+
ALPACA_PROMPT_INF = ("Below is an instruction that describes a task, paired with an input that provides further context. "
|
27 |
+
"Write a response that appropriately completes the request.\n\n"
|
28 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n")
|
29 |
+
|
30 |
+
ALPACA_PROMPT_TRAIN = (ALPACA_PROMPT_INF + "{output}")
|
31 |
+
|
32 |
+
|
33 |
+
def prep_prompt(data, prompt_format, inference=False):
|
34 |
+
if prompt_format in {PF_RAW, PF_RAWLINES}:
|
35 |
+
# data is a string, return it
|
36 |
+
return data
|
37 |
+
|
38 |
+
elif prompt_format in {PF_SMUGRI_MT, PF_SMUGRI_LID}:
|
39 |
+
# data has src_segm, src_lang, tgt_lang, etc
|
40 |
+
return _prep_ljmf_entry(data, prompt_format, inference)
|
41 |
+
|
42 |
+
elif prompt_format == PF_ALPACA:
|
43 |
+
# data has instruction and input in it
|
44 |
+
return _prep_alpaca_entry(data, inference)
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise NotImplementedError(f"Prompt format {prompt_format} is not implemented.")
|
48 |
+
|
49 |
+
|
50 |
+
def _prep_alpaca_entry(entry, inference=False):
|
51 |
+
fmt = ALPACA_PROMPT_INF if inference else ALPACA_PROMPT_TRAIN
|
52 |
+
prompt = fmt.format(**entry)
|
53 |
+
return prompt
|
54 |
+
|
55 |
+
|
56 |
+
def _prep_ljmf_entry(entry, fmt, inference=False):
|
57 |
+
if inference:
|
58 |
+
if fmt == PF_SMUGRI_MT:
|
59 |
+
prompt = SMUGRI_INF_PROMPT_MT.format(**entry)
|
60 |
+
elif fmt == PF_SMUGRI_LID:
|
61 |
+
prompt = SMUGRI_INF_PROMPT_LID.format(**entry)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"Prompt format {fmt} is not implemented.")
|
64 |
+
else:
|
65 |
+
if entry['task'] in {'translate', 'approx-translate'} and entry['tgt_segm'] and entry['tgt_lang']:
|
66 |
+
prompt = SMUGRI_PROMPT_TRAIN_PARA.format(**entry)
|
67 |
+
else:
|
68 |
+
prompt = SMUGRI_PROMPT_TRAIN_MONO.format(**entry)
|
69 |
+
|
70 |
+
return prompt
|
kuidastaltsutadalaamat/trainllm.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from .promptops import PF_SMUGRI_MT
|
4 |
+
from .aux import log, CmdlineArgs
|
5 |
+
from .data import load_training_data
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os, socket, torch
|
9 |
+
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from transformers import (
|
14 |
+
AutoTokenizer,
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
TrainingArguments,
|
17 |
+
Trainer,
|
18 |
+
DataCollatorForLanguageModeling,
|
19 |
+
logging,
|
20 |
+
TrainerCallback
|
21 |
+
)
|
22 |
+
|
23 |
+
"""
|
24 |
+
1/3 This simply reads in command-line arguments
|
25 |
+
"""
|
26 |
+
|
27 |
+
def _cmdline_args():
|
28 |
+
description = """Train or tune decoder models"""
|
29 |
+
|
30 |
+
result = CmdlineArgs(description,
|
31 |
+
pos_arg_list=["mdl_id", "save_location", "train_file"],
|
32 |
+
pos_arg_types=[str, str, str],
|
33 |
+
kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5,
|
34 |
+
"batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4,
|
35 |
+
"max_length": 2000, "prompt_format": PF_SMUGRI_MT,
|
36 |
+
"deepspeed": "none"})
|
37 |
+
|
38 |
+
# if the directory args.save_location already exists, raise an exception:
|
39 |
+
if not result.continue_training and os.path.exists(result.save_location):
|
40 |
+
raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.")
|
41 |
+
|
42 |
+
if result.nr_sents_per_gpu == 0:
|
43 |
+
result.nr_sents_per_gpu = result.batch_size
|
44 |
+
|
45 |
+
if result.deepspeed == "none":
|
46 |
+
result.deepspeed = None
|
47 |
+
|
48 |
+
return result
|
49 |
+
|
50 |
+
"""
|
51 |
+
2/3 This here is used in training in order to report timing and predictions
|
52 |
+
"""
|
53 |
+
|
54 |
+
class StepTimerCallback(TrainerCallback):
|
55 |
+
def __init__(self):
|
56 |
+
self._step_start = None
|
57 |
+
self.lengths = []
|
58 |
+
self.abs_start = datetime.now()
|
59 |
+
|
60 |
+
self.actual_first_step = None
|
61 |
+
|
62 |
+
self.zero = self.abs_start - self.abs_start
|
63 |
+
|
64 |
+
def on_step_begin(self, args, state, control, **kwargs):
|
65 |
+
# called right before each training step
|
66 |
+
self._step_start = datetime.now()
|
67 |
+
|
68 |
+
def on_step_end(self, args, state, control, **kwargs):
|
69 |
+
if self.actual_first_step is None:
|
70 |
+
self.actual_first_step = state.global_step - 1
|
71 |
+
|
72 |
+
# called right after each training step
|
73 |
+
now = datetime.now()
|
74 |
+
elapsed = now - self._step_start
|
75 |
+
tot_elapsed = now - self.abs_start
|
76 |
+
self.lengths.append(elapsed)
|
77 |
+
|
78 |
+
avg = sum(self.lengths, start=self.zero) / len(self.lengths)
|
79 |
+
|
80 |
+
remaining = state.max_steps - self.actual_first_step - state.global_step
|
81 |
+
prediction = (tot_elapsed/(state.global_step - self.actual_first_step)) * remaining
|
82 |
+
|
83 |
+
# you can use logging.get_logger(...) instead of print
|
84 |
+
print(f"[step {state.global_step}/{state.max_steps}] took {elapsed}, avg {avg}; approx {prediction} remaining")
|
85 |
+
|
86 |
+
"""
|
87 |
+
3/3 Finally, the filling of TrainingArguments and the launching of Trainer:
|
88 |
+
"""
|
89 |
+
|
90 |
+
def get_training_args(cmdline_args, acc):
|
91 |
+
world_size = acc.num_processes
|
92 |
+
|
93 |
+
assert cmdline_args.batch_size % (cmdline_args.nr_sents_per_gpu * world_size) == 0, \
|
94 |
+
"Batch size must be divisible by the number of GPUs and nr of sents per GPU"
|
95 |
+
|
96 |
+
accum_steps = cmdline_args.batch_size // (cmdline_args.nr_sents_per_gpu * world_size)
|
97 |
+
|
98 |
+
log(f"Nr of processes (GPUs): {world_size}, per-device batch: {cmdline_args.nr_sents_per_gpu}, accum. steps: {accum_steps}")
|
99 |
+
|
100 |
+
if cmdline_args.deepspeed is not None:
|
101 |
+
with open(cmdline_args.deepspeed, "r") as f:
|
102 |
+
dpspd = json.load(f)
|
103 |
+
|
104 |
+
#correct the dictionary with current values, so that we wouldn't need to update the JSON every time
|
105 |
+
dpspd['train_batch_size'] = cmdline_args.batch_size
|
106 |
+
dpspd['train_micro_batch_size_per_gpu'] = cmdline_args.nr_sents_per_gpu
|
107 |
+
dpspd['gradient_accumulation_steps'] = accum_steps
|
108 |
+
|
109 |
+
log(f"Using deepspeed with config {dpspd}")
|
110 |
+
else:
|
111 |
+
dpspd = None
|
112 |
+
|
113 |
+
tr_args = TrainingArguments(
|
114 |
+
output_dir=cmdline_args.save_location,
|
115 |
+
per_device_train_batch_size=cmdline_args.nr_sents_per_gpu,
|
116 |
+
gradient_accumulation_steps=accum_steps,
|
117 |
+
num_train_epochs=cmdline_args.epochs,
|
118 |
+
save_steps=cmdline_args.save_steps,
|
119 |
+
save_total_limit=10,
|
120 |
+
logging_steps=cmdline_args.log_steps,
|
121 |
+
deepspeed=dpspd,
|
122 |
+
learning_rate=cmdline_args.lr,
|
123 |
+
save_strategy="epoch",
|
124 |
+
disable_tqdm=True,
|
125 |
+
report_to="none",
|
126 |
+
# Optional but often helpful on LUMI/ROCm if you enable it in your args:
|
127 |
+
bf16=True,
|
128 |
+
ddp_find_unused_parameters=False,
|
129 |
+
#dataloader_num_workers=1,
|
130 |
+
#group_by_length=True,
|
131 |
+
log_level="debug",
|
132 |
+
#gradient_checkpointing=True,
|
133 |
+
#dataloader_persistent_workers=True
|
134 |
+
)
|
135 |
+
|
136 |
+
return tr_args
|
137 |
+
|
138 |
+
|
139 |
+
def load_model(mdl_id, device, accelerator=None, attention="flash_attention_2"):
|
140 |
+
log(f"Load model", accelerator=accelerator)
|
141 |
+
model = AutoModelForCausalLM.from_pretrained(mdl_id,
|
142 |
+
low_cpu_mem_usage=False,
|
143 |
+
torch_dtype=torch.bfloat16,
|
144 |
+
attn_implementation=attention)
|
145 |
+
|
146 |
+
model.config.use_cache = False
|
147 |
+
model = model.to(device)
|
148 |
+
log(f"Model loaded on device: {model.device}.", accelerator=accelerator)
|
149 |
+
|
150 |
+
return model
|
151 |
+
|
152 |
+
|
153 |
+
def load_tokenizer(mdl_id, accelerator=None):
|
154 |
+
log(f"Load tokenizer", accelerator=accelerator)
|
155 |
+
tokenizer = AutoTokenizer.from_pretrained(mdl_id)
|
156 |
+
|
157 |
+
# LLaMA 3.x: no pad token by default
|
158 |
+
if tokenizer.pad_token is None:
|
159 |
+
tokenizer.pad_token = "<|reserved_special_token_100|>"
|
160 |
+
|
161 |
+
return tokenizer
|
162 |
+
|
163 |
+
|
164 |
+
def simple_train():
|
165 |
+
cmd_args = _cmdline_args()
|
166 |
+
acc = Accelerator()
|
167 |
+
device = acc.device # it seems that the accelerator loses/changes this info later
|
168 |
+
|
169 |
+
training_args = get_training_args(cmd_args, acc)
|
170 |
+
|
171 |
+
tokenizer = load_tokenizer(cmd_args.mdl_id, acc)
|
172 |
+
model = load_model(cmd_args.mdl_id, device, acc)
|
173 |
+
|
174 |
+
if getattr(model.config, "pad_token_id", None) is None:
|
175 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
176 |
+
|
177 |
+
log(f"Load data", accelerator=acc)
|
178 |
+
tokenized_train_data = load_training_data(cmd_args.train_file, tokenizer, cmd_args)
|
179 |
+
|
180 |
+
data_collator = DataCollatorForLanguageModeling(
|
181 |
+
tokenizer=tokenizer,
|
182 |
+
mlm=False,
|
183 |
+
pad_to_multiple_of=8, # GPT says this helps performance
|
184 |
+
)
|
185 |
+
|
186 |
+
log(f"Preparing to train", accelerator=acc)
|
187 |
+
|
188 |
+
clbks = [StepTimerCallback] if acc.is_main_process else []
|
189 |
+
|
190 |
+
trainer = Trainer(
|
191 |
+
model=model,
|
192 |
+
args=training_args,
|
193 |
+
train_dataset=tokenized_train_data,
|
194 |
+
tokenizer=tokenizer,
|
195 |
+
data_collator=data_collator,
|
196 |
+
callbacks=clbks,
|
197 |
+
)
|
198 |
+
|
199 |
+
logging.set_verbosity_debug()
|
200 |
+
|
201 |
+
log(f"Starting training", accelerator=acc)
|
202 |
+
trainer.train(resume_from_checkpoint=cmd_args.continue_training)
|
203 |
+
|
204 |
+
log(f"Done, saving model", accelerator=acc)
|
205 |
+
trainer.save_model()
|
206 |
+
|
207 |
+
|
208 |
+
def env_stuff():
|
209 |
+
os.environ.setdefault("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "---"))
|
210 |
+
os.environ.setdefault("RANK", os.environ.get("SLURM_PROCID", "0"))
|
211 |
+
os.environ.setdefault("WORLD_SIZE", os.environ.get("SLURM_NTASKS", "1"))
|
212 |
+
os.environ.setdefault("MASTER_ADDR", os.environ.get("SLURM_LAUNCH_NODE_IPADDR", "127.0.0.1"))
|
213 |
+
os.environ.setdefault("MASTER_PORT", "29500") # pick an open port
|
214 |
+
|
215 |
+
# Optional: make sure each process selects its own GPU
|
216 |
+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
217 |
+
|
218 |
+
try:
|
219 |
+
log(
|
220 |
+
f"host={socket.gethostname()} "
|
221 |
+
f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} "
|
222 |
+
f"LOCAL_RANK={os.environ['LOCAL_RANK']} "
|
223 |
+
f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} "
|
224 |
+
f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} "
|
225 |
+
f"cuda_count={torch.cuda.device_count()} curr_dev={torch.cuda.current_device()}"
|
226 |
+
)
|
227 |
+
except AssertionError:
|
228 |
+
log(
|
229 |
+
f"host={socket.gethostname()} "
|
230 |
+
f"RANK={os.environ['RANK']}/{os.environ['WORLD_SIZE']} "
|
231 |
+
f"LOCAL_RANK={os.environ['LOCAL_RANK']} "
|
232 |
+
f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')} "
|
233 |
+
f"ROCR_VISIBLE_DEVICES={os.environ.get('ROCR_VISIBLE_DEVICES')} "
|
234 |
+
f"no cuda"
|
235 |
+
)
|
236 |
+
|
237 |
+
"""
|
238 |
+
This replaces the trainer, in order to
|
239 |
+
print out the final batch when training,
|
240 |
+
and commit harakiri. So only for temporary
|
241 |
+
debugging-related usage
|
242 |
+
"""
|
243 |
+
class LoggingKillingTrainer(Trainer):
|
244 |
+
def compute_loss(self, model, inputs, **kwargs):
|
245 |
+
log(f"Here is the batch for training: {inputs}")
|
246 |
+
raise NotImplementedError
|
247 |
+
#return super().compute_loss(model, inputs, **kwargs)
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
env_stuff()
|
251 |
+
|
252 |
+
simple_train()
|