Team Finetuner
commited on
Commit
•
344bcbc
1
Parent(s):
5ee2c37
chore: update from afe81ca705ca1a5bd6b7d90548fcac068850b2af
Browse files- configuration_bert.py +1 -7
- modeling_bert.py +0 -3
configuration_bert.py
CHANGED
@@ -84,14 +84,10 @@ class JinaBertConfig(PretrainedConfig):
|
|
84 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
85 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
86 |
Should be one of `None`, `"mean"`.
|
87 |
-
|
88 |
-
Whether to use triton flash attention. Only works for `triton==2.0.0.dev20230208`.
|
89 |
-
This argument will be deprecated in the future. Use `attention_implementation` instead.
|
90 |
-
attn_implementation (`str`, *optional*, defaults to `None`):
|
91 |
The implementation of the self-attention layer. Can be one of:
|
92 |
- `None` for the original implementation,
|
93 |
- `torch` for the PyTorch SDPA implementation,
|
94 |
-
- `triton` for the Triton Flash implementation. Only works for `triton==2.0.0.dev20230208`
|
95 |
|
96 |
Examples:
|
97 |
|
@@ -132,7 +128,6 @@ class JinaBertConfig(PretrainedConfig):
|
|
132 |
classifier_dropout=None,
|
133 |
feed_forward_type="original",
|
134 |
emb_pooler=None,
|
135 |
-
with_flash=False,
|
136 |
attn_implementation='torch',
|
137 |
**kwargs,
|
138 |
):
|
@@ -156,7 +151,6 @@ class JinaBertConfig(PretrainedConfig):
|
|
156 |
self.feed_forward_type = feed_forward_type
|
157 |
self.emb_pooler = emb_pooler
|
158 |
self.attn_implementation = attn_implementation
|
159 |
-
self.with_flash = with_flash
|
160 |
|
161 |
class JinaBertOnnxConfig(OnnxConfig):
|
162 |
@property
|
|
|
84 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
85 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
86 |
Should be one of `None`, `"mean"`.
|
87 |
+
attn_implementation (`str`, *optional*, defaults to `"torch"`):
|
|
|
|
|
|
|
88 |
The implementation of the self-attention layer. Can be one of:
|
89 |
- `None` for the original implementation,
|
90 |
- `torch` for the PyTorch SDPA implementation,
|
|
|
91 |
|
92 |
Examples:
|
93 |
|
|
|
128 |
classifier_dropout=None,
|
129 |
feed_forward_type="original",
|
130 |
emb_pooler=None,
|
|
|
131 |
attn_implementation='torch',
|
132 |
**kwargs,
|
133 |
):
|
|
|
151 |
self.feed_forward_type = feed_forward_type
|
152 |
self.emb_pooler = emb_pooler
|
153 |
self.attn_implementation = attn_implementation
|
|
|
154 |
|
155 |
class JinaBertOnnxConfig(OnnxConfig):
|
156 |
@property
|
modeling_bert.py
CHANGED
@@ -273,9 +273,6 @@ class JinaBertSelfAttention(nn.Module):
|
|
273 |
)
|
274 |
|
275 |
self.attn_implementation = config.attn_implementation
|
276 |
-
if config.with_flash:
|
277 |
-
self.attn_implementation = 'triton'
|
278 |
-
|
279 |
self.num_attention_heads = config.num_attention_heads
|
280 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
281 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
273 |
)
|
274 |
|
275 |
self.attn_implementation = config.attn_implementation
|
|
|
|
|
|
|
276 |
self.num_attention_heads = config.num_attention_heads
|
277 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
278 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|