ccdv commited on
Commit
89b8bf9
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - long context
5
+ - legal
6
+ pipeline_tag: fill-mask
7
+ ---
8
+
9
+ # LSG model
10
+ **Transformers >= 4.18.0**\
11
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
12
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
13
+
14
+ * [Usage](#usage)
15
+ * [Parameters](#parameters)
16
+ * [Sparse selection type](#sparse-selection-type)
17
+ * [Tasks](#tasks)
18
+ * [Training global tokens](#training-global-tokens)
19
+
20
+ This model is a small version of the [LEGAL-BERT](https://huggingface.co/nlpaueb/legal-bert-small-uncased) model without additional pretraining yet. It uses the same number of parameters/layers and the same tokenizer.
21
+
22
+
23
+ This model can handle long sequences but faster and more efficiently than Longformer or BigBird (from Transformers) and relies on Local + Sparse + Global attention (LSG).
24
+
25
+
26
+ The model requires sequences whose length is a multiple of the block size. The model is "adaptive" and automatically pads the sequences if needed (adaptive=True in config). It is however recommended, thanks to the tokenizer, to truncate the inputs (truncation=True) and optionally to pad with a multiple of the block size (pad_to_multiple_of=...). \
27
+
28
+
29
+ Support encoder-decoder but I didnt test it extensively.\
30
+ Implemented in PyTorch.
31
+
32
+ ![attn](attn.png)
33
+
34
+ ## Usage
35
+ The model relies on a custom modeling file, you need to add trust_remote_code=True to use it.
36
+
37
+ ```python:
38
+ from transformers import AutoModel, AutoTokenizer
39
+
40
+ model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
41
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
42
+ ```
43
+
44
+ ## Parameters
45
+ You can change various parameters like :
46
+ * the number of global tokens (num_global_tokens=1)
47
+ * local block size (block_size=128)
48
+ * sparse block size (sparse_block_size=128)
49
+ * sparsity factor (sparsity_factor=2)
50
+ * see config.json file
51
+
52
+ Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
53
+
54
+ ```python:
55
+ model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
56
+ trust_remote_code=True,
57
+ num_global_tokens=16,
58
+ block_size=64,
59
+ sparse_block_size=64,
60
+ sparsity_factor=4,
61
+ attention_probs_dropout_prob=0.0
62
+ )
63
+ ```
64
+
65
+ ## Sparse selection type
66
+
67
+ There are 5 different sparse selection patterns. The best type is task dependent. \
68
+ Note that for sequences with length < 2*block_size, the type has no effect.
69
+
70
+ * sparsity_type="norm", select highest norm tokens
71
+ * Works best for a small sparsity_factor (2 to 4)
72
+ * Additional parameters:
73
+ * None
74
+ * sparsity_type="pooling", use average pooling to merge tokens
75
+ * Works best for a small sparsity_factor (2 to 4)
76
+ * Additional parameters:
77
+ * None
78
+ * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
79
+ * Works best for a large sparsity_factor (4+)
80
+ * LSH relies on random projections, thus inference may differ slightly with different seeds
81
+ * Additional parameters:
82
+ * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
83
+ * sparsity_type="stride", use a striding mecanism per head
84
+ * Each head will use different tokens strided by sparsify_factor
85
+ * Not recommended if sparsify_factor > num_heads
86
+ * sparsity_type="block_stride", use a striding mecanism per head
87
+ * Each head will use block of tokens strided by sparsify_factor
88
+ * Not recommended if sparsify_factor > num_heads
89
+
90
+ ## Tasks
91
+ Fill mask example:
92
+ ```python:
93
+ from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
94
+
95
+ model = AutoModelForMaskedLM.from_pretrained("ccdv/legal-lsg-small-uncased-4096", trust_remote_code=True)
96
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
97
+
98
+ SENTENCES = ["Paris is the <mask> of France.", "The goal of life is <mask>."]
99
+ pipeline = FillMaskPipeline(model, tokenizer)
100
+ output = pipeline(SENTENCES, top_k=1)
101
+
102
+ output = [o[0]["sequence"] for o in output]
103
+ > ['Paris is the capital of France.', 'The goal of life is happiness.']
104
+ ```
105
+
106
+
107
+ Classification example:
108
+ ```python:
109
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
110
+
111
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
112
+ trust_remote_code=True,
113
+ pool_with_global=True, # pool with a global token instead of first token
114
+ )
115
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
116
+
117
+ SENTENCE = "This is a test for sequence classification. " * 300
118
+ token_ids = tokenizer(
119
+ SENTENCE,
120
+ return_tensors="pt",
121
+ #pad_to_multiple_of=... # Optional
122
+ truncation=True
123
+ )
124
+ output = model(**token_ids)
125
+
126
+ > SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
127
+ ```
128
+
129
+ ## Training global tokens
130
+ To train global tokens and the classification head only:
131
+ ```python:
132
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
133
+
134
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
135
+ trust_remote_code=True,
136
+ pool_with_global=True, # pool with a global token instead of first token
137
+ num_global_tokens=16
138
+ )
139
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/legal-lsg-small-uncased-4096")
140
+
141
+ for name, param in model.named_parameters():
142
+ if "global_embeddings" not in name:
143
+ param.requires_grad = False
144
+ else:
145
+ param.required_grad = True
146
+ ```
147
+
148
+
149
+ **LEGAL-BERT**
150
+ ```
151
+ @inproceedings{chalkidis-etal-2020-legal,
152
+ title = "{LEGAL}-{BERT}: The Muppets straight out of Law School",
153
+ author = "Chalkidis, Ilias and
154
+ Fergadiotis, Manos and
155
+ Malakasiotis, Prodromos and
156
+ Aletras, Nikolaos and
157
+ Androutsopoulos, Ion",
158
+ booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020",
159
+ month = nov,
160
+ year = "2020",
161
+ address = "Online",
162
+ publisher = "Association for Computational Linguistics",
163
+ doi = "10.18653/v1/2020.findings-emnlp.261",
164
+ pages = "2898--2904"
165
+ }
166
+ ```
attn.png ADDED
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ccdv/legal-lsg-small-uncased-4096",
3
+ "adaptive": true,
4
+ "architectures": [
5
+ "LSGBertForPreTraining"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_lsg_bert.LSGBertConfig",
10
+ "AutoModel": "modeling_lsg_bert.LSGBertModel",
11
+ "AutoModelForCausalLM": "modeling_lsg_bert.LSGBertLMHeadModel",
12
+ "AutoModelForMaskedLM": "modeling_lsg_bert.LSGBertForMaskedLM",
13
+ "AutoModelForMultipleChoice": "modeling_lsg_bert.LSGBertForMultipleChoice",
14
+ "AutoModelForPreTraining": "modeling_lsg_bert.LSGBertForPreTraining",
15
+ "AutoModelForQuestionAnswering": "modeling_lsg_bert.LSGBertForQuestionAnswering",
16
+ "AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
17
+ "AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
18
+ },
19
+ "base_model_prefix": "lsg",
20
+ "block_size": 128,
21
+ "bos_token_id": 0,
22
+ "classifier_dropout": null,
23
+ "eos_token_ids": 0,
24
+ "hidden_act": "gelu",
25
+ "hidden_dropout_prob": 0.1,
26
+ "hidden_size": 512,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 2048,
29
+ "layer_norm_eps": 1e-12,
30
+ "lsh_num_pre_rounds": 1,
31
+ "max_position_embeddings": 4096,
32
+ "model_type": "bert",
33
+ "num_attention_heads": 8,
34
+ "num_global_tokens": 1,
35
+ "num_hidden_layers": 6,
36
+ "output_past": true,
37
+ "pad_token_id": 0,
38
+ "pool_with_global": true,
39
+ "position_embedding_type": "absolute",
40
+ "sparse_block_size": 128,
41
+ "sparsity_factor": 2,
42
+ "sparsity_type": "norm",
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.19.2",
45
+ "type_vocab_size": 2,
46
+ "use_cache": true,
47
+ "vocab_size": 30522
48
+ }
modeling_lsg_bert.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ from transformers.models.bert.modeling_bert import *
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.bert.configuration_bert import BertConfig
6
+ import sys
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_lsg_bert.LSGBertModel",
10
+ "AutoModelForCausalLM": "modeling_lsg_bert.LSGBertLMHeadModel",
11
+ "AutoModelForMaskedLM": "modeling_lsg_bert.LSGBertForMaskedLM",
12
+ "AutoModelForPreTraining": "modeling_lsg_bert.LSGBertForPreTraining",
13
+ "AutoModelForMultipleChoice": "modeling_lsg_bert.LSGBertForMultipleChoice",
14
+ "AutoModelForQuestionAnswering": "modeling_lsg_bert.LSGBertForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_lsg_bert.LSGBertForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_lsg_bert.LSGBertForTokenClassification"
17
+ }
18
+
19
+ class LSGBertConfig(BertConfig):
20
+ """
21
+ This class overrides :class:`~transformers.BertConfig`. Please check the superclass for the appropriate
22
+ documentation alongside usage examples.
23
+ """
24
+
25
+ base_model_prefix = "lsg"
26
+ model_type = "bert"
27
+
28
+ def __init__(
29
+ self,
30
+ adaptive=True,
31
+ base_model_prefix="lsg",
32
+ block_size=128,
33
+ lsh_num_pre_rounds=1,
34
+ num_global_tokens=1,
35
+ pool_with_global=True,
36
+ sparse_block_size=128,
37
+ sparsity_factor=2,
38
+ sparsity_type="norm",
39
+ **kwargs
40
+ ):
41
+ """Constructs LSGBertConfig."""
42
+ super().__init__(**kwargs)
43
+
44
+ self.adaptive = adaptive
45
+ self.auto_map = AUTO_MAP
46
+ self.base_model_prefix = base_model_prefix
47
+ self.block_size = block_size
48
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
49
+ self.num_global_tokens = num_global_tokens
50
+ self.pool_with_global = pool_with_global
51
+ self.sparse_block_size = sparse_block_size
52
+ self.sparsity_factor = sparsity_factor
53
+ self.sparsity_type = sparsity_type
54
+
55
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
56
+ logger.warning(
57
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
58
+ self.sparsity_type = None
59
+
60
+ if self.sparsity_type in ["stride", "block_stride"]:
61
+ if self.sparsity_factor > self.encoder_attention_heads:
62
+ logger.warning(
63
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
64
+ )
65
+
66
+ if self.num_global_tokens < 1:
67
+ logger.warning(
68
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
69
+ )
70
+ self.num_global_tokens = 1
71
+ elif self.num_global_tokens > 512:
72
+ logger.warning(
73
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
74
+ )
75
+ self.num_global_tokens = 512
76
+
77
+ if self.sparsity_factor > 0:
78
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
79
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
80
+
81
+
82
+ class BaseSelfAttention(nn.Module):
83
+
84
+ def init_modules(self, config):
85
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
86
+ config, "embedding_size"
87
+ ):
88
+ raise ValueError(
89
+ "The hidden size (%d) is not a multiple of the number of attention "
90
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
91
+ )
92
+
93
+ self.num_attention_heads = config.num_attention_heads
94
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
95
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
96
+
97
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
98
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
99
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
100
+
101
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
102
+
103
+ def transpose_for_scores(self, x):
104
+ new_x_shape = x.size()[:-1] + (
105
+ self.num_attention_heads,
106
+ self.attention_head_size,
107
+ )
108
+ x = x.view(*new_x_shape)
109
+ return x.permute(0, 2, 1, 3)
110
+
111
+ def reshape_output(self, context_layer):
112
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
113
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
114
+ return context_layer.view(*new_context_layer_shape)
115
+
116
+ def project_QKV(self, hidden_states):
117
+
118
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
119
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
120
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
121
+ return query_layer, key_layer, value_layer
122
+
123
+
124
+ class BaseAttentionProduct(nn.Module):
125
+
126
+ def __init__(self, config):
127
+ """
128
+ Compute attention: softmax(Q @ K.T) @ V
129
+ """
130
+ super().__init__()
131
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
132
+
133
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
134
+
135
+ d = query_layer.shape[-1]
136
+
137
+ # Take the dot product between "query" and "key" to get the raw attention scores.
138
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
139
+
140
+ del query_layer
141
+ del key_layer
142
+
143
+ if attention_mask is not None:
144
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
145
+ attention_scores = attention_scores + attention_mask
146
+ del attention_mask
147
+
148
+ # Normalize the attention scores to probabilities.
149
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
150
+
151
+ # This is actually dropping out entire tokens to attend to, which might
152
+ # seem a bit unusual, but is taken from the original Transformer paper.
153
+ context_layer = self.dropout(attention_probs) @ value_layer
154
+
155
+ return context_layer
156
+
157
+
158
+ class CausalAttentionProduct(nn.Module):
159
+
160
+ def __init__(self, config):
161
+ """
162
+ Compute attention: softmax(Q @ K.T) @ V
163
+ """
164
+ super().__init__()
165
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
166
+ self.block_size = config.block_size
167
+
168
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
169
+
170
+ d = query_layer.shape[-1]
171
+
172
+ # Take the dot product between "query" and "key" to get the raw attention scores.
173
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
174
+
175
+ del query_layer
176
+ del key_layer
177
+
178
+ if attention_mask is not None:
179
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
180
+ attention_scores = attention_scores + attention_mask
181
+
182
+ # Add causal mask
183
+ causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
184
+ causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
185
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
186
+
187
+ del attention_mask
188
+
189
+ # Normalize the attention scores to probabilities.
190
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
191
+
192
+ # This is actually dropping out entire tokens to attend to, which might
193
+ # seem a bit unusual, but is taken from the original Transformer paper.
194
+ context_layer = self.dropout(attention_probs) @ value_layer
195
+
196
+ return context_layer
197
+
198
+
199
+ class LSGAttentionProduct(nn.Module):
200
+
201
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
202
+ """
203
+ Compute block or overlapping blocks attention products
204
+ """
205
+ super().__init__()
206
+
207
+ self.block_size = block_size
208
+ self.sparse_block_size = sparse_block_size
209
+ self.sparsity_factor = sparsity_factor
210
+ self.is_causal = is_causal
211
+
212
+ if self.block_size is None:
213
+ self.block_size = config.block_size
214
+
215
+ if self.sparse_block_size is None:
216
+ self.sparse_block_size = config.sparse_block_size
217
+
218
+ # Shape of blocks
219
+ self.local_shapes = (self.block_size*3, self.block_size)
220
+ if self.sparse_block_size and self.sparsity_factor > 0:
221
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
222
+
223
+ if is_causal:
224
+ self.attention = CausalAttentionProduct(config)
225
+ else:
226
+ self.attention = BaseAttentionProduct(config)
227
+
228
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
229
+
230
+ # Build local tokens
231
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
232
+ del hidden_states
233
+
234
+ # Build sparse tokens
235
+ if sparse_hidden_states is not None:
236
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
237
+
238
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
239
+
240
+ def forward(
241
+ self,
242
+ query_layer,
243
+ key_layer,
244
+ value_layer,
245
+ attention_mask=None,
246
+ sparse_key=None,
247
+ sparse_value=None,
248
+ sparse_mask=None,
249
+ global_key=None,
250
+ global_value=None,
251
+ global_mask=None
252
+ ):
253
+
254
+ # Input batch, heads, length, hidden_size
255
+ n, h, t, d = query_layer.size()
256
+ n_blocks = t // self.block_size
257
+ assert t % self.block_size == 0
258
+
259
+ key_layer = self.build_lsg_inputs(
260
+ key_layer,
261
+ sparse_key,
262
+ global_key
263
+ )
264
+ del sparse_key
265
+ del global_key
266
+
267
+ value_layer = self.build_lsg_inputs(
268
+ value_layer,
269
+ sparse_value,
270
+ global_value
271
+ )
272
+ del sparse_value
273
+ del global_value
274
+
275
+ attention_mask = self.build_lsg_inputs(
276
+ attention_mask,
277
+ sparse_mask,
278
+ global_mask.transpose(-1, -2),
279
+ is_attn_mask=True
280
+ ).transpose(-1, -2)
281
+ del sparse_mask
282
+ del global_mask
283
+
284
+ # expect (..., t, d) shape
285
+ # Compute attention
286
+ context_layer = self.attention(
287
+ query_layer=self.chunk(query_layer, n_blocks),
288
+ key_layer=key_layer,
289
+ value_layer=value_layer,
290
+ attention_mask=attention_mask
291
+ )
292
+
293
+ return context_layer.reshape(n, h, -1, d)
294
+
295
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
296
+
297
+ size, step = self.local_shapes
298
+ s = (size - step) // 2
299
+
300
+ # Pad before block reshaping
301
+ if is_attn_mask:
302
+ pad_value = -10000
303
+ hidden_states = hidden_states.transpose(-1, -2)
304
+ else:
305
+ pad_value = 0
306
+
307
+ hidden_states = torch.nn.functional.pad(
308
+ hidden_states.transpose(-1, -2),
309
+ pad=(s, s),
310
+ value=pad_value
311
+ ).transpose(-1, -2)
312
+
313
+ # Make blocks
314
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
315
+
316
+ # Skip third block if causal
317
+ if self.is_causal:
318
+ return hidden_states[..., :size*2//3, :]
319
+
320
+ return hidden_states
321
+
322
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
323
+
324
+ size, step = self.sparse_shapes
325
+
326
+ # In case of odd case
327
+ odd_offset = (step % 2)
328
+
329
+ # n, h, t, d*2 + 1
330
+ size = size*2
331
+ s = (size - step) // 2 + odd_offset
332
+
333
+ # Pad before block reshaping
334
+ if is_attn_mask:
335
+ pad_value = -10000
336
+ hidden_states = hidden_states.transpose(-1, -2)
337
+ else:
338
+ pad_value = 0
339
+
340
+ hidden_states = torch.nn.functional.pad(
341
+ hidden_states.transpose(-1, -2),
342
+ pad=(s, s),
343
+ value=pad_value
344
+ ).transpose(-1, -2)
345
+
346
+ # Make blocks
347
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
348
+
349
+ # Fix case where block_size == sparsify_factor
350
+ if odd_offset:
351
+ hidden_states = hidden_states[..., :-1, :, :]
352
+
353
+ # Indexes for selection
354
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
355
+ s = self.sparse_block_size
356
+
357
+ # Skip right block if causal
358
+ if self.is_causal:
359
+ return hidden_states[..., u-s:u, :]
360
+
361
+ u_ = u + odd_offset
362
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
363
+
364
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
365
+
366
+ n, h, b, t, d = x_local.size()
367
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
368
+ if x_sparse is not None:
369
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
370
+ return torch.cat([x_global, x_local], dim=dim)
371
+
372
+ def chunk(self, x, n_blocks):
373
+
374
+ t, d = x.size()[-2:]
375
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
376
+
377
+
378
+ class LSGBertEmbeddings(BertEmbeddings):
379
+
380
+ def __init__(self, config):
381
+ super().__init__(config)
382
+
383
+ self.num_global_tokens = config.num_global_tokens
384
+
385
+ # Hardcoded but partially trained
386
+ self.global_embeddings = nn.Embedding(512, embedding_dim=config.hidden_size, )
387
+
388
+ self.block_size = config.block_size
389
+
390
+ def forward(
391
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
392
+ ):
393
+ if input_ids is not None:
394
+ input_shape = input_ids.size()
395
+ else:
396
+ input_shape = inputs_embeds.size()[:-1]
397
+
398
+ seq_length = input_shape[1]
399
+
400
+ if position_ids is None:
401
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
402
+
403
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
404
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
405
+ # issue #5664
406
+ if token_type_ids is None:
407
+ if hasattr(self, "token_type_ids"):
408
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
409
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
410
+ token_type_ids = buffered_token_type_ids_expanded
411
+ else:
412
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
413
+
414
+ if inputs_embeds is None:
415
+ inputs_embeds = self.word_embeddings(input_ids)
416
+ token_type_embeddings = self.token_type_embeddings(token_type_ids[:, :seq_length])
417
+
418
+ embeddings = inputs_embeds + token_type_embeddings
419
+ if self.position_embedding_type == "absolute":
420
+ position_embeddings = self.position_embeddings(position_ids[:, :seq_length])
421
+ embeddings += position_embeddings
422
+
423
+ #if self.num_global_tokens < 0:
424
+ n, t, d = embeddings.size()
425
+
426
+ # Add global_tokens
427
+ indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
428
+ global_embeddings = self.global_embeddings(indexes)
429
+ embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
430
+
431
+ embeddings = self.LayerNorm(embeddings)
432
+ embeddings = self.dropout(embeddings)
433
+ return embeddings
434
+
435
+
436
+ class LSGSelfAttention(BaseSelfAttention):
437
+ '''
438
+ Compute local attention with overlapping blocs
439
+ Use global attention for tokens with highest norm
440
+ '''
441
+ def __init__(self, config):
442
+ super().__init__()
443
+
444
+ self.init_modules(config)
445
+
446
+ self.block_size = config.block_size
447
+ self.sparse_block_size = config.sparse_block_size
448
+ self.num_global_tokens = config.num_global_tokens
449
+ self.sparsity_factor = config.sparsity_factor
450
+ self.is_causal = config.is_decoder
451
+ self.is_decoder = config.is_decoder
452
+
453
+ self.attention = LSGAttentionProduct(
454
+ config,
455
+ block_size=config.block_size,
456
+ sparse_block_size=config.sparse_block_size,
457
+ sparsity_factor=self.sparsity_factor,
458
+ is_causal=self.is_causal
459
+ )
460
+
461
+ if self.is_causal:
462
+ self.causal_attention = CausalAttentionProduct(config)
463
+ self.full_attention = BaseAttentionProduct(config)
464
+
465
+ sparse_functions = {
466
+ "norm": self.get_sparse_tokens_with_norm,
467
+ "pooling": self.get_sparse_tokens_with_pooling,
468
+ "lsh": self.get_sparse_tokens_with_lsh,
469
+ "stride": self.get_sparse_tokens_with_stride,
470
+ "block_stride": self.get_sparse_tokens_with_block_stride,
471
+ }
472
+
473
+ self.sparsity_type = config.sparsity_type
474
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
475
+
476
+ if config.sparsity_type == "lsh":
477
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
478
+
479
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
480
+
481
+ if self.sparsity_factor == 1:
482
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
483
+
484
+ with torch.no_grad():
485
+
486
+ block_size = min(self.block_size, self.sparse_block_size)
487
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
488
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
489
+ key_norm = self.chunk(key_norm, block_size)
490
+
491
+ n, h, b, t, d = key_norm.size()
492
+
493
+ idx = key_norm.argsort(dim=-2)
494
+ del key_norm
495
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
496
+
497
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
498
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
499
+
500
+ d = keys.size()[-1]
501
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
502
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
503
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
504
+
505
+ return keys, values, mask
506
+
507
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
508
+
509
+ if self.sparsity_factor == 1:
510
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
511
+
512
+ keys = self.chunk(keys, self.sparsity_factor)
513
+ values = self.chunk(values, self.sparsity_factor)
514
+
515
+ n, h, b, t, d = keys.size()
516
+ mask = mask.reshape(n, 1, b, 1, t)
517
+ mask = ~mask.transpose(-1, -2).bool()
518
+
519
+ keys = keys * mask
520
+ values = values * mask
521
+
522
+ mask = mask.sum(dim=-2)
523
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
524
+ values = values.sum(dim=-2) / (mask + 1e-6)
525
+
526
+ mask = - (1. - mask.clamp(0, 1)) * 1e4
527
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
528
+
529
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
530
+
531
+ if self.sparsity_factor == 1:
532
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
533
+
534
+ n, h, t, d = keys.size()
535
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
536
+ sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
537
+ sparse_idx = sparse_idx.expand(n, h, -1, 1)
538
+
539
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
540
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
541
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
542
+
543
+ return keys, values, mask
544
+
545
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
546
+
547
+ if self.sparsity_factor == 1:
548
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
549
+
550
+ n, h, t, d = keys.size()
551
+
552
+ t, b = self.block_size, t // self.block_size
553
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
554
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
555
+ sparse_idx = (sparse_idx % t)
556
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
557
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
558
+
559
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
560
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
561
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
562
+
563
+ return keys, values, mask
564
+
565
+ def get_sparse_tokens_with_lsh(self, keys, values, mask):
566
+
567
+ if self.sparsity_factor == 1:
568
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
569
+
570
+ block_size = min(self.block_size, self.sparse_block_size)
571
+ keys = self.chunk(keys, block_size)
572
+ values = self.chunk(values, block_size)
573
+
574
+ n, h, b, t, d = keys.size()
575
+ mask = mask.reshape(n, 1, b, 1, t)
576
+ mask = ~mask.transpose(-1, -2).bool()
577
+
578
+ keys = keys * mask
579
+ values = values * mask
580
+ mask = mask.expand(-1, h, -1, -1, -1).float()
581
+
582
+ extra_factor = 1
583
+
584
+ for _ in range(self.lsh_num_pre_rounds):
585
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
586
+
587
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
588
+ keys /= mask + 1e-8
589
+ values /= mask + 1e-8
590
+
591
+ mask = -10000 * (1. - mask.clamp(0, 1))
592
+
593
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
594
+
595
+ def lsh_round(self, keys, values, mask, output_size):
596
+
597
+ with torch.no_grad():
598
+
599
+ n_hashes = output_size // 2
600
+ n, h, b, t, d = keys.size()
601
+ binary_mask = mask.clamp(0, 1)
602
+
603
+ indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
604
+ indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
605
+
606
+ n, h, b, t, d = keys.size()
607
+
608
+ x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
609
+ mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
610
+ keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
611
+ values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
612
+ mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
613
+
614
+ return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states,
619
+ attention_mask=None,
620
+ head_mask=None,
621
+ encoder_hidden_states=None,
622
+ encoder_attention_mask=None,
623
+ past_key_value=None,
624
+ output_attentions=False,
625
+ ):
626
+
627
+ query_layer = self.query(hidden_states)
628
+
629
+ # If this is instantiated as a cross-attention module, the keys
630
+ # and values come from an encoder; the attention mask needs to be
631
+ # such that the encoder's padding tokens are not attended to.
632
+ is_cross_attention = encoder_hidden_states is not None
633
+
634
+ if is_cross_attention and past_key_value is not None:
635
+ # reuse k,v, cross_attentions
636
+ key_layer = past_key_value[0]
637
+ value_layer = past_key_value[1]
638
+ attention_mask = encoder_attention_mask
639
+ elif is_cross_attention:
640
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
641
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
642
+ attention_mask = encoder_attention_mask
643
+ elif past_key_value is not None:
644
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
645
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
646
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
647
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
648
+ else:
649
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
650
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
651
+
652
+ query_layer = self.transpose_for_scores(query_layer)
653
+
654
+ if self.is_decoder:
655
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
656
+ # Further calls to cross_attention layer can then reuse all cross-attention
657
+ # key/value_states (first "if" case)
658
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
659
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
660
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
661
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
662
+ past_key_value = (key_layer, value_layer)
663
+
664
+ if is_cross_attention:
665
+ outputs = self.cross_attention_forward(
666
+ query_layer=query_layer,
667
+ key_layer=key_layer,
668
+ value_layer=value_layer,
669
+ attention_mask=attention_mask,
670
+ output_attentions=output_attentions
671
+ )
672
+ else:
673
+ outputs = self.causal_forward(
674
+ query_layer,
675
+ key_layer,
676
+ value_layer,
677
+ attention_mask=attention_mask,
678
+ output_attentions=output_attentions,
679
+ )
680
+
681
+ outputs = outputs + ((key_layer, value_layer),)
682
+
683
+ else:
684
+ outputs = self.not_causal_forward(
685
+ query_layer,
686
+ key_layer,
687
+ value_layer,
688
+ attention_mask=attention_mask,
689
+ output_attentions=output_attentions
690
+ )
691
+
692
+ #if head_mask is not None:
693
+ # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
694
+ return outputs
695
+
696
+ def causal_forward(
697
+ self,
698
+ query_layer,
699
+ key_layer,
700
+ value_layer,
701
+ attention_mask=None,
702
+ output_attentions=False,
703
+ ):
704
+
705
+ n, h, t, d = key_layer.size()
706
+
707
+ # Cat global mask
708
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
709
+
710
+ # Split input into global tokens and other tokens
711
+ split = (self.num_global_tokens, t - self.num_global_tokens)
712
+ global_query, query_layer = query_layer.split(split, dim=-2)
713
+
714
+ # Use normal causal attention if local attention covers every tokens
715
+ if t <= 2 * self.block_size + self.num_global_tokens:
716
+ context_layer = self.causal_attention(
717
+ query_layer=query_layer,
718
+ key_layer=key_layer,
719
+ value_layer=value_layer,
720
+ attention_mask=attention_mask,
721
+ causal_shape=(t - self.num_global_tokens, t - self.num_global_tokens)
722
+ )
723
+
724
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
725
+ return (self.reshape_output(context_layer), )
726
+
727
+ # Split K Q M on global and non global
728
+ global_key, key_layer = key_layer.split(split, dim=-2)
729
+ global_value, value_layer = value_layer.split(split, dim=-2)
730
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
731
+
732
+ n, h, t, d = key_layer.size()
733
+
734
+ # Get sparse idx
735
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
736
+ if self.sparse_block_size and self.sparsity_factor > 0:
737
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
738
+
739
+ # Expand masks on heads
740
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
741
+ global_mask = global_mask.expand(-1, h, -1, -1)
742
+
743
+ # Compute dot product attention
744
+ context_layer = self.attention(
745
+ query_layer,
746
+ key_layer,
747
+ value_layer,
748
+ attention_mask,
749
+ sparse_key=sparse_key,
750
+ sparse_value=sparse_value,
751
+ sparse_mask=sparse_mask,
752
+ global_key=global_key,
753
+ global_value=global_value,
754
+ global_mask=global_mask
755
+ )
756
+
757
+ # Merge pseudo global (causal) and local-sparse tokens
758
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
759
+ context_layer = self.reshape_output(context_layer)
760
+
761
+ return (context_layer,)
762
+
763
+ def not_causal_forward(
764
+ self,
765
+ query_layer,
766
+ key_layer,
767
+ value_layer,
768
+ attention_mask=None,
769
+ output_attentions=False,
770
+ ):
771
+
772
+ n, h, t, d = query_layer.size()
773
+
774
+ # Cat global mask
775
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
776
+
777
+ # Use normal attention if local attention covers every tokens
778
+ if t <= 2 * self.block_size + self.num_global_tokens:
779
+ context_layer = self.full_attention(
780
+ query_layer=query_layer,
781
+ key_layer=key_layer,
782
+ value_layer=value_layer,
783
+ attention_mask=attention_mask
784
+ )
785
+ return (self.reshape_output(context_layer), )
786
+
787
+ # Split input into global tokens and other tokens
788
+ split = (self.num_global_tokens, t - self.num_global_tokens)
789
+ global_query, query_layer = query_layer.split(split, dim=-2)
790
+
791
+ # Get global_attention
792
+ bos = self.full_attention(
793
+ query_layer=global_query,
794
+ key_layer=key_layer,
795
+ value_layer=value_layer,
796
+ attention_mask=attention_mask
797
+ )
798
+
799
+ # Split K Q M on global and non global
800
+ global_key, key_layer = key_layer.split(split, dim=-2)
801
+ global_value, value_layer = value_layer.split(split, dim=-2)
802
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
803
+
804
+ n, h, t, d = key_layer.size()
805
+
806
+ # Get sparse idx
807
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
808
+
809
+ if self.sparse_block_size and self.sparsity_factor > 0:
810
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
811
+
812
+ # Expand masks on heads
813
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
814
+ global_mask = global_mask.expand(-1, h, -1, -1)
815
+
816
+ # Compute dot product attention
817
+ context_layer = self.attention(
818
+ query_layer,
819
+ key_layer,
820
+ value_layer,
821
+ attention_mask,
822
+ sparse_key=sparse_key,
823
+ sparse_value=sparse_value,
824
+ sparse_mask=sparse_mask,
825
+ global_key=global_key,
826
+ global_value=global_value,
827
+ global_mask=global_mask
828
+ )
829
+
830
+ # Merge global and local-sparse tokens
831
+ context_layer = torch.cat([bos, context_layer], dim=-2)
832
+ context_layer = self.reshape_output(context_layer)
833
+
834
+ return (context_layer,)
835
+
836
+ def cross_attention_forward(
837
+ self,
838
+ query_layer,
839
+ key_layer,
840
+ value_layer,
841
+ attention_mask=None,
842
+ output_attentions=False,
843
+ ):
844
+
845
+ context_layer = self.full_attention(
846
+ query_layer=query_layer,
847
+ key_layer=key_layer,
848
+ value_layer=value_layer,
849
+ attention_mask=attention_mask
850
+ )
851
+ return (self.reshape_output(context_layer), )
852
+
853
+ def chunk(self, x, chunk_size):
854
+
855
+ n, h, t, d = x.size()
856
+ return x.reshape(n, h, -1, chunk_size, d)
857
+
858
+
859
+ class LSGBertSelfOutput(BertSelfOutput):
860
+
861
+ def __init__(self, config):
862
+ super().__init__(config)
863
+
864
+
865
+ class LSGAttention(BertAttention):
866
+
867
+ def __init__(self, config):
868
+
869
+ nn.Module.__init__(self)
870
+
871
+ self.self = LSGSelfAttention(config)
872
+ self.output = LSGBertSelfOutput(config)
873
+ self.pruned_heads = set()
874
+
875
+
876
+ class LSGBertIntermediate(BertIntermediate):
877
+
878
+ def __init__(self, config):
879
+ super().__init__(config)
880
+
881
+
882
+ class LSGBertOutput(BertOutput):
883
+
884
+ def __init__(self, config):
885
+ super().__init__(config)
886
+
887
+
888
+ class LSGBertLayer(BertLayer):
889
+
890
+ def __init__(self, config):
891
+
892
+ nn.Module.__init__(self)
893
+
894
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
895
+ self.seq_len_dim = 1
896
+ self.attention = LSGAttention(config)
897
+ self.is_decoder = config.is_decoder
898
+ self.add_cross_attention = config.add_cross_attention
899
+ if self.add_cross_attention:
900
+ if not self.is_decoder:
901
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
902
+ self.crossattention = LSGAttention(config)
903
+ self.intermediate = LSGBertIntermediate(config)
904
+ self.output = LSGBertOutput(config)
905
+
906
+
907
+ class LSGBertEncoder(BertEncoder):
908
+
909
+ def __init__(self, config):
910
+
911
+ nn.Module.__init__(self)
912
+
913
+ self.config = config
914
+ self.layer = nn.ModuleList([LSGBertLayer(config) for _ in range(config.num_hidden_layers)])
915
+ self.gradient_checkpointing = False
916
+
917
+
918
+ class LSGBertPooler(BertPooler):
919
+
920
+ def __init__(self, config):
921
+ super().__init__(config)
922
+
923
+
924
+ class LSGBertPredictionHeadTransform(BertPredictionHeadTransform):
925
+
926
+ def __init__(self, config):
927
+ super().__init__(config)
928
+
929
+
930
+ class LSGBertLMPredictionHead(BertLMPredictionHead):
931
+
932
+ def __init__(self, config):
933
+
934
+ nn.Module.__init__(self)
935
+
936
+ self.transform = LSGBertPredictionHeadTransform(config)
937
+
938
+ # The output weights are the same as the input embeddings, but there is
939
+ # an output-only bias for each token.
940
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
941
+
942
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
943
+
944
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
945
+ self.decoder.bias = self.bias
946
+
947
+
948
+ class LSGBertOnlyMLMHead(BertOnlyMLMHead):
949
+ """LSG Head for masked language modeling."""
950
+
951
+ def __init__(self, config):
952
+
953
+ nn.Module.__init__(self)
954
+
955
+ self.predictions = LSGBertLMPredictionHead(config)
956
+
957
+
958
+ class LSGBertOnlyNSPHead(BertOnlyNSPHead):
959
+
960
+ def __init__(self, config):
961
+ super().__init__(config)
962
+
963
+
964
+ class LSGBertPreTrainingHeads(BertPreTrainingHeads):
965
+
966
+ def __init__(self, config):
967
+
968
+ nn.Module.__init__(self)
969
+
970
+ self.predictions = BertLMPredictionHead(config)
971
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
972
+
973
+
974
+ class LSGBertPreTrainedModel(BertPreTrainedModel):
975
+ """
976
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
977
+ models.
978
+ """
979
+
980
+ config_class = LSGBertConfig
981
+
982
+ def _set_gradient_checkpointing(self, module, value=False):
983
+ if isinstance(module, (BertEncoder, LSGBertEncoder)):
984
+ module.gradient_checkpointing = value
985
+
986
+
987
+ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
988
+ """
989
+ This class overrides :class:`~transformers.BertModel`. Please check the superclass for the appropriate
990
+ documentation alongside usage examples.
991
+ """
992
+
993
+ config_class = LSGBertConfig
994
+
995
+ def __init__(self, config, add_pooling_layer=True):
996
+
997
+ LSGBertPreTrainedModel.__init__(self, config)
998
+
999
+ self.config = config
1000
+ assert hasattr(config, "num_global_tokens")
1001
+ self.num_global_tokens = config.num_global_tokens
1002
+ self.pad_idx = config.pad_token_id
1003
+
1004
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
1005
+ self.block_size = config.block_size
1006
+ self.adaptive = config.adaptive
1007
+ self.pool_with_global = config.pool_with_global
1008
+
1009
+ self.embeddings = LSGBertEmbeddings(config)
1010
+ self.encoder = LSGBertEncoder(config)
1011
+ self.pooler = LSGBertPooler(config) if add_pooling_layer else None
1012
+
1013
+ if config.add_cross_attention:
1014
+ logger.warning(
1015
+ "Cross attention is computed using full attention since it is not LSG compatible."
1016
+ )
1017
+
1018
+ # Initialize weights and apply final processing
1019
+ self.post_init()
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids=None,
1024
+ attention_mask=None,
1025
+ token_type_ids=None,
1026
+ position_ids=None,
1027
+ head_mask=None,
1028
+ inputs_embeds=None,
1029
+ encoder_hidden_states=None,
1030
+ encoder_attention_mask=None,
1031
+ past_key_values=None,
1032
+ use_cache=None,
1033
+ output_attentions=None,
1034
+ output_hidden_states=None,
1035
+ return_dict=None
1036
+ ):
1037
+
1038
+ inputs_ = input_ids if input_ids is not None else inputs_embeds
1039
+ n, t = inputs_.size()[:2]
1040
+
1041
+ if attention_mask is None:
1042
+ attention_mask = torch.ones(n, t, device=inputs_.device)
1043
+ if token_type_ids is None:
1044
+ token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
1045
+
1046
+ b = self.block_size * 2
1047
+ pad = t % self.block_size
1048
+
1049
+ # Check if t is multiple of block_size and pad
1050
+ if self.adaptive and t > b and pad > 0:
1051
+ pad_length = self.block_size - pad
1052
+ if input_ids is not None:
1053
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1054
+ else:
1055
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1056
+
1057
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1058
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1059
+
1060
+ if position_ids is not None:
1061
+ position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1062
+
1063
+ n, t_ = attention_mask.size()
1064
+
1065
+ encoder_outputs = super().forward(
1066
+ input_ids=input_ids,
1067
+ attention_mask=attention_mask,
1068
+ token_type_ids=token_type_ids,
1069
+ position_ids=position_ids,
1070
+ head_mask=head_mask,
1071
+ inputs_embeds=inputs_embeds,
1072
+ encoder_hidden_states=encoder_hidden_states,
1073
+ encoder_attention_mask=encoder_attention_mask,
1074
+ past_key_values=past_key_values,
1075
+ use_cache=use_cache,
1076
+ output_attentions=output_attentions,
1077
+ output_hidden_states=output_hidden_states,
1078
+ return_dict=return_dict
1079
+ )
1080
+
1081
+ context = encoder_outputs[0]
1082
+ if self.pool_with_global:
1083
+ context[:, self.num_global_tokens] = context[:, 0]
1084
+
1085
+ diff = t - t_
1086
+ n, _, d = context.size()
1087
+ context = context[..., self.num_global_tokens:, :]
1088
+
1089
+ # Adapt sequence to initial shape
1090
+ if diff < 0:
1091
+ context = context[:, :t]
1092
+
1093
+ encoder_outputs.last_hidden_state = context
1094
+
1095
+ sequence_output = encoder_outputs[0]
1096
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1097
+
1098
+ if not return_dict:
1099
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1100
+
1101
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1102
+ last_hidden_state=sequence_output,
1103
+ pooler_output=pooled_output,
1104
+ past_key_values=encoder_outputs.past_key_values,
1105
+ hidden_states=encoder_outputs.hidden_states,
1106
+ attentions=encoder_outputs.attentions,
1107
+ cross_attentions=encoder_outputs.cross_attentions,
1108
+ )
1109
+
1110
+ def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1111
+
1112
+ # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
1113
+ if attention_mask.dim() == 3:
1114
+ extended_attention_mask = attention_mask[:, None, :, :]
1115
+ elif attention_mask.dim() == 2:
1116
+ extended_attention_mask = attention_mask[:, None, None, :]
1117
+ else:
1118
+ raise ValueError(
1119
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
1120
+ )
1121
+
1122
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1123
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1124
+
1125
+ return extended_attention_mask
1126
+
1127
+
1128
+ class LSGBertForPreTraining(LSGBertPreTrainedModel):
1129
+
1130
+ def __init__(self, config):
1131
+
1132
+ super().__init__(config)
1133
+
1134
+ self.bert = LSGBertModel(config)
1135
+ self.cls = LSGBertPreTrainingHeads(config)
1136
+
1137
+ # Initialize weights and apply final processing
1138
+ self.post_init()
1139
+
1140
+
1141
+ class LSGBertLMHeadModel(BertLMHeadModel):
1142
+
1143
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1144
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1145
+
1146
+ def __init__(self, config):
1147
+
1148
+ BertPreTrainedModel.__init__(self, config)
1149
+
1150
+ if not config.is_decoder:
1151
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1152
+
1153
+ self.bert = LSGBertModel(config, add_pooling_layer=False)
1154
+ self.cls = LSGBertOnlyMLMHead(config)
1155
+
1156
+ # Initialize weights and apply final processing
1157
+ self.post_init()
1158
+
1159
+
1160
+ class LSGBertForMaskedLM(LSGBertPreTrainedModel, BertForMaskedLM):
1161
+ """
1162
+ This class overrides :class:`~transformers.BertForMaskedLM`. Please check the superclass for the appropriate
1163
+ documentation alongside usage examples.
1164
+ """
1165
+
1166
+ config_class = LSGBertConfig
1167
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1168
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1169
+
1170
+ def __init__(self, config):
1171
+
1172
+ LSGBertPreTrainedModel.__init__(self, config)
1173
+
1174
+ if config.is_decoder:
1175
+ logger.warning(
1176
+ "If you want to use `LSGBertForMaskedLM` make sure `config.is_decoder=False` for "
1177
+ "bi-directional self-attention."
1178
+ )
1179
+
1180
+ self.bert = LSGBertModel(config, add_pooling_layer=False)
1181
+ self.cls = LSGBertOnlyMLMHead(config)
1182
+
1183
+ # Initialize weights and apply final processing
1184
+ self.post_init()
1185
+
1186
+
1187
+ class LSGBertForNextSentencePrediction(LSGBertPreTrainedModel, BertForNextSentencePrediction):
1188
+
1189
+ def __init__(self, config):
1190
+
1191
+ LSGBertPreTrainedModel.__init__(self, config)
1192
+
1193
+ self.bert = LSGBertModel(config)
1194
+ self.cls = LSGBertOnlyNSPHead(config)
1195
+
1196
+ # Initialize weights and apply final processing
1197
+ self.post_init()
1198
+
1199
+
1200
+ class LSGBertForSequenceClassification(LSGBertPreTrainedModel, BertForSequenceClassification):
1201
+ """
1202
+ This class overrides :class:`~transformers.BertForSequenceClassification`. Please check the superclass for the
1203
+ appropriate documentation alongside usage examples.
1204
+ """
1205
+
1206
+ config_class = LSGBertConfig
1207
+
1208
+ def __init__(self, config):
1209
+
1210
+ LSGBertPreTrainedModel.__init__(self, config)
1211
+
1212
+ self.num_labels = config.num_labels
1213
+ self.config = config
1214
+
1215
+ self.bert = LSGBertModel(config)
1216
+ classifier_dropout = (
1217
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1218
+ )
1219
+ self.dropout = nn.Dropout(classifier_dropout)
1220
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1221
+
1222
+ # Initialize weights and apply final processing
1223
+ self.post_init()
1224
+
1225
+
1226
+ class LSGBertForMultipleChoice(LSGBertPreTrainedModel, BertForMultipleChoice):
1227
+ """
1228
+ This class overrides :class:`~transformers.BertForMultipleChoice`. Please check the superclass for the
1229
+ appropriate documentation alongside usage examples.
1230
+ """
1231
+
1232
+ config_class = LSGBertConfig
1233
+
1234
+ def __init__(self, config):
1235
+
1236
+ LSGBertPreTrainedModel.__init__(self, config)
1237
+
1238
+ self.bert = LSGBertModel(config)
1239
+ classifier_dropout = (
1240
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1241
+ )
1242
+ self.dropout = nn.Dropout(classifier_dropout)
1243
+ self.classifier = nn.Linear(config.hidden_size, 1)
1244
+
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+
1249
+ class LSGBertForTokenClassification(LSGBertPreTrainedModel, BertForTokenClassification):
1250
+ """
1251
+ This class overrides :class:`~transformers.BertForTokenClassification`. Please check the superclass for the
1252
+ appropriate documentation alongside usage examples.
1253
+ """
1254
+
1255
+ config_class = LSGBertConfig
1256
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1257
+
1258
+ def __init__(self, config):
1259
+
1260
+ LSGBertPreTrainedModel.__init__(self, config)
1261
+
1262
+ self.num_labels = config.num_labels
1263
+
1264
+ self.bert = LSGBertModel(config, add_pooling_layer=False)
1265
+ classifier_dropout = (
1266
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1267
+ )
1268
+ self.dropout = nn.Dropout(classifier_dropout)
1269
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1270
+
1271
+ # Initialize weights and apply final processing
1272
+ self.post_init()
1273
+
1274
+
1275
+ class LSGBertForQuestionAnswering(LSGBertPreTrainedModel, BertForQuestionAnswering):
1276
+ """
1277
+ This class overrides :class:`~transformers.BertForQuestionAnswering`. Please check the superclass for the
1278
+ appropriate documentation alongside usage examples.
1279
+ """
1280
+
1281
+ config_class = LSGBertConfig
1282
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1283
+
1284
+ def __init__(self, config):
1285
+
1286
+ LSGBertPreTrainedModel.__init__(self, config)
1287
+
1288
+ self.num_labels = config.num_labels
1289
+
1290
+ self.bert = LSGBertModel(config, add_pooling_layer=False)
1291
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1292
+
1293
+ # Initialize weights and apply final processing
1294
+ self.post_init()
1295
+
1296
+
1297
+ def str_to_class(classname):
1298
+ return getattr(sys.modules[__name__], classname)
1299
+
1300
+ # Register model in Auto API
1301
+ try:
1302
+ LSGBertConfig.register_for_auto_class()
1303
+ for key, value in AUTO_MAP.items():
1304
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1305
+ except:
1306
+ warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1307
+ warn("Update to transformers >= 4.17.0 to fix.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b2ab614efe94be35365d04f75cf89fa9fe02f1798b9b335dc232fdbe0e077b9
3
+ size 213468895
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 4096, "special_tokens_map_file": null, "name_or_path": "nlpaueb/legal-bert-small-uncased", "do_basic_tokenize": true, "never_split": null, "tokenizer_class": "BertTokenizer"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff