Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/transformers-4.47.0.dist-info/INSTALLER +1 -0
- .venv/Lib/site-packages/transformers-4.47.0.dist-info/LICENSE +203 -0
- .venv/Lib/site-packages/transformers-4.47.0.dist-info/entry_points.txt +2 -0
- .venv/Lib/site-packages/transformers/__pycache__/__init__.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/modeling_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/pytorch_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/safetensors_conversion.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_base.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_fast.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/pipelines/feature_extraction.py +86 -0
- .venv/Lib/site-packages/transformers/pipelines/fill_mask.py +273 -0
- .venv/Lib/site-packages/transformers/pipelines/image_classification.py +226 -0
- .venv/Lib/site-packages/transformers/pipelines/image_feature_extraction.py +112 -0
- .venv/Lib/site-packages/transformers/pipelines/image_segmentation.py +220 -0
- .venv/Lib/site-packages/transformers/pipelines/image_text_to_text.py +432 -0
- .venv/Lib/site-packages/transformers/pipelines/image_to_image.py +136 -0
- .venv/Lib/site-packages/transformers/pipelines/image_to_text.py +216 -0
- .venv/Lib/site-packages/transformers/pipelines/mask_generation.py +287 -0
- .venv/Lib/site-packages/transformers/pipelines/object_detection.py +191 -0
- .venv/Lib/site-packages/transformers/utils/__init__.py +315 -0
- .venv/Lib/site-packages/transformers/utils/__pycache__/backbone_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/utils/__pycache__/chat_template_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/utils/__pycache__/constants.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/utils/__pycache__/deprecation.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/utils/__pycache__/doc.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/transformers/utils/quantization_config.py +1344 -0
- .venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2.py +1511 -0
- .venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2_new.py +48 -0
- .venv/Lib/site-packages/transformers/utils/versions.py +117 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Abidjan +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Accra +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Addis_Ababa +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Algiers +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmara +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmera +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bamako +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bangui +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Dar_es_Salaam +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Djibouti +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Douala +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/El_Aaiun +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Freetown +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Gaborone +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Harare +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Johannesburg +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Juba +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kampala +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Khartoum +0 -0
- .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kigali +0 -0
.venv/Lib/site-packages/transformers-4.47.0.dist-info/INSTALLER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
uv
|
.venv/Lib/site-packages/transformers-4.47.0.dist-info/LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright [yyyy] [name of copyright owner]
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
.venv/Lib/site-packages/transformers-4.47.0.dist-info/entry_points.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[console_scripts]
|
2 |
+
transformers-cli = transformers.commands.transformers_cli:main
|
.venv/Lib/site-packages/transformers/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (166 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/modeling_utils.cpython-39.pyc
ADDED
Binary file (169 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/pytorch_utils.cpython-39.pyc
ADDED
Binary file (12.5 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/safetensors_conversion.cpython-39.pyc
ADDED
Binary file (3.3 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils.cpython-39.pyc
ADDED
Binary file (32.4 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_base.cpython-39.pyc
ADDED
Binary file (149 kB). View file
|
|
.venv/Lib/site-packages/transformers/__pycache__/tokenization_utils_fast.cpython-39.pyc
ADDED
Binary file (27.7 kB). View file
|
|
.venv/Lib/site-packages/transformers/pipelines/feature_extraction.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from ..utils import add_end_docstrings
|
4 |
+
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
5 |
+
|
6 |
+
|
7 |
+
@add_end_docstrings(
|
8 |
+
build_pipeline_init_args(has_tokenizer=True, supports_binary_output=False),
|
9 |
+
r"""
|
10 |
+
tokenize_kwargs (`dict`, *optional*):
|
11 |
+
Additional dictionary of keyword arguments passed along to the tokenizer.
|
12 |
+
return_tensors (`bool`, *optional*):
|
13 |
+
If `True`, returns a tensor according to the specified framework, otherwise returns a list.""",
|
14 |
+
)
|
15 |
+
class FeatureExtractionPipeline(Pipeline):
|
16 |
+
"""
|
17 |
+
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
18 |
+
transformer, which can be used as features in downstream tasks.
|
19 |
+
|
20 |
+
Example:
|
21 |
+
|
22 |
+
```python
|
23 |
+
>>> from transformers import pipeline
|
24 |
+
|
25 |
+
>>> extractor = pipeline(model="google-bert/bert-base-uncased", task="feature-extraction")
|
26 |
+
>>> result = extractor("This is a simple test.", return_tensors=True)
|
27 |
+
>>> result.shape # This is a tensor of shape [1, sequence_length, hidden_dimension] representing the input string.
|
28 |
+
torch.Size([1, 8, 768])
|
29 |
+
```
|
30 |
+
|
31 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
32 |
+
|
33 |
+
This feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
34 |
+
`"feature-extraction"`.
|
35 |
+
|
36 |
+
All models may be used for this pipeline. See a list of all models, including community-contributed models on
|
37 |
+
[huggingface.co/models](https://huggingface.co/models).
|
38 |
+
"""
|
39 |
+
|
40 |
+
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
|
41 |
+
if tokenize_kwargs is None:
|
42 |
+
tokenize_kwargs = {}
|
43 |
+
|
44 |
+
if truncation is not None:
|
45 |
+
if "truncation" in tokenize_kwargs:
|
46 |
+
raise ValueError(
|
47 |
+
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
48 |
+
)
|
49 |
+
tokenize_kwargs["truncation"] = truncation
|
50 |
+
|
51 |
+
preprocess_params = tokenize_kwargs
|
52 |
+
|
53 |
+
postprocess_params = {}
|
54 |
+
if return_tensors is not None:
|
55 |
+
postprocess_params["return_tensors"] = return_tensors
|
56 |
+
|
57 |
+
return preprocess_params, {}, postprocess_params
|
58 |
+
|
59 |
+
def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
|
60 |
+
model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
|
61 |
+
return model_inputs
|
62 |
+
|
63 |
+
def _forward(self, model_inputs):
|
64 |
+
model_outputs = self.model(**model_inputs)
|
65 |
+
return model_outputs
|
66 |
+
|
67 |
+
def postprocess(self, model_outputs, return_tensors=False):
|
68 |
+
# [0] is the first available tensor, logits or last_hidden_state.
|
69 |
+
if return_tensors:
|
70 |
+
return model_outputs[0]
|
71 |
+
if self.framework == "pt":
|
72 |
+
return model_outputs[0].tolist()
|
73 |
+
elif self.framework == "tf":
|
74 |
+
return model_outputs[0].numpy().tolist()
|
75 |
+
|
76 |
+
def __call__(self, *args, **kwargs):
|
77 |
+
"""
|
78 |
+
Extract the features of the input(s).
|
79 |
+
|
80 |
+
Args:
|
81 |
+
args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
|
82 |
+
|
83 |
+
Return:
|
84 |
+
A nested list of `float`: The features computed by the model.
|
85 |
+
"""
|
86 |
+
return super().__call__(*args, **kwargs)
|
.venv/Lib/site-packages/transformers/pipelines/fill_mask.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
6 |
+
from .base import GenericTensor, Pipeline, PipelineException, build_pipeline_init_args
|
7 |
+
|
8 |
+
|
9 |
+
if is_tf_available():
|
10 |
+
import tensorflow as tf
|
11 |
+
|
12 |
+
from ..tf_utils import stable_softmax
|
13 |
+
|
14 |
+
|
15 |
+
if is_torch_available():
|
16 |
+
import torch
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@add_end_docstrings(
|
23 |
+
build_pipeline_init_args(has_tokenizer=True),
|
24 |
+
r"""
|
25 |
+
top_k (`int`, *optional*, defaults to 5):
|
26 |
+
The number of predictions to return.
|
27 |
+
targets (`str` or `List[str]`, *optional*):
|
28 |
+
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
29 |
+
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
|
30 |
+
token will be used (with a warning, and that might be slower).
|
31 |
+
tokenizer_kwargs (`dict`, *optional*):
|
32 |
+
Additional dictionary of keyword arguments passed along to the tokenizer.""",
|
33 |
+
)
|
34 |
+
class FillMaskPipeline(Pipeline):
|
35 |
+
"""
|
36 |
+
Masked language modeling prediction pipeline using any `ModelWithLMHead`. See the [masked language modeling
|
37 |
+
examples](../task_summary#masked-language-modeling) for more information.
|
38 |
+
|
39 |
+
Example:
|
40 |
+
|
41 |
+
```python
|
42 |
+
>>> from transformers import pipeline
|
43 |
+
|
44 |
+
>>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
|
45 |
+
>>> fill_masker("This is a simple [MASK].")
|
46 |
+
[{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}]
|
47 |
+
```
|
48 |
+
|
49 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
50 |
+
|
51 |
+
This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
52 |
+
`"fill-mask"`.
|
53 |
+
|
54 |
+
The models that this pipeline can use are models that have been trained with a masked language modeling objective,
|
55 |
+
which includes the bi-directional models in the library. See the up-to-date list of available models on
|
56 |
+
[huggingface.co/models](https://huggingface.co/models?filter=fill-mask).
|
57 |
+
|
58 |
+
<Tip>
|
59 |
+
|
60 |
+
This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
|
61 |
+
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
|
62 |
+
joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
|
63 |
+
|
64 |
+
</Tip>
|
65 |
+
|
66 |
+
<Tip>
|
67 |
+
|
68 |
+
This pipeline now supports tokenizer_kwargs. For example try:
|
69 |
+
|
70 |
+
```python
|
71 |
+
>>> from transformers import pipeline
|
72 |
+
|
73 |
+
>>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
|
74 |
+
>>> tokenizer_kwargs = {"truncation": True}
|
75 |
+
>>> fill_masker(
|
76 |
+
... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
|
77 |
+
... tokenizer_kwargs=tokenizer_kwargs,
|
78 |
+
... )
|
79 |
+
```
|
80 |
+
|
81 |
+
|
82 |
+
</Tip>
|
83 |
+
|
84 |
+
|
85 |
+
"""
|
86 |
+
|
87 |
+
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
|
88 |
+
if self.framework == "tf":
|
89 |
+
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
90 |
+
elif self.framework == "pt":
|
91 |
+
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
92 |
+
else:
|
93 |
+
raise ValueError("Unsupported framework")
|
94 |
+
return masked_index
|
95 |
+
|
96 |
+
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
|
97 |
+
masked_index = self.get_masked_index(input_ids)
|
98 |
+
numel = np.prod(masked_index.shape)
|
99 |
+
if numel < 1:
|
100 |
+
raise PipelineException(
|
101 |
+
"fill-mask",
|
102 |
+
self.model.base_model_prefix,
|
103 |
+
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
104 |
+
)
|
105 |
+
|
106 |
+
def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
|
107 |
+
if isinstance(model_inputs, list):
|
108 |
+
for model_input in model_inputs:
|
109 |
+
self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
|
110 |
+
else:
|
111 |
+
for input_ids in model_inputs["input_ids"]:
|
112 |
+
self._ensure_exactly_one_mask_token(input_ids)
|
113 |
+
|
114 |
+
def preprocess(
|
115 |
+
self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
|
116 |
+
) -> Dict[str, GenericTensor]:
|
117 |
+
if return_tensors is None:
|
118 |
+
return_tensors = self.framework
|
119 |
+
if tokenizer_kwargs is None:
|
120 |
+
tokenizer_kwargs = {}
|
121 |
+
|
122 |
+
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
|
123 |
+
self.ensure_exactly_one_mask_token(model_inputs)
|
124 |
+
return model_inputs
|
125 |
+
|
126 |
+
def _forward(self, model_inputs):
|
127 |
+
model_outputs = self.model(**model_inputs)
|
128 |
+
model_outputs["input_ids"] = model_inputs["input_ids"]
|
129 |
+
return model_outputs
|
130 |
+
|
131 |
+
def postprocess(self, model_outputs, top_k=5, target_ids=None):
|
132 |
+
# Cap top_k if there are targets
|
133 |
+
if target_ids is not None and target_ids.shape[0] < top_k:
|
134 |
+
top_k = target_ids.shape[0]
|
135 |
+
input_ids = model_outputs["input_ids"][0]
|
136 |
+
outputs = model_outputs["logits"]
|
137 |
+
|
138 |
+
if self.framework == "tf":
|
139 |
+
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
|
140 |
+
|
141 |
+
outputs = outputs.numpy()
|
142 |
+
|
143 |
+
logits = outputs[0, masked_index, :]
|
144 |
+
probs = stable_softmax(logits, axis=-1)
|
145 |
+
if target_ids is not None:
|
146 |
+
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
|
147 |
+
probs = tf.expand_dims(probs, 0)
|
148 |
+
|
149 |
+
topk = tf.math.top_k(probs, k=top_k)
|
150 |
+
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
151 |
+
else:
|
152 |
+
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
|
153 |
+
# Fill mask pipeline supports only one ${mask_token} per sample
|
154 |
+
|
155 |
+
logits = outputs[0, masked_index, :]
|
156 |
+
probs = logits.softmax(dim=-1)
|
157 |
+
if target_ids is not None:
|
158 |
+
probs = probs[..., target_ids]
|
159 |
+
|
160 |
+
values, predictions = probs.topk(top_k)
|
161 |
+
|
162 |
+
result = []
|
163 |
+
single_mask = values.shape[0] == 1
|
164 |
+
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
|
165 |
+
row = []
|
166 |
+
for v, p in zip(_values, _predictions):
|
167 |
+
# Copy is important since we're going to modify this array in place
|
168 |
+
tokens = input_ids.numpy().copy()
|
169 |
+
if target_ids is not None:
|
170 |
+
p = target_ids[p].tolist()
|
171 |
+
|
172 |
+
tokens[masked_index[i]] = p
|
173 |
+
# Filter padding out:
|
174 |
+
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
175 |
+
# Originally we skip special tokens to give readable output.
|
176 |
+
# For multi masks though, the other [MASK] would be removed otherwise
|
177 |
+
# making the output look odd, so we add them back
|
178 |
+
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
|
179 |
+
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
|
180 |
+
row.append(proposition)
|
181 |
+
result.append(row)
|
182 |
+
if single_mask:
|
183 |
+
return result[0]
|
184 |
+
return result
|
185 |
+
|
186 |
+
def get_target_ids(self, targets, top_k=None):
|
187 |
+
if isinstance(targets, str):
|
188 |
+
targets = [targets]
|
189 |
+
try:
|
190 |
+
vocab = self.tokenizer.get_vocab()
|
191 |
+
except Exception:
|
192 |
+
vocab = {}
|
193 |
+
target_ids = []
|
194 |
+
for target in targets:
|
195 |
+
id_ = vocab.get(target, None)
|
196 |
+
if id_ is None:
|
197 |
+
input_ids = self.tokenizer(
|
198 |
+
target,
|
199 |
+
add_special_tokens=False,
|
200 |
+
return_attention_mask=False,
|
201 |
+
return_token_type_ids=False,
|
202 |
+
max_length=1,
|
203 |
+
truncation=True,
|
204 |
+
)["input_ids"]
|
205 |
+
if len(input_ids) == 0:
|
206 |
+
logger.warning(
|
207 |
+
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
208 |
+
"We cannot replace it with anything meaningful, ignoring it"
|
209 |
+
)
|
210 |
+
continue
|
211 |
+
id_ = input_ids[0]
|
212 |
+
# XXX: If users encounter this pass
|
213 |
+
# it becomes pretty slow, so let's make sure
|
214 |
+
# The warning enables them to fix the input to
|
215 |
+
# get faster performance.
|
216 |
+
logger.warning(
|
217 |
+
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
218 |
+
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
|
219 |
+
)
|
220 |
+
target_ids.append(id_)
|
221 |
+
target_ids = list(set(target_ids))
|
222 |
+
if len(target_ids) == 0:
|
223 |
+
raise ValueError("At least one target must be provided when passed.")
|
224 |
+
target_ids = np.array(target_ids)
|
225 |
+
return target_ids
|
226 |
+
|
227 |
+
def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
|
228 |
+
preprocess_params = {}
|
229 |
+
|
230 |
+
if tokenizer_kwargs is not None:
|
231 |
+
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
232 |
+
|
233 |
+
postprocess_params = {}
|
234 |
+
|
235 |
+
if targets is not None:
|
236 |
+
target_ids = self.get_target_ids(targets, top_k)
|
237 |
+
postprocess_params["target_ids"] = target_ids
|
238 |
+
|
239 |
+
if top_k is not None:
|
240 |
+
postprocess_params["top_k"] = top_k
|
241 |
+
|
242 |
+
if self.tokenizer.mask_token_id is None:
|
243 |
+
raise PipelineException(
|
244 |
+
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
|
245 |
+
)
|
246 |
+
return preprocess_params, {}, postprocess_params
|
247 |
+
|
248 |
+
def __call__(self, inputs, **kwargs):
|
249 |
+
"""
|
250 |
+
Fill the masked token in the text(s) given as inputs.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
inputs (`str` or `List[str]`):
|
254 |
+
One or several texts (or one list of prompts) with masked tokens.
|
255 |
+
targets (`str` or `List[str]`, *optional*):
|
256 |
+
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
257 |
+
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
|
258 |
+
resulting token will be used (with a warning, and that might be slower).
|
259 |
+
top_k (`int`, *optional*):
|
260 |
+
When passed, overrides the number of predictions to return.
|
261 |
+
|
262 |
+
Return:
|
263 |
+
A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:
|
264 |
+
|
265 |
+
- **sequence** (`str`) -- The corresponding input with the mask token prediction.
|
266 |
+
- **score** (`float`) -- The corresponding probability.
|
267 |
+
- **token** (`int`) -- The predicted token id (to replace the masked one).
|
268 |
+
- **token_str** (`str`) -- The predicted token (to replace the masked one).
|
269 |
+
"""
|
270 |
+
outputs = super().__call__(inputs, **kwargs)
|
271 |
+
if isinstance(inputs, list) and len(inputs) == 1:
|
272 |
+
return outputs[0]
|
273 |
+
return outputs
|
.venv/Lib/site-packages/transformers/pipelines/image_classification.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List, Union
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from ..utils import (
|
19 |
+
ExplicitEnum,
|
20 |
+
add_end_docstrings,
|
21 |
+
is_tf_available,
|
22 |
+
is_torch_available,
|
23 |
+
is_vision_available,
|
24 |
+
logging,
|
25 |
+
requires_backends,
|
26 |
+
)
|
27 |
+
from .base import Pipeline, build_pipeline_init_args
|
28 |
+
|
29 |
+
|
30 |
+
if is_vision_available():
|
31 |
+
from PIL import Image
|
32 |
+
|
33 |
+
from ..image_utils import load_image
|
34 |
+
|
35 |
+
if is_tf_available():
|
36 |
+
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
37 |
+
|
38 |
+
if is_torch_available():
|
39 |
+
import torch
|
40 |
+
|
41 |
+
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
# Copied from transformers.pipelines.text_classification.sigmoid
|
47 |
+
def sigmoid(_outputs):
|
48 |
+
return 1.0 / (1.0 + np.exp(-_outputs))
|
49 |
+
|
50 |
+
|
51 |
+
# Copied from transformers.pipelines.text_classification.softmax
|
52 |
+
def softmax(_outputs):
|
53 |
+
maxes = np.max(_outputs, axis=-1, keepdims=True)
|
54 |
+
shifted_exp = np.exp(_outputs - maxes)
|
55 |
+
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
56 |
+
|
57 |
+
|
58 |
+
# Copied from transformers.pipelines.text_classification.ClassificationFunction
|
59 |
+
class ClassificationFunction(ExplicitEnum):
|
60 |
+
SIGMOID = "sigmoid"
|
61 |
+
SOFTMAX = "softmax"
|
62 |
+
NONE = "none"
|
63 |
+
|
64 |
+
|
65 |
+
@add_end_docstrings(
|
66 |
+
build_pipeline_init_args(has_image_processor=True),
|
67 |
+
r"""
|
68 |
+
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
69 |
+
The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
|
70 |
+
|
71 |
+
- `"default"`: if the model has a single label, will apply the sigmoid function on the output. If the model
|
72 |
+
has several labels, will apply the softmax function on the output.
|
73 |
+
- `"sigmoid"`: Applies the sigmoid function on the output.
|
74 |
+
- `"softmax"`: Applies the softmax function on the output.
|
75 |
+
- `"none"`: Does not apply any function on the output.""",
|
76 |
+
)
|
77 |
+
class ImageClassificationPipeline(Pipeline):
|
78 |
+
"""
|
79 |
+
Image classification pipeline using any `AutoModelForImageClassification`. This pipeline predicts the class of an
|
80 |
+
image.
|
81 |
+
|
82 |
+
Example:
|
83 |
+
|
84 |
+
```python
|
85 |
+
>>> from transformers import pipeline
|
86 |
+
|
87 |
+
>>> classifier = pipeline(model="microsoft/beit-base-patch16-224-pt22k-ft22k")
|
88 |
+
>>> classifier("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
89 |
+
[{'score': 0.442, 'label': 'macaw'}, {'score': 0.088, 'label': 'popinjay'}, {'score': 0.075, 'label': 'parrot'}, {'score': 0.073, 'label': 'parodist, lampooner'}, {'score': 0.046, 'label': 'poll, poll_parrot'}]
|
90 |
+
```
|
91 |
+
|
92 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
93 |
+
|
94 |
+
This image classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
95 |
+
`"image-classification"`.
|
96 |
+
|
97 |
+
See the list of available models on
|
98 |
+
[huggingface.co/models](https://huggingface.co/models?filter=image-classification).
|
99 |
+
"""
|
100 |
+
|
101 |
+
function_to_apply: ClassificationFunction = ClassificationFunction.NONE
|
102 |
+
|
103 |
+
def __init__(self, *args, **kwargs):
|
104 |
+
super().__init__(*args, **kwargs)
|
105 |
+
requires_backends(self, "vision")
|
106 |
+
self.check_model_type(
|
107 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
108 |
+
if self.framework == "tf"
|
109 |
+
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
110 |
+
)
|
111 |
+
|
112 |
+
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
|
113 |
+
preprocess_params = {}
|
114 |
+
if timeout is not None:
|
115 |
+
preprocess_params["timeout"] = timeout
|
116 |
+
postprocess_params = {}
|
117 |
+
if top_k is not None:
|
118 |
+
postprocess_params["top_k"] = top_k
|
119 |
+
if isinstance(function_to_apply, str):
|
120 |
+
function_to_apply = ClassificationFunction(function_to_apply.lower())
|
121 |
+
if function_to_apply is not None:
|
122 |
+
postprocess_params["function_to_apply"] = function_to_apply
|
123 |
+
return preprocess_params, {}, postprocess_params
|
124 |
+
|
125 |
+
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
126 |
+
"""
|
127 |
+
Assign labels to the image(s) passed as inputs.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
131 |
+
The pipeline handles three types of images:
|
132 |
+
|
133 |
+
- A string containing a http link pointing to an image
|
134 |
+
- A string containing a local path to an image
|
135 |
+
- An image loaded in PIL directly
|
136 |
+
|
137 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
138 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
139 |
+
images.
|
140 |
+
function_to_apply (`str`, *optional*, defaults to `"default"`):
|
141 |
+
The function to apply to the model outputs in order to retrieve the scores. Accepts four different
|
142 |
+
values:
|
143 |
+
|
144 |
+
If this argument is not specified, then it will apply the following functions according to the number
|
145 |
+
of labels:
|
146 |
+
|
147 |
+
- If the model has a single label, will apply the sigmoid function on the output.
|
148 |
+
- If the model has several labels, will apply the softmax function on the output.
|
149 |
+
|
150 |
+
Possible values are:
|
151 |
+
|
152 |
+
- `"sigmoid"`: Applies the sigmoid function on the output.
|
153 |
+
- `"softmax"`: Applies the softmax function on the output.
|
154 |
+
- `"none"`: Does not apply any function on the output.
|
155 |
+
top_k (`int`, *optional*, defaults to 5):
|
156 |
+
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
157 |
+
the number of labels available in the model configuration, it will default to the number of labels.
|
158 |
+
timeout (`float`, *optional*, defaults to None):
|
159 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
160 |
+
the call may block forever.
|
161 |
+
|
162 |
+
Return:
|
163 |
+
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
164 |
+
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
165 |
+
the images.
|
166 |
+
|
167 |
+
The dictionaries contain the following keys:
|
168 |
+
|
169 |
+
- **label** (`str`) -- The label identified by the model.
|
170 |
+
- **score** (`int`) -- The score attributed by the model for that label.
|
171 |
+
"""
|
172 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
173 |
+
if "images" in kwargs:
|
174 |
+
inputs = kwargs.pop("images")
|
175 |
+
if inputs is None:
|
176 |
+
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
177 |
+
return super().__call__(inputs, **kwargs)
|
178 |
+
|
179 |
+
def preprocess(self, image, timeout=None):
|
180 |
+
image = load_image(image, timeout=timeout)
|
181 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
182 |
+
if self.framework == "pt":
|
183 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
184 |
+
return model_inputs
|
185 |
+
|
186 |
+
def _forward(self, model_inputs):
|
187 |
+
model_outputs = self.model(**model_inputs)
|
188 |
+
return model_outputs
|
189 |
+
|
190 |
+
def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
|
191 |
+
if function_to_apply is None:
|
192 |
+
if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
|
193 |
+
function_to_apply = ClassificationFunction.SIGMOID
|
194 |
+
elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
|
195 |
+
function_to_apply = ClassificationFunction.SOFTMAX
|
196 |
+
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
|
197 |
+
function_to_apply = self.model.config.function_to_apply
|
198 |
+
else:
|
199 |
+
function_to_apply = ClassificationFunction.NONE
|
200 |
+
|
201 |
+
if top_k > self.model.config.num_labels:
|
202 |
+
top_k = self.model.config.num_labels
|
203 |
+
|
204 |
+
outputs = model_outputs["logits"][0]
|
205 |
+
if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
|
206 |
+
outputs = outputs.to(torch.float32).numpy()
|
207 |
+
else:
|
208 |
+
outputs = outputs.numpy()
|
209 |
+
|
210 |
+
if function_to_apply == ClassificationFunction.SIGMOID:
|
211 |
+
scores = sigmoid(outputs)
|
212 |
+
elif function_to_apply == ClassificationFunction.SOFTMAX:
|
213 |
+
scores = softmax(outputs)
|
214 |
+
elif function_to_apply == ClassificationFunction.NONE:
|
215 |
+
scores = outputs
|
216 |
+
else:
|
217 |
+
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
|
218 |
+
|
219 |
+
dict_scores = [
|
220 |
+
{"label": self.model.config.id2label[i], "score": score.item()} for i, score in enumerate(scores)
|
221 |
+
]
|
222 |
+
dict_scores.sort(key=lambda x: x["score"], reverse=True)
|
223 |
+
if top_k is not None:
|
224 |
+
dict_scores = dict_scores[:top_k]
|
225 |
+
|
226 |
+
return dict_scores
|
.venv/Lib/site-packages/transformers/pipelines/image_feature_extraction.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from ..utils import add_end_docstrings, is_vision_available
|
4 |
+
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
5 |
+
|
6 |
+
|
7 |
+
if is_vision_available():
|
8 |
+
from ..image_utils import load_image
|
9 |
+
|
10 |
+
|
11 |
+
@add_end_docstrings(
|
12 |
+
build_pipeline_init_args(has_image_processor=True),
|
13 |
+
"""
|
14 |
+
image_processor_kwargs (`dict`, *optional*):
|
15 |
+
Additional dictionary of keyword arguments passed along to the image processor e.g.
|
16 |
+
{"size": {"height": 100, "width": 100}}
|
17 |
+
pool (`bool`, *optional*, defaults to `False`):
|
18 |
+
Whether or not to return the pooled output. If `False`, the model will return the raw hidden states.
|
19 |
+
""",
|
20 |
+
)
|
21 |
+
class ImageFeatureExtractionPipeline(Pipeline):
|
22 |
+
"""
|
23 |
+
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
24 |
+
transformer, which can be used as features in downstream tasks.
|
25 |
+
|
26 |
+
Example:
|
27 |
+
|
28 |
+
```python
|
29 |
+
>>> from transformers import pipeline
|
30 |
+
|
31 |
+
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
|
32 |
+
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
|
33 |
+
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
|
34 |
+
torch.Size([1, 197, 768])
|
35 |
+
```
|
36 |
+
|
37 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
38 |
+
|
39 |
+
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
40 |
+
`"image-feature-extraction"`.
|
41 |
+
|
42 |
+
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
|
43 |
+
[huggingface.co/models](https://huggingface.co/models).
|
44 |
+
"""
|
45 |
+
|
46 |
+
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, pool=None, **kwargs):
|
47 |
+
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
|
48 |
+
|
49 |
+
postprocess_params = {}
|
50 |
+
if pool is not None:
|
51 |
+
postprocess_params["pool"] = pool
|
52 |
+
if return_tensors is not None:
|
53 |
+
postprocess_params["return_tensors"] = return_tensors
|
54 |
+
|
55 |
+
if "timeout" in kwargs:
|
56 |
+
preprocess_params["timeout"] = kwargs["timeout"]
|
57 |
+
|
58 |
+
return preprocess_params, {}, postprocess_params
|
59 |
+
|
60 |
+
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
|
61 |
+
image = load_image(image, timeout=timeout)
|
62 |
+
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
|
63 |
+
if self.framework == "pt":
|
64 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
65 |
+
return model_inputs
|
66 |
+
|
67 |
+
def _forward(self, model_inputs):
|
68 |
+
model_outputs = self.model(**model_inputs)
|
69 |
+
return model_outputs
|
70 |
+
|
71 |
+
def postprocess(self, model_outputs, pool=None, return_tensors=False):
|
72 |
+
pool = pool if pool is not None else False
|
73 |
+
|
74 |
+
if pool:
|
75 |
+
if "pooler_output" not in model_outputs:
|
76 |
+
raise ValueError(
|
77 |
+
"No pooled output was returned. Make sure the model has a `pooler` layer when using the `pool` option."
|
78 |
+
)
|
79 |
+
outputs = model_outputs["pooler_output"]
|
80 |
+
else:
|
81 |
+
# [0] is the first available tensor, logits or last_hidden_state.
|
82 |
+
outputs = model_outputs[0]
|
83 |
+
|
84 |
+
if return_tensors:
|
85 |
+
return outputs
|
86 |
+
if self.framework == "pt":
|
87 |
+
return outputs.tolist()
|
88 |
+
elif self.framework == "tf":
|
89 |
+
return outputs.numpy().tolist()
|
90 |
+
|
91 |
+
def __call__(self, *args, **kwargs):
|
92 |
+
"""
|
93 |
+
Extract the features of the input(s).
|
94 |
+
|
95 |
+
Args:
|
96 |
+
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
97 |
+
The pipeline handles three types of images:
|
98 |
+
|
99 |
+
- A string containing a http link pointing to an image
|
100 |
+
- A string containing a local path to an image
|
101 |
+
- An image loaded in PIL directly
|
102 |
+
|
103 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
104 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
105 |
+
images.
|
106 |
+
timeout (`float`, *optional*, defaults to None):
|
107 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
108 |
+
the call may block forever.
|
109 |
+
Return:
|
110 |
+
A nested list of `float`: The features computed by the model.
|
111 |
+
"""
|
112 |
+
return super().__call__(*args, **kwargs)
|
.venv/Lib/site-packages/transformers/pipelines/image_segmentation.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
6 |
+
from .base import Pipeline, build_pipeline_init_args
|
7 |
+
|
8 |
+
|
9 |
+
if is_vision_available():
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from ..image_utils import load_image
|
13 |
+
|
14 |
+
if is_torch_available():
|
15 |
+
from ..models.auto.modeling_auto import (
|
16 |
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
|
17 |
+
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
|
18 |
+
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
19 |
+
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
Prediction = Dict[str, Any]
|
27 |
+
Predictions = List[Prediction]
|
28 |
+
|
29 |
+
|
30 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
|
31 |
+
class ImageSegmentationPipeline(Pipeline):
|
32 |
+
"""
|
33 |
+
Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
|
34 |
+
their classes.
|
35 |
+
|
36 |
+
Example:
|
37 |
+
|
38 |
+
```python
|
39 |
+
>>> from transformers import pipeline
|
40 |
+
|
41 |
+
>>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
|
42 |
+
>>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
43 |
+
>>> len(segments)
|
44 |
+
2
|
45 |
+
|
46 |
+
>>> segments[0]["label"]
|
47 |
+
'bird'
|
48 |
+
|
49 |
+
>>> segments[1]["label"]
|
50 |
+
'bird'
|
51 |
+
|
52 |
+
>>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
|
53 |
+
<class 'PIL.Image.Image'>
|
54 |
+
|
55 |
+
>>> segments[0]["mask"].size
|
56 |
+
(768, 512)
|
57 |
+
```
|
58 |
+
|
59 |
+
|
60 |
+
This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
61 |
+
`"image-segmentation"`.
|
62 |
+
|
63 |
+
See the list of available models on
|
64 |
+
[huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, *args, **kwargs):
|
68 |
+
super().__init__(*args, **kwargs)
|
69 |
+
|
70 |
+
if self.framework == "tf":
|
71 |
+
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
72 |
+
|
73 |
+
requires_backends(self, "vision")
|
74 |
+
mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
|
75 |
+
mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
|
76 |
+
mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
|
77 |
+
mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
|
78 |
+
self.check_model_type(mapping)
|
79 |
+
|
80 |
+
def _sanitize_parameters(self, **kwargs):
|
81 |
+
preprocess_kwargs = {}
|
82 |
+
postprocess_kwargs = {}
|
83 |
+
if "subtask" in kwargs:
|
84 |
+
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
85 |
+
preprocess_kwargs["subtask"] = kwargs["subtask"]
|
86 |
+
if "threshold" in kwargs:
|
87 |
+
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
88 |
+
if "mask_threshold" in kwargs:
|
89 |
+
postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
|
90 |
+
if "overlap_mask_area_threshold" in kwargs:
|
91 |
+
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
92 |
+
if "timeout" in kwargs:
|
93 |
+
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
94 |
+
|
95 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
96 |
+
|
97 |
+
def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
|
98 |
+
"""
|
99 |
+
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
103 |
+
The pipeline handles three types of images:
|
104 |
+
|
105 |
+
- A string containing an HTTP(S) link pointing to an image
|
106 |
+
- A string containing a local path to an image
|
107 |
+
- An image loaded in PIL directly
|
108 |
+
|
109 |
+
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
110 |
+
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
111 |
+
subtask (`str`, *optional*):
|
112 |
+
Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
|
113 |
+
capabilities. If not set, the pipeline will attempt tp resolve in the following order:
|
114 |
+
`panoptic`, `instance`, `semantic`.
|
115 |
+
threshold (`float`, *optional*, defaults to 0.9):
|
116 |
+
Probability threshold to filter out predicted masks.
|
117 |
+
mask_threshold (`float`, *optional*, defaults to 0.5):
|
118 |
+
Threshold to use when turning the predicted masks into binary values.
|
119 |
+
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
|
120 |
+
Mask overlap threshold to eliminate small, disconnected segments.
|
121 |
+
timeout (`float`, *optional*, defaults to None):
|
122 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
123 |
+
the call may block forever.
|
124 |
+
|
125 |
+
Return:
|
126 |
+
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
127 |
+
list of dictionaries, if the input is a list of several images, will return a list of list of dictionaries
|
128 |
+
corresponding to each image.
|
129 |
+
|
130 |
+
The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
|
131 |
+
the following keys:
|
132 |
+
|
133 |
+
- **label** (`str`) -- The class label identified by the model.
|
134 |
+
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
|
135 |
+
the original image. Returns a mask filled with zeros if no object is found.
|
136 |
+
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
|
137 |
+
"object" described by the label and the mask.
|
138 |
+
"""
|
139 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
140 |
+
if "images" in kwargs:
|
141 |
+
inputs = kwargs.pop("images")
|
142 |
+
if inputs is None:
|
143 |
+
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
144 |
+
return super().__call__(inputs, **kwargs)
|
145 |
+
|
146 |
+
def preprocess(self, image, subtask=None, timeout=None):
|
147 |
+
image = load_image(image, timeout=timeout)
|
148 |
+
target_size = [(image.height, image.width)]
|
149 |
+
if self.model.config.__class__.__name__ == "OneFormerConfig":
|
150 |
+
if subtask is None:
|
151 |
+
kwargs = {}
|
152 |
+
else:
|
153 |
+
kwargs = {"task_inputs": [subtask]}
|
154 |
+
inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
|
155 |
+
if self.framework == "pt":
|
156 |
+
inputs = inputs.to(self.torch_dtype)
|
157 |
+
inputs["task_inputs"] = self.tokenizer(
|
158 |
+
inputs["task_inputs"],
|
159 |
+
padding="max_length",
|
160 |
+
max_length=self.model.config.task_seq_len,
|
161 |
+
return_tensors=self.framework,
|
162 |
+
)["input_ids"]
|
163 |
+
else:
|
164 |
+
inputs = self.image_processor(images=[image], return_tensors="pt")
|
165 |
+
if self.framework == "pt":
|
166 |
+
inputs = inputs.to(self.torch_dtype)
|
167 |
+
inputs["target_size"] = target_size
|
168 |
+
return inputs
|
169 |
+
|
170 |
+
def _forward(self, model_inputs):
|
171 |
+
target_size = model_inputs.pop("target_size")
|
172 |
+
model_outputs = self.model(**model_inputs)
|
173 |
+
model_outputs["target_size"] = target_size
|
174 |
+
return model_outputs
|
175 |
+
|
176 |
+
def postprocess(
|
177 |
+
self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
|
178 |
+
):
|
179 |
+
fn = None
|
180 |
+
if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
|
181 |
+
fn = self.image_processor.post_process_panoptic_segmentation
|
182 |
+
elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
|
183 |
+
fn = self.image_processor.post_process_instance_segmentation
|
184 |
+
|
185 |
+
if fn is not None:
|
186 |
+
outputs = fn(
|
187 |
+
model_outputs,
|
188 |
+
threshold=threshold,
|
189 |
+
mask_threshold=mask_threshold,
|
190 |
+
overlap_mask_area_threshold=overlap_mask_area_threshold,
|
191 |
+
target_sizes=model_outputs["target_size"],
|
192 |
+
)[0]
|
193 |
+
|
194 |
+
annotation = []
|
195 |
+
segmentation = outputs["segmentation"]
|
196 |
+
|
197 |
+
for segment in outputs["segments_info"]:
|
198 |
+
mask = (segmentation == segment["id"]) * 255
|
199 |
+
mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
|
200 |
+
label = self.model.config.id2label[segment["label_id"]]
|
201 |
+
score = segment["score"]
|
202 |
+
annotation.append({"score": score, "label": label, "mask": mask})
|
203 |
+
|
204 |
+
elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
|
205 |
+
outputs = self.image_processor.post_process_semantic_segmentation(
|
206 |
+
model_outputs, target_sizes=model_outputs["target_size"]
|
207 |
+
)[0]
|
208 |
+
|
209 |
+
annotation = []
|
210 |
+
segmentation = outputs.numpy()
|
211 |
+
labels = np.unique(segmentation)
|
212 |
+
|
213 |
+
for label in labels:
|
214 |
+
mask = (segmentation == label) * 255
|
215 |
+
mask = Image.fromarray(mask.astype(np.uint8), mode="L")
|
216 |
+
label = self.model.config.id2label[label]
|
217 |
+
annotation.append({"score": None, "label": label, "mask": mask})
|
218 |
+
else:
|
219 |
+
raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
|
220 |
+
return annotation
|
.venv/Lib/site-packages/transformers/pipelines/image_text_to_text.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import enum
|
17 |
+
from typing import Dict, List, Optional, Union
|
18 |
+
|
19 |
+
from ..processing_utils import ProcessingKwargs, Unpack
|
20 |
+
from ..utils import (
|
21 |
+
add_end_docstrings,
|
22 |
+
is_torch_available,
|
23 |
+
is_vision_available,
|
24 |
+
logging,
|
25 |
+
requires_backends,
|
26 |
+
)
|
27 |
+
from .base import Pipeline, build_pipeline_init_args
|
28 |
+
|
29 |
+
|
30 |
+
if is_vision_available():
|
31 |
+
from PIL import Image
|
32 |
+
|
33 |
+
from ..image_utils import load_images, valid_images
|
34 |
+
|
35 |
+
|
36 |
+
if is_torch_available():
|
37 |
+
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
38 |
+
from .pt_utils import KeyDataset
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
IMAGE_TOKEN = "<image>"
|
43 |
+
|
44 |
+
|
45 |
+
class ReturnType(enum.Enum):
|
46 |
+
TENSORS = 0
|
47 |
+
NEW_TEXT = 1
|
48 |
+
FULL_TEXT = 2
|
49 |
+
|
50 |
+
|
51 |
+
class Chat:
|
52 |
+
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
|
53 |
+
to this format because the rest of the pipeline code tends to assume that lists of messages are
|
54 |
+
actually a batch of samples rather than messages in the same conversation."""
|
55 |
+
|
56 |
+
def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]):
|
57 |
+
for message in messages:
|
58 |
+
if not ("role" in message and "content" in message):
|
59 |
+
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
60 |
+
images = retrieve_images_in_messages(messages, images)
|
61 |
+
|
62 |
+
self.messages = messages
|
63 |
+
self.images = images
|
64 |
+
|
65 |
+
|
66 |
+
def retrieve_images_in_messages(
|
67 |
+
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Retrieve and combine images from the chat and the images passed as input.
|
71 |
+
"""
|
72 |
+
if images is None:
|
73 |
+
images = []
|
74 |
+
idx_images = 0
|
75 |
+
retrieved_images = []
|
76 |
+
for message in messages:
|
77 |
+
for content in message["content"]:
|
78 |
+
if isinstance(content, dict):
|
79 |
+
if content.get("type") == "image":
|
80 |
+
for key in ["image", "url", "path", "base64"]:
|
81 |
+
if key in content:
|
82 |
+
retrieved_images.append(content[key])
|
83 |
+
break
|
84 |
+
else:
|
85 |
+
if idx_images < len(images):
|
86 |
+
retrieved_images.append(images[idx_images])
|
87 |
+
idx_images += 1
|
88 |
+
else:
|
89 |
+
raise ValueError(
|
90 |
+
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
91 |
+
)
|
92 |
+
# Add support for OpenAI/TGI chat format
|
93 |
+
elif content.get("type") == "image_url":
|
94 |
+
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
|
95 |
+
retrieved_images.append(content["image_url"]["url"])
|
96 |
+
# Rewrite content to be in the Transformers chat format
|
97 |
+
content["type"] = "image"
|
98 |
+
content["image"] = content["image_url"]["url"]
|
99 |
+
del content["image_url"]
|
100 |
+
else:
|
101 |
+
raise ValueError(
|
102 |
+
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
|
103 |
+
)
|
104 |
+
|
105 |
+
# The number of images passed should be consistent with the number of images in the chat without an image key
|
106 |
+
if idx_images != len(images):
|
107 |
+
raise ValueError(
|
108 |
+
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
109 |
+
)
|
110 |
+
|
111 |
+
return retrieved_images
|
112 |
+
|
113 |
+
|
114 |
+
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
|
115 |
+
class ImageTextToTextPipeline(Pipeline):
|
116 |
+
"""
|
117 |
+
Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
|
118 |
+
When the underlying model is a conversational model, it can also accept one or more chats,
|
119 |
+
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
|
120 |
+
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
|
121 |
+
|
122 |
+
Example:
|
123 |
+
|
124 |
+
```python
|
125 |
+
>>> from transformers import pipeline
|
126 |
+
|
127 |
+
>>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base")
|
128 |
+
>>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of")
|
129 |
+
[{'generated_text': 'a photo of two birds'}]
|
130 |
+
```
|
131 |
+
|
132 |
+
```python
|
133 |
+
>>> from transformers import pipeline
|
134 |
+
|
135 |
+
>>> pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
136 |
+
>>> messages = [
|
137 |
+
>>> {
|
138 |
+
>>> "role": "user",
|
139 |
+
>>> "content": [
|
140 |
+
>>> {
|
141 |
+
>>> "type": "image",
|
142 |
+
>>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
143 |
+
>>> },
|
144 |
+
>>> {"type": "text", "text": "Describe this image."},
|
145 |
+
>>> ],
|
146 |
+
>>> },
|
147 |
+
>>> {
|
148 |
+
>>> "role": "assistant",
|
149 |
+
>>> "content": [
|
150 |
+
>>> {"type": "text", "text": "There is a dog and"},
|
151 |
+
>>> ],
|
152 |
+
>>> },
|
153 |
+
>>> ]
|
154 |
+
>>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
|
155 |
+
[{'input_text': [{'role': 'user',
|
156 |
+
'content': [{'type': 'image',
|
157 |
+
'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
|
158 |
+
{'type': 'text', 'text': 'Describe this image.'}]},
|
159 |
+
{'role': 'assistant',
|
160 |
+
'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
|
161 |
+
'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}]
|
162 |
+
```
|
163 |
+
|
164 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
165 |
+
|
166 |
+
This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier:
|
167 |
+
"image-text-to-text".
|
168 |
+
|
169 |
+
See the list of available models on
|
170 |
+
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text).
|
171 |
+
"""
|
172 |
+
|
173 |
+
_load_processor = True
|
174 |
+
_load_image_processor = False
|
175 |
+
_load_feature_extractor = False
|
176 |
+
_load_tokenizer = False
|
177 |
+
|
178 |
+
def __init__(self, *args, **kwargs):
|
179 |
+
super().__init__(*args, **kwargs)
|
180 |
+
requires_backends(self, "vision")
|
181 |
+
self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
182 |
+
|
183 |
+
def _sanitize_parameters(
|
184 |
+
self,
|
185 |
+
max_new_tokens=None,
|
186 |
+
generate_kwargs=None,
|
187 |
+
timeout=None,
|
188 |
+
return_full_text=None,
|
189 |
+
return_tensors=None,
|
190 |
+
return_type=None,
|
191 |
+
continue_final_message=None,
|
192 |
+
**kwargs: Unpack[ProcessingKwargs],
|
193 |
+
):
|
194 |
+
forward_kwargs = {}
|
195 |
+
preprocess_params = {}
|
196 |
+
postprocess_params = {}
|
197 |
+
|
198 |
+
preprocess_params["processing_kwargs"] = kwargs
|
199 |
+
|
200 |
+
if timeout is not None:
|
201 |
+
preprocess_params["timeout"] = timeout
|
202 |
+
|
203 |
+
if continue_final_message is not None:
|
204 |
+
preprocess_params["continue_final_message"] = continue_final_message
|
205 |
+
|
206 |
+
if generate_kwargs is not None:
|
207 |
+
forward_kwargs["generate_kwargs"] = generate_kwargs
|
208 |
+
|
209 |
+
if max_new_tokens is not None:
|
210 |
+
if "generate_kwargs" not in forward_kwargs:
|
211 |
+
forward_kwargs["generate_kwargs"] = {}
|
212 |
+
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
|
213 |
+
raise ValueError(
|
214 |
+
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
|
215 |
+
" please use only one"
|
216 |
+
)
|
217 |
+
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
218 |
+
|
219 |
+
if return_full_text is not None and return_type is None:
|
220 |
+
if return_tensors is not None:
|
221 |
+
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
|
222 |
+
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
|
223 |
+
if return_tensors is not None and return_type is None:
|
224 |
+
return_type = ReturnType.TENSORS
|
225 |
+
if return_type is not None:
|
226 |
+
postprocess_params["return_type"] = return_type
|
227 |
+
if continue_final_message is not None:
|
228 |
+
postprocess_params["continue_final_message"] = continue_final_message
|
229 |
+
|
230 |
+
return preprocess_params, forward_kwargs, postprocess_params
|
231 |
+
|
232 |
+
def __call__(
|
233 |
+
self,
|
234 |
+
images: Optional[
|
235 |
+
Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]]
|
236 |
+
] = None,
|
237 |
+
text: Optional[Union[str, List[str], List[dict]]] = None,
|
238 |
+
**kwargs,
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
Generate a text given text and the image(s) passed as inputs.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`):
|
245 |
+
The pipeline handles three types of images:
|
246 |
+
|
247 |
+
- A string containing a HTTP(s) link pointing to an image
|
248 |
+
- A string containing a local path to an image
|
249 |
+
- An image loaded in PIL directly
|
250 |
+
|
251 |
+
The pipeline accepts either a single image or a batch of images.
|
252 |
+
text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`):
|
253 |
+
The text to be used for generation. If a list of strings is passed, the length of the list should be the
|
254 |
+
same as the number of images. Text can also follow the chat format: a list of dictionaries where each
|
255 |
+
dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and
|
256 |
+
'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of dictionary
|
257 |
+
containing the text of the message and the type of the message. The type of the message can be either
|
258 |
+
'text' or 'image'. If the type is 'image', no text is needed.
|
259 |
+
return_tensors (`bool`, *optional*, defaults to `False`):
|
260 |
+
Returns the tensors of predictions (as token indices) in the outputs. If set to
|
261 |
+
`True`, the decoded text is not returned.
|
262 |
+
return_text (`bool`, *optional*):
|
263 |
+
Returns the decoded texts in the outputs.
|
264 |
+
return_full_text (`bool`, *optional*, defaults to `True`):
|
265 |
+
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
|
266 |
+
specified at the same time as `return_text`.
|
267 |
+
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
|
268 |
+
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
|
269 |
+
By default this is `True` when the final message in the input chat has the `assistant` role and
|
270 |
+
`False` otherwise, but you can manually override that behaviour by setting this flag.
|
271 |
+
|
272 |
+
Return:
|
273 |
+
A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot return a combination
|
274 |
+
of both `generated_text` and `generated_token_ids`):
|
275 |
+
|
276 |
+
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
277 |
+
- **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token
|
278 |
+
ids of the generated text.
|
279 |
+
- **input_text** (`str`) -- The input text.
|
280 |
+
"""
|
281 |
+
if images is None and text is None:
|
282 |
+
raise ValueError("You must at least provide either text or images.")
|
283 |
+
if images is not None and text is None and not valid_images(images):
|
284 |
+
"""
|
285 |
+
Supports the following format
|
286 |
+
- {"image": image, "text": text}
|
287 |
+
- [{"image": image, "text": text}]
|
288 |
+
- Generator and datasets
|
289 |
+
This is a common pattern in other multimodal pipelines, so we support it here as well.
|
290 |
+
"""
|
291 |
+
return super().__call__(images, **kwargs)
|
292 |
+
|
293 |
+
if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)):
|
294 |
+
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
295 |
+
if isinstance(text[0], dict):
|
296 |
+
return super().__call__(Chat(text, images), **kwargs)
|
297 |
+
else:
|
298 |
+
if images is None:
|
299 |
+
images = [None] * len(text)
|
300 |
+
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
|
301 |
+
return super().__call__(chats, **kwargs)
|
302 |
+
|
303 |
+
# encourage the user to use the chat format if supported
|
304 |
+
if getattr(self.processor, "chat_template", None) is not None:
|
305 |
+
logger.warning_once(
|
306 |
+
"The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. "
|
307 |
+
"Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating"
|
308 |
+
)
|
309 |
+
|
310 |
+
# support text only generation
|
311 |
+
if images is None:
|
312 |
+
return super().__call__(text, **kwargs)
|
313 |
+
if text is None:
|
314 |
+
raise ValueError("You must provide text for this pipeline.")
|
315 |
+
|
316 |
+
return super().__call__({"images": images, "text": text}, **kwargs)
|
317 |
+
|
318 |
+
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
|
319 |
+
# In case we only have text inputs
|
320 |
+
if isinstance(inputs, (list, tuple, str)):
|
321 |
+
images = None
|
322 |
+
text = inputs
|
323 |
+
inputs_text = inputs
|
324 |
+
else:
|
325 |
+
if isinstance(inputs, Chat):
|
326 |
+
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
327 |
+
# because very few models support multiple separate, consecutive assistant messages
|
328 |
+
if continue_final_message is None:
|
329 |
+
continue_final_message = inputs.messages[-1]["role"] == "assistant"
|
330 |
+
text = self.processor.apply_chat_template(
|
331 |
+
inputs.messages,
|
332 |
+
add_generation_prompt=not continue_final_message,
|
333 |
+
continue_final_message=continue_final_message,
|
334 |
+
return_tensors=self.framework,
|
335 |
+
)
|
336 |
+
inputs_text = inputs
|
337 |
+
images = inputs.images
|
338 |
+
else:
|
339 |
+
text = inputs["text"]
|
340 |
+
inputs_text = inputs["text"]
|
341 |
+
images = inputs["images"]
|
342 |
+
|
343 |
+
images = load_images(images)
|
344 |
+
|
345 |
+
# if batched text inputs, we set padding to True unless specified otherwise
|
346 |
+
if isinstance(text, (list, tuple)) and len(text) > 1:
|
347 |
+
processing_kwargs.setdefault("padding", True)
|
348 |
+
model_inputs = self.processor(
|
349 |
+
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
|
350 |
+
).to(dtype=self.torch_dtype)
|
351 |
+
|
352 |
+
model_inputs["text"] = inputs_text
|
353 |
+
|
354 |
+
return model_inputs
|
355 |
+
|
356 |
+
def _forward(self, model_inputs, generate_kwargs=None):
|
357 |
+
generate_kwargs = {} if generate_kwargs is None else generate_kwargs
|
358 |
+
prompt_text = model_inputs.pop("text")
|
359 |
+
input_ids = (
|
360 |
+
model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"]
|
361 |
+
) # for decoder-only models
|
362 |
+
generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)
|
363 |
+
|
364 |
+
return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
|
365 |
+
|
366 |
+
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
|
367 |
+
input_texts = model_outputs["prompt_text"]
|
368 |
+
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
|
369 |
+
generated_sequence = model_outputs["generated_sequence"]
|
370 |
+
input_ids = model_outputs["input_ids"]
|
371 |
+
if return_type == ReturnType.TENSORS:
|
372 |
+
return [
|
373 |
+
{"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
|
374 |
+
for i in range(len(input_texts))
|
375 |
+
]
|
376 |
+
|
377 |
+
# Decode inputs and outputs the same way to remove input text from generated text if present
|
378 |
+
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
|
379 |
+
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
|
380 |
+
|
381 |
+
# Force consistent behavior for including the input text in the output
|
382 |
+
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
383 |
+
# Remove the input text from the generated text if the generated text starts with the input text
|
384 |
+
# (accounting for the possibility of a space between the input and generated text)
|
385 |
+
new_generated_texts = []
|
386 |
+
for text_generated, decoded_input in zip(generated_texts, decoded_inputs):
|
387 |
+
# There can be added characters before the input text, so we need to find the beginning of the input text in the generated text
|
388 |
+
index_input_text = text_generated.find(decoded_input)
|
389 |
+
# Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer
|
390 |
+
if 0 <= index_input_text <= 2:
|
391 |
+
# If the input text is found, we remove it
|
392 |
+
new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :])
|
393 |
+
else:
|
394 |
+
new_generated_texts.append(text_generated)
|
395 |
+
generated_texts = new_generated_texts
|
396 |
+
if return_type == ReturnType.FULL_TEXT:
|
397 |
+
full_texts = []
|
398 |
+
for prompt_text, generated_text in zip(input_texts, generated_texts):
|
399 |
+
if isinstance(prompt_text, str):
|
400 |
+
generated_text = prompt_text + generated_text
|
401 |
+
elif isinstance(prompt_text, Chat):
|
402 |
+
if continue_final_message is None:
|
403 |
+
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
|
404 |
+
# default because very few models support multiple separate, consecutive assistant messages
|
405 |
+
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
406 |
+
if continue_final_message:
|
407 |
+
# With assistant prefill, concat onto the end of the last message
|
408 |
+
new_text = dict(prompt_text.messages[-1]["content"][-1].items())
|
409 |
+
new_text["text"] += generated_text
|
410 |
+
generated_text = list(prompt_text.messages)[:-1] + [
|
411 |
+
{
|
412 |
+
"role": prompt_text.messages[-1]["role"],
|
413 |
+
"content": prompt_text.messages[-1]["content"][:-1] + [new_text],
|
414 |
+
}
|
415 |
+
]
|
416 |
+
else:
|
417 |
+
# When we're not starting from a prefill, the output is a new assistant message
|
418 |
+
generated_text = list(prompt_text.messages) + [
|
419 |
+
{"role": "assistant", "content": generated_text}
|
420 |
+
]
|
421 |
+
full_texts.append(generated_text)
|
422 |
+
generated_texts = full_texts
|
423 |
+
|
424 |
+
records = [
|
425 |
+
{
|
426 |
+
"input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
|
427 |
+
"generated_text": generated_text,
|
428 |
+
}
|
429 |
+
for input_text, generated_text in zip(input_texts, generated_texts)
|
430 |
+
]
|
431 |
+
|
432 |
+
return records
|
.venv/Lib/site-packages/transformers/pipelines/image_to_image.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List, Union
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from ..utils import (
|
19 |
+
add_end_docstrings,
|
20 |
+
is_torch_available,
|
21 |
+
is_vision_available,
|
22 |
+
logging,
|
23 |
+
requires_backends,
|
24 |
+
)
|
25 |
+
from .base import Pipeline, build_pipeline_init_args
|
26 |
+
|
27 |
+
|
28 |
+
if is_vision_available():
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
from ..image_utils import load_image
|
32 |
+
|
33 |
+
if is_torch_available():
|
34 |
+
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
|
40 |
+
class ImageToImagePipeline(Pipeline):
|
41 |
+
"""
|
42 |
+
Image to Image pipeline using any `AutoModelForImageToImage`. This pipeline generates an image based on a previous
|
43 |
+
image input.
|
44 |
+
|
45 |
+
Example:
|
46 |
+
|
47 |
+
```python
|
48 |
+
>>> from PIL import Image
|
49 |
+
>>> import requests
|
50 |
+
|
51 |
+
>>> from transformers import pipeline
|
52 |
+
|
53 |
+
>>> upscaler = pipeline("image-to-image", model="caidas/swin2SR-classical-sr-x2-64")
|
54 |
+
>>> img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
55 |
+
>>> img = img.resize((64, 64))
|
56 |
+
>>> upscaled_img = upscaler(img)
|
57 |
+
>>> img.size
|
58 |
+
(64, 64)
|
59 |
+
|
60 |
+
>>> upscaled_img.size
|
61 |
+
(144, 144)
|
62 |
+
```
|
63 |
+
|
64 |
+
This image to image pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
65 |
+
`"image-to-image"`.
|
66 |
+
|
67 |
+
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=image-to-image).
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, *args, **kwargs):
|
71 |
+
super().__init__(*args, **kwargs)
|
72 |
+
requires_backends(self, "vision")
|
73 |
+
self.check_model_type(MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
74 |
+
|
75 |
+
def _sanitize_parameters(self, **kwargs):
|
76 |
+
preprocess_params = {}
|
77 |
+
postprocess_params = {}
|
78 |
+
forward_params = {}
|
79 |
+
|
80 |
+
if "timeout" in kwargs:
|
81 |
+
preprocess_params["timeout"] = kwargs["timeout"]
|
82 |
+
if "head_mask" in kwargs:
|
83 |
+
forward_params["head_mask"] = kwargs["head_mask"]
|
84 |
+
|
85 |
+
return preprocess_params, forward_params, postprocess_params
|
86 |
+
|
87 |
+
def __call__(
|
88 |
+
self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs
|
89 |
+
) -> Union["Image.Image", List["Image.Image"]]:
|
90 |
+
"""
|
91 |
+
Transform the image(s) passed as inputs.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
95 |
+
The pipeline handles three types of images:
|
96 |
+
|
97 |
+
- A string containing a http link pointing to an image
|
98 |
+
- A string containing a local path to an image
|
99 |
+
- An image loaded in PIL directly
|
100 |
+
|
101 |
+
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
102 |
+
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
103 |
+
images.
|
104 |
+
timeout (`float`, *optional*, defaults to None):
|
105 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
106 |
+
the call may block forever.
|
107 |
+
|
108 |
+
Return:
|
109 |
+
An image (Image.Image) or a list of images (List["Image.Image"]) containing result(s). If the input is a
|
110 |
+
single image, the return will be also a single image, if the input is a list of several images, it will
|
111 |
+
return a list of transformed images.
|
112 |
+
"""
|
113 |
+
return super().__call__(images, **kwargs)
|
114 |
+
|
115 |
+
def _forward(self, model_inputs):
|
116 |
+
model_outputs = self.model(**model_inputs)
|
117 |
+
return model_outputs
|
118 |
+
|
119 |
+
def preprocess(self, image, timeout=None):
|
120 |
+
image = load_image(image, timeout=timeout)
|
121 |
+
inputs = self.image_processor(images=[image], return_tensors="pt")
|
122 |
+
if self.framework == "pt":
|
123 |
+
inputs = inputs.to(self.torch_dtype)
|
124 |
+
return inputs
|
125 |
+
|
126 |
+
def postprocess(self, model_outputs):
|
127 |
+
images = []
|
128 |
+
if "reconstruction" in model_outputs.keys():
|
129 |
+
outputs = model_outputs.reconstruction
|
130 |
+
for output in outputs:
|
131 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
132 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
133 |
+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
134 |
+
images.append(Image.fromarray(output))
|
135 |
+
|
136 |
+
return images if len(images) > 1 else images[0]
|
.venv/Lib/site-packages/transformers/pipelines/image_to_text.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import List, Union
|
17 |
+
|
18 |
+
from ..utils import (
|
19 |
+
add_end_docstrings,
|
20 |
+
is_tf_available,
|
21 |
+
is_torch_available,
|
22 |
+
is_vision_available,
|
23 |
+
logging,
|
24 |
+
requires_backends,
|
25 |
+
)
|
26 |
+
from .base import Pipeline, build_pipeline_init_args
|
27 |
+
|
28 |
+
|
29 |
+
if is_vision_available():
|
30 |
+
from PIL import Image
|
31 |
+
|
32 |
+
from ..image_utils import load_image
|
33 |
+
|
34 |
+
if is_tf_available():
|
35 |
+
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
36 |
+
|
37 |
+
if is_torch_available():
|
38 |
+
import torch
|
39 |
+
|
40 |
+
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
|
46 |
+
class ImageToTextPipeline(Pipeline):
|
47 |
+
"""
|
48 |
+
Image To Text pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given image.
|
49 |
+
|
50 |
+
Example:
|
51 |
+
|
52 |
+
```python
|
53 |
+
>>> from transformers import pipeline
|
54 |
+
|
55 |
+
>>> captioner = pipeline(model="ydshieh/vit-gpt2-coco-en")
|
56 |
+
>>> captioner("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
57 |
+
[{'generated_text': 'two birds are standing next to each other '}]
|
58 |
+
```
|
59 |
+
|
60 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
61 |
+
|
62 |
+
This image to text pipeline can currently be loaded from pipeline() using the following task identifier:
|
63 |
+
"image-to-text".
|
64 |
+
|
65 |
+
See the list of available models on
|
66 |
+
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-to-text).
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, *args, **kwargs):
|
70 |
+
super().__init__(*args, **kwargs)
|
71 |
+
requires_backends(self, "vision")
|
72 |
+
self.check_model_type(
|
73 |
+
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
74 |
+
)
|
75 |
+
|
76 |
+
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
|
77 |
+
forward_params = {}
|
78 |
+
preprocess_params = {}
|
79 |
+
|
80 |
+
if prompt is not None:
|
81 |
+
preprocess_params["prompt"] = prompt
|
82 |
+
if timeout is not None:
|
83 |
+
preprocess_params["timeout"] = timeout
|
84 |
+
|
85 |
+
if max_new_tokens is not None:
|
86 |
+
forward_params["max_new_tokens"] = max_new_tokens
|
87 |
+
if generate_kwargs is not None:
|
88 |
+
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
89 |
+
raise ValueError(
|
90 |
+
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
|
91 |
+
" only 1 version"
|
92 |
+
)
|
93 |
+
forward_params.update(generate_kwargs)
|
94 |
+
|
95 |
+
return preprocess_params, forward_params, {}
|
96 |
+
|
97 |
+
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
98 |
+
"""
|
99 |
+
Assign labels to the image(s) passed as inputs.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
103 |
+
The pipeline handles three types of images:
|
104 |
+
|
105 |
+
- A string containing a HTTP(s) link pointing to an image
|
106 |
+
- A string containing a local path to an image
|
107 |
+
- An image loaded in PIL directly
|
108 |
+
|
109 |
+
The pipeline accepts either a single image or a batch of images.
|
110 |
+
|
111 |
+
max_new_tokens (`int`, *optional*):
|
112 |
+
The amount of maximum tokens to generate. By default it will use `generate` default.
|
113 |
+
|
114 |
+
generate_kwargs (`Dict`, *optional*):
|
115 |
+
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
|
116 |
+
|
117 |
+
timeout (`float`, *optional*, defaults to None):
|
118 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
119 |
+
the call may block forever.
|
120 |
+
|
121 |
+
Return:
|
122 |
+
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
123 |
+
|
124 |
+
- **generated_text** (`str`) -- The generated text.
|
125 |
+
"""
|
126 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
127 |
+
if "images" in kwargs:
|
128 |
+
inputs = kwargs.pop("images")
|
129 |
+
if inputs is None:
|
130 |
+
raise ValueError("Cannot call the image-to-text pipeline without an inputs argument!")
|
131 |
+
return super().__call__(inputs, **kwargs)
|
132 |
+
|
133 |
+
def preprocess(self, image, prompt=None, timeout=None):
|
134 |
+
image = load_image(image, timeout=timeout)
|
135 |
+
|
136 |
+
if prompt is not None:
|
137 |
+
logger.warning_once(
|
138 |
+
"Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48"
|
139 |
+
" of 🤗 Transformers. Use the `image-text-to-text` pipeline instead",
|
140 |
+
)
|
141 |
+
if not isinstance(prompt, str):
|
142 |
+
raise ValueError(
|
143 |
+
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
|
144 |
+
"Note also that one single text can be provided for conditional image to text generation."
|
145 |
+
)
|
146 |
+
|
147 |
+
model_type = self.model.config.model_type
|
148 |
+
|
149 |
+
if model_type == "git":
|
150 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
151 |
+
if self.framework == "pt":
|
152 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
153 |
+
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
|
154 |
+
input_ids = [self.tokenizer.cls_token_id] + input_ids
|
155 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
156 |
+
model_inputs.update({"input_ids": input_ids})
|
157 |
+
|
158 |
+
elif model_type == "pix2struct":
|
159 |
+
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
|
160 |
+
if self.framework == "pt":
|
161 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
162 |
+
|
163 |
+
elif model_type != "vision-encoder-decoder":
|
164 |
+
# vision-encoder-decoder does not support conditional generation
|
165 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
166 |
+
if self.framework == "pt":
|
167 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
168 |
+
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
|
169 |
+
model_inputs.update(text_inputs)
|
170 |
+
|
171 |
+
else:
|
172 |
+
raise ValueError(f"Model type {model_type} does not support conditional text generation")
|
173 |
+
|
174 |
+
else:
|
175 |
+
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
176 |
+
if self.framework == "pt":
|
177 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
178 |
+
|
179 |
+
if self.model.config.model_type == "git" and prompt is None:
|
180 |
+
model_inputs["input_ids"] = None
|
181 |
+
|
182 |
+
return model_inputs
|
183 |
+
|
184 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
185 |
+
# Git model sets `model_inputs["input_ids"] = None` in `preprocess` (when `prompt=None`). In batch model, the
|
186 |
+
# pipeline will group them into a list of `None`, which fail `_forward`. Avoid this by checking it first.
|
187 |
+
if (
|
188 |
+
"input_ids" in model_inputs
|
189 |
+
and isinstance(model_inputs["input_ids"], list)
|
190 |
+
and all(x is None for x in model_inputs["input_ids"])
|
191 |
+
):
|
192 |
+
model_inputs["input_ids"] = None
|
193 |
+
|
194 |
+
# User-defined `generation_config` passed to the pipeline call take precedence
|
195 |
+
if "generation_config" not in generate_kwargs:
|
196 |
+
generate_kwargs["generation_config"] = self.generation_config
|
197 |
+
|
198 |
+
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
|
199 |
+
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
200 |
+
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
201 |
+
# in the `_prepare_model_inputs` method.
|
202 |
+
inputs = model_inputs.pop(self.model.main_input_name)
|
203 |
+
model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
|
204 |
+
return model_outputs
|
205 |
+
|
206 |
+
def postprocess(self, model_outputs):
|
207 |
+
records = []
|
208 |
+
for output_ids in model_outputs:
|
209 |
+
record = {
|
210 |
+
"generated_text": self.tokenizer.decode(
|
211 |
+
output_ids,
|
212 |
+
skip_special_tokens=True,
|
213 |
+
)
|
214 |
+
}
|
215 |
+
records.append(record)
|
216 |
+
return records
|
.venv/Lib/site-packages/transformers/pipelines/mask_generation.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from ..image_utils import load_image
|
5 |
+
from ..utils import (
|
6 |
+
add_end_docstrings,
|
7 |
+
is_torch_available,
|
8 |
+
logging,
|
9 |
+
requires_backends,
|
10 |
+
)
|
11 |
+
from .base import ChunkPipeline, build_pipeline_init_args
|
12 |
+
|
13 |
+
|
14 |
+
if is_torch_available():
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@add_end_docstrings(
|
23 |
+
build_pipeline_init_args(has_image_processor=True),
|
24 |
+
r"""
|
25 |
+
points_per_batch (*optional*, int, default to 64):
|
26 |
+
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
|
27 |
+
memory.
|
28 |
+
output_bboxes_mask (`bool`, *optional*, default to `False`):
|
29 |
+
Whether or not to output the bounding box predictions.
|
30 |
+
output_rle_masks (`bool`, *optional*, default to `False`):
|
31 |
+
Whether or not to output the masks in `RLE` format""",
|
32 |
+
)
|
33 |
+
class MaskGenerationPipeline(ChunkPipeline):
|
34 |
+
"""
|
35 |
+
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
|
36 |
+
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
|
37 |
+
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
|
38 |
+
same time. Default is `64`.
|
39 |
+
|
40 |
+
The pipeline works in 3 steps:
|
41 |
+
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
|
42 |
+
labels.
|
43 |
+
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
|
44 |
+
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
|
45 |
+
`points_per_batch`.
|
46 |
+
|
47 |
+
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
|
48 |
+
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
|
49 |
+
tensors and models are on the same device.
|
50 |
+
|
51 |
+
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
|
52 |
+
are induced:
|
53 |
+
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
|
54 |
+
resizes them according
|
55 |
+
to the image size, and transforms there to binary masks.
|
56 |
+
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
|
57 |
+
`stability_scores`. Also
|
58 |
+
applies a variety of filters based on non maximum suppression to remove bad masks.
|
59 |
+
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
|
60 |
+
|
61 |
+
Example:
|
62 |
+
|
63 |
+
```python
|
64 |
+
>>> from transformers import pipeline
|
65 |
+
|
66 |
+
>>> generator = pipeline(model="facebook/sam-vit-base", task="mask-generation")
|
67 |
+
>>> outputs = generator(
|
68 |
+
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
69 |
+
... )
|
70 |
+
|
71 |
+
>>> outputs = generator(
|
72 |
+
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
|
73 |
+
... )
|
74 |
+
```
|
75 |
+
|
76 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
77 |
+
|
78 |
+
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
79 |
+
`"mask-generation"`.
|
80 |
+
|
81 |
+
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, **kwargs):
|
85 |
+
super().__init__(**kwargs)
|
86 |
+
requires_backends(self, "vision")
|
87 |
+
requires_backends(self, "torch")
|
88 |
+
|
89 |
+
if self.framework != "pt":
|
90 |
+
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
91 |
+
|
92 |
+
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
93 |
+
|
94 |
+
def _sanitize_parameters(self, **kwargs):
|
95 |
+
preprocess_kwargs = {}
|
96 |
+
postprocess_kwargs = {}
|
97 |
+
forward_params = {}
|
98 |
+
# preprocess args
|
99 |
+
if "points_per_batch" in kwargs:
|
100 |
+
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
101 |
+
if "points_per_crop" in kwargs:
|
102 |
+
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
103 |
+
if "crops_n_layers" in kwargs:
|
104 |
+
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
105 |
+
if "crop_overlap_ratio" in kwargs:
|
106 |
+
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
107 |
+
if "crop_n_points_downscale_factor" in kwargs:
|
108 |
+
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
109 |
+
if "timeout" in kwargs:
|
110 |
+
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
111 |
+
# postprocess args
|
112 |
+
if "pred_iou_thresh" in kwargs:
|
113 |
+
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
114 |
+
if "stability_score_offset" in kwargs:
|
115 |
+
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
116 |
+
if "mask_threshold" in kwargs:
|
117 |
+
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
118 |
+
if "stability_score_thresh" in kwargs:
|
119 |
+
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
120 |
+
if "crops_nms_thresh" in kwargs:
|
121 |
+
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
122 |
+
if "output_rle_mask" in kwargs:
|
123 |
+
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
124 |
+
if "output_bboxes_mask" in kwargs:
|
125 |
+
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
126 |
+
return preprocess_kwargs, forward_params, postprocess_kwargs
|
127 |
+
|
128 |
+
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
|
129 |
+
"""
|
130 |
+
Generates binary segmentation masks
|
131 |
+
|
132 |
+
Args:
|
133 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
134 |
+
Image or list of images.
|
135 |
+
mask_threshold (`float`, *optional*, defaults to 0.0):
|
136 |
+
Threshold to use when turning the predicted masks into binary values.
|
137 |
+
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
138 |
+
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
|
139 |
+
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
140 |
+
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
|
141 |
+
binarize the model's mask predictions.
|
142 |
+
stability_score_offset (`int`, *optional*, defaults to 1):
|
143 |
+
The amount to shift the cutoff when calculated the stability score.
|
144 |
+
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
145 |
+
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
|
146 |
+
crops_n_layers (`int`, *optional*, defaults to 0):
|
147 |
+
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
|
148 |
+
layers to run, where each layer has 2**i_layer number of image crops.
|
149 |
+
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
|
150 |
+
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
151 |
+
the image length. Later layers with more crops scale down this overlap.
|
152 |
+
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
|
153 |
+
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
154 |
+
timeout (`float`, *optional*, defaults to None):
|
155 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
156 |
+
the call may block forever.
|
157 |
+
|
158 |
+
Return:
|
159 |
+
`Dict`: A dictionary with the following keys:
|
160 |
+
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
|
161 |
+
height)` of the original image. Returns a mask filled with zeros if no object is found.
|
162 |
+
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
|
163 |
+
the "object" described by the label and the mask.
|
164 |
+
|
165 |
+
"""
|
166 |
+
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
|
167 |
+
|
168 |
+
def preprocess(
|
169 |
+
self,
|
170 |
+
image,
|
171 |
+
points_per_batch=64,
|
172 |
+
crops_n_layers: int = 0,
|
173 |
+
crop_overlap_ratio: float = 512 / 1500,
|
174 |
+
points_per_crop: Optional[int] = 32,
|
175 |
+
crop_n_points_downscale_factor: Optional[int] = 1,
|
176 |
+
timeout: Optional[float] = None,
|
177 |
+
):
|
178 |
+
image = load_image(image, timeout=timeout)
|
179 |
+
target_size = self.image_processor.size["longest_edge"]
|
180 |
+
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
181 |
+
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
182 |
+
)
|
183 |
+
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
184 |
+
if self.framework == "pt":
|
185 |
+
model_inputs = model_inputs.to(self.torch_dtype)
|
186 |
+
|
187 |
+
with self.device_placement():
|
188 |
+
if self.framework == "pt":
|
189 |
+
inference_context = self.get_inference_context()
|
190 |
+
with inference_context():
|
191 |
+
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
192 |
+
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
193 |
+
model_inputs["image_embeddings"] = image_embeddings
|
194 |
+
|
195 |
+
n_points = grid_points.shape[1]
|
196 |
+
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
197 |
+
|
198 |
+
if points_per_batch <= 0:
|
199 |
+
raise ValueError(
|
200 |
+
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
201 |
+
"To return all points at once, set points_per_batch to None"
|
202 |
+
)
|
203 |
+
|
204 |
+
for i in range(0, n_points, points_per_batch):
|
205 |
+
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
206 |
+
labels = input_labels[:, i : i + points_per_batch]
|
207 |
+
is_last = i == n_points - points_per_batch
|
208 |
+
yield {
|
209 |
+
"input_points": batched_points,
|
210 |
+
"input_labels": labels,
|
211 |
+
"input_boxes": crop_boxes,
|
212 |
+
"is_last": is_last,
|
213 |
+
**model_inputs,
|
214 |
+
}
|
215 |
+
|
216 |
+
def _forward(
|
217 |
+
self,
|
218 |
+
model_inputs,
|
219 |
+
pred_iou_thresh=0.88,
|
220 |
+
stability_score_thresh=0.95,
|
221 |
+
mask_threshold=0,
|
222 |
+
stability_score_offset=1,
|
223 |
+
):
|
224 |
+
input_boxes = model_inputs.pop("input_boxes")
|
225 |
+
is_last = model_inputs.pop("is_last")
|
226 |
+
original_sizes = model_inputs.pop("original_sizes").tolist()
|
227 |
+
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
|
228 |
+
|
229 |
+
model_outputs = self.model(**model_inputs)
|
230 |
+
|
231 |
+
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
|
232 |
+
low_resolution_masks = model_outputs["pred_masks"]
|
233 |
+
masks = self.image_processor.post_process_masks(
|
234 |
+
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
|
235 |
+
)
|
236 |
+
iou_scores = model_outputs["iou_scores"]
|
237 |
+
masks, iou_scores, boxes = self.image_processor.filter_masks(
|
238 |
+
masks[0],
|
239 |
+
iou_scores[0],
|
240 |
+
original_sizes[0],
|
241 |
+
input_boxes[0],
|
242 |
+
pred_iou_thresh,
|
243 |
+
stability_score_thresh,
|
244 |
+
mask_threshold,
|
245 |
+
stability_score_offset,
|
246 |
+
)
|
247 |
+
return {
|
248 |
+
"masks": masks,
|
249 |
+
"is_last": is_last,
|
250 |
+
"boxes": boxes,
|
251 |
+
"iou_scores": iou_scores,
|
252 |
+
}
|
253 |
+
|
254 |
+
def postprocess(
|
255 |
+
self,
|
256 |
+
model_outputs,
|
257 |
+
output_rle_mask=False,
|
258 |
+
output_bboxes_mask=False,
|
259 |
+
crops_nms_thresh=0.7,
|
260 |
+
):
|
261 |
+
all_scores = []
|
262 |
+
all_masks = []
|
263 |
+
all_boxes = []
|
264 |
+
for model_output in model_outputs:
|
265 |
+
all_scores.append(model_output.pop("iou_scores"))
|
266 |
+
all_masks.extend(model_output.pop("masks"))
|
267 |
+
all_boxes.append(model_output.pop("boxes"))
|
268 |
+
|
269 |
+
all_scores = torch.cat(all_scores)
|
270 |
+
all_boxes = torch.cat(all_boxes)
|
271 |
+
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
|
272 |
+
all_masks, all_scores, all_boxes, crops_nms_thresh
|
273 |
+
)
|
274 |
+
|
275 |
+
extra = defaultdict(list)
|
276 |
+
for output in model_outputs:
|
277 |
+
for k, v in output.items():
|
278 |
+
extra[k].append(v)
|
279 |
+
|
280 |
+
optional = {}
|
281 |
+
if output_rle_mask:
|
282 |
+
optional["rle_mask"] = rle_mask
|
283 |
+
|
284 |
+
if output_bboxes_mask:
|
285 |
+
optional["bounding_boxes"] = bounding_boxes
|
286 |
+
|
287 |
+
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
|
.venv/Lib/site-packages/transformers/pipelines/object_detection.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union
|
2 |
+
|
3 |
+
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
4 |
+
from .base import Pipeline, build_pipeline_init_args
|
5 |
+
|
6 |
+
|
7 |
+
if is_vision_available():
|
8 |
+
from ..image_utils import load_image
|
9 |
+
|
10 |
+
|
11 |
+
if is_torch_available():
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from ..models.auto.modeling_auto import (
|
15 |
+
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
|
16 |
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
17 |
+
)
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
Prediction = Dict[str, Any]
|
23 |
+
Predictions = List[Prediction]
|
24 |
+
|
25 |
+
|
26 |
+
@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
|
27 |
+
class ObjectDetectionPipeline(Pipeline):
|
28 |
+
"""
|
29 |
+
Object detection pipeline using any `AutoModelForObjectDetection`. This pipeline predicts bounding boxes of objects
|
30 |
+
and their classes.
|
31 |
+
|
32 |
+
Example:
|
33 |
+
|
34 |
+
```python
|
35 |
+
>>> from transformers import pipeline
|
36 |
+
|
37 |
+
>>> detector = pipeline(model="facebook/detr-resnet-50")
|
38 |
+
>>> detector("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
39 |
+
[{'score': 0.997, 'label': 'bird', 'box': {'xmin': 69, 'ymin': 171, 'xmax': 396, 'ymax': 507}}, {'score': 0.999, 'label': 'bird', 'box': {'xmin': 398, 'ymin': 105, 'xmax': 767, 'ymax': 507}}]
|
40 |
+
|
41 |
+
>>> # x, y are expressed relative to the top left hand corner.
|
42 |
+
```
|
43 |
+
|
44 |
+
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
45 |
+
|
46 |
+
This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
47 |
+
`"object-detection"`.
|
48 |
+
|
49 |
+
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=object-detection).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, *args, **kwargs):
|
53 |
+
super().__init__(*args, **kwargs)
|
54 |
+
|
55 |
+
if self.framework == "tf":
|
56 |
+
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
57 |
+
|
58 |
+
requires_backends(self, "vision")
|
59 |
+
mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES.copy()
|
60 |
+
mapping.update(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
|
61 |
+
self.check_model_type(mapping)
|
62 |
+
|
63 |
+
def _sanitize_parameters(self, **kwargs):
|
64 |
+
preprocess_params = {}
|
65 |
+
if "timeout" in kwargs:
|
66 |
+
preprocess_params["timeout"] = kwargs["timeout"]
|
67 |
+
postprocess_kwargs = {}
|
68 |
+
if "threshold" in kwargs:
|
69 |
+
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
70 |
+
return preprocess_params, {}, postprocess_kwargs
|
71 |
+
|
72 |
+
def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
|
73 |
+
"""
|
74 |
+
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
78 |
+
The pipeline handles three types of images:
|
79 |
+
|
80 |
+
- A string containing an HTTP(S) link pointing to an image
|
81 |
+
- A string containing a local path to an image
|
82 |
+
- An image loaded in PIL directly
|
83 |
+
|
84 |
+
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
85 |
+
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
86 |
+
threshold (`float`, *optional*, defaults to 0.5):
|
87 |
+
The probability necessary to make a prediction.
|
88 |
+
timeout (`float`, *optional*, defaults to None):
|
89 |
+
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
90 |
+
the call may block forever.
|
91 |
+
|
92 |
+
Return:
|
93 |
+
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
|
94 |
+
image, will return a list of dictionaries, if the input is a list of several images, will return a list of
|
95 |
+
list of dictionaries corresponding to each image.
|
96 |
+
|
97 |
+
The dictionaries contain the following keys:
|
98 |
+
|
99 |
+
- **label** (`str`) -- The class label identified by the model.
|
100 |
+
- **score** (`float`) -- The score attributed by the model for that label.
|
101 |
+
- **box** (`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
|
102 |
+
"""
|
103 |
+
# After deprecation of this is completed, remove the default `None` value for `images`
|
104 |
+
if "images" in kwargs and "inputs" not in kwargs:
|
105 |
+
kwargs["inputs"] = kwargs.pop("images")
|
106 |
+
return super().__call__(*args, **kwargs)
|
107 |
+
|
108 |
+
def preprocess(self, image, timeout=None):
|
109 |
+
image = load_image(image, timeout=timeout)
|
110 |
+
target_size = torch.IntTensor([[image.height, image.width]])
|
111 |
+
inputs = self.image_processor(images=[image], return_tensors="pt")
|
112 |
+
if self.framework == "pt":
|
113 |
+
inputs = inputs.to(self.torch_dtype)
|
114 |
+
if self.tokenizer is not None:
|
115 |
+
inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
|
116 |
+
inputs["target_size"] = target_size
|
117 |
+
return inputs
|
118 |
+
|
119 |
+
def _forward(self, model_inputs):
|
120 |
+
target_size = model_inputs.pop("target_size")
|
121 |
+
outputs = self.model(**model_inputs)
|
122 |
+
model_outputs = outputs.__class__({"target_size": target_size, **outputs})
|
123 |
+
if self.tokenizer is not None:
|
124 |
+
model_outputs["bbox"] = model_inputs["bbox"]
|
125 |
+
return model_outputs
|
126 |
+
|
127 |
+
def postprocess(self, model_outputs, threshold=0.5):
|
128 |
+
target_size = model_outputs["target_size"]
|
129 |
+
if self.tokenizer is not None:
|
130 |
+
# This is a LayoutLMForTokenClassification variant.
|
131 |
+
# The OCR got the boxes and the model classified the words.
|
132 |
+
height, width = target_size[0].tolist()
|
133 |
+
|
134 |
+
def unnormalize(bbox):
|
135 |
+
return self._get_bounding_box(
|
136 |
+
torch.Tensor(
|
137 |
+
[
|
138 |
+
(width * bbox[0] / 1000),
|
139 |
+
(height * bbox[1] / 1000),
|
140 |
+
(width * bbox[2] / 1000),
|
141 |
+
(height * bbox[3] / 1000),
|
142 |
+
]
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1)
|
147 |
+
labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()]
|
148 |
+
boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)]
|
149 |
+
keys = ["score", "label", "box"]
|
150 |
+
annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold]
|
151 |
+
else:
|
152 |
+
# This is a regular ForObjectDetectionModel
|
153 |
+
raw_annotations = self.image_processor.post_process_object_detection(model_outputs, threshold, target_size)
|
154 |
+
raw_annotation = raw_annotations[0]
|
155 |
+
scores = raw_annotation["scores"]
|
156 |
+
labels = raw_annotation["labels"]
|
157 |
+
boxes = raw_annotation["boxes"]
|
158 |
+
|
159 |
+
raw_annotation["scores"] = scores.tolist()
|
160 |
+
raw_annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
|
161 |
+
raw_annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
|
162 |
+
|
163 |
+
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
|
164 |
+
keys = ["score", "label", "box"]
|
165 |
+
annotation = [
|
166 |
+
dict(zip(keys, vals))
|
167 |
+
for vals in zip(raw_annotation["scores"], raw_annotation["labels"], raw_annotation["boxes"])
|
168 |
+
]
|
169 |
+
|
170 |
+
return annotation
|
171 |
+
|
172 |
+
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
|
173 |
+
"""
|
174 |
+
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
|
175 |
+
|
176 |
+
Args:
|
177 |
+
box (`torch.Tensor`): Tensor containing the coordinates in corners format.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
|
181 |
+
"""
|
182 |
+
if self.framework != "pt":
|
183 |
+
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
|
184 |
+
xmin, ymin, xmax, ymax = box.int().tolist()
|
185 |
+
bbox = {
|
186 |
+
"xmin": xmin,
|
187 |
+
"ymin": ymin,
|
188 |
+
"xmax": xmax,
|
189 |
+
"ymax": ymax,
|
190 |
+
}
|
191 |
+
return bbox
|
.venv/Lib/site-packages/transformers/utils/__init__.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from functools import lru_cache
|
19 |
+
from typing import FrozenSet
|
20 |
+
|
21 |
+
from huggingface_hub import get_full_repo_name # for backward compatibility
|
22 |
+
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
|
23 |
+
from packaging import version
|
24 |
+
|
25 |
+
from .. import __version__
|
26 |
+
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
27 |
+
from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
28 |
+
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
29 |
+
from .doc import (
|
30 |
+
add_code_sample_docstrings,
|
31 |
+
add_end_docstrings,
|
32 |
+
add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
copy_func,
|
35 |
+
replace_return_docstrings,
|
36 |
+
)
|
37 |
+
from .generic import (
|
38 |
+
ContextManagers,
|
39 |
+
ExplicitEnum,
|
40 |
+
LossKwargs,
|
41 |
+
ModelOutput,
|
42 |
+
PaddingStrategy,
|
43 |
+
TensorType,
|
44 |
+
add_model_info_to_auto_map,
|
45 |
+
add_model_info_to_custom_pipelines,
|
46 |
+
cached_property,
|
47 |
+
can_return_loss,
|
48 |
+
expand_dims,
|
49 |
+
filter_out_non_signature_kwargs,
|
50 |
+
find_labels,
|
51 |
+
flatten_dict,
|
52 |
+
infer_framework,
|
53 |
+
is_jax_tensor,
|
54 |
+
is_numpy_array,
|
55 |
+
is_tensor,
|
56 |
+
is_tf_symbolic_tensor,
|
57 |
+
is_tf_tensor,
|
58 |
+
is_torch_device,
|
59 |
+
is_torch_dtype,
|
60 |
+
is_torch_tensor,
|
61 |
+
reshape,
|
62 |
+
squeeze,
|
63 |
+
strtobool,
|
64 |
+
tensor_size,
|
65 |
+
to_numpy,
|
66 |
+
to_py_obj,
|
67 |
+
torch_float,
|
68 |
+
torch_int,
|
69 |
+
transpose,
|
70 |
+
working_or_temp_dir,
|
71 |
+
)
|
72 |
+
from .hub import (
|
73 |
+
CLOUDFRONT_DISTRIB_PREFIX,
|
74 |
+
HF_MODULES_CACHE,
|
75 |
+
HUGGINGFACE_CO_PREFIX,
|
76 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
77 |
+
PYTORCH_PRETRAINED_BERT_CACHE,
|
78 |
+
PYTORCH_TRANSFORMERS_CACHE,
|
79 |
+
S3_BUCKET_PREFIX,
|
80 |
+
TRANSFORMERS_CACHE,
|
81 |
+
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
82 |
+
EntryNotFoundError,
|
83 |
+
PushInProgress,
|
84 |
+
PushToHubMixin,
|
85 |
+
RepositoryNotFoundError,
|
86 |
+
RevisionNotFoundError,
|
87 |
+
cached_file,
|
88 |
+
default_cache_path,
|
89 |
+
define_sagemaker_information,
|
90 |
+
download_url,
|
91 |
+
extract_commit_hash,
|
92 |
+
get_cached_models,
|
93 |
+
get_file_from_repo,
|
94 |
+
has_file,
|
95 |
+
http_user_agent,
|
96 |
+
is_offline_mode,
|
97 |
+
is_remote_url,
|
98 |
+
move_cache,
|
99 |
+
send_example_telemetry,
|
100 |
+
try_to_load_from_cache,
|
101 |
+
)
|
102 |
+
from .import_utils import (
|
103 |
+
ACCELERATE_MIN_VERSION,
|
104 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
105 |
+
ENV_VARS_TRUE_VALUES,
|
106 |
+
GGUF_MIN_VERSION,
|
107 |
+
TORCH_FX_REQUIRED_VERSION,
|
108 |
+
USE_JAX,
|
109 |
+
USE_TF,
|
110 |
+
USE_TORCH,
|
111 |
+
XLA_FSDPV2_MIN_VERSION,
|
112 |
+
DummyObject,
|
113 |
+
OptionalDependencyNotAvailable,
|
114 |
+
_LazyModule,
|
115 |
+
ccl_version,
|
116 |
+
direct_transformers_import,
|
117 |
+
get_torch_version,
|
118 |
+
is_accelerate_available,
|
119 |
+
is_apex_available,
|
120 |
+
is_aqlm_available,
|
121 |
+
is_auto_awq_available,
|
122 |
+
is_auto_gptq_available,
|
123 |
+
is_av_available,
|
124 |
+
is_bitsandbytes_available,
|
125 |
+
is_bitsandbytes_multi_backend_available,
|
126 |
+
is_bs4_available,
|
127 |
+
is_coloredlogs_available,
|
128 |
+
is_compressed_tensors_available,
|
129 |
+
is_cv2_available,
|
130 |
+
is_cython_available,
|
131 |
+
is_datasets_available,
|
132 |
+
is_detectron2_available,
|
133 |
+
is_eetq_available,
|
134 |
+
is_essentia_available,
|
135 |
+
is_faiss_available,
|
136 |
+
is_fbgemm_gpu_available,
|
137 |
+
is_flash_attn_2_available,
|
138 |
+
is_flash_attn_greater_or_equal,
|
139 |
+
is_flash_attn_greater_or_equal_2_10,
|
140 |
+
is_flax_available,
|
141 |
+
is_fsdp_available,
|
142 |
+
is_ftfy_available,
|
143 |
+
is_g2p_en_available,
|
144 |
+
is_galore_torch_available,
|
145 |
+
is_gguf_available,
|
146 |
+
is_grokadamw_available,
|
147 |
+
is_hqq_available,
|
148 |
+
is_in_notebook,
|
149 |
+
is_ipex_available,
|
150 |
+
is_jieba_available,
|
151 |
+
is_jinja_available,
|
152 |
+
is_jumanpp_available,
|
153 |
+
is_kenlm_available,
|
154 |
+
is_keras_nlp_available,
|
155 |
+
is_levenshtein_available,
|
156 |
+
is_librosa_available,
|
157 |
+
is_liger_kernel_available,
|
158 |
+
is_lomo_available,
|
159 |
+
is_mlx_available,
|
160 |
+
is_natten_available,
|
161 |
+
is_ninja_available,
|
162 |
+
is_nltk_available,
|
163 |
+
is_onnx_available,
|
164 |
+
is_openai_available,
|
165 |
+
is_optimum_available,
|
166 |
+
is_optimum_quanto_available,
|
167 |
+
is_pandas_available,
|
168 |
+
is_peft_available,
|
169 |
+
is_phonemizer_available,
|
170 |
+
is_pretty_midi_available,
|
171 |
+
is_protobuf_available,
|
172 |
+
is_psutil_available,
|
173 |
+
is_py3nvml_available,
|
174 |
+
is_pyctcdecode_available,
|
175 |
+
is_pytesseract_available,
|
176 |
+
is_pytest_available,
|
177 |
+
is_pytorch_quantization_available,
|
178 |
+
is_rjieba_available,
|
179 |
+
is_sacremoses_available,
|
180 |
+
is_safetensors_available,
|
181 |
+
is_sagemaker_dp_enabled,
|
182 |
+
is_sagemaker_mp_enabled,
|
183 |
+
is_schedulefree_available,
|
184 |
+
is_scipy_available,
|
185 |
+
is_sentencepiece_available,
|
186 |
+
is_seqio_available,
|
187 |
+
is_sklearn_available,
|
188 |
+
is_soundfile_availble,
|
189 |
+
is_spacy_available,
|
190 |
+
is_speech_available,
|
191 |
+
is_sudachi_available,
|
192 |
+
is_sudachi_projection_available,
|
193 |
+
is_tensorflow_probability_available,
|
194 |
+
is_tensorflow_text_available,
|
195 |
+
is_tf2onnx_available,
|
196 |
+
is_tf_available,
|
197 |
+
is_tiktoken_available,
|
198 |
+
is_timm_available,
|
199 |
+
is_tokenizers_available,
|
200 |
+
is_torch_available,
|
201 |
+
is_torch_bf16_available,
|
202 |
+
is_torch_bf16_available_on_device,
|
203 |
+
is_torch_bf16_cpu_available,
|
204 |
+
is_torch_bf16_gpu_available,
|
205 |
+
is_torch_compile_available,
|
206 |
+
is_torch_cuda_available,
|
207 |
+
is_torch_deterministic,
|
208 |
+
is_torch_flex_attn_available,
|
209 |
+
is_torch_fp16_available_on_device,
|
210 |
+
is_torch_fx_available,
|
211 |
+
is_torch_fx_proxy,
|
212 |
+
is_torch_greater_or_equal,
|
213 |
+
is_torch_mlu_available,
|
214 |
+
is_torch_mps_available,
|
215 |
+
is_torch_musa_available,
|
216 |
+
is_torch_neuroncore_available,
|
217 |
+
is_torch_npu_available,
|
218 |
+
is_torch_sdpa_available,
|
219 |
+
is_torch_tensorrt_fx_available,
|
220 |
+
is_torch_tf32_available,
|
221 |
+
is_torch_tpu_available,
|
222 |
+
is_torch_xla_available,
|
223 |
+
is_torch_xpu_available,
|
224 |
+
is_torchao_available,
|
225 |
+
is_torchaudio_available,
|
226 |
+
is_torchdistx_available,
|
227 |
+
is_torchdynamo_available,
|
228 |
+
is_torchdynamo_compiling,
|
229 |
+
is_torchvision_available,
|
230 |
+
is_torchvision_v2_available,
|
231 |
+
is_training_run_on_sagemaker,
|
232 |
+
is_uroman_available,
|
233 |
+
is_vision_available,
|
234 |
+
requires_backends,
|
235 |
+
torch_only_method,
|
236 |
+
)
|
237 |
+
from .peft_utils import (
|
238 |
+
ADAPTER_CONFIG_NAME,
|
239 |
+
ADAPTER_SAFE_WEIGHTS_NAME,
|
240 |
+
ADAPTER_WEIGHTS_NAME,
|
241 |
+
check_peft_version,
|
242 |
+
find_adapter_config_file,
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
247 |
+
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
248 |
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
249 |
+
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
250 |
+
TF_WEIGHTS_NAME = "model.ckpt"
|
251 |
+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
252 |
+
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
253 |
+
SAFE_WEIGHTS_NAME = "model.safetensors"
|
254 |
+
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
255 |
+
CONFIG_NAME = "config.json"
|
256 |
+
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
257 |
+
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
|
258 |
+
PROCESSOR_NAME = "processor_config.json"
|
259 |
+
CHAT_TEMPLATE_NAME = "chat_template.json"
|
260 |
+
GENERATION_CONFIG_NAME = "generation_config.json"
|
261 |
+
MODEL_CARD_NAME = "modelcard.json"
|
262 |
+
|
263 |
+
SENTENCEPIECE_UNDERLINE = "▁"
|
264 |
+
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
|
265 |
+
|
266 |
+
MULTIPLE_CHOICE_DUMMY_INPUTS = [
|
267 |
+
[[0, 1, 0, 1], [1, 0, 0, 1]]
|
268 |
+
] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
|
269 |
+
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
270 |
+
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
271 |
+
|
272 |
+
|
273 |
+
def check_min_version(min_version):
|
274 |
+
if version.parse(__version__) < version.parse(min_version):
|
275 |
+
if "dev" in min_version:
|
276 |
+
error_message = (
|
277 |
+
"This example requires a source install from HuggingFace Transformers (see "
|
278 |
+
"`https://huggingface.co/docs/transformers/installation#install-from-source`),"
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
error_message = f"This example requires a minimum version of {min_version},"
|
282 |
+
error_message += f" but the version found is {__version__}.\n"
|
283 |
+
raise ImportError(
|
284 |
+
error_message
|
285 |
+
+ "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
|
286 |
+
"versions of HuggingFace Transformers."
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
@lru_cache()
|
291 |
+
def get_available_devices() -> FrozenSet[str]:
|
292 |
+
"""
|
293 |
+
Returns a frozenset of devices available for the current PyTorch installation.
|
294 |
+
"""
|
295 |
+
devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
|
296 |
+
|
297 |
+
if is_torch_cuda_available():
|
298 |
+
devices.add("cuda")
|
299 |
+
|
300 |
+
if is_torch_mps_available():
|
301 |
+
devices.add("mps")
|
302 |
+
|
303 |
+
if is_torch_xpu_available():
|
304 |
+
devices.add("xpu")
|
305 |
+
|
306 |
+
if is_torch_npu_available():
|
307 |
+
devices.add("npu")
|
308 |
+
|
309 |
+
if is_torch_mlu_available():
|
310 |
+
devices.add("mlu")
|
311 |
+
|
312 |
+
if is_torch_musa_available():
|
313 |
+
devices.add("musa")
|
314 |
+
|
315 |
+
return frozenset(devices)
|
.venv/Lib/site-packages/transformers/utils/__pycache__/backbone_utils.cpython-39.pyc
ADDED
Binary file (13.8 kB). View file
|
|
.venv/Lib/site-packages/transformers/utils/__pycache__/chat_template_utils.cpython-39.pyc
ADDED
Binary file (14.6 kB). View file
|
|
.venv/Lib/site-packages/transformers/utils/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (510 Bytes). View file
|
|
.venv/Lib/site-packages/transformers/utils/__pycache__/deprecation.cpython-39.pyc
ADDED
Binary file (5.43 kB). View file
|
|
.venv/Lib/site-packages/transformers/utils/__pycache__/doc.cpython-39.pyc
ADDED
Binary file (36 kB). View file
|
|
.venv/Lib/site-packages/transformers/utils/quantization_config.py
ADDED
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
import copy
|
18 |
+
import importlib.metadata
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from enum import Enum
|
23 |
+
from inspect import Parameter, signature
|
24 |
+
from typing import Any, Dict, List, Optional, Union
|
25 |
+
|
26 |
+
from packaging import version
|
27 |
+
|
28 |
+
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
|
29 |
+
|
30 |
+
|
31 |
+
if is_torch_available():
|
32 |
+
import torch
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class QuantizationMethod(str, Enum):
|
38 |
+
BITS_AND_BYTES = "bitsandbytes"
|
39 |
+
GPTQ = "gptq"
|
40 |
+
AWQ = "awq"
|
41 |
+
AQLM = "aqlm"
|
42 |
+
QUANTO = "quanto"
|
43 |
+
EETQ = "eetq"
|
44 |
+
HQQ = "hqq"
|
45 |
+
COMPRESSED_TENSORS = "compressed-tensors"
|
46 |
+
FBGEMM_FP8 = "fbgemm_fp8"
|
47 |
+
TORCHAO = "torchao"
|
48 |
+
BITNET = "bitnet"
|
49 |
+
|
50 |
+
|
51 |
+
class AWQLinearVersion(str, Enum):
|
52 |
+
GEMM = "gemm"
|
53 |
+
GEMV = "gemv"
|
54 |
+
EXLLAMA = "exllama"
|
55 |
+
IPEX = "ipex"
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def from_str(version: str):
|
59 |
+
version = version.lower()
|
60 |
+
if version == "gemm":
|
61 |
+
return AWQLinearVersion.GEMM
|
62 |
+
elif version == "gemv":
|
63 |
+
return AWQLinearVersion.GEMV
|
64 |
+
elif version == "exllama":
|
65 |
+
return AWQLinearVersion.EXLLAMA
|
66 |
+
elif version == "ipex":
|
67 |
+
return AWQLinearVersion.IPEX
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Unknown AWQLinearVersion {version}")
|
70 |
+
|
71 |
+
|
72 |
+
class AwqBackendPackingMethod(str, Enum):
|
73 |
+
AUTOAWQ = "autoawq"
|
74 |
+
LLMAWQ = "llm-awq"
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class QuantizationConfigMixin:
|
79 |
+
"""
|
80 |
+
Mixin class for quantization config
|
81 |
+
"""
|
82 |
+
|
83 |
+
quant_method: QuantizationMethod
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
87 |
+
"""
|
88 |
+
Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
config_dict (`Dict[str, Any]`):
|
92 |
+
Dictionary that will be used to instantiate the configuration object.
|
93 |
+
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
|
94 |
+
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
|
95 |
+
`PreTrainedModel`.
|
96 |
+
kwargs (`Dict[str, Any]`):
|
97 |
+
Additional parameters from which to initialize the configuration object.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
|
101 |
+
"""
|
102 |
+
config = cls(**config_dict)
|
103 |
+
|
104 |
+
to_remove = []
|
105 |
+
for key, value in kwargs.items():
|
106 |
+
if hasattr(config, key):
|
107 |
+
setattr(config, key, value)
|
108 |
+
to_remove.append(key)
|
109 |
+
for key in to_remove:
|
110 |
+
kwargs.pop(key, None)
|
111 |
+
|
112 |
+
if return_unused_kwargs:
|
113 |
+
return config, kwargs
|
114 |
+
else:
|
115 |
+
return config
|
116 |
+
|
117 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
118 |
+
"""
|
119 |
+
Save this instance to a JSON file.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
json_file_path (`str` or `os.PathLike`):
|
123 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
124 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
125 |
+
If set to `True`, only the difference between the config instance and the default
|
126 |
+
`QuantizationConfig()` is serialized to JSON file.
|
127 |
+
"""
|
128 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
129 |
+
config_dict = self.to_dict()
|
130 |
+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
131 |
+
|
132 |
+
writer.write(json_string)
|
133 |
+
|
134 |
+
def to_dict(self) -> Dict[str, Any]:
|
135 |
+
"""
|
136 |
+
Serializes this instance to a Python dictionary. Returns:
|
137 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
138 |
+
"""
|
139 |
+
return copy.deepcopy(self.__dict__)
|
140 |
+
|
141 |
+
def __iter__(self):
|
142 |
+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
143 |
+
for attr, value in copy.deepcopy(self.__dict__).items():
|
144 |
+
yield attr, value
|
145 |
+
|
146 |
+
def __repr__(self):
|
147 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
148 |
+
|
149 |
+
def to_json_string(self, use_diff: bool = True) -> str:
|
150 |
+
"""
|
151 |
+
Serializes this instance to a JSON string.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
155 |
+
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
156 |
+
is serialized to JSON string.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
160 |
+
"""
|
161 |
+
if use_diff is True:
|
162 |
+
config_dict = self.to_diff_dict()
|
163 |
+
else:
|
164 |
+
config_dict = self.to_dict()
|
165 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
166 |
+
|
167 |
+
def update(self, **kwargs):
|
168 |
+
"""
|
169 |
+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
170 |
+
returning all the unused kwargs.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
kwargs (`Dict[str, Any]`):
|
174 |
+
Dictionary of attributes to tentatively update this class.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
178 |
+
"""
|
179 |
+
to_remove = []
|
180 |
+
for key, value in kwargs.items():
|
181 |
+
if hasattr(self, key):
|
182 |
+
setattr(self, key, value)
|
183 |
+
to_remove.append(key)
|
184 |
+
|
185 |
+
# Remove all the attributes that were updated, without modifying the input dict
|
186 |
+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
187 |
+
return unused_kwargs
|
188 |
+
|
189 |
+
|
190 |
+
@dataclass
|
191 |
+
class HqqConfig(QuantizationConfigMixin):
|
192 |
+
"""
|
193 |
+
This is wrapper around hqq's BaseQuantizeConfig.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
nbits (`int`, *optional*, defaults to 4):
|
197 |
+
Number of bits. Supported values are (8, 4, 3, 2, 1).
|
198 |
+
group_size (`int`, *optional*, defaults to 64):
|
199 |
+
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
|
200 |
+
view_as_float (`bool`, *optional*, defaults to `False`):
|
201 |
+
View the quantized weight as float (used in distributed training) if set to `True`.
|
202 |
+
axis (`Optional[int]`, *optional*):
|
203 |
+
Axis along which grouping is performed. Supported values are 0 or 1.
|
204 |
+
dynamic_config (dict, *optional*):
|
205 |
+
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
|
206 |
+
If set, each layer specified by its id will use its dedicated quantization configuration.
|
207 |
+
skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`):
|
208 |
+
List of `nn.Linear` layers to skip.
|
209 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
210 |
+
Additional parameters from which to initialize the configuration object.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
nbits: int = 4,
|
216 |
+
group_size: int = 64,
|
217 |
+
view_as_float: bool = False,
|
218 |
+
axis: Optional[int] = None,
|
219 |
+
dynamic_config: Optional[dict] = None,
|
220 |
+
skip_modules: List[str] = ["lm_head"],
|
221 |
+
**kwargs,
|
222 |
+
):
|
223 |
+
if is_hqq_available():
|
224 |
+
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
|
225 |
+
|
226 |
+
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
|
227 |
+
if deprecated_key in kwargs:
|
228 |
+
logger.info(
|
229 |
+
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
|
230 |
+
)
|
231 |
+
|
232 |
+
if axis is None:
|
233 |
+
axis = 1
|
234 |
+
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
|
235 |
+
|
236 |
+
if axis not in [0, 1]:
|
237 |
+
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
|
238 |
+
|
239 |
+
if dynamic_config is not None:
|
240 |
+
self.quant_config = {}
|
241 |
+
for key in dynamic_config:
|
242 |
+
self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
|
243 |
+
else:
|
244 |
+
self.quant_config = HQQBaseQuantizeConfig(
|
245 |
+
**{
|
246 |
+
"nbits": nbits,
|
247 |
+
"group_size": group_size,
|
248 |
+
"view_as_float": view_as_float,
|
249 |
+
"axis": axis,
|
250 |
+
}
|
251 |
+
)
|
252 |
+
|
253 |
+
self.quant_method = QuantizationMethod.HQQ
|
254 |
+
self.skip_modules = skip_modules
|
255 |
+
|
256 |
+
self.post_init()
|
257 |
+
|
258 |
+
def post_init(self):
|
259 |
+
r"""
|
260 |
+
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
261 |
+
"""
|
262 |
+
pass
|
263 |
+
|
264 |
+
@classmethod
|
265 |
+
def from_dict(cls, config: Dict[str, Any]):
|
266 |
+
"""
|
267 |
+
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
|
268 |
+
"""
|
269 |
+
instance = cls()
|
270 |
+
instance.quant_config = config["quant_config"]
|
271 |
+
instance.skip_modules = config["skip_modules"]
|
272 |
+
return instance
|
273 |
+
|
274 |
+
def to_dict(self) -> Dict[str, Any]:
|
275 |
+
"""
|
276 |
+
Serializes this instance to a Python dictionary. Returns:
|
277 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
278 |
+
"""
|
279 |
+
return {
|
280 |
+
"quant_config": self.quant_config,
|
281 |
+
"quant_method": self.quant_method,
|
282 |
+
"skip_modules": self.skip_modules,
|
283 |
+
}
|
284 |
+
|
285 |
+
def __repr__(self):
|
286 |
+
config_dict = self.to_dict()
|
287 |
+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
288 |
+
|
289 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
290 |
+
"""
|
291 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
292 |
+
serializes to a Python dictionary.
|
293 |
+
Returns:
|
294 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
295 |
+
"""
|
296 |
+
config_dict = self.to_dict()
|
297 |
+
|
298 |
+
# get the default config dict
|
299 |
+
default_config_dict = HqqConfig().to_dict()
|
300 |
+
|
301 |
+
serializable_config_dict = {}
|
302 |
+
|
303 |
+
# only serialize values that differ from the default config
|
304 |
+
for key, value in config_dict.items():
|
305 |
+
if value != default_config_dict[key]:
|
306 |
+
serializable_config_dict[key] = value
|
307 |
+
|
308 |
+
return serializable_config_dict
|
309 |
+
|
310 |
+
|
311 |
+
@dataclass
|
312 |
+
class BitsAndBytesConfig(QuantizationConfigMixin):
|
313 |
+
"""
|
314 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
315 |
+
loaded using `bitsandbytes`.
|
316 |
+
|
317 |
+
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
|
318 |
+
|
319 |
+
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
|
320 |
+
then more arguments will be added to this class.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
load_in_8bit (`bool`, *optional*, defaults to `False`):
|
324 |
+
This flag is used to enable 8-bit quantization with LLM.int8().
|
325 |
+
load_in_4bit (`bool`, *optional*, defaults to `False`):
|
326 |
+
This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
|
327 |
+
`bitsandbytes`.
|
328 |
+
llm_int8_threshold (`float`, *optional*, defaults to 6.0):
|
329 |
+
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
|
330 |
+
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
|
331 |
+
that is above this threshold will be considered an outlier and the operation on those values will be done
|
332 |
+
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
|
333 |
+
there are some exceptional systematic outliers that are very differently distributed for large models.
|
334 |
+
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
|
335 |
+
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
|
336 |
+
but a lower threshold might be needed for more unstable models (small models, fine-tuning).
|
337 |
+
llm_int8_skip_modules (`List[str]`, *optional*):
|
338 |
+
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
|
339 |
+
Jukebox that has several heads in different places and not necessarily at the last position. For example
|
340 |
+
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
|
341 |
+
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
|
342 |
+
This flag is used for advanced use cases and users that are aware of this feature. If you want to split
|
343 |
+
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
|
344 |
+
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
|
345 |
+
operations will not be run on CPU.
|
346 |
+
llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
|
347 |
+
This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
|
348 |
+
have to be converted back and forth for the backward pass.
|
349 |
+
bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
|
350 |
+
This sets the computational type which might be different than the input type. For example, inputs might be
|
351 |
+
fp32, but computation can be set to bf16 for speedups.
|
352 |
+
bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
|
353 |
+
This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
|
354 |
+
which are specified by `fp4` or `nf4`.
|
355 |
+
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
|
356 |
+
This flag is used for nested quantization where the quantization constants from the first quantization are
|
357 |
+
quantized again.
|
358 |
+
bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
|
359 |
+
This sets the storage type to pack the quanitzed 4-bit prarams.
|
360 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
361 |
+
Additional parameters from which to initialize the configuration object.
|
362 |
+
"""
|
363 |
+
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
load_in_8bit=False,
|
367 |
+
load_in_4bit=False,
|
368 |
+
llm_int8_threshold=6.0,
|
369 |
+
llm_int8_skip_modules=None,
|
370 |
+
llm_int8_enable_fp32_cpu_offload=False,
|
371 |
+
llm_int8_has_fp16_weight=False,
|
372 |
+
bnb_4bit_compute_dtype=None,
|
373 |
+
bnb_4bit_quant_type="fp4",
|
374 |
+
bnb_4bit_use_double_quant=False,
|
375 |
+
bnb_4bit_quant_storage=None,
|
376 |
+
**kwargs,
|
377 |
+
):
|
378 |
+
self.quant_method = QuantizationMethod.BITS_AND_BYTES
|
379 |
+
|
380 |
+
if load_in_4bit and load_in_8bit:
|
381 |
+
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
382 |
+
|
383 |
+
self._load_in_8bit = load_in_8bit
|
384 |
+
self._load_in_4bit = load_in_4bit
|
385 |
+
self.llm_int8_threshold = llm_int8_threshold
|
386 |
+
self.llm_int8_skip_modules = llm_int8_skip_modules
|
387 |
+
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
388 |
+
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
389 |
+
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
390 |
+
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
391 |
+
|
392 |
+
if bnb_4bit_compute_dtype is None:
|
393 |
+
self.bnb_4bit_compute_dtype = torch.float32
|
394 |
+
elif isinstance(bnb_4bit_compute_dtype, str):
|
395 |
+
self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
|
396 |
+
elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
|
397 |
+
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
398 |
+
else:
|
399 |
+
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
400 |
+
|
401 |
+
if bnb_4bit_quant_storage is None:
|
402 |
+
self.bnb_4bit_quant_storage = torch.uint8
|
403 |
+
elif isinstance(bnb_4bit_quant_storage, str):
|
404 |
+
if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
|
405 |
+
raise ValueError(
|
406 |
+
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
|
407 |
+
)
|
408 |
+
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
|
409 |
+
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
|
410 |
+
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
411 |
+
else:
|
412 |
+
raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
|
413 |
+
|
414 |
+
if kwargs:
|
415 |
+
logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
|
416 |
+
|
417 |
+
self.post_init()
|
418 |
+
|
419 |
+
@property
|
420 |
+
def load_in_4bit(self):
|
421 |
+
return self._load_in_4bit
|
422 |
+
|
423 |
+
@load_in_4bit.setter
|
424 |
+
def load_in_4bit(self, value: bool):
|
425 |
+
if not isinstance(value, bool):
|
426 |
+
raise TypeError("load_in_4bit must be a boolean")
|
427 |
+
|
428 |
+
if self.load_in_8bit and value:
|
429 |
+
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
430 |
+
self._load_in_4bit = value
|
431 |
+
|
432 |
+
@property
|
433 |
+
def load_in_8bit(self):
|
434 |
+
return self._load_in_8bit
|
435 |
+
|
436 |
+
@load_in_8bit.setter
|
437 |
+
def load_in_8bit(self, value: bool):
|
438 |
+
if not isinstance(value, bool):
|
439 |
+
raise TypeError("load_in_8bit must be a boolean")
|
440 |
+
|
441 |
+
if self.load_in_4bit and value:
|
442 |
+
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
|
443 |
+
self._load_in_8bit = value
|
444 |
+
|
445 |
+
def post_init(self):
|
446 |
+
r"""
|
447 |
+
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
448 |
+
"""
|
449 |
+
if not isinstance(self.load_in_4bit, bool):
|
450 |
+
raise TypeError("load_in_4bit must be a boolean")
|
451 |
+
|
452 |
+
if not isinstance(self.load_in_8bit, bool):
|
453 |
+
raise TypeError("load_in_8bit must be a boolean")
|
454 |
+
|
455 |
+
if not isinstance(self.llm_int8_threshold, float):
|
456 |
+
raise TypeError("llm_int8_threshold must be a float")
|
457 |
+
|
458 |
+
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
|
459 |
+
raise TypeError("llm_int8_skip_modules must be a list of strings")
|
460 |
+
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
|
461 |
+
raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean")
|
462 |
+
|
463 |
+
if not isinstance(self.llm_int8_has_fp16_weight, bool):
|
464 |
+
raise TypeError("llm_int8_has_fp16_weight must be a boolean")
|
465 |
+
|
466 |
+
if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
|
467 |
+
raise TypeError("bnb_4bit_compute_dtype must be torch.dtype")
|
468 |
+
|
469 |
+
if not isinstance(self.bnb_4bit_quant_type, str):
|
470 |
+
raise TypeError("bnb_4bit_quant_type must be a string")
|
471 |
+
|
472 |
+
if not isinstance(self.bnb_4bit_use_double_quant, bool):
|
473 |
+
raise TypeError("bnb_4bit_use_double_quant must be a boolean")
|
474 |
+
|
475 |
+
if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
|
476 |
+
"0.39.0"
|
477 |
+
):
|
478 |
+
raise ValueError(
|
479 |
+
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
|
480 |
+
)
|
481 |
+
|
482 |
+
def is_quantizable(self):
|
483 |
+
r"""
|
484 |
+
Returns `True` if the model is quantizable, `False` otherwise.
|
485 |
+
"""
|
486 |
+
return self.load_in_8bit or self.load_in_4bit
|
487 |
+
|
488 |
+
def quantization_method(self):
|
489 |
+
r"""
|
490 |
+
This method returns the quantization method used for the model. If the model is not quantizable, it returns
|
491 |
+
`None`.
|
492 |
+
"""
|
493 |
+
if self.load_in_8bit:
|
494 |
+
return "llm_int8"
|
495 |
+
elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
|
496 |
+
return "fp4"
|
497 |
+
elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
|
498 |
+
return "nf4"
|
499 |
+
else:
|
500 |
+
return None
|
501 |
+
|
502 |
+
def to_dict(self) -> Dict[str, Any]:
|
503 |
+
"""
|
504 |
+
Serializes this instance to a Python dictionary. Returns:
|
505 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
506 |
+
"""
|
507 |
+
output = copy.deepcopy(self.__dict__)
|
508 |
+
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
|
509 |
+
output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
|
510 |
+
output["load_in_4bit"] = self.load_in_4bit
|
511 |
+
output["load_in_8bit"] = self.load_in_8bit
|
512 |
+
|
513 |
+
return output
|
514 |
+
|
515 |
+
def __repr__(self):
|
516 |
+
config_dict = self.to_dict()
|
517 |
+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
518 |
+
|
519 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
520 |
+
"""
|
521 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
522 |
+
serializes to a Python dictionary.
|
523 |
+
|
524 |
+
Returns:
|
525 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
526 |
+
"""
|
527 |
+
config_dict = self.to_dict()
|
528 |
+
|
529 |
+
# get the default config dict
|
530 |
+
default_config_dict = BitsAndBytesConfig().to_dict()
|
531 |
+
|
532 |
+
serializable_config_dict = {}
|
533 |
+
|
534 |
+
# only serialize values that differ from the default config
|
535 |
+
for key, value in config_dict.items():
|
536 |
+
if value != default_config_dict[key]:
|
537 |
+
serializable_config_dict[key] = value
|
538 |
+
|
539 |
+
return serializable_config_dict
|
540 |
+
|
541 |
+
|
542 |
+
class ExllamaVersion(int, Enum):
|
543 |
+
ONE = 1
|
544 |
+
TWO = 2
|
545 |
+
|
546 |
+
|
547 |
+
@dataclass
|
548 |
+
class GPTQConfig(QuantizationConfigMixin):
|
549 |
+
"""
|
550 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
551 |
+
loaded using `optimum` api for gptq quantization relying on auto_gptq backend.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
bits (`int`):
|
555 |
+
The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
|
556 |
+
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
|
557 |
+
The tokenizer used to process the dataset. You can pass either:
|
558 |
+
- A custom tokenizer object.
|
559 |
+
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
560 |
+
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
|
561 |
+
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
562 |
+
dataset (`Union[List[str]]`, *optional*):
|
563 |
+
The dataset used for quantization. You can provide your own dataset in a list of string or just use the
|
564 |
+
original datasets used in GPTQ paper ['wikitext2','c4','c4-new']
|
565 |
+
group_size (`int`, *optional*, defaults to 128):
|
566 |
+
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
|
567 |
+
damp_percent (`float`, *optional*, defaults to 0.1):
|
568 |
+
The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1.
|
569 |
+
desc_act (`bool`, *optional*, defaults to `False`):
|
570 |
+
Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly
|
571 |
+
speed up inference but the perplexity may become slightly worse. Also known as act-order.
|
572 |
+
sym (`bool`, *optional*, defaults to `True`):
|
573 |
+
Whether to use symetric quantization.
|
574 |
+
true_sequential (`bool`, *optional*, defaults to `True`):
|
575 |
+
Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing
|
576 |
+
the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes
|
577 |
+
quantization using inputs that have passed through the previously quantized layers.
|
578 |
+
use_cuda_fp16 (`bool`, *optional*, defaults to `False`):
|
579 |
+
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
|
580 |
+
model_seqlen (`int`, *optional*):
|
581 |
+
The maximum sequence length that the model can take.
|
582 |
+
block_name_to_quantize (`str`, *optional*):
|
583 |
+
The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers)
|
584 |
+
module_name_preceding_first_block (`List[str]`, *optional*):
|
585 |
+
The layers that are preceding the first Transformer block.
|
586 |
+
batch_size (`int`, *optional*, defaults to 1):
|
587 |
+
The batch size used when processing the dataset
|
588 |
+
pad_token_id (`int`, *optional*):
|
589 |
+
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
|
590 |
+
use_exllama (`bool`, *optional*):
|
591 |
+
Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4.
|
592 |
+
max_input_length (`int`, *optional*):
|
593 |
+
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
|
594 |
+
length. It is specific to the exllama backend with act-order.
|
595 |
+
exllama_config (`Dict[str, Any]`, *optional*):
|
596 |
+
The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults
|
597 |
+
to `{"version": 1}` if unset.
|
598 |
+
cache_block_outputs (`bool`, *optional*, defaults to `True`):
|
599 |
+
Whether to cache block outputs to reuse as inputs for the succeeding block.
|
600 |
+
modules_in_block_to_quantize (`List[List[str]]`, *optional*):
|
601 |
+
List of list of module names to quantize in the specified block. This argument is useful to exclude certain linear modules from being quantized.
|
602 |
+
The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers.
|
603 |
+
Example: `modules_in_block_to_quantize =[["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"]]`.
|
604 |
+
In this example, we will first quantize the q,k,v layers simultaneously since they are independent.
|
605 |
+
Then, we will quantize `self_attn.o_proj` layer with the q,k,v layers quantized. This way, we will get
|
606 |
+
better results since it reflects the real input `self_attn.o_proj` will get when the model is quantized.
|
607 |
+
"""
|
608 |
+
|
609 |
+
def __init__(
|
610 |
+
self,
|
611 |
+
bits: int,
|
612 |
+
tokenizer: Any = None,
|
613 |
+
dataset: Optional[Union[List[str], str]] = None,
|
614 |
+
group_size: int = 128,
|
615 |
+
damp_percent: float = 0.1,
|
616 |
+
desc_act: bool = False,
|
617 |
+
sym: bool = True,
|
618 |
+
true_sequential: bool = True,
|
619 |
+
use_cuda_fp16: bool = False,
|
620 |
+
model_seqlen: Optional[int] = None,
|
621 |
+
block_name_to_quantize: Optional[str] = None,
|
622 |
+
module_name_preceding_first_block: Optional[List[str]] = None,
|
623 |
+
batch_size: int = 1,
|
624 |
+
pad_token_id: Optional[int] = None,
|
625 |
+
use_exllama: Optional[bool] = None,
|
626 |
+
max_input_length: Optional[int] = None,
|
627 |
+
exllama_config: Optional[Dict[str, Any]] = None,
|
628 |
+
cache_block_outputs: bool = True,
|
629 |
+
modules_in_block_to_quantize: Optional[List[List[str]]] = None,
|
630 |
+
**kwargs,
|
631 |
+
):
|
632 |
+
self.quant_method = QuantizationMethod.GPTQ
|
633 |
+
self.bits = bits
|
634 |
+
self.tokenizer = tokenizer
|
635 |
+
self.dataset = dataset
|
636 |
+
self.group_size = group_size
|
637 |
+
self.damp_percent = damp_percent
|
638 |
+
self.desc_act = desc_act
|
639 |
+
self.sym = sym
|
640 |
+
self.true_sequential = true_sequential
|
641 |
+
self.use_cuda_fp16 = use_cuda_fp16
|
642 |
+
self.model_seqlen = model_seqlen
|
643 |
+
self.block_name_to_quantize = block_name_to_quantize
|
644 |
+
self.module_name_preceding_first_block = module_name_preceding_first_block
|
645 |
+
self.batch_size = batch_size
|
646 |
+
self.pad_token_id = pad_token_id
|
647 |
+
self.use_exllama = use_exllama
|
648 |
+
self.max_input_length = max_input_length
|
649 |
+
self.exllama_config = exllama_config
|
650 |
+
self.disable_exllama = kwargs.pop("disable_exllama", None)
|
651 |
+
self.cache_block_outputs = cache_block_outputs
|
652 |
+
self.modules_in_block_to_quantize = modules_in_block_to_quantize
|
653 |
+
self.post_init()
|
654 |
+
|
655 |
+
def get_loading_attributes(self):
|
656 |
+
attibutes_dict = copy.deepcopy(self.__dict__)
|
657 |
+
loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
|
658 |
+
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
659 |
+
return loading_attibutes_dict
|
660 |
+
|
661 |
+
def post_init(self):
|
662 |
+
r"""
|
663 |
+
Safety checker that arguments are correct
|
664 |
+
"""
|
665 |
+
if self.bits not in [2, 3, 4, 8]:
|
666 |
+
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
|
667 |
+
if self.group_size != -1 and self.group_size <= 0:
|
668 |
+
raise ValueError("group_size must be greater than 0 or equal to -1")
|
669 |
+
if not (0 < self.damp_percent < 1):
|
670 |
+
raise ValueError("damp_percent must between 0 and 1.")
|
671 |
+
if self.dataset is not None:
|
672 |
+
if isinstance(self.dataset, str):
|
673 |
+
if self.dataset in ["ptb", "ptb-new"]:
|
674 |
+
raise ValueError(
|
675 |
+
f"""{self.dataset} dataset was deprecated. You can only choose between
|
676 |
+
['wikitext2','c4','c4-new']"""
|
677 |
+
)
|
678 |
+
if self.dataset not in ["wikitext2", "c4", "c4-new"]:
|
679 |
+
raise ValueError(
|
680 |
+
f"""You have entered a string value for dataset. You can only choose between
|
681 |
+
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
|
682 |
+
)
|
683 |
+
elif not isinstance(self.dataset, list):
|
684 |
+
raise ValueError(
|
685 |
+
f"""dataset needs to be either a list of string or a value in
|
686 |
+
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
|
687 |
+
)
|
688 |
+
|
689 |
+
if self.disable_exllama is None and self.use_exllama is None:
|
690 |
+
# New default behaviour
|
691 |
+
self.use_exllama = True
|
692 |
+
elif self.disable_exllama is not None and self.use_exllama is None:
|
693 |
+
# Follow pattern of old config
|
694 |
+
logger.warning(
|
695 |
+
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
|
696 |
+
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
|
697 |
+
)
|
698 |
+
self.use_exllama = not self.disable_exllama
|
699 |
+
self.disable_exllama = None
|
700 |
+
elif self.disable_exllama is not None and self.use_exllama is not None:
|
701 |
+
# Only happens if user explicitly passes in both arguments
|
702 |
+
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
|
703 |
+
|
704 |
+
if self.exllama_config is None:
|
705 |
+
self.exllama_config = {"version": ExllamaVersion.ONE}
|
706 |
+
else:
|
707 |
+
if "version" not in self.exllama_config:
|
708 |
+
raise ValueError("`exllama_config` needs to have a `version` key.")
|
709 |
+
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
|
710 |
+
exllama_version = self.exllama_config["version"]
|
711 |
+
raise ValueError(
|
712 |
+
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
|
713 |
+
)
|
714 |
+
|
715 |
+
if self.bits == 4 and self.use_exllama:
|
716 |
+
if self.exllama_config["version"] == ExllamaVersion.ONE:
|
717 |
+
logger.info(
|
718 |
+
"You have activated exllama backend. Note that you can get better inference "
|
719 |
+
"speed using exllamav2 kernel by setting `exllama_config`."
|
720 |
+
)
|
721 |
+
elif self.exllama_config["version"] == ExllamaVersion.TWO:
|
722 |
+
optimum_version = version.parse(importlib.metadata.version("optimum"))
|
723 |
+
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
|
724 |
+
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
|
725 |
+
raise ValueError(
|
726 |
+
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
|
727 |
+
)
|
728 |
+
if self.modules_in_block_to_quantize is not None:
|
729 |
+
optimum_version = version.parse(importlib.metadata.version("optimum"))
|
730 |
+
if optimum_version < version.parse("1.15.0"):
|
731 |
+
raise ValueError(
|
732 |
+
"You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
|
733 |
+
)
|
734 |
+
|
735 |
+
def to_dict(self):
|
736 |
+
config_dict = super().to_dict()
|
737 |
+
config_dict.pop("disable_exllama", None)
|
738 |
+
return config_dict
|
739 |
+
|
740 |
+
def to_dict_optimum(self):
|
741 |
+
"""
|
742 |
+
Get compatible dict for optimum gptq config
|
743 |
+
"""
|
744 |
+
quant_dict = self.to_dict()
|
745 |
+
# make it compatible with optimum config
|
746 |
+
quant_dict["disable_exllama"] = not self.use_exllama
|
747 |
+
return quant_dict
|
748 |
+
|
749 |
+
@classmethod
|
750 |
+
def from_dict_optimum(cls, config_dict):
|
751 |
+
"""
|
752 |
+
Get compatible class with optimum gptq config dict
|
753 |
+
"""
|
754 |
+
|
755 |
+
if "disable_exllama" in config_dict:
|
756 |
+
config_dict["use_exllama"] = not config_dict["disable_exllama"]
|
757 |
+
# switch to None to not trigger the warning
|
758 |
+
config_dict["disable_exllama"] = None
|
759 |
+
|
760 |
+
config = cls(**config_dict)
|
761 |
+
return config
|
762 |
+
|
763 |
+
|
764 |
+
@dataclass
|
765 |
+
class AwqConfig(QuantizationConfigMixin):
|
766 |
+
"""
|
767 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
768 |
+
loaded using `auto-awq` library awq quantization relying on auto_awq backend.
|
769 |
+
|
770 |
+
Args:
|
771 |
+
bits (`int`, *optional*, defaults to 4):
|
772 |
+
The number of bits to quantize to.
|
773 |
+
group_size (`int`, *optional*, defaults to 128):
|
774 |
+
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
|
775 |
+
zero_point (`bool`, *optional*, defaults to `True`):
|
776 |
+
Whether to use zero point quantization.
|
777 |
+
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
|
778 |
+
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
|
779 |
+
GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
|
780 |
+
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
|
781 |
+
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
|
782 |
+
that quantize their own models using `llm-awq` library.
|
783 |
+
do_fuse (`bool`, *optional*, defaults to `False`):
|
784 |
+
Whether to fuse attention and mlp layers together for faster inference
|
785 |
+
fuse_max_seq_len (`int`, *optional*):
|
786 |
+
The Maximum sequence length to generate when using fusing.
|
787 |
+
modules_to_fuse (`dict`, *optional*, default to `None`):
|
788 |
+
Overwrite the natively supported fusing scheme with the one specified by the users.
|
789 |
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
790 |
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
791 |
+
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
792 |
+
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
|
793 |
+
exllama_config (`Dict[str, Any]`, *optional*):
|
794 |
+
You can specify the version of the exllama kernel through the `version` key, the maximum sequence
|
795 |
+
length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
|
796 |
+
Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
|
797 |
+
"""
|
798 |
+
|
799 |
+
def __init__(
|
800 |
+
self,
|
801 |
+
bits: int = 4,
|
802 |
+
group_size: int = 128,
|
803 |
+
zero_point: bool = True,
|
804 |
+
version: AWQLinearVersion = AWQLinearVersion.GEMM,
|
805 |
+
backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
|
806 |
+
do_fuse: Optional[bool] = None,
|
807 |
+
fuse_max_seq_len: Optional[int] = None,
|
808 |
+
modules_to_fuse: Optional[dict] = None,
|
809 |
+
modules_to_not_convert: Optional[List] = None,
|
810 |
+
exllama_config: Optional[Dict[str, int]] = None,
|
811 |
+
**kwargs,
|
812 |
+
):
|
813 |
+
self.quant_method = QuantizationMethod.AWQ
|
814 |
+
|
815 |
+
self.bits = bits
|
816 |
+
self.group_size = group_size
|
817 |
+
self.zero_point = zero_point
|
818 |
+
self.version = version
|
819 |
+
self.backend = backend
|
820 |
+
self.fuse_max_seq_len = fuse_max_seq_len
|
821 |
+
self.modules_to_not_convert = modules_to_not_convert
|
822 |
+
self.exllama_config = exllama_config
|
823 |
+
|
824 |
+
self.modules_to_fuse = modules_to_fuse
|
825 |
+
if do_fuse is None:
|
826 |
+
self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0
|
827 |
+
else:
|
828 |
+
self.do_fuse = do_fuse
|
829 |
+
self.fuse_max_seq_len = fuse_max_seq_len
|
830 |
+
|
831 |
+
self.post_init()
|
832 |
+
|
833 |
+
def post_init(self):
|
834 |
+
r"""
|
835 |
+
Safety checker that arguments are correct
|
836 |
+
"""
|
837 |
+
if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
|
838 |
+
raise ValueError(
|
839 |
+
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
|
840 |
+
)
|
841 |
+
|
842 |
+
self.version = AWQLinearVersion.from_str(self.version)
|
843 |
+
if self.version not in [
|
844 |
+
AWQLinearVersion.GEMM,
|
845 |
+
AWQLinearVersion.GEMV,
|
846 |
+
AWQLinearVersion.EXLLAMA,
|
847 |
+
AWQLinearVersion.IPEX,
|
848 |
+
]:
|
849 |
+
raise ValueError(
|
850 |
+
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA, AWQLinearVersion.IPEX] - not recognized version {self.version}"
|
851 |
+
)
|
852 |
+
|
853 |
+
if self.backend == AwqBackendPackingMethod.LLMAWQ:
|
854 |
+
compute_capability = torch.cuda.get_device_capability()
|
855 |
+
major, minor = compute_capability
|
856 |
+
if major < 8:
|
857 |
+
raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
|
858 |
+
|
859 |
+
if self.do_fuse and self.fuse_max_seq_len is None:
|
860 |
+
raise ValueError(
|
861 |
+
"You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase"
|
862 |
+
)
|
863 |
+
|
864 |
+
if self.do_fuse:
|
865 |
+
awq_version_supports_fusing = False
|
866 |
+
MIN_AWQ_VERSION = "0.1.7"
|
867 |
+
if is_auto_awq_available():
|
868 |
+
awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
|
869 |
+
MIN_AWQ_VERSION
|
870 |
+
)
|
871 |
+
|
872 |
+
if not awq_version_supports_fusing:
|
873 |
+
raise ValueError(
|
874 |
+
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
875 |
+
)
|
876 |
+
|
877 |
+
if self.modules_to_not_convert is not None:
|
878 |
+
awq_version_supports_non_conversion = False
|
879 |
+
MIN_AWQ_VERSION = "0.1.8"
|
880 |
+
if is_auto_awq_available():
|
881 |
+
awq_version_supports_non_conversion = version.parse(
|
882 |
+
importlib.metadata.version("autoawq")
|
883 |
+
) >= version.parse(MIN_AWQ_VERSION)
|
884 |
+
|
885 |
+
if not awq_version_supports_non_conversion:
|
886 |
+
raise ValueError(
|
887 |
+
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
888 |
+
)
|
889 |
+
|
890 |
+
if self.do_fuse and self.modules_to_fuse is not None:
|
891 |
+
required_keys = [
|
892 |
+
"hidden_size",
|
893 |
+
"num_attention_heads",
|
894 |
+
"num_key_value_heads",
|
895 |
+
"mlp",
|
896 |
+
"attention",
|
897 |
+
"layernorm",
|
898 |
+
"use_alibi",
|
899 |
+
]
|
900 |
+
if not all(key in self.modules_to_fuse for key in required_keys):
|
901 |
+
raise ValueError(
|
902 |
+
f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
|
903 |
+
)
|
904 |
+
|
905 |
+
if self.version == AWQLinearVersion.EXLLAMA:
|
906 |
+
awq_version_supports_exllama = False
|
907 |
+
MIN_AWQ_VERSION = "0.2.0"
|
908 |
+
if is_auto_awq_available():
|
909 |
+
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
|
910 |
+
MIN_AWQ_VERSION
|
911 |
+
)
|
912 |
+
|
913 |
+
if not awq_version_supports_exllama:
|
914 |
+
raise ValueError(
|
915 |
+
f"You current version of `autoawq` does not support exllama backend, "
|
916 |
+
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
917 |
+
)
|
918 |
+
|
919 |
+
if self.exllama_config is None:
|
920 |
+
self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
|
921 |
+
else:
|
922 |
+
if "version" not in self.exllama_config:
|
923 |
+
raise ValueError("`exllama_config` needs to have a `version` key.")
|
924 |
+
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
|
925 |
+
exllama_version = self.exllama_config["version"]
|
926 |
+
raise ValueError(
|
927 |
+
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
|
928 |
+
)
|
929 |
+
|
930 |
+
def get_loading_attributes(self):
|
931 |
+
attibutes_dict = copy.deepcopy(self.__dict__)
|
932 |
+
loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len", "exllama_config"]
|
933 |
+
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
934 |
+
return loading_attibutes_dict
|
935 |
+
|
936 |
+
|
937 |
+
@dataclass
|
938 |
+
class AqlmConfig(QuantizationConfigMixin):
|
939 |
+
"""
|
940 |
+
This is a wrapper class about `aqlm` parameters.
|
941 |
+
|
942 |
+
Args:
|
943 |
+
in_group_size (`int`, *optional*, defaults to 8):
|
944 |
+
The group size along the input dimension.
|
945 |
+
out_group_size (`int`, *optional*, defaults to 1):
|
946 |
+
The group size along the output dimension. It's recommended to always use 1.
|
947 |
+
num_codebooks (`int`, *optional*, defaults to 1):
|
948 |
+
Number of codebooks for the Additive Quantization procedure.
|
949 |
+
nbits_per_codebook (`int`, *optional*, defaults to 16):
|
950 |
+
Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
|
951 |
+
linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*):
|
952 |
+
List of full paths of `nn.Linear` weight parameters that shall not be quantized.
|
953 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
954 |
+
Additional parameters from which to initialize the configuration object.
|
955 |
+
"""
|
956 |
+
|
957 |
+
def __init__(
|
958 |
+
self,
|
959 |
+
in_group_size: int = 8,
|
960 |
+
out_group_size: int = 1,
|
961 |
+
num_codebooks: int = 1,
|
962 |
+
nbits_per_codebook: int = 16,
|
963 |
+
linear_weights_not_to_quantize: Optional[List[str]] = None,
|
964 |
+
**kwargs,
|
965 |
+
):
|
966 |
+
self.quant_method = QuantizationMethod.AQLM
|
967 |
+
self.in_group_size = in_group_size
|
968 |
+
self.out_group_size = out_group_size
|
969 |
+
self.num_codebooks = num_codebooks
|
970 |
+
self.nbits_per_codebook = nbits_per_codebook
|
971 |
+
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
|
972 |
+
|
973 |
+
self.post_init()
|
974 |
+
|
975 |
+
def post_init(self):
|
976 |
+
r"""
|
977 |
+
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
978 |
+
"""
|
979 |
+
if not isinstance(self.in_group_size, int):
|
980 |
+
raise TypeError("in_group_size must be a float")
|
981 |
+
if not isinstance(self.out_group_size, int):
|
982 |
+
raise TypeError("out_group_size must be a float")
|
983 |
+
if not isinstance(self.num_codebooks, int):
|
984 |
+
raise TypeError("num_codebooks must be a float")
|
985 |
+
if not isinstance(self.nbits_per_codebook, int):
|
986 |
+
raise TypeError("nbits_per_codebook must be a float")
|
987 |
+
|
988 |
+
if self.linear_weights_not_to_quantize is not None and not isinstance(
|
989 |
+
self.linear_weights_not_to_quantize, list
|
990 |
+
):
|
991 |
+
raise ValueError("linear_weights_not_to_quantize must be a list of strings")
|
992 |
+
|
993 |
+
if self.linear_weights_not_to_quantize is None:
|
994 |
+
self.linear_weights_not_to_quantize = []
|
995 |
+
|
996 |
+
|
997 |
+
@dataclass
|
998 |
+
class QuantoConfig(QuantizationConfigMixin):
|
999 |
+
"""
|
1000 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
1001 |
+
loaded using `quanto`.
|
1002 |
+
|
1003 |
+
Args:
|
1004 |
+
weights (`str`, *optional*, defaults to `"int8"`):
|
1005 |
+
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
1006 |
+
activations (`str`, *optional*):
|
1007 |
+
The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
|
1008 |
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
1009 |
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
1010 |
+
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
1011 |
+
"""
|
1012 |
+
|
1013 |
+
def __init__(
|
1014 |
+
self,
|
1015 |
+
weights="int8",
|
1016 |
+
activations=None,
|
1017 |
+
modules_to_not_convert: Optional[List] = None,
|
1018 |
+
**kwargs,
|
1019 |
+
):
|
1020 |
+
self.quant_method = QuantizationMethod.QUANTO
|
1021 |
+
self.weights = weights
|
1022 |
+
self.activations = activations
|
1023 |
+
self.modules_to_not_convert = modules_to_not_convert
|
1024 |
+
self.post_init()
|
1025 |
+
|
1026 |
+
def post_init(self):
|
1027 |
+
r"""
|
1028 |
+
Safety checker that arguments are correct
|
1029 |
+
"""
|
1030 |
+
accepted_weights = ["float8", "int8", "int4", "int2"]
|
1031 |
+
accepted_activations = [None, "int8", "float8"]
|
1032 |
+
if self.weights not in accepted_weights:
|
1033 |
+
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
|
1034 |
+
if self.activations not in accepted_activations:
|
1035 |
+
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
|
1036 |
+
|
1037 |
+
|
1038 |
+
@dataclass
|
1039 |
+
class EetqConfig(QuantizationConfigMixin):
|
1040 |
+
"""
|
1041 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
1042 |
+
loaded using `eetq`.
|
1043 |
+
|
1044 |
+
Args:
|
1045 |
+
weights (`str`, *optional*, defaults to `"int8"`):
|
1046 |
+
The target dtype for the weights. Supported value is only "int8"
|
1047 |
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
1048 |
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
1049 |
+
some modules left in their original precision.
|
1050 |
+
"""
|
1051 |
+
|
1052 |
+
def __init__(
|
1053 |
+
self,
|
1054 |
+
weights: str = "int8",
|
1055 |
+
modules_to_not_convert: Optional[List] = None,
|
1056 |
+
**kwargs,
|
1057 |
+
):
|
1058 |
+
self.quant_method = QuantizationMethod.EETQ
|
1059 |
+
self.weights = weights
|
1060 |
+
self.modules_to_not_convert = modules_to_not_convert
|
1061 |
+
self.post_init()
|
1062 |
+
|
1063 |
+
def post_init(self):
|
1064 |
+
r"""
|
1065 |
+
Safety checker that arguments are correct
|
1066 |
+
"""
|
1067 |
+
accepted_weights = ["int8"]
|
1068 |
+
if self.weights not in accepted_weights:
|
1069 |
+
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
|
1070 |
+
|
1071 |
+
|
1072 |
+
class CompressedTensorsConfig(QuantizationConfigMixin):
|
1073 |
+
"""
|
1074 |
+
This is a wrapper class that handles compressed-tensors quantization config options.
|
1075 |
+
It is a wrapper around `compressed_tensors.QuantizationConfig`
|
1076 |
+
Args:
|
1077 |
+
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
|
1078 |
+
dictionary mapping group name to a quantization scheme definition
|
1079 |
+
format (`str`, *optional*, defaults to `"dense"`):
|
1080 |
+
format the model is represented as
|
1081 |
+
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
|
1082 |
+
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
|
1083 |
+
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
|
1084 |
+
specifies quantization of the kv cache. If None, kv cache is not quantized.
|
1085 |
+
global_compression_ratio (`typing.Union[float, NoneType]`, *optional*):
|
1086 |
+
0-1 float percentage of model compression
|
1087 |
+
ignore (`typing.Union[typing.List[str], NoneType]`, *optional*):
|
1088 |
+
layer names or types to not quantize, supports regex prefixed by 're:'
|
1089 |
+
sparsity_config (`typing.Dict[str, typing.Any]`, *optional*):
|
1090 |
+
configuration for sparsity compression
|
1091 |
+
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
|
1092 |
+
do not override, should be compressed-tensors
|
1093 |
+
"""
|
1094 |
+
|
1095 |
+
def __init__(
|
1096 |
+
self,
|
1097 |
+
config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, # noqa: F821
|
1098 |
+
format: str = "dense",
|
1099 |
+
quantization_status: "QuantizationStatus" = "initialized", # noqa: F821
|
1100 |
+
kv_cache_scheme: Optional["QuantizationArgs"] = None, # noqa: F821
|
1101 |
+
global_compression_ratio: Optional[float] = None,
|
1102 |
+
ignore: Optional[List[str]] = None,
|
1103 |
+
sparsity_config: Dict[str, Any] = None,
|
1104 |
+
quant_method: str = "compressed-tensors",
|
1105 |
+
**kwargs,
|
1106 |
+
):
|
1107 |
+
from compressed_tensors import QuantizationConfig
|
1108 |
+
from compressed_tensors.config import SparsityCompressionConfig
|
1109 |
+
|
1110 |
+
self.quantization_config = None
|
1111 |
+
self.sparsity_config = None
|
1112 |
+
|
1113 |
+
# parse from dict to load nested QuantizationScheme objects
|
1114 |
+
if config_groups or kv_cache_scheme:
|
1115 |
+
self.quantization_config = QuantizationConfig.parse_obj(
|
1116 |
+
{
|
1117 |
+
"config_groups": config_groups,
|
1118 |
+
"quant_method": quant_method,
|
1119 |
+
"format": format,
|
1120 |
+
"quantization_status": quantization_status,
|
1121 |
+
"kv_cache_scheme": kv_cache_scheme,
|
1122 |
+
"global_compression_ratio": global_compression_ratio,
|
1123 |
+
"ignore": ignore,
|
1124 |
+
**kwargs,
|
1125 |
+
}
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
if sparsity_config:
|
1129 |
+
self.sparsity_config = SparsityCompressionConfig.load_from_registry(
|
1130 |
+
sparsity_config.get("format"), **sparsity_config
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)
|
1134 |
+
|
1135 |
+
@classmethod
|
1136 |
+
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
1137 |
+
"""
|
1138 |
+
Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters.
|
1139 |
+
Optionally unwraps any args from the nested quantization_config
|
1140 |
+
|
1141 |
+
Args:
|
1142 |
+
config_dict (`Dict[str, Any]`):
|
1143 |
+
Dictionary that will be used to instantiate the configuration object.
|
1144 |
+
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
|
1145 |
+
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
|
1146 |
+
`PreTrainedModel`.
|
1147 |
+
kwargs (`Dict[str, Any]`):
|
1148 |
+
Additional parameters from which to initialize the configuration object.
|
1149 |
+
|
1150 |
+
Returns:
|
1151 |
+
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
|
1152 |
+
"""
|
1153 |
+
|
1154 |
+
if "quantization_config" in config_dict:
|
1155 |
+
config_dict = dict(
|
1156 |
+
sparsity_config=config_dict.get("sparsity_config"),
|
1157 |
+
**config_dict["quantization_config"],
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
1161 |
+
|
1162 |
+
def to_dict(self) -> Dict[str, Any]:
|
1163 |
+
"""
|
1164 |
+
Quantization config to be added to config.json
|
1165 |
+
|
1166 |
+
Serializes this instance to a Python dictionary. Returns:
|
1167 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
1168 |
+
"""
|
1169 |
+
quantization_config = {}
|
1170 |
+
if self.quantization_config is not None:
|
1171 |
+
quantization_config = self.quantization_config.dict()
|
1172 |
+
else:
|
1173 |
+
quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS
|
1174 |
+
|
1175 |
+
if self.sparsity_config is not None:
|
1176 |
+
quantization_config["sparsity_config"] = self.sparsity_config.dict()
|
1177 |
+
else:
|
1178 |
+
quantization_config["sparsity_config"] = {}
|
1179 |
+
|
1180 |
+
return quantization_config
|
1181 |
+
|
1182 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
1183 |
+
"""
|
1184 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
1185 |
+
serializes to a Python dictionary.
|
1186 |
+
Returns:
|
1187 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
1188 |
+
"""
|
1189 |
+
config_dict = self.to_dict()
|
1190 |
+
|
1191 |
+
# get the default config dict
|
1192 |
+
default_config_dict = CompressedTensorsConfig().to_dict()
|
1193 |
+
|
1194 |
+
serializable_config_dict = {}
|
1195 |
+
|
1196 |
+
# only serialize values that differ from the default config
|
1197 |
+
for key, value in config_dict.items():
|
1198 |
+
if value != default_config_dict[key]:
|
1199 |
+
serializable_config_dict[key] = value
|
1200 |
+
|
1201 |
+
return serializable_config_dict
|
1202 |
+
|
1203 |
+
|
1204 |
+
@dataclass
|
1205 |
+
class FbgemmFp8Config(QuantizationConfigMixin):
|
1206 |
+
"""
|
1207 |
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
1208 |
+
loaded using fbgemm fp8 quantization.
|
1209 |
+
|
1210 |
+
Args:
|
1211 |
+
activation_scale_ub (`float`, *optional*, defaults to 1200.0):
|
1212 |
+
The activation scale upper bound. This is used when quantizing the input activation.
|
1213 |
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
1214 |
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
1215 |
+
some modules left in their original precision.
|
1216 |
+
"""
|
1217 |
+
|
1218 |
+
def __init__(
|
1219 |
+
self,
|
1220 |
+
activation_scale_ub: float = 1200.0,
|
1221 |
+
modules_to_not_convert: Optional[List] = None,
|
1222 |
+
**kwargs,
|
1223 |
+
):
|
1224 |
+
self.quant_method = QuantizationMethod.FBGEMM_FP8
|
1225 |
+
self.activation_scale_ub = activation_scale_ub
|
1226 |
+
self.modules_to_not_convert = modules_to_not_convert
|
1227 |
+
|
1228 |
+
def get_loading_attributes(self):
|
1229 |
+
attibutes_dict = copy.deepcopy(self.__dict__)
|
1230 |
+
loading_attibutes = ["activation_scale_ub"]
|
1231 |
+
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
1232 |
+
return loading_attibutes_dict
|
1233 |
+
|
1234 |
+
|
1235 |
+
@dataclass
|
1236 |
+
class TorchAoConfig(QuantizationConfigMixin):
|
1237 |
+
"""This is a config class for torchao quantization/sparsity techniques.
|
1238 |
+
|
1239 |
+
Args:
|
1240 |
+
quant_type (`str`):
|
1241 |
+
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
1242 |
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
1243 |
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
1244 |
+
some modules left in their original precision.
|
1245 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
1246 |
+
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
|
1247 |
+
`group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
|
1248 |
+
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
1249 |
+
|
1250 |
+
Example:
|
1251 |
+
|
1252 |
+
```python
|
1253 |
+
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
1254 |
+
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
1255 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
1256 |
+
```
|
1257 |
+
"""
|
1258 |
+
|
1259 |
+
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
|
1260 |
+
self.quant_method = QuantizationMethod.TORCHAO
|
1261 |
+
self.quant_type = quant_type
|
1262 |
+
self.modules_to_not_convert = modules_to_not_convert
|
1263 |
+
# when we load from serailized config, "quant_type_kwargs" will be the key
|
1264 |
+
if "quant_type_kwargs" in kwargs:
|
1265 |
+
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
1266 |
+
else:
|
1267 |
+
self.quant_type_kwargs = kwargs
|
1268 |
+
|
1269 |
+
self.post_init()
|
1270 |
+
|
1271 |
+
def post_init(self):
|
1272 |
+
r"""
|
1273 |
+
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
1274 |
+
"""
|
1275 |
+
if is_torchao_available():
|
1276 |
+
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
1277 |
+
raise ValueError("Requires torchao 0.4.0 version and above")
|
1278 |
+
else:
|
1279 |
+
raise ValueError(
|
1280 |
+
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
1281 |
+
)
|
1282 |
+
|
1283 |
+
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
1284 |
+
if self.quant_type not in _STR_TO_METHOD.keys():
|
1285 |
+
raise ValueError(
|
1286 |
+
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
method = _STR_TO_METHOD[self.quant_type]
|
1290 |
+
sig = signature(method)
|
1291 |
+
all_kwargs = [
|
1292 |
+
param.name
|
1293 |
+
for param in sig.parameters.values()
|
1294 |
+
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
1295 |
+
]
|
1296 |
+
for k in self.quant_type_kwargs:
|
1297 |
+
if k not in all_kwargs:
|
1298 |
+
raise ValueError(
|
1299 |
+
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
def _get_torchao_quant_type_to_method(self):
|
1303 |
+
if is_torchao_available():
|
1304 |
+
from torchao.quantization import (
|
1305 |
+
int4_weight_only,
|
1306 |
+
int8_dynamic_activation_int8_weight,
|
1307 |
+
int8_weight_only,
|
1308 |
+
)
|
1309 |
+
|
1310 |
+
return {
|
1311 |
+
"int4_weight_only": int4_weight_only,
|
1312 |
+
"int8_weight_only": int8_weight_only,
|
1313 |
+
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
1314 |
+
}
|
1315 |
+
else:
|
1316 |
+
raise ValueError(
|
1317 |
+
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
1318 |
+
)
|
1319 |
+
|
1320 |
+
def get_apply_tensor_subclass(self):
|
1321 |
+
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
1322 |
+
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
|
1323 |
+
|
1324 |
+
def __repr__(self):
|
1325 |
+
config_dict = self.to_dict()
|
1326 |
+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
1327 |
+
|
1328 |
+
|
1329 |
+
@dataclass
|
1330 |
+
class BitNetConfig(QuantizationConfigMixin):
|
1331 |
+
def __init__(
|
1332 |
+
self,
|
1333 |
+
modules_to_not_convert: Optional[List] = None,
|
1334 |
+
**kwargs,
|
1335 |
+
):
|
1336 |
+
self.quant_method = QuantizationMethod.BITNET
|
1337 |
+
self.modules_to_not_convert = modules_to_not_convert
|
1338 |
+
self.post_init()
|
1339 |
+
|
1340 |
+
def post_init(self):
|
1341 |
+
r"""
|
1342 |
+
Safety checker that arguments are correct
|
1343 |
+
"""
|
1344 |
+
pass
|
.venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2.py
ADDED
@@ -0,0 +1,1511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2 |
+
# source: sentencepiece_model.proto
|
3 |
+
|
4 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
from google.protobuf import descriptor as _descriptor
|
18 |
+
from google.protobuf import message as _message
|
19 |
+
from google.protobuf import reflection as _reflection
|
20 |
+
from google.protobuf import symbol_database as _symbol_database
|
21 |
+
|
22 |
+
|
23 |
+
# @@protoc_insertion_point(imports)
|
24 |
+
|
25 |
+
_sym_db = _symbol_database.Default()
|
26 |
+
|
27 |
+
|
28 |
+
DESCRIPTOR = _descriptor.FileDescriptor(
|
29 |
+
name="sentencepiece_model.proto",
|
30 |
+
package="sentencepiece",
|
31 |
+
syntax="proto2",
|
32 |
+
serialized_options=b"H\003",
|
33 |
+
create_key=_descriptor._internal_create_key,
|
34 |
+
serialized_pb=(
|
35 |
+
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\xa1\n\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01'
|
36 |
+
b" \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02"
|
37 |
+
b" \x01(\t\x12\x41\n\nmodel_type\x18\x03"
|
38 |
+
b" \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04"
|
39 |
+
b" \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12"
|
40 |
+
b' \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n'
|
41 |
+
b" \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b"
|
42 |
+
b" \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12"
|
43 |
+
b' \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r'
|
44 |
+
b" \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e"
|
45 |
+
b" \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f"
|
46 |
+
b" \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12"
|
47 |
+
b" \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10"
|
48 |
+
b" \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11"
|
49 |
+
b" \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14"
|
50 |
+
b" \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15"
|
51 |
+
b" \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17"
|
52 |
+
b" \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16"
|
53 |
+
b" \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18"
|
54 |
+
b" \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19"
|
55 |
+
b" \x01(\x08:\x05\x66\x61lse\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e"
|
56 |
+
b" \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$"
|
57 |
+
b" \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18"
|
58 |
+
b' \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18"'
|
59 |
+
b" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18)"
|
60 |
+
b" \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+"
|
61 |
+
b" \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18."
|
62 |
+
b" \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30"
|
63 |
+
b" \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87"
|
64 |
+
b" \x12+\n\x1ctrain_extremely_large_corpus\x18\x31"
|
65 |
+
b' \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01'
|
66 |
+
b" \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03"
|
67 |
+
b" \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12"
|
68 |
+
b" \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06"
|
69 |
+
b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01'
|
70 |
+
b' \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01'
|
71 |
+
b" \x01(\t\x12\x10\n\x08\x65xpected\x18\x02"
|
72 |
+
b' \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01'
|
73 |
+
b" \x03(\x0b\x32'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02"
|
74 |
+
b" \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03"
|
75 |
+
b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04"
|
76 |
+
b" \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05"
|
77 |
+
b" \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01"
|
78 |
+
b" \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03"
|
79 |
+
b' \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
|
80 |
+
),
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor(
|
85 |
+
name="ModelType",
|
86 |
+
full_name="sentencepiece.TrainerSpec.ModelType",
|
87 |
+
filename=None,
|
88 |
+
file=DESCRIPTOR,
|
89 |
+
create_key=_descriptor._internal_create_key,
|
90 |
+
values=[
|
91 |
+
_descriptor.EnumValueDescriptor(
|
92 |
+
name="UNIGRAM",
|
93 |
+
index=0,
|
94 |
+
number=1,
|
95 |
+
serialized_options=None,
|
96 |
+
type=None,
|
97 |
+
create_key=_descriptor._internal_create_key,
|
98 |
+
),
|
99 |
+
_descriptor.EnumValueDescriptor(
|
100 |
+
name="BPE",
|
101 |
+
index=1,
|
102 |
+
number=2,
|
103 |
+
serialized_options=None,
|
104 |
+
type=None,
|
105 |
+
create_key=_descriptor._internal_create_key,
|
106 |
+
),
|
107 |
+
_descriptor.EnumValueDescriptor(
|
108 |
+
name="WORD",
|
109 |
+
index=2,
|
110 |
+
number=3,
|
111 |
+
serialized_options=None,
|
112 |
+
type=None,
|
113 |
+
create_key=_descriptor._internal_create_key,
|
114 |
+
),
|
115 |
+
_descriptor.EnumValueDescriptor(
|
116 |
+
name="CHAR",
|
117 |
+
index=3,
|
118 |
+
number=4,
|
119 |
+
serialized_options=None,
|
120 |
+
type=None,
|
121 |
+
create_key=_descriptor._internal_create_key,
|
122 |
+
),
|
123 |
+
],
|
124 |
+
containing_type=None,
|
125 |
+
serialized_options=None,
|
126 |
+
serialized_start=1294,
|
127 |
+
serialized_end=1347,
|
128 |
+
)
|
129 |
+
_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE)
|
130 |
+
|
131 |
+
_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor(
|
132 |
+
name="Type",
|
133 |
+
full_name="sentencepiece.ModelProto.SentencePiece.Type",
|
134 |
+
filename=None,
|
135 |
+
file=DESCRIPTOR,
|
136 |
+
create_key=_descriptor._internal_create_key,
|
137 |
+
values=[
|
138 |
+
_descriptor.EnumValueDescriptor(
|
139 |
+
name="NORMAL",
|
140 |
+
index=0,
|
141 |
+
number=1,
|
142 |
+
serialized_options=None,
|
143 |
+
type=None,
|
144 |
+
create_key=_descriptor._internal_create_key,
|
145 |
+
),
|
146 |
+
_descriptor.EnumValueDescriptor(
|
147 |
+
name="UNKNOWN",
|
148 |
+
index=1,
|
149 |
+
number=2,
|
150 |
+
serialized_options=None,
|
151 |
+
type=None,
|
152 |
+
create_key=_descriptor._internal_create_key,
|
153 |
+
),
|
154 |
+
_descriptor.EnumValueDescriptor(
|
155 |
+
name="CONTROL",
|
156 |
+
index=2,
|
157 |
+
number=3,
|
158 |
+
serialized_options=None,
|
159 |
+
type=None,
|
160 |
+
create_key=_descriptor._internal_create_key,
|
161 |
+
),
|
162 |
+
_descriptor.EnumValueDescriptor(
|
163 |
+
name="USER_DEFINED",
|
164 |
+
index=3,
|
165 |
+
number=4,
|
166 |
+
serialized_options=None,
|
167 |
+
type=None,
|
168 |
+
create_key=_descriptor._internal_create_key,
|
169 |
+
),
|
170 |
+
_descriptor.EnumValueDescriptor(
|
171 |
+
name="BYTE",
|
172 |
+
index=4,
|
173 |
+
number=6,
|
174 |
+
serialized_options=None,
|
175 |
+
type=None,
|
176 |
+
create_key=_descriptor._internal_create_key,
|
177 |
+
),
|
178 |
+
_descriptor.EnumValueDescriptor(
|
179 |
+
name="UNUSED",
|
180 |
+
index=5,
|
181 |
+
number=5,
|
182 |
+
serialized_options=None,
|
183 |
+
type=None,
|
184 |
+
create_key=_descriptor._internal_create_key,
|
185 |
+
),
|
186 |
+
],
|
187 |
+
containing_type=None,
|
188 |
+
serialized_options=None,
|
189 |
+
serialized_start=2100,
|
190 |
+
serialized_end=2184,
|
191 |
+
)
|
192 |
+
_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE)
|
193 |
+
|
194 |
+
|
195 |
+
_TRAINERSPEC = _descriptor.Descriptor(
|
196 |
+
name="TrainerSpec",
|
197 |
+
full_name="sentencepiece.TrainerSpec",
|
198 |
+
filename=None,
|
199 |
+
file=DESCRIPTOR,
|
200 |
+
containing_type=None,
|
201 |
+
create_key=_descriptor._internal_create_key,
|
202 |
+
fields=[
|
203 |
+
_descriptor.FieldDescriptor(
|
204 |
+
name="input",
|
205 |
+
full_name="sentencepiece.TrainerSpec.input",
|
206 |
+
index=0,
|
207 |
+
number=1,
|
208 |
+
type=9,
|
209 |
+
cpp_type=9,
|
210 |
+
label=3,
|
211 |
+
has_default_value=False,
|
212 |
+
default_value=[],
|
213 |
+
message_type=None,
|
214 |
+
enum_type=None,
|
215 |
+
containing_type=None,
|
216 |
+
is_extension=False,
|
217 |
+
extension_scope=None,
|
218 |
+
serialized_options=None,
|
219 |
+
file=DESCRIPTOR,
|
220 |
+
create_key=_descriptor._internal_create_key,
|
221 |
+
),
|
222 |
+
_descriptor.FieldDescriptor(
|
223 |
+
name="input_format",
|
224 |
+
full_name="sentencepiece.TrainerSpec.input_format",
|
225 |
+
index=1,
|
226 |
+
number=7,
|
227 |
+
type=9,
|
228 |
+
cpp_type=9,
|
229 |
+
label=1,
|
230 |
+
has_default_value=False,
|
231 |
+
default_value=b"".decode("utf-8"),
|
232 |
+
message_type=None,
|
233 |
+
enum_type=None,
|
234 |
+
containing_type=None,
|
235 |
+
is_extension=False,
|
236 |
+
extension_scope=None,
|
237 |
+
serialized_options=None,
|
238 |
+
file=DESCRIPTOR,
|
239 |
+
create_key=_descriptor._internal_create_key,
|
240 |
+
),
|
241 |
+
_descriptor.FieldDescriptor(
|
242 |
+
name="model_prefix",
|
243 |
+
full_name="sentencepiece.TrainerSpec.model_prefix",
|
244 |
+
index=2,
|
245 |
+
number=2,
|
246 |
+
type=9,
|
247 |
+
cpp_type=9,
|
248 |
+
label=1,
|
249 |
+
has_default_value=False,
|
250 |
+
default_value=b"".decode("utf-8"),
|
251 |
+
message_type=None,
|
252 |
+
enum_type=None,
|
253 |
+
containing_type=None,
|
254 |
+
is_extension=False,
|
255 |
+
extension_scope=None,
|
256 |
+
serialized_options=None,
|
257 |
+
file=DESCRIPTOR,
|
258 |
+
create_key=_descriptor._internal_create_key,
|
259 |
+
),
|
260 |
+
_descriptor.FieldDescriptor(
|
261 |
+
name="model_type",
|
262 |
+
full_name="sentencepiece.TrainerSpec.model_type",
|
263 |
+
index=3,
|
264 |
+
number=3,
|
265 |
+
type=14,
|
266 |
+
cpp_type=8,
|
267 |
+
label=1,
|
268 |
+
has_default_value=True,
|
269 |
+
default_value=1,
|
270 |
+
message_type=None,
|
271 |
+
enum_type=None,
|
272 |
+
containing_type=None,
|
273 |
+
is_extension=False,
|
274 |
+
extension_scope=None,
|
275 |
+
serialized_options=None,
|
276 |
+
file=DESCRIPTOR,
|
277 |
+
create_key=_descriptor._internal_create_key,
|
278 |
+
),
|
279 |
+
_descriptor.FieldDescriptor(
|
280 |
+
name="vocab_size",
|
281 |
+
full_name="sentencepiece.TrainerSpec.vocab_size",
|
282 |
+
index=4,
|
283 |
+
number=4,
|
284 |
+
type=5,
|
285 |
+
cpp_type=1,
|
286 |
+
label=1,
|
287 |
+
has_default_value=True,
|
288 |
+
default_value=8000,
|
289 |
+
message_type=None,
|
290 |
+
enum_type=None,
|
291 |
+
containing_type=None,
|
292 |
+
is_extension=False,
|
293 |
+
extension_scope=None,
|
294 |
+
serialized_options=None,
|
295 |
+
file=DESCRIPTOR,
|
296 |
+
create_key=_descriptor._internal_create_key,
|
297 |
+
),
|
298 |
+
_descriptor.FieldDescriptor(
|
299 |
+
name="accept_language",
|
300 |
+
full_name="sentencepiece.TrainerSpec.accept_language",
|
301 |
+
index=5,
|
302 |
+
number=5,
|
303 |
+
type=9,
|
304 |
+
cpp_type=9,
|
305 |
+
label=3,
|
306 |
+
has_default_value=False,
|
307 |
+
default_value=[],
|
308 |
+
message_type=None,
|
309 |
+
enum_type=None,
|
310 |
+
containing_type=None,
|
311 |
+
is_extension=False,
|
312 |
+
extension_scope=None,
|
313 |
+
serialized_options=None,
|
314 |
+
file=DESCRIPTOR,
|
315 |
+
create_key=_descriptor._internal_create_key,
|
316 |
+
),
|
317 |
+
_descriptor.FieldDescriptor(
|
318 |
+
name="self_test_sample_size",
|
319 |
+
full_name="sentencepiece.TrainerSpec.self_test_sample_size",
|
320 |
+
index=6,
|
321 |
+
number=6,
|
322 |
+
type=5,
|
323 |
+
cpp_type=1,
|
324 |
+
label=1,
|
325 |
+
has_default_value=True,
|
326 |
+
default_value=0,
|
327 |
+
message_type=None,
|
328 |
+
enum_type=None,
|
329 |
+
containing_type=None,
|
330 |
+
is_extension=False,
|
331 |
+
extension_scope=None,
|
332 |
+
serialized_options=None,
|
333 |
+
file=DESCRIPTOR,
|
334 |
+
create_key=_descriptor._internal_create_key,
|
335 |
+
),
|
336 |
+
_descriptor.FieldDescriptor(
|
337 |
+
name="character_coverage",
|
338 |
+
full_name="sentencepiece.TrainerSpec.character_coverage",
|
339 |
+
index=7,
|
340 |
+
number=10,
|
341 |
+
type=2,
|
342 |
+
cpp_type=6,
|
343 |
+
label=1,
|
344 |
+
has_default_value=True,
|
345 |
+
default_value=float(0.9995),
|
346 |
+
message_type=None,
|
347 |
+
enum_type=None,
|
348 |
+
containing_type=None,
|
349 |
+
is_extension=False,
|
350 |
+
extension_scope=None,
|
351 |
+
serialized_options=None,
|
352 |
+
file=DESCRIPTOR,
|
353 |
+
create_key=_descriptor._internal_create_key,
|
354 |
+
),
|
355 |
+
_descriptor.FieldDescriptor(
|
356 |
+
name="input_sentence_size",
|
357 |
+
full_name="sentencepiece.TrainerSpec.input_sentence_size",
|
358 |
+
index=8,
|
359 |
+
number=11,
|
360 |
+
type=4,
|
361 |
+
cpp_type=4,
|
362 |
+
label=1,
|
363 |
+
has_default_value=True,
|
364 |
+
default_value=0,
|
365 |
+
message_type=None,
|
366 |
+
enum_type=None,
|
367 |
+
containing_type=None,
|
368 |
+
is_extension=False,
|
369 |
+
extension_scope=None,
|
370 |
+
serialized_options=None,
|
371 |
+
file=DESCRIPTOR,
|
372 |
+
create_key=_descriptor._internal_create_key,
|
373 |
+
),
|
374 |
+
_descriptor.FieldDescriptor(
|
375 |
+
name="shuffle_input_sentence",
|
376 |
+
full_name="sentencepiece.TrainerSpec.shuffle_input_sentence",
|
377 |
+
index=9,
|
378 |
+
number=19,
|
379 |
+
type=8,
|
380 |
+
cpp_type=7,
|
381 |
+
label=1,
|
382 |
+
has_default_value=True,
|
383 |
+
default_value=True,
|
384 |
+
message_type=None,
|
385 |
+
enum_type=None,
|
386 |
+
containing_type=None,
|
387 |
+
is_extension=False,
|
388 |
+
extension_scope=None,
|
389 |
+
serialized_options=None,
|
390 |
+
file=DESCRIPTOR,
|
391 |
+
create_key=_descriptor._internal_create_key,
|
392 |
+
),
|
393 |
+
_descriptor.FieldDescriptor(
|
394 |
+
name="mining_sentence_size",
|
395 |
+
full_name="sentencepiece.TrainerSpec.mining_sentence_size",
|
396 |
+
index=10,
|
397 |
+
number=12,
|
398 |
+
type=5,
|
399 |
+
cpp_type=1,
|
400 |
+
label=1,
|
401 |
+
has_default_value=False,
|
402 |
+
default_value=0,
|
403 |
+
message_type=None,
|
404 |
+
enum_type=None,
|
405 |
+
containing_type=None,
|
406 |
+
is_extension=False,
|
407 |
+
extension_scope=None,
|
408 |
+
serialized_options=b"\030\001",
|
409 |
+
file=DESCRIPTOR,
|
410 |
+
create_key=_descriptor._internal_create_key,
|
411 |
+
),
|
412 |
+
_descriptor.FieldDescriptor(
|
413 |
+
name="training_sentence_size",
|
414 |
+
full_name="sentencepiece.TrainerSpec.training_sentence_size",
|
415 |
+
index=11,
|
416 |
+
number=13,
|
417 |
+
type=5,
|
418 |
+
cpp_type=1,
|
419 |
+
label=1,
|
420 |
+
has_default_value=False,
|
421 |
+
default_value=0,
|
422 |
+
message_type=None,
|
423 |
+
enum_type=None,
|
424 |
+
containing_type=None,
|
425 |
+
is_extension=False,
|
426 |
+
extension_scope=None,
|
427 |
+
serialized_options=b"\030\001",
|
428 |
+
file=DESCRIPTOR,
|
429 |
+
create_key=_descriptor._internal_create_key,
|
430 |
+
),
|
431 |
+
_descriptor.FieldDescriptor(
|
432 |
+
name="seed_sentencepiece_size",
|
433 |
+
full_name="sentencepiece.TrainerSpec.seed_sentencepiece_size",
|
434 |
+
index=12,
|
435 |
+
number=14,
|
436 |
+
type=5,
|
437 |
+
cpp_type=1,
|
438 |
+
label=1,
|
439 |
+
has_default_value=True,
|
440 |
+
default_value=1000000,
|
441 |
+
message_type=None,
|
442 |
+
enum_type=None,
|
443 |
+
containing_type=None,
|
444 |
+
is_extension=False,
|
445 |
+
extension_scope=None,
|
446 |
+
serialized_options=None,
|
447 |
+
file=DESCRIPTOR,
|
448 |
+
create_key=_descriptor._internal_create_key,
|
449 |
+
),
|
450 |
+
_descriptor.FieldDescriptor(
|
451 |
+
name="shrinking_factor",
|
452 |
+
full_name="sentencepiece.TrainerSpec.shrinking_factor",
|
453 |
+
index=13,
|
454 |
+
number=15,
|
455 |
+
type=2,
|
456 |
+
cpp_type=6,
|
457 |
+
label=1,
|
458 |
+
has_default_value=True,
|
459 |
+
default_value=float(0.75),
|
460 |
+
message_type=None,
|
461 |
+
enum_type=None,
|
462 |
+
containing_type=None,
|
463 |
+
is_extension=False,
|
464 |
+
extension_scope=None,
|
465 |
+
serialized_options=None,
|
466 |
+
file=DESCRIPTOR,
|
467 |
+
create_key=_descriptor._internal_create_key,
|
468 |
+
),
|
469 |
+
_descriptor.FieldDescriptor(
|
470 |
+
name="max_sentence_length",
|
471 |
+
full_name="sentencepiece.TrainerSpec.max_sentence_length",
|
472 |
+
index=14,
|
473 |
+
number=18,
|
474 |
+
type=5,
|
475 |
+
cpp_type=1,
|
476 |
+
label=1,
|
477 |
+
has_default_value=True,
|
478 |
+
default_value=4192,
|
479 |
+
message_type=None,
|
480 |
+
enum_type=None,
|
481 |
+
containing_type=None,
|
482 |
+
is_extension=False,
|
483 |
+
extension_scope=None,
|
484 |
+
serialized_options=None,
|
485 |
+
file=DESCRIPTOR,
|
486 |
+
create_key=_descriptor._internal_create_key,
|
487 |
+
),
|
488 |
+
_descriptor.FieldDescriptor(
|
489 |
+
name="num_threads",
|
490 |
+
full_name="sentencepiece.TrainerSpec.num_threads",
|
491 |
+
index=15,
|
492 |
+
number=16,
|
493 |
+
type=5,
|
494 |
+
cpp_type=1,
|
495 |
+
label=1,
|
496 |
+
has_default_value=True,
|
497 |
+
default_value=16,
|
498 |
+
message_type=None,
|
499 |
+
enum_type=None,
|
500 |
+
containing_type=None,
|
501 |
+
is_extension=False,
|
502 |
+
extension_scope=None,
|
503 |
+
serialized_options=None,
|
504 |
+
file=DESCRIPTOR,
|
505 |
+
create_key=_descriptor._internal_create_key,
|
506 |
+
),
|
507 |
+
_descriptor.FieldDescriptor(
|
508 |
+
name="num_sub_iterations",
|
509 |
+
full_name="sentencepiece.TrainerSpec.num_sub_iterations",
|
510 |
+
index=16,
|
511 |
+
number=17,
|
512 |
+
type=5,
|
513 |
+
cpp_type=1,
|
514 |
+
label=1,
|
515 |
+
has_default_value=True,
|
516 |
+
default_value=2,
|
517 |
+
message_type=None,
|
518 |
+
enum_type=None,
|
519 |
+
containing_type=None,
|
520 |
+
is_extension=False,
|
521 |
+
extension_scope=None,
|
522 |
+
serialized_options=None,
|
523 |
+
file=DESCRIPTOR,
|
524 |
+
create_key=_descriptor._internal_create_key,
|
525 |
+
),
|
526 |
+
_descriptor.FieldDescriptor(
|
527 |
+
name="max_sentencepiece_length",
|
528 |
+
full_name="sentencepiece.TrainerSpec.max_sentencepiece_length",
|
529 |
+
index=17,
|
530 |
+
number=20,
|
531 |
+
type=5,
|
532 |
+
cpp_type=1,
|
533 |
+
label=1,
|
534 |
+
has_default_value=True,
|
535 |
+
default_value=16,
|
536 |
+
message_type=None,
|
537 |
+
enum_type=None,
|
538 |
+
containing_type=None,
|
539 |
+
is_extension=False,
|
540 |
+
extension_scope=None,
|
541 |
+
serialized_options=None,
|
542 |
+
file=DESCRIPTOR,
|
543 |
+
create_key=_descriptor._internal_create_key,
|
544 |
+
),
|
545 |
+
_descriptor.FieldDescriptor(
|
546 |
+
name="split_by_unicode_script",
|
547 |
+
full_name="sentencepiece.TrainerSpec.split_by_unicode_script",
|
548 |
+
index=18,
|
549 |
+
number=21,
|
550 |
+
type=8,
|
551 |
+
cpp_type=7,
|
552 |
+
label=1,
|
553 |
+
has_default_value=True,
|
554 |
+
default_value=True,
|
555 |
+
message_type=None,
|
556 |
+
enum_type=None,
|
557 |
+
containing_type=None,
|
558 |
+
is_extension=False,
|
559 |
+
extension_scope=None,
|
560 |
+
serialized_options=None,
|
561 |
+
file=DESCRIPTOR,
|
562 |
+
create_key=_descriptor._internal_create_key,
|
563 |
+
),
|
564 |
+
_descriptor.FieldDescriptor(
|
565 |
+
name="split_by_number",
|
566 |
+
full_name="sentencepiece.TrainerSpec.split_by_number",
|
567 |
+
index=19,
|
568 |
+
number=23,
|
569 |
+
type=8,
|
570 |
+
cpp_type=7,
|
571 |
+
label=1,
|
572 |
+
has_default_value=True,
|
573 |
+
default_value=True,
|
574 |
+
message_type=None,
|
575 |
+
enum_type=None,
|
576 |
+
containing_type=None,
|
577 |
+
is_extension=False,
|
578 |
+
extension_scope=None,
|
579 |
+
serialized_options=None,
|
580 |
+
file=DESCRIPTOR,
|
581 |
+
create_key=_descriptor._internal_create_key,
|
582 |
+
),
|
583 |
+
_descriptor.FieldDescriptor(
|
584 |
+
name="split_by_whitespace",
|
585 |
+
full_name="sentencepiece.TrainerSpec.split_by_whitespace",
|
586 |
+
index=20,
|
587 |
+
number=22,
|
588 |
+
type=8,
|
589 |
+
cpp_type=7,
|
590 |
+
label=1,
|
591 |
+
has_default_value=True,
|
592 |
+
default_value=True,
|
593 |
+
message_type=None,
|
594 |
+
enum_type=None,
|
595 |
+
containing_type=None,
|
596 |
+
is_extension=False,
|
597 |
+
extension_scope=None,
|
598 |
+
serialized_options=None,
|
599 |
+
file=DESCRIPTOR,
|
600 |
+
create_key=_descriptor._internal_create_key,
|
601 |
+
),
|
602 |
+
_descriptor.FieldDescriptor(
|
603 |
+
name="treat_whitespace_as_suffix",
|
604 |
+
full_name="sentencepiece.TrainerSpec.treat_whitespace_as_suffix",
|
605 |
+
index=21,
|
606 |
+
number=24,
|
607 |
+
type=8,
|
608 |
+
cpp_type=7,
|
609 |
+
label=1,
|
610 |
+
has_default_value=True,
|
611 |
+
default_value=False,
|
612 |
+
message_type=None,
|
613 |
+
enum_type=None,
|
614 |
+
containing_type=None,
|
615 |
+
is_extension=False,
|
616 |
+
extension_scope=None,
|
617 |
+
serialized_options=None,
|
618 |
+
file=DESCRIPTOR,
|
619 |
+
create_key=_descriptor._internal_create_key,
|
620 |
+
),
|
621 |
+
_descriptor.FieldDescriptor(
|
622 |
+
name="split_digits",
|
623 |
+
full_name="sentencepiece.TrainerSpec.split_digits",
|
624 |
+
index=22,
|
625 |
+
number=25,
|
626 |
+
type=8,
|
627 |
+
cpp_type=7,
|
628 |
+
label=1,
|
629 |
+
has_default_value=True,
|
630 |
+
default_value=False,
|
631 |
+
message_type=None,
|
632 |
+
enum_type=None,
|
633 |
+
containing_type=None,
|
634 |
+
is_extension=False,
|
635 |
+
extension_scope=None,
|
636 |
+
serialized_options=None,
|
637 |
+
file=DESCRIPTOR,
|
638 |
+
create_key=_descriptor._internal_create_key,
|
639 |
+
),
|
640 |
+
_descriptor.FieldDescriptor(
|
641 |
+
name="control_symbols",
|
642 |
+
full_name="sentencepiece.TrainerSpec.control_symbols",
|
643 |
+
index=23,
|
644 |
+
number=30,
|
645 |
+
type=9,
|
646 |
+
cpp_type=9,
|
647 |
+
label=3,
|
648 |
+
has_default_value=False,
|
649 |
+
default_value=[],
|
650 |
+
message_type=None,
|
651 |
+
enum_type=None,
|
652 |
+
containing_type=None,
|
653 |
+
is_extension=False,
|
654 |
+
extension_scope=None,
|
655 |
+
serialized_options=None,
|
656 |
+
file=DESCRIPTOR,
|
657 |
+
create_key=_descriptor._internal_create_key,
|
658 |
+
),
|
659 |
+
_descriptor.FieldDescriptor(
|
660 |
+
name="user_defined_symbols",
|
661 |
+
full_name="sentencepiece.TrainerSpec.user_defined_symbols",
|
662 |
+
index=24,
|
663 |
+
number=31,
|
664 |
+
type=9,
|
665 |
+
cpp_type=9,
|
666 |
+
label=3,
|
667 |
+
has_default_value=False,
|
668 |
+
default_value=[],
|
669 |
+
message_type=None,
|
670 |
+
enum_type=None,
|
671 |
+
containing_type=None,
|
672 |
+
is_extension=False,
|
673 |
+
extension_scope=None,
|
674 |
+
serialized_options=None,
|
675 |
+
file=DESCRIPTOR,
|
676 |
+
create_key=_descriptor._internal_create_key,
|
677 |
+
),
|
678 |
+
_descriptor.FieldDescriptor(
|
679 |
+
name="required_chars",
|
680 |
+
full_name="sentencepiece.TrainerSpec.required_chars",
|
681 |
+
index=25,
|
682 |
+
number=36,
|
683 |
+
type=9,
|
684 |
+
cpp_type=9,
|
685 |
+
label=1,
|
686 |
+
has_default_value=False,
|
687 |
+
default_value=b"".decode("utf-8"),
|
688 |
+
message_type=None,
|
689 |
+
enum_type=None,
|
690 |
+
containing_type=None,
|
691 |
+
is_extension=False,
|
692 |
+
extension_scope=None,
|
693 |
+
serialized_options=None,
|
694 |
+
file=DESCRIPTOR,
|
695 |
+
create_key=_descriptor._internal_create_key,
|
696 |
+
),
|
697 |
+
_descriptor.FieldDescriptor(
|
698 |
+
name="byte_fallback",
|
699 |
+
full_name="sentencepiece.TrainerSpec.byte_fallback",
|
700 |
+
index=26,
|
701 |
+
number=35,
|
702 |
+
type=8,
|
703 |
+
cpp_type=7,
|
704 |
+
label=1,
|
705 |
+
has_default_value=True,
|
706 |
+
default_value=False,
|
707 |
+
message_type=None,
|
708 |
+
enum_type=None,
|
709 |
+
containing_type=None,
|
710 |
+
is_extension=False,
|
711 |
+
extension_scope=None,
|
712 |
+
serialized_options=None,
|
713 |
+
file=DESCRIPTOR,
|
714 |
+
create_key=_descriptor._internal_create_key,
|
715 |
+
),
|
716 |
+
_descriptor.FieldDescriptor(
|
717 |
+
name="vocabulary_output_piece_score",
|
718 |
+
full_name="sentencepiece.TrainerSpec.vocabulary_output_piece_score",
|
719 |
+
index=27,
|
720 |
+
number=32,
|
721 |
+
type=8,
|
722 |
+
cpp_type=7,
|
723 |
+
label=1,
|
724 |
+
has_default_value=True,
|
725 |
+
default_value=True,
|
726 |
+
message_type=None,
|
727 |
+
enum_type=None,
|
728 |
+
containing_type=None,
|
729 |
+
is_extension=False,
|
730 |
+
extension_scope=None,
|
731 |
+
serialized_options=None,
|
732 |
+
file=DESCRIPTOR,
|
733 |
+
create_key=_descriptor._internal_create_key,
|
734 |
+
),
|
735 |
+
_descriptor.FieldDescriptor(
|
736 |
+
name="hard_vocab_limit",
|
737 |
+
full_name="sentencepiece.TrainerSpec.hard_vocab_limit",
|
738 |
+
index=28,
|
739 |
+
number=33,
|
740 |
+
type=8,
|
741 |
+
cpp_type=7,
|
742 |
+
label=1,
|
743 |
+
has_default_value=True,
|
744 |
+
default_value=True,
|
745 |
+
message_type=None,
|
746 |
+
enum_type=None,
|
747 |
+
containing_type=None,
|
748 |
+
is_extension=False,
|
749 |
+
extension_scope=None,
|
750 |
+
serialized_options=None,
|
751 |
+
file=DESCRIPTOR,
|
752 |
+
create_key=_descriptor._internal_create_key,
|
753 |
+
),
|
754 |
+
_descriptor.FieldDescriptor(
|
755 |
+
name="use_all_vocab",
|
756 |
+
full_name="sentencepiece.TrainerSpec.use_all_vocab",
|
757 |
+
index=29,
|
758 |
+
number=34,
|
759 |
+
type=8,
|
760 |
+
cpp_type=7,
|
761 |
+
label=1,
|
762 |
+
has_default_value=True,
|
763 |
+
default_value=False,
|
764 |
+
message_type=None,
|
765 |
+
enum_type=None,
|
766 |
+
containing_type=None,
|
767 |
+
is_extension=False,
|
768 |
+
extension_scope=None,
|
769 |
+
serialized_options=None,
|
770 |
+
file=DESCRIPTOR,
|
771 |
+
create_key=_descriptor._internal_create_key,
|
772 |
+
),
|
773 |
+
_descriptor.FieldDescriptor(
|
774 |
+
name="unk_id",
|
775 |
+
full_name="sentencepiece.TrainerSpec.unk_id",
|
776 |
+
index=30,
|
777 |
+
number=40,
|
778 |
+
type=5,
|
779 |
+
cpp_type=1,
|
780 |
+
label=1,
|
781 |
+
has_default_value=True,
|
782 |
+
default_value=0,
|
783 |
+
message_type=None,
|
784 |
+
enum_type=None,
|
785 |
+
containing_type=None,
|
786 |
+
is_extension=False,
|
787 |
+
extension_scope=None,
|
788 |
+
serialized_options=None,
|
789 |
+
file=DESCRIPTOR,
|
790 |
+
create_key=_descriptor._internal_create_key,
|
791 |
+
),
|
792 |
+
_descriptor.FieldDescriptor(
|
793 |
+
name="bos_id",
|
794 |
+
full_name="sentencepiece.TrainerSpec.bos_id",
|
795 |
+
index=31,
|
796 |
+
number=41,
|
797 |
+
type=5,
|
798 |
+
cpp_type=1,
|
799 |
+
label=1,
|
800 |
+
has_default_value=True,
|
801 |
+
default_value=1,
|
802 |
+
message_type=None,
|
803 |
+
enum_type=None,
|
804 |
+
containing_type=None,
|
805 |
+
is_extension=False,
|
806 |
+
extension_scope=None,
|
807 |
+
serialized_options=None,
|
808 |
+
file=DESCRIPTOR,
|
809 |
+
create_key=_descriptor._internal_create_key,
|
810 |
+
),
|
811 |
+
_descriptor.FieldDescriptor(
|
812 |
+
name="eos_id",
|
813 |
+
full_name="sentencepiece.TrainerSpec.eos_id",
|
814 |
+
index=32,
|
815 |
+
number=42,
|
816 |
+
type=5,
|
817 |
+
cpp_type=1,
|
818 |
+
label=1,
|
819 |
+
has_default_value=True,
|
820 |
+
default_value=2,
|
821 |
+
message_type=None,
|
822 |
+
enum_type=None,
|
823 |
+
containing_type=None,
|
824 |
+
is_extension=False,
|
825 |
+
extension_scope=None,
|
826 |
+
serialized_options=None,
|
827 |
+
file=DESCRIPTOR,
|
828 |
+
create_key=_descriptor._internal_create_key,
|
829 |
+
),
|
830 |
+
_descriptor.FieldDescriptor(
|
831 |
+
name="pad_id",
|
832 |
+
full_name="sentencepiece.TrainerSpec.pad_id",
|
833 |
+
index=33,
|
834 |
+
number=43,
|
835 |
+
type=5,
|
836 |
+
cpp_type=1,
|
837 |
+
label=1,
|
838 |
+
has_default_value=True,
|
839 |
+
default_value=-1,
|
840 |
+
message_type=None,
|
841 |
+
enum_type=None,
|
842 |
+
containing_type=None,
|
843 |
+
is_extension=False,
|
844 |
+
extension_scope=None,
|
845 |
+
serialized_options=None,
|
846 |
+
file=DESCRIPTOR,
|
847 |
+
create_key=_descriptor._internal_create_key,
|
848 |
+
),
|
849 |
+
_descriptor.FieldDescriptor(
|
850 |
+
name="unk_piece",
|
851 |
+
full_name="sentencepiece.TrainerSpec.unk_piece",
|
852 |
+
index=34,
|
853 |
+
number=45,
|
854 |
+
type=9,
|
855 |
+
cpp_type=9,
|
856 |
+
label=1,
|
857 |
+
has_default_value=True,
|
858 |
+
default_value=b"<unk>".decode("utf-8"),
|
859 |
+
message_type=None,
|
860 |
+
enum_type=None,
|
861 |
+
containing_type=None,
|
862 |
+
is_extension=False,
|
863 |
+
extension_scope=None,
|
864 |
+
serialized_options=None,
|
865 |
+
file=DESCRIPTOR,
|
866 |
+
create_key=_descriptor._internal_create_key,
|
867 |
+
),
|
868 |
+
_descriptor.FieldDescriptor(
|
869 |
+
name="bos_piece",
|
870 |
+
full_name="sentencepiece.TrainerSpec.bos_piece",
|
871 |
+
index=35,
|
872 |
+
number=46,
|
873 |
+
type=9,
|
874 |
+
cpp_type=9,
|
875 |
+
label=1,
|
876 |
+
has_default_value=True,
|
877 |
+
default_value=b"<s>".decode("utf-8"),
|
878 |
+
message_type=None,
|
879 |
+
enum_type=None,
|
880 |
+
containing_type=None,
|
881 |
+
is_extension=False,
|
882 |
+
extension_scope=None,
|
883 |
+
serialized_options=None,
|
884 |
+
file=DESCRIPTOR,
|
885 |
+
create_key=_descriptor._internal_create_key,
|
886 |
+
),
|
887 |
+
_descriptor.FieldDescriptor(
|
888 |
+
name="eos_piece",
|
889 |
+
full_name="sentencepiece.TrainerSpec.eos_piece",
|
890 |
+
index=36,
|
891 |
+
number=47,
|
892 |
+
type=9,
|
893 |
+
cpp_type=9,
|
894 |
+
label=1,
|
895 |
+
has_default_value=True,
|
896 |
+
default_value=b"</s>".decode("utf-8"),
|
897 |
+
message_type=None,
|
898 |
+
enum_type=None,
|
899 |
+
containing_type=None,
|
900 |
+
is_extension=False,
|
901 |
+
extension_scope=None,
|
902 |
+
serialized_options=None,
|
903 |
+
file=DESCRIPTOR,
|
904 |
+
create_key=_descriptor._internal_create_key,
|
905 |
+
),
|
906 |
+
_descriptor.FieldDescriptor(
|
907 |
+
name="pad_piece",
|
908 |
+
full_name="sentencepiece.TrainerSpec.pad_piece",
|
909 |
+
index=37,
|
910 |
+
number=48,
|
911 |
+
type=9,
|
912 |
+
cpp_type=9,
|
913 |
+
label=1,
|
914 |
+
has_default_value=True,
|
915 |
+
default_value=b"<pad>".decode("utf-8"),
|
916 |
+
message_type=None,
|
917 |
+
enum_type=None,
|
918 |
+
containing_type=None,
|
919 |
+
is_extension=False,
|
920 |
+
extension_scope=None,
|
921 |
+
serialized_options=None,
|
922 |
+
file=DESCRIPTOR,
|
923 |
+
create_key=_descriptor._internal_create_key,
|
924 |
+
),
|
925 |
+
_descriptor.FieldDescriptor(
|
926 |
+
name="unk_surface",
|
927 |
+
full_name="sentencepiece.TrainerSpec.unk_surface",
|
928 |
+
index=38,
|
929 |
+
number=44,
|
930 |
+
type=9,
|
931 |
+
cpp_type=9,
|
932 |
+
label=1,
|
933 |
+
has_default_value=True,
|
934 |
+
default_value=b" \342\201\207 ".decode("utf-8"),
|
935 |
+
message_type=None,
|
936 |
+
enum_type=None,
|
937 |
+
containing_type=None,
|
938 |
+
is_extension=False,
|
939 |
+
extension_scope=None,
|
940 |
+
serialized_options=None,
|
941 |
+
file=DESCRIPTOR,
|
942 |
+
create_key=_descriptor._internal_create_key,
|
943 |
+
),
|
944 |
+
_descriptor.FieldDescriptor(
|
945 |
+
name="train_extremely_large_corpus",
|
946 |
+
full_name="sentencepiece.TrainerSpec.train_extremely_large_corpus",
|
947 |
+
index=39,
|
948 |
+
number=49,
|
949 |
+
type=8,
|
950 |
+
cpp_type=7,
|
951 |
+
label=1,
|
952 |
+
has_default_value=True,
|
953 |
+
default_value=False,
|
954 |
+
message_type=None,
|
955 |
+
enum_type=None,
|
956 |
+
containing_type=None,
|
957 |
+
is_extension=False,
|
958 |
+
extension_scope=None,
|
959 |
+
serialized_options=None,
|
960 |
+
file=DESCRIPTOR,
|
961 |
+
create_key=_descriptor._internal_create_key,
|
962 |
+
),
|
963 |
+
],
|
964 |
+
extensions=[],
|
965 |
+
nested_types=[],
|
966 |
+
enum_types=[
|
967 |
+
_TRAINERSPEC_MODELTYPE,
|
968 |
+
],
|
969 |
+
serialized_options=None,
|
970 |
+
is_extendable=True,
|
971 |
+
syntax="proto2",
|
972 |
+
extension_ranges=[
|
973 |
+
(200, 536870912),
|
974 |
+
],
|
975 |
+
oneofs=[],
|
976 |
+
serialized_start=45,
|
977 |
+
serialized_end=1358,
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
_NORMALIZERSPEC = _descriptor.Descriptor(
|
982 |
+
name="NormalizerSpec",
|
983 |
+
full_name="sentencepiece.NormalizerSpec",
|
984 |
+
filename=None,
|
985 |
+
file=DESCRIPTOR,
|
986 |
+
containing_type=None,
|
987 |
+
create_key=_descriptor._internal_create_key,
|
988 |
+
fields=[
|
989 |
+
_descriptor.FieldDescriptor(
|
990 |
+
name="name",
|
991 |
+
full_name="sentencepiece.NormalizerSpec.name",
|
992 |
+
index=0,
|
993 |
+
number=1,
|
994 |
+
type=9,
|
995 |
+
cpp_type=9,
|
996 |
+
label=1,
|
997 |
+
has_default_value=False,
|
998 |
+
default_value=b"".decode("utf-8"),
|
999 |
+
message_type=None,
|
1000 |
+
enum_type=None,
|
1001 |
+
containing_type=None,
|
1002 |
+
is_extension=False,
|
1003 |
+
extension_scope=None,
|
1004 |
+
serialized_options=None,
|
1005 |
+
file=DESCRIPTOR,
|
1006 |
+
create_key=_descriptor._internal_create_key,
|
1007 |
+
),
|
1008 |
+
_descriptor.FieldDescriptor(
|
1009 |
+
name="precompiled_charsmap",
|
1010 |
+
full_name="sentencepiece.NormalizerSpec.precompiled_charsmap",
|
1011 |
+
index=1,
|
1012 |
+
number=2,
|
1013 |
+
type=12,
|
1014 |
+
cpp_type=9,
|
1015 |
+
label=1,
|
1016 |
+
has_default_value=False,
|
1017 |
+
default_value=b"",
|
1018 |
+
message_type=None,
|
1019 |
+
enum_type=None,
|
1020 |
+
containing_type=None,
|
1021 |
+
is_extension=False,
|
1022 |
+
extension_scope=None,
|
1023 |
+
serialized_options=None,
|
1024 |
+
file=DESCRIPTOR,
|
1025 |
+
create_key=_descriptor._internal_create_key,
|
1026 |
+
),
|
1027 |
+
_descriptor.FieldDescriptor(
|
1028 |
+
name="add_dummy_prefix",
|
1029 |
+
full_name="sentencepiece.NormalizerSpec.add_dummy_prefix",
|
1030 |
+
index=2,
|
1031 |
+
number=3,
|
1032 |
+
type=8,
|
1033 |
+
cpp_type=7,
|
1034 |
+
label=1,
|
1035 |
+
has_default_value=True,
|
1036 |
+
default_value=True,
|
1037 |
+
message_type=None,
|
1038 |
+
enum_type=None,
|
1039 |
+
containing_type=None,
|
1040 |
+
is_extension=False,
|
1041 |
+
extension_scope=None,
|
1042 |
+
serialized_options=None,
|
1043 |
+
file=DESCRIPTOR,
|
1044 |
+
create_key=_descriptor._internal_create_key,
|
1045 |
+
),
|
1046 |
+
_descriptor.FieldDescriptor(
|
1047 |
+
name="remove_extra_whitespaces",
|
1048 |
+
full_name="sentencepiece.NormalizerSpec.remove_extra_whitespaces",
|
1049 |
+
index=3,
|
1050 |
+
number=4,
|
1051 |
+
type=8,
|
1052 |
+
cpp_type=7,
|
1053 |
+
label=1,
|
1054 |
+
has_default_value=True,
|
1055 |
+
default_value=True,
|
1056 |
+
message_type=None,
|
1057 |
+
enum_type=None,
|
1058 |
+
containing_type=None,
|
1059 |
+
is_extension=False,
|
1060 |
+
extension_scope=None,
|
1061 |
+
serialized_options=None,
|
1062 |
+
file=DESCRIPTOR,
|
1063 |
+
create_key=_descriptor._internal_create_key,
|
1064 |
+
),
|
1065 |
+
_descriptor.FieldDescriptor(
|
1066 |
+
name="escape_whitespaces",
|
1067 |
+
full_name="sentencepiece.NormalizerSpec.escape_whitespaces",
|
1068 |
+
index=4,
|
1069 |
+
number=5,
|
1070 |
+
type=8,
|
1071 |
+
cpp_type=7,
|
1072 |
+
label=1,
|
1073 |
+
has_default_value=True,
|
1074 |
+
default_value=True,
|
1075 |
+
message_type=None,
|
1076 |
+
enum_type=None,
|
1077 |
+
containing_type=None,
|
1078 |
+
is_extension=False,
|
1079 |
+
extension_scope=None,
|
1080 |
+
serialized_options=None,
|
1081 |
+
file=DESCRIPTOR,
|
1082 |
+
create_key=_descriptor._internal_create_key,
|
1083 |
+
),
|
1084 |
+
_descriptor.FieldDescriptor(
|
1085 |
+
name="normalization_rule_tsv",
|
1086 |
+
full_name="sentencepiece.NormalizerSpec.normalization_rule_tsv",
|
1087 |
+
index=5,
|
1088 |
+
number=6,
|
1089 |
+
type=9,
|
1090 |
+
cpp_type=9,
|
1091 |
+
label=1,
|
1092 |
+
has_default_value=False,
|
1093 |
+
default_value=b"".decode("utf-8"),
|
1094 |
+
message_type=None,
|
1095 |
+
enum_type=None,
|
1096 |
+
containing_type=None,
|
1097 |
+
is_extension=False,
|
1098 |
+
extension_scope=None,
|
1099 |
+
serialized_options=None,
|
1100 |
+
file=DESCRIPTOR,
|
1101 |
+
create_key=_descriptor._internal_create_key,
|
1102 |
+
),
|
1103 |
+
],
|
1104 |
+
extensions=[],
|
1105 |
+
nested_types=[],
|
1106 |
+
enum_types=[],
|
1107 |
+
serialized_options=None,
|
1108 |
+
is_extendable=True,
|
1109 |
+
syntax="proto2",
|
1110 |
+
extension_ranges=[
|
1111 |
+
(200, 536870912),
|
1112 |
+
],
|
1113 |
+
oneofs=[],
|
1114 |
+
serialized_start=1361,
|
1115 |
+
serialized_end=1570,
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
|
1119 |
+
_SELFTESTDATA_SAMPLE = _descriptor.Descriptor(
|
1120 |
+
name="Sample",
|
1121 |
+
full_name="sentencepiece.SelfTestData.Sample",
|
1122 |
+
filename=None,
|
1123 |
+
file=DESCRIPTOR,
|
1124 |
+
containing_type=None,
|
1125 |
+
create_key=_descriptor._internal_create_key,
|
1126 |
+
fields=[
|
1127 |
+
_descriptor.FieldDescriptor(
|
1128 |
+
name="input",
|
1129 |
+
full_name="sentencepiece.SelfTestData.Sample.input",
|
1130 |
+
index=0,
|
1131 |
+
number=1,
|
1132 |
+
type=9,
|
1133 |
+
cpp_type=9,
|
1134 |
+
label=1,
|
1135 |
+
has_default_value=False,
|
1136 |
+
default_value=b"".decode("utf-8"),
|
1137 |
+
message_type=None,
|
1138 |
+
enum_type=None,
|
1139 |
+
containing_type=None,
|
1140 |
+
is_extension=False,
|
1141 |
+
extension_scope=None,
|
1142 |
+
serialized_options=None,
|
1143 |
+
file=DESCRIPTOR,
|
1144 |
+
create_key=_descriptor._internal_create_key,
|
1145 |
+
),
|
1146 |
+
_descriptor.FieldDescriptor(
|
1147 |
+
name="expected",
|
1148 |
+
full_name="sentencepiece.SelfTestData.Sample.expected",
|
1149 |
+
index=1,
|
1150 |
+
number=2,
|
1151 |
+
type=9,
|
1152 |
+
cpp_type=9,
|
1153 |
+
label=1,
|
1154 |
+
has_default_value=False,
|
1155 |
+
default_value=b"".decode("utf-8"),
|
1156 |
+
message_type=None,
|
1157 |
+
enum_type=None,
|
1158 |
+
containing_type=None,
|
1159 |
+
is_extension=False,
|
1160 |
+
extension_scope=None,
|
1161 |
+
serialized_options=None,
|
1162 |
+
file=DESCRIPTOR,
|
1163 |
+
create_key=_descriptor._internal_create_key,
|
1164 |
+
),
|
1165 |
+
],
|
1166 |
+
extensions=[],
|
1167 |
+
nested_types=[],
|
1168 |
+
enum_types=[],
|
1169 |
+
serialized_options=None,
|
1170 |
+
is_extendable=False,
|
1171 |
+
syntax="proto2",
|
1172 |
+
extension_ranges=[],
|
1173 |
+
oneofs=[],
|
1174 |
+
serialized_start=1641,
|
1175 |
+
serialized_end=1682,
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
_SELFTESTDATA = _descriptor.Descriptor(
|
1179 |
+
name="SelfTestData",
|
1180 |
+
full_name="sentencepiece.SelfTestData",
|
1181 |
+
filename=None,
|
1182 |
+
file=DESCRIPTOR,
|
1183 |
+
containing_type=None,
|
1184 |
+
create_key=_descriptor._internal_create_key,
|
1185 |
+
fields=[
|
1186 |
+
_descriptor.FieldDescriptor(
|
1187 |
+
name="samples",
|
1188 |
+
full_name="sentencepiece.SelfTestData.samples",
|
1189 |
+
index=0,
|
1190 |
+
number=1,
|
1191 |
+
type=11,
|
1192 |
+
cpp_type=10,
|
1193 |
+
label=3,
|
1194 |
+
has_default_value=False,
|
1195 |
+
default_value=[],
|
1196 |
+
message_type=None,
|
1197 |
+
enum_type=None,
|
1198 |
+
containing_type=None,
|
1199 |
+
is_extension=False,
|
1200 |
+
extension_scope=None,
|
1201 |
+
serialized_options=None,
|
1202 |
+
file=DESCRIPTOR,
|
1203 |
+
create_key=_descriptor._internal_create_key,
|
1204 |
+
),
|
1205 |
+
],
|
1206 |
+
extensions=[],
|
1207 |
+
nested_types=[
|
1208 |
+
_SELFTESTDATA_SAMPLE,
|
1209 |
+
],
|
1210 |
+
enum_types=[],
|
1211 |
+
serialized_options=None,
|
1212 |
+
is_extendable=True,
|
1213 |
+
syntax="proto2",
|
1214 |
+
extension_ranges=[
|
1215 |
+
(200, 536870912),
|
1216 |
+
],
|
1217 |
+
oneofs=[],
|
1218 |
+
serialized_start=1572,
|
1219 |
+
serialized_end=1693,
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
|
1223 |
+
_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor(
|
1224 |
+
name="SentencePiece",
|
1225 |
+
full_name="sentencepiece.ModelProto.SentencePiece",
|
1226 |
+
filename=None,
|
1227 |
+
file=DESCRIPTOR,
|
1228 |
+
containing_type=None,
|
1229 |
+
create_key=_descriptor._internal_create_key,
|
1230 |
+
fields=[
|
1231 |
+
_descriptor.FieldDescriptor(
|
1232 |
+
name="piece",
|
1233 |
+
full_name="sentencepiece.ModelProto.SentencePiece.piece",
|
1234 |
+
index=0,
|
1235 |
+
number=1,
|
1236 |
+
type=9,
|
1237 |
+
cpp_type=9,
|
1238 |
+
label=1,
|
1239 |
+
has_default_value=False,
|
1240 |
+
default_value=b"".decode("utf-8"),
|
1241 |
+
message_type=None,
|
1242 |
+
enum_type=None,
|
1243 |
+
containing_type=None,
|
1244 |
+
is_extension=False,
|
1245 |
+
extension_scope=None,
|
1246 |
+
serialized_options=None,
|
1247 |
+
file=DESCRIPTOR,
|
1248 |
+
create_key=_descriptor._internal_create_key,
|
1249 |
+
),
|
1250 |
+
_descriptor.FieldDescriptor(
|
1251 |
+
name="score",
|
1252 |
+
full_name="sentencepiece.ModelProto.SentencePiece.score",
|
1253 |
+
index=1,
|
1254 |
+
number=2,
|
1255 |
+
type=2,
|
1256 |
+
cpp_type=6,
|
1257 |
+
label=1,
|
1258 |
+
has_default_value=False,
|
1259 |
+
default_value=float(0),
|
1260 |
+
message_type=None,
|
1261 |
+
enum_type=None,
|
1262 |
+
containing_type=None,
|
1263 |
+
is_extension=False,
|
1264 |
+
extension_scope=None,
|
1265 |
+
serialized_options=None,
|
1266 |
+
file=DESCRIPTOR,
|
1267 |
+
create_key=_descriptor._internal_create_key,
|
1268 |
+
),
|
1269 |
+
_descriptor.FieldDescriptor(
|
1270 |
+
name="type",
|
1271 |
+
full_name="sentencepiece.ModelProto.SentencePiece.type",
|
1272 |
+
index=2,
|
1273 |
+
number=3,
|
1274 |
+
type=14,
|
1275 |
+
cpp_type=8,
|
1276 |
+
label=1,
|
1277 |
+
has_default_value=True,
|
1278 |
+
default_value=1,
|
1279 |
+
message_type=None,
|
1280 |
+
enum_type=None,
|
1281 |
+
containing_type=None,
|
1282 |
+
is_extension=False,
|
1283 |
+
extension_scope=None,
|
1284 |
+
serialized_options=None,
|
1285 |
+
file=DESCRIPTOR,
|
1286 |
+
create_key=_descriptor._internal_create_key,
|
1287 |
+
),
|
1288 |
+
],
|
1289 |
+
extensions=[],
|
1290 |
+
nested_types=[],
|
1291 |
+
enum_types=[
|
1292 |
+
_MODELPROTO_SENTENCEPIECE_TYPE,
|
1293 |
+
],
|
1294 |
+
serialized_options=None,
|
1295 |
+
is_extendable=True,
|
1296 |
+
syntax="proto2",
|
1297 |
+
extension_ranges=[
|
1298 |
+
(200, 536870912),
|
1299 |
+
],
|
1300 |
+
oneofs=[],
|
1301 |
+
serialized_start=1985,
|
1302 |
+
serialized_end=2195,
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
_MODELPROTO = _descriptor.Descriptor(
|
1306 |
+
name="ModelProto",
|
1307 |
+
full_name="sentencepiece.ModelProto",
|
1308 |
+
filename=None,
|
1309 |
+
file=DESCRIPTOR,
|
1310 |
+
containing_type=None,
|
1311 |
+
create_key=_descriptor._internal_create_key,
|
1312 |
+
fields=[
|
1313 |
+
_descriptor.FieldDescriptor(
|
1314 |
+
name="pieces",
|
1315 |
+
full_name="sentencepiece.ModelProto.pieces",
|
1316 |
+
index=0,
|
1317 |
+
number=1,
|
1318 |
+
type=11,
|
1319 |
+
cpp_type=10,
|
1320 |
+
label=3,
|
1321 |
+
has_default_value=False,
|
1322 |
+
default_value=[],
|
1323 |
+
message_type=None,
|
1324 |
+
enum_type=None,
|
1325 |
+
containing_type=None,
|
1326 |
+
is_extension=False,
|
1327 |
+
extension_scope=None,
|
1328 |
+
serialized_options=None,
|
1329 |
+
file=DESCRIPTOR,
|
1330 |
+
create_key=_descriptor._internal_create_key,
|
1331 |
+
),
|
1332 |
+
_descriptor.FieldDescriptor(
|
1333 |
+
name="trainer_spec",
|
1334 |
+
full_name="sentencepiece.ModelProto.trainer_spec",
|
1335 |
+
index=1,
|
1336 |
+
number=2,
|
1337 |
+
type=11,
|
1338 |
+
cpp_type=10,
|
1339 |
+
label=1,
|
1340 |
+
has_default_value=False,
|
1341 |
+
default_value=None,
|
1342 |
+
message_type=None,
|
1343 |
+
enum_type=None,
|
1344 |
+
containing_type=None,
|
1345 |
+
is_extension=False,
|
1346 |
+
extension_scope=None,
|
1347 |
+
serialized_options=None,
|
1348 |
+
file=DESCRIPTOR,
|
1349 |
+
create_key=_descriptor._internal_create_key,
|
1350 |
+
),
|
1351 |
+
_descriptor.FieldDescriptor(
|
1352 |
+
name="normalizer_spec",
|
1353 |
+
full_name="sentencepiece.ModelProto.normalizer_spec",
|
1354 |
+
index=2,
|
1355 |
+
number=3,
|
1356 |
+
type=11,
|
1357 |
+
cpp_type=10,
|
1358 |
+
label=1,
|
1359 |
+
has_default_value=False,
|
1360 |
+
default_value=None,
|
1361 |
+
message_type=None,
|
1362 |
+
enum_type=None,
|
1363 |
+
containing_type=None,
|
1364 |
+
is_extension=False,
|
1365 |
+
extension_scope=None,
|
1366 |
+
serialized_options=None,
|
1367 |
+
file=DESCRIPTOR,
|
1368 |
+
create_key=_descriptor._internal_create_key,
|
1369 |
+
),
|
1370 |
+
_descriptor.FieldDescriptor(
|
1371 |
+
name="self_test_data",
|
1372 |
+
full_name="sentencepiece.ModelProto.self_test_data",
|
1373 |
+
index=3,
|
1374 |
+
number=4,
|
1375 |
+
type=11,
|
1376 |
+
cpp_type=10,
|
1377 |
+
label=1,
|
1378 |
+
has_default_value=False,
|
1379 |
+
default_value=None,
|
1380 |
+
message_type=None,
|
1381 |
+
enum_type=None,
|
1382 |
+
containing_type=None,
|
1383 |
+
is_extension=False,
|
1384 |
+
extension_scope=None,
|
1385 |
+
serialized_options=None,
|
1386 |
+
file=DESCRIPTOR,
|
1387 |
+
create_key=_descriptor._internal_create_key,
|
1388 |
+
),
|
1389 |
+
_descriptor.FieldDescriptor(
|
1390 |
+
name="denormalizer_spec",
|
1391 |
+
full_name="sentencepiece.ModelProto.denormalizer_spec",
|
1392 |
+
index=4,
|
1393 |
+
number=5,
|
1394 |
+
type=11,
|
1395 |
+
cpp_type=10,
|
1396 |
+
label=1,
|
1397 |
+
has_default_value=False,
|
1398 |
+
default_value=None,
|
1399 |
+
message_type=None,
|
1400 |
+
enum_type=None,
|
1401 |
+
containing_type=None,
|
1402 |
+
is_extension=False,
|
1403 |
+
extension_scope=None,
|
1404 |
+
serialized_options=None,
|
1405 |
+
file=DESCRIPTOR,
|
1406 |
+
create_key=_descriptor._internal_create_key,
|
1407 |
+
),
|
1408 |
+
],
|
1409 |
+
extensions=[],
|
1410 |
+
nested_types=[
|
1411 |
+
_MODELPROTO_SENTENCEPIECE,
|
1412 |
+
],
|
1413 |
+
enum_types=[],
|
1414 |
+
serialized_options=None,
|
1415 |
+
is_extendable=True,
|
1416 |
+
syntax="proto2",
|
1417 |
+
extension_ranges=[
|
1418 |
+
(200, 536870912),
|
1419 |
+
],
|
1420 |
+
oneofs=[],
|
1421 |
+
serialized_start=1696,
|
1422 |
+
serialized_end=2206,
|
1423 |
+
)
|
1424 |
+
|
1425 |
+
_TRAINERSPEC.fields_by_name["model_type"].enum_type = _TRAINERSPEC_MODELTYPE
|
1426 |
+
_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC
|
1427 |
+
_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA
|
1428 |
+
_SELFTESTDATA.fields_by_name["samples"].message_type = _SELFTESTDATA_SAMPLE
|
1429 |
+
_MODELPROTO_SENTENCEPIECE.fields_by_name["type"].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE
|
1430 |
+
_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO
|
1431 |
+
_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE
|
1432 |
+
_MODELPROTO.fields_by_name["pieces"].message_type = _MODELPROTO_SENTENCEPIECE
|
1433 |
+
_MODELPROTO.fields_by_name["trainer_spec"].message_type = _TRAINERSPEC
|
1434 |
+
_MODELPROTO.fields_by_name["normalizer_spec"].message_type = _NORMALIZERSPEC
|
1435 |
+
_MODELPROTO.fields_by_name["self_test_data"].message_type = _SELFTESTDATA
|
1436 |
+
_MODELPROTO.fields_by_name["denormalizer_spec"].message_type = _NORMALIZERSPEC
|
1437 |
+
DESCRIPTOR.message_types_by_name["TrainerSpec"] = _TRAINERSPEC
|
1438 |
+
DESCRIPTOR.message_types_by_name["NormalizerSpec"] = _NORMALIZERSPEC
|
1439 |
+
DESCRIPTOR.message_types_by_name["SelfTestData"] = _SELFTESTDATA
|
1440 |
+
DESCRIPTOR.message_types_by_name["ModelProto"] = _MODELPROTO
|
1441 |
+
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
1442 |
+
|
1443 |
+
TrainerSpec = _reflection.GeneratedProtocolMessageType(
|
1444 |
+
"TrainerSpec",
|
1445 |
+
(_message.Message,),
|
1446 |
+
{
|
1447 |
+
"DESCRIPTOR": _TRAINERSPEC,
|
1448 |
+
"__module__": "sentencepiece_model_pb2",
|
1449 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec)
|
1450 |
+
},
|
1451 |
+
)
|
1452 |
+
_sym_db.RegisterMessage(TrainerSpec)
|
1453 |
+
|
1454 |
+
NormalizerSpec = _reflection.GeneratedProtocolMessageType(
|
1455 |
+
"NormalizerSpec",
|
1456 |
+
(_message.Message,),
|
1457 |
+
{
|
1458 |
+
"DESCRIPTOR": _NORMALIZERSPEC,
|
1459 |
+
"__module__": "sentencepiece_model_pb2",
|
1460 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec)
|
1461 |
+
},
|
1462 |
+
)
|
1463 |
+
_sym_db.RegisterMessage(NormalizerSpec)
|
1464 |
+
|
1465 |
+
SelfTestData = _reflection.GeneratedProtocolMessageType(
|
1466 |
+
"SelfTestData",
|
1467 |
+
(_message.Message,),
|
1468 |
+
{
|
1469 |
+
"Sample": _reflection.GeneratedProtocolMessageType(
|
1470 |
+
"Sample",
|
1471 |
+
(_message.Message,),
|
1472 |
+
{
|
1473 |
+
"DESCRIPTOR": _SELFTESTDATA_SAMPLE,
|
1474 |
+
"__module__": "sentencepiece_model_pb2",
|
1475 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample)
|
1476 |
+
},
|
1477 |
+
),
|
1478 |
+
"DESCRIPTOR": _SELFTESTDATA,
|
1479 |
+
"__module__": "sentencepiece_model_pb2",
|
1480 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData)
|
1481 |
+
},
|
1482 |
+
)
|
1483 |
+
_sym_db.RegisterMessage(SelfTestData)
|
1484 |
+
_sym_db.RegisterMessage(SelfTestData.Sample)
|
1485 |
+
|
1486 |
+
ModelProto = _reflection.GeneratedProtocolMessageType(
|
1487 |
+
"ModelProto",
|
1488 |
+
(_message.Message,),
|
1489 |
+
{
|
1490 |
+
"SentencePiece": _reflection.GeneratedProtocolMessageType(
|
1491 |
+
"SentencePiece",
|
1492 |
+
(_message.Message,),
|
1493 |
+
{
|
1494 |
+
"DESCRIPTOR": _MODELPROTO_SENTENCEPIECE,
|
1495 |
+
"__module__": "sentencepiece_model_pb2",
|
1496 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece)
|
1497 |
+
},
|
1498 |
+
),
|
1499 |
+
"DESCRIPTOR": _MODELPROTO,
|
1500 |
+
"__module__": "sentencepiece_model_pb2",
|
1501 |
+
# @@protoc_insertion_point(class_scope:sentencepiece.ModelProto)
|
1502 |
+
},
|
1503 |
+
)
|
1504 |
+
_sym_db.RegisterMessage(ModelProto)
|
1505 |
+
_sym_db.RegisterMessage(ModelProto.SentencePiece)
|
1506 |
+
|
1507 |
+
|
1508 |
+
DESCRIPTOR._options = None
|
1509 |
+
_TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
|
1510 |
+
_TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
|
1511 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/Lib/site-packages/transformers/utils/sentencepiece_model_pb2_new.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: sentencepiece_model.proto
|
4 |
+
"""Generated protocol buffer code."""
|
5 |
+
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
from google.protobuf.internal import builder as _builder
|
10 |
+
|
11 |
+
|
12 |
+
# @@protoc_insertion_point(imports)
|
13 |
+
|
14 |
+
_sym_db = _symbol_database.Default()
|
15 |
+
|
16 |
+
|
17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
18 |
+
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
|
19 |
+
)
|
20 |
+
|
21 |
+
_globals = globals()
|
22 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
23 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
|
24 |
+
if _descriptor._USE_C_DESCRIPTORS is False:
|
25 |
+
DESCRIPTOR._options = None
|
26 |
+
DESCRIPTOR._serialized_options = b"H\003"
|
27 |
+
# (generated by protobuf compiler, but `_TRAINERSPEC` is not defined)
|
28 |
+
# _TRAINERSPEC.fields_by_name["mining_sentence_size"]._options = None
|
29 |
+
# _TRAINERSPEC.fields_by_name["mining_sentence_size"]._serialized_options = b"\030\001"
|
30 |
+
# _TRAINERSPEC.fields_by_name["training_sentence_size"]._options = None
|
31 |
+
# _TRAINERSPEC.fields_by_name["training_sentence_size"]._serialized_options = b"\030\001"
|
32 |
+
_globals["_TRAINERSPEC"]._serialized_start = 45
|
33 |
+
_globals["_TRAINERSPEC"]._serialized_end = 1581
|
34 |
+
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
|
35 |
+
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
|
36 |
+
_globals["_NORMALIZERSPEC"]._serialized_start = 1584
|
37 |
+
_globals["_NORMALIZERSPEC"]._serialized_end = 1793
|
38 |
+
_globals["_SELFTESTDATA"]._serialized_start = 1795
|
39 |
+
_globals["_SELFTESTDATA"]._serialized_end = 1916
|
40 |
+
_globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
|
41 |
+
_globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
|
42 |
+
_globals["_MODELPROTO"]._serialized_start = 1919
|
43 |
+
_globals["_MODELPROTO"]._serialized_end = 2429
|
44 |
+
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
|
45 |
+
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
|
46 |
+
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
|
47 |
+
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
|
48 |
+
# @@protoc_insertion_point(module_scope)
|
.venv/Lib/site-packages/transformers/utils/versions.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Utilities for working with package versions
|
16 |
+
"""
|
17 |
+
|
18 |
+
import importlib.metadata
|
19 |
+
import operator
|
20 |
+
import re
|
21 |
+
import sys
|
22 |
+
from typing import Optional
|
23 |
+
|
24 |
+
from packaging import version
|
25 |
+
|
26 |
+
|
27 |
+
ops = {
|
28 |
+
"<": operator.lt,
|
29 |
+
"<=": operator.le,
|
30 |
+
"==": operator.eq,
|
31 |
+
"!=": operator.ne,
|
32 |
+
">=": operator.ge,
|
33 |
+
">": operator.gt,
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
|
38 |
+
if got_ver is None or want_ver is None:
|
39 |
+
raise ValueError(
|
40 |
+
f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
|
41 |
+
f" reinstalling {pkg}."
|
42 |
+
)
|
43 |
+
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
44 |
+
raise ImportError(
|
45 |
+
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
50 |
+
"""
|
51 |
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
52 |
+
|
53 |
+
The installed module version comes from the *site-packages* dir via *importlib.metadata*.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
|
57 |
+
hint (`str`, *optional*): what suggestion to print in case of requirements not being met
|
58 |
+
|
59 |
+
Example:
|
60 |
+
|
61 |
+
```python
|
62 |
+
require_version("pandas>1.1.2")
|
63 |
+
require_version("numpy>1.18.5", "this is important to have for whatever reason")
|
64 |
+
```"""
|
65 |
+
|
66 |
+
hint = f"\n{hint}" if hint is not None else ""
|
67 |
+
|
68 |
+
# non-versioned check
|
69 |
+
if re.match(r"^[\w_\-\d]+$", requirement):
|
70 |
+
pkg, op, want_ver = requirement, None, None
|
71 |
+
else:
|
72 |
+
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
|
73 |
+
if not match:
|
74 |
+
raise ValueError(
|
75 |
+
"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
|
76 |
+
f" got {requirement}"
|
77 |
+
)
|
78 |
+
pkg, want_full = match[0]
|
79 |
+
want_range = want_full.split(",") # there could be multiple requirements
|
80 |
+
wanted = {}
|
81 |
+
for w in want_range:
|
82 |
+
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
|
83 |
+
if not match:
|
84 |
+
raise ValueError(
|
85 |
+
"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
|
86 |
+
f" but got {requirement}"
|
87 |
+
)
|
88 |
+
op, want_ver = match[0]
|
89 |
+
wanted[op] = want_ver
|
90 |
+
if op not in ops:
|
91 |
+
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
|
92 |
+
|
93 |
+
# special case
|
94 |
+
if pkg == "python":
|
95 |
+
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
96 |
+
for op, want_ver in wanted.items():
|
97 |
+
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
98 |
+
return
|
99 |
+
|
100 |
+
# check if any version is installed
|
101 |
+
try:
|
102 |
+
got_ver = importlib.metadata.version(pkg)
|
103 |
+
except importlib.metadata.PackageNotFoundError:
|
104 |
+
raise importlib.metadata.PackageNotFoundError(
|
105 |
+
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
|
106 |
+
)
|
107 |
+
|
108 |
+
# check that the right version is installed if version number or a range was provided
|
109 |
+
if want_ver is not None:
|
110 |
+
for op, want_ver in wanted.items():
|
111 |
+
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
112 |
+
|
113 |
+
|
114 |
+
def require_version_core(requirement):
|
115 |
+
"""require_version wrapper which emits a core-specific hint on failure"""
|
116 |
+
hint = "Try: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main"
|
117 |
+
return require_version(requirement, hint)
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Abidjan
ADDED
Binary file (130 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Accra
ADDED
Binary file (130 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Addis_Ababa
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Algiers
ADDED
Binary file (470 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmara
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Asmera
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bamako
ADDED
Binary file (130 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Bangui
ADDED
Binary file (180 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Dar_es_Salaam
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Djibouti
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Douala
ADDED
Binary file (180 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/El_Aaiun
ADDED
Binary file (1.83 kB). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Freetown
ADDED
Binary file (130 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Gaborone
ADDED
Binary file (131 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Harare
ADDED
Binary file (131 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Johannesburg
ADDED
Binary file (190 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Juba
ADDED
Binary file (458 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kampala
ADDED
Binary file (191 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Khartoum
ADDED
Binary file (458 Bytes). View file
|
|
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Kigali
ADDED
Binary file (131 Bytes). View file
|
|