ayousanz commited on
Commit
27712eb
·
verified ·
1 Parent(s): 373d77f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/transformers-4.47.0.dist-info/INSTALLER +1 -0
  2. .venv/Lib/site-packages/transformers-4.47.0.dist-info/LICENSE +203 -0
  3. .venv/Lib/site-packages/transformers-4.47.0.dist-info/entry_points.txt +2 -0
  4. .venv/Lib/site-packages/transformers/__pycache__/__init__.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/transformers/__pycache__/modeling_utils.cpython-39.pyc +0 -0
  6. .venv/Lib/site-packages/transformers/__pycache__/pytorch_utils.cpython-39.pyc +0 -0
  7. .venv/Lib/site-packages/transformers/__pycache__/safetensors_conversion.cpython-39.pyc +0 -0
  8. .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
  9. .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_base.cpython-39.pyc +0 -0
  10. .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_fast.cpython-39.pyc +0 -0
  11. .venv/Lib/site-packages/transformers/pipelines/feature_extraction.py +86 -0
  12. .venv/Lib/site-packages/transformers/pipelines/fill_mask.py +273 -0
  13. .venv/Lib/site-packages/transformers/pipelines/image_classification.py +226 -0
  14. .venv/Lib/site-packages/transformers/pipelines/image_feature_extraction.py +112 -0
  15. .venv/Lib/site-packages/transformers/pipelines/image_segmentation.py +220 -0
  16. .venv/Lib/site-packages/transformers/pipelines/image_text_to_text.py +432 -0
  17. .venv/Lib/site-packages/transformers/pipelines/image_to_image.py +136 -0
  18. .venv/Lib/site-packages/transformers/pipelines/image_to_text.py +216 -0
  19. .venv/Lib/site-packages/transformers/pipelines/mask_generation.py +287 -0
  20. .venv/Lib/site-packages/transformers/pipelines/object_detection.py +191 -0
  21. .venv/Lib/site-packages/transformers/utils/__init__.py +315 -0
  22. .venv/Lib/site-packages/transformers/utils/__pycache__/backbone_utils.cpython-39.pyc +0 -0
  23. .venv/Lib/site-packages/transformers/utils/__pycache__/chat_template_utils.cpython-39.pyc +0 -0
  24. .venv/Lib/site-packages/transformers/utils/__pycache__/constants.cpython-39.pyc +0 -0
  25. .venv/Lib/site-packages/transformers/utils/__pycache__/deprecation.cpython-39.pyc +0 -0
  26. .venv/Lib/site-packages/transformers/utils/__pycache__/doc.cpython-39.pyc +0 -0
  27. .venv/Lib/site-packages/transformers/utils/quantization_config.py +1344 -0
  28. .venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2.py +1511 -0
  29. .venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2_new.py +48 -0
  30. .venv/Lib/site-packages/transformers/utils/versions.py +117 -0
  31. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Abidjan +0 -0
  32. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Accra +0 -0
  33. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Addis_Ababa +0 -0
  34. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Algiers +0 -0
  35. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmara +0 -0
  36. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmera +0 -0
  37. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bamako +0 -0
  38. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bangui +0 -0
  39. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Dar_es_Salaam +0 -0
  40. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Djibouti +0 -0
  41. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Douala +0 -0
  42. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/El_Aaiun +0 -0
  43. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Freetown +0 -0
  44. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Gaborone +0 -0
  45. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Harare +0 -0
  46. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Johannesburg +0 -0
  47. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Juba +0 -0
  48. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kampala +0 -0
  49. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Khartoum +0 -0
  50. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kigali +0 -0
