ayousanz commited on
Commit
b9aaa56
·
verified ·
1 Parent(s): 9e8681b

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +44 -35
  2. .venv/.gitignore +1 -0
  3. .venv/.lock +0 -0
  4. .venv/CACHEDIR.TAG +1 -0
  5. .venv/CHANGES.rst +76 -0
  6. .venv/Lib/site-packages/_cffi_backend.cp39-win_amd64.pyd +0 -0
  7. .venv/Lib/site-packages/_soundfile.py +11 -0
  8. .venv/Lib/site-packages/_virtualenv.py +101 -0
  9. .venv/Lib/site-packages/accelerate/__init__.py +50 -0
  10. .venv/Lib/site-packages/accelerate/accelerator.py +0 -0
  11. .venv/Lib/site-packages/accelerate/big_modeling.py +637 -0
  12. .venv/Lib/site-packages/accelerate/checkpointing.py +306 -0
  13. .venv/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-39.pyc +0 -0
  14. .venv/Lib/site-packages/accelerate/commands/accelerate_cli.py +52 -0
  15. .venv/Lib/site-packages/accelerate/commands/config/__init__.py +52 -0
  16. .venv/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-39.pyc +0 -0
  17. .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-39.pyc +0 -0
  18. .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-39.pyc +0 -0
  19. .venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-39.pyc +0 -0
  20. .venv/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-39.pyc +0 -0
  21. .venv/Lib/site-packages/accelerate/commands/config/config_args.py +252 -0
  22. .venv/Lib/site-packages/accelerate/commands/config/default.py +142 -0
  23. .venv/Lib/site-packages/accelerate/commands/config/sagemaker.py +267 -0
  24. .venv/Lib/site-packages/accelerate/commands/config/update.py +63 -0
  25. .venv/Lib/site-packages/accelerate/commands/env.py +113 -0
  26. .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-39.pyc +0 -0
  28. .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-39.pyc +0 -0
  29. .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-39.pyc +0 -0
  30. .venv/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/accelerate/data_loader.py +1323 -0
  32. .venv/Lib/site-packages/accelerate/hooks.py +726 -0
  33. .venv/Lib/site-packages/accelerate/inference.py +184 -0
  34. .venv/Lib/site-packages/accelerate/launchers.py +302 -0
  35. .venv/Lib/site-packages/accelerate/local_sgd.py +104 -0
  36. .venv/Lib/site-packages/accelerate/logging.py +125 -0
  37. .venv/Lib/site-packages/accelerate/memory_utils.py +22 -0
  38. .venv/Lib/site-packages/accelerate/optimizer.py +212 -0
  39. .venv/Lib/site-packages/accelerate/scheduler.py +98 -0
  40. .venv/Lib/site-packages/accelerate/state.py +1257 -0
  41. .venv/Lib/site-packages/accelerate/tracking.py +1023 -0
  42. .venv/Lib/site-packages/decorator.py +451 -0
  43. .venv/Lib/site-packages/isympy.py +342 -0
  44. .venv/Lib/site-packages/mojimoji.cp39-win_amd64.pyd +0 -0
  45. .venv/Lib/site-packages/numpy-1.26.3-cp39-cp39-win_amd64.whl +0 -0
  46. .venv/Lib/site-packages/plac.py +37 -0
  47. .venv/Lib/site-packages/plac_core.py +439 -0
  48. .venv/Lib/site-packages/plac_ext.py +1205 -0
  49. .venv/Lib/site-packages/plac_tk.py +64 -0
  50. .venv/Lib/site-packages/pylab.py +3 -0
.gitattributes CHANGED
@@ -1,35 +1,44 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Utils/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
37
+ Utils/PLBERT/step_1050000.t7 filter=lfs diff=lfs merge=lfs -text
38
+ reference_sample_wavs/01008270.wav filter=lfs diff=lfs merge=lfs -text
39
+ reference_sample_wavs/kaede_san.wav filter=lfs diff=lfs merge=lfs -text
40
+ reference_sample_wavs/riamu_zeroshot_02.wav filter=lfs diff=lfs merge=lfs -text
41
+ reference_sample_wavs/sample_ref01.wav filter=lfs diff=lfs merge=lfs -text
42
+ reference_sample_wavs/sample_ref02.wav filter=lfs diff=lfs merge=lfs -text
43
+ reference_sample_wavs/shiki_fine05.wav filter=lfs diff=lfs merge=lfs -text
44
+ reference_sample_wavs/syuukovoice_200918_3_01.wav filter=lfs diff=lfs merge=lfs -text
.venv/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
.venv/.lock ADDED
File without changes
.venv/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
.venv/CHANGES.rst ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CHANGES
2
+ =======
3
+
4
+ 0.4.0 (2024-7-26)
5
+ -------------------
6
+ - Add stub files according to PEP 561 for mypy (thanks @ernix)
7
+
8
+ 0.3.4 (2023-2-18)
9
+ -------------------
10
+ - Fix to support Python2.7 ~ 3.4 (thanks @manjuu-eater)
11
+ - Support Python 3.11
12
+
13
+ 0.3.3 (2022-12-31)
14
+ -------------------
15
+ - Support Python 3.10
16
+ - Re-support Python2.7 ~ 3.4 (thanks @manjuu-eater)
17
+ - Fix z2h, h2z all flag off bug (thanks @manjuu-eater)
18
+
19
+ 0.3.1 (2022-12-14)
20
+ -------------------
21
+ - Fix alpha2kana infinite loop bug (thanks @frog42)
22
+
23
+ 0.3 (2021-03-29)
24
+ -------------------
25
+ - Fix bug (alphabet2kana) thanks @Cuddlemuffin007
26
+ - Support Python 3.8 and 3.9
27
+ - Add handy functions: alphabet2kata and kata2alphabet. thanks @kokimame
28
+ - Add function for julius: hiragana2julius
29
+
30
+ 0.2.4 (2018-02-04)
31
+ -------------------
32
+ - Fix bug (kana2alphabet)
33
+ - Support Python 3.7
34
+ - No longer support Python 2.6
35
+ - Add aliases of z2h -> zenkaku2hankaku and h2z -> hankaku2zenkaku
36
+
37
+ 0.2.3 (2018-02-03)
38
+ -------------------
39
+ - Fix bugs (alphabet2kana, kana2alphabet) thanks @letuananh
40
+
41
+ 0.2.2 (2018-01-22)
42
+ -------------------
43
+ - Fix bug (kana2alphabet) thanks @kokimame
44
+ - Support Python 3.6
45
+
46
+ 0.2.1 (2017-09-14)
47
+ -------------------
48
+ - Fix bugs (alphabet2kana, kana2alphabet)
49
+
50
+ 0.2 (2015-04-02)
51
+ ------------------
52
+
53
+ - Change module name jctconv -> jaconv
54
+ - Add alphabet and hiragana interconvert (alphabet2kana, kana2alphabet)
55
+
56
+ 0.1.1 (2015-03-12)
57
+ ------------------
58
+
59
+ - Support Windows
60
+ - Support Python 3.5
61
+
62
+
63
+ 0.1 (2014-11-24)
64
+ ------------------
65
+
66
+ - Add some Japanese characters to convert table (ゝゞ・「」。、)
67
+ - Decresing memory usage
68
+ - Some function names are deprecated (hankaku2zenkaku, zenkaku2hankaku, H2K, H2hK, K2H)
69
+
70
+
71
+ 0.0.7 (2014-03-22)
72
+ ------------------
73
+
74
+ z2h and h2z allow mojimoji-like target character type determination.
75
+ Bug fix about Half Kana conversion.
76
+
.venv/Lib/site-packages/_cffi_backend.cp39-win_amd64.pyd ADDED
Binary file (178 kB). View file
 
.venv/Lib/site-packages/_soundfile.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # auto-generated file
2
+ import _cffi_backend
3
+
4
+ ffi = _cffi_backend.FFI('_soundfile',
5
+ _version = 0x2601,
6
+ _types = b'\x00\x00\x17\x0D\x00\x00\x6D\x03\x00\x00\x07\x01\x00\x00\x6C\x03\x00\x00\x7A\x03\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x6F\x03\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x07\x01\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x17\x0D\x00\x00\x7B\x03\x00\x00\x07\x01\x00\x00\x03\x11\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x6E\x03\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x07\x0D\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x6C\x03\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x6F\x03\x00\x00\x1C\x01\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x07\x11\x00\x00\x00\x0F\x00\x00\x02\x0D\x00\x00\x17\x11\x00\x00\x07\x01\x00\x00\x04\x11\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x70\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x74\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x02\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x17\x01\x00\x00\x07\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x79\x03\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x11\x00\x00\x04\x11\x00\x00\x17\x01\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x17\x01\x00\x00\x07\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x04\x11\x00\x00\x17\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x3B\x0D\x00\x00\x7A\x03\x00\x00\x17\x01\x00\x00\x04\x11\x00\x00\x00\x0F\x00\x00\x7A\x0D\x00\x00\x17\x11\x00\x00\x00\x0F\x00\x00\x00\x09\x00\x00\x01\x09\x00\x00\x02\x09\x00\x00\x03\x09\x00\x00\x02\x01\x00\x00\x0E\x01\x00\x00\x00\x0B\x00\x00\x01\x0B\x00\x00\x02\x0B\x00\x00\x0D\x01\x00\x00\x56\x03\x00\x00\x5B\x03\x00\x00\x5E\x03\x00\x00\x63\x03\x00\x00\x05\x01\x00\x00\x00\x01\x00\x00\x10\x01',
7
+ _globals = (b'\xFF\xFF\xFF\x0BSFC_FILE_TRUNCATE',4224,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_INFO',4136,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_MAJOR',4145,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_MAJOR_COUNT',4144,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_SUBTYPE',4147,b'\xFF\xFF\xFF\x0BSFC_GET_FORMAT_SUBTYPE_COUNT',4146,b'\xFF\xFF\xFF\x0BSFC_GET_LIB_VERSION',4096,b'\xFF\xFF\xFF\x0BSFC_GET_LOG_INFO',4097,b'\xFF\xFF\xFF\x0BSFC_SET_CLIPPING',4288,b'\xFF\xFF\xFF\x0BSFC_SET_SCALE_FLOAT_INT_READ',4116,b'\xFF\xFF\xFF\x0BSFC_SET_SCALE_INT_FLOAT_WRITE',4117,b'\xFF\xFF\xFF\x0BSFM_RDWR',48,b'\xFF\xFF\xFF\x0BSFM_READ',16,b'\xFF\xFF\xFF\x0BSFM_WRITE',32,b'\xFF\xFF\xFF\x0BSF_FALSE',0,b'\xFF\xFF\xFF\x0BSF_FORMAT_ENDMASK',805306368,b'\xFF\xFF\xFF\x0BSF_FORMAT_SUBMASK',65535,b'\xFF\xFF\xFF\x0BSF_FORMAT_TYPEMASK',268369920,b'\xFF\xFF\xFF\x0BSF_TRUE',1,b'\x00\x00\x25\x23sf_close',0,b'\x00\x00\x32\x23sf_command',0,b'\x00\x00\x25\x23sf_error',0,b'\x00\x00\x1D\x23sf_error_number',0,b'\x00\x00\x28\x23sf_error_str',0,b'\x00\x00\x22\x23sf_format_check',0,b'\x00\x00\x19\x23sf_get_string',0,b'\x00\x00\x06\x23sf_open',0,b'\x00\x00\x0B\x23sf_open_fd',0,b'\x00\x00\x00\x23sf_open_virtual',0,b'\x00\x00\x25\x23sf_perror',0,b'\x00\x00\x38\x23sf_read_double',0,b'\x00\x00\x3D\x23sf_read_float',0,b'\x00\x00\x42\x23sf_read_int',0,b'\x00\x00\x51\x23sf_read_raw',0,b'\x00\x00\x4C\x23sf_read_short',0,b'\x00\x00\x51\x23sf_readf_double',0,b'\x00\x00\x51\x23sf_readf_float',0,b'\x00\x00\x51\x23sf_readf_int',0,b'\x00\x00\x51\x23sf_readf_short',0,b'\x00\x00\x47\x23sf_seek',0,b'\x00\x00\x2D\x23sf_set_string',0,b'\x00\x00\x16\x23sf_strerror',0,b'\x00\x00\x20\x23sf_version_string',0,b'\x00\x00\x11\x23sf_wchar_open',0,b'\x00\x00\x38\x23sf_write_double',0,b'\x00\x00\x3D\x23sf_write_float',0,b'\x00\x00\x42\x23sf_write_int',0,b'\x00\x00\x51\x23sf_write_raw',0,b'\x00\x00\x4C\x23sf_write_short',0,b'\x00\x00\x68\x23sf_write_sync',0,b'\x00\x00\x51\x23sf_writef_double',0,b'\x00\x00\x51\x23sf_writef_float',0,b'\x00\x00\x51\x23sf_writef_int',0,b'\x00\x00\x51\x23sf_writef_short',0),
8
+ _struct_unions = ((b'\x00\x00\x00\x6B\x00\x00\x00\x02SF_FORMAT_INFO',b'\x00\x00\x02\x11format',b'\x00\x00\x07\x11name',b'\x00\x00\x07\x11extension'),(b'\x00\x00\x00\x6C\x00\x00\x00\x02SF_INFO',b'\x00\x00\x3B\x11frames',b'\x00\x00\x02\x11samplerate',b'\x00\x00\x02\x11channels',b'\x00\x00\x02\x11format',b'\x00\x00\x02\x11sections',b'\x00\x00\x02\x11seekable'),(b'\x00\x00\x00\x6D\x00\x00\x00\x02SF_VIRTUAL_IO',b'\x00\x00\x76\x11get_filelen',b'\x00\x00\x75\x11seek',b'\x00\x00\x77\x11read',b'\x00\x00\x78\x11write',b'\x00\x00\x76\x11tell'),(b'\x00\x00\x00\x6E\x00\x00\x00\x10SNDFILE_tag',)),
9
+ _enums = (b'\x00\x00\x00\x71\x00\x00\x00\x16$1\x00SF_FORMAT_SUBMASK,SF_FORMAT_TYPEMASK,SF_FORMAT_ENDMASK',b'\x00\x00\x00\x72\x00\x00\x00\x16$2\x00SFC_GET_LIB_VERSION,SFC_GET_LOG_INFO,SFC_GET_FORMAT_INFO,SFC_GET_FORMAT_MAJOR_COUNT,SFC_GET_FORMAT_MAJOR,SFC_GET_FORMAT_SUBTYPE_COUNT,SFC_GET_FORMAT_SUBTYPE,SFC_FILE_TRUNCATE,SFC_SET_CLIPPING,SFC_SET_SCALE_FLOAT_INT_READ,SFC_SET_SCALE_INT_FLOAT_WRITE',b'\x00\x00\x00\x73\x00\x00\x00\x16$3\x00SF_FALSE,SF_TRUE,SFM_READ,SFM_WRITE,SFM_RDWR'),
10
+ _typenames = (b'\x00\x00\x00\x6BSF_FORMAT_INFO',b'\x00\x00\x00\x6CSF_INFO',b'\x00\x00\x00\x6DSF_VIRTUAL_IO',b'\x00\x00\x00\x6ESNDFILE',b'\x00\x00\x00\x3Bsf_count_t',b'\x00\x00\x00\x76sf_vio_get_filelen',b'\x00\x00\x00\x77sf_vio_read',b'\x00\x00\x00\x75sf_vio_seek',b'\x00\x00\x00\x76sf_vio_tell',b'\x00\x00\x00\x78sf_vio_write'),
11
+ )
.venv/Lib/site-packages/_virtualenv.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patches that are applied at runtime to the virtual environment."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ VIRTUALENV_PATCH_FILE = os.path.join(__file__)
7
+
8
+
9
+ def patch_dist(dist):
10
+ """
11
+ Distutils allows user to configure some arguments via a configuration file:
12
+ https://docs.python.org/3.11/install/index.html#distutils-configuration-files.
13
+
14
+ Some of this arguments though don't make sense in context of the virtual environment files, let's fix them up.
15
+ """ # noqa: D205
16
+ # we cannot allow some install config as that would get packages installed outside of the virtual environment
17
+ old_parse_config_files = dist.Distribution.parse_config_files
18
+
19
+ def parse_config_files(self, *args, **kwargs):
20
+ result = old_parse_config_files(self, *args, **kwargs)
21
+ install = self.get_option_dict("install")
22
+
23
+ if "prefix" in install: # the prefix governs where to install the libraries
24
+ install["prefix"] = VIRTUALENV_PATCH_FILE, os.path.abspath(sys.prefix)
25
+ for base in ("purelib", "platlib", "headers", "scripts", "data"):
26
+ key = f"install_{base}"
27
+ if key in install: # do not allow global configs to hijack venv paths
28
+ install.pop(key, None)
29
+ return result
30
+
31
+ dist.Distribution.parse_config_files = parse_config_files
32
+
33
+
34
+ # Import hook that patches some modules to ignore configuration values that break package installation in case
35
+ # of virtual environments.
36
+ _DISTUTILS_PATCH = "distutils.dist", "setuptools.dist"
37
+ # https://docs.python.org/3/library/importlib.html#setting-up-an-importer
38
+
39
+
40
+ class _Finder:
41
+ """A meta path finder that allows patching the imported distutils modules."""
42
+
43
+ fullname = None
44
+
45
+ # lock[0] is threading.Lock(), but initialized lazily to avoid importing threading very early at startup,
46
+ # because there are gevent-based applications that need to be first to import threading by themselves.
47
+ # See https://github.com/pypa/virtualenv/issues/1895 for details.
48
+ lock = [] # noqa: RUF012
49
+
50
+ def find_spec(self, fullname, path, target=None): # noqa: ARG002
51
+ if fullname in _DISTUTILS_PATCH and self.fullname is None:
52
+ # initialize lock[0] lazily
53
+ if len(self.lock) == 0:
54
+ import threading
55
+
56
+ lock = threading.Lock()
57
+ # there is possibility that two threads T1 and T2 are simultaneously running into find_spec,
58
+ # observing .lock as empty, and further going into hereby initialization. However due to the GIL,
59
+ # list.append() operation is atomic and this way only one of the threads will "win" to put the lock
60
+ # - that every thread will use - into .lock[0].
61
+ # https://docs.python.org/3/faq/library.html#what-kinds-of-global-value-mutation-are-thread-safe
62
+ self.lock.append(lock)
63
+
64
+ from functools import partial
65
+ from importlib.util import find_spec
66
+
67
+ with self.lock[0]:
68
+ self.fullname = fullname
69
+ try:
70
+ spec = find_spec(fullname, path)
71
+ if spec is not None:
72
+ # https://www.python.org/dev/peps/pep-0451/#how-loading-will-work
73
+ is_new_api = hasattr(spec.loader, "exec_module")
74
+ func_name = "exec_module" if is_new_api else "load_module"
75
+ old = getattr(spec.loader, func_name)
76
+ func = self.exec_module if is_new_api else self.load_module
77
+ if old is not func:
78
+ try: # noqa: SIM105
79
+ setattr(spec.loader, func_name, partial(func, old))
80
+ except AttributeError:
81
+ pass # C-Extension loaders are r/o such as zipimporter with <3.7
82
+ return spec
83
+ finally:
84
+ self.fullname = None
85
+ return None
86
+
87
+ @staticmethod
88
+ def exec_module(old, module):
89
+ old(module)
90
+ if module.__name__ in _DISTUTILS_PATCH:
91
+ patch_dist(module)
92
+
93
+ @staticmethod
94
+ def load_module(old, name):
95
+ module = old(name)
96
+ if module.__name__ in _DISTUTILS_PATCH:
97
+ patch_dist(module)
98
+ return module
99
+
100
+
101
+ sys.meta_path.insert(0, _Finder())
.venv/Lib/site-packages/accelerate/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ __version__ = "1.2.0"
15
+
16
+ from .accelerator import Accelerator
17
+ from .big_modeling import (
18
+ cpu_offload,
19
+ cpu_offload_with_hook,
20
+ disk_offload,
21
+ dispatch_model,
22
+ init_empty_weights,
23
+ init_on_device,
24
+ load_checkpoint_and_dispatch,
25
+ )
26
+ from .data_loader import skip_first_batches
27
+ from .inference import prepare_pippy
28
+ from .launchers import debug_launcher, notebook_launcher
29
+ from .state import PartialState
30
+ from .utils import (
31
+ AutocastKwargs,
32
+ DataLoaderConfiguration,
33
+ DDPCommunicationHookType,
34
+ DeepSpeedPlugin,
35
+ DistributedDataParallelKwargs,
36
+ DistributedType,
37
+ FullyShardedDataParallelPlugin,
38
+ GradScalerKwargs,
39
+ InitProcessGroupKwargs,
40
+ ProfileKwargs,
41
+ find_executable_batch_size,
42
+ infer_auto_device_map,
43
+ is_rich_available,
44
+ load_checkpoint_in_model,
45
+ synchronize_rng_states,
46
+ )
47
+
48
+
49
+ if is_rich_available():
50
+ from .utils import rich
.venv/Lib/site-packages/accelerate/accelerator.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/Lib/site-packages/accelerate/big_modeling.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ from contextlib import contextmanager
18
+ from functools import wraps
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from .hooks import (
25
+ AlignDevicesHook,
26
+ CpuOffload,
27
+ UserCpuOffloadHook,
28
+ add_hook_to_module,
29
+ attach_align_device_hook,
30
+ attach_align_device_hook_on_blocks,
31
+ )
32
+ from .utils import (
33
+ OffloadedWeightsLoader,
34
+ check_cuda_p2p_ib_support,
35
+ check_device_map,
36
+ extract_submodules_state_dict,
37
+ find_tied_parameters,
38
+ get_balanced_memory,
39
+ infer_auto_device_map,
40
+ is_bnb_available,
41
+ is_mlu_available,
42
+ is_musa_available,
43
+ is_npu_available,
44
+ is_torch_version,
45
+ is_xpu_available,
46
+ load_checkpoint_in_model,
47
+ offload_state_dict,
48
+ parse_flag_from_env,
49
+ retie_parameters,
50
+ )
51
+ from .utils.other import recursive_getattr
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ @contextmanager
58
+ def init_empty_weights(include_buffers: bool = None):
59
+ """
60
+ A context manager under which models are initialized with all parameters on the meta device, therefore creating an
61
+ empty model. Useful when just initializing the model would blow the available RAM.
62
+
63
+ Args:
64
+ include_buffers (`bool`, *optional*):
65
+ Whether or not to also put all buffers on the meta device while initializing.
66
+
67
+ Example:
68
+
69
+ ```python
70
+ import torch.nn as nn
71
+ from accelerate import init_empty_weights
72
+
73
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
74
+ with init_empty_weights():
75
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
76
+ ```
77
+
78
+ <Tip warning={true}>
79
+
80
+ Any model created under this context manager has no weights. As such you can't do something like
81
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
82
+ Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
83
+ called.
84
+
85
+ </Tip>
86
+ """
87
+ if include_buffers is None:
88
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
89
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
90
+ yield f
91
+
92
+
93
+ @contextmanager
94
+ def init_on_device(device: torch.device, include_buffers: bool = None):
95
+ """
96
+ A context manager under which models are initialized with all parameters on the specified device.
97
+
98
+ Args:
99
+ device (`torch.device`):
100
+ Device to initialize all parameters on.
101
+ include_buffers (`bool`, *optional*):
102
+ Whether or not to also put all buffers on the meta device while initializing.
103
+
104
+ Example:
105
+
106
+ ```python
107
+ import torch.nn as nn
108
+ from accelerate import init_on_device
109
+
110
+ with init_on_device(device=torch.device("cuda")):
111
+ tst = nn.Linear(100, 100) # on `cuda` device
112
+ ```
113
+ """
114
+ if include_buffers is None:
115
+ include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
116
+
117
+ # TODO(shingjan): remove the torch version check once older versions are deprecated
118
+ if is_torch_version(">=", "2.0") and include_buffers:
119
+ with device:
120
+ yield
121
+ return
122
+
123
+ old_register_parameter = nn.Module.register_parameter
124
+ if include_buffers:
125
+ old_register_buffer = nn.Module.register_buffer
126
+
127
+ def register_empty_parameter(module, name, param):
128
+ old_register_parameter(module, name, param)
129
+ if param is not None:
130
+ param_cls = type(module._parameters[name])
131
+ kwargs = module._parameters[name].__dict__
132
+ kwargs["requires_grad"] = param.requires_grad
133
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
134
+
135
+ def register_empty_buffer(module, name, buffer, persistent=True):
136
+ old_register_buffer(module, name, buffer, persistent=persistent)
137
+ if buffer is not None:
138
+ module._buffers[name] = module._buffers[name].to(device)
139
+
140
+ # Patch tensor creation
141
+ if include_buffers:
142
+ tensor_constructors_to_patch = {
143
+ torch_function_name: getattr(torch, torch_function_name)
144
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
145
+ }
146
+ else:
147
+ tensor_constructors_to_patch = {}
148
+
149
+ def patch_tensor_constructor(fn):
150
+ def wrapper(*args, **kwargs):
151
+ kwargs["device"] = device
152
+ return fn(*args, **kwargs)
153
+
154
+ return wrapper
155
+
156
+ try:
157
+ nn.Module.register_parameter = register_empty_parameter
158
+ if include_buffers:
159
+ nn.Module.register_buffer = register_empty_buffer
160
+ for torch_function_name in tensor_constructors_to_patch.keys():
161
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
162
+ yield
163
+ finally:
164
+ nn.Module.register_parameter = old_register_parameter
165
+ if include_buffers:
166
+ nn.Module.register_buffer = old_register_buffer
167
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
168
+ setattr(torch, torch_function_name, old_torch_function)
169
+
170
+
171
+ def cpu_offload(
172
+ model: nn.Module,
173
+ execution_device: Optional[torch.device] = None,
174
+ offload_buffers: bool = False,
175
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
176
+ preload_module_classes: Optional[List[str]] = None,
177
+ ):
178
+ """
179
+ Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
180
+ copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
181
+ state dict and put on the execution device passed as they are needed, then offloaded again.
182
+
183
+ Args:
184
+ model (`torch.nn.Module`):
185
+ The model to offload.
186
+ execution_device (`torch.device`, *optional*):
187
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
188
+ model first parameter device.
189
+ offload_buffers (`bool`, *optional*, defaults to `False`):
190
+ Whether or not to offload the buffers with the model parameters.
191
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
192
+ The state dict of the model that will be kept on CPU.
193
+ preload_module_classes (`List[str]`, *optional*):
194
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
195
+ of the forward. This should only be used for classes that have submodules which are registered but not
196
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
197
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
198
+ """
199
+ if execution_device is None:
200
+ execution_device = next(iter(model.parameters())).device
201
+ if state_dict is None:
202
+ state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
203
+
204
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
205
+ attach_align_device_hook(
206
+ model,
207
+ execution_device=execution_device,
208
+ offload=True,
209
+ offload_buffers=offload_buffers,
210
+ weights_map=state_dict,
211
+ preload_module_classes=preload_module_classes,
212
+ )
213
+
214
+ return model
215
+
216
+
217
+ def cpu_offload_with_hook(
218
+ model: torch.nn.Module,
219
+ execution_device: Optional[Union[int, str, torch.device]] = None,
220
+ prev_module_hook: Optional[UserCpuOffloadHook] = None,
221
+ ):
222
+ """
223
+ Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
224
+ [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
225
+ the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
226
+
227
+ Args:
228
+ model (`torch.nn.Module`):
229
+ The model to offload.
230
+ execution_device(`str`, `int` or `torch.device`, *optional*):
231
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
232
+ GPU 0 if there is a GPU, and finally to the CPU.
233
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
234
+ The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
235
+ offload method will be called just before the forward of the model to which this hook is attached.
236
+
237
+ Example:
238
+
239
+ ```py
240
+ model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
241
+ model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
242
+ model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
243
+
244
+ hid_1 = model_1(input)
245
+ for i in range(50):
246
+ # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
247
+ hid_2 = model_2(hid_1)
248
+ # model2 is offloaded to the CPU just before this forward.
249
+ hid_3 = model_3(hid_3)
250
+
251
+ # For model3, you need to manually call the hook offload method.
252
+ hook_3.offload()
253
+ ```
254
+ """
255
+ hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
256
+ add_hook_to_module(model, hook, append=True)
257
+ user_hook = UserCpuOffloadHook(model, hook)
258
+ return model, user_hook
259
+
260
+
261
+ def disk_offload(
262
+ model: nn.Module,
263
+ offload_dir: Union[str, os.PathLike],
264
+ execution_device: Optional[torch.device] = None,
265
+ offload_buffers: bool = False,
266
+ preload_module_classes: Optional[List[str]] = None,
267
+ ):
268
+ """
269
+ Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
270
+ memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
271
+ put on the execution device passed as they are needed, then offloaded again.
272
+
273
+ Args:
274
+ model (`torch.nn.Module`): The model to offload.
275
+ offload_dir (`str` or `os.PathLike`):
276
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
277
+ execution_device (`torch.device`, *optional*):
278
+ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
279
+ model's first parameter device.
280
+ offload_buffers (`bool`, *optional*, defaults to `False`):
281
+ Whether or not to offload the buffers with the model parameters.
282
+ preload_module_classes (`List[str]`, *optional*):
283
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
284
+ of the forward. This should only be used for classes that have submodules which are registered but not
285
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
286
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
287
+ """
288
+ if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
289
+ offload_state_dict(offload_dir, model.state_dict())
290
+ if execution_device is None:
291
+ execution_device = next(iter(model.parameters())).device
292
+ weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
293
+
294
+ add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
295
+ attach_align_device_hook(
296
+ model,
297
+ execution_device=execution_device,
298
+ offload=True,
299
+ offload_buffers=offload_buffers,
300
+ weights_map=weights_map,
301
+ preload_module_classes=preload_module_classes,
302
+ )
303
+
304
+ return model
305
+
306
+
307
+ def dispatch_model(
308
+ model: nn.Module,
309
+ device_map: Dict[str, Union[str, int, torch.device]],
310
+ main_device: Optional[torch.device] = None,
311
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
312
+ offload_dir: Optional[Union[str, os.PathLike]] = None,
313
+ offload_index: Optional[Dict[str, str]] = None,
314
+ offload_buffers: bool = False,
315
+ skip_keys: Optional[Union[str, List[str]]] = None,
316
+ preload_module_classes: Optional[List[str]] = None,
317
+ force_hooks: bool = False,
318
+ ):
319
+ """
320
+ Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
321
+ the CPU or even the disk.
322
+
323
+ Args:
324
+ model (`torch.nn.Module`):
325
+ The model to dispatch.
326
+ device_map (`Dict[str, Union[str, int, torch.device]]`):
327
+ A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
328
+ `"disk"` is accepted even if it's not a proper value for `torch.device`.
329
+ main_device (`str`, `int` or `torch.device`, *optional*):
330
+ The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
331
+ `"disk"`.
332
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
333
+ The state dict of the part of the model that will be kept on CPU.
334
+ offload_dir (`str` or `os.PathLike`):
335
+ The folder in which to offload the model weights (or where the model weights are already offloaded).
336
+ offload_index (`Dict`, *optional*):
337
+ A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
338
+ to the index saved in `save_folder`.
339
+ offload_buffers (`bool`, *optional*, defaults to `False`):
340
+ Whether or not to offload the buffers with the model parameters.
341
+ skip_keys (`str` or `List[str]`, *optional*):
342
+ A list of keys to ignore when moving inputs or outputs between devices.
343
+ preload_module_classes (`List[str]`, *optional*):
344
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
345
+ of the forward. This should only be used for classes that have submodules which are registered but not
346
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
347
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
348
+ force_hooks (`bool`, *optional*, defaults to `False`):
349
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
350
+ single device.
351
+ """
352
+ # Error early if the device map is incomplete.
353
+ check_device_map(model, device_map)
354
+
355
+ # We need to force hook for quantized model that can't be moved with to()
356
+ if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
357
+ # since bnb 0.43.2, we can move 4-bit model
358
+ if getattr(model, "is_loaded_in_8bit", False) or (
359
+ getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
360
+ ):
361
+ force_hooks = True
362
+
363
+ # We attach hooks if the device_map has at least 2 different devices or if
364
+ # force_hooks is set to `True`. Otherwise, the model in already loaded
365
+ # in the unique device and the user can decide where to dispatch the model.
366
+ # If the model is quantized, we always force-dispatch the model
367
+ if (len(set(device_map.values())) > 1) or force_hooks:
368
+ if main_device is None:
369
+ if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
370
+ main_device = "cpu"
371
+ else:
372
+ main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
373
+
374
+ if main_device != "cpu":
375
+ cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
376
+ if state_dict is None and len(cpu_modules) > 0:
377
+ state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
378
+
379
+ disk_modules = [name for name, device in device_map.items() if device == "disk"]
380
+ if offload_dir is None and offload_index is None and len(disk_modules) > 0:
381
+ raise ValueError(
382
+ "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
383
+ f"need to be offloaded: {', '.join(disk_modules)}."
384
+ )
385
+ if (
386
+ len(disk_modules) > 0
387
+ and offload_index is None
388
+ and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
389
+ ):
390
+ disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
391
+ offload_state_dict(offload_dir, disk_state_dict)
392
+
393
+ execution_device = {
394
+ name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
395
+ }
396
+ execution_device[""] = main_device
397
+ offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
398
+ offload = {name: device in offloaded_devices for name, device in device_map.items()}
399
+ save_folder = offload_dir if len(disk_modules) > 0 else None
400
+ if state_dict is not None or save_folder is not None or offload_index is not None:
401
+ device = main_device if offload_index is not None else None
402
+ weights_map = OffloadedWeightsLoader(
403
+ state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
404
+ )
405
+ else:
406
+ weights_map = None
407
+
408
+ # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
409
+ # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
410
+ # original pointer) on each devices.
411
+ tied_params = find_tied_parameters(model)
412
+
413
+ tied_params_map = {}
414
+ for group in tied_params:
415
+ for param_name in group:
416
+ # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
417
+ # to care about views of tensors through storage_offset.
418
+ data_ptr = recursive_getattr(model, param_name).data_ptr()
419
+ tied_params_map[data_ptr] = {}
420
+
421
+ # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
422
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
423
+
424
+ attach_align_device_hook_on_blocks(
425
+ model,
426
+ execution_device=execution_device,
427
+ offload=offload,
428
+ offload_buffers=offload_buffers,
429
+ weights_map=weights_map,
430
+ skip_keys=skip_keys,
431
+ preload_module_classes=preload_module_classes,
432
+ tied_params_map=tied_params_map,
433
+ )
434
+
435
+ # warn if there is any params on the meta device
436
+ offloaded_devices_str = " and ".join(
437
+ [device for device in set(device_map.values()) if device in ("cpu", "disk")]
438
+ )
439
+ if len(offloaded_devices_str) > 0:
440
+ logger.warning(
441
+ f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
442
+ )
443
+
444
+ # Attaching the hook may break tied weights, so we retie them
445
+ retie_parameters(model, tied_params)
446
+
447
+ # add warning to cuda and to method
448
+ def add_warning(fn, model):
449
+ @wraps(fn)
450
+ def wrapper(*args, **kwargs):
451
+ warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
452
+ if str(fn.__name__) == "to":
453
+ to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
454
+ if to_device is not None:
455
+ logger.warning(warning_msg)
456
+ else:
457
+ logger.warning(warning_msg)
458
+ for param in model.parameters():
459
+ if param.device == torch.device("meta"):
460
+ raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
461
+ return fn(*args, **kwargs)
462
+
463
+ return wrapper
464
+
465
+ # Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
466
+ model.to = add_warning(model.to, model)
467
+ if is_npu_available():
468
+ model.npu = add_warning(model.npu, model)
469
+ elif is_mlu_available():
470
+ model.mlu = add_warning(model.mlu, model)
471
+ elif is_musa_available():
472
+ model.musa = add_warning(model.musa, model)
473
+ elif is_xpu_available():
474
+ model.xpu = add_warning(model.xpu, model)
475
+ else:
476
+ model.cuda = add_warning(model.cuda, model)
477
+
478
+ # Check if we are using multi-gpus with RTX 4000 series
479
+ use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
480
+ if use_multi_gpu and not check_cuda_p2p_ib_support():
481
+ logger.warning(
482
+ "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
483
+ "This can affect the multi-gpu inference when using accelerate device_map."
484
+ "Please make sure to update your driver to the latest version which resolves this."
485
+ )
486
+ else:
487
+ device = list(device_map.values())[0]
488
+ # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
489
+ if is_npu_available() and isinstance(device, int):
490
+ device = f"npu:{device}"
491
+ elif is_mlu_available() and isinstance(device, int):
492
+ device = f"mlu:{device}"
493
+ elif is_musa_available() and isinstance(device, int):
494
+ device = f"musa:{device}"
495
+ elif is_xpu_available() and isinstance(device, int):
496
+ device = f"xpu:{device}"
497
+ if device != "disk":
498
+ model.to(device)
499
+ else:
500
+ raise ValueError(
501
+ "You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
502
+ )
503
+ # Convert OrderedDict back to dict for easier usage
504
+ model.hf_device_map = dict(device_map)
505
+ return model
506
+
507
+
508
+ def load_checkpoint_and_dispatch(
509
+ model: nn.Module,
510
+ checkpoint: Union[str, os.PathLike],
511
+ device_map: Optional[Union[str, Dict[str, Union[int, str, torch.device]]]] = None,
512
+ max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
513
+ no_split_module_classes: Optional[List[str]] = None,
514
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
515
+ offload_buffers: bool = False,
516
+ dtype: Optional[Union[str, torch.dtype]] = None,
517
+ offload_state_dict: Optional[bool] = None,
518
+ skip_keys: Optional[Union[str, List[str]]] = None,
519
+ preload_module_classes: Optional[List[str]] = None,
520
+ force_hooks: bool = False,
521
+ strict: bool = False,
522
+ ):
523
+ """
524
+ Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
525
+ loaded and adds the various hooks that will make this model run properly (even if split across devices).
526
+
527
+ Args:
528
+ model (`torch.nn.Module`): The model in which we want to load a checkpoint.
529
+ checkpoint (`str` or `os.PathLike`):
530
+ The folder checkpoint to load. It can be:
531
+ - a path to a file containing a whole model state dict
532
+ - a path to a `.json` file containing the index to a sharded checkpoint
533
+ - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
534
+ device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
535
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
536
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
537
+
538
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
539
+ information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
540
+ Defaults to None, which means [`dispatch_model`] will not be called.
541
+ max_memory (`Dict`, *optional*):
542
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
543
+ and the available CPU RAM if unset.
544
+ no_split_module_classes (`List[str]`, *optional*):
545
+ A list of layer class names that should never be split across device (for instance any layer that has a
546
+ residual connection).
547
+ offload_folder (`str` or `os.PathLike`, *optional*):
548
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
549
+ offload_buffers (`bool`, *optional*, defaults to `False`):
550
+ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
551
+ well as the parameters.
552
+ dtype (`str` or `torch.dtype`, *optional*):
553
+ If provided, the weights will be converted to that type when loaded.
554
+ offload_state_dict (`bool`, *optional*):
555
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
556
+ the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
557
+ picked contains `"disk"` values.
558
+ skip_keys (`str` or `List[str]`, *optional*):
559
+ A list of keys to ignore when moving inputs or outputs between devices.
560
+ preload_module_classes (`List[str]`, *optional*):
561
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
562
+ of the forward. This should only be used for classes that have submodules which are registered but not
563
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
564
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
565
+ force_hooks (`bool`, *optional*, defaults to `False`):
566
+ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
567
+ single device.
568
+ strict (`bool`, *optional*, defaults to `False`):
569
+ Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
570
+ state_dict.
571
+
572
+ Example:
573
+
574
+ ```python
575
+ >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
576
+ >>> from huggingface_hub import hf_hub_download
577
+ >>> from transformers import AutoConfig, AutoModelForCausalLM
578
+
579
+ >>> # Download the Weights
580
+ >>> checkpoint = "EleutherAI/gpt-j-6B"
581
+ >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
582
+
583
+ >>> # Create a model and initialize it with empty weights
584
+ >>> config = AutoConfig.from_pretrained(checkpoint)
585
+ >>> with init_empty_weights():
586
+ ... model = AutoModelForCausalLM.from_config(config)
587
+
588
+ >>> # Load the checkpoint and dispatch it to the right devices
589
+ >>> model = load_checkpoint_and_dispatch(
590
+ ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
591
+ ... )
592
+ ```
593
+ """
594
+ if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
595
+ raise ValueError(
596
+ "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
597
+ "'sequential'."
598
+ )
599
+ if isinstance(device_map, str):
600
+ if device_map != "sequential":
601
+ max_memory = get_balanced_memory(
602
+ model,
603
+ max_memory=max_memory,
604
+ no_split_module_classes=no_split_module_classes,
605
+ dtype=dtype,
606
+ low_zero=(device_map == "balanced_low_0"),
607
+ )
608
+ device_map = infer_auto_device_map(
609
+ model,
610
+ max_memory=max_memory,
611
+ no_split_module_classes=no_split_module_classes,
612
+ dtype=dtype,
613
+ offload_buffers=offload_buffers,
614
+ )
615
+ if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
616
+ offload_state_dict = True
617
+ load_checkpoint_in_model(
618
+ model,
619
+ checkpoint,
620
+ device_map=device_map,
621
+ offload_folder=offload_folder,
622
+ dtype=dtype,
623
+ offload_state_dict=offload_state_dict,
624
+ offload_buffers=offload_buffers,
625
+ strict=strict,
626
+ )
627
+ if device_map is None:
628
+ return model
629
+ return dispatch_model(
630
+ model,
631
+ device_map=device_map,
632
+ offload_dir=offload_folder,
633
+ offload_buffers=offload_buffers,
634
+ skip_keys=skip_keys,
635
+ preload_module_classes=preload_module_classes,
636
+ force_hooks=force_hooks,
637
+ )
.venv/Lib/site-packages/accelerate/checkpointing.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ import numpy as np
20
+ import torch
21
+ from safetensors.torch import load_model
22
+ from torch.cuda.amp import GradScaler
23
+
24
+ from .utils import (
25
+ MODEL_NAME,
26
+ OPTIMIZER_NAME,
27
+ RNG_STATE_NAME,
28
+ SAFE_MODEL_NAME,
29
+ SAFE_WEIGHTS_NAME,
30
+ SAMPLER_NAME,
31
+ SCALER_NAME,
32
+ SCHEDULER_NAME,
33
+ WEIGHTS_NAME,
34
+ get_pretty_name,
35
+ is_mlu_available,
36
+ is_torch_xla_available,
37
+ is_xpu_available,
38
+ load,
39
+ save,
40
+ )
41
+
42
+
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ from .logging import get_logger
47
+ from .state import PartialState
48
+
49
+
50
+ logger = get_logger(__name__)
51
+
52
+
53
+ def save_accelerator_state(
54
+ output_dir: str,
55
+ model_states: List[dict],
56
+ optimizers: list,
57
+ schedulers: list,
58
+ dataloaders: list,
59
+ process_index: int,
60
+ step: int,
61
+ scaler: GradScaler = None,
62
+ save_on_each_node: bool = False,
63
+ safe_serialization: bool = True,
64
+ ):
65
+ """
66
+ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
67
+
68
+ <Tip>
69
+
70
+ If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native
71
+ `pickle`.
72
+
73
+ </Tip>
74
+
75
+ Args:
76
+ output_dir (`str` or `os.PathLike`):
77
+ The name of the folder to save all relevant weights and states.
78
+ model_states (`List[torch.nn.Module]`):
79
+ A list of model states
80
+ optimizers (`List[torch.optim.Optimizer]`):
81
+ A list of optimizer instances
82
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
83
+ A list of learning rate schedulers
84
+ dataloaders (`List[torch.utils.data.DataLoader]`):
85
+ A list of dataloader instances to save their sampler states
86
+ process_index (`int`):
87
+ The current process index in the Accelerator state
88
+ step (`int`):
89
+ The current step in the internal step tracker
90
+ scaler (`torch.amp.GradScaler`, *optional*):
91
+ An optional gradient scaler instance to save;
92
+ save_on_each_node (`bool`, *optional*):
93
+ Whether to save on every node, or only the main node.
94
+ safe_serialization (`bool`, *optional*, defaults to `True`):
95
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
96
+ """
97
+ output_dir = Path(output_dir)
98
+ # Model states
99
+ for i, state in enumerate(model_states):
100
+ weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
101
+ if i > 0:
102
+ weights_name = weights_name.replace(".", f"_{i}.")
103
+ output_model_file = output_dir.joinpath(weights_name)
104
+ save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
105
+ logger.info(f"Model weights saved in {output_model_file}")
106
+ # Optimizer states
107
+ for i, opt in enumerate(optimizers):
108
+ state = opt.state_dict()
109
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
110
+ output_optimizer_file = output_dir.joinpath(optimizer_name)
111
+ save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False)
112
+ logger.info(f"Optimizer state saved in {output_optimizer_file}")
113
+ # Scheduler states
114
+ for i, scheduler in enumerate(schedulers):
115
+ state = scheduler.state_dict()
116
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
117
+ output_scheduler_file = output_dir.joinpath(scheduler_name)
118
+ save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
119
+ logger.info(f"Scheduler state saved in {output_scheduler_file}")
120
+ # DataLoader states
121
+ for i, dataloader in enumerate(dataloaders):
122
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
123
+ output_sampler_file = output_dir.joinpath(sampler_name)
124
+ # Only save if we have our custom sampler
125
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
126
+
127
+ if isinstance(dataloader.dataset, IterableDatasetShard):
128
+ sampler = dataloader.get_sampler()
129
+ if isinstance(sampler, SeedableRandomSampler):
130
+ save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
131
+ if getattr(dataloader, "use_stateful_dataloader", False):
132
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
133
+ output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
134
+ state_dict = dataloader.state_dict()
135
+ torch.save(state_dict, output_dataloader_state_dict_file)
136
+ logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
137
+
138
+ # GradScaler state
139
+ if scaler is not None:
140
+ state = scaler.state_dict()
141
+ output_scaler_file = output_dir.joinpath(SCALER_NAME)
142
+ torch.save(state, output_scaler_file)
143
+ logger.info(f"Gradient scaler state saved in {output_scaler_file}")
144
+ # Random number generator states
145
+ states = {}
146
+ states_name = f"{RNG_STATE_NAME}_{process_index}.pkl"
147
+ states["step"] = step
148
+ states["random_state"] = random.getstate()
149
+ states["numpy_random_seed"] = np.random.get_state()
150
+ states["torch_manual_seed"] = torch.get_rng_state()
151
+ if is_xpu_available():
152
+ states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
153
+ if is_mlu_available():
154
+ states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
155
+ else:
156
+ states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
157
+ if is_torch_xla_available():
158
+ states["xm_seed"] = xm.get_rng_state()
159
+ output_states_file = output_dir.joinpath(states_name)
160
+ torch.save(states, output_states_file)
161
+ logger.info(f"Random states saved in {output_states_file}")
162
+ return output_dir
163
+
164
+
165
+ def load_accelerator_state(
166
+ input_dir,
167
+ models,
168
+ optimizers,
169
+ schedulers,
170
+ dataloaders,
171
+ process_index,
172
+ scaler=None,
173
+ map_location=None,
174
+ **load_model_func_kwargs,
175
+ ):
176
+ """
177
+ Loads states of the models, optimizers, scaler, and RNG generators from a given directory.
178
+
179
+ Args:
180
+ input_dir (`str` or `os.PathLike`):
181
+ The name of the folder to load all relevant weights and states.
182
+ models (`List[torch.nn.Module]`):
183
+ A list of model instances
184
+ optimizers (`List[torch.optim.Optimizer]`):
185
+ A list of optimizer instances
186
+ schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`):
187
+ A list of learning rate schedulers
188
+ process_index (`int`):
189
+ The current process index in the Accelerator state
190
+ scaler (`torch.amp.GradScaler`, *optional*):
191
+ An optional *GradScaler* instance to load
192
+ map_location (`str`, *optional*):
193
+ What device to load the optimizer state onto. Should be one of either "cpu" or "on_device".
194
+ load_model_func_kwargs (`dict`, *optional*):
195
+ Additional arguments that can be passed to the model's `load_state_dict` method.
196
+
197
+ Returns:
198
+ `dict`: Contains the `Accelerator` attributes to override while loading the state.
199
+ """
200
+ # stores the `Accelerator` attributes to override
201
+ override_attributes = dict()
202
+ if map_location not in [None, "cpu", "on_device"]:
203
+ raise TypeError(
204
+ "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`"
205
+ )
206
+ if map_location is None:
207
+ map_location = "cpu"
208
+ elif map_location == "on_device":
209
+ map_location = PartialState().device
210
+
211
+ input_dir = Path(input_dir)
212
+ # Model states
213
+ for i, model in enumerate(models):
214
+ ending = f"_{i}" if i > 0 else ""
215
+ input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
216
+ if input_model_file.exists():
217
+ load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
218
+ else:
219
+ # Load with torch
220
+ input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
221
+ state_dict = load(input_model_file, map_location=map_location)
222
+ model.load_state_dict(state_dict, **load_model_func_kwargs)
223
+ logger.info("All model weights loaded successfully")
224
+
225
+ # Optimizer states
226
+ for i, opt in enumerate(optimizers):
227
+ optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin"
228
+ input_optimizer_file = input_dir.joinpath(optimizer_name)
229
+ optimizer_state = load(input_optimizer_file, map_location=map_location)
230
+ optimizers[i].load_state_dict(optimizer_state)
231
+ logger.info("All optimizer states loaded successfully")
232
+
233
+ # Scheduler states
234
+ for i, scheduler in enumerate(schedulers):
235
+ scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin"
236
+ input_scheduler_file = input_dir.joinpath(scheduler_name)
237
+ scheduler_state = load(input_scheduler_file)
238
+ scheduler.load_state_dict(scheduler_state)
239
+ logger.info("All scheduler states loaded successfully")
240
+
241
+ for i, dataloader in enumerate(dataloaders):
242
+ sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin"
243
+ input_sampler_file = input_dir.joinpath(sampler_name)
244
+ # Only load if we have our custom sampler
245
+ from .data_loader import IterableDatasetShard, SeedableRandomSampler
246
+
247
+ if isinstance(dataloader.dataset, IterableDatasetShard):
248
+ sampler = dataloader.get_sampler()
249
+ if isinstance(sampler, SeedableRandomSampler):
250
+ sampler = dataloader.set_sampler(load(input_sampler_file))
251
+ if getattr(dataloader, "use_stateful_dataloader", False):
252
+ dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
253
+ input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
254
+ if input_dataloader_state_dict_file.exists():
255
+ state_dict = load(input_dataloader_state_dict_file)
256
+ dataloader.load_state_dict(state_dict)
257
+ logger.info("All dataloader sampler states loaded successfully")
258
+
259
+ # GradScaler state
260
+ if scaler is not None:
261
+ input_scaler_file = input_dir.joinpath(SCALER_NAME)
262
+ scaler_state = load(input_scaler_file)
263
+ scaler.load_state_dict(scaler_state)
264
+ logger.info("GradScaler state loaded successfully")
265
+
266
+ # Random states
267
+ try:
268
+ states = load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl"))
269
+ if "step" in states:
270
+ override_attributes["step"] = states["step"]
271
+ random.setstate(states["random_state"])
272
+ np.random.set_state(states["numpy_random_seed"])
273
+ torch.set_rng_state(states["torch_manual_seed"])
274
+ if is_xpu_available():
275
+ torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
276
+ if is_mlu_available():
277
+ torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
278
+ else:
279
+ torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
280
+ if is_torch_xla_available():
281
+ xm.set_rng_state(states["xm_seed"])
282
+ logger.info("All random states loaded successfully")
283
+ except Exception:
284
+ logger.info("Could not load random states")
285
+
286
+ return override_attributes
287
+
288
+
289
+ def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False):
290
+ """
291
+ Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl`
292
+ """
293
+ # Should this be the right way to get a qual_name type value from `obj`?
294
+ save_location = Path(path) / f"custom_checkpoint_{index}.pkl"
295
+ logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}")
296
+ save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node)
297
+
298
+
299
+ def load_custom_state(obj, path, index: int = 0):
300
+ """
301
+ Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl`. Will always set `weights_only=False` when
302
+ loading the state.
303
+ """
304
+ load_location = f"{path}/custom_checkpoint_{index}.pkl"
305
+ logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}")
306
+ obj.load_state_dict(load(load_location, map_location="cpu", weights_only=False))
.venv/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-39.pyc ADDED
Binary file (1.31 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/accelerate_cli.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from accelerate.commands.config import get_config_parser
18
+ from accelerate.commands.env import env_command_parser
19
+ from accelerate.commands.estimate import estimate_command_parser
20
+ from accelerate.commands.launch import launch_command_parser
21
+ from accelerate.commands.merge import merge_command_parser
22
+ from accelerate.commands.test import test_command_parser
23
+ from accelerate.commands.tpu import tpu_command_parser
24
+ from accelerate.commands.utils import CustomArgumentParser
25
+
26
+
27
+ def main():
28
+ parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate <command> [<args>]", allow_abbrev=False)
29
+ subparsers = parser.add_subparsers(help="accelerate command helpers")
30
+
31
+ # Register commands
32
+ get_config_parser(subparsers=subparsers)
33
+ estimate_command_parser(subparsers=subparsers)
34
+ env_command_parser(subparsers=subparsers)
35
+ launch_command_parser(subparsers=subparsers)
36
+ merge_command_parser(subparsers=subparsers)
37
+ tpu_command_parser(subparsers=subparsers)
38
+ test_command_parser(subparsers=subparsers)
39
+
40
+ # Let's go
41
+ args = parser.parse_args()
42
+
43
+ if not hasattr(args, "func"):
44
+ parser.print_help()
45
+ exit(1)
46
+
47
+ # Run
48
+ args.func(args)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
.venv/Lib/site-packages/accelerate/commands/config/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+
19
+ from .config import config_command_parser
20
+ from .config_args import default_config_file, load_config_from_file # noqa: F401
21
+ from .default import default_command_parser
22
+ from .update import update_command_parser
23
+
24
+
25
+ def get_config_parser(subparsers=None):
26
+ parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
27
+ # The main config parser
28
+ config_parser = config_command_parser(subparsers)
29
+ # The subparser to add commands to
30
+ subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand")
31
+
32
+ # Then add other parsers with the parent parser
33
+ default_command_parser(subcommands, parents=[parent_parser])
34
+ update_command_parser(subcommands, parents=[parent_parser])
35
+
36
+ return config_parser
37
+
38
+
39
+ def main():
40
+ config_parser = get_config_parser()
41
+ args = config_parser.parse_args()
42
+
43
+ if not hasattr(args, "func"):
44
+ config_parser.print_help()
45
+ exit(1)
46
+
47
+ # Run
48
+ args.func(args)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-39.pyc ADDED
Binary file (17.7 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-39.pyc ADDED
Binary file (2.43 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-39.pyc ADDED
Binary file (7.52 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-39.pyc ADDED
Binary file (3.05 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-39.pyc ADDED
Binary file (1.86 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/config/config_args.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+ from enum import Enum
21
+ from typing import List, Optional, Union
22
+
23
+ import yaml
24
+
25
+ from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
26
+ from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
27
+
28
+
29
+ hf_cache_home = os.path.expanduser(
30
+ os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
31
+ )
32
+ cache_dir = os.path.join(hf_cache_home, "accelerate")
33
+ default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
34
+ default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
35
+
36
+ # For backward compatibility: the default config is the json one if it's the only existing file.
37
+ if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
38
+ default_config_file = default_yaml_config_file
39
+ else:
40
+ default_config_file = default_json_config_file
41
+
42
+
43
+ def load_config_from_file(config_file):
44
+ if config_file is not None:
45
+ if not os.path.isfile(config_file):
46
+ raise FileNotFoundError(
47
+ f"The passed configuration file `{config_file}` does not exist. "
48
+ "Please pass an existing file to `accelerate launch`, or use the default one "
49
+ "created through `accelerate config` and run `accelerate launch` "
50
+ "without the `--config_file` argument."
51
+ )
52
+ else:
53
+ config_file = default_config_file
54
+ with open(config_file, encoding="utf-8") as f:
55
+ if config_file.endswith(".json"):
56
+ if (
57
+ json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
58
+ == ComputeEnvironment.LOCAL_MACHINE
59
+ ):
60
+ config_class = ClusterConfig
61
+ else:
62
+ config_class = SageMakerConfig
63
+ return config_class.from_json_file(json_file=config_file)
64
+ else:
65
+ if (
66
+ yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
67
+ == ComputeEnvironment.LOCAL_MACHINE
68
+ ):
69
+ config_class = ClusterConfig
70
+ else:
71
+ config_class = SageMakerConfig
72
+ return config_class.from_yaml_file(yaml_file=config_file)
73
+
74
+
75
+ @dataclass
76
+ class BaseConfig:
77
+ compute_environment: ComputeEnvironment
78
+ distributed_type: Union[DistributedType, SageMakerDistributedType]
79
+ mixed_precision: str
80
+ use_cpu: bool
81
+ debug: bool
82
+
83
+ def to_dict(self):
84
+ result = self.__dict__
85
+ # For serialization, it's best to convert Enums to strings (or their underlying value type).
86
+
87
+ def _convert_enums(value):
88
+ if isinstance(value, Enum):
89
+ return value.value
90
+ if isinstance(value, dict):
91
+ if not bool(value):
92
+ return None
93
+ for key1, value1 in value.items():
94
+ value[key1] = _convert_enums(value1)
95
+ return value
96
+
97
+ for key, value in result.items():
98
+ result[key] = _convert_enums(value)
99
+ result = {k: v for k, v in result.items() if v is not None}
100
+ return result
101
+
102
+ @staticmethod
103
+ def process_config(config_dict):
104
+ """
105
+ Processes `config_dict` and sets default values for any missing keys
106
+ """
107
+ if "compute_environment" not in config_dict:
108
+ config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
109
+ if "distributed_type" not in config_dict:
110
+ raise ValueError("A `distributed_type` must be specified in the config file.")
111
+ if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
112
+ config_dict["num_processes"] = 1
113
+ if "mixed_precision" not in config_dict:
114
+ config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
115
+ if "fp16" in config_dict: # Convert the config to the new format.
116
+ del config_dict["fp16"]
117
+ if "dynamo_backend" in config_dict: # Convert the config to the new format.
118
+ dynamo_backend = config_dict.pop("dynamo_backend")
119
+ config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
120
+ if "use_cpu" not in config_dict:
121
+ config_dict["use_cpu"] = False
122
+ if "debug" not in config_dict:
123
+ config_dict["debug"] = False
124
+ if "enable_cpu_affinity" not in config_dict:
125
+ config_dict["enable_cpu_affinity"] = False
126
+ return config_dict
127
+
128
+ @classmethod
129
+ def from_json_file(cls, json_file=None):
130
+ json_file = default_json_config_file if json_file is None else json_file
131
+ with open(json_file, encoding="utf-8") as f:
132
+ config_dict = json.load(f)
133
+ config_dict = cls.process_config(config_dict)
134
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
135
+ if len(extra_keys) > 0:
136
+ raise ValueError(
137
+ f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
138
+ " version or fix (and potentially remove) these keys from your config file."
139
+ )
140
+
141
+ return cls(**config_dict)
142
+
143
+ def to_json_file(self, json_file):
144
+ with open(json_file, "w", encoding="utf-8") as f:
145
+ content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
146
+ f.write(content)
147
+
148
+ @classmethod
149
+ def from_yaml_file(cls, yaml_file=None):
150
+ yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
151
+ with open(yaml_file, encoding="utf-8") as f:
152
+ config_dict = yaml.safe_load(f)
153
+ config_dict = cls.process_config(config_dict)
154
+ extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
155
+ if len(extra_keys) > 0:
156
+ raise ValueError(
157
+ f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
158
+ " version or fix (and potentially remove) these keys from your config file."
159
+ )
160
+ return cls(**config_dict)
161
+
162
+ def to_yaml_file(self, yaml_file):
163
+ with open(yaml_file, "w", encoding="utf-8") as f:
164
+ yaml.safe_dump(self.to_dict(), f)
165
+
166
+ def __post_init__(self):
167
+ if isinstance(self.compute_environment, str):
168
+ self.compute_environment = ComputeEnvironment(self.compute_environment)
169
+ if isinstance(self.distributed_type, str):
170
+ if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
171
+ self.distributed_type = SageMakerDistributedType(self.distributed_type)
172
+ else:
173
+ self.distributed_type = DistributedType(self.distributed_type)
174
+ if getattr(self, "dynamo_config", None) is None:
175
+ self.dynamo_config = {}
176
+
177
+
178
+ @dataclass
179
+ class ClusterConfig(BaseConfig):
180
+ num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
181
+ machine_rank: int = 0
182
+ num_machines: int = 1
183
+ gpu_ids: Optional[str] = None
184
+ main_process_ip: Optional[str] = None
185
+ main_process_port: Optional[int] = None
186
+ rdzv_backend: Optional[str] = "static"
187
+ same_network: Optional[bool] = False
188
+ main_training_function: str = "main"
189
+ enable_cpu_affinity: bool = False
190
+
191
+ # args for FP8 training
192
+ fp8_config: dict = None
193
+ # args for deepspeed_plugin
194
+ deepspeed_config: dict = None
195
+ # args for fsdp
196
+ fsdp_config: dict = None
197
+ # args for megatron_lm
198
+ megatron_lm_config: dict = None
199
+ # args for ipex
200
+ ipex_config: dict = None
201
+ # args for mpirun
202
+ mpirun_config: dict = None
203
+ # args for TPU
204
+ downcast_bf16: bool = False
205
+
206
+ # args for TPU pods
207
+ tpu_name: str = None
208
+ tpu_zone: str = None
209
+ tpu_use_cluster: bool = False
210
+ tpu_use_sudo: bool = False
211
+ command_file: str = None
212
+ commands: List[str] = None
213
+ tpu_vm: List[str] = None
214
+ tpu_env: List[str] = None
215
+
216
+ # args for dynamo
217
+ dynamo_config: dict = None
218
+
219
+ def __post_init__(self):
220
+ if self.deepspeed_config is None:
221
+ self.deepspeed_config = {}
222
+ if self.fsdp_config is None:
223
+ self.fsdp_config = {}
224
+ if self.megatron_lm_config is None:
225
+ self.megatron_lm_config = {}
226
+ if self.ipex_config is None:
227
+ self.ipex_config = {}
228
+ if self.mpirun_config is None:
229
+ self.mpirun_config = {}
230
+ if self.fp8_config is None:
231
+ self.fp8_config = {}
232
+ return super().__post_init__()
233
+
234
+
235
+ @dataclass
236
+ class SageMakerConfig(BaseConfig):
237
+ ec2_instance_type: str
238
+ iam_role_name: str
239
+ image_uri: Optional[str] = None
240
+ profile: Optional[str] = None
241
+ region: str = "us-east-1"
242
+ num_machines: int = 1
243
+ gpu_ids: str = "all"
244
+ base_job_name: str = f"accelerate-sagemaker-{num_machines}"
245
+ pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
246
+ transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
247
+ py_version: str = SAGEMAKER_PYTHON_VERSION
248
+ sagemaker_inputs_file: str = None
249
+ sagemaker_metrics_file: str = None
250
+ additional_args: dict = None
251
+ dynamo_config: dict = None
252
+ enable_cpu_affinity: bool = False
.venv/Lib/site-packages/accelerate/commands/config/default.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from ...utils import is_mlu_available, is_musa_available, is_npu_available, is_xpu_available
22
+ from .config_args import ClusterConfig, default_json_config_file
23
+ from .config_utils import SubcommandHelpFormatter
24
+
25
+
26
+ description = "Create a default config file for Accelerate with only a few flags set."
27
+
28
+
29
+ def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file, use_xpu: bool = False):
30
+ """
31
+ Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
32
+ set CPU if it is a CPU-only machine.
33
+
34
+ Args:
35
+ mixed_precision (`str`, *optional*, defaults to "no"):
36
+ Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
37
+ save_location (`str`, *optional*, defaults to `default_json_config_file`):
38
+ Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
39
+ location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overriden by setting
40
+ the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
41
+ use_xpu (`bool`, *optional*, defaults to `False`):
42
+ Whether to use XPU if available.
43
+ """
44
+ path = Path(save_location)
45
+ path.parent.mkdir(parents=True, exist_ok=True)
46
+ if path.exists():
47
+ print(
48
+ f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
49
+ )
50
+ return False
51
+ mixed_precision = mixed_precision.lower()
52
+ if mixed_precision not in ["no", "fp16", "bf16", "fp8"]:
53
+ raise ValueError(
54
+ f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}"
55
+ )
56
+ config = {
57
+ "compute_environment": "LOCAL_MACHINE",
58
+ "mixed_precision": mixed_precision,
59
+ }
60
+ if is_mlu_available():
61
+ num_mlus = torch.mlu.device_count()
62
+ config["num_processes"] = num_mlus
63
+ config["use_cpu"] = False
64
+ if num_mlus > 1:
65
+ config["distributed_type"] = "MULTI_MLU"
66
+ else:
67
+ config["distributed_type"] = "NO"
68
+ elif is_musa_available():
69
+ num_musas = torch.musa.device_count()
70
+ config["num_processes"] = num_musas
71
+ config["use_cpu"] = False
72
+ if num_musas > 1:
73
+ config["distributed_type"] = "MULTI_MUSA"
74
+ else:
75
+ config["distributed_type"] = "NO"
76
+ elif torch.cuda.is_available():
77
+ num_gpus = torch.cuda.device_count()
78
+ config["num_processes"] = num_gpus
79
+ config["use_cpu"] = False
80
+ if num_gpus > 1:
81
+ config["distributed_type"] = "MULTI_GPU"
82
+ else:
83
+ config["distributed_type"] = "NO"
84
+ elif is_xpu_available() and use_xpu:
85
+ num_xpus = torch.xpu.device_count()
86
+ config["num_processes"] = num_xpus
87
+ config["use_cpu"] = False
88
+ if num_xpus > 1:
89
+ config["distributed_type"] = "MULTI_XPU"
90
+ else:
91
+ config["distributed_type"] = "NO"
92
+ elif is_npu_available():
93
+ num_npus = torch.npu.device_count()
94
+ config["num_processes"] = num_npus
95
+ config["use_cpu"] = False
96
+ if num_npus > 1:
97
+ config["distributed_type"] = "MULTI_NPU"
98
+ else:
99
+ config["distributed_type"] = "NO"
100
+ else:
101
+ num_xpus = 0
102
+ config["use_cpu"] = True
103
+ config["num_processes"] = 1
104
+ config["distributed_type"] = "NO"
105
+ config["debug"] = False
106
+ config["enable_cpu_affinity"] = False
107
+ config = ClusterConfig(**config)
108
+ config.to_json_file(path)
109
+ return path
110
+
111
+
112
+ def default_command_parser(parser, parents):
113
+ parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
114
+ parser.add_argument(
115
+ "--config_file",
116
+ default=default_json_config_file,
117
+ help=(
118
+ "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
119
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
120
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
121
+ "with 'huggingface'."
122
+ ),
123
+ dest="save_location",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--mixed_precision",
128
+ choices=["no", "fp16", "bf16"],
129
+ type=str,
130
+ help="Whether or not to use mixed precision training. "
131
+ "Choose between FP16 and BF16 (bfloat16) training. "
132
+ "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
133
+ default="no",
134
+ )
135
+ parser.set_defaults(func=default_config_command)
136
+ return parser
137
+
138
+
139
+ def default_config_command(args):
140
+ config_file = write_basic_config(args.mixed_precision, args.save_location)
141
+ if config_file:
142
+ print(f"accelerate configuration saved at {config_file}")
.venv/Lib/site-packages/accelerate/commands/config/sagemaker.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import json
17
+ import os
18
+
19
+ from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES
20
+ from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
21
+ from ...utils.imports import is_boto3_available
22
+ from .config_args import SageMakerConfig
23
+ from .config_utils import (
24
+ DYNAMO_BACKENDS,
25
+ _ask_field,
26
+ _ask_options,
27
+ _convert_dynamo_backend,
28
+ _convert_mixed_precision,
29
+ _convert_sagemaker_distributed_mode,
30
+ _convert_yes_no_to_bool,
31
+ )
32
+
33
+
34
+ if is_boto3_available():
35
+ import boto3 # noqa: F401
36
+
37
+
38
+ def _create_iam_role_for_sagemaker(role_name):
39
+ iam_client = boto3.client("iam")
40
+
41
+ sagemaker_trust_policy = {
42
+ "Version": "2012-10-17",
43
+ "Statement": [
44
+ {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
45
+ ],
46
+ }
47
+ try:
48
+ # create the role, associated with the chosen trust policy
49
+ iam_client.create_role(
50
+ RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2)
51
+ )
52
+ policy_document = {
53
+ "Version": "2012-10-17",
54
+ "Statement": [
55
+ {
56
+ "Effect": "Allow",
57
+ "Action": [
58
+ "sagemaker:*",
59
+ "ecr:GetDownloadUrlForLayer",
60
+ "ecr:BatchGetImage",
61
+ "ecr:BatchCheckLayerAvailability",
62
+ "ecr:GetAuthorizationToken",
63
+ "cloudwatch:PutMetricData",
64
+ "cloudwatch:GetMetricData",
65
+ "cloudwatch:GetMetricStatistics",
66
+ "cloudwatch:ListMetrics",
67
+ "logs:CreateLogGroup",
68
+ "logs:CreateLogStream",
69
+ "logs:DescribeLogStreams",
70
+ "logs:PutLogEvents",
71
+ "logs:GetLogEvents",
72
+ "s3:CreateBucket",
73
+ "s3:ListBucket",
74
+ "s3:GetBucketLocation",
75
+ "s3:GetObject",
76
+ "s3:PutObject",
77
+ ],
78
+ "Resource": "*",
79
+ }
80
+ ],
81
+ }
82
+ # attach policy to role
83
+ iam_client.put_role_policy(
84
+ RoleName=role_name,
85
+ PolicyName=f"{role_name}_policy_permission",
86
+ PolicyDocument=json.dumps(policy_document, indent=2),
87
+ )
88
+ except iam_client.exceptions.EntityAlreadyExistsException:
89
+ print(f"role {role_name} already exists. Using existing one")
90
+
91
+
92
+ def _get_iam_role_arn(role_name):
93
+ iam_client = boto3.client("iam")
94
+ return iam_client.get_role(RoleName=role_name)["Role"]["Arn"]
95
+
96
+
97
+ def get_sagemaker_input():
98
+ credentials_configuration = _ask_options(
99
+ "How do you want to authorize?",
100
+ ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "],
101
+ int,
102
+ )
103
+ aws_profile = None
104
+ if credentials_configuration == 0:
105
+ aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default")
106
+ os.environ["AWS_PROFILE"] = aws_profile
107
+ else:
108
+ print(
109
+ "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with,"
110
+ "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`"
111
+ )
112
+ aws_access_key_id = _ask_field("AWS Access Key ID: ")
113
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
114
+
115
+ aws_secret_access_key = _ask_field("AWS Secret Access Key: ")
116
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
117
+
118
+ aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1")
119
+ os.environ["AWS_DEFAULT_REGION"] = aws_region
120
+
121
+ role_management = _ask_options(
122
+ "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?",
123
+ ["Provide IAM Role name", "Create new IAM role using credentials"],
124
+ int,
125
+ )
126
+ if role_management == 0:
127
+ iam_role_name = _ask_field("Enter your IAM role name: ")
128
+ else:
129
+ iam_role_name = "accelerate_sagemaker_execution_role"
130
+ print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials')
131
+ _create_iam_role_for_sagemaker(iam_role_name)
132
+
133
+ is_custom_docker_image = _ask_field(
134
+ "Do you want to use custom Docker image? [yes/NO]: ",
135
+ _convert_yes_no_to_bool,
136
+ default=False,
137
+ error_message="Please enter yes or no.",
138
+ )
139
+ docker_image = None
140
+ if is_custom_docker_image:
141
+ docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower())
142
+
143
+ is_sagemaker_inputs_enabled = _ask_field(
144
+ "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ",
145
+ _convert_yes_no_to_bool,
146
+ default=False,
147
+ error_message="Please enter yes or no.",
148
+ )
149
+ sagemaker_inputs_file = None
150
+ if is_sagemaker_inputs_enabled:
151
+ sagemaker_inputs_file = _ask_field(
152
+ "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ",
153
+ lambda x: str(x).lower(),
154
+ )
155
+
156
+ is_sagemaker_metrics_enabled = _ask_field(
157
+ "Do you want to enable SageMaker metrics? [yes/NO]: ",
158
+ _convert_yes_no_to_bool,
159
+ default=False,
160
+ error_message="Please enter yes or no.",
161
+ )
162
+ sagemaker_metrics_file = None
163
+ if is_sagemaker_metrics_enabled:
164
+ sagemaker_metrics_file = _ask_field(
165
+ "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ",
166
+ lambda x: str(x).lower(),
167
+ )
168
+
169
+ distributed_type = _ask_options(
170
+ "What is the distributed mode?",
171
+ ["No distributed training", "Data parallelism"],
172
+ _convert_sagemaker_distributed_mode,
173
+ )
174
+ dynamo_config = {}
175
+ use_dynamo = _ask_field(
176
+ "Do you wish to optimize your script with torch dynamo?[yes/NO]:",
177
+ _convert_yes_no_to_bool,
178
+ default=False,
179
+ error_message="Please enter yes or no.",
180
+ )
181
+ if use_dynamo:
182
+ prefix = "dynamo_"
183
+ dynamo_config[prefix + "backend"] = _ask_options(
184
+ "Which dynamo backend would you like to use?",
185
+ [x.lower() for x in DYNAMO_BACKENDS],
186
+ _convert_dynamo_backend,
187
+ default=2,
188
+ )
189
+ use_custom_options = _ask_field(
190
+ "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ",
191
+ _convert_yes_no_to_bool,
192
+ default=False,
193
+ error_message="Please enter yes or no.",
194
+ )
195
+
196
+ if use_custom_options:
197
+ dynamo_config[prefix + "mode"] = _ask_options(
198
+ "Which mode do you want to use?",
199
+ TORCH_DYNAMO_MODES,
200
+ lambda x: TORCH_DYNAMO_MODES[int(x)],
201
+ default="default",
202
+ )
203
+ dynamo_config[prefix + "use_fullgraph"] = _ask_field(
204
+ "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ",
205
+ _convert_yes_no_to_bool,
206
+ default=False,
207
+ error_message="Please enter yes or no.",
208
+ )
209
+ dynamo_config[prefix + "use_dynamic"] = _ask_field(
210
+ "Do you want to enable dynamic shape tracing? [yes/NO]: ",
211
+ _convert_yes_no_to_bool,
212
+ default=False,
213
+ error_message="Please enter yes or no.",
214
+ )
215
+ ec2_instance_query = "Which EC2 instance type you want to use for your training?"
216
+ if distributed_type != SageMakerDistributedType.NO:
217
+ ec2_instance_type = _ask_options(
218
+ ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)]
219
+ )
220
+ else:
221
+ ec2_instance_query += "? [ml.p3.2xlarge]:"
222
+ ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge")
223
+
224
+ debug = False
225
+ if distributed_type != SageMakerDistributedType.NO:
226
+ debug = _ask_field(
227
+ "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ",
228
+ _convert_yes_no_to_bool,
229
+ default=False,
230
+ error_message="Please enter yes or no.",
231
+ )
232
+
233
+ num_machines = 1
234
+ if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL):
235
+ num_machines = _ask_field(
236
+ "How many machines do you want use? [1]: ",
237
+ int,
238
+ default=1,
239
+ )
240
+
241
+ mixed_precision = _ask_options(
242
+ "Do you wish to use FP16 or BF16 (mixed precision)?",
243
+ ["no", "fp16", "bf16", "fp8"],
244
+ _convert_mixed_precision,
245
+ )
246
+
247
+ if use_dynamo and mixed_precision == "no":
248
+ print(
249
+ "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
250
+ )
251
+
252
+ return SageMakerConfig(
253
+ image_uri=docker_image,
254
+ compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER,
255
+ distributed_type=distributed_type,
256
+ use_cpu=False,
257
+ dynamo_config=dynamo_config,
258
+ ec2_instance_type=ec2_instance_type,
259
+ profile=aws_profile,
260
+ region=aws_region,
261
+ iam_role_name=iam_role_name,
262
+ mixed_precision=mixed_precision,
263
+ num_machines=num_machines,
264
+ sagemaker_inputs_file=sagemaker_inputs_file,
265
+ sagemaker_metrics_file=sagemaker_metrics_file,
266
+ debug=debug,
267
+ )
.venv/Lib/site-packages/accelerate/commands/config/update.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from pathlib import Path
18
+
19
+ from .config_args import default_config_file, load_config_from_file
20
+ from .config_utils import SubcommandHelpFormatter
21
+
22
+
23
+ description = "Update an existing config file with the latest defaults while maintaining the old configuration."
24
+
25
+
26
+ def update_config(args):
27
+ """
28
+ Update an existing config file with the latest defaults while maintaining the old configuration.
29
+ """
30
+ config_file = args.config_file
31
+ if config_file is None and Path(default_config_file).exists():
32
+ config_file = default_config_file
33
+ elif not Path(config_file).exists():
34
+ raise ValueError(f"The passed config file located at {config_file} doesn't exist.")
35
+ config = load_config_from_file(config_file)
36
+
37
+ if config_file.endswith(".json"):
38
+ config.to_json_file(config_file)
39
+ else:
40
+ config.to_yaml_file(config_file)
41
+ return config_file
42
+
43
+
44
+ def update_command_parser(parser, parents):
45
+ parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter)
46
+ parser.add_argument(
47
+ "--config_file",
48
+ default=None,
49
+ help=(
50
+ "The path to the config file to update. Will default to a file named default_config.yaml in the cache "
51
+ "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
52
+ "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
53
+ "with 'huggingface'."
54
+ ),
55
+ )
56
+
57
+ parser.set_defaults(func=update_config_command)
58
+ return parser
59
+
60
+
61
+ def update_config_command(args):
62
+ config_file = update_config(args)
63
+ print(f"Sucessfully updated the configuration file at {config_file}.")
.venv/Lib/site-packages/accelerate/commands/env.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import platform
20
+ import subprocess
21
+
22
+ import numpy as np
23
+ import psutil
24
+ import torch
25
+
26
+ from accelerate import __version__ as version
27
+ from accelerate.commands.config import default_config_file, load_config_from_file
28
+
29
+ from ..utils import is_mlu_available, is_musa_available, is_npu_available, is_xpu_available
30
+
31
+
32
+ def env_command_parser(subparsers=None):
33
+ if subparsers is not None:
34
+ parser = subparsers.add_parser("env")
35
+ else:
36
+ parser = argparse.ArgumentParser("Accelerate env command")
37
+
38
+ parser.add_argument(
39
+ "--config_file", default=None, help="The config file to use for the default values in the launching script."
40
+ )
41
+
42
+ if subparsers is not None:
43
+ parser.set_defaults(func=env_command)
44
+ return parser
45
+
46
+
47
+ def env_command(args):
48
+ pt_version = torch.__version__
49
+ pt_cuda_available = torch.cuda.is_available()
50
+ pt_xpu_available = is_xpu_available()
51
+ pt_mlu_available = is_mlu_available()
52
+ pt_musa_available = is_musa_available()
53
+ pt_npu_available = is_npu_available()
54
+
55
+ accelerate_config = "Not found"
56
+ # Get the default from the config file.
57
+ if args.config_file is not None or os.path.isfile(default_config_file):
58
+ accelerate_config = load_config_from_file(args.config_file).to_dict()
59
+
60
+ # if we can run which, get it
61
+ command = None
62
+ bash_location = "Not found"
63
+ if os.name == "nt":
64
+ command = ["where", "accelerate"]
65
+ elif os.name == "posix":
66
+ command = ["which", "accelerate"]
67
+ if command is not None:
68
+ bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
69
+ info = {
70
+ "`Accelerate` version": version,
71
+ "Platform": platform.platform(),
72
+ "`accelerate` bash location": bash_location,
73
+ "Python version": platform.python_version(),
74
+ "Numpy version": np.__version__,
75
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
76
+ "PyTorch XPU available": str(pt_xpu_available),
77
+ "PyTorch NPU available": str(pt_npu_available),
78
+ "PyTorch MLU available": str(pt_mlu_available),
79
+ "PyTorch MUSA available": str(pt_musa_available),
80
+ "System RAM": f"{psutil.virtual_memory().total / 1024 ** 3:.2f} GB",
81
+ }
82
+ if pt_cuda_available:
83
+ info["GPU type"] = torch.cuda.get_device_name()
84
+ if pt_mlu_available:
85
+ info["MLU type"] = torch.mlu.get_device_name()
86
+ if pt_npu_available:
87
+ info["CANN version"] = torch.version.cann
88
+
89
+ print("\nCopy-and-paste the text below in your GitHub issue\n")
90
+ print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]))
91
+
92
+ print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:")
93
+ accelerate_config_str = (
94
+ "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
95
+ if isinstance(accelerate_config, dict)
96
+ else f"\t{accelerate_config}"
97
+ )
98
+ print(accelerate_config_str)
99
+
100
+ info["`Accelerate` configs"] = accelerate_config
101
+
102
+ return info
103
+
104
+
105
+ def main() -> int:
106
+ parser = env_command_parser()
107
+ args = parser.parse_args()
108
+ env_command(args)
109
+ return 0
110
+
111
+
112
+ if __name__ == "__main__":
113
+ raise SystemExit(main())
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (245 Bytes). View file
 
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-39.pyc ADDED
Binary file (1.56 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-39.pyc ADDED
Binary file (2.41 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-39.pyc ADDED
Binary file (2.39 kB). View file
 
.venv/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
.venv/Lib/site-packages/accelerate/data_loader.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from contextlib import suppress
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import torch
20
+ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
21
+
22
+ from .logging import get_logger
23
+ from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
24
+ from .utils import (
25
+ RNGType,
26
+ broadcast,
27
+ broadcast_object_list,
28
+ concatenate,
29
+ find_batch_size,
30
+ get_data_structure,
31
+ initialize_tensors,
32
+ is_torch_version,
33
+ is_torchdata_stateful_dataloader_available,
34
+ send_to_device,
35
+ slice_tensors,
36
+ synchronize_rng_states,
37
+ )
38
+
39
+
40
+ logger = get_logger(__name__)
41
+
42
+ # kwargs of the DataLoader in min version 1.4.0.
43
+ _PYTORCH_DATALOADER_KWARGS = {
44
+ "batch_size": 1,
45
+ "shuffle": False,
46
+ "sampler": None,
47
+ "batch_sampler": None,
48
+ "num_workers": 0,
49
+ "collate_fn": None,
50
+ "pin_memory": False,
51
+ "drop_last": False,
52
+ "timeout": 0,
53
+ "worker_init_fn": None,
54
+ "multiprocessing_context": None,
55
+ "generator": None,
56
+ "prefetch_factor": 2,
57
+ "persistent_workers": False,
58
+ }
59
+
60
+ # kwargs added after by version
61
+ _PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {}
62
+
63
+ for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
64
+ if is_torch_version(">=", v):
65
+ _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)
66
+
67
+
68
+ class SeedableRandomSampler(RandomSampler):
69
+ """
70
+ Same as a random sampler, except that in `__iter__` a seed can be used.
71
+
72
+ Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
73
+ and be fully reproducable on multiple iterations.
74
+
75
+ If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
76
+ (stored in `self.epoch`).
77
+ """
78
+
79
+ def __init__(self, *args, **kwargs):
80
+ data_seed = kwargs.pop("data_seed", None)
81
+ super().__init__(*args, **kwargs)
82
+
83
+ self.initial_seed = data_seed if data_seed is not None else torch.random.initial_seed()
84
+ self.epoch = 0
85
+
86
+ def __iter__(self):
87
+ if self.generator is None:
88
+ self.generator = torch.Generator()
89
+ self.generator.manual_seed(self.initial_seed)
90
+
91
+ # Allow `self.epoch` to modify the seed of the generator
92
+ seed = self.epoch + self.initial_seed
93
+ # print("Setting seed at epoch", self.epoch, seed)
94
+ self.generator.manual_seed(seed)
95
+ yield from super().__iter__()
96
+ self.set_epoch(self.epoch + 1)
97
+
98
+ def set_epoch(self, epoch: int):
99
+ "Sets the current iteration of the sampler."
100
+ self.epoch = epoch
101
+
102
+
103
+ class BatchSamplerShard(BatchSampler):
104
+ """
105
+ Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will
106
+ always yield a number of batches that is a round multiple of `num_processes` and that all have the same size.
107
+ Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration
108
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
109
+
110
+ Args:
111
+ batch_sampler (`torch.utils.data.sampler.BatchSampler`):
112
+ The batch sampler to split in several shards.
113
+ num_processes (`int`, *optional*, defaults to 1):
114
+ The number of processes running concurrently.
115
+ process_index (`int`, *optional*, defaults to 0):
116
+ The index of the current process.
117
+ split_batches (`bool`, *optional*, defaults to `False`):
118
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
119
+ yielding different full batches on each process.
120
+
121
+ On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in:
122
+
123
+ - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if
124
+ this argument is set to `False`.
125
+ - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]`
126
+ then `[6, 7]` if this argument is set to `True`.
127
+ even_batches (`bool`, *optional*, defaults to `True`):
128
+ Whether or not to loop back at the beginning of the sampler when the number of samples is not a round
129
+ multiple of (original batch size / number of processes).
130
+
131
+ <Tip warning={true}>
132
+
133
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
134
+ equal to `False`
135
+
136
+ </Tip>"""
137
+
138
+ def __init__(
139
+ self,
140
+ batch_sampler: BatchSampler,
141
+ num_processes: int = 1,
142
+ process_index: int = 0,
143
+ split_batches: bool = False,
144
+ even_batches: bool = True,
145
+ ):
146
+ if split_batches and batch_sampler.batch_size % num_processes != 0:
147
+ raise ValueError(
148
+ f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) "
149
+ f"needs to be a round multiple of the number of processes ({num_processes})."
150
+ )
151
+ self.batch_sampler = batch_sampler
152
+ self.num_processes = num_processes
153
+ self.process_index = process_index
154
+ self.split_batches = split_batches
155
+ self.even_batches = even_batches
156
+ self.batch_size = getattr(batch_sampler, "batch_size", None)
157
+ self.drop_last = getattr(batch_sampler, "drop_last", False)
158
+ if self.batch_size is None and self.even_batches:
159
+ raise ValueError(
160
+ "You need to use `even_batches=False` when the batch sampler has no batch size. If you "
161
+ "are not calling this method directly, set `accelerator.even_batches=False` instead."
162
+ )
163
+
164
+ @property
165
+ def total_length(self):
166
+ return len(self.batch_sampler)
167
+
168
+ def __len__(self):
169
+ if self.split_batches:
170
+ # Split batches does not change the length of the batch sampler
171
+ return len(self.batch_sampler)
172
+ if len(self.batch_sampler) % self.num_processes == 0:
173
+ # If the length is a round multiple of the number of processes, it's easy.
174
+ return len(self.batch_sampler) // self.num_processes
175
+ length = len(self.batch_sampler) // self.num_processes
176
+ if self.drop_last:
177
+ # Same if we drop the remainder.
178
+ return length
179
+ elif self.even_batches:
180
+ # When we even batches we always get +1
181
+ return length + 1
182
+ else:
183
+ # Otherwise it depends on the process index.
184
+ return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length
185
+
186
+ def __iter__(self):
187
+ return self._iter_with_split() if self.split_batches else self._iter_with_no_split()
188
+
189
+ def _iter_with_split(self):
190
+ initial_data = []
191
+ batch_length = self.batch_sampler.batch_size // self.num_processes
192
+ for idx, batch in enumerate(self.batch_sampler):
193
+ if idx == 0:
194
+ initial_data = batch
195
+ if len(batch) == self.batch_size:
196
+ # If the batch is full, we yield the part of it this process is responsible of.
197
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
198
+
199
+ # If drop_last is True of the last batch was full, iteration is over, otherwise...
200
+ if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size:
201
+ if not self.even_batches:
202
+ if len(batch) > batch_length * self.process_index:
203
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
204
+ else:
205
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
206
+ while len(initial_data) < self.batch_size:
207
+ initial_data += initial_data
208
+ batch = batch + initial_data
209
+ yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)]
210
+
211
+ def _iter_with_no_split(self):
212
+ initial_data = []
213
+ batch_to_yield = []
214
+ for idx, batch in enumerate(self.batch_sampler):
215
+ # We gather the initial indices in case we need to circle back at the end.
216
+ if not self.drop_last and idx < self.num_processes:
217
+ initial_data += batch
218
+ # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
219
+ # yielding it.
220
+ if idx % self.num_processes == self.process_index:
221
+ batch_to_yield = batch
222
+ if idx % self.num_processes == self.num_processes - 1 and (
223
+ self.batch_size is None or len(batch) == self.batch_size
224
+ ):
225
+ yield batch_to_yield
226
+ batch_to_yield = []
227
+
228
+ # If drop_last is True, iteration is over, otherwise...
229
+ if not self.drop_last and len(initial_data) > 0:
230
+ if not self.even_batches:
231
+ if len(batch_to_yield) > 0:
232
+ yield batch_to_yield
233
+ else:
234
+ # ... we yield the complete batch we had saved before if it has the proper length
235
+ if len(batch_to_yield) == self.batch_size:
236
+ yield batch_to_yield
237
+
238
+ # For degenerate cases where the dataset has less than num_process * batch_size samples
239
+ while len(initial_data) < self.num_processes * self.batch_size:
240
+ initial_data += initial_data
241
+
242
+ # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next
243
+ if len(batch) == self.batch_size:
244
+ batch = []
245
+ idx += 1
246
+
247
+ # Make sure we yield a multiple of self.num_processes batches
248
+ cycle_index = 0
249
+ while idx % self.num_processes != 0 or len(batch) > 0:
250
+ end_index = cycle_index + self.batch_size - len(batch)
251
+ batch += initial_data[cycle_index:end_index]
252
+ if idx % self.num_processes == self.process_index:
253
+ yield batch
254
+ cycle_index = end_index
255
+ batch = []
256
+ idx += 1
257
+
258
+
259
+ class IterableDatasetShard(IterableDataset):
260
+ """
261
+ Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
262
+ always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
263
+ `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the
264
+ `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
265
+ be too small or loop with indices from the beginning.
266
+
267
+ Args:
268
+ dataset (`torch.utils.data.dataset.IterableDataset`):
269
+ The batch sampler to split in several shards.
270
+ batch_size (`int`, *optional*, defaults to 1):
271
+ The size of the batches per shard (if `split_batches=False`) or the size of the batches (if
272
+ `split_batches=True`).
273
+ drop_last (`bool`, *optional*, defaults to `False`):
274
+ Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
275
+ beginning.
276
+ num_processes (`int`, *optional*, defaults to 1):
277
+ The number of processes running concurrently.
278
+ process_index (`int`, *optional*, defaults to 0):
279
+ The index of the current process.
280
+ split_batches (`bool`, *optional*, defaults to `False`):
281
+ Whether the shards should be created by splitting a batch to give a piece of it on each process, or by
282
+ yielding different full batches on each process.
283
+
284
+ On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in:
285
+
286
+ - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this
287
+ argument is set to `False`.
288
+ - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if
289
+ this argument is set to `True`.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ dataset: IterableDataset,
295
+ batch_size: int = 1,
296
+ drop_last: bool = False,
297
+ num_processes: int = 1,
298
+ process_index: int = 0,
299
+ split_batches: bool = False,
300
+ ):
301
+ if split_batches and batch_size > 1 and batch_size % num_processes != 0:
302
+ raise ValueError(
303
+ f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) "
304
+ f"needs to be a round multiple of the number of processes ({num_processes})."
305
+ )
306
+ self.dataset = dataset
307
+ self.batch_size = batch_size
308
+ self.drop_last = drop_last
309
+ self.num_processes = num_processes
310
+ self.process_index = process_index
311
+ self.split_batches = split_batches
312
+
313
+ def set_epoch(self, epoch):
314
+ self.epoch = epoch
315
+ if hasattr(self.dataset, "set_epoch"):
316
+ self.dataset.set_epoch(epoch)
317
+
318
+ def __len__(self):
319
+ # We will just raise the downstream error if the underlying dataset is not sized
320
+ if self.drop_last:
321
+ return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
322
+ else:
323
+ return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
324
+
325
+ def __iter__(self):
326
+ if (
327
+ not hasattr(self.dataset, "set_epoch")
328
+ and hasattr(self.dataset, "generator")
329
+ and isinstance(self.dataset.generator, torch.Generator)
330
+ ):
331
+ self.dataset.generator.manual_seed(self.epoch)
332
+ real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes)
333
+ process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size
334
+ process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size)
335
+
336
+ first_batch = None
337
+ current_batch = []
338
+ for element in self.dataset:
339
+ current_batch.append(element)
340
+ # Wait to have a full batch before yielding elements.
341
+ if len(current_batch) == real_batch_size:
342
+ for i in process_slice:
343
+ yield current_batch[i]
344
+ if first_batch is None:
345
+ first_batch = current_batch.copy()
346
+ current_batch = []
347
+
348
+ # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
349
+ if not self.drop_last and len(current_batch) > 0:
350
+ if first_batch is None:
351
+ first_batch = current_batch.copy()
352
+ while len(current_batch) < real_batch_size:
353
+ current_batch += first_batch
354
+ for i in process_slice:
355
+ yield current_batch[i]
356
+
357
+
358
+ class DataLoaderStateMixin:
359
+ """
360
+ Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the
361
+ end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other
362
+ useful information that might be needed.
363
+
364
+ **Available attributes:**
365
+
366
+ - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch
367
+ - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
368
+ batch size
369
+
370
+ <Tip warning={true}>
371
+
372
+ Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
373
+ `self.gradient_state`.
374
+
375
+ </Tip>
376
+
377
+ """
378
+
379
+ def __init_subclass__(cls, **kwargs):
380
+ cls.end_of_dataloader = False
381
+ cls.remainder = -1
382
+
383
+ def reset(self):
384
+ self.end_of_dataloader = False
385
+ self.remainder = -1
386
+
387
+ def begin(self):
388
+ "Prepares the gradient state for the current dataloader"
389
+ self.reset()
390
+ with suppress(Exception):
391
+ if not self._drop_last:
392
+ length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
393
+ self.remainder = length % self.total_batch_size
394
+ self.gradient_state._add_dataloader(self)
395
+
396
+ def end(self):
397
+ "Cleans up the gradient state after exiting the dataloader"
398
+ self.gradient_state._remove_dataloader(self)
399
+
400
+
401
+ class DataLoaderAdapter:
402
+ """
403
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
404
+ compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
405
+ """
406
+
407
+ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
408
+ self.use_stateful_dataloader = use_stateful_dataloader
409
+ if is_torchdata_stateful_dataloader_available():
410
+ from torchdata.stateful_dataloader import StatefulDataLoader
411
+
412
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
413
+ raise ImportError(
414
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
415
+ )
416
+ if use_stateful_dataloader:
417
+ self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
418
+ else:
419
+ self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
420
+
421
+ if hasattr(self.base_dataloader, "state_dict"):
422
+ self.dl_state_dict = self.base_dataloader.state_dict()
423
+
424
+ def __getattr__(self, name):
425
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
426
+ if name == "base_dataloader":
427
+ raise AttributeError()
428
+ # Delegate attribute access to the internal dataloader
429
+ return getattr(self.base_dataloader, name)
430
+
431
+ def state_dict(self):
432
+ return self.dl_state_dict
433
+
434
+ def load_state_dict(self, state_dict):
435
+ self.base_dataloader.load_state_dict(state_dict)
436
+
437
+ @property
438
+ def __class__(self):
439
+ """
440
+ In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
441
+ returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
442
+ object.
443
+ """
444
+ return self.base_dataloader.__class__
445
+
446
+ def __len__(self):
447
+ return len(self.base_dataloader)
448
+
449
+ def adjust_state_dict_for_prefetch(self):
450
+ """
451
+ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
452
+ `self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
453
+ overridden.
454
+
455
+ This should modify `self.dl_state_dict` directly
456
+ """
457
+ # The state dict will be off by a factor of `n-1` batch too many during DDP,
458
+ # so we need to adjust it here
459
+ if PartialState().distributed_type != DistributedType.NO:
460
+ factor = PartialState().num_processes - 1
461
+ if self.dl_state_dict["_sampler_iter_yielded"] > 0:
462
+ self.dl_state_dict["_sampler_iter_yielded"] -= factor
463
+ if self.dl_state_dict["_num_yielded"] > 0:
464
+ self.dl_state_dict["_num_yielded"] -= factor
465
+ if self.dl_state_dict["_index_sampler_state"] is not None:
466
+ if (
467
+ "samples_yielded" in self.dl_state_dict["_index_sampler_state"]
468
+ and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
469
+ ):
470
+ self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
471
+
472
+ def _update_state_dict(self):
473
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
474
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
475
+ # what it wants to yield.
476
+ #
477
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
478
+ if hasattr(self.base_dataloader, "state_dict"):
479
+ self.dl_state_dict = self.base_dataloader.state_dict()
480
+ # Potentially modify the state_dict to adjust for prefetching
481
+ self.adjust_state_dict_for_prefetch()
482
+ # Then tag if we are at the end of the dataloader
483
+ self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
484
+
485
+
486
+ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
487
+ """
488
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
489
+
490
+ Args:
491
+ dataset (`torch.utils.data.dataset.Dataset`):
492
+ The dataset to use to build this dataloader.
493
+ device (`torch.device`, *optional*):
494
+ If passed, the device to put all batches on.
495
+ rng_types (list of `str` or [`~utils.RNGType`]):
496
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
497
+ several of:
498
+
499
+ - `"torch"`: the base torch random number generator
500
+ - `"cuda"`: the CUDA random number generator (GPU only)
501
+ - `"xla"`: the XLA random number generator (TPU only)
502
+ - `"generator"`: an optional `torch.Generator`
503
+ synchronized_generator (`torch.Generator`, *optional*):
504
+ A random number generator to keep synchronized across processes.
505
+ skip_batches (`int`, *optional*, defaults to 0):
506
+ The number of batches to skip at the beginning.
507
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
508
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
509
+ **kwargs (additional keyword arguments, *optional*):
510
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
511
+
512
+ **Available attributes:**
513
+
514
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
515
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
516
+ number of processes
517
+
518
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ dataset,
524
+ device=None,
525
+ rng_types=None,
526
+ synchronized_generator=None,
527
+ skip_batches=0,
528
+ use_stateful_dataloader=False,
529
+ _drop_last: bool = False,
530
+ _non_blocking: bool = False,
531
+ **kwargs,
532
+ ):
533
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
534
+ self.device = device
535
+ self.rng_types = rng_types
536
+ self.synchronized_generator = synchronized_generator
537
+ self.skip_batches = skip_batches
538
+ self.gradient_state = GradientState()
539
+ self._drop_last = _drop_last
540
+ self._non_blocking = _non_blocking
541
+ self.iteration = 0
542
+
543
+ def __iter__(self):
544
+ if self.rng_types is not None:
545
+ synchronize_rng_states(self.rng_types, self.synchronized_generator)
546
+ self.begin()
547
+
548
+ self.set_epoch(self.iteration)
549
+ dataloader_iter = self.base_dataloader.__iter__()
550
+ # We iterate one batch ahead to check when we are at the end
551
+ try:
552
+ current_batch = next(dataloader_iter)
553
+ except StopIteration:
554
+ yield
555
+
556
+ batch_index = 0
557
+ while True:
558
+ try:
559
+ # But we still move it to the device so it is done before `StopIteration` is reached
560
+ if self.device is not None:
561
+ current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
562
+ self._update_state_dict()
563
+ next_batch = next(dataloader_iter)
564
+ if batch_index >= self.skip_batches:
565
+ yield current_batch
566
+ batch_index += 1
567
+ current_batch = next_batch
568
+ except StopIteration:
569
+ self.end_of_dataloader = True
570
+ self._update_state_dict()
571
+ if batch_index >= self.skip_batches:
572
+ yield current_batch
573
+ break
574
+
575
+ self.iteration += 1
576
+ self.end()
577
+
578
+ def __reduce__(self):
579
+ """
580
+ Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
581
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
582
+ `__class__` member.
583
+ """
584
+ args = super().__reduce__()
585
+ return (DataLoaderShard, *args[1:])
586
+
587
+ def set_epoch(self, epoch: int):
588
+ # In case it is manually passed in, the user can set it to what they like
589
+ if self.iteration != epoch:
590
+ self.iteration = epoch
591
+ if hasattr(self.batch_sampler, "set_epoch"):
592
+ self.batch_sampler.set_epoch(epoch)
593
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
594
+ self.batch_sampler.sampler.set_epoch(epoch)
595
+ # We support if a custom `Dataset` implementation has `set_epoch`
596
+ # or in general HF datasets `Datasets`
597
+ elif hasattr(self.dataset, "set_epoch"):
598
+ self.dataset.set_epoch(epoch)
599
+
600
+ @property
601
+ def total_batch_size(self):
602
+ batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler
603
+ return (
604
+ batch_sampler.batch_size
605
+ if getattr(batch_sampler, "split_batches", False)
606
+ else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1))
607
+ )
608
+
609
+ @property
610
+ def total_dataset_length(self):
611
+ if hasattr(self.dataset, "total_length"):
612
+ return self.dataset.total_length
613
+ else:
614
+ return len(self.dataset)
615
+
616
+ def get_sampler(self):
617
+ return get_sampler(self)
618
+
619
+ def set_sampler(self, sampler):
620
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
621
+ if sampler_is_batch_sampler:
622
+ self.sampler.sampler = sampler
623
+ else:
624
+ self.batch_sampler.sampler = sampler
625
+ if hasattr(self.batch_sampler, "batch_sampler"):
626
+ self.batch_sampler.batch_sampler.sampler = sampler
627
+
628
+
629
+ if is_torch_xla_available():
630
+ import torch_xla.distributed.parallel_loader as xpl
631
+
632
+ class MpDeviceLoaderWrapper(xpl.MpDeviceLoader):
633
+ """
634
+ Wrapper for the xpl.MpDeviceLoader class that knows the total batch size.
635
+
636
+ XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to
637
+ prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main
638
+ thread only.
639
+
640
+ **Available attributes:**
641
+
642
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
643
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
644
+ number of processes
645
+
646
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
647
+ """
648
+
649
+ def __init__(self, dataloader: DataLoaderShard, device: torch.device):
650
+ super().__init__(dataloader, device)
651
+ self._rng_types = self._loader.rng_types
652
+ self._loader.rng_types = None
653
+ self.device = device
654
+
655
+ def __iter__(self):
656
+ if self._rng_types is not None:
657
+ synchronize_rng_states(self._rng_types, self._loader.synchronized_generator)
658
+
659
+ return super().__iter__()
660
+
661
+ def set_epoch(self, epoch: int):
662
+ if hasattr(self.dataloader, "set_epoch"):
663
+ self.dataloader.set_epoch(epoch)
664
+
665
+ @property
666
+ def total_batch_size(self):
667
+ return self._loader.total_batch_size
668
+
669
+ @property
670
+ def total_dataset_length(self):
671
+ return self._loader.total_dataset_length
672
+
673
+ @property
674
+ def batch_sampler(self):
675
+ return self._loader.batch_sampler
676
+
677
+ @property
678
+ def dataloader(self):
679
+ return self._loader
680
+
681
+
682
+ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
683
+ """
684
+ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
685
+ their part of the batch.
686
+
687
+ Args:
688
+ split_batches (`bool`, *optional*, defaults to `False`):
689
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
690
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
691
+ `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be
692
+ the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial
693
+ `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch
694
+ size of the `dataloader` is a round multiple of `batch_size`.
695
+ skip_batches (`int`, *optional*, defaults to 0):
696
+ The number of batches to skip at the beginning of an iteration.
697
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
698
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
699
+
700
+ **Available attributes:**
701
+
702
+ - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes.
703
+ Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total
704
+ number of processes
705
+
706
+ - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes.
707
+ """
708
+
709
+ def __init__(
710
+ self,
711
+ dataset,
712
+ split_batches: bool = False,
713
+ skip_batches=0,
714
+ use_stateful_dataloader=False,
715
+ _drop_last: bool = False,
716
+ _non_blocking: bool = False,
717
+ slice_fn=None,
718
+ **kwargs,
719
+ ):
720
+ shuffle = False
721
+ if is_torch_version(">=", "1.11.0"):
722
+ from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
723
+
724
+ # We need to save the shuffling state of the DataPipe
725
+ if isinstance(dataset, ShufflerIterDataPipe):
726
+ shuffle = dataset._shuffle_enabled
727
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
728
+ self.split_batches = split_batches
729
+ if shuffle:
730
+ torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
731
+
732
+ self.gradient_state = GradientState()
733
+ self.state = PartialState()
734
+ self._drop_last = _drop_last
735
+ self._non_blocking = _non_blocking
736
+ self.skip_batches = skip_batches
737
+
738
+ self.slice_fn = slice_tensors if slice_fn is None else slice_fn
739
+ self.iteration = 0
740
+
741
+ def _fetch_batches(self, iterator):
742
+ batches, batch = None, None
743
+ # On process 0, we gather the batch to dispatch.
744
+ if self.state.process_index == 0:
745
+ try:
746
+ if self.split_batches:
747
+ # One batch of the main iterator is dispatched and split.
748
+ self._update_state_dict()
749
+ batch = next(iterator)
750
+ else:
751
+ # num_processes batches of the main iterator are concatenated then dispatched and split.
752
+ # We add the batches one by one so we have the remainder available when drop_last=False.
753
+ batches = []
754
+ for _ in range(self.state.num_processes):
755
+ self._update_state_dict()
756
+ batches.append(next(iterator))
757
+ try:
758
+ batch = concatenate(batches, dim=0)
759
+ except RuntimeError as e:
760
+ raise RuntimeError(
761
+ "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`."
762
+ "either pass `dispatch_batches=False` and have each process fetch its own batch "
763
+ " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and "
764
+ "slice it into `num_processes` batches for each process."
765
+ ) from e
766
+ # In both cases, we need to get the structure of the batch that we will broadcast on other
767
+ # processes to initialize the tensors with the right shape.
768
+ # data_structure, stop_iteration
769
+ batch_info = [get_data_structure(batch), False]
770
+ except StopIteration:
771
+ batch_info = [None, True]
772
+ else:
773
+ batch_info = [None, self._stop_iteration]
774
+ # This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
775
+ broadcast_object_list(batch_info)
776
+ self._stop_iteration = batch_info[1]
777
+ if self._stop_iteration:
778
+ # If drop_last is False and split_batches is False, we may have a remainder to take care of.
779
+ if not self.split_batches and not self._drop_last:
780
+ if self.state.process_index == 0 and len(batches) > 0:
781
+ batch = concatenate(batches, dim=0)
782
+ batch_info = [get_data_structure(batch), False]
783
+ else:
784
+ batch_info = [None, True]
785
+ broadcast_object_list(batch_info)
786
+ return batch, batch_info
787
+
788
+ def __iter__(self):
789
+ self.begin()
790
+ self.set_epoch(self.iteration)
791
+ main_iterator = None
792
+ if is_torch_version(">=", "2.0.1"):
793
+ # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
794
+ # shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
795
+ # But, we only iterate through the DataLoader on process 0.
796
+ main_iterator = self.base_dataloader.__iter__()
797
+ elif self.state.process_index == 0:
798
+ main_iterator = self.base_dataloader.__iter__()
799
+ stop_iteration = False
800
+ self._stop_iteration = False
801
+ first_batch = None
802
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
803
+ batch_index = 0
804
+ while not stop_iteration:
805
+ batch, batch_info = next_batch, next_batch_info
806
+
807
+ if self.state.process_index != 0:
808
+ # Initialize tensors on other processes than process 0.
809
+ batch = initialize_tensors(batch_info[0])
810
+ batch = send_to_device(batch, self.state.device, non_blocking=self._non_blocking)
811
+ # Broadcast the batch before splitting it.
812
+ batch = broadcast(batch, from_process=0)
813
+
814
+ if not self._drop_last and first_batch is None:
815
+ # We keep at least num processes elements of the first batch to be able to complete the last batch
816
+ first_batch = self.slice_fn(
817
+ batch,
818
+ slice(0, self.state.num_processes),
819
+ process_index=self.state.process_index,
820
+ num_processes=self.state.num_processes,
821
+ )
822
+
823
+ if batch is None:
824
+ raise ValueError(
825
+ f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration."
826
+ )
827
+
828
+ observed_batch_size = find_batch_size(batch)
829
+ batch_size = observed_batch_size // self.state.num_processes
830
+
831
+ stop_iteration = self._stop_iteration
832
+ if not stop_iteration:
833
+ # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in
834
+ # the dataloader since the number of batches is a round multiple of the number of processes.
835
+ next_batch, next_batch_info = self._fetch_batches(main_iterator)
836
+ # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them.
837
+ if self._stop_iteration and next_batch_info[0] is None:
838
+ stop_iteration = True
839
+
840
+ if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0:
841
+ # If the last batch is not complete, let's add the first batch to it.
842
+ batch = concatenate([batch, first_batch], dim=0)
843
+ # Batch size computation above is wrong, it's off by 1 so we fix it.
844
+ batch_size += 1
845
+
846
+ data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
847
+ batch = self.slice_fn(
848
+ batch,
849
+ data_slice,
850
+ process_index=self.state.process_index,
851
+ num_processes=self.state.num_processes,
852
+ )
853
+
854
+ if stop_iteration:
855
+ self.end_of_dataloader = True
856
+ self._update_state_dict()
857
+ self.remainder = observed_batch_size
858
+ if batch_index >= self.skip_batches:
859
+ yield batch
860
+ batch_index += 1
861
+ self.iteration += 1
862
+ self.end()
863
+
864
+ def set_epoch(self, epoch: int):
865
+ # In case it is manually passed in, the user can set it to what they like
866
+ if self.iteration != epoch:
867
+ self.iteration = epoch
868
+ if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
869
+ self.batch_sampler.sampler.set_epoch(epoch)
870
+ elif hasattr(self.dataset, "set_epoch"):
871
+ self.dataset.set_epoch(epoch)
872
+
873
+ def __len__(self):
874
+ whole_length = len(self.base_dataloader)
875
+ if self.split_batches:
876
+ return whole_length
877
+ elif self._drop_last:
878
+ return whole_length // self.state.num_processes
879
+ else:
880
+ return math.ceil(whole_length / self.state.num_processes)
881
+
882
+ def __reduce__(self):
883
+ """
884
+ Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
885
+ be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
886
+ `__class__` member.
887
+ """
888
+ args = super().__reduce__()
889
+ return (DataLoaderDispatcher, *args[1:])
890
+
891
+ @property
892
+ def total_batch_size(self):
893
+ return (
894
+ self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes)
895
+ )
896
+
897
+ @property
898
+ def total_dataset_length(self):
899
+ return len(self.dataset)
900
+
901
+ def get_sampler(self):
902
+ return get_sampler(self)
903
+
904
+ def set_sampler(self, sampler):
905
+ sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
906
+ if sampler_is_batch_sampler:
907
+ self.sampler.sampler = sampler
908
+ else:
909
+ self.batch_sampler.sampler = sampler
910
+ if hasattr(self.batch_sampler, "batch_sampler"):
911
+ self.batch_sampler.batch_sampler.sampler = sampler
912
+
913
+
914
+ def get_sampler(dataloader):
915
+ """
916
+ Get the sampler associated to the dataloader
917
+
918
+ Args:
919
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
920
+ The data loader to split across several devices.
921
+ Returns:
922
+ `torch.utils.data.Sampler`: The sampler associated to the dataloader
923
+ """
924
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
925
+ if sampler_is_batch_sampler:
926
+ sampler = getattr(dataloader.sampler, "sampler", None)
927
+ else:
928
+ sampler = getattr(dataloader.batch_sampler, "sampler", None)
929
+ return sampler
930
+
931
+
932
+ def prepare_data_loader(
933
+ dataloader: DataLoader,
934
+ device: Optional[torch.device] = None,
935
+ num_processes: Optional[int] = None,
936
+ process_index: Optional[int] = None,
937
+ split_batches: bool = False,
938
+ put_on_device: bool = False,
939
+ rng_types: Optional[List[Union[str, RNGType]]] = None,
940
+ dispatch_batches: Optional[bool] = None,
941
+ even_batches: bool = True,
942
+ slice_fn_for_dispatch: Optional[Callable] = None,
943
+ use_seedable_sampler: bool = False,
944
+ data_seed: Optional[int] = None,
945
+ non_blocking: bool = False,
946
+ use_stateful_dataloader: bool = False,
947
+ ) -> DataLoader:
948
+ """
949
+ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
950
+
951
+ Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration
952
+ at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
953
+
954
+ Args:
955
+ dataloader (`torch.utils.data.dataloader.DataLoader`):
956
+ The data loader to split across several devices.
957
+ device (`torch.device`):
958
+ The target device for the returned `DataLoader`.
959
+ num_processes (`int`, *optional*):
960
+ The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
961
+ process_index (`int`, *optional*):
962
+ The index of the current process. Will default to the value given by [`~state.PartialState`].
963
+ split_batches (`bool`, *optional*, defaults to `False`):
964
+ Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
965
+ yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
966
+ `num_processes` batches at each iteration).
967
+
968
+ Another way to see this is that the observed batch size will be the same as the initial `dataloader` if
969
+ this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes`
970
+ otherwise.
971
+
972
+ Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of
973
+ `batch_size`.
974
+ put_on_device (`bool`, *optional*, defaults to `False`):
975
+ Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or
976
+ dictionaries of tensors).
977
+ rng_types (list of `str` or [`~utils.RNGType`]):
978
+ The list of random number generators to synchronize at the beginning of each iteration. Should be one or
979
+ several of:
980
+
981
+ - `"torch"`: the base torch random number generator
982
+ - `"cuda"`: the CUDA random number generator (GPU only)
983
+ - `"xla"`: the XLA random number generator (TPU only)
984
+ - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
985
+ dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
986
+
987
+ dispatch_batches (`bool`, *optional*):
988
+ If set to `True`, the dataloader prepared is only iterated through on the main process and then the batches
989
+ are split and broadcast to each process. Will default to `True` when the underlying dataset is an
990
+ `IterableDataset`, `False` otherwise.
991
+ even_batches (`bool`, *optional*, defaults to `True`):
992
+ If set to `True`, in cases where the total batch size across all processes does not exactly divide the
993
+ dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
994
+ all workers.
995
+ slice_fn_for_dispatch (`Callable`, *optional*`):
996
+ If passed, this function will be used to slice tensors across `num_processes`. Will default to
997
+ [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
998
+ ignored otherwise.
999
+ use_seedable_sampler (`bool`, *optional*, defaults to `False`):
1000
+ Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
1001
+ reproducability. Comes at a cost of potentially different performances due to different shuffling
1002
+ algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
1003
+ `self.set_epoch`
1004
+ data_seed (`int`, *optional*, defaults to `None`):
1005
+ The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
1006
+ will use the current default seed from torch.
1007
+ non_blocking (`bool`, *optional*, defaults to `False`):
1008
+ If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
1009
+ `pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
1010
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
1011
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
1012
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
1013
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
1014
+
1015
+
1016
+ Returns:
1017
+ `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
1018
+
1019
+ <Tip warning={true}>
1020
+
1021
+ `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches`
1022
+ equal to `False`
1023
+
1024
+ </Tip>
1025
+ """
1026
+ if dispatch_batches is None:
1027
+ if not put_on_device:
1028
+ dispatch_batches = False
1029
+ else:
1030
+ dispatch_batches = isinstance(dataloader.dataset, IterableDataset)
1031
+
1032
+ if dispatch_batches and not put_on_device:
1033
+ raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
1034
+ # Grab defaults from PartialState
1035
+ state = PartialState()
1036
+ if num_processes is None:
1037
+ num_processes = state.num_processes
1038
+ if process_index is None:
1039
+ process_index = state.process_index
1040
+
1041
+ # Sanity check
1042
+ if split_batches:
1043
+ if dataloader.batch_size is not None:
1044
+ batch_size_for_check = dataloader.batch_size
1045
+ else:
1046
+ # For custom batch_sampler
1047
+ if hasattr(dataloader.batch_sampler, "batch_size"):
1048
+ batch_size_for_check = dataloader.batch_sampler.batch_size
1049
+ else:
1050
+ raise ValueError(
1051
+ "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed "
1052
+ "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. "
1053
+ "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` "
1054
+ f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set."
1055
+ )
1056
+
1057
+ if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0:
1058
+ raise ValueError(
1059
+ f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) "
1060
+ f"needs to be a round multiple of the number of processes ({num_processes})."
1061
+ )
1062
+
1063
+ new_dataset = dataloader.dataset
1064
+ # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1065
+ new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
1066
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1067
+ synchronized_generator = None
1068
+
1069
+ sampler = get_sampler(dataloader)
1070
+ if isinstance(sampler, RandomSampler) and use_seedable_sampler:
1071
+ # When iterating through the dataloader during distributed processes
1072
+ # we want to ensure that on each process we are iterating through the same
1073
+ # samples in the same order if a seed is set. This requires a tweak
1074
+ # to the `torch.utils.data.RandomSampler` class (if used).
1075
+ sampler = SeedableRandomSampler(
1076
+ data_source=sampler.data_source,
1077
+ replacement=sampler.replacement,
1078
+ num_samples=sampler._num_samples,
1079
+ generator=getattr(sampler, "generator", torch.Generator()),
1080
+ data_seed=data_seed,
1081
+ )
1082
+
1083
+ if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
1084
+ # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1085
+ generator = torch.Generator().manual_seed(42)
1086
+ dataloader.generator = generator
1087
+ dataloader.sampler.generator = generator
1088
+ # No change if no multiprocess
1089
+ if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
1090
+ if isinstance(new_dataset, IterableDataset):
1091
+ if getattr(dataloader.dataset, "generator", None) is not None:
1092
+ synchronized_generator = dataloader.dataset.generator
1093
+ new_dataset = IterableDatasetShard(
1094
+ new_dataset,
1095
+ batch_size=dataloader.batch_size,
1096
+ drop_last=dataloader.drop_last,
1097
+ num_processes=num_processes,
1098
+ process_index=process_index,
1099
+ split_batches=split_batches,
1100
+ )
1101
+ else:
1102
+ if not use_seedable_sampler and hasattr(sampler, "generator"):
1103
+ if sampler.generator is None:
1104
+ sampler.generator = torch.Generator()
1105
+ synchronized_generator = sampler.generator
1106
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1107
+ new_batch_sampler = BatchSamplerShard(
1108
+ batch_sampler,
1109
+ num_processes=num_processes,
1110
+ process_index=process_index,
1111
+ split_batches=split_batches,
1112
+ even_batches=even_batches,
1113
+ )
1114
+
1115
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1116
+ ignore_kwargs = [
1117
+ "batch_size",
1118
+ "shuffle",
1119
+ "sampler",
1120
+ "batch_sampler",
1121
+ "drop_last",
1122
+ ]
1123
+
1124
+ if rng_types is not None and synchronized_generator is None and "generator" in rng_types:
1125
+ rng_types.remove("generator")
1126
+
1127
+ kwargs = {
1128
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1129
+ for k in _PYTORCH_DATALOADER_KWARGS
1130
+ if k not in ignore_kwargs
1131
+ }
1132
+
1133
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1134
+ if new_batch_sampler is None:
1135
+ kwargs["drop_last"] = dataloader.drop_last
1136
+ kwargs["batch_size"] = (
1137
+ dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
1138
+ )
1139
+ if dispatch_batches:
1140
+ kwargs.pop("generator")
1141
+ dataloader = DataLoaderDispatcher(
1142
+ new_dataset,
1143
+ split_batches=split_batches,
1144
+ batch_sampler=new_batch_sampler,
1145
+ _drop_last=dataloader.drop_last,
1146
+ _non_blocking=non_blocking,
1147
+ slice_fn=slice_fn_for_dispatch,
1148
+ use_stateful_dataloader=use_stateful_dataloader,
1149
+ **kwargs,
1150
+ )
1151
+ elif sampler_is_batch_sampler:
1152
+ dataloader = DataLoaderShard(
1153
+ new_dataset,
1154
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1155
+ sampler=new_batch_sampler,
1156
+ batch_size=dataloader.batch_size,
1157
+ rng_types=rng_types,
1158
+ _drop_last=dataloader.drop_last,
1159
+ _non_blocking=non_blocking,
1160
+ synchronized_generator=synchronized_generator,
1161
+ use_stateful_dataloader=use_stateful_dataloader,
1162
+ **kwargs,
1163
+ )
1164
+ else:
1165
+ dataloader = DataLoaderShard(
1166
+ new_dataset,
1167
+ device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1168
+ batch_sampler=new_batch_sampler,
1169
+ rng_types=rng_types,
1170
+ synchronized_generator=synchronized_generator,
1171
+ _drop_last=dataloader.drop_last,
1172
+ _non_blocking=non_blocking,
1173
+ use_stateful_dataloader=use_stateful_dataloader,
1174
+ **kwargs,
1175
+ )
1176
+
1177
+ if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
1178
+ dataloader.set_sampler(sampler)
1179
+ if state.distributed_type == DistributedType.XLA:
1180
+ return MpDeviceLoaderWrapper(dataloader, device)
1181
+ return dataloader
1182
+
1183
+
1184
+ class SkipBatchSampler(BatchSampler):
1185
+ """
1186
+ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
1187
+ Should not be used if the original dataloader is a `StatefulDataLoader`.
1188
+ """
1189
+
1190
+ def __init__(self, batch_sampler, skip_batches=0):
1191
+ self.batch_sampler = batch_sampler
1192
+ self.skip_batches = skip_batches
1193
+
1194
+ def __iter__(self):
1195
+ for index, samples in enumerate(self.batch_sampler):
1196
+ if index >= self.skip_batches:
1197
+ yield samples
1198
+
1199
+ @property
1200
+ def total_length(self):
1201
+ return len(self.batch_sampler)
1202
+
1203
+ def __len__(self):
1204
+ return len(self.batch_sampler) - self.skip_batches
1205
+
1206
+
1207
+ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
1208
+ """
1209
+ Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
1210
+ `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
1211
+
1212
+ Args:
1213
+ dataset (`torch.utils.data.dataset.Dataset`):
1214
+ The dataset to use to build this dataloader.
1215
+ skip_batches (`int`, *optional*, defaults to 0):
1216
+ The number of batches to skip at the beginning.
1217
+ kwargs:
1218
+ All other keyword arguments to pass to the regular `DataLoader` initialization.
1219
+ """
1220
+
1221
+ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
1222
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
1223
+ self.skip_batches = skip_batches
1224
+ self.gradient_state = GradientState()
1225
+
1226
+ def __iter__(self):
1227
+ self.begin()
1228
+ for index, batch in enumerate(self.base_dataloader.__iter__()):
1229
+ if index >= self.skip_batches:
1230
+ self._update_state_dict()
1231
+ yield batch
1232
+ self.end()
1233
+
1234
+ def __len__(self):
1235
+ return len(self.base_dataloader) - self.skip_batches
1236
+
1237
+ def __reduce__(self):
1238
+ """
1239
+ Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
1240
+ explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
1241
+ `__class__` member.
1242
+ """
1243
+ args = super().__reduce__()
1244
+ return (SkipDataLoader, *args[1:])
1245
+
1246
+
1247
+ def skip_first_batches(dataloader, num_batches=0):
1248
+ """
1249
+ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
1250
+ the original dataloader is a `StatefulDataLoader`.
1251
+ """
1252
+ state = PartialState()
1253
+ if state.distributed_type == DistributedType.XLA:
1254
+ device = dataloader.device
1255
+ dataloader = dataloader.dataloader
1256
+
1257
+ dataset = dataloader.dataset
1258
+ sampler_is_batch_sampler = False
1259
+ if isinstance(dataset, IterableDataset):
1260
+ new_batch_sampler = None
1261
+ else:
1262
+ sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1263
+ batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1264
+ new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
1265
+
1266
+ # We ignore all of those since they are all dealt with by our new_batch_sampler
1267
+ ignore_kwargs = [
1268
+ "batch_size",
1269
+ "shuffle",
1270
+ "sampler",
1271
+ "batch_sampler",
1272
+ "drop_last",
1273
+ ]
1274
+
1275
+ kwargs = {
1276
+ k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
1277
+ for k in _PYTORCH_DATALOADER_KWARGS
1278
+ if k not in ignore_kwargs
1279
+ }
1280
+
1281
+ # Need to provide batch_size as batch_sampler is None for Iterable dataset
1282
+ if new_batch_sampler is None:
1283
+ kwargs["drop_last"] = dataloader.drop_last
1284
+ kwargs["batch_size"] = dataloader.batch_size
1285
+
1286
+ if isinstance(dataloader, DataLoaderDispatcher):
1287
+ if new_batch_sampler is None:
1288
+ # Need to manually skip batches in the dataloader
1289
+ kwargs["skip_batches"] = num_batches
1290
+ dataloader = DataLoaderDispatcher(
1291
+ dataset,
1292
+ split_batches=dataloader.split_batches,
1293
+ batch_sampler=new_batch_sampler,
1294
+ _drop_last=dataloader._drop_last,
1295
+ **kwargs,
1296
+ )
1297
+ elif isinstance(dataloader, DataLoaderShard):
1298
+ if new_batch_sampler is None:
1299
+ # Need to manually skip batches in the dataloader
1300
+ kwargs["skip_batches"] = num_batches
1301
+ elif sampler_is_batch_sampler:
1302
+ kwargs["sampler"] = new_batch_sampler
1303
+ kwargs["batch_size"] = dataloader.batch_size
1304
+ else:
1305
+ kwargs["batch_sampler"] = new_batch_sampler
1306
+ dataloader = DataLoaderShard(
1307
+ dataset,
1308
+ device=dataloader.device,
1309
+ rng_types=dataloader.rng_types,
1310
+ synchronized_generator=dataloader.synchronized_generator,
1311
+ **kwargs,
1312
+ )
1313
+ else:
1314
+ if new_batch_sampler is None:
1315
+ # Need to manually skip batches in the dataloader
1316
+ dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
1317
+ else:
1318
+ dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
1319
+
1320
+ if state.distributed_type == DistributedType.XLA:
1321
+ dataloader = MpDeviceLoaderWrapper(dataloader, device)
1322
+
1323
+ return dataloader
.venv/Lib/site-packages/accelerate/hooks.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import Dict, List, Mapping, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .state import PartialState
22
+ from .utils import (
23
+ PrefixedDataset,
24
+ find_device,
25
+ named_module_tensors,
26
+ send_to_device,
27
+ set_module_tensor_to_device,
28
+ )
29
+ from .utils.memory import clear_device_cache
30
+ from .utils.modeling import get_non_persistent_buffers
31
+ from .utils.other import recursive_getattr
32
+
33
+
34
+ _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "musa"]
35
+
36
+
37
+ class ModelHook:
38
+ """
39
+ A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
40
+ with PyTorch existing hooks is that they get passed along the kwargs.
41
+
42
+ Class attribute:
43
+ - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
44
+ the `torch.no_grad()` context manager.
45
+ """
46
+
47
+ no_grad = False
48
+
49
+ def init_hook(self, module):
50
+ """
51
+ To be executed when the hook is attached to the module.
52
+
53
+ Args:
54
+ module (`torch.nn.Module`): The module attached to this hook.
55
+ """
56
+ return module
57
+
58
+ def pre_forward(self, module, *args, **kwargs):
59
+ """
60
+ To be executed just before the forward method of the model.
61
+
62
+ Args:
63
+ module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
64
+ args (`Tuple[Any]`): The positional arguments passed to the module.
65
+ kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
66
+
67
+ Returns:
68
+ `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
69
+ """
70
+ return args, kwargs
71
+
72
+ def post_forward(self, module, output):
73
+ """
74
+ To be executed just after the forward method of the model.
75
+
76
+ Args:
77
+ module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
78
+ output (`Any`): The output of the module.
79
+
80
+ Returns:
81
+ `Any`: The processed `output`.
82
+ """
83
+ return output
84
+
85
+ def detach_hook(self, module):
86
+ """
87
+ To be executed when the hook is detached from a module.
88
+
89
+ Args:
90
+ module (`torch.nn.Module`): The module detached from this hook.
91
+ """
92
+ return module
93
+
94
+
95
+ class SequentialHook(ModelHook):
96
+ """
97
+ A hook that can contain several hooks and iterates through them at each event.
98
+ """
99
+
100
+ def __init__(self, *hooks):
101
+ self.hooks = hooks
102
+
103
+ def init_hook(self, module):
104
+ for hook in self.hooks:
105
+ module = hook.init_hook(module)
106
+ return module
107
+
108
+ def pre_forward(self, module, *args, **kwargs):
109
+ for hook in self.hooks:
110
+ args, kwargs = hook.pre_forward(module, *args, **kwargs)
111
+ return args, kwargs
112
+
113
+ def post_forward(self, module, output):
114
+ for hook in self.hooks:
115
+ output = hook.post_forward(module, output)
116
+ return output
117
+
118
+ def detach_hook(self, module):
119
+ for hook in self.hooks:
120
+ module = hook.detach_hook(module)
121
+ return module
122
+
123
+
124
+ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
125
+ """
126
+ Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
127
+ this behavior and restore the original `forward` method, use `remove_hook_from_module`.
128
+
129
+ <Tip warning={true}>
130
+
131
+ If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
132
+ together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
133
+
134
+ </Tip>
135
+
136
+ Args:
137
+ module (`torch.nn.Module`):
138
+ The module to attach a hook to.
139
+ hook (`ModelHook`):
140
+ The hook to attach.
141
+ append (`bool`, *optional*, defaults to `False`):
142
+ Whether the hook should be chained with an existing one (if module already contains a hook) or not.
143
+
144
+ Returns:
145
+ `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
146
+ be discarded).
147
+ """
148
+
149
+ if append and (getattr(module, "_hf_hook", None) is not None):
150
+ old_hook = module._hf_hook
151
+ remove_hook_from_module(module)
152
+ hook = SequentialHook(old_hook, hook)
153
+
154
+ if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
155
+ # If we already put some hook on this module, we replace it with the new one.
156
+ old_forward = module._old_forward
157
+ else:
158
+ old_forward = module.forward
159
+ module._old_forward = old_forward
160
+
161
+ module = hook.init_hook(module)
162
+ module._hf_hook = hook
163
+
164
+ def new_forward(module, *args, **kwargs):
165
+ args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
166
+ if module._hf_hook.no_grad:
167
+ with torch.no_grad():
168
+ output = module._old_forward(*args, **kwargs)
169
+ else:
170
+ output = module._old_forward(*args, **kwargs)
171
+ return module._hf_hook.post_forward(module, output)
172
+
173
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
174
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
175
+ if "GraphModuleImpl" in str(type(module)):
176
+ module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
177
+ else:
178
+ module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
179
+
180
+ return module
181
+
182
+
183
+ def remove_hook_from_module(module: nn.Module, recurse=False):
184
+ """
185
+ Removes any hook attached to a module via `add_hook_to_module`.
186
+
187
+ Args:
188
+ module (`torch.nn.Module`): The module to attach a hook to.
189
+ recurse (`bool`, **optional**): Whether to remove the hooks recursively
190
+
191
+ Returns:
192
+ `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
193
+ be discarded).
194
+ """
195
+
196
+ if hasattr(module, "_hf_hook"):
197
+ module._hf_hook.detach_hook(module)
198
+ delattr(module, "_hf_hook")
199
+
200
+ if hasattr(module, "_old_forward"):
201
+ # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
202
+ # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
203
+ if "GraphModuleImpl" in str(type(module)):
204
+ module.__class__.forward = module._old_forward
205
+ else:
206
+ module.forward = module._old_forward
207
+ delattr(module, "_old_forward")
208
+
209
+ # Remove accelerate added warning hooks from dispatch_model
210
+ for attr in _accelerate_added_attributes:
211
+ module.__dict__.pop(attr, None)
212
+
213
+ if recurse:
214
+ for child in module.children():
215
+ remove_hook_from_module(child, recurse)
216
+
217
+ return module
218
+
219
+
220
+ class AlignDevicesHook(ModelHook):
221
+ """
222
+ A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
223
+ associated module, potentially offloading the weights after the forward pass.
224
+
225
+ Args:
226
+ execution_device (`torch.device`, *optional*):
227
+ The device on which inputs and model weights should be placed before the forward pass.
228
+ offload (`bool`, *optional*, defaults to `False`):
229
+ Whether or not the weights should be offloaded after the forward pass.
230
+ io_same_device (`bool`, *optional*, defaults to `False`):
231
+ Whether or not the output should be placed on the same device as the input was.
232
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
233
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
234
+ offload_buffers (`bool`, *optional*, defaults to `False`):
235
+ Whether or not to include the associated module's buffers when offloading.
236
+ place_submodules (`bool`, *optional*, defaults to `False`):
237
+ Whether to place the submodules on `execution_device` during the `init_hook` event.
238
+ """
239
+
240
+ def __init__(
241
+ self,
242
+ execution_device: Optional[Union[int, str, torch.device]] = None,
243
+ offload: bool = False,
244
+ io_same_device: bool = False,
245
+ weights_map: Optional[Mapping] = None,
246
+ offload_buffers: bool = False,
247
+ place_submodules: bool = False,
248
+ skip_keys: Optional[Union[str, List[str]]] = None,
249
+ tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
250
+ ):
251
+ self.execution_device = execution_device
252
+ self.offload = offload
253
+ self.io_same_device = io_same_device
254
+ self.weights_map = weights_map
255
+ self.offload_buffers = offload_buffers
256
+ self.place_submodules = place_submodules
257
+ self.skip_keys = skip_keys
258
+
259
+ # Will contain the input device when `io_same_device=True`.
260
+ self.input_device = None
261
+ self.param_original_devices = {}
262
+ self.buffer_original_devices = {}
263
+ self.tied_params_names = set()
264
+
265
+ # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
266
+ # for tied weights already loaded on the target execution device.
267
+ self.tied_params_map = tied_params_map
268
+
269
+ def __repr__(self):
270
+ return (
271
+ f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
272
+ f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
273
+ f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
274
+ )
275
+
276
+ def init_hook(self, module):
277
+ # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
278
+ if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
279
+ self.tied_params_map = None
280
+
281
+ if not self.offload and self.execution_device is not None:
282
+ for name, _ in named_module_tensors(module, recurse=self.place_submodules):
283
+ set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
284
+ elif self.offload:
285
+ self.original_devices = {
286
+ name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
287
+ }
288
+ if self.weights_map is None:
289
+ self.weights_map = {
290
+ name: param.to("cpu")
291
+ for name, param in named_module_tensors(
292
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules
293
+ )
294
+ }
295
+ for name, _ in named_module_tensors(
296
+ module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
297
+ ):
298
+ # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
299
+ # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
300
+ # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
301
+ # to add on the fly pointers to `tied_params_map` in the pre_forward call.
302
+ if (
303
+ self.tied_params_map is not None
304
+ and recursive_getattr(module, name).data_ptr() in self.tied_params_map
305
+ ):
306
+ self.tied_params_names.add(name)
307
+
308
+ set_module_tensor_to_device(module, name, "meta")
309
+
310
+ if not self.offload_buffers and self.execution_device is not None:
311
+ for name, _ in module.named_buffers(recurse=self.place_submodules):
312
+ set_module_tensor_to_device(
313
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
314
+ )
315
+ elif self.offload_buffers and self.execution_device is not None:
316
+ for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
317
+ set_module_tensor_to_device(
318
+ module, name, self.execution_device, tied_params_map=self.tied_params_map
319
+ )
320
+
321
+ return module
322
+
323
+ def pre_forward(self, module, *args, **kwargs):
324
+ if self.io_same_device:
325
+ self.input_device = find_device([args, kwargs])
326
+ if self.offload:
327
+ self.tied_pointers_to_remove = set()
328
+
329
+ for name, _ in named_module_tensors(
330
+ module,
331
+ include_buffers=self.offload_buffers,
332
+ recurse=self.place_submodules,
333
+ remove_non_persistent=True,
334
+ ):
335
+ fp16_statistics = None
336
+ value = self.weights_map[name]
337
+ if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
338
+ if value.dtype == torch.int8:
339
+ fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
340
+
341
+ # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
342
+ # that are loaded on device at this point, as we will need to remove them as well from the dictionary
343
+ # self.tied_params_map in order to allow to free memory.
344
+ if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
345
+ self.tied_params_map[value.data_ptr()] = {}
346
+
347
+ if (
348
+ value is not None
349
+ and self.tied_params_map is not None
350
+ and value.data_ptr() in self.tied_params_map
351
+ and self.execution_device not in self.tied_params_map[value.data_ptr()]
352
+ ):
353
+ self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
354
+
355
+ set_module_tensor_to_device(
356
+ module,
357
+ name,
358
+ self.execution_device,
359
+ value=value,
360
+ fp16_statistics=fp16_statistics,
361
+ tied_params_map=self.tied_params_map,
362
+ )
363
+
364
+ return send_to_device(args, self.execution_device), send_to_device(
365
+ kwargs, self.execution_device, skip_keys=self.skip_keys
366
+ )
367
+
368
+ def post_forward(self, module, output):
369
+ if self.offload:
370
+ for name, _ in named_module_tensors(
371
+ module,
372
+ include_buffers=self.offload_buffers,
373
+ recurse=self.place_submodules,
374
+ remove_non_persistent=True,
375
+ ):
376
+ set_module_tensor_to_device(module, name, "meta")
377
+ if type(module).__name__ == "Linear8bitLt":
378
+ module.state.SCB = None
379
+ module.state.CxB = None
380
+
381
+ # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
382
+ # this dictionary to allow the garbage collector to do its job.
383
+ for value_pointer, device in self.tied_pointers_to_remove:
384
+ del self.tied_params_map[value_pointer][device]
385
+ self.tied_pointers_to_remove = set()
386
+
387
+ if self.io_same_device and self.input_device is not None:
388
+ output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
389
+
390
+ return output
391
+
392
+ def detach_hook(self, module):
393
+ if self.offload:
394
+ for name, device in self.original_devices.items():
395
+ if device != torch.device("meta"):
396
+ set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
397
+ return module
398
+
399
+
400
+ def attach_execution_device_hook(
401
+ module: torch.nn.Module,
402
+ execution_device: Union[int, str, torch.device],
403
+ skip_keys: Optional[Union[str, List[str]]] = None,
404
+ preload_module_classes: Optional[List[str]] = None,
405
+ tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
406
+ ):
407
+ """
408
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
409
+ execution device
410
+
411
+ Args:
412
+ module (`torch.nn.Module`):
413
+ The module where we want to attach the hooks.
414
+ execution_device (`int`, `str` or `torch.device`):
415
+ The device on which inputs and model weights should be placed before the forward pass.
416
+ skip_keys (`str` or `List[str]`, *optional*):
417
+ A list of keys to ignore when moving inputs or outputs between devices.
418
+ preload_module_classes (`List[str]`, *optional*):
419
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
420
+ of the forward. This should only be used for classes that have submodules which are registered but not
421
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
422
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
423
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
424
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
425
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
426
+ instead of duplicating memory.
427
+ """
428
+ if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
429
+ add_hook_to_module(
430
+ module,
431
+ AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
432
+ )
433
+
434
+ # Break the recursion if we get to a preload module.
435
+ if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
436
+ return
437
+
438
+ for child in module.children():
439
+ attach_execution_device_hook(
440
+ child,
441
+ execution_device,
442
+ skip_keys=skip_keys,
443
+ preload_module_classes=preload_module_classes,
444
+ tied_params_map=tied_params_map,
445
+ )
446
+
447
+
448
+ def attach_align_device_hook(
449
+ module: torch.nn.Module,
450
+ execution_device: Optional[torch.device] = None,
451
+ offload: bool = False,
452
+ weights_map: Optional[Mapping] = None,
453
+ offload_buffers: bool = False,
454
+ module_name: str = "",
455
+ skip_keys: Optional[Union[str, List[str]]] = None,
456
+ preload_module_classes: Optional[List[str]] = None,
457
+ tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
458
+ ):
459
+ """
460
+ Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
461
+ buffers.
462
+
463
+ Args:
464
+ module (`torch.nn.Module`):
465
+ The module where we want to attach the hooks.
466
+ execution_device (`torch.device`, *optional*):
467
+ The device on which inputs and model weights should be placed before the forward pass.
468
+ offload (`bool`, *optional*, defaults to `False`):
469
+ Whether or not the weights should be offloaded after the forward pass.
470
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
471
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
472
+ offload_buffers (`bool`, *optional*, defaults to `False`):
473
+ Whether or not to include the associated module's buffers when offloading.
474
+ module_name (`str`, *optional*, defaults to `""`):
475
+ The name of the module.
476
+ skip_keys (`str` or `List[str]`, *optional*):
477
+ A list of keys to ignore when moving inputs or outputs between devices.
478
+ preload_module_classes (`List[str]`, *optional*):
479
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
480
+ of the forward. This should only be used for classes that have submodules which are registered but not
481
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
482
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
483
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
484
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
485
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
486
+ instead of duplicating memory.
487
+ """
488
+ # Attach the hook on this module if it has any direct tensor.
489
+ directs = named_module_tensors(module)
490
+ full_offload = (
491
+ offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
492
+ )
493
+
494
+ if len(list(directs)) > 0 or full_offload:
495
+ if weights_map is not None:
496
+ prefix = f"{module_name}." if len(module_name) > 0 else ""
497
+ prefixed_weights_map = PrefixedDataset(weights_map, prefix)
498
+ else:
499
+ prefixed_weights_map = None
500
+ hook = AlignDevicesHook(
501
+ execution_device=execution_device,
502
+ offload=offload,
503
+ weights_map=prefixed_weights_map,
504
+ offload_buffers=offload_buffers,
505
+ place_submodules=full_offload,
506
+ skip_keys=skip_keys,
507
+ tied_params_map=tied_params_map,
508
+ )
509
+ add_hook_to_module(module, hook, append=True)
510
+
511
+ # We stop the recursion in case we hit the full offload.
512
+ if full_offload:
513
+ return
514
+
515
+ # Recurse on all children of the module.
516
+ for child_name, child in module.named_children():
517
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
518
+ attach_align_device_hook(
519
+ child,
520
+ execution_device=execution_device,
521
+ offload=offload,
522
+ weights_map=weights_map,
523
+ offload_buffers=offload_buffers,
524
+ module_name=child_name,
525
+ preload_module_classes=preload_module_classes,
526
+ skip_keys=skip_keys,
527
+ tied_params_map=tied_params_map,
528
+ )
529
+
530
+
531
+ def remove_hook_from_submodules(module: nn.Module):
532
+ """
533
+ Recursively removes all hooks attached on the submodules of a given model.
534
+
535
+ Args:
536
+ module (`torch.nn.Module`): The module on which to remove all hooks.
537
+ """
538
+ remove_hook_from_module(module)
539
+ for child in module.children():
540
+ remove_hook_from_submodules(child)
541
+
542
+
543
+ def attach_align_device_hook_on_blocks(
544
+ module: nn.Module,
545
+ execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None,
546
+ offload: Union[bool, Dict[str, bool]] = False,
547
+ weights_map: Mapping = None,
548
+ offload_buffers: bool = False,
549
+ module_name: str = "",
550
+ skip_keys: Optional[Union[str, List[str]]] = None,
551
+ preload_module_classes: Optional[List[str]] = None,
552
+ tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
553
+ ):
554
+ """
555
+ Attaches `AlignDevicesHook` to all blocks of a given model as needed.
556
+
557
+ Args:
558
+ module (`torch.nn.Module`):
559
+ The module where we want to attach the hooks.
560
+ execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
561
+ The device on which inputs and model weights should be placed before the forward pass. It can be one device
562
+ for the whole module, or a dictionary mapping module name to device.
563
+ offload (`bool`, *optional*, defaults to `False`):
564
+ Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
565
+ module, or a dictionary mapping module name to boolean.
566
+ weights_map (`Mapping[str, torch.Tensor]`, *optional*):
567
+ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
568
+ offload_buffers (`bool`, *optional*, defaults to `False`):
569
+ Whether or not to include the associated module's buffers when offloading.
570
+ module_name (`str`, *optional*, defaults to `""`):
571
+ The name of the module.
572
+ skip_keys (`str` or `List[str]`, *optional*):
573
+ A list of keys to ignore when moving inputs or outputs between devices.
574
+ preload_module_classes (`List[str]`, *optional*):
575
+ A list of classes whose instances should load all their weights (even in the submodules) at the beginning
576
+ of the forward. This should only be used for classes that have submodules which are registered but not
577
+ called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
578
+ `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
579
+ tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
580
+ A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
581
+ device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
582
+ instead of duplicating memory.
583
+ """
584
+ # If one device and one offload, we've got one hook.
585
+ if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
586
+ if not offload:
587
+ hook = AlignDevicesHook(
588
+ execution_device=execution_device,
589
+ io_same_device=True,
590
+ skip_keys=skip_keys,
591
+ place_submodules=True,
592
+ tied_params_map=tied_params_map,
593
+ )
594
+ add_hook_to_module(module, hook)
595
+ else:
596
+ attach_align_device_hook(
597
+ module,
598
+ execution_device=execution_device,
599
+ offload=True,
600
+ weights_map=weights_map,
601
+ offload_buffers=offload_buffers,
602
+ module_name=module_name,
603
+ skip_keys=skip_keys,
604
+ tied_params_map=tied_params_map,
605
+ )
606
+ return
607
+
608
+ if not isinstance(execution_device, Mapping):
609
+ execution_device = {key: execution_device for key in offload.keys()}
610
+ if not isinstance(offload, Mapping):
611
+ offload = {key: offload for key in execution_device.keys()}
612
+
613
+ if module_name in execution_device and module_name in offload and not offload[module_name]:
614
+ hook = AlignDevicesHook(
615
+ execution_device=execution_device[module_name],
616
+ offload_buffers=offload_buffers,
617
+ io_same_device=(module_name == ""),
618
+ place_submodules=True,
619
+ skip_keys=skip_keys,
620
+ tied_params_map=tied_params_map,
621
+ )
622
+ add_hook_to_module(module, hook)
623
+ attach_execution_device_hook(
624
+ module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
625
+ )
626
+ elif module_name in execution_device and module_name in offload:
627
+ attach_align_device_hook(
628
+ module,
629
+ execution_device=execution_device[module_name],
630
+ offload=True,
631
+ weights_map=weights_map,
632
+ offload_buffers=offload_buffers,
633
+ module_name=module_name,
634
+ skip_keys=skip_keys,
635
+ preload_module_classes=preload_module_classes,
636
+ tied_params_map=tied_params_map,
637
+ )
638
+ if not hasattr(module, "_hf_hook"):
639
+ hook = AlignDevicesHook(
640
+ execution_device=execution_device[module_name],
641
+ io_same_device=(module_name == ""),
642
+ skip_keys=skip_keys,
643
+ tied_params_map=tied_params_map,
644
+ )
645
+ add_hook_to_module(module, hook)
646
+ attach_execution_device_hook(
647
+ module,
648
+ execution_device[module_name],
649
+ preload_module_classes=preload_module_classes,
650
+ skip_keys=skip_keys,
651
+ tied_params_map=tied_params_map,
652
+ )
653
+ elif module_name == "":
654
+ hook = AlignDevicesHook(
655
+ execution_device=execution_device.get(""),
656
+ io_same_device=True,
657
+ skip_keys=skip_keys,
658
+ tied_params_map=tied_params_map,
659
+ )
660
+ add_hook_to_module(module, hook)
661
+
662
+ for child_name, child in module.named_children():
663
+ child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
664
+ attach_align_device_hook_on_blocks(
665
+ child,
666
+ execution_device=execution_device,
667
+ offload=offload,
668
+ weights_map=weights_map,
669
+ offload_buffers=offload_buffers,
670
+ module_name=child_name,
671
+ preload_module_classes=preload_module_classes,
672
+ skip_keys=skip_keys,
673
+ tied_params_map=tied_params_map,
674
+ )
675
+
676
+
677
+ class CpuOffload(ModelHook):
678
+ """
679
+ Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
680
+ the forward, the user needs to call the `init_hook` method again for this.
681
+
682
+ Args:
683
+ execution_device(`str`, `int` or `torch.device`, *optional*):
684
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
685
+ GPU 0 if there is a GPU, and finally to the CPU.
686
+ prev_module_hook (`UserCpuOffloadHook`, *optional*):
687
+ The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
688
+ passed, its offload method will be called just before the forward of the model to which this hook is
689
+ attached.
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ execution_device: Optional[Union[str, int, torch.device]] = None,
695
+ prev_module_hook: Optional["UserCpuOffloadHook"] = None,
696
+ ):
697
+ self.prev_module_hook = prev_module_hook
698
+
699
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
700
+
701
+ def init_hook(self, module):
702
+ return module.to("cpu")
703
+
704
+ def pre_forward(self, module, *args, **kwargs):
705
+ if self.prev_module_hook is not None:
706
+ self.prev_module_hook.offload()
707
+ clear_device_cache()
708
+ module.to(self.execution_device)
709
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
710
+
711
+
712
+ class UserCpuOffloadHook:
713
+ """
714
+ A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
715
+ or remove it entirely.
716
+ """
717
+
718
+ def __init__(self, model, hook):
719
+ self.model = model
720
+ self.hook = hook
721
+
722
+ def offload(self):
723
+ self.hook.init_hook(self.model)
724
+
725
+ def remove(self):
726
+ remove_hook_from_module(self.model)
.venv/Lib/site-packages/accelerate/inference.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from types import MethodType
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ from .state import PartialState
19
+ from .utils import (
20
+ calculate_maximum_sizes,
21
+ convert_bytes,
22
+ copy_tensor_to_devices,
23
+ ignorant_find_batch_size,
24
+ infer_auto_device_map,
25
+ is_pippy_available,
26
+ pad_input_tensors,
27
+ send_to_device,
28
+ )
29
+
30
+
31
+ def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
32
+ """
33
+ Calculates the device map for `model` with an offset for PiPPy
34
+ """
35
+ if num_processes == 1:
36
+ return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
37
+ if max_memory is None:
38
+ model_size, shared = calculate_maximum_sizes(model)
39
+
40
+ # Split into `n` chunks for each GPU
41
+ memory = (model_size + shared[0]) / num_processes
42
+ memory = convert_bytes(memory)
43
+ value, ending = memory.split(" ")
44
+
45
+ # Add a chunk to deal with potential extra shared memory instances
46
+ memory = math.ceil(float(value)) * 1.1
47
+ memory = f"{memory} {ending}"
48
+ max_memory = {i: memory for i in range(num_processes)}
49
+ device_map = infer_auto_device_map(
50
+ model,
51
+ max_memory=max_memory,
52
+ no_split_module_classes=no_split_module_classes,
53
+ clean_result=False,
54
+ )
55
+ return device_map
56
+
57
+
58
+ def find_pippy_batch_size(args, kwargs):
59
+ found_batch_size = None
60
+ if args is not None:
61
+ for arg in args:
62
+ found_batch_size = ignorant_find_batch_size(arg)
63
+ if found_batch_size is not None:
64
+ break
65
+ if kwargs is not None and found_batch_size is None:
66
+ for kwarg in kwargs.values():
67
+ found_batch_size = ignorant_find_batch_size(kwarg)
68
+ if found_batch_size is not None:
69
+ break
70
+ return found_batch_size
71
+
72
+
73
+ def build_pipeline(model, split_points, args, kwargs, num_chunks):
74
+ """
75
+ Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
76
+ in needed `args` and `kwargs` as the model needs on the CPU.
77
+
78
+ Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
79
+ `AcceleratorState.num_processes`
80
+ """
81
+ # Note: We import here to reduce import time from general modules, and isolate outside dependencies
82
+ from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
83
+
84
+ # We need to annotate the split points in the model for PiPPy
85
+ state = PartialState()
86
+ split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
87
+ pipe = pipeline(
88
+ model,
89
+ mb_args=args,
90
+ mb_kwargs=kwargs,
91
+ split_spec=split_spec,
92
+ )
93
+ stage = pipe.build_stage(state.local_process_index, device=state.device)
94
+ schedule = ScheduleGPipe(stage, num_chunks)
95
+
96
+ return schedule
97
+
98
+
99
+ def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
100
+ state = PartialState()
101
+ output = None
102
+
103
+ if state.num_processes == 1:
104
+ output = forward(*args, **kwargs)
105
+ elif state.is_local_main_process:
106
+ found_batch_size = find_pippy_batch_size(args, kwargs)
107
+ if found_batch_size is None:
108
+ raise ValueError("Could not find batch size from args or kwargs")
109
+ else:
110
+ if found_batch_size != num_chunks:
111
+ args = pad_input_tensors(args, found_batch_size, num_chunks)
112
+ kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
113
+ forward(*args, **kwargs)
114
+ elif state.is_last_process:
115
+ output = forward()
116
+ else:
117
+ forward()
118
+ if gather_output:
119
+ # Each node will get a copy of the full output which is only on the last GPU
120
+ output = copy_tensor_to_devices(output)
121
+ return output
122
+
123
+
124
+ def prepare_pippy(
125
+ model,
126
+ split_points: Optional[Union[str, List[str]]] = "auto",
127
+ no_split_module_classes: Optional[List[str]] = None,
128
+ example_args: Optional[Tuple[Any]] = (),
129
+ example_kwargs: Optional[Dict[str, Any]] = None,
130
+ num_chunks: Optional[int] = None,
131
+ gather_output: Optional[bool] = False,
132
+ ):
133
+ """
134
+ Wraps `model` for pipeline parallel inference.
135
+
136
+ Args:
137
+ model (`torch.nn.Module`):
138
+ A model we want to split for pipeline-parallel inference
139
+ split_points (`str` or `List[str]`, defaults to 'auto'):
140
+ How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
141
+ split given any model. Should be a list of layer names in the model to split by otherwise.
142
+ no_split_module_classes (`List[str]`):
143
+ A list of class names for layers we don't want to be split.
144
+ example_args (tuple of model inputs):
145
+ The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
146
+ this method if possible.
147
+ example_kwargs (dict of model inputs)
148
+ The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
149
+ *highly* limiting structure that requires the same keys be present at *all* inference calls. Not
150
+ recommended unless the prior condition is true for all cases.
151
+ num_chunks (`int`, defaults to the number of available GPUs):
152
+ The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
153
+ this can be tuned and played with. In general one should have num_chunks >= num_gpus.
154
+ gather_output (`bool`, defaults to `False`):
155
+ If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
156
+ """
157
+ if not is_pippy_available():
158
+ raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
159
+ state = PartialState()
160
+ example_args = send_to_device(example_args, "cpu")
161
+ example_kwargs = send_to_device(example_kwargs, "cpu")
162
+ if num_chunks is None:
163
+ num_chunks = state.num_processes
164
+ if split_points == "auto":
165
+ device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
166
+ split_points = []
167
+ for i in range(1, num_chunks):
168
+ split_points.append(next(k for k, v in device_map.items() if v == i))
169
+ model.hf_split_points = split_points
170
+ stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
171
+ model._original_forward = model.forward
172
+ model._original_call = model.__call__
173
+ model.pippy_stage = stage
174
+ model.hf_split_points = split_points
175
+
176
+ def forward(*args, **kwargs):
177
+ return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
178
+
179
+ # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
180
+ # Note: creates an infinite recursion loop with `generate`
181
+ model_forward = MethodType(forward, model)
182
+ forward.__wrapped__ = model_forward
183
+ model.forward = forward
184
+ return model
.venv/Lib/site-packages/accelerate/launchers.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import tempfile
18
+
19
+ import torch
20
+
21
+ from .state import AcceleratorState, PartialState
22
+ from .utils import (
23
+ PrecisionType,
24
+ PrepareForLaunch,
25
+ are_libraries_initialized,
26
+ check_cuda_p2p_ib_support,
27
+ get_gpu_info,
28
+ is_mps_available,
29
+ is_torch_version,
30
+ patch_environment,
31
+ )
32
+ from .utils.constants import ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION
33
+
34
+
35
+ def test_launch():
36
+ "Verify a `PartialState` can be initialized."
37
+ _ = PartialState()
38
+
39
+
40
+ def notebook_launcher(
41
+ function,
42
+ args=(),
43
+ num_processes=None,
44
+ mixed_precision="no",
45
+ use_port="29500",
46
+ master_addr="127.0.0.1",
47
+ node_rank=0,
48
+ num_nodes=1,
49
+ rdzv_backend="static",
50
+ rdzv_endpoint="",
51
+ rdzv_conf=None,
52
+ rdzv_id="none",
53
+ max_restarts=0,
54
+ monitor_interval=0.1,
55
+ log_line_prefix_template=None,
56
+ ):
57
+ """
58
+ Launches a training function, using several processes or multiple nodes if it's possible in the current environment
59
+ (TPU with multiple cores for instance).
60
+
61
+ <Tip warning={true}>
62
+
63
+ To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
64
+ any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
65
+
66
+ Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
67
+ of those calls have been made.
68
+
69
+ </Tip>
70
+
71
+ Args:
72
+ function (`Callable`):
73
+ The training function to execute. If it accepts arguments, the first argument should be the index of the
74
+ process run.
75
+ args (`Tuple`):
76
+ Tuple of arguments to pass to the function (it will receive `*args`).
77
+ num_processes (`int`, *optional*):
78
+ The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
79
+ the number of GPUs available otherwise.
80
+ mixed_precision (`str`, *optional*, defaults to `"no"`):
81
+ If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
82
+ use_port (`str`, *optional*, defaults to `"29500"`):
83
+ The port to use to communicate between processes when launching a multi-GPU training.
84
+ master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
85
+ The address to use for communication between processes.
86
+ node_rank (`int`, *optional*, defaults to 0):
87
+ The rank of the current node.
88
+ num_nodes (`int`, *optional*, defaults to 1):
89
+ The number of nodes to use for training.
90
+ rdzv_backend (`str`, *optional*, defaults to `"static"`):
91
+ The rendezvous method to use, such as 'static' (the default) or 'c10d'
92
+ rdzv_endpoint (`str`, *optional*, defaults to `""`):
93
+ The endpoint of the rdzv sync. storage.
94
+ rdzv_conf (`Dict`, *optional*, defaults to `None`):
95
+ Additional rendezvous configuration.
96
+ rdzv_id (`str`, *optional*, defaults to `"none"`):
97
+ The unique run id of the job.
98
+ max_restarts (`int`, *optional*, defaults to 0):
99
+ The maximum amount of restarts that elastic agent will conduct on workers before failure.
100
+ monitor_interval (`float`, *optional*, defaults to 0.1):
101
+ The interval in seconds that is used by the elastic_agent as a period of monitoring workers.
102
+ log_line_prefix_template (`str`, *optional*, defaults to `None`):
103
+ The prefix template for elastic launch logging. Available from PyTorch 2.2.0.
104
+
105
+ Example:
106
+
107
+ ```python
108
+ # Assume this is defined in a Jupyter Notebook on an instance with two GPUs
109
+ from accelerate import notebook_launcher
110
+
111
+
112
+ def train(*args):
113
+ # Your training function here
114
+ ...
115
+
116
+
117
+ notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16")
118
+ ```
119
+ """
120
+ # Are we in a google colab or a Kaggle Kernel?
121
+ in_colab = False
122
+ in_kaggle = False
123
+ if any(key.startswith("KAGGLE") for key in os.environ.keys()):
124
+ in_kaggle = True
125
+ elif "IPython" in sys.modules:
126
+ in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython())
127
+
128
+ try:
129
+ mixed_precision = PrecisionType(mixed_precision.lower())
130
+ except ValueError:
131
+ raise ValueError(
132
+ f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
133
+ )
134
+
135
+ if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None):
136
+ # TPU launch
137
+ import torch_xla.distributed.xla_multiprocessing as xmp
138
+
139
+ if len(AcceleratorState._shared_state) > 0:
140
+ raise ValueError(
141
+ "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside "
142
+ "your training function. Restart your notebook and make sure no cells initializes an "
143
+ "`Accelerator`."
144
+ )
145
+ if num_processes is None:
146
+ num_processes = 8
147
+
148
+ launcher = PrepareForLaunch(function, distributed_type="XLA")
149
+ print(f"Launching a training on {num_processes} TPU cores.")
150
+ xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork")
151
+ elif in_colab and get_gpu_info()[1] < 2:
152
+ # No need for a distributed launch otherwise as it's either CPU or one GPU.
153
+ if torch.cuda.is_available():
154
+ print("Launching training on one GPU.")
155
+ else:
156
+ print("Launching training on one CPU.")
157
+ function(*args)
158
+ else:
159
+ if num_processes is None:
160
+ raise ValueError(
161
+ "You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
162
+ )
163
+ if node_rank >= num_nodes:
164
+ raise ValueError("The node_rank must be less than the number of nodes.")
165
+ if num_processes > 1:
166
+ # Multi-GPU launch
167
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
168
+ from torch.multiprocessing import start_processes
169
+ from torch.multiprocessing.spawn import ProcessRaisedException
170
+
171
+ if len(AcceleratorState._shared_state) > 0:
172
+ raise ValueError(
173
+ "To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
174
+ "inside your training function. Restart your notebook and make sure no cells initializes an "
175
+ "`Accelerator`."
176
+ )
177
+ # Check for specific libraries known to initialize CUDA that users constantly use
178
+ problematic_imports = are_libraries_initialized("bitsandbytes")
179
+ if len(problematic_imports) > 0:
180
+ err = (
181
+ "Could not start distributed process. Libraries known to initialize CUDA upon import have been "
182
+ "imported already. Please keep these imports inside your training function to try and help with this:"
183
+ )
184
+ for lib_name in problematic_imports:
185
+ err += f"\n\t* `{lib_name}`"
186
+ raise RuntimeError(err)
187
+
188
+ patched_env = dict(
189
+ nproc=num_processes,
190
+ node_rank=node_rank,
191
+ world_size=num_nodes * num_processes,
192
+ master_addr=master_addr,
193
+ master_port=use_port,
194
+ mixed_precision=mixed_precision,
195
+ )
196
+
197
+ # Check for CUDA P2P and IB issues
198
+ if not check_cuda_p2p_ib_support():
199
+ patched_env["nccl_p2p_disable"] = "1"
200
+ patched_env["nccl_ib_disable"] = "1"
201
+
202
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
203
+ # process here (the other ones will be set be the launcher).
204
+ with patch_environment(**patched_env):
205
+ # First dummy launch
206
+ if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
207
+ launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
208
+ try:
209
+ start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
210
+ except ProcessRaisedException as e:
211
+ err = "An issue was found when verifying a stable environment for the notebook launcher."
212
+ if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
213
+ raise RuntimeError(
214
+ f"{err}"
215
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
216
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
217
+ "which one is problematic and causing CUDA to be initialized."
218
+ ) from e
219
+ else:
220
+ raise RuntimeError(f"{err} The following error was raised: {e}") from e
221
+ # Now the actual launch
222
+ launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
223
+ print(f"Launching training on {num_processes} GPUs.")
224
+ try:
225
+ if rdzv_conf is None:
226
+ rdzv_conf = {}
227
+ if rdzv_backend == "static":
228
+ rdzv_conf["rank"] = node_rank
229
+ if not rdzv_endpoint:
230
+ rdzv_endpoint = f"{master_addr}:{use_port}"
231
+ launch_config_kwargs = dict(
232
+ min_nodes=num_nodes,
233
+ max_nodes=num_nodes,
234
+ nproc_per_node=num_processes,
235
+ run_id=rdzv_id,
236
+ rdzv_endpoint=rdzv_endpoint,
237
+ rdzv_backend=rdzv_backend,
238
+ rdzv_configs=rdzv_conf,
239
+ max_restarts=max_restarts,
240
+ monitor_interval=monitor_interval,
241
+ start_method="fork",
242
+ )
243
+ if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
244
+ launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
245
+ elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
246
+ except ProcessRaisedException as e:
247
+ if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
248
+ raise RuntimeError(
249
+ "CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
250
+ "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
251
+ "Please review your imports and test them when running the `notebook_launcher()` to identify "
252
+ "which one is problematic and causing CUDA to be initialized."
253
+ ) from e
254
+ else:
255
+ raise RuntimeError(f"An issue was found when launching the training: {e}") from e
256
+
257
+ else:
258
+ # No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
259
+ if is_mps_available():
260
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
261
+ print("Launching training on MPS.")
262
+ elif torch.cuda.is_available():
263
+ print("Launching training on one GPU.")
264
+ else:
265
+ print("Launching training on CPU.")
266
+ function(*args)
267
+
268
+
269
+ def debug_launcher(function, args=(), num_processes=2):
270
+ """
271
+ Launches a training function using several processes on CPU for debugging purposes.
272
+
273
+ <Tip warning={true}>
274
+
275
+ This function is provided for internal testing and debugging, but it's not intended for real trainings. It will
276
+ only use the CPU.
277
+
278
+ </Tip>
279
+
280
+ Args:
281
+ function (`Callable`):
282
+ The training function to execute.
283
+ args (`Tuple`):
284
+ Tuple of arguments to pass to the function (it will receive `*args`).
285
+ num_processes (`int`, *optional*, defaults to 2):
286
+ The number of processes to use for training.
287
+ """
288
+ from torch.multiprocessing import start_processes
289
+
290
+ with tempfile.NamedTemporaryFile() as tmp_file:
291
+ # torch.distributed will expect a few environment variable to be here. We set the ones common to each
292
+ # process here (the other ones will be set be the launcher).
293
+ with patch_environment(
294
+ world_size=num_processes,
295
+ master_addr="127.0.0.1",
296
+ master_port="29500",
297
+ accelerate_mixed_precision="no",
298
+ accelerate_debug_rdv_file=tmp_file.name,
299
+ accelerate_use_cpu="yes",
300
+ ):
301
+ launcher = PrepareForLaunch(function, debug=True)
302
+ start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
.venv/Lib/site-packages/accelerate/local_sgd.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ from accelerate import Accelerator, DistributedType
17
+
18
+
19
+ class LocalSGD:
20
+ """
21
+ A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
22
+ on each device, and averages model weights every K synchronization step.
23
+
24
+ It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
25
+ this is a simple implementation that cannot support scenarios such as model parallelism.
26
+
27
+
28
+ Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
29
+ back to at least:
30
+
31
+ Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
32
+ arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)
33
+
34
+ We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
35
+
36
+ Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
37
+ Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)
38
+
39
+ """
40
+
41
+ def __enter__(self):
42
+ if self.enabled:
43
+ self.model_sync_obj = self.model.no_sync()
44
+ self.model_sync_obj.__enter__()
45
+
46
+ return self
47
+
48
+ def __exit__(self, type, value, tb):
49
+ if self.enabled:
50
+ # Average all models on exit
51
+ self._sync_and_avg_model_params()
52
+ self.model_sync_obj.__exit__(type, value, tb)
53
+
54
+ def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
55
+ """
56
+ Constructor.
57
+
58
+ Args:
59
+ model (`torch.nn.Module):
60
+ The model whose parameters we need to average.
61
+ accelerator (`Accelerator`):
62
+ Accelerator object.
63
+ local_sgd_steps (`int`):
64
+ A number of local SGD steps (before model parameters are synchronized).
65
+ enabled (`bool):
66
+ Local SGD is disabled if this parameter set to `False`.
67
+ """
68
+ if accelerator.distributed_type not in [
69
+ DistributedType.NO,
70
+ DistributedType.MULTI_CPU,
71
+ DistributedType.MULTI_GPU,
72
+ DistributedType.MULTI_XPU,
73
+ DistributedType.MULTI_MLU,
74
+ DistributedType.MULTI_MUSA,
75
+ DistributedType.MULTI_NPU,
76
+ ]:
77
+ raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
78
+ self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
79
+ self.num_steps = 0
80
+ if self.enabled:
81
+ self.accelerator = accelerator
82
+ self.model = model
83
+ self.local_sgd_steps = local_sgd_steps
84
+
85
+ def step(self):
86
+ """
87
+ This function makes a "step" and synchronizes model parameters if necessary.
88
+ """
89
+ self.num_steps += 1
90
+ if not self.enabled:
91
+ return
92
+
93
+ if self.num_steps % self.local_sgd_steps == 0:
94
+ self._sync_and_avg_model_params()
95
+
96
+ def _sync_and_avg_model_params(self):
97
+ """
98
+ Synchronize + Average model parameters across all GPUs
99
+ """
100
+
101
+ self.accelerator.wait_for_everyone()
102
+ with self.accelerator.autocast():
103
+ for param in self.model.parameters():
104
+ param.data = self.accelerator.reduce(param.data, reduction="mean")
.venv/Lib/site-packages/accelerate/logging.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import logging
17
+ import os
18
+
19
+ from .state import PartialState
20
+
21
+
22
+ class MultiProcessAdapter(logging.LoggerAdapter):
23
+ """
24
+ An adapter to assist with logging in multiprocess.
25
+
26
+ `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
27
+ or only the main executed one. Default is `main_process_only=True`.
28
+
29
+ Does not require an `Accelerator` object to be created first.
30
+ """
31
+
32
+ @staticmethod
33
+ def _should_log(main_process_only):
34
+ "Check if log should be performed"
35
+ state = PartialState()
36
+ return not main_process_only or (main_process_only and state.is_main_process)
37
+
38
+ def log(self, level, msg, *args, **kwargs):
39
+ """
40
+ Delegates logger call after checking if we should log.
41
+
42
+ Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
43
+ or only the main executed one. Default is `True` if not passed
44
+
45
+ Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
46
+ read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
47
+ break with the previous behavior.
48
+
49
+ `in_order` is ignored if `main_process_only` is passed.
50
+ """
51
+ if PartialState._shared_state == {}:
52
+ raise RuntimeError(
53
+ "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
54
+ )
55
+ main_process_only = kwargs.pop("main_process_only", True)
56
+ in_order = kwargs.pop("in_order", False)
57
+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
58
+ kwargs.setdefault("stacklevel", 2)
59
+
60
+ if self.isEnabledFor(level):
61
+ if self._should_log(main_process_only):
62
+ msg, kwargs = self.process(msg, kwargs)
63
+ self.logger.log(level, msg, *args, **kwargs)
64
+
65
+ elif in_order:
66
+ state = PartialState()
67
+ for i in range(state.num_processes):
68
+ if i == state.process_index:
69
+ msg, kwargs = self.process(msg, kwargs)
70
+ self.logger.log(level, msg, *args, **kwargs)
71
+ state.wait_for_everyone()
72
+
73
+ @functools.lru_cache(None)
74
+ def warning_once(self, *args, **kwargs):
75
+ """
76
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
77
+
78
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
79
+ cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
80
+ switch to another type of cache that includes the caller frame information in the hashing function.
81
+ """
82
+ self.warning(*args, **kwargs)
83
+
84
+
85
+ def get_logger(name: str, log_level: str = None):
86
+ """
87
+ Returns a `logging.Logger` for `name` that can handle multiprocessing.
88
+
89
+ If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
90
+ processes and in order, also pass `in_order=True`
91
+
92
+ Args:
93
+ name (`str`):
94
+ The name for the logger, such as `__file__`
95
+ log_level (`str`, *optional*):
96
+ The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
97
+
98
+ Example:
99
+
100
+ ```python
101
+ >>> from accelerate.logging import get_logger
102
+ >>> from accelerate import Accelerator
103
+
104
+ >>> logger = get_logger(__name__)
105
+
106
+ >>> accelerator = Accelerator()
107
+ >>> logger.info("My log", main_process_only=False)
108
+ >>> logger.debug("My log", main_process_only=True)
109
+
110
+ >>> logger = get_logger(__name__, log_level="DEBUG")
111
+ >>> logger.info("My log")
112
+ >>> logger.debug("My second log")
113
+
114
+ >>> array = ["a", "b", "c", "d"]
115
+ >>> letter_at_rank = array[accelerator.process_index]
116
+ >>> logger.info(letter_at_rank, in_order=True)
117
+ ```
118
+ """
119
+ if log_level is None:
120
+ log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
121
+ logger = logging.getLogger(name)
122
+ if log_level is not None:
123
+ logger.setLevel(log_level.upper())
124
+ logger.root.setLevel(log_level.upper())
125
+ return MultiProcessAdapter(logger, {})
.venv/Lib/site-packages/accelerate/memory_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+
17
+
18
+ warnings.warn(
19
+ "memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
20
+ "`from accelerate import find_executable_batch_size` to avoid this warning.",
21
+ FutureWarning,
22
+ )
.venv/Lib/site-packages/accelerate/optimizer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+
17
+ import torch
18
+
19
+ from .state import AcceleratorState, GradientState
20
+ from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+
27
+ def move_to_device(state, device):
28
+ if isinstance(state, (list, tuple)):
29
+ return honor_type(state, (move_to_device(t, device) for t in state))
30
+ elif isinstance(state, dict):
31
+ return type(state)({k: move_to_device(v, device) for k, v in state.items()})
32
+ elif isinstance(state, torch.Tensor):
33
+ return state.to(device)
34
+ return state
35
+
36
+
37
+ class AcceleratedOptimizer(torch.optim.Optimizer):
38
+ """
39
+ Internal wrapper around a torch optimizer.
40
+
41
+ Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
42
+ accumulation.
43
+
44
+ Args:
45
+ optimizer (`torch.optim.optimizer.Optimizer`):
46
+ The optimizer to wrap.
47
+ device_placement (`bool`, *optional*, defaults to `True`):
48
+ Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
49
+ `optimizer` on the right device.
50
+ scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
51
+ The scaler to use in the step function if training with mixed precision.
52
+ """
53
+
54
+ def __init__(self, optimizer, device_placement=True, scaler=None):
55
+ self.optimizer = optimizer
56
+ self.scaler = scaler
57
+ self.accelerator_state = AcceleratorState()
58
+ self.gradient_state = GradientState()
59
+ self.device_placement = device_placement
60
+ self._is_overflow = False
61
+
62
+ if self.scaler is not None:
63
+ self._accelerate_step_called = False
64
+ self._optimizer_original_step_method = self.optimizer.step
65
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
66
+
67
+ # Handle device placement
68
+ if device_placement:
69
+ state_dict = self.optimizer.state_dict()
70
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
71
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
72
+ else:
73
+ state_dict = move_to_device(state_dict, self.accelerator_state.device)
74
+ self.optimizer.load_state_dict(state_dict)
75
+
76
+ @property
77
+ def state(self):
78
+ return self.optimizer.state
79
+
80
+ @state.setter
81
+ def state(self, state):
82
+ self.optimizer.state = state
83
+
84
+ @property
85
+ def param_groups(self):
86
+ return self.optimizer.param_groups
87
+
88
+ @param_groups.setter
89
+ def param_groups(self, param_groups):
90
+ self.optimizer.param_groups = param_groups
91
+
92
+ @property
93
+ def defaults(self):
94
+ return self.optimizer.defaults
95
+
96
+ @defaults.setter
97
+ def defaults(self, defaults):
98
+ self.optimizer.defaults = defaults
99
+
100
+ def add_param_group(self, param_group):
101
+ self.optimizer.add_param_group(param_group)
102
+
103
+ def load_state_dict(self, state_dict):
104
+ if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
105
+ xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
106
+ self.optimizer.load_state_dict(state_dict)
107
+
108
+ def state_dict(self):
109
+ return self.optimizer.state_dict()
110
+
111
+ def zero_grad(self, set_to_none=None):
112
+ if self.gradient_state.sync_gradients:
113
+ accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
114
+ if accept_arg:
115
+ if set_to_none is None:
116
+ set_to_none = True
117
+ self.optimizer.zero_grad(set_to_none=set_to_none)
118
+ else:
119
+ if set_to_none is not None:
120
+ raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
121
+ self.optimizer.zero_grad()
122
+
123
+ def train(self):
124
+ """
125
+ Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
126
+ """
127
+ if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
128
+ self.optimizer.train()
129
+ elif (
130
+ hasattr(self.optimizer, "optimizer")
131
+ and hasattr(self.optimizer.optimizer, "train")
132
+ and callable(self.optimizer.optimizer.train)
133
+ ):
134
+ # the deepspeed optimizer further wraps the optimizer
135
+ self.optimizer.optimizer.train()
136
+
137
+ def eval(self):
138
+ """
139
+ Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
140
+ """
141
+ if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
142
+ self.optimizer.eval()
143
+
144
+ def step(self, closure=None):
145
+ if is_lomo_available():
146
+ from lomo_optim import AdaLomo, Lomo
147
+
148
+ if (
149
+ not self.gradient_state.is_xla_gradients_synced
150
+ and self.accelerator_state.distributed_type == DistributedType.XLA
151
+ ):
152
+ gradients = xm._fetch_gradients(self.optimizer)
153
+ xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
154
+ self.gradient_state.is_xla_gradients_synced = True
155
+
156
+ if is_lomo_available():
157
+ # `step` should be a no-op for LOMO optimizers.
158
+ if isinstance(self.optimizer, (Lomo, AdaLomo)):
159
+ return
160
+
161
+ if self.gradient_state.sync_gradients:
162
+ if self.scaler is not None:
163
+ self.optimizer.step = self._optimizer_patched_step_method
164
+
165
+ self.scaler.step(self.optimizer, closure)
166
+ self.scaler.update()
167
+
168
+ if not self._accelerate_step_called:
169
+ # If the optimizer step was skipped, gradient overflow was detected.
170
+ self._is_overflow = True
171
+ else:
172
+ self._is_overflow = False
173
+ # Reset the step method to the original one
174
+ self.optimizer.step = self._optimizer_original_step_method
175
+ # Reset the indicator
176
+ self._accelerate_step_called = False
177
+ else:
178
+ self.optimizer.step(closure)
179
+ if self.accelerator_state.distributed_type == DistributedType.XLA:
180
+ self.gradient_state.is_xla_gradients_synced = False
181
+
182
+ def _switch_parameters(self, parameters_map):
183
+ for param_group in self.optimizer.param_groups:
184
+ param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
185
+
186
+ @property
187
+ def step_was_skipped(self):
188
+ """Whether or not the optimizer step was skipped."""
189
+ return self._is_overflow
190
+
191
+ def __getstate__(self):
192
+ _ignored_keys = [
193
+ "_accelerate_step_called",
194
+ "_optimizer_original_step_method",
195
+ "_optimizer_patched_step_method",
196
+ ]
197
+ return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
198
+
199
+ def __setstate__(self, state):
200
+ self.__dict__.update(state)
201
+ if self.scaler is not None:
202
+ self._accelerate_step_called = False
203
+ self._optimizer_original_step_method = self.optimizer.step
204
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
205
+
206
+
207
+ def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
208
+ def patched_step(*args, **kwargs):
209
+ accelerated_optimizer._accelerate_step_called = True
210
+ return method(*args, **kwargs)
211
+
212
+ return patched_step
.venv/Lib/site-packages/accelerate/scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
16
+
17
+ import warnings
18
+
19
+ from .state import AcceleratorState, GradientState
20
+
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
23
+
24
+
25
+ class AcceleratedScheduler:
26
+ """
27
+ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
28
+ to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
29
+ precision training)
30
+
31
+ When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
32
+ step the scheduler to account for it.
33
+
34
+ Args:
35
+ scheduler (`torch.optim.lr_scheduler._LRScheduler`):
36
+ The scheduler to wrap.
37
+ optimizers (one or a list of `torch.optim.Optimizer`):
38
+ The optimizers used.
39
+ step_with_optimizer (`bool`, *optional*, defaults to `True`):
40
+ Whether or not the scheduler should be stepped at each optimizer step.
41
+ split_batches (`bool`, *optional*, defaults to `False`):
42
+ Whether or not the dataloaders split one batch across the different processes (so batch size is the same
43
+ regardless of the number of processes) or create batches on each process (so batch size is the original
44
+ batch size multiplied by the number of processes).
45
+ """
46
+
47
+ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
48
+ self.scheduler = scheduler
49
+ self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
50
+ self.split_batches = split_batches
51
+ self.step_with_optimizer = step_with_optimizer
52
+ self.gradient_state = GradientState()
53
+
54
+ def step(self, *args, **kwargs):
55
+ if not self.step_with_optimizer:
56
+ # No link between scheduler and optimizer -> just step
57
+ self.scheduler.step(*args, **kwargs)
58
+ return
59
+
60
+ # Otherwise, first make sure the optimizer was stepped.
61
+ if not self.gradient_state.sync_gradients:
62
+ if self.gradient_state.adjust_scheduler:
63
+ self.scheduler._step_count += 1
64
+ return
65
+
66
+ for opt in self.optimizers:
67
+ if opt.step_was_skipped:
68
+ return
69
+ if self.split_batches:
70
+ # Split batches -> the training dataloader batch size is not changed so one step per training step
71
+ self.scheduler.step(*args, **kwargs)
72
+ else:
73
+ # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
74
+ # num_processes steps per training step
75
+ num_processes = AcceleratorState().num_processes
76
+ for _ in range(num_processes):
77
+ # Special case when using OneCycle and `drop_last` was not used
78
+ if hasattr(self.scheduler, "total_steps"):
79
+ if self.scheduler._step_count <= self.scheduler.total_steps:
80
+ self.scheduler.step(*args, **kwargs)
81
+ else:
82
+ self.scheduler.step(*args, **kwargs)
83
+
84
+ # Passthroughs
85
+ def get_last_lr(self):
86
+ return self.scheduler.get_last_lr()
87
+
88
+ def state_dict(self):
89
+ return self.scheduler.state_dict()
90
+
91
+ def load_state_dict(self, state_dict):
92
+ self.scheduler.load_state_dict(state_dict)
93
+
94
+ def get_lr(self):
95
+ return self.scheduler.get_lr()
96
+
97
+ def print_lr(self, *args, **kwargs):
98
+ return self.scheduler.print_lr(*args, **kwargs)
.venv/Lib/site-packages/accelerate/state.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import os
19
+ import threading
20
+ import warnings
21
+ from contextlib import contextmanager
22
+ from functools import partial
23
+ from typing import Any, Callable, Optional
24
+
25
+ import torch
26
+
27
+ from .utils import (
28
+ DistributedType,
29
+ DynamoBackend,
30
+ GradientAccumulationPlugin,
31
+ check_cuda_p2p_ib_support,
32
+ check_fp8_capability,
33
+ deepspeed_required,
34
+ get_ccl_version,
35
+ get_cpu_distributed_information,
36
+ get_int_from_env,
37
+ is_ccl_available,
38
+ is_datasets_available,
39
+ is_deepspeed_available,
40
+ is_fp8_available,
41
+ is_ipex_available,
42
+ is_mlu_available,
43
+ is_mps_available,
44
+ is_musa_available,
45
+ is_npu_available,
46
+ is_torch_xla_available,
47
+ is_xpu_available,
48
+ parse_choice_from_env,
49
+ parse_flag_from_env,
50
+ set_numa_affinity,
51
+ )
52
+ from .utils.dataclasses import SageMakerDistributedType
53
+
54
+
55
+ if is_torch_xla_available():
56
+ import torch_xla.core.xla_model as xm
57
+
58
+ if is_mlu_available(check_device=False):
59
+ import torch_mlu # noqa: F401
60
+
61
+ if is_musa_available(check_device=False):
62
+ import torch_musa # noqa: F401
63
+
64
+ if is_npu_available(check_device=False):
65
+ import torch_npu # noqa: F401
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ def is_initialized() -> bool:
71
+ """
72
+ Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`,
73
+ but works as a module method.
74
+ """
75
+ return AcceleratorState._shared_state != {}
76
+
77
+
78
+ # Lambda function that does nothing
79
+ def do_nothing(*args, **kwargs):
80
+ return None
81
+
82
+
83
+ class ThreadLocalSharedDict(threading.local):
84
+ """
85
+ Descriptor that holds a dict shared between instances of a class in the same thread.
86
+
87
+ Note: Descriptors have slightly different semantics than just a dict field on its own.
88
+ `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the
89
+ underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside
90
+ the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor
91
+ object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`).
92
+
93
+ See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html
94
+
95
+ This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3).
96
+
97
+ See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3
98
+ """
99
+
100
+ def __init__(self, thread_local: bool = False):
101
+ self._storage = {}
102
+
103
+ def __get__(self, obj, objtype=None):
104
+ return self._storage
105
+
106
+ def __set__(self, obj, value):
107
+ self._storage = value
108
+
109
+
110
+ # Prefer global shared dictionary, except when using TPU.
111
+ SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict
112
+
113
+
114
+ # Inspired by Alex Martelli's 'Borg'.
115
+ class PartialState:
116
+ """
117
+ Singleton class that has information about the current training environment and functions to help with process
118
+ control. Designed to be used when only process control and device execution states are needed. Does *not* need to
119
+ be initialized from `Accelerator`.
120
+
121
+ Args:
122
+ cpu (`bool`, *optional*):
123
+ Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
124
+ `True` and force the execution on the CPU.
125
+ kwargs (additional keyword arguments, *optional*):
126
+ Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
127
+ found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
128
+
129
+ **Available attributes:**
130
+
131
+ - **device** (`torch.device`) -- The device to use.
132
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
133
+ in use.
134
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
135
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
136
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
137
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
138
+ - **process_index** (`int`) -- The index of the current process.
139
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
140
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
141
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
142
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
143
+
144
+ Example:
145
+ ```python
146
+ from accelerate.utils import InitProcessGroupKwargs
147
+
148
+ # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()`
149
+ kwargs = InitProcessGroupKwargs(...).to_kwargs()
150
+ state = PartialState(**kwargs)
151
+ ```
152
+ """
153
+
154
+ _shared_state = SharedDict()
155
+ _known_attrs = [
156
+ "_cpu",
157
+ "_mixed_precision",
158
+ "_shared_state",
159
+ "backend",
160
+ "debug",
161
+ "device",
162
+ "distributed_type",
163
+ "fork_launched",
164
+ "local_process_index",
165
+ "num_processes",
166
+ "process_index",
167
+ ]
168
+
169
+ def __init__(self, cpu: bool = False, **kwargs):
170
+ self.__dict__ = self._shared_state
171
+ if not self.initialized:
172
+ self._cpu = cpu
173
+ self.backend = None
174
+ env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
175
+ self.device = torch.device(env_device) if env_device is not None else None
176
+ self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
177
+ use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None)
178
+ dist_information = None
179
+ if use_sagemaker_dp is None:
180
+ use_sagemaker_dp = (
181
+ os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
182
+ and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
183
+ )
184
+
185
+ # Sets up self.backend + imports
186
+ original_backend = kwargs.pop("backend", None)
187
+ backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
188
+ if original_backend is not None and backend != original_backend:
189
+ raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
190
+ self.backend = backend
191
+ self.distributed_type = distributed_type
192
+ use_deepspeed = False
193
+ if not cpu and self.backend != "xla":
194
+ if int(os.environ.get("LOCAL_RANK", -1)) != -1:
195
+ # Deal with spawning deepspeed
196
+ if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
197
+ if not is_deepspeed_available():
198
+ raise ImportError(
199
+ "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
200
+ )
201
+ from deepspeed import comm as dist
202
+
203
+ if not dist.is_initialized():
204
+ dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
205
+ # We need to flag to `use_deepspeed` to be True to override `distributed_type` later
206
+ use_deepspeed = True
207
+ # Deal with all other backends but XPU and CPU, that gets handled special later
208
+ elif (
209
+ self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU)
210
+ and not torch.distributed.is_initialized()
211
+ ):
212
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
213
+ # XPU and CPU require special env configs to be set
214
+ if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU):
215
+ dist_information = get_cpu_distributed_information()
216
+ os.environ["RANK"] = str(dist_information.rank)
217
+ os.environ["WORLD_SIZE"] = str(dist_information.world_size)
218
+ os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
219
+ os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
220
+ if not os.environ.get("MASTER_PORT", None):
221
+ os.environ["MASTER_PORT"] = "29500"
222
+ if (
223
+ not os.environ.get("MASTER_ADDR", None)
224
+ and dist_information.local_world_size != dist_information.world_size
225
+ and self.backend != "mpi"
226
+ ):
227
+ raise ValueError(
228
+ "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, "
229
+ "please try exporting rank 0's hostname as `MASTER_ADDR`"
230
+ )
231
+ kwargs["rank"] = dist_information.rank
232
+ kwargs["world_size"] = dist_information.world_size
233
+
234
+ if (
235
+ self.distributed_type == DistributedType.MULTI_CPU
236
+ and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0
237
+ ):
238
+ import psutil
239
+
240
+ num_cpu_threads_per_process = int(
241
+ psutil.cpu_count(logical=False) / dist_information.local_world_size
242
+ )
243
+ if num_cpu_threads_per_process == 0:
244
+ num_cpu_threads_per_process = 1
245
+ torch.set_num_threads(num_cpu_threads_per_process)
246
+ warnings.warn(
247
+ f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob"
248
+ " performance."
249
+ )
250
+
251
+ if not torch.distributed.is_initialized():
252
+ torch.distributed.init_process_group(backend=self.backend, **kwargs)
253
+
254
+ # No backend == no distributed training
255
+ if self.backend is None:
256
+ self.distributed_type = DistributedType.NO
257
+ self.num_processes = 1
258
+ self.process_index = 0
259
+ self.local_process_index = 0
260
+ elif self.backend == "xla":
261
+ # XLA needs device setting first for `set_replication`
262
+ self.set_device()
263
+ xm.set_replication(self.device, xm.get_xla_supported_devices())
264
+ self.num_processes = xm.xrt_world_size()
265
+ self.process_index = xm.get_ordinal()
266
+ if is_torch_xla_available(check_is_tpu=True):
267
+ self.local_process_index = xm.get_local_ordinal()
268
+ else:
269
+ self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
270
+ else:
271
+ self.num_processes = torch.distributed.get_world_size()
272
+ self.process_index = torch.distributed.get_rank()
273
+ self.local_process_index = (
274
+ int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank
275
+ )
276
+ self.set_device()
277
+ # Now we can change to deepseed
278
+ if use_deepspeed:
279
+ self.distributed_type = DistributedType.DEEPSPEED
280
+
281
+ # Set CPU affinity if enabled
282
+ if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False):
283
+ set_numa_affinity(self.local_process_index)
284
+
285
+ # Check for old RTX 4000's that can't use P2P or IB and are on old drivers
286
+ if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
287
+ if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
288
+ raise NotImplementedError(
289
+ "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. "
290
+ 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
291
+ "will do this automatically."
292
+ )
293
+ # Important: This should be the *only* code outside of `self.initialized!`
294
+ self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)
295
+
296
+ def __repr__(self) -> str:
297
+ return (
298
+ f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
299
+ f"Num processes: {self.num_processes}\n"
300
+ f"Process index: {self.process_index}\n"
301
+ f"Local process index: {self.local_process_index}\n"
302
+ f"Device: {self.device}\n"
303
+ )
304
+
305
+ @staticmethod
306
+ def _reset_state():
307
+ "Resets `_shared_state`, is used internally and should not be called"
308
+ PartialState._shared_state.clear()
309
+
310
+ @property
311
+ def initialized(self) -> bool:
312
+ "Returns whether the `PartialState` has been initialized"
313
+ return self._shared_state != {}
314
+
315
+ @property
316
+ def use_distributed(self):
317
+ """
318
+ Whether the Accelerator is configured for distributed training
319
+ """
320
+ return self.distributed_type != DistributedType.NO and self.num_processes > 1
321
+
322
+ @property
323
+ def is_last_process(self) -> bool:
324
+ "Returns whether the current process is the last one"
325
+ return self.process_index == self.num_processes - 1
326
+
327
+ @property
328
+ def is_main_process(self) -> bool:
329
+ "Returns whether the current process is the main process"
330
+ return (
331
+ self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process
332
+ )
333
+
334
+ @property
335
+ def is_local_main_process(self) -> bool:
336
+ "Returns whether the current process is the main process on the local node"
337
+ return (
338
+ self.local_process_index == 0
339
+ if self.distributed_type != DistributedType.MEGATRON_LM
340
+ else self.is_last_process
341
+ )
342
+
343
+ def wait_for_everyone(self):
344
+ """
345
+ Will stop the execution of the current process until every other process has reached that point (so this does
346
+ nothing when the script is only run in one process). Useful to do before saving a model.
347
+
348
+ Example:
349
+
350
+ ```python
351
+ >>> # Assuming two GPU processes
352
+ >>> import time
353
+ >>> from accelerate.state import PartialState
354
+
355
+ >>> state = PartialState()
356
+ >>> if state.is_main_process:
357
+ ... time.sleep(2)
358
+ >>> else:
359
+ ... print("I'm waiting for the main process to finish its sleep...")
360
+ >>> state.wait_for_everyone()
361
+ >>> # Should print on every process at the same time
362
+ >>> print("Everyone is here")
363
+ ```
364
+ """
365
+ if self.distributed_type in (
366
+ DistributedType.MULTI_GPU,
367
+ DistributedType.MULTI_MLU,
368
+ DistributedType.MULTI_MUSA,
369
+ DistributedType.MULTI_NPU,
370
+ DistributedType.MULTI_XPU,
371
+ DistributedType.MULTI_CPU,
372
+ DistributedType.DEEPSPEED,
373
+ DistributedType.FSDP,
374
+ ):
375
+ torch.distributed.barrier()
376
+ elif self.distributed_type == DistributedType.XLA:
377
+ xm.rendezvous("accelerate.utils.wait_for_everyone")
378
+
379
+ def _goes_first(self, is_main: bool):
380
+ if not is_main:
381
+ self.wait_for_everyone()
382
+
383
+ yield
384
+
385
+ if is_main:
386
+ self.wait_for_everyone()
387
+
388
+ @contextmanager
389
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
390
+ """
391
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
392
+ distributed inference, such as with different prompts.
393
+
394
+ Note that when using a `dict`, all keys need to have the same number of elements.
395
+
396
+ Args:
397
+ inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`):
398
+ The input to split between processes.
399
+ apply_padding (`bool`, `optional`, defaults to `False`):
400
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
401
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
402
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
403
+
404
+
405
+ Example:
406
+
407
+ ```python
408
+ # Assume there are two processes
409
+ from accelerate import PartialState
410
+
411
+ state = PartialState()
412
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
413
+ print(inputs)
414
+ # Process 0
415
+ ["A", "B"]
416
+ # Process 1
417
+ ["C"]
418
+
419
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
420
+ print(inputs)
421
+ # Process 0
422
+ ["A", "B"]
423
+ # Process 1
424
+ ["C", "C"]
425
+ ```
426
+ """
427
+ if self.num_processes == 1:
428
+ yield inputs
429
+ return
430
+ length = len(inputs)
431
+ # Nested dictionary of any types
432
+ if isinstance(inputs, dict):
433
+ length = len(inputs[list(inputs.keys())[0]])
434
+ if not all(len(v) == length for v in inputs.values()):
435
+ raise ValueError("All values in the dictionary must have the same length")
436
+ num_samples_per_process, num_extras = divmod(length, self.num_processes)
437
+ start_index = self.process_index * num_samples_per_process + min(self.process_index, num_extras)
438
+ end_index = start_index + num_samples_per_process + (1 if self.process_index < num_extras else 0)
439
+
440
+ def _split_values(inputs, start_index, end_index):
441
+ if isinstance(inputs, (list, tuple, torch.Tensor)):
442
+ if start_index >= len(inputs):
443
+ result = inputs[-1:]
444
+ else:
445
+ result = inputs[start_index:end_index]
446
+ if apply_padding:
447
+ if isinstance(result, torch.Tensor):
448
+ from accelerate.utils import pad_across_processes, send_to_device
449
+
450
+ # The tensor needs to be on the device before we can pad it
451
+ tensorized_result = send_to_device(result, self.device)
452
+ result = pad_across_processes(tensorized_result, pad_index=inputs[-1])
453
+ else:
454
+ result += [result[-1]] * (num_samples_per_process + 1 - len(result))
455
+ return result
456
+ elif isinstance(inputs, dict):
457
+ for key in inputs.keys():
458
+ inputs[key] = _split_values(inputs[key], start_index, end_index)
459
+ return inputs
460
+ else:
461
+ if is_datasets_available():
462
+ from datasets import Dataset
463
+
464
+ if isinstance(inputs, Dataset):
465
+ if start_index >= len(inputs):
466
+ start_index = len(inputs) - 1
467
+ if end_index > len(inputs):
468
+ end_index = len(inputs)
469
+ result_idcs = list(range(start_index, end_index))
470
+ if apply_padding:
471
+ result_idcs += [end_index - 1] * (num_samples_per_process + 1 - len(result_idcs))
472
+ return inputs.select(result_idcs)
473
+ return inputs
474
+
475
+ yield _split_values(inputs, start_index, end_index)
476
+
477
+ @contextmanager
478
+ def main_process_first(self):
479
+ """
480
+ Lets the main process go first inside a with block.
481
+
482
+ The other processes will enter the with block after the main process exits.
483
+
484
+ Example:
485
+
486
+ ```python
487
+ >>> from accelerate import Accelerator
488
+
489
+ >>> accelerator = Accelerator()
490
+ >>> with accelerator.main_process_first():
491
+ ... # This will be printed first by process 0 then in a seemingly
492
+ ... # random order by the other processes.
493
+ ... print(f"This will be printed by process {accelerator.process_index}")
494
+ ```
495
+ """
496
+ yield from self._goes_first(self.is_main_process)
497
+
498
+ @contextmanager
499
+ def local_main_process_first(self):
500
+ """
501
+ Lets the local main process go inside a with block.
502
+
503
+ The other processes will enter the with block after the main process exits.
504
+
505
+ Example:
506
+
507
+ ```python
508
+ >>> from accelerate.state import PartialState
509
+
510
+ >>> state = PartialState()
511
+ >>> with state.local_main_process_first():
512
+ ... # This will be printed first by local process 0 then in a seemingly
513
+ ... # random order by the other processes.
514
+ ... print(f"This will be printed by process {state.local_process_index}")
515
+ ```
516
+ """
517
+ yield from self._goes_first(self.is_local_main_process)
518
+
519
+ def on_main_process(self, function: Callable[..., Any] = None):
520
+ """
521
+ Decorator that only runs the decorated function on the main process.
522
+
523
+ Args:
524
+ function (`Callable`): The function to decorate.
525
+
526
+ Example:
527
+
528
+ ```python
529
+ >>> from accelerate.state import PartialState
530
+
531
+ >>> state = PartialState()
532
+
533
+
534
+ >>> @state.on_main_process
535
+ ... def print_something():
536
+ ... print("This will be printed by process 0 only.")
537
+
538
+
539
+ >>> print_something()
540
+ "This will be printed by process 0 only"
541
+ ```
542
+ """
543
+ if not self.initialized:
544
+ raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.")
545
+ if self.is_main_process or not self.use_distributed:
546
+ return function
547
+ return do_nothing
548
+
549
+ def on_local_main_process(self, function: Callable[..., Any] = None):
550
+ """
551
+ Decorator that only runs the decorated function on the local main process.
552
+
553
+ Args:
554
+ function (`Callable`): The function to decorate.
555
+
556
+ Example:
557
+ ```python
558
+ # Assume we have 2 servers with 4 processes each.
559
+ from accelerate.state import PartialState
560
+
561
+ state = PartialState()
562
+
563
+
564
+ @state.on_local_main_process
565
+ def print_something():
566
+ print("This will be printed by process 0 only on each server.")
567
+
568
+
569
+ print_something()
570
+ # On server 1:
571
+ "This will be printed by process 0 only"
572
+ # On server 2:
573
+ "This will be printed by process 0 only"
574
+ ```
575
+ """
576
+ if self.is_local_main_process or not self.use_distributed:
577
+ return function
578
+ return do_nothing
579
+
580
+ def on_last_process(self, function: Callable[..., Any]):
581
+ """
582
+ Decorator that only runs the decorated function on the last process.
583
+
584
+ Args:
585
+ function (`Callable`): The function to decorate.
586
+
587
+ Example:
588
+ ```python
589
+ # Assume we have 4 processes.
590
+ from accelerate.state import PartialState
591
+
592
+ state = PartialState()
593
+
594
+
595
+ @state.on_last_process
596
+ def print_something():
597
+ print(f"Printed on process {state.process_index}")
598
+
599
+
600
+ print_something()
601
+ "Printed on process 3"
602
+ ```
603
+ """
604
+ if self.is_last_process or not self.use_distributed:
605
+ return function
606
+ return do_nothing
607
+
608
+ def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
609
+ """
610
+ Decorator that only runs the decorated function on the process with the given index.
611
+
612
+ Args:
613
+ function (`Callable`, `optional`):
614
+ The function to decorate.
615
+ process_index (`int`, `optional`):
616
+ The index of the process on which to run the function.
617
+
618
+ Example:
619
+ ```python
620
+ # Assume we have 4 processes.
621
+ from accelerate.state import PartialState
622
+
623
+ state = PartialState()
624
+
625
+
626
+ @state.on_process(process_index=2)
627
+ def print_something():
628
+ print(f"Printed on process {state.process_index}")
629
+
630
+
631
+ print_something()
632
+ "Printed on process 2"
633
+ ```
634
+ """
635
+ if function is None:
636
+ return partial(self.on_process, process_index=process_index)
637
+ if (self.process_index == process_index) or (not self.use_distributed):
638
+ return function
639
+ return do_nothing
640
+
641
+ def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
642
+ """
643
+ Decorator that only runs the decorated function on the process with the given index on the current node.
644
+
645
+ Args:
646
+ function (`Callable`, *optional*):
647
+ The function to decorate.
648
+ local_process_index (`int`, *optional*):
649
+ The index of the local process on which to run the function.
650
+
651
+ Example:
652
+ ```python
653
+ # Assume we have 2 servers with 4 processes each.
654
+ from accelerate import Accelerator
655
+
656
+ accelerator = Accelerator()
657
+
658
+
659
+ @accelerator.on_local_process(local_process_index=2)
660
+ def print_something():
661
+ print(f"Printed on process {accelerator.local_process_index}")
662
+
663
+
664
+ print_something()
665
+ # On server 1:
666
+ "Printed on process 2"
667
+ # On server 2:
668
+ "Printed on process 2"
669
+ ```
670
+ """
671
+ if function is None:
672
+ return partial(self.on_local_process, local_process_index=local_process_index)
673
+ if (self.local_process_index == local_process_index) or (not self.use_distributed):
674
+ return function
675
+ return do_nothing
676
+
677
+ def print(self, *args, **kwargs):
678
+ if self.is_local_main_process:
679
+ print(*args, **kwargs)
680
+
681
+ @property
682
+ def default_device(self) -> torch.device:
683
+ """
684
+ Returns the default device which is:
685
+ - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
686
+ - CUDA if `torch.cuda.is_available()`
687
+ - MLU if `is_mlu_available()`
688
+ - MUSA if `is_musa_available()`
689
+ - NPU if `is_npu_available()`
690
+ - CPU otherwise
691
+ """
692
+ if is_mps_available():
693
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
694
+ return torch.device("mps")
695
+ elif is_mlu_available():
696
+ return torch.device("mlu")
697
+ elif is_musa_available():
698
+ return torch.device("musa")
699
+ # NPU should be checked before CUDA when using `transfer_to_npu`
700
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
701
+ elif is_npu_available():
702
+ return torch.device("npu")
703
+ elif torch.cuda.is_available():
704
+ return torch.device("cuda")
705
+ elif is_xpu_available():
706
+ return torch.device("xpu")
707
+ else:
708
+ return torch.device("cpu")
709
+
710
+ def _prepare_backend(
711
+ self, cpu: bool = False, sagemaker_dp=False, backend: str = None
712
+ ) -> tuple[str, DistributedType]:
713
+ "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly"
714
+ distributed_type = None
715
+ if sagemaker_dp:
716
+ import smdistributed.dataparallel.torch.torch_smddp # noqa
717
+
718
+ backend = "smddp"
719
+ distributed_type = DistributedType.MULTI_GPU
720
+ elif is_torch_xla_available():
721
+ backend = "xla"
722
+ distributed_type = DistributedType.XLA
723
+ elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
724
+ if is_mlu_available():
725
+ backend = "cncl"
726
+ distributed_type = DistributedType.MULTI_MLU
727
+ elif is_musa_available():
728
+ backend = "mccl"
729
+ distributed_type = DistributedType.MULTI_MUSA
730
+ # NPU should be checked before CUDA when using `transfer_to_npu`
731
+ # See issue #3020: https://github.com/huggingface/accelerate/issues/3020
732
+ elif is_npu_available():
733
+ backend = "hccl"
734
+ distributed_type = DistributedType.MULTI_NPU
735
+ elif torch.cuda.is_available():
736
+ if backend is None:
737
+ backend = "nccl"
738
+ distributed_type = DistributedType.MULTI_GPU
739
+
740
+ if distributed_type is None and (
741
+ int(os.environ.get("LOCAL_RANK", -1)) != -1
742
+ or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1
743
+ ):
744
+ if not cpu and is_xpu_available():
745
+ distributed_type = DistributedType.MULTI_XPU
746
+ else:
747
+ distributed_type = DistributedType.MULTI_CPU
748
+
749
+ if (
750
+ backend in (None, "ccl")
751
+ and is_ccl_available()
752
+ and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU)
753
+ ):
754
+ if get_ccl_version() >= "1.12":
755
+ import oneccl_bindings_for_pytorch # noqa: F401
756
+ else:
757
+ import torch_ccl # noqa: F401
758
+
759
+ backend = "ccl"
760
+ elif backend in (None, "mpi") and torch.distributed.is_mpi_available():
761
+ backend = "mpi"
762
+ else:
763
+ backend = "gloo"
764
+ if distributed_type is None:
765
+ distributed_type = DistributedType.NO
766
+
767
+ return backend, distributed_type
768
+
769
+ def set_device(self):
770
+ """
771
+ Sets the device in `self.device` to the current distributed environment.
772
+ """
773
+ if self.device is not None:
774
+ return
775
+ if self.distributed_type == DistributedType.NO:
776
+ self.device = torch.device("cpu") if self._cpu else self.default_device
777
+ return
778
+ device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower()
779
+ if device not in ("cpu", "gpu", "mlu", "musa", "npu", "xpu", "xla"):
780
+ raise ValueError(
781
+ f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!"
782
+ )
783
+ if device == "xla":
784
+ self.device = xm.xla_device()
785
+ else:
786
+ if device == "gpu":
787
+ device = "cuda"
788
+ device_module = getattr(torch, device)
789
+ device_index = self.local_process_index % device_module.device_count()
790
+ self.device = torch.device(device, device_index)
791
+ device_module.set_device(self.device)
792
+
793
+ def destroy_process_group(self, group=None):
794
+ """
795
+ Destroys the process group. If one is not specified, the default process group is destroyed.
796
+ """
797
+ if self.fork_launched and group is None:
798
+ return
799
+ # needed when using torch.distributed.init_process_group
800
+ if torch.distributed.is_initialized():
801
+ torch.distributed.destroy_process_group(group)
802
+
803
+ def __getattr__(self, name: str):
804
+ # By this point we know that no attributes of `self` contain `name`,
805
+ # so we just modify the error message
806
+ if name in self._known_attrs:
807
+ raise AttributeError(
808
+ f"`PartialState` object has no attribute `{name}`. "
809
+ "This happens if `PartialState._reset_state()` was called and "
810
+ "an `Accelerator` or `PartialState` was not reinitialized."
811
+ )
812
+ # Raise a typical AttributeError
813
+ raise AttributeError(f"'PartialState' object has no attribute '{name}'")
814
+
815
+
816
+ class AcceleratorState:
817
+ """
818
+ Singleton class that has information about the current training environment.
819
+
820
+ **Available attributes:**
821
+
822
+ - **device** (`torch.device`) -- The device to use.
823
+ - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
824
+ in use.
825
+ - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
826
+ - **local_process_index** (`int`) -- The index of the current process on the current server.
827
+ - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
828
+ of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8').
829
+ - **num_processes** (`int`) -- The number of processes currently launched in parallel.
830
+ - **process_index** (`int`) -- The index of the current process.
831
+ - **is_last_process** (`bool`) -- Whether or not the current process is the last one.
832
+ - **is_main_process** (`bool`) -- Whether or not the current process is the main one.
833
+ - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node.
834
+ - **debug** (`bool`) -- Whether or not the current script is being run in debug mode.
835
+ """
836
+
837
+ _shared_state = SharedDict()
838
+ _known_attrs = PartialState._known_attrs + [
839
+ "deepspeed_plugin",
840
+ "use_ipex",
841
+ "fsdp_plugin",
842
+ "megatron_lm_plugin",
843
+ "dynamo_plugin",
844
+ ]
845
+
846
+ def __init__(
847
+ self,
848
+ mixed_precision: str = None,
849
+ cpu: bool = False,
850
+ dynamo_plugin=None,
851
+ deepspeed_plugin=None,
852
+ fsdp_plugin=None,
853
+ megatron_lm_plugin=None,
854
+ _from_accelerator: bool = False,
855
+ **kwargs,
856
+ ):
857
+ self.__dict__ = self._shared_state
858
+ if parse_flag_from_env("ACCELERATE_USE_CPU"):
859
+ cpu = True
860
+ if PartialState._shared_state == {}:
861
+ PartialState(cpu, **kwargs)
862
+ self.__dict__.update(PartialState._shared_state)
863
+ self._check_initialized(mixed_precision, cpu)
864
+ if not self.initialized:
865
+ self.deepspeed_plugins = None
866
+ self.use_ipex = None
867
+ mixed_precision = (
868
+ parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
869
+ if mixed_precision is None
870
+ else mixed_precision.lower()
871
+ )
872
+ if mixed_precision == "fp8":
873
+ if not is_fp8_available():
874
+ raise ValueError(
875
+ "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
876
+ )
877
+ elif not check_fp8_capability():
878
+ logger.warning(
879
+ f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
880
+ "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
881
+ "or higher, compute capability of 8.9 or higher). Will use FP16 instead."
882
+ )
883
+ mixed_precision = "fp16"
884
+
885
+ self.dynamo_plugin = dynamo_plugin
886
+ if not _from_accelerator:
887
+ raise ValueError(
888
+ "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
889
+ "before using any functionality from the `accelerate` library."
890
+ )
891
+ # deepspeed handles mixed_precision using deepspeed_config
892
+ self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
893
+ if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
894
+ if mixed_precision == "bf16":
895
+ if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
896
+ os.environ["XLA_USE_BF16"] = str(0)
897
+ os.environ["XLA_DOWNCAST_BF16"] = str(1)
898
+ self.downcast_bfloat = True
899
+ else:
900
+ os.environ["XLA_USE_BF16"] = str(1)
901
+ os.environ["XLA_DOWNCAST_BF16"] = str(0)
902
+ self.downcast_bfloat = False
903
+ elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
904
+ self.deepspeed_plugins = deepspeed_plugin
905
+ self.distributed_type = DistributedType.DEEPSPEED
906
+ elif self.distributed_type in [
907
+ DistributedType.MULTI_GPU,
908
+ DistributedType.MULTI_MLU,
909
+ DistributedType.MULTI_MUSA,
910
+ DistributedType.MULTI_NPU,
911
+ DistributedType.MULTI_XPU,
912
+ ]:
913
+ if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
914
+ self.distributed_type = DistributedType.FSDP
915
+ if self._mixed_precision != "no":
916
+ fsdp_plugin.set_mixed_precision(self._mixed_precision)
917
+ self.fsdp_plugin = fsdp_plugin
918
+ if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
919
+ DistributedType.MULTI_XPU,
920
+ ]:
921
+ self.distributed_type = DistributedType.MEGATRON_LM
922
+ megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
923
+ self.megatron_lm_plugin = megatron_lm_plugin
924
+ elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
925
+ if is_ipex_available():
926
+ # check if user disables it explicitly
927
+ self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True)
928
+ else:
929
+ self.use_ipex = False
930
+ if (
931
+ self.dynamo_plugin.backend != DynamoBackend.NO
932
+ and self._mixed_precision == "no"
933
+ and self.device.type == "cuda"
934
+ ):
935
+ torch.backends.cuda.matmul.allow_tf32 = True
936
+ if (
937
+ self.dynamo_plugin.backend != DynamoBackend.NO
938
+ and self._mixed_precision == "no"
939
+ and self.device.type == "musa"
940
+ ):
941
+ torch.backends.musa.matmul.allow_tf32 = True
942
+ PartialState._shared_state["distributed_type"] = self.distributed_type
943
+
944
+ @property
945
+ def initialized(self) -> bool:
946
+ return self._shared_state != PartialState._shared_state
947
+
948
+ def __repr__(self):
949
+ repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
950
+ if self.distributed_type == DistributedType.DEEPSPEED:
951
+ repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
952
+ return repr
953
+
954
+ def _check_initialized(self, mixed_precision=None, cpu=None):
955
+ "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
956
+ if self.initialized:
957
+ err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
958
+ if cpu and self.device.type != "cpu":
959
+ raise ValueError(err.format(flag="cpu=True"))
960
+ if (
961
+ mixed_precision is not None
962
+ and mixed_precision != self._mixed_precision
963
+ and self.distributed_type != DistributedType.DEEPSPEED
964
+ ):
965
+ raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'"))
966
+
967
+ @property
968
+ def mixed_precision(self):
969
+ if self.distributed_type == DistributedType.DEEPSPEED:
970
+ config = self.deepspeed_plugin.deepspeed_config
971
+ if config.get("fp16", {}).get("enabled", False):
972
+ mixed_precision = "fp16"
973
+ elif config.get("bf16", {}).get("enabled", False):
974
+ mixed_precision = "bf16"
975
+ else:
976
+ mixed_precision = "no"
977
+ else:
978
+ mixed_precision = self._mixed_precision
979
+ return mixed_precision
980
+
981
+ @staticmethod
982
+ def _reset_state(reset_partial_state: bool = False):
983
+ "Resets `_shared_state`, is used internally and should not be called"
984
+ AcceleratorState._shared_state.clear()
985
+ if reset_partial_state:
986
+ PartialState._reset_state()
987
+
988
+ def destroy_process_group(self, group=None):
989
+ """
990
+ Destroys the process group. If one is not specified, the default process group is destroyed.
991
+
992
+ If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
993
+ """
994
+ PartialState().destroy_process_group(group)
995
+
996
+ @property
997
+ def fork_launched(self):
998
+ return PartialState().fork_launched
999
+
1000
+ @property
1001
+ def use_distributed(self):
1002
+ """
1003
+ Whether the Accelerator is configured for distributed training
1004
+ """
1005
+ return PartialState().use_distributed
1006
+
1007
+ @property
1008
+ def is_last_process(self) -> bool:
1009
+ "Returns whether the current process is the last one"
1010
+ return PartialState().is_last_process
1011
+
1012
+ @property
1013
+ def is_main_process(self) -> bool:
1014
+ "Returns whether the current process is the main process"
1015
+ return PartialState().is_main_process
1016
+
1017
+ @property
1018
+ def is_local_main_process(self) -> bool:
1019
+ "Returns whether the current process is the main process on the local node"
1020
+ return PartialState().is_local_main_process
1021
+
1022
+ def wait_for_everyone(self):
1023
+ PartialState().wait_for_everyone()
1024
+
1025
+ @contextmanager
1026
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
1027
+ """
1028
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
1029
+ distributed inference, such as with different prompts.
1030
+
1031
+ Note that when using a `dict`, all keys need to have the same number of elements.
1032
+
1033
+ Args:
1034
+ inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
1035
+ The input to split between processes.
1036
+ apply_padding (`bool`, `optional`, defaults to `False`):
1037
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
1038
+ number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing
1039
+ in less inputs than there are processes. If so, just remember to drop the padded elements afterwards.
1040
+
1041
+
1042
+ Example:
1043
+
1044
+ ```python
1045
+ # Assume there are two processes
1046
+ from accelerate.state import AcceleratorState
1047
+
1048
+ state = AcceleratorState()
1049
+ with state.split_between_processes(["A", "B", "C"]) as inputs:
1050
+ print(inputs)
1051
+ # Process 0
1052
+ ["A", "B"]
1053
+ # Process 1
1054
+ ["C"]
1055
+
1056
+ with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
1057
+ print(inputs)
1058
+ # Process 0
1059
+ ["A", "B"]
1060
+ # Process 1
1061
+ ["C", "C"]
1062
+ ```
1063
+ """
1064
+ with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
1065
+ yield inputs
1066
+
1067
+ @contextmanager
1068
+ def main_process_first(self):
1069
+ """
1070
+ Lets the main process go first inside a with block.
1071
+
1072
+ The other processes will enter the with block after the main process exits.
1073
+ """
1074
+ with PartialState().main_process_first():
1075
+ yield
1076
+
1077
+ @contextmanager
1078
+ def local_main_process_first(self):
1079
+ """
1080
+ Lets the local main process go inside a with block.
1081
+
1082
+ The other processes will enter the with block after the main process exits.
1083
+ """
1084
+ with PartialState().local_main_process_first():
1085
+ yield
1086
+
1087
+ @property
1088
+ def deepspeed_plugin(self):
1089
+ """
1090
+ Returns the currently active DeepSpeedPlugin.
1091
+
1092
+ If not using deepspeed, returns `None`.
1093
+ """
1094
+ # To maintain original behavior, return None if not using deepspeed.
1095
+ if self.distributed_type != DistributedType.DEEPSPEED:
1096
+ return None
1097
+ from accelerate.utils.deepspeed import get_active_deepspeed_plugin
1098
+
1099
+ return get_active_deepspeed_plugin(self)
1100
+
1101
+ @deepspeed_required
1102
+ def get_deepspeed_plugin(self, name: str):
1103
+ """
1104
+ Returns the DeepSpeedPlugin with the given plugin_key.
1105
+ """
1106
+ return self.deepspeed_plugins[name]
1107
+
1108
+ @deepspeed_required
1109
+ def select_deepspeed_plugin(self, name: str = None):
1110
+ """
1111
+ Activates the DeepSpeedPlugin with the given `name`, and will disable all other plugins.
1112
+ """
1113
+ for key, plugin in self.deepspeed_plugins.items():
1114
+ if key != name:
1115
+ plugin._unselect()
1116
+ self.deepspeed_plugins[name].select(_from_accelerator_state=True)
1117
+
1118
+ def print(self, *args, **kwargs):
1119
+ PartialState().print(*args, **kwargs)
1120
+
1121
+ def __getattr__(self, name: str):
1122
+ # By this point we know that no attributes of `self` contain `name`,
1123
+ # so we just modify the error message
1124
+ if name in self._known_attrs:
1125
+ raise AttributeError(
1126
+ f"`AcceleratorState` object has no attribute `{name}`. "
1127
+ "This happens if `AcceleratorState._reset_state()` was called and "
1128
+ "an `Accelerator` or `PartialState` was not reinitialized."
1129
+ )
1130
+ # Raise a typical AttributeError
1131
+ raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'")
1132
+
1133
+
1134
+ class GradientState:
1135
+ """
1136
+ Singleton class that has information related to gradient synchronization for gradient accumulation
1137
+
1138
+ **Available attributes:**
1139
+
1140
+ - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader
1141
+ - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader
1142
+ - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices
1143
+ - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over
1144
+ - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are
1145
+ being iterated over
1146
+ - **num_steps** (`int`) -- The number of steps to accumulate over
1147
+ - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient
1148
+ accumulation
1149
+ - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
1150
+ iteration and the number of total steps reset
1151
+ - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
1152
+ as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
1153
+ after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
1154
+ is_xla_gradients_synced is always true.
1155
+ """
1156
+
1157
+ _shared_state = SharedDict()
1158
+
1159
+ def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None):
1160
+ self.__dict__ = self._shared_state
1161
+ if not self.initialized:
1162
+ self.sync_gradients = True
1163
+ self.active_dataloader = None
1164
+ self.dataloader_references = [None]
1165
+ self.plugin_kwargs = (
1166
+ gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
1167
+ )
1168
+ self._is_xla_gradients_synced = False
1169
+
1170
+ # Plugin args are different and can be updated
1171
+ if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs():
1172
+ self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs()
1173
+
1174
+ @property
1175
+ def num_steps(self) -> int:
1176
+ "Returns the number of steps to accumulate over"
1177
+ return self.plugin_kwargs.get("num_steps", 1)
1178
+
1179
+ @property
1180
+ def adjust_scheduler(self) -> bool:
1181
+ "Returns whether the scheduler should be adjusted"
1182
+ return self.plugin_kwargs.get("adjust_scheduler", False)
1183
+
1184
+ @property
1185
+ def sync_with_dataloader(self) -> bool:
1186
+ "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset"
1187
+ return self.plugin_kwargs.get("sync_with_dataloader", True)
1188
+
1189
+ @property
1190
+ def initialized(self) -> bool:
1191
+ "Returns whether the `GradientState` has been initialized"
1192
+ return GradientState._shared_state != {}
1193
+
1194
+ @property
1195
+ def end_of_dataloader(self) -> bool:
1196
+ "Returns whether we have reached the end of the current dataloader"
1197
+ if not self.in_dataloader:
1198
+ return False
1199
+ return self.active_dataloader.end_of_dataloader
1200
+
1201
+ @property
1202
+ def remainder(self) -> int:
1203
+ "Returns the number of extra samples that were added from padding the dataloader"
1204
+ if not self.in_dataloader:
1205
+ return -1
1206
+ return self.active_dataloader.remainder
1207
+
1208
+ def __repr__(self):
1209
+ return (
1210
+ f"Sync Gradients: {self.sync_gradients}\n"
1211
+ f"At end of current dataloader: {self.end_of_dataloader}\n"
1212
+ f"Extra samples added: {self.remainder}\n"
1213
+ f"Gradient accumulation plugin: {self.plugin_kwargs}\n"
1214
+ )
1215
+
1216
+ @property
1217
+ def is_xla_gradients_synced(self):
1218
+ "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true."
1219
+ if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False):
1220
+ return True
1221
+ return self._is_xla_gradients_synced
1222
+
1223
+ @is_xla_gradients_synced.setter
1224
+ def is_xla_gradients_synced(self, is_synced):
1225
+ "Set the _is_xla_gradients_synced attribute."
1226
+ self._is_xla_gradients_synced = is_synced
1227
+
1228
+ def _set_sync_gradients(self, sync_gradients):
1229
+ "Private function that sets whether gradients should be synchronized. Users should not have to call this."
1230
+ self.sync_gradients = sync_gradients
1231
+ # Allow grad-sync to automatically work on TPUs
1232
+ if (
1233
+ self.sync_gradients
1234
+ and is_torch_xla_available(check_is_tpu=True)
1235
+ and PartialState().distributed_type == DistributedType.XLA
1236
+ ):
1237
+ xm.mark_step()
1238
+
1239
+ def _add_dataloader(self, dataloader):
1240
+ "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
1241
+ self.active_dataloader = dataloader
1242
+ self.dataloader_references.append(self.active_dataloader)
1243
+
1244
+ def _remove_dataloader(self, dataloader):
1245
+ "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
1246
+ self.dataloader_references.remove(dataloader)
1247
+ self.active_dataloader = self.dataloader_references[-1]
1248
+
1249
+ @property
1250
+ def in_dataloader(self) -> bool:
1251
+ "Returns whether the current process is in a dataloader"
1252
+ return self.active_dataloader is not None
1253
+
1254
+ @staticmethod
1255
+ def _reset_state():
1256
+ "Resets `_shared_state`, is used internally and should not be called"
1257
+ GradientState._shared_state.clear()
.venv/Lib/site-packages/accelerate/tracking.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Expectation:
16
+ # Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
17
+
18
+ import json
19
+ import os
20
+ import time
21
+ from functools import wraps
22
+ from typing import Any, Dict, List, Optional, Union
23
+
24
+ import yaml
25
+
26
+ from .logging import get_logger
27
+ from .state import PartialState
28
+ from .utils import (
29
+ LoggerType,
30
+ is_aim_available,
31
+ is_clearml_available,
32
+ is_comet_ml_available,
33
+ is_dvclive_available,
34
+ is_mlflow_available,
35
+ is_tensorboard_available,
36
+ is_wandb_available,
37
+ listify,
38
+ )
39
+
40
+
41
+ _available_trackers = []
42
+
43
+ if is_tensorboard_available():
44
+ _available_trackers.append(LoggerType.TENSORBOARD)
45
+
46
+ if is_wandb_available():
47
+ _available_trackers.append(LoggerType.WANDB)
48
+
49
+ if is_comet_ml_available():
50
+ _available_trackers.append(LoggerType.COMETML)
51
+
52
+ if is_aim_available():
53
+ _available_trackers.append(LoggerType.AIM)
54
+
55
+ if is_mlflow_available():
56
+ _available_trackers.append(LoggerType.MLFLOW)
57
+
58
+ if is_clearml_available():
59
+ _available_trackers.append(LoggerType.CLEARML)
60
+
61
+ if is_dvclive_available():
62
+ _available_trackers.append(LoggerType.DVCLIVE)
63
+
64
+ logger = get_logger(__name__)
65
+
66
+
67
+ def on_main_process(function):
68
+ """
69
+ Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
70
+ attribute in a class.
71
+
72
+ Checks at function execution rather than initialization time, not triggering the initialization of the
73
+ `PartialState`.
74
+ """
75
+
76
+ @wraps(function)
77
+ def execute_on_main_process(self, *args, **kwargs):
78
+ if getattr(self, "main_process_only", False):
79
+ return PartialState().on_main_process(function)(self, *args, **kwargs)
80
+ else:
81
+ return function(self, *args, **kwargs)
82
+
83
+ return execute_on_main_process
84
+
85
+
86
+ def get_available_trackers():
87
+ "Returns a list of all supported available trackers in the system"
88
+ return _available_trackers
89
+
90
+
91
+ class GeneralTracker:
92
+ """
93
+ A base Tracker class to be used for all logging integration implementations.
94
+
95
+ Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
96
+ [`Accelerator`].
97
+
98
+ Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
99
+
100
+ `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
101
+ (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
102
+ tracking mechanism used by a tracker class (such as the `run` for wandb)
103
+
104
+ Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
105
+ other functions should occur on the main process or across all processes (by default will use `True`)
106
+ """
107
+
108
+ main_process_only = True
109
+
110
+ def __init__(self, _blank=False):
111
+ if not _blank:
112
+ err = ""
113
+ if not hasattr(self, "name"):
114
+ err += "`name`"
115
+ if not hasattr(self, "requires_logging_directory"):
116
+ if len(err) > 0:
117
+ err += ", "
118
+ err += "`requires_logging_directory`"
119
+
120
+ # as tracker is a @property that relies on post-init
121
+ if "tracker" not in dir(self):
122
+ if len(err) > 0:
123
+ err += ", "
124
+ err += "`tracker`"
125
+ if len(err) > 0:
126
+ raise NotImplementedError(
127
+ f"The implementation for this tracker class is missing the following "
128
+ f"required attributes. Please define them in the class definition: "
129
+ f"{err}"
130
+ )
131
+
132
+ def store_init_configuration(self, values: dict):
133
+ """
134
+ Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
135
+ functionality of a tracking API.
136
+
137
+ Args:
138
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
139
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
140
+ `str`, `float`, `int`, or `None`.
141
+ """
142
+ pass
143
+
144
+ def log(self, values: dict, step: Optional[int], **kwargs):
145
+ """
146
+ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
147
+ special behavior for the `step parameter.
148
+
149
+ Args:
150
+ values (Dictionary `str` to `str`, `float`, or `int`):
151
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
152
+ step (`int`, *optional*):
153
+ The run step. If included, the log will be affiliated with this step.
154
+ """
155
+ pass
156
+
157
+ def finish(self):
158
+ """
159
+ Should run any finalizing functions within the tracking API. If the API should not have one, just don't
160
+ overwrite that method.
161
+ """
162
+ pass
163
+
164
+
165
+ class TensorBoardTracker(GeneralTracker):
166
+ """
167
+ A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
168
+
169
+ Args:
170
+ run_name (`str`):
171
+ The name of the experiment run
172
+ logging_dir (`str`, `os.PathLike`):
173
+ Location for TensorBoard logs to be stored.
174
+ **kwargs (additional keyword arguments, *optional*):
175
+ Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
176
+ """
177
+
178
+ name = "tensorboard"
179
+ requires_logging_directory = True
180
+
181
+ @on_main_process
182
+ def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
183
+ try:
184
+ from torch.utils import tensorboard
185
+ except ModuleNotFoundError:
186
+ import tensorboardX as tensorboard
187
+ super().__init__()
188
+ self.run_name = run_name
189
+ self.logging_dir = os.path.join(logging_dir, run_name)
190
+ self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
191
+ logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
192
+ logger.debug(
193
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
194
+ )
195
+
196
+ @property
197
+ def tracker(self):
198
+ return self.writer
199
+
200
+ @on_main_process
201
+ def store_init_configuration(self, values: dict):
202
+ """
203
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
204
+ hyperparameters in a yaml file for future use.
205
+
206
+ Args:
207
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
208
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
209
+ `str`, `float`, `int`, or `None`.
210
+ """
211
+ self.writer.add_hparams(values, metric_dict={})
212
+ self.writer.flush()
213
+ project_run_name = time.time()
214
+ dir_name = os.path.join(self.logging_dir, str(project_run_name))
215
+ os.makedirs(dir_name, exist_ok=True)
216
+ with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
217
+ try:
218
+ yaml.dump(values, outfile)
219
+ except yaml.representer.RepresenterError:
220
+ logger.error("Serialization to store hyperparameters failed")
221
+ raise
222
+ logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
223
+
224
+ @on_main_process
225
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
226
+ """
227
+ Logs `values` to the current run.
228
+
229
+ Args:
230
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
231
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
232
+ `str` to `float`/`int`.
233
+ step (`int`, *optional*):
234
+ The run step. If included, the log will be affiliated with this step.
235
+ kwargs:
236
+ Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
237
+ `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
238
+ """
239
+ values = listify(values)
240
+ for k, v in values.items():
241
+ if isinstance(v, (int, float)):
242
+ self.writer.add_scalar(k, v, global_step=step, **kwargs)
243
+ elif isinstance(v, str):
244
+ self.writer.add_text(k, v, global_step=step, **kwargs)
245
+ elif isinstance(v, dict):
246
+ self.writer.add_scalars(k, v, global_step=step, **kwargs)
247
+ self.writer.flush()
248
+ logger.debug("Successfully logged to TensorBoard")
249
+
250
+ @on_main_process
251
+ def log_images(self, values: dict, step: Optional[int], **kwargs):
252
+ """
253
+ Logs `images` to the current run.
254
+
255
+ Args:
256
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
257
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
258
+ step (`int`, *optional*):
259
+ The run step. If included, the log will be affiliated with this step.
260
+ kwargs:
261
+ Additional key word arguments passed along to the `SummaryWriter.add_image` method.
262
+ """
263
+ for k, v in values.items():
264
+ self.writer.add_images(k, v, global_step=step, **kwargs)
265
+ logger.debug("Successfully logged images to TensorBoard")
266
+
267
+ @on_main_process
268
+ def finish(self):
269
+ """
270
+ Closes `TensorBoard` writer
271
+ """
272
+ self.writer.close()
273
+ logger.debug("TensorBoard writer closed")
274
+
275
+
276
+ class WandBTracker(GeneralTracker):
277
+ """
278
+ A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
279
+
280
+ Args:
281
+ run_name (`str`):
282
+ The name of the experiment run.
283
+ **kwargs (additional keyword arguments, *optional*):
284
+ Additional key word arguments passed along to the `wandb.init` method.
285
+ """
286
+
287
+ name = "wandb"
288
+ requires_logging_directory = False
289
+ main_process_only = False
290
+
291
+ @on_main_process
292
+ def __init__(self, run_name: str, **kwargs):
293
+ super().__init__()
294
+ self.run_name = run_name
295
+
296
+ import wandb
297
+
298
+ self.run = wandb.init(project=self.run_name, **kwargs)
299
+ logger.debug(f"Initialized WandB project {self.run_name}")
300
+ logger.debug(
301
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
302
+ )
303
+
304
+ @property
305
+ def tracker(self):
306
+ return self.run
307
+
308
+ @on_main_process
309
+ def store_init_configuration(self, values: dict):
310
+ """
311
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
312
+
313
+ Args:
314
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
315
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
316
+ `str`, `float`, `int`, or `None`.
317
+ """
318
+ import wandb
319
+
320
+ wandb.config.update(values, allow_val_change=True)
321
+ logger.debug("Stored initial configuration hyperparameters to WandB")
322
+
323
+ @on_main_process
324
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
325
+ """
326
+ Logs `values` to the current run.
327
+
328
+ Args:
329
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
330
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
331
+ `str` to `float`/`int`.
332
+ step (`int`, *optional*):
333
+ The run step. If included, the log will be affiliated with this step.
334
+ kwargs:
335
+ Additional key word arguments passed along to the `wandb.log` method.
336
+ """
337
+ self.run.log(values, step=step, **kwargs)
338
+ logger.debug("Successfully logged to WandB")
339
+
340
+ @on_main_process
341
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
342
+ """
343
+ Logs `images` to the current run.
344
+
345
+ Args:
346
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
347
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
348
+ step (`int`, *optional*):
349
+ The run step. If included, the log will be affiliated with this step.
350
+ kwargs:
351
+ Additional key word arguments passed along to the `wandb.log` method.
352
+ """
353
+ import wandb
354
+
355
+ for k, v in values.items():
356
+ self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
357
+ logger.debug("Successfully logged images to WandB")
358
+
359
+ @on_main_process
360
+ def log_table(
361
+ self,
362
+ table_name: str,
363
+ columns: List[str] = None,
364
+ data: List[List[Any]] = None,
365
+ dataframe: Any = None,
366
+ step: Optional[int] = None,
367
+ **kwargs,
368
+ ):
369
+ """
370
+ Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
371
+ with `columns` and `data` or with `dataframe`.
372
+
373
+ Args:
374
+ table_name (`str`):
375
+ The name to give to the logged table on the wandb workspace
376
+ columns (list of `str`, *optional*):
377
+ The name of the columns on the table
378
+ data (List of List of Any data type, *optional*):
379
+ The data to be logged in the table
380
+ dataframe (Any data type, *optional*):
381
+ The data to be logged in the table
382
+ step (`int`, *optional*):
383
+ The run step. If included, the log will be affiliated with this step.
384
+ """
385
+ import wandb
386
+
387
+ values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
388
+ self.log(values, step=step, **kwargs)
389
+
390
+ @on_main_process
391
+ def finish(self):
392
+ """
393
+ Closes `wandb` writer
394
+ """
395
+ self.run.finish()
396
+ logger.debug("WandB run closed")
397
+
398
+
399
+ class CometMLTracker(GeneralTracker):
400
+ """
401
+ A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
402
+
403
+ API keys must be stored in a Comet config file.
404
+
405
+ Args:
406
+ run_name (`str`):
407
+ The name of the experiment run.
408
+ **kwargs (additional keyword arguments, *optional*):
409
+ Additional key word arguments passed along to the `Experiment.__init__` method.
410
+ """
411
+
412
+ name = "comet_ml"
413
+ requires_logging_directory = False
414
+
415
+ @on_main_process
416
+ def __init__(self, run_name: str, **kwargs):
417
+ super().__init__()
418
+ self.run_name = run_name
419
+
420
+ from comet_ml import Experiment
421
+
422
+ self.writer = Experiment(project_name=run_name, **kwargs)
423
+ logger.debug(f"Initialized CometML project {self.run_name}")
424
+ logger.debug(
425
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
426
+ )
427
+
428
+ @property
429
+ def tracker(self):
430
+ return self.writer
431
+
432
+ @on_main_process
433
+ def store_init_configuration(self, values: dict):
434
+ """
435
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
436
+
437
+ Args:
438
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
439
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
440
+ `str`, `float`, `int`, or `None`.
441
+ """
442
+ self.writer.log_parameters(values)
443
+ logger.debug("Stored initial configuration hyperparameters to CometML")
444
+
445
+ @on_main_process
446
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
447
+ """
448
+ Logs `values` to the current run.
449
+
450
+ Args:
451
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
452
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
453
+ `str` to `float`/`int`.
454
+ step (`int`, *optional*):
455
+ The run step. If included, the log will be affiliated with this step.
456
+ kwargs:
457
+ Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
458
+ or `Experiment.log_metrics` method based on the contents of `values`.
459
+ """
460
+ if step is not None:
461
+ self.writer.set_step(step)
462
+ for k, v in values.items():
463
+ if isinstance(v, (int, float)):
464
+ self.writer.log_metric(k, v, step=step, **kwargs)
465
+ elif isinstance(v, str):
466
+ self.writer.log_other(k, v, **kwargs)
467
+ elif isinstance(v, dict):
468
+ self.writer.log_metrics(v, step=step, **kwargs)
469
+ logger.debug("Successfully logged to CometML")
470
+
471
+ @on_main_process
472
+ def finish(self):
473
+ """
474
+ Closes `comet-ml` writer
475
+ """
476
+ self.writer.end()
477
+ logger.debug("CometML run closed")
478
+
479
+
480
+ class AimTracker(GeneralTracker):
481
+ """
482
+ A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
483
+
484
+ Args:
485
+ run_name (`str`):
486
+ The name of the experiment run.
487
+ **kwargs (additional keyword arguments, *optional*):
488
+ Additional key word arguments passed along to the `Run.__init__` method.
489
+ """
490
+
491
+ name = "aim"
492
+ requires_logging_directory = True
493
+
494
+ @on_main_process
495
+ def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
496
+ self.run_name = run_name
497
+
498
+ from aim import Run
499
+
500
+ self.writer = Run(repo=logging_dir, **kwargs)
501
+ self.writer.name = self.run_name
502
+ logger.debug(f"Initialized Aim project {self.run_name}")
503
+ logger.debug(
504
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
505
+ )
506
+
507
+ @property
508
+ def tracker(self):
509
+ return self.writer
510
+
511
+ @on_main_process
512
+ def store_init_configuration(self, values: dict):
513
+ """
514
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
515
+
516
+ Args:
517
+ values (`dict`):
518
+ Values to be stored as initial hyperparameters as key-value pairs.
519
+ """
520
+ self.writer["hparams"] = values
521
+
522
+ @on_main_process
523
+ def log(self, values: dict, step: Optional[int], **kwargs):
524
+ """
525
+ Logs `values` to the current run.
526
+
527
+ Args:
528
+ values (`dict`):
529
+ Values to be logged as key-value pairs.
530
+ step (`int`, *optional*):
531
+ The run step. If included, the log will be affiliated with this step.
532
+ kwargs:
533
+ Additional key word arguments passed along to the `Run.track` method.
534
+ """
535
+ # Note: replace this with the dictionary support when merged
536
+ for key, value in values.items():
537
+ self.writer.track(value, name=key, step=step, **kwargs)
538
+
539
+ @on_main_process
540
+ def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[Dict[str, dict]] = None):
541
+ """
542
+ Logs `images` to the current run.
543
+
544
+ Args:
545
+ values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
546
+ Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
547
+ tuple is provided, the first element should be the image and the second element should be the caption.
548
+ step (`int`, *optional*):
549
+ The run step. If included, the log will be affiliated with this step.
550
+ kwargs (`Dict[str, dict]`):
551
+ Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
552
+ keys `aim_image` and `track`, respectively.
553
+ """
554
+ import aim
555
+
556
+ aim_image_kw = {}
557
+ track_kw = {}
558
+
559
+ if kwargs is not None:
560
+ aim_image_kw = kwargs.get("aim_image", {})
561
+ track_kw = kwargs.get("track", {})
562
+
563
+ for key, value in values.items():
564
+ if isinstance(value, tuple):
565
+ img, caption = value
566
+ else:
567
+ img, caption = value, ""
568
+ aim_image = aim.Image(img, caption=caption, **aim_image_kw)
569
+ self.writer.track(aim_image, name=key, step=step, **track_kw)
570
+
571
+ @on_main_process
572
+ def finish(self):
573
+ """
574
+ Closes `aim` writer
575
+ """
576
+ self.writer.close()
577
+
578
+
579
+ class MLflowTracker(GeneralTracker):
580
+ """
581
+ A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
582
+
583
+ Args:
584
+ experiment_name (`str`, *optional*):
585
+ Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
586
+ logging_dir (`str` or `os.PathLike`, defaults to `"."`):
587
+ Location for mlflow logs to be stored.
588
+ run_id (`str`, *optional*):
589
+ If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
590
+ end time is unset and its status is set to running, but the run’s other attributes (source_version,
591
+ source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
592
+ tags (`Dict[str, str]`, *optional*):
593
+ An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
594
+ run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
595
+ set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
596
+ nested_run (`bool`, *optional*, defaults to `False`):
597
+ Controls whether run is nested in parent run. True creates a nested run. Environment variable
598
+ MLFLOW_NESTED_RUN has priority over this argument.
599
+ run_name (`str`, *optional*):
600
+ Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
601
+ description (`str`, *optional*):
602
+ An optional string that populates the description box of the run. If a run is being resumed, the
603
+ description is set on the resumed run. If a new run is being created, the description is set on the new
604
+ run.
605
+ """
606
+
607
+ name = "mlflow"
608
+ requires_logging_directory = False
609
+
610
+ @on_main_process
611
+ def __init__(
612
+ self,
613
+ experiment_name: str = None,
614
+ logging_dir: Optional[Union[str, os.PathLike]] = None,
615
+ run_id: Optional[str] = None,
616
+ tags: Optional[Union[Dict[str, Any], str]] = None,
617
+ nested_run: Optional[bool] = False,
618
+ run_name: Optional[str] = None,
619
+ description: Optional[str] = None,
620
+ ):
621
+ experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name)
622
+ run_id = os.environ.get("MLFLOW_RUN_ID", run_id)
623
+ tags = os.environ.get("MLFLOW_TAGS", tags)
624
+ if isinstance(tags, str):
625
+ tags = json.loads(tags)
626
+
627
+ nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run)
628
+
629
+ import mlflow
630
+
631
+ exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'")
632
+ if len(exps) > 0:
633
+ if len(exps) > 1:
634
+ logger.warning("Multiple experiments with the same name found. Using first one.")
635
+ experiment_id = exps[0].experiment_id
636
+ else:
637
+ experiment_id = mlflow.create_experiment(
638
+ name=experiment_name,
639
+ artifact_location=logging_dir,
640
+ tags=tags,
641
+ )
642
+
643
+ self.active_run = mlflow.start_run(
644
+ run_id=run_id,
645
+ experiment_id=experiment_id,
646
+ run_name=run_name,
647
+ nested=nested_run,
648
+ tags=tags,
649
+ description=description,
650
+ )
651
+
652
+ logger.debug(f"Initialized mlflow experiment {experiment_name}")
653
+ logger.debug(
654
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
655
+ )
656
+
657
+ @property
658
+ def tracker(self):
659
+ return self.active_run
660
+
661
+ @on_main_process
662
+ def store_init_configuration(self, values: dict):
663
+ """
664
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
665
+
666
+ Args:
667
+ values (`dict`):
668
+ Values to be stored as initial hyperparameters as key-value pairs.
669
+ """
670
+ import mlflow
671
+
672
+ for name, value in list(values.items()):
673
+ # internally, all values are converted to str in MLflow
674
+ if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
675
+ logger.warning_once(
676
+ f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
677
+ f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
678
+ )
679
+ del values[name]
680
+
681
+ values_list = list(values.items())
682
+
683
+ # MLflow cannot log more than 100 values in one go, so we have to split it
684
+ for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
685
+ mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
686
+
687
+ logger.debug("Stored initial configuration hyperparameters to MLflow")
688
+
689
+ @on_main_process
690
+ def log(self, values: dict, step: Optional[int]):
691
+ """
692
+ Logs `values` to the current run.
693
+
694
+ Args:
695
+ values (`dict`):
696
+ Values to be logged as key-value pairs.
697
+ step (`int`, *optional*):
698
+ The run step. If included, the log will be affiliated with this step.
699
+ """
700
+ metrics = {}
701
+ for k, v in values.items():
702
+ if isinstance(v, (int, float)):
703
+ metrics[k] = v
704
+ else:
705
+ logger.warning_once(
706
+ f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
707
+ "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
708
+ )
709
+ import mlflow
710
+
711
+ mlflow.log_metrics(metrics, step=step)
712
+ logger.debug("Successfully logged to mlflow")
713
+
714
+ @on_main_process
715
+ def finish(self):
716
+ """
717
+ End the active MLflow run.
718
+ """
719
+ import mlflow
720
+
721
+ mlflow.end_run()
722
+
723
+
724
+ class ClearMLTracker(GeneralTracker):
725
+ """
726
+ A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
727
+
728
+ Args:
729
+ run_name (`str`, *optional*):
730
+ Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
731
+ argument.
732
+ **kwargs (additional keyword arguments, *optional*):
733
+ Kwargs passed along to the `Task.__init__` method.
734
+ """
735
+
736
+ name = "clearml"
737
+ requires_logging_directory = False
738
+
739
+ @on_main_process
740
+ def __init__(self, run_name: str = None, **kwargs):
741
+ from clearml import Task
742
+
743
+ current_task = Task.current_task()
744
+ self._initialized_externally = False
745
+ if current_task:
746
+ self._initialized_externally = True
747
+ self.task = current_task
748
+ return
749
+
750
+ kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name))
751
+ kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name))
752
+ self.task = Task.init(**kwargs)
753
+
754
+ @property
755
+ def tracker(self):
756
+ return self.task
757
+
758
+ @on_main_process
759
+ def store_init_configuration(self, values: dict):
760
+ """
761
+ Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
762
+
763
+ Args:
764
+ values (`dict`):
765
+ Values to be stored as initial hyperparameters as key-value pairs.
766
+ """
767
+ return self.task.connect_configuration(values)
768
+
769
+ @on_main_process
770
+ def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
771
+ """
772
+ Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
773
+ ints or floats
774
+
775
+ Args:
776
+ values (`Dict[str, Union[int, float]]`):
777
+ Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
778
+ be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
779
+ Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
780
+ step (`int`, *optional*):
781
+ If specified, the values will be reported as scalars, with the iteration number equal to `step`.
782
+ Otherwise they will be reported as single values.
783
+ kwargs:
784
+ Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
785
+ `clearml.Logger.report_scalar` methods.
786
+ """
787
+ clearml_logger = self.task.get_logger()
788
+ for k, v in values.items():
789
+ if not isinstance(v, (int, float)):
790
+ logger.warning_once(
791
+ "Accelerator is attempting to log a value of "
792
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
793
+ "This invocation of ClearML logger's report_scalar() "
794
+ "is incorrect so we dropped this attribute."
795
+ )
796
+ continue
797
+ if step is None:
798
+ clearml_logger.report_single_value(name=k, value=v, **kwargs)
799
+ continue
800
+ title, series = ClearMLTracker._get_title_series(k)
801
+ clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
802
+
803
+ @on_main_process
804
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
805
+ """
806
+ Logs `images` to the current run.
807
+
808
+ Args:
809
+ values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
810
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
811
+ step (`int`, *optional*):
812
+ The run step. If included, the log will be affiliated with this step.
813
+ kwargs:
814
+ Additional key word arguments passed along to the `clearml.Logger.report_image` method.
815
+ """
816
+ clearml_logger = self.task.get_logger()
817
+ for k, v in values.items():
818
+ title, series = ClearMLTracker._get_title_series(k)
819
+ clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
820
+
821
+ @on_main_process
822
+ def log_table(
823
+ self,
824
+ table_name: str,
825
+ columns: List[str] = None,
826
+ data: List[List[Any]] = None,
827
+ dataframe: Any = None,
828
+ step: Optional[int] = None,
829
+ **kwargs,
830
+ ):
831
+ """
832
+ Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
833
+
834
+ Args:
835
+ table_name (`str`):
836
+ The name of the table
837
+ columns (list of `str`, *optional*):
838
+ The name of the columns on the table
839
+ data (List of List of Any data type, *optional*):
840
+ The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
841
+ the name of the columns of the table
842
+ dataframe (Any data type, *optional*):
843
+ The data to be logged in the table
844
+ step (`int`, *optional*):
845
+ The run step. If included, the log will be affiliated with this step.
846
+ kwargs:
847
+ Additional key word arguments passed along to the `clearml.Logger.report_table` method.
848
+ """
849
+ to_report = dataframe
850
+ if dataframe is None:
851
+ if data is None:
852
+ raise ValueError(
853
+ "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
854
+ )
855
+ to_report = [columns] + data if columns else data
856
+ title, series = ClearMLTracker._get_title_series(table_name)
857
+ self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
858
+
859
+ @on_main_process
860
+ def finish(self):
861
+ """
862
+ Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
863
+ function is a noop
864
+ """
865
+ if self.task and not self._initialized_externally:
866
+ self.task.close()
867
+
868
+ @staticmethod
869
+ def _get_title_series(name):
870
+ for prefix in ["eval", "test", "train"]:
871
+ if name.startswith(prefix + "_"):
872
+ return name[len(prefix) + 1 :], prefix
873
+ return name, "train"
874
+
875
+
876
+ class DVCLiveTracker(GeneralTracker):
877
+ """
878
+ A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
879
+
880
+ Args:
881
+ run_name (`str`, *optional*):
882
+ Ignored for dvclive. See `kwargs` instead.
883
+ kwargs:
884
+ Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
885
+
886
+ Example:
887
+
888
+ ```py
889
+ from accelerate import Accelerator
890
+
891
+ accelerator = Accelerator(log_with="dvclive")
892
+ accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
893
+ ```
894
+ """
895
+
896
+ name = "dvclive"
897
+ requires_logging_directory = False
898
+
899
+ @on_main_process
900
+ def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
901
+ from dvclive import Live
902
+
903
+ super().__init__()
904
+ self.live = live if live is not None else Live(**kwargs)
905
+
906
+ @property
907
+ def tracker(self):
908
+ return self.live
909
+
910
+ @on_main_process
911
+ def store_init_configuration(self, values: dict):
912
+ """
913
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
914
+ hyperparameters in a yaml file for future use.
915
+
916
+ Args:
917
+ values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
918
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
919
+ `str`, `float`, or `int`.
920
+ """
921
+ self.live.log_params(values)
922
+
923
+ @on_main_process
924
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
925
+ """
926
+ Logs `values` to the current run.
927
+
928
+ Args:
929
+ values (Dictionary `str` to `str`, `float`, or `int`):
930
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
931
+ step (`int`, *optional*):
932
+ The run step. If included, the log will be affiliated with this step.
933
+ kwargs:
934
+ Additional key word arguments passed along to `dvclive.Live.log_metric()`.
935
+ """
936
+ from dvclive.plots import Metric
937
+
938
+ if step is not None:
939
+ self.live.step = step
940
+ for k, v in values.items():
941
+ if Metric.could_log(v):
942
+ self.live.log_metric(k, v, **kwargs)
943
+ else:
944
+ logger.warning_once(
945
+ "Accelerator attempted to log a value of "
946
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
947
+ "This invocation of DVCLive's Live.log_metric() "
948
+ "is incorrect so we dropped this attribute."
949
+ )
950
+ self.live.next_step()
951
+
952
+ @on_main_process
953
+ def finish(self):
954
+ """
955
+ Closes `dvclive.Live()`.
956
+ """
957
+ self.live.end()
958
+
959
+
960
+ LOGGER_TYPE_TO_CLASS = {
961
+ "aim": AimTracker,
962
+ "comet_ml": CometMLTracker,
963
+ "mlflow": MLflowTracker,
964
+ "tensorboard": TensorBoardTracker,
965
+ "wandb": WandBTracker,
966
+ "clearml": ClearMLTracker,
967
+ "dvclive": DVCLiveTracker,
968
+ }
969
+
970
+
971
+ def filter_trackers(
972
+ log_with: List[Union[str, LoggerType, GeneralTracker]],
973
+ logging_dir: Union[str, os.PathLike] = None,
974
+ ):
975
+ """
976
+ Takes in a list of potential tracker types and checks that:
977
+ - The tracker wanted is available in that environment
978
+ - Filters out repeats of tracker types
979
+ - If `all` is in `log_with`, will return all trackers in the environment
980
+ - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
981
+
982
+ Args:
983
+ log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
984
+ A list of loggers to be setup for experiment tracking. Should be one or several of:
985
+
986
+ - `"all"`
987
+ - `"tensorboard"`
988
+ - `"wandb"`
989
+ - `"comet_ml"`
990
+ - `"mlflow"`
991
+ - `"dvclive"`
992
+ If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
993
+ also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
994
+ logging_dir (`str`, `os.PathLike`, *optional*):
995
+ A path to a directory for storing logs of locally-compatible loggers.
996
+ """
997
+ loggers = []
998
+ if log_with is not None:
999
+ if not isinstance(log_with, (list, tuple)):
1000
+ log_with = [log_with]
1001
+ if "all" in log_with or LoggerType.ALL in log_with:
1002
+ loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
1003
+ else:
1004
+ for log_type in log_with:
1005
+ if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
1006
+ raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
1007
+ if issubclass(type(log_type), GeneralTracker):
1008
+ loggers.append(log_type)
1009
+ else:
1010
+ log_type = LoggerType(log_type)
1011
+ if log_type not in loggers:
1012
+ if log_type in get_available_trackers():
1013
+ tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
1014
+ if tracker_init.requires_logging_directory:
1015
+ if logging_dir is None:
1016
+ raise ValueError(
1017
+ f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
1018
+ )
1019
+ loggers.append(log_type)
1020
+ else:
1021
+ logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
1022
+
1023
+ return loggers
.venv/Lib/site-packages/decorator.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ######################### LICENSE ############################ #
2
+
3
+ # Copyright (c) 2005-2021, Michele Simionato
4
+ # All rights reserved.
5
+
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are
8
+ # met:
9
+
10
+ # Redistributions of source code must retain the above copyright
11
+ # notice, this list of conditions and the following disclaimer.
12
+ # Redistributions in bytecode form must reproduce the above copyright
13
+ # notice, this list of conditions and the following disclaimer in
14
+ # the documentation and/or other materials provided with the
15
+ # distribution.
16
+
17
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21
+ # HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
22
+ # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23
+ # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
24
+ # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
25
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
26
+ # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
27
+ # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28
+ # DAMAGE.
29
+
30
+ """
31
+ Decorator module, see
32
+ https://github.com/micheles/decorator/blob/master/docs/documentation.md
33
+ for the documentation.
34
+ """
35
+ import re
36
+ import sys
37
+ import inspect
38
+ import operator
39
+ import itertools
40
+ from contextlib import _GeneratorContextManager
41
+ from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction
42
+
43
+ __version__ = '5.1.1'
44
+
45
+ DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
46
+ POS = inspect.Parameter.POSITIONAL_OR_KEYWORD
47
+ EMPTY = inspect.Parameter.empty
48
+
49
+
50
+ # this is not used anymore in the core, but kept for backward compatibility
51
+ class FunctionMaker(object):
52
+ """
53
+ An object with the ability to create functions with a given signature.
54
+ It has attributes name, doc, module, signature, defaults, dict and
55
+ methods update and make.
56
+ """
57
+
58
+ # Atomic get-and-increment provided by the GIL
59
+ _compile_count = itertools.count()
60
+
61
+ # make pylint happy
62
+ args = varargs = varkw = defaults = kwonlyargs = kwonlydefaults = ()
63
+
64
+ def __init__(self, func=None, name=None, signature=None,
65
+ defaults=None, doc=None, module=None, funcdict=None):
66
+ self.shortsignature = signature
67
+ if func:
68
+ # func can be a class or a callable, but not an instance method
69
+ self.name = func.__name__
70
+ if self.name == '<lambda>': # small hack for lambda functions
71
+ self.name = '_lambda_'
72
+ self.doc = func.__doc__
73
+ self.module = func.__module__
74
+ if inspect.isroutine(func):
75
+ argspec = getfullargspec(func)
76
+ self.annotations = getattr(func, '__annotations__', {})
77
+ for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
78
+ 'kwonlydefaults'):
79
+ setattr(self, a, getattr(argspec, a))
80
+ for i, arg in enumerate(self.args):
81
+ setattr(self, 'arg%d' % i, arg)
82
+ allargs = list(self.args)
83
+ allshortargs = list(self.args)
84
+ if self.varargs:
85
+ allargs.append('*' + self.varargs)
86
+ allshortargs.append('*' + self.varargs)
87
+ elif self.kwonlyargs:
88
+ allargs.append('*') # single star syntax
89
+ for a in self.kwonlyargs:
90
+ allargs.append('%s=None' % a)
91
+ allshortargs.append('%s=%s' % (a, a))
92
+ if self.varkw:
93
+ allargs.append('**' + self.varkw)
94
+ allshortargs.append('**' + self.varkw)
95
+ self.signature = ', '.join(allargs)
96
+ self.shortsignature = ', '.join(allshortargs)
97
+ self.dict = func.__dict__.copy()
98
+ # func=None happens when decorating a caller
99
+ if name:
100
+ self.name = name
101
+ if signature is not None:
102
+ self.signature = signature
103
+ if defaults:
104
+ self.defaults = defaults
105
+ if doc:
106
+ self.doc = doc
107
+ if module:
108
+ self.module = module
109
+ if funcdict:
110
+ self.dict = funcdict
111
+ # check existence required attributes
112
+ assert hasattr(self, 'name')
113
+ if not hasattr(self, 'signature'):
114
+ raise TypeError('You are decorating a non function: %s' % func)
115
+
116
+ def update(self, func, **kw):
117
+ """
118
+ Update the signature of func with the data in self
119
+ """
120
+ func.__name__ = self.name
121
+ func.__doc__ = getattr(self, 'doc', None)
122
+ func.__dict__ = getattr(self, 'dict', {})
123
+ func.__defaults__ = self.defaults
124
+ func.__kwdefaults__ = self.kwonlydefaults or None
125
+ func.__annotations__ = getattr(self, 'annotations', None)
126
+ try:
127
+ frame = sys._getframe(3)
128
+ except AttributeError: # for IronPython and similar implementations
129
+ callermodule = '?'
130
+ else:
131
+ callermodule = frame.f_globals.get('__name__', '?')
132
+ func.__module__ = getattr(self, 'module', callermodule)
133
+ func.__dict__.update(kw)
134
+
135
+ def make(self, src_templ, evaldict=None, addsource=False, **attrs):
136
+ """
137
+ Make a new function from a given template and update the signature
138
+ """
139
+ src = src_templ % vars(self) # expand name and signature
140
+ evaldict = evaldict or {}
141
+ mo = DEF.search(src)
142
+ if mo is None:
143
+ raise SyntaxError('not a valid function template\n%s' % src)
144
+ name = mo.group(1) # extract the function name
145
+ names = set([name] + [arg.strip(' *') for arg in
146
+ self.shortsignature.split(',')])
147
+ for n in names:
148
+ if n in ('_func_', '_call_'):
149
+ raise NameError('%s is overridden in\n%s' % (n, src))
150
+
151
+ if not src.endswith('\n'): # add a newline for old Pythons
152
+ src += '\n'
153
+
154
+ # Ensure each generated function has a unique filename for profilers
155
+ # (such as cProfile) that depend on the tuple of (<filename>,
156
+ # <definition line>, <function name>) being unique.
157
+ filename = '<decorator-gen-%d>' % next(self._compile_count)
158
+ try:
159
+ code = compile(src, filename, 'single')
160
+ exec(code, evaldict)
161
+ except Exception:
162
+ print('Error in generated code:', file=sys.stderr)
163
+ print(src, file=sys.stderr)
164
+ raise
165
+ func = evaldict[name]
166
+ if addsource:
167
+ attrs['__source__'] = src
168
+ self.update(func, **attrs)
169
+ return func
170
+
171
+ @classmethod
172
+ def create(cls, obj, body, evaldict, defaults=None,
173
+ doc=None, module=None, addsource=True, **attrs):
174
+ """
175
+ Create a function from the strings name, signature and body.
176
+ evaldict is the evaluation dictionary. If addsource is true an
177
+ attribute __source__ is added to the result. The attributes attrs
178
+ are added, if any.
179
+ """
180
+ if isinstance(obj, str): # "name(signature)"
181
+ name, rest = obj.strip().split('(', 1)
182
+ signature = rest[:-1] # strip a right parens
183
+ func = None
184
+ else: # a function
185
+ name = None
186
+ signature = None
187
+ func = obj
188
+ self = cls(func, name, signature, defaults, doc, module)
189
+ ibody = '\n'.join(' ' + line for line in body.splitlines())
190
+ caller = evaldict.get('_call_') # when called from `decorate`
191
+ if caller and iscoroutinefunction(caller):
192
+ body = ('async def %(name)s(%(signature)s):\n' + ibody).replace(
193
+ 'return', 'return await')
194
+ else:
195
+ body = 'def %(name)s(%(signature)s):\n' + ibody
196
+ return self.make(body, evaldict, addsource, **attrs)
197
+
198
+
199
+ def fix(args, kwargs, sig):
200
+ """
201
+ Fix args and kwargs to be consistent with the signature
202
+ """
203
+ ba = sig.bind(*args, **kwargs)
204
+ ba.apply_defaults() # needed for test_dan_schult
205
+ return ba.args, ba.kwargs
206
+
207
+
208
+ def decorate(func, caller, extras=(), kwsyntax=False):
209
+ """
210
+ Decorates a function/generator/coroutine using a caller.
211
+ If kwsyntax is True calling the decorated functions with keyword
212
+ syntax will pass the named arguments inside the ``kw`` dictionary,
213
+ even if such argument are positional, similarly to what functools.wraps
214
+ does. By default kwsyntax is False and the the arguments are untouched.
215
+ """
216
+ sig = inspect.signature(func)
217
+ if iscoroutinefunction(caller):
218
+ async def fun(*args, **kw):
219
+ if not kwsyntax:
220
+ args, kw = fix(args, kw, sig)
221
+ return await caller(func, *(extras + args), **kw)
222
+ elif isgeneratorfunction(caller):
223
+ def fun(*args, **kw):
224
+ if not kwsyntax:
225
+ args, kw = fix(args, kw, sig)
226
+ for res in caller(func, *(extras + args), **kw):
227
+ yield res
228
+ else:
229
+ def fun(*args, **kw):
230
+ if not kwsyntax:
231
+ args, kw = fix(args, kw, sig)
232
+ return caller(func, *(extras + args), **kw)
233
+ fun.__name__ = func.__name__
234
+ fun.__doc__ = func.__doc__
235
+ fun.__wrapped__ = func
236
+ fun.__signature__ = sig
237
+ fun.__qualname__ = func.__qualname__
238
+ # builtin functions like defaultdict.__setitem__ lack many attributes
239
+ try:
240
+ fun.__defaults__ = func.__defaults__
241
+ except AttributeError:
242
+ pass
243
+ try:
244
+ fun.__kwdefaults__ = func.__kwdefaults__
245
+ except AttributeError:
246
+ pass
247
+ try:
248
+ fun.__annotations__ = func.__annotations__
249
+ except AttributeError:
250
+ pass
251
+ try:
252
+ fun.__module__ = func.__module__
253
+ except AttributeError:
254
+ pass
255
+ try:
256
+ fun.__dict__.update(func.__dict__)
257
+ except AttributeError:
258
+ pass
259
+ return fun
260
+
261
+
262
+ def decoratorx(caller):
263
+ """
264
+ A version of "decorator" implemented via "exec" and not via the
265
+ Signature object. Use this if you are want to preserve the `.__code__`
266
+ object properties (https://github.com/micheles/decorator/issues/129).
267
+ """
268
+ def dec(func):
269
+ return FunctionMaker.create(
270
+ func,
271
+ "return _call_(_func_, %(shortsignature)s)",
272
+ dict(_call_=caller, _func_=func),
273
+ __wrapped__=func, __qualname__=func.__qualname__)
274
+ return dec
275
+
276
+
277
+ def decorator(caller, _func=None, kwsyntax=False):
278
+ """
279
+ decorator(caller) converts a caller function into a decorator
280
+ """
281
+ if _func is not None: # return a decorated function
282
+ # this is obsolete behavior; you should use decorate instead
283
+ return decorate(_func, caller, (), kwsyntax)
284
+ # else return a decorator function
285
+ sig = inspect.signature(caller)
286
+ dec_params = [p for p in sig.parameters.values() if p.kind is POS]
287
+
288
+ def dec(func=None, *args, **kw):
289
+ na = len(args) + 1
290
+ extras = args + tuple(kw.get(p.name, p.default)
291
+ for p in dec_params[na:]
292
+ if p.default is not EMPTY)
293
+ if func is None:
294
+ return lambda func: decorate(func, caller, extras, kwsyntax)
295
+ else:
296
+ return decorate(func, caller, extras, kwsyntax)
297
+ dec.__signature__ = sig.replace(parameters=dec_params)
298
+ dec.__name__ = caller.__name__
299
+ dec.__doc__ = caller.__doc__
300
+ dec.__wrapped__ = caller
301
+ dec.__qualname__ = caller.__qualname__
302
+ dec.__kwdefaults__ = getattr(caller, '__kwdefaults__', None)
303
+ dec.__dict__.update(caller.__dict__)
304
+ return dec
305
+
306
+
307
+ # ####################### contextmanager ####################### #
308
+
309
+
310
+ class ContextManager(_GeneratorContextManager):
311
+ def __init__(self, g, *a, **k):
312
+ _GeneratorContextManager.__init__(self, g, a, k)
313
+
314
+ def __call__(self, func):
315
+ def caller(f, *a, **k):
316
+ with self.__class__(self.func, *self.args, **self.kwds):
317
+ return f(*a, **k)
318
+ return decorate(func, caller)
319
+
320
+
321
+ _contextmanager = decorator(ContextManager)
322
+
323
+
324
+ def contextmanager(func):
325
+ # Enable Pylint config: contextmanager-decorators=decorator.contextmanager
326
+ return _contextmanager(func)
327
+
328
+
329
+ # ############################ dispatch_on ############################ #
330
+
331
+ def append(a, vancestors):
332
+ """
333
+ Append ``a`` to the list of the virtual ancestors, unless it is already
334
+ included.
335
+ """
336
+ add = True
337
+ for j, va in enumerate(vancestors):
338
+ if issubclass(va, a):
339
+ add = False
340
+ break
341
+ if issubclass(a, va):
342
+ vancestors[j] = a
343
+ add = False
344
+ if add:
345
+ vancestors.append(a)
346
+
347
+
348
+ # inspired from simplegeneric by P.J. Eby and functools.singledispatch
349
+ def dispatch_on(*dispatch_args):
350
+ """
351
+ Factory of decorators turning a function into a generic function
352
+ dispatching on the given arguments.
353
+ """
354
+ assert dispatch_args, 'No dispatch args passed'
355
+ dispatch_str = '(%s,)' % ', '.join(dispatch_args)
356
+
357
+ def check(arguments, wrong=operator.ne, msg=''):
358
+ """Make sure one passes the expected number of arguments"""
359
+ if wrong(len(arguments), len(dispatch_args)):
360
+ raise TypeError('Expected %d arguments, got %d%s' %
361
+ (len(dispatch_args), len(arguments), msg))
362
+
363
+ def gen_func_dec(func):
364
+ """Decorator turning a function into a generic function"""
365
+
366
+ # first check the dispatch arguments
367
+ argset = set(getfullargspec(func).args)
368
+ if not set(dispatch_args) <= argset:
369
+ raise NameError('Unknown dispatch arguments %s' % dispatch_str)
370
+
371
+ typemap = {}
372
+
373
+ def vancestors(*types):
374
+ """
375
+ Get a list of sets of virtual ancestors for the given types
376
+ """
377
+ check(types)
378
+ ras = [[] for _ in range(len(dispatch_args))]
379
+ for types_ in typemap:
380
+ for t, type_, ra in zip(types, types_, ras):
381
+ if issubclass(t, type_) and type_ not in t.mro():
382
+ append(type_, ra)
383
+ return [set(ra) for ra in ras]
384
+
385
+ def ancestors(*types):
386
+ """
387
+ Get a list of virtual MROs, one for each type
388
+ """
389
+ check(types)
390
+ lists = []
391
+ for t, vas in zip(types, vancestors(*types)):
392
+ n_vas = len(vas)
393
+ if n_vas > 1:
394
+ raise RuntimeError(
395
+ 'Ambiguous dispatch for %s: %s' % (t, vas))
396
+ elif n_vas == 1:
397
+ va, = vas
398
+ mro = type('t', (t, va), {}).mro()[1:]
399
+ else:
400
+ mro = t.mro()
401
+ lists.append(mro[:-1]) # discard t and object
402
+ return lists
403
+
404
+ def register(*types):
405
+ """
406
+ Decorator to register an implementation for the given types
407
+ """
408
+ check(types)
409
+
410
+ def dec(f):
411
+ check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__)
412
+ typemap[types] = f
413
+ return f
414
+ return dec
415
+
416
+ def dispatch_info(*types):
417
+ """
418
+ An utility to introspect the dispatch algorithm
419
+ """
420
+ check(types)
421
+ lst = []
422
+ for anc in itertools.product(*ancestors(*types)):
423
+ lst.append(tuple(a.__name__ for a in anc))
424
+ return lst
425
+
426
+ def _dispatch(dispatch_args, *args, **kw):
427
+ types = tuple(type(arg) for arg in dispatch_args)
428
+ try: # fast path
429
+ f = typemap[types]
430
+ except KeyError:
431
+ pass
432
+ else:
433
+ return f(*args, **kw)
434
+ combinations = itertools.product(*ancestors(*types))
435
+ next(combinations) # the first one has been already tried
436
+ for types_ in combinations:
437
+ f = typemap.get(types_)
438
+ if f is not None:
439
+ return f(*args, **kw)
440
+
441
+ # else call the default implementation
442
+ return func(*args, **kw)
443
+
444
+ return FunctionMaker.create(
445
+ func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
446
+ dict(_f_=_dispatch), register=register, default=func,
447
+ typemap=typemap, vancestors=vancestors, ancestors=ancestors,
448
+ dispatch_info=dispatch_info, __wrapped__=func)
449
+
450
+ gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
451
+ return gen_func_dec
.venv/Lib/site-packages/isympy.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python shell for SymPy.
3
+
4
+ This is just a normal Python shell (IPython shell if you have the
5
+ IPython package installed), that executes the following commands for
6
+ the user:
7
+
8
+ >>> from __future__ import division
9
+ >>> from sympy import *
10
+ >>> x, y, z, t = symbols('x y z t')
11
+ >>> k, m, n = symbols('k m n', integer=True)
12
+ >>> f, g, h = symbols('f g h', cls=Function)
13
+ >>> init_printing()
14
+
15
+ So starting 'isympy' is equivalent to starting Python (or IPython) and
16
+ executing the above commands by hand. It is intended for easy and quick
17
+ experimentation with SymPy. isympy is a good way to use SymPy as an
18
+ interactive calculator. If you have IPython and Matplotlib installed, then
19
+ interactive plotting is enabled by default.
20
+
21
+ COMMAND LINE OPTIONS
22
+ --------------------
23
+
24
+ -c CONSOLE, --console=CONSOLE
25
+
26
+ Use the specified shell (Python or IPython) shell as the console
27
+ backend instead of the default one (IPython if present, Python
28
+ otherwise), e.g.:
29
+
30
+ $isympy -c python
31
+
32
+ CONSOLE must be one of 'ipython' or 'python'
33
+
34
+ -p PRETTY, --pretty PRETTY
35
+
36
+ Setup pretty-printing in SymPy. When pretty-printing is enabled,
37
+ expressions can be printed with Unicode or ASCII. The default is
38
+ to use pretty-printing (with Unicode if the terminal supports it).
39
+ When this option is 'no', expressions will not be pretty-printed
40
+ and ASCII will be used:
41
+
42
+ $isympy -p no
43
+
44
+ PRETTY must be one of 'unicode', 'ascii', or 'no'
45
+
46
+ -t TYPES, --types=TYPES
47
+
48
+ Setup the ground types for the polys. By default, gmpy ground types
49
+ are used if gmpy2 or gmpy is installed, otherwise it falls back to python
50
+ ground types, which are a little bit slower. You can manually
51
+ choose python ground types even if gmpy is installed (e.g., for
52
+ testing purposes):
53
+
54
+ $isympy -t python
55
+
56
+ TYPES must be one of 'gmpy', 'gmpy1' or 'python'
57
+
58
+ Note that the ground type gmpy1 is primarily intended for testing; it
59
+ forces the use of gmpy version 1 even if gmpy2 is available.
60
+
61
+ This is the same as setting the environment variable
62
+ SYMPY_GROUND_TYPES to the given ground type (e.g.,
63
+ SYMPY_GROUND_TYPES='gmpy')
64
+
65
+ The ground types can be determined interactively from the variable
66
+ sympy.polys.domains.GROUND_TYPES.
67
+
68
+ -o ORDER, --order ORDER
69
+
70
+ Setup the ordering of terms for printing. The default is lex, which
71
+ orders terms lexicographically (e.g., x**2 + x + 1). You can choose
72
+ other orderings, such as rev-lex, which will use reverse
73
+ lexicographic ordering (e.g., 1 + x + x**2):
74
+
75
+ $isympy -o rev-lex
76
+
77
+ ORDER must be one of 'lex', 'rev-lex', 'grlex', 'rev-grlex',
78
+ 'grevlex', 'rev-grevlex', 'old', or 'none'.
79
+
80
+ Note that for very large expressions, ORDER='none' may speed up
81
+ printing considerably but the terms will have no canonical order.
82
+
83
+ -q, --quiet
84
+
85
+ Print only Python's and SymPy's versions to stdout at startup.
86
+
87
+ -d, --doctest
88
+
89
+ Use the same format that should be used for doctests. This is
90
+ equivalent to -c python -p no.
91
+
92
+ -C, --no-cache
93
+
94
+ Disable the caching mechanism. Disabling the cache may slow certain
95
+ operations down considerably. This is useful for testing the cache,
96
+ or for benchmarking, as the cache can result in deceptive timings.
97
+
98
+ This is equivalent to setting the environment variable
99
+ SYMPY_USE_CACHE to 'no'.
100
+
101
+ -a, --auto-symbols (requires at least IPython 0.11)
102
+
103
+ Automatically create missing symbols. Normally, typing a name of a
104
+ Symbol that has not been instantiated first would raise NameError,
105
+ but with this option enabled, any undefined name will be
106
+ automatically created as a Symbol.
107
+
108
+ Note that this is intended only for interactive, calculator style
109
+ usage. In a script that uses SymPy, Symbols should be instantiated
110
+ at the top, so that it's clear what they are.
111
+
112
+ This will not override any names that are already defined, which
113
+ includes the single character letters represented by the mnemonic
114
+ QCOSINE (see the "Gotchas and Pitfalls" document in the
115
+ documentation). You can delete existing names by executing "del
116
+ name". If a name is defined, typing "'name' in dir()" will return True.
117
+
118
+ The Symbols that are created using this have default assumptions.
119
+ If you want to place assumptions on symbols, you should create them
120
+ using symbols() or var().
121
+
122
+ Finally, this only works in the top level namespace. So, for
123
+ example, if you define a function in isympy with an undefined
124
+ Symbol, it will not work.
125
+
126
+ See also the -i and -I options.
127
+
128
+ -i, --int-to-Integer (requires at least IPython 0.11)
129
+
130
+ Automatically wrap int literals with Integer. This makes it so that
131
+ things like 1/2 will come out as Rational(1, 2), rather than 0.5. This
132
+ works by preprocessing the source and wrapping all int literals with
133
+ Integer. Note that this will not change the behavior of int literals
134
+ assigned to variables, and it also won't change the behavior of functions
135
+ that return int literals.
136
+
137
+ If you want an int, you can wrap the literal in int(), e.g. int(3)/int(2)
138
+ gives 1.5 (with division imported from __future__).
139
+
140
+ -I, --interactive (requires at least IPython 0.11)
141
+
142
+ This is equivalent to --auto-symbols --int-to-Integer. Future options
143
+ designed for ease of interactive use may be added to this.
144
+
145
+ -D, --debug
146
+
147
+ Enable debugging output. This is the same as setting the
148
+ environment variable SYMPY_DEBUG to 'True'. The debug status is set
149
+ in the variable SYMPY_DEBUG within isympy.
150
+
151
+ -- IPython options
152
+
153
+ Additionally you can pass command line options directly to the IPython
154
+ interpreter (the standard Python shell is not supported). However you
155
+ need to add the '--' separator between two types of options, e.g the
156
+ startup banner option and the colors option. You need to enter the
157
+ options as required by the version of IPython that you are using, too:
158
+
159
+ in IPython 0.11,
160
+
161
+ $isympy -q -- --colors=NoColor
162
+
163
+ or older versions of IPython,
164
+
165
+ $isympy -q -- -colors NoColor
166
+
167
+ See also isympy --help.
168
+ """
169
+
170
+ import os
171
+ import sys
172
+
173
+ # DO NOT IMPORT SYMPY HERE! Or the setting of the sympy environment variables
174
+ # by the command line will break.
175
+
176
+ def main() -> None:
177
+ from argparse import ArgumentParser, RawDescriptionHelpFormatter
178
+
179
+ VERSION = None
180
+ if '--version' in sys.argv:
181
+ # We cannot import sympy before this is run, because flags like -C and
182
+ # -t set environment variables that must be set before SymPy is
183
+ # imported. The only thing we need to import it for is to get the
184
+ # version, which only matters with the --version flag.
185
+ import sympy
186
+ VERSION = sympy.__version__
187
+
188
+ usage = 'isympy [options] -- [ipython options]'
189
+ parser = ArgumentParser(
190
+ usage=usage,
191
+ description=__doc__,
192
+ formatter_class=RawDescriptionHelpFormatter,
193
+ )
194
+
195
+ parser.add_argument('--version', action='version', version=VERSION)
196
+
197
+ parser.add_argument(
198
+ '-c', '--console',
199
+ dest='console',
200
+ action='store',
201
+ default=None,
202
+ choices=['ipython', 'python'],
203
+ metavar='CONSOLE',
204
+ help='select type of interactive session: ipython | python; defaults '
205
+ 'to ipython if IPython is installed, otherwise python')
206
+
207
+ parser.add_argument(
208
+ '-p', '--pretty',
209
+ dest='pretty',
210
+ action='store',
211
+ default=None,
212
+ metavar='PRETTY',
213
+ choices=['unicode', 'ascii', 'no'],
214
+ help='setup pretty printing: unicode | ascii | no; defaults to '
215
+ 'unicode printing if the terminal supports it, otherwise ascii')
216
+
217
+ parser.add_argument(
218
+ '-t', '--types',
219
+ dest='types',
220
+ action='store',
221
+ default=None,
222
+ metavar='TYPES',
223
+ choices=['gmpy', 'gmpy1', 'python'],
224
+ help='setup ground types: gmpy | gmpy1 | python; defaults to gmpy if gmpy2 '
225
+ 'or gmpy is installed, otherwise python')
226
+
227
+ parser.add_argument(
228
+ '-o', '--order',
229
+ dest='order',
230
+ action='store',
231
+ default=None,
232
+ metavar='ORDER',
233
+ choices=['lex', 'grlex', 'grevlex', 'rev-lex', 'rev-grlex', 'rev-grevlex', 'old', 'none'],
234
+ help='setup ordering of terms: [rev-]lex | [rev-]grlex | [rev-]grevlex | old | none; defaults to lex')
235
+
236
+ parser.add_argument(
237
+ '-q', '--quiet',
238
+ dest='quiet',
239
+ action='store_true',
240
+ default=False,
241
+ help='print only version information at startup')
242
+
243
+ parser.add_argument(
244
+ '-d', '--doctest',
245
+ dest='doctest',
246
+ action='store_true',
247
+ default=False,
248
+ help='use the doctest format for output (you can just copy and paste it)')
249
+
250
+ parser.add_argument(
251
+ '-C', '--no-cache',
252
+ dest='cache',
253
+ action='store_false',
254
+ default=True,
255
+ help='disable caching mechanism')
256
+
257
+ parser.add_argument(
258
+ '-a', '--auto-symbols',
259
+ dest='auto_symbols',
260
+ action='store_true',
261
+ default=False,
262
+ help='automatically construct missing symbols')
263
+
264
+ parser.add_argument(
265
+ '-i', '--int-to-Integer',
266
+ dest='auto_int_to_Integer',
267
+ action='store_true',
268
+ default=False,
269
+ help="automatically wrap int literals with Integer")
270
+
271
+ parser.add_argument(
272
+ '-I', '--interactive',
273
+ dest='interactive',
274
+ action='store_true',
275
+ default=False,
276
+ help="equivalent to -a -i")
277
+
278
+ parser.add_argument(
279
+ '-D', '--debug',
280
+ dest='debug',
281
+ action='store_true',
282
+ default=False,
283
+ help='enable debugging output')
284
+
285
+ (options, ipy_args) = parser.parse_known_args()
286
+ if '--' in ipy_args:
287
+ ipy_args.remove('--')
288
+
289
+ if not options.cache:
290
+ os.environ['SYMPY_USE_CACHE'] = 'no'
291
+
292
+ if options.types:
293
+ os.environ['SYMPY_GROUND_TYPES'] = options.types
294
+
295
+ if options.debug:
296
+ os.environ['SYMPY_DEBUG'] = str(options.debug)
297
+
298
+ if options.doctest:
299
+ options.pretty = 'no'
300
+ options.console = 'python'
301
+
302
+ session = options.console
303
+
304
+ if session is not None:
305
+ ipython = session == 'ipython'
306
+ else:
307
+ try:
308
+ import IPython
309
+ ipython = True
310
+ except ImportError:
311
+ if not options.quiet:
312
+ from sympy.interactive.session import no_ipython
313
+ print(no_ipython)
314
+ ipython = False
315
+
316
+ args = {
317
+ 'pretty_print': True,
318
+ 'use_unicode': None,
319
+ 'use_latex': None,
320
+ 'order': None,
321
+ 'argv': ipy_args,
322
+ }
323
+
324
+ if options.pretty == 'unicode':
325
+ args['use_unicode'] = True
326
+ elif options.pretty == 'ascii':
327
+ args['use_unicode'] = False
328
+ elif options.pretty == 'no':
329
+ args['pretty_print'] = False
330
+
331
+ if options.order is not None:
332
+ args['order'] = options.order
333
+
334
+ args['quiet'] = options.quiet
335
+ args['auto_symbols'] = options.auto_symbols or options.interactive
336
+ args['auto_int_to_Integer'] = options.auto_int_to_Integer or options.interactive
337
+
338
+ from sympy.interactive import init_session
339
+ init_session(ipython, **args)
340
+
341
+ if __name__ == "__main__":
342
+ main()
.venv/Lib/site-packages/mojimoji.cp39-win_amd64.pyd ADDED
Binary file (93.7 kB). View file
 
.venv/Lib/site-packages/numpy-1.26.3-cp39-cp39-win_amd64.whl ADDED
File without changes
.venv/Lib/site-packages/plac.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ######################### LICENSE ###############################
2
+ #
3
+ # Copyright (c) 2010-2021, Michele Simionato
4
+ # All rights reserved.
5
+ #
6
+ # Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # Redistributions in bytecode form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in
10
+ # the documentation and/or other materials provided with the
11
+ # distribution.
12
+
13
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
14
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
15
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
16
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
17
+ # HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
18
+ # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19
+ # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
20
+ # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
22
+ # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
23
+ # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
24
+ # DAMAGE.
25
+ """
26
+ See docs/index.html for the documentation.
27
+ """
28
+ from plac_core import *
29
+ from plac_ext import (import_main, ReadlineInput, Interpreter,
30
+ stdout, runp, Monitor, default_help)
31
+
32
+ __version__ = '1.4.3'
33
+
34
+ try:
35
+ from plac_tk import TkMonitor
36
+ except ImportError:
37
+ pass
.venv/Lib/site-packages/plac_core.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this module should be kept Python 2.3 compatible
2
+ import re
3
+ import sys
4
+ import time
5
+ import inspect
6
+ import textwrap
7
+ import functools
8
+ import argparse
9
+ from datetime import datetime, date
10
+ from gettext import gettext as _
11
+
12
+ version = sys.version_info[:2]
13
+
14
+ if sys.version >= '3':
15
+ from inspect import getfullargspec
16
+ else:
17
+ class getfullargspec(object):
18
+ "A quick and dirty replacement for getfullargspec for Python 2.X"
19
+ def __init__(self, f):
20
+ self.args, self.varargs, self.varkw, self.defaults = \
21
+ inspect.getargspec(f)
22
+ self.annotations = getattr(f, '__annotations__', {})
23
+
24
+
25
+ def to_date(s):
26
+ """Returns year-month-day"""
27
+ return date(*time.strptime(s, "%Y-%m-%d")[0:3])
28
+
29
+
30
+ def to_datetime(s):
31
+ """Returns year-month-day hour-minute-second"""
32
+ return datetime(*time.strptime(s, "%Y-%m-%d %H-%M-%S")[0:6])
33
+
34
+
35
+ def getargspec(callableobj):
36
+ """Given a callable return an object with attributes .args, .varargs,
37
+ .varkw, .defaults. It tries to do the "right thing" with functions,
38
+ methods, classes and generic callables."""
39
+ if inspect.isfunction(callableobj):
40
+ argspec = getfullargspec(callableobj)
41
+ elif inspect.ismethod(callableobj):
42
+ argspec = getfullargspec(callableobj)
43
+ del argspec.args[0] # remove first argument
44
+ elif inspect.isclass(callableobj):
45
+ if callableobj.__init__ is object.__init__: # to avoid an error
46
+ argspec = getfullargspec(lambda self: None)
47
+ else:
48
+ argspec = getfullargspec(callableobj.__init__)
49
+ del argspec.args[0] # remove first argument
50
+ elif hasattr(callableobj, '__call__'):
51
+ argspec = getfullargspec(callableobj.__call__)
52
+ del argspec.args[0] # remove first argument
53
+ else:
54
+ raise TypeError(_('Could not determine the signature of ') +
55
+ str(callableobj))
56
+ return argspec
57
+
58
+
59
+ def annotations(**ann):
60
+ """
61
+ Returns a decorator annotating a function with the given annotations.
62
+ This is a trick to support function annotations in Python 2.X.
63
+ """
64
+ def annotate(f):
65
+ fas = getfullargspec(f)
66
+ args = fas.args
67
+ if fas.varargs:
68
+ args.append(fas.varargs)
69
+ if fas.varkw:
70
+ args.append(fas.varkw)
71
+ for argname in ann:
72
+ if argname not in args:
73
+ raise NameError(
74
+ _('Annotating non-existing argument: %s') % argname)
75
+ f.__annotations__ = ann
76
+ return f
77
+ return annotate
78
+
79
+
80
+ def _annotate(arg, ann, f):
81
+ try:
82
+ f.__annotations__[arg] = ann
83
+ except AttributeError: # Python 2.7
84
+ f.__annotations__ = {arg: ann}
85
+ return f
86
+
87
+
88
+ def pos(arg, help=None, type=None, choices=None, metavar=None):
89
+ """
90
+ Decorator for annotating positional arguments
91
+ """
92
+ return functools.partial(
93
+ _annotate, arg, (help, 'positional', None, type, choices, metavar))
94
+
95
+
96
+ def opt(arg, help=None, type=None, abbrev=None, choices=None, metavar=None):
97
+ """
98
+ Decorator for annotating optional arguments
99
+ """
100
+ abbrev = abbrev or arg[0]
101
+ return functools.partial(
102
+ _annotate, arg, (help, 'option', abbrev, type, choices, metavar))
103
+
104
+
105
+ def flg(arg, help=None, abbrev=None):
106
+ """
107
+ Decorator for annotating flags
108
+ """
109
+ return functools.partial(
110
+ _annotate, arg, (help, 'flag', abbrev or arg[0], None, None, None))
111
+
112
+
113
+ def is_annotation(obj):
114
+ """
115
+ An object is an annotation object if it has the attributes
116
+ help, kind, abbrev, type, choices, metavar.
117
+ """
118
+ return (hasattr(obj, 'help') and hasattr(obj, 'kind')
119
+ and hasattr(obj, 'abbrev') and hasattr(obj, 'type')
120
+ and hasattr(obj, 'choices') and hasattr(obj, 'metavar'))
121
+
122
+
123
+ class Annotation(object):
124
+ def __init__(self, help=None, kind="positional", abbrev=None, type=None,
125
+ choices=None, metavar=None):
126
+ assert kind in ('positional', 'option', 'flag'), kind
127
+ if kind == "positional":
128
+ assert abbrev is None, abbrev
129
+ self.help = help
130
+ self.kind = kind
131
+ self.abbrev = abbrev
132
+ self.type = type
133
+ self.choices = choices
134
+ self.metavar = metavar
135
+
136
+ def from_(cls, obj):
137
+ "Helper to convert an object into an annotation, if needed"
138
+ if is_annotation(obj):
139
+ return obj # do nothing
140
+ elif inspect.isclass(obj):
141
+ obj = str(obj)
142
+ elif iterable(obj):
143
+ return cls(*obj)
144
+ return cls(obj)
145
+ from_ = classmethod(from_)
146
+
147
+
148
+ NONE = object() # sentinel use to signal the absence of a default
149
+
150
+ PARSER_CFG = getfullargspec(argparse.ArgumentParser.__init__).args[1:]
151
+ # the default arguments accepted by an ArgumentParser object
152
+
153
+
154
+ def pconf(obj):
155
+ """
156
+ Extracts the configuration of the underlying ArgumentParser from obj
157
+ """
158
+ cfg = dict(description=(textwrap.dedent(obj.__doc__.rstrip())
159
+ if obj.__doc__ else None),
160
+ formatter_class=argparse.RawDescriptionHelpFormatter)
161
+ for name in dir(obj):
162
+ if name in PARSER_CFG: # argument of ArgumentParser
163
+ cfg[name] = getattr(obj, name)
164
+ return cfg
165
+
166
+
167
+ _parser_registry = {}
168
+
169
+
170
+ def parser_from(obj, **confparams):
171
+ """
172
+ obj can be a callable or an object with a .commands attribute.
173
+ Returns an ArgumentParser.
174
+ """
175
+ try: # the underlying parser has been generated already
176
+ return _parser_registry[obj]
177
+ except KeyError: # generate a new parser
178
+ pass
179
+ conf = pconf(obj).copy()
180
+ conf.update(confparams)
181
+ _parser_registry[obj] = parser = ArgumentParser(**conf)
182
+ parser.obj = obj
183
+ parser.case_sensitive = confparams.get(
184
+ 'case_sensitive', getattr(obj, 'case_sensitive', True))
185
+ if hasattr(obj, 'commands') and not inspect.isclass(obj):
186
+ # a command container instance
187
+ parser.addsubcommands(obj.commands, obj, 'subcommands')
188
+ else:
189
+ parser.populate_from(obj)
190
+ return parser
191
+
192
+
193
+ def _extract_kwargs(args):
194
+ """
195
+ Returns two lists: regular args and name=value args
196
+ """
197
+ arglist = []
198
+ kwargs = {}
199
+ for arg in args:
200
+ match = re.match(r'([a-zA-Z_]\w*)=', arg)
201
+ if match:
202
+ name = match.group(1)
203
+ kwargs[name] = arg[len(name)+1:]
204
+ else:
205
+ arglist.append(arg)
206
+ return arglist, kwargs
207
+
208
+
209
+ def _match_cmd(abbrev, commands, case_sensitive=True):
210
+ """
211
+ Extract the command name from an abbreviation or raise a NameError
212
+ """
213
+ if not case_sensitive:
214
+ abbrev = abbrev.upper()
215
+ commands = [c.upper() for c in commands]
216
+ perfect_matches = [name for name in commands if name == abbrev]
217
+ if len(perfect_matches) == 1:
218
+ return perfect_matches[0]
219
+ matches = [name for name in commands if name.startswith(abbrev)]
220
+ n = len(matches)
221
+ if n == 1:
222
+ return matches[0]
223
+ elif n > 1:
224
+ raise NameError(
225
+ _('Ambiguous command %r: matching %s' % (abbrev, matches)))
226
+
227
+
228
+ class ArgumentParser(argparse.ArgumentParser):
229
+ """
230
+ An ArgumentParser with .func and .argspec attributes, and possibly
231
+ .commands and .subparsers.
232
+ """
233
+ case_sensitive = True
234
+
235
+ if version < (3, 10):
236
+ def __init__(self, *args, **kwargs):
237
+ super(ArgumentParser, self).__init__(*args, **kwargs)
238
+ if self._action_groups[1].title == _('optional arguments'):
239
+ self._action_groups[1].title = _('options')
240
+
241
+ def alias(self, arg):
242
+ "Can be overridden to preprocess command-line arguments"
243
+ return arg
244
+
245
+ def consume(self, args):
246
+ """
247
+ Call the underlying function with the args. Works also for
248
+ command containers, by dispatching to the right subparser.
249
+ """
250
+ arglist = [self.alias(a) for a in args]
251
+ cmd = None
252
+ if hasattr(self, 'subparsers'):
253
+ subp, cmd = self._extract_subparser_cmd(arglist)
254
+ if subp is None and cmd is not None:
255
+ return cmd, self.missing(cmd)
256
+ elif subp is not None: # use the subparser
257
+ self = subp
258
+ if hasattr(self, 'argspec') and self.argspec.varargs:
259
+ # ignore unrecognized arguments
260
+ ns, extraopts = self.parse_known_args(arglist)
261
+ else:
262
+ ns, extraopts = self.parse_args(arglist), [] # may raise an exit
263
+ if not hasattr(self, 'argspec'):
264
+ raise SystemExit
265
+ if hasattr(self, 'argspec') and self.argspec.varkw:
266
+ v = self.argspec.varargs
267
+ varkw = self.argspec.varkw
268
+ if v in ns.__dict__:
269
+ lst = ns.__dict__.pop(v)
270
+ lst, kwargs = _extract_kwargs(lst)
271
+ ns.__dict__[v] = lst
272
+ elif varkw in ns.__dict__:
273
+ lst = ns.__dict__.pop(varkw)
274
+ lst, kwargs = _extract_kwargs(lst)
275
+ ns.__dict__[varkw] = lst
276
+ if lst and not v:
277
+ self.error(_('Unrecognized arguments: %s') % arglist)
278
+ else:
279
+ kwargs = {}
280
+ collision = set(self.argspec.args) & set(kwargs)
281
+ if collision:
282
+ self.error(
283
+ _('colliding keyword arguments: %s') % ' '.join(collision))
284
+ # Correct options with trailing undescores
285
+ args = [getattr(ns, a.rstrip('_')) for a in self.argspec.args]
286
+ varargs = getattr(ns, self.argspec.varargs or '', [])
287
+ return cmd, self.func(*(args + varargs + extraopts), **kwargs)
288
+
289
+ def _extract_subparser_cmd(self, arglist):
290
+ """
291
+ Extract the right subparser from the first recognized argument
292
+ """
293
+ optprefix = self.prefix_chars[0]
294
+ name_parser_map = self.subparsers._name_parser_map
295
+ for i, arg in enumerate(arglist):
296
+ if not arg.startswith(optprefix):
297
+ cmd = _match_cmd(arg, name_parser_map, self.case_sensitive)
298
+ del arglist[i]
299
+ return name_parser_map.get(cmd), cmd or arg
300
+ return None, None
301
+
302
+ def addsubcommands(self, commands, obj, title=None, cmdprefix=''):
303
+ """
304
+ Extract a list of subcommands from obj and add them to the parser
305
+ """
306
+ if hasattr(obj, cmdprefix) and obj.cmdprefix in self.prefix_chars:
307
+ raise ValueError(_('The prefix %r is already taken!' % cmdprefix))
308
+ if not hasattr(self, 'subparsers'):
309
+ self.subparsers = self.add_subparsers(title=title)
310
+ elif title:
311
+ self.add_argument_group(title=title) # populate ._action_groups
312
+ prefixlen = len(getattr(obj, 'cmdprefix', ''))
313
+ add_help = getattr(obj, 'add_help', True)
314
+ for cmd in commands:
315
+ func = getattr(obj, cmd[prefixlen:]) # strip the prefix
316
+ doc = (textwrap.dedent(func.__doc__.rstrip())
317
+ if func.__doc__ else None)
318
+ self.subparsers.add_parser(
319
+ cmd, add_help=add_help, help=doc, **pconf(func)
320
+ ).populate_from(func)
321
+
322
+ def _set_func_argspec(self, obj):
323
+ """
324
+ Extracts the signature from a callable object and adds an .argspec
325
+ attribute to the parser. Also adds a .func reference to the object.
326
+ """
327
+ self.func = obj
328
+ self.argspec = getargspec(obj)
329
+ _parser_registry[obj] = self
330
+
331
+ def populate_from(self, func):
332
+ """
333
+ Extract the arguments from the attributes of the passed function
334
+ and return a populated ArgumentParser instance.
335
+ """
336
+ self._set_func_argspec(func)
337
+ f = self.argspec
338
+ defaults = f.defaults or ()
339
+ n_args = len(f.args)
340
+ n_defaults = len(defaults)
341
+ alldefaults = (NONE,) * (n_args - n_defaults) + defaults
342
+ prefix = self.prefix = getattr(func, 'prefix_chars', '-')[0]
343
+ for name, default in zip(f.args, alldefaults):
344
+ ann = f.annotations.get(name, ())
345
+ a = Annotation.from_(ann)
346
+ metavar = a.metavar
347
+ if default is NONE:
348
+ dflt = None
349
+ else:
350
+ dflt = default
351
+ if a.help is None:
352
+ a.help = '[%s]' % str(dflt) # dflt can be a tuple
353
+ if a.type is None:
354
+ # try to infer the type from the default argument
355
+ if isinstance(default, datetime):
356
+ a.type = to_datetime
357
+ elif isinstance(default, date):
358
+ a.type = to_date
359
+ elif default is not None:
360
+ a.type = type(default)
361
+ if not metavar and default == '':
362
+ metavar = "''"
363
+ if a.kind in ('option', 'flag'):
364
+
365
+ if name.endswith("_"):
366
+ # allows reserved words to be specified with underscores
367
+ suffix = name.rstrip('_')
368
+ else:
369
+ # convert undescores to dashes.
370
+ suffix = name.replace('_', '-')
371
+
372
+ if a.abbrev:
373
+ shortlong = (prefix + a.abbrev,
374
+ prefix*2 + suffix)
375
+ else:
376
+ shortlong = (prefix + suffix,)
377
+ elif default is NONE: # required argument
378
+ self.add_argument(name, help=a.help, type=a.type,
379
+ choices=a.choices, metavar=metavar)
380
+ else: # default argument
381
+ self.add_argument(
382
+ name, nargs='?', help=a.help, default=dflt,
383
+ type=a.type, choices=a.choices, metavar=metavar)
384
+ if a.kind == 'option':
385
+ if default is not NONE:
386
+ metavar = metavar or str(default)
387
+ self.add_argument(
388
+ help=a.help, default=dflt, type=a.type,
389
+ choices=a.choices, metavar=metavar, *shortlong)
390
+ elif a.kind == 'flag':
391
+ if default is not NONE and default is not False:
392
+ raise TypeError(_('Flag %r wants default False, got %r') %
393
+ (name, default))
394
+ self.add_argument(action='store_true', help=a.help, *shortlong)
395
+ if f.varargs:
396
+ a = Annotation.from_(f.annotations.get(f.varargs, ()))
397
+ self.add_argument(f.varargs, nargs='*', help=a.help, default=[],
398
+ type=a.type, metavar=a.metavar)
399
+ if f.varkw:
400
+ a = Annotation.from_(f.annotations.get(f.varkw, ()))
401
+ self.add_argument(f.varkw, nargs='*', help=a.help, default={},
402
+ type=a.type, metavar=a.metavar)
403
+
404
+ def missing(self, name):
405
+ "May raise a SystemExit"
406
+ miss = getattr(self.obj, '__missing__', lambda name:
407
+ self.error('No command %r' % name))
408
+ return miss(name)
409
+
410
+ def print_actions(self):
411
+ "Useful for debugging"
412
+ print(self)
413
+ for a in self._actions:
414
+ print(a)
415
+
416
+
417
+ def iterable(obj):
418
+ "Any object with an __iter__ method which is not a string or class"
419
+ return hasattr(obj, '__iter__') and not inspect.isclass(obj) and not isinstance(obj, (str, bytes))
420
+
421
+
422
+ def call(obj, arglist=None, eager=True, version=None):
423
+ """
424
+ If obj is a function or a bound method, parse the given arglist
425
+ by using the parser inferred from the annotations of obj
426
+ and call obj with the parsed arguments.
427
+ If obj is an object with attribute .commands, dispatch to the
428
+ associated subparser.
429
+ """
430
+ if arglist is None:
431
+ arglist = sys.argv[1:]
432
+ parser = parser_from(obj)
433
+ if version:
434
+ parser.add_argument(
435
+ '--version', '-v', action='version', version=version)
436
+ cmd, result = parser.consume(arglist)
437
+ if iterable(result) and eager: # listify the result
438
+ return list(result)
439
+ return result
.venv/Lib/site-packages/plac_ext.py ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this module requires Python 2.6+
2
+ from __future__ import with_statement
3
+ from contextlib import contextmanager
4
+ from operator import attrgetter
5
+ from gettext import gettext as _
6
+ import inspect
7
+ import os
8
+ import sys
9
+ import cmd
10
+ import shlex
11
+ import subprocess
12
+ import argparse
13
+ import itertools
14
+ import traceback
15
+ import multiprocessing
16
+ import signal
17
+ import threading
18
+ import plac_core
19
+
20
+ version = sys.version_info[:2]
21
+
22
+ if version < (3, 5):
23
+ from imp import load_source
24
+ else:
25
+ import importlib.util
26
+
27
+ def load_source(dotname, path):
28
+ spec = importlib.util.spec_from_file_location(dotname, path)
29
+ mod = importlib.util.module_from_spec(spec)
30
+ spec.loader.exec_module(mod)
31
+ return mod
32
+
33
+
34
+ if sys.version < '3':
35
+ def exec_(_code_, _globs_=None, _locs_=None):
36
+ if _globs_ is None:
37
+ frame = sys._getframe(1)
38
+ _globs_ = frame.f_globals
39
+ if _locs_ is None:
40
+ _locs_ = frame.f_locals
41
+ del frame
42
+ elif _locs_ is None:
43
+ _locs_ = _globs_
44
+ exec("""exec _code_ in _globs_, _locs_""")
45
+
46
+ exec('''
47
+ def raise_(tp, value=None, tb=None):
48
+ raise tp, value, tb
49
+ ''')
50
+ else:
51
+ exec_ = eval('exec')
52
+
53
+ def raise_(tp, value=None, tb=None):
54
+ """
55
+ A function that matches the Python 2.x ``raise`` statement. This
56
+ allows re-raising exceptions with the cls value and traceback on
57
+ Python 2 and 3.
58
+ """
59
+ if value is not None and isinstance(tp, Exception):
60
+ raise TypeError("instance exception may not have a separate value")
61
+ if value is not None:
62
+ exc = tp(value)
63
+ else:
64
+ exc = tp
65
+ if exc.__traceback__ is not tb:
66
+ raise exc.with_traceback(tb)
67
+ raise exc
68
+
69
+ try:
70
+ raw_input
71
+ except NameError: # Python 3
72
+ raw_input = input
73
+
74
+
75
+ def decode(val):
76
+ """
77
+ Decode an object assuming the encoding is UTF-8.
78
+ """
79
+ try:
80
+ # assume it is an encoded bytes object
81
+ return val.decode('utf-8')
82
+ except AttributeError:
83
+ # it was an already decoded unicode object
84
+ return str(val)
85
+
86
+ # ############################ generic utils ############################### #
87
+
88
+
89
+ @contextmanager
90
+ def stdout(fileobj):
91
+ "usage: with stdout(file('out.txt', 'a')): do_something()"
92
+ orig_stdout = sys.stdout
93
+ sys.stdout = fileobj
94
+ try:
95
+ yield
96
+ finally:
97
+ sys.stdout = orig_stdout
98
+
99
+
100
+ def write(x):
101
+ "Write str(x) on stdout and flush, no newline added"
102
+ sys.stdout.write(str(x))
103
+ sys.stdout.flush()
104
+
105
+
106
+ def gen_val(value):
107
+ "Return a generator object with a single element"
108
+ yield value
109
+
110
+
111
+ def gen_exc(etype, exc, tb):
112
+ "Return a generator object raising an exception"
113
+ raise_(etype, exc, tb)
114
+ yield
115
+
116
+
117
+ def less(text):
118
+ "Send a text to less via a pipe"
119
+ # -c clear the screen before starting less
120
+ po = subprocess.Popen(['less', '-c'], stdin=subprocess.PIPE)
121
+ try:
122
+ po.stdin.write(text)
123
+ except IOError:
124
+ pass
125
+ po.stdin.close()
126
+ po.wait()
127
+
128
+
129
+ use_less = (sys.platform != 'win32') # unices
130
+
131
+
132
+ class TerminatedProcess(Exception):
133
+ pass
134
+
135
+
136
+ def terminatedProcess(signum, frame):
137
+ raise TerminatedProcess
138
+
139
+
140
+ # ########################## readline support ############################ #
141
+
142
+ def read_line(stdin, prompt=''):
143
+ "Read a line from stdin, using readline when possible"
144
+ if isinstance(stdin, ReadlineInput):
145
+ return stdin.readline(prompt)
146
+ else:
147
+ write(prompt)
148
+ return stdin.readline()
149
+
150
+
151
+ def read_long_line(stdin, terminator):
152
+ """
153
+ Read multiple lines from stdin until the terminator character is found,
154
+ then yield a single space-separated long line.
155
+ """
156
+ while True:
157
+ lines = []
158
+ while True:
159
+ line = stdin.readline() # ends with \n
160
+ if not line: # EOF
161
+ return
162
+ line = line.strip()
163
+ if not line:
164
+ continue
165
+ elif line[-1] == terminator:
166
+ lines.append(line[:-1])
167
+ break
168
+ else:
169
+ lines.append(line)
170
+ yield ' '.join(lines)
171
+
172
+
173
+ class ReadlineInput(object):
174
+ """
175
+ An iterable with a .readline method reading from stdin.
176
+ """
177
+ def __init__(self, completions, case_sensitive=True, histfile=None):
178
+ self.completions = completions
179
+ self.case_sensitive = case_sensitive
180
+ self.histfile = histfile
181
+ if not case_sensitive:
182
+ self.completions = [c.upper() for c in completions]
183
+ import readline
184
+ self.rl = readline
185
+ readline.parse_and_bind("tab: complete")
186
+ readline.set_completer(self.complete)
187
+
188
+ def __enter__(self):
189
+ self.old_completer = self.rl.get_completer()
190
+ try:
191
+ if self.histfile:
192
+ self.rl.read_history_file(self.histfile)
193
+ except IOError: # the first time
194
+ pass
195
+ return self
196
+
197
+ def __exit__(self, etype, exc, tb):
198
+ self.rl.set_completer(self.old_completer)
199
+ if self.histfile:
200
+ self.rl.write_history_file(self.histfile)
201
+
202
+ def complete(self, kw, state):
203
+ # state is 0, 1, 2, ... and increases by hitting TAB
204
+ if not self.case_sensitive:
205
+ kw = kw.upper()
206
+ try:
207
+ return [k for k in self.completions if k.startswith(kw)][state]
208
+ except IndexError: # no completions
209
+ return # exit
210
+
211
+ def readline(self, prompt=''):
212
+ try:
213
+ return raw_input(prompt) + '\n'
214
+ except EOFError:
215
+ return ''
216
+
217
+ def __iter__(self):
218
+ return iter(self.readline, '')
219
+
220
+ # ################# help functionality in plac interpreters ################# #
221
+
222
+
223
+ class HelpSummary(object):
224
+ "Build the help summary consistently with the cmd module"
225
+
226
+ @classmethod
227
+ def add(cls, obj, specialcommands):
228
+ p = plac_core.parser_from(obj)
229
+ c = cmd.Cmd(stdout=cls())
230
+ c.stdout.write('\n')
231
+ c.print_topics('special commands',
232
+ sorted(specialcommands), 15, 80)
233
+ c.print_topics('custom commands',
234
+ sorted(obj.commands), 15, 80)
235
+ c.print_topics('commands run in external processes',
236
+ sorted(obj.mpcommands), 15, 80)
237
+ c.print_topics('threaded commands',
238
+ sorted(obj.thcommands), 15, 80)
239
+ p.helpsummary = str(c.stdout)
240
+
241
+ def __init__(self):
242
+ self._ls = []
243
+
244
+ def write(self, s):
245
+ self._ls.append(s)
246
+
247
+ def __str__(self):
248
+ return ''.join(self._ls)
249
+
250
+
251
+ class PlacFormatter(argparse.RawDescriptionHelpFormatter):
252
+ def _metavar_formatter(self, action, default_metavar):
253
+ 'Remove special commands from the usage message'
254
+ choices = action.choices or {}
255
+ action.choices = dict((n, c) for n, c in choices.items()
256
+ if not n.startswith('.'))
257
+ return super(PlacFormatter, self)._metavar_formatter(
258
+ action, default_metavar)
259
+
260
+
261
+ def format_help(self):
262
+ "Attached to plac_core.ArgumentParser for plac interpreters"
263
+ try:
264
+ return self.helpsummary
265
+ except AttributeError:
266
+ return super(plac_core.ArgumentParser, self).format_help()
267
+ plac_core.ArgumentParser.format_help = format_help
268
+
269
+
270
+ def default_help(obj, cmd=None):
271
+ "The default help functionality in plac interpreters"
272
+ parser = plac_core.parser_from(obj)
273
+ if cmd is None:
274
+ yield parser.format_help()
275
+ return
276
+ subp = parser.subparsers._name_parser_map.get(cmd)
277
+ if subp is None:
278
+ yield _('Unknown command %s' % cmd)
279
+ elif getattr(obj, '_interact_', False): # in interactive mode
280
+ formatter = subp._get_formatter()
281
+ formatter._prog = cmd # remove the program name from the usage
282
+ formatter.add_usage(
283
+ subp.usage, [a for a in subp._actions if a.dest != 'help'],
284
+ subp._mutually_exclusive_groups)
285
+ formatter.add_text(subp.description)
286
+ for action_group in subp._action_groups:
287
+ formatter.start_section(action_group.title)
288
+ formatter.add_text(action_group.description)
289
+ formatter.add_arguments(a for a in action_group._group_actions
290
+ if a.dest != 'help')
291
+ formatter.end_section()
292
+ yield formatter.format_help()
293
+ else: # regular argparse help
294
+ yield subp.format_help()
295
+
296
+ # ######################## import management ############################## #
297
+
298
+ try:
299
+ PLACDIRS = os.environ.get('PLACPATH', '.').split(':')
300
+ except:
301
+ raise ValueError(_('Ill-formed PLACPATH: got %PLACPATHs') % os.environ)
302
+
303
+
304
+ def partial_call(factory, arglist):
305
+ "Call a container factory with the arglist and return a plac object"
306
+ a = plac_core.parser_from(factory).argspec
307
+ if a.defaults or a.varargs or a.varkw:
308
+ raise TypeError('Interpreter.call must be invoked on '
309
+ 'factories with required arguments only')
310
+ required_args = ', '.join(a.args)
311
+ if required_args:
312
+ required_args += ',' # trailing comma
313
+ code = '''def makeobj(interact, %s *args):
314
+ obj = factory(%s)
315
+ obj._interact_ = interact
316
+ obj._args_ = args
317
+ return obj\n''' % (required_args, required_args)
318
+ dic = dict(factory=factory)
319
+ exec_(code, dic)
320
+ makeobj = dic['makeobj']
321
+ makeobj.add_help = False
322
+ if inspect.isclass(factory):
323
+ makeobj.__annotations__ = getattr(
324
+ factory.__init__, '__annotations__', {})
325
+ else:
326
+ makeobj.__annotations__ = getattr(
327
+ factory, '__annotations__', {})
328
+ makeobj.__annotations__['interact'] = (
329
+ 'start interactive interpreter', 'flag', 'i')
330
+ return plac_core.call(makeobj, arglist)
331
+
332
+
333
+ def import_main(path, *args):
334
+ """
335
+ A utility to import the main function of a plac tool. It also
336
+ works with command container factories.
337
+ """
338
+ if ':' in path: # importing a factory
339
+ path, factory_name = path.split(':')
340
+ else: # importing the main function
341
+ factory_name = None
342
+ if not os.path.isabs(path): # relative path, look at PLACDIRS
343
+ for placdir in PLACDIRS:
344
+ fullpath = os.path.join(placdir, path)
345
+ if os.path.exists(fullpath):
346
+ break
347
+ else: # no break
348
+ raise ImportError(_('Cannot find %s' % path))
349
+ else:
350
+ fullpath = path
351
+ name, ext = os.path.splitext(os.path.basename(fullpath))
352
+ module = load_source(name, fullpath)
353
+ if factory_name:
354
+ tool = partial_call(getattr(module, factory_name), args)
355
+ else:
356
+ tool = module.main
357
+ return tool
358
+
359
+ # ############################ Task classes ############################# #
360
+
361
+
362
+ # base class not instantiated directly
363
+ class BaseTask(object):
364
+ """
365
+ A task is a wrapper over a generator object with signature
366
+ Task(no, arglist, genobj), attributes
367
+ .no
368
+ .arglist
369
+ .outlist
370
+ .str
371
+ .etype
372
+ .exc
373
+ .tb
374
+ .status
375
+ and methods .run and .kill.
376
+ """
377
+ STATES = ('SUBMITTED', 'RUNNING', 'TOBEKILLED', 'KILLED', 'FINISHED',
378
+ 'ABORTED')
379
+
380
+ def __init__(self, no, arglist, genobj):
381
+ self.no = no
382
+ self.arglist = arglist
383
+ self._genobj = self._wrap(genobj)
384
+ self.str, self.etype, self.exc, self.tb = '', None, None, None
385
+ self.status = 'SUBMITTED'
386
+ self.outlist = []
387
+
388
+ def notify(self, msg):
389
+ "Notifies the underlying monitor. To be implemented"
390
+
391
+ def _wrap(self, genobj, stringify_tb=False):
392
+ """
393
+ Wrap the genobj into a generator managing the exceptions,
394
+ populating the .outlist, setting the .status and yielding None.
395
+ stringify_tb must be True if the traceback must be sent to a process.
396
+ """
397
+ self.status = 'RUNNING'
398
+ try:
399
+ for value in genobj:
400
+ if self.status == 'TOBEKILLED': # exit from the loop
401
+ raise GeneratorExit
402
+ if value is not None: # add output
403
+ self.outlist.append(value)
404
+ self.notify(decode(value))
405
+ yield
406
+ except Interpreter.Exit: # wanted exit
407
+ self._regular_exit()
408
+ raise
409
+ except (GeneratorExit, TerminatedProcess, KeyboardInterrupt):
410
+ # soft termination
411
+ self.status = 'KILLED'
412
+ except Exception: # unexpected exception
413
+ self.etype, self.exc, tb = sys.exc_info()
414
+ self.tb = ''.join(traceback.format_tb(tb)) if stringify_tb else tb
415
+ self.status = 'ABORTED'
416
+ else:
417
+ self._regular_exit()
418
+
419
+ def _regular_exit(self):
420
+ self.status = 'FINISHED'
421
+ try:
422
+ self.str = '\n'.join(map(decode, self.outlist))
423
+ except IndexError:
424
+ self.str = 'no result'
425
+
426
+ def run(self):
427
+ "Run the inner generator"
428
+ for none in self._genobj:
429
+ pass
430
+
431
+ def kill(self):
432
+ "Set a TOBEKILLED status"
433
+ self.status = 'TOBEKILLED'
434
+
435
+ def wait(self):
436
+ "Wait for the task to finish: to be overridden"
437
+
438
+ @property
439
+ def traceback(self):
440
+ "Return the traceback as a (possibly empty) string"
441
+ if self.tb is None:
442
+ return ''
443
+ elif isinstance(self.tb, (str, bytes)):
444
+ return self.tb
445
+ else:
446
+ return ''.join(traceback.format_tb(self.tb))
447
+
448
+ @property
449
+ def result(self):
450
+ self.wait()
451
+ if self.exc:
452
+ if isinstance(self.tb, (str, bytes)):
453
+ raise self.etype(self.tb)
454
+ else:
455
+ raise_(self.etype, self.exc, self.tb or None)
456
+ if not self.outlist:
457
+ return None
458
+ return self.outlist[-1]
459
+
460
+ def __repr__(self):
461
+ "String representation containing class name, number, arglist, status"
462
+ return '<%s %d [%s] %s>' % (
463
+ self.__class__.__name__, self.no,
464
+ ' '.join(self.arglist), self.status)
465
+
466
+ nulltask = BaseTask(0, [], ('skip' for dummy in (1,)))
467
+
468
+ # ######################## synchronous tasks ############################## #
469
+
470
+
471
+ class SynTask(BaseTask):
472
+ """
473
+ Synchronous task running in the interpreter loop and displaying its
474
+ output as soon as available.
475
+ """
476
+ def __str__(self):
477
+ "Return the output string or the error message"
478
+ if self.etype: # there was an error
479
+ return '%s: %s' % (self.etype.__name__, self.exc)
480
+ else:
481
+ return '\n'.join(map(str, self.outlist))
482
+
483
+
484
+ class ThreadedTask(BaseTask):
485
+ """
486
+ A task running in a separated thread.
487
+ """
488
+ def __init__(self, no, arglist, genobj):
489
+ BaseTask.__init__(self, no, arglist, genobj)
490
+ self.thread = threading.Thread(target=super(ThreadedTask, self).run)
491
+
492
+ def run(self):
493
+ "Run the task into a thread"
494
+ self.thread.start()
495
+
496
+ def wait(self):
497
+ "Block until the thread ends"
498
+ self.thread.join()
499
+
500
+
501
+ # ######################## multiprocessing tasks ######################### #
502
+
503
+ def sharedattr(name, on_error):
504
+ "Return a property to be attached to an MPTask"
505
+ def get(self):
506
+ try:
507
+ return getattr(self.ns, name)
508
+ except: # the process was killed or died hard
509
+ return on_error
510
+
511
+ def set(self, value):
512
+ try:
513
+ setattr(self.ns, name, value)
514
+ except: # the process was killed or died hard
515
+ pass
516
+ return property(get, set)
517
+
518
+
519
+ class MPTask(BaseTask):
520
+ """
521
+ A task running as an external process. The current implementation
522
+ only works on Unix-like systems, where multiprocessing use forks.
523
+ """
524
+ str = sharedattr('str', '')
525
+ etype = sharedattr('etype', None)
526
+ exc = sharedattr('exc', None)
527
+ tb = sharedattr('tb', None)
528
+ status = sharedattr('status', 'ABORTED')
529
+
530
+ @property
531
+ def outlist(self):
532
+ try:
533
+ return self._outlist
534
+ except: # the process died hard
535
+ return []
536
+
537
+ def notify(self, msg):
538
+ self.man.notify_listener(self.no, msg)
539
+
540
+ def __init__(self, no, arglist, genobj, manager):
541
+ """
542
+ The monitor has a .send method and a .man multiprocessing.Manager
543
+ """
544
+ self.no = no
545
+ self.arglist = arglist
546
+ self._genobj = self._wrap(genobj, stringify_tb=True)
547
+ self.man = manager
548
+ self._outlist = manager.mp.list()
549
+ self.ns = manager.mp.Namespace()
550
+ self.status = 'SUBMITTED'
551
+ self.etype, self.exc, self.tb = None, None, None
552
+ self.str = repr(self)
553
+ self.proc = multiprocessing.Process(target=super(MPTask, self).run)
554
+
555
+ def run(self):
556
+ "Run the task into an external process"
557
+ self.proc.start()
558
+
559
+ def wait(self):
560
+ "Block until the external process ends or is killed"
561
+ self.proc.join()
562
+
563
+ def kill(self):
564
+ """Kill the process with a SIGTERM inducing a TerminatedProcess
565
+ exception in the children"""
566
+ self.proc.terminate()
567
+
568
+ # ######################## Task Manager ###################### #
569
+
570
+
571
+ class TaskManager(object):
572
+ """
573
+ Store the given commands into a task registry. Provides methods to
574
+ manage the submitted tasks.
575
+ """
576
+ cmdprefix = '.'
577
+ specialcommands = set(['.last_tb'])
578
+
579
+ def __init__(self, obj):
580
+ self.obj = obj
581
+ self.registry = {} # {taskno : task}
582
+ if obj.mpcommands or obj.thcommands:
583
+ self.specialcommands.update(['.kill', '.list', '.output'])
584
+ interact = getattr(obj, '_interact_', False)
585
+ self.parser = plac_core.parser_from(
586
+ obj, prog='' if interact else None, formatter_class=PlacFormatter)
587
+ HelpSummary.add(obj, self.specialcommands)
588
+ self.man = Manager() if obj.mpcommands else None
589
+ signal.signal(signal.SIGTERM, terminatedProcess)
590
+
591
+ def close(self):
592
+ "Kill all the running tasks"
593
+ for task in self.registry.values():
594
+ try:
595
+ if task.status == 'RUNNING':
596
+ task.kill()
597
+ task.wait()
598
+ except: # task killed, nothing to wait
599
+ pass
600
+ if self.man:
601
+ self.man.stop()
602
+
603
+ def _get_latest(self, taskno=-1, status=None):
604
+ "Get the latest submitted task from the registry"
605
+ assert taskno < 0, 'You must pass a negative number'
606
+ if status:
607
+ tasks = [t for t in self.registry.values()
608
+ if t.status == status]
609
+ else:
610
+ tasks = [t for t in self.registry.values()]
611
+ tasks.sort(key=attrgetter('no'))
612
+ if len(tasks) >= abs(taskno):
613
+ return tasks[taskno]
614
+
615
+ # ########################## special commands ######################## #
616
+
617
+ @plac_core.annotations(
618
+ taskno=('task to kill', 'positional', None, int))
619
+ def kill(self, taskno=-1):
620
+ 'kill the given task (-1 to kill the latest running task)'
621
+ if taskno < 0:
622
+ task = self._get_latest(taskno, status='RUNNING')
623
+ if task is None:
624
+ yield 'Nothing to kill'
625
+ return
626
+ elif taskno not in self.registry:
627
+ yield 'Unknown task %d' % taskno
628
+ return
629
+ else:
630
+ task = self.registry[taskno]
631
+ if task.status in ('ABORTED', 'KILLED', 'FINISHED'):
632
+ yield 'Already finished %s' % task
633
+ return
634
+ task.kill()
635
+ yield task
636
+
637
+ @plac_core.annotations(
638
+ status=('', 'positional', None, str, BaseTask.STATES))
639
+ def list(self, status='RUNNING'):
640
+ 'list tasks with a given status'
641
+ for task in self.registry.values():
642
+ if task.status == status:
643
+ yield task
644
+
645
+ @plac_core.annotations(
646
+ taskno=('task number', 'positional', None, int))
647
+ def output(self, taskno=-1, fname=None):
648
+ 'show the output of a given task (and optionally save it to a file)'
649
+ if taskno < 0:
650
+ task = self._get_latest(taskno)
651
+ if task is None:
652
+ yield 'Nothing to show'
653
+ return
654
+ elif taskno not in self.registry:
655
+ yield 'Unknown task %d' % taskno
656
+ return
657
+ else:
658
+ task = self.registry[taskno]
659
+ outstr = '\n'.join(map(str, task.outlist))
660
+ if fname:
661
+ open(fname, 'w').write(outstr)
662
+ yield 'saved output of %d into %s' % (taskno, fname)
663
+ return
664
+ yield task
665
+ if len(task.outlist) > 20 and use_less:
666
+ less(outstr) # has no meaning for a plac server
667
+ else:
668
+ yield outstr
669
+
670
+ @plac_core.annotations(
671
+ taskno=('task number', 'positional', None, int))
672
+ def last_tb(self, taskno=-1):
673
+ "show the traceback of a given task, if any"
674
+ task = self._get_latest(taskno)
675
+ if task:
676
+ yield task.traceback
677
+ else:
678
+ yield 'Nothing to show'
679
+
680
+ # ########################## SyncProcess ############################# #
681
+
682
+
683
+ class Process(subprocess.Popen):
684
+ "Start the interpreter specified by the params in a subprocess"
685
+
686
+ def __init__(self, params):
687
+ signal.signal(signal.SIGPIPE, signal.SIG_DFL)
688
+ # to avoid broken pipe messages
689
+ code = '''import plac, sys
690
+ sys.argv[0] = '<%s>'
691
+ plac.Interpreter(plac.import_main(*%s)).interact(prompt='i>\\n')
692
+ ''' % (params[0], params)
693
+ subprocess.Popen.__init__(
694
+ self, [sys.executable, '-u', '-c', code],
695
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE)
696
+ self.man = multiprocessing.Manager()
697
+
698
+ def close(self):
699
+ "Close stdin and stdout"
700
+ self.stdin.close()
701
+ self.stdout.close()
702
+ self.man.shutdown()
703
+
704
+ def recv(self): # char-by-char cannot work
705
+ "Return the output of the subprocess, line-by-line until the prompt"
706
+ lines = []
707
+ while True:
708
+ lines.append(self.stdout.readline())
709
+ if lines[-1] == 'i>\n':
710
+ out = ''.join(lines)
711
+ return out[:-1] + ' ' # remove last newline
712
+
713
+ def send(self, line):
714
+ """Send a line (adding a newline) to the underlying subprocess
715
+ and wait for the answer"""
716
+ self.stdin.write(line + os.linesep)
717
+ return self.recv()
718
+
719
+
720
+ class StartStopObject(object):
721
+ started = False
722
+
723
+ def start(self):
724
+ pass
725
+
726
+ def stop(self):
727
+ pass
728
+
729
+
730
+ class Monitor(StartStopObject):
731
+ """
732
+ Base monitor class with methods add_listener/del_listener/notify_listener
733
+ read_queue and and start/stop.
734
+ """
735
+ def __init__(self, name, queue=None):
736
+ self.name = name
737
+ self.queue = queue or multiprocessing.Queue()
738
+
739
+ def add_listener(self, taskno):
740
+ pass
741
+
742
+ def del_listener(self, taskno):
743
+ pass
744
+
745
+ def notify_listener(self, taskno, msg):
746
+ pass
747
+
748
+ def start(self):
749
+ pass
750
+
751
+ def stop(self):
752
+ pass
753
+
754
+ def read_queue(self):
755
+ pass
756
+
757
+
758
+ class Manager(StartStopObject):
759
+ """
760
+ The plac Manager contains a multiprocessing.Manager and a set
761
+ of slave monitor processes to which we can send commands. There
762
+ is a manager for each interpreter with mpcommands.
763
+ """
764
+ def __init__(self):
765
+ self.registry = {}
766
+ self.started = False
767
+ self.mp = None
768
+
769
+ def add(self, monitor):
770
+ 'Add or replace a monitor in the registry'
771
+ proc = multiprocessing.Process(None, monitor.start, monitor.name)
772
+ proc.queue = monitor.queue
773
+ self.registry[monitor.name] = proc
774
+
775
+ def delete(self, name):
776
+ 'Remove a named monitor from the registry'
777
+ del self.registry[name]
778
+
779
+ # can be called more than once
780
+ def start(self):
781
+ if self.mp is None:
782
+ self.mp = multiprocessing.Manager()
783
+ for monitor in self.registry.values():
784
+ monitor.start()
785
+ self.started = True
786
+
787
+ def stop(self):
788
+ for monitor in self.registry.values():
789
+ monitor.queue.close()
790
+ monitor.terminate()
791
+ if self.mp:
792
+ self.mp.shutdown()
793
+ self.mp = None
794
+ self.started = False
795
+
796
+ def notify_listener(self, taskno, msg):
797
+ for monitor in self.registry.values():
798
+ monitor.queue.put(('notify_listener', taskno, msg))
799
+
800
+ def add_listener(self, no):
801
+ for monitor in self.registry.values():
802
+ monitor.queue.put(('add_listener', no))
803
+
804
+ # ######################### plac server ############################# #
805
+
806
+ #
807
+ # Removed in version 1.4.0 due to incompatibility with Python 3.12
808
+ #
809
+ '''
810
+ import asyncore
811
+ import asynchat
812
+ import socket
813
+
814
+ class _AsynHandler(asynchat.async_chat):
815
+ "asynchat handler starting a new interpreter loop for each connection"
816
+
817
+ terminator = '\r\n' # the standard one for telnet
818
+ prompt = 'i> '
819
+
820
+ def __init__(self, socket, interpreter):
821
+ asynchat.async_chat.__init__(self, socket)
822
+ self.set_terminator(self.terminator)
823
+ self.i = interpreter
824
+ self.i.__enter__()
825
+ self.data = []
826
+ self.write(self.prompt)
827
+
828
+ def write(self, data, *args):
829
+ "Push a string back to the client"
830
+ if args:
831
+ data %= args
832
+ if data.endswith('\n') and not data.endswith(self.terminator):
833
+ data = data[:-1] + self.terminator # fix newlines
834
+ self.push(data)
835
+
836
+ def collect_incoming_data(self, data):
837
+ "Collect one character at the time"
838
+ self.data.append(data)
839
+
840
+ def found_terminator(self):
841
+ "Put in the queue the line received from the client"
842
+ line = ''.join(self.data)
843
+ self.log('Received line %r from %s' % (line, self.addr))
844
+ if line == 'EOF':
845
+ self.i.__exit__(None, None, None)
846
+ self.handle_close()
847
+ else:
848
+ task = self.i.submit(line)
849
+ task.run() # synchronous or not
850
+ if task.etype: # manage exception
851
+ error = '%s: %s\nReceived: %s' % (
852
+ task.etype.__name__, task.exc, ' '.join(task.arglist))
853
+ self.log_info(task.traceback + error) # on the server
854
+ self.write(error + self.terminator) # back to the client
855
+ else: # no exception
856
+ self.write(task.str + self.terminator)
857
+ self.data = []
858
+ self.write(self.prompt)
859
+
860
+
861
+ class _AsynServer(asyncore.dispatcher):
862
+ "asyncore-based server spawning AsynHandlers"
863
+
864
+ def __init__(self, interpreter, newhandler, port, listen=5):
865
+ self.interpreter = interpreter
866
+ self.newhandler = newhandler
867
+ self.port = port
868
+ asyncore.dispatcher.__init__(self)
869
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
870
+ self.bind(('', port))
871
+ self.listen(listen)
872
+
873
+ def handle_accept(self):
874
+ clientsock, clientaddr = self.accept()
875
+ self.log('Connected from %s' % str(clientaddr))
876
+ i = self.interpreter.__class__(self.interpreter.obj) # new interpreter
877
+ self.newhandler(clientsock, i) # spawn a new handler
878
+
879
+ '''
880
+
881
+ # ########################## the Interpreter ############################ #
882
+
883
+ class Interpreter(object):
884
+ """
885
+ A context manager with a .send method and a few utility methods:
886
+ execute, test and doctest.
887
+ """
888
+ class Exit(Exception):
889
+ pass
890
+
891
+ def __init__(self, obj, commentchar='#', split=shlex.split):
892
+ self.obj = obj
893
+ try:
894
+ self.name = obj.__module__
895
+ except AttributeError:
896
+ self.name = 'plac'
897
+ self.commentchar = commentchar
898
+ self.split = split
899
+ self._set_commands(obj)
900
+ self.tm = TaskManager(obj)
901
+ self.man = self.tm.man
902
+ self.parser = self.tm.parser
903
+ if self.commands:
904
+ self.parser.addsubcommands(
905
+ self.tm.specialcommands, self.tm, title='special commands')
906
+ if obj.mpcommands:
907
+ self.parser.addsubcommands(
908
+ obj.mpcommands, obj,
909
+ title='commands run in external processes')
910
+ if obj.thcommands:
911
+ self.parser.addsubcommands(
912
+ obj.thcommands, obj, title='threaded commands')
913
+ self.parser.error = lambda msg: sys.exit(msg) # patch the parser
914
+ self._interpreter = None
915
+
916
+ def _set_commands(self, obj):
917
+ "Make sure obj has the right command attributes as Python sets"
918
+ for attrname in ('commands', 'mpcommands', 'thcommands'):
919
+ setattr(self, attrname, set(getattr(self.__class__, attrname, [])))
920
+ setattr(obj, attrname, set(getattr(obj, attrname, [])))
921
+ self.commands = obj.commands
922
+ self.mpcommands.update(obj.mpcommands)
923
+ self.thcommands.update(obj.thcommands)
924
+ if (obj.commands or obj.mpcommands or obj.thcommands) and \
925
+ not hasattr(obj, 'help'): # add default help
926
+ obj.help = default_help.__get__(obj, obj.__class__)
927
+ self.commands.add('help')
928
+
929
+ def __enter__(self):
930
+ "Start the inner interpreter loop"
931
+ self._interpreter = self._make_interpreter()
932
+ self._interpreter.send(None)
933
+ return self
934
+
935
+ def __exit__(self, exctype, exc, tb):
936
+ "Close the inner interpreter and the task manager"
937
+ self.close(exctype, exc, tb)
938
+
939
+ def submit(self, line):
940
+ "Send a line to the underlying interpreter and return a task object"
941
+ if self._interpreter is None:
942
+ raise RuntimeError(_('%r not initialized: probably you forgot to '
943
+ 'use the with statement') % self)
944
+ if isinstance(line, (str, bytes)):
945
+ arglist = self.split(line, self.commentchar)
946
+ else: # expects a list of strings
947
+ arglist = line
948
+ if not arglist:
949
+ return nulltask
950
+ m = self.tm.man # manager
951
+ if m and not m.started:
952
+ m.start()
953
+ task = self._interpreter.send(arglist) # nonblocking
954
+ if not plac_core._match_cmd(arglist[0], self.tm.specialcommands):
955
+ self.tm.registry[task.no] = task
956
+ if m:
957
+ m.add_listener(task.no)
958
+ return task
959
+
960
+ def send(self, line):
961
+ """Send a line to the underlying interpreter and return
962
+ the finished task"""
963
+ task = self.submit(line)
964
+ BaseTask.run(task) # blocking
965
+ return task
966
+
967
+ def tasks(self):
968
+ "The full lists of the submitted tasks"
969
+ return self.tm.registry.values()
970
+
971
+ def close(self, exctype=None, exc=None, tb=None):
972
+ "Can be called to close the interpreter prematurely"
973
+ self.tm.close()
974
+ if exctype is not None:
975
+ self._interpreter.throw(exctype, exc, tb)
976
+ else:
977
+ self._interpreter.close()
978
+
979
+ def _make_interpreter(self):
980
+ "The interpreter main loop, from lists of arguments to task objects"
981
+ enter = getattr(self.obj, '__enter__', lambda: None)
982
+ exit = getattr(self.obj, '__exit__', lambda et, ex, tb: None)
983
+ enter()
984
+ task = None
985
+ try:
986
+ for no in itertools.count(1):
987
+ arglist = yield task
988
+ try:
989
+ cmd, result = self.parser.consume(arglist)
990
+ except SystemExit as e: # for invalid commands
991
+ if e.args == (0,): # raised as sys.exit(0)
992
+ errlist = []
993
+ else:
994
+ errlist = [str(e)]
995
+ task = SynTask(no, arglist, iter(errlist))
996
+ continue
997
+ except: # anything else
998
+ task = SynTask(no, arglist, gen_exc(*sys.exc_info()))
999
+ continue
1000
+ if not plac_core.iterable(result): # atomic result
1001
+ task = SynTask(no, arglist, gen_val(result))
1002
+ elif cmd in self.obj.mpcommands:
1003
+ task = MPTask(no, arglist, result, self.tm.man)
1004
+ elif cmd in self.obj.thcommands:
1005
+ task = ThreadedTask(no, arglist, result)
1006
+ else: # blocking task
1007
+ task = SynTask(no, arglist, result)
1008
+ except GeneratorExit: # regular exit
1009
+ exit(None, None, None)
1010
+ except: # exceptional exit
1011
+ exit(*sys.exc_info())
1012
+ raise
1013
+
1014
+ def check(self, given_input, expected_output):
1015
+ "Make sure you get the expected_output from the given_input"
1016
+ output = self.send(given_input).str # blocking
1017
+ ok = (output == expected_output)
1018
+ if not ok:
1019
+ # the message here is not internationalized on purpose
1020
+ msg = 'input: %s\noutput: %s\nexpected: %s' % (
1021
+ given_input, output, expected_output)
1022
+ raise AssertionError(msg)
1023
+
1024
+ def _parse_doctest(self, lineiter):
1025
+ "Returns the lines of input, the lines of output, and the line number"
1026
+ lines = [line.strip() for line in lineiter]
1027
+ inputs = []
1028
+ positions = []
1029
+ for i, line in enumerate(lines):
1030
+ if line.startswith('i> '):
1031
+ inputs.append(line[3:])
1032
+ positions.append(i)
1033
+ positions.append(len(lines) + 1) # last position
1034
+ outputs = []
1035
+ for i, start in enumerate(positions[:-1]):
1036
+ end = positions[i + 1]
1037
+ outputs.append('\n'.join(lines[start+1:end]))
1038
+ return zip(inputs, outputs, positions)
1039
+
1040
+ def doctest(self, lineiter, verbose=False):
1041
+ """
1042
+ Parse a text containing doctests in a context and tests of all them.
1043
+ Raise an error even if a single doctest if broken. Use this for
1044
+ sequential tests which are logically grouped.
1045
+ """
1046
+ with self:
1047
+ try:
1048
+ for input, output, no in self._parse_doctest(lineiter):
1049
+ if verbose:
1050
+ write('i> %s\n' % input)
1051
+ write('-> %s\n' % output)
1052
+ task = self.send(input) # blocking
1053
+ if not str(task) == output:
1054
+ msg = ('line %d: input: %s\noutput: %s\nexpected: %s\n'
1055
+ % (no + 1, input, task, output))
1056
+ write(msg)
1057
+ if task.exc:
1058
+ raise_(task.etype, task.exc, task.tb)
1059
+ except self.Exit:
1060
+ pass
1061
+
1062
+ def execute(self, lineiter, verbose=False):
1063
+ "Execute a lineiter of commands in a context and print the output"
1064
+ with self:
1065
+ try:
1066
+ for line in lineiter:
1067
+ if verbose:
1068
+ write('i> ' + line)
1069
+ task = self.send(line) # finished task
1070
+ if task.etype: # there was an error
1071
+ raise_(task.etype, task.exc, task.tb)
1072
+ write('%s\n' % task.str)
1073
+ except self.Exit:
1074
+ pass
1075
+
1076
+ def multiline(self, stdin=sys.stdin, terminator=';', verbose=False):
1077
+ "The multiline mode is especially suited for usage with emacs"
1078
+ with self:
1079
+ try:
1080
+ for line in read_long_line(stdin, terminator):
1081
+ task = self.submit(line)
1082
+ task.run()
1083
+ write('%s\n' % task.str)
1084
+ if verbose and task.traceback:
1085
+ write(task.traceback)
1086
+ except self.Exit:
1087
+ pass
1088
+
1089
+ def interact(self, stdin=sys.stdin, prompt='i> ', verbose=False):
1090
+ "Starts an interactive command loop reading commands from the console"
1091
+ try:
1092
+ import readline
1093
+ readline_present = True
1094
+ except ImportError:
1095
+ readline_present = False
1096
+ if stdin is sys.stdin and readline_present: # use readline
1097
+ histfile = os.path.expanduser('~/.%s.history' % self.name)
1098
+ completions = list(self.commands) + list(self.mpcommands) + \
1099
+ list(self.thcommands) + list(self.tm.specialcommands)
1100
+ self.stdin = ReadlineInput(completions, histfile=histfile)
1101
+ else:
1102
+ self.stdin = stdin
1103
+ self.prompt = prompt
1104
+ self.verbose = verbose
1105
+ intro = self.obj.__doc__ or ''
1106
+ write(intro + '\n')
1107
+ with self:
1108
+ self.obj._interact_ = True
1109
+ if self.stdin is sys.stdin: # do not close stdin automatically
1110
+ self._manage_input()
1111
+ else:
1112
+ with self.stdin: # close stdin automatically
1113
+ self._manage_input()
1114
+
1115
+ def _manage_input(self):
1116
+ "Convert input lines into task which are then executed"
1117
+ try:
1118
+ for line in iter(lambda: read_line(self.stdin, self.prompt), ''):
1119
+ line = line.strip()
1120
+ if not line:
1121
+ continue
1122
+ task = self.submit(line)
1123
+ task.run() # synchronous or not
1124
+ write(str(task) + '\n')
1125
+ if self.verbose and task.etype:
1126
+ write(task.traceback)
1127
+ except self.Exit:
1128
+ pass
1129
+
1130
+ def start_server(self, port=2199, **kw):
1131
+ """Starts an asyncore server reading commands for clients and opening
1132
+ a new interpreter for each connection."""
1133
+ _AsynServer(self, _AsynHandler, port) # register the server
1134
+ try:
1135
+ asyncore.loop(**kw)
1136
+ except (KeyboardInterrupt, TerminatedProcess):
1137
+ pass
1138
+ finally:
1139
+ asyncore.close_all()
1140
+
1141
+ def add_monitor(self, mon):
1142
+ self.man.add(mon)
1143
+
1144
+ def del_monitor(self, name):
1145
+ self.man.delete(name)
1146
+
1147
+ @classmethod
1148
+ def call(cls, factory, arglist=sys.argv[1:],
1149
+ commentchar='#', split=shlex.split,
1150
+ stdin=sys.stdin, prompt='i> ', verbose=False):
1151
+ """
1152
+ Call a container factory with the arglist and instantiate an
1153
+ interpreter object. If there are remaining arguments, send them to the
1154
+ interpreter, else start an interactive session.
1155
+ """
1156
+ obj = partial_call(factory, arglist)
1157
+ i = cls(obj, commentchar, split)
1158
+ if i.obj._args_:
1159
+ with i:
1160
+ task = i.send(i.obj._args_) # synchronous
1161
+ if task.exc:
1162
+ raise_(task.etype, task.exc, task.tb)
1163
+ out = str(task)
1164
+ if out:
1165
+ print(out)
1166
+ elif i.obj._interact_:
1167
+ i.interact(stdin, prompt, verbose)
1168
+ else:
1169
+ i.parser.print_usage()
1170
+
1171
+ # ################################## runp ################################### #
1172
+
1173
+
1174
+ class _TaskLauncher(object):
1175
+ "Helper for runp"
1176
+
1177
+ def __init__(self, genseq, mode):
1178
+ if mode == 'p':
1179
+ self.mpcommands = ['rungen']
1180
+ else:
1181
+ self.thcommands = ['rungen']
1182
+ self.genlist = list(genseq)
1183
+
1184
+ def rungen(self, i):
1185
+ for out in self.genlist[int(i) - 1]:
1186
+ yield out
1187
+
1188
+
1189
+ def runp(genseq, mode='p'):
1190
+ """Run a sequence of generators in parallel. Mode can be 'p' (use processes)
1191
+ or 't' (use threads). After all of them are finished, return a list of
1192
+ task objects.
1193
+ """
1194
+ assert mode in 'pt', mode
1195
+ launcher = _TaskLauncher(genseq, mode)
1196
+ res = []
1197
+ with Interpreter(launcher) as inter:
1198
+ for i in range(len(launcher.genlist)):
1199
+ inter.submit('rungen %d' % (i + 1)).run()
1200
+ for task in inter.tasks():
1201
+ try:
1202
+ res.append(task.result)
1203
+ except Exception as e:
1204
+ res.append(e)
1205
+ return res
.venv/Lib/site-packages/plac_tk.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ if sys.version_info < (3,):
5
+ import Queue as queue
6
+ else:
7
+ import queue
8
+ import plac_core
9
+ from Tkinter import Tk
10
+ from ScrolledText import ScrolledText
11
+ from plac_ext import Monitor, TerminatedProcess
12
+
13
+
14
+ class TkMonitor(Monitor):
15
+ """
16
+ An interface over a dictionary {taskno: scrolledtext widget}, with
17
+ methods add_listener, del_listener, notify_listener and start/stop.
18
+ """
19
+ def __init__(self, name, queue=None):
20
+ Monitor.__init__(self, name, queue)
21
+ self.widgets = {}
22
+
23
+ @plac_core.annotations(taskno=('task number', 'positional', None, int))
24
+ def add_listener(self, taskno):
25
+ "There is a ScrolledText for each task"
26
+ st = ScrolledText(self.root, height=5)
27
+ st.insert('end', 'Output of task %d\n' % taskno)
28
+ st.pack()
29
+ self.widgets[taskno] = st
30
+
31
+ @plac_core.annotations(taskno=('task number', 'positional', None, int))
32
+ def del_listener(self, taskno):
33
+ del self.widgets[taskno]
34
+
35
+ @plac_core.annotations(taskno=('task number', 'positional', None, int))
36
+ def notify_listener(self, taskno, msg):
37
+ w = self.widgets[taskno]
38
+ w.insert('end', msg + '\n')
39
+ w.update()
40
+
41
+ def start(self):
42
+ 'Start the mainloop'
43
+ self.root = Tk()
44
+ self.root.title(self.name)
45
+ self.root.wm_protocol("WM_DELETE_WINDOW", self.stop)
46
+ self.root.after(0, self.read_queue)
47
+ try:
48
+ self.root.mainloop()
49
+ except KeyboardInterrupt:
50
+ print('Process %d killed by CTRL-C' % os.getpid(), file=sys.stderr)
51
+ except TerminatedProcess:
52
+ pass
53
+
54
+ def stop(self):
55
+ self.root.quit()
56
+
57
+ def read_queue(self):
58
+ try:
59
+ cmd_args = self.queue.get_nowait()
60
+ except queue.Empty:
61
+ pass
62
+ else:
63
+ getattr(self, cmd_args[0])(*cmd_args[1:])
64
+ self.root.after(100, self.read_queue)
.venv/Lib/site-packages/pylab.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from matplotlib.pylab import * # noqa: F401, F403
2
+ import matplotlib.pylab
3
+ __doc__ = matplotlib.pylab.__doc__