Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +44 -35
- .venv/.gitignore +1 -0
- .venv/.lock +0 -0
- .venv/CACHEDIR.TAG +1 -0
- .venv/CHANGES.rst +76 -0
- .venv/Lib/site-packages/_cffi_backend.cp39-win_amd64.pyd +0 -0
- .venv/Lib/site-packages/_soundfile.py +11 -0
- .venv/Lib/site-packages/_virtualenv.py +101 -0
- .venv/Lib/site-packages/accelerate/__init__.py +50 -0
- .venv/Lib/site-packages/accelerate/accelerator.py +0 -0
- .venv/Lib/site-packages/accelerate/big_modeling.py +637 -0
- .venv/Lib/site-packages/accelerate/checkpointing.py +306 -0
- .venv/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/accelerate_cli.py +52 -0
- .venv/Lib/site-packages/accelerate/commands/config/__init__.py +52 -0
- .venv/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/config/config_args.py +252 -0
- .venv/Lib/site-packages/accelerate/commands/config/default.py +142 -0
- .venv/Lib/site-packages/accelerate/commands/config/sagemaker.py +267 -0
- .venv/Lib/site-packages/accelerate/commands/config/update.py +63 -0
- .venv/Lib/site-packages/accelerate/commands/env.py +113 -0
- .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/accelerate/data_loader.py +1323 -0
- .venv/Lib/site-packages/accelerate/hooks.py +726 -0
- .venv/Lib/site-packages/accelerate/inference.py +184 -0
- .venv/Lib/site-packages/accelerate/launchers.py +302 -0
- .venv/Lib/site-packages/accelerate/local_sgd.py +104 -0
- .venv/Lib/site-packages/accelerate/logging.py +125 -0
- .venv/Lib/site-packages/accelerate/memory_utils.py +22 -0
- .venv/Lib/site-packages/accelerate/optimizer.py +212 -0
- .venv/Lib/site-packages/accelerate/scheduler.py +98 -0
- .venv/Lib/site-packages/accelerate/state.py +1257 -0
- .venv/Lib/site-packages/accelerate/tracking.py +1023 -0
- .venv/Lib/site-packages/decorator.py +451 -0
- .venv/Lib/site-packages/isympy.py +342 -0
- .venv/Lib/site-packages/mojimoji.cp39-win_amd64.pyd +0 -0
- .venv/Lib/site-packages/numpy-1.26.3-cp39-cp39-win_amd64.whl +0 -0
- .venv/Lib/site-packages/plac.py +37 -0
- .venv/Lib/site-packages/plac_core.py +439 -0
- .venv/Lib/site-packages/plac_ext.py +1205 -0
- .venv/Lib/site-packages/plac_tk.py +64 -0
- .venv/Lib/site-packages/pylab.py +3 -0
.gitattributes
CHANGED
@@ -1,35 +1,44 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Utils/PLBERT/step_1050000.t7 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
reference_sample_wavs/01008270.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
reference_sample_wavs/kaede_san.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
reference_sample_wavs/riamu_zeroshot_02.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
reference_sample_wavs/sample_ref01.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
reference_sample_wavs/sample_ref02.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
reference_sample_wavs/shiki_fine05.wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs -text
|
.venv/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*
|
.venv/.lock
ADDED
File without changes
|
.venv/CACHEDIR.TAG
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
.venv/CHANGES.rst
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CHANGES
|
2 |
+
=======
|
3 |
+
|
4 |
+
0.4.0 (2024-7-26)
|
5 |
+
-------------------
|
6 |
+
- Add stub files according to PEP 561 for mypy (thanks @ernix)
|
7 |
+
|
8 |
+
0.3.4 (2023-2-18)
|
9 |
+
-------------------
|
10 |
+
- Fix to support Python2.7 ~ 3.4 (thanks @manjuu-eater)
|
11 |
+
- Support Python 3.11
|
12 |
+
|
13 |
+
0.3.3 (2022-12-31)
|
14 |
+
-------------------
|
15 |
+
- Support Python 3.10
|
16 |
+
- Re-support Python2.7 ~ 3.4 (thanks @manjuu-eater)
|
17 |
+
- Fix z2h, h2z all flag off bug (thanks @manjuu-eater)
|
18 |
+
|
19 |
+
0.3.1 (2022-12-14)
|
20 |
+
-------------------
|
21 |
+
- Fix alpha2kana infinite loop bug (thanks @frog42)
|
22 |
+
|
23 |
+
0.3 (2021-03-29)
|
24 |
+
-------------------
|
25 |
+
- Fix bug (alphabet2kana) thanks @Cuddlemuffin007
|
26 |
+
- Support Python 3.8 and 3.9
|
27 |
+
- Add handy functions: alphabet2kata and kata2alphabet. thanks @kokimame
|
28 |
+
- Add function for julius: hiragana2julius
|
29 |
+
|
30 |
+
0.2.4 (2018-02-04)
|
31 |
+
-------------------
|
32 |
+
- Fix bug (kana2alphabet)
|
33 |
+
- Support Python 3.7
|
34 |
+
- No longer support Python 2.6
|
35 |
+
- Add aliases of z2h -> zenkaku2hankaku and h2z -> hankaku2zenkaku
|
36 |
+
|
37 |
+
0.2.3 (2018-02-03)
|
38 |
+
-------------------
|
39 |
+
- Fix bugs (alphabet2kana, kana2alphabet) thanks @letuananh
|
40 |
+
|
41 |
+
0.2.2 (2018-01-22)
|
42 |
+
-------------------
|
43 |
+
- Fix bug (kana2alphabet) thanks @kokimame
|
44 |
+
- Support Python 3.6
|
45 |
+
|
46 |
+
0.2.1 (2017-09-14)
|
47 |
+
-------------------
|
48 |
+
- Fix bugs (alphabet2kana, kana2alphabet)
|
49 |
+
|
50 |
+
0.2 (2015-04-02)
|
51 |
+
------------------
|
52 |
+
|
53 |
+
- Change module name jctconv -> jaconv
|
54 |
+
- Add alphabet and hiragana interconvert (alphabet2kana, kana2alphabet)
|
55 |
+
|
56 |
+
0.1.1 (2015-03-12)
|
57 |
+
------------------
|
58 |
+
|
59 |
+
- Support Windows
|
60 |
+
- Support Python 3.5
|
61 |
+
|
62 |
+
|
63 |
+
0.1 (2014-11-24)
|
64 |
+
------------------
|
65 |
+
|
66 |
+
- Add some Japanese characters to convert table (ゝゞ・「」。、)
|
67 |
+
- Decresing memory usage
|
68 |
+
- Some function names are deprecated (hankaku2zenkaku, zenkaku2hankaku, H2K, H2hK, K2H)
|
69 |
+
|
70 |
+
|
71 |
+
0.0.7 (2014-03-22)
|
72 |
+
------------------
|
73 |
+
|
74 |
+
z2h and h2z allow mojimoji-like target character type determination.
|
75 |
+
Bug fix about Half Kana conversion.
|
76 |
+
|
.venv/Lib/site-packages/_cffi_backend.cp39-win_amd64.pyd
ADDED
Binary file (178 kB). View file
|
|
.venv/Lib/site-packages/_soundfile.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# auto-generated file
|
2 |
+
import _cffi_backend
|
3 |
+
|
4 |
+
ffi = _cffi_backend.FFI('_soundfile',
|
5 |
+
_version = 0x2601,
|
6 |
+
_types = b'\x00\x00\x17\x0D\x00\x00\x6D\x03\x00\x00\x07\x01\x00\x00\x6C\x03\x00\x00\x7A\x03\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x6F\x03\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x07\x01\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x7B\x03\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x6E\x03\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x6C\x03\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x6F\x03\x00\x00\x1C\x01\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x07\x11\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x04\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x70\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x74\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x02\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x17\x01\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x79\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x04\x11\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x01\x00\x00\x07\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x04\x11\x00\x00\x17\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x7A\x03\x00\x00\x17\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x7A\x0D\x00\x00\x17\x11\x00\x00\x00\x0F\x00\x00\x00\x09\x00\x00\x01\x09\x00\x00\x02\x09\x00\x00\x03\x09\x00\x00\x02\x01\x00\x00\x0E\x01\x00\x00\x00\x0B\x00\x00\x01\x0B\x00\x00\x02\x0B\x00\x00\x0D\x01\x00\x00\x56\x03\x00\x00\x5B\x03\x00\x00\x5E\x03\x00\x00\x63\x03\x00\x00\x05\x01\x00\x00\x00\x01\x00\x00\x10\x01',
|
7 |
+
_globals = (b'\xFF\xFF\xFF\x0BSFC_FILE_TRUNCATE',4224,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_INFO',4136,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_MAJOR',4145,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_MAJOR_COUNT',4144,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_SUBTYPE',4147,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_SUBTYPE_COUNT',4146,b'\xFF\xFF\xFF\x0BSFC_GET_LIB_VERSION',4096,b'\xFF\xFF\xFF\x0BSFC_GET_LOG_INFO',4097,b'\xFF\xFF\xFF\x0BSFC_SET_CLIPPING',4288,b'\xFF\xFF\xFF\x0BSFC_SET_SCALE_FLOAT_INT_READ',4116,b'\xFF\xFF\xFF\x0BSFC_SET_SCALE_INT_FLOAT_WRITE',4117,b'\xFF\xFF\xFF\x0BSFM_RDWR',48,b'\xFF\xFF\xFF\x0BSFM_READ',16,b'\xFF\xFF\xFF\x0BSFM_WRITE',32,b'\xFF\xFF\xFF\x0BSF_FALSE',0,b'\xFF\xFF\xFF\x0BSF_FORMAT_ENDMASK',805306368,b'\xFF\xFF\xFF\x0BSF_FORMAT_SUBMASK',65535,b'\xFF\xFF\xFF\x0BSF_FORMAT_TYPEMASK',268369920,b'\xFF\xFF\xFF\x0BSF_TRUE',1,b'\x00\x00\x25\x23sf_close',0,b'\x00\x00\x32\x23sf_command',0,b'\x00\x00\x25\x23sf_error',0,b'\x00\x00\x1D\x23sf_error_number',0,b'\x00\x00\x28\x23sf_error_str',0,b'\x00\x00\x22\x23sf_format_check',0,b'\x00\x00\x19\x23sf_get_string',0,b'\x00\x00\x06\x23sf_open',0,b'\x00\x00\x0B\x23sf_open_fd',0,b'\x00\x00\x00\x23sf_open_virtual',0,b'\x00\x00\x25\x23sf_perror',0,b'\x00\x00\x38\x23sf_read_double',0,b'\x00\x00\x3D\x23sf_read_float',0,b'\x00\x00\x42\x23sf_read_int',0,b'\x00\x00\x51\x23sf_read_raw',0,b'\x00\x00\x4C\x23sf_read_short',0,b'\x00\x00\x51\x23sf_readf_double',0,b'\x00\x00\x51\x23sf_readf_float',0,b'\x00\x00\x51\x23sf_readf_int',0,b'\x00\x00\x51\x23sf_readf_short',0,b'\x00\x00\x47\x23sf_seek',0,b'\x00\x00\x2D\x23sf_set_string',0,b'\x00\x00\x16\x23sf_strerror',0,b'\x00\x00\x20\x23sf_version_string',0,b'\x00\x00\x11\x23sf_wchar_open',0,b'\x00\x00\x38\x23sf_write_double',0,b'\x00\x00\x3D\x23sf_write_float',0,b'\x00\x00\x42\x23sf_write_int',0,b'\x00\x00\x51\x23sf_write_raw',0,b'\x00\x00\x4C\x23sf_write_short',0,b'\x00\x00\x68\x23sf_write_sync',0,b'\x00\x00\x51\x23sf_writef_double',0,b'\x00\x00\x51\x23sf_writef_float',0,b'\x00\x00\x51\x23sf_writef_int',0,b'\x00\x00\x51\x23sf_writef_short',0),
|
8 |
+
_struct_unions = ((b'\x00\x00\x00\x6B\x00\x00\x00\x02SF_FORMAT_INFO',b'\x00\x00\x02\x11format',b'\x00\x00\x07\x11name',b'\x00\x00\x07\x11extension'),(b'\x00\x00\x00\x6C\x00\x00\x00\x02SF_INFO',b'\x00\x00\x3B\x11frames',b'\x00\x00\x02\x11samplerate',b'\x00\x00\x02\x11channels',b'\x00\x00\x02\x11format',b'\x00\x00\x02\x11sections',b'\x00\x00\x02\x11seekable'),(b'\x00\x00\x00\x6D\x00\x00\x00\x02SF_VIRTUAL_IO',b'\x00\x00\x76\x11get_filelen',b'\x00\x00\x75\x11seek',b'\x00\x00\x77\x11read',b'\x00\x00\x78\x11write',b'\x00\x00\x76\x11tell'),(b'\x00\x00\x00\x6E\x00\x00\x00\x10SNDFILE_tag',)),
|
9 |
+
_enums = (b'\x00\x00\x00\x71\x00\x00\x00\x16$1\x00SF_FORMAT_SUBMASK,SF_FORMAT_TYPEMASK,SF_FORMAT_ENDMASK',b'\x00\x00\x00\x72\x00\x00\x00\x16$2\x00SFC_GET_LIB_VERSION,SFC_GET_LOG_INFO,SFC_GET_FORMAT_INFO,SFC_GET_FORMAT_MAJOR_COUNT,SFC_GET_FORMAT_MAJOR,SFC_GET_FORMAT_SUBTYPE_COUNT,SFC_GET_FORMAT_SUBTYPE,SFC_FILE_TRUNCATE,SFC_SET_CLIPPING,SFC_SET_SCALE_FLOAT_INT_READ,SFC_SET_SCALE_INT_FLOAT_WRITE',b'\x00\x00\x00\x73\x00\x00\x00\x16$3\x00SF_FALSE,SF_TRUE,SFM_READ,SFM_WRITE,SFM_RDWR'),
|
10 |
+
_typenames = (b'\x00\x00\x00\x6BSF_FORMAT_INFO',b'\x00\x00\x00\x6CSF_INFO',b'\x00\x00\x00\x6DSF_VIRTUAL_IO',b'\x00\x00\x00\x6ESNDFILE',b'\x00\x00\x00\x3Bsf_count_t',b'\x00\x00\x00\x76sf_vio_get_filelen',b'\x00\x00\x00\x77sf_vio_read',b'\x00\x00\x00\x75sf_vio_seek',b'\x00\x00\x00\x76sf_vio_tell',b'\x00\x00\x00\x78sf_vio_write'),
|
11 |
+
)
|
.venv/Lib/site-packages/_virtualenv.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Patches that are applied at runtime to the virtual environment."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
VIRTUALENV_PATCH_FILE = os.path.join(__file__)
|
7 |
+
|
8 |
+
|
9 |
+
def patch_dist(dist):
|
10 |
+
"""
|
11 |
+
Distutils allows user to configure some arguments via a configuration file:
|
12 |
+
https://docs.python.org/3.11/install/index.html#distutils-configuration-files.
|
13 |
+
|
14 |
+
Some of this arguments though don't make sense in context of the virtual environment files, let's fix them up.
|
15 |
+
""" # noqa: D205
|
16 |
+
# we cannot allow some install config as that would get packages installed outside of the virtual environment
|
17 |
+
old_parse_config_files = dist.Distribution.parse_config_files
|
18 |
+
|
19 |
+
def parse_config_files(self, *args, **kwargs):
|
20 |
+
result = old_parse_config_files(self, *args, **kwargs)
|
21 |
+
install = self.get_option_dict("install")
|
22 |
+
|
23 |
+
if "prefix" in install: # the prefix governs where to install the libraries
|
24 |
+
install["prefix"] = VIRTUALENV_PATCH_FILE, os.path.abspath(sys.prefix)
|
25 |
+
for base in ("purelib", "platlib", "headers", "scripts", "data"):
|
26 |
+
key = f"install_{base}"
|
27 |
+
if key in install: # do not allow global configs to hijack venv paths
|
28 |
+
install.pop(key, None)
|
29 |
+
return result
|
30 |
+
|
31 |
+
dist.Distribution.parse_config_files = parse_config_files
|
32 |
+
|
33 |
+
|
34 |
+
# Import hook that patches some modules to ignore configuration values that break package installation in case
|
35 |
+
# of virtual environments.
|
36 |
+
_DISTUTILS_PATCH = "distutils.dist", "setuptools.dist"
|
37 |
+
# https://docs.python.org/3/library/importlib.html#setting-up-an-importer
|
38 |
+
|
39 |
+
|
40 |
+
class _Finder:
|
41 |
+
"""A meta path finder that allows patching the imported distutils modules."""
|
42 |
+
|
43 |
+
fullname = None
|
44 |
+
|
45 |
+
# lock[0] is threading.Lock(), but initialized lazily to avoid importing threading very early at startup,
|
46 |
+
# because there are gevent-based applications that need to be first to import threading by themselves.
|
47 |
+
# See https://github.com/pypa/virtualenv/issues/1895 for details.
|
48 |
+
lock = [] # noqa: RUF012
|
49 |
+
|
50 |
+
def find_spec(self, fullname, path, target=None): # noqa: ARG002
|
51 |
+
if fullname in _DISTUTILS_PATCH and self.fullname is None:
|
52 |
+
# initialize lock[0] lazily
|
53 |
+
if len(self.lock) == 0:
|
54 |
+
import threading
|
55 |
+
|
56 |
+
lock = threading.Lock()
|
57 |
+
# there is possibility that two threads T1 and T2 are simultaneously running into find_spec,
|
58 |
+
# observing .lock as empty, and further going into hereby initialization. However due to the GIL,
|
59 |
+
# list.append() operation is atomic and this way only one of the threads will "win" to put the lock
|
60 |
+
# - that every thread will use - into .lock[0].
|
61 |
+
# https://docs.python.org/3/faq/library.html#what-kinds-of-global-value-mutation-are-thread-safe
|
62 |
+
self.lock.append(lock)
|
63 |
+
|
64 |
+
from functools import partial
|
65 |
+
from importlib.util import find_spec
|
66 |
+
|
67 |
+
with self.lock[0]:
|
68 |
+
self.fullname = fullname
|
69 |
+
try:
|
70 |
+
spec = find_spec(fullname, path)
|
71 |
+
if spec is not None:
|
72 |
+
# https://www.python.org/dev/peps/pep-0451/#how-loading-will-work
|
73 |
+
is_new_api = hasattr(spec.loader, "exec_module")
|
74 |
+
func_name = "exec_module" if is_new_api else "load_module"
|
75 |
+
old = getattr(spec.loader, func_name)
|
76 |
+
func = self.exec_module if is_new_api else self.load_module
|
77 |
+
if old is not func:
|
78 |
+
try: # noqa: SIM105
|
79 |
+
setattr(spec.loader, func_name, partial(func, old))
|
80 |
+
except AttributeError:
|
81 |
+
pass # C-Extension loaders are r/o such as zipimporter with <3.7
|
82 |
+
return spec
|
83 |
+
finally:
|
84 |
+
self.fullname = None
|
85 |
+
return None
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def exec_module(old, module):
|
89 |
+
old(module)
|
90 |
+
if module.__name__ in _DISTUTILS_PATCH:
|
91 |
+
patch_dist(module)
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def load_module(old, name):
|
95 |
+
module = old(name)
|
96 |
+
if module.__name__ in _DISTUTILS_PATCH:
|
97 |
+
patch_dist(module)
|
98 |
+
return module
|
99 |
+
|
100 |
+
|
101 |
+
sys.meta_path.insert(0, _Finder())
|
.venv/Lib/site-packages/accelerate/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
__version__ = "1.2.0"
|
15 |
+
|
16 |
+
from .accelerator import Accelerator
|
17 |
+
from .big_modeling import (
|
18 |
+
cpu_offload,
|
19 |
+
cpu_offload_with_hook,
|
20 |
+
disk_offload,
|
21 |
+
dispatch_model,
|
22 |
+
init_empty_weights,
|
23 |
+
init_on_device,
|
24 |
+
load_checkpoint_and_dispatch,
|
25 |
+
)
|
26 |
+
from .data_loader import skip_first_batches
|
27 |
+
from .inference import prepare_pippy
|
28 |
+
from .launchers import debug_launcher, notebook_launcher
|
29 |
+
from .state import PartialState
|
30 |
+
from .utils import (
|
31 |
+
AutocastKwargs,
|
32 |
+
DataLoaderConfiguration,
|
33 |
+
DDPCommunicationHookType,
|
34 |
+
DeepSpeedPlugin,
|
35 |
+
DistributedDataParallelKwargs,
|
36 |
+
DistributedType,
|
37 |
+
FullyShardedDataParallelPlugin,
|
38 |
+
GradScalerKwargs,
|
39 |
+
InitProcessGroupKwargs,
|
40 |
+
ProfileKwargs,
|
41 |
+
find_executable_batch_size,
|
42 |
+
infer_auto_device_map,
|
43 |
+
is_rich_available,
|
44 |
+
load_checkpoint_in_model,
|
45 |
+
synchronize_rng_states,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
if is_rich_available():
|
50 |
+
from .utils import rich
|
.venv/Lib/site-packages/accelerate/accelerator.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
.venv/Lib/site-packages/accelerate/big_modeling.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
from contextlib import contextmanager
|
18 |
+
from functools import wraps
|
19 |
+
from typing import Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
from .hooks import (
|
25 |
+
AlignDevicesHook,
|
26 |
+
CpuOffload,
|
27 |
+
UserCpuOffloadHook,
|
28 |
+
add_hook_to_module,
|
29 |
+
attach_align_device_hook,
|
30 |
+
attach_align_device_hook_on_blocks,
|
31 |
+
)
|
32 |
+
from .utils import (
|
33 |
+
OffloadedWeightsLoader,
|
34 |
+
check_cuda_p2p_ib_support,
|
35 |
+
check_device_map,
|
36 |
+
extract_submodules_state_dict,
|
37 |
+
find_tied_parameters,
|
38 |
+
get_balanced_memory,
|
39 |
+
infer_auto_device_map,
|
40 |
+
is_bnb_available,
|
41 |
+
is_mlu_available,
|
42 |
+
is_musa_available,
|
43 |
+
is_npu_available,
|
44 |
+
is_torch_version,
|
45 |
+
is_xpu_available,
|
46 |
+
load_checkpoint_in_model,
|
47 |
+
offload_state_dict,
|
48 |
+
parse_flag_from_env,
|
49 |
+
retie_parameters,
|
50 |
+
)
|
51 |
+
from .utils.other import recursive_getattr
|
52 |
+
|
53 |
+
|
54 |
+
logger = logging.getLogger(__name__)
|
55 |
+
|
56 |
+
|
57 |
+
@contextmanager
|
58 |
+
def init_empty_weights(include_buffers: bool = None):
|
59 |
+
"""
|
60 |
+
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
61 |
+
empty model. Useful when just initializing the model would blow the available RAM.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
include_buffers (`bool`, *optional*):
|
65 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
66 |
+
|
67 |
+
Example:
|
68 |
+
|
69 |
+
```python
|
70 |
+
import torch.nn as nn
|
71 |
+
from accelerate import init_empty_weights
|
72 |
+
|
73 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
74 |
+
with init_empty_weights():
|
75 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
76 |
+
```
|
77 |
+
|
78 |
+
<Tip warning={true}>
|
79 |
+
|
80 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
81 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
82 |
+
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
83 |
+
called.
|
84 |
+
|
85 |
+
</Tip>
|
86 |
+
"""
|
87 |
+
if include_buffers is None:
|
88 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
89 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
90 |
+
yield f
|
91 |
+
|
92 |
+
|
93 |
+
@contextmanager
|
94 |
+
def init_on_device(device: torch.device, include_buffers: bool = None):
|
95 |
+
"""
|
96 |
+
A context manager under which models are initialized with all parameters on the specified device.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
device (`torch.device`):
|
100 |
+
Device to initialize all parameters on.
|
101 |
+
include_buffers (`bool`, *optional*):
|
102 |
+
Whether or not to also put all buffers on the meta device while initializing.
|
103 |
+
|
104 |
+
Example:
|
105 |
+
|
106 |
+
```python
|
107 |
+
import torch.nn as nn
|
108 |
+
from accelerate import init_on_device
|
109 |
+
|
110 |
+
with init_on_device(device=torch.device("cuda")):
|
111 |
+
tst = nn.Linear(100, 100) # on `cuda` device
|
112 |
+
```
|
113 |
+
"""
|
114 |
+
if include_buffers is None:
|
115 |
+
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
|
116 |
+
|
117 |
+
# TODO(shingjan): remove the torch version check once older versions are deprecated
|
118 |
+
if is_torch_version(">=", "2.0") and include_buffers:
|
119 |
+
with device:
|
120 |
+
yield
|
121 |
+
return
|
122 |
+
|
123 |
+
old_register_parameter = nn.Module.register_parameter
|
124 |
+
if include_buffers:
|
125 |
+
old_register_buffer = nn.Module.register_buffer
|
126 |
+
|
127 |
+
def register_empty_parameter(module, name, param):
|
128 |
+
old_register_parameter(module, name, param)
|
129 |
+
if param is not None:
|
130 |
+
param_cls = type(module._parameters[name])
|
131 |
+
kwargs = module._parameters[name].__dict__
|
132 |
+
kwargs["requires_grad"] = param.requires_grad
|
133 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
134 |
+
|
135 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
136 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
137 |
+
if buffer is not None:
|
138 |
+
module._buffers[name] = module._buffers[name].to(device)
|
139 |
+
|
140 |
+
# Patch tensor creation
|
141 |
+
if include_buffers:
|
142 |
+
tensor_constructors_to_patch = {
|
143 |
+
torch_function_name: getattr(torch, torch_function_name)
|
144 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
145 |
+
}
|
146 |
+
else:
|
147 |
+
tensor_constructors_to_patch = {}
|
148 |
+
|
149 |
+
def patch_tensor_constructor(fn):
|
150 |
+
def wrapper(*args, **kwargs):
|
151 |
+
kwargs["device"] = device
|
152 |
+
return fn(*args, **kwargs)
|
153 |
+
|
154 |
+
return wrapper
|
155 |
+
|
156 |
+
try:
|
157 |
+
nn.Module.register_parameter = register_empty_parameter
|
158 |
+
if include_buffers:
|
159 |
+
nn.Module.register_buffer = register_empty_buffer
|
160 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
161 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
162 |
+
yield
|
163 |
+
finally:
|
164 |
+
nn.Module.register_parameter = old_register_parameter
|
165 |
+
if include_buffers:
|
166 |
+
nn.Module.register_buffer = old_register_buffer
|
167 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
168 |
+
setattr(torch, torch_function_name, old_torch_function)
|
169 |
+
|
170 |
+
|
171 |
+
def cpu_offload(
|
172 |
+
model: nn.Module,
|
173 |
+
execution_device: Optional[torch.device] = None,
|
174 |
+
offload_buffers: bool = False,
|
175 |
+
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
176 |
+
preload_module_classes: Optional[List[str]] = None,
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
|
180 |
+
copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
|
181 |
+
state dict and put on the execution device passed as they are needed, then offloaded again.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
model (`torch.nn.Module`):
|
185 |
+
The model to offload.
|
186 |
+
execution_device (`torch.device`, *optional*):
|
187 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
188 |
+
model first parameter device.
|
189 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
190 |
+
Whether or not to offload the buffers with the model parameters.
|
191 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
192 |
+
The state dict of the model that will be kept on CPU.
|
193 |
+
preload_module_classes (`List[str]`, *optional*):
|
194 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
195 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
196 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
197 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
198 |
+
"""
|
199 |
+
if execution_device is None:
|
200 |
+
execution_device = next(iter(model.parameters())).device
|
201 |
+
if state_dict is None:
|
202 |
+
state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
|
203 |
+
|
204 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
205 |
+
attach_align_device_hook(
|
206 |
+
model,
|
207 |
+
execution_device=execution_device,
|
208 |
+
offload=True,
|
209 |
+
offload_buffers=offload_buffers,
|
210 |
+
weights_map=state_dict,
|
211 |
+
preload_module_classes=preload_module_classes,
|
212 |
+
)
|
213 |
+
|
214 |
+
return model
|
215 |
+
|
216 |
+
|
217 |
+
def cpu_offload_with_hook(
|
218 |
+
model: torch.nn.Module,
|
219 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
220 |
+
prev_module_hook: Optional[UserCpuOffloadHook] = None,
|
221 |
+
):
|
222 |
+
"""
|
223 |
+
Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
|
224 |
+
[`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
|
225 |
+
the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
model (`torch.nn.Module`):
|
229 |
+
The model to offload.
|
230 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
231 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
232 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
233 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
234 |
+
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
|
235 |
+
offload method will be called just before the forward of the model to which this hook is attached.
|
236 |
+
|
237 |
+
Example:
|
238 |
+
|
239 |
+
```py
|
240 |
+
model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
|
241 |
+
model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
|
242 |
+
model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
|
243 |
+
|
244 |
+
hid_1 = model_1(input)
|
245 |
+
for i in range(50):
|
246 |
+
# model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
|
247 |
+
hid_2 = model_2(hid_1)
|
248 |
+
# model2 is offloaded to the CPU just before this forward.
|
249 |
+
hid_3 = model_3(hid_3)
|
250 |
+
|
251 |
+
# For model3, you need to manually call the hook offload method.
|
252 |
+
hook_3.offload()
|
253 |
+
```
|
254 |
+
"""
|
255 |
+
hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
|
256 |
+
add_hook_to_module(model, hook, append=True)
|
257 |
+
user_hook = UserCpuOffloadHook(model, hook)
|
258 |
+
return model, user_hook
|
259 |
+
|
260 |
+
|
261 |
+
def disk_offload(
|
262 |
+
model: nn.Module,
|
263 |
+
offload_dir: Union[str, os.PathLike],
|
264 |
+
execution_device: Optional[torch.device] = None,
|
265 |
+
offload_buffers: bool = False,
|
266 |
+
preload_module_classes: Optional[List[str]] = None,
|
267 |
+
):
|
268 |
+
"""
|
269 |
+
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
|
270 |
+
memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
|
271 |
+
put on the execution device passed as they are needed, then offloaded again.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
model (`torch.nn.Module`): The model to offload.
|
275 |
+
offload_dir (`str` or `os.PathLike`):
|
276 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
277 |
+
execution_device (`torch.device`, *optional*):
|
278 |
+
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
279 |
+
model's first parameter device.
|
280 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
281 |
+
Whether or not to offload the buffers with the model parameters.
|
282 |
+
preload_module_classes (`List[str]`, *optional*):
|
283 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
284 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
285 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
286 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
287 |
+
"""
|
288 |
+
if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
|
289 |
+
offload_state_dict(offload_dir, model.state_dict())
|
290 |
+
if execution_device is None:
|
291 |
+
execution_device = next(iter(model.parameters())).device
|
292 |
+
weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
|
293 |
+
|
294 |
+
add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
|
295 |
+
attach_align_device_hook(
|
296 |
+
model,
|
297 |
+
execution_device=execution_device,
|
298 |
+
offload=True,
|
299 |
+
offload_buffers=offload_buffers,
|
300 |
+
weights_map=weights_map,
|
301 |
+
preload_module_classes=preload_module_classes,
|
302 |
+
)
|
303 |
+
|
304 |
+
return model
|
305 |
+
|
306 |
+
|
307 |
+
def dispatch_model(
|
308 |
+
model: nn.Module,
|
309 |
+
device_map: Dict[str, Union[str, int, torch.device]],
|
310 |
+
main_device: Optional[torch.device] = None,
|
311 |
+
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
312 |
+
offload_dir: Optional[Union[str, os.PathLike]] = None,
|
313 |
+
offload_index: Optional[Dict[str, str]] = None,
|
314 |
+
offload_buffers: bool = False,
|
315 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
316 |
+
preload_module_classes: Optional[List[str]] = None,
|
317 |
+
force_hooks: bool = False,
|
318 |
+
):
|
319 |
+
"""
|
320 |
+
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
|
321 |
+
the CPU or even the disk.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
model (`torch.nn.Module`):
|
325 |
+
The model to dispatch.
|
326 |
+
device_map (`Dict[str, Union[str, int, torch.device]]`):
|
327 |
+
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
|
328 |
+
`"disk"` is accepted even if it's not a proper value for `torch.device`.
|
329 |
+
main_device (`str`, `int` or `torch.device`, *optional*):
|
330 |
+
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
|
331 |
+
`"disk"`.
|
332 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
333 |
+
The state dict of the part of the model that will be kept on CPU.
|
334 |
+
offload_dir (`str` or `os.PathLike`):
|
335 |
+
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
336 |
+
offload_index (`Dict`, *optional*):
|
337 |
+
A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
|
338 |
+
to the index saved in `save_folder`.
|
339 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
340 |
+
Whether or not to offload the buffers with the model parameters.
|
341 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
342 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
343 |
+
preload_module_classes (`List[str]`, *optional*):
|
344 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
345 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
346 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
347 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
348 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
349 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
350 |
+
single device.
|
351 |
+
"""
|
352 |
+
# Error early if the device map is incomplete.
|
353 |
+
check_device_map(model, device_map)
|
354 |
+
|
355 |
+
# We need to force hook for quantized model that can't be moved with to()
|
356 |
+
if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
|
357 |
+
# since bnb 0.43.2, we can move 4-bit model
|
358 |
+
if getattr(model, "is_loaded_in_8bit", False) or (
|
359 |
+
getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
|
360 |
+
):
|
361 |
+
force_hooks = True
|
362 |
+
|
363 |
+
# We attach hooks if the device_map has at least 2 different devices or if
|
364 |
+
# force_hooks is set to `True`. Otherwise, the model in already loaded
|
365 |
+
# in the unique device and the user can decide where to dispatch the model.
|
366 |
+
# If the model is quantized, we always force-dispatch the model
|
367 |
+
if (len(set(device_map.values())) > 1) or force_hooks:
|
368 |
+
if main_device is None:
|
369 |
+
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
|
370 |
+
main_device = "cpu"
|
371 |
+
else:
|
372 |
+
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
373 |
+
|
374 |
+
if main_device != "cpu":
|
375 |
+
cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
|
376 |
+
if state_dict is None and len(cpu_modules) > 0:
|
377 |
+
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
|
378 |
+
|
379 |
+
disk_modules = [name for name, device in device_map.items() if device == "disk"]
|
380 |
+
if offload_dir is None and offload_index is None and len(disk_modules) > 0:
|
381 |
+
raise ValueError(
|
382 |
+
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
|
383 |
+
f"need to be offloaded: {', '.join(disk_modules)}."
|
384 |
+
)
|
385 |
+
if (
|
386 |
+
len(disk_modules) > 0
|
387 |
+
and offload_index is None
|
388 |
+
and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
|
389 |
+
):
|
390 |
+
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
|
391 |
+
offload_state_dict(offload_dir, disk_state_dict)
|
392 |
+
|
393 |
+
execution_device = {
|
394 |
+
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
|
395 |
+
}
|
396 |
+
execution_device[""] = main_device
|
397 |
+
offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
|
398 |
+
offload = {name: device in offloaded_devices for name, device in device_map.items()}
|
399 |
+
save_folder = offload_dir if len(disk_modules) > 0 else None
|
400 |
+
if state_dict is not None or save_folder is not None or offload_index is not None:
|
401 |
+
device = main_device if offload_index is not None else None
|
402 |
+
weights_map = OffloadedWeightsLoader(
|
403 |
+
state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
|
404 |
+
)
|
405 |
+
else:
|
406 |
+
weights_map = None
|
407 |
+
|
408 |
+
# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
|
409 |
+
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
|
410 |
+
# original pointer) on each devices.
|
411 |
+
tied_params = find_tied_parameters(model)
|
412 |
+
|
413 |
+
tied_params_map = {}
|
414 |
+
for group in tied_params:
|
415 |
+
for param_name in group:
|
416 |
+
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
|
417 |
+
# to care about views of tensors through storage_offset.
|
418 |
+
data_ptr = recursive_getattr(model, param_name).data_ptr()
|
419 |
+
tied_params_map[data_ptr] = {}
|
420 |
+
|
421 |
+
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
|
422 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
423 |
+
|
424 |
+
attach_align_device_hook_on_blocks(
|
425 |
+
model,
|
426 |
+
execution_device=execution_device,
|
427 |
+
offload=offload,
|
428 |
+
offload_buffers=offload_buffers,
|
429 |
+
weights_map=weights_map,
|
430 |
+
skip_keys=skip_keys,
|
431 |
+
preload_module_classes=preload_module_classes,
|
432 |
+
tied_params_map=tied_params_map,
|
433 |
+
)
|
434 |
+
|
435 |
+
# warn if there is any params on the meta device
|
436 |
+
offloaded_devices_str = " and ".join(
|
437 |
+
[device for device in set(device_map.values()) if device in ("cpu", "disk")]
|
438 |
+
)
|
439 |
+
if len(offloaded_devices_str) > 0:
|
440 |
+
logger.warning(
|
441 |
+
f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
|
442 |
+
)
|
443 |
+
|
444 |
+
# Attaching the hook may break tied weights, so we retie them
|
445 |
+
retie_parameters(model, tied_params)
|
446 |
+
|
447 |
+
# add warning to cuda and to method
|
448 |
+
def add_warning(fn, model):
|
449 |
+
@wraps(fn)
|
450 |
+
def wrapper(*args, **kwargs):
|
451 |
+
warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
|
452 |
+
if str(fn.__name__) == "to":
|
453 |
+
to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
454 |
+
if to_device is not None:
|
455 |
+
logger.warning(warning_msg)
|
456 |
+
else:
|
457 |
+
logger.warning(warning_msg)
|
458 |
+
for param in model.parameters():
|
459 |
+
if param.device == torch.device("meta"):
|
460 |
+
raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
|
461 |
+
return fn(*args, **kwargs)
|
462 |
+
|
463 |
+
return wrapper
|
464 |
+
|
465 |
+
# Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
|
466 |
+
model.to = add_warning(model.to, model)
|
467 |
+
if is_npu_available():
|
468 |
+
model.npu = add_warning(model.npu, model)
|
469 |
+
elif is_mlu_available():
|
470 |
+
model.mlu = add_warning(model.mlu, model)
|
471 |
+
elif is_musa_available():
|
472 |
+
model.musa = add_warning(model.musa, model)
|
473 |
+
elif is_xpu_available():
|
474 |
+
model.xpu = add_warning(model.xpu, model)
|
475 |
+
else:
|
476 |
+
model.cuda = add_warning(model.cuda, model)
|
477 |
+
|
478 |
+
# Check if we are using multi-gpus with RTX 4000 series
|
479 |
+
use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
|
480 |
+
if use_multi_gpu and not check_cuda_p2p_ib_support():
|
481 |
+
logger.warning(
|
482 |
+
"We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
|
483 |
+
"This can affect the multi-gpu inference when using accelerate device_map."
|
484 |
+
"Please make sure to update your driver to the latest version which resolves this."
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
device = list(device_map.values())[0]
|
488 |
+
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
|
489 |
+
if is_npu_available() and isinstance(device, int):
|
490 |
+
device = f"npu:{device}"
|
491 |
+
elif is_mlu_available() and isinstance(device, int):
|
492 |
+
device = f"mlu:{device}"
|
493 |
+
elif is_musa_available() and isinstance(device, int):
|
494 |
+
device = f"musa:{device}"
|
495 |
+
elif is_xpu_available() and isinstance(device, int):
|
496 |
+
device = f"xpu:{device}"
|
497 |
+
if device != "disk":
|
498 |
+
model.to(device)
|
499 |
+
else:
|
500 |
+
raise ValueError(
|
501 |
+
"You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
|
502 |
+
)
|
503 |
+
# Convert OrderedDict back to dict for easier usage
|
504 |
+
model.hf_device_map = dict(device_map)
|
505 |
+
return model
|
506 |
+
|
507 |
+
|
508 |
+
def load_checkpoint_and_dispatch(
|
509 |
+
model: nn.Module,
|
510 |
+
checkpoint: Union[str, os.PathLike],
|
511 |
+
device_map: Optional[Union[str, Dict[str, Union[int, str, torch.device]]]] = None,
|
512 |
+
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
|
513 |
+
no_split_module_classes: Optional[List[str]] = None,
|
514 |
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
515 |
+
offload_buffers: bool = False,
|
516 |
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
517 |
+
offload_state_dict: Optional[bool] = None,
|
518 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
519 |
+
preload_module_classes: Optional[List[str]] = None,
|
520 |
+
force_hooks: bool = False,
|
521 |
+
strict: bool = False,
|
522 |
+
):
|
523 |
+
"""
|
524 |
+
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
525 |
+
loaded and adds the various hooks that will make this model run properly (even if split across devices).
|
526 |
+
|
527 |
+
Args:
|
528 |
+
model (`torch.nn.Module`): The model in which we want to load a checkpoint.
|
529 |
+
checkpoint (`str` or `os.PathLike`):
|
530 |
+
The folder checkpoint to load. It can be:
|
531 |
+
- a path to a file containing a whole model state dict
|
532 |
+
- a path to a `.json` file containing the index to a sharded checkpoint
|
533 |
+
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
534 |
+
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
535 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
536 |
+
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
537 |
+
|
538 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
|
539 |
+
information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
|
540 |
+
Defaults to None, which means [`dispatch_model`] will not be called.
|
541 |
+
max_memory (`Dict`, *optional*):
|
542 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
|
543 |
+
and the available CPU RAM if unset.
|
544 |
+
no_split_module_classes (`List[str]`, *optional*):
|
545 |
+
A list of layer class names that should never be split across device (for instance any layer that has a
|
546 |
+
residual connection).
|
547 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
548 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
549 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
550 |
+
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
551 |
+
well as the parameters.
|
552 |
+
dtype (`str` or `torch.dtype`, *optional*):
|
553 |
+
If provided, the weights will be converted to that type when loaded.
|
554 |
+
offload_state_dict (`bool`, *optional*):
|
555 |
+
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
|
556 |
+
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
|
557 |
+
picked contains `"disk"` values.
|
558 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
559 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
560 |
+
preload_module_classes (`List[str]`, *optional*):
|
561 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
562 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
563 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
564 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
565 |
+
force_hooks (`bool`, *optional*, defaults to `False`):
|
566 |
+
Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
|
567 |
+
single device.
|
568 |
+
strict (`bool`, *optional*, defaults to `False`):
|
569 |
+
Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
|
570 |
+
state_dict.
|
571 |
+
|
572 |
+
Example:
|
573 |
+
|
574 |
+
```python
|
575 |
+
>>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
576 |
+
>>> from huggingface_hub import hf_hub_download
|
577 |
+
>>> from transformers import AutoConfig, AutoModelForCausalLM
|
578 |
+
|
579 |
+
>>> # Download the Weights
|
580 |
+
>>> checkpoint = "EleutherAI/gpt-j-6B"
|
581 |
+
>>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
|
582 |
+
|
583 |
+
>>> # Create a model and initialize it with empty weights
|
584 |
+
>>> config = AutoConfig.from_pretrained(checkpoint)
|
585 |
+
>>> with init_empty_weights():
|
586 |
+
... model = AutoModelForCausalLM.from_config(config)
|
587 |
+
|
588 |
+
>>> # Load the checkpoint and dispatch it to the right devices
|
589 |
+
>>> model = load_checkpoint_and_dispatch(
|
590 |
+
... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
|
591 |
+
... )
|
592 |
+
```
|
593 |
+
"""
|
594 |
+
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
595 |
+
raise ValueError(
|
596 |
+
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
597 |
+
"'sequential'."
|
598 |
+
)
|
599 |
+
if isinstance(device_map, str):
|
600 |
+
if device_map != "sequential":
|
601 |
+
max_memory = get_balanced_memory(
|
602 |
+
model,
|
603 |
+
max_memory=max_memory,
|
604 |
+
no_split_module_classes=no_split_module_classes,
|
605 |
+
dtype=dtype,
|
606 |
+
low_zero=(device_map == "balanced_low_0"),
|
607 |
+
)
|
608 |
+
device_map = infer_auto_device_map(
|
609 |
+
model,
|
610 |
+
max_memory=max_memory,
|
611 |
+
no_split_module_classes=no_split_module_classes,
|
612 |
+
dtype=dtype,
|
613 |
+
offload_buffers=offload_buffers,
|
614 |
+
)
|
615 |
+
if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
|
616 |
+
offload_state_dict = True
|
617 |
+
load_checkpoint_in_model(
|
618 |
+
model,
|
619 |
+
checkpoint,
|
620 |
+
device_map=device_map,
|
621 |
+
offload_folder=offload_folder,
|
622 |
+
dtype=dtype,
|
623 |
+
offload_state_dict=offload_state_dict,
|
624 |
+
offload_buffers=offload_buffers,
|
625 |
+
strict=strict,
|
626 |
+
)
|
627 |
+
if device_map is None:
|
628 |
+
return model
|
629 |
+
return dispatch_model(
|
630 |
+
model,
|
631 |
+
device_map=device_map,
|
632 |
+
offload_dir=offload_folder,
|
633 |
+
offload_buffers=offload_buffers,
|
634 |
+
skip_keys=skip_keys,
|
635 |
+
preload_module_classes=preload_module_classes,
|
636 |
+
force_hooks=force_hooks,
|
637 |
+
)
|
.venv/Lib/site-packages/accelerate/checkpointing.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import List
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_model
|
22 |
+
from torch.cuda.amp import GradScaler
|
23 |
+
|
24 |
+
from .utils import (
|
25 |
+
MODEL_NAME,
|
26 |
+
OPTIMIZER_NAME,
|
27 |
+
RNG_STATE_NAME,
|
28 |
+
SAFE_MODEL_NAME,
|
29 |
+
SAFE_WEIGHTS_NAME,
|
30 |
+
SAMPLER_NAME,
|
31 |
+
SCALER_NAME,
|
32 |
+
SCHEDULER_NAME,
|
33 |
+
WEIGHTS_NAME,
|
34 |
+
get_pretty_name,
|
35 |
+
is_mlu_available,
|
36 |
+
is_torch_xla_available,
|
37 |
+
is_xpu_available,
|
38 |
+
load,
|
39 |
+
save,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
if is_torch_xla_available():
|
44 |
+
import torch_xla.core.xla_model as xm
|
45 |
+
|
46 |
+
from .logging import get_logger
|
47 |
+
from .state import PartialState
|
48 |
+
|
49 |
+
|
50 |
+
logger = get_logger(__name__)
|
51 |
+
|
52 |
+
|
53 |
+
def save_accelerator_state(
|
54 |
+
output_dir: str,
|
55 |
+
model_states: List[dict],
|
56 |
+
optimizers: list,
|
57 |
+
schedulers: list,
|
58 |
+
dataloaders: list,
|
59 |
+
process_index: int,
|
60 |
+
step: int,
|
61 |
+
scaler: GradScaler = None,
|
62 |
+
save_on_each_node: bool = False,
|
63 |
+
safe_serialization: bool = True,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
|
67 |
+
|
68 |
+
<Tip>
|
69 |
+
|
70 |
+
If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
|
71 |
+
`pickle`.
|
72 |
+
|
73 |
+
</Tip>
|
74 |
+
|
75 |
+
Args:
|
76 |
+
output_dir (`str` or `os.PathLike`):
|
77 |
+
The name of the folder to save all relevant weights and states.
|
78 |
+
model_states (`List[torch.nn.Module]`):
|
79 |
+
A list of model states
|
80 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
81 |
+
A list of optimizer instances
|
82 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
83 |
+
A list of learning rate schedulers
|
84 |
+
dataloaders (`List[torch.utils.data.DataLoader]`):
|
85 |
+
A list of dataloader instances to save their sampler states
|
86 |
+
process_index (`int`):
|
87 |
+
The current process index in the Accelerator state
|
88 |
+
step (`int`):
|
89 |
+
The current step in the internal step tracker
|
90 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
91 |
+
An optional gradient scaler instance to save;
|
92 |
+
save_on_each_node (`bool`, *optional*):
|
93 |
+
Whether to save on every node, or only the main node.
|
94 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
95 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
96 |
+
"""
|
97 |
+
output_dir = Path(output_dir)
|
98 |
+
# Model states
|
99 |
+
for i, state in enumerate(model_states):
|
100 |
+
weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
|
101 |
+
if i > 0:
|
102 |
+
weights_name = weights_name.replace(".", f"_{i}.")
|
103 |
+
output_model_file = output_dir.joinpath(weights_name)
|
104 |
+
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
|
105 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
106 |
+
# Optimizer states
|
107 |
+
for i, opt in enumerate(optimizers):
|
108 |
+
state = opt.state_dict()
|
109 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
110 |
+
output_optimizer_file = output_dir.joinpath(optimizer_name)
|
111 |
+
save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
112 |
+
logger.info(f"Optimizer state saved in {output_optimizer_file}")
|
113 |
+
# Scheduler states
|
114 |
+
for i, scheduler in enumerate(schedulers):
|
115 |
+
state = scheduler.state_dict()
|
116 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
117 |
+
output_scheduler_file = output_dir.joinpath(scheduler_name)
|
118 |
+
save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
119 |
+
logger.info(f"Scheduler state saved in {output_scheduler_file}")
|
120 |
+
# DataLoader states
|
121 |
+
for i, dataloader in enumerate(dataloaders):
|
122 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
123 |
+
output_sampler_file = output_dir.joinpath(sampler_name)
|
124 |
+
# Only save if we have our custom sampler
|
125 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
126 |
+
|
127 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
128 |
+
sampler = dataloader.get_sampler()
|
129 |
+
if isinstance(sampler, SeedableRandomSampler):
|
130 |
+
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
131 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
132 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
133 |
+
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
|
134 |
+
state_dict = dataloader.state_dict()
|
135 |
+
torch.save(state_dict, output_dataloader_state_dict_file)
|
136 |
+
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
137 |
+
|
138 |
+
# GradScaler state
|
139 |
+
if scaler is not None:
|
140 |
+
state = scaler.state_dict()
|
141 |
+
output_scaler_file = output_dir.joinpath(SCALER_NAME)
|
142 |
+
torch.save(state, output_scaler_file)
|
143 |
+
logger.info(f"Gradient scaler state saved in {output_scaler_file}")
|
144 |
+
# Random number generator states
|
145 |
+
states = {}
|
146 |
+
states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
|
147 |
+
states["step"] = step
|
148 |
+
states["random_state"] = random.getstate()
|
149 |
+
states["numpy_random_seed"] = np.random.get_state()
|
150 |
+
states["torch_manual_seed"] = torch.get_rng_state()
|
151 |
+
if is_xpu_available():
|
152 |
+
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
|
153 |
+
if is_mlu_available():
|
154 |
+
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
|
155 |
+
else:
|
156 |
+
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
|
157 |
+
if is_torch_xla_available():
|
158 |
+
states["xm_seed"] = xm.get_rng_state()
|
159 |
+
output_states_file = output_dir.joinpath(states_name)
|
160 |
+
torch.save(states, output_states_file)
|
161 |
+
logger.info(f"Random states saved in {output_states_file}")
|
162 |
+
return output_dir
|
163 |
+
|
164 |
+
|
165 |
+
def load_accelerator_state(
|
166 |
+
input_dir,
|
167 |
+
models,
|
168 |
+
optimizers,
|
169 |
+
schedulers,
|
170 |
+
dataloaders,
|
171 |
+
process_index,
|
172 |
+
scaler=None,
|
173 |
+
map_location=None,
|
174 |
+
**load_model_func_kwargs,
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
input_dir (`str` or `os.PathLike`):
|
181 |
+
The name of the folder to load all relevant weights and states.
|
182 |
+
models (`List[torch.nn.Module]`):
|
183 |
+
A list of model instances
|
184 |
+
optimizers (`List[torch.optim.Optimizer]`):
|
185 |
+
A list of optimizer instances
|
186 |
+
schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
|
187 |
+
A list of learning rate schedulers
|
188 |
+
process_index (`int`):
|
189 |
+
The current process index in the Accelerator state
|
190 |
+
scaler (`torch.amp.GradScaler`, *optional*):
|
191 |
+
An optional *GradScaler* instance to load
|
192 |
+
map_location (`str`, *optional*):
|
193 |
+
What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
|
194 |
+
load_model_func_kwargs (`dict`, *optional*):
|
195 |
+
Additional arguments that can be passed to the model's `load_state_dict` method.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
`dict`: Contains the `Accelerator` attributes to override while loading the state.
|
199 |
+
"""
|
200 |
+
# stores the `Accelerator` attributes to override
|
201 |
+
override_attributes = dict()
|
202 |
+
if map_location not in [None, "cpu", "on_device"]:
|
203 |
+
raise TypeError(
|
204 |
+
"Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
|
205 |
+
)
|
206 |
+
if map_location is None:
|
207 |
+
map_location = "cpu"
|
208 |
+
elif map_location == "on_device":
|
209 |
+
map_location = PartialState().device
|
210 |
+
|
211 |
+
input_dir = Path(input_dir)
|
212 |
+
# Model states
|
213 |
+
for i, model in enumerate(models):
|
214 |
+
ending = f"_{i}" if i > 0 else ""
|
215 |
+
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
|
216 |
+
if input_model_file.exists():
|
217 |
+
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
|
218 |
+
else:
|
219 |
+
# Load with torch
|
220 |
+
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
|
221 |
+
state_dict = load(input_model_file, map_location=map_location)
|
222 |
+
model.load_state_dict(state_dict, **load_model_func_kwargs)
|
223 |
+
logger.info("All model weights loaded successfully")
|
224 |
+
|
225 |
+
# Optimizer states
|
226 |
+
for i, opt in enumerate(optimizers):
|
227 |
+
optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
|
228 |
+
input_optimizer_file = input_dir.joinpath(optimizer_name)
|
229 |
+
optimizer_state = load(input_optimizer_file, map_location=map_location)
|
230 |
+
optimizers[i].load_state_dict(optimizer_state)
|
231 |
+
logger.info("All optimizer states loaded successfully")
|
232 |
+
|
233 |
+
# Scheduler states
|
234 |
+
for i, scheduler in enumerate(schedulers):
|
235 |
+
scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
|
236 |
+
input_scheduler_file = input_dir.joinpath(scheduler_name)
|
237 |
+
scheduler_state = load(input_scheduler_file)
|
238 |
+
scheduler.load_state_dict(scheduler_state)
|
239 |
+
logger.info("All scheduler states loaded successfully")
|
240 |
+
|
241 |
+
for i, dataloader in enumerate(dataloaders):
|
242 |
+
sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
|
243 |
+
input_sampler_file = input_dir.joinpath(sampler_name)
|
244 |
+
# Only load if we have our custom sampler
|
245 |
+
from .data_loader import IterableDatasetShard, SeedableRandomSampler
|
246 |
+
|
247 |
+
if isinstance(dataloader.dataset, IterableDatasetShard):
|
248 |
+
sampler = dataloader.get_sampler()
|
249 |
+
if isinstance(sampler, SeedableRandomSampler):
|
250 |
+
sampler = dataloader.set_sampler(load(input_sampler_file))
|
251 |
+
if getattr(dataloader, "use_stateful_dataloader", False):
|
252 |
+
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
253 |
+
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
|
254 |
+
if input_dataloader_state_dict_file.exists():
|
255 |
+
state_dict = load(input_dataloader_state_dict_file)
|
256 |
+
dataloader.load_state_dict(state_dict)
|
257 |
+
logger.info("All dataloader sampler states loaded successfully")
|
258 |
+
|
259 |
+
# GradScaler state
|
260 |
+
if scaler is not None:
|
261 |
+
input_scaler_file = input_dir.joinpath(SCALER_NAME)
|
262 |
+
scaler_state = load(input_scaler_file)
|
263 |
+
scaler.load_state_dict(scaler_state)
|
264 |
+
logger.info("GradScaler state loaded successfully")
|
265 |
+
|
266 |
+
# Random states
|
267 |
+
try:
|
268 |
+
states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
|
269 |
+
if "step" in states:
|
270 |
+
override_attributes["step"] = states["step"]
|
271 |
+
random.setstate(states["random_state"])
|
272 |
+
np.random.set_state(states["numpy_random_seed"])
|
273 |
+
torch.set_rng_state(states["torch_manual_seed"])
|
274 |
+
if is_xpu_available():
|
275 |
+
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
|
276 |
+
if is_mlu_available():
|
277 |
+
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
|
278 |
+
else:
|
279 |
+
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
|
280 |
+
if is_torch_xla_available():
|
281 |
+
xm.set_rng_state(states["xm_seed"])
|
282 |
+
logger.info("All random states loaded successfully")
|
283 |
+
except Exception:
|
284 |
+
logger.info("Could not load random states")
|
285 |
+
|
286 |
+
return override_attributes
|
287 |
+
|
288 |
+
|
289 |
+
def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
|
290 |
+
"""
|
291 |
+
Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
|
292 |
+
"""
|
293 |
+
# Should this be the right way to get a qual_name type value from `obj`?
|
294 |
+
save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
|
295 |
+
logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
|
296 |
+
save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
|
297 |
+
|
298 |
+
|
299 |
+
def load_custom_state(obj, path, index: int = 0):
|
300 |
+
"""
|
301 |
+
Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
|
302 |
+
loading the state.
|
303 |
+
"""
|
304 |
+
load_location = f"{path}/custom_checkpoint_{index}.pkl"
|
305 |
+
logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
|
306 |
+
obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
|
.venv/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-39.pyc
ADDED
Binary file (1.31 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/accelerate_cli.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from accelerate.commands.config import get_config_parser
|
18 |
+
from accelerate.commands.env import env_command_parser
|
19 |
+
from accelerate.commands.estimate import estimate_command_parser
|
20 |
+
from accelerate.commands.launch import launch_command_parser
|
21 |
+
from accelerate.commands.merge import merge_command_parser
|
22 |
+
from accelerate.commands.test import test_command_parser
|
23 |
+
from accelerate.commands.tpu import tpu_command_parser
|
24 |
+
from accelerate.commands.utils import CustomArgumentParser
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate <command> [<args>]", allow_abbrev=False)
|
29 |
+
subparsers = parser.add_subparsers(help="accelerate command helpers")
|
30 |
+
|
31 |
+
# Register commands
|
32 |
+
get_config_parser(subparsers=subparsers)
|
33 |
+
estimate_command_parser(subparsers=subparsers)
|
34 |
+
env_command_parser(subparsers=subparsers)
|
35 |
+
launch_command_parser(subparsers=subparsers)
|
36 |
+
merge_command_parser(subparsers=subparsers)
|
37 |
+
tpu_command_parser(subparsers=subparsers)
|
38 |
+
test_command_parser(subparsers=subparsers)
|
39 |
+
|
40 |
+
# Let's go
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
if not hasattr(args, "func"):
|
44 |
+
parser.print_help()
|
45 |
+
exit(1)
|
46 |
+
|
47 |
+
# Run
|
48 |
+
args.func(args)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
main()
|
.venv/Lib/site-packages/accelerate/commands/config/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
from .config import config_command_parser
|
20 |
+
from .config_args import default_config_file, load_config_from_file # noqa: F401
|
21 |
+
from .default import default_command_parser
|
22 |
+
from .update import update_command_parser
|
23 |
+
|
24 |
+
|
25 |
+
def get_config_parser(subparsers=None):
|
26 |
+
parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
27 |
+
# The main config parser
|
28 |
+
config_parser = config_command_parser(subparsers)
|
29 |
+
# The subparser to add commands to
|
30 |
+
subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
|
31 |
+
|
32 |
+
# Then add other parsers with the parent parser
|
33 |
+
default_command_parser(subcommands, parents=[parent_parser])
|
34 |
+
update_command_parser(subcommands, parents=[parent_parser])
|
35 |
+
|
36 |
+
return config_parser
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
config_parser = get_config_parser()
|
41 |
+
args = config_parser.parse_args()
|
42 |
+
|
43 |
+
if not hasattr(args, "func"):
|
44 |
+
config_parser.print_help()
|
45 |
+
exit(1)
|
46 |
+
|
47 |
+
# Run
|
48 |
+
args.func(args)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
main()
|
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-39.pyc
ADDED
Binary file (17.7 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-39.pyc
ADDED
Binary file (2.43 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-39.pyc
ADDED
Binary file (7.52 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-39.pyc
ADDED
Binary file (3.05 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-39.pyc
ADDED
Binary file (1.86 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/config/config_args.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from enum import Enum
|
21 |
+
from typing import List, Optional, Union
|
22 |
+
|
23 |
+
import yaml
|
24 |
+
|
25 |
+
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
26 |
+
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
|
27 |
+
|
28 |
+
|
29 |
+
hf_cache_home = os.path.expanduser(
|
30 |
+
os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
31 |
+
)
|
32 |
+
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
33 |
+
default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
|
34 |
+
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
|
35 |
+
|
36 |
+
# For backward compatibility: the default config is the json one if it's the only existing file.
|
37 |
+
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
|
38 |
+
default_config_file = default_yaml_config_file
|
39 |
+
else:
|
40 |
+
default_config_file = default_json_config_file
|
41 |
+
|
42 |
+
|
43 |
+
def load_config_from_file(config_file):
|
44 |
+
if config_file is not None:
|
45 |
+
if not os.path.isfile(config_file):
|
46 |
+
raise FileNotFoundError(
|
47 |
+
f"The passed configuration file `{config_file}` does not exist. "
|
48 |
+
"Please pass an existing file to `accelerate launch`, or use the default one "
|
49 |
+
"created through `accelerate config` and run `accelerate launch` "
|
50 |
+
"without the `--config_file` argument."
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
config_file = default_config_file
|
54 |
+
with open(config_file, encoding="utf-8") as f:
|
55 |
+
if config_file.endswith(".json"):
|
56 |
+
if (
|
57 |
+
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
58 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
59 |
+
):
|
60 |
+
config_class = ClusterConfig
|
61 |
+
else:
|
62 |
+
config_class = SageMakerConfig
|
63 |
+
return config_class.from_json_file(json_file=config_file)
|
64 |
+
else:
|
65 |
+
if (
|
66 |
+
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
67 |
+
== ComputeEnvironment.LOCAL_MACHINE
|
68 |
+
):
|
69 |
+
config_class = ClusterConfig
|
70 |
+
else:
|
71 |
+
config_class = SageMakerConfig
|
72 |
+
return config_class.from_yaml_file(yaml_file=config_file)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class BaseConfig:
|
77 |
+
compute_environment: ComputeEnvironment
|
78 |
+
distributed_type: Union[DistributedType, SageMakerDistributedType]
|
79 |
+
mixed_precision: str
|
80 |
+
use_cpu: bool
|
81 |
+
debug: bool
|
82 |
+
|
83 |
+
def to_dict(self):
|
84 |
+
result = self.__dict__
|
85 |
+
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
86 |
+
|
87 |
+
def _convert_enums(value):
|
88 |
+
if isinstance(value, Enum):
|
89 |
+
return value.value
|
90 |
+
if isinstance(value, dict):
|
91 |
+
if not bool(value):
|
92 |
+
return None
|
93 |
+
for key1, value1 in value.items():
|
94 |
+
value[key1] = _convert_enums(value1)
|
95 |
+
return value
|
96 |
+
|
97 |
+
for key, value in result.items():
|
98 |
+
result[key] = _convert_enums(value)
|
99 |
+
result = {k: v for k, v in result.items() if v is not None}
|
100 |
+
return result
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def process_config(config_dict):
|
104 |
+
"""
|
105 |
+
Processes `config_dict` and sets default values for any missing keys
|
106 |
+
"""
|
107 |
+
if "compute_environment" not in config_dict:
|
108 |
+
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
109 |
+
if "distributed_type" not in config_dict:
|
110 |
+
raise ValueError("A `distributed_type` must be specified in the config file.")
|
111 |
+
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
|
112 |
+
config_dict["num_processes"] = 1
|
113 |
+
if "mixed_precision" not in config_dict:
|
114 |
+
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
115 |
+
if "fp16" in config_dict: # Convert the config to the new format.
|
116 |
+
del config_dict["fp16"]
|
117 |
+
if "dynamo_backend" in config_dict: # Convert the config to the new format.
|
118 |
+
dynamo_backend = config_dict.pop("dynamo_backend")
|
119 |
+
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
|
120 |
+
if "use_cpu" not in config_dict:
|
121 |
+
config_dict["use_cpu"] = False
|
122 |
+
if "debug" not in config_dict:
|
123 |
+
config_dict["debug"] = False
|
124 |
+
if "enable_cpu_affinity" not in config_dict:
|
125 |
+
config_dict["enable_cpu_affinity"] = False
|
126 |
+
return config_dict
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def from_json_file(cls, json_file=None):
|
130 |
+
json_file = default_json_config_file if json_file is None else json_file
|
131 |
+
with open(json_file, encoding="utf-8") as f:
|
132 |
+
config_dict = json.load(f)
|
133 |
+
config_dict = cls.process_config(config_dict)
|
134 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
135 |
+
if len(extra_keys) > 0:
|
136 |
+
raise ValueError(
|
137 |
+
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
138 |
+
" version or fix (and potentially remove) these keys from your config file."
|
139 |
+
)
|
140 |
+
|
141 |
+
return cls(**config_dict)
|
142 |
+
|
143 |
+
def to_json_file(self, json_file):
|
144 |
+
with open(json_file, "w", encoding="utf-8") as f:
|
145 |
+
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
146 |
+
f.write(content)
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def from_yaml_file(cls, yaml_file=None):
|
150 |
+
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
151 |
+
with open(yaml_file, encoding="utf-8") as f:
|
152 |
+
config_dict = yaml.safe_load(f)
|
153 |
+
config_dict = cls.process_config(config_dict)
|
154 |
+
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
155 |
+
if len(extra_keys) > 0:
|
156 |
+
raise ValueError(
|
157 |
+
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
158 |
+
" version or fix (and potentially remove) these keys from your config file."
|
159 |
+
)
|
160 |
+
return cls(**config_dict)
|
161 |
+
|
162 |
+
def to_yaml_file(self, yaml_file):
|
163 |
+
with open(yaml_file, "w", encoding="utf-8") as f:
|
164 |
+
yaml.safe_dump(self.to_dict(), f)
|
165 |
+
|
166 |
+
def __post_init__(self):
|
167 |
+
if isinstance(self.compute_environment, str):
|
168 |
+
self.compute_environment = ComputeEnvironment(self.compute_environment)
|
169 |
+
if isinstance(self.distributed_type, str):
|
170 |
+
if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
171 |
+
self.distributed_type = SageMakerDistributedType(self.distributed_type)
|
172 |
+
else:
|
173 |
+
self.distributed_type = DistributedType(self.distributed_type)
|
174 |
+
if getattr(self, "dynamo_config", None) is None:
|
175 |
+
self.dynamo_config = {}
|
176 |
+
|
177 |
+
|
178 |
+
@dataclass
|
179 |
+
class ClusterConfig(BaseConfig):
|
180 |
+
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
|
181 |
+
machine_rank: int = 0
|
182 |
+
num_machines: int = 1
|
183 |
+
gpu_ids: Optional[str] = None
|
184 |
+
main_process_ip: Optional[str] = None
|
185 |
+
main_process_port: Optional[int] = None
|
186 |
+
rdzv_backend: Optional[str] = "static"
|
187 |
+
same_network: Optional[bool] = False
|
188 |
+
main_training_function: str = "main"
|
189 |
+
enable_cpu_affinity: bool = False
|
190 |
+
|
191 |
+
# args for FP8 training
|
192 |
+
fp8_config: dict = None
|
193 |
+
# args for deepspeed_plugin
|
194 |
+
deepspeed_config: dict = None
|
195 |
+
# args for fsdp
|
196 |
+
fsdp_config: dict = None
|
197 |
+
# args for megatron_lm
|
198 |
+
megatron_lm_config: dict = None
|
199 |
+
# args for ipex
|
200 |
+
ipex_config: dict = None
|
201 |
+
# args for mpirun
|
202 |
+
mpirun_config: dict = None
|
203 |
+
# args for TPU
|
204 |
+
downcast_bf16: bool = False
|
205 |
+
|
206 |
+
# args for TPU pods
|
207 |
+
tpu_name: str = None
|
208 |
+
tpu_zone: str = None
|
209 |
+
tpu_use_cluster: bool = False
|
210 |
+
tpu_use_sudo: bool = False
|
211 |
+
command_file: str = None
|
212 |
+
commands: List[str] = None
|
213 |
+
tpu_vm: List[str] = None
|
214 |
+
tpu_env: List[str] = None
|
215 |
+
|
216 |
+
# args for dynamo
|
217 |
+
dynamo_config: dict = None
|
218 |
+
|
219 |
+
def __post_init__(self):
|
220 |
+
if self.deepspeed_config is None:
|
221 |
+
self.deepspeed_config = {}
|
222 |
+
if self.fsdp_config is None:
|
223 |
+
self.fsdp_config = {}
|
224 |
+
if self.megatron_lm_config is None:
|
225 |
+
self.megatron_lm_config = {}
|
226 |
+
if self.ipex_config is None:
|
227 |
+
self.ipex_config = {}
|
228 |
+
if self.mpirun_config is None:
|
229 |
+
self.mpirun_config = {}
|
230 |
+
if self.fp8_config is None:
|
231 |
+
self.fp8_config = {}
|
232 |
+
return super().__post_init__()
|
233 |
+
|
234 |
+
|
235 |
+
@dataclass
|
236 |
+
class SageMakerConfig(BaseConfig):
|
237 |
+
ec2_instance_type: str
|
238 |
+
iam_role_name: str
|
239 |
+
image_uri: Optional[str] = None
|
240 |
+
profile: Optional[str] = None
|
241 |
+
region: str = "us-east-1"
|
242 |
+
num_machines: int = 1
|
243 |
+
gpu_ids: str = "all"
|
244 |
+
base_job_name: str = f"accelerate-sagemaker-{num_machines}"
|
245 |
+
pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
|
246 |
+
transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
|
247 |
+
py_version: str = SAGEMAKER_PYTHON_VERSION
|
248 |
+
sagemaker_inputs_file: str = None
|
249 |
+
sagemaker_metrics_file: str = None
|
250 |
+
additional_args: dict = None
|
251 |
+
dynamo_config: dict = None
|
252 |
+
enable_cpu_affinity: bool = False
|
.venv/Lib/site-packages/accelerate/commands/config/default.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from ...utils import is_mlu_available, is_musa_available, is_npu_available, is_xpu_available
|
22 |
+
from .config_args import ClusterConfig, default_json_config_file
|
23 |
+
from .config_utils import SubcommandHelpFormatter
|
24 |
+
|
25 |
+
|
26 |
+
description = "Create a default config file for Accelerate with only a few flags set."
|
27 |
+
|
28 |
+
|
29 |
+
def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file, use_xpu: bool = False):
|
30 |
+
"""
|
31 |
+
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
|
32 |
+
set CPU if it is a CPU-only machine.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
mixed_precision (`str`, *optional*, defaults to "no"):
|
36 |
+
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
|
37 |
+
save_location (`str`, *optional*, defaults to `default_json_config_file`):
|
38 |
+
Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
|
39 |
+
location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overriden by setting
|
40 |
+
the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
|
41 |
+
use_xpu (`bool`, *optional*, defaults to `False`):
|
42 |
+
Whether to use XPU if available.
|
43 |
+
"""
|
44 |
+
path = Path(save_location)
|
45 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
46 |
+
if path.exists():
|
47 |
+
print(
|
48 |
+
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
|
49 |
+
)
|
50 |
+
return False
|
51 |
+
mixed_precision = mixed_precision.lower()
|
52 |
+
if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
|
53 |
+
raise ValueError(
|
54 |
+
f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
|
55 |
+
)
|
56 |
+
config = {
|
57 |
+
"compute_environment": "LOCAL_MACHINE",
|
58 |
+
"mixed_precision": mixed_precision,
|
59 |
+
}
|
60 |
+
if is_mlu_available():
|
61 |
+
num_mlus = torch.mlu.device_count()
|
62 |
+
config["num_processes"] = num_mlus
|
63 |
+
config["use_cpu"] = False
|
64 |
+
if num_mlus > 1:
|
65 |
+
config["distributed_type"] = "MULTI_MLU"
|
66 |
+
else:
|
67 |
+
config["distributed_type"] = "NO"
|
68 |
+
elif is_musa_available():
|
69 |
+
num_musas = torch.musa.device_count()
|
70 |
+
config["num_processes"] = num_musas
|
71 |
+
config["use_cpu"] = False
|
72 |
+
if num_musas > 1:
|
73 |
+
config["distributed_type"] = "MULTI_MUSA"
|
74 |
+
else:
|
75 |
+
config["distributed_type"] = "NO"
|
76 |
+
elif torch.cuda.is_available():
|
77 |
+
num_gpus = torch.cuda.device_count()
|
78 |
+
config["num_processes"] = num_gpus
|
79 |
+
config["use_cpu"] = False
|
80 |
+
if num_gpus > 1:
|
81 |
+
config["distributed_type"] = "MULTI_GPU"
|
82 |
+
else:
|
83 |
+
config["distributed_type"] = "NO"
|
84 |
+
elif is_xpu_available() and use_xpu:
|
85 |
+
num_xpus = torch.xpu.device_count()
|
86 |
+
config["num_processes"] = num_xpus
|
87 |
+
config["use_cpu"] = False
|
88 |
+
if num_xpus > 1:
|
89 |
+
config["distributed_type"] = "MULTI_XPU"
|
90 |
+
else:
|
91 |
+
config["distributed_type"] = "NO"
|
92 |
+
elif is_npu_available():
|
93 |
+
num_npus = torch.npu.device_count()
|
94 |
+
config["num_processes"] = num_npus
|
95 |
+
config["use_cpu"] = False
|
96 |
+
if num_npus > 1:
|
97 |
+
config["distributed_type"] = "MULTI_NPU"
|
98 |
+
else:
|
99 |
+
config["distributed_type"] = "NO"
|
100 |
+
else:
|
101 |
+
num_xpus = 0
|
102 |
+
config["use_cpu"] = True
|
103 |
+
config["num_processes"] = 1
|
104 |
+
config["distributed_type"] = "NO"
|
105 |
+
config["debug"] = False
|
106 |
+
config["enable_cpu_affinity"] = False
|
107 |
+
config = ClusterConfig(**config)
|
108 |
+
config.to_json_file(path)
|
109 |
+
return path
|
110 |
+
|
111 |
+
|
112 |
+
def default_command_parser(parser, parents):
|
113 |
+
parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
114 |
+
parser.add_argument(
|
115 |
+
"--config_file",
|
116 |
+
default=default_json_config_file,
|
117 |
+
help=(
|
118 |
+
"The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
|
119 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
120 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
121 |
+
"with 'huggingface'."
|
122 |
+
),
|
123 |
+
dest="save_location",
|
124 |
+
)
|
125 |
+
|
126 |
+
parser.add_argument(
|
127 |
+
"--mixed_precision",
|
128 |
+
choices=["no", "fp16", "bf16"],
|
129 |
+
type=str,
|
130 |
+
help="Whether or not to use mixed precision training. "
|
131 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
132 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
133 |
+
default="no",
|
134 |
+
)
|
135 |
+
parser.set_defaults(func=default_config_command)
|
136 |
+
return parser
|
137 |
+
|
138 |
+
|
139 |
+
def default_config_command(args):
|
140 |
+
config_file = write_basic_config(args.mixed_precision, args.save_location)
|
141 |
+
if config_file:
|
142 |
+
print(f"accelerate configuration saved at {config_file}")
|
.venv/Lib/site-packages/accelerate/commands/config/sagemaker.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
|
19 |
+
from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
|
20 |
+
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
|
21 |
+
from ...utils.imports import is_boto3_available
|
22 |
+
from .config_args import SageMakerConfig
|
23 |
+
from .config_utils import (
|
24 |
+
DYNAMO_BACKENDS,
|
25 |
+
_ask_field,
|
26 |
+
_ask_options,
|
27 |
+
_convert_dynamo_backend,
|
28 |
+
_convert_mixed_precision,
|
29 |
+
_convert_sagemaker_distributed_mode,
|
30 |
+
_convert_yes_no_to_bool,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
if is_boto3_available():
|
35 |
+
import boto3 # noqa: F401
|
36 |
+
|
37 |
+
|
38 |
+
def _create_iam_role_for_sagemaker(role_name):
|
39 |
+
iam_client = boto3.client("iam")
|
40 |
+
|
41 |
+
sagemaker_trust_policy = {
|
42 |
+
"Version": "2012-10-17",
|
43 |
+
"Statement": [
|
44 |
+
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
|
45 |
+
],
|
46 |
+
}
|
47 |
+
try:
|
48 |
+
# create the role, associated with the chosen trust policy
|
49 |
+
iam_client.create_role(
|
50 |
+
RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
|
51 |
+
)
|
52 |
+
policy_document = {
|
53 |
+
"Version": "2012-10-17",
|
54 |
+
"Statement": [
|
55 |
+
{
|
56 |
+
"Effect": "Allow",
|
57 |
+
"Action": [
|
58 |
+
"sagemaker:*",
|
59 |
+
"ecr:GetDownloadUrlForLayer",
|
60 |
+
"ecr:BatchGetImage",
|
61 |
+
"ecr:BatchCheckLayerAvailability",
|
62 |
+
"ecr:GetAuthorizationToken",
|
63 |
+
"cloudwatch:PutMetricData",
|
64 |
+
"cloudwatch:GetMetricData",
|
65 |
+
"cloudwatch:GetMetricStatistics",
|
66 |
+
"cloudwatch:ListMetrics",
|
67 |
+
"logs:CreateLogGroup",
|
68 |
+
"logs:CreateLogStream",
|
69 |
+
"logs:DescribeLogStreams",
|
70 |
+
"logs:PutLogEvents",
|
71 |
+
"logs:GetLogEvents",
|
72 |
+
"s3:CreateBucket",
|
73 |
+
"s3:ListBucket",
|
74 |
+
"s3:GetBucketLocation",
|
75 |
+
"s3:GetObject",
|
76 |
+
"s3:PutObject",
|
77 |
+
],
|
78 |
+
"Resource": "*",
|
79 |
+
}
|
80 |
+
],
|
81 |
+
}
|
82 |
+
# attach policy to role
|
83 |
+
iam_client.put_role_policy(
|
84 |
+
RoleName=role_name,
|
85 |
+
PolicyName=f"{role_name}_policy_permission",
|
86 |
+
PolicyDocument=json.dumps(policy_document, indent=2),
|
87 |
+
)
|
88 |
+
except iam_client.exceptions.EntityAlreadyExistsException:
|
89 |
+
print(f"role {role_name} already exists. Using existing one")
|
90 |
+
|
91 |
+
|
92 |
+
def _get_iam_role_arn(role_name):
|
93 |
+
iam_client = boto3.client("iam")
|
94 |
+
return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
|
95 |
+
|
96 |
+
|
97 |
+
def get_sagemaker_input():
|
98 |
+
credentials_configuration = _ask_options(
|
99 |
+
"How do you want to authorize?",
|
100 |
+
["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
|
101 |
+
int,
|
102 |
+
)
|
103 |
+
aws_profile = None
|
104 |
+
if credentials_configuration == 0:
|
105 |
+
aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
|
106 |
+
os.environ["AWS_PROFILE"] = aws_profile
|
107 |
+
else:
|
108 |
+
print(
|
109 |
+
"Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
|
110 |
+
"`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
|
111 |
+
)
|
112 |
+
aws_access_key_id = _ask_field("AWS Access Key ID: ")
|
113 |
+
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
114 |
+
|
115 |
+
aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
|
116 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
117 |
+
|
118 |
+
aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
|
119 |
+
os.environ["AWS_DEFAULT_REGION"] = aws_region
|
120 |
+
|
121 |
+
role_management = _ask_options(
|
122 |
+
"Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
|
123 |
+
["Provide IAM Role name", "Create new IAM role using credentials"],
|
124 |
+
int,
|
125 |
+
)
|
126 |
+
if role_management == 0:
|
127 |
+
iam_role_name = _ask_field("Enter your IAM role name: ")
|
128 |
+
else:
|
129 |
+
iam_role_name = "accelerate_sagemaker_execution_role"
|
130 |
+
print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
|
131 |
+
_create_iam_role_for_sagemaker(iam_role_name)
|
132 |
+
|
133 |
+
is_custom_docker_image = _ask_field(
|
134 |
+
"Do you want to use custom Docker image? [yes/NO]: ",
|
135 |
+
_convert_yes_no_to_bool,
|
136 |
+
default=False,
|
137 |
+
error_message="Please enter yes or no.",
|
138 |
+
)
|
139 |
+
docker_image = None
|
140 |
+
if is_custom_docker_image:
|
141 |
+
docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
|
142 |
+
|
143 |
+
is_sagemaker_inputs_enabled = _ask_field(
|
144 |
+
"Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
|
145 |
+
_convert_yes_no_to_bool,
|
146 |
+
default=False,
|
147 |
+
error_message="Please enter yes or no.",
|
148 |
+
)
|
149 |
+
sagemaker_inputs_file = None
|
150 |
+
if is_sagemaker_inputs_enabled:
|
151 |
+
sagemaker_inputs_file = _ask_field(
|
152 |
+
"Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
|
153 |
+
lambda x: str(x).lower(),
|
154 |
+
)
|
155 |
+
|
156 |
+
is_sagemaker_metrics_enabled = _ask_field(
|
157 |
+
"Do you want to enable SageMaker metrics? [yes/NO]: ",
|
158 |
+
_convert_yes_no_to_bool,
|
159 |
+
default=False,
|
160 |
+
error_message="Please enter yes or no.",
|
161 |
+
)
|
162 |
+
sagemaker_metrics_file = None
|
163 |
+
if is_sagemaker_metrics_enabled:
|
164 |
+
sagemaker_metrics_file = _ask_field(
|
165 |
+
"Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
|
166 |
+
lambda x: str(x).lower(),
|
167 |
+
)
|
168 |
+
|
169 |
+
distributed_type = _ask_options(
|
170 |
+
"What is the distributed mode?",
|
171 |
+
["No distributed training", "Data parallelism"],
|
172 |
+
_convert_sagemaker_distributed_mode,
|
173 |
+
)
|
174 |
+
dynamo_config = {}
|
175 |
+
use_dynamo = _ask_field(
|
176 |
+
"Do you wish to optimize your script with torch dynamo?[yes/NO]:",
|
177 |
+
_convert_yes_no_to_bool,
|
178 |
+
default=False,
|
179 |
+
error_message="Please enter yes or no.",
|
180 |
+
)
|
181 |
+
if use_dynamo:
|
182 |
+
prefix = "dynamo_"
|
183 |
+
dynamo_config[prefix + "backend"] = _ask_options(
|
184 |
+
"Which dynamo backend would you like to use?",
|
185 |
+
[x.lower() for x in DYNAMO_BACKENDS],
|
186 |
+
_convert_dynamo_backend,
|
187 |
+
default=2,
|
188 |
+
)
|
189 |
+
use_custom_options = _ask_field(
|
190 |
+
"Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
|
191 |
+
_convert_yes_no_to_bool,
|
192 |
+
default=False,
|
193 |
+
error_message="Please enter yes or no.",
|
194 |
+
)
|
195 |
+
|
196 |
+
if use_custom_options:
|
197 |
+
dynamo_config[prefix + "mode"] = _ask_options(
|
198 |
+
"Which mode do you want to use?",
|
199 |
+
TORCH_DYNAMO_MODES,
|
200 |
+
lambda x: TORCH_DYNAMO_MODES[int(x)],
|
201 |
+
default="default",
|
202 |
+
)
|
203 |
+
dynamo_config[prefix + "use_fullgraph"] = _ask_field(
|
204 |
+
"Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
|
205 |
+
_convert_yes_no_to_bool,
|
206 |
+
default=False,
|
207 |
+
error_message="Please enter yes or no.",
|
208 |
+
)
|
209 |
+
dynamo_config[prefix + "use_dynamic"] = _ask_field(
|
210 |
+
"Do you want to enable dynamic shape tracing? [yes/NO]: ",
|
211 |
+
_convert_yes_no_to_bool,
|
212 |
+
default=False,
|
213 |
+
error_message="Please enter yes or no.",
|
214 |
+
)
|
215 |
+
ec2_instance_query = "Which EC2 instance type you want to use for your training?"
|
216 |
+
if distributed_type != SageMakerDistributedType.NO:
|
217 |
+
ec2_instance_type = _ask_options(
|
218 |
+
ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
ec2_instance_query += "? [ml.p3.2xlarge]:"
|
222 |
+
ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
|
223 |
+
|
224 |
+
debug = False
|
225 |
+
if distributed_type != SageMakerDistributedType.NO:
|
226 |
+
debug = _ask_field(
|
227 |
+
"Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
|
228 |
+
_convert_yes_no_to_bool,
|
229 |
+
default=False,
|
230 |
+
error_message="Please enter yes or no.",
|
231 |
+
)
|
232 |
+
|
233 |
+
num_machines = 1
|
234 |
+
if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
|
235 |
+
num_machines = _ask_field(
|
236 |
+
"How many machines do you want use? [1]: ",
|
237 |
+
int,
|
238 |
+
default=1,
|
239 |
+
)
|
240 |
+
|
241 |
+
mixed_precision = _ask_options(
|
242 |
+
"Do you wish to use FP16 or BF16 (mixed precision)?",
|
243 |
+
["no", "fp16", "bf16", "fp8"],
|
244 |
+
_convert_mixed_precision,
|
245 |
+
)
|
246 |
+
|
247 |
+
if use_dynamo and mixed_precision == "no":
|
248 |
+
print(
|
249 |
+
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
|
250 |
+
)
|
251 |
+
|
252 |
+
return SageMakerConfig(
|
253 |
+
image_uri=docker_image,
|
254 |
+
compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
|
255 |
+
distributed_type=distributed_type,
|
256 |
+
use_cpu=False,
|
257 |
+
dynamo_config=dynamo_config,
|
258 |
+
ec2_instance_type=ec2_instance_type,
|
259 |
+
profile=aws_profile,
|
260 |
+
region=aws_region,
|
261 |
+
iam_role_name=iam_role_name,
|
262 |
+
mixed_precision=mixed_precision,
|
263 |
+
num_machines=num_machines,
|
264 |
+
sagemaker_inputs_file=sagemaker_inputs_file,
|
265 |
+
sagemaker_metrics_file=sagemaker_metrics_file,
|
266 |
+
debug=debug,
|
267 |
+
)
|
.venv/Lib/site-packages/accelerate/commands/config/update.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
from .config_args import default_config_file, load_config_from_file
|
20 |
+
from .config_utils import SubcommandHelpFormatter
|
21 |
+
|
22 |
+
|
23 |
+
description = "Update an existing config file with the latest defaults while maintaining the old configuration."
|
24 |
+
|
25 |
+
|
26 |
+
def update_config(args):
|
27 |
+
"""
|
28 |
+
Update an existing config file with the latest defaults while maintaining the old configuration.
|
29 |
+
"""
|
30 |
+
config_file = args.config_file
|
31 |
+
if config_file is None and Path(default_config_file).exists():
|
32 |
+
config_file = default_config_file
|
33 |
+
elif not Path(config_file).exists():
|
34 |
+
raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
|
35 |
+
config = load_config_from_file(config_file)
|
36 |
+
|
37 |
+
if config_file.endswith(".json"):
|
38 |
+
config.to_json_file(config_file)
|
39 |
+
else:
|
40 |
+
config.to_yaml_file(config_file)
|
41 |
+
return config_file
|
42 |
+
|
43 |
+
|
44 |
+
def update_command_parser(parser, parents):
|
45 |
+
parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
|
46 |
+
parser.add_argument(
|
47 |
+
"--config_file",
|
48 |
+
default=None,
|
49 |
+
help=(
|
50 |
+
"The path to the config file to update. Will default to a file named default_config.yaml in the cache "
|
51 |
+
"location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
|
52 |
+
"such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
|
53 |
+
"with 'huggingface'."
|
54 |
+
),
|
55 |
+
)
|
56 |
+
|
57 |
+
parser.set_defaults(func=update_config_command)
|
58 |
+
return parser
|
59 |
+
|
60 |
+
|
61 |
+
def update_config_command(args):
|
62 |
+
config_file = update_config(args)
|
63 |
+
print(f"Sucessfully updated the configuration file at {config_file}.")
|
.venv/Lib/site-packages/accelerate/commands/env.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
import platform
|
20 |
+
import subprocess
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import psutil
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from accelerate import __version__ as version
|
27 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
28 |
+
|
29 |
+
from ..utils import is_mlu_available, is_musa_available, is_npu_available, is_xpu_available
|
30 |
+
|
31 |
+
|
32 |
+
def env_command_parser(subparsers=None):
|
33 |
+
if subparsers is not None:
|
34 |
+
parser = subparsers.add_parser("env")
|
35 |
+
else:
|
36 |
+
parser = argparse.ArgumentParser("Accelerate env command")
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--config_file", default=None, help="The config file to use for the default values in the launching script."
|
40 |
+
)
|
41 |
+
|
42 |
+
if subparsers is not None:
|
43 |
+
parser.set_defaults(func=env_command)
|
44 |
+
return parser
|
45 |
+
|
46 |
+
|
47 |
+
def env_command(args):
|
48 |
+
pt_version = torch.__version__
|
49 |
+
pt_cuda_available = torch.cuda.is_available()
|
50 |
+
pt_xpu_available = is_xpu_available()
|
51 |
+
pt_mlu_available = is_mlu_available()
|
52 |
+
pt_musa_available = is_musa_available()
|
53 |
+
pt_npu_available = is_npu_available()
|
54 |
+
|
55 |
+
accelerate_config = "Not found"
|
56 |
+
# Get the default from the config file.
|
57 |
+
if args.config_file is not None or os.path.isfile(default_config_file):
|
58 |
+
accelerate_config = load_config_from_file(args.config_file).to_dict()
|
59 |
+
|
60 |
+
# if we can run which, get it
|
61 |
+
command = None
|
62 |
+
bash_location = "Not found"
|
63 |
+
if os.name == "nt":
|
64 |
+
command = ["where", "accelerate"]
|
65 |
+
elif os.name == "posix":
|
66 |
+
command = ["which", "accelerate"]
|
67 |
+
if command is not None:
|
68 |
+
bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
|
69 |
+
info = {
|
70 |
+
"`Accelerate` version": version,
|
71 |
+
"Platform": platform.platform(),
|
72 |
+
"`accelerate` bash location": bash_location,
|
73 |
+
"Python version": platform.python_version(),
|
74 |
+
"Numpy version": np.__version__,
|
75 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
76 |
+
"PyTorch XPU available": str(pt_xpu_available),
|
77 |
+
"PyTorch NPU available": str(pt_npu_available),
|
78 |
+
"PyTorch MLU available": str(pt_mlu_available),
|
79 |
+
"PyTorch MUSA available": str(pt_musa_available),
|
80 |
+
"System RAM": f"{psutil.virtual_memory().total / 1024 ** 3:.2f} GB",
|
81 |
+
}
|
82 |
+
if pt_cuda_available:
|
83 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
84 |
+
if pt_mlu_available:
|
85 |
+
info["MLU type"] = torch.mlu.get_device_name()
|
86 |
+
if pt_npu_available:
|
87 |
+
info["CANN version"] = torch.version.cann
|
88 |
+
|
89 |
+
print("\nCopy-and-paste the text below in your GitHub issue\n")
|
90 |
+
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]))
|
91 |
+
|
92 |
+
print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:")
|
93 |
+
accelerate_config_str = (
|
94 |
+
"\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
|
95 |
+
if isinstance(accelerate_config, dict)
|
96 |
+
else f"\t{accelerate_config}"
|
97 |
+
)
|
98 |
+
print(accelerate_config_str)
|
99 |
+
|
100 |
+
info["`Accelerate` configs"] = accelerate_config
|
101 |
+
|
102 |
+
return info
|
103 |
+
|
104 |
+
|
105 |
+
def main() -> int:
|
106 |
+
parser = env_command_parser()
|
107 |
+
args = parser.parse_args()
|
108 |
+
env_command(args)
|
109 |
+
return 0
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
raise SystemExit(main())
|
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (245 Bytes). View file
|
|
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-39.pyc
ADDED
Binary file (1.56 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-39.pyc
ADDED
Binary file (2.41 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-39.pyc
ADDED
Binary file (2.39 kB). View file
|
|
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-39.pyc
ADDED
Binary file (4.46 kB). View file
|
|
.venv/Lib/site-packages/accelerate/data_loader.py
ADDED
@@ -0,0 +1,1323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from contextlib import suppress
|
17 |
+
from typing import Callable, List, Optional, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
21 |
+
|
22 |
+
from .logging import get_logger
|
23 |
+
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
24 |
+
from .utils import (
|
25 |
+
RNGType,
|
26 |
+
broadcast,
|
27 |
+
broadcast_object_list,
|
28 |
+
concatenate,
|
29 |
+
find_batch_size,
|
30 |
+
get_data_structure,
|
31 |
+
initialize_tensors,
|
32 |
+
is_torch_version,
|
33 |
+
is_torchdata_stateful_dataloader_available,
|
34 |
+
send_to_device,
|
35 |
+
slice_tensors,
|
36 |
+
synchronize_rng_states,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
logger = get_logger(__name__)
|
41 |
+
|
42 |
+
# kwargs of the DataLoader in min version 1.4.0.
|
43 |
+
_PYTORCH_DATALOADER_KWARGS = {
|
44 |
+
"batch_size": 1,
|
45 |
+
"shuffle": False,
|
46 |
+
"sampler": None,
|
47 |
+
"batch_sampler": None,
|
48 |
+
"num_workers": 0,
|
49 |
+
"collate_fn": None,
|
50 |
+
"pin_memory": False,
|
51 |
+
"drop_last": False,
|
52 |
+
"timeout": 0,
|
53 |
+
"worker_init_fn": None,
|
54 |
+
"multiprocessing_context": None,
|
55 |
+
"generator": None,
|
56 |
+
"prefetch_factor": 2,
|
57 |
+
"persistent_workers": False,
|
58 |
+
}
|
59 |
+
|
60 |
+
# kwargs added after by version
|
61 |
+
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {}
|
62 |
+
|
63 |
+
for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
|
64 |
+
if is_torch_version(">=", v):
|
65 |
+
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
|
66 |
+
|
67 |
+
|
68 |
+
class SeedableRandomSampler(RandomSampler):
|
69 |
+
"""
|
70 |
+
Same as a random sampler, except that in `__iter__` a seed can be used.
|
71 |
+
|
72 |
+
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
73 |
+
and be fully reproducable on multiple iterations.
|
74 |
+
|
75 |
+
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
76 |
+
(stored in `self.epoch`).
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, *args, **kwargs):
|
80 |
+
data_seed = kwargs.pop("data_seed", None)
|
81 |
+
super().__init__(*args, **kwargs)
|
82 |
+
|
83 |
+
self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
|
84 |
+
self.epoch = 0
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
if self.generator is None:
|
88 |
+
self.generator = torch.Generator()
|
89 |
+
self.generator.manual_seed(self.initial_seed)
|
90 |
+
|
91 |
+
# Allow `self.epoch` to modify the seed of the generator
|
92 |
+
seed = self.epoch + self.initial_seed
|
93 |
+
# print("Setting seed at epoch", self.epoch, seed)
|
94 |
+
self.generator.manual_seed(seed)
|
95 |
+
yield from super().__iter__()
|
96 |
+
self.set_epoch(self.epoch + 1)
|
97 |
+
|
98 |
+
def set_epoch(self, epoch: int):
|
99 |
+
"Sets the current iteration of the sampler."
|
100 |
+
self.epoch = epoch
|
101 |
+
|
102 |
+
|
103 |
+
class BatchSamplerShard(BatchSampler):
|
104 |
+
"""
|
105 |
+
Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
|
106 |
+
always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
|
107 |
+
Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
|
108 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
batch_sampler (`torch.utils.data.sampler.BatchSampler`):
|
112 |
+
The batch sampler to split in several shards.
|
113 |
+
num_processes (`int`, *optional*, defaults to 1):
|
114 |
+
The number of processes running concurrently.
|
115 |
+
process_index (`int`, *optional*, defaults to 0):
|
116 |
+
The index of the current process.
|
117 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
118 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
119 |
+
yielding different full batches on each process.
|
120 |
+
|
121 |
+
On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
|
122 |
+
|
123 |
+
- the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
|
124 |
+
this argument is set to `False`.
|
125 |
+
- the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
|
126 |
+
then `[6, 7]` if this argument is set to `True`.
|
127 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
128 |
+
Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
|
129 |
+
multiple of (original batch size / number of processes).
|
130 |
+
|
131 |
+
<Tip warning={true}>
|
132 |
+
|
133 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
134 |
+
equal to `False`
|
135 |
+
|
136 |
+
</Tip>"""
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
batch_sampler: BatchSampler,
|
141 |
+
num_processes: int = 1,
|
142 |
+
process_index: int = 0,
|
143 |
+
split_batches: bool = False,
|
144 |
+
even_batches: bool = True,
|
145 |
+
):
|
146 |
+
if split_batches and batch_sampler.batch_size % num_processes != 0:
|
147 |
+
raise ValueError(
|
148 |
+
f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
|
149 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
150 |
+
)
|
151 |
+
self.batch_sampler = batch_sampler
|
152 |
+
self.num_processes = num_processes
|
153 |
+
self.process_index = process_index
|
154 |
+
self.split_batches = split_batches
|
155 |
+
self.even_batches = even_batches
|
156 |
+
self.batch_size = getattr(batch_sampler, "batch_size", None)
|
157 |
+
self.drop_last = getattr(batch_sampler, "drop_last", False)
|
158 |
+
if self.batch_size is None and self.even_batches:
|
159 |
+
raise ValueError(
|
160 |
+
"You need to use `even_batches=False` when the batch sampler has no batch size. If you "
|
161 |
+
"are not calling this method directly, set `accelerator.even_batches=False` instead."
|
162 |
+
)
|
163 |
+
|
164 |
+
@property
|
165 |
+
def total_length(self):
|
166 |
+
return len(self.batch_sampler)
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
if self.split_batches:
|
170 |
+
# Split batches does not change the length of the batch sampler
|
171 |
+
return len(self.batch_sampler)
|
172 |
+
if len(self.batch_sampler) % self.num_processes == 0:
|
173 |
+
# If the length is a round multiple of the number of processes, it's easy.
|
174 |
+
return len(self.batch_sampler) // self.num_processes
|
175 |
+
length = len(self.batch_sampler) // self.num_processes
|
176 |
+
if self.drop_last:
|
177 |
+
# Same if we drop the remainder.
|
178 |
+
return length
|
179 |
+
elif self.even_batches:
|
180 |
+
# When we even batches we always get +1
|
181 |
+
return length + 1
|
182 |
+
else:
|
183 |
+
# Otherwise it depends on the process index.
|
184 |
+
return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
|
185 |
+
|
186 |
+
def __iter__(self):
|
187 |
+
return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
|
188 |
+
|
189 |
+
def _iter_with_split(self):
|
190 |
+
initial_data = []
|
191 |
+
batch_length = self.batch_sampler.batch_size // self.num_processes
|
192 |
+
for idx, batch in enumerate(self.batch_sampler):
|
193 |
+
if idx == 0:
|
194 |
+
initial_data = batch
|
195 |
+
if len(batch) == self.batch_size:
|
196 |
+
# If the batch is full, we yield the part of it this process is responsible of.
|
197 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
198 |
+
|
199 |
+
# If drop_last is True of the last batch was full, iteration is over, otherwise...
|
200 |
+
if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
|
201 |
+
if not self.even_batches:
|
202 |
+
if len(batch) > batch_length * self.process_index:
|
203 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
204 |
+
else:
|
205 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
206 |
+
while len(initial_data) < self.batch_size:
|
207 |
+
initial_data += initial_data
|
208 |
+
batch = batch + initial_data
|
209 |
+
yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
|
210 |
+
|
211 |
+
def _iter_with_no_split(self):
|
212 |
+
initial_data = []
|
213 |
+
batch_to_yield = []
|
214 |
+
for idx, batch in enumerate(self.batch_sampler):
|
215 |
+
# We gather the initial indices in case we need to circle back at the end.
|
216 |
+
if not self.drop_last and idx < self.num_processes:
|
217 |
+
initial_data += batch
|
218 |
+
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
|
219 |
+
# yielding it.
|
220 |
+
if idx % self.num_processes == self.process_index:
|
221 |
+
batch_to_yield = batch
|
222 |
+
if idx % self.num_processes == self.num_processes - 1 and (
|
223 |
+
self.batch_size is None or len(batch) == self.batch_size
|
224 |
+
):
|
225 |
+
yield batch_to_yield
|
226 |
+
batch_to_yield = []
|
227 |
+
|
228 |
+
# If drop_last is True, iteration is over, otherwise...
|
229 |
+
if not self.drop_last and len(initial_data) > 0:
|
230 |
+
if not self.even_batches:
|
231 |
+
if len(batch_to_yield) > 0:
|
232 |
+
yield batch_to_yield
|
233 |
+
else:
|
234 |
+
# ... we yield the complete batch we had saved before if it has the proper length
|
235 |
+
if len(batch_to_yield) == self.batch_size:
|
236 |
+
yield batch_to_yield
|
237 |
+
|
238 |
+
# For degenerate cases where the dataset has less than num_process * batch_size samples
|
239 |
+
while len(initial_data) < self.num_processes * self.batch_size:
|
240 |
+
initial_data += initial_data
|
241 |
+
|
242 |
+
# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
|
243 |
+
if len(batch) == self.batch_size:
|
244 |
+
batch = []
|
245 |
+
idx += 1
|
246 |
+
|
247 |
+
# Make sure we yield a multiple of self.num_processes batches
|
248 |
+
cycle_index = 0
|
249 |
+
while idx % self.num_processes != 0 or len(batch) > 0:
|
250 |
+
end_index = cycle_index + self.batch_size - len(batch)
|
251 |
+
batch += initial_data[cycle_index:end_index]
|
252 |
+
if idx % self.num_processes == self.process_index:
|
253 |
+
yield batch
|
254 |
+
cycle_index = end_index
|
255 |
+
batch = []
|
256 |
+
idx += 1
|
257 |
+
|
258 |
+
|
259 |
+
class IterableDatasetShard(IterableDataset):
|
260 |
+
"""
|
261 |
+
Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
|
262 |
+
always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
|
263 |
+
`split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
|
264 |
+
`drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
|
265 |
+
be too small or loop with indices from the beginning.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
dataset (`torch.utils.data.dataset.IterableDataset`):
|
269 |
+
The batch sampler to split in several shards.
|
270 |
+
batch_size (`int`, *optional*, defaults to 1):
|
271 |
+
The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
|
272 |
+
`split_batches=True`).
|
273 |
+
drop_last (`bool`, *optional*, defaults to `False`):
|
274 |
+
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
|
275 |
+
beginning.
|
276 |
+
num_processes (`int`, *optional*, defaults to 1):
|
277 |
+
The number of processes running concurrently.
|
278 |
+
process_index (`int`, *optional*, defaults to 0):
|
279 |
+
The index of the current process.
|
280 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
281 |
+
Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
|
282 |
+
yielding different full batches on each process.
|
283 |
+
|
284 |
+
On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
|
285 |
+
|
286 |
+
- the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
|
287 |
+
argument is set to `False`.
|
288 |
+
- the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
|
289 |
+
this argument is set to `True`.
|
290 |
+
"""
|
291 |
+
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
dataset: IterableDataset,
|
295 |
+
batch_size: int = 1,
|
296 |
+
drop_last: bool = False,
|
297 |
+
num_processes: int = 1,
|
298 |
+
process_index: int = 0,
|
299 |
+
split_batches: bool = False,
|
300 |
+
):
|
301 |
+
if split_batches and batch_size > 1 and batch_size % num_processes != 0:
|
302 |
+
raise ValueError(
|
303 |
+
f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
|
304 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
305 |
+
)
|
306 |
+
self.dataset = dataset
|
307 |
+
self.batch_size = batch_size
|
308 |
+
self.drop_last = drop_last
|
309 |
+
self.num_processes = num_processes
|
310 |
+
self.process_index = process_index
|
311 |
+
self.split_batches = split_batches
|
312 |
+
|
313 |
+
def set_epoch(self, epoch):
|
314 |
+
self.epoch = epoch
|
315 |
+
if hasattr(self.dataset, "set_epoch"):
|
316 |
+
self.dataset.set_epoch(epoch)
|
317 |
+
|
318 |
+
def __len__(self):
|
319 |
+
# We will just raise the downstream error if the underlying dataset is not sized
|
320 |
+
if self.drop_last:
|
321 |
+
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
|
322 |
+
else:
|
323 |
+
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
|
324 |
+
|
325 |
+
def __iter__(self):
|
326 |
+
if (
|
327 |
+
not hasattr(self.dataset, "set_epoch")
|
328 |
+
and hasattr(self.dataset, "generator")
|
329 |
+
and isinstance(self.dataset.generator, torch.Generator)
|
330 |
+
):
|
331 |
+
self.dataset.generator.manual_seed(self.epoch)
|
332 |
+
real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
|
333 |
+
process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
|
334 |
+
process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
|
335 |
+
|
336 |
+
first_batch = None
|
337 |
+
current_batch = []
|
338 |
+
for element in self.dataset:
|
339 |
+
current_batch.append(element)
|
340 |
+
# Wait to have a full batch before yielding elements.
|
341 |
+
if len(current_batch) == real_batch_size:
|
342 |
+
for i in process_slice:
|
343 |
+
yield current_batch[i]
|
344 |
+
if first_batch is None:
|
345 |
+
first_batch = current_batch.copy()
|
346 |
+
current_batch = []
|
347 |
+
|
348 |
+
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
|
349 |
+
if not self.drop_last and len(current_batch) > 0:
|
350 |
+
if first_batch is None:
|
351 |
+
first_batch = current_batch.copy()
|
352 |
+
while len(current_batch) < real_batch_size:
|
353 |
+
current_batch += first_batch
|
354 |
+
for i in process_slice:
|
355 |
+
yield current_batch[i]
|
356 |
+
|
357 |
+
|
358 |
+
class DataLoaderStateMixin:
|
359 |
+
"""
|
360 |
+
Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
|
361 |
+
end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
|
362 |
+
useful information that might be needed.
|
363 |
+
|
364 |
+
**Available attributes:**
|
365 |
+
|
366 |
+
- **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
|
367 |
+
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
|
368 |
+
batch size
|
369 |
+
|
370 |
+
<Tip warning={true}>
|
371 |
+
|
372 |
+
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
|
373 |
+
`self.gradient_state`.
|
374 |
+
|
375 |
+
</Tip>
|
376 |
+
|
377 |
+
"""
|
378 |
+
|
379 |
+
def __init_subclass__(cls, **kwargs):
|
380 |
+
cls.end_of_dataloader = False
|
381 |
+
cls.remainder = -1
|
382 |
+
|
383 |
+
def reset(self):
|
384 |
+
self.end_of_dataloader = False
|
385 |
+
self.remainder = -1
|
386 |
+
|
387 |
+
def begin(self):
|
388 |
+
"Prepares the gradient state for the current dataloader"
|
389 |
+
self.reset()
|
390 |
+
with suppress(Exception):
|
391 |
+
if not self._drop_last:
|
392 |
+
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
|
393 |
+
self.remainder = length % self.total_batch_size
|
394 |
+
self.gradient_state._add_dataloader(self)
|
395 |
+
|
396 |
+
def end(self):
|
397 |
+
"Cleans up the gradient state after exiting the dataloader"
|
398 |
+
self.gradient_state._remove_dataloader(self)
|
399 |
+
|
400 |
+
|
401 |
+
class DataLoaderAdapter:
|
402 |
+
"""
|
403 |
+
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
404 |
+
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
405 |
+
"""
|
406 |
+
|
407 |
+
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
408 |
+
self.use_stateful_dataloader = use_stateful_dataloader
|
409 |
+
if is_torchdata_stateful_dataloader_available():
|
410 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
411 |
+
|
412 |
+
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
|
413 |
+
raise ImportError(
|
414 |
+
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
|
415 |
+
)
|
416 |
+
if use_stateful_dataloader:
|
417 |
+
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
418 |
+
else:
|
419 |
+
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
420 |
+
|
421 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
422 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
423 |
+
|
424 |
+
def __getattr__(self, name):
|
425 |
+
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
|
426 |
+
if name == "base_dataloader":
|
427 |
+
raise AttributeError()
|
428 |
+
# Delegate attribute access to the internal dataloader
|
429 |
+
return getattr(self.base_dataloader, name)
|
430 |
+
|
431 |
+
def state_dict(self):
|
432 |
+
return self.dl_state_dict
|
433 |
+
|
434 |
+
def load_state_dict(self, state_dict):
|
435 |
+
self.base_dataloader.load_state_dict(state_dict)
|
436 |
+
|
437 |
+
@property
|
438 |
+
def __class__(self):
|
439 |
+
"""
|
440 |
+
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
441 |
+
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
442 |
+
object.
|
443 |
+
"""
|
444 |
+
return self.base_dataloader.__class__
|
445 |
+
|
446 |
+
def __len__(self):
|
447 |
+
return len(self.base_dataloader)
|
448 |
+
|
449 |
+
def adjust_state_dict_for_prefetch(self):
|
450 |
+
"""
|
451 |
+
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
452 |
+
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
|
453 |
+
overridden.
|
454 |
+
|
455 |
+
This should modify `self.dl_state_dict` directly
|
456 |
+
"""
|
457 |
+
# The state dict will be off by a factor of `n-1` batch too many during DDP,
|
458 |
+
# so we need to adjust it here
|
459 |
+
if PartialState().distributed_type != DistributedType.NO:
|
460 |
+
factor = PartialState().num_processes - 1
|
461 |
+
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
|
462 |
+
self.dl_state_dict["_sampler_iter_yielded"] -= factor
|
463 |
+
if self.dl_state_dict["_num_yielded"] > 0:
|
464 |
+
self.dl_state_dict["_num_yielded"] -= factor
|
465 |
+
if self.dl_state_dict["_index_sampler_state"] is not None:
|
466 |
+
if (
|
467 |
+
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
|
468 |
+
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
|
469 |
+
):
|
470 |
+
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
|
471 |
+
|
472 |
+
def _update_state_dict(self):
|
473 |
+
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
|
474 |
+
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
|
475 |
+
# what it wants to yield.
|
476 |
+
#
|
477 |
+
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
|
478 |
+
if hasattr(self.base_dataloader, "state_dict"):
|
479 |
+
self.dl_state_dict = self.base_dataloader.state_dict()
|
480 |
+
# Potentially modify the state_dict to adjust for prefetching
|
481 |
+
self.adjust_state_dict_for_prefetch()
|
482 |
+
# Then tag if we are at the end of the dataloader
|
483 |
+
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
|
484 |
+
|
485 |
+
|
486 |
+
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
487 |
+
"""
|
488 |
+
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
|
489 |
+
|
490 |
+
Args:
|
491 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
492 |
+
The dataset to use to build this dataloader.
|
493 |
+
device (`torch.device`, *optional*):
|
494 |
+
If passed, the device to put all batches on.
|
495 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
496 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
497 |
+
several of:
|
498 |
+
|
499 |
+
- `"torch"`: the base torch random number generator
|
500 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
501 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
502 |
+
- `"generator"`: an optional `torch.Generator`
|
503 |
+
synchronized_generator (`torch.Generator`, *optional*):
|
504 |
+
A random number generator to keep synchronized across processes.
|
505 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
506 |
+
The number of batches to skip at the beginning.
|
507 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
508 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
509 |
+
**kwargs (additional keyword arguments, *optional*):
|
510 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
511 |
+
|
512 |
+
**Available attributes:**
|
513 |
+
|
514 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
515 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
516 |
+
number of processes
|
517 |
+
|
518 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
519 |
+
"""
|
520 |
+
|
521 |
+
def __init__(
|
522 |
+
self,
|
523 |
+
dataset,
|
524 |
+
device=None,
|
525 |
+
rng_types=None,
|
526 |
+
synchronized_generator=None,
|
527 |
+
skip_batches=0,
|
528 |
+
use_stateful_dataloader=False,
|
529 |
+
_drop_last: bool = False,
|
530 |
+
_non_blocking: bool = False,
|
531 |
+
**kwargs,
|
532 |
+
):
|
533 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
534 |
+
self.device = device
|
535 |
+
self.rng_types = rng_types
|
536 |
+
self.synchronized_generator = synchronized_generator
|
537 |
+
self.skip_batches = skip_batches
|
538 |
+
self.gradient_state = GradientState()
|
539 |
+
self._drop_last = _drop_last
|
540 |
+
self._non_blocking = _non_blocking
|
541 |
+
self.iteration = 0
|
542 |
+
|
543 |
+
def __iter__(self):
|
544 |
+
if self.rng_types is not None:
|
545 |
+
synchronize_rng_states(self.rng_types, self.synchronized_generator)
|
546 |
+
self.begin()
|
547 |
+
|
548 |
+
self.set_epoch(self.iteration)
|
549 |
+
dataloader_iter = self.base_dataloader.__iter__()
|
550 |
+
# We iterate one batch ahead to check when we are at the end
|
551 |
+
try:
|
552 |
+
current_batch = next(dataloader_iter)
|
553 |
+
except StopIteration:
|
554 |
+
yield
|
555 |
+
|
556 |
+
batch_index = 0
|
557 |
+
while True:
|
558 |
+
try:
|
559 |
+
# But we still move it to the device so it is done before `StopIteration` is reached
|
560 |
+
if self.device is not None:
|
561 |
+
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
|
562 |
+
self._update_state_dict()
|
563 |
+
next_batch = next(dataloader_iter)
|
564 |
+
if batch_index >= self.skip_batches:
|
565 |
+
yield current_batch
|
566 |
+
batch_index += 1
|
567 |
+
current_batch = next_batch
|
568 |
+
except StopIteration:
|
569 |
+
self.end_of_dataloader = True
|
570 |
+
self._update_state_dict()
|
571 |
+
if batch_index >= self.skip_batches:
|
572 |
+
yield current_batch
|
573 |
+
break
|
574 |
+
|
575 |
+
self.iteration += 1
|
576 |
+
self.end()
|
577 |
+
|
578 |
+
def __reduce__(self):
|
579 |
+
"""
|
580 |
+
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
|
581 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
582 |
+
`__class__` member.
|
583 |
+
"""
|
584 |
+
args = super().__reduce__()
|
585 |
+
return (DataLoaderShard, *args[1:])
|
586 |
+
|
587 |
+
def set_epoch(self, epoch: int):
|
588 |
+
# In case it is manually passed in, the user can set it to what they like
|
589 |
+
if self.iteration != epoch:
|
590 |
+
self.iteration = epoch
|
591 |
+
if hasattr(self.batch_sampler, "set_epoch"):
|
592 |
+
self.batch_sampler.set_epoch(epoch)
|
593 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
594 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
595 |
+
# We support if a custom `Dataset` implementation has `set_epoch`
|
596 |
+
# or in general HF datasets `Datasets`
|
597 |
+
elif hasattr(self.dataset, "set_epoch"):
|
598 |
+
self.dataset.set_epoch(epoch)
|
599 |
+
|
600 |
+
@property
|
601 |
+
def total_batch_size(self):
|
602 |
+
batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
|
603 |
+
return (
|
604 |
+
batch_sampler.batch_size
|
605 |
+
if getattr(batch_sampler, "split_batches", False)
|
606 |
+
else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
|
607 |
+
)
|
608 |
+
|
609 |
+
@property
|
610 |
+
def total_dataset_length(self):
|
611 |
+
if hasattr(self.dataset, "total_length"):
|
612 |
+
return self.dataset.total_length
|
613 |
+
else:
|
614 |
+
return len(self.dataset)
|
615 |
+
|
616 |
+
def get_sampler(self):
|
617 |
+
return get_sampler(self)
|
618 |
+
|
619 |
+
def set_sampler(self, sampler):
|
620 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
621 |
+
if sampler_is_batch_sampler:
|
622 |
+
self.sampler.sampler = sampler
|
623 |
+
else:
|
624 |
+
self.batch_sampler.sampler = sampler
|
625 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
626 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
627 |
+
|
628 |
+
|
629 |
+
if is_torch_xla_available():
|
630 |
+
import torch_xla.distributed.parallel_loader as xpl
|
631 |
+
|
632 |
+
class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
|
633 |
+
"""
|
634 |
+
Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
|
635 |
+
|
636 |
+
XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
|
637 |
+
prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
|
638 |
+
thread only.
|
639 |
+
|
640 |
+
**Available attributes:**
|
641 |
+
|
642 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
643 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
644 |
+
number of processes
|
645 |
+
|
646 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
647 |
+
"""
|
648 |
+
|
649 |
+
def __init__(self, dataloader: DataLoaderShard, device: torch.device):
|
650 |
+
super().__init__(dataloader, device)
|
651 |
+
self._rng_types = self._loader.rng_types
|
652 |
+
self._loader.rng_types = None
|
653 |
+
self.device = device
|
654 |
+
|
655 |
+
def __iter__(self):
|
656 |
+
if self._rng_types is not None:
|
657 |
+
synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
|
658 |
+
|
659 |
+
return super().__iter__()
|
660 |
+
|
661 |
+
def set_epoch(self, epoch: int):
|
662 |
+
if hasattr(self.dataloader, "set_epoch"):
|
663 |
+
self.dataloader.set_epoch(epoch)
|
664 |
+
|
665 |
+
@property
|
666 |
+
def total_batch_size(self):
|
667 |
+
return self._loader.total_batch_size
|
668 |
+
|
669 |
+
@property
|
670 |
+
def total_dataset_length(self):
|
671 |
+
return self._loader.total_dataset_length
|
672 |
+
|
673 |
+
@property
|
674 |
+
def batch_sampler(self):
|
675 |
+
return self._loader.batch_sampler
|
676 |
+
|
677 |
+
@property
|
678 |
+
def dataloader(self):
|
679 |
+
return self._loader
|
680 |
+
|
681 |
+
|
682 |
+
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
683 |
+
"""
|
684 |
+
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
|
685 |
+
their part of the batch.
|
686 |
+
|
687 |
+
Args:
|
688 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
689 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
690 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
691 |
+
`num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
|
692 |
+
the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
|
693 |
+
`dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
|
694 |
+
size of the `dataloader` is a round multiple of `batch_size`.
|
695 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
696 |
+
The number of batches to skip at the beginning of an iteration.
|
697 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
698 |
+
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
699 |
+
|
700 |
+
**Available attributes:**
|
701 |
+
|
702 |
+
- **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
|
703 |
+
Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
|
704 |
+
number of processes
|
705 |
+
|
706 |
+
- **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
|
707 |
+
"""
|
708 |
+
|
709 |
+
def __init__(
|
710 |
+
self,
|
711 |
+
dataset,
|
712 |
+
split_batches: bool = False,
|
713 |
+
skip_batches=0,
|
714 |
+
use_stateful_dataloader=False,
|
715 |
+
_drop_last: bool = False,
|
716 |
+
_non_blocking: bool = False,
|
717 |
+
slice_fn=None,
|
718 |
+
**kwargs,
|
719 |
+
):
|
720 |
+
shuffle = False
|
721 |
+
if is_torch_version(">=", "1.11.0"):
|
722 |
+
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
723 |
+
|
724 |
+
# We need to save the shuffling state of the DataPipe
|
725 |
+
if isinstance(dataset, ShufflerIterDataPipe):
|
726 |
+
shuffle = dataset._shuffle_enabled
|
727 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
728 |
+
self.split_batches = split_batches
|
729 |
+
if shuffle:
|
730 |
+
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
731 |
+
|
732 |
+
self.gradient_state = GradientState()
|
733 |
+
self.state = PartialState()
|
734 |
+
self._drop_last = _drop_last
|
735 |
+
self._non_blocking = _non_blocking
|
736 |
+
self.skip_batches = skip_batches
|
737 |
+
|
738 |
+
self.slice_fn = slice_tensors if slice_fn is None else slice_fn
|
739 |
+
self.iteration = 0
|
740 |
+
|
741 |
+
def _fetch_batches(self, iterator):
|
742 |
+
batches, batch = None, None
|
743 |
+
# On process 0, we gather the batch to dispatch.
|
744 |
+
if self.state.process_index == 0:
|
745 |
+
try:
|
746 |
+
if self.split_batches:
|
747 |
+
# One batch of the main iterator is dispatched and split.
|
748 |
+
self._update_state_dict()
|
749 |
+
batch = next(iterator)
|
750 |
+
else:
|
751 |
+
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
752 |
+
# We add the batches one by one so we have the remainder available when drop_last=False.
|
753 |
+
batches = []
|
754 |
+
for _ in range(self.state.num_processes):
|
755 |
+
self._update_state_dict()
|
756 |
+
batches.append(next(iterator))
|
757 |
+
try:
|
758 |
+
batch = concatenate(batches, dim=0)
|
759 |
+
except RuntimeError as e:
|
760 |
+
raise RuntimeError(
|
761 |
+
"You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
|
762 |
+
"either pass `dispatch_batches=False` and have each process fetch its own batch "
|
763 |
+
" or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
|
764 |
+
"slice it into `num_processes` batches for each process."
|
765 |
+
) from e
|
766 |
+
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
767 |
+
# processes to initialize the tensors with the right shape.
|
768 |
+
# data_structure, stop_iteration
|
769 |
+
batch_info = [get_data_structure(batch), False]
|
770 |
+
except StopIteration:
|
771 |
+
batch_info = [None, True]
|
772 |
+
else:
|
773 |
+
batch_info = [None, self._stop_iteration]
|
774 |
+
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
775 |
+
broadcast_object_list(batch_info)
|
776 |
+
self._stop_iteration = batch_info[1]
|
777 |
+
if self._stop_iteration:
|
778 |
+
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
|
779 |
+
if not self.split_batches and not self._drop_last:
|
780 |
+
if self.state.process_index == 0 and len(batches) > 0:
|
781 |
+
batch = concatenate(batches, dim=0)
|
782 |
+
batch_info = [get_data_structure(batch), False]
|
783 |
+
else:
|
784 |
+
batch_info = [None, True]
|
785 |
+
broadcast_object_list(batch_info)
|
786 |
+
return batch, batch_info
|
787 |
+
|
788 |
+
def __iter__(self):
|
789 |
+
self.begin()
|
790 |
+
self.set_epoch(self.iteration)
|
791 |
+
main_iterator = None
|
792 |
+
if is_torch_version(">=", "2.0.1"):
|
793 |
+
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
|
794 |
+
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
|
795 |
+
# But, we only iterate through the DataLoader on process 0.
|
796 |
+
main_iterator = self.base_dataloader.__iter__()
|
797 |
+
elif self.state.process_index == 0:
|
798 |
+
main_iterator = self.base_dataloader.__iter__()
|
799 |
+
stop_iteration = False
|
800 |
+
self._stop_iteration = False
|
801 |
+
first_batch = None
|
802 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
803 |
+
batch_index = 0
|
804 |
+
while not stop_iteration:
|
805 |
+
batch, batch_info = next_batch, next_batch_info
|
806 |
+
|
807 |
+
if self.state.process_index != 0:
|
808 |
+
# Initialize tensors on other processes than process 0.
|
809 |
+
batch = initialize_tensors(batch_info[0])
|
810 |
+
batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
|
811 |
+
# Broadcast the batch before splitting it.
|
812 |
+
batch = broadcast(batch, from_process=0)
|
813 |
+
|
814 |
+
if not self._drop_last and first_batch is None:
|
815 |
+
# We keep at least num processes elements of the first batch to be able to complete the last batch
|
816 |
+
first_batch = self.slice_fn(
|
817 |
+
batch,
|
818 |
+
slice(0, self.state.num_processes),
|
819 |
+
process_index=self.state.process_index,
|
820 |
+
num_processes=self.state.num_processes,
|
821 |
+
)
|
822 |
+
|
823 |
+
if batch is None:
|
824 |
+
raise ValueError(
|
825 |
+
f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
|
826 |
+
)
|
827 |
+
|
828 |
+
observed_batch_size = find_batch_size(batch)
|
829 |
+
batch_size = observed_batch_size // self.state.num_processes
|
830 |
+
|
831 |
+
stop_iteration = self._stop_iteration
|
832 |
+
if not stop_iteration:
|
833 |
+
# We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
|
834 |
+
# the dataloader since the number of batches is a round multiple of the number of processes.
|
835 |
+
next_batch, next_batch_info = self._fetch_batches(main_iterator)
|
836 |
+
# next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
|
837 |
+
if self._stop_iteration and next_batch_info[0] is None:
|
838 |
+
stop_iteration = True
|
839 |
+
|
840 |
+
if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
|
841 |
+
# If the last batch is not complete, let's add the first batch to it.
|
842 |
+
batch = concatenate([batch, first_batch], dim=0)
|
843 |
+
# Batch size computation above is wrong, it's off by 1 so we fix it.
|
844 |
+
batch_size += 1
|
845 |
+
|
846 |
+
data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
|
847 |
+
batch = self.slice_fn(
|
848 |
+
batch,
|
849 |
+
data_slice,
|
850 |
+
process_index=self.state.process_index,
|
851 |
+
num_processes=self.state.num_processes,
|
852 |
+
)
|
853 |
+
|
854 |
+
if stop_iteration:
|
855 |
+
self.end_of_dataloader = True
|
856 |
+
self._update_state_dict()
|
857 |
+
self.remainder = observed_batch_size
|
858 |
+
if batch_index >= self.skip_batches:
|
859 |
+
yield batch
|
860 |
+
batch_index += 1
|
861 |
+
self.iteration += 1
|
862 |
+
self.end()
|
863 |
+
|
864 |
+
def set_epoch(self, epoch: int):
|
865 |
+
# In case it is manually passed in, the user can set it to what they like
|
866 |
+
if self.iteration != epoch:
|
867 |
+
self.iteration = epoch
|
868 |
+
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
869 |
+
self.batch_sampler.sampler.set_epoch(epoch)
|
870 |
+
elif hasattr(self.dataset, "set_epoch"):
|
871 |
+
self.dataset.set_epoch(epoch)
|
872 |
+
|
873 |
+
def __len__(self):
|
874 |
+
whole_length = len(self.base_dataloader)
|
875 |
+
if self.split_batches:
|
876 |
+
return whole_length
|
877 |
+
elif self._drop_last:
|
878 |
+
return whole_length // self.state.num_processes
|
879 |
+
else:
|
880 |
+
return math.ceil(whole_length / self.state.num_processes)
|
881 |
+
|
882 |
+
def __reduce__(self):
|
883 |
+
"""
|
884 |
+
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
|
885 |
+
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
886 |
+
`__class__` member.
|
887 |
+
"""
|
888 |
+
args = super().__reduce__()
|
889 |
+
return (DataLoaderDispatcher, *args[1:])
|
890 |
+
|
891 |
+
@property
|
892 |
+
def total_batch_size(self):
|
893 |
+
return (
|
894 |
+
self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
|
895 |
+
)
|
896 |
+
|
897 |
+
@property
|
898 |
+
def total_dataset_length(self):
|
899 |
+
return len(self.dataset)
|
900 |
+
|
901 |
+
def get_sampler(self):
|
902 |
+
return get_sampler(self)
|
903 |
+
|
904 |
+
def set_sampler(self, sampler):
|
905 |
+
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
|
906 |
+
if sampler_is_batch_sampler:
|
907 |
+
self.sampler.sampler = sampler
|
908 |
+
else:
|
909 |
+
self.batch_sampler.sampler = sampler
|
910 |
+
if hasattr(self.batch_sampler, "batch_sampler"):
|
911 |
+
self.batch_sampler.batch_sampler.sampler = sampler
|
912 |
+
|
913 |
+
|
914 |
+
def get_sampler(dataloader):
|
915 |
+
"""
|
916 |
+
Get the sampler associated to the dataloader
|
917 |
+
|
918 |
+
Args:
|
919 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
920 |
+
The data loader to split across several devices.
|
921 |
+
Returns:
|
922 |
+
`torch.utils.data.Sampler`: The sampler associated to the dataloader
|
923 |
+
"""
|
924 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
925 |
+
if sampler_is_batch_sampler:
|
926 |
+
sampler = getattr(dataloader.sampler, "sampler", None)
|
927 |
+
else:
|
928 |
+
sampler = getattr(dataloader.batch_sampler, "sampler", None)
|
929 |
+
return sampler
|
930 |
+
|
931 |
+
|
932 |
+
def prepare_data_loader(
|
933 |
+
dataloader: DataLoader,
|
934 |
+
device: Optional[torch.device] = None,
|
935 |
+
num_processes: Optional[int] = None,
|
936 |
+
process_index: Optional[int] = None,
|
937 |
+
split_batches: bool = False,
|
938 |
+
put_on_device: bool = False,
|
939 |
+
rng_types: Optional[List[Union[str, RNGType]]] = None,
|
940 |
+
dispatch_batches: Optional[bool] = None,
|
941 |
+
even_batches: bool = True,
|
942 |
+
slice_fn_for_dispatch: Optional[Callable] = None,
|
943 |
+
use_seedable_sampler: bool = False,
|
944 |
+
data_seed: Optional[int] = None,
|
945 |
+
non_blocking: bool = False,
|
946 |
+
use_stateful_dataloader: bool = False,
|
947 |
+
) -> DataLoader:
|
948 |
+
"""
|
949 |
+
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
|
950 |
+
|
951 |
+
Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
|
952 |
+
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
dataloader (`torch.utils.data.dataloader.DataLoader`):
|
956 |
+
The data loader to split across several devices.
|
957 |
+
device (`torch.device`):
|
958 |
+
The target device for the returned `DataLoader`.
|
959 |
+
num_processes (`int`, *optional*):
|
960 |
+
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
|
961 |
+
process_index (`int`, *optional*):
|
962 |
+
The index of the current process. Will default to the value given by [`~state.PartialState`].
|
963 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
964 |
+
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
965 |
+
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
966 |
+
`num_processes` batches at each iteration).
|
967 |
+
|
968 |
+
Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
|
969 |
+
this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
|
970 |
+
otherwise.
|
971 |
+
|
972 |
+
Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
|
973 |
+
`batch_size`.
|
974 |
+
put_on_device (`bool`, *optional*, defaults to `False`):
|
975 |
+
Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
|
976 |
+
dictionaries of tensors).
|
977 |
+
rng_types (list of `str` or [`~utils.RNGType`]):
|
978 |
+
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
|
979 |
+
several of:
|
980 |
+
|
981 |
+
- `"torch"`: the base torch random number generator
|
982 |
+
- `"cuda"`: the CUDA random number generator (GPU only)
|
983 |
+
- `"xla"`: the XLA random number generator (TPU only)
|
984 |
+
- `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
|
985 |
+
dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
|
986 |
+
|
987 |
+
dispatch_batches (`bool`, *optional*):
|
988 |
+
If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
|
989 |
+
are split and broadcast to each process. Will default to `True` when the underlying dataset is an
|
990 |
+
`IterableDataset`, `False` otherwise.
|
991 |
+
even_batches (`bool`, *optional*, defaults to `True`):
|
992 |
+
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
|
993 |
+
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
|
994 |
+
all workers.
|
995 |
+
slice_fn_for_dispatch (`Callable`, *optional*`):
|
996 |
+
If passed, this function will be used to slice tensors across `num_processes`. Will default to
|
997 |
+
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
|
998 |
+
ignored otherwise.
|
999 |
+
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
|
1000 |
+
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
|
1001 |
+
reproducability. Comes at a cost of potentially different performances due to different shuffling
|
1002 |
+
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
|
1003 |
+
`self.set_epoch`
|
1004 |
+
data_seed (`int`, *optional*, defaults to `None`):
|
1005 |
+
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
|
1006 |
+
will use the current default seed from torch.
|
1007 |
+
non_blocking (`bool`, *optional*, defaults to `False`):
|
1008 |
+
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
|
1009 |
+
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
|
1010 |
+
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
1011 |
+
"If set to true, the dataloader prepared by the Accelerator will be backed by "
|
1012 |
+
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
|
1013 |
+
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
|
1014 |
+
|
1015 |
+
|
1016 |
+
Returns:
|
1017 |
+
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
|
1018 |
+
|
1019 |
+
<Tip warning={true}>
|
1020 |
+
|
1021 |
+
`BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
|
1022 |
+
equal to `False`
|
1023 |
+
|
1024 |
+
</Tip>
|
1025 |
+
"""
|
1026 |
+
if dispatch_batches is None:
|
1027 |
+
if not put_on_device:
|
1028 |
+
dispatch_batches = False
|
1029 |
+
else:
|
1030 |
+
dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
|
1031 |
+
|
1032 |
+
if dispatch_batches and not put_on_device:
|
1033 |
+
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
|
1034 |
+
# Grab defaults from PartialState
|
1035 |
+
state = PartialState()
|
1036 |
+
if num_processes is None:
|
1037 |
+
num_processes = state.num_processes
|
1038 |
+
if process_index is None:
|
1039 |
+
process_index = state.process_index
|
1040 |
+
|
1041 |
+
# Sanity check
|
1042 |
+
if split_batches:
|
1043 |
+
if dataloader.batch_size is not None:
|
1044 |
+
batch_size_for_check = dataloader.batch_size
|
1045 |
+
else:
|
1046 |
+
# For custom batch_sampler
|
1047 |
+
if hasattr(dataloader.batch_sampler, "batch_size"):
|
1048 |
+
batch_size_for_check = dataloader.batch_sampler.batch_size
|
1049 |
+
else:
|
1050 |
+
raise ValueError(
|
1051 |
+
"In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
|
1052 |
+
"`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
|
1053 |
+
"Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
|
1054 |
+
f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
|
1058 |
+
raise ValueError(
|
1059 |
+
f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
|
1060 |
+
f"needs to be a round multiple of the number of processes ({num_processes})."
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
new_dataset = dataloader.dataset
|
1064 |
+
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
|
1065 |
+
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
|
1066 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
1067 |
+
synchronized_generator = None
|
1068 |
+
|
1069 |
+
sampler = get_sampler(dataloader)
|
1070 |
+
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
|
1071 |
+
# When iterating through the dataloader during distributed processes
|
1072 |
+
# we want to ensure that on each process we are iterating through the same
|
1073 |
+
# samples in the same order if a seed is set. This requires a tweak
|
1074 |
+
# to the `torch.utils.data.RandomSampler` class (if used).
|
1075 |
+
sampler = SeedableRandomSampler(
|
1076 |
+
data_source=sampler.data_source,
|
1077 |
+
replacement=sampler.replacement,
|
1078 |
+
num_samples=sampler._num_samples,
|
1079 |
+
generator=getattr(sampler, "generator", torch.Generator()),
|
1080 |
+
data_seed=data_seed,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
|
1084 |
+
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
|
1085 |
+
generator = torch.Generator().manual_seed(42)
|
1086 |
+
dataloader.generator = generator
|
1087 |
+
dataloader.sampler.generator = generator
|
1088 |
+
# No change if no multiprocess
|
1089 |
+
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
|
1090 |
+
if isinstance(new_dataset, IterableDataset):
|
1091 |
+
if getattr(dataloader.dataset, "generator", None) is not None:
|
1092 |
+
synchronized_generator = dataloader.dataset.generator
|
1093 |
+
new_dataset = IterableDatasetShard(
|
1094 |
+
new_dataset,
|
1095 |
+
batch_size=dataloader.batch_size,
|
1096 |
+
drop_last=dataloader.drop_last,
|
1097 |
+
num_processes=num_processes,
|
1098 |
+
process_index=process_index,
|
1099 |
+
split_batches=split_batches,
|
1100 |
+
)
|
1101 |
+
else:
|
1102 |
+
if not use_seedable_sampler and hasattr(sampler, "generator"):
|
1103 |
+
if sampler.generator is None:
|
1104 |
+
sampler.generator = torch.Generator()
|
1105 |
+
synchronized_generator = sampler.generator
|
1106 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
1107 |
+
new_batch_sampler = BatchSamplerShard(
|
1108 |
+
batch_sampler,
|
1109 |
+
num_processes=num_processes,
|
1110 |
+
process_index=process_index,
|
1111 |
+
split_batches=split_batches,
|
1112 |
+
even_batches=even_batches,
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
1116 |
+
ignore_kwargs = [
|
1117 |
+
"batch_size",
|
1118 |
+
"shuffle",
|
1119 |
+
"sampler",
|
1120 |
+
"batch_sampler",
|
1121 |
+
"drop_last",
|
1122 |
+
]
|
1123 |
+
|
1124 |
+
if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
|
1125 |
+
rng_types.remove("generator")
|
1126 |
+
|
1127 |
+
kwargs = {
|
1128 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
1129 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
1130 |
+
if k not in ignore_kwargs
|
1131 |
+
}
|
1132 |
+
|
1133 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
1134 |
+
if new_batch_sampler is None:
|
1135 |
+
kwargs["drop_last"] = dataloader.drop_last
|
1136 |
+
kwargs["batch_size"] = (
|
1137 |
+
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
|
1138 |
+
)
|
1139 |
+
if dispatch_batches:
|
1140 |
+
kwargs.pop("generator")
|
1141 |
+
dataloader = DataLoaderDispatcher(
|
1142 |
+
new_dataset,
|
1143 |
+
split_batches=split_batches,
|
1144 |
+
batch_sampler=new_batch_sampler,
|
1145 |
+
_drop_last=dataloader.drop_last,
|
1146 |
+
_non_blocking=non_blocking,
|
1147 |
+
slice_fn=slice_fn_for_dispatch,
|
1148 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
1149 |
+
**kwargs,
|
1150 |
+
)
|
1151 |
+
elif sampler_is_batch_sampler:
|
1152 |
+
dataloader = DataLoaderShard(
|
1153 |
+
new_dataset,
|
1154 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
1155 |
+
sampler=new_batch_sampler,
|
1156 |
+
batch_size=dataloader.batch_size,
|
1157 |
+
rng_types=rng_types,
|
1158 |
+
_drop_last=dataloader.drop_last,
|
1159 |
+
_non_blocking=non_blocking,
|
1160 |
+
synchronized_generator=synchronized_generator,
|
1161 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
1162 |
+
**kwargs,
|
1163 |
+
)
|
1164 |
+
else:
|
1165 |
+
dataloader = DataLoaderShard(
|
1166 |
+
new_dataset,
|
1167 |
+
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
|
1168 |
+
batch_sampler=new_batch_sampler,
|
1169 |
+
rng_types=rng_types,
|
1170 |
+
synchronized_generator=synchronized_generator,
|
1171 |
+
_drop_last=dataloader.drop_last,
|
1172 |
+
_non_blocking=non_blocking,
|
1173 |
+
use_stateful_dataloader=use_stateful_dataloader,
|
1174 |
+
**kwargs,
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
|
1178 |
+
dataloader.set_sampler(sampler)
|
1179 |
+
if state.distributed_type == DistributedType.XLA:
|
1180 |
+
return MpDeviceLoaderWrapper(dataloader, device)
|
1181 |
+
return dataloader
|
1182 |
+
|
1183 |
+
|
1184 |
+
class SkipBatchSampler(BatchSampler):
|
1185 |
+
"""
|
1186 |
+
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
1187 |
+
Should not be used if the original dataloader is a `StatefulDataLoader`.
|
1188 |
+
"""
|
1189 |
+
|
1190 |
+
def __init__(self, batch_sampler, skip_batches=0):
|
1191 |
+
self.batch_sampler = batch_sampler
|
1192 |
+
self.skip_batches = skip_batches
|
1193 |
+
|
1194 |
+
def __iter__(self):
|
1195 |
+
for index, samples in enumerate(self.batch_sampler):
|
1196 |
+
if index >= self.skip_batches:
|
1197 |
+
yield samples
|
1198 |
+
|
1199 |
+
@property
|
1200 |
+
def total_length(self):
|
1201 |
+
return len(self.batch_sampler)
|
1202 |
+
|
1203 |
+
def __len__(self):
|
1204 |
+
return len(self.batch_sampler) - self.skip_batches
|
1205 |
+
|
1206 |
+
|
1207 |
+
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
1208 |
+
"""
|
1209 |
+
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
|
1210 |
+
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
|
1211 |
+
|
1212 |
+
Args:
|
1213 |
+
dataset (`torch.utils.data.dataset.Dataset`):
|
1214 |
+
The dataset to use to build this dataloader.
|
1215 |
+
skip_batches (`int`, *optional*, defaults to 0):
|
1216 |
+
The number of batches to skip at the beginning.
|
1217 |
+
kwargs:
|
1218 |
+
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
1219 |
+
"""
|
1220 |
+
|
1221 |
+
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
|
1222 |
+
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
1223 |
+
self.skip_batches = skip_batches
|
1224 |
+
self.gradient_state = GradientState()
|
1225 |
+
|
1226 |
+
def __iter__(self):
|
1227 |
+
self.begin()
|
1228 |
+
for index, batch in enumerate(self.base_dataloader.__iter__()):
|
1229 |
+
if index >= self.skip_batches:
|
1230 |
+
self._update_state_dict()
|
1231 |
+
yield batch
|
1232 |
+
self.end()
|
1233 |
+
|
1234 |
+
def __len__(self):
|
1235 |
+
return len(self.base_dataloader) - self.skip_batches
|
1236 |
+
|
1237 |
+
def __reduce__(self):
|
1238 |
+
"""
|
1239 |
+
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
|
1240 |
+
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
1241 |
+
`__class__` member.
|
1242 |
+
"""
|
1243 |
+
args = super().__reduce__()
|
1244 |
+
return (SkipDataLoader, *args[1:])
|
1245 |
+
|
1246 |
+
|
1247 |
+
def skip_first_batches(dataloader, num_batches=0):
|
1248 |
+
"""
|
1249 |
+
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
|
1250 |
+
the original dataloader is a `StatefulDataLoader`.
|
1251 |
+
"""
|
1252 |
+
state = PartialState()
|
1253 |
+
if state.distributed_type == DistributedType.XLA:
|
1254 |
+
device = dataloader.device
|
1255 |
+
dataloader = dataloader.dataloader
|
1256 |
+
|
1257 |
+
dataset = dataloader.dataset
|
1258 |
+
sampler_is_batch_sampler = False
|
1259 |
+
if isinstance(dataset, IterableDataset):
|
1260 |
+
new_batch_sampler = None
|
1261 |
+
else:
|
1262 |
+
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
|
1263 |
+
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
|
1264 |
+
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
|
1265 |
+
|
1266 |
+
# We ignore all of those since they are all dealt with by our new_batch_sampler
|
1267 |
+
ignore_kwargs = [
|
1268 |
+
"batch_size",
|
1269 |
+
"shuffle",
|
1270 |
+
"sampler",
|
1271 |
+
"batch_sampler",
|
1272 |
+
"drop_last",
|
1273 |
+
]
|
1274 |
+
|
1275 |
+
kwargs = {
|
1276 |
+
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
|
1277 |
+
for k in _PYTORCH_DATALOADER_KWARGS
|
1278 |
+
if k not in ignore_kwargs
|
1279 |
+
}
|
1280 |
+
|
1281 |
+
# Need to provide batch_size as batch_sampler is None for Iterable dataset
|
1282 |
+
if new_batch_sampler is None:
|
1283 |
+
kwargs["drop_last"] = dataloader.drop_last
|
1284 |
+
kwargs["batch_size"] = dataloader.batch_size
|
1285 |
+
|
1286 |
+
if isinstance(dataloader, DataLoaderDispatcher):
|
1287 |
+
if new_batch_sampler is None:
|
1288 |
+
# Need to manually skip batches in the dataloader
|
1289 |
+
kwargs["skip_batches"] = num_batches
|
1290 |
+
dataloader = DataLoaderDispatcher(
|
1291 |
+
dataset,
|
1292 |
+
split_batches=dataloader.split_batches,
|
1293 |
+
batch_sampler=new_batch_sampler,
|
1294 |
+
_drop_last=dataloader._drop_last,
|
1295 |
+
**kwargs,
|
1296 |
+
)
|
1297 |
+
elif isinstance(dataloader, DataLoaderShard):
|
1298 |
+
if new_batch_sampler is None:
|
1299 |
+
# Need to manually skip batches in the dataloader
|
1300 |
+
kwargs["skip_batches"] = num_batches
|
1301 |
+
elif sampler_is_batch_sampler:
|
1302 |
+
kwargs["sampler"] = new_batch_sampler
|
1303 |
+
kwargs["batch_size"] = dataloader.batch_size
|
1304 |
+
else:
|
1305 |
+
kwargs["batch_sampler"] = new_batch_sampler
|
1306 |
+
dataloader = DataLoaderShard(
|
1307 |
+
dataset,
|
1308 |
+
device=dataloader.device,
|
1309 |
+
rng_types=dataloader.rng_types,
|
1310 |
+
synchronized_generator=dataloader.synchronized_generator,
|
1311 |
+
**kwargs,
|
1312 |
+
)
|
1313 |
+
else:
|
1314 |
+
if new_batch_sampler is None:
|
1315 |
+
# Need to manually skip batches in the dataloader
|
1316 |
+
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
|
1317 |
+
else:
|
1318 |
+
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
|
1319 |
+
|
1320 |
+
if state.distributed_type == DistributedType.XLA:
|
1321 |
+
dataloader = MpDeviceLoaderWrapper(dataloader, device)
|
1322 |
+
|
1323 |
+
return dataloader
|
.venv/Lib/site-packages/accelerate/hooks.py
ADDED
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
from typing import Dict, List, Mapping, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from .state import PartialState
|
22 |
+
from .utils import (
|
23 |
+
PrefixedDataset,
|
24 |
+
find_device,
|
25 |
+
named_module_tensors,
|
26 |
+
send_to_device,
|
27 |
+
set_module_tensor_to_device,
|
28 |
+
)
|
29 |
+
from .utils.memory import clear_device_cache
|
30 |
+
from .utils.modeling import get_non_persistent_buffers
|
31 |
+
from .utils.other import recursive_getattr
|
32 |
+
|
33 |
+
|
34 |
+
_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "musa"]
|
35 |
+
|
36 |
+
|
37 |
+
class ModelHook:
|
38 |
+
"""
|
39 |
+
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
|
40 |
+
with PyTorch existing hooks is that they get passed along the kwargs.
|
41 |
+
|
42 |
+
Class attribute:
|
43 |
+
- **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
|
44 |
+
the `torch.no_grad()` context manager.
|
45 |
+
"""
|
46 |
+
|
47 |
+
no_grad = False
|
48 |
+
|
49 |
+
def init_hook(self, module):
|
50 |
+
"""
|
51 |
+
To be executed when the hook is attached to the module.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
module (`torch.nn.Module`): The module attached to this hook.
|
55 |
+
"""
|
56 |
+
return module
|
57 |
+
|
58 |
+
def pre_forward(self, module, *args, **kwargs):
|
59 |
+
"""
|
60 |
+
To be executed just before the forward method of the model.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
|
64 |
+
args (`Tuple[Any]`): The positional arguments passed to the module.
|
65 |
+
kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
`Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
|
69 |
+
"""
|
70 |
+
return args, kwargs
|
71 |
+
|
72 |
+
def post_forward(self, module, output):
|
73 |
+
"""
|
74 |
+
To be executed just after the forward method of the model.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
|
78 |
+
output (`Any`): The output of the module.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
`Any`: The processed `output`.
|
82 |
+
"""
|
83 |
+
return output
|
84 |
+
|
85 |
+
def detach_hook(self, module):
|
86 |
+
"""
|
87 |
+
To be executed when the hook is detached from a module.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
module (`torch.nn.Module`): The module detached from this hook.
|
91 |
+
"""
|
92 |
+
return module
|
93 |
+
|
94 |
+
|
95 |
+
class SequentialHook(ModelHook):
|
96 |
+
"""
|
97 |
+
A hook that can contain several hooks and iterates through them at each event.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, *hooks):
|
101 |
+
self.hooks = hooks
|
102 |
+
|
103 |
+
def init_hook(self, module):
|
104 |
+
for hook in self.hooks:
|
105 |
+
module = hook.init_hook(module)
|
106 |
+
return module
|
107 |
+
|
108 |
+
def pre_forward(self, module, *args, **kwargs):
|
109 |
+
for hook in self.hooks:
|
110 |
+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
111 |
+
return args, kwargs
|
112 |
+
|
113 |
+
def post_forward(self, module, output):
|
114 |
+
for hook in self.hooks:
|
115 |
+
output = hook.post_forward(module, output)
|
116 |
+
return output
|
117 |
+
|
118 |
+
def detach_hook(self, module):
|
119 |
+
for hook in self.hooks:
|
120 |
+
module = hook.detach_hook(module)
|
121 |
+
return module
|
122 |
+
|
123 |
+
|
124 |
+
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
|
125 |
+
"""
|
126 |
+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
127 |
+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
128 |
+
|
129 |
+
<Tip warning={true}>
|
130 |
+
|
131 |
+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
132 |
+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
133 |
+
|
134 |
+
</Tip>
|
135 |
+
|
136 |
+
Args:
|
137 |
+
module (`torch.nn.Module`):
|
138 |
+
The module to attach a hook to.
|
139 |
+
hook (`ModelHook`):
|
140 |
+
The hook to attach.
|
141 |
+
append (`bool`, *optional*, defaults to `False`):
|
142 |
+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
|
146 |
+
be discarded).
|
147 |
+
"""
|
148 |
+
|
149 |
+
if append and (getattr(module, "_hf_hook", None) is not None):
|
150 |
+
old_hook = module._hf_hook
|
151 |
+
remove_hook_from_module(module)
|
152 |
+
hook = SequentialHook(old_hook, hook)
|
153 |
+
|
154 |
+
if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
|
155 |
+
# If we already put some hook on this module, we replace it with the new one.
|
156 |
+
old_forward = module._old_forward
|
157 |
+
else:
|
158 |
+
old_forward = module.forward
|
159 |
+
module._old_forward = old_forward
|
160 |
+
|
161 |
+
module = hook.init_hook(module)
|
162 |
+
module._hf_hook = hook
|
163 |
+
|
164 |
+
def new_forward(module, *args, **kwargs):
|
165 |
+
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
|
166 |
+
if module._hf_hook.no_grad:
|
167 |
+
with torch.no_grad():
|
168 |
+
output = module._old_forward(*args, **kwargs)
|
169 |
+
else:
|
170 |
+
output = module._old_forward(*args, **kwargs)
|
171 |
+
return module._hf_hook.post_forward(module, output)
|
172 |
+
|
173 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
174 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
175 |
+
if "GraphModuleImpl" in str(type(module)):
|
176 |
+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
177 |
+
else:
|
178 |
+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
179 |
+
|
180 |
+
return module
|
181 |
+
|
182 |
+
|
183 |
+
def remove_hook_from_module(module: nn.Module, recurse=False):
|
184 |
+
"""
|
185 |
+
Removes any hook attached to a module via `add_hook_to_module`.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
module (`torch.nn.Module`): The module to attach a hook to.
|
189 |
+
recurse (`bool`, **optional**): Whether to remove the hooks recursively
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
`torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
|
193 |
+
be discarded).
|
194 |
+
"""
|
195 |
+
|
196 |
+
if hasattr(module, "_hf_hook"):
|
197 |
+
module._hf_hook.detach_hook(module)
|
198 |
+
delattr(module, "_hf_hook")
|
199 |
+
|
200 |
+
if hasattr(module, "_old_forward"):
|
201 |
+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
202 |
+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
203 |
+
if "GraphModuleImpl" in str(type(module)):
|
204 |
+
module.__class__.forward = module._old_forward
|
205 |
+
else:
|
206 |
+
module.forward = module._old_forward
|
207 |
+
delattr(module, "_old_forward")
|
208 |
+
|
209 |
+
# Remove accelerate added warning hooks from dispatch_model
|
210 |
+
for attr in _accelerate_added_attributes:
|
211 |
+
module.__dict__.pop(attr, None)
|
212 |
+
|
213 |
+
if recurse:
|
214 |
+
for child in module.children():
|
215 |
+
remove_hook_from_module(child, recurse)
|
216 |
+
|
217 |
+
return module
|
218 |
+
|
219 |
+
|
220 |
+
class AlignDevicesHook(ModelHook):
|
221 |
+
"""
|
222 |
+
A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
|
223 |
+
associated module, potentially offloading the weights after the forward pass.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
execution_device (`torch.device`, *optional*):
|
227 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
228 |
+
offload (`bool`, *optional*, defaults to `False`):
|
229 |
+
Whether or not the weights should be offloaded after the forward pass.
|
230 |
+
io_same_device (`bool`, *optional*, defaults to `False`):
|
231 |
+
Whether or not the output should be placed on the same device as the input was.
|
232 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
233 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
234 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
235 |
+
Whether or not to include the associated module's buffers when offloading.
|
236 |
+
place_submodules (`bool`, *optional*, defaults to `False`):
|
237 |
+
Whether to place the submodules on `execution_device` during the `init_hook` event.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
execution_device: Optional[Union[int, str, torch.device]] = None,
|
243 |
+
offload: bool = False,
|
244 |
+
io_same_device: bool = False,
|
245 |
+
weights_map: Optional[Mapping] = None,
|
246 |
+
offload_buffers: bool = False,
|
247 |
+
place_submodules: bool = False,
|
248 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
249 |
+
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
250 |
+
):
|
251 |
+
self.execution_device = execution_device
|
252 |
+
self.offload = offload
|
253 |
+
self.io_same_device = io_same_device
|
254 |
+
self.weights_map = weights_map
|
255 |
+
self.offload_buffers = offload_buffers
|
256 |
+
self.place_submodules = place_submodules
|
257 |
+
self.skip_keys = skip_keys
|
258 |
+
|
259 |
+
# Will contain the input device when `io_same_device=True`.
|
260 |
+
self.input_device = None
|
261 |
+
self.param_original_devices = {}
|
262 |
+
self.buffer_original_devices = {}
|
263 |
+
self.tied_params_names = set()
|
264 |
+
|
265 |
+
# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
|
266 |
+
# for tied weights already loaded on the target execution device.
|
267 |
+
self.tied_params_map = tied_params_map
|
268 |
+
|
269 |
+
def __repr__(self):
|
270 |
+
return (
|
271 |
+
f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
|
272 |
+
f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
|
273 |
+
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
|
274 |
+
)
|
275 |
+
|
276 |
+
def init_hook(self, module):
|
277 |
+
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
|
278 |
+
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
|
279 |
+
self.tied_params_map = None
|
280 |
+
|
281 |
+
if not self.offload and self.execution_device is not None:
|
282 |
+
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
283 |
+
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
|
284 |
+
elif self.offload:
|
285 |
+
self.original_devices = {
|
286 |
+
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
|
287 |
+
}
|
288 |
+
if self.weights_map is None:
|
289 |
+
self.weights_map = {
|
290 |
+
name: param.to("cpu")
|
291 |
+
for name, param in named_module_tensors(
|
292 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
293 |
+
)
|
294 |
+
}
|
295 |
+
for name, _ in named_module_tensors(
|
296 |
+
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
|
297 |
+
):
|
298 |
+
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
|
299 |
+
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
300 |
+
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
|
301 |
+
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
|
302 |
+
if (
|
303 |
+
self.tied_params_map is not None
|
304 |
+
and recursive_getattr(module, name).data_ptr() in self.tied_params_map
|
305 |
+
):
|
306 |
+
self.tied_params_names.add(name)
|
307 |
+
|
308 |
+
set_module_tensor_to_device(module, name, "meta")
|
309 |
+
|
310 |
+
if not self.offload_buffers and self.execution_device is not None:
|
311 |
+
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
312 |
+
set_module_tensor_to_device(
|
313 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
314 |
+
)
|
315 |
+
elif self.offload_buffers and self.execution_device is not None:
|
316 |
+
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
|
317 |
+
set_module_tensor_to_device(
|
318 |
+
module, name, self.execution_device, tied_params_map=self.tied_params_map
|
319 |
+
)
|
320 |
+
|
321 |
+
return module
|
322 |
+
|
323 |
+
def pre_forward(self, module, *args, **kwargs):
|
324 |
+
if self.io_same_device:
|
325 |
+
self.input_device = find_device([args, kwargs])
|
326 |
+
if self.offload:
|
327 |
+
self.tied_pointers_to_remove = set()
|
328 |
+
|
329 |
+
for name, _ in named_module_tensors(
|
330 |
+
module,
|
331 |
+
include_buffers=self.offload_buffers,
|
332 |
+
recurse=self.place_submodules,
|
333 |
+
remove_non_persistent=True,
|
334 |
+
):
|
335 |
+
fp16_statistics = None
|
336 |
+
value = self.weights_map[name]
|
337 |
+
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
|
338 |
+
if value.dtype == torch.int8:
|
339 |
+
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
|
340 |
+
|
341 |
+
# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
|
342 |
+
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
|
343 |
+
# self.tied_params_map in order to allow to free memory.
|
344 |
+
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
|
345 |
+
self.tied_params_map[value.data_ptr()] = {}
|
346 |
+
|
347 |
+
if (
|
348 |
+
value is not None
|
349 |
+
and self.tied_params_map is not None
|
350 |
+
and value.data_ptr() in self.tied_params_map
|
351 |
+
and self.execution_device not in self.tied_params_map[value.data_ptr()]
|
352 |
+
):
|
353 |
+
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
|
354 |
+
|
355 |
+
set_module_tensor_to_device(
|
356 |
+
module,
|
357 |
+
name,
|
358 |
+
self.execution_device,
|
359 |
+
value=value,
|
360 |
+
fp16_statistics=fp16_statistics,
|
361 |
+
tied_params_map=self.tied_params_map,
|
362 |
+
)
|
363 |
+
|
364 |
+
return send_to_device(args, self.execution_device), send_to_device(
|
365 |
+
kwargs, self.execution_device, skip_keys=self.skip_keys
|
366 |
+
)
|
367 |
+
|
368 |
+
def post_forward(self, module, output):
|
369 |
+
if self.offload:
|
370 |
+
for name, _ in named_module_tensors(
|
371 |
+
module,
|
372 |
+
include_buffers=self.offload_buffers,
|
373 |
+
recurse=self.place_submodules,
|
374 |
+
remove_non_persistent=True,
|
375 |
+
):
|
376 |
+
set_module_tensor_to_device(module, name, "meta")
|
377 |
+
if type(module).__name__ == "Linear8bitLt":
|
378 |
+
module.state.SCB = None
|
379 |
+
module.state.CxB = None
|
380 |
+
|
381 |
+
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
|
382 |
+
# this dictionary to allow the garbage collector to do its job.
|
383 |
+
for value_pointer, device in self.tied_pointers_to_remove:
|
384 |
+
del self.tied_params_map[value_pointer][device]
|
385 |
+
self.tied_pointers_to_remove = set()
|
386 |
+
|
387 |
+
if self.io_same_device and self.input_device is not None:
|
388 |
+
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
|
389 |
+
|
390 |
+
return output
|
391 |
+
|
392 |
+
def detach_hook(self, module):
|
393 |
+
if self.offload:
|
394 |
+
for name, device in self.original_devices.items():
|
395 |
+
if device != torch.device("meta"):
|
396 |
+
set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
|
397 |
+
return module
|
398 |
+
|
399 |
+
|
400 |
+
def attach_execution_device_hook(
|
401 |
+
module: torch.nn.Module,
|
402 |
+
execution_device: Union[int, str, torch.device],
|
403 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
404 |
+
preload_module_classes: Optional[List[str]] = None,
|
405 |
+
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
406 |
+
):
|
407 |
+
"""
|
408 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
|
409 |
+
execution device
|
410 |
+
|
411 |
+
Args:
|
412 |
+
module (`torch.nn.Module`):
|
413 |
+
The module where we want to attach the hooks.
|
414 |
+
execution_device (`int`, `str` or `torch.device`):
|
415 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
416 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
417 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
418 |
+
preload_module_classes (`List[str]`, *optional*):
|
419 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
420 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
421 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
422 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
423 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
424 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
425 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
426 |
+
instead of duplicating memory.
|
427 |
+
"""
|
428 |
+
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
|
429 |
+
add_hook_to_module(
|
430 |
+
module,
|
431 |
+
AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
|
432 |
+
)
|
433 |
+
|
434 |
+
# Break the recursion if we get to a preload module.
|
435 |
+
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
|
436 |
+
return
|
437 |
+
|
438 |
+
for child in module.children():
|
439 |
+
attach_execution_device_hook(
|
440 |
+
child,
|
441 |
+
execution_device,
|
442 |
+
skip_keys=skip_keys,
|
443 |
+
preload_module_classes=preload_module_classes,
|
444 |
+
tied_params_map=tied_params_map,
|
445 |
+
)
|
446 |
+
|
447 |
+
|
448 |
+
def attach_align_device_hook(
|
449 |
+
module: torch.nn.Module,
|
450 |
+
execution_device: Optional[torch.device] = None,
|
451 |
+
offload: bool = False,
|
452 |
+
weights_map: Optional[Mapping] = None,
|
453 |
+
offload_buffers: bool = False,
|
454 |
+
module_name: str = "",
|
455 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
456 |
+
preload_module_classes: Optional[List[str]] = None,
|
457 |
+
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
458 |
+
):
|
459 |
+
"""
|
460 |
+
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
|
461 |
+
buffers.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
module (`torch.nn.Module`):
|
465 |
+
The module where we want to attach the hooks.
|
466 |
+
execution_device (`torch.device`, *optional*):
|
467 |
+
The device on which inputs and model weights should be placed before the forward pass.
|
468 |
+
offload (`bool`, *optional*, defaults to `False`):
|
469 |
+
Whether or not the weights should be offloaded after the forward pass.
|
470 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
471 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
472 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
473 |
+
Whether or not to include the associated module's buffers when offloading.
|
474 |
+
module_name (`str`, *optional*, defaults to `""`):
|
475 |
+
The name of the module.
|
476 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
477 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
478 |
+
preload_module_classes (`List[str]`, *optional*):
|
479 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
480 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
481 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
482 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
483 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
484 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
485 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
486 |
+
instead of duplicating memory.
|
487 |
+
"""
|
488 |
+
# Attach the hook on this module if it has any direct tensor.
|
489 |
+
directs = named_module_tensors(module)
|
490 |
+
full_offload = (
|
491 |
+
offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
|
492 |
+
)
|
493 |
+
|
494 |
+
if len(list(directs)) > 0 or full_offload:
|
495 |
+
if weights_map is not None:
|
496 |
+
prefix = f"{module_name}." if len(module_name) > 0 else ""
|
497 |
+
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
|
498 |
+
else:
|
499 |
+
prefixed_weights_map = None
|
500 |
+
hook = AlignDevicesHook(
|
501 |
+
execution_device=execution_device,
|
502 |
+
offload=offload,
|
503 |
+
weights_map=prefixed_weights_map,
|
504 |
+
offload_buffers=offload_buffers,
|
505 |
+
place_submodules=full_offload,
|
506 |
+
skip_keys=skip_keys,
|
507 |
+
tied_params_map=tied_params_map,
|
508 |
+
)
|
509 |
+
add_hook_to_module(module, hook, append=True)
|
510 |
+
|
511 |
+
# We stop the recursion in case we hit the full offload.
|
512 |
+
if full_offload:
|
513 |
+
return
|
514 |
+
|
515 |
+
# Recurse on all children of the module.
|
516 |
+
for child_name, child in module.named_children():
|
517 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
518 |
+
attach_align_device_hook(
|
519 |
+
child,
|
520 |
+
execution_device=execution_device,
|
521 |
+
offload=offload,
|
522 |
+
weights_map=weights_map,
|
523 |
+
offload_buffers=offload_buffers,
|
524 |
+
module_name=child_name,
|
525 |
+
preload_module_classes=preload_module_classes,
|
526 |
+
skip_keys=skip_keys,
|
527 |
+
tied_params_map=tied_params_map,
|
528 |
+
)
|
529 |
+
|
530 |
+
|
531 |
+
def remove_hook_from_submodules(module: nn.Module):
|
532 |
+
"""
|
533 |
+
Recursively removes all hooks attached on the submodules of a given model.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
module (`torch.nn.Module`): The module on which to remove all hooks.
|
537 |
+
"""
|
538 |
+
remove_hook_from_module(module)
|
539 |
+
for child in module.children():
|
540 |
+
remove_hook_from_submodules(child)
|
541 |
+
|
542 |
+
|
543 |
+
def attach_align_device_hook_on_blocks(
|
544 |
+
module: nn.Module,
|
545 |
+
execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None,
|
546 |
+
offload: Union[bool, Dict[str, bool]] = False,
|
547 |
+
weights_map: Mapping = None,
|
548 |
+
offload_buffers: bool = False,
|
549 |
+
module_name: str = "",
|
550 |
+
skip_keys: Optional[Union[str, List[str]]] = None,
|
551 |
+
preload_module_classes: Optional[List[str]] = None,
|
552 |
+
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
553 |
+
):
|
554 |
+
"""
|
555 |
+
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
|
556 |
+
|
557 |
+
Args:
|
558 |
+
module (`torch.nn.Module`):
|
559 |
+
The module where we want to attach the hooks.
|
560 |
+
execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
|
561 |
+
The device on which inputs and model weights should be placed before the forward pass. It can be one device
|
562 |
+
for the whole module, or a dictionary mapping module name to device.
|
563 |
+
offload (`bool`, *optional*, defaults to `False`):
|
564 |
+
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
|
565 |
+
module, or a dictionary mapping module name to boolean.
|
566 |
+
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
567 |
+
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
568 |
+
offload_buffers (`bool`, *optional*, defaults to `False`):
|
569 |
+
Whether or not to include the associated module's buffers when offloading.
|
570 |
+
module_name (`str`, *optional*, defaults to `""`):
|
571 |
+
The name of the module.
|
572 |
+
skip_keys (`str` or `List[str]`, *optional*):
|
573 |
+
A list of keys to ignore when moving inputs or outputs between devices.
|
574 |
+
preload_module_classes (`List[str]`, *optional*):
|
575 |
+
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
|
576 |
+
of the forward. This should only be used for classes that have submodules which are registered but not
|
577 |
+
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
578 |
+
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
579 |
+
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
580 |
+
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
581 |
+
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
582 |
+
instead of duplicating memory.
|
583 |
+
"""
|
584 |
+
# If one device and one offload, we've got one hook.
|
585 |
+
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
|
586 |
+
if not offload:
|
587 |
+
hook = AlignDevicesHook(
|
588 |
+
execution_device=execution_device,
|
589 |
+
io_same_device=True,
|
590 |
+
skip_keys=skip_keys,
|
591 |
+
place_submodules=True,
|
592 |
+
tied_params_map=tied_params_map,
|
593 |
+
)
|
594 |
+
add_hook_to_module(module, hook)
|
595 |
+
else:
|
596 |
+
attach_align_device_hook(
|
597 |
+
module,
|
598 |
+
execution_device=execution_device,
|
599 |
+
offload=True,
|
600 |
+
weights_map=weights_map,
|
601 |
+
offload_buffers=offload_buffers,
|
602 |
+
module_name=module_name,
|
603 |
+
skip_keys=skip_keys,
|
604 |
+
tied_params_map=tied_params_map,
|
605 |
+
)
|
606 |
+
return
|
607 |
+
|
608 |
+
if not isinstance(execution_device, Mapping):
|
609 |
+
execution_device = {key: execution_device for key in offload.keys()}
|
610 |
+
if not isinstance(offload, Mapping):
|
611 |
+
offload = {key: offload for key in execution_device.keys()}
|
612 |
+
|
613 |
+
if module_name in execution_device and module_name in offload and not offload[module_name]:
|
614 |
+
hook = AlignDevicesHook(
|
615 |
+
execution_device=execution_device[module_name],
|
616 |
+
offload_buffers=offload_buffers,
|
617 |
+
io_same_device=(module_name == ""),
|
618 |
+
place_submodules=True,
|
619 |
+
skip_keys=skip_keys,
|
620 |
+
tied_params_map=tied_params_map,
|
621 |
+
)
|
622 |
+
add_hook_to_module(module, hook)
|
623 |
+
attach_execution_device_hook(
|
624 |
+
module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
|
625 |
+
)
|
626 |
+
elif module_name in execution_device and module_name in offload:
|
627 |
+
attach_align_device_hook(
|
628 |
+
module,
|
629 |
+
execution_device=execution_device[module_name],
|
630 |
+
offload=True,
|
631 |
+
weights_map=weights_map,
|
632 |
+
offload_buffers=offload_buffers,
|
633 |
+
module_name=module_name,
|
634 |
+
skip_keys=skip_keys,
|
635 |
+
preload_module_classes=preload_module_classes,
|
636 |
+
tied_params_map=tied_params_map,
|
637 |
+
)
|
638 |
+
if not hasattr(module, "_hf_hook"):
|
639 |
+
hook = AlignDevicesHook(
|
640 |
+
execution_device=execution_device[module_name],
|
641 |
+
io_same_device=(module_name == ""),
|
642 |
+
skip_keys=skip_keys,
|
643 |
+
tied_params_map=tied_params_map,
|
644 |
+
)
|
645 |
+
add_hook_to_module(module, hook)
|
646 |
+
attach_execution_device_hook(
|
647 |
+
module,
|
648 |
+
execution_device[module_name],
|
649 |
+
preload_module_classes=preload_module_classes,
|
650 |
+
skip_keys=skip_keys,
|
651 |
+
tied_params_map=tied_params_map,
|
652 |
+
)
|
653 |
+
elif module_name == "":
|
654 |
+
hook = AlignDevicesHook(
|
655 |
+
execution_device=execution_device.get(""),
|
656 |
+
io_same_device=True,
|
657 |
+
skip_keys=skip_keys,
|
658 |
+
tied_params_map=tied_params_map,
|
659 |
+
)
|
660 |
+
add_hook_to_module(module, hook)
|
661 |
+
|
662 |
+
for child_name, child in module.named_children():
|
663 |
+
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
664 |
+
attach_align_device_hook_on_blocks(
|
665 |
+
child,
|
666 |
+
execution_device=execution_device,
|
667 |
+
offload=offload,
|
668 |
+
weights_map=weights_map,
|
669 |
+
offload_buffers=offload_buffers,
|
670 |
+
module_name=child_name,
|
671 |
+
preload_module_classes=preload_module_classes,
|
672 |
+
skip_keys=skip_keys,
|
673 |
+
tied_params_map=tied_params_map,
|
674 |
+
)
|
675 |
+
|
676 |
+
|
677 |
+
class CpuOffload(ModelHook):
|
678 |
+
"""
|
679 |
+
Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
|
680 |
+
the forward, the user needs to call the `init_hook` method again for this.
|
681 |
+
|
682 |
+
Args:
|
683 |
+
execution_device(`str`, `int` or `torch.device`, *optional*):
|
684 |
+
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
685 |
+
GPU 0 if there is a GPU, and finally to the CPU.
|
686 |
+
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
687 |
+
The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
|
688 |
+
passed, its offload method will be called just before the forward of the model to which this hook is
|
689 |
+
attached.
|
690 |
+
"""
|
691 |
+
|
692 |
+
def __init__(
|
693 |
+
self,
|
694 |
+
execution_device: Optional[Union[str, int, torch.device]] = None,
|
695 |
+
prev_module_hook: Optional["UserCpuOffloadHook"] = None,
|
696 |
+
):
|
697 |
+
self.prev_module_hook = prev_module_hook
|
698 |
+
|
699 |
+
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
700 |
+
|
701 |
+
def init_hook(self, module):
|
702 |
+
return module.to("cpu")
|
703 |
+
|
704 |
+
def pre_forward(self, module, *args, **kwargs):
|
705 |
+
if self.prev_module_hook is not None:
|
706 |
+
self.prev_module_hook.offload()
|
707 |
+
clear_device_cache()
|
708 |
+
module.to(self.execution_device)
|
709 |
+
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
710 |
+
|
711 |
+
|
712 |
+
class UserCpuOffloadHook:
|
713 |
+
"""
|
714 |
+
A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
|
715 |
+
or remove it entirely.
|
716 |
+
"""
|
717 |
+
|
718 |
+
def __init__(self, model, hook):
|
719 |
+
self.model = model
|
720 |
+
self.hook = hook
|
721 |
+
|
722 |
+
def offload(self):
|
723 |
+
self.hook.init_hook(self.model)
|
724 |
+
|
725 |
+
def remove(self):
|
726 |
+
remove_hook_from_module(self.model)
|
.venv/Lib/site-packages/accelerate/inference.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from types import MethodType
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
from .state import PartialState
|
19 |
+
from .utils import (
|
20 |
+
calculate_maximum_sizes,
|
21 |
+
convert_bytes,
|
22 |
+
copy_tensor_to_devices,
|
23 |
+
ignorant_find_batch_size,
|
24 |
+
infer_auto_device_map,
|
25 |
+
is_pippy_available,
|
26 |
+
pad_input_tensors,
|
27 |
+
send_to_device,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
|
32 |
+
"""
|
33 |
+
Calculates the device map for `model` with an offset for PiPPy
|
34 |
+
"""
|
35 |
+
if num_processes == 1:
|
36 |
+
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
|
37 |
+
if max_memory is None:
|
38 |
+
model_size, shared = calculate_maximum_sizes(model)
|
39 |
+
|
40 |
+
# Split into `n` chunks for each GPU
|
41 |
+
memory = (model_size + shared[0]) / num_processes
|
42 |
+
memory = convert_bytes(memory)
|
43 |
+
value, ending = memory.split(" ")
|
44 |
+
|
45 |
+
# Add a chunk to deal with potential extra shared memory instances
|
46 |
+
memory = math.ceil(float(value)) * 1.1
|
47 |
+
memory = f"{memory} {ending}"
|
48 |
+
max_memory = {i: memory for i in range(num_processes)}
|
49 |
+
device_map = infer_auto_device_map(
|
50 |
+
model,
|
51 |
+
max_memory=max_memory,
|
52 |
+
no_split_module_classes=no_split_module_classes,
|
53 |
+
clean_result=False,
|
54 |
+
)
|
55 |
+
return device_map
|
56 |
+
|
57 |
+
|
58 |
+
def find_pippy_batch_size(args, kwargs):
|
59 |
+
found_batch_size = None
|
60 |
+
if args is not None:
|
61 |
+
for arg in args:
|
62 |
+
found_batch_size = ignorant_find_batch_size(arg)
|
63 |
+
if found_batch_size is not None:
|
64 |
+
break
|
65 |
+
if kwargs is not None and found_batch_size is None:
|
66 |
+
for kwarg in kwargs.values():
|
67 |
+
found_batch_size = ignorant_find_batch_size(kwarg)
|
68 |
+
if found_batch_size is not None:
|
69 |
+
break
|
70 |
+
return found_batch_size
|
71 |
+
|
72 |
+
|
73 |
+
def build_pipeline(model, split_points, args, kwargs, num_chunks):
|
74 |
+
"""
|
75 |
+
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
|
76 |
+
in needed `args` and `kwargs` as the model needs on the CPU.
|
77 |
+
|
78 |
+
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
|
79 |
+
`AcceleratorState.num_processes`
|
80 |
+
"""
|
81 |
+
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
82 |
+
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
|
83 |
+
|
84 |
+
# We need to annotate the split points in the model for PiPPy
|
85 |
+
state = PartialState()
|
86 |
+
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
|
87 |
+
pipe = pipeline(
|
88 |
+
model,
|
89 |
+
mb_args=args,
|
90 |
+
mb_kwargs=kwargs,
|
91 |
+
split_spec=split_spec,
|
92 |
+
)
|
93 |
+
stage = pipe.build_stage(state.local_process_index, device=state.device)
|
94 |
+
schedule = ScheduleGPipe(stage, num_chunks)
|
95 |
+
|
96 |
+
return schedule
|
97 |
+
|
98 |
+
|
99 |
+
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
|
100 |
+
state = PartialState()
|
101 |
+
output = None
|
102 |
+
|
103 |
+
if state.num_processes == 1:
|
104 |
+
output = forward(*args, **kwargs)
|
105 |
+
elif state.is_local_main_process:
|
106 |
+
found_batch_size = find_pippy_batch_size(args, kwargs)
|
107 |
+
if found_batch_size is None:
|
108 |
+
raise ValueError("Could not find batch size from args or kwargs")
|
109 |
+
else:
|
110 |
+
if found_batch_size != num_chunks:
|
111 |
+
args = pad_input_tensors(args, found_batch_size, num_chunks)
|
112 |
+
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
|
113 |
+
forward(*args, **kwargs)
|
114 |
+
elif state.is_last_process:
|
115 |
+
output = forward()
|
116 |
+
else:
|
117 |
+
forward()
|
118 |
+
if gather_output:
|
119 |
+
# Each node will get a copy of the full output which is only on the last GPU
|
120 |
+
output = copy_tensor_to_devices(output)
|
121 |
+
return output
|
122 |
+
|
123 |
+
|
124 |
+
def prepare_pippy(
|
125 |
+
model,
|
126 |
+
split_points: Optional[Union[str, List[str]]] = "auto",
|
127 |
+
no_split_module_classes: Optional[List[str]] = None,
|
128 |
+
example_args: Optional[Tuple[Any]] = (),
|
129 |
+
example_kwargs: Optional[Dict[str, Any]] = None,
|
130 |
+
num_chunks: Optional[int] = None,
|
131 |
+
gather_output: Optional[bool] = False,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Wraps `model` for pipeline parallel inference.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
model (`torch.nn.Module`):
|
138 |
+
A model we want to split for pipeline-parallel inference
|
139 |
+
split_points (`str` or `List[str]`, defaults to 'auto'):
|
140 |
+
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
|
141 |
+
split given any model. Should be a list of layer names in the model to split by otherwise.
|
142 |
+
no_split_module_classes (`List[str]`):
|
143 |
+
A list of class names for layers we don't want to be split.
|
144 |
+
example_args (tuple of model inputs):
|
145 |
+
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
|
146 |
+
this method if possible.
|
147 |
+
example_kwargs (dict of model inputs)
|
148 |
+
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
|
149 |
+
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
|
150 |
+
recommended unless the prior condition is true for all cases.
|
151 |
+
num_chunks (`int`, defaults to the number of available GPUs):
|
152 |
+
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
|
153 |
+
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
|
154 |
+
gather_output (`bool`, defaults to `False`):
|
155 |
+
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
|
156 |
+
"""
|
157 |
+
if not is_pippy_available():
|
158 |
+
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
|
159 |
+
state = PartialState()
|
160 |
+
example_args = send_to_device(example_args, "cpu")
|
161 |
+
example_kwargs = send_to_device(example_kwargs, "cpu")
|
162 |
+
if num_chunks is None:
|
163 |
+
num_chunks = state.num_processes
|
164 |
+
if split_points == "auto":
|
165 |
+
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
|
166 |
+
split_points = []
|
167 |
+
for i in range(1, num_chunks):
|
168 |
+
split_points.append(next(k for k, v in device_map.items() if v == i))
|
169 |
+
model.hf_split_points = split_points
|
170 |
+
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
|
171 |
+
model._original_forward = model.forward
|
172 |
+
model._original_call = model.__call__
|
173 |
+
model.pippy_stage = stage
|
174 |
+
model.hf_split_points = split_points
|
175 |
+
|
176 |
+
def forward(*args, **kwargs):
|
177 |
+
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
|
178 |
+
|
179 |
+
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
180 |
+
# Note: creates an infinite recursion loop with `generate`
|
181 |
+
model_forward = MethodType(forward, model)
|
182 |
+
forward.__wrapped__ = model_forward
|
183 |
+
model.forward = forward
|
184 |
+
return model
|
.venv/Lib/site-packages/accelerate/launchers.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import tempfile
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from .state import AcceleratorState, PartialState
|
22 |
+
from .utils import (
|
23 |
+
PrecisionType,
|
24 |
+
PrepareForLaunch,
|
25 |
+
are_libraries_initialized,
|
26 |
+
check_cuda_p2p_ib_support,
|
27 |
+
get_gpu_info,
|
28 |
+
is_mps_available,
|
29 |
+
is_torch_version,
|
30 |
+
patch_environment,
|
31 |
+
)
|
32 |
+
from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
|
33 |
+
|
34 |
+
|
35 |
+
def test_launch():
|
36 |
+
"Verify a `PartialState` can be initialized."
|
37 |
+
_ = PartialState()
|
38 |
+
|
39 |
+
|
40 |
+
def notebook_launcher(
|
41 |
+
function,
|
42 |
+
args=(),
|
43 |
+
num_processes=None,
|
44 |
+
mixed_precision="no",
|
45 |
+
use_port="29500",
|
46 |
+
master_addr="127.0.0.1",
|
47 |
+
node_rank=0,
|
48 |
+
num_nodes=1,
|
49 |
+
rdzv_backend="static",
|
50 |
+
rdzv_endpoint="",
|
51 |
+
rdzv_conf=None,
|
52 |
+
rdzv_id="none",
|
53 |
+
max_restarts=0,
|
54 |
+
monitor_interval=0.1,
|
55 |
+
log_line_prefix_template=None,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Launches a training function, using several processes or multiple nodes if it's possible in the current environment
|
59 |
+
(TPU with multiple cores for instance).
|
60 |
+
|
61 |
+
<Tip warning={true}>
|
62 |
+
|
63 |
+
To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
|
64 |
+
any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
|
65 |
+
|
66 |
+
Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
|
67 |
+
of those calls have been made.
|
68 |
+
|
69 |
+
</Tip>
|
70 |
+
|
71 |
+
Args:
|
72 |
+
function (`Callable`):
|
73 |
+
The training function to execute. If it accepts arguments, the first argument should be the index of the
|
74 |
+
process run.
|
75 |
+
args (`Tuple`):
|
76 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
77 |
+
num_processes (`int`, *optional*):
|
78 |
+
The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
|
79 |
+
the number of GPUs available otherwise.
|
80 |
+
mixed_precision (`str`, *optional*, defaults to `"no"`):
|
81 |
+
If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
|
82 |
+
use_port (`str`, *optional*, defaults to `"29500"`):
|
83 |
+
The port to use to communicate between processes when launching a multi-GPU training.
|
84 |
+
master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
|
85 |
+
The address to use for communication between processes.
|
86 |
+
node_rank (`int`, *optional*, defaults to 0):
|
87 |
+
The rank of the current node.
|
88 |
+
num_nodes (`int`, *optional*, defaults to 1):
|
89 |
+
The number of nodes to use for training.
|
90 |
+
rdzv_backend (`str`, *optional*, defaults to `"static"`):
|
91 |
+
The rendezvous method to use, such as 'static' (the default) or 'c10d'
|
92 |
+
rdzv_endpoint (`str`, *optional*, defaults to `""`):
|
93 |
+
The endpoint of the rdzv sync. storage.
|
94 |
+
rdzv_conf (`Dict`, *optional*, defaults to `None`):
|
95 |
+
Additional rendezvous configuration.
|
96 |
+
rdzv_id (`str`, *optional*, defaults to `"none"`):
|
97 |
+
The unique run id of the job.
|
98 |
+
max_restarts (`int`, *optional*, defaults to 0):
|
99 |
+
The maximum amount of restarts that elastic agent will conduct on workers before failure.
|
100 |
+
monitor_interval (`float`, *optional*, defaults to 0.1):
|
101 |
+
The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
|
102 |
+
log_line_prefix_template (`str`, *optional*, defaults to `None`):
|
103 |
+
The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
|
104 |
+
|
105 |
+
Example:
|
106 |
+
|
107 |
+
```python
|
108 |
+
# Assume this is defined in a Jupyter Notebook on an instance with two GPUs
|
109 |
+
from accelerate import notebook_launcher
|
110 |
+
|
111 |
+
|
112 |
+
def train(*args):
|
113 |
+
# Your training function here
|
114 |
+
...
|
115 |
+
|
116 |
+
|
117 |
+
notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
|
118 |
+
```
|
119 |
+
"""
|
120 |
+
# Are we in a google colab or a Kaggle Kernel?
|
121 |
+
in_colab = False
|
122 |
+
in_kaggle = False
|
123 |
+
if any(key.startswith("KAGGLE") for key in os.environ.keys()):
|
124 |
+
in_kaggle = True
|
125 |
+
elif "IPython" in sys.modules:
|
126 |
+
in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
|
127 |
+
|
128 |
+
try:
|
129 |
+
mixed_precision = PrecisionType(mixed_precision.lower())
|
130 |
+
except ValueError:
|
131 |
+
raise ValueError(
|
132 |
+
f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
|
133 |
+
)
|
134 |
+
|
135 |
+
if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
|
136 |
+
# TPU launch
|
137 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
138 |
+
|
139 |
+
if len(AcceleratorState._shared_state) > 0:
|
140 |
+
raise ValueError(
|
141 |
+
"To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
|
142 |
+
"your training function. Restart your notebook and make sure no cells initializes an "
|
143 |
+
"`Accelerator`."
|
144 |
+
)
|
145 |
+
if num_processes is None:
|
146 |
+
num_processes = 8
|
147 |
+
|
148 |
+
launcher = PrepareForLaunch(function, distributed_type="XLA")
|
149 |
+
print(f"Launching a training on {num_processes} TPU cores.")
|
150 |
+
xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
|
151 |
+
elif in_colab and get_gpu_info()[1] < 2:
|
152 |
+
# No need for a distributed launch otherwise as it's either CPU or one GPU.
|
153 |
+
if torch.cuda.is_available():
|
154 |
+
print("Launching training on one GPU.")
|
155 |
+
else:
|
156 |
+
print("Launching training on one CPU.")
|
157 |
+
function(*args)
|
158 |
+
else:
|
159 |
+
if num_processes is None:
|
160 |
+
raise ValueError(
|
161 |
+
"You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
|
162 |
+
)
|
163 |
+
if node_rank >= num_nodes:
|
164 |
+
raise ValueError("The node_rank must be less than the number of nodes.")
|
165 |
+
if num_processes > 1:
|
166 |
+
# Multi-GPU launch
|
167 |
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
168 |
+
from torch.multiprocessing import start_processes
|
169 |
+
from torch.multiprocessing.spawn import ProcessRaisedException
|
170 |
+
|
171 |
+
if len(AcceleratorState._shared_state) > 0:
|
172 |
+
raise ValueError(
|
173 |
+
"To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
|
174 |
+
"inside your training function. Restart your notebook and make sure no cells initializes an "
|
175 |
+
"`Accelerator`."
|
176 |
+
)
|
177 |
+
# Check for specific libraries known to initialize CUDA that users constantly use
|
178 |
+
problematic_imports = are_libraries_initialized("bitsandbytes")
|
179 |
+
if len(problematic_imports) > 0:
|
180 |
+
err = (
|
181 |
+
"Could not start distributed process. Libraries known to initialize CUDA upon import have been "
|
182 |
+
"imported already. Please keep these imports inside your training function to try and help with this:"
|
183 |
+
)
|
184 |
+
for lib_name in problematic_imports:
|
185 |
+
err += f"\n\t* `{lib_name}`"
|
186 |
+
raise RuntimeError(err)
|
187 |
+
|
188 |
+
patched_env = dict(
|
189 |
+
nproc=num_processes,
|
190 |
+
node_rank=node_rank,
|
191 |
+
world_size=num_nodes * num_processes,
|
192 |
+
master_addr=master_addr,
|
193 |
+
master_port=use_port,
|
194 |
+
mixed_precision=mixed_precision,
|
195 |
+
)
|
196 |
+
|
197 |
+
# Check for CUDA P2P and IB issues
|
198 |
+
if not check_cuda_p2p_ib_support():
|
199 |
+
patched_env["nccl_p2p_disable"] = "1"
|
200 |
+
patched_env["nccl_ib_disable"] = "1"
|
201 |
+
|
202 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
203 |
+
# process here (the other ones will be set be the launcher).
|
204 |
+
with patch_environment(**patched_env):
|
205 |
+
# First dummy launch
|
206 |
+
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
|
207 |
+
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
|
208 |
+
try:
|
209 |
+
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
|
210 |
+
except ProcessRaisedException as e:
|
211 |
+
err = "An issue was found when verifying a stable environment for the notebook launcher."
|
212 |
+
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
213 |
+
raise RuntimeError(
|
214 |
+
f"{err}"
|
215 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
216 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
217 |
+
"which one is problematic and causing CUDA to be initialized."
|
218 |
+
) from e
|
219 |
+
else:
|
220 |
+
raise RuntimeError(f"{err} The following error was raised: {e}") from e
|
221 |
+
# Now the actual launch
|
222 |
+
launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
|
223 |
+
print(f"Launching training on {num_processes} GPUs.")
|
224 |
+
try:
|
225 |
+
if rdzv_conf is None:
|
226 |
+
rdzv_conf = {}
|
227 |
+
if rdzv_backend == "static":
|
228 |
+
rdzv_conf["rank"] = node_rank
|
229 |
+
if not rdzv_endpoint:
|
230 |
+
rdzv_endpoint = f"{master_addr}:{use_port}"
|
231 |
+
launch_config_kwargs = dict(
|
232 |
+
min_nodes=num_nodes,
|
233 |
+
max_nodes=num_nodes,
|
234 |
+
nproc_per_node=num_processes,
|
235 |
+
run_id=rdzv_id,
|
236 |
+
rdzv_endpoint=rdzv_endpoint,
|
237 |
+
rdzv_backend=rdzv_backend,
|
238 |
+
rdzv_configs=rdzv_conf,
|
239 |
+
max_restarts=max_restarts,
|
240 |
+
monitor_interval=monitor_interval,
|
241 |
+
start_method="fork",
|
242 |
+
)
|
243 |
+
if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
|
244 |
+
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
|
245 |
+
elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
|
246 |
+
except ProcessRaisedException as e:
|
247 |
+
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
248 |
+
raise RuntimeError(
|
249 |
+
"CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
|
250 |
+
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
251 |
+
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
252 |
+
"which one is problematic and causing CUDA to be initialized."
|
253 |
+
) from e
|
254 |
+
else:
|
255 |
+
raise RuntimeError(f"An issue was found when launching the training: {e}") from e
|
256 |
+
|
257 |
+
else:
|
258 |
+
# No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
|
259 |
+
if is_mps_available():
|
260 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
261 |
+
print("Launching training on MPS.")
|
262 |
+
elif torch.cuda.is_available():
|
263 |
+
print("Launching training on one GPU.")
|
264 |
+
else:
|
265 |
+
print("Launching training on CPU.")
|
266 |
+
function(*args)
|
267 |
+
|
268 |
+
|
269 |
+
def debug_launcher(function, args=(), num_processes=2):
|
270 |
+
"""
|
271 |
+
Launches a training function using several processes on CPU for debugging purposes.
|
272 |
+
|
273 |
+
<Tip warning={true}>
|
274 |
+
|
275 |
+
This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
|
276 |
+
only use the CPU.
|
277 |
+
|
278 |
+
</Tip>
|
279 |
+
|
280 |
+
Args:
|
281 |
+
function (`Callable`):
|
282 |
+
The training function to execute.
|
283 |
+
args (`Tuple`):
|
284 |
+
Tuple of arguments to pass to the function (it will receive `*args`).
|
285 |
+
num_processes (`int`, *optional*, defaults to 2):
|
286 |
+
The number of processes to use for training.
|
287 |
+
"""
|
288 |
+
from torch.multiprocessing import start_processes
|
289 |
+
|
290 |
+
with tempfile.NamedTemporaryFile() as tmp_file:
|
291 |
+
# torch.distributed will expect a few environment variable to be here. We set the ones common to each
|
292 |
+
# process here (the other ones will be set be the launcher).
|
293 |
+
with patch_environment(
|
294 |
+
world_size=num_processes,
|
295 |
+
master_addr="127.0.0.1",
|
296 |
+
master_port="29500",
|
297 |
+
accelerate_mixed_precision="no",
|
298 |
+
accelerate_debug_rdv_file=tmp_file.name,
|
299 |
+
accelerate_use_cpu="yes",
|
300 |
+
):
|
301 |
+
launcher = PrepareForLaunch(function, debug=True)
|
302 |
+
start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
|
.venv/Lib/site-packages/accelerate/local_sgd.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from accelerate import Accelerator, DistributedType
|
17 |
+
|
18 |
+
|
19 |
+
class LocalSGD:
|
20 |
+
"""
|
21 |
+
A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
|
22 |
+
on each device, and averages model weights every K synchronization step.
|
23 |
+
|
24 |
+
It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
|
25 |
+
this is a simple implementation that cannot support scenarios such as model parallelism.
|
26 |
+
|
27 |
+
|
28 |
+
Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
|
29 |
+
back to at least:
|
30 |
+
|
31 |
+
Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
|
32 |
+
arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
|
33 |
+
|
34 |
+
We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
|
35 |
+
|
36 |
+
Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
|
37 |
+
Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __enter__(self):
|
42 |
+
if self.enabled:
|
43 |
+
self.model_sync_obj = self.model.no_sync()
|
44 |
+
self.model_sync_obj.__enter__()
|
45 |
+
|
46 |
+
return self
|
47 |
+
|
48 |
+
def __exit__(self, type, value, tb):
|
49 |
+
if self.enabled:
|
50 |
+
# Average all models on exit
|
51 |
+
self._sync_and_avg_model_params()
|
52 |
+
self.model_sync_obj.__exit__(type, value, tb)
|
53 |
+
|
54 |
+
def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
|
55 |
+
"""
|
56 |
+
Constructor.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model (`torch.nn.Module):
|
60 |
+
The model whose parameters we need to average.
|
61 |
+
accelerator (`Accelerator`):
|
62 |
+
Accelerator object.
|
63 |
+
local_sgd_steps (`int`):
|
64 |
+
A number of local SGD steps (before model parameters are synchronized).
|
65 |
+
enabled (`bool):
|
66 |
+
Local SGD is disabled if this parameter set to `False`.
|
67 |
+
"""
|
68 |
+
if accelerator.distributed_type not in [
|
69 |
+
DistributedType.NO,
|
70 |
+
DistributedType.MULTI_CPU,
|
71 |
+
DistributedType.MULTI_GPU,
|
72 |
+
DistributedType.MULTI_XPU,
|
73 |
+
DistributedType.MULTI_MLU,
|
74 |
+
DistributedType.MULTI_MUSA,
|
75 |
+
DistributedType.MULTI_NPU,
|
76 |
+
]:
|
77 |
+
raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
|
78 |
+
self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
|
79 |
+
self.num_steps = 0
|
80 |
+
if self.enabled:
|
81 |
+
self.accelerator = accelerator
|
82 |
+
self.model = model
|
83 |
+
self.local_sgd_steps = local_sgd_steps
|
84 |
+
|
85 |
+
def step(self):
|
86 |
+
"""
|
87 |
+
This function makes a "step" and synchronizes model parameters if necessary.
|
88 |
+
"""
|
89 |
+
self.num_steps += 1
|
90 |
+
if not self.enabled:
|
91 |
+
return
|
92 |
+
|
93 |
+
if self.num_steps % self.local_sgd_steps == 0:
|
94 |
+
self._sync_and_avg_model_params()
|
95 |
+
|
96 |
+
def _sync_and_avg_model_params(self):
|
97 |
+
"""
|
98 |
+
Synchronize + Average model parameters across all GPUs
|
99 |
+
"""
|
100 |
+
|
101 |
+
self.accelerator.wait_for_everyone()
|
102 |
+
with self.accelerator.autocast():
|
103 |
+
for param in self.model.parameters():
|
104 |
+
param.data = self.accelerator.reduce(param.data, reduction="mean")
|
.venv/Lib/site-packages/accelerate/logging.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
|
19 |
+
from .state import PartialState
|
20 |
+
|
21 |
+
|
22 |
+
class MultiProcessAdapter(logging.LoggerAdapter):
|
23 |
+
"""
|
24 |
+
An adapter to assist with logging in multiprocess.
|
25 |
+
|
26 |
+
`log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
|
27 |
+
or only the main executed one. Default is `main_process_only=True`.
|
28 |
+
|
29 |
+
Does not require an `Accelerator` object to be created first.
|
30 |
+
"""
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def _should_log(main_process_only):
|
34 |
+
"Check if log should be performed"
|
35 |
+
state = PartialState()
|
36 |
+
return not main_process_only or (main_process_only and state.is_main_process)
|
37 |
+
|
38 |
+
def log(self, level, msg, *args, **kwargs):
|
39 |
+
"""
|
40 |
+
Delegates logger call after checking if we should log.
|
41 |
+
|
42 |
+
Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
|
43 |
+
or only the main executed one. Default is `True` if not passed
|
44 |
+
|
45 |
+
Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
|
46 |
+
read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
|
47 |
+
break with the previous behavior.
|
48 |
+
|
49 |
+
`in_order` is ignored if `main_process_only` is passed.
|
50 |
+
"""
|
51 |
+
if PartialState._shared_state == {}:
|
52 |
+
raise RuntimeError(
|
53 |
+
"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
|
54 |
+
)
|
55 |
+
main_process_only = kwargs.pop("main_process_only", True)
|
56 |
+
in_order = kwargs.pop("in_order", False)
|
57 |
+
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
|
58 |
+
kwargs.setdefault("stacklevel", 2)
|
59 |
+
|
60 |
+
if self.isEnabledFor(level):
|
61 |
+
if self._should_log(main_process_only):
|
62 |
+
msg, kwargs = self.process(msg, kwargs)
|
63 |
+
self.logger.log(level, msg, *args, **kwargs)
|
64 |
+
|
65 |
+
elif in_order:
|
66 |
+
state = PartialState()
|
67 |
+
for i in range(state.num_processes):
|
68 |
+
if i == state.process_index:
|
69 |
+
msg, kwargs = self.process(msg, kwargs)
|
70 |
+
self.logger.log(level, msg, *args, **kwargs)
|
71 |
+
state.wait_for_everyone()
|
72 |
+
|
73 |
+
@functools.lru_cache(None)
|
74 |
+
def warning_once(self, *args, **kwargs):
|
75 |
+
"""
|
76 |
+
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
|
77 |
+
|
78 |
+
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
|
79 |
+
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
|
80 |
+
switch to another type of cache that includes the caller frame information in the hashing function.
|
81 |
+
"""
|
82 |
+
self.warning(*args, **kwargs)
|
83 |
+
|
84 |
+
|
85 |
+
def get_logger(name: str, log_level: str = None):
|
86 |
+
"""
|
87 |
+
Returns a `logging.Logger` for `name` that can handle multiprocessing.
|
88 |
+
|
89 |
+
If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
|
90 |
+
processes and in order, also pass `in_order=True`
|
91 |
+
|
92 |
+
Args:
|
93 |
+
name (`str`):
|
94 |
+
The name for the logger, such as `__file__`
|
95 |
+
log_level (`str`, *optional*):
|
96 |
+
The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
|
97 |
+
|
98 |
+
Example:
|
99 |
+
|
100 |
+
```python
|
101 |
+
>>> from accelerate.logging import get_logger
|
102 |
+
>>> from accelerate import Accelerator
|
103 |
+
|
104 |
+
>>> logger = get_logger(__name__)
|
105 |
+
|
106 |
+
>>> accelerator = Accelerator()
|
107 |
+
>>> logger.info("My log", main_process_only=False)
|
108 |
+
>>> logger.debug("My log", main_process_only=True)
|
109 |
+
|
110 |
+
>>> logger = get_logger(__name__, log_level="DEBUG")
|
111 |
+
>>> logger.info("My log")
|
112 |
+
>>> logger.debug("My second log")
|
113 |
+
|
114 |
+
>>> array = ["a", "b", "c", "d"]
|
115 |
+
>>> letter_at_rank = array[accelerator.process_index]
|
116 |
+
>>> logger.info(letter_at_rank, in_order=True)
|
117 |
+
```
|
118 |
+
"""
|
119 |
+
if log_level is None:
|
120 |
+
log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
|
121 |
+
logger = logging.getLogger(name)
|
122 |
+
if log_level is not None:
|
123 |
+
logger.setLevel(log_level.upper())
|
124 |
+
logger.root.setLevel(log_level.upper())
|
125 |
+
return MultiProcessAdapter(logger, {})
|
.venv/Lib/site-packages/accelerate/memory_utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
|
18 |
+
warnings.warn(
|
19 |
+
"memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
|
20 |
+
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
21 |
+
FutureWarning,
|
22 |
+
)
|
.venv/Lib/site-packages/accelerate/optimizer.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from .state import AcceleratorState, GradientState
|
20 |
+
from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
|
21 |
+
|
22 |
+
|
23 |
+
if is_torch_xla_available():
|
24 |
+
import torch_xla.core.xla_model as xm
|
25 |
+
|
26 |
+
|
27 |
+
def move_to_device(state, device):
|
28 |
+
if isinstance(state, (list, tuple)):
|
29 |
+
return honor_type(state, (move_to_device(t, device) for t in state))
|
30 |
+
elif isinstance(state, dict):
|
31 |
+
return type(state)({k: move_to_device(v, device) for k, v in state.items()})
|
32 |
+
elif isinstance(state, torch.Tensor):
|
33 |
+
return state.to(device)
|
34 |
+
return state
|
35 |
+
|
36 |
+
|
37 |
+
class AcceleratedOptimizer(torch.optim.Optimizer):
|
38 |
+
"""
|
39 |
+
Internal wrapper around a torch optimizer.
|
40 |
+
|
41 |
+
Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
|
42 |
+
accumulation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
optimizer (`torch.optim.optimizer.Optimizer`):
|
46 |
+
The optimizer to wrap.
|
47 |
+
device_placement (`bool`, *optional*, defaults to `True`):
|
48 |
+
Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
|
49 |
+
`optimizer` on the right device.
|
50 |
+
scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
|
51 |
+
The scaler to use in the step function if training with mixed precision.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, optimizer, device_placement=True, scaler=None):
|
55 |
+
self.optimizer = optimizer
|
56 |
+
self.scaler = scaler
|
57 |
+
self.accelerator_state = AcceleratorState()
|
58 |
+
self.gradient_state = GradientState()
|
59 |
+
self.device_placement = device_placement
|
60 |
+
self._is_overflow = False
|
61 |
+
|
62 |
+
if self.scaler is not None:
|
63 |
+
self._accelerate_step_called = False
|
64 |
+
self._optimizer_original_step_method = self.optimizer.step
|
65 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
66 |
+
|
67 |
+
# Handle device placement
|
68 |
+
if device_placement:
|
69 |
+
state_dict = self.optimizer.state_dict()
|
70 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
71 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
72 |
+
else:
|
73 |
+
state_dict = move_to_device(state_dict, self.accelerator_state.device)
|
74 |
+
self.optimizer.load_state_dict(state_dict)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def state(self):
|
78 |
+
return self.optimizer.state
|
79 |
+
|
80 |
+
@state.setter
|
81 |
+
def state(self, state):
|
82 |
+
self.optimizer.state = state
|
83 |
+
|
84 |
+
@property
|
85 |
+
def param_groups(self):
|
86 |
+
return self.optimizer.param_groups
|
87 |
+
|
88 |
+
@param_groups.setter
|
89 |
+
def param_groups(self, param_groups):
|
90 |
+
self.optimizer.param_groups = param_groups
|
91 |
+
|
92 |
+
@property
|
93 |
+
def defaults(self):
|
94 |
+
return self.optimizer.defaults
|
95 |
+
|
96 |
+
@defaults.setter
|
97 |
+
def defaults(self, defaults):
|
98 |
+
self.optimizer.defaults = defaults
|
99 |
+
|
100 |
+
def add_param_group(self, param_group):
|
101 |
+
self.optimizer.add_param_group(param_group)
|
102 |
+
|
103 |
+
def load_state_dict(self, state_dict):
|
104 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
|
105 |
+
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
|
106 |
+
self.optimizer.load_state_dict(state_dict)
|
107 |
+
|
108 |
+
def state_dict(self):
|
109 |
+
return self.optimizer.state_dict()
|
110 |
+
|
111 |
+
def zero_grad(self, set_to_none=None):
|
112 |
+
if self.gradient_state.sync_gradients:
|
113 |
+
accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
|
114 |
+
if accept_arg:
|
115 |
+
if set_to_none is None:
|
116 |
+
set_to_none = True
|
117 |
+
self.optimizer.zero_grad(set_to_none=set_to_none)
|
118 |
+
else:
|
119 |
+
if set_to_none is not None:
|
120 |
+
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
|
121 |
+
self.optimizer.zero_grad()
|
122 |
+
|
123 |
+
def train(self):
|
124 |
+
"""
|
125 |
+
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
|
126 |
+
"""
|
127 |
+
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
128 |
+
self.optimizer.train()
|
129 |
+
elif (
|
130 |
+
hasattr(self.optimizer, "optimizer")
|
131 |
+
and hasattr(self.optimizer.optimizer, "train")
|
132 |
+
and callable(self.optimizer.optimizer.train)
|
133 |
+
):
|
134 |
+
# the deepspeed optimizer further wraps the optimizer
|
135 |
+
self.optimizer.optimizer.train()
|
136 |
+
|
137 |
+
def eval(self):
|
138 |
+
"""
|
139 |
+
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
|
140 |
+
"""
|
141 |
+
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
142 |
+
self.optimizer.eval()
|
143 |
+
|
144 |
+
def step(self, closure=None):
|
145 |
+
if is_lomo_available():
|
146 |
+
from lomo_optim import AdaLomo, Lomo
|
147 |
+
|
148 |
+
if (
|
149 |
+
not self.gradient_state.is_xla_gradients_synced
|
150 |
+
and self.accelerator_state.distributed_type == DistributedType.XLA
|
151 |
+
):
|
152 |
+
gradients = xm._fetch_gradients(self.optimizer)
|
153 |
+
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
|
154 |
+
self.gradient_state.is_xla_gradients_synced = True
|
155 |
+
|
156 |
+
if is_lomo_available():
|
157 |
+
# `step` should be a no-op for LOMO optimizers.
|
158 |
+
if isinstance(self.optimizer, (Lomo, AdaLomo)):
|
159 |
+
return
|
160 |
+
|
161 |
+
if self.gradient_state.sync_gradients:
|
162 |
+
if self.scaler is not None:
|
163 |
+
self.optimizer.step = self._optimizer_patched_step_method
|
164 |
+
|
165 |
+
self.scaler.step(self.optimizer, closure)
|
166 |
+
self.scaler.update()
|
167 |
+
|
168 |
+
if not self._accelerate_step_called:
|
169 |
+
# If the optimizer step was skipped, gradient overflow was detected.
|
170 |
+
self._is_overflow = True
|
171 |
+
else:
|
172 |
+
self._is_overflow = False
|
173 |
+
# Reset the step method to the original one
|
174 |
+
self.optimizer.step = self._optimizer_original_step_method
|
175 |
+
# Reset the indicator
|
176 |
+
self._accelerate_step_called = False
|
177 |
+
else:
|
178 |
+
self.optimizer.step(closure)
|
179 |
+
if self.accelerator_state.distributed_type == DistributedType.XLA:
|
180 |
+
self.gradient_state.is_xla_gradients_synced = False
|
181 |
+
|
182 |
+
def _switch_parameters(self, parameters_map):
|
183 |
+
for param_group in self.optimizer.param_groups:
|
184 |
+
param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
|
185 |
+
|
186 |
+
@property
|
187 |
+
def step_was_skipped(self):
|
188 |
+
"""Whether or not the optimizer step was skipped."""
|
189 |
+
return self._is_overflow
|
190 |
+
|
191 |
+
def __getstate__(self):
|
192 |
+
_ignored_keys = [
|
193 |
+
"_accelerate_step_called",
|
194 |
+
"_optimizer_original_step_method",
|
195 |
+
"_optimizer_patched_step_method",
|
196 |
+
]
|
197 |
+
return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
|
198 |
+
|
199 |
+
def __setstate__(self, state):
|
200 |
+
self.__dict__.update(state)
|
201 |
+
if self.scaler is not None:
|
202 |
+
self._accelerate_step_called = False
|
203 |
+
self._optimizer_original_step_method = self.optimizer.step
|
204 |
+
self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
|
205 |
+
|
206 |
+
|
207 |
+
def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
|
208 |
+
def patched_step(*args, **kwargs):
|
209 |
+
accelerated_optimizer._accelerate_step_called = True
|
210 |
+
return method(*args, **kwargs)
|
211 |
+
|
212 |
+
return patched_step
|
.venv/Lib/site-packages/accelerate/scheduler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
|
16 |
+
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
from .state import AcceleratorState, GradientState
|
20 |
+
|
21 |
+
|
22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
|
23 |
+
|
24 |
+
|
25 |
+
class AcceleratedScheduler:
|
26 |
+
"""
|
27 |
+
A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
|
28 |
+
to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
|
29 |
+
precision training)
|
30 |
+
|
31 |
+
When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
|
32 |
+
step the scheduler to account for it.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
scheduler (`torch.optim.lr_scheduler._LRScheduler`):
|
36 |
+
The scheduler to wrap.
|
37 |
+
optimizers (one or a list of `torch.optim.Optimizer`):
|
38 |
+
The optimizers used.
|
39 |
+
step_with_optimizer (`bool`, *optional*, defaults to `True`):
|
40 |
+
Whether or not the scheduler should be stepped at each optimizer step.
|
41 |
+
split_batches (`bool`, *optional*, defaults to `False`):
|
42 |
+
Whether or not the dataloaders split one batch across the different processes (so batch size is the same
|
43 |
+
regardless of the number of processes) or create batches on each process (so batch size is the original
|
44 |
+
batch size multiplied by the number of processes).
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
|
48 |
+
self.scheduler = scheduler
|
49 |
+
self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
|
50 |
+
self.split_batches = split_batches
|
51 |
+
self.step_with_optimizer = step_with_optimizer
|
52 |
+
self.gradient_state = GradientState()
|
53 |
+
|
54 |
+
def step(self, *args, **kwargs):
|
55 |
+
if not self.step_with_optimizer:
|
56 |
+
# No link between scheduler and optimizer -> just step
|
57 |
+
self.scheduler.step(*args, **kwargs)
|
58 |
+
return
|
59 |
+
|
60 |
+
# Otherwise, first make sure the optimizer was stepped.
|
61 |
+
if not self.gradient_state.sync_gradients:
|
62 |
+
if self.gradient_state.adjust_scheduler:
|
63 |
+
self.scheduler._step_count += 1
|
64 |
+
return
|
65 |
+
|
66 |
+
for opt in self.optimizers:
|
67 |
+
if opt.step_was_skipped:
|
68 |
+
return
|
69 |
+
if self.split_batches:
|
70 |
+
# Split batches -> the training dataloader batch size is not changed so one step per training step
|
71 |
+
self.scheduler.step(*args, **kwargs)
|
72 |
+
else:
|
73 |
+
# Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
|
74 |
+
# num_processes steps per training step
|
75 |
+
num_processes = AcceleratorState().num_processes
|
76 |
+
for _ in range(num_processes):
|
77 |
+
# Special case when using OneCycle and `drop_last` was not used
|
78 |
+
if hasattr(self.scheduler, "total_steps"):
|
79 |
+
if self.scheduler._step_count <= self.scheduler.total_steps:
|
80 |
+
self.scheduler.step(*args, **kwargs)
|
81 |
+
else:
|
82 |
+
self.scheduler.step(*args, **kwargs)
|
83 |
+
|
84 |
+
# Passthroughs
|
85 |
+
def get_last_lr(self):
|
86 |
+
return self.scheduler.get_last_lr()
|
87 |
+
|
88 |
+
def state_dict(self):
|
89 |
+
return self.scheduler.state_dict()
|
90 |
+
|
91 |
+
def load_state_dict(self, state_dict):
|
92 |
+
self.scheduler.load_state_dict(state_dict)
|
93 |
+
|
94 |
+
def get_lr(self):
|
95 |
+
return self.scheduler.get_lr()
|
96 |
+
|
97 |
+
def print_lr(self, *args, **kwargs):
|
98 |
+
return self.scheduler.print_lr(*args, **kwargs)
|
.venv/Lib/site-packages/accelerate/state.py
ADDED
@@ -0,0 +1,1257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import threading
|
20 |
+
import warnings
|
21 |
+
from contextlib import contextmanager
|
22 |
+
from functools import partial
|
23 |
+
from typing import Any, Callable, Optional
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
from .utils import (
|
28 |
+
DistributedType,
|
29 |
+
DynamoBackend,
|
30 |
+
GradientAccumulationPlugin,
|
31 |
+
check_cuda_p2p_ib_support,
|
32 |
+
check_fp8_capability,
|
33 |
+
deepspeed_required,
|
34 |
+
get_ccl_version,
|
35 |
+
get_cpu_distributed_information,
|
36 |
+
get_int_from_env,
|
37 |
+
is_ccl_available,
|
38 |
+
is_datasets_available,
|
39 |
+
is_deepspeed_available,
|
40 |
+
is_fp8_available,
|
41 |
+
is_ipex_available,
|
42 |
+
is_mlu_available,
|
43 |
+
is_mps_available,
|
44 |
+
is_musa_available,
|
45 |
+
is_npu_available,
|
46 |
+
is_torch_xla_available,
|
47 |
+
is_xpu_available,
|
48 |
+
parse_choice_from_env,
|
49 |
+
parse_flag_from_env,
|
50 |
+
set_numa_affinity,
|
51 |
+
)
|
52 |
+
from .utils.dataclasses import SageMakerDistributedType
|
53 |
+
|
54 |
+
|
55 |
+
if is_torch_xla_available():
|
56 |
+
import torch_xla.core.xla_model as xm
|
57 |
+
|
58 |
+
if is_mlu_available(check_device=False):
|
59 |
+
import torch_mlu # noqa: F401
|
60 |
+
|
61 |
+
if is_musa_available(check_device=False):
|
62 |
+
import torch_musa # noqa: F401
|
63 |
+
|
64 |
+
if is_npu_available(check_device=False):
|
65 |
+
import torch_npu # noqa: F401
|
66 |
+
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
|
69 |
+
|
70 |
+
def is_initialized() -> bool:
|
71 |
+
"""
|
72 |
+
Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
|
73 |
+
but works as a module method.
|
74 |
+
"""
|
75 |
+
return AcceleratorState._shared_state != {}
|
76 |
+
|
77 |
+
|
78 |
+
# Lambda function that does nothing
|
79 |
+
def do_nothing(*args, **kwargs):
|
80 |
+
return None
|
81 |
+
|
82 |
+
|
83 |
+
class ThreadLocalSharedDict(threading.local):
|
84 |
+
"""
|
85 |
+
Descriptor that holds a dict shared between instances of a class in the same thread.
|
86 |
+
|
87 |
+
Note: Descriptors have slightly different semantics than just a dict field on its own.
|
88 |
+
`PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
|
89 |
+
underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
|
90 |
+
the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
|
91 |
+
object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
|
92 |
+
|
93 |
+
See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
|
94 |
+
|
95 |
+
This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
|
96 |
+
|
97 |
+
See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, thread_local: bool = False):
|
101 |
+
self._storage = {}
|
102 |
+
|
103 |
+
def __get__(self, obj, objtype=None):
|
104 |
+
return self._storage
|
105 |
+
|
106 |
+
def __set__(self, obj, value):
|
107 |
+
self._storage = value
|
108 |
+
|
109 |
+
|
110 |
+
# Prefer global shared dictionary, except when using TPU.
|
111 |
+
SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
|
112 |
+
|
113 |
+
|
114 |
+
# Inspired by Alex Martelli's 'Borg'.
|
115 |
+
class PartialState:
|
116 |
+
"""
|
117 |
+
Singleton class that has information about the current training environment and functions to help with process
|
118 |
+
control. Designed to be used when only process control and device execution states are needed. Does *not* need to
|
119 |
+
be initialized from `Accelerator`.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
cpu (`bool`, *optional*):
|
123 |
+
Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
|
124 |
+
`True` and force the execution on the CPU.
|
125 |
+
kwargs (additional keyword arguments, *optional*):
|
126 |
+
Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
|
127 |
+
found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
|
128 |
+
|
129 |
+
**Available attributes:**
|
130 |
+
|
131 |
+
- **device** (`torch.device`) -- The device to use.
|
132 |
+
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
133 |
+
in use.
|
134 |
+
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
135 |
+
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
136 |
+
of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
|
137 |
+
- **num_processes** (`int`) -- The number of processes currently launched in parallel.
|
138 |
+
- **process_index** (`int`) -- The index of the current process.
|
139 |
+
- **is_last_process** (`bool`) -- Whether or not the current process is the last one.
|
140 |
+
- **is_main_process** (`bool`) -- Whether or not the current process is the main one.
|
141 |
+
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
|
142 |
+
- **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
|
143 |
+
|
144 |
+
Example:
|
145 |
+
```python
|
146 |
+
from accelerate.utils import InitProcessGroupKwargs
|
147 |
+
|
148 |
+
# To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
|
149 |
+
kwargs = InitProcessGroupKwargs(...).to_kwargs()
|
150 |
+
state = PartialState(**kwargs)
|
151 |
+
```
|
152 |
+
"""
|
153 |
+
|
154 |
+
_shared_state = SharedDict()
|
155 |
+
_known_attrs = [
|
156 |
+
"_cpu",
|
157 |
+
"_mixed_precision",
|
158 |
+
"_shared_state",
|
159 |
+
"backend",
|
160 |
+
"debug",
|
161 |
+
"device",
|
162 |
+
"distributed_type",
|
163 |
+
"fork_launched",
|
164 |
+
"local_process_index",
|
165 |
+
"num_processes",
|
166 |
+
"process_index",
|
167 |
+
]
|
168 |
+
|
169 |
+
def __init__(self, cpu: bool = False, **kwargs):
|
170 |
+
self.__dict__ = self._shared_state
|
171 |
+
if not self.initialized:
|
172 |
+
self._cpu = cpu
|
173 |
+
self.backend = None
|
174 |
+
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
|
175 |
+
self.device = torch.device(env_device) if env_device is not None else None
|
176 |
+
self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
|
177 |
+
use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
|
178 |
+
dist_information = None
|
179 |
+
if use_sagemaker_dp is None:
|
180 |
+
use_sagemaker_dp = (
|
181 |
+
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
|
182 |
+
and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
|
183 |
+
)
|
184 |
+
|
185 |
+
# Sets up self.backend + imports
|
186 |
+
original_backend = kwargs.pop("backend", None)
|
187 |
+
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
|
188 |
+
if original_backend is not None and backend != original_backend:
|
189 |
+
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
|
190 |
+
self.backend = backend
|
191 |
+
self.distributed_type = distributed_type
|
192 |
+
use_deepspeed = False
|
193 |
+
if not cpu and self.backend != "xla":
|
194 |
+
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
195 |
+
# Deal with spawning deepspeed
|
196 |
+
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
197 |
+
if not is_deepspeed_available():
|
198 |
+
raise ImportError(
|
199 |
+
"DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
|
200 |
+
)
|
201 |
+
from deepspeed import comm as dist
|
202 |
+
|
203 |
+
if not dist.is_initialized():
|
204 |
+
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
|
205 |
+
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
|
206 |
+
use_deepspeed = True
|
207 |
+
# Deal with all other backends but XPU and CPU, that gets handled special later
|
208 |
+
elif (
|
209 |
+
self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
|
210 |
+
and not torch.distributed.is_initialized()
|
211 |
+
):
|
212 |
+
torch.distributed.init_process_group(backend=self.backend, **kwargs)
|
213 |
+
# XPU and CPU require special env configs to be set
|
214 |
+
if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
|
215 |
+
dist_information = get_cpu_distributed_information()
|
216 |
+
os.environ["RANK"] = str(dist_information.rank)
|
217 |
+
os.environ["WORLD_SIZE"] = str(dist_information.world_size)
|
218 |
+
os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
|
219 |
+
os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
|
220 |
+
if not os.environ.get("MASTER_PORT", None):
|
221 |
+
os.environ["MASTER_PORT"] = "29500"
|
222 |
+
if (
|
223 |
+
not os.environ.get("MASTER_ADDR", None)
|
224 |
+
and dist_information.local_world_size != dist_information.world_size
|
225 |
+
and self.backend != "mpi"
|
226 |
+
):
|
227 |
+
raise ValueError(
|
228 |
+
"Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
|
229 |
+
"please try exporting rank 0's hostname as `MASTER_ADDR`"
|
230 |
+
)
|
231 |
+
kwargs["rank"] = dist_information.rank
|
232 |
+
kwargs["world_size"] = dist_information.world_size
|
233 |
+
|
234 |
+
if (
|
235 |
+
self.distributed_type == DistributedType.MULTI_CPU
|
236 |
+
and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
|
237 |
+
):
|
238 |
+
import psutil
|
239 |
+
|
240 |
+
num_cpu_threads_per_process = int(
|
241 |
+
psutil.cpu_count(logical=False) / dist_information.local_world_size
|
242 |
+
)
|
243 |
+
if num_cpu_threads_per_process == 0:
|
244 |
+
num_cpu_threads_per_process = 1
|
245 |
+
torch.set_num_threads(num_cpu_threads_per_process)
|
246 |
+
warnings.warn(
|
247 |
+
f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
|
248 |
+
" performance."
|
249 |
+
)
|
250 |
+
|
251 |
+
if not torch.distributed.is_initialized():
|
252 |
+
torch.distributed.init_process_group(backend=self.backend, **kwargs)
|
253 |
+
|
254 |
+
# No backend == no distributed training
|
255 |
+
if self.backend is None:
|
256 |
+
self.distributed_type = DistributedType.NO
|
257 |
+
self.num_processes = 1
|
258 |
+
self.process_index = 0
|
259 |
+
self.local_process_index = 0
|
260 |
+
elif self.backend == "xla":
|
261 |
+
# XLA needs device setting first for `set_replication`
|
262 |
+
self.set_device()
|
263 |
+
xm.set_replication(self.device, xm.get_xla_supported_devices())
|
264 |
+
self.num_processes = xm.xrt_world_size()
|
265 |
+
self.process_index = xm.get_ordinal()
|
266 |
+
if is_torch_xla_available(check_is_tpu=True):
|
267 |
+
self.local_process_index = xm.get_local_ordinal()
|
268 |
+
else:
|
269 |
+
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
|
270 |
+
else:
|
271 |
+
self.num_processes = torch.distributed.get_world_size()
|
272 |
+
self.process_index = torch.distributed.get_rank()
|
273 |
+
self.local_process_index = (
|
274 |
+
int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
|
275 |
+
)
|
276 |
+
self.set_device()
|
277 |
+
# Now we can change to deepseed
|
278 |
+
if use_deepspeed:
|
279 |
+
self.distributed_type = DistributedType.DEEPSPEED
|
280 |
+
|
281 |
+
# Set CPU affinity if enabled
|
282 |
+
if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
|
283 |
+
set_numa_affinity(self.local_process_index)
|
284 |
+
|
285 |
+
# Check for old RTX 4000's that can't use P2P or IB and are on old drivers
|
286 |
+
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
|
287 |
+
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
|
288 |
+
raise NotImplementedError(
|
289 |
+
"Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
|
290 |
+
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
|
291 |
+
"will do this automatically."
|
292 |
+
)
|
293 |
+
# Important: This should be the *only* code outside of `self.initialized!`
|
294 |
+
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
|
295 |
+
|
296 |
+
def __repr__(self) -> str:
|
297 |
+
return (
|
298 |
+
f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
|
299 |
+
f"Num processes: {self.num_processes}\n"
|
300 |
+
f"Process index: {self.process_index}\n"
|
301 |
+
f"Local process index: {self.local_process_index}\n"
|
302 |
+
f"Device: {self.device}\n"
|
303 |
+
)
|
304 |
+
|
305 |
+
@staticmethod
|
306 |
+
def _reset_state():
|
307 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
308 |
+
PartialState._shared_state.clear()
|
309 |
+
|
310 |
+
@property
|
311 |
+
def initialized(self) -> bool:
|
312 |
+
"Returns whether the `PartialState` has been initialized"
|
313 |
+
return self._shared_state != {}
|
314 |
+
|
315 |
+
@property
|
316 |
+
def use_distributed(self):
|
317 |
+
"""
|
318 |
+
Whether the Accelerator is configured for distributed training
|
319 |
+
"""
|
320 |
+
return self.distributed_type != DistributedType.NO and self.num_processes > 1
|
321 |
+
|
322 |
+
@property
|
323 |
+
def is_last_process(self) -> bool:
|
324 |
+
"Returns whether the current process is the last one"
|
325 |
+
return self.process_index == self.num_processes - 1
|
326 |
+
|
327 |
+
@property
|
328 |
+
def is_main_process(self) -> bool:
|
329 |
+
"Returns whether the current process is the main process"
|
330 |
+
return (
|
331 |
+
self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
|
332 |
+
)
|
333 |
+
|
334 |
+
@property
|
335 |
+
def is_local_main_process(self) -> bool:
|
336 |
+
"Returns whether the current process is the main process on the local node"
|
337 |
+
return (
|
338 |
+
self.local_process_index == 0
|
339 |
+
if self.distributed_type != DistributedType.MEGATRON_LM
|
340 |
+
else self.is_last_process
|
341 |
+
)
|
342 |
+
|
343 |
+
def wait_for_everyone(self):
|
344 |
+
"""
|
345 |
+
Will stop the execution of the current process until every other process has reached that point (so this does
|
346 |
+
nothing when the script is only run in one process). Useful to do before saving a model.
|
347 |
+
|
348 |
+
Example:
|
349 |
+
|
350 |
+
```python
|
351 |
+
>>> # Assuming two GPU processes
|
352 |
+
>>> import time
|
353 |
+
>>> from accelerate.state import PartialState
|
354 |
+
|
355 |
+
>>> state = PartialState()
|
356 |
+
>>> if state.is_main_process:
|
357 |
+
... time.sleep(2)
|
358 |
+
>>> else:
|
359 |
+
... print("I'm waiting for the main process to finish its sleep...")
|
360 |
+
>>> state.wait_for_everyone()
|
361 |
+
>>> # Should print on every process at the same time
|
362 |
+
>>> print("Everyone is here")
|
363 |
+
```
|
364 |
+
"""
|
365 |
+
if self.distributed_type in (
|
366 |
+
DistributedType.MULTI_GPU,
|
367 |
+
DistributedType.MULTI_MLU,
|
368 |
+
DistributedType.MULTI_MUSA,
|
369 |
+
DistributedType.MULTI_NPU,
|
370 |
+
DistributedType.MULTI_XPU,
|
371 |
+
DistributedType.MULTI_CPU,
|
372 |
+
DistributedType.DEEPSPEED,
|
373 |
+
DistributedType.FSDP,
|
374 |
+
):
|
375 |
+
torch.distributed.barrier()
|
376 |
+
elif self.distributed_type == DistributedType.XLA:
|
377 |
+
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
378 |
+
|
379 |
+
def _goes_first(self, is_main: bool):
|
380 |
+
if not is_main:
|
381 |
+
self.wait_for_everyone()
|
382 |
+
|
383 |
+
yield
|
384 |
+
|
385 |
+
if is_main:
|
386 |
+
self.wait_for_everyone()
|
387 |
+
|
388 |
+
@contextmanager
|
389 |
+
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
|
390 |
+
"""
|
391 |
+
Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
|
392 |
+
distributed inference, such as with different prompts.
|
393 |
+
|
394 |
+
Note that when using a `dict`, all keys need to have the same number of elements.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
|
398 |
+
The input to split between processes.
|
399 |
+
apply_padding (`bool`, `optional`, defaults to `False`):
|
400 |
+
Whether to apply padding by repeating the last element of the input so that all processes have the same
|
401 |
+
number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
|
402 |
+
in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
|
403 |
+
|
404 |
+
|
405 |
+
Example:
|
406 |
+
|
407 |
+
```python
|
408 |
+
# Assume there are two processes
|
409 |
+
from accelerate import PartialState
|
410 |
+
|
411 |
+
state = PartialState()
|
412 |
+
with state.split_between_processes(["A", "B", "C"]) as inputs:
|
413 |
+
print(inputs)
|
414 |
+
# Process 0
|
415 |
+
["A", "B"]
|
416 |
+
# Process 1
|
417 |
+
["C"]
|
418 |
+
|
419 |
+
with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
|
420 |
+
print(inputs)
|
421 |
+
# Process 0
|
422 |
+
["A", "B"]
|
423 |
+
# Process 1
|
424 |
+
["C", "C"]
|
425 |
+
```
|
426 |
+
"""
|
427 |
+
if self.num_processes == 1:
|
428 |
+
yield inputs
|
429 |
+
return
|
430 |
+
length = len(inputs)
|
431 |
+
# Nested dictionary of any types
|
432 |
+
if isinstance(inputs, dict):
|
433 |
+
length = len(inputs[list(inputs.keys())[0]])
|
434 |
+
if not all(len(v) == length for v in inputs.values()):
|
435 |
+
raise ValueError("All values in the dictionary must have the same length")
|
436 |
+
num_samples_per_process, num_extras = divmod(length, self.num_processes)
|
437 |
+
start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
|
438 |
+
end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
|
439 |
+
|
440 |
+
def _split_values(inputs, start_index, end_index):
|
441 |
+
if isinstance(inputs, (list, tuple, torch.Tensor)):
|
442 |
+
if start_index >= len(inputs):
|
443 |
+
result = inputs[-1:]
|
444 |
+
else:
|
445 |
+
result = inputs[start_index:end_index]
|
446 |
+
if apply_padding:
|
447 |
+
if isinstance(result, torch.Tensor):
|
448 |
+
from accelerate.utils import pad_across_processes, send_to_device
|
449 |
+
|
450 |
+
# The tensor needs to be on the device before we can pad it
|
451 |
+
tensorized_result = send_to_device(result, self.device)
|
452 |
+
result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
|
453 |
+
else:
|
454 |
+
result += [result[-1]] * (num_samples_per_process + 1 - len(result))
|
455 |
+
return result
|
456 |
+
elif isinstance(inputs, dict):
|
457 |
+
for key in inputs.keys():
|
458 |
+
inputs[key] = _split_values(inputs[key], start_index, end_index)
|
459 |
+
return inputs
|
460 |
+
else:
|
461 |
+
if is_datasets_available():
|
462 |
+
from datasets import Dataset
|
463 |
+
|
464 |
+
if isinstance(inputs, Dataset):
|
465 |
+
if start_index >= len(inputs):
|
466 |
+
start_index = len(inputs) - 1
|
467 |
+
if end_index > len(inputs):
|
468 |
+
end_index = len(inputs)
|
469 |
+
result_idcs = list(range(start_index, end_index))
|
470 |
+
if apply_padding:
|
471 |
+
result_idcs += [end_index - 1] * (num_samples_per_process + 1 - len(result_idcs))
|
472 |
+
return inputs.select(result_idcs)
|
473 |
+
return inputs
|
474 |
+
|
475 |
+
yield _split_values(inputs, start_index, end_index)
|
476 |
+
|
477 |
+
@contextmanager
|
478 |
+
def main_process_first(self):
|
479 |
+
"""
|
480 |
+
Lets the main process go first inside a with block.
|
481 |
+
|
482 |
+
The other processes will enter the with block after the main process exits.
|
483 |
+
|
484 |
+
Example:
|
485 |
+
|
486 |
+
```python
|
487 |
+
>>> from accelerate import Accelerator
|
488 |
+
|
489 |
+
>>> accelerator = Accelerator()
|
490 |
+
>>> with accelerator.main_process_first():
|
491 |
+
... # This will be printed first by process 0 then in a seemingly
|
492 |
+
... # random order by the other processes.
|
493 |
+
... print(f"This will be printed by process {accelerator.process_index}")
|
494 |
+
```
|
495 |
+
"""
|
496 |
+
yield from self._goes_first(self.is_main_process)
|
497 |
+
|
498 |
+
@contextmanager
|
499 |
+
def local_main_process_first(self):
|
500 |
+
"""
|
501 |
+
Lets the local main process go inside a with block.
|
502 |
+
|
503 |
+
The other processes will enter the with block after the main process exits.
|
504 |
+
|
505 |
+
Example:
|
506 |
+
|
507 |
+
```python
|
508 |
+
>>> from accelerate.state import PartialState
|
509 |
+
|
510 |
+
>>> state = PartialState()
|
511 |
+
>>> with state.local_main_process_first():
|
512 |
+
... # This will be printed first by local process 0 then in a seemingly
|
513 |
+
... # random order by the other processes.
|
514 |
+
... print(f"This will be printed by process {state.local_process_index}")
|
515 |
+
```
|
516 |
+
"""
|
517 |
+
yield from self._goes_first(self.is_local_main_process)
|
518 |
+
|
519 |
+
def on_main_process(self, function: Callable[..., Any] = None):
|
520 |
+
"""
|
521 |
+
Decorator that only runs the decorated function on the main process.
|
522 |
+
|
523 |
+
Args:
|
524 |
+
function (`Callable`): The function to decorate.
|
525 |
+
|
526 |
+
Example:
|
527 |
+
|
528 |
+
```python
|
529 |
+
>>> from accelerate.state import PartialState
|
530 |
+
|
531 |
+
>>> state = PartialState()
|
532 |
+
|
533 |
+
|
534 |
+
>>> @state.on_main_process
|
535 |
+
... def print_something():
|
536 |
+
... print("This will be printed by process 0 only.")
|
537 |
+
|
538 |
+
|
539 |
+
>>> print_something()
|
540 |
+
"This will be printed by process 0 only"
|
541 |
+
```
|
542 |
+
"""
|
543 |
+
if not self.initialized:
|
544 |
+
raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
|
545 |
+
if self.is_main_process or not self.use_distributed:
|
546 |
+
return function
|
547 |
+
return do_nothing
|
548 |
+
|
549 |
+
def on_local_main_process(self, function: Callable[..., Any] = None):
|
550 |
+
"""
|
551 |
+
Decorator that only runs the decorated function on the local main process.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
function (`Callable`): The function to decorate.
|
555 |
+
|
556 |
+
Example:
|
557 |
+
```python
|
558 |
+
# Assume we have 2 servers with 4 processes each.
|
559 |
+
from accelerate.state import PartialState
|
560 |
+
|
561 |
+
state = PartialState()
|
562 |
+
|
563 |
+
|
564 |
+
@state.on_local_main_process
|
565 |
+
def print_something():
|
566 |
+
print("This will be printed by process 0 only on each server.")
|
567 |
+
|
568 |
+
|
569 |
+
print_something()
|
570 |
+
# On server 1:
|
571 |
+
"This will be printed by process 0 only"
|
572 |
+
# On server 2:
|
573 |
+
"This will be printed by process 0 only"
|
574 |
+
```
|
575 |
+
"""
|
576 |
+
if self.is_local_main_process or not self.use_distributed:
|
577 |
+
return function
|
578 |
+
return do_nothing
|
579 |
+
|
580 |
+
def on_last_process(self, function: Callable[..., Any]):
|
581 |
+
"""
|
582 |
+
Decorator that only runs the decorated function on the last process.
|
583 |
+
|
584 |
+
Args:
|
585 |
+
function (`Callable`): The function to decorate.
|
586 |
+
|
587 |
+
Example:
|
588 |
+
```python
|
589 |
+
# Assume we have 4 processes.
|
590 |
+
from accelerate.state import PartialState
|
591 |
+
|
592 |
+
state = PartialState()
|
593 |
+
|
594 |
+
|
595 |
+
@state.on_last_process
|
596 |
+
def print_something():
|
597 |
+
print(f"Printed on process {state.process_index}")
|
598 |
+
|
599 |
+
|
600 |
+
print_something()
|
601 |
+
"Printed on process 3"
|
602 |
+
```
|
603 |
+
"""
|
604 |
+
if self.is_last_process or not self.use_distributed:
|
605 |
+
return function
|
606 |
+
return do_nothing
|
607 |
+
|
608 |
+
def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
|
609 |
+
"""
|
610 |
+
Decorator that only runs the decorated function on the process with the given index.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
function (`Callable`, `optional`):
|
614 |
+
The function to decorate.
|
615 |
+
process_index (`int`, `optional`):
|
616 |
+
The index of the process on which to run the function.
|
617 |
+
|
618 |
+
Example:
|
619 |
+
```python
|
620 |
+
# Assume we have 4 processes.
|
621 |
+
from accelerate.state import PartialState
|
622 |
+
|
623 |
+
state = PartialState()
|
624 |
+
|
625 |
+
|
626 |
+
@state.on_process(process_index=2)
|
627 |
+
def print_something():
|
628 |
+
print(f"Printed on process {state.process_index}")
|
629 |
+
|
630 |
+
|
631 |
+
print_something()
|
632 |
+
"Printed on process 2"
|
633 |
+
```
|
634 |
+
"""
|
635 |
+
if function is None:
|
636 |
+
return partial(self.on_process, process_index=process_index)
|
637 |
+
if (self.process_index == process_index) or (not self.use_distributed):
|
638 |
+
return function
|
639 |
+
return do_nothing
|
640 |
+
|
641 |
+
def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
|
642 |
+
"""
|
643 |
+
Decorator that only runs the decorated function on the process with the given index on the current node.
|
644 |
+
|
645 |
+
Args:
|
646 |
+
function (`Callable`, *optional*):
|
647 |
+
The function to decorate.
|
648 |
+
local_process_index (`int`, *optional*):
|
649 |
+
The index of the local process on which to run the function.
|
650 |
+
|
651 |
+
Example:
|
652 |
+
```python
|
653 |
+
# Assume we have 2 servers with 4 processes each.
|
654 |
+
from accelerate import Accelerator
|
655 |
+
|
656 |
+
accelerator = Accelerator()
|
657 |
+
|
658 |
+
|
659 |
+
@accelerator.on_local_process(local_process_index=2)
|
660 |
+
def print_something():
|
661 |
+
print(f"Printed on process {accelerator.local_process_index}")
|
662 |
+
|
663 |
+
|
664 |
+
print_something()
|
665 |
+
# On server 1:
|
666 |
+
"Printed on process 2"
|
667 |
+
# On server 2:
|
668 |
+
"Printed on process 2"
|
669 |
+
```
|
670 |
+
"""
|
671 |
+
if function is None:
|
672 |
+
return partial(self.on_local_process, local_process_index=local_process_index)
|
673 |
+
if (self.local_process_index == local_process_index) or (not self.use_distributed):
|
674 |
+
return function
|
675 |
+
return do_nothing
|
676 |
+
|
677 |
+
def print(self, *args, **kwargs):
|
678 |
+
if self.is_local_main_process:
|
679 |
+
print(*args, **kwargs)
|
680 |
+
|
681 |
+
@property
|
682 |
+
def default_device(self) -> torch.device:
|
683 |
+
"""
|
684 |
+
Returns the default device which is:
|
685 |
+
- MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
|
686 |
+
- CUDA if `torch.cuda.is_available()`
|
687 |
+
- MLU if `is_mlu_available()`
|
688 |
+
- MUSA if `is_musa_available()`
|
689 |
+
- NPU if `is_npu_available()`
|
690 |
+
- CPU otherwise
|
691 |
+
"""
|
692 |
+
if is_mps_available():
|
693 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
694 |
+
return torch.device("mps")
|
695 |
+
elif is_mlu_available():
|
696 |
+
return torch.device("mlu")
|
697 |
+
elif is_musa_available():
|
698 |
+
return torch.device("musa")
|
699 |
+
# NPU should be checked before CUDA when using `transfer_to_npu`
|
700 |
+
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
701 |
+
elif is_npu_available():
|
702 |
+
return torch.device("npu")
|
703 |
+
elif torch.cuda.is_available():
|
704 |
+
return torch.device("cuda")
|
705 |
+
elif is_xpu_available():
|
706 |
+
return torch.device("xpu")
|
707 |
+
else:
|
708 |
+
return torch.device("cpu")
|
709 |
+
|
710 |
+
def _prepare_backend(
|
711 |
+
self, cpu: bool = False, sagemaker_dp=False, backend: str = None
|
712 |
+
) -> tuple[str, DistributedType]:
|
713 |
+
"Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
|
714 |
+
distributed_type = None
|
715 |
+
if sagemaker_dp:
|
716 |
+
import smdistributed.dataparallel.torch.torch_smddp # noqa
|
717 |
+
|
718 |
+
backend = "smddp"
|
719 |
+
distributed_type = DistributedType.MULTI_GPU
|
720 |
+
elif is_torch_xla_available():
|
721 |
+
backend = "xla"
|
722 |
+
distributed_type = DistributedType.XLA
|
723 |
+
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
|
724 |
+
if is_mlu_available():
|
725 |
+
backend = "cncl"
|
726 |
+
distributed_type = DistributedType.MULTI_MLU
|
727 |
+
elif is_musa_available():
|
728 |
+
backend = "mccl"
|
729 |
+
distributed_type = DistributedType.MULTI_MUSA
|
730 |
+
# NPU should be checked before CUDA when using `transfer_to_npu`
|
731 |
+
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
732 |
+
elif is_npu_available():
|
733 |
+
backend = "hccl"
|
734 |
+
distributed_type = DistributedType.MULTI_NPU
|
735 |
+
elif torch.cuda.is_available():
|
736 |
+
if backend is None:
|
737 |
+
backend = "nccl"
|
738 |
+
distributed_type = DistributedType.MULTI_GPU
|
739 |
+
|
740 |
+
if distributed_type is None and (
|
741 |
+
int(os.environ.get("LOCAL_RANK", -1)) != -1
|
742 |
+
or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
|
743 |
+
):
|
744 |
+
if not cpu and is_xpu_available():
|
745 |
+
distributed_type = DistributedType.MULTI_XPU
|
746 |
+
else:
|
747 |
+
distributed_type = DistributedType.MULTI_CPU
|
748 |
+
|
749 |
+
if (
|
750 |
+
backend in (None, "ccl")
|
751 |
+
and is_ccl_available()
|
752 |
+
and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
|
753 |
+
):
|
754 |
+
if get_ccl_version() >= "1.12":
|
755 |
+
import oneccl_bindings_for_pytorch # noqa: F401
|
756 |
+
else:
|
757 |
+
import torch_ccl # noqa: F401
|
758 |
+
|
759 |
+
backend = "ccl"
|
760 |
+
elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
|
761 |
+
backend = "mpi"
|
762 |
+
else:
|
763 |
+
backend = "gloo"
|
764 |
+
if distributed_type is None:
|
765 |
+
distributed_type = DistributedType.NO
|
766 |
+
|
767 |
+
return backend, distributed_type
|
768 |
+
|
769 |
+
def set_device(self):
|
770 |
+
"""
|
771 |
+
Sets the device in `self.device` to the current distributed environment.
|
772 |
+
"""
|
773 |
+
if self.device is not None:
|
774 |
+
return
|
775 |
+
if self.distributed_type == DistributedType.NO:
|
776 |
+
self.device = torch.device("cpu") if self._cpu else self.default_device
|
777 |
+
return
|
778 |
+
device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
|
779 |
+
if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla"):
|
780 |
+
raise ValueError(
|
781 |
+
f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
|
782 |
+
)
|
783 |
+
if device == "xla":
|
784 |
+
self.device = xm.xla_device()
|
785 |
+
else:
|
786 |
+
if device == "gpu":
|
787 |
+
device = "cuda"
|
788 |
+
device_module = getattr(torch, device)
|
789 |
+
device_index = self.local_process_index % device_module.device_count()
|
790 |
+
self.device = torch.device(device, device_index)
|
791 |
+
device_module.set_device(self.device)
|
792 |
+
|
793 |
+
def destroy_process_group(self, group=None):
|
794 |
+
"""
|
795 |
+
Destroys the process group. If one is not specified, the default process group is destroyed.
|
796 |
+
"""
|
797 |
+
if self.fork_launched and group is None:
|
798 |
+
return
|
799 |
+
# needed when using torch.distributed.init_process_group
|
800 |
+
if torch.distributed.is_initialized():
|
801 |
+
torch.distributed.destroy_process_group(group)
|
802 |
+
|
803 |
+
def __getattr__(self, name: str):
|
804 |
+
# By this point we know that no attributes of `self` contain `name`,
|
805 |
+
# so we just modify the error message
|
806 |
+
if name in self._known_attrs:
|
807 |
+
raise AttributeError(
|
808 |
+
f"`PartialState` object has no attribute `{name}`. "
|
809 |
+
"This happens if `PartialState._reset_state()` was called and "
|
810 |
+
"an `Accelerator` or `PartialState` was not reinitialized."
|
811 |
+
)
|
812 |
+
# Raise a typical AttributeError
|
813 |
+
raise AttributeError(f"'PartialState' object has no attribute '{name}'")
|
814 |
+
|
815 |
+
|
816 |
+
class AcceleratorState:
|
817 |
+
"""
|
818 |
+
Singleton class that has information about the current training environment.
|
819 |
+
|
820 |
+
**Available attributes:**
|
821 |
+
|
822 |
+
- **device** (`torch.device`) -- The device to use.
|
823 |
+
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
824 |
+
in use.
|
825 |
+
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
|
826 |
+
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
827 |
+
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
828 |
+
of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
|
829 |
+
- **num_processes** (`int`) -- The number of processes currently launched in parallel.
|
830 |
+
- **process_index** (`int`) -- The index of the current process.
|
831 |
+
- **is_last_process** (`bool`) -- Whether or not the current process is the last one.
|
832 |
+
- **is_main_process** (`bool`) -- Whether or not the current process is the main one.
|
833 |
+
- **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
|
834 |
+
- **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
|
835 |
+
"""
|
836 |
+
|
837 |
+
_shared_state = SharedDict()
|
838 |
+
_known_attrs = PartialState._known_attrs + [
|
839 |
+
"deepspeed_plugin",
|
840 |
+
"use_ipex",
|
841 |
+
"fsdp_plugin",
|
842 |
+
"megatron_lm_plugin",
|
843 |
+
"dynamo_plugin",
|
844 |
+
]
|
845 |
+
|
846 |
+
def __init__(
|
847 |
+
self,
|
848 |
+
mixed_precision: str = None,
|
849 |
+
cpu: bool = False,
|
850 |
+
dynamo_plugin=None,
|
851 |
+
deepspeed_plugin=None,
|
852 |
+
fsdp_plugin=None,
|
853 |
+
megatron_lm_plugin=None,
|
854 |
+
_from_accelerator: bool = False,
|
855 |
+
**kwargs,
|
856 |
+
):
|
857 |
+
self.__dict__ = self._shared_state
|
858 |
+
if parse_flag_from_env("ACCELERATE_USE_CPU"):
|
859 |
+
cpu = True
|
860 |
+
if PartialState._shared_state == {}:
|
861 |
+
PartialState(cpu, **kwargs)
|
862 |
+
self.__dict__.update(PartialState._shared_state)
|
863 |
+
self._check_initialized(mixed_precision, cpu)
|
864 |
+
if not self.initialized:
|
865 |
+
self.deepspeed_plugins = None
|
866 |
+
self.use_ipex = None
|
867 |
+
mixed_precision = (
|
868 |
+
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
|
869 |
+
if mixed_precision is None
|
870 |
+
else mixed_precision.lower()
|
871 |
+
)
|
872 |
+
if mixed_precision == "fp8":
|
873 |
+
if not is_fp8_available():
|
874 |
+
raise ValueError(
|
875 |
+
"Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
|
876 |
+
)
|
877 |
+
elif not check_fp8_capability():
|
878 |
+
logger.warning(
|
879 |
+
f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
|
880 |
+
"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
|
881 |
+
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
|
882 |
+
)
|
883 |
+
mixed_precision = "fp16"
|
884 |
+
|
885 |
+
self.dynamo_plugin = dynamo_plugin
|
886 |
+
if not _from_accelerator:
|
887 |
+
raise ValueError(
|
888 |
+
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
|
889 |
+
"before using any functionality from the `accelerate` library."
|
890 |
+
)
|
891 |
+
# deepspeed handles mixed_precision using deepspeed_config
|
892 |
+
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
|
893 |
+
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
|
894 |
+
if mixed_precision == "bf16":
|
895 |
+
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
|
896 |
+
os.environ["XLA_USE_BF16"] = str(0)
|
897 |
+
os.environ["XLA_DOWNCAST_BF16"] = str(1)
|
898 |
+
self.downcast_bfloat = True
|
899 |
+
else:
|
900 |
+
os.environ["XLA_USE_BF16"] = str(1)
|
901 |
+
os.environ["XLA_DOWNCAST_BF16"] = str(0)
|
902 |
+
self.downcast_bfloat = False
|
903 |
+
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
|
904 |
+
self.deepspeed_plugins = deepspeed_plugin
|
905 |
+
self.distributed_type = DistributedType.DEEPSPEED
|
906 |
+
elif self.distributed_type in [
|
907 |
+
DistributedType.MULTI_GPU,
|
908 |
+
DistributedType.MULTI_MLU,
|
909 |
+
DistributedType.MULTI_MUSA,
|
910 |
+
DistributedType.MULTI_NPU,
|
911 |
+
DistributedType.MULTI_XPU,
|
912 |
+
]:
|
913 |
+
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
|
914 |
+
self.distributed_type = DistributedType.FSDP
|
915 |
+
if self._mixed_precision != "no":
|
916 |
+
fsdp_plugin.set_mixed_precision(self._mixed_precision)
|
917 |
+
self.fsdp_plugin = fsdp_plugin
|
918 |
+
if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
|
919 |
+
DistributedType.MULTI_XPU,
|
920 |
+
]:
|
921 |
+
self.distributed_type = DistributedType.MEGATRON_LM
|
922 |
+
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
|
923 |
+
self.megatron_lm_plugin = megatron_lm_plugin
|
924 |
+
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
|
925 |
+
if is_ipex_available():
|
926 |
+
# check if user disables it explicitly
|
927 |
+
self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
|
928 |
+
else:
|
929 |
+
self.use_ipex = False
|
930 |
+
if (
|
931 |
+
self.dynamo_plugin.backend != DynamoBackend.NO
|
932 |
+
and self._mixed_precision == "no"
|
933 |
+
and self.device.type == "cuda"
|
934 |
+
):
|
935 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
936 |
+
if (
|
937 |
+
self.dynamo_plugin.backend != DynamoBackend.NO
|
938 |
+
and self._mixed_precision == "no"
|
939 |
+
and self.device.type == "musa"
|
940 |
+
):
|
941 |
+
torch.backends.musa.matmul.allow_tf32 = True
|
942 |
+
PartialState._shared_state["distributed_type"] = self.distributed_type
|
943 |
+
|
944 |
+
@property
|
945 |
+
def initialized(self) -> bool:
|
946 |
+
return self._shared_state != PartialState._shared_state
|
947 |
+
|
948 |
+
def __repr__(self):
|
949 |
+
repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
|
950 |
+
if self.distributed_type == DistributedType.DEEPSPEED:
|
951 |
+
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
|
952 |
+
return repr
|
953 |
+
|
954 |
+
def _check_initialized(self, mixed_precision=None, cpu=None):
|
955 |
+
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
|
956 |
+
if self.initialized:
|
957 |
+
err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
|
958 |
+
if cpu and self.device.type != "cpu":
|
959 |
+
raise ValueError(err.format(flag="cpu=True"))
|
960 |
+
if (
|
961 |
+
mixed_precision is not None
|
962 |
+
and mixed_precision != self._mixed_precision
|
963 |
+
and self.distributed_type != DistributedType.DEEPSPEED
|
964 |
+
):
|
965 |
+
raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
|
966 |
+
|
967 |
+
@property
|
968 |
+
def mixed_precision(self):
|
969 |
+
if self.distributed_type == DistributedType.DEEPSPEED:
|
970 |
+
config = self.deepspeed_plugin.deepspeed_config
|
971 |
+
if config.get("fp16", {}).get("enabled", False):
|
972 |
+
mixed_precision = "fp16"
|
973 |
+
elif config.get("bf16", {}).get("enabled", False):
|
974 |
+
mixed_precision = "bf16"
|
975 |
+
else:
|
976 |
+
mixed_precision = "no"
|
977 |
+
else:
|
978 |
+
mixed_precision = self._mixed_precision
|
979 |
+
return mixed_precision
|
980 |
+
|
981 |
+
@staticmethod
|
982 |
+
def _reset_state(reset_partial_state: bool = False):
|
983 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
984 |
+
AcceleratorState._shared_state.clear()
|
985 |
+
if reset_partial_state:
|
986 |
+
PartialState._reset_state()
|
987 |
+
|
988 |
+
def destroy_process_group(self, group=None):
|
989 |
+
"""
|
990 |
+
Destroys the process group. If one is not specified, the default process group is destroyed.
|
991 |
+
|
992 |
+
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
993 |
+
"""
|
994 |
+
PartialState().destroy_process_group(group)
|
995 |
+
|
996 |
+
@property
|
997 |
+
def fork_launched(self):
|
998 |
+
return PartialState().fork_launched
|
999 |
+
|
1000 |
+
@property
|
1001 |
+
def use_distributed(self):
|
1002 |
+
"""
|
1003 |
+
Whether the Accelerator is configured for distributed training
|
1004 |
+
"""
|
1005 |
+
return PartialState().use_distributed
|
1006 |
+
|
1007 |
+
@property
|
1008 |
+
def is_last_process(self) -> bool:
|
1009 |
+
"Returns whether the current process is the last one"
|
1010 |
+
return PartialState().is_last_process
|
1011 |
+
|
1012 |
+
@property
|
1013 |
+
def is_main_process(self) -> bool:
|
1014 |
+
"Returns whether the current process is the main process"
|
1015 |
+
return PartialState().is_main_process
|
1016 |
+
|
1017 |
+
@property
|
1018 |
+
def is_local_main_process(self) -> bool:
|
1019 |
+
"Returns whether the current process is the main process on the local node"
|
1020 |
+
return PartialState().is_local_main_process
|
1021 |
+
|
1022 |
+
def wait_for_everyone(self):
|
1023 |
+
PartialState().wait_for_everyone()
|
1024 |
+
|
1025 |
+
@contextmanager
|
1026 |
+
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
|
1027 |
+
"""
|
1028 |
+
Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
|
1029 |
+
distributed inference, such as with different prompts.
|
1030 |
+
|
1031 |
+
Note that when using a `dict`, all keys need to have the same number of elements.
|
1032 |
+
|
1033 |
+
Args:
|
1034 |
+
inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
|
1035 |
+
The input to split between processes.
|
1036 |
+
apply_padding (`bool`, `optional`, defaults to `False`):
|
1037 |
+
Whether to apply padding by repeating the last element of the input so that all processes have the same
|
1038 |
+
number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
|
1039 |
+
in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
|
1040 |
+
|
1041 |
+
|
1042 |
+
Example:
|
1043 |
+
|
1044 |
+
```python
|
1045 |
+
# Assume there are two processes
|
1046 |
+
from accelerate.state import AcceleratorState
|
1047 |
+
|
1048 |
+
state = AcceleratorState()
|
1049 |
+
with state.split_between_processes(["A", "B", "C"]) as inputs:
|
1050 |
+
print(inputs)
|
1051 |
+
# Process 0
|
1052 |
+
["A", "B"]
|
1053 |
+
# Process 1
|
1054 |
+
["C"]
|
1055 |
+
|
1056 |
+
with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
|
1057 |
+
print(inputs)
|
1058 |
+
# Process 0
|
1059 |
+
["A", "B"]
|
1060 |
+
# Process 1
|
1061 |
+
["C", "C"]
|
1062 |
+
```
|
1063 |
+
"""
|
1064 |
+
with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
|
1065 |
+
yield inputs
|
1066 |
+
|
1067 |
+
@contextmanager
|
1068 |
+
def main_process_first(self):
|
1069 |
+
"""
|
1070 |
+
Lets the main process go first inside a with block.
|
1071 |
+
|
1072 |
+
The other processes will enter the with block after the main process exits.
|
1073 |
+
"""
|
1074 |
+
with PartialState().main_process_first():
|
1075 |
+
yield
|
1076 |
+
|
1077 |
+
@contextmanager
|
1078 |
+
def local_main_process_first(self):
|
1079 |
+
"""
|
1080 |
+
Lets the local main process go inside a with block.
|
1081 |
+
|
1082 |
+
The other processes will enter the with block after the main process exits.
|
1083 |
+
"""
|
1084 |
+
with PartialState().local_main_process_first():
|
1085 |
+
yield
|
1086 |
+
|
1087 |
+
@property
|
1088 |
+
def deepspeed_plugin(self):
|
1089 |
+
"""
|
1090 |
+
Returns the currently active DeepSpeedPlugin.
|
1091 |
+
|
1092 |
+
If not using deepspeed, returns `None`.
|
1093 |
+
"""
|
1094 |
+
# To maintain original behavior, return None if not using deepspeed.
|
1095 |
+
if self.distributed_type != DistributedType.DEEPSPEED:
|
1096 |
+
return None
|
1097 |
+
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
|
1098 |
+
|
1099 |
+
return get_active_deepspeed_plugin(self)
|
1100 |
+
|
1101 |
+
@deepspeed_required
|
1102 |
+
def get_deepspeed_plugin(self, name: str):
|
1103 |
+
"""
|
1104 |
+
Returns the DeepSpeedPlugin with the given plugin_key.
|
1105 |
+
"""
|
1106 |
+
return self.deepspeed_plugins[name]
|
1107 |
+
|
1108 |
+
@deepspeed_required
|
1109 |
+
def select_deepspeed_plugin(self, name: str = None):
|
1110 |
+
"""
|
1111 |
+
Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
|
1112 |
+
"""
|
1113 |
+
for key, plugin in self.deepspeed_plugins.items():
|
1114 |
+
if key != name:
|
1115 |
+
plugin._unselect()
|
1116 |
+
self.deepspeed_plugins[name].select(_from_accelerator_state=True)
|
1117 |
+
|
1118 |
+
def print(self, *args, **kwargs):
|
1119 |
+
PartialState().print(*args, **kwargs)
|
1120 |
+
|
1121 |
+
def __getattr__(self, name: str):
|
1122 |
+
# By this point we know that no attributes of `self` contain `name`,
|
1123 |
+
# so we just modify the error message
|
1124 |
+
if name in self._known_attrs:
|
1125 |
+
raise AttributeError(
|
1126 |
+
f"`AcceleratorState` object has no attribute `{name}`. "
|
1127 |
+
"This happens if `AcceleratorState._reset_state()` was called and "
|
1128 |
+
"an `Accelerator` or `PartialState` was not reinitialized."
|
1129 |
+
)
|
1130 |
+
# Raise a typical AttributeError
|
1131 |
+
raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
|
1132 |
+
|
1133 |
+
|
1134 |
+
class GradientState:
|
1135 |
+
"""
|
1136 |
+
Singleton class that has information related to gradient synchronization for gradient accumulation
|
1137 |
+
|
1138 |
+
**Available attributes:**
|
1139 |
+
|
1140 |
+
- **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
|
1141 |
+
- **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
|
1142 |
+
- **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
|
1143 |
+
- **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
|
1144 |
+
- **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
|
1145 |
+
being iterated over
|
1146 |
+
- **num_steps** (`int`) -- The number of steps to accumulate over
|
1147 |
+
- **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
|
1148 |
+
accumulation
|
1149 |
+
- **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
|
1150 |
+
iteration and the number of total steps reset
|
1151 |
+
- **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
|
1152 |
+
as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
|
1153 |
+
after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
|
1154 |
+
is_xla_gradients_synced is always true.
|
1155 |
+
"""
|
1156 |
+
|
1157 |
+
_shared_state = SharedDict()
|
1158 |
+
|
1159 |
+
def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None):
|
1160 |
+
self.__dict__ = self._shared_state
|
1161 |
+
if not self.initialized:
|
1162 |
+
self.sync_gradients = True
|
1163 |
+
self.active_dataloader = None
|
1164 |
+
self.dataloader_references = [None]
|
1165 |
+
self.plugin_kwargs = (
|
1166 |
+
gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
|
1167 |
+
)
|
1168 |
+
self._is_xla_gradients_synced = False
|
1169 |
+
|
1170 |
+
# Plugin args are different and can be updated
|
1171 |
+
if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
|
1172 |
+
self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
|
1173 |
+
|
1174 |
+
@property
|
1175 |
+
def num_steps(self) -> int:
|
1176 |
+
"Returns the number of steps to accumulate over"
|
1177 |
+
return self.plugin_kwargs.get("num_steps", 1)
|
1178 |
+
|
1179 |
+
@property
|
1180 |
+
def adjust_scheduler(self) -> bool:
|
1181 |
+
"Returns whether the scheduler should be adjusted"
|
1182 |
+
return self.plugin_kwargs.get("adjust_scheduler", False)
|
1183 |
+
|
1184 |
+
@property
|
1185 |
+
def sync_with_dataloader(self) -> bool:
|
1186 |
+
"Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
|
1187 |
+
return self.plugin_kwargs.get("sync_with_dataloader", True)
|
1188 |
+
|
1189 |
+
@property
|
1190 |
+
def initialized(self) -> bool:
|
1191 |
+
"Returns whether the `GradientState` has been initialized"
|
1192 |
+
return GradientState._shared_state != {}
|
1193 |
+
|
1194 |
+
@property
|
1195 |
+
def end_of_dataloader(self) -> bool:
|
1196 |
+
"Returns whether we have reached the end of the current dataloader"
|
1197 |
+
if not self.in_dataloader:
|
1198 |
+
return False
|
1199 |
+
return self.active_dataloader.end_of_dataloader
|
1200 |
+
|
1201 |
+
@property
|
1202 |
+
def remainder(self) -> int:
|
1203 |
+
"Returns the number of extra samples that were added from padding the dataloader"
|
1204 |
+
if not self.in_dataloader:
|
1205 |
+
return -1
|
1206 |
+
return self.active_dataloader.remainder
|
1207 |
+
|
1208 |
+
def __repr__(self):
|
1209 |
+
return (
|
1210 |
+
f"Sync Gradients: {self.sync_gradients}\n"
|
1211 |
+
f"At end of current dataloader: {self.end_of_dataloader}\n"
|
1212 |
+
f"Extra samples added: {self.remainder}\n"
|
1213 |
+
f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
@property
|
1217 |
+
def is_xla_gradients_synced(self):
|
1218 |
+
"Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
|
1219 |
+
if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
|
1220 |
+
return True
|
1221 |
+
return self._is_xla_gradients_synced
|
1222 |
+
|
1223 |
+
@is_xla_gradients_synced.setter
|
1224 |
+
def is_xla_gradients_synced(self, is_synced):
|
1225 |
+
"Set the _is_xla_gradients_synced attribute."
|
1226 |
+
self._is_xla_gradients_synced = is_synced
|
1227 |
+
|
1228 |
+
def _set_sync_gradients(self, sync_gradients):
|
1229 |
+
"Private function that sets whether gradients should be synchronized. Users should not have to call this."
|
1230 |
+
self.sync_gradients = sync_gradients
|
1231 |
+
# Allow grad-sync to automatically work on TPUs
|
1232 |
+
if (
|
1233 |
+
self.sync_gradients
|
1234 |
+
and is_torch_xla_available(check_is_tpu=True)
|
1235 |
+
and PartialState().distributed_type == DistributedType.XLA
|
1236 |
+
):
|
1237 |
+
xm.mark_step()
|
1238 |
+
|
1239 |
+
def _add_dataloader(self, dataloader):
|
1240 |
+
"Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
|
1241 |
+
self.active_dataloader = dataloader
|
1242 |
+
self.dataloader_references.append(self.active_dataloader)
|
1243 |
+
|
1244 |
+
def _remove_dataloader(self, dataloader):
|
1245 |
+
"Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
|
1246 |
+
self.dataloader_references.remove(dataloader)
|
1247 |
+
self.active_dataloader = self.dataloader_references[-1]
|
1248 |
+
|
1249 |
+
@property
|
1250 |
+
def in_dataloader(self) -> bool:
|
1251 |
+
"Returns whether the current process is in a dataloader"
|
1252 |
+
return self.active_dataloader is not None
|
1253 |
+
|
1254 |
+
@staticmethod
|
1255 |
+
def _reset_state():
|
1256 |
+
"Resets `_shared_state`, is used internally and should not be called"
|
1257 |
+
GradientState._shared_state.clear()
|
.venv/Lib/site-packages/accelerate/tracking.py
ADDED
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Expectation:
|
16 |
+
# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
|
17 |
+
|
18 |
+
import json
|
19 |
+
import os
|
20 |
+
import time
|
21 |
+
from functools import wraps
|
22 |
+
from typing import Any, Dict, List, Optional, Union
|
23 |
+
|
24 |
+
import yaml
|
25 |
+
|
26 |
+
from .logging import get_logger
|
27 |
+
from .state import PartialState
|
28 |
+
from .utils import (
|
29 |
+
LoggerType,
|
30 |
+
is_aim_available,
|
31 |
+
is_clearml_available,
|
32 |
+
is_comet_ml_available,
|
33 |
+
is_dvclive_available,
|
34 |
+
is_mlflow_available,
|
35 |
+
is_tensorboard_available,
|
36 |
+
is_wandb_available,
|
37 |
+
listify,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
_available_trackers = []
|
42 |
+
|
43 |
+
if is_tensorboard_available():
|
44 |
+
_available_trackers.append(LoggerType.TENSORBOARD)
|
45 |
+
|
46 |
+
if is_wandb_available():
|
47 |
+
_available_trackers.append(LoggerType.WANDB)
|
48 |
+
|
49 |
+
if is_comet_ml_available():
|
50 |
+
_available_trackers.append(LoggerType.COMETML)
|
51 |
+
|
52 |
+
if is_aim_available():
|
53 |
+
_available_trackers.append(LoggerType.AIM)
|
54 |
+
|
55 |
+
if is_mlflow_available():
|
56 |
+
_available_trackers.append(LoggerType.MLFLOW)
|
57 |
+
|
58 |
+
if is_clearml_available():
|
59 |
+
_available_trackers.append(LoggerType.CLEARML)
|
60 |
+
|
61 |
+
if is_dvclive_available():
|
62 |
+
_available_trackers.append(LoggerType.DVCLIVE)
|
63 |
+
|
64 |
+
logger = get_logger(__name__)
|
65 |
+
|
66 |
+
|
67 |
+
def on_main_process(function):
|
68 |
+
"""
|
69 |
+
Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
|
70 |
+
attribute in a class.
|
71 |
+
|
72 |
+
Checks at function execution rather than initialization time, not triggering the initialization of the
|
73 |
+
`PartialState`.
|
74 |
+
"""
|
75 |
+
|
76 |
+
@wraps(function)
|
77 |
+
def execute_on_main_process(self, *args, **kwargs):
|
78 |
+
if getattr(self, "main_process_only", False):
|
79 |
+
return PartialState().on_main_process(function)(self, *args, **kwargs)
|
80 |
+
else:
|
81 |
+
return function(self, *args, **kwargs)
|
82 |
+
|
83 |
+
return execute_on_main_process
|
84 |
+
|
85 |
+
|
86 |
+
def get_available_trackers():
|
87 |
+
"Returns a list of all supported available trackers in the system"
|
88 |
+
return _available_trackers
|
89 |
+
|
90 |
+
|
91 |
+
class GeneralTracker:
|
92 |
+
"""
|
93 |
+
A base Tracker class to be used for all logging integration implementations.
|
94 |
+
|
95 |
+
Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
|
96 |
+
[`Accelerator`].
|
97 |
+
|
98 |
+
Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
|
99 |
+
|
100 |
+
`name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
|
101 |
+
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
|
102 |
+
tracking mechanism used by a tracker class (such as the `run` for wandb)
|
103 |
+
|
104 |
+
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
|
105 |
+
other functions should occur on the main process or across all processes (by default will use `True`)
|
106 |
+
"""
|
107 |
+
|
108 |
+
main_process_only = True
|
109 |
+
|
110 |
+
def __init__(self, _blank=False):
|
111 |
+
if not _blank:
|
112 |
+
err = ""
|
113 |
+
if not hasattr(self, "name"):
|
114 |
+
err += "`name`"
|
115 |
+
if not hasattr(self, "requires_logging_directory"):
|
116 |
+
if len(err) > 0:
|
117 |
+
err += ", "
|
118 |
+
err += "`requires_logging_directory`"
|
119 |
+
|
120 |
+
# as tracker is a @property that relies on post-init
|
121 |
+
if "tracker" not in dir(self):
|
122 |
+
if len(err) > 0:
|
123 |
+
err += ", "
|
124 |
+
err += "`tracker`"
|
125 |
+
if len(err) > 0:
|
126 |
+
raise NotImplementedError(
|
127 |
+
f"The implementation for this tracker class is missing the following "
|
128 |
+
f"required attributes. Please define them in the class definition: "
|
129 |
+
f"{err}"
|
130 |
+
)
|
131 |
+
|
132 |
+
def store_init_configuration(self, values: dict):
|
133 |
+
"""
|
134 |
+
Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
|
135 |
+
functionality of a tracking API.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
139 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
140 |
+
`str`, `float`, `int`, or `None`.
|
141 |
+
"""
|
142 |
+
pass
|
143 |
+
|
144 |
+
def log(self, values: dict, step: Optional[int], **kwargs):
|
145 |
+
"""
|
146 |
+
Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
|
147 |
+
special behavior for the `step parameter.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
values (Dictionary `str` to `str`, `float`, or `int`):
|
151 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
|
152 |
+
step (`int`, *optional*):
|
153 |
+
The run step. If included, the log will be affiliated with this step.
|
154 |
+
"""
|
155 |
+
pass
|
156 |
+
|
157 |
+
def finish(self):
|
158 |
+
"""
|
159 |
+
Should run any finalizing functions within the tracking API. If the API should not have one, just don't
|
160 |
+
overwrite that method.
|
161 |
+
"""
|
162 |
+
pass
|
163 |
+
|
164 |
+
|
165 |
+
class TensorBoardTracker(GeneralTracker):
|
166 |
+
"""
|
167 |
+
A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
run_name (`str`):
|
171 |
+
The name of the experiment run
|
172 |
+
logging_dir (`str`, `os.PathLike`):
|
173 |
+
Location for TensorBoard logs to be stored.
|
174 |
+
**kwargs (additional keyword arguments, *optional*):
|
175 |
+
Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
|
176 |
+
"""
|
177 |
+
|
178 |
+
name = "tensorboard"
|
179 |
+
requires_logging_directory = True
|
180 |
+
|
181 |
+
@on_main_process
|
182 |
+
def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
|
183 |
+
try:
|
184 |
+
from torch.utils import tensorboard
|
185 |
+
except ModuleNotFoundError:
|
186 |
+
import tensorboardX as tensorboard
|
187 |
+
super().__init__()
|
188 |
+
self.run_name = run_name
|
189 |
+
self.logging_dir = os.path.join(logging_dir, run_name)
|
190 |
+
self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
|
191 |
+
logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
|
192 |
+
logger.debug(
|
193 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
194 |
+
)
|
195 |
+
|
196 |
+
@property
|
197 |
+
def tracker(self):
|
198 |
+
return self.writer
|
199 |
+
|
200 |
+
@on_main_process
|
201 |
+
def store_init_configuration(self, values: dict):
|
202 |
+
"""
|
203 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
|
204 |
+
hyperparameters in a yaml file for future use.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
208 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
209 |
+
`str`, `float`, `int`, or `None`.
|
210 |
+
"""
|
211 |
+
self.writer.add_hparams(values, metric_dict={})
|
212 |
+
self.writer.flush()
|
213 |
+
project_run_name = time.time()
|
214 |
+
dir_name = os.path.join(self.logging_dir, str(project_run_name))
|
215 |
+
os.makedirs(dir_name, exist_ok=True)
|
216 |
+
with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
|
217 |
+
try:
|
218 |
+
yaml.dump(values, outfile)
|
219 |
+
except yaml.representer.RepresenterError:
|
220 |
+
logger.error("Serialization to store hyperparameters failed")
|
221 |
+
raise
|
222 |
+
logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
|
223 |
+
|
224 |
+
@on_main_process
|
225 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
226 |
+
"""
|
227 |
+
Logs `values` to the current run.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
231 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
232 |
+
`str` to `float`/`int`.
|
233 |
+
step (`int`, *optional*):
|
234 |
+
The run step. If included, the log will be affiliated with this step.
|
235 |
+
kwargs:
|
236 |
+
Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
|
237 |
+
`SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
|
238 |
+
"""
|
239 |
+
values = listify(values)
|
240 |
+
for k, v in values.items():
|
241 |
+
if isinstance(v, (int, float)):
|
242 |
+
self.writer.add_scalar(k, v, global_step=step, **kwargs)
|
243 |
+
elif isinstance(v, str):
|
244 |
+
self.writer.add_text(k, v, global_step=step, **kwargs)
|
245 |
+
elif isinstance(v, dict):
|
246 |
+
self.writer.add_scalars(k, v, global_step=step, **kwargs)
|
247 |
+
self.writer.flush()
|
248 |
+
logger.debug("Successfully logged to TensorBoard")
|
249 |
+
|
250 |
+
@on_main_process
|
251 |
+
def log_images(self, values: dict, step: Optional[int], **kwargs):
|
252 |
+
"""
|
253 |
+
Logs `images` to the current run.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
|
257 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
258 |
+
step (`int`, *optional*):
|
259 |
+
The run step. If included, the log will be affiliated with this step.
|
260 |
+
kwargs:
|
261 |
+
Additional key word arguments passed along to the `SummaryWriter.add_image` method.
|
262 |
+
"""
|
263 |
+
for k, v in values.items():
|
264 |
+
self.writer.add_images(k, v, global_step=step, **kwargs)
|
265 |
+
logger.debug("Successfully logged images to TensorBoard")
|
266 |
+
|
267 |
+
@on_main_process
|
268 |
+
def finish(self):
|
269 |
+
"""
|
270 |
+
Closes `TensorBoard` writer
|
271 |
+
"""
|
272 |
+
self.writer.close()
|
273 |
+
logger.debug("TensorBoard writer closed")
|
274 |
+
|
275 |
+
|
276 |
+
class WandBTracker(GeneralTracker):
|
277 |
+
"""
|
278 |
+
A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
run_name (`str`):
|
282 |
+
The name of the experiment run.
|
283 |
+
**kwargs (additional keyword arguments, *optional*):
|
284 |
+
Additional key word arguments passed along to the `wandb.init` method.
|
285 |
+
"""
|
286 |
+
|
287 |
+
name = "wandb"
|
288 |
+
requires_logging_directory = False
|
289 |
+
main_process_only = False
|
290 |
+
|
291 |
+
@on_main_process
|
292 |
+
def __init__(self, run_name: str, **kwargs):
|
293 |
+
super().__init__()
|
294 |
+
self.run_name = run_name
|
295 |
+
|
296 |
+
import wandb
|
297 |
+
|
298 |
+
self.run = wandb.init(project=self.run_name, **kwargs)
|
299 |
+
logger.debug(f"Initialized WandB project {self.run_name}")
|
300 |
+
logger.debug(
|
301 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
302 |
+
)
|
303 |
+
|
304 |
+
@property
|
305 |
+
def tracker(self):
|
306 |
+
return self.run
|
307 |
+
|
308 |
+
@on_main_process
|
309 |
+
def store_init_configuration(self, values: dict):
|
310 |
+
"""
|
311 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
315 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
316 |
+
`str`, `float`, `int`, or `None`.
|
317 |
+
"""
|
318 |
+
import wandb
|
319 |
+
|
320 |
+
wandb.config.update(values, allow_val_change=True)
|
321 |
+
logger.debug("Stored initial configuration hyperparameters to WandB")
|
322 |
+
|
323 |
+
@on_main_process
|
324 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
325 |
+
"""
|
326 |
+
Logs `values` to the current run.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
330 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
331 |
+
`str` to `float`/`int`.
|
332 |
+
step (`int`, *optional*):
|
333 |
+
The run step. If included, the log will be affiliated with this step.
|
334 |
+
kwargs:
|
335 |
+
Additional key word arguments passed along to the `wandb.log` method.
|
336 |
+
"""
|
337 |
+
self.run.log(values, step=step, **kwargs)
|
338 |
+
logger.debug("Successfully logged to WandB")
|
339 |
+
|
340 |
+
@on_main_process
|
341 |
+
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
|
342 |
+
"""
|
343 |
+
Logs `images` to the current run.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
|
347 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
348 |
+
step (`int`, *optional*):
|
349 |
+
The run step. If included, the log will be affiliated with this step.
|
350 |
+
kwargs:
|
351 |
+
Additional key word arguments passed along to the `wandb.log` method.
|
352 |
+
"""
|
353 |
+
import wandb
|
354 |
+
|
355 |
+
for k, v in values.items():
|
356 |
+
self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
|
357 |
+
logger.debug("Successfully logged images to WandB")
|
358 |
+
|
359 |
+
@on_main_process
|
360 |
+
def log_table(
|
361 |
+
self,
|
362 |
+
table_name: str,
|
363 |
+
columns: List[str] = None,
|
364 |
+
data: List[List[Any]] = None,
|
365 |
+
dataframe: Any = None,
|
366 |
+
step: Optional[int] = None,
|
367 |
+
**kwargs,
|
368 |
+
):
|
369 |
+
"""
|
370 |
+
Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
|
371 |
+
with `columns` and `data` or with `dataframe`.
|
372 |
+
|
373 |
+
Args:
|
374 |
+
table_name (`str`):
|
375 |
+
The name to give to the logged table on the wandb workspace
|
376 |
+
columns (list of `str`, *optional*):
|
377 |
+
The name of the columns on the table
|
378 |
+
data (List of List of Any data type, *optional*):
|
379 |
+
The data to be logged in the table
|
380 |
+
dataframe (Any data type, *optional*):
|
381 |
+
The data to be logged in the table
|
382 |
+
step (`int`, *optional*):
|
383 |
+
The run step. If included, the log will be affiliated with this step.
|
384 |
+
"""
|
385 |
+
import wandb
|
386 |
+
|
387 |
+
values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
|
388 |
+
self.log(values, step=step, **kwargs)
|
389 |
+
|
390 |
+
@on_main_process
|
391 |
+
def finish(self):
|
392 |
+
"""
|
393 |
+
Closes `wandb` writer
|
394 |
+
"""
|
395 |
+
self.run.finish()
|
396 |
+
logger.debug("WandB run closed")
|
397 |
+
|
398 |
+
|
399 |
+
class CometMLTracker(GeneralTracker):
|
400 |
+
"""
|
401 |
+
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
|
402 |
+
|
403 |
+
API keys must be stored in a Comet config file.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
run_name (`str`):
|
407 |
+
The name of the experiment run.
|
408 |
+
**kwargs (additional keyword arguments, *optional*):
|
409 |
+
Additional key word arguments passed along to the `Experiment.__init__` method.
|
410 |
+
"""
|
411 |
+
|
412 |
+
name = "comet_ml"
|
413 |
+
requires_logging_directory = False
|
414 |
+
|
415 |
+
@on_main_process
|
416 |
+
def __init__(self, run_name: str, **kwargs):
|
417 |
+
super().__init__()
|
418 |
+
self.run_name = run_name
|
419 |
+
|
420 |
+
from comet_ml import Experiment
|
421 |
+
|
422 |
+
self.writer = Experiment(project_name=run_name, **kwargs)
|
423 |
+
logger.debug(f"Initialized CometML project {self.run_name}")
|
424 |
+
logger.debug(
|
425 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
426 |
+
)
|
427 |
+
|
428 |
+
@property
|
429 |
+
def tracker(self):
|
430 |
+
return self.writer
|
431 |
+
|
432 |
+
@on_main_process
|
433 |
+
def store_init_configuration(self, values: dict):
|
434 |
+
"""
|
435 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
439 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
440 |
+
`str`, `float`, `int`, or `None`.
|
441 |
+
"""
|
442 |
+
self.writer.log_parameters(values)
|
443 |
+
logger.debug("Stored initial configuration hyperparameters to CometML")
|
444 |
+
|
445 |
+
@on_main_process
|
446 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
447 |
+
"""
|
448 |
+
Logs `values` to the current run.
|
449 |
+
|
450 |
+
Args:
|
451 |
+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
452 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
453 |
+
`str` to `float`/`int`.
|
454 |
+
step (`int`, *optional*):
|
455 |
+
The run step. If included, the log will be affiliated with this step.
|
456 |
+
kwargs:
|
457 |
+
Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
|
458 |
+
or `Experiment.log_metrics` method based on the contents of `values`.
|
459 |
+
"""
|
460 |
+
if step is not None:
|
461 |
+
self.writer.set_step(step)
|
462 |
+
for k, v in values.items():
|
463 |
+
if isinstance(v, (int, float)):
|
464 |
+
self.writer.log_metric(k, v, step=step, **kwargs)
|
465 |
+
elif isinstance(v, str):
|
466 |
+
self.writer.log_other(k, v, **kwargs)
|
467 |
+
elif isinstance(v, dict):
|
468 |
+
self.writer.log_metrics(v, step=step, **kwargs)
|
469 |
+
logger.debug("Successfully logged to CometML")
|
470 |
+
|
471 |
+
@on_main_process
|
472 |
+
def finish(self):
|
473 |
+
"""
|
474 |
+
Closes `comet-ml` writer
|
475 |
+
"""
|
476 |
+
self.writer.end()
|
477 |
+
logger.debug("CometML run closed")
|
478 |
+
|
479 |
+
|
480 |
+
class AimTracker(GeneralTracker):
|
481 |
+
"""
|
482 |
+
A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
run_name (`str`):
|
486 |
+
The name of the experiment run.
|
487 |
+
**kwargs (additional keyword arguments, *optional*):
|
488 |
+
Additional key word arguments passed along to the `Run.__init__` method.
|
489 |
+
"""
|
490 |
+
|
491 |
+
name = "aim"
|
492 |
+
requires_logging_directory = True
|
493 |
+
|
494 |
+
@on_main_process
|
495 |
+
def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
|
496 |
+
self.run_name = run_name
|
497 |
+
|
498 |
+
from aim import Run
|
499 |
+
|
500 |
+
self.writer = Run(repo=logging_dir, **kwargs)
|
501 |
+
self.writer.name = self.run_name
|
502 |
+
logger.debug(f"Initialized Aim project {self.run_name}")
|
503 |
+
logger.debug(
|
504 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
505 |
+
)
|
506 |
+
|
507 |
+
@property
|
508 |
+
def tracker(self):
|
509 |
+
return self.writer
|
510 |
+
|
511 |
+
@on_main_process
|
512 |
+
def store_init_configuration(self, values: dict):
|
513 |
+
"""
|
514 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
values (`dict`):
|
518 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
519 |
+
"""
|
520 |
+
self.writer["hparams"] = values
|
521 |
+
|
522 |
+
@on_main_process
|
523 |
+
def log(self, values: dict, step: Optional[int], **kwargs):
|
524 |
+
"""
|
525 |
+
Logs `values` to the current run.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
values (`dict`):
|
529 |
+
Values to be logged as key-value pairs.
|
530 |
+
step (`int`, *optional*):
|
531 |
+
The run step. If included, the log will be affiliated with this step.
|
532 |
+
kwargs:
|
533 |
+
Additional key word arguments passed along to the `Run.track` method.
|
534 |
+
"""
|
535 |
+
# Note: replace this with the dictionary support when merged
|
536 |
+
for key, value in values.items():
|
537 |
+
self.writer.track(value, name=key, step=step, **kwargs)
|
538 |
+
|
539 |
+
@on_main_process
|
540 |
+
def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[Dict[str, dict]] = None):
|
541 |
+
"""
|
542 |
+
Logs `images` to the current run.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
|
546 |
+
Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
|
547 |
+
tuple is provided, the first element should be the image and the second element should be the caption.
|
548 |
+
step (`int`, *optional*):
|
549 |
+
The run step. If included, the log will be affiliated with this step.
|
550 |
+
kwargs (`Dict[str, dict]`):
|
551 |
+
Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
|
552 |
+
keys `aim_image` and `track`, respectively.
|
553 |
+
"""
|
554 |
+
import aim
|
555 |
+
|
556 |
+
aim_image_kw = {}
|
557 |
+
track_kw = {}
|
558 |
+
|
559 |
+
if kwargs is not None:
|
560 |
+
aim_image_kw = kwargs.get("aim_image", {})
|
561 |
+
track_kw = kwargs.get("track", {})
|
562 |
+
|
563 |
+
for key, value in values.items():
|
564 |
+
if isinstance(value, tuple):
|
565 |
+
img, caption = value
|
566 |
+
else:
|
567 |
+
img, caption = value, ""
|
568 |
+
aim_image = aim.Image(img, caption=caption, **aim_image_kw)
|
569 |
+
self.writer.track(aim_image, name=key, step=step, **track_kw)
|
570 |
+
|
571 |
+
@on_main_process
|
572 |
+
def finish(self):
|
573 |
+
"""
|
574 |
+
Closes `aim` writer
|
575 |
+
"""
|
576 |
+
self.writer.close()
|
577 |
+
|
578 |
+
|
579 |
+
class MLflowTracker(GeneralTracker):
|
580 |
+
"""
|
581 |
+
A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
experiment_name (`str`, *optional*):
|
585 |
+
Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
|
586 |
+
logging_dir (`str` or `os.PathLike`, defaults to `"."`):
|
587 |
+
Location for mlflow logs to be stored.
|
588 |
+
run_id (`str`, *optional*):
|
589 |
+
If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
|
590 |
+
end time is unset and its status is set to running, but the run’s other attributes (source_version,
|
591 |
+
source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
|
592 |
+
tags (`Dict[str, str]`, *optional*):
|
593 |
+
An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
|
594 |
+
run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
|
595 |
+
set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
|
596 |
+
nested_run (`bool`, *optional*, defaults to `False`):
|
597 |
+
Controls whether run is nested in parent run. True creates a nested run. Environment variable
|
598 |
+
MLFLOW_NESTED_RUN has priority over this argument.
|
599 |
+
run_name (`str`, *optional*):
|
600 |
+
Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
|
601 |
+
description (`str`, *optional*):
|
602 |
+
An optional string that populates the description box of the run. If a run is being resumed, the
|
603 |
+
description is set on the resumed run. If a new run is being created, the description is set on the new
|
604 |
+
run.
|
605 |
+
"""
|
606 |
+
|
607 |
+
name = "mlflow"
|
608 |
+
requires_logging_directory = False
|
609 |
+
|
610 |
+
@on_main_process
|
611 |
+
def __init__(
|
612 |
+
self,
|
613 |
+
experiment_name: str = None,
|
614 |
+
logging_dir: Optional[Union[str, os.PathLike]] = None,
|
615 |
+
run_id: Optional[str] = None,
|
616 |
+
tags: Optional[Union[Dict[str, Any], str]] = None,
|
617 |
+
nested_run: Optional[bool] = False,
|
618 |
+
run_name: Optional[str] = None,
|
619 |
+
description: Optional[str] = None,
|
620 |
+
):
|
621 |
+
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name)
|
622 |
+
run_id = os.environ.get("MLFLOW_RUN_ID", run_id)
|
623 |
+
tags = os.environ.get("MLFLOW_TAGS", tags)
|
624 |
+
if isinstance(tags, str):
|
625 |
+
tags = json.loads(tags)
|
626 |
+
|
627 |
+
nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run)
|
628 |
+
|
629 |
+
import mlflow
|
630 |
+
|
631 |
+
exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'")
|
632 |
+
if len(exps) > 0:
|
633 |
+
if len(exps) > 1:
|
634 |
+
logger.warning("Multiple experiments with the same name found. Using first one.")
|
635 |
+
experiment_id = exps[0].experiment_id
|
636 |
+
else:
|
637 |
+
experiment_id = mlflow.create_experiment(
|
638 |
+
name=experiment_name,
|
639 |
+
artifact_location=logging_dir,
|
640 |
+
tags=tags,
|
641 |
+
)
|
642 |
+
|
643 |
+
self.active_run = mlflow.start_run(
|
644 |
+
run_id=run_id,
|
645 |
+
experiment_id=experiment_id,
|
646 |
+
run_name=run_name,
|
647 |
+
nested=nested_run,
|
648 |
+
tags=tags,
|
649 |
+
description=description,
|
650 |
+
)
|
651 |
+
|
652 |
+
logger.debug(f"Initialized mlflow experiment {experiment_name}")
|
653 |
+
logger.debug(
|
654 |
+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
655 |
+
)
|
656 |
+
|
657 |
+
@property
|
658 |
+
def tracker(self):
|
659 |
+
return self.active_run
|
660 |
+
|
661 |
+
@on_main_process
|
662 |
+
def store_init_configuration(self, values: dict):
|
663 |
+
"""
|
664 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
665 |
+
|
666 |
+
Args:
|
667 |
+
values (`dict`):
|
668 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
669 |
+
"""
|
670 |
+
import mlflow
|
671 |
+
|
672 |
+
for name, value in list(values.items()):
|
673 |
+
# internally, all values are converted to str in MLflow
|
674 |
+
if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
|
675 |
+
logger.warning_once(
|
676 |
+
f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
|
677 |
+
f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
|
678 |
+
)
|
679 |
+
del values[name]
|
680 |
+
|
681 |
+
values_list = list(values.items())
|
682 |
+
|
683 |
+
# MLflow cannot log more than 100 values in one go, so we have to split it
|
684 |
+
for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
|
685 |
+
mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
|
686 |
+
|
687 |
+
logger.debug("Stored initial configuration hyperparameters to MLflow")
|
688 |
+
|
689 |
+
@on_main_process
|
690 |
+
def log(self, values: dict, step: Optional[int]):
|
691 |
+
"""
|
692 |
+
Logs `values` to the current run.
|
693 |
+
|
694 |
+
Args:
|
695 |
+
values (`dict`):
|
696 |
+
Values to be logged as key-value pairs.
|
697 |
+
step (`int`, *optional*):
|
698 |
+
The run step. If included, the log will be affiliated with this step.
|
699 |
+
"""
|
700 |
+
metrics = {}
|
701 |
+
for k, v in values.items():
|
702 |
+
if isinstance(v, (int, float)):
|
703 |
+
metrics[k] = v
|
704 |
+
else:
|
705 |
+
logger.warning_once(
|
706 |
+
f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
|
707 |
+
"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
|
708 |
+
)
|
709 |
+
import mlflow
|
710 |
+
|
711 |
+
mlflow.log_metrics(metrics, step=step)
|
712 |
+
logger.debug("Successfully logged to mlflow")
|
713 |
+
|
714 |
+
@on_main_process
|
715 |
+
def finish(self):
|
716 |
+
"""
|
717 |
+
End the active MLflow run.
|
718 |
+
"""
|
719 |
+
import mlflow
|
720 |
+
|
721 |
+
mlflow.end_run()
|
722 |
+
|
723 |
+
|
724 |
+
class ClearMLTracker(GeneralTracker):
|
725 |
+
"""
|
726 |
+
A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
|
727 |
+
|
728 |
+
Args:
|
729 |
+
run_name (`str`, *optional*):
|
730 |
+
Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
|
731 |
+
argument.
|
732 |
+
**kwargs (additional keyword arguments, *optional*):
|
733 |
+
Kwargs passed along to the `Task.__init__` method.
|
734 |
+
"""
|
735 |
+
|
736 |
+
name = "clearml"
|
737 |
+
requires_logging_directory = False
|
738 |
+
|
739 |
+
@on_main_process
|
740 |
+
def __init__(self, run_name: str = None, **kwargs):
|
741 |
+
from clearml import Task
|
742 |
+
|
743 |
+
current_task = Task.current_task()
|
744 |
+
self._initialized_externally = False
|
745 |
+
if current_task:
|
746 |
+
self._initialized_externally = True
|
747 |
+
self.task = current_task
|
748 |
+
return
|
749 |
+
|
750 |
+
kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name))
|
751 |
+
kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name))
|
752 |
+
self.task = Task.init(**kwargs)
|
753 |
+
|
754 |
+
@property
|
755 |
+
def tracker(self):
|
756 |
+
return self.task
|
757 |
+
|
758 |
+
@on_main_process
|
759 |
+
def store_init_configuration(self, values: dict):
|
760 |
+
"""
|
761 |
+
Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
|
762 |
+
|
763 |
+
Args:
|
764 |
+
values (`dict`):
|
765 |
+
Values to be stored as initial hyperparameters as key-value pairs.
|
766 |
+
"""
|
767 |
+
return self.task.connect_configuration(values)
|
768 |
+
|
769 |
+
@on_main_process
|
770 |
+
def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
|
771 |
+
"""
|
772 |
+
Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
|
773 |
+
ints or floats
|
774 |
+
|
775 |
+
Args:
|
776 |
+
values (`Dict[str, Union[int, float]]`):
|
777 |
+
Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
|
778 |
+
be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
|
779 |
+
Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
|
780 |
+
step (`int`, *optional*):
|
781 |
+
If specified, the values will be reported as scalars, with the iteration number equal to `step`.
|
782 |
+
Otherwise they will be reported as single values.
|
783 |
+
kwargs:
|
784 |
+
Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
|
785 |
+
`clearml.Logger.report_scalar` methods.
|
786 |
+
"""
|
787 |
+
clearml_logger = self.task.get_logger()
|
788 |
+
for k, v in values.items():
|
789 |
+
if not isinstance(v, (int, float)):
|
790 |
+
logger.warning_once(
|
791 |
+
"Accelerator is attempting to log a value of "
|
792 |
+
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
|
793 |
+
"This invocation of ClearML logger's report_scalar() "
|
794 |
+
"is incorrect so we dropped this attribute."
|
795 |
+
)
|
796 |
+
continue
|
797 |
+
if step is None:
|
798 |
+
clearml_logger.report_single_value(name=k, value=v, **kwargs)
|
799 |
+
continue
|
800 |
+
title, series = ClearMLTracker._get_title_series(k)
|
801 |
+
clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
|
802 |
+
|
803 |
+
@on_main_process
|
804 |
+
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
|
805 |
+
"""
|
806 |
+
Logs `images` to the current run.
|
807 |
+
|
808 |
+
Args:
|
809 |
+
values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
|
810 |
+
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
811 |
+
step (`int`, *optional*):
|
812 |
+
The run step. If included, the log will be affiliated with this step.
|
813 |
+
kwargs:
|
814 |
+
Additional key word arguments passed along to the `clearml.Logger.report_image` method.
|
815 |
+
"""
|
816 |
+
clearml_logger = self.task.get_logger()
|
817 |
+
for k, v in values.items():
|
818 |
+
title, series = ClearMLTracker._get_title_series(k)
|
819 |
+
clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
|
820 |
+
|
821 |
+
@on_main_process
|
822 |
+
def log_table(
|
823 |
+
self,
|
824 |
+
table_name: str,
|
825 |
+
columns: List[str] = None,
|
826 |
+
data: List[List[Any]] = None,
|
827 |
+
dataframe: Any = None,
|
828 |
+
step: Optional[int] = None,
|
829 |
+
**kwargs,
|
830 |
+
):
|
831 |
+
"""
|
832 |
+
Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
|
833 |
+
|
834 |
+
Args:
|
835 |
+
table_name (`str`):
|
836 |
+
The name of the table
|
837 |
+
columns (list of `str`, *optional*):
|
838 |
+
The name of the columns on the table
|
839 |
+
data (List of List of Any data type, *optional*):
|
840 |
+
The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
|
841 |
+
the name of the columns of the table
|
842 |
+
dataframe (Any data type, *optional*):
|
843 |
+
The data to be logged in the table
|
844 |
+
step (`int`, *optional*):
|
845 |
+
The run step. If included, the log will be affiliated with this step.
|
846 |
+
kwargs:
|
847 |
+
Additional key word arguments passed along to the `clearml.Logger.report_table` method.
|
848 |
+
"""
|
849 |
+
to_report = dataframe
|
850 |
+
if dataframe is None:
|
851 |
+
if data is None:
|
852 |
+
raise ValueError(
|
853 |
+
"`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
|
854 |
+
)
|
855 |
+
to_report = [columns] + data if columns else data
|
856 |
+
title, series = ClearMLTracker._get_title_series(table_name)
|
857 |
+
self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
|
858 |
+
|
859 |
+
@on_main_process
|
860 |
+
def finish(self):
|
861 |
+
"""
|
862 |
+
Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
|
863 |
+
function is a noop
|
864 |
+
"""
|
865 |
+
if self.task and not self._initialized_externally:
|
866 |
+
self.task.close()
|
867 |
+
|
868 |
+
@staticmethod
|
869 |
+
def _get_title_series(name):
|
870 |
+
for prefix in ["eval", "test", "train"]:
|
871 |
+
if name.startswith(prefix + "_"):
|
872 |
+
return name[len(prefix) + 1 :], prefix
|
873 |
+
return name, "train"
|
874 |
+
|
875 |
+
|
876 |
+
class DVCLiveTracker(GeneralTracker):
|
877 |
+
"""
|
878 |
+
A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
|
879 |
+
|
880 |
+
Args:
|
881 |
+
run_name (`str`, *optional*):
|
882 |
+
Ignored for dvclive. See `kwargs` instead.
|
883 |
+
kwargs:
|
884 |
+
Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
|
885 |
+
|
886 |
+
Example:
|
887 |
+
|
888 |
+
```py
|
889 |
+
from accelerate import Accelerator
|
890 |
+
|
891 |
+
accelerator = Accelerator(log_with="dvclive")
|
892 |
+
accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
|
893 |
+
```
|
894 |
+
"""
|
895 |
+
|
896 |
+
name = "dvclive"
|
897 |
+
requires_logging_directory = False
|
898 |
+
|
899 |
+
@on_main_process
|
900 |
+
def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
|
901 |
+
from dvclive import Live
|
902 |
+
|
903 |
+
super().__init__()
|
904 |
+
self.live = live if live is not None else Live(**kwargs)
|
905 |
+
|
906 |
+
@property
|
907 |
+
def tracker(self):
|
908 |
+
return self.live
|
909 |
+
|
910 |
+
@on_main_process
|
911 |
+
def store_init_configuration(self, values: dict):
|
912 |
+
"""
|
913 |
+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
|
914 |
+
hyperparameters in a yaml file for future use.
|
915 |
+
|
916 |
+
Args:
|
917 |
+
values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
|
918 |
+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
919 |
+
`str`, `float`, or `int`.
|
920 |
+
"""
|
921 |
+
self.live.log_params(values)
|
922 |
+
|
923 |
+
@on_main_process
|
924 |
+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
925 |
+
"""
|
926 |
+
Logs `values` to the current run.
|
927 |
+
|
928 |
+
Args:
|
929 |
+
values (Dictionary `str` to `str`, `float`, or `int`):
|
930 |
+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
|
931 |
+
step (`int`, *optional*):
|
932 |
+
The run step. If included, the log will be affiliated with this step.
|
933 |
+
kwargs:
|
934 |
+
Additional key word arguments passed along to `dvclive.Live.log_metric()`.
|
935 |
+
"""
|
936 |
+
from dvclive.plots import Metric
|
937 |
+
|
938 |
+
if step is not None:
|
939 |
+
self.live.step = step
|
940 |
+
for k, v in values.items():
|
941 |
+
if Metric.could_log(v):
|
942 |
+
self.live.log_metric(k, v, **kwargs)
|
943 |
+
else:
|
944 |
+
logger.warning_once(
|
945 |
+
"Accelerator attempted to log a value of "
|
946 |
+
f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
|
947 |
+
"This invocation of DVCLive's Live.log_metric() "
|
948 |
+
"is incorrect so we dropped this attribute."
|
949 |
+
)
|
950 |
+
self.live.next_step()
|
951 |
+
|
952 |
+
@on_main_process
|
953 |
+
def finish(self):
|
954 |
+
"""
|
955 |
+
Closes `dvclive.Live()`.
|
956 |
+
"""
|
957 |
+
self.live.end()
|
958 |
+
|
959 |
+
|
960 |
+
LOGGER_TYPE_TO_CLASS = {
|
961 |
+
"aim": AimTracker,
|
962 |
+
"comet_ml": CometMLTracker,
|
963 |
+
"mlflow": MLflowTracker,
|
964 |
+
"tensorboard": TensorBoardTracker,
|
965 |
+
"wandb": WandBTracker,
|
966 |
+
"clearml": ClearMLTracker,
|
967 |
+
"dvclive": DVCLiveTracker,
|
968 |
+
}
|
969 |
+
|
970 |
+
|
971 |
+
def filter_trackers(
|
972 |
+
log_with: List[Union[str, LoggerType, GeneralTracker]],
|
973 |
+
logging_dir: Union[str, os.PathLike] = None,
|
974 |
+
):
|
975 |
+
"""
|
976 |
+
Takes in a list of potential tracker types and checks that:
|
977 |
+
- The tracker wanted is available in that environment
|
978 |
+
- Filters out repeats of tracker types
|
979 |
+
- If `all` is in `log_with`, will return all trackers in the environment
|
980 |
+
- If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
|
981 |
+
|
982 |
+
Args:
|
983 |
+
log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
|
984 |
+
A list of loggers to be setup for experiment tracking. Should be one or several of:
|
985 |
+
|
986 |
+
- `"all"`
|
987 |
+
- `"tensorboard"`
|
988 |
+
- `"wandb"`
|
989 |
+
- `"comet_ml"`
|
990 |
+
- `"mlflow"`
|
991 |
+
- `"dvclive"`
|
992 |
+
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
|
993 |
+
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
|
994 |
+
logging_dir (`str`, `os.PathLike`, *optional*):
|
995 |
+
A path to a directory for storing logs of locally-compatible loggers.
|
996 |
+
"""
|
997 |
+
loggers = []
|
998 |
+
if log_with is not None:
|
999 |
+
if not isinstance(log_with, (list, tuple)):
|
1000 |
+
log_with = [log_with]
|
1001 |
+
if "all" in log_with or LoggerType.ALL in log_with:
|
1002 |
+
loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
|
1003 |
+
else:
|
1004 |
+
for log_type in log_with:
|
1005 |
+
if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
|
1006 |
+
raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
|
1007 |
+
if issubclass(type(log_type), GeneralTracker):
|
1008 |
+
loggers.append(log_type)
|
1009 |
+
else:
|
1010 |
+
log_type = LoggerType(log_type)
|
1011 |
+
if log_type not in loggers:
|
1012 |
+
if log_type in get_available_trackers():
|
1013 |
+
tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
|
1014 |
+
if tracker_init.requires_logging_directory:
|
1015 |
+
if logging_dir is None:
|
1016 |
+
raise ValueError(
|
1017 |
+
f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
|
1018 |
+
)
|
1019 |
+
loggers.append(log_type)
|
1020 |
+
else:
|
1021 |
+
logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
|
1022 |
+
|
1023 |
+
return loggers
|
.venv/Lib/site-packages/decorator.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ######################### LICENSE ############################ #
|
2 |
+
|
3 |
+
# Copyright (c) 2005-2021, Michele Simionato
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are
|
8 |
+
# met:
|
9 |
+
|
10 |
+
# Redistributions of source code must retain the above copyright
|
11 |
+
# notice, this list of conditions and the following disclaimer.
|
12 |
+
# Redistributions in bytecode form must reproduce the above copyright
|
13 |
+
# notice, this list of conditions and the following disclaimer in
|
14 |
+
# the documentation and/or other materials provided with the
|
15 |
+
# distribution.
|
16 |
+
|
17 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
18 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
19 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
20 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
21 |
+
# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
22 |
+
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
23 |
+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
|
24 |
+
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
25 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
|
26 |
+
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
27 |
+
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
28 |
+
# DAMAGE.
|
29 |
+
|
30 |
+
"""
|
31 |
+
Decorator module, see
|
32 |
+
https://github.com/micheles/decorator/blob/master/docs/documentation.md
|
33 |
+
for the documentation.
|
34 |
+
"""
|
35 |
+
import re
|
36 |
+
import sys
|
37 |
+
import inspect
|
38 |
+
import operator
|
39 |
+
import itertools
|
40 |
+
from contextlib import _GeneratorContextManager
|
41 |
+
from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction
|
42 |
+
|
43 |
+
__version__ = '5.1.1'
|
44 |
+
|
45 |
+
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
|
46 |
+
POS = inspect.Parameter.POSITIONAL_OR_KEYWORD
|
47 |
+
EMPTY = inspect.Parameter.empty
|
48 |
+
|
49 |
+
|
50 |
+
# this is not used anymore in the core, but kept for backward compatibility
|
51 |
+
class FunctionMaker(object):
|
52 |
+
"""
|
53 |
+
An object with the ability to create functions with a given signature.
|
54 |
+
It has attributes name, doc, module, signature, defaults, dict and
|
55 |
+
methods update and make.
|
56 |
+
"""
|
57 |
+
|
58 |
+
# Atomic get-and-increment provided by the GIL
|
59 |
+
_compile_count = itertools.count()
|
60 |
+
|
61 |
+
# make pylint happy
|
62 |
+
args = varargs = varkw = defaults = kwonlyargs = kwonlydefaults = ()
|
63 |
+
|
64 |
+
def __init__(self, func=None, name=None, signature=None,
|
65 |
+
defaults=None, doc=None, module=None, funcdict=None):
|
66 |
+
self.shortsignature = signature
|
67 |
+
if func:
|
68 |
+
# func can be a class or a callable, but not an instance method
|
69 |
+
self.name = func.__name__
|
70 |
+
if self.name == '<lambda>': # small hack for lambda functions
|
71 |
+
self.name = '_lambda_'
|
72 |
+
self.doc = func.__doc__
|
73 |
+
self.module = func.__module__
|
74 |
+
if inspect.isroutine(func):
|
75 |
+
argspec = getfullargspec(func)
|
76 |
+
self.annotations = getattr(func, '__annotations__', {})
|
77 |
+
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
|
78 |
+
'kwonlydefaults'):
|
79 |
+
setattr(self, a, getattr(argspec, a))
|
80 |
+
for i, arg in enumerate(self.args):
|
81 |
+
setattr(self, 'arg%d' % i, arg)
|
82 |
+
allargs = list(self.args)
|
83 |
+
allshortargs = list(self.args)
|
84 |
+
if self.varargs:
|
85 |
+
allargs.append('*' + self.varargs)
|
86 |
+
allshortargs.append('*' + self.varargs)
|
87 |
+
elif self.kwonlyargs:
|
88 |
+
allargs.append('*') # single star syntax
|
89 |
+
for a in self.kwonlyargs:
|
90 |
+
allargs.append('%s=None' % a)
|
91 |
+
allshortargs.append('%s=%s' % (a, a))
|
92 |
+
if self.varkw:
|
93 |
+
allargs.append('**' + self.varkw)
|
94 |
+
allshortargs.append('**' + self.varkw)
|
95 |
+
self.signature = ', '.join(allargs)
|
96 |
+
self.shortsignature = ', '.join(allshortargs)
|
97 |
+
self.dict = func.__dict__.copy()
|
98 |
+
# func=None happens when decorating a caller
|
99 |
+
if name:
|
100 |
+
self.name = name
|
101 |
+
if signature is not None:
|
102 |
+
self.signature = signature
|
103 |
+
if defaults:
|
104 |
+
self.defaults = defaults
|
105 |
+
if doc:
|
106 |
+
self.doc = doc
|
107 |
+
if module:
|
108 |
+
self.module = module
|
109 |
+
if funcdict:
|
110 |
+
self.dict = funcdict
|
111 |
+
# check existence required attributes
|
112 |
+
assert hasattr(self, 'name')
|
113 |
+
if not hasattr(self, 'signature'):
|
114 |
+
raise TypeError('You are decorating a non function: %s' % func)
|
115 |
+
|
116 |
+
def update(self, func, **kw):
|
117 |
+
"""
|
118 |
+
Update the signature of func with the data in self
|
119 |
+
"""
|
120 |
+
func.__name__ = self.name
|
121 |
+
func.__doc__ = getattr(self, 'doc', None)
|
122 |
+
func.__dict__ = getattr(self, 'dict', {})
|
123 |
+
func.__defaults__ = self.defaults
|
124 |
+
func.__kwdefaults__ = self.kwonlydefaults or None
|
125 |
+
func.__annotations__ = getattr(self, 'annotations', None)
|
126 |
+
try:
|
127 |
+
frame = sys._getframe(3)
|
128 |
+
except AttributeError: # for IronPython and similar implementations
|
129 |
+
callermodule = '?'
|
130 |
+
else:
|
131 |
+
callermodule = frame.f_globals.get('__name__', '?')
|
132 |
+
func.__module__ = getattr(self, 'module', callermodule)
|
133 |
+
func.__dict__.update(kw)
|
134 |
+
|
135 |
+
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
|
136 |
+
"""
|
137 |
+
Make a new function from a given template and update the signature
|
138 |
+
"""
|
139 |
+
src = src_templ % vars(self) # expand name and signature
|
140 |
+
evaldict = evaldict or {}
|
141 |
+
mo = DEF.search(src)
|
142 |
+
if mo is None:
|
143 |
+
raise SyntaxError('not a valid function template\n%s' % src)
|
144 |
+
name = mo.group(1) # extract the function name
|
145 |
+
names = set([name] + [arg.strip(' *') for arg in
|
146 |
+
self.shortsignature.split(',')])
|
147 |
+
for n in names:
|
148 |
+
if n in ('_func_', '_call_'):
|
149 |
+
raise NameError('%s is overridden in\n%s' % (n, src))
|
150 |
+
|
151 |
+
if not src.endswith('\n'): # add a newline for old Pythons
|
152 |
+
src += '\n'
|
153 |
+
|
154 |
+
# Ensure each generated function has a unique filename for profilers
|
155 |
+
# (such as cProfile) that depend on the tuple of (<filename>,
|
156 |
+
# <definition line>, <function name>) being unique.
|
157 |
+
filename = '<decorator-gen-%d>' % next(self._compile_count)
|
158 |
+
try:
|
159 |
+
code = compile(src, filename, 'single')
|
160 |
+
exec(code, evaldict)
|
161 |
+
except Exception:
|
162 |
+
print('Error in generated code:', file=sys.stderr)
|
163 |
+
print(src, file=sys.stderr)
|
164 |
+
raise
|
165 |
+
func = evaldict[name]
|
166 |
+
if addsource:
|
167 |
+
attrs['__source__'] = src
|
168 |
+
self.update(func, **attrs)
|
169 |
+
return func
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def create(cls, obj, body, evaldict, defaults=None,
|
173 |
+
doc=None, module=None, addsource=True, **attrs):
|
174 |
+
"""
|
175 |
+
Create a function from the strings name, signature and body.
|
176 |
+
evaldict is the evaluation dictionary. If addsource is true an
|
177 |
+
attribute __source__ is added to the result. The attributes attrs
|
178 |
+
are added, if any.
|
179 |
+
"""
|
180 |
+
if isinstance(obj, str): # "name(signature)"
|
181 |
+
name, rest = obj.strip().split('(', 1)
|
182 |
+
signature = rest[:-1] # strip a right parens
|
183 |
+
func = None
|
184 |
+
else: # a function
|
185 |
+
name = None
|
186 |
+
signature = None
|
187 |
+
func = obj
|
188 |
+
self = cls(func, name, signature, defaults, doc, module)
|
189 |
+
ibody = '\n'.join(' ' + line for line in body.splitlines())
|
190 |
+
caller = evaldict.get('_call_') # when called from `decorate`
|
191 |
+
if caller and iscoroutinefunction(caller):
|
192 |
+
body = ('async def %(name)s(%(signature)s):\n' + ibody).replace(
|
193 |
+
'return', 'return await')
|
194 |
+
else:
|
195 |
+
body = 'def %(name)s(%(signature)s):\n' + ibody
|
196 |
+
return self.make(body, evaldict, addsource, **attrs)
|
197 |
+
|
198 |
+
|
199 |
+
def fix(args, kwargs, sig):
|
200 |
+
"""
|
201 |
+
Fix args and kwargs to be consistent with the signature
|
202 |
+
"""
|
203 |
+
ba = sig.bind(*args, **kwargs)
|
204 |
+
ba.apply_defaults() # needed for test_dan_schult
|
205 |
+
return ba.args, ba.kwargs
|
206 |
+
|
207 |
+
|
208 |
+
def decorate(func, caller, extras=(), kwsyntax=False):
|
209 |
+
"""
|
210 |
+
Decorates a function/generator/coroutine using a caller.
|
211 |
+
If kwsyntax is True calling the decorated functions with keyword
|
212 |
+
syntax will pass the named arguments inside the ``kw`` dictionary,
|
213 |
+
even if such argument are positional, similarly to what functools.wraps
|
214 |
+
does. By default kwsyntax is False and the the arguments are untouched.
|
215 |
+
"""
|
216 |
+
sig = inspect.signature(func)
|
217 |
+
if iscoroutinefunction(caller):
|
218 |
+
async def fun(*args, **kw):
|
219 |
+
if not kwsyntax:
|
220 |
+
args, kw = fix(args, kw, sig)
|
221 |
+
return await caller(func, *(extras + args), **kw)
|
222 |
+
elif isgeneratorfunction(caller):
|
223 |
+
def fun(*args, **kw):
|
224 |
+
if not kwsyntax:
|
225 |
+
args, kw = fix(args, kw, sig)
|
226 |
+
for res in caller(func, *(extras + args), **kw):
|
227 |
+
yield res
|
228 |
+
else:
|
229 |
+
def fun(*args, **kw):
|
230 |
+
if not kwsyntax:
|
231 |
+
args, kw = fix(args, kw, sig)
|
232 |
+
return caller(func, *(extras + args), **kw)
|
233 |
+
fun.__name__ = func.__name__
|
234 |
+
fun.__doc__ = func.__doc__
|
235 |
+
fun.__wrapped__ = func
|
236 |
+
fun.__signature__ = sig
|
237 |
+
fun.__qualname__ = func.__qualname__
|
238 |
+
# builtin functions like defaultdict.__setitem__ lack many attributes
|
239 |
+
try:
|
240 |
+
fun.__defaults__ = func.__defaults__
|
241 |
+
except AttributeError:
|
242 |
+
pass
|
243 |
+
try:
|
244 |
+
fun.__kwdefaults__ = func.__kwdefaults__
|
245 |
+
except AttributeError:
|
246 |
+
pass
|
247 |
+
try:
|
248 |
+
fun.__annotations__ = func.__annotations__
|
249 |
+
except AttributeError:
|
250 |
+
pass
|
251 |
+
try:
|
252 |
+
fun.__module__ = func.__module__
|
253 |
+
except AttributeError:
|
254 |
+
pass
|
255 |
+
try:
|
256 |
+
fun.__dict__.update(func.__dict__)
|
257 |
+
except AttributeError:
|
258 |
+
pass
|
259 |
+
return fun
|
260 |
+
|
261 |
+
|
262 |
+
def decoratorx(caller):
|
263 |
+
"""
|
264 |
+
A version of "decorator" implemented via "exec" and not via the
|
265 |
+
Signature object. Use this if you are want to preserve the `.__code__`
|
266 |
+
object properties (https://github.com/micheles/decorator/issues/129).
|
267 |
+
"""
|
268 |
+
def dec(func):
|
269 |
+
return FunctionMaker.create(
|
270 |
+
func,
|
271 |
+
"return _call_(_func_, %(shortsignature)s)",
|
272 |
+
dict(_call_=caller, _func_=func),
|
273 |
+
__wrapped__=func, __qualname__=func.__qualname__)
|
274 |
+
return dec
|
275 |
+
|
276 |
+
|
277 |
+
def decorator(caller, _func=None, kwsyntax=False):
|
278 |
+
"""
|
279 |
+
decorator(caller) converts a caller function into a decorator
|
280 |
+
"""
|
281 |
+
if _func is not None: # return a decorated function
|
282 |
+
# this is obsolete behavior; you should use decorate instead
|
283 |
+
return decorate(_func, caller, (), kwsyntax)
|
284 |
+
# else return a decorator function
|
285 |
+
sig = inspect.signature(caller)
|
286 |
+
dec_params = [p for p in sig.parameters.values() if p.kind is POS]
|
287 |
+
|
288 |
+
def dec(func=None, *args, **kw):
|
289 |
+
na = len(args) + 1
|
290 |
+
extras = args + tuple(kw.get(p.name, p.default)
|
291 |
+
for p in dec_params[na:]
|
292 |
+
if p.default is not EMPTY)
|
293 |
+
if func is None:
|
294 |
+
return lambda func: decorate(func, caller, extras, kwsyntax)
|
295 |
+
else:
|
296 |
+
return decorate(func, caller, extras, kwsyntax)
|
297 |
+
dec.__signature__ = sig.replace(parameters=dec_params)
|
298 |
+
dec.__name__ = caller.__name__
|
299 |
+
dec.__doc__ = caller.__doc__
|
300 |
+
dec.__wrapped__ = caller
|
301 |
+
dec.__qualname__ = caller.__qualname__
|
302 |
+
dec.__kwdefaults__ = getattr(caller, '__kwdefaults__', None)
|
303 |
+
dec.__dict__.update(caller.__dict__)
|
304 |
+
return dec
|
305 |
+
|
306 |
+
|
307 |
+
# ####################### contextmanager ####################### #
|
308 |
+
|
309 |
+
|
310 |
+
class ContextManager(_GeneratorContextManager):
|
311 |
+
def __init__(self, g, *a, **k):
|
312 |
+
_GeneratorContextManager.__init__(self, g, a, k)
|
313 |
+
|
314 |
+
def __call__(self, func):
|
315 |
+
def caller(f, *a, **k):
|
316 |
+
with self.__class__(self.func, *self.args, **self.kwds):
|
317 |
+
return f(*a, **k)
|
318 |
+
return decorate(func, caller)
|
319 |
+
|
320 |
+
|
321 |
+
_contextmanager = decorator(ContextManager)
|
322 |
+
|
323 |
+
|
324 |
+
def contextmanager(func):
|
325 |
+
# Enable Pylint config: contextmanager-decorators=decorator.contextmanager
|
326 |
+
return _contextmanager(func)
|
327 |
+
|
328 |
+
|
329 |
+
# ############################ dispatch_on ############################ #
|
330 |
+
|
331 |
+
def append(a, vancestors):
|
332 |
+
"""
|
333 |
+
Append ``a`` to the list of the virtual ancestors, unless it is already
|
334 |
+
included.
|
335 |
+
"""
|
336 |
+
add = True
|
337 |
+
for j, va in enumerate(vancestors):
|
338 |
+
if issubclass(va, a):
|
339 |
+
add = False
|
340 |
+
break
|
341 |
+
if issubclass(a, va):
|
342 |
+
vancestors[j] = a
|
343 |
+
add = False
|
344 |
+
if add:
|
345 |
+
vancestors.append(a)
|
346 |
+
|
347 |
+
|
348 |
+
# inspired from simplegeneric by P.J. Eby and functools.singledispatch
|
349 |
+
def dispatch_on(*dispatch_args):
|
350 |
+
"""
|
351 |
+
Factory of decorators turning a function into a generic function
|
352 |
+
dispatching on the given arguments.
|
353 |
+
"""
|
354 |
+
assert dispatch_args, 'No dispatch args passed'
|
355 |
+
dispatch_str = '(%s,)' % ', '.join(dispatch_args)
|
356 |
+
|
357 |
+
def check(arguments, wrong=operator.ne, msg=''):
|
358 |
+
"""Make sure one passes the expected number of arguments"""
|
359 |
+
if wrong(len(arguments), len(dispatch_args)):
|
360 |
+
raise TypeError('Expected %d arguments, got %d%s' %
|
361 |
+
(len(dispatch_args), len(arguments), msg))
|
362 |
+
|
363 |
+
def gen_func_dec(func):
|
364 |
+
"""Decorator turning a function into a generic function"""
|
365 |
+
|
366 |
+
# first check the dispatch arguments
|
367 |
+
argset = set(getfullargspec(func).args)
|
368 |
+
if not set(dispatch_args) <= argset:
|
369 |
+
raise NameError('Unknown dispatch arguments %s' % dispatch_str)
|
370 |
+
|
371 |
+
typemap = {}
|
372 |
+
|
373 |
+
def vancestors(*types):
|
374 |
+
"""
|
375 |
+
Get a list of sets of virtual ancestors for the given types
|
376 |
+
"""
|
377 |
+
check(types)
|
378 |
+
ras = [[] for _ in range(len(dispatch_args))]
|
379 |
+
for types_ in typemap:
|
380 |
+
for t, type_, ra in zip(types, types_, ras):
|
381 |
+
if issubclass(t, type_) and type_ not in t.mro():
|
382 |
+
append(type_, ra)
|
383 |
+
return [set(ra) for ra in ras]
|
384 |
+
|
385 |
+
def ancestors(*types):
|
386 |
+
"""
|
387 |
+
Get a list of virtual MROs, one for each type
|
388 |
+
"""
|
389 |
+
check(types)
|
390 |
+
lists = []
|
391 |
+
for t, vas in zip(types, vancestors(*types)):
|
392 |
+
n_vas = len(vas)
|
393 |
+
if n_vas > 1:
|
394 |
+
raise RuntimeError(
|
395 |
+
'Ambiguous dispatch for %s: %s' % (t, vas))
|
396 |
+
elif n_vas == 1:
|
397 |
+
va, = vas
|
398 |
+
mro = type('t', (t, va), {}).mro()[1:]
|
399 |
+
else:
|
400 |
+
mro = t.mro()
|
401 |
+
lists.append(mro[:-1]) # discard t and object
|
402 |
+
return lists
|
403 |
+
|
404 |
+
def register(*types):
|
405 |
+
"""
|
406 |
+
Decorator to register an implementation for the given types
|
407 |
+
"""
|
408 |
+
check(types)
|
409 |
+
|
410 |
+
def dec(f):
|
411 |
+
check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__)
|
412 |
+
typemap[types] = f
|
413 |
+
return f
|
414 |
+
return dec
|
415 |
+
|
416 |
+
def dispatch_info(*types):
|
417 |
+
"""
|
418 |
+
An utility to introspect the dispatch algorithm
|
419 |
+
"""
|
420 |
+
check(types)
|
421 |
+
lst = []
|
422 |
+
for anc in itertools.product(*ancestors(*types)):
|
423 |
+
lst.append(tuple(a.__name__ for a in anc))
|
424 |
+
return lst
|
425 |
+
|
426 |
+
def _dispatch(dispatch_args, *args, **kw):
|
427 |
+
types = tuple(type(arg) for arg in dispatch_args)
|
428 |
+
try: # fast path
|
429 |
+
f = typemap[types]
|
430 |
+
except KeyError:
|
431 |
+
pass
|
432 |
+
else:
|
433 |
+
return f(*args, **kw)
|
434 |
+
combinations = itertools.product(*ancestors(*types))
|
435 |
+
next(combinations) # the first one has been already tried
|
436 |
+
for types_ in combinations:
|
437 |
+
f = typemap.get(types_)
|
438 |
+
if f is not None:
|
439 |
+
return f(*args, **kw)
|
440 |
+
|
441 |
+
# else call the default implementation
|
442 |
+
return func(*args, **kw)
|
443 |
+
|
444 |
+
return FunctionMaker.create(
|
445 |
+
func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
|
446 |
+
dict(_f_=_dispatch), register=register, default=func,
|
447 |
+
typemap=typemap, vancestors=vancestors, ancestors=ancestors,
|
448 |
+
dispatch_info=dispatch_info, __wrapped__=func)
|
449 |
+
|
450 |
+
gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
|
451 |
+
return gen_func_dec
|
.venv/Lib/site-packages/isympy.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python shell for SymPy.
|
3 |
+
|
4 |
+
This is just a normal Python shell (IPython shell if you have the
|
5 |
+
IPython package installed), that executes the following commands for
|
6 |
+
the user:
|
7 |
+
|
8 |
+
>>> from __future__ import division
|
9 |
+
>>> from sympy import *
|
10 |
+
>>> x, y, z, t = symbols('x y z t')
|
11 |
+
>>> k, m, n = symbols('k m n', integer=True)
|
12 |
+
>>> f, g, h = symbols('f g h', cls=Function)
|
13 |
+
>>> init_printing()
|
14 |
+
|
15 |
+
So starting 'isympy' is equivalent to starting Python (or IPython) and
|
16 |
+
executing the above commands by hand. It is intended for easy and quick
|
17 |
+
experimentation with SymPy. isympy is a good way to use SymPy as an
|
18 |
+
interactive calculator. If you have IPython and Matplotlib installed, then
|
19 |
+
interactive plotting is enabled by default.
|
20 |
+
|
21 |
+
COMMAND LINE OPTIONS
|
22 |
+
--------------------
|
23 |
+
|
24 |
+
-c CONSOLE, --console=CONSOLE
|
25 |
+
|
26 |
+
Use the specified shell (Python or IPython) shell as the console
|
27 |
+
backend instead of the default one (IPython if present, Python
|
28 |
+
otherwise), e.g.:
|
29 |
+
|
30 |
+
$isympy -c python
|
31 |
+
|
32 |
+
CONSOLE must be one of 'ipython' or 'python'
|
33 |
+
|
34 |
+
-p PRETTY, --pretty PRETTY
|
35 |
+
|
36 |
+
Setup pretty-printing in SymPy. When pretty-printing is enabled,
|
37 |
+
expressions can be printed with Unicode or ASCII. The default is
|
38 |
+
to use pretty-printing (with Unicode if the terminal supports it).
|
39 |
+
When this option is 'no', expressions will not be pretty-printed
|
40 |
+
and ASCII will be used:
|
41 |
+
|
42 |
+
$isympy -p no
|
43 |
+
|
44 |
+
PRETTY must be one of 'unicode', 'ascii', or 'no'
|
45 |
+
|
46 |
+
-t TYPES, --types=TYPES
|
47 |
+
|
48 |
+
Setup the ground types for the polys. By default, gmpy ground types
|
49 |
+
are used if gmpy2 or gmpy is installed, otherwise it falls back to python
|
50 |
+
ground types, which are a little bit slower. You can manually
|
51 |
+
choose python ground types even if gmpy is installed (e.g., for
|
52 |
+
testing purposes):
|
53 |
+
|
54 |
+
$isympy -t python
|
55 |
+
|
56 |
+
TYPES must be one of 'gmpy', 'gmpy1' or 'python'
|
57 |
+
|
58 |
+
Note that the ground type gmpy1 is primarily intended for testing; it
|
59 |
+
forces the use of gmpy version 1 even if gmpy2 is available.
|
60 |
+
|
61 |
+
This is the same as setting the environment variable
|
62 |
+
SYMPY_GROUND_TYPES to the given ground type (e.g.,
|
63 |
+
SYMPY_GROUND_TYPES='gmpy')
|
64 |
+
|
65 |
+
The ground types can be determined interactively from the variable
|
66 |
+
sympy.polys.domains.GROUND_TYPES.
|
67 |
+
|
68 |
+
-o ORDER, --order ORDER
|
69 |
+
|
70 |
+
Setup the ordering of terms for printing. The default is lex, which
|
71 |
+
orders terms lexicographically (e.g., x**2 + x + 1). You can choose
|
72 |
+
other orderings, such as rev-lex, which will use reverse
|
73 |
+
lexicographic ordering (e.g., 1 + x + x**2):
|
74 |
+
|
75 |
+
$isympy -o rev-lex
|
76 |
+
|
77 |
+
ORDER must be one of 'lex', 'rev-lex', 'grlex', 'rev-grlex',
|
78 |
+
'grevlex', 'rev-grevlex', 'old', or 'none'.
|
79 |
+
|
80 |
+
Note that for very large expressions, ORDER='none' may speed up
|
81 |
+
printing considerably but the terms will have no canonical order.
|
82 |
+
|
83 |
+
-q, --quiet
|
84 |
+
|
85 |
+
Print only Python's and SymPy's versions to stdout at startup.
|
86 |
+
|
87 |
+
-d, --doctest
|
88 |
+
|
89 |
+
Use the same format that should be used for doctests. This is
|
90 |
+
equivalent to -c python -p no.
|
91 |
+
|
92 |
+
-C, --no-cache
|
93 |
+
|
94 |
+
Disable the caching mechanism. Disabling the cache may slow certain
|
95 |
+
operations down considerably. This is useful for testing the cache,
|
96 |
+
or for benchmarking, as the cache can result in deceptive timings.
|
97 |
+
|
98 |
+
This is equivalent to setting the environment variable
|
99 |
+
SYMPY_USE_CACHE to 'no'.
|
100 |
+
|
101 |
+
-a, --auto-symbols (requires at least IPython 0.11)
|
102 |
+
|
103 |
+
Automatically create missing symbols. Normally, typing a name of a
|
104 |
+
Symbol that has not been instantiated first would raise NameError,
|
105 |
+
but with this option enabled, any undefined name will be
|
106 |
+
automatically created as a Symbol.
|
107 |
+
|
108 |
+
Note that this is intended only for interactive, calculator style
|
109 |
+
usage. In a script that uses SymPy, Symbols should be instantiated
|
110 |
+
at the top, so that it's clear what they are.
|
111 |
+
|
112 |
+
This will not override any names that are already defined, which
|
113 |
+
includes the single character letters represented by the mnemonic
|
114 |
+
QCOSINE (see the "Gotchas and Pitfalls" document in the
|
115 |
+
documentation). You can delete existing names by executing "del
|
116 |
+
name". If a name is defined, typing "'name' in dir()" will return True.
|
117 |
+
|
118 |
+
The Symbols that are created using this have default assumptions.
|
119 |
+
If you want to place assumptions on symbols, you should create them
|
120 |
+
using symbols() or var().
|
121 |
+
|
122 |
+
Finally, this only works in the top level namespace. So, for
|
123 |
+
example, if you define a function in isympy with an undefined
|
124 |
+
Symbol, it will not work.
|
125 |
+
|
126 |
+
See also the -i and -I options.
|
127 |
+
|
128 |
+
-i, --int-to-Integer (requires at least IPython 0.11)
|
129 |
+
|
130 |
+
Automatically wrap int literals with Integer. This makes it so that
|
131 |
+
things like 1/2 will come out as Rational(1, 2), rather than 0.5. This
|
132 |
+
works by preprocessing the source and wrapping all int literals with
|
133 |
+
Integer. Note that this will not change the behavior of int literals
|
134 |
+
assigned to variables, and it also won't change the behavior of functions
|
135 |
+
that return int literals.
|
136 |
+
|
137 |
+
If you want an int, you can wrap the literal in int(), e.g. int(3)/int(2)
|
138 |
+
gives 1.5 (with division imported from __future__).
|
139 |
+
|
140 |
+
-I, --interactive (requires at least IPython 0.11)
|
141 |
+
|
142 |
+
This is equivalent to --auto-symbols --int-to-Integer. Future options
|
143 |
+
designed for ease of interactive use may be added to this.
|
144 |
+
|
145 |
+
-D, --debug
|
146 |
+
|
147 |
+
Enable debugging output. This is the same as setting the
|
148 |
+
environment variable SYMPY_DEBUG to 'True'. The debug status is set
|
149 |
+
in the variable SYMPY_DEBUG within isympy.
|
150 |
+
|
151 |
+
-- IPython options
|
152 |
+
|
153 |
+
Additionally you can pass command line options directly to the IPython
|
154 |
+
interpreter (the standard Python shell is not supported). However you
|
155 |
+
need to add the '--' separator between two types of options, e.g the
|
156 |
+
startup banner option and the colors option. You need to enter the
|
157 |
+
options as required by the version of IPython that you are using, too:
|
158 |
+
|
159 |
+
in IPython 0.11,
|
160 |
+
|
161 |
+
$isympy -q -- --colors=NoColor
|
162 |
+
|
163 |
+
or older versions of IPython,
|
164 |
+
|
165 |
+
$isympy -q -- -colors NoColor
|
166 |
+
|
167 |
+
See also isympy --help.
|
168 |
+
"""
|
169 |
+
|
170 |
+
import os
|
171 |
+
import sys
|
172 |
+
|
173 |
+
# DO NOT IMPORT SYMPY HERE! Or the setting of the sympy environment variables
|
174 |
+
# by the command line will break.
|
175 |
+
|
176 |
+
def main() -> None:
|
177 |
+
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
178 |
+
|
179 |
+
VERSION = None
|
180 |
+
if '--version' in sys.argv:
|
181 |
+
# We cannot import sympy before this is run, because flags like -C and
|
182 |
+
# -t set environment variables that must be set before SymPy is
|
183 |
+
# imported. The only thing we need to import it for is to get the
|
184 |
+
# version, which only matters with the --version flag.
|
185 |
+
import sympy
|
186 |
+
VERSION = sympy.__version__
|
187 |
+
|
188 |
+
usage = 'isympy [options] -- [ipython options]'
|
189 |
+
parser = ArgumentParser(
|
190 |
+
usage=usage,
|
191 |
+
description=__doc__,
|
192 |
+
formatter_class=RawDescriptionHelpFormatter,
|
193 |
+
)
|
194 |
+
|
195 |
+
parser.add_argument('--version', action='version', version=VERSION)
|
196 |
+
|
197 |
+
parser.add_argument(
|
198 |
+
'-c', '--console',
|
199 |
+
dest='console',
|
200 |
+
action='store',
|
201 |
+
default=None,
|
202 |
+
choices=['ipython', 'python'],
|
203 |
+
metavar='CONSOLE',
|
204 |
+
help='select type of interactive session: ipython | python; defaults '
|
205 |
+
'to ipython if IPython is installed, otherwise python')
|
206 |
+
|
207 |
+
parser.add_argument(
|
208 |
+
'-p', '--pretty',
|
209 |
+
dest='pretty',
|
210 |
+
action='store',
|
211 |
+
default=None,
|
212 |
+
metavar='PRETTY',
|
213 |
+
choices=['unicode', 'ascii', 'no'],
|
214 |
+
help='setup pretty printing: unicode | ascii | no; defaults to '
|
215 |
+
'unicode printing if the terminal supports it, otherwise ascii')
|
216 |
+
|
217 |
+
parser.add_argument(
|
218 |
+
'-t', '--types',
|
219 |
+
dest='types',
|
220 |
+
action='store',
|
221 |
+
default=None,
|
222 |
+
metavar='TYPES',
|
223 |
+
choices=['gmpy', 'gmpy1', 'python'],
|
224 |
+
help='setup ground types: gmpy | gmpy1 | python; defaults to gmpy if gmpy2 '
|
225 |
+
'or gmpy is installed, otherwise python')
|
226 |
+
|
227 |
+
parser.add_argument(
|
228 |
+
'-o', '--order',
|
229 |
+
dest='order',
|
230 |
+
action='store',
|
231 |
+
default=None,
|
232 |
+
metavar='ORDER',
|
233 |
+
choices=['lex', 'grlex', 'grevlex', 'rev-lex', 'rev-grlex', 'rev-grevlex', 'old', 'none'],
|
234 |
+
help='setup ordering of terms: [rev-]lex | [rev-]grlex | [rev-]grevlex | old | none; defaults to lex')
|
235 |
+
|
236 |
+
parser.add_argument(
|
237 |
+
'-q', '--quiet',
|
238 |
+
dest='quiet',
|
239 |
+
action='store_true',
|
240 |
+
default=False,
|
241 |
+
help='print only version information at startup')
|
242 |
+
|
243 |
+
parser.add_argument(
|
244 |
+
'-d', '--doctest',
|
245 |
+
dest='doctest',
|
246 |
+
action='store_true',
|
247 |
+
default=False,
|
248 |
+
help='use the doctest format for output (you can just copy and paste it)')
|
249 |
+
|
250 |
+
parser.add_argument(
|
251 |
+
'-C', '--no-cache',
|
252 |
+
dest='cache',
|
253 |
+
action='store_false',
|
254 |
+
default=True,
|
255 |
+
help='disable caching mechanism')
|
256 |
+
|
257 |
+
parser.add_argument(
|
258 |
+
'-a', '--auto-symbols',
|
259 |
+
dest='auto_symbols',
|
260 |
+
action='store_true',
|
261 |
+
default=False,
|
262 |
+
help='automatically construct missing symbols')
|
263 |
+
|
264 |
+
parser.add_argument(
|
265 |
+
'-i', '--int-to-Integer',
|
266 |
+
dest='auto_int_to_Integer',
|
267 |
+
action='store_true',
|
268 |
+
default=False,
|
269 |
+
help="automatically wrap int literals with Integer")
|
270 |
+
|
271 |
+
parser.add_argument(
|
272 |
+
'-I', '--interactive',
|
273 |
+
dest='interactive',
|
274 |
+
action='store_true',
|
275 |
+
default=False,
|
276 |
+
help="equivalent to -a -i")
|
277 |
+
|
278 |
+
parser.add_argument(
|
279 |
+
'-D', '--debug',
|
280 |
+
dest='debug',
|
281 |
+
action='store_true',
|
282 |
+
default=False,
|
283 |
+
help='enable debugging output')
|
284 |
+
|
285 |
+
(options, ipy_args) = parser.parse_known_args()
|
286 |
+
if '--' in ipy_args:
|
287 |
+
ipy_args.remove('--')
|
288 |
+
|
289 |
+
if not options.cache:
|
290 |
+
os.environ['SYMPY_USE_CACHE'] = 'no'
|
291 |
+
|
292 |
+
if options.types:
|
293 |
+
os.environ['SYMPY_GROUND_TYPES'] = options.types
|
294 |
+
|
295 |
+
if options.debug:
|
296 |
+
os.environ['SYMPY_DEBUG'] = str(options.debug)
|
297 |
+
|
298 |
+
if options.doctest:
|
299 |
+
options.pretty = 'no'
|
300 |
+
options.console = 'python'
|
301 |
+
|
302 |
+
session = options.console
|
303 |
+
|
304 |
+
if session is not None:
|
305 |
+
ipython = session == 'ipython'
|
306 |
+
else:
|
307 |
+
try:
|
308 |
+
import IPython
|
309 |
+
ipython = True
|
310 |
+
except ImportError:
|
311 |
+
if not options.quiet:
|
312 |
+
from sympy.interactive.session import no_ipython
|
313 |
+
print(no_ipython)
|
314 |
+
ipython = False
|
315 |
+
|
316 |
+
args = {
|
317 |
+
'pretty_print': True,
|
318 |
+
'use_unicode': None,
|
319 |
+
'use_latex': None,
|
320 |
+
'order': None,
|
321 |
+
'argv': ipy_args,
|
322 |
+
}
|
323 |
+
|
324 |
+
if options.pretty == 'unicode':
|
325 |
+
args['use_unicode'] = True
|
326 |
+
elif options.pretty == 'ascii':
|
327 |
+
args['use_unicode'] = False
|
328 |
+
elif options.pretty == 'no':
|
329 |
+
args['pretty_print'] = False
|
330 |
+
|
331 |
+
if options.order is not None:
|
332 |
+
args['order'] = options.order
|
333 |
+
|
334 |
+
args['quiet'] = options.quiet
|
335 |
+
args['auto_symbols'] = options.auto_symbols or options.interactive
|
336 |
+
args['auto_int_to_Integer'] = options.auto_int_to_Integer or options.interactive
|
337 |
+
|
338 |
+
from sympy.interactive import init_session
|
339 |
+
init_session(ipython, **args)
|
340 |
+
|
341 |
+
if __name__ == "__main__":
|
342 |
+
main()
|
.venv/Lib/site-packages/mojimoji.cp39-win_amd64.pyd
ADDED
Binary file (93.7 kB). View file
|
|
.venv/Lib/site-packages/numpy-1.26.3-cp39-cp39-win_amd64.whl
ADDED
File without changes
|
.venv/Lib/site-packages/plac.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ######################### LICENSE ###############################
|
2 |
+
#
|
3 |
+
# Copyright (c) 2010-2021, Michele Simionato
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# Redistributions in bytecode form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in
|
10 |
+
# the documentation and/or other materials provided with the
|
11 |
+
# distribution.
|
12 |
+
|
13 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
14 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
15 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
16 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
17 |
+
# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
18 |
+
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
19 |
+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
|
20 |
+
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
21 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
|
22 |
+
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
23 |
+
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
24 |
+
# DAMAGE.
|
25 |
+
"""
|
26 |
+
See docs/index.html for the documentation.
|
27 |
+
"""
|
28 |
+
from plac_core import *
|
29 |
+
from plac_ext import (import_main, ReadlineInput, Interpreter,
|
30 |
+
stdout, runp, Monitor, default_help)
|
31 |
+
|
32 |
+
__version__ = '1.4.3'
|
33 |
+
|
34 |
+
try:
|
35 |
+
from plac_tk import TkMonitor
|
36 |
+
except ImportError:
|
37 |
+
pass
|
.venv/Lib/site-packages/plac_core.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this module should be kept Python 2.3 compatible
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import inspect
|
6 |
+
import textwrap
|
7 |
+
import functools
|
8 |
+
import argparse
|
9 |
+
from datetime import datetime, date
|
10 |
+
from gettext import gettext as _
|
11 |
+
|
12 |
+
version = sys.version_info[:2]
|
13 |
+
|
14 |
+
if sys.version >= '3':
|
15 |
+
from inspect import getfullargspec
|
16 |
+
else:
|
17 |
+
class getfullargspec(object):
|
18 |
+
"A quick and dirty replacement for getfullargspec for Python 2.X"
|
19 |
+
def __init__(self, f):
|
20 |
+
self.args, self.varargs, self.varkw, self.defaults = \
|
21 |
+
inspect.getargspec(f)
|
22 |
+
self.annotations = getattr(f, '__annotations__', {})
|
23 |
+
|
24 |
+
|
25 |
+
def to_date(s):
|
26 |
+
"""Returns year-month-day"""
|
27 |
+
return date(*time.strptime(s, "%Y-%m-%d")[0:3])
|
28 |
+
|
29 |
+
|
30 |
+
def to_datetime(s):
|
31 |
+
"""Returns year-month-day hour-minute-second"""
|
32 |
+
return datetime(*time.strptime(s, "%Y-%m-%d %H-%M-%S")[0:6])
|
33 |
+
|
34 |
+
|
35 |
+
def getargspec(callableobj):
|
36 |
+
"""Given a callable return an object with attributes .args, .varargs,
|
37 |
+
.varkw, .defaults. It tries to do the "right thing" with functions,
|
38 |
+
methods, classes and generic callables."""
|
39 |
+
if inspect.isfunction(callableobj):
|
40 |
+
argspec = getfullargspec(callableobj)
|
41 |
+
elif inspect.ismethod(callableobj):
|
42 |
+
argspec = getfullargspec(callableobj)
|
43 |
+
del argspec.args[0] # remove first argument
|
44 |
+
elif inspect.isclass(callableobj):
|
45 |
+
if callableobj.__init__ is object.__init__: # to avoid an error
|
46 |
+
argspec = getfullargspec(lambda self: None)
|
47 |
+
else:
|
48 |
+
argspec = getfullargspec(callableobj.__init__)
|
49 |
+
del argspec.args[0] # remove first argument
|
50 |
+
elif hasattr(callableobj, '__call__'):
|
51 |
+
argspec = getfullargspec(callableobj.__call__)
|
52 |
+
del argspec.args[0] # remove first argument
|
53 |
+
else:
|
54 |
+
raise TypeError(_('Could not determine the signature of ') +
|
55 |
+
str(callableobj))
|
56 |
+
return argspec
|
57 |
+
|
58 |
+
|
59 |
+
def annotations(**ann):
|
60 |
+
"""
|
61 |
+
Returns a decorator annotating a function with the given annotations.
|
62 |
+
This is a trick to support function annotations in Python 2.X.
|
63 |
+
"""
|
64 |
+
def annotate(f):
|
65 |
+
fas = getfullargspec(f)
|
66 |
+
args = fas.args
|
67 |
+
if fas.varargs:
|
68 |
+
args.append(fas.varargs)
|
69 |
+
if fas.varkw:
|
70 |
+
args.append(fas.varkw)
|
71 |
+
for argname in ann:
|
72 |
+
if argname not in args:
|
73 |
+
raise NameError(
|
74 |
+
_('Annotating non-existing argument: %s') % argname)
|
75 |
+
f.__annotations__ = ann
|
76 |
+
return f
|
77 |
+
return annotate
|
78 |
+
|
79 |
+
|
80 |
+
def _annotate(arg, ann, f):
|
81 |
+
try:
|
82 |
+
f.__annotations__[arg] = ann
|
83 |
+
except AttributeError: # Python 2.7
|
84 |
+
f.__annotations__ = {arg: ann}
|
85 |
+
return f
|
86 |
+
|
87 |
+
|
88 |
+
def pos(arg, help=None, type=None, choices=None, metavar=None):
|
89 |
+
"""
|
90 |
+
Decorator for annotating positional arguments
|
91 |
+
"""
|
92 |
+
return functools.partial(
|
93 |
+
_annotate, arg, (help, 'positional', None, type, choices, metavar))
|
94 |
+
|
95 |
+
|
96 |
+
def opt(arg, help=None, type=None, abbrev=None, choices=None, metavar=None):
|
97 |
+
"""
|
98 |
+
Decorator for annotating optional arguments
|
99 |
+
"""
|
100 |
+
abbrev = abbrev or arg[0]
|
101 |
+
return functools.partial(
|
102 |
+
_annotate, arg, (help, 'option', abbrev, type, choices, metavar))
|
103 |
+
|
104 |
+
|
105 |
+
def flg(arg, help=None, abbrev=None):
|
106 |
+
"""
|
107 |
+
Decorator for annotating flags
|
108 |
+
"""
|
109 |
+
return functools.partial(
|
110 |
+
_annotate, arg, (help, 'flag', abbrev or arg[0], None, None, None))
|
111 |
+
|
112 |
+
|
113 |
+
def is_annotation(obj):
|
114 |
+
"""
|
115 |
+
An object is an annotation object if it has the attributes
|
116 |
+
help, kind, abbrev, type, choices, metavar.
|
117 |
+
"""
|
118 |
+
return (hasattr(obj, 'help') and hasattr(obj, 'kind')
|
119 |
+
and hasattr(obj, 'abbrev') and hasattr(obj, 'type')
|
120 |
+
and hasattr(obj, 'choices') and hasattr(obj, 'metavar'))
|
121 |
+
|
122 |
+
|
123 |
+
class Annotation(object):
|
124 |
+
def __init__(self, help=None, kind="positional", abbrev=None, type=None,
|
125 |
+
choices=None, metavar=None):
|
126 |
+
assert kind in ('positional', 'option', 'flag'), kind
|
127 |
+
if kind == "positional":
|
128 |
+
assert abbrev is None, abbrev
|
129 |
+
self.help = help
|
130 |
+
self.kind = kind
|
131 |
+
self.abbrev = abbrev
|
132 |
+
self.type = type
|
133 |
+
self.choices = choices
|
134 |
+
self.metavar = metavar
|
135 |
+
|
136 |
+
def from_(cls, obj):
|
137 |
+
"Helper to convert an object into an annotation, if needed"
|
138 |
+
if is_annotation(obj):
|
139 |
+
return obj # do nothing
|
140 |
+
elif inspect.isclass(obj):
|
141 |
+
obj = str(obj)
|
142 |
+
elif iterable(obj):
|
143 |
+
return cls(*obj)
|
144 |
+
return cls(obj)
|
145 |
+
from_ = classmethod(from_)
|
146 |
+
|
147 |
+
|
148 |
+
NONE = object() # sentinel use to signal the absence of a default
|
149 |
+
|
150 |
+
PARSER_CFG = getfullargspec(argparse.ArgumentParser.__init__).args[1:]
|
151 |
+
# the default arguments accepted by an ArgumentParser object
|
152 |
+
|
153 |
+
|
154 |
+
def pconf(obj):
|
155 |
+
"""
|
156 |
+
Extracts the configuration of the underlying ArgumentParser from obj
|
157 |
+
"""
|
158 |
+
cfg = dict(description=(textwrap.dedent(obj.__doc__.rstrip())
|
159 |
+
if obj.__doc__ else None),
|
160 |
+
formatter_class=argparse.RawDescriptionHelpFormatter)
|
161 |
+
for name in dir(obj):
|
162 |
+
if name in PARSER_CFG: # argument of ArgumentParser
|
163 |
+
cfg[name] = getattr(obj, name)
|
164 |
+
return cfg
|
165 |
+
|
166 |
+
|
167 |
+
_parser_registry = {}
|
168 |
+
|
169 |
+
|
170 |
+
def parser_from(obj, **confparams):
|
171 |
+
"""
|
172 |
+
obj can be a callable or an object with a .commands attribute.
|
173 |
+
Returns an ArgumentParser.
|
174 |
+
"""
|
175 |
+
try: # the underlying parser has been generated already
|
176 |
+
return _parser_registry[obj]
|
177 |
+
except KeyError: # generate a new parser
|
178 |
+
pass
|
179 |
+
conf = pconf(obj).copy()
|
180 |
+
conf.update(confparams)
|
181 |
+
_parser_registry[obj] = parser = ArgumentParser(**conf)
|
182 |
+
parser.obj = obj
|
183 |
+
parser.case_sensitive = confparams.get(
|
184 |
+
'case_sensitive', getattr(obj, 'case_sensitive', True))
|
185 |
+
if hasattr(obj, 'commands') and not inspect.isclass(obj):
|
186 |
+
# a command container instance
|
187 |
+
parser.addsubcommands(obj.commands, obj, 'subcommands')
|
188 |
+
else:
|
189 |
+
parser.populate_from(obj)
|
190 |
+
return parser
|
191 |
+
|
192 |
+
|
193 |
+
def _extract_kwargs(args):
|
194 |
+
"""
|
195 |
+
Returns two lists: regular args and name=value args
|
196 |
+
"""
|
197 |
+
arglist = []
|
198 |
+
kwargs = {}
|
199 |
+
for arg in args:
|
200 |
+
match = re.match(r'([a-zA-Z_]\w*)=', arg)
|
201 |
+
if match:
|
202 |
+
name = match.group(1)
|
203 |
+
kwargs[name] = arg[len(name)+1:]
|
204 |
+
else:
|
205 |
+
arglist.append(arg)
|
206 |
+
return arglist, kwargs
|
207 |
+
|
208 |
+
|
209 |
+
def _match_cmd(abbrev, commands, case_sensitive=True):
|
210 |
+
"""
|
211 |
+
Extract the command name from an abbreviation or raise a NameError
|
212 |
+
"""
|
213 |
+
if not case_sensitive:
|
214 |
+
abbrev = abbrev.upper()
|
215 |
+
commands = [c.upper() for c in commands]
|
216 |
+
perfect_matches = [name for name in commands if name == abbrev]
|
217 |
+
if len(perfect_matches) == 1:
|
218 |
+
return perfect_matches[0]
|
219 |
+
matches = [name for name in commands if name.startswith(abbrev)]
|
220 |
+
n = len(matches)
|
221 |
+
if n == 1:
|
222 |
+
return matches[0]
|
223 |
+
elif n > 1:
|
224 |
+
raise NameError(
|
225 |
+
_('Ambiguous command %r: matching %s' % (abbrev, matches)))
|
226 |
+
|
227 |
+
|
228 |
+
class ArgumentParser(argparse.ArgumentParser):
|
229 |
+
"""
|
230 |
+
An ArgumentParser with .func and .argspec attributes, and possibly
|
231 |
+
.commands and .subparsers.
|
232 |
+
"""
|
233 |
+
case_sensitive = True
|
234 |
+
|
235 |
+
if version < (3, 10):
|
236 |
+
def __init__(self, *args, **kwargs):
|
237 |
+
super(ArgumentParser, self).__init__(*args, **kwargs)
|
238 |
+
if self._action_groups[1].title == _('optional arguments'):
|
239 |
+
self._action_groups[1].title = _('options')
|
240 |
+
|
241 |
+
def alias(self, arg):
|
242 |
+
"Can be overridden to preprocess command-line arguments"
|
243 |
+
return arg
|
244 |
+
|
245 |
+
def consume(self, args):
|
246 |
+
"""
|
247 |
+
Call the underlying function with the args. Works also for
|
248 |
+
command containers, by dispatching to the right subparser.
|
249 |
+
"""
|
250 |
+
arglist = [self.alias(a) for a in args]
|
251 |
+
cmd = None
|
252 |
+
if hasattr(self, 'subparsers'):
|
253 |
+
subp, cmd = self._extract_subparser_cmd(arglist)
|
254 |
+
if subp is None and cmd is not None:
|
255 |
+
return cmd, self.missing(cmd)
|
256 |
+
elif subp is not None: # use the subparser
|
257 |
+
self = subp
|
258 |
+
if hasattr(self, 'argspec') and self.argspec.varargs:
|
259 |
+
# ignore unrecognized arguments
|
260 |
+
ns, extraopts = self.parse_known_args(arglist)
|
261 |
+
else:
|
262 |
+
ns, extraopts = self.parse_args(arglist), [] # may raise an exit
|
263 |
+
if not hasattr(self, 'argspec'):
|
264 |
+
raise SystemExit
|
265 |
+
if hasattr(self, 'argspec') and self.argspec.varkw:
|
266 |
+
v = self.argspec.varargs
|
267 |
+
varkw = self.argspec.varkw
|
268 |
+
if v in ns.__dict__:
|
269 |
+
lst = ns.__dict__.pop(v)
|
270 |
+
lst, kwargs = _extract_kwargs(lst)
|
271 |
+
ns.__dict__[v] = lst
|
272 |
+
elif varkw in ns.__dict__:
|
273 |
+
lst = ns.__dict__.pop(varkw)
|
274 |
+
lst, kwargs = _extract_kwargs(lst)
|
275 |
+
ns.__dict__[varkw] = lst
|
276 |
+
if lst and not v:
|
277 |
+
self.error(_('Unrecognized arguments: %s') % arglist)
|
278 |
+
else:
|
279 |
+
kwargs = {}
|
280 |
+
collision = set(self.argspec.args) & set(kwargs)
|
281 |
+
if collision:
|
282 |
+
self.error(
|
283 |
+
_('colliding keyword arguments: %s') % ' '.join(collision))
|
284 |
+
# Correct options with trailing undescores
|
285 |
+
args = [getattr(ns, a.rstrip('_')) for a in self.argspec.args]
|
286 |
+
varargs = getattr(ns, self.argspec.varargs or '', [])
|
287 |
+
return cmd, self.func(*(args + varargs + extraopts), **kwargs)
|
288 |
+
|
289 |
+
def _extract_subparser_cmd(self, arglist):
|
290 |
+
"""
|
291 |
+
Extract the right subparser from the first recognized argument
|
292 |
+
"""
|
293 |
+
optprefix = self.prefix_chars[0]
|
294 |
+
name_parser_map = self.subparsers._name_parser_map
|
295 |
+
for i, arg in enumerate(arglist):
|
296 |
+
if not arg.startswith(optprefix):
|
297 |
+
cmd = _match_cmd(arg, name_parser_map, self.case_sensitive)
|
298 |
+
del arglist[i]
|
299 |
+
return name_parser_map.get(cmd), cmd or arg
|
300 |
+
return None, None
|
301 |
+
|
302 |
+
def addsubcommands(self, commands, obj, title=None, cmdprefix=''):
|
303 |
+
"""
|
304 |
+
Extract a list of subcommands from obj and add them to the parser
|
305 |
+
"""
|
306 |
+
if hasattr(obj, cmdprefix) and obj.cmdprefix in self.prefix_chars:
|
307 |
+
raise ValueError(_('The prefix %r is already taken!' % cmdprefix))
|
308 |
+
if not hasattr(self, 'subparsers'):
|
309 |
+
self.subparsers = self.add_subparsers(title=title)
|
310 |
+
elif title:
|
311 |
+
self.add_argument_group(title=title) # populate ._action_groups
|
312 |
+
prefixlen = len(getattr(obj, 'cmdprefix', ''))
|
313 |
+
add_help = getattr(obj, 'add_help', True)
|
314 |
+
for cmd in commands:
|
315 |
+
func = getattr(obj, cmd[prefixlen:]) # strip the prefix
|
316 |
+
doc = (textwrap.dedent(func.__doc__.rstrip())
|
317 |
+
if func.__doc__ else None)
|
318 |
+
self.subparsers.add_parser(
|
319 |
+
cmd, add_help=add_help, help=doc, **pconf(func)
|
320 |
+
).populate_from(func)
|
321 |
+
|
322 |
+
def _set_func_argspec(self, obj):
|
323 |
+
"""
|
324 |
+
Extracts the signature from a callable object and adds an .argspec
|
325 |
+
attribute to the parser. Also adds a .func reference to the object.
|
326 |
+
"""
|
327 |
+
self.func = obj
|
328 |
+
self.argspec = getargspec(obj)
|
329 |
+
_parser_registry[obj] = self
|
330 |
+
|
331 |
+
def populate_from(self, func):
|
332 |
+
"""
|
333 |
+
Extract the arguments from the attributes of the passed function
|
334 |
+
and return a populated ArgumentParser instance.
|
335 |
+
"""
|
336 |
+
self._set_func_argspec(func)
|
337 |
+
f = self.argspec
|
338 |
+
defaults = f.defaults or ()
|
339 |
+
n_args = len(f.args)
|
340 |
+
n_defaults = len(defaults)
|
341 |
+
alldefaults = (NONE,) * (n_args - n_defaults) + defaults
|
342 |
+
prefix = self.prefix = getattr(func, 'prefix_chars', '-')[0]
|
343 |
+
for name, default in zip(f.args, alldefaults):
|
344 |
+
ann = f.annotations.get(name, ())
|
345 |
+
a = Annotation.from_(ann)
|
346 |
+
metavar = a.metavar
|
347 |
+
if default is NONE:
|
348 |
+
dflt = None
|
349 |
+
else:
|
350 |
+
dflt = default
|
351 |
+
if a.help is None:
|
352 |
+
a.help = '[%s]' % str(dflt) # dflt can be a tuple
|
353 |
+
if a.type is None:
|
354 |
+
# try to infer the type from the default argument
|
355 |
+
if isinstance(default, datetime):
|
356 |
+
a.type = to_datetime
|
357 |
+
elif isinstance(default, date):
|
358 |
+
a.type = to_date
|
359 |
+
elif default is not None:
|
360 |
+
a.type = type(default)
|
361 |
+
if not metavar and default == '':
|
362 |
+
metavar = "''"
|
363 |
+
if a.kind in ('option', 'flag'):
|
364 |
+
|
365 |
+
if name.endswith("_"):
|
366 |
+
# allows reserved words to be specified with underscores
|
367 |
+
suffix = name.rstrip('_')
|
368 |
+
else:
|
369 |
+
# convert undescores to dashes.
|
370 |
+
suffix = name.replace('_', '-')
|
371 |
+
|
372 |
+
if a.abbrev:
|
373 |
+
shortlong = (prefix + a.abbrev,
|
374 |
+
prefix*2 + suffix)
|
375 |
+
else:
|
376 |
+
shortlong = (prefix + suffix,)
|
377 |
+
elif default is NONE: # required argument
|
378 |
+
self.add_argument(name, help=a.help, type=a.type,
|
379 |
+
choices=a.choices, metavar=metavar)
|
380 |
+
else: # default argument
|
381 |
+
self.add_argument(
|
382 |
+
name, nargs='?', help=a.help, default=dflt,
|
383 |
+
type=a.type, choices=a.choices, metavar=metavar)
|
384 |
+
if a.kind == 'option':
|
385 |
+
if default is not NONE:
|
386 |
+
metavar = metavar or str(default)
|
387 |
+
self.add_argument(
|
388 |
+
help=a.help, default=dflt, type=a.type,
|
389 |
+
choices=a.choices, metavar=metavar, *shortlong)
|
390 |
+
elif a.kind == 'flag':
|
391 |
+
if default is not NONE and default is not False:
|
392 |
+
raise TypeError(_('Flag %r wants default False, got %r') %
|
393 |
+
(name, default))
|
394 |
+
self.add_argument(action='store_true', help=a.help, *shortlong)
|
395 |
+
if f.varargs:
|
396 |
+
a = Annotation.from_(f.annotations.get(f.varargs, ()))
|
397 |
+
self.add_argument(f.varargs, nargs='*', help=a.help, default=[],
|
398 |
+
type=a.type, metavar=a.metavar)
|
399 |
+
if f.varkw:
|
400 |
+
a = Annotation.from_(f.annotations.get(f.varkw, ()))
|
401 |
+
self.add_argument(f.varkw, nargs='*', help=a.help, default={},
|
402 |
+
type=a.type, metavar=a.metavar)
|
403 |
+
|
404 |
+
def missing(self, name):
|
405 |
+
"May raise a SystemExit"
|
406 |
+
miss = getattr(self.obj, '__missing__', lambda name:
|
407 |
+
self.error('No command %r' % name))
|
408 |
+
return miss(name)
|
409 |
+
|
410 |
+
def print_actions(self):
|
411 |
+
"Useful for debugging"
|
412 |
+
print(self)
|
413 |
+
for a in self._actions:
|
414 |
+
print(a)
|
415 |
+
|
416 |
+
|
417 |
+
def iterable(obj):
|
418 |
+
"Any object with an __iter__ method which is not a string or class"
|
419 |
+
return hasattr(obj, '__iter__') and not inspect.isclass(obj) and not isinstance(obj, (str, bytes))
|
420 |
+
|
421 |
+
|
422 |
+
def call(obj, arglist=None, eager=True, version=None):
|
423 |
+
"""
|
424 |
+
If obj is a function or a bound method, parse the given arglist
|
425 |
+
by using the parser inferred from the annotations of obj
|
426 |
+
and call obj with the parsed arguments.
|
427 |
+
If obj is an object with attribute .commands, dispatch to the
|
428 |
+
associated subparser.
|
429 |
+
"""
|
430 |
+
if arglist is None:
|
431 |
+
arglist = sys.argv[1:]
|
432 |
+
parser = parser_from(obj)
|
433 |
+
if version:
|
434 |
+
parser.add_argument(
|
435 |
+
'--version', '-v', action='version', version=version)
|
436 |
+
cmd, result = parser.consume(arglist)
|
437 |
+
if iterable(result) and eager: # listify the result
|
438 |
+
return list(result)
|
439 |
+
return result
|
.venv/Lib/site-packages/plac_ext.py
ADDED
@@ -0,0 +1,1205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this module requires Python 2.6+
|
2 |
+
from __future__ import with_statement
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from operator import attrgetter
|
5 |
+
from gettext import gettext as _
|
6 |
+
import inspect
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import cmd
|
10 |
+
import shlex
|
11 |
+
import subprocess
|
12 |
+
import argparse
|
13 |
+
import itertools
|
14 |
+
import traceback
|
15 |
+
import multiprocessing
|
16 |
+
import signal
|
17 |
+
import threading
|
18 |
+
import plac_core
|
19 |
+
|
20 |
+
version = sys.version_info[:2]
|
21 |
+
|
22 |
+
if version < (3, 5):
|
23 |
+
from imp import load_source
|
24 |
+
else:
|
25 |
+
import importlib.util
|
26 |
+
|
27 |
+
def load_source(dotname, path):
|
28 |
+
spec = importlib.util.spec_from_file_location(dotname, path)
|
29 |
+
mod = importlib.util.module_from_spec(spec)
|
30 |
+
spec.loader.exec_module(mod)
|
31 |
+
return mod
|
32 |
+
|
33 |
+
|
34 |
+
if sys.version < '3':
|
35 |
+
def exec_(_code_, _globs_=None, _locs_=None):
|
36 |
+
if _globs_ is None:
|
37 |
+
frame = sys._getframe(1)
|
38 |
+
_globs_ = frame.f_globals
|
39 |
+
if _locs_ is None:
|
40 |
+
_locs_ = frame.f_locals
|
41 |
+
del frame
|
42 |
+
elif _locs_ is None:
|
43 |
+
_locs_ = _globs_
|
44 |
+
exec("""exec _code_ in _globs_, _locs_""")
|
45 |
+
|
46 |
+
exec('''
|
47 |
+
def raise_(tp, value=None, tb=None):
|
48 |
+
raise tp, value, tb
|
49 |
+
''')
|
50 |
+
else:
|
51 |
+
exec_ = eval('exec')
|
52 |
+
|
53 |
+
def raise_(tp, value=None, tb=None):
|
54 |
+
"""
|
55 |
+
A function that matches the Python 2.x ``raise`` statement. This
|
56 |
+
allows re-raising exceptions with the cls value and traceback on
|
57 |
+
Python 2 and 3.
|
58 |
+
"""
|
59 |
+
if value is not None and isinstance(tp, Exception):
|
60 |
+
raise TypeError("instance exception may not have a separate value")
|
61 |
+
if value is not None:
|
62 |
+
exc = tp(value)
|
63 |
+
else:
|
64 |
+
exc = tp
|
65 |
+
if exc.__traceback__ is not tb:
|
66 |
+
raise exc.with_traceback(tb)
|
67 |
+
raise exc
|
68 |
+
|
69 |
+
try:
|
70 |
+
raw_input
|
71 |
+
except NameError: # Python 3
|
72 |
+
raw_input = input
|
73 |
+
|
74 |
+
|
75 |
+
def decode(val):
|
76 |
+
"""
|
77 |
+
Decode an object assuming the encoding is UTF-8.
|
78 |
+
"""
|
79 |
+
try:
|
80 |
+
# assume it is an encoded bytes object
|
81 |
+
return val.decode('utf-8')
|
82 |
+
except AttributeError:
|
83 |
+
# it was an already decoded unicode object
|
84 |
+
return str(val)
|
85 |
+
|
86 |
+
# ############################ generic utils ############################### #
|
87 |
+
|
88 |
+
|
89 |
+
@contextmanager
|
90 |
+
def stdout(fileobj):
|
91 |
+
"usage: with stdout(file('out.txt', 'a')): do_something()"
|
92 |
+
orig_stdout = sys.stdout
|
93 |
+
sys.stdout = fileobj
|
94 |
+
try:
|
95 |
+
yield
|
96 |
+
finally:
|
97 |
+
sys.stdout = orig_stdout
|
98 |
+
|
99 |
+
|
100 |
+
def write(x):
|
101 |
+
"Write str(x) on stdout and flush, no newline added"
|
102 |
+
sys.stdout.write(str(x))
|
103 |
+
sys.stdout.flush()
|
104 |
+
|
105 |
+
|
106 |
+
def gen_val(value):
|
107 |
+
"Return a generator object with a single element"
|
108 |
+
yield value
|
109 |
+
|
110 |
+
|
111 |
+
def gen_exc(etype, exc, tb):
|
112 |
+
"Return a generator object raising an exception"
|
113 |
+
raise_(etype, exc, tb)
|
114 |
+
yield
|
115 |
+
|
116 |
+
|
117 |
+
def less(text):
|
118 |
+
"Send a text to less via a pipe"
|
119 |
+
# -c clear the screen before starting less
|
120 |
+
po = subprocess.Popen(['less', '-c'], stdin=subprocess.PIPE)
|
121 |
+
try:
|
122 |
+
po.stdin.write(text)
|
123 |
+
except IOError:
|
124 |
+
pass
|
125 |
+
po.stdin.close()
|
126 |
+
po.wait()
|
127 |
+
|
128 |
+
|
129 |
+
use_less = (sys.platform != 'win32') # unices
|
130 |
+
|
131 |
+
|
132 |
+
class TerminatedProcess(Exception):
|
133 |
+
pass
|
134 |
+
|
135 |
+
|
136 |
+
def terminatedProcess(signum, frame):
|
137 |
+
raise TerminatedProcess
|
138 |
+
|
139 |
+
|
140 |
+
# ########################## readline support ############################ #
|
141 |
+
|
142 |
+
def read_line(stdin, prompt=''):
|
143 |
+
"Read a line from stdin, using readline when possible"
|
144 |
+
if isinstance(stdin, ReadlineInput):
|
145 |
+
return stdin.readline(prompt)
|
146 |
+
else:
|
147 |
+
write(prompt)
|
148 |
+
return stdin.readline()
|
149 |
+
|
150 |
+
|
151 |
+
def read_long_line(stdin, terminator):
|
152 |
+
"""
|
153 |
+
Read multiple lines from stdin until the terminator character is found,
|
154 |
+
then yield a single space-separated long line.
|
155 |
+
"""
|
156 |
+
while True:
|
157 |
+
lines = []
|
158 |
+
while True:
|
159 |
+
line = stdin.readline() # ends with \n
|
160 |
+
if not line: # EOF
|
161 |
+
return
|
162 |
+
line = line.strip()
|
163 |
+
if not line:
|
164 |
+
continue
|
165 |
+
elif line[-1] == terminator:
|
166 |
+
lines.append(line[:-1])
|
167 |
+
break
|
168 |
+
else:
|
169 |
+
lines.append(line)
|
170 |
+
yield ' '.join(lines)
|
171 |
+
|
172 |
+
|
173 |
+
class ReadlineInput(object):
|
174 |
+
"""
|
175 |
+
An iterable with a .readline method reading from stdin.
|
176 |
+
"""
|
177 |
+
def __init__(self, completions, case_sensitive=True, histfile=None):
|
178 |
+
self.completions = completions
|
179 |
+
self.case_sensitive = case_sensitive
|
180 |
+
self.histfile = histfile
|
181 |
+
if not case_sensitive:
|
182 |
+
self.completions = [c.upper() for c in completions]
|
183 |
+
import readline
|
184 |
+
self.rl = readline
|
185 |
+
readline.parse_and_bind("tab: complete")
|
186 |
+
readline.set_completer(self.complete)
|
187 |
+
|
188 |
+
def __enter__(self):
|
189 |
+
self.old_completer = self.rl.get_completer()
|
190 |
+
try:
|
191 |
+
if self.histfile:
|
192 |
+
self.rl.read_history_file(self.histfile)
|
193 |
+
except IOError: # the first time
|
194 |
+
pass
|
195 |
+
return self
|
196 |
+
|
197 |
+
def __exit__(self, etype, exc, tb):
|
198 |
+
self.rl.set_completer(self.old_completer)
|
199 |
+
if self.histfile:
|
200 |
+
self.rl.write_history_file(self.histfile)
|
201 |
+
|
202 |
+
def complete(self, kw, state):
|
203 |
+
# state is 0, 1, 2, ... and increases by hitting TAB
|
204 |
+
if not self.case_sensitive:
|
205 |
+
kw = kw.upper()
|
206 |
+
try:
|
207 |
+
return [k for k in self.completions if k.startswith(kw)][state]
|
208 |
+
except IndexError: # no completions
|
209 |
+
return # exit
|
210 |
+
|
211 |
+
def readline(self, prompt=''):
|
212 |
+
try:
|
213 |
+
return raw_input(prompt) + '\n'
|
214 |
+
except EOFError:
|
215 |
+
return ''
|
216 |
+
|
217 |
+
def __iter__(self):
|
218 |
+
return iter(self.readline, '')
|
219 |
+
|
220 |
+
# ################# help functionality in plac interpreters ################# #
|
221 |
+
|
222 |
+
|
223 |
+
class HelpSummary(object):
|
224 |
+
"Build the help summary consistently with the cmd module"
|
225 |
+
|
226 |
+
@classmethod
|
227 |
+
def add(cls, obj, specialcommands):
|
228 |
+
p = plac_core.parser_from(obj)
|
229 |
+
c = cmd.Cmd(stdout=cls())
|
230 |
+
c.stdout.write('\n')
|
231 |
+
c.print_topics('special commands',
|
232 |
+
sorted(specialcommands), 15, 80)
|
233 |
+
c.print_topics('custom commands',
|
234 |
+
sorted(obj.commands), 15, 80)
|
235 |
+
c.print_topics('commands run in external processes',
|
236 |
+
sorted(obj.mpcommands), 15, 80)
|
237 |
+
c.print_topics('threaded commands',
|
238 |
+
sorted(obj.thcommands), 15, 80)
|
239 |
+
p.helpsummary = str(c.stdout)
|
240 |
+
|
241 |
+
def __init__(self):
|
242 |
+
self._ls = []
|
243 |
+
|
244 |
+
def write(self, s):
|
245 |
+
self._ls.append(s)
|
246 |
+
|
247 |
+
def __str__(self):
|
248 |
+
return ''.join(self._ls)
|
249 |
+
|
250 |
+
|
251 |
+
class PlacFormatter(argparse.RawDescriptionHelpFormatter):
|
252 |
+
def _metavar_formatter(self, action, default_metavar):
|
253 |
+
'Remove special commands from the usage message'
|
254 |
+
choices = action.choices or {}
|
255 |
+
action.choices = dict((n, c) for n, c in choices.items()
|
256 |
+
if not n.startswith('.'))
|
257 |
+
return super(PlacFormatter, self)._metavar_formatter(
|
258 |
+
action, default_metavar)
|
259 |
+
|
260 |
+
|
261 |
+
def format_help(self):
|
262 |
+
"Attached to plac_core.ArgumentParser for plac interpreters"
|
263 |
+
try:
|
264 |
+
return self.helpsummary
|
265 |
+
except AttributeError:
|
266 |
+
return super(plac_core.ArgumentParser, self).format_help()
|
267 |
+
plac_core.ArgumentParser.format_help = format_help
|
268 |
+
|
269 |
+
|
270 |
+
def default_help(obj, cmd=None):
|
271 |
+
"The default help functionality in plac interpreters"
|
272 |
+
parser = plac_core.parser_from(obj)
|
273 |
+
if cmd is None:
|
274 |
+
yield parser.format_help()
|
275 |
+
return
|
276 |
+
subp = parser.subparsers._name_parser_map.get(cmd)
|
277 |
+
if subp is None:
|
278 |
+
yield _('Unknown command %s' % cmd)
|
279 |
+
elif getattr(obj, '_interact_', False): # in interactive mode
|
280 |
+
formatter = subp._get_formatter()
|
281 |
+
formatter._prog = cmd # remove the program name from the usage
|
282 |
+
formatter.add_usage(
|
283 |
+
subp.usage, [a for a in subp._actions if a.dest != 'help'],
|
284 |
+
subp._mutually_exclusive_groups)
|
285 |
+
formatter.add_text(subp.description)
|
286 |
+
for action_group in subp._action_groups:
|
287 |
+
formatter.start_section(action_group.title)
|
288 |
+
formatter.add_text(action_group.description)
|
289 |
+
formatter.add_arguments(a for a in action_group._group_actions
|
290 |
+
if a.dest != 'help')
|
291 |
+
formatter.end_section()
|
292 |
+
yield formatter.format_help()
|
293 |
+
else: # regular argparse help
|
294 |
+
yield subp.format_help()
|
295 |
+
|
296 |
+
# ######################## import management ############################## #
|
297 |
+
|
298 |
+
try:
|
299 |
+
PLACDIRS = os.environ.get('PLACPATH', '.').split(':')
|
300 |
+
except:
|
301 |
+
raise ValueError(_('Ill-formed PLACPATH: got %PLACPATHs') % os.environ)
|
302 |
+
|
303 |
+
|
304 |
+
def partial_call(factory, arglist):
|
305 |
+
"Call a container factory with the arglist and return a plac object"
|
306 |
+
a = plac_core.parser_from(factory).argspec
|
307 |
+
if a.defaults or a.varargs or a.varkw:
|
308 |
+
raise TypeError('Interpreter.call must be invoked on '
|
309 |
+
'factories with required arguments only')
|
310 |
+
required_args = ', '.join(a.args)
|
311 |
+
if required_args:
|
312 |
+
required_args += ',' # trailing comma
|
313 |
+
code = '''def makeobj(interact, %s *args):
|
314 |
+
obj = factory(%s)
|
315 |
+
obj._interact_ = interact
|
316 |
+
obj._args_ = args
|
317 |
+
return obj\n''' % (required_args, required_args)
|
318 |
+
dic = dict(factory=factory)
|
319 |
+
exec_(code, dic)
|
320 |
+
makeobj = dic['makeobj']
|
321 |
+
makeobj.add_help = False
|
322 |
+
if inspect.isclass(factory):
|
323 |
+
makeobj.__annotations__ = getattr(
|
324 |
+
factory.__init__, '__annotations__', {})
|
325 |
+
else:
|
326 |
+
makeobj.__annotations__ = getattr(
|
327 |
+
factory, '__annotations__', {})
|
328 |
+
makeobj.__annotations__['interact'] = (
|
329 |
+
'start interactive interpreter', 'flag', 'i')
|
330 |
+
return plac_core.call(makeobj, arglist)
|
331 |
+
|
332 |
+
|
333 |
+
def import_main(path, *args):
|
334 |
+
"""
|
335 |
+
A utility to import the main function of a plac tool. It also
|
336 |
+
works with command container factories.
|
337 |
+
"""
|
338 |
+
if ':' in path: # importing a factory
|
339 |
+
path, factory_name = path.split(':')
|
340 |
+
else: # importing the main function
|
341 |
+
factory_name = None
|
342 |
+
if not os.path.isabs(path): # relative path, look at PLACDIRS
|
343 |
+
for placdir in PLACDIRS:
|
344 |
+
fullpath = os.path.join(placdir, path)
|
345 |
+
if os.path.exists(fullpath):
|
346 |
+
break
|
347 |
+
else: # no break
|
348 |
+
raise ImportError(_('Cannot find %s' % path))
|
349 |
+
else:
|
350 |
+
fullpath = path
|
351 |
+
name, ext = os.path.splitext(os.path.basename(fullpath))
|
352 |
+
module = load_source(name, fullpath)
|
353 |
+
if factory_name:
|
354 |
+
tool = partial_call(getattr(module, factory_name), args)
|
355 |
+
else:
|
356 |
+
tool = module.main
|
357 |
+
return tool
|
358 |
+
|
359 |
+
# ############################ Task classes ############################# #
|
360 |
+
|
361 |
+
|
362 |
+
# base class not instantiated directly
|
363 |
+
class BaseTask(object):
|
364 |
+
"""
|
365 |
+
A task is a wrapper over a generator object with signature
|
366 |
+
Task(no, arglist, genobj), attributes
|
367 |
+
.no
|
368 |
+
.arglist
|
369 |
+
.outlist
|
370 |
+
.str
|
371 |
+
.etype
|
372 |
+
.exc
|
373 |
+
.tb
|
374 |
+
.status
|
375 |
+
and methods .run and .kill.
|
376 |
+
"""
|
377 |
+
STATES = ('SUBMITTED', 'RUNNING', 'TOBEKILLED', 'KILLED', 'FINISHED',
|
378 |
+
'ABORTED')
|
379 |
+
|
380 |
+
def __init__(self, no, arglist, genobj):
|
381 |
+
self.no = no
|
382 |
+
self.arglist = arglist
|
383 |
+
self._genobj = self._wrap(genobj)
|
384 |
+
self.str, self.etype, self.exc, self.tb = '', None, None, None
|
385 |
+
self.status = 'SUBMITTED'
|
386 |
+
self.outlist = []
|
387 |
+
|
388 |
+
def notify(self, msg):
|
389 |
+
"Notifies the underlying monitor. To be implemented"
|
390 |
+
|
391 |
+
def _wrap(self, genobj, stringify_tb=False):
|
392 |
+
"""
|
393 |
+
Wrap the genobj into a generator managing the exceptions,
|
394 |
+
populating the .outlist, setting the .status and yielding None.
|
395 |
+
stringify_tb must be True if the traceback must be sent to a process.
|
396 |
+
"""
|
397 |
+
self.status = 'RUNNING'
|
398 |
+
try:
|
399 |
+
for value in genobj:
|
400 |
+
if self.status == 'TOBEKILLED': # exit from the loop
|
401 |
+
raise GeneratorExit
|
402 |
+
if value is not None: # add output
|
403 |
+
self.outlist.append(value)
|
404 |
+
self.notify(decode(value))
|
405 |
+
yield
|
406 |
+
except Interpreter.Exit: # wanted exit
|
407 |
+
self._regular_exit()
|
408 |
+
raise
|
409 |
+
except (GeneratorExit, TerminatedProcess, KeyboardInterrupt):
|
410 |
+
# soft termination
|
411 |
+
self.status = 'KILLED'
|
412 |
+
except Exception: # unexpected exception
|
413 |
+
self.etype, self.exc, tb = sys.exc_info()
|
414 |
+
self.tb = ''.join(traceback.format_tb(tb)) if stringify_tb else tb
|
415 |
+
self.status = 'ABORTED'
|
416 |
+
else:
|
417 |
+
self._regular_exit()
|
418 |
+
|
419 |
+
def _regular_exit(self):
|
420 |
+
self.status = 'FINISHED'
|
421 |
+
try:
|
422 |
+
self.str = '\n'.join(map(decode, self.outlist))
|
423 |
+
except IndexError:
|
424 |
+
self.str = 'no result'
|
425 |
+
|
426 |
+
def run(self):
|
427 |
+
"Run the inner generator"
|
428 |
+
for none in self._genobj:
|
429 |
+
pass
|
430 |
+
|
431 |
+
def kill(self):
|
432 |
+
"Set a TOBEKILLED status"
|
433 |
+
self.status = 'TOBEKILLED'
|
434 |
+
|
435 |
+
def wait(self):
|
436 |
+
"Wait for the task to finish: to be overridden"
|
437 |
+
|
438 |
+
@property
|
439 |
+
def traceback(self):
|
440 |
+
"Return the traceback as a (possibly empty) string"
|
441 |
+
if self.tb is None:
|
442 |
+
return ''
|
443 |
+
elif isinstance(self.tb, (str, bytes)):
|
444 |
+
return self.tb
|
445 |
+
else:
|
446 |
+
return ''.join(traceback.format_tb(self.tb))
|
447 |
+
|
448 |
+
@property
|
449 |
+
def result(self):
|
450 |
+
self.wait()
|
451 |
+
if self.exc:
|
452 |
+
if isinstance(self.tb, (str, bytes)):
|
453 |
+
raise self.etype(self.tb)
|
454 |
+
else:
|
455 |
+
raise_(self.etype, self.exc, self.tb or None)
|
456 |
+
if not self.outlist:
|
457 |
+
return None
|
458 |
+
return self.outlist[-1]
|
459 |
+
|
460 |
+
def __repr__(self):
|
461 |
+
"String representation containing class name, number, arglist, status"
|
462 |
+
return '<%s %d [%s] %s>' % (
|
463 |
+
self.__class__.__name__, self.no,
|
464 |
+
' '.join(self.arglist), self.status)
|
465 |
+
|
466 |
+
nulltask = BaseTask(0, [], ('skip' for dummy in (1,)))
|
467 |
+
|
468 |
+
# ######################## synchronous tasks ############################## #
|
469 |
+
|
470 |
+
|
471 |
+
class SynTask(BaseTask):
|
472 |
+
"""
|
473 |
+
Synchronous task running in the interpreter loop and displaying its
|
474 |
+
output as soon as available.
|
475 |
+
"""
|
476 |
+
def __str__(self):
|
477 |
+
"Return the output string or the error message"
|
478 |
+
if self.etype: # there was an error
|
479 |
+
return '%s: %s' % (self.etype.__name__, self.exc)
|
480 |
+
else:
|
481 |
+
return '\n'.join(map(str, self.outlist))
|
482 |
+
|
483 |
+
|
484 |
+
class ThreadedTask(BaseTask):
|
485 |
+
"""
|
486 |
+
A task running in a separated thread.
|
487 |
+
"""
|
488 |
+
def __init__(self, no, arglist, genobj):
|
489 |
+
BaseTask.__init__(self, no, arglist, genobj)
|
490 |
+
self.thread = threading.Thread(target=super(ThreadedTask, self).run)
|
491 |
+
|
492 |
+
def run(self):
|
493 |
+
"Run the task into a thread"
|
494 |
+
self.thread.start()
|
495 |
+
|
496 |
+
def wait(self):
|
497 |
+
"Block until the thread ends"
|
498 |
+
self.thread.join()
|
499 |
+
|
500 |
+
|
501 |
+
# ######################## multiprocessing tasks ######################### #
|
502 |
+
|
503 |
+
def sharedattr(name, on_error):
|
504 |
+
"Return a property to be attached to an MPTask"
|
505 |
+
def get(self):
|
506 |
+
try:
|
507 |
+
return getattr(self.ns, name)
|
508 |
+
except: # the process was killed or died hard
|
509 |
+
return on_error
|
510 |
+
|
511 |
+
def set(self, value):
|
512 |
+
try:
|
513 |
+
setattr(self.ns, name, value)
|
514 |
+
except: # the process was killed or died hard
|
515 |
+
pass
|
516 |
+
return property(get, set)
|
517 |
+
|
518 |
+
|
519 |
+
class MPTask(BaseTask):
|
520 |
+
"""
|
521 |
+
A task running as an external process. The current implementation
|
522 |
+
only works on Unix-like systems, where multiprocessing use forks.
|
523 |
+
"""
|
524 |
+
str = sharedattr('str', '')
|
525 |
+
etype = sharedattr('etype', None)
|
526 |
+
exc = sharedattr('exc', None)
|
527 |
+
tb = sharedattr('tb', None)
|
528 |
+
status = sharedattr('status', 'ABORTED')
|
529 |
+
|
530 |
+
@property
|
531 |
+
def outlist(self):
|
532 |
+
try:
|
533 |
+
return self._outlist
|
534 |
+
except: # the process died hard
|
535 |
+
return []
|
536 |
+
|
537 |
+
def notify(self, msg):
|
538 |
+
self.man.notify_listener(self.no, msg)
|
539 |
+
|
540 |
+
def __init__(self, no, arglist, genobj, manager):
|
541 |
+
"""
|
542 |
+
The monitor has a .send method and a .man multiprocessing.Manager
|
543 |
+
"""
|
544 |
+
self.no = no
|
545 |
+
self.arglist = arglist
|
546 |
+
self._genobj = self._wrap(genobj, stringify_tb=True)
|
547 |
+
self.man = manager
|
548 |
+
self._outlist = manager.mp.list()
|
549 |
+
self.ns = manager.mp.Namespace()
|
550 |
+
self.status = 'SUBMITTED'
|
551 |
+
self.etype, self.exc, self.tb = None, None, None
|
552 |
+
self.str = repr(self)
|
553 |
+
self.proc = multiprocessing.Process(target=super(MPTask, self).run)
|
554 |
+
|
555 |
+
def run(self):
|
556 |
+
"Run the task into an external process"
|
557 |
+
self.proc.start()
|
558 |
+
|
559 |
+
def wait(self):
|
560 |
+
"Block until the external process ends or is killed"
|
561 |
+
self.proc.join()
|
562 |
+
|
563 |
+
def kill(self):
|
564 |
+
"""Kill the process with a SIGTERM inducing a TerminatedProcess
|
565 |
+
exception in the children"""
|
566 |
+
self.proc.terminate()
|
567 |
+
|
568 |
+
# ######################## Task Manager ###################### #
|
569 |
+
|
570 |
+
|
571 |
+
class TaskManager(object):
|
572 |
+
"""
|
573 |
+
Store the given commands into a task registry. Provides methods to
|
574 |
+
manage the submitted tasks.
|
575 |
+
"""
|
576 |
+
cmdprefix = '.'
|
577 |
+
specialcommands = set(['.last_tb'])
|
578 |
+
|
579 |
+
def __init__(self, obj):
|
580 |
+
self.obj = obj
|
581 |
+
self.registry = {} # {taskno : task}
|
582 |
+
if obj.mpcommands or obj.thcommands:
|
583 |
+
self.specialcommands.update(['.kill', '.list', '.output'])
|
584 |
+
interact = getattr(obj, '_interact_', False)
|
585 |
+
self.parser = plac_core.parser_from(
|
586 |
+
obj, prog='' if interact else None, formatter_class=PlacFormatter)
|
587 |
+
HelpSummary.add(obj, self.specialcommands)
|
588 |
+
self.man = Manager() if obj.mpcommands else None
|
589 |
+
signal.signal(signal.SIGTERM, terminatedProcess)
|
590 |
+
|
591 |
+
def close(self):
|
592 |
+
"Kill all the running tasks"
|
593 |
+
for task in self.registry.values():
|
594 |
+
try:
|
595 |
+
if task.status == 'RUNNING':
|
596 |
+
task.kill()
|
597 |
+
task.wait()
|
598 |
+
except: # task killed, nothing to wait
|
599 |
+
pass
|
600 |
+
if self.man:
|
601 |
+
self.man.stop()
|
602 |
+
|
603 |
+
def _get_latest(self, taskno=-1, status=None):
|
604 |
+
"Get the latest submitted task from the registry"
|
605 |
+
assert taskno < 0, 'You must pass a negative number'
|
606 |
+
if status:
|
607 |
+
tasks = [t for t in self.registry.values()
|
608 |
+
if t.status == status]
|
609 |
+
else:
|
610 |
+
tasks = [t for t in self.registry.values()]
|
611 |
+
tasks.sort(key=attrgetter('no'))
|
612 |
+
if len(tasks) >= abs(taskno):
|
613 |
+
return tasks[taskno]
|
614 |
+
|
615 |
+
# ########################## special commands ######################## #
|
616 |
+
|
617 |
+
@plac_core.annotations(
|
618 |
+
taskno=('task to kill', 'positional', None, int))
|
619 |
+
def kill(self, taskno=-1):
|
620 |
+
'kill the given task (-1 to kill the latest running task)'
|
621 |
+
if taskno < 0:
|
622 |
+
task = self._get_latest(taskno, status='RUNNING')
|
623 |
+
if task is None:
|
624 |
+
yield 'Nothing to kill'
|
625 |
+
return
|
626 |
+
elif taskno not in self.registry:
|
627 |
+
yield 'Unknown task %d' % taskno
|
628 |
+
return
|
629 |
+
else:
|
630 |
+
task = self.registry[taskno]
|
631 |
+
if task.status in ('ABORTED', 'KILLED', 'FINISHED'):
|
632 |
+
yield 'Already finished %s' % task
|
633 |
+
return
|
634 |
+
task.kill()
|
635 |
+
yield task
|
636 |
+
|
637 |
+
@plac_core.annotations(
|
638 |
+
status=('', 'positional', None, str, BaseTask.STATES))
|
639 |
+
def list(self, status='RUNNING'):
|
640 |
+
'list tasks with a given status'
|
641 |
+
for task in self.registry.values():
|
642 |
+
if task.status == status:
|
643 |
+
yield task
|
644 |
+
|
645 |
+
@plac_core.annotations(
|
646 |
+
taskno=('task number', 'positional', None, int))
|
647 |
+
def output(self, taskno=-1, fname=None):
|
648 |
+
'show the output of a given task (and optionally save it to a file)'
|
649 |
+
if taskno < 0:
|
650 |
+
task = self._get_latest(taskno)
|
651 |
+
if task is None:
|
652 |
+
yield 'Nothing to show'
|
653 |
+
return
|
654 |
+
elif taskno not in self.registry:
|
655 |
+
yield 'Unknown task %d' % taskno
|
656 |
+
return
|
657 |
+
else:
|
658 |
+
task = self.registry[taskno]
|
659 |
+
outstr = '\n'.join(map(str, task.outlist))
|
660 |
+
if fname:
|
661 |
+
open(fname, 'w').write(outstr)
|
662 |
+
yield 'saved output of %d into %s' % (taskno, fname)
|
663 |
+
return
|
664 |
+
yield task
|
665 |
+
if len(task.outlist) > 20 and use_less:
|
666 |
+
less(outstr) # has no meaning for a plac server
|
667 |
+
else:
|
668 |
+
yield outstr
|
669 |
+
|
670 |
+
@plac_core.annotations(
|
671 |
+
taskno=('task number', 'positional', None, int))
|
672 |
+
def last_tb(self, taskno=-1):
|
673 |
+
"show the traceback of a given task, if any"
|
674 |
+
task = self._get_latest(taskno)
|
675 |
+
if task:
|
676 |
+
yield task.traceback
|
677 |
+
else:
|
678 |
+
yield 'Nothing to show'
|
679 |
+
|
680 |
+
# ########################## SyncProcess ############################# #
|
681 |
+
|
682 |
+
|
683 |
+
class Process(subprocess.Popen):
|
684 |
+
"Start the interpreter specified by the params in a subprocess"
|
685 |
+
|
686 |
+
def __init__(self, params):
|
687 |
+
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
688 |
+
# to avoid broken pipe messages
|
689 |
+
code = '''import plac, sys
|
690 |
+
sys.argv[0] = '<%s>'
|
691 |
+
plac.Interpreter(plac.import_main(*%s)).interact(prompt='i>\\n')
|
692 |
+
''' % (params[0], params)
|
693 |
+
subprocess.Popen.__init__(
|
694 |
+
self, [sys.executable, '-u', '-c', code],
|
695 |
+
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
696 |
+
self.man = multiprocessing.Manager()
|
697 |
+
|
698 |
+
def close(self):
|
699 |
+
"Close stdin and stdout"
|
700 |
+
self.stdin.close()
|
701 |
+
self.stdout.close()
|
702 |
+
self.man.shutdown()
|
703 |
+
|
704 |
+
def recv(self): # char-by-char cannot work
|
705 |
+
"Return the output of the subprocess, line-by-line until the prompt"
|
706 |
+
lines = []
|
707 |
+
while True:
|
708 |
+
lines.append(self.stdout.readline())
|
709 |
+
if lines[-1] == 'i>\n':
|
710 |
+
out = ''.join(lines)
|
711 |
+
return out[:-1] + ' ' # remove last newline
|
712 |
+
|
713 |
+
def send(self, line):
|
714 |
+
"""Send a line (adding a newline) to the underlying subprocess
|
715 |
+
and wait for the answer"""
|
716 |
+
self.stdin.write(line + os.linesep)
|
717 |
+
return self.recv()
|
718 |
+
|
719 |
+
|
720 |
+
class StartStopObject(object):
|
721 |
+
started = False
|
722 |
+
|
723 |
+
def start(self):
|
724 |
+
pass
|
725 |
+
|
726 |
+
def stop(self):
|
727 |
+
pass
|
728 |
+
|
729 |
+
|
730 |
+
class Monitor(StartStopObject):
|
731 |
+
"""
|
732 |
+
Base monitor class with methods add_listener/del_listener/notify_listener
|
733 |
+
read_queue and and start/stop.
|
734 |
+
"""
|
735 |
+
def __init__(self, name, queue=None):
|
736 |
+
self.name = name
|
737 |
+
self.queue = queue or multiprocessing.Queue()
|
738 |
+
|
739 |
+
def add_listener(self, taskno):
|
740 |
+
pass
|
741 |
+
|
742 |
+
def del_listener(self, taskno):
|
743 |
+
pass
|
744 |
+
|
745 |
+
def notify_listener(self, taskno, msg):
|
746 |
+
pass
|
747 |
+
|
748 |
+
def start(self):
|
749 |
+
pass
|
750 |
+
|
751 |
+
def stop(self):
|
752 |
+
pass
|
753 |
+
|
754 |
+
def read_queue(self):
|
755 |
+
pass
|
756 |
+
|
757 |
+
|
758 |
+
class Manager(StartStopObject):
|
759 |
+
"""
|
760 |
+
The plac Manager contains a multiprocessing.Manager and a set
|
761 |
+
of slave monitor processes to which we can send commands. There
|
762 |
+
is a manager for each interpreter with mpcommands.
|
763 |
+
"""
|
764 |
+
def __init__(self):
|
765 |
+
self.registry = {}
|
766 |
+
self.started = False
|
767 |
+
self.mp = None
|
768 |
+
|
769 |
+
def add(self, monitor):
|
770 |
+
'Add or replace a monitor in the registry'
|
771 |
+
proc = multiprocessing.Process(None, monitor.start, monitor.name)
|
772 |
+
proc.queue = monitor.queue
|
773 |
+
self.registry[monitor.name] = proc
|
774 |
+
|
775 |
+
def delete(self, name):
|
776 |
+
'Remove a named monitor from the registry'
|
777 |
+
del self.registry[name]
|
778 |
+
|
779 |
+
# can be called more than once
|
780 |
+
def start(self):
|
781 |
+
if self.mp is None:
|
782 |
+
self.mp = multiprocessing.Manager()
|
783 |
+
for monitor in self.registry.values():
|
784 |
+
monitor.start()
|
785 |
+
self.started = True
|
786 |
+
|
787 |
+
def stop(self):
|
788 |
+
for monitor in self.registry.values():
|
789 |
+
monitor.queue.close()
|
790 |
+
monitor.terminate()
|
791 |
+
if self.mp:
|
792 |
+
self.mp.shutdown()
|
793 |
+
self.mp = None
|
794 |
+
self.started = False
|
795 |
+
|
796 |
+
def notify_listener(self, taskno, msg):
|
797 |
+
for monitor in self.registry.values():
|
798 |
+
monitor.queue.put(('notify_listener', taskno, msg))
|
799 |
+
|
800 |
+
def add_listener(self, no):
|
801 |
+
for monitor in self.registry.values():
|
802 |
+
monitor.queue.put(('add_listener', no))
|
803 |
+
|
804 |
+
# ######################### plac server ############################# #
|
805 |
+
|
806 |
+
#
|
807 |
+
# Removed in version 1.4.0 due to incompatibility with Python 3.12
|
808 |
+
#
|
809 |
+
'''
|
810 |
+
import asyncore
|
811 |
+
import asynchat
|
812 |
+
import socket
|
813 |
+
|
814 |
+
class _AsynHandler(asynchat.async_chat):
|
815 |
+
"asynchat handler starting a new interpreter loop for each connection"
|
816 |
+
|
817 |
+
terminator = '\r\n' # the standard one for telnet
|
818 |
+
prompt = 'i> '
|
819 |
+
|
820 |
+
def __init__(self, socket, interpreter):
|
821 |
+
asynchat.async_chat.__init__(self, socket)
|
822 |
+
self.set_terminator(self.terminator)
|
823 |
+
self.i = interpreter
|
824 |
+
self.i.__enter__()
|
825 |
+
self.data = []
|
826 |
+
self.write(self.prompt)
|
827 |
+
|
828 |
+
def write(self, data, *args):
|
829 |
+
"Push a string back to the client"
|
830 |
+
if args:
|
831 |
+
data %= args
|
832 |
+
if data.endswith('\n') and not data.endswith(self.terminator):
|
833 |
+
data = data[:-1] + self.terminator # fix newlines
|
834 |
+
self.push(data)
|
835 |
+
|
836 |
+
def collect_incoming_data(self, data):
|
837 |
+
"Collect one character at the time"
|
838 |
+
self.data.append(data)
|
839 |
+
|
840 |
+
def found_terminator(self):
|
841 |
+
"Put in the queue the line received from the client"
|
842 |
+
line = ''.join(self.data)
|
843 |
+
self.log('Received line %r from %s' % (line, self.addr))
|
844 |
+
if line == 'EOF':
|
845 |
+
self.i.__exit__(None, None, None)
|
846 |
+
self.handle_close()
|
847 |
+
else:
|
848 |
+
task = self.i.submit(line)
|
849 |
+
task.run() # synchronous or not
|
850 |
+
if task.etype: # manage exception
|
851 |
+
error = '%s: %s\nReceived: %s' % (
|
852 |
+
task.etype.__name__, task.exc, ' '.join(task.arglist))
|
853 |
+
self.log_info(task.traceback + error) # on the server
|
854 |
+
self.write(error + self.terminator) # back to the client
|
855 |
+
else: # no exception
|
856 |
+
self.write(task.str + self.terminator)
|
857 |
+
self.data = []
|
858 |
+
self.write(self.prompt)
|
859 |
+
|
860 |
+
|
861 |
+
class _AsynServer(asyncore.dispatcher):
|
862 |
+
"asyncore-based server spawning AsynHandlers"
|
863 |
+
|
864 |
+
def __init__(self, interpreter, newhandler, port, listen=5):
|
865 |
+
self.interpreter = interpreter
|
866 |
+
self.newhandler = newhandler
|
867 |
+
self.port = port
|
868 |
+
asyncore.dispatcher.__init__(self)
|
869 |
+
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
870 |
+
self.bind(('', port))
|
871 |
+
self.listen(listen)
|
872 |
+
|
873 |
+
def handle_accept(self):
|
874 |
+
clientsock, clientaddr = self.accept()
|
875 |
+
self.log('Connected from %s' % str(clientaddr))
|
876 |
+
i = self.interpreter.__class__(self.interpreter.obj) # new interpreter
|
877 |
+
self.newhandler(clientsock, i) # spawn a new handler
|
878 |
+
|
879 |
+
'''
|
880 |
+
|
881 |
+
# ########################## the Interpreter ############################ #
|
882 |
+
|
883 |
+
class Interpreter(object):
|
884 |
+
"""
|
885 |
+
A context manager with a .send method and a few utility methods:
|
886 |
+
execute, test and doctest.
|
887 |
+
"""
|
888 |
+
class Exit(Exception):
|
889 |
+
pass
|
890 |
+
|
891 |
+
def __init__(self, obj, commentchar='#', split=shlex.split):
|
892 |
+
self.obj = obj
|
893 |
+
try:
|
894 |
+
self.name = obj.__module__
|
895 |
+
except AttributeError:
|
896 |
+
self.name = 'plac'
|
897 |
+
self.commentchar = commentchar
|
898 |
+
self.split = split
|
899 |
+
self._set_commands(obj)
|
900 |
+
self.tm = TaskManager(obj)
|
901 |
+
self.man = self.tm.man
|
902 |
+
self.parser = self.tm.parser
|
903 |
+
if self.commands:
|
904 |
+
self.parser.addsubcommands(
|
905 |
+
self.tm.specialcommands, self.tm, title='special commands')
|
906 |
+
if obj.mpcommands:
|
907 |
+
self.parser.addsubcommands(
|
908 |
+
obj.mpcommands, obj,
|
909 |
+
title='commands run in external processes')
|
910 |
+
if obj.thcommands:
|
911 |
+
self.parser.addsubcommands(
|
912 |
+
obj.thcommands, obj, title='threaded commands')
|
913 |
+
self.parser.error = lambda msg: sys.exit(msg) # patch the parser
|
914 |
+
self._interpreter = None
|
915 |
+
|
916 |
+
def _set_commands(self, obj):
|
917 |
+
"Make sure obj has the right command attributes as Python sets"
|
918 |
+
for attrname in ('commands', 'mpcommands', 'thcommands'):
|
919 |
+
setattr(self, attrname, set(getattr(self.__class__, attrname, [])))
|
920 |
+
setattr(obj, attrname, set(getattr(obj, attrname, [])))
|
921 |
+
self.commands = obj.commands
|
922 |
+
self.mpcommands.update(obj.mpcommands)
|
923 |
+
self.thcommands.update(obj.thcommands)
|
924 |
+
if (obj.commands or obj.mpcommands or obj.thcommands) and \
|
925 |
+
not hasattr(obj, 'help'): # add default help
|
926 |
+
obj.help = default_help.__get__(obj, obj.__class__)
|
927 |
+
self.commands.add('help')
|
928 |
+
|
929 |
+
def __enter__(self):
|
930 |
+
"Start the inner interpreter loop"
|
931 |
+
self._interpreter = self._make_interpreter()
|
932 |
+
self._interpreter.send(None)
|
933 |
+
return self
|
934 |
+
|
935 |
+
def __exit__(self, exctype, exc, tb):
|
936 |
+
"Close the inner interpreter and the task manager"
|
937 |
+
self.close(exctype, exc, tb)
|
938 |
+
|
939 |
+
def submit(self, line):
|
940 |
+
"Send a line to the underlying interpreter and return a task object"
|
941 |
+
if self._interpreter is None:
|
942 |
+
raise RuntimeError(_('%r not initialized: probably you forgot to '
|
943 |
+
'use the with statement') % self)
|
944 |
+
if isinstance(line, (str, bytes)):
|
945 |
+
arglist = self.split(line, self.commentchar)
|
946 |
+
else: # expects a list of strings
|
947 |
+
arglist = line
|
948 |
+
if not arglist:
|
949 |
+
return nulltask
|
950 |
+
m = self.tm.man # manager
|
951 |
+
if m and not m.started:
|
952 |
+
m.start()
|
953 |
+
task = self._interpreter.send(arglist) # nonblocking
|
954 |
+
if not plac_core._match_cmd(arglist[0], self.tm.specialcommands):
|
955 |
+
self.tm.registry[task.no] = task
|
956 |
+
if m:
|
957 |
+
m.add_listener(task.no)
|
958 |
+
return task
|
959 |
+
|
960 |
+
def send(self, line):
|
961 |
+
"""Send a line to the underlying interpreter and return
|
962 |
+
the finished task"""
|
963 |
+
task = self.submit(line)
|
964 |
+
BaseTask.run(task) # blocking
|
965 |
+
return task
|
966 |
+
|
967 |
+
def tasks(self):
|
968 |
+
"The full lists of the submitted tasks"
|
969 |
+
return self.tm.registry.values()
|
970 |
+
|
971 |
+
def close(self, exctype=None, exc=None, tb=None):
|
972 |
+
"Can be called to close the interpreter prematurely"
|
973 |
+
self.tm.close()
|
974 |
+
if exctype is not None:
|
975 |
+
self._interpreter.throw(exctype, exc, tb)
|
976 |
+
else:
|
977 |
+
self._interpreter.close()
|
978 |
+
|
979 |
+
def _make_interpreter(self):
|
980 |
+
"The interpreter main loop, from lists of arguments to task objects"
|
981 |
+
enter = getattr(self.obj, '__enter__', lambda: None)
|
982 |
+
exit = getattr(self.obj, '__exit__', lambda et, ex, tb: None)
|
983 |
+
enter()
|
984 |
+
task = None
|
985 |
+
try:
|
986 |
+
for no in itertools.count(1):
|
987 |
+
arglist = yield task
|
988 |
+
try:
|
989 |
+
cmd, result = self.parser.consume(arglist)
|
990 |
+
except SystemExit as e: # for invalid commands
|
991 |
+
if e.args == (0,): # raised as sys.exit(0)
|
992 |
+
errlist = []
|
993 |
+
else:
|
994 |
+
errlist = [str(e)]
|
995 |
+
task = SynTask(no, arglist, iter(errlist))
|
996 |
+
continue
|
997 |
+
except: # anything else
|
998 |
+
task = SynTask(no, arglist, gen_exc(*sys.exc_info()))
|
999 |
+
continue
|
1000 |
+
if not plac_core.iterable(result): # atomic result
|
1001 |
+
task = SynTask(no, arglist, gen_val(result))
|
1002 |
+
elif cmd in self.obj.mpcommands:
|
1003 |
+
task = MPTask(no, arglist, result, self.tm.man)
|
1004 |
+
elif cmd in self.obj.thcommands:
|
1005 |
+
task = ThreadedTask(no, arglist, result)
|
1006 |
+
else: # blocking task
|
1007 |
+
task = SynTask(no, arglist, result)
|
1008 |
+
except GeneratorExit: # regular exit
|
1009 |
+
exit(None, None, None)
|
1010 |
+
except: # exceptional exit
|
1011 |
+
exit(*sys.exc_info())
|
1012 |
+
raise
|
1013 |
+
|
1014 |
+
def check(self, given_input, expected_output):
|
1015 |
+
"Make sure you get the expected_output from the given_input"
|
1016 |
+
output = self.send(given_input).str # blocking
|
1017 |
+
ok = (output == expected_output)
|
1018 |
+
if not ok:
|
1019 |
+
# the message here is not internationalized on purpose
|
1020 |
+
msg = 'input: %s\noutput: %s\nexpected: %s' % (
|
1021 |
+
given_input, output, expected_output)
|
1022 |
+
raise AssertionError(msg)
|
1023 |
+
|
1024 |
+
def _parse_doctest(self, lineiter):
|
1025 |
+
"Returns the lines of input, the lines of output, and the line number"
|
1026 |
+
lines = [line.strip() for line in lineiter]
|
1027 |
+
inputs = []
|
1028 |
+
positions = []
|
1029 |
+
for i, line in enumerate(lines):
|
1030 |
+
if line.startswith('i> '):
|
1031 |
+
inputs.append(line[3:])
|
1032 |
+
positions.append(i)
|
1033 |
+
positions.append(len(lines) + 1) # last position
|
1034 |
+
outputs = []
|
1035 |
+
for i, start in enumerate(positions[:-1]):
|
1036 |
+
end = positions[i + 1]
|
1037 |
+
outputs.append('\n'.join(lines[start+1:end]))
|
1038 |
+
return zip(inputs, outputs, positions)
|
1039 |
+
|
1040 |
+
def doctest(self, lineiter, verbose=False):
|
1041 |
+
"""
|
1042 |
+
Parse a text containing doctests in a context and tests of all them.
|
1043 |
+
Raise an error even if a single doctest if broken. Use this for
|
1044 |
+
sequential tests which are logically grouped.
|
1045 |
+
"""
|
1046 |
+
with self:
|
1047 |
+
try:
|
1048 |
+
for input, output, no in self._parse_doctest(lineiter):
|
1049 |
+
if verbose:
|
1050 |
+
write('i> %s\n' % input)
|
1051 |
+
write('-> %s\n' % output)
|
1052 |
+
task = self.send(input) # blocking
|
1053 |
+
if not str(task) == output:
|
1054 |
+
msg = ('line %d: input: %s\noutput: %s\nexpected: %s\n'
|
1055 |
+
% (no + 1, input, task, output))
|
1056 |
+
write(msg)
|
1057 |
+
if task.exc:
|
1058 |
+
raise_(task.etype, task.exc, task.tb)
|
1059 |
+
except self.Exit:
|
1060 |
+
pass
|
1061 |
+
|
1062 |
+
def execute(self, lineiter, verbose=False):
|
1063 |
+
"Execute a lineiter of commands in a context and print the output"
|
1064 |
+
with self:
|
1065 |
+
try:
|
1066 |
+
for line in lineiter:
|
1067 |
+
if verbose:
|
1068 |
+
write('i> ' + line)
|
1069 |
+
task = self.send(line) # finished task
|
1070 |
+
if task.etype: # there was an error
|
1071 |
+
raise_(task.etype, task.exc, task.tb)
|
1072 |
+
write('%s\n' % task.str)
|
1073 |
+
except self.Exit:
|
1074 |
+
pass
|
1075 |
+
|
1076 |
+
def multiline(self, stdin=sys.stdin, terminator=';', verbose=False):
|
1077 |
+
"The multiline mode is especially suited for usage with emacs"
|
1078 |
+
with self:
|
1079 |
+
try:
|
1080 |
+
for line in read_long_line(stdin, terminator):
|
1081 |
+
task = self.submit(line)
|
1082 |
+
task.run()
|
1083 |
+
write('%s\n' % task.str)
|
1084 |
+
if verbose and task.traceback:
|
1085 |
+
write(task.traceback)
|
1086 |
+
except self.Exit:
|
1087 |
+
pass
|
1088 |
+
|
1089 |
+
def interact(self, stdin=sys.stdin, prompt='i> ', verbose=False):
|
1090 |
+
"Starts an interactive command loop reading commands from the console"
|
1091 |
+
try:
|
1092 |
+
import readline
|
1093 |
+
readline_present = True
|
1094 |
+
except ImportError:
|
1095 |
+
readline_present = False
|
1096 |
+
if stdin is sys.stdin and readline_present: # use readline
|
1097 |
+
histfile = os.path.expanduser('~/.%s.history' % self.name)
|
1098 |
+
completions = list(self.commands) + list(self.mpcommands) + \
|
1099 |
+
list(self.thcommands) + list(self.tm.specialcommands)
|
1100 |
+
self.stdin = ReadlineInput(completions, histfile=histfile)
|
1101 |
+
else:
|
1102 |
+
self.stdin = stdin
|
1103 |
+
self.prompt = prompt
|
1104 |
+
self.verbose = verbose
|
1105 |
+
intro = self.obj.__doc__ or ''
|
1106 |
+
write(intro + '\n')
|
1107 |
+
with self:
|
1108 |
+
self.obj._interact_ = True
|
1109 |
+
if self.stdin is sys.stdin: # do not close stdin automatically
|
1110 |
+
self._manage_input()
|
1111 |
+
else:
|
1112 |
+
with self.stdin: # close stdin automatically
|
1113 |
+
self._manage_input()
|
1114 |
+
|
1115 |
+
def _manage_input(self):
|
1116 |
+
"Convert input lines into task which are then executed"
|
1117 |
+
try:
|
1118 |
+
for line in iter(lambda: read_line(self.stdin, self.prompt), ''):
|
1119 |
+
line = line.strip()
|
1120 |
+
if not line:
|
1121 |
+
continue
|
1122 |
+
task = self.submit(line)
|
1123 |
+
task.run() # synchronous or not
|
1124 |
+
write(str(task) + '\n')
|
1125 |
+
if self.verbose and task.etype:
|
1126 |
+
write(task.traceback)
|
1127 |
+
except self.Exit:
|
1128 |
+
pass
|
1129 |
+
|
1130 |
+
def start_server(self, port=2199, **kw):
|
1131 |
+
"""Starts an asyncore server reading commands for clients and opening
|
1132 |
+
a new interpreter for each connection."""
|
1133 |
+
_AsynServer(self, _AsynHandler, port) # register the server
|
1134 |
+
try:
|
1135 |
+
asyncore.loop(**kw)
|
1136 |
+
except (KeyboardInterrupt, TerminatedProcess):
|
1137 |
+
pass
|
1138 |
+
finally:
|
1139 |
+
asyncore.close_all()
|
1140 |
+
|
1141 |
+
def add_monitor(self, mon):
|
1142 |
+
self.man.add(mon)
|
1143 |
+
|
1144 |
+
def del_monitor(self, name):
|
1145 |
+
self.man.delete(name)
|
1146 |
+
|
1147 |
+
@classmethod
|
1148 |
+
def call(cls, factory, arglist=sys.argv[1:],
|
1149 |
+
commentchar='#', split=shlex.split,
|
1150 |
+
stdin=sys.stdin, prompt='i> ', verbose=False):
|
1151 |
+
"""
|
1152 |
+
Call a container factory with the arglist and instantiate an
|
1153 |
+
interpreter object. If there are remaining arguments, send them to the
|
1154 |
+
interpreter, else start an interactive session.
|
1155 |
+
"""
|
1156 |
+
obj = partial_call(factory, arglist)
|
1157 |
+
i = cls(obj, commentchar, split)
|
1158 |
+
if i.obj._args_:
|
1159 |
+
with i:
|
1160 |
+
task = i.send(i.obj._args_) # synchronous
|
1161 |
+
if task.exc:
|
1162 |
+
raise_(task.etype, task.exc, task.tb)
|
1163 |
+
out = str(task)
|
1164 |
+
if out:
|
1165 |
+
print(out)
|
1166 |
+
elif i.obj._interact_:
|
1167 |
+
i.interact(stdin, prompt, verbose)
|
1168 |
+
else:
|
1169 |
+
i.parser.print_usage()
|
1170 |
+
|
1171 |
+
# ################################## runp ################################### #
|
1172 |
+
|
1173 |
+
|
1174 |
+
class _TaskLauncher(object):
|
1175 |
+
"Helper for runp"
|
1176 |
+
|
1177 |
+
def __init__(self, genseq, mode):
|
1178 |
+
if mode == 'p':
|
1179 |
+
self.mpcommands = ['rungen']
|
1180 |
+
else:
|
1181 |
+
self.thcommands = ['rungen']
|
1182 |
+
self.genlist = list(genseq)
|
1183 |
+
|
1184 |
+
def rungen(self, i):
|
1185 |
+
for out in self.genlist[int(i) - 1]:
|
1186 |
+
yield out
|
1187 |
+
|
1188 |
+
|
1189 |
+
def runp(genseq, mode='p'):
|
1190 |
+
"""Run a sequence of generators in parallel. Mode can be 'p' (use processes)
|
1191 |
+
or 't' (use threads). After all of them are finished, return a list of
|
1192 |
+
task objects.
|
1193 |
+
"""
|
1194 |
+
assert mode in 'pt', mode
|
1195 |
+
launcher = _TaskLauncher(genseq, mode)
|
1196 |
+
res = []
|
1197 |
+
with Interpreter(launcher) as inter:
|
1198 |
+
for i in range(len(launcher.genlist)):
|
1199 |
+
inter.submit('rungen %d' % (i + 1)).run()
|
1200 |
+
for task in inter.tasks():
|
1201 |
+
try:
|
1202 |
+
res.append(task.result)
|
1203 |
+
except Exception as e:
|
1204 |
+
res.append(e)
|
1205 |
+
return res
|
.venv/Lib/site-packages/plac_tk.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
if sys.version_info < (3,):
|
5 |
+
import Queue as queue
|
6 |
+
else:
|
7 |
+
import queue
|
8 |
+
import plac_core
|
9 |
+
from Tkinter import Tk
|
10 |
+
from ScrolledText import ScrolledText
|
11 |
+
from plac_ext import Monitor, TerminatedProcess
|
12 |
+
|
13 |
+
|
14 |
+
class TkMonitor(Monitor):
|
15 |
+
"""
|
16 |
+
An interface over a dictionary {taskno: scrolledtext widget}, with
|
17 |
+
methods add_listener, del_listener, notify_listener and start/stop.
|
18 |
+
"""
|
19 |
+
def __init__(self, name, queue=None):
|
20 |
+
Monitor.__init__(self, name, queue)
|
21 |
+
self.widgets = {}
|
22 |
+
|
23 |
+
@plac_core.annotations(taskno=('task number', 'positional', None, int))
|
24 |
+
def add_listener(self, taskno):
|
25 |
+
"There is a ScrolledText for each task"
|
26 |
+
st = ScrolledText(self.root, height=5)
|
27 |
+
st.insert('end', 'Output of task %d\n' % taskno)
|
28 |
+
st.pack()
|
29 |
+
self.widgets[taskno] = st
|
30 |
+
|
31 |
+
@plac_core.annotations(taskno=('task number', 'positional', None, int))
|
32 |
+
def del_listener(self, taskno):
|
33 |
+
del self.widgets[taskno]
|
34 |
+
|
35 |
+
@plac_core.annotations(taskno=('task number', 'positional', None, int))
|
36 |
+
def notify_listener(self, taskno, msg):
|
37 |
+
w = self.widgets[taskno]
|
38 |
+
w.insert('end', msg + '\n')
|
39 |
+
w.update()
|
40 |
+
|
41 |
+
def start(self):
|
42 |
+
'Start the mainloop'
|
43 |
+
self.root = Tk()
|
44 |
+
self.root.title(self.name)
|
45 |
+
self.root.wm_protocol("WM_DELETE_WINDOW", self.stop)
|
46 |
+
self.root.after(0, self.read_queue)
|
47 |
+
try:
|
48 |
+
self.root.mainloop()
|
49 |
+
except KeyboardInterrupt:
|
50 |
+
print('Process %d killed by CTRL-C' % os.getpid(), file=sys.stderr)
|
51 |
+
except TerminatedProcess:
|
52 |
+
pass
|
53 |
+
|
54 |
+
def stop(self):
|
55 |
+
self.root.quit()
|
56 |
+
|
57 |
+
def read_queue(self):
|
58 |
+
try:
|
59 |
+
cmd_args = self.queue.get_nowait()
|
60 |
+
except queue.Empty:
|
61 |
+
pass
|
62 |
+
else:
|
63 |
+
getattr(self, cmd_args[0])(*cmd_args[1:])
|
64 |
+
self.root.after(100, self.read_queue)
|
.venv/Lib/site-packages/pylab.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib.pylab import * # noqa: F401, F403
|
2 |
+
import matplotlib.pylab
|
3 |
+
__doc__ = matplotlib.pylab.__doc__
|