.venv/Lib/site-packages/transformers-4.47.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ uv
.venv/Lib/site-packages/transformers-4.47.0.dist-info/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2018- The Hugging Face team. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
.venv/Lib/site-packages/transformers-4.47.0.dist-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ transformers-cli = transformers.commands.transformers_cli:main
.venv/Lib/site-packages/transformers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (166 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/modeling_utils.cpython-39.pyc ADDED
Binary file (169 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/pytorch_utils.cpython-39.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/safetensors_conversion.cpython-39.pyc ADDED
Binary file (3.3 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils.cpython-39.pyc ADDED
Binary file (32.4 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_base.cpython-39.pyc ADDED
Binary file (149 kB). View file
 
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_fast.cpython-39.pyc ADDED
Binary file (27.7 kB). View file
 
.venv/Lib/site-packages/transformers/pipelines/feature_extraction.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ..utils import add_end_docstrings
4
+ from .base import GenericTensor, Pipeline, build_pipeline_init_args
5
+
6
+
7
+ @add_end_docstrings(
8
+ build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
9
+ r"""
10
+ tokenize_kwargs (`dict`, *optional*):
11
+ Additional dictionary of keyword arguments passed along to the tokenizer.
12
+ return_tensors (`bool`, *optional*):
13
+ If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
14
+ )
15
+ class FeatureExtractionPipeline(Pipeline):
16
+ """
17
+ Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
18
+ transformer, which can be used as features in downstream tasks.
19
+
20
+ Example:
21
+
22
+ ```python
23
+ >>> from transformers import pipeline
24
+
25
+ >>> extractor = pipeline(model="google-bert/bert-base-uncased", task="feature-extraction")
26
+ >>> result = extractor("This is a simple test.", return_tensors=True)
27
+ >>> result.shape # This is a tensor of shape [1, sequence_length, hidden_dimension] representing the input string.
28
+ torch.Size([1, 8, 768])
29
+ ```
30
+
31
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
32
+
33
+ This feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
34
+ `"feature-extraction"`.
35
+
36
+ All models may be used for this pipeline. See a list of all models, including community-contributed models on
37
+ [huggingface.co/models](https://huggingface.co/models).
38
+ """
39
+
40
+ def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
41
+ if tokenize_kwargs is None:
42
+ tokenize_kwargs = {}
43
+
44
+ if truncation is not None:
45
+ if "truncation" in tokenize_kwargs:
46
+ raise ValueError(
47
+ "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
48
+ )
49
+ tokenize_kwargs["truncation"] = truncation
50
+
51
+ preprocess_params = tokenize_kwargs
52
+
53
+ postprocess_params = {}
54
+ if return_tensors is not None:
55
+ postprocess_params["return_tensors"] = return_tensors
56
+
57
+ return preprocess_params, {}, postprocess_params
58
+
59
+ def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
60
+ model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
61
+ return model_inputs
62
+
63
+ def _forward(self, model_inputs):
64
+ model_outputs = self.model(**model_inputs)
65
+ return model_outputs
66
+
67
+ def postprocess(self, model_outputs, return_tensors=False):
68
+ # [0] is the first available tensor, logits or last_hidden_state.
69
+ if return_tensors:
70
+ return model_outputs[0]
71
+ if self.framework == "pt":
72
+ return model_outputs[0].tolist()
73
+ elif self.framework == "tf":
74
+ return model_outputs[0].numpy().tolist()
75
+
76
+ def __call__(self, *args, **kwargs):
77
+ """
78
+ Extract the features of the input(s).
79
+
80
+ Args:
81
+ args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
82
+
83
+ Return:
84
+ A nested list of `float`: The features computed by the model.
85
+ """
86
+ return super().__call__(*args, **kwargs)
.venv/Lib/site-packages/transformers/pipelines/fill_mask.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+
5
+ from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
6
+ from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args
7
+
8
+
9
+ if is_tf_available():
10
+ import tensorflow as tf
11
+
12
+ from ..tf_utils import stable_softmax
13
+
14
+
15
+ if is_torch_available():
16
+ import torch
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ @add_end_docstrings(
23
+ build_pipeline_init_args(has_tokenizer=True),
24
+ r"""
25
+ top_k (`int`, *optional*, defaults to 5):
26
+ The number of predictions to return.
27
+ targets (`str` or `List[str]`, *optional*):
28
+ When passed, the model will limit the scores to the passed targets instead of looking up in the whole
29
+ vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
30
+ token will be used (with a warning, and that might be slower).
31
+ tokenizer_kwargs (`dict`, *optional*):
32
+ Additional dictionary of keyword arguments passed along to the tokenizer.""",
33
+ )
34
+ class FillMaskPipeline(Pipeline):
35
+ """
36
+ Masked language modeling prediction pipeline using any `ModelWithLMHead`. See the [masked language modeling
37
+ examples](../task_summary#masked-language-modeling) for more information.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ >>> from transformers import pipeline
43
+
44
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
45
+ >>> fill_masker("This is a simple [MASK].")
46
+ [{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}]
47
+ ```
48
+
49
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
50
+
51
+ This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier:
52
+ `"fill-mask"`.
53
+
54
+ The models that this pipeline can use are models that have been trained with a masked language modeling objective,
55
+ which includes the bi-directional models in the library. See the up-to-date list of available models on
56
+ [huggingface.co/models](https://huggingface.co/models?filter=fill-mask).
57
+
58
+ <Tip>
59
+
60
+ This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
61
+ masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
62
+ joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
63
+
64
+ </Tip>
65
+
66
+ <Tip>
67
+
68
+ This pipeline now supports tokenizer_kwargs. For example try:
69
+
70
+ ```python
71
+ >>> from transformers import pipeline
72
+
73
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
74
+ >>> tokenizer_kwargs = {"truncation": True}
75
+ >>> fill_masker(
76
+ ... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
77
+ ... tokenizer_kwargs=tokenizer_kwargs,
78
+ ... )
79
+ ```
80
+
81
+
82
+ </Tip>
83
+
84
+
85
+ """
86
+
87
+ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
88
+ if self.framework == "tf":
89
+ masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
90
+ elif self.framework == "pt":
91
+ masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
92
+ else:
93
+ raise ValueError("Unsupported framework")
94
+ return masked_index
95
+
96
+ def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
97
+ masked_index = self.get_masked_index(input_ids)
98
+ numel = np.prod(masked_index.shape)
99
+ if numel < 1:
100
+ raise PipelineException(
101
+ "fill-mask",
102
+ self.model.base_model_prefix,
103
+ f"No mask_token ({self.tokenizer.mask_token}) found on the input",
104
+ )
105
+
106
+ def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
107
+ if isinstance(model_inputs, list):
108
+ for model_input in model_inputs:
109
+ self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
110
+ else:
111
+ for input_ids in model_inputs["input_ids"]:
112
+ self._ensure_exactly_one_mask_token(input_ids)
113
+
114
+ def preprocess(
115
+ self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
116
+ ) -> Dict[str, GenericTensor]:
117
+ if return_tensors is None:
118
+ return_tensors = self.framework
119
+ if tokenizer_kwargs is None:
120
+ tokenizer_kwargs = {}
121
+
122
+ model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
123
+ self.ensure_exactly_one_mask_token(model_inputs)
124
+ return model_inputs
125
+
126
+ def _forward(self, model_inputs):
127
+ model_outputs = self.model(**model_inputs)
128
+ model_outputs["input_ids"] = model_inputs["input_ids"]
129
+ return model_outputs
130
+
131
+ def postprocess(self, model_outputs, top_k=5, target_ids=None):
132
+ # Cap top_k if there are targets
133
+ if target_ids is not None and target_ids.shape[0] < top_k:
134
+ top_k = target_ids.shape[0]
135
+ input_ids = model_outputs["input_ids"][0]
136
+ outputs = model_outputs["logits"]
137
+
138
+ if self.framework == "tf":
139
+ masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
140
+
141
+ outputs = outputs.numpy()
142
+
143
+ logits = outputs[0, masked_index, :]
144
+ probs = stable_softmax(logits, axis=-1)
145
+ if target_ids is not None:
146
+ probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
147
+ probs = tf.expand_dims(probs, 0)
148
+
149
+ topk = tf.math.top_k(probs, k=top_k)
150
+ values, predictions = topk.values.numpy(), topk.indices.numpy()
151
+ else:
152
+ masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
153
+ # Fill mask pipeline supports only one ${mask_token} per sample
154
+
155
+ logits = outputs[0, masked_index, :]
156
+ probs = logits.softmax(dim=-1)
157
+ if target_ids is not None:
158
+ probs = probs[..., target_ids]
159
+
160
+ values, predictions = probs.topk(top_k)
161
+
162
+ result = []
163
+ single_mask = values.shape[0] == 1
164
+ for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
165
+ row = []
166
+ for v, p in zip(_values, _predictions):
167
+ # Copy is important since we're going to modify this array in place
168
+ tokens = input_ids.numpy().copy()
169
+ if target_ids is not None:
170
+ p = target_ids[p].tolist()
171
+
172
+ tokens[masked_index[i]] = p
173
+ # Filter padding out:
174
+ tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
175
+ # Originally we skip special tokens to give readable output.
176
+ # For multi masks though, the other [MASK] would be removed otherwise
177
+ # making the output look odd, so we add them back
178
+ sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
179
+ proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
180
+ row.append(proposition)
181
+ result.append(row)
182
+ if single_mask:
183
+ return result[0]
184
+ return result
185
+
186
+ def get_target_ids(self, targets, top_k=None):
187
+ if isinstance(targets, str):
188
+ targets = [targets]
189
+ try:
190
+ vocab = self.tokenizer.get_vocab()
191
+ except Exception:
192
+ vocab = {}
193
+ target_ids = []
194
+ for target in targets:
195
+ id_ = vocab.get(target, None)
196
+ if id_ is None:
197
+ input_ids = self.tokenizer(
198
+ target,
199
+ add_special_tokens=False,
200
+ return_attention_mask=False,
201
+ return_token_type_ids=False,
202
+ max_length=1,
203
+ truncation=True,
204
+ )["input_ids"]
205
+ if len(input_ids) == 0:
206
+ logger.warning(
207
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
208
+ "We cannot replace it with anything meaningful, ignoring it"
209
+ )
210
+ continue
211
+ id_ = input_ids[0]
212
+ # XXX: If users encounter this pass
213
+ # it becomes pretty slow, so let's make sure
214
+ # The warning enables them to fix the input to
215
+ # get faster performance.
216
+ logger.warning(
217
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
218
+ f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
219
+ )
220
+ target_ids.append(id_)
221
+ target_ids = list(set(target_ids))
222
+ if len(target_ids) == 0:
223
+ raise ValueError("At least one target must be provided when passed.")
224
+ target_ids = np.array(target_ids)
225
+ return target_ids
226
+
227
+ def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
228
+ preprocess_params = {}
229
+
230
+ if tokenizer_kwargs is not None:
231
+ preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
232
+
233
+ postprocess_params = {}
234
+
235
+ if targets is not None:
236
+ target_ids = self.get_target_ids(targets, top_k)
237
+ postprocess_params["target_ids"] = target_ids
238
+
239
+ if top_k is not None:
240
+ postprocess_params["top_k"] = top_k
241
+
242
+ if self.tokenizer.mask_token_id is None:
243
+ raise PipelineException(
244
+ "fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
245
+ )
246
+ return preprocess_params, {}, postprocess_params
247
+
248
+ def __call__(self, inputs, **kwargs):
249
+ """
250
+ Fill the masked token in the text(s) given as inputs.
251
+
252
+ Args:
253
+ inputs (`str` or `List[str]`):
254
+ One or several texts (or one list of prompts) with masked tokens.
255
+ targets (`str` or `List[str]`, *optional*):
256
+ When passed, the model will limit the scores to the passed targets instead of looking up in the whole
257
+ vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
258
+ resulting token will be used (with a warning, and that might be slower).
259
+ top_k (`int`, *optional*):
260
+ When passed, overrides the number of predictions to return.
261
+
262
+ Return:
263
+ A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:
264
+
265
+ - **sequence** (`str`) -- The corresponding input with the mask token prediction.
266
+ - **score** (`float`) -- The corresponding probability.
267
+ - **token** (`int`) -- The predicted token id (to replace the masked one).
268
+ - **token_str** (`str`) -- The predicted token (to replace the masked one).
269
+ """
270
+ outputs = super().__call__(inputs, **kwargs)
271
+ if isinstance(inputs, list) and len(inputs) == 1:
272
+ return outputs[0]
273
+ return outputs
.venv/Lib/site-packages/transformers/pipelines/image_classification.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Union
15
+
16
+ import numpy as np
17
+
18
+ from ..utils import (
19
+ ExplicitEnum,
20
+ add_end_docstrings,
21
+ is_tf_available,
22
+ is_torch_available,
23
+ is_vision_available,
24
+ logging,
25
+ requires_backends,
26
+ )
27
+ from .base import Pipeline, build_pipeline_init_args
28
+
29
+
30
+ if is_vision_available():
31
+ from PIL import Image
32
+
33
+ from ..image_utils import load_image
34
+
35
+ if is_tf_available():
36
+ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
37
+
38
+ if is_torch_available():
39
+ import torch
40
+
41
+ from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # Copied from transformers.pipelines.text_classification.sigmoid
47
+ def sigmoid(_outputs):
48
+ return 1.0 / (1.0 + np.exp(-_outputs))
49
+
50
+
51
+ # Copied from transformers.pipelines.text_classification.softmax
52
+ def softmax(_outputs):
53
+ maxes = np.max(_outputs, axis=-1, keepdims=True)
54
+ shifted_exp = np.exp(_outputs - maxes)
55
+ return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
56
+
57
+
58
+ # Copied from transformers.pipelines.text_classification.ClassificationFunction
59
+ class ClassificationFunction(ExplicitEnum):
60
+ SIGMOID = "sigmoid"
61
+ SOFTMAX = "softmax"
62
+ NONE = "none"
63
+
64
+
65
+ @add_end_docstrings(
66
+ build_pipeline_init_args(has_image_processor=True),
67
+ r"""
68
+ function_to_apply (`str`, *optional*, defaults to `"default"`):
69
+ The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
70
+
71
+ - `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
72
+ has several labels, will apply the softmax function on the output.
73
+ - `"sigmoid"`: Applies the sigmoid function on the output.
74
+ - `"softmax"`: Applies the softmax function on the output.
75
+ - `"none"`: Does not apply any function on the output.""",
76
+ )
77
+ class ImageClassificationPipeline(Pipeline):
78
+ """
79
+ Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
80
+ image.
81
+
82
+ Example:
83
+
84
+ ```python
85
+ >>> from transformers import pipeline
86
+
87
+ >>> classifier = pipeline(model="microsoft/beit-base-patch16-224-pt22k-ft22k")
88
+ >>> classifier("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
89
+ [{'score': 0.442, 'label': 'macaw'}, {'score': 0.088, 'label': 'popinjay'}, {'score': 0.075, 'label': 'parrot'}, {'score': 0.073, 'label': 'parodist, lampooner'}, {'score': 0.046, 'label': 'poll, poll_parrot'}]
90
+ ```
91
+
92
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
93
+
94
+ This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
95
+ `"image-classification"`.
96
+
97
+ See the list of available models on
98
+ [huggingface.co/models](https://huggingface.co/models?filter=image-classification).
99
+ """
100
+
101
+ function_to_apply: ClassificationFunction = ClassificationFunction.NONE
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ requires_backends(self, "vision")
106
+ self.check_model_type(
107
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
108
+ if self.framework == "tf"
109
+ else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
110
+ )
111
+
112
+ def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
113
+ preprocess_params = {}
114
+ if timeout is not None:
115
+ preprocess_params["timeout"] = timeout
116
+ postprocess_params = {}
117
+ if top_k is not None:
118
+ postprocess_params["top_k"] = top_k
119
+ if isinstance(function_to_apply, str):
120
+ function_to_apply = ClassificationFunction(function_to_apply.lower())
121
+ if function_to_apply is not None:
122
+ postprocess_params["function_to_apply"] = function_to_apply
123
+ return preprocess_params, {}, postprocess_params
124
+
125
+ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
126
+ """
127
+ Assign labels to the image(s) passed as inputs.
128
+
129
+ Args:
130
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
131
+ The pipeline handles three types of images:
132
+
133
+ - A string containing a http link pointing to an image
134
+ - A string containing a local path to an image
135
+ - An image loaded in PIL directly
136
+
137
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
138
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
139
+ images.
140
+ function_to_apply (`str`, *optional*, defaults to `"default"`):
141
+ The function to apply to the model outputs in order to retrieve the scores. Accepts four different
142
+ values:
143
+
144
+ If this argument is not specified, then it will apply the following functions according to the number
145
+ of labels:
146
+
147
+ - If the model has a single label, will apply the sigmoid function on the output.
148
+ - If the model has several labels, will apply the softmax function on the output.
149
+
150
+ Possible values are:
151
+
152
+ - `"sigmoid"`: Applies the sigmoid function on the output.
153
+ - `"softmax"`: Applies the softmax function on the output.
154
+ - `"none"`: Does not apply any function on the output.
155
+ top_k (`int`, *optional*, defaults to 5):
156
+ The number of top labels that will be returned by the pipeline. If the provided number is higher than
157
+ the number of labels available in the model configuration, it will default to the number of labels.
158
+ timeout (`float`, *optional*, defaults to None):
159
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
160
+ the call may block forever.
161
+
162
+ Return:
163
+ A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
164
+ dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
165
+ the images.
166
+
167
+ The dictionaries contain the following keys:
168
+
169
+ - **label** (`str`) -- The label identified by the model.
170
+ - **score** (`int`) -- The score attributed by the model for that label.
171
+ """
172
+ # After deprecation of this is completed, remove the default `None` value for `images`
173
+ if "images" in kwargs:
174
+ inputs = kwargs.pop("images")
175
+ if inputs is None:
176
+ raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
177
+ return super().__call__(inputs, **kwargs)
178
+
179
+ def preprocess(self, image, timeout=None):
180
+ image = load_image(image, timeout=timeout)
181
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
182
+ if self.framework == "pt":
183
+ model_inputs = model_inputs.to(self.torch_dtype)
184
+ return model_inputs
185
+
186
+ def _forward(self, model_inputs):
187
+ model_outputs = self.model(**model_inputs)
188
+ return model_outputs
189
+
190
+ def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
191
+ if function_to_apply is None:
192
+ if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
193
+ function_to_apply = ClassificationFunction.SIGMOID
194
+ elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
195
+ function_to_apply = ClassificationFunction.SOFTMAX
196
+ elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
197
+ function_to_apply = self.model.config.function_to_apply
198
+ else:
199
+ function_to_apply = ClassificationFunction.NONE
200
+
201
+ if top_k > self.model.config.num_labels:
202
+ top_k = self.model.config.num_labels
203
+
204
+ outputs = model_outputs["logits"][0]
205
+ if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
206
+ outputs = outputs.to(torch.float32).numpy()
207
+ else:
208
+ outputs = outputs.numpy()
209
+
210
+ if function_to_apply == ClassificationFunction.SIGMOID:
211
+ scores = sigmoid(outputs)
212
+ elif function_to_apply == ClassificationFunction.SOFTMAX:
213
+ scores = softmax(outputs)
214
+ elif function_to_apply == ClassificationFunction.NONE:
215
+ scores = outputs
216
+ else:
217
+ raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
218
+
219
+ dict_scores = [
220
+ {"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
221
+ ]
222
+ dict_scores.sort(key=lambda x: x["score"], reverse=True)
223
+ if top_k is not None:
224
+ dict_scores = dict_scores[:top_k]
225
+
226
+ return dict_scores
.venv/Lib/site-packages/transformers/pipelines/image_feature_extraction.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ..utils import add_end_docstrings, is_vision_available
4
+ from .base import GenericTensor, Pipeline, build_pipeline_init_args
5
+
6
+
7
+ if is_vision_available():
8
+ from ..image_utils import load_image
9
+
10
+
11
+ @add_end_docstrings(
12
+ build_pipeline_init_args(has_image_processor=True),
13
+ """
14
+ image_processor_kwargs (`dict`, *optional*):
15
+ Additional dictionary of keyword arguments passed along to the image processor e.g.
16
+ {"size": {"height": 100, "width": 100}}
17
+ pool (`bool`, *optional*, defaults to `False`):
18
+ Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
19
+ """,
20
+ )
21
+ class ImageFeatureExtractionPipeline(Pipeline):
22
+ """
23
+ Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
24
+ transformer, which can be used as features in downstream tasks.
25
+
26
+ Example:
27
+
28
+ ```python
29
+ >>> from transformers import pipeline
30
+
31
+ >>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
32
+ >>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
33
+ >>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
34
+ torch.Size([1, 197, 768])
35
+ ```
36
+
37
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
38
+
39
+ This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
40
+ `"image-feature-extraction"`.
41
+
42
+ All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
43
+ [huggingface.co/models](https://huggingface.co/models).
44
+ """
45
+
46
+ def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
47
+ preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
48
+
49
+ postprocess_params = {}
50
+ if pool is not None:
51
+ postprocess_params["pool"] = pool
52
+ if return_tensors is not None:
53
+ postprocess_params["return_tensors"] = return_tensors
54
+
55
+ if "timeout" in kwargs:
56
+ preprocess_params["timeout"] = kwargs["timeout"]
57
+
58
+ return preprocess_params, {}, postprocess_params
59
+
60
+ def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
61
+ image = load_image(image, timeout=timeout)
62
+ model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
63
+ if self.framework == "pt":
64
+ model_inputs = model_inputs.to(self.torch_dtype)
65
+ return model_inputs
66
+
67
+ def _forward(self, model_inputs):
68
+ model_outputs = self.model(**model_inputs)
69
+ return model_outputs
70
+
71
+ def postprocess(self, model_outputs, pool=None, return_tensors=False):
72
+ pool = pool if pool is not None else False
73
+
74
+ if pool:
75
+ if "pooler_output" not in model_outputs:
76
+ raise ValueError(
77
+ "No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
78
+ )
79
+ outputs = model_outputs["pooler_output"]
80
+ else:
81
+ # [0] is the first available tensor, logits or last_hidden_state.
82
+ outputs = model_outputs[0]
83
+
84
+ if return_tensors:
85
+ return outputs
86
+ if self.framework == "pt":
87
+ return outputs.tolist()
88
+ elif self.framework == "tf":
89
+ return outputs.numpy().tolist()
90
+
91
+ def __call__(self, *args, **kwargs):
92
+ """
93
+ Extract the features of the input(s).
94
+
95
+ Args:
96
+ images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
97
+ The pipeline handles three types of images:
98
+
99
+ - A string containing a http link pointing to an image
100
+ - A string containing a local path to an image
101
+ - An image loaded in PIL directly
102
+
103
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
104
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
105
+ images.
106
+ timeout (`float`, *optional*, defaults to None):
107
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
108
+ the call may block forever.
109
+ Return:
110
+ A nested list of `float`: The features computed by the model.
111
+ """
112
+ return super().__call__(*args, **kwargs)
.venv/Lib/site-packages/transformers/pipelines/image_segmentation.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ import numpy as np
4
+
5
+ from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
6
+ from .base import Pipeline, build_pipeline_init_args
7
+
8
+
9
+ if is_vision_available():
10
+ from PIL import Image
11
+
12
+ from ..image_utils import load_image
13
+
14
+ if is_torch_available():
15
+ from ..models.auto.modeling_auto import (
16
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
17
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
18
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
19
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
20
+ )
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ Prediction = Dict[str, Any]
27
+ Predictions = List[Prediction]
28
+
29
+
30
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
31
+ class ImageSegmentationPipeline(Pipeline):
32
+ """
33
+ Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
34
+ their classes.
35
+
36
+ Example:
37
+
38
+ ```python
39
+ >>> from transformers import pipeline
40
+
41
+ >>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
42
+ >>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
43
+ >>> len(segments)
44
+ 2
45
+
46
+ >>> segments[0]["label"]
47
+ 'bird'
48
+
49
+ >>> segments[1]["label"]
50
+ 'bird'
51
+
52
+ >>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
53
+ <class 'PIL.Image.Image'>
54
+
55
+ >>> segments[0]["mask"].size
56
+ (768, 512)
57
+ ```
58
+
59
+
60
+ This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
61
+ `"image-segmentation"`.
62
+
63
+ See the list of available models on
64
+ [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
65
+ """
66
+
67
+ def __init__(self, *args, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+
70
+ if self.framework == "tf":
71
+ raise ValueError(f"The {self.__class__} is only available in PyTorch.")
72
+
73
+ requires_backends(self, "vision")
74
+ mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
75
+ mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
76
+ mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
77
+ mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
78
+ self.check_model_type(mapping)
79
+
80
+ def _sanitize_parameters(self, **kwargs):
81
+ preprocess_kwargs = {}
82
+ postprocess_kwargs = {}
83
+ if "subtask" in kwargs:
84
+ postprocess_kwargs["subtask"] = kwargs["subtask"]
85
+ preprocess_kwargs["subtask"] = kwargs["subtask"]
86
+ if "threshold" in kwargs:
87
+ postprocess_kwargs["threshold"] = kwargs["threshold"]
88
+ if "mask_threshold" in kwargs:
89
+ postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
90
+ if "overlap_mask_area_threshold" in kwargs:
91
+ postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
92
+ if "timeout" in kwargs:
93
+ preprocess_kwargs["timeout"] = kwargs["timeout"]
94
+
95
+ return preprocess_kwargs, {}, postprocess_kwargs
96
+
97
+ def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
98
+ """
99
+ Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
100
+
101
+ Args:
102
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
103
+ The pipeline handles three types of images:
104
+
105
+ - A string containing an HTTP(S) link pointing to an image
106
+ - A string containing a local path to an image
107
+ - An image loaded in PIL directly
108
+
109
+ The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
110
+ same format: all as HTTP(S) links, all as local paths, or all as PIL images.
111
+ subtask (`str`, *optional*):
112
+ Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
113
+ capabilities. If not set, the pipeline will attempt tp resolve in the following order:
114
+ `panoptic`, `instance`, `semantic`.
115
+ threshold (`float`, *optional*, defaults to 0.9):
116
+ Probability threshold to filter out predicted masks.
117
+ mask_threshold (`float`, *optional*, defaults to 0.5):
118
+ Threshold to use when turning the predicted masks into binary values.
119
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
120
+ Mask overlap threshold to eliminate small, disconnected segments.
121
+ timeout (`float`, *optional*, defaults to None):
122
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
123
+ the call may block forever.
124
+
125
+ Return:
126
+ A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
127
+ list of dictionaries, if the input is a list of several images, will return a list of list of dictionaries
128
+ corresponding to each image.
129
+
130
+ The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
131
+ the following keys:
132
+
133
+ - **label** (`str`) -- The class label identified by the model.
134
+ - **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
135
+ the original image. Returns a mask filled with zeros if no object is found.
136
+ - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
137
+ "object" described by the label and the mask.
138
+ """
139
+ # After deprecation of this is completed, remove the default `None` value for `images`
140
+ if "images" in kwargs:
141
+ inputs = kwargs.pop("images")
142
+ if inputs is None:
143
+ raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
144
+ return super().__call__(inputs, **kwargs)
145
+
146
+ def preprocess(self, image, subtask=None, timeout=None):
147
+ image = load_image(image, timeout=timeout)
148
+ target_size = [(image.height, image.width)]
149
+ if self.model.config.__class__.__name__ == "OneFormerConfig":
150
+ if subtask is None:
151
+ kwargs = {}
152
+ else:
153
+ kwargs = {"task_inputs": [subtask]}
154
+ inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
155
+ if self.framework == "pt":
156
+ inputs = inputs.to(self.torch_dtype)
157
+ inputs["task_inputs"] = self.tokenizer(
158
+ inputs["task_inputs"],
159
+ padding="max_length",
160
+ max_length=self.model.config.task_seq_len,
161
+ return_tensors=self.framework,
162
+ )["input_ids"]
163
+ else:
164
+ inputs = self.image_processor(images=[image], return_tensors="pt")
165
+ if self.framework == "pt":
166
+ inputs = inputs.to(self.torch_dtype)
167
+ inputs["target_size"] = target_size
168
+ return inputs
169
+
170
+ def _forward(self, model_inputs):
171
+ target_size = model_inputs.pop("target_size")
172
+ model_outputs = self.model(**model_inputs)
173
+ model_outputs["target_size"] = target_size
174
+ return model_outputs
175
+
176
+ def postprocess(
177
+ self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
178
+ ):
179
+ fn = None
180
+ if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
181
+ fn = self.image_processor.post_process_panoptic_segmentation
182
+ elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
183
+ fn = self.image_processor.post_process_instance_segmentation
184
+
185
+ if fn is not None:
186
+ outputs = fn(
187
+ model_outputs,
188
+ threshold=threshold,
189
+ mask_threshold=mask_threshold,
190
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
191
+ target_sizes=model_outputs["target_size"],
192
+ )[0]
193
+
194
+ annotation = []
195
+ segmentation = outputs["segmentation"]
196
+
197
+ for segment in outputs["segments_info"]:
198
+ mask = (segmentation == segment["id"]) * 255
199
+ mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
200
+ label = self.model.config.id2label[segment["label_id"]]
201
+ score = segment["score"]
202
+ annotation.append({"score": score, "label": label, "mask": mask})
203
+
204
+ elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
205
+ outputs = self.image_processor.post_process_semantic_segmentation(
206
+ model_outputs, target_sizes=model_outputs["target_size"]
207
+ )[0]
208
+
209
+ annotation = []
210
+ segmentation = outputs.numpy()
211
+ labels = np.unique(segmentation)
212
+
213
+ for label in labels:
214
+ mask = (segmentation == label) * 255
215
+ mask = Image.fromarray(mask.astype(np.uint8), mode="L")
216
+ label = self.model.config.id2label[label]
217
+ annotation.append({"score": None, "label": label, "mask": mask})
218
+ else:
219
+ raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
220
+ return annotation
.venv/Lib/site-packages/transformers/pipelines/image_text_to_text.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import enum
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ from ..processing_utils import ProcessingKwargs, Unpack
20
+ from ..utils import (
21
+ add_end_docstrings,
22
+ is_torch_available,
23
+ is_vision_available,
24
+ logging,
25
+ requires_backends,
26
+ )
27
+ from .base import Pipeline, build_pipeline_init_args
28
+
29
+
30
+ if is_vision_available():
31
+ from PIL import Image
32
+
33
+ from ..image_utils import load_images, valid_images
34
+
35
+
36
+ if is_torch_available():
37
+ from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
38
+ from .pt_utils import KeyDataset
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ IMAGE_TOKEN = "<image>"
43
+
44
+
45
+ class ReturnType(enum.Enum):
46
+ TENSORS = 0
47
+ NEW_TEXT = 1
48
+ FULL_TEXT = 2
49
+
50
+
51
+ class Chat:
52
+ """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
53
+ to this format because the rest of the pipeline code tends to assume that lists of messages are
54
+ actually a batch of samples rather than messages in the same conversation."""
55
+
56
+ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]):
57
+ for message in messages:
58
+ if not ("role" in message and "content" in message):
59
+ raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
60
+ images = retrieve_images_in_messages(messages, images)
61
+
62
+ self.messages = messages
63
+ self.images = images
64
+
65
+
66
+ def retrieve_images_in_messages(
67
+ messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
68
+ ):
69
+ """
70
+ Retrieve and combine images from the chat and the images passed as input.
71
+ """
72
+ if images is None:
73
+ images = []
74
+ idx_images = 0
75
+ retrieved_images = []
76
+ for message in messages:
77
+ for content in message["content"]:
78
+ if isinstance(content, dict):
79
+ if content.get("type") == "image":
80
+ for key in ["image", "url", "path", "base64"]:
81
+ if key in content:
82
+ retrieved_images.append(content[key])
83
+ break
84
+ else:
85
+ if idx_images < len(images):
86
+ retrieved_images.append(images[idx_images])
87
+ idx_images += 1
88
+ else:
89
+ raise ValueError(
90
+ "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
91
+ )
92
+ # Add support for OpenAI/TGI chat format
93
+ elif content.get("type") == "image_url":
94
+ if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
95
+ retrieved_images.append(content["image_url"]["url"])
96
+ # Rewrite content to be in the Transformers chat format
97
+ content["type"] = "image"
98
+ content["image"] = content["image_url"]["url"]
99
+ del content["image_url"]
100
+ else:
101
+ raise ValueError(
102
+ "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
103
+ )
104
+
105
+ # The number of images passed should be consistent with the number of images in the chat without an image key
106
+ if idx_images != len(images):
107
+ raise ValueError(
108
+ "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
109
+ )
110
+
111
+ return retrieved_images
112
+
113
+
114
+ @add_end_docstrings(build_pipeline_init_args(has_processor=True))
115
+ class ImageTextToTextPipeline(Pipeline):
116
+ """
117
+ Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
118
+ When the underlying model is a conversational model, it can also accept one or more chats,
119
+ in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
120
+ Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
121
+
122
+ Example:
123
+
124
+ ```python
125
+ >>> from transformers import pipeline
126
+
127
+ >>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base")
128
+ >>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of")
129
+ [{'generated_text': 'a photo of two birds'}]
130
+ ```
131
+
132
+ ```python
133
+ >>> from transformers import pipeline
134
+
135
+ >>> pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
136
+ >>> messages = [
137
+ >>> {
138
+ >>> "role": "user",
139
+ >>> "content": [
140
+ >>> {
141
+ >>> "type": "image",
142
+ >>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
143
+ >>> },
144
+ >>> {"type": "text", "text": "Describe this image."},
145
+ >>> ],
146
+ >>> },
147
+ >>> {
148
+ >>> "role": "assistant",
149
+ >>> "content": [
150
+ >>> {"type": "text", "text": "There is a dog and"},
151
+ >>> ],
152
+ >>> },
153
+ >>> ]
154
+ >>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
155
+ [{'input_text': [{'role': 'user',
156
+ 'content': [{'type': 'image',
157
+ 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
158
+ {'type': 'text', 'text': 'Describe this image.'}]},
159
+ {'role': 'assistant',
160
+ 'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
161
+ 'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}]
162
+ ```
163
+
164
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
165
+
166
+ This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier:
167
+ "image-text-to-text".
168
+
169
+ See the list of available models on
170
+ [huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text).
171
+ """
172
+
173
+ _load_processor = True
174
+ _load_image_processor = False
175
+ _load_feature_extractor = False
176
+ _load_tokenizer = False
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ super().__init__(*args, **kwargs)
180
+ requires_backends(self, "vision")
181
+ self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
182
+
183
+ def _sanitize_parameters(
184
+ self,
185
+ max_new_tokens=None,
186
+ generate_kwargs=None,
187
+ timeout=None,
188
+ return_full_text=None,
189
+ return_tensors=None,
190
+ return_type=None,
191
+ continue_final_message=None,
192
+ **kwargs: Unpack[ProcessingKwargs],
193
+ ):
194
+ forward_kwargs = {}
195
+ preprocess_params = {}
196
+ postprocess_params = {}
197
+
198
+ preprocess_params["processing_kwargs"] = kwargs
199
+
200
+ if timeout is not None:
201
+ preprocess_params["timeout"] = timeout
202
+
203
+ if continue_final_message is not None:
204
+ preprocess_params["continue_final_message"] = continue_final_message
205
+
206
+ if generate_kwargs is not None:
207
+ forward_kwargs["generate_kwargs"] = generate_kwargs
208
+
209
+ if max_new_tokens is not None:
210
+ if "generate_kwargs" not in forward_kwargs:
211
+ forward_kwargs["generate_kwargs"] = {}
212
+ if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
213
+ raise ValueError(
214
+ "'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
215
+ " please use only one"
216
+ )
217
+ forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
218
+
219
+ if return_full_text is not None and return_type is None:
220
+ if return_tensors is not None:
221
+ raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
222
+ return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
223
+ if return_tensors is not None and return_type is None:
224
+ return_type = ReturnType.TENSORS
225
+ if return_type is not None:
226
+ postprocess_params["return_type"] = return_type
227
+ if continue_final_message is not None:
228
+ postprocess_params["continue_final_message"] = continue_final_message
229
+
230
+ return preprocess_params, forward_kwargs, postprocess_params
231
+
232
+ def __call__(
233
+ self,
234
+ images: Optional[
235
+ Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]]
236
+ ] = None,
237
+ text: Optional[Union[str, List[str], List[dict]]] = None,
238
+ **kwargs,
239
+ ):
240
+ """
241
+ Generate a text given text and the image(s) passed as inputs.
242
+
243
+ Args:
244
+ images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`):
245
+ The pipeline handles three types of images:
246
+
247
+ - A string containing a HTTP(s) link pointing to an image
248
+ - A string containing a local path to an image
249
+ - An image loaded in PIL directly
250
+
251
+ The pipeline accepts either a single image or a batch of images.
252
+ text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`):
253
+ The text to be used for generation. If a list of strings is passed, the length of the list should be the
254
+ same as the number of images. Text can also follow the chat format: a list of dictionaries where each
255
+ dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and
256
+ 'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of dictionary
257
+ containing the text of the message and the type of the message. The type of the message can be either
258
+ 'text' or 'image'. If the type is 'image', no text is needed.
259
+ return_tensors (`bool`, *optional*, defaults to `False`):
260
+ Returns the tensors of predictions (as token indices) in the outputs. If set to
261
+ `True`, the decoded text is not returned.
262
+ return_text (`bool`, *optional*):
263
+ Returns the decoded texts in the outputs.
264
+ return_full_text (`bool`, *optional*, defaults to `True`):
265
+ If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
266
+ specified at the same time as `return_text`.
267
+ continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
268
+ last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
269
+ By default this is `True` when the final message in the input chat has the `assistant` role and
270
+ `False` otherwise, but you can manually override that behaviour by setting this flag.
271
+
272
+ Return:
273
+ A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot return a combination
274
+ of both `generated_text` and `generated_token_ids`):
275
+
276
+ - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
277
+ - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token
278
+ ids of the generated text.
279
+ - **input_text** (`str`) -- The input text.
280
+ """
281
+ if images is None and text is None:
282
+ raise ValueError("You must at least provide either text or images.")
283
+ if images is not None and text is None and not valid_images(images):
284
+ """
285
+ Supports the following format
286
+ - {"image": image, "text": text}
287
+ - [{"image": image, "text": text}]
288
+ - Generator and datasets
289
+ This is a common pattern in other multimodal pipelines, so we support it here as well.
290
+ """
291
+ return super().__call__(images, **kwargs)
292
+
293
+ if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)):
294
+ # We have one or more prompts in list-of-dicts format, so this is chat mode
295
+ if isinstance(text[0], dict):
296
+ return super().__call__(Chat(text, images), **kwargs)
297
+ else:
298
+ if images is None:
299
+ images = [None] * len(text)
300
+ chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
301
+ return super().__call__(chats, **kwargs)
302
+
303
+ # encourage the user to use the chat format if supported
304
+ if getattr(self.processor, "chat_template", None) is not None:
305
+ logger.warning_once(
306
+ "The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. "
307
+ "Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating"
308
+ )
309
+
310
+ # support text only generation
311
+ if images is None:
312
+ return super().__call__(text, **kwargs)
313
+ if text is None:
314
+ raise ValueError("You must provide text for this pipeline.")
315
+
316
+ return super().__call__({"images": images, "text": text}, **kwargs)
317
+
318
+ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
319
+ # In case we only have text inputs
320
+ if isinstance(inputs, (list, tuple, str)):
321
+ images = None
322
+ text = inputs
323
+ inputs_text = inputs
324
+ else:
325
+ if isinstance(inputs, Chat):
326
+ # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
327
+ # because very few models support multiple separate, consecutive assistant messages
328
+ if continue_final_message is None:
329
+ continue_final_message = inputs.messages[-1]["role"] == "assistant"
330
+ text = self.processor.apply_chat_template(
331
+ inputs.messages,
332
+ add_generation_prompt=not continue_final_message,
333
+ continue_final_message=continue_final_message,
334
+ return_tensors=self.framework,
335
+ )
336
+ inputs_text = inputs
337
+ images = inputs.images
338
+ else:
339
+ text = inputs["text"]
340
+ inputs_text = inputs["text"]
341
+ images = inputs["images"]
342
+
343
+ images = load_images(images)
344
+
345
+ # if batched text inputs, we set padding to True unless specified otherwise
346
+ if isinstance(text, (list, tuple)) and len(text) > 1:
347
+ processing_kwargs.setdefault("padding", True)
348
+ model_inputs = self.processor(
349
+ images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
350
+ ).to(dtype=self.torch_dtype)
351
+
352
+ model_inputs["text"] = inputs_text
353
+
354
+ return model_inputs
355
+
356
+ def _forward(self, model_inputs, generate_kwargs=None):
357
+ generate_kwargs = {} if generate_kwargs is None else generate_kwargs
358
+ prompt_text = model_inputs.pop("text")
359
+ input_ids = (
360
+ model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"]
361
+ ) # for decoder-only models
362
+ generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)
363
+
364
+ return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
365
+
366
+ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
367
+ input_texts = model_outputs["prompt_text"]
368
+ input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
369
+ generated_sequence = model_outputs["generated_sequence"]
370
+ input_ids = model_outputs["input_ids"]
371
+ if return_type == ReturnType.TENSORS:
372
+ return [
373
+ {"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
374
+ for i in range(len(input_texts))
375
+ ]
376
+
377
+ # Decode inputs and outputs the same way to remove input text from generated text if present
378
+ generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
379
+ decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
380
+
381
+ # Force consistent behavior for including the input text in the output
382
+ if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
383
+ # Remove the input text from the generated text if the generated text starts with the input text
384
+ # (accounting for the possibility of a space between the input and generated text)
385
+ new_generated_texts = []
386
+ for text_generated, decoded_input in zip(generated_texts, decoded_inputs):
387
+ # There can be added characters before the input text, so we need to find the beginning of the input text in the generated text
388
+ index_input_text = text_generated.find(decoded_input)
389
+ # Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer
390
+ if 0 <= index_input_text <= 2:
391
+ # If the input text is found, we remove it
392
+ new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :])
393
+ else:
394
+ new_generated_texts.append(text_generated)
395
+ generated_texts = new_generated_texts
396
+ if return_type == ReturnType.FULL_TEXT:
397
+ full_texts = []
398
+ for prompt_text, generated_text in zip(input_texts, generated_texts):
399
+ if isinstance(prompt_text, str):
400
+ generated_text = prompt_text + generated_text
401
+ elif isinstance(prompt_text, Chat):
402
+ if continue_final_message is None:
403
+ # If the user passes a chat ending in an assistant message, we treat it as a prefill by
404
+ # default because very few models support multiple separate, consecutive assistant messages
405
+ continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
406
+ if continue_final_message:
407
+ # With assistant prefill, concat onto the end of the last message
408
+ new_text = dict(prompt_text.messages[-1]["content"][-1].items())
409
+ new_text["text"] += generated_text
410
+ generated_text = list(prompt_text.messages)[:-1] + [
411
+ {
412
+ "role": prompt_text.messages[-1]["role"],
413
+ "content": prompt_text.messages[-1]["content"][:-1] + [new_text],
414
+ }
415
+ ]
416
+ else:
417
+ # When we're not starting from a prefill, the output is a new assistant message
418
+ generated_text = list(prompt_text.messages) + [
419
+ {"role": "assistant", "content": generated_text}
420
+ ]
421
+ full_texts.append(generated_text)
422
+ generated_texts = full_texts
423
+
424
+ records = [
425
+ {
426
+ "input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
427
+ "generated_text": generated_text,
428
+ }
429
+ for input_text, generated_text in zip(input_texts, generated_texts)
430
+ ]
431
+
432
+ return records
.venv/Lib/site-packages/transformers/pipelines/image_to_image.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Union
15
+
16
+ import numpy as np
17
+
18
+ from ..utils import (
19
+ add_end_docstrings,
20
+ is_torch_available,
21
+ is_vision_available,
22
+ logging,
23
+ requires_backends,
24
+ )
25
+ from .base import Pipeline, build_pipeline_init_args
26
+
27
+
28
+ if is_vision_available():
29
+ from PIL import Image
30
+
31
+ from ..image_utils import load_image
32
+
33
+ if is_torch_available():
34
+ from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
40
+ class ImageToImagePipeline(Pipeline):
41
+ """
42
+ Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous
43
+ image input.
44
+
45
+ Example:
46
+
47
+ ```python
48
+ >>> from PIL import Image
49
+ >>> import requests
50
+
51
+ >>> from transformers import pipeline
52
+
53
+ >>> upscaler = pipeline("image-to-image", model="caidas/swin2SR-classical-sr-x2-64")
54
+ >>> img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
55
+ >>> img = img.resize((64, 64))
56
+ >>> upscaled_img = upscaler(img)
57
+ >>> img.size
58
+ (64, 64)
59
+
60
+ >>> upscaled_img.size
61
+ (144, 144)
62
+ ```
63
+
64
+ This image to image pipeline can currently be loaded from [`pipeline`] using the following task identifier:
65
+ `"image-to-image"`.
66
+
67
+ See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=image-to-image).
68
+ """
69
+
70
+ def __init__(self, *args, **kwargs):
71
+ super().__init__(*args, **kwargs)
72
+ requires_backends(self, "vision")
73
+ self.check_model_type(MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
74
+
75
+ def _sanitize_parameters(self, **kwargs):
76
+ preprocess_params = {}
77
+ postprocess_params = {}
78
+ forward_params = {}
79
+
80
+ if "timeout" in kwargs:
81
+ preprocess_params["timeout"] = kwargs["timeout"]
82
+ if "head_mask" in kwargs:
83
+ forward_params["head_mask"] = kwargs["head_mask"]
84
+
85
+ return preprocess_params, forward_params, postprocess_params
86
+
87
+ def __call__(
88
+ self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs
89
+ ) -> Union["Image.Image", List["Image.Image"]]:
90
+ """
91
+ Transform the image(s) passed as inputs.
92
+
93
+ Args:
94
+ images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
95
+ The pipeline handles three types of images:
96
+
97
+ - A string containing a http link pointing to an image
98
+ - A string containing a local path to an image
99
+ - An image loaded in PIL directly
100
+
101
+ The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
102
+ Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
103
+ images.
104
+ timeout (`float`, *optional*, defaults to None):
105
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
106
+ the call may block forever.
107
+
108
+ Return:
109
+ An image (Image.Image) or a list of images (List["Image.Image"]) containing result(s). If the input is a
110
+ single image, the return will be also a single image, if the input is a list of several images, it will
111
+ return a list of transformed images.
112
+ """
113
+ return super().__call__(images, **kwargs)
114
+
115
+ def _forward(self, model_inputs):
116
+ model_outputs = self.model(**model_inputs)
117
+ return model_outputs
118
+
119
+ def preprocess(self, image, timeout=None):
120
+ image = load_image(image, timeout=timeout)
121
+ inputs = self.image_processor(images=[image], return_tensors="pt")
122
+ if self.framework == "pt":
123
+ inputs = inputs.to(self.torch_dtype)
124
+ return inputs
125
+
126
+ def postprocess(self, model_outputs):
127
+ images = []
128
+ if "reconstruction" in model_outputs.keys():
129
+ outputs = model_outputs.reconstruction
130
+ for output in outputs:
131
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
132
+ output = np.moveaxis(output, source=0, destination=-1)
133
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
134
+ images.append(Image.fromarray(output))
135
+
136
+ return images if len(images) > 1 else images[0]
.venv/Lib/site-packages/transformers/pipelines/image_to_text.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Union
17
+
18
+ from ..utils import (
19
+ add_end_docstrings,
20
+ is_tf_available,
21
+ is_torch_available,
22
+ is_vision_available,
23
+ logging,
24
+ requires_backends,
25
+ )
26
+ from .base import Pipeline, build_pipeline_init_args
27
+
28
+
29
+ if is_vision_available():
30
+ from PIL import Image
31
+
32
+ from ..image_utils import load_image
33
+
34
+ if is_tf_available():
35
+ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
36
+
37
+ if is_torch_available():
38
+ import torch
39
+
40
+ from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
46
+ class ImageToTextPipeline(Pipeline):
47
+ """
48
+ Image To Text pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given image.
49
+
50
+ Example:
51
+
52
+ ```python
53
+ >>> from transformers import pipeline
54
+
55
+ >>> captioner = pipeline(model="ydshieh/vit-gpt2-coco-en")
56
+ >>> captioner("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
57
+ [{'generated_text': 'two birds are standing next to each other '}]
58
+ ```
59
+
60
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
61
+
62
+ This image to text pipeline can currently be loaded from pipeline() using the following task identifier:
63
+ "image-to-text".
64
+
65
+ See the list of available models on
66
+ [huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-to-text).
67
+ """
68
+
69
+ def __init__(self, *args, **kwargs):
70
+ super().__init__(*args, **kwargs)
71
+ requires_backends(self, "vision")
72
+ self.check_model_type(
73
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
74
+ )
75
+
76
+ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
77
+ forward_params = {}
78
+ preprocess_params = {}
79
+
80
+ if prompt is not None:
81
+ preprocess_params["prompt"] = prompt
82
+ if timeout is not None:
83
+ preprocess_params["timeout"] = timeout
84
+
85
+ if max_new_tokens is not None:
86
+ forward_params["max_new_tokens"] = max_new_tokens
87
+ if generate_kwargs is not None:
88
+ if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
89
+ raise ValueError(
90
+ "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
91
+ " only 1 version"
92
+ )
93
+ forward_params.update(generate_kwargs)
94
+
95
+ return preprocess_params, forward_params, {}
96
+
97
+ def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
98
+ """
99
+ Assign labels to the image(s) passed as inputs.
100
+
101
+ Args:
102
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
103
+ The pipeline handles three types of images:
104
+
105
+ - A string containing a HTTP(s) link pointing to an image
106
+ - A string containing a local path to an image
107
+ - An image loaded in PIL directly
108
+
109
+ The pipeline accepts either a single image or a batch of images.
110
+
111
+ max_new_tokens (`int`, *optional*):
112
+ The amount of maximum tokens to generate. By default it will use `generate` default.
113
+
114
+ generate_kwargs (`Dict`, *optional*):
115
+ Pass it to send all of these arguments directly to `generate` allowing full control of this function.
116
+
117
+ timeout (`float`, *optional*, defaults to None):
118
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
119
+ the call may block forever.
120
+
121
+ Return:
122
+ A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
123
+
124
+ - **generated_text** (`str`) -- The generated text.
125
+ """
126
+ # After deprecation of this is completed, remove the default `None` value for `images`
127
+ if "images" in kwargs:
128
+ inputs = kwargs.pop("images")
129
+ if inputs is None:
130
+ raise ValueError("Cannot call the image-to-text pipeline without an inputs argument!")
131
+ return super().__call__(inputs, **kwargs)
132
+
133
+ def preprocess(self, image, prompt=None, timeout=None):
134
+ image = load_image(image, timeout=timeout)
135
+
136
+ if prompt is not None:
137
+ logger.warning_once(
138
+ "Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48"
139
+ " of 🤗 Transformers. Use the `image-text-to-text` pipeline instead",
140
+ )
141
+ if not isinstance(prompt, str):
142
+ raise ValueError(
143
+ f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
144
+ "Note also that one single text can be provided for conditional image to text generation."
145
+ )
146
+
147
+ model_type = self.model.config.model_type
148
+
149
+ if model_type == "git":
150
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
151
+ if self.framework == "pt":
152
+ model_inputs = model_inputs.to(self.torch_dtype)
153
+ input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
154
+ input_ids = [self.tokenizer.cls_token_id] + input_ids
155
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
156
+ model_inputs.update({"input_ids": input_ids})
157
+
158
+ elif model_type == "pix2struct":
159
+ model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
160
+ if self.framework == "pt":
161
+ model_inputs = model_inputs.to(self.torch_dtype)
162
+
163
+ elif model_type != "vision-encoder-decoder":
164
+ # vision-encoder-decoder does not support conditional generation
165
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
166
+ if self.framework == "pt":
167
+ model_inputs = model_inputs.to(self.torch_dtype)
168
+ text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
169
+ model_inputs.update(text_inputs)
170
+
171
+ else:
172
+ raise ValueError(f"Model type {model_type} does not support conditional text generation")
173
+
174
+ else:
175
+ model_inputs = self.image_processor(images=image, return_tensors=self.framework)
176
+ if self.framework == "pt":
177
+ model_inputs = model_inputs.to(self.torch_dtype)
178
+
179
+ if self.model.config.model_type == "git" and prompt is None:
180
+ model_inputs["input_ids"] = None
181
+
182
+ return model_inputs
183
+
184
+ def _forward(self, model_inputs, **generate_kwargs):
185
+ # Git model sets `model_inputs["input_ids"] = None` in `preprocess` (when `prompt=None`). In batch model, the
186
+ # pipeline will group them into a list of `None`, which fail `_forward`. Avoid this by checking it first.
187
+ if (
188
+ "input_ids" in model_inputs
189
+ and isinstance(model_inputs["input_ids"], list)
190
+ and all(x is None for x in model_inputs["input_ids"])
191
+ ):
192
+ model_inputs["input_ids"] = None
193
+
194
+ # User-defined `generation_config` passed to the pipeline call take precedence
195
+ if "generation_config" not in generate_kwargs:
196
+ generate_kwargs["generation_config"] = self.generation_config
197
+
198
+ # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
199
+ # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
200
+ # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
201
+ # in the `_prepare_model_inputs` method.
202
+ inputs = model_inputs.pop(self.model.main_input_name)
203
+ model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
204
+ return model_outputs
205
+
206
+ def postprocess(self, model_outputs):
207
+ records = []
208
+ for output_ids in model_outputs:
209
+ record = {
210
+ "generated_text": self.tokenizer.decode(
211
+ output_ids,
212
+ skip_special_tokens=True,
213
+ )
214
+ }
215
+ records.append(record)
216
+ return records
.venv/Lib/site-packages/transformers/pipelines/mask_generation.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ from ..image_utils import load_image
5
+ from ..utils import (
6
+ add_end_docstrings,
7
+ is_torch_available,
8
+ logging,
9
+ requires_backends,
10
+ )
11
+ from .base import ChunkPipeline, build_pipeline_init_args
12
+
13
+
14
+ if is_torch_available():
15
+ import torch
16
+
17
+ from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ @add_end_docstrings(
23
+ build_pipeline_init_args(has_image_processor=True),
24
+ r"""
25
+ points_per_batch (*optional*, int, default to 64):
26
+ Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
27
+ memory.
28
+ output_bboxes_mask (`bool`, *optional*, default to `False`):
29
+ Whether or not to output the bounding box predictions.
30
+ output_rle_masks (`bool`, *optional*, default to `False`):
31
+ Whether or not to output the masks in `RLE` format""",
32
+ )
33
+ class MaskGenerationPipeline(ChunkPipeline):
34
+ """
35
+ Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
36
+ image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
37
+ avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
38
+ same time. Default is `64`.
39
+
40
+ The pipeline works in 3 steps:
41
+ 1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
42
+ labels.
43
+ For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
44
+ function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
45
+ `points_per_batch`.
46
+
47
+ 2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
48
+ Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
49
+ tensors and models are on the same device.
50
+
51
+ 3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
52
+ are induced:
53
+ - image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
54
+ resizes them according
55
+ to the image size, and transforms there to binary masks.
56
+ - image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
57
+ `stability_scores`. Also
58
+ applies a variety of filters based on non maximum suppression to remove bad masks.
59
+ - image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import pipeline
65
+
66
+ >>> generator = pipeline(model="facebook/sam-vit-base", task="mask-generation")
67
+ >>> outputs = generator(
68
+ ... "http://images.cocodataset.org/val2017/000000039769.jpg",
69
+ ... )
70
+
71
+ >>> outputs = generator(
72
+ ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
73
+ ... )
74
+ ```
75
+
76
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
77
+
78
+ This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
79
+ `"mask-generation"`.
80
+
81
+ See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
82
+ """
83
+
84
+ def __init__(self, **kwargs):
85
+ super().__init__(**kwargs)
86
+ requires_backends(self, "vision")
87
+ requires_backends(self, "torch")
88
+
89
+ if self.framework != "pt":
90
+ raise ValueError(f"The {self.__class__} is only available in PyTorch.")
91
+
92
+ self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
93
+
94
+ def _sanitize_parameters(self, **kwargs):
95
+ preprocess_kwargs = {}
96
+ postprocess_kwargs = {}
97
+ forward_params = {}
98
+ # preprocess args
99
+ if "points_per_batch" in kwargs:
100
+ preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
101
+ if "points_per_crop" in kwargs:
102
+ preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
103
+ if "crops_n_layers" in kwargs:
104
+ preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
105
+ if "crop_overlap_ratio" in kwargs:
106
+ preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
107
+ if "crop_n_points_downscale_factor" in kwargs:
108
+ preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
109
+ if "timeout" in kwargs:
110
+ preprocess_kwargs["timeout"] = kwargs["timeout"]
111
+ # postprocess args
112
+ if "pred_iou_thresh" in kwargs:
113
+ forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
114
+ if "stability_score_offset" in kwargs:
115
+ forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
116
+ if "mask_threshold" in kwargs:
117
+ forward_params["mask_threshold"] = kwargs["mask_threshold"]
118
+ if "stability_score_thresh" in kwargs:
119
+ forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
120
+ if "crops_nms_thresh" in kwargs:
121
+ postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
122
+ if "output_rle_mask" in kwargs:
123
+ postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
124
+ if "output_bboxes_mask" in kwargs:
125
+ postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
126
+ return preprocess_kwargs, forward_params, postprocess_kwargs
127
+
128
+ def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
129
+ """
130
+ Generates binary segmentation masks
131
+
132
+ Args:
133
+ inputs (`np.ndarray` or `bytes` or `str` or `dict`):
134
+ Image or list of images.
135
+ mask_threshold (`float`, *optional*, defaults to 0.0):
136
+ Threshold to use when turning the predicted masks into binary values.
137
+ pred_iou_thresh (`float`, *optional*, defaults to 0.88):
138
+ A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
139
+ stability_score_thresh (`float`, *optional*, defaults to 0.95):
140
+ A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
141
+ binarize the model's mask predictions.
142
+ stability_score_offset (`int`, *optional*, defaults to 1):
143
+ The amount to shift the cutoff when calculated the stability score.
144
+ crops_nms_thresh (`float`, *optional*, defaults to 0.7):
145
+ The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
146
+ crops_n_layers (`int`, *optional*, defaults to 0):
147
+ If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
148
+ layers to run, where each layer has 2**i_layer number of image crops.
149
+ crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
150
+ Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
151
+ the image length. Later layers with more crops scale down this overlap.
152
+ crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
153
+ The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
154
+ timeout (`float`, *optional*, defaults to None):
155
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
156
+ the call may block forever.
157
+
158
+ Return:
159
+ `Dict`: A dictionary with the following keys:
160
+ - **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
161
+ height)` of the original image. Returns a mask filled with zeros if no object is found.
162
+ - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
163
+ the "object" described by the label and the mask.
164
+
165
+ """
166
+ return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
167
+
168
+ def preprocess(
169
+ self,
170
+ image,
171
+ points_per_batch=64,
172
+ crops_n_layers: int = 0,
173
+ crop_overlap_ratio: float = 512 / 1500,
174
+ points_per_crop: Optional[int] = 32,
175
+ crop_n_points_downscale_factor: Optional[int] = 1,
176
+ timeout: Optional[float] = None,
177
+ ):
178
+ image = load_image(image, timeout=timeout)
179
+ target_size = self.image_processor.size["longest_edge"]
180
+ crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
181
+ image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
182
+ )
183
+ model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
184
+ if self.framework == "pt":
185
+ model_inputs = model_inputs.to(self.torch_dtype)
186
+
187
+ with self.device_placement():
188
+ if self.framework == "pt":
189
+ inference_context = self.get_inference_context()
190
+ with inference_context():
191
+ model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
192
+ image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
193
+ model_inputs["image_embeddings"] = image_embeddings
194
+
195
+ n_points = grid_points.shape[1]
196
+ points_per_batch = points_per_batch if points_per_batch is not None else n_points
197
+
198
+ if points_per_batch <= 0:
199
+ raise ValueError(
200
+ "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
201
+ "To return all points at once, set points_per_batch to None"
202
+ )
203
+
204
+ for i in range(0, n_points, points_per_batch):
205
+ batched_points = grid_points[:, i : i + points_per_batch, :, :]
206
+ labels = input_labels[:, i : i + points_per_batch]
207
+ is_last = i == n_points - points_per_batch
208
+ yield {
209
+ "input_points": batched_points,
210
+ "input_labels": labels,
211
+ "input_boxes": crop_boxes,
212
+ "is_last": is_last,
213
+ **model_inputs,
214
+ }
215
+
216
+ def _forward(
217
+ self,
218
+ model_inputs,
219
+ pred_iou_thresh=0.88,
220
+ stability_score_thresh=0.95,
221
+ mask_threshold=0,
222
+ stability_score_offset=1,
223
+ ):
224
+ input_boxes = model_inputs.pop("input_boxes")
225
+ is_last = model_inputs.pop("is_last")
226
+ original_sizes = model_inputs.pop("original_sizes").tolist()
227
+ reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
228
+
229
+ model_outputs = self.model(**model_inputs)
230
+
231
+ # post processing happens here in order to avoid CPU GPU copies of ALL the masks
232
+ low_resolution_masks = model_outputs["pred_masks"]
233
+ masks = self.image_processor.post_process_masks(
234
+ low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
235
+ )
236
+ iou_scores = model_outputs["iou_scores"]
237
+ masks, iou_scores, boxes = self.image_processor.filter_masks(
238
+ masks[0],
239
+ iou_scores[0],
240
+ original_sizes[0],
241
+ input_boxes[0],
242
+ pred_iou_thresh,
243
+ stability_score_thresh,
244
+ mask_threshold,
245
+ stability_score_offset,
246
+ )
247
+ return {
248
+ "masks": masks,
249
+ "is_last": is_last,
250
+ "boxes": boxes,
251
+ "iou_scores": iou_scores,
252
+ }
253
+
254
+ def postprocess(
255
+ self,
256
+ model_outputs,
257
+ output_rle_mask=False,
258
+ output_bboxes_mask=False,
259
+ crops_nms_thresh=0.7,
260
+ ):
261
+ all_scores = []
262
+ all_masks = []
263
+ all_boxes = []
264
+ for model_output in model_outputs:
265
+ all_scores.append(model_output.pop("iou_scores"))
266
+ all_masks.extend(model_output.pop("masks"))
267
+ all_boxes.append(model_output.pop("boxes"))
268
+
269
+ all_scores = torch.cat(all_scores)
270
+ all_boxes = torch.cat(all_boxes)
271
+ output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
272
+ all_masks, all_scores, all_boxes, crops_nms_thresh
273
+ )
274
+
275
+ extra = defaultdict(list)
276
+ for output in model_outputs:
277
+ for k, v in output.items():
278
+ extra[k].append(v)
279
+
280
+ optional = {}
281
+ if output_rle_mask:
282
+ optional["rle_mask"] = rle_mask
283
+
284
+ if output_bboxes_mask:
285
+ optional["bounding_boxes"] = bounding_boxes
286
+
287
+ return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
.venv/Lib/site-packages/transformers/pipelines/object_detection.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
+
3
+ from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
4
+ from .base import Pipeline, build_pipeline_init_args
5
+
6
+
7
+ if is_vision_available():
8
+ from ..image_utils import load_image
9
+
10
+
11
+ if is_torch_available():
12
+ import torch
13
+
14
+ from ..models.auto.modeling_auto import (
15
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
16
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
17
+ )
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ Prediction = Dict[str, Any]
23
+ Predictions = List[Prediction]
24
+
25
+
26
+ @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
27
+ class ObjectDetectionPipeline(Pipeline):
28
+ """
29
+ Object detection pipeline using any `AutoModelForObjectDetection`. This pipeline predicts bounding boxes of objects
30
+ and their classes.
31
+
32
+ Example:
33
+
34
+ ```python
35
+ >>> from transformers import pipeline
36
+
37
+ >>> detector = pipeline(model="facebook/detr-resnet-50")
38
+ >>> detector("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
39
+ [{'score': 0.997, 'label': 'bird', 'box': {'xmin': 69, 'ymin': 171, 'xmax': 396, 'ymax': 507}}, {'score': 0.999, 'label': 'bird', 'box': {'xmin': 398, 'ymin': 105, 'xmax': 767, 'ymax': 507}}]
40
+
41
+ >>> # x, y are expressed relative to the top left hand corner.
42
+ ```
43
+
44
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
45
+
46
+ This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
47
+ `"object-detection"`.
48
+
49
+ See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=object-detection).
50
+ """
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__(*args, **kwargs)
54
+
55
+ if self.framework == "tf":
56
+ raise ValueError(f"The {self.__class__} is only available in PyTorch.")
57
+
58
+ requires_backends(self, "vision")
59
+ mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
60
+ mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
61
+ self.check_model_type(mapping)
62
+
63
+ def _sanitize_parameters(self, **kwargs):
64
+ preprocess_params = {}
65
+ if "timeout" in kwargs:
66
+ preprocess_params["timeout"] = kwargs["timeout"]
67
+ postprocess_kwargs = {}
68
+ if "threshold" in kwargs:
69
+ postprocess_kwargs["threshold"] = kwargs["threshold"]
70
+ return preprocess_params, {}, postprocess_kwargs
71
+
72
+ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
73
+ """
74
+ Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
75
+
76
+ Args:
77
+ inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
78
+ The pipeline handles three types of images:
79
+
80
+ - A string containing an HTTP(S) link pointing to an image
81
+ - A string containing a local path to an image
82
+ - An image loaded in PIL directly
83
+
84
+ The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
85
+ same format: all as HTTP(S) links, all as local paths, or all as PIL images.
86
+ threshold (`float`, *optional*, defaults to 0.5):
87
+ The probability necessary to make a prediction.
88
+ timeout (`float`, *optional*, defaults to None):
89
+ The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
90
+ the call may block forever.
91
+
92
+ Return:
93
+ A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
94
+ image, will return a list of dictionaries, if the input is a list of several images, will return a list of
95
+ list of dictionaries corresponding to each image.
96
+
97
+ The dictionaries contain the following keys:
98
+
99
+ - **label** (`str`) -- The class label identified by the model.
100
+ - **score** (`float`) -- The score attributed by the model for that label.
101
+ - **box** (`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
102
+ """
103
+ # After deprecation of this is completed, remove the default `None` value for `images`
104
+ if "images" in kwargs and "inputs" not in kwargs:
105
+ kwargs["inputs"] = kwargs.pop("images")
106
+ return super().__call__(*args, **kwargs)
107
+
108
+ def preprocess(self, image, timeout=None):
109
+ image = load_image(image, timeout=timeout)
110
+ target_size = torch.IntTensor([[image.height, image.width]])
111
+ inputs = self.image_processor(images=[image], return_tensors="pt")
112
+ if self.framework == "pt":
113
+ inputs = inputs.to(self.torch_dtype)
114
+ if self.tokenizer is not None:
115
+ inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
116
+ inputs["target_size"] = target_size
117
+ return inputs
118
+
119
+ def _forward(self, model_inputs):
120
+ target_size = model_inputs.pop("target_size")
121
+ outputs = self.model(**model_inputs)
122
+ model_outputs = outputs.__class__({"target_size": target_size, **outputs})
123
+ if self.tokenizer is not None:
124
+ model_outputs["bbox"] = model_inputs["bbox"]
125
+ return model_outputs
126
+
127
+ def postprocess(self, model_outputs, threshold=0.5):
128
+ target_size = model_outputs["target_size"]
129
+ if self.tokenizer is not None:
130
+ # This is a LayoutLMForTokenClassification variant.
131
+ # The OCR got the boxes and the model classified the words.
132
+ height, width = target_size[0].tolist()
133
+
134
+ def unnormalize(bbox):
135
+ return self._get_bounding_box(
136
+ torch.Tensor(
137
+ [
138
+ (width * bbox[0] / 1000),
139
+ (height * bbox[1] / 1000),
140
+ (width * bbox[2] / 1000),
141
+ (height * bbox[3] / 1000),
142
+ ]
143
+ )
144
+ )
145
+
146
+ scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1)
147
+ labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()]
148
+ boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)]
149
+ keys = ["score", "label", "box"]
150
+ annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold]
151
+ else:
152
+ # This is a regular ForObjectDetectionModel
153
+ raw_annotations = self.image_processor.post_process_object_detection(model_outputs, threshold, target_size)
154
+ raw_annotation = raw_annotations[0]
155
+ scores = raw_annotation["scores"]
156
+ labels = raw_annotation["labels"]
157
+ boxes = raw_annotation["boxes"]
158
+
159
+ raw_annotation["scores"] = scores.tolist()
160
+ raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
161
+ raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
162
+
163
+ # {"scores": [...], ...} --> [{"score":x, ...}, ...]
164
+ keys = ["score", "label", "box"]
165
+ annotation = [
166
+ dict(zip(keys, vals))
167
+ for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"])
168
+ ]
169
+
170
+ return annotation
171
+
172
+ def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
173
+ """
174
+ Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
175
+
176
+ Args:
177
+ box (`torch.Tensor`): Tensor containing the coordinates in corners format.
178
+
179
+ Returns:
180
+ bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
181
+ """
182
+ if self.framework != "pt":
183
+ raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
184
+ xmin, ymin, xmax, ymax = box.int().tolist()
185
+ bbox = {
186
+ "xmin": xmin,
187
+ "ymin": ymin,
188
+ "xmax": xmax,
189
+ "ymax": ymax,
190
+ }
191
+ return bbox
.venv/Lib/site-packages/transformers/utils/__init__.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from functools import lru_cache
19
+ from typing import FrozenSet
20
+
21
+ from huggingface_hub import get_full_repo_name # for backward compatibility
22
+ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
23
+ from packaging import version
24
+
25
+ from .. import __version__
26
+ from .backbone_utils import BackboneConfigMixin, BackboneMixin
27
+ from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
28
+ from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
29
+ from .doc import (
30
+ add_code_sample_docstrings,
31
+ add_end_docstrings,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ copy_func,
35
+ replace_return_docstrings,
36
+ )
37
+ from .generic import (
38
+ ContextManagers,
39
+ ExplicitEnum,
40
+ LossKwargs,
41
+ ModelOutput,
42
+ PaddingStrategy,
43
+ TensorType,
44
+ add_model_info_to_auto_map,
45
+ add_model_info_to_custom_pipelines,
46
+ cached_property,
47
+ can_return_loss,
48
+ expand_dims,
49
+ filter_out_non_signature_kwargs,
50
+ find_labels,
51
+ flatten_dict,
52
+ infer_framework,
53
+ is_jax_tensor,
54
+ is_numpy_array,
55
+ is_tensor,
56
+ is_tf_symbolic_tensor,
57
+ is_tf_tensor,
58
+ is_torch_device,
59
+ is_torch_dtype,
60
+ is_torch_tensor,
61
+ reshape,
62
+ squeeze,
63
+ strtobool,
64
+ tensor_size,
65
+ to_numpy,
66
+ to_py_obj,
67
+ torch_float,
68
+ torch_int,
69
+ transpose,
70
+ working_or_temp_dir,
71
+ )
72
+ from .hub import (
73
+ CLOUDFRONT_DISTRIB_PREFIX,
74
+ HF_MODULES_CACHE,
75
+ HUGGINGFACE_CO_PREFIX,
76
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
77
+ PYTORCH_PRETRAINED_BERT_CACHE,
78
+ PYTORCH_TRANSFORMERS_CACHE,
79
+ S3_BUCKET_PREFIX,
80
+ TRANSFORMERS_CACHE,
81
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
82
+ EntryNotFoundError,
83
+ PushInProgress,
84
+ PushToHubMixin,
85
+ RepositoryNotFoundError,
86
+ RevisionNotFoundError,
87
+ cached_file,
88
+ default_cache_path,
89
+ define_sagemaker_information,
90
+ download_url,
91
+ extract_commit_hash,
92
+ get_cached_models,
93
+ get_file_from_repo,
94
+ has_file,
95
+ http_user_agent,
96
+ is_offline_mode,
97
+ is_remote_url,
98
+ move_cache,
99
+ send_example_telemetry,
100
+ try_to_load_from_cache,
101
+ )
102
+ from .import_utils import (
103
+ ACCELERATE_MIN_VERSION,
104
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
105
+ ENV_VARS_TRUE_VALUES,
106
+ GGUF_MIN_VERSION,
107
+ TORCH_FX_REQUIRED_VERSION,
108
+ USE_JAX,
109
+ USE_TF,
110
+ USE_TORCH,
111
+ XLA_FSDPV2_MIN_VERSION,
112
+ DummyObject,
113
+ OptionalDependencyNotAvailable,
114
+ _LazyModule,
115
+ ccl_version,
116
+ direct_transformers_import,
117
+ get_torch_version,
118
+ is_accelerate_available,
119
+ is_apex_available,
120
+ is_aqlm_available,
121
+ is_auto_awq_available,
122
+ is_auto_gptq_available,
123
+ is_av_available,
124
+ is_bitsandbytes_available,
125
+ is_bitsandbytes_multi_backend_available,
126
+ is_bs4_available,
127
+ is_coloredlogs_available,
128
+ is_compressed_tensors_available,
129
+ is_cv2_available,
130
+ is_cython_available,
131
+ is_datasets_available,
132
+ is_detectron2_available,
133
+ is_eetq_available,
134
+ is_essentia_available,
135
+ is_faiss_available,
136
+ is_fbgemm_gpu_available,
137
+ is_flash_attn_2_available,
138
+ is_flash_attn_greater_or_equal,
139
+ is_flash_attn_greater_or_equal_2_10,
140
+ is_flax_available,
141
+ is_fsdp_available,
142
+ is_ftfy_available,
143
+ is_g2p_en_available,
144
+ is_galore_torch_available,
145
+ is_gguf_available,
146
+ is_grokadamw_available,
147
+ is_hqq_available,
148
+ is_in_notebook,
149
+ is_ipex_available,
150
+ is_jieba_available,
151
+ is_jinja_available,
152
+ is_jumanpp_available,
153
+ is_kenlm_available,
154
+ is_keras_nlp_available,
155
+ is_levenshtein_available,
156
+ is_librosa_available,
157
+ is_liger_kernel_available,
158
+ is_lomo_available,
159
+ is_mlx_available,
160
+ is_natten_available,
161
+ is_ninja_available,
162
+ is_nltk_available,
163
+ is_onnx_available,
164
+ is_openai_available,
165
+ is_optimum_available,
166
+ is_optimum_quanto_available,
167
+ is_pandas_available,
168
+ is_peft_available,
169
+ is_phonemizer_available,
170
+ is_pretty_midi_available,
171
+ is_protobuf_available,
172
+ is_psutil_available,
173
+ is_py3nvml_available,
174
+ is_pyctcdecode_available,
175
+ is_pytesseract_available,
176
+ is_pytest_available,
177
+ is_pytorch_quantization_available,
178
+ is_rjieba_available,
179
+ is_sacremoses_available,
180
+ is_safetensors_available,
181
+ is_sagemaker_dp_enabled,
182
+ is_sagemaker_mp_enabled,
183
+ is_schedulefree_available,
184
+ is_scipy_available,
185
+ is_sentencepiece_available,
186
+ is_seqio_available,
187
+ is_sklearn_available,
188
+ is_soundfile_availble,
189
+ is_spacy_available,
190
+ is_speech_available,
191
+ is_sudachi_available,
192
+ is_sudachi_projection_available,
193
+ is_tensorflow_probability_available,
194
+ is_tensorflow_text_available,
195
+ is_tf2onnx_available,
196
+ is_tf_available,
197
+ is_tiktoken_available,
198
+ is_timm_available,
199
+ is_tokenizers_available,
200
+ is_torch_available,
201
+ is_torch_bf16_available,
202
+ is_torch_bf16_available_on_device,
203
+ is_torch_bf16_cpu_available,
204
+ is_torch_bf16_gpu_available,
205
+ is_torch_compile_available,
206
+ is_torch_cuda_available,
207
+ is_torch_deterministic,
208
+ is_torch_flex_attn_available,
209
+ is_torch_fp16_available_on_device,
210
+ is_torch_fx_available,
211
+ is_torch_fx_proxy,
212
+ is_torch_greater_or_equal,
213
+ is_torch_mlu_available,
214
+ is_torch_mps_available,
215
+ is_torch_musa_available,
216
+ is_torch_neuroncore_available,
217
+ is_torch_npu_available,
218
+ is_torch_sdpa_available,
219
+ is_torch_tensorrt_fx_available,
220
+ is_torch_tf32_available,
221
+ is_torch_tpu_available,
222
+ is_torch_xla_available,
223
+ is_torch_xpu_available,
224
+ is_torchao_available,
225
+ is_torchaudio_available,
226
+ is_torchdistx_available,
227
+ is_torchdynamo_available,
228
+ is_torchdynamo_compiling,
229
+ is_torchvision_available,
230
+ is_torchvision_v2_available,
231
+ is_training_run_on_sagemaker,
232
+ is_uroman_available,
233
+ is_vision_available,
234
+ requires_backends,
235
+ torch_only_method,
236
+ )
237
+ from .peft_utils import (
238
+ ADAPTER_CONFIG_NAME,
239
+ ADAPTER_SAFE_WEIGHTS_NAME,
240
+ ADAPTER_WEIGHTS_NAME,
241
+ check_peft_version,
242
+ find_adapter_config_file,
243
+ )
244
+
245
+
246
+ WEIGHTS_NAME = "pytorch_model.bin"
247
+ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
248
+ TF2_WEIGHTS_NAME = "tf_model.h5"
249
+ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
250
+ TF_WEIGHTS_NAME = "model.ckpt"
251
+ FLAX_WEIGHTS_NAME = "flax_model.msgpack"
252
+ FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
253
+ SAFE_WEIGHTS_NAME = "model.safetensors"
254
+ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
255
+ CONFIG_NAME = "config.json"
256
+ FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
257
+ IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
258
+ PROCESSOR_NAME = "processor_config.json"
259
+ CHAT_TEMPLATE_NAME = "chat_template.json"
260
+ GENERATION_CONFIG_NAME = "generation_config.json"
261
+ MODEL_CARD_NAME = "modelcard.json"
262
+
263
+ SENTENCEPIECE_UNDERLINE = "▁"
264
+ SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
265
+
266
+ MULTIPLE_CHOICE_DUMMY_INPUTS = [
267
+ [[0, 1, 0, 1], [1, 0, 0, 1]]
268
+ ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
269
+ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
270
+ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
271
+
272
+
273
+ def check_min_version(min_version):
274
+ if version.parse(__version__) < version.parse(min_version):
275
+ if "dev" in min_version:
276
+ error_message = (
277
+ "This example requires a source install from HuggingFace Transformers (see "
278
+ "`https://huggingface.co/docs/transformers/installation#install-from-source`),"
279
+ )
280
+ else:
281
+ error_message = f"This example requires a minimum version of {min_version},"
282
+ error_message += f" but the version found is {__version__}.\n"
283
+ raise ImportError(
284
+ error_message
285
+ + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
286
+ "versions of HuggingFace Transformers."
287
+ )
288
+
289
+
290
+ @lru_cache()
291
+ def get_available_devices() -> FrozenSet[str]:
292
+ """
293
+ Returns a frozenset of devices available for the current PyTorch installation.
294
+ """
295
+ devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
296
+
297
+ if is_torch_cuda_available():
298
+ devices.add("cuda")
299
+
300
+ if is_torch_mps_available():
301
+ devices.add("mps")
302
+
303
+ if is_torch_xpu_available():
304
+ devices.add("xpu")
305
+
306
+ if is_torch_npu_available():
307
+ devices.add("npu")
308
+
309
+ if is_torch_mlu_available():
310
+ devices.add("mlu")
311
+
312
+ if is_torch_musa_available():
313
+ devices.add("musa")
314
+
315
+ return frozenset(devices)
.venv/Lib/site-packages/transformers/utils/__pycache__/backbone_utils.cpython-39.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/Lib/site-packages/transformers/utils/__pycache__/chat_template_utils.cpython-39.pyc ADDED
Binary file (14.6 kB). View file
 
.venv/Lib/site-packages/transformers/utils/__pycache__/constants.cpython-39.pyc ADDED
Binary file (510 Bytes). View file
 
.venv/Lib/site-packages/transformers/utils/__pycache__/deprecation.cpython-39.pyc ADDED
Binary file (5.43 kB). View file
 
.venv/Lib/site-packages/transformers/utils/__pycache__/doc.cpython-39.pyc ADDED
Binary file (36 kB). View file
 
.venv/Lib/site-packages/transformers/utils/quantization_config.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import copy
18
+ import importlib.metadata
19
+ import json
20
+ import os
21
+ from dataclasses import dataclass
22
+ from enum import Enum
23
+ from inspect import Parameter, signature
24
+ from typing import Any, Dict, List, Optional, Union
25
+
26
+ from packaging import version
27
+
28
+ from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
29
+
30
+
31
+ if is_torch_available():
32
+ import torch
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ class QuantizationMethod(str, Enum):
38
+ BITS_AND_BYTES = "bitsandbytes"
39
+ GPTQ = "gptq"
40
+ AWQ = "awq"
41
+ AQLM = "aqlm"
42
+ QUANTO = "quanto"
43
+ EETQ = "eetq"
44
+ HQQ = "hqq"
45
+ COMPRESSED_TENSORS = "compressed-tensors"
46
+ FBGEMM_FP8 = "fbgemm_fp8"
47
+ TORCHAO = "torchao"
48
+ BITNET = "bitnet"
49
+
50
+
51
+ class AWQLinearVersion(str, Enum):
52
+ GEMM = "gemm"
53
+ GEMV = "gemv"
54
+ EXLLAMA = "exllama"
55
+ IPEX = "ipex"
56
+
57
+ @staticmethod
58
+ def from_str(version: str):
59
+ version = version.lower()
60
+ if version == "gemm":
61
+ return AWQLinearVersion.GEMM
62
+ elif version == "gemv":
63
+ return AWQLinearVersion.GEMV
64
+ elif version == "exllama":
65
+ return AWQLinearVersion.EXLLAMA
66
+ elif version == "ipex":
67
+ return AWQLinearVersion.IPEX
68
+ else:
69
+ raise ValueError(f"Unknown AWQLinearVersion {version}")
70
+
71
+
72
+ class AwqBackendPackingMethod(str, Enum):
73
+ AUTOAWQ = "autoawq"
74
+ LLMAWQ = "llm-awq"
75
+
76
+
77
+ @dataclass
78
+ class QuantizationConfigMixin:
79
+ """
80
+ Mixin class for quantization config
81
+ """
82
+
83
+ quant_method: QuantizationMethod
84
+
85
+ @classmethod
86
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
87
+ """
88
+ Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
89
+
90
+ Args:
91
+ config_dict (`Dict[str, Any]`):
92
+ Dictionary that will be used to instantiate the configuration object.
93
+ return_unused_kwargs (`bool`,*optional*, defaults to `False`):
94
+ Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
95
+ `PreTrainedModel`.
96
+ kwargs (`Dict[str, Any]`):
97
+ Additional parameters from which to initialize the configuration object.
98
+
99
+ Returns:
100
+ [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
101
+ """
102
+ config = cls(**config_dict)
103
+
104
+ to_remove = []
105
+ for key, value in kwargs.items():
106
+ if hasattr(config, key):
107
+ setattr(config, key, value)
108
+ to_remove.append(key)
109
+ for key in to_remove:
110
+ kwargs.pop(key, None)
111
+
112
+ if return_unused_kwargs:
113
+ return config, kwargs
114
+ else:
115
+ return config
116
+
117
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
118
+ """
119
+ Save this instance to a JSON file.
120
+
121
+ Args:
122
+ json_file_path (`str` or `os.PathLike`):
123
+ Path to the JSON file in which this configuration instance's parameters will be saved.
124
+ use_diff (`bool`, *optional*, defaults to `True`):
125
+ If set to `True`, only the difference between the config instance and the default
126
+ `QuantizationConfig()` is serialized to JSON file.
127
+ """
128
+ with open(json_file_path, "w", encoding="utf-8") as writer:
129
+ config_dict = self.to_dict()
130
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
131
+
132
+ writer.write(json_string)
133
+
134
+ def to_dict(self) -> Dict[str, Any]:
135
+ """
136
+ Serializes this instance to a Python dictionary. Returns:
137
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
138
+ """
139
+ return copy.deepcopy(self.__dict__)
140
+
141
+ def __iter__(self):
142
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
143
+ for attr, value in copy.deepcopy(self.__dict__).items():
144
+ yield attr, value
145
+
146
+ def __repr__(self):
147
+ return f"{self.__class__.__name__} {self.to_json_string()}"
148
+
149
+ def to_json_string(self, use_diff: bool = True) -> str:
150
+ """
151
+ Serializes this instance to a JSON string.
152
+
153
+ Args:
154
+ use_diff (`bool`, *optional*, defaults to `True`):
155
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
156
+ is serialized to JSON string.
157
+
158
+ Returns:
159
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
160
+ """
161
+ if use_diff is True:
162
+ config_dict = self.to_diff_dict()
163
+ else:
164
+ config_dict = self.to_dict()
165
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
166
+
167
+ def update(self, **kwargs):
168
+ """
169
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
170
+ returning all the unused kwargs.
171
+
172
+ Args:
173
+ kwargs (`Dict[str, Any]`):
174
+ Dictionary of attributes to tentatively update this class.
175
+
176
+ Returns:
177
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
178
+ """
179
+ to_remove = []
180
+ for key, value in kwargs.items():
181
+ if hasattr(self, key):
182
+ setattr(self, key, value)
183
+ to_remove.append(key)
184
+
185
+ # Remove all the attributes that were updated, without modifying the input dict
186
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
187
+ return unused_kwargs
188
+
189
+
190
+ @dataclass
191
+ class HqqConfig(QuantizationConfigMixin):
192
+ """
193
+ This is wrapper around hqq's BaseQuantizeConfig.
194
+
195
+ Args:
196
+ nbits (`int`, *optional*, defaults to 4):
197
+ Number of bits. Supported values are (8, 4, 3, 2, 1).
198
+ group_size (`int`, *optional*, defaults to 64):
199
+ Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
200
+ view_as_float (`bool`, *optional*, defaults to `False`):
201
+ View the quantized weight as float (used in distributed training) if set to `True`.
202
+ axis (`Optional[int]`, *optional*):
203
+ Axis along which grouping is performed. Supported values are 0 or 1.
204
+ dynamic_config (dict, *optional*):
205
+ Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
206
+ If set, each layer specified by its id will use its dedicated quantization configuration.
207
+ skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`):
208
+ List of `nn.Linear` layers to skip.
209
+ kwargs (`Dict[str, Any]`, *optional*):
210
+ Additional parameters from which to initialize the configuration object.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ nbits: int = 4,
216
+ group_size: int = 64,
217
+ view_as_float: bool = False,
218
+ axis: Optional[int] = None,
219
+ dynamic_config: Optional[dict] = None,
220
+ skip_modules: List[str] = ["lm_head"],
221
+ **kwargs,
222
+ ):
223
+ if is_hqq_available():
224
+ from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
225
+
226
+ for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
227
+ if deprecated_key in kwargs:
228
+ logger.info(
229
+ deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
230
+ )
231
+
232
+ if axis is None:
233
+ axis = 1
234
+ logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
235
+
236
+ if axis not in [0, 1]:
237
+ raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
238
+
239
+ if dynamic_config is not None:
240
+ self.quant_config = {}
241
+ for key in dynamic_config:
242
+ self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
243
+ else:
244
+ self.quant_config = HQQBaseQuantizeConfig(
245
+ **{
246
+ "nbits": nbits,
247
+ "group_size": group_size,
248
+ "view_as_float": view_as_float,
249
+ "axis": axis,
250
+ }
251
+ )
252
+
253
+ self.quant_method = QuantizationMethod.HQQ
254
+ self.skip_modules = skip_modules
255
+
256
+ self.post_init()
257
+
258
+ def post_init(self):
259
+ r"""
260
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
261
+ """
262
+ pass
263
+
264
+ @classmethod
265
+ def from_dict(cls, config: Dict[str, Any]):
266
+ """
267
+ Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
268
+ """
269
+ instance = cls()
270
+ instance.quant_config = config["quant_config"]
271
+ instance.skip_modules = config["skip_modules"]
272
+ return instance
273
+
274
+ def to_dict(self) -> Dict[str, Any]:
275
+ """
276
+ Serializes this instance to a Python dictionary. Returns:
277
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
278
+ """
279
+ return {
280
+ "quant_config": self.quant_config,
281
+ "quant_method": self.quant_method,
282
+ "skip_modules": self.skip_modules,
283
+ }
284
+
285
+ def __repr__(self):
286
+ config_dict = self.to_dict()
287
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
288
+
289
+ def to_diff_dict(self) -> Dict[str, Any]:
290
+ """
291
+ Removes all attributes from config which correspond to the default config attributes for better readability and
292
+ serializes to a Python dictionary.
293
+ Returns:
294
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
295
+ """
296
+ config_dict = self.to_dict()
297
+
298
+ # get the default config dict
299
+ default_config_dict = HqqConfig().to_dict()
300
+
301
+ serializable_config_dict = {}
302
+
303
+ # only serialize values that differ from the default config
304
+ for key, value in config_dict.items():
305
+ if value != default_config_dict[key]:
306
+ serializable_config_dict[key] = value
307
+
308
+ return serializable_config_dict
309
+
310
+
311
+ @dataclass
312
+ class BitsAndBytesConfig(QuantizationConfigMixin):
313
+ """
314
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
315
+ loaded using `bitsandbytes`.
316
+
317
+ This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
318
+
319
+ Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
320
+ then more arguments will be added to this class.
321
+
322
+ Args:
323
+ load_in_8bit (`bool`, *optional*, defaults to `False`):
324
+ This flag is used to enable 8-bit quantization with LLM.int8().
325
+ load_in_4bit (`bool`, *optional*, defaults to `False`):
326
+ This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
327
+ `bitsandbytes`.
328
+ llm_int8_threshold (`float`, *optional*, defaults to 6.0):
329
+ This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
330
+ Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
331
+ that is above this threshold will be considered an outlier and the operation on those values will be done
332
+ in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
333
+ there are some exceptional systematic outliers that are very differently distributed for large models.
334
+ These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
335
+ magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
336
+ but a lower threshold might be needed for more unstable models (small models, fine-tuning).
337
+ llm_int8_skip_modules (`List[str]`, *optional*):
338
+ An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
339
+ Jukebox that has several heads in different places and not necessarily at the last position. For example
340
+ for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
341
+ llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
342
+ This flag is used for advanced use cases and users that are aware of this feature. If you want to split
343
+ your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
344
+ this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
345
+ operations will not be run on CPU.
346
+ llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
347
+ This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
348
+ have to be converted back and forth for the backward pass.
349
+ bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
350
+ This sets the computational type which might be different than the input type. For example, inputs might be
351
+ fp32, but computation can be set to bf16 for speedups.
352
+ bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
353
+ This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
354
+ which are specified by `fp4` or `nf4`.
355
+ bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
356
+ This flag is used for nested quantization where the quantization constants from the first quantization are
357
+ quantized again.
358
+ bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
359
+ This sets the storage type to pack the quanitzed 4-bit prarams.
360
+ kwargs (`Dict[str, Any]`, *optional*):
361
+ Additional parameters from which to initialize the configuration object.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ load_in_8bit=False,
367
+ load_in_4bit=False,
368
+ llm_int8_threshold=6.0,
369
+ llm_int8_skip_modules=None,
370
+ llm_int8_enable_fp32_cpu_offload=False,
371
+ llm_int8_has_fp16_weight=False,
372
+ bnb_4bit_compute_dtype=None,
373
+ bnb_4bit_quant_type="fp4",
374
+ bnb_4bit_use_double_quant=False,
375
+ bnb_4bit_quant_storage=None,
376
+ **kwargs,
377
+ ):
378
+ self.quant_method = QuantizationMethod.BITS_AND_BYTES
379
+
380
+ if load_in_4bit and load_in_8bit:
381
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
382
+
383
+ self._load_in_8bit = load_in_8bit
384
+ self._load_in_4bit = load_in_4bit
385
+ self.llm_int8_threshold = llm_int8_threshold
386
+ self.llm_int8_skip_modules = llm_int8_skip_modules
387
+ self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
388
+ self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
389
+ self.bnb_4bit_quant_type = bnb_4bit_quant_type
390
+ self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
391
+
392
+ if bnb_4bit_compute_dtype is None:
393
+ self.bnb_4bit_compute_dtype = torch.float32
394
+ elif isinstance(bnb_4bit_compute_dtype, str):
395
+ self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
396
+ elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
397
+ self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
398
+ else:
399
+ raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
400
+
401
+ if bnb_4bit_quant_storage is None:
402
+ self.bnb_4bit_quant_storage = torch.uint8
403
+ elif isinstance(bnb_4bit_quant_storage, str):
404
+ if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
405
+ raise ValueError(
406
+ "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
407
+ )
408
+ self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
409
+ elif isinstance(bnb_4bit_quant_storage, torch.dtype):
410
+ self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
411
+ else:
412
+ raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
413
+
414
+ if kwargs:
415
+ logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
416
+
417
+ self.post_init()
418
+
419
+ @property
420
+ def load_in_4bit(self):
421
+ return self._load_in_4bit
422
+
423
+ @load_in_4bit.setter
424
+ def load_in_4bit(self, value: bool):
425
+ if not isinstance(value, bool):
426
+ raise TypeError("load_in_4bit must be a boolean")
427
+
428
+ if self.load_in_8bit and value:
429
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
430
+ self._load_in_4bit = value
431
+
432
+ @property
433
+ def load_in_8bit(self):
434
+ return self._load_in_8bit
435
+
436
+ @load_in_8bit.setter
437
+ def load_in_8bit(self, value: bool):
438
+ if not isinstance(value, bool):
439
+ raise TypeError("load_in_8bit must be a boolean")
440
+
441
+ if self.load_in_4bit and value:
442
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
443
+ self._load_in_8bit = value
444
+
445
+ def post_init(self):
446
+ r"""
447
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
448
+ """
449
+ if not isinstance(self.load_in_4bit, bool):
450
+ raise TypeError("load_in_4bit must be a boolean")
451
+
452
+ if not isinstance(self.load_in_8bit, bool):
453
+ raise TypeError("load_in_8bit must be a boolean")
454
+
455
+ if not isinstance(self.llm_int8_threshold, float):
456
+ raise TypeError("llm_int8_threshold must be a float")
457
+
458
+ if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
459
+ raise TypeError("llm_int8_skip_modules must be a list of strings")
460
+ if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
461
+ raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean")
462
+
463
+ if not isinstance(self.llm_int8_has_fp16_weight, bool):
464
+ raise TypeError("llm_int8_has_fp16_weight must be a boolean")
465
+
466
+ if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
467
+ raise TypeError("bnb_4bit_compute_dtype must be torch.dtype")
468
+
469
+ if not isinstance(self.bnb_4bit_quant_type, str):
470
+ raise TypeError("bnb_4bit_quant_type must be a string")
471
+
472
+ if not isinstance(self.bnb_4bit_use_double_quant, bool):
473
+ raise TypeError("bnb_4bit_use_double_quant must be a boolean")
474
+
475
+ if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
476
+ "0.39.0"
477
+ ):
478
+ raise ValueError(
479
+ "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
480
+ )
481
+
482
+ def is_quantizable(self):
483
+ r"""
484
+ Returns `True` if the model is quantizable, `False` otherwise.
485
+ """
486
+ return self.load_in_8bit or self.load_in_4bit
487
+
488
+ def quantization_method(self):
489
+ r"""
490
+ This method returns the quantization method used for the model. If the model is not quantizable, it returns
491
+ `None`.
492
+ """
493
+ if self.load_in_8bit:
494
+ return "llm_int8"
495
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
496
+ return "fp4"
497
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
498
+ return "nf4"
499
+ else:
500
+ return None
501
+
502
+ def to_dict(self) -> Dict[str, Any]:
503
+ """
504
+ Serializes this instance to a Python dictionary. Returns:
505
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
506
+ """
507
+ output = copy.deepcopy(self.__dict__)
508
+ output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
509
+ output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
510
+ output["load_in_4bit"] = self.load_in_4bit
511
+ output["load_in_8bit"] = self.load_in_8bit
512
+
513
+ return output
514
+
515
+ def __repr__(self):
516
+ config_dict = self.to_dict()
517
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
518
+
519
+ def to_diff_dict(self) -> Dict[str, Any]:
520
+ """
521
+ Removes all attributes from config which correspond to the default config attributes for better readability and
522
+ serializes to a Python dictionary.
523
+
524
+ Returns:
525
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
526
+ """
527
+ config_dict = self.to_dict()
528
+
529
+ # get the default config dict
530
+ default_config_dict = BitsAndBytesConfig().to_dict()
531
+
532
+ serializable_config_dict = {}
533
+
534
+ # only serialize values that differ from the default config
535
+ for key, value in config_dict.items():
536
+ if value != default_config_dict[key]:
537
+ serializable_config_dict[key] = value
538
+
539
+ return serializable_config_dict
540
+
541
+
542
+ class ExllamaVersion(int, Enum):
543
+ ONE = 1
544
+ TWO = 2
545
+
546
+
547
+ @dataclass
548
+ class GPTQConfig(QuantizationConfigMixin):
549
+ """
550
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
551
+ loaded using `optimum` api for gptq quantization relying on auto_gptq backend.
552
+
553
+ Args:
554
+ bits (`int`):
555
+ The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
556
+ tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
557
+ The tokenizer used to process the dataset. You can pass either:
558
+ - A custom tokenizer object.
559
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
560
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
561
+ using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
562
+ dataset (`Union[List[str]]`, *optional*):
563
+ The dataset used for quantization. You can provide your own dataset in a list of string or just use the
564
+ original datasets used in GPTQ paper ['wikitext2','c4','c4-new']
565
+ group_size (`int`, *optional*, defaults to 128):
566
+ The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
567
+ damp_percent (`float`, *optional*, defaults to 0.1):
568
+ The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1.
569
+ desc_act (`bool`, *optional*, defaults to `False`):
570
+ Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly
571
+ speed up inference but the perplexity may become slightly worse. Also known as act-order.
572
+ sym (`bool`, *optional*, defaults to `True`):
573
+ Whether to use symetric quantization.
574
+ true_sequential (`bool`, *optional*, defaults to `True`):
575
+ Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing
576
+ the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes
577
+ quantization using inputs that have passed through the previously quantized layers.
578
+ use_cuda_fp16 (`bool`, *optional*, defaults to `False`):
579
+ Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
580
+ model_seqlen (`int`, *optional*):
581
+ The maximum sequence length that the model can take.
582
+ block_name_to_quantize (`str`, *optional*):
583
+ The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers)
584
+ module_name_preceding_first_block (`List[str]`, *optional*):
585
+ The layers that are preceding the first Transformer block.
586
+ batch_size (`int`, *optional*, defaults to 1):
587
+ The batch size used when processing the dataset
588
+ pad_token_id (`int`, *optional*):
589
+ The pad token id. Needed to prepare the dataset when `batch_size` > 1.
590
+ use_exllama (`bool`, *optional*):
591
+ Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4.
592
+ max_input_length (`int`, *optional*):
593
+ The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
594
+ length. It is specific to the exllama backend with act-order.
595
+ exllama_config (`Dict[str, Any]`, *optional*):
596
+ The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults
597
+ to `{"version": 1}` if unset.
598
+ cache_block_outputs (`bool`, *optional*, defaults to `True`):
599
+ Whether to cache block outputs to reuse as inputs for the succeeding block.
600
+ modules_in_block_to_quantize (`List[List[str]]`, *optional*):
601
+ List of list of module names to quantize in the specified block. This argument is useful to exclude certain linear modules from being quantized.
602
+ The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers.
603
+ Example: `modules_in_block_to_quantize =[["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"]]`.
604
+ In this example, we will first quantize the q,k,v layers simultaneously since they are independent.
605
+ Then, we will quantize `self_attn.o_proj` layer with the q,k,v layers quantized. This way, we will get
606
+ better results since it reflects the real input `self_attn.o_proj` will get when the model is quantized.
607
+ """
608
+
609
+ def __init__(
610
+ self,
611
+ bits: int,
612
+ tokenizer: Any = None,
613
+ dataset: Optional[Union[List[str], str]] = None,
614
+ group_size: int = 128,
615
+ damp_percent: float = 0.1,
616
+ desc_act: bool = False,
617
+ sym: bool = True,
618
+ true_sequential: bool = True,
619
+ use_cuda_fp16: bool = False,
620
+ model_seqlen: Optional[int] = None,
621
+ block_name_to_quantize: Optional[str] = None,
622
+ module_name_preceding_first_block: Optional[List[str]] = None,
623
+ batch_size: int = 1,
624
+ pad_token_id: Optional[int] = None,
625
+ use_exllama: Optional[bool] = None,
626
+ max_input_length: Optional[int] = None,
627
+ exllama_config: Optional[Dict[str, Any]] = None,
628
+ cache_block_outputs: bool = True,
629
+ modules_in_block_to_quantize: Optional[List[List[str]]] = None,
630
+ **kwargs,
631
+ ):
632
+ self.quant_method = QuantizationMethod.GPTQ
633
+ self.bits = bits
634
+ self.tokenizer = tokenizer
635
+ self.dataset = dataset
636
+ self.group_size = group_size
637
+ self.damp_percent = damp_percent
638
+ self.desc_act = desc_act
639
+ self.sym = sym
640
+ self.true_sequential = true_sequential
641
+ self.use_cuda_fp16 = use_cuda_fp16
642
+ self.model_seqlen = model_seqlen
643
+ self.block_name_to_quantize = block_name_to_quantize
644
+ self.module_name_preceding_first_block = module_name_preceding_first_block
645
+ self.batch_size = batch_size
646
+ self.pad_token_id = pad_token_id
647
+ self.use_exllama = use_exllama
648
+ self.max_input_length = max_input_length
649
+ self.exllama_config = exllama_config
650
+ self.disable_exllama = kwargs.pop("disable_exllama", None)
651
+ self.cache_block_outputs = cache_block_outputs
652
+ self.modules_in_block_to_quantize = modules_in_block_to_quantize
653
+ self.post_init()
654
+
655
+ def get_loading_attributes(self):
656
+ attibutes_dict = copy.deepcopy(self.__dict__)
657
+ loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
658
+ loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
659
+ return loading_attibutes_dict
660
+
661
+ def post_init(self):
662
+ r"""
663
+ Safety checker that arguments are correct
664
+ """
665
+ if self.bits not in [2, 3, 4, 8]:
666
+ raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
667
+ if self.group_size != -1 and self.group_size <= 0:
668
+ raise ValueError("group_size must be greater than 0 or equal to -1")
669
+ if not (0 < self.damp_percent < 1):
670
+ raise ValueError("damp_percent must between 0 and 1.")
671
+ if self.dataset is not None:
672
+ if isinstance(self.dataset, str):
673
+ if self.dataset in ["ptb", "ptb-new"]:
674
+ raise ValueError(
675
+ f"""{self.dataset} dataset was deprecated. You can only choose between
676
+ ['wikitext2','c4','c4-new']"""
677
+ )
678
+ if self.dataset not in ["wikitext2", "c4", "c4-new"]:
679
+ raise ValueError(
680
+ f"""You have entered a string value for dataset. You can only choose between
681
+ ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
682
+ )
683
+ elif not isinstance(self.dataset, list):
684
+ raise ValueError(
685
+ f"""dataset needs to be either a list of string or a value in
686
+ ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
687
+ )
688
+
689
+ if self.disable_exllama is None and self.use_exllama is None:
690
+ # New default behaviour
691
+ self.use_exllama = True
692
+ elif self.disable_exllama is not None and self.use_exllama is None:
693
+ # Follow pattern of old config
694
+ logger.warning(
695
+ "Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
696
+ "The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
697
+ )
698
+ self.use_exllama = not self.disable_exllama
699
+ self.disable_exllama = None
700
+ elif self.disable_exllama is not None and self.use_exllama is not None:
701
+ # Only happens if user explicitly passes in both arguments
702
+ raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
703
+
704
+ if self.exllama_config is None:
705
+ self.exllama_config = {"version": ExllamaVersion.ONE}
706
+ else:
707
+ if "version" not in self.exllama_config:
708
+ raise ValueError("`exllama_config` needs to have a `version` key.")
709
+ elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
710
+ exllama_version = self.exllama_config["version"]
711
+ raise ValueError(
712
+ f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
713
+ )
714
+
715
+ if self.bits == 4 and self.use_exllama:
716
+ if self.exllama_config["version"] == ExllamaVersion.ONE:
717
+ logger.info(
718
+ "You have activated exllama backend. Note that you can get better inference "
719
+ "speed using exllamav2 kernel by setting `exllama_config`."
720
+ )
721
+ elif self.exllama_config["version"] == ExllamaVersion.TWO:
722
+ optimum_version = version.parse(importlib.metadata.version("optimum"))
723
+ autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
724
+ if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
725
+ raise ValueError(
726
+ f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
727
+ )
728
+ if self.modules_in_block_to_quantize is not None:
729
+ optimum_version = version.parse(importlib.metadata.version("optimum"))
730
+ if optimum_version < version.parse("1.15.0"):
731
+ raise ValueError(
732
+ "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
733
+ )
734
+
735
+ def to_dict(self):
736
+ config_dict = super().to_dict()
737
+ config_dict.pop("disable_exllama", None)
738
+ return config_dict
739
+
740
+ def to_dict_optimum(self):
741
+ """
742
+ Get compatible dict for optimum gptq config
743
+ """
744
+ quant_dict = self.to_dict()
745
+ # make it compatible with optimum config
746
+ quant_dict["disable_exllama"] = not self.use_exllama
747
+ return quant_dict
748
+
749
+ @classmethod
750
+ def from_dict_optimum(cls, config_dict):
751
+ """
752
+ Get compatible class with optimum gptq config dict
753
+ """
754
+
755
+ if "disable_exllama" in config_dict:
756
+ config_dict["use_exllama"] = not config_dict["disable_exllama"]
757
+ # switch to None to not trigger the warning
758
+ config_dict["disable_exllama"] = None
759
+
760
+ config = cls(**config_dict)
761
+ return config
762
+
763
+
764
+ @dataclass
765
+ class AwqConfig(QuantizationConfigMixin):
766
+ """
767
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
768
+ loaded using `auto-awq` library awq quantization relying on auto_awq backend.
769
+
770
+ Args:
771
+ bits (`int`, *optional*, defaults to 4):
772
+ The number of bits to quantize to.
773
+ group_size (`int`, *optional*, defaults to 128):
774
+ The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
775
+ zero_point (`bool`, *optional*, defaults to `True`):
776
+ Whether to use zero point quantization.
777
+ version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
778
+ The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
779
+ GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
780
+ backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
781
+ The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
782
+ that quantize their own models using `llm-awq` library.
783
+ do_fuse (`bool`, *optional*, defaults to `False`):
784
+ Whether to fuse attention and mlp layers together for faster inference
785
+ fuse_max_seq_len (`int`, *optional*):
786
+ The Maximum sequence length to generate when using fusing.
787
+ modules_to_fuse (`dict`, *optional*, default to `None`):
788
+ Overwrite the natively supported fusing scheme with the one specified by the users.
789
+ modules_to_not_convert (`list`, *optional*, default to `None`):
790
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
791
+ some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
792
+ Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
793
+ exllama_config (`Dict[str, Any]`, *optional*):
794
+ You can specify the version of the exllama kernel through the `version` key, the maximum sequence
795
+ length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
796
+ Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
797
+ """
798
+
799
+ def __init__(
800
+ self,
801
+ bits: int = 4,
802
+ group_size: int = 128,
803
+ zero_point: bool = True,
804
+ version: AWQLinearVersion = AWQLinearVersion.GEMM,
805
+ backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
806
+ do_fuse: Optional[bool] = None,
807
+ fuse_max_seq_len: Optional[int] = None,
808
+ modules_to_fuse: Optional[dict] = None,
809
+ modules_to_not_convert: Optional[List] = None,
810
+ exllama_config: Optional[Dict[str, int]] = None,
811
+ **kwargs,
812
+ ):
813
+ self.quant_method = QuantizationMethod.AWQ
814
+
815
+ self.bits = bits
816
+ self.group_size = group_size
817
+ self.zero_point = zero_point
818
+ self.version = version
819
+ self.backend = backend
820
+ self.fuse_max_seq_len = fuse_max_seq_len
821
+ self.modules_to_not_convert = modules_to_not_convert
822
+ self.exllama_config = exllama_config
823
+
824
+ self.modules_to_fuse = modules_to_fuse
825
+ if do_fuse is None:
826
+ self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0
827
+ else:
828
+ self.do_fuse = do_fuse
829
+ self.fuse_max_seq_len = fuse_max_seq_len
830
+
831
+ self.post_init()
832
+
833
+ def post_init(self):
834
+ r"""
835
+ Safety checker that arguments are correct
836
+ """
837
+ if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
838
+ raise ValueError(
839
+ f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
840
+ )
841
+
842
+ self.version = AWQLinearVersion.from_str(self.version)
843
+ if self.version not in [
844
+ AWQLinearVersion.GEMM,
845
+ AWQLinearVersion.GEMV,
846
+ AWQLinearVersion.EXLLAMA,
847
+ AWQLinearVersion.IPEX,
848
+ ]:
849
+ raise ValueError(
850
+ f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA, AWQLinearVersion.IPEX] - not recognized version {self.version}"
851
+ )
852
+
853
+ if self.backend == AwqBackendPackingMethod.LLMAWQ:
854
+ compute_capability = torch.cuda.get_device_capability()
855
+ major, minor = compute_capability
856
+ if major < 8:
857
+ raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
858
+
859
+ if self.do_fuse and self.fuse_max_seq_len is None:
860
+ raise ValueError(
861
+ "You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase"
862
+ )
863
+
864
+ if self.do_fuse:
865
+ awq_version_supports_fusing = False
866
+ MIN_AWQ_VERSION = "0.1.7"
867
+ if is_auto_awq_available():
868
+ awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
869
+ MIN_AWQ_VERSION
870
+ )
871
+
872
+ if not awq_version_supports_fusing:
873
+ raise ValueError(
874
+ f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
875
+ )
876
+
877
+ if self.modules_to_not_convert is not None:
878
+ awq_version_supports_non_conversion = False
879
+ MIN_AWQ_VERSION = "0.1.8"
880
+ if is_auto_awq_available():
881
+ awq_version_supports_non_conversion = version.parse(
882
+ importlib.metadata.version("autoawq")
883
+ ) >= version.parse(MIN_AWQ_VERSION)
884
+
885
+ if not awq_version_supports_non_conversion:
886
+ raise ValueError(
887
+ f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
888
+ )
889
+
890
+ if self.do_fuse and self.modules_to_fuse is not None:
891
+ required_keys = [
892
+ "hidden_size",
893
+ "num_attention_heads",
894
+ "num_key_value_heads",
895
+ "mlp",
896
+ "attention",
897
+ "layernorm",
898
+ "use_alibi",
899
+ ]
900
+ if not all(key in self.modules_to_fuse for key in required_keys):
901
+ raise ValueError(
902
+ f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
903
+ )
904
+
905
+ if self.version == AWQLinearVersion.EXLLAMA:
906
+ awq_version_supports_exllama = False
907
+ MIN_AWQ_VERSION = "0.2.0"
908
+ if is_auto_awq_available():
909
+ awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
910
+ MIN_AWQ_VERSION
911
+ )
912
+
913
+ if not awq_version_supports_exllama:
914
+ raise ValueError(
915
+ f"You current version of `autoawq` does not support exllama backend, "
916
+ f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
917
+ )
918
+
919
+ if self.exllama_config is None:
920
+ self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
921
+ else:
922
+ if "version" not in self.exllama_config:
923
+ raise ValueError("`exllama_config` needs to have a `version` key.")
924
+ elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
925
+ exllama_version = self.exllama_config["version"]
926
+ raise ValueError(
927
+ f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
928
+ )
929
+
930
+ def get_loading_attributes(self):
931
+ attibutes_dict = copy.deepcopy(self.__dict__)
932
+ loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len", "exllama_config"]
933
+ loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
934
+ return loading_attibutes_dict
935
+
936
+
937
+ @dataclass
938
+ class AqlmConfig(QuantizationConfigMixin):
939
+ """
940
+ This is a wrapper class about `aqlm` parameters.
941
+
942
+ Args:
943
+ in_group_size (`int`, *optional*, defaults to 8):
944
+ The group size along the input dimension.
945
+ out_group_size (`int`, *optional*, defaults to 1):
946
+ The group size along the output dimension. It's recommended to always use 1.
947
+ num_codebooks (`int`, *optional*, defaults to 1):
948
+ Number of codebooks for the Additive Quantization procedure.
949
+ nbits_per_codebook (`int`, *optional*, defaults to 16):
950
+ Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
951
+ linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*):
952
+ List of full paths of `nn.Linear` weight parameters that shall not be quantized.
953
+ kwargs (`Dict[str, Any]`, *optional*):
954
+ Additional parameters from which to initialize the configuration object.
955
+ """
956
+
957
+ def __init__(
958
+ self,
959
+ in_group_size: int = 8,
960
+ out_group_size: int = 1,
961
+ num_codebooks: int = 1,
962
+ nbits_per_codebook: int = 16,
963
+ linear_weights_not_to_quantize: Optional[List[str]] = None,
964
+ **kwargs,
965
+ ):
966
+ self.quant_method = QuantizationMethod.AQLM
967
+ self.in_group_size = in_group_size
968
+ self.out_group_size = out_group_size
969
+ self.num_codebooks = num_codebooks
970
+ self.nbits_per_codebook = nbits_per_codebook
971
+ self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
972
+
973
+ self.post_init()
974
+
975
+ def post_init(self):
976
+ r"""
977
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
978
+ """
979
+ if not isinstance(self.in_group_size, int):
980
+ raise TypeError("in_group_size must be a float")
981
+ if not isinstance(self.out_group_size, int):
982
+ raise TypeError("out_group_size must be a float")
983
+ if not isinstance(self.num_codebooks, int):
984
+ raise TypeError("num_codebooks must be a float")
985
+ if not isinstance(self.nbits_per_codebook, int):
986
+ raise TypeError("nbits_per_codebook must be a float")
987
+
988
+ if self.linear_weights_not_to_quantize is not None and not isinstance(
989
+ self.linear_weights_not_to_quantize, list
990
+ ):
991
+ raise ValueError("linear_weights_not_to_quantize must be a list of strings")
992
+
993
+ if self.linear_weights_not_to_quantize is None:
994
+ self.linear_weights_not_to_quantize = []
995
+
996
+
997
+ @dataclass
998
+ class QuantoConfig(QuantizationConfigMixin):
999
+ """
1000
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
1001
+ loaded using `quanto`.
1002
+
1003
+ Args:
1004
+ weights (`str`, *optional*, defaults to `"int8"`):
1005
+ The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
1006
+ activations (`str`, *optional*):
1007
+ The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
1008
+ modules_to_not_convert (`list`, *optional*, default to `None`):
1009
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
1010
+ some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
1011
+ """
1012
+
1013
+ def __init__(
1014
+ self,
1015
+ weights="int8",
1016
+ activations=None,
1017
+ modules_to_not_convert: Optional[List] = None,
1018
+ **kwargs,
1019
+ ):
1020
+ self.quant_method = QuantizationMethod.QUANTO
1021
+ self.weights = weights
1022
+ self.activations = activations
1023
+ self.modules_to_not_convert = modules_to_not_convert
1024
+ self.post_init()
1025
+
1026
+ def post_init(self):
1027
+ r"""
1028
+ Safety checker that arguments are correct
1029
+ """
1030
+ accepted_weights = ["float8", "int8", "int4", "int2"]
1031
+ accepted_activations = [None, "int8", "float8"]
1032
+ if self.weights not in accepted_weights:
1033
+ raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
1034
+ if self.activations not in accepted_activations:
1035
+ raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
1036
+
1037
+
1038
+ @dataclass
1039
+ class EetqConfig(QuantizationConfigMixin):
1040
+ """
1041
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
1042
+ loaded using `eetq`.
1043
+
1044
+ Args:
1045
+ weights (`str`, *optional*, defaults to `"int8"`):
1046
+ The target dtype for the weights. Supported value is only "int8"
1047
+ modules_to_not_convert (`list`, *optional*, default to `None`):
1048
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
1049
+ some modules left in their original precision.
1050
+ """
1051
+
1052
+ def __init__(
1053
+ self,
1054
+ weights: str = "int8",
1055
+ modules_to_not_convert: Optional[List] = None,
1056
+ **kwargs,
1057
+ ):
1058
+ self.quant_method = QuantizationMethod.EETQ
1059
+ self.weights = weights
1060
+ self.modules_to_not_convert = modules_to_not_convert
1061
+ self.post_init()
1062
+
1063
+ def post_init(self):
1064
+ r"""
1065
+ Safety checker that arguments are correct
1066
+ """
1067
+ accepted_weights = ["int8"]
1068
+ if self.weights not in accepted_weights:
1069
+ raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
1070
+
1071
+
1072
+ class CompressedTensorsConfig(QuantizationConfigMixin):
1073
+ """
1074
+ This is a wrapper class that handles compressed-tensors quantization config options.
1075
+ It is a wrapper around `compressed_tensors.QuantizationConfig`
1076
+ Args:
1077
+ config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
1078
+ dictionary mapping group name to a quantization scheme definition
1079
+ format (`str`, *optional*, defaults to `"dense"`):
1080
+ format the model is represented as
1081
+ quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
1082
+ status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
1083
+ kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
1084
+ specifies quantization of the kv cache. If None, kv cache is not quantized.
1085
+ global_compression_ratio (`typing.Union[float, NoneType]`, *optional*):
1086
+ 0-1 float percentage of model compression
1087
+ ignore (`typing.Union[typing.List[str], NoneType]`, *optional*):
1088
+ layer names or types to not quantize, supports regex prefixed by 're:'
1089
+ sparsity_config (`typing.Dict[str, typing.Any]`, *optional*):
1090
+ configuration for sparsity compression
1091
+ quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
1092
+ do not override, should be compressed-tensors
1093
+ """
1094
+
1095
+ def __init__(
1096
+ self,
1097
+ config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, # noqa: F821
1098
+ format: str = "dense",
1099
+ quantization_status: "QuantizationStatus" = "initialized", # noqa: F821
1100
+ kv_cache_scheme: Optional["QuantizationArgs"] = None, # noqa: F821
1101
+ global_compression_ratio: Optional[float] = None,
1102
+ ignore: Optional[List[str]] = None,
1103
+ sparsity_config: Dict[str, Any] = None,
1104
+ quant_method: str = "compressed-tensors",
1105
+ **kwargs,
1106
+ ):
1107
+ from compressed_tensors import QuantizationConfig
1108
+ from compressed_tensors.config import SparsityCompressionConfig
1109
+
1110
+ self.quantization_config = None
1111
+ self.sparsity_config = None
1112
+
1113
+ # parse from dict to load nested QuantizationScheme objects
1114
+ if config_groups or kv_cache_scheme:
1115
+ self.quantization_config = QuantizationConfig.parse_obj(
1116
+ {
1117
+ "config_groups": config_groups,
1118
+ "quant_method": quant_method,
1119
+ "format": format,
1120
+ "quantization_status": quantization_status,
1121
+ "kv_cache_scheme": kv_cache_scheme,
1122
+ "global_compression_ratio": global_compression_ratio,
1123
+ "ignore": ignore,
1124
+ **kwargs,
1125
+ }
1126
+ )
1127
+
1128
+ if sparsity_config:
1129
+ self.sparsity_config = SparsityCompressionConfig.load_from_registry(
1130
+ sparsity_config.get("format"), **sparsity_config
1131
+ )
1132
+
1133
+ super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)
1134
+
1135
+ @classmethod
1136
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
1137
+ """
1138
+ Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters.
1139
+ Optionally unwraps any args from the nested quantization_config
1140
+
1141
+ Args:
1142
+ config_dict (`Dict[str, Any]`):
1143
+ Dictionary that will be used to instantiate the configuration object.
1144
+ return_unused_kwargs (`bool`,*optional*, defaults to `False`):
1145
+ Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
1146
+ `PreTrainedModel`.
1147
+ kwargs (`Dict[str, Any]`):
1148
+ Additional parameters from which to initialize the configuration object.
1149
+
1150
+ Returns:
1151
+ [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
1152
+ """
1153
+
1154
+ if "quantization_config" in config_dict:
1155
+ config_dict = dict(
1156
+ sparsity_config=config_dict.get("sparsity_config"),
1157
+ **config_dict["quantization_config"],
1158
+ )
1159
+
1160
+ return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
1161
+
1162
+ def to_dict(self) -> Dict[str, Any]:
1163
+ """
1164
+ Quantization config to be added to config.json
1165
+
1166
+ Serializes this instance to a Python dictionary. Returns:
1167
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
1168
+ """
1169
+ quantization_config = {}
1170
+ if self.quantization_config is not None:
1171
+ quantization_config = self.quantization_config.dict()
1172
+ else:
1173
+ quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS
1174
+
1175
+ if self.sparsity_config is not None:
1176
+ quantization_config["sparsity_config"] = self.sparsity_config.dict()
1177
+ else:
1178
+ quantization_config["sparsity_config"] = {}
1179
+
1180
+ return quantization_config
1181
+
1182
+ def to_diff_dict(self) -> Dict[str, Any]:
1183
+ """
1184
+ Removes all attributes from config which correspond to the default config attributes for better readability and
1185
+ serializes to a Python dictionary.
1186
+ Returns:
1187
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
1188
+ """
1189
+ config_dict = self.to_dict()
1190
+
1191
+ # get the default config dict
1192
+ default_config_dict = CompressedTensorsConfig().to_dict()
1193
+
1194
+ serializable_config_dict = {}
1195
+
1196
+ # only serialize values that differ from the default config
1197
+ for key, value in config_dict.items():
1198
+ if value != default_config_dict[key]:
1199
+ serializable_config_dict[key] = value
1200
+
1201
+ return serializable_config_dict
1202
+
1203
+
1204
+ @dataclass
1205
+ class FbgemmFp8Config(QuantizationConfigMixin):
1206
+ """
1207
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
1208
+ loaded using fbgemm fp8 quantization.
1209
+
1210
+ Args:
1211
+ activation_scale_ub (`float`, *optional*, defaults to 1200.0):
1212
+ The activation scale upper bound. This is used when quantizing the input activation.
1213
+ modules_to_not_convert (`list`, *optional*, default to `None`):
1214
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
1215
+ some modules left in their original precision.
1216
+ """
1217
+
1218
+ def __init__(
1219
+ self,
1220
+ activation_scale_ub: float = 1200.0,
1221
+ modules_to_not_convert: Optional[List] = None,
1222
+ **kwargs,
1223
+ ):
1224
+ self.quant_method = QuantizationMethod.FBGEMM_FP8
1225
+ self.activation_scale_ub = activation_scale_ub
1226
+ self.modules_to_not_convert = modules_to_not_convert
1227
+
1228
+ def get_loading_attributes(self):
1229
+ attibutes_dict = copy.deepcopy(self.__dict__)
1230
+ loading_attibutes = ["activation_scale_ub"]
1231
+ loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
1232
+ return loading_attibutes_dict
1233
+
1234
+
1235
+ @dataclass
1236
+ class TorchAoConfig(QuantizationConfigMixin):
1237
+ """This is a config class for torchao quantization/sparsity techniques.
1238
+
1239
+ Args:
1240
+ quant_type (`str`):
1241
+ The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
1242
+ modules_to_not_convert (`list`, *optional*, default to `None`):
1243
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have
1244
+ some modules left in their original precision.
1245
+ kwargs (`Dict[str, Any]`, *optional*):
1246
+ The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
1247
+ `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
1248
+ https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
1249
+
1250
+ Example:
1251
+
1252
+ ```python
1253
+ quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
1254
+ # int4_weight_only quant is only working with *torch.bfloat16* dtype right now
1255
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
1256
+ ```
1257
+ """
1258
+
1259
+ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
1260
+ self.quant_method = QuantizationMethod.TORCHAO
1261
+ self.quant_type = quant_type
1262
+ self.modules_to_not_convert = modules_to_not_convert
1263
+ # when we load from serailized config, "quant_type_kwargs" will be the key
1264
+ if "quant_type_kwargs" in kwargs:
1265
+ self.quant_type_kwargs = kwargs["quant_type_kwargs"]
1266
+ else:
1267
+ self.quant_type_kwargs = kwargs
1268
+
1269
+ self.post_init()
1270
+
1271
+ def post_init(self):
1272
+ r"""
1273
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
1274
+ """
1275
+ if is_torchao_available():
1276
+ if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
1277
+ raise ValueError("Requires torchao 0.4.0 version and above")
1278
+ else:
1279
+ raise ValueError(
1280
+ "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
1281
+ )
1282
+
1283
+ _STR_TO_METHOD = self._get_torchao_quant_type_to_method()
1284
+ if self.quant_type not in _STR_TO_METHOD.keys():
1285
+ raise ValueError(
1286
+ f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
1287
+ )
1288
+
1289
+ method = _STR_TO_METHOD[self.quant_type]
1290
+ sig = signature(method)
1291
+ all_kwargs = [
1292
+ param.name
1293
+ for param in sig.parameters.values()
1294
+ if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
1295
+ ]
1296
+ for k in self.quant_type_kwargs:
1297
+ if k not in all_kwargs:
1298
+ raise ValueError(
1299
+ f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
1300
+ )
1301
+
1302
+ def _get_torchao_quant_type_to_method(self):
1303
+ if is_torchao_available():
1304
+ from torchao.quantization import (
1305
+ int4_weight_only,
1306
+ int8_dynamic_activation_int8_weight,
1307
+ int8_weight_only,
1308
+ )
1309
+
1310
+ return {
1311
+ "int4_weight_only": int4_weight_only,
1312
+ "int8_weight_only": int8_weight_only,
1313
+ "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
1314
+ }
1315
+ else:
1316
+ raise ValueError(
1317
+ "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
1318
+ )
1319
+
1320
+ def get_apply_tensor_subclass(self):
1321
+ _STR_TO_METHOD = self._get_torchao_quant_type_to_method()
1322
+ return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
1323
+
1324
+ def __repr__(self):
1325
+ config_dict = self.to_dict()
1326
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
1327
+
1328
+
1329
+ @dataclass
1330
+ class BitNetConfig(QuantizationConfigMixin):
1331
+ def __init__(
1332
+ self,
1333
+ modules_to_not_convert: Optional[List] = None,
1334
+ **kwargs,
1335
+ ):
1336
+ self.quant_method = QuantizationMethod.BITNET
1337
+ self.modules_to_not_convert = modules_to_not_convert
1338
+ self.post_init()
1339
+
1340
+ def post_init(self):
1341
+ r"""
1342
+ Safety checker that arguments are correct
1343
+ """
1344
+ pass
.venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2.py ADDED
@@ -0,0 +1,1511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: sentencepiece_model.proto
3
+
4
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ from google.protobuf import descriptor as _descriptor
18
+ from google.protobuf import message as _message
19
+ from google.protobuf import reflection as _reflection
20
+ from google.protobuf import symbol_database as _symbol_database
21
+
22
+
23
+ # @@protoc_insertion_point(imports)
24
+
25
+ _sym_db = _symbol_database.Default()
26
+
27
+
28
+ DESCRIPTOR = _descriptor.FileDescriptor(
29
+ name="sentencepiece_model.proto",
30
+ package="sentencepiece",
31
+ syntax="proto2",
32
+ serialized_options=b"H\003",
33
+ create_key=_descriptor._internal_create_key,
34
+ serialized_pb=(
35
+ b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01'
36
+ b" \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02"
37
+ b" \x01(\t\x12\x41\n\nmodel_type\x18\x03"
38
+ b" \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04"
39
+ b" \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12"
40
+ b' \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n'
41
+ b" \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b"
42
+ b" \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12"
43
+ b' \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r'
44
+ b" \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e"
45
+ b" \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f"
46
+ b" \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12"
47
+ b" \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10"
48
+ b" \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11"
49
+ b" \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14"
50
+ b" \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15"
51
+ b" \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17"
52
+ b" \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16"
53
+ b" \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18"
54
+ b" \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19"
55
+ b" \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e"
56
+ b" \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$"
57
+ b" \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18"
58
+ b' \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18"'
59
+ b" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18)"
60
+ b" \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+"
61
+ b" \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18."
62
+ b" \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30"
63
+ b" \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87"
64
+ b" \x12+\n\x1ctrain_extremely_large_corpus\x18\x31"
65
+ b' \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01'
66
+ b" \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03"
67
+ b" \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12"
68
+ b" \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06"
69
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01'
70
+ b' \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01'
71
+ b" \x01(\t\x12\x10\n\x08\x65xpected\x18\x02"
72
+ b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01'
73
+ b" \x03(\x0b\x32'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02"
74
+ b" \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03"
75
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04"
76
+ b" \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05"
77
+ b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01"
78
+ b" \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03"
79
+ b' \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
80
+ ),
81
+ )
82
+
83
+
84
+ _TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor(
85
+ name="ModelType",
86
+ full_name="sentencepiece.TrainerSpec.ModelType",
87
+ filename=None,
88
+ file=DESCRIPTOR,
89
+ create_key=_descriptor._internal_create_key,
90
+ values=[
91
+ _descriptor.EnumValueDescriptor(
92
+ name="UNIGRAM",
93
+ index=0,
94
+ number=1,
95
+ serialized_options=None,
96
+ type=None,
97
+ create_key=_descriptor._internal_create_key,
98
+ ),
99
+ _descriptor.EnumValueDescriptor(
100
+ name="BPE",
101
+ index=1,
102
+ number=2,
103
+ serialized_options=None,
104
+ type=None,
105
+ create_key=_descriptor._internal_create_key,
106
+ ),
107
+ _descriptor.EnumValueDescriptor(
108
+ name="WORD",
109
+ index=2,
110
+ number=3,
111
+ serialized_options=None,
112
+ type=None,
113
+ create_key=_descriptor._internal_create_key,
114
+ ),
115
+ _descriptor.EnumValueDescriptor(
116
+ name="CHAR",
117
+ index=3,
118
+ number=4,
119
+ serialized_options=None,
120
+ type=None,
121
+ create_key=_descriptor._internal_create_key,
122
+ ),
123
+ ],
124
+ containing_type=None,
125
+ serialized_options=None,
126
+ serialized_start=1294,
127
+ serialized_end=1347,
128
+ )
129
+ _sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE)
130
+
131
+ _MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor(
132
+ name="Type",
133
+ full_name="sentencepiece.ModelProto.SentencePiece.Type",
134
+ filename=None,
135
+ file=DESCRIPTOR,
136
+ create_key=_descriptor._internal_create_key,
137
+ values=[
138
+ _descriptor.EnumValueDescriptor(
139
+ name="NORMAL",
140
+ index=0,
141
+ number=1,
142
+ serialized_options=None,
143
+ type=None,
144
+ create_key=_descriptor._internal_create_key,
145
+ ),
146
+ _descriptor.EnumValueDescriptor(
147
+ name="UNKNOWN",
148
+ index=1,
149
+ number=2,
150
+ serialized_options=None,
151
+ type=None,
152
+ create_key=_descriptor._internal_create_key,
153
+ ),
154
+ _descriptor.EnumValueDescriptor(
155
+ name="CONTROL",
156
+ index=2,
157
+ number=3,
158
+ serialized_options=None,
159
+ type=None,
160
+ create_key=_descriptor._internal_create_key,
161
+ ),
162
+ _descriptor.EnumValueDescriptor(
163
+ name="USER_DEFINED",
164
+ index=3,
165
+ number=4,
166
+ serialized_options=None,
167
+ type=None,
168
+ create_key=_descriptor._internal_create_key,
169
+ ),
170
+ _descriptor.EnumValueDescriptor(
171
+ name="BYTE",
172
+ index=4,
173
+ number=6,
174
+ serialized_options=None,
175
+ type=None,
176
+ create_key=_descriptor._internal_create_key,
177
+ ),
178
+ _descriptor.EnumValueDescriptor(
179
+ name="UNUSED",
180
+ index=5,
181
+ number=5,
182
+ serialized_options=None,
183
+ type=None,
184
+ create_key=_descriptor._internal_create_key,
185
+ ),
186
+ ],
187
+ containing_type=None,
188
+ serialized_options=None,
189
+ serialized_start=2100,
190
+ serialized_end=2184,
191
+ )
192
+ _sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE)
193
+
194
+
195
+ _TRAINERSPEC = _descriptor.Descriptor(
196
+ name="TrainerSpec",
197
+ full_name="sentencepiece.TrainerSpec",
198
+ filename=None,
199
+ file=DESCRIPTOR,
200
+ containing_type=None,
201
+ create_key=_descriptor._internal_create_key,
202
+ fields=[
203
+ _descriptor.FieldDescriptor(
204
+ name="input",
205
+ full_name="sentencepiece.TrainerSpec.input",
206
+ index=0,
207
+ number=1,
208
+ type=9,
209
+ cpp_type=9,
210
+ label=3,
211
+ has_default_value=False,
212
+ default_value=[],
213
+ message_type=None,
214
+ enum_type=None,
215
+ containing_type=None,
216
+ is_extension=False,
217
+ extension_scope=None,
218
+ serialized_options=None,
219
+ file=DESCRIPTOR,
220
+ create_key=_descriptor._internal_create_key,
221
+ ),
222
+ _descriptor.FieldDescriptor(
223
+ name="input_format",
224
+ full_name="sentencepiece.TrainerSpec.input_format",
225
+ index=1,
226
+ number=7,
227
+ type=9,
228
+ cpp_type=9,
229
+ label=1,
230
+ has_default_value=False,
231
+ default_value=b"".decode("utf-8"),
232
+ message_type=None,
233
+ enum_type=None,
234
+ containing_type=None,
235
+ is_extension=False,
236
+ extension_scope=None,
237
+ serialized_options=None,
238
+ file=DESCRIPTOR,
239
+ create_key=_descriptor._internal_create_key,
240
+ ),
241
+ _descriptor.FieldDescriptor(
242
+ name="model_prefix",
243
+ full_name="sentencepiece.TrainerSpec.model_prefix",
244
+ index=2,
245
+ number=2,
246
+ type=9,
247
+ cpp_type=9,
248
+ label=1,
249
+ has_default_value=False,
250
+ default_value=b"".decode("utf-8"),
251
+ message_type=None,
252
+ enum_type=None,
253
+ containing_type=None,
254
+ is_extension=False,
255
+ extension_scope=None,
256
+ serialized_options=None,
257
+ file=DESCRIPTOR,
258
+ create_key=_descriptor._internal_create_key,
259
+ ),
260
+ _descriptor.FieldDescriptor(
261
+ name="model_type",
262
+ full_name="sentencepiece.TrainerSpec.model_type",
263
+ index=3,
264
+ number=3,
265
+ type=14,
266
+ cpp_type=8,
267
+ label=1,
268
+ has_default_value=True,
269
+ default_value=1,
270
+ message_type=None,
271
+ enum_type=None,
272
+ containing_type=None,
273
+ is_extension=False,
274
+ extension_scope=None,
275
+ serialized_options=None,
276
+ file=DESCRIPTOR,
277
+ create_key=_descriptor._internal_create_key,
278
+ ),
279
+ _descriptor.FieldDescriptor(
280
+ name="vocab_size",
281
+ full_name="sentencepiece.TrainerSpec.vocab_size",
282
+ index=4,
283
+ number=4,
284
+ type=5,
285
+ cpp_type=1,
286
+ label=1,
287
+ has_default_value=True,
288
+ default_value=8000,
289
+ message_type=None,
290
+ enum_type=None,
291
+ containing_type=None,
292
+ is_extension=False,
293
+ extension_scope=None,
294
+ serialized_options=None,
295
+ file=DESCRIPTOR,
296
+ create_key=_descriptor._internal_create_key,
297
+ ),
298
+ _descriptor.FieldDescriptor(
299
+ name="accept_language",
300
+ full_name="sentencepiece.TrainerSpec.accept_language",
301
+ index=5,
302
+ number=5,
303
+ type=9,
304
+ cpp_type=9,
305
+ label=3,
306
+ has_default_value=False,
307
+ default_value=[],
308
+ message_type=None,
309
+ enum_type=None,
310
+ containing_type=None,
311
+ is_extension=False,
312
+ extension_scope=None,
313
+ serialized_options=None,
314
+ file=DESCRIPTOR,
315
+ create_key=_descriptor._internal_create_key,
316
+ ),
317
+ _descriptor.FieldDescriptor(
318
+ name="self_test_sample_size",
319
+ full_name="sentencepiece.TrainerSpec.self_test_sample_size",
320
+ index=6,
321
+ number=6,
322
+ type=5,
323
+ cpp_type=1,
324
+ label=1,
325
+ has_default_value=True,
326
+ default_value=0,
327
+ message_type=None,
328
+ enum_type=None,
329
+ containing_type=None,
330
+ is_extension=False,
331
+ extension_scope=None,
332
+ serialized_options=None,
333
+ file=DESCRIPTOR,
334
+ create_key=_descriptor._internal_create_key,
335
+ ),
336
+ _descriptor.FieldDescriptor(
337
+ name="character_coverage",
338
+ full_name="sentencepiece.TrainerSpec.character_coverage",
339
+ index=7,
340
+ number=10,
341
+ type=2,
342
+ cpp_type=6,
343
+ label=1,
344
+ has_default_value=True,
345
+ default_value=float(0.9995),
346
+ message_type=None,
347
+ enum_type=None,
348
+ containing_type=None,
349
+ is_extension=False,
350
+ extension_scope=None,
351
+ serialized_options=None,
352
+ file=DESCRIPTOR,
353
+ create_key=_descriptor._internal_create_key,
354
+ ),
355
+ _descriptor.FieldDescriptor(
356
+ name="input_sentence_size",
357
+ full_name="sentencepiece.TrainerSpec.input_sentence_size",
358
+ index=8,
359
+ number=11,
360
+ type=4,
361
+ cpp_type=4,
362
+ label=1,
363
+ has_default_value=True,
364
+ default_value=0,
365
+ message_type=None,
366
+ enum_type=None,
367
+ containing_type=None,
368
+ is_extension=False,
369
+ extension_scope=None,
370
+ serialized_options=None,
371
+ file=DESCRIPTOR,
372
+ create_key=_descriptor._internal_create_key,
373
+ ),
374
+ _descriptor.FieldDescriptor(
375
+ name="shuffle_input_sentence",
376
+ full_name="sentencepiece.TrainerSpec.shuffle_input_sentence",
377
+ index=9,
378
+ number=19,
379
+ type=8,
380
+ cpp_type=7,
381
+ label=1,
382
+ has_default_value=True,
383
+ default_value=True,
384
+ message_type=None,
385
+ enum_type=None,
386
+ containing_type=None,
387
+ is_extension=False,
388
+ extension_scope=None,
389
+ serialized_options=None,
390
+ file=DESCRIPTOR,
391
+ create_key=_descriptor._internal_create_key,
392
+ ),
393
+ _descriptor.FieldDescriptor(
394
+ name="mining_sentence_size",
395
+ full_name="sentencepiece.TrainerSpec.mining_sentence_size",
396
+ index=10,
397
+ number=12,
398
+ type=5,
399
+ cpp_type=1,
400
+ label=1,
401
+ has_default_value=False,
402
+ default_value=0,
403
+ message_type=None,
404
+ enum_type=None,
405
+ containing_type=None,
406
+ is_extension=False,
407
+ extension_scope=None,
408
+ serialized_options=b"\030\001",
409
+ file=DESCRIPTOR,
410
+ create_key=_descriptor._internal_create_key,
411
+ ),
412
+ _descriptor.FieldDescriptor(
413
+ name="training_sentence_size",
414
+ full_name="sentencepiece.TrainerSpec.training_sentence_size",
415
+ index=11,
416
+ number=13,
417
+ type=5,
418
+ cpp_type=1,
419
+ label=1,
420
+ has_default_value=False,
421
+ default_value=0,
422
+ message_type=None,
423
+ enum_type=None,
424
+ containing_type=None,
425
+ is_extension=False,
426
+ extension_scope=None,
427
+ serialized_options=b"\030\001",
428
+ file=DESCRIPTOR,
429
+ create_key=_descriptor._internal_create_key,
430
+ ),
431
+ _descriptor.FieldDescriptor(
432
+ name="seed_sentencepiece_size",
433
+ full_name="sentencepiece.TrainerSpec.seed_sentencepiece_size",
434
+ index=12,
435
+ number=14,
436
+ type=5,
437
+ cpp_type=1,
438
+ label=1,
439
+ has_default_value=True,
440
+ default_value=1000000,
441
+ message_type=None,
442
+ enum_type=None,
443
+ containing_type=None,
444
+ is_extension=False,
445
+ extension_scope=None,
446
+ serialized_options=None,
447
+ file=DESCRIPTOR,
448
+ create_key=_descriptor._internal_create_key,
449
+ ),
450
+ _descriptor.FieldDescriptor(
451
+ name="shrinking_factor",
452
+ full_name="sentencepiece.TrainerSpec.shrinking_factor",
453
+ index=13,
454
+ number=15,
455
+ type=2,
456
+ cpp_type=6,
457
+ label=1,
458
+ has_default_value=True,
459
+ default_value=float(0.75),
460
+ message_type=None,
461
+ enum_type=None,
462
+ containing_type=None,
463
+ is_extension=False,
464
+ extension_scope=None,
465
+ serialized_options=None,
466
+ file=DESCRIPTOR,
467
+ create_key=_descriptor._internal_create_key,
468
+ ),
469
+ _descriptor.FieldDescriptor(
470
+ name="max_sentence_length",
471
+ full_name="sentencepiece.TrainerSpec.max_sentence_length",
472
+ index=14,
473
+ number=18,
474
+ type=5,
475
+ cpp_type=1,
476
+ label=1,
477
+ has_default_value=True,
478
+ default_value=4192,
479
+ message_type=None,
480
+ enum_type=None,
481
+ containing_type=None,
482
+ is_extension=False,
483
+ extension_scope=None,
484
+ serialized_options=None,
485
+ file=DESCRIPTOR,
486
+ create_key=_descriptor._internal_create_key,
487
+ ),
488
+ _descriptor.FieldDescriptor(
489
+ name="num_threads",
490
+ full_name="sentencepiece.TrainerSpec.num_threads",
491
+ index=15,
492
+ number=16,
493
+ type=5,
494
+ cpp_type=1,
495
+ label=1,
496
+ has_default_value=True,
497
+ default_value=16,
498
+ message_type=None,
499
+ enum_type=None,
500
+ containing_type=None,
501
+ is_extension=False,
502
+ extension_scope=None,
503
+ serialized_options=None,
504
+ file=DESCRIPTOR,
505
+ create_key=_descriptor._internal_create_key,
506
+ ),
507
+ _descriptor.FieldDescriptor(
508
+ name="num_sub_iterations",
509
+ full_name="sentencepiece.TrainerSpec.num_sub_iterations",
510
+ index=16,
511
+ number=17,
512
+ type=5,
513
+ cpp_type=1,
514
+ label=1,
515
+ has_default_value=True,
516
+ default_value=2,
517
+ message_type=None,
518
+ enum_type=None,
519
+ containing_type=None,
520
+ is_extension=False,
521
+ extension_scope=None,
522
+ serialized_options=None,
523
+ file=DESCRIPTOR,
524
+ create_key=_descriptor._internal_create_key,
525
+ ),
526
+ _descriptor.FieldDescriptor(
527
+ name="max_sentencepiece_length",
528
+ full_name="sentencepiece.TrainerSpec.max_sentencepiece_length",
529
+ index=17,
530
+ number=20,
531
+ type=5,
532
+ cpp_type=1,
533
+ label=1,
534
+ has_default_value=True,
535
+ default_value=16,
536
+ message_type=None,
537
+ enum_type=None,
538
+ containing_type=None,
539
+ is_extension=False,
540
+ extension_scope=None,
541
+ serialized_options=None,
542
+ file=DESCRIPTOR,
543
+ create_key=_descriptor._internal_create_key,
544
+ ),
545
+ _descriptor.FieldDescriptor(
546
+ name="split_by_unicode_script",
547
+ full_name="sentencepiece.TrainerSpec.split_by_unicode_script",
548
+ index=18,
549
+ number=21,
550
+ type=8,
551
+ cpp_type=7,
552
+ label=1,
553
+ has_default_value=True,
554
+ default_value=True,
555
+ message_type=None,
556
+ enum_type=None,
557
+ containing_type=None,
558
+ is_extension=False,
559
+ extension_scope=None,
560
+ serialized_options=None,
561
+ file=DESCRIPTOR,
562
+ create_key=_descriptor._internal_create_key,
563
+ ),
564
+ _descriptor.FieldDescriptor(
565
+ name="split_by_number",
566
+ full_name="sentencepiece.TrainerSpec.split_by_number",
567
+ index=19,
568
+ number=23,
569
+ type=8,
570
+ cpp_type=7,
571
+ label=1,
572
+ has_default_value=True,
573
+ default_value=True,
574
+ message_type=None,
575
+ enum_type=None,
576
+ containing_type=None,
577
+ is_extension=False,
578
+ extension_scope=None,
579
+ serialized_options=None,
580
+ file=DESCRIPTOR,
581
+ create_key=_descriptor._internal_create_key,
582
+ ),
583
+ _descriptor.FieldDescriptor(
584
+ name="split_by_whitespace",
585
+ full_name="sentencepiece.TrainerSpec.split_by_whitespace",
586
+ index=20,
587
+ number=22,
588
+ type=8,
589
+ cpp_type=7,
590
+ label=1,
591
+ has_default_value=True,
592
+ default_value=True,
593
+ message_type=None,
594
+ enum_type=None,
595
+ containing_type=None,
596
+ is_extension=False,
597
+ extension_scope=None,
598
+ serialized_options=None,
599
+ file=DESCRIPTOR,
600
+ create_key=_descriptor._internal_create_key,
601
+ ),
602
+ _descriptor.FieldDescriptor(
603
+ name="treat_whitespace_as_suffix",
604
+ full_name="sentencepiece.TrainerSpec.treat_whitespace_as_suffix",
605
+ index=21,
606
+ number=24,
607
+ type=8,
608
+ cpp_type=7,
609
+ label=1,
610
+ has_default_value=True,
611
+ default_value=False,
612
+ message_type=None,
613
+ enum_type=None,
614
+ containing_type=None,
615
+ is_extension=False,
616
+ extension_scope=None,
617
+ serialized_options=None,
618
+ file=DESCRIPTOR,
619
+ create_key=_descriptor._internal_create_key,
620
+ ),
621
+ _descriptor.FieldDescriptor(
622
+ name="split_digits",
623
+ full_name="sentencepiece.TrainerSpec.split_digits",
624
+ index=22,
625
+ number=25,
626
+ type=8,
627
+ cpp_type=7,
628
+ label=1,
629
+ has_default_value=True,
630
+ default_value=False,
631
+ message_type=None,
632
+ enum_type=None,
633
+ containing_type=None,
634
+ is_extension=False,
635
+ extension_scope=None,
636
+ serialized_options=None,
637
+ file=DESCRIPTOR,
638
+ create_key=_descriptor._internal_create_key,
639
+ ),
640
+ _descriptor.FieldDescriptor(
641
+ name="control_symbols",
642
+ full_name="sentencepiece.TrainerSpec.control_symbols",
643
+ index=23,
644
+ number=30,
645
+ type=9,
646
+ cpp_type=9,
647
+ label=3,
648
+ has_default_value=False,
649
+ default_value=[],
650
+ message_type=None,
651
+ enum_type=None,
652
+ containing_type=None,
653
+ is_extension=False,
654
+ extension_scope=None,
655
+ serialized_options=None,
656
+ file=DESCRIPTOR,
657
+ create_key=_descriptor._internal_create_key,
658
+ ),
659
+ _descriptor.FieldDescriptor(
660
+ name="user_defined_symbols",
661
+ full_name="sentencepiece.TrainerSpec.user_defined_symbols",
662
+ index=24,
663
+ number=31,
664
+ type=9,
665
+ cpp_type=9,
666
+ label=3,
667
+ has_default_value=False,
668
+ default_value=[],
669
+ message_type=None,
670
+ enum_type=None,
671
+ containing_type=None,
672
+ is_extension=False,
673
+ extension_scope=None,
674
+ serialized_options=None,
675
+ file=DESCRIPTOR,
676
+ create_key=_descriptor._internal_create_key,
677
+ ),
678
+ _descriptor.FieldDescriptor(
679
+ name="required_chars",
680
+ full_name="sentencepiece.TrainerSpec.required_chars",
681
+ index=25,
682
+ number=36,
683
+ type=9,
684
+ cpp_type=9,
685
+ label=1,
686
+ has_default_value=False,
687
+ default_value=b"".decode("utf-8"),
688
+ message_type=None,
689
+ enum_type=None,
690
+ containing_type=None,
691
+ is_extension=False,
692
+ extension_scope=None,
693
+ serialized_options=None,
694
+ file=DESCRIPTOR,
695
+ create_key=_descriptor._internal_create_key,
696
+ ),
697
+ _descriptor.FieldDescriptor(
698
+ name="byte_fallback",
699
+ full_name="sentencepiece.TrainerSpec.byte_fallback",
700
+ index=26,
701
+ number=35,
702
+ type=8,
703
+ cpp_type=7,
704
+ label=1,
705
+ has_default_value=True,
706
+ default_value=False,
707
+ message_type=None,
708
+ enum_type=None,
709
+ containing_type=None,
710
+ is_extension=False,
711
+ extension_scope=None,
712
+ serialized_options=None,
713
+ file=DESCRIPTOR,
714
+ create_key=_descriptor._internal_create_key,
715
+ ),
716
+ _descriptor.FieldDescriptor(
717
+ name="vocabulary_output_piece_score",
718
+ full_name="sentencepiece.TrainerSpec.vocabulary_output_piece_score",
719
+ index=27,
720
+ number=32,
721
+ type=8,
722
+ cpp_type=7,
723
+ label=1,
724
+ has_default_value=True,
725
+ default_value=True,
726
+ message_type=None,
727
+ enum_type=None,
728
+ containing_type=None,
729
+ is_extension=False,
730
+ extension_scope=None,
731
+ serialized_options=None,
732
+ file=DESCRIPTOR,
733
+ create_key=_descriptor._internal_create_key,
734
+ ),
735
+ _descriptor.FieldDescriptor(
736
+ name="hard_vocab_limit",
737
+ full_name="sentencepiece.TrainerSpec.hard_vocab_limit",
738
+ index=28,
739
+ number=33,
740
+ type=8,
741
+ cpp_type=7,
742
+ label=1,
743
+ has_default_value=True,
744
+ default_value=True,
745
+ message_type=None,
746
+ enum_type=None,
747
+ containing_type=None,
748
+ is_extension=False,
749
+ extension_scope=None,
750
+ serialized_options=None,
751
+ file=DESCRIPTOR,
752
+ create_key=_descriptor._internal_create_key,
753
+ ),
754
+ _descriptor.FieldDescriptor(
755
+ name="use_all_vocab",
756
+ full_name="sentencepiece.TrainerSpec.use_all_vocab",
757
+ index=29,
758
+ number=34,
759
+ type=8,
760
+ cpp_type=7,
761
+ label=1,
762
+ has_default_value=True,
763
+ default_value=False,
764
+ message_type=None,
765
+ enum_type=None,
766
+ containing_type=None,
767
+ is_extension=False,
768
+ extension_scope=None,
769
+ serialized_options=None,
770
+ file=DESCRIPTOR,
771
+ create_key=_descriptor._internal_create_key,
772
+ ),
773
+ _descriptor.FieldDescriptor(
774
+ name="unk_id",
775
+ full_name="sentencepiece.TrainerSpec.unk_id",
776
+ index=30,
777
+ number=40,
778
+ type=5,
779
+ cpp_type=1,
780
+ label=1,
781
+ has_default_value=True,
782
+ default_value=0,
783
+ message_type=None,
784
+ enum_type=None,
785
+ containing_type=None,
786
+ is_extension=False,
787
+ extension_scope=None,
788
+ serialized_options=None,
789
+ file=DESCRIPTOR,
790
+ create_key=_descriptor._internal_create_key,
791
+ ),
792
+ _descriptor.FieldDescriptor(
793
+ name="bos_id",
794
+ full_name="sentencepiece.TrainerSpec.bos_id",
795
+ index=31,
796
+ number=41,
797
+ type=5,
798
+ cpp_type=1,
799
+ label=1,
800
+ has_default_value=True,
801
+ default_value=1,
802
+ message_type=None,
803
+ enum_type=None,
804
+ containing_type=None,
805
+ is_extension=False,
806
+ extension_scope=None,
807
+ serialized_options=None,
808
+ file=DESCRIPTOR,
809
+ create_key=_descriptor._internal_create_key,
810
+ ),
811
+ _descriptor.FieldDescriptor(
812
+ name="eos_id",
813
+ full_name="sentencepiece.TrainerSpec.eos_id",
814
+ index=32,
815
+ number=42,
816
+ type=5,
817
+ cpp_type=1,
818
+ label=1,
819
+ has_default_value=True,
820
+ default_value=2,
821
+ message_type=None,
822
+ enum_type=None,
823
+ containing_type=None,
824
+ is_extension=False,
825
+ extension_scope=None,
826
+ serialized_options=None,
827
+ file=DESCRIPTOR,
828
+ create_key=_descriptor._internal_create_key,
829
+ ),
830
+ _descriptor.FieldDescriptor(
831
+ name="pad_id",
832
+ full_name="sentencepiece.TrainerSpec.pad_id",
833
+ index=33,
834
+ number=43,
835
+ type=5,
836
+ cpp_type=1,
837
+ label=1,
838
+ has_default_value=True,
839
+ default_value=-1,
840
+ message_type=None,
841
+ enum_type=None,
842
+ containing_type=None,
843
+ is_extension=False,
844
+ extension_scope=None,
845
+ serialized_options=None,
846
+ file=DESCRIPTOR,
847
+ create_key=_descriptor._internal_create_key,
848
+ ),
849
+ _descriptor.FieldDescriptor(
850
+ name="unk_piece",
851
+ full_name="sentencepiece.TrainerSpec.unk_piece",
852
+ index=34,
853
+ number=45,
854
+ type=9,
855
+ cpp_type=9,
856
+ label=1,
857
+ has_default_value=True,
858
+ default_value=b"<unk>".decode("utf-8"),
859
+ message_type=None,
860
+ enum_type=None,
861
+ containing_type=None,
862
+ is_extension=False,
863
+ extension_scope=None,
864
+ serialized_options=None,
865
+ file=DESCRIPTOR,
866
+ create_key=_descriptor._internal_create_key,
867
+ ),
868
+ _descriptor.FieldDescriptor(
869
+ name="bos_piece",
870
+ full_name="sentencepiece.TrainerSpec.bos_piece",
871
+ index=35,
872
+ number=46,
873
+ type=9,
874
+ cpp_type=9,
875
+ label=1,
876
+ has_default_value=True,
877
+ default_value=b"<s>".decode("utf-8"),
878
+ message_type=None,
879
+ enum_type=None,
880
+ containing_type=None,
881
+ is_extension=False,
882
+ extension_scope=None,
883
+ serialized_options=None,
884
+ file=DESCRIPTOR,
885
+ create_key=_descriptor._internal_create_key,
886
+ ),
887
+ _descriptor.FieldDescriptor(
888
+ name="eos_piece",
889
+ full_name="sentencepiece.TrainerSpec.eos_piece",
890
+ index=36,
891
+ number=47,
892
+ type=9,
893
+ cpp_type=9,
894
+ label=1,
895
+ has_default_value=True,
896
+ default_value=b"</s>".decode("utf-8"),
897
+ message_type=None,
898
+ enum_type=None,
899
+ containing_type=None,
900
+ is_extension=False,
901
+ extension_scope=None,
902
+ serialized_options=None,
903
+ file=DESCRIPTOR,
904
+ create_key=_descriptor._internal_create_key,
905
+ ),
906
+ _descriptor.FieldDescriptor(
907
+ name="pad_piece",
908
+ full_name="sentencepiece.TrainerSpec.pad_piece",
909
+ index=37,
910
+ number=48,
911
+ type=9,
912
+ cpp_type=9,
913
+ label=1,
914
+ has_default_value=True,
915
+ default_value=b"<pad>".decode("utf-8"),
916
+ message_type=None,
917
+ enum_type=None,
918
+ containing_type=None,
919
+ is_extension=False,
920
+ extension_scope=None,
921
+ serialized_options=None,
922
+ file=DESCRIPTOR,
923
+ create_key=_descriptor._internal_create_key,
924
+ ),
925
+ _descriptor.FieldDescriptor(
926
+ name="unk_surface",
927
+ full_name="sentencepiece.TrainerSpec.unk_surface",
928
+ index=38,
929
+ number=44,
930
+ type=9,
931
+ cpp_type=9,
932
+ label=1,
933
+ has_default_value=True,
934
+ default_value=b" \342\201\207 ".decode("utf-8"),
935
+ message_type=None,
936
+ enum_type=None,
937
+ containing_type=None,
938
+ is_extension=False,
939
+ extension_scope=None,
940
+ serialized_options=None,
941
+ file=DESCRIPTOR,
942
+ create_key=_descriptor._internal_create_key,
943
+ ),
944
+ _descriptor.FieldDescriptor(
945
+ name="train_extremely_large_corpus",
946
+ full_name="sentencepiece.TrainerSpec.train_extremely_large_corpus",
947
+ index=39,
948
+ number=49,
949
+ type=8,
950
+ cpp_type=7,
951
+ label=1,
952
+ has_default_value=True,
953
+ default_value=False,
954
+ message_type=None,
955
+ enum_type=None,
956
+ containing_type=None,
957
+ is_extension=False,
958
+ extension_scope=None,
959
+ serialized_options=None,
960
+ file=DESCRIPTOR,
961
+ create_key=_descriptor._internal_create_key,
962
+ ),
963
+ ],
964
+ extensions=[],
965
+ nested_types=[],
966
+ enum_types=[
967
+ _TRAINERSPEC_MODELTYPE,
968
+ ],
969
+ serialized_options=None,
970
+ is_extendable=True,
971
+ syntax="proto2",
972
+ extension_ranges=[
973
+ (200, 536870912),
974
+ ],
975
+ oneofs=[],
976
+ serialized_start=45,
977
+ serialized_end=1358,
978
+ )
979
+
980
+
981
+ _NORMALIZERSPEC = _descriptor.Descriptor(
982
+ name="NormalizerSpec",
983
+ full_name="sentencepiece.NormalizerSpec",
984
+ filename=None,
985
+ file=DESCRIPTOR,
986
+ containing_type=None,
987
+ create_key=_descriptor._internal_create_key,
988
+ fields=[
989
+ _descriptor.FieldDescriptor(
990
+ name="name",
991
+ full_name="sentencepiece.NormalizerSpec.name",
992
+ index=0,
993
+ number=1,
994
+ type=9,
995
+ cpp_type=9,
996
+ label=1,
997
+ has_default_value=False,
998
+ default_value=b"".decode("utf-8"),
999
+ message_type=None,
1000
+ enum_type=None,
1001
+ containing_type=None,
1002
+ is_extension=False,
1003
+ extension_scope=None,
1004
+ serialized_options=None,
1005
+ file=DESCRIPTOR,
1006
+ create_key=_descriptor._internal_create_key,
1007
+ ),
1008
+ _descriptor.FieldDescriptor(
1009
+ name="precompiled_charsmap",
1010
+ full_name="sentencepiece.NormalizerSpec.precompiled_charsmap",
1011
+ index=1,
1012
+ number=2,
1013
+ type=12,
1014
+ cpp_type=9,
1015
+ label=1,
1016
+ has_default_value=False,
1017
+ default_value=b"",
1018
+ message_type=None,
1019
+ enum_type=None,
1020
+ containing_type=None,
1021
+ is_extension=False,
1022
+ extension_scope=None,
1023
+ serialized_options=None,
1024
+ file=DESCRIPTOR,
1025
+ create_key=_descriptor._internal_create_key,
1026
+ ),
1027
+ _descriptor.FieldDescriptor(
1028
+ name="add_dummy_prefix",
1029
+ full_name="sentencepiece.NormalizerSpec.add_dummy_prefix",
1030
+ index=2,
1031
+ number=3,
1032
+ type=8,
1033
+ cpp_type=7,
1034
+ label=1,
1035
+ has_default_value=True,
1036
+ default_value=True,
1037
+ message_type=None,
1038
+ enum_type=None,
1039
+ containing_type=None,
1040
+ is_extension=False,
1041
+ extension_scope=None,
1042
+ serialized_options=None,
1043
+ file=DESCRIPTOR,
1044
+ create_key=_descriptor._internal_create_key,
1045
+ ),
1046
+ _descriptor.FieldDescriptor(
1047
+ name="remove_extra_whitespaces",
1048
+ full_name="sentencepiece.NormalizerSpec.remove_extra_whitespaces",
1049
+ index=3,
1050
+ number=4,
1051
+ type=8,
1052
+ cpp_type=7,
1053
+ label=1,
1054
+ has_default_value=True,
1055
+ default_value=True,
1056
+ message_type=None,
1057
+ enum_type=None,
1058
+ containing_type=None,
1059
+ is_extension=False,
1060
+ extension_scope=None,
1061
+ serialized_options=None,
1062
+ file=DESCRIPTOR,
1063
+ create_key=_descriptor._internal_create_key,
1064
+ ),
1065
+ _descriptor.FieldDescriptor(
1066
+ name="escape_whitespaces",
1067
+ full_name="sentencepiece.NormalizerSpec.escape_whitespaces",
1068
+ index=4,
1069
+ number=5,
1070
+ type=8,
1071
+ cpp_type=7,
1072
+ label=1,
1073
+ has_default_value=True,
1074
+ default_value=True,
1075
+ message_type=None,
1076
+ enum_type=None,
1077
+ containing_type=None,
1078
+ is_extension=False,
1079
+ extension_scope=None,
1080
+ serialized_options=None,
1081
+ file=DESCRIPTOR,
1082
+ create_key=_descriptor._internal_create_key,
1083
+ ),
1084
+ _descriptor.FieldDescriptor(
1085
+ name="normalization_rule_tsv",
1086
+ full_name="sentencepiece.NormalizerSpec.normalization_rule_tsv",
1087
+ index=5,
1088
+ number=6,
1089
+ type=9,
1090
+ cpp_type=9,
1091
+ label=1,
1092
+ has_default_value=False,
1093
+ default_value=b"".decode("utf-8"),
1094
+ message_type=None,
1095
+ enum_type=None,
1096
+ containing_type=None,
1097
+ is_extension=False,
1098
+ extension_scope=None,
1099
+ serialized_options=None,
1100
+ file=DESCRIPTOR,
1101
+ create_key=_descriptor._internal_create_key,
1102
+ ),
1103
+ ],
1104
+ extensions=[],
1105
+ nested_types=[],
1106
+ enum_types=[],
1107
+ serialized_options=None,
1108
+ is_extendable=True,
1109
+ syntax="proto2",
1110
+ extension_ranges=[
1111
+ (200, 536870912),
1112
+ ],
1113
+ oneofs=[],
1114
+ serialized_start=1361,
1115
+ serialized_end=1570,
1116
+ )
1117
+
1118
+
1119
+ _SELFTESTDATA_SAMPLE = _descriptor.Descriptor(
1120
+ name="Sample",
1121
+ full_name="sentencepiece.SelfTestData.Sample",
1122
+ filename=None,
1123
+ file=DESCRIPTOR,
1124
+ containing_type=None,
1125
+ create_key=_descriptor._internal_create_key,
1126
+ fields=[
1127
+ _descriptor.FieldDescriptor(
1128
+ name="input",
1129
+ full_name="sentencepiece.SelfTestData.Sample.input",
1130
+ index=0,
1131
+ number=1,
1132
+ type=9,
1133
+ cpp_type=9,
1134
+ label=1,
1135
+ has_default_value=False,
1136
+ default_value=b"".decode("utf-8"),
1137
+ message_type=None,
1138
+ enum_type=None,
1139
+ containing_type=None,
1140
+ is_extension=False,
1141
+ extension_scope=None,
1142
+ serialized_options=None,
1143
+ file=DESCRIPTOR,
1144
+ create_key=_descriptor._internal_create_key,
1145
+ ),
1146
+ _descriptor.FieldDescriptor(
1147
+ name="expected",
1148
+ full_name="sentencepiece.SelfTestData.Sample.expected",
1149
+ index=1,
1150
+ number=2,
1151
+ type=9,
1152
+ cpp_type=9,
1153
+ label=1,
1154
+ has_default_value=False,
1155
+ default_value=b"".decode("utf-8"),
1156
+ message_type=None,
1157
+ enum_type=None,
1158
+ containing_type=None,
1159
+ is_extension=False,
1160
+ extension_scope=None,
1161
+ serialized_options=None,
1162
+ file=DESCRIPTOR,
1163
+ create_key=_descriptor._internal_create_key,
1164
+ ),
1165
+ ],
1166
+ extensions=[],
1167
+ nested_types=[],
1168
+ enum_types=[],
1169
+ serialized_options=None,
1170
+ is_extendable=False,
1171
+ syntax="proto2",
1172
+ extension_ranges=[],
1173
+ oneofs=[],
1174
+ serialized_start=1641,
1175
+ serialized_end=1682,
1176
+ )
1177
+
1178
+ _SELFTESTDATA = _descriptor.Descriptor(
1179
+ name="SelfTestData",
1180
+ full_name="sentencepiece.SelfTestData",
1181
+ filename=None,
1182
+ file=DESCRIPTOR,
1183
+ containing_type=None,
1184
+ create_key=_descriptor._internal_create_key,
1185
+ fields=[
1186
+ _descriptor.FieldDescriptor(
1187
+ name="samples",
1188
+ full_name="sentencepiece.SelfTestData.samples",
1189
+ index=0,
1190
+ number=1,
1191
+ type=11,
1192
+ cpp_type=10,
1193
+ label=3,
1194
+ has_default_value=False,
1195
+ default_value=[],
1196
+ message_type=None,
1197
+ enum_type=None,
1198
+ containing_type=None,
1199
+ is_extension=False,
1200
+ extension_scope=None,
1201
+ serialized_options=None,
1202
+ file=DESCRIPTOR,
1203
+ create_key=_descriptor._internal_create_key,
1204
+ ),
1205
+ ],
1206
+ extensions=[],
1207
+ nested_types=[
1208
+ _SELFTESTDATA_SAMPLE,
1209
+ ],
1210
+ enum_types=[],
1211
+ serialized_options=None,
1212
+ is_extendable=True,
1213
+ syntax="proto2",
1214
+ extension_ranges=[
1215
+ (200, 536870912),
1216
+ ],
1217
+ oneofs=[],
1218
+ serialized_start=1572,
1219
+ serialized_end=1693,
1220
+ )
1221
+
1222
+
1223
+ _MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor(
1224
+ name="SentencePiece",
1225
+ full_name="sentencepiece.ModelProto.SentencePiece",
1226
+ filename=None,
1227
+ file=DESCRIPTOR,
1228
+ containing_type=None,
1229
+ create_key=_descriptor._internal_create_key,
1230
+ fields=[
1231
+ _descriptor.FieldDescriptor(
1232
+ name="piece",
1233
+ full_name="sentencepiece.ModelProto.SentencePiece.piece",
1234
+ index=0,
1235
+ number=1,
1236
+ type=9,
1237
+ cpp_type=9,
1238
+ label=1,
1239
+ has_default_value=False,
1240
+ default_value=b"".decode("utf-8"),
1241
+ message_type=None,
1242
+ enum_type=None,
1243
+ containing_type=None,
1244
+ is_extension=False,
1245
+ extension_scope=None,
1246
+ serialized_options=None,
1247
+ file=DESCRIPTOR,
1248
+ create_key=_descriptor._internal_create_key,
1249
+ ),
1250
+ _descriptor.FieldDescriptor(
1251
+ name="score",
1252
+ full_name="sentencepiece.ModelProto.SentencePiece.score",
1253
+ index=1,
1254
+ number=2,
1255
+ type=2,
1256
+ cpp_type=6,
1257
+ label=1,
1258
+ has_default_value=False,
1259
+ default_value=float(0),
1260
+ message_type=None,
1261
+ enum_type=None,
1262
+ containing_type=None,
1263
+ is_extension=False,
1264
+ extension_scope=None,
1265
+ serialized_options=None,
1266
+ file=DESCRIPTOR,
1267
+ create_key=_descriptor._internal_create_key,
1268
+ ),
1269
+ _descriptor.FieldDescriptor(
1270
+ name="type",
1271
+ full_name="sentencepiece.ModelProto.SentencePiece.type",
1272
+ index=2,
1273
+ number=3,
1274
+ type=14,
1275
+ cpp_type=8,
1276
+ label=1,
1277
+ has_default_value=True,
1278
+ default_value=1,
1279
+ message_type=None,
1280
+ enum_type=None,
1281
+ containing_type=None,
1282
+ is_extension=False,
1283
+ extension_scope=None,
1284
+ serialized_options=None,
1285
+ file=DESCRIPTOR,
1286
+ create_key=_descriptor._internal_create_key,
1287
+ ),
1288
+ ],
1289
+ extensions=[],
1290
+ nested_types=[],
1291
+ enum_types=[
1292
+ _MODELPROTO_SENTENCEPIECE_TYPE,
1293
+ ],
1294
+ serialized_options=None,
1295
+ is_extendable=True,
1296
+ syntax="proto2",
1297
+ extension_ranges=[
1298
+ (200, 536870912),
1299
+ ],
1300
+ oneofs=[],
1301
+ serialized_start=1985,
1302
+ serialized_end=2195,
1303
+ )
1304
+
1305
+ _MODELPROTO = _descriptor.Descriptor(
1306
+ name="ModelProto",
1307
+ full_name="sentencepiece.ModelProto",
1308
+ filename=None,
1309
+ file=DESCRIPTOR,
1310
+ containing_type=None,
1311
+ create_key=_descriptor._internal_create_key,
1312
+ fields=[
1313
+ _descriptor.FieldDescriptor(
1314
+ name="pieces",
1315
+ full_name="sentencepiece.ModelProto.pieces",
1316
+ index=0,
1317
+ number=1,
1318
+ type=11,
1319
+ cpp_type=10,
1320
+ label=3,
1321
+ has_default_value=False,
1322
+ default_value=[],
1323
+ message_type=None,
1324
+ enum_type=None,
1325
+ containing_type=None,
1326
+ is_extension=False,
1327
+ extension_scope=None,
1328
+ serialized_options=None,
1329
+ file=DESCRIPTOR,
1330
+ create_key=_descriptor._internal_create_key,
1331
+ ),
1332
+ _descriptor.FieldDescriptor(
1333
+ name="trainer_spec",
1334
+ full_name="sentencepiece.ModelProto.trainer_spec",
1335
+ index=1,
1336
+ number=2,
1337
+ type=11,
1338
+ cpp_type=10,
1339
+ label=1,
1340
+ has_default_value=False,
1341
+ default_value=None,
1342
+ message_type=None,
1343
+ enum_type=None,
1344
+ containing_type=None,
1345
+ is_extension=False,
1346
+ extension_scope=None,
1347
+ serialized_options=None,
1348
+ file=DESCRIPTOR,
1349
+ create_key=_descriptor._internal_create_key,
1350
+ ),
1351
+ _descriptor.FieldDescriptor(
1352
+ name="normalizer_spec",
1353
+ full_name="sentencepiece.ModelProto.normalizer_spec",
1354
+ index=2,
1355
+ number=3,
1356
+ type=11,
1357
+ cpp_type=10,
1358
+ label=1,
1359
+ has_default_value=False,
1360
+ default_value=None,
1361
+ message_type=None,
1362
+ enum_type=None,
1363
+ containing_type=None,
1364
+ is_extension=False,
1365
+ extension_scope=None,
1366
+ serialized_options=None,
1367
+ file=DESCRIPTOR,
1368
+ create_key=_descriptor._internal_create_key,
1369
+ ),
1370
+ _descriptor.FieldDescriptor(
1371
+ name="self_test_data",
1372
+ full_name="sentencepiece.ModelProto.self_test_data",
1373
+ index=3,
1374
+ number=4,
1375
+ type=11,
1376
+ cpp_type=10,
1377
+ label=1,
1378
+ has_default_value=False,
1379
+ default_value=None,
1380
+ message_type=None,
1381
+ enum_type=None,
1382
+ containing_type=None,
1383
+ is_extension=False,
1384
+ extension_scope=None,
1385
+ serialized_options=None,
1386
+ file=DESCRIPTOR,
1387
+ create_key=_descriptor._internal_create_key,
1388
+ ),
1389
+ _descriptor.FieldDescriptor(
1390
+ name="denormalizer_spec",
1391
+ full_name="sentencepiece.ModelProto.denormalizer_spec",
1392
+ index=4,
1393
+ number=5,
1394
+ type=11,
1395
+ cpp_type=10,
1396
+ label=1,
1397
+ has_default_value=False,
1398
+ default_value=None,
1399
+ message_type=None,
1400
+ enum_type=None,
1401
+ containing_type=None,
1402
+ is_extension=False,
1403
+ extension_scope=None,
1404
+ serialized_options=None,
1405
+ file=DESCRIPTOR,
1406
+ create_key=_descriptor._internal_create_key,
1407
+ ),
1408
+ ],
1409
+ extensions=[],
1410
+ nested_types=[
1411
+ _MODELPROTO_SENTENCEPIECE,
1412
+ ],
1413
+ enum_types=[],
1414
+ serialized_options=None,
1415
+ is_extendable=True,
1416
+ syntax="proto2",
1417
+ extension_ranges=[
1418
+ (200, 536870912),
1419
+ ],
1420
+ oneofs=[],
1421
+ serialized_start=1696,
1422
+ serialized_end=2206,
1423
+ )
1424
+
1425
+ _TRAINERSPEC.fields_by_name["model_type"].enum_type = _TRAINERSPEC_MODELTYPE
1426
+ _TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC
1427
+ _SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA
1428
+ _SELFTESTDATA.fields_by_name["samples"].message_type = _SELFTESTDATA_SAMPLE
1429
+ _MODELPROTO_SENTENCEPIECE.fields_by_name["type"].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE
1430
+ _MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO
1431
+ _MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE
1432
+ _MODELPROTO.fields_by_name["pieces"].message_type = _MODELPROTO_SENTENCEPIECE
1433
+ _MODELPROTO.fields_by_name["trainer_spec"].message_type = _TRAINERSPEC
1434
+ _MODELPROTO.fields_by_name["normalizer_spec"].message_type = _NORMALIZERSPEC
1435
+ _MODELPROTO.fields_by_name["self_test_data"].message_type = _SELFTESTDATA
1436
+ _MODELPROTO.fields_by_name["denormalizer_spec"].message_type = _NORMALIZERSPEC
1437
+ DESCRIPTOR.message_types_by_name["TrainerSpec"] = _TRAINERSPEC
1438
+ DESCRIPTOR.message_types_by_name["NormalizerSpec"] = _NORMALIZERSPEC
1439
+ DESCRIPTOR.message_types_by_name["SelfTestData"] = _SELFTESTDATA
1440
+ DESCRIPTOR.message_types_by_name["ModelProto"] = _MODELPROTO
1441
+ _sym_db.RegisterFileDescriptor(DESCRIPTOR)
1442
+
1443
+ TrainerSpec = _reflection.GeneratedProtocolMessageType(
1444
+ "TrainerSpec",
1445
+ (_message.Message,),
1446
+ {
1447
+ "DESCRIPTOR": _TRAINERSPEC,
1448
+ "__module__": "sentencepiece_model_pb2",
1449
+ # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec)
1450
+ },
1451
+ )
1452
+ _sym_db.RegisterMessage(TrainerSpec)
1453
+
1454
+ NormalizerSpec = _reflection.GeneratedProtocolMessageType(
1455
+ "NormalizerSpec",
1456
+ (_message.Message,),
1457
+ {
1458
+ "DESCRIPTOR": _NORMALIZERSPEC,
1459
+ "__module__": "sentencepiece_model_pb2",
1460
+ # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec)
1461
+ },
1462
+ )
1463
+ _sym_db.RegisterMessage(NormalizerSpec)
1464
+
1465
+ SelfTestData = _reflection.GeneratedProtocolMessageType(
1466
+ "SelfTestData",
1467
+ (_message.Message,),
1468
+ {
1469
+ "Sample": _reflection.GeneratedProtocolMessageType(
1470
+ "Sample",
1471
+ (_message.Message,),
1472
+ {
1473
+ "DESCRIPTOR": _SELFTESTDATA_SAMPLE,
1474
+ "__module__": "sentencepiece_model_pb2",
1475
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample)
1476
+ },
1477
+ ),
1478
+ "DESCRIPTOR": _SELFTESTDATA,
1479
+ "__module__": "sentencepiece_model_pb2",
1480
+ # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData)
1481
+ },
1482
+ )
1483
+ _sym_db.RegisterMessage(SelfTestData)
1484
+ _sym_db.RegisterMessage(SelfTestData.Sample)
1485
+
1486
+ ModelProto = _reflection.GeneratedProtocolMessageType(
1487
+ "ModelProto",
1488
+ (_message.Message,),
1489
+ {
1490
+ "SentencePiece": _reflection.GeneratedProtocolMessageType(
1491
+ "SentencePiece",
1492
+ (_message.Message,),
1493
+ {
1494
+ "DESCRIPTOR": _MODELPROTO_SENTENCEPIECE,
1495
+ "__module__": "sentencepiece_model_pb2",
1496
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece)
1497
+ },
1498
+ ),
1499
+ "DESCRIPTOR": _MODELPROTO,
1500
+ "__module__": "sentencepiece_model_pb2",
1501
+ # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto)
1502
+ },
1503
+ )
1504
+ _sym_db.RegisterMessage(ModelProto)
1505
+ _sym_db.RegisterMessage(ModelProto.SentencePiece)
1506
+
1507
+
1508
+ DESCRIPTOR._options = None
1509
+ _TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
1510
+ _TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
1511
+ # @@protoc_insertion_point(module_scope)
.venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2_new.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: sentencepiece_model.proto
4
+ """Generated protocol buffer code."""
5
+
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+
12
+ # @@protoc_insertion_point(imports)
13
+
14
+ _sym_db = _symbol_database.Default()
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
18
+ b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
19
+ )
20
+
21
+ _globals = globals()
22
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
23
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
24
+ if _descriptor._USE_C_DESCRIPTORS is False:
25
+ DESCRIPTOR._options = None
26
+ DESCRIPTOR._serialized_options = b"H\003"
27
+ # (generated by protobuf compiler, but `_TRAINERSPEC` is not defined)
28
+ # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
29
+ # _TRAINERSPEC.fields_by_name["mining_sentence_size"]._serialized_options = b"\030\001"
30
+ # _TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
31
+ # _TRAINERSPEC.fields_by_name["training_sentence_size"]._serialized_options = b"\030\001"
32
+ _globals["_TRAINERSPEC"]._serialized_start = 45
33
+ _globals["_TRAINERSPEC"]._serialized_end = 1581
34
+ _globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
35
+ _globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
36
+ _globals["_NORMALIZERSPEC"]._serialized_start = 1584
37
+ _globals["_NORMALIZERSPEC"]._serialized_end = 1793
38
+ _globals["_SELFTESTDATA"]._serialized_start = 1795
39
+ _globals["_SELFTESTDATA"]._serialized_end = 1916
40
+ _globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
41
+ _globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
42
+ _globals["_MODELPROTO"]._serialized_start = 1919
43
+ _globals["_MODELPROTO"]._serialized_end = 2429
44
+ _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
45
+ _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
46
+ _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
47
+ _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
48
+ # @@protoc_insertion_point(module_scope)
.venv/Lib/site-packages/transformers/utils/versions.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Utilities for working with package versions
16
+ """
17
+
18
+ import importlib.metadata
19
+ import operator
20
+ import re
21
+ import sys
22
+ from typing import Optional
23
+
24
+ from packaging import version
25
+
26
+
27
+ ops = {
28
+ "<": operator.lt,
29
+ "<=": operator.le,
30
+ "==": operator.eq,
31
+ "!=": operator.ne,
32
+ ">=": operator.ge,
33
+ ">": operator.gt,
34
+ }
35
+
36
+
37
+ def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
38
+ if got_ver is None or want_ver is None:
39
+ raise ValueError(
40
+ f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
41
+ f" reinstalling {pkg}."
42
+ )
43
+ if not ops[op](version.parse(got_ver), version.parse(want_ver)):
44
+ raise ImportError(
45
+ f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
46
+ )
47
+
48
+
49
+ def require_version(requirement: str, hint: Optional[str] = None) -> None:
50
+ """
51
+ Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
52
+
53
+ The installed module version comes from the *site-packages* dir via *importlib.metadata*.
54
+
55
+ Args:
56
+ requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
57
+ hint (`str`, *optional*): what suggestion to print in case of requirements not being met
58
+
59
+ Example:
60
+
61
+ ```python
62
+ require_version("pandas>1.1.2")
63
+ require_version("numpy>1.18.5", "this is important to have for whatever reason")
64
+ ```"""
65
+
66
+ hint = f"\n{hint}" if hint is not None else ""
67
+
68
+ # non-versioned check
69
+ if re.match(r"^[\w_\-\d]+$", requirement):
70
+ pkg, op, want_ver = requirement, None, None
71
+ else:
72
+ match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
73
+ if not match:
74
+ raise ValueError(
75
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
76
+ f" got {requirement}"
77
+ )
78
+ pkg, want_full = match[0]
79
+ want_range = want_full.split(",") # there could be multiple requirements
80
+ wanted = {}
81
+ for w in want_range:
82
+ match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
83
+ if not match:
84
+ raise ValueError(
85
+ "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
86
+ f" but got {requirement}"
87
+ )
88
+ op, want_ver = match[0]
89
+ wanted[op] = want_ver
90
+ if op not in ops:
91
+ raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
92
+
93
+ # special case
94
+ if pkg == "python":
95
+ got_ver = ".".join([str(x) for x in sys.version_info[:3]])
96
+ for op, want_ver in wanted.items():
97
+ _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
98
+ return
99
+
100
+ # check if any version is installed
101
+ try:
102
+ got_ver = importlib.metadata.version(pkg)
103
+ except importlib.metadata.PackageNotFoundError:
104
+ raise importlib.metadata.PackageNotFoundError(
105
+ f"The '{requirement}' distribution was not found and is required by this application. {hint}"
106
+ )
107
+
108
+ # check that the right version is installed if version number or a range was provided
109
+ if want_ver is not None:
110
+ for op, want_ver in wanted.items():
111
+ _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
112
+
113
+
114
+ def require_version_core(requirement):
115
+ """require_version wrapper which emits a core-specific hint on failure"""
116
+ hint = "Try: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main"
117
+ return require_version(requirement, hint)
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Abidjan ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Accra ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Addis_Ababa ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Algiers ADDED
Binary file (470 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmara ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmera ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bamako ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bangui ADDED
Binary file (180 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Dar_es_Salaam ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Djibouti ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Douala ADDED
Binary file (180 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/El_Aaiun ADDED
Binary file (1.83 kB). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Freetown ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Gaborone ADDED
Binary file (131 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Harare ADDED
Binary file (131 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Johannesburg ADDED
Binary file (190 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Juba ADDED
Binary file (458 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kampala ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Khartoum ADDED
Binary file (458 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kigali ADDED
Binary file (131 Bytes). View file