Upload ContextualDocumentEmbeddingTransformer
Browse files- README.md +199 -0
- config.json +28 -0
- misc.py +518 -0
- model.py +622 -0
- model.safetensors +3 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/fsx-checkpoints/jxm/cde/2024-08-06-transductive-pretrain-transductive-long-10node-3/checkpoint-7176",
|
3 |
+
"architecture": "transductive",
|
4 |
+
"architectures": [
|
5 |
+
"ContextualDocumentEmbeddingTransformer"
|
6 |
+
],
|
7 |
+
"attn_implementation": null,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "misc.ContextualModelConfig",
|
10 |
+
"AutoModel": "model.ContextualDocumentEmbeddingTransformer"
|
11 |
+
},
|
12 |
+
"cache_dir": null,
|
13 |
+
"config_name": null,
|
14 |
+
"disable_dropout": true,
|
15 |
+
"disable_transductive_rotary_embedding": true,
|
16 |
+
"embedder": "nomic-ai/nomic-bert-2048",
|
17 |
+
"embedder_rerank": "sentence-transformers/gtr-t5-base",
|
18 |
+
"embedding_output_dim": null,
|
19 |
+
"limit_layers": null,
|
20 |
+
"logit_scale": 50.0,
|
21 |
+
"max_seq_length": 512,
|
22 |
+
"model_revision": "main",
|
23 |
+
"tokenizer_name": null,
|
24 |
+
"torch_dtype": "float32",
|
25 |
+
"transductive_corpus_size": 512,
|
26 |
+
"transductive_sequence_dropout_prob": 0.0,
|
27 |
+
"transformers_version": "4.48.0.dev0"
|
28 |
+
}
|
misc.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import hashlib
|
7 |
+
import itertools
|
8 |
+
import logging
|
9 |
+
import multiprocessing
|
10 |
+
import os
|
11 |
+
import pickle
|
12 |
+
import random
|
13 |
+
import requests
|
14 |
+
import sys
|
15 |
+
import zipfile
|
16 |
+
|
17 |
+
import datasets
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import tqdm
|
21 |
+
import transformers
|
22 |
+
|
23 |
+
from cde.lib.dist import get_num_proc, get_rank
|
24 |
+
|
25 |
+
|
26 |
+
def get_cde_cache_dir() -> str:
|
27 |
+
script_directory = os.path.normpath(
|
28 |
+
os.path.join(
|
29 |
+
os.path.dirname(os.path.abspath(__file__)),
|
30 |
+
os.pardir, os.pardir,
|
31 |
+
)
|
32 |
+
)
|
33 |
+
return os.path.join(script_directory, "data")
|
34 |
+
|
35 |
+
|
36 |
+
def get_cache_location_from_kwargs(**kwargs):
|
37 |
+
cache_location = os.path.join(
|
38 |
+
get_cde_cache_dir(), "cluster"
|
39 |
+
)
|
40 |
+
os.makedirs(cache_location, exist_ok=True)
|
41 |
+
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
42 |
+
|
43 |
+
|
44 |
+
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
45 |
+
qrels_idxs = collections.defaultdict(list)
|
46 |
+
qrels_scores = collections.defaultdict(list)
|
47 |
+
corpus_ids = np.array(corpus['_id'])
|
48 |
+
skipped_qrels = 0
|
49 |
+
|
50 |
+
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
51 |
+
#
|
52 |
+
# example:
|
53 |
+
# {
|
54 |
+
# 'query-id': 1,
|
55 |
+
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
56 |
+
# 'score': 2
|
57 |
+
# }
|
58 |
+
#
|
59 |
+
q_id = str(ex['query-id'])
|
60 |
+
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
61 |
+
#
|
62 |
+
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
63 |
+
#
|
64 |
+
if len(c_idxs):
|
65 |
+
qrels_idxs[q_id].append(c_idxs[0])
|
66 |
+
qrels_scores[q_id].append(ex['score'])
|
67 |
+
else:
|
68 |
+
skipped_qrels += 1
|
69 |
+
#
|
70 |
+
|
71 |
+
if skipped_qrels > 0:
|
72 |
+
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
73 |
+
|
74 |
+
return qrels_idxs, qrels_scores
|
75 |
+
|
76 |
+
|
77 |
+
def process_qrels(
|
78 |
+
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
79 |
+
use_cache: bool = True
|
80 |
+
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
81 |
+
dataset_cache_file = '_'.join(
|
82 |
+
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
83 |
+
)
|
84 |
+
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
85 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
86 |
+
|
87 |
+
if not (use_cache and os.path.exists(cache_file)):
|
88 |
+
qrels_idxs, qrels_scores = process_qrels_uncached(
|
89 |
+
corpus=corpus, qrels=qrels
|
90 |
+
)
|
91 |
+
if use_cache:
|
92 |
+
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
93 |
+
else:
|
94 |
+
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
95 |
+
|
96 |
+
return qrels_idxs, qrels_scores
|
97 |
+
|
98 |
+
|
99 |
+
def strip_extension(filename: str) -> str:
|
100 |
+
"""Strips file extension.
|
101 |
+
|
102 |
+
Ex:
|
103 |
+
>> strip_extension('/root/dir/sub/file.ext')
|
104 |
+
'/root/dir/sub/file'
|
105 |
+
"""
|
106 |
+
return os.path.splitext(filename)[0]
|
107 |
+
|
108 |
+
|
109 |
+
def md5_hash(t: Tuple[str]) -> str:
|
110 |
+
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
111 |
+
|
112 |
+
|
113 |
+
def md5_hash_kwargs(**kwargs) -> str:
|
114 |
+
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
115 |
+
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
116 |
+
s = json.dumps(safe_kwargs, sort_keys=True)
|
117 |
+
return hashlib.md5(s.encode()).hexdigest()
|
118 |
+
|
119 |
+
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
120 |
+
"""Download url with progress bar using tqdm
|
121 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
122 |
+
Args:
|
123 |
+
url (str): downloadable url
|
124 |
+
save_path (str): local path to save the downloaded file
|
125 |
+
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
126 |
+
"""
|
127 |
+
r = requests.get(url, stream=True)
|
128 |
+
total = int(r.headers.get('Content-Length', 0))
|
129 |
+
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
130 |
+
desc=save_path,
|
131 |
+
total=total,
|
132 |
+
unit='iB',
|
133 |
+
unit_scale=True,
|
134 |
+
unit_divisor=chunk_size,
|
135 |
+
) as bar:
|
136 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
137 |
+
size = fd.write(data)
|
138 |
+
bar.update(size)
|
139 |
+
|
140 |
+
|
141 |
+
def unzip(zip_file: str, out_dir: str):
|
142 |
+
print("unzipping =>", zip_file)
|
143 |
+
zip_ = zipfile.ZipFile(zip_file, "r")
|
144 |
+
zip_.extractall(path=out_dir)
|
145 |
+
zip_.close()
|
146 |
+
|
147 |
+
|
148 |
+
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
149 |
+
os.makedirs(out_dir, exist_ok=True)
|
150 |
+
dataset = url.split("/")[-1]
|
151 |
+
zip_file = os.path.join(out_dir, dataset)
|
152 |
+
|
153 |
+
if not os.path.isfile(zip_file):
|
154 |
+
logging.info("Downloading {} ...".format(dataset))
|
155 |
+
download_url(url, zip_file, chunk_size)
|
156 |
+
|
157 |
+
if not os.path.isdir(zip_file.replace(".zip", "")):
|
158 |
+
logging.info("Unzipping {} ...".format(dataset))
|
159 |
+
unzip(zip_file, out_dir)
|
160 |
+
|
161 |
+
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
162 |
+
|
163 |
+
|
164 |
+
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
165 |
+
if get_rank() == 0:
|
166 |
+
return tqdm.tqdm(iterable, **kwargs)
|
167 |
+
else:
|
168 |
+
return iterable
|
169 |
+
|
170 |
+
|
171 |
+
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
172 |
+
"""We create a dummy configuration class that will just set properties
|
173 |
+
based on whatever kwargs we pass in.
|
174 |
+
|
175 |
+
When this class is initialized (see experiments.py) we pass in the
|
176 |
+
union of all data, model, and training args, all of which should
|
177 |
+
get saved to the config json.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, **kwargs):
|
181 |
+
for key, value in kwargs.items():
|
182 |
+
try:
|
183 |
+
json.dumps(value)
|
184 |
+
setattr(self, key, value)
|
185 |
+
except TypeError:
|
186 |
+
# value was not JSON-serializable, skip
|
187 |
+
continue
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
|
191 |
+
def independent_crop(
|
192 |
+
input_ids: torch.Tensor, pad_token_id: int,
|
193 |
+
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
194 |
+
"""Returns two independent crops from input_ids.
|
195 |
+
|
196 |
+
Assumes input_ids has a beginning and end token, like
|
197 |
+
[101, ..., 102, 0, 0, 0].
|
198 |
+
|
199 |
+
Args:
|
200 |
+
input_ids: tensor of IDs
|
201 |
+
pad_token_id: ID of pad tokens in input_ids
|
202 |
+
l1: length of span 1, cropped
|
203 |
+
l2: length of span 2, cropped
|
204 |
+
Returns:
|
205 |
+
span1: first crop (of length l1)
|
206 |
+
span2: second crop (of length l2)
|
207 |
+
"""
|
208 |
+
# Count tokens until pad.
|
209 |
+
if (input_ids == pad_token_id).sum() == 0:
|
210 |
+
N = len(input_ids)
|
211 |
+
else:
|
212 |
+
N = (input_ids == pad_token_id).int().argmax().item()
|
213 |
+
|
214 |
+
####
|
215 |
+
###
|
216 |
+
##
|
217 |
+
## Contriever: We use the random cropping data
|
218 |
+
## augmentation, with documents of 256 tokens and span
|
219 |
+
## sizes sampled between 5% and 50% of the document
|
220 |
+
## length
|
221 |
+
##
|
222 |
+
###
|
223 |
+
#####
|
224 |
+
####### LaPraDor: The maximum lengths set for queries and
|
225 |
+
####### documents are 64 and 350...
|
226 |
+
#####
|
227 |
+
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
228 |
+
nl1 = min(N//2, l1)
|
229 |
+
nl2 = min(N//2, l2)
|
230 |
+
|
231 |
+
s1_start = random.randint(1, N-nl1)
|
232 |
+
s2_start = random.randint(1, N-nl2)
|
233 |
+
|
234 |
+
s1_idxs = itertools.chain(
|
235 |
+
[0], range(s1_start, s1_start+nl1), [N-1]
|
236 |
+
)
|
237 |
+
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
238 |
+
s2_idxs = itertools.chain(
|
239 |
+
[0], range(s2_start, s2_start+nl2), [N-1]
|
240 |
+
)
|
241 |
+
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
242 |
+
return (s1, s2)
|
243 |
+
|
244 |
+
|
245 |
+
def load_dataset_tables(
|
246 |
+
files: Iterable[str], num_workers: int = 16
|
247 |
+
) -> Iterable[datasets.table.MemoryMappedTable]:
|
248 |
+
import concurrent
|
249 |
+
from multiprocessing import Pool
|
250 |
+
|
251 |
+
# num_workers = min(num_workers, len(files))
|
252 |
+
num_workers = min(32, len(files))
|
253 |
+
|
254 |
+
use_threads = True
|
255 |
+
if use_threads:
|
256 |
+
pool_cls = concurrent.futures.ThreadPoolExecutor
|
257 |
+
pool_kwargs = {"max_workers": num_workers}
|
258 |
+
else:
|
259 |
+
pool_cls = Pool
|
260 |
+
pool_kwargs = {"processes": num_workers}
|
261 |
+
|
262 |
+
with pool_cls(**pool_kwargs) as pool:
|
263 |
+
if len(files) > 10:
|
264 |
+
files = tqdm_if_main_worker(
|
265 |
+
files,
|
266 |
+
desc=f"Loading {len(files)} files with {num_workers} workers",
|
267 |
+
total=len(files),
|
268 |
+
colour="#ffbd88"
|
269 |
+
)
|
270 |
+
|
271 |
+
result = list(
|
272 |
+
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
273 |
+
)
|
274 |
+
return result
|
275 |
+
|
276 |
+
|
277 |
+
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
278 |
+
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
279 |
+
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
280 |
+
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
281 |
+
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
282 |
+
|
283 |
+
dataset_state_path = os.path.join(cache_path, "state.json")
|
284 |
+
with open(dataset_state_path, encoding="utf-8") as state_file:
|
285 |
+
state = json.load(state_file)
|
286 |
+
|
287 |
+
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
288 |
+
files = sorted(files)
|
289 |
+
num_workers = get_num_proc()
|
290 |
+
ds_tables = load_dataset_tables(
|
291 |
+
files=files,
|
292 |
+
num_workers=num_workers
|
293 |
+
)
|
294 |
+
arrow_table = datasets.table.concat_tables(ds_tables)
|
295 |
+
|
296 |
+
split = state["_split"]
|
297 |
+
split = datasets.splits.Split(split) if split is not None else split
|
298 |
+
|
299 |
+
# print("returning dataset")
|
300 |
+
return datasets.Dataset(
|
301 |
+
arrow_table=arrow_table,
|
302 |
+
info=dataset_info,
|
303 |
+
split=split,
|
304 |
+
fingerprint=state["_fingerprint"],
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
def tokenize_dataset(
|
309 |
+
dataset: datasets.Dataset,
|
310 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
311 |
+
max_length: int,
|
312 |
+
text_key: str,
|
313 |
+
padding_strategy: str
|
314 |
+
) -> datasets.Dataset:
|
315 |
+
def tokenize_text(ex: Dict) -> Dict:
|
316 |
+
tt = tokenizer(
|
317 |
+
ex[text_key],
|
318 |
+
max_length=max_length,
|
319 |
+
truncation=True,
|
320 |
+
padding=padding_strategy,
|
321 |
+
)
|
322 |
+
for k,v in tt.items():
|
323 |
+
ex[f"{text_key}_{k}"] = v
|
324 |
+
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
325 |
+
return ex
|
326 |
+
|
327 |
+
# generate unique hash for tokenizer
|
328 |
+
vocab = tokenizer.vocab
|
329 |
+
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
330 |
+
vocab_hash = md5_hash(vocab_words)
|
331 |
+
|
332 |
+
data_fingerprint = '__'.join((
|
333 |
+
dataset._fingerprint, str(vocab_hash), str(max_length),
|
334 |
+
text_key, padding_strategy
|
335 |
+
))
|
336 |
+
data_fingerprint = md5_hash(data_fingerprint)
|
337 |
+
dataset = dataset.map(
|
338 |
+
tokenize_text,
|
339 |
+
new_fingerprint=data_fingerprint,
|
340 |
+
batched=True,
|
341 |
+
load_from_cache_file=True,
|
342 |
+
)
|
343 |
+
return dataset
|
344 |
+
|
345 |
+
|
346 |
+
class TensorRunningAverages:
|
347 |
+
_store_sum: Dict[str, torch.Tensor]
|
348 |
+
_store_total: Dict[str, torch.Tensor]
|
349 |
+
|
350 |
+
def __init__(self):
|
351 |
+
self._store_sum = {}
|
352 |
+
self._store_total = {}
|
353 |
+
|
354 |
+
def __iter__(self) -> Iterable[str]:
|
355 |
+
return iter(self._store_sum.keys())
|
356 |
+
|
357 |
+
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
358 |
+
if key not in self._store_sum:
|
359 |
+
self.clear(key)
|
360 |
+
if isinstance(val, torch.Tensor):
|
361 |
+
val = val.item() # tensor -> num
|
362 |
+
self._store_sum[key] += val
|
363 |
+
self._store_total[key] += 1
|
364 |
+
|
365 |
+
def get(self, key: str) -> float:
|
366 |
+
total = max(self._store_total.get(key).item(), 1.0)
|
367 |
+
return (self._store_sum[key] / float(total)).item() or 0.0
|
368 |
+
|
369 |
+
def clear(self, key: str) -> None:
|
370 |
+
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
371 |
+
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
372 |
+
|
373 |
+
def clear_all(self) -> None:
|
374 |
+
for key in self._store_sum:
|
375 |
+
self.clear(key)
|
376 |
+
|
377 |
+
def get_and_clear_all(self) -> Dict[str, float]:
|
378 |
+
metrics = {}
|
379 |
+
for key in self:
|
380 |
+
metrics[key] = self.get(key)
|
381 |
+
self.clear(key)
|
382 |
+
return metrics
|
383 |
+
|
384 |
+
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
385 |
+
transformers.PreTrainedModel,
|
386 |
+
transformers.PreTrainedTokenizer
|
387 |
+
]:
|
388 |
+
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
389 |
+
from cde.lib.nomic_bert import NomicBertModel
|
390 |
+
if name.endswith("--from-scratch"):
|
391 |
+
name = name.replace("--from-scratch", "")
|
392 |
+
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
393 |
+
model = NomicBertModel._from_config(config)
|
394 |
+
else:
|
395 |
+
model = NomicBertModel.from_pretrained(
|
396 |
+
name, add_pooling_layer=False
|
397 |
+
)
|
398 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
399 |
+
elif name in ["gtr-base", "gtr_base"]:
|
400 |
+
model = transformers.AutoModel.from_pretrained(
|
401 |
+
"sentence-transformers/gtr-t5-base"
|
402 |
+
).encoder
|
403 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
404 |
+
"sentence-transformers/gtr-t5-base"
|
405 |
+
)
|
406 |
+
elif name == "pile-t5-base-encoder":
|
407 |
+
model = transformers.AutoModel.from_pretrained(
|
408 |
+
"EleutherAI/pile-t5-base"
|
409 |
+
).encoder
|
410 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
411 |
+
"EleutherAI/pile-t5-base"
|
412 |
+
)
|
413 |
+
tokenizer.pad_token = tokenizer.eos_token
|
414 |
+
elif name == "pile-t5-base-decoder":
|
415 |
+
model = transformers.AutoModel.from_pretrained(
|
416 |
+
"EleutherAI/pile-t5-base"
|
417 |
+
).decoder
|
418 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
419 |
+
"EleutherAI/pile-t5-base"
|
420 |
+
)
|
421 |
+
tokenizer.pad_token = tokenizer.eos_token
|
422 |
+
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
423 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
424 |
+
name,
|
425 |
+
# torch_dtype=torch.bfloat16,
|
426 |
+
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
|
427 |
+
low_cpu_mem_usage=True,
|
428 |
+
# device_map="auto",
|
429 |
+
)
|
430 |
+
model.padding_side = "right"
|
431 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
432 |
+
tokenizer.pad_token = tokenizer.eos_token
|
433 |
+
tokenizer.add_eos_token = True
|
434 |
+
elif "Modern" in name:
|
435 |
+
print("special loading for ModernBERT!")
|
436 |
+
# [1] needed for faster training
|
437 |
+
# model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=True)
|
438 |
+
# [2] needed for non-breaking inference
|
439 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=False)
|
440 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
441 |
+
else:
|
442 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
443 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
444 |
+
return model, tokenizer
|
445 |
+
|
446 |
+
|
447 |
+
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
448 |
+
key += "_"
|
449 |
+
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
450 |
+
|
451 |
+
|
452 |
+
def count_cpus() -> int:
|
453 |
+
try:
|
454 |
+
return len(os.sched_getaffinity(0))
|
455 |
+
except AttributeError:
|
456 |
+
return multiprocessing.cpu_count()
|
457 |
+
|
458 |
+
|
459 |
+
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
460 |
+
all_indices = []
|
461 |
+
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
462 |
+
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
463 |
+
batch_list = batch_tensor[rand_perm].tolist()
|
464 |
+
all_indices.extend(batch_list)
|
465 |
+
return all_indices
|
466 |
+
|
467 |
+
|
468 |
+
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
469 |
+
# all_indices = []
|
470 |
+
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
471 |
+
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
472 |
+
# pool = multiprocessing.Pool(processes=num_processes)
|
473 |
+
# chunk_size = len(list_of_tensors) // num_processes
|
474 |
+
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
475 |
+
# worker_func = functools.partial(shuffle_batches, g=g)
|
476 |
+
# results = pool.map(worker_func, chunks)
|
477 |
+
# all_indices = []
|
478 |
+
# for result in results:
|
479 |
+
# all_indices.extend(result)
|
480 |
+
# pbar.update()
|
481 |
+
# return all_indices
|
482 |
+
|
483 |
+
|
484 |
+
def exit_if_running_or_finished_wandb(
|
485 |
+
project_name: str,
|
486 |
+
exp_group: str, exp_name: str
|
487 |
+
) -> None:
|
488 |
+
print("Checking if experiment is already running...")
|
489 |
+
import wandb
|
490 |
+
|
491 |
+
api = wandb.Api()
|
492 |
+
running_runs = api.runs(
|
493 |
+
path="cde-0",
|
494 |
+
filters={
|
495 |
+
"display_name": exp_name,
|
496 |
+
"state": {"$regex": "Running|Finished"},
|
497 |
+
"config.exp_group": exp_group,
|
498 |
+
}
|
499 |
+
)
|
500 |
+
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
501 |
+
|
502 |
+
if len(running_runs) > 0:
|
503 |
+
print("Exiting because experiment is already running or completed.")
|
504 |
+
sys.exit(0)
|
505 |
+
|
506 |
+
|
507 |
+
HN_FILTER_TOKENIZER_MAP = {
|
508 |
+
"nomic": "nomic-ai/nomic-embed-text-v1",
|
509 |
+
"stella": "dunzhang/stella_en_400M_v5",
|
510 |
+
"sbert": "sentence-transformers/all-MiniLM-L6-v2",
|
511 |
+
"sentence_t5": "sentence-transformers/sentence-t5-base",
|
512 |
+
"gte": "Alibaba-NLP/gte-large-en-v1.5",
|
513 |
+
}
|
514 |
+
def load_hn_filter_tokenizer(tokenizer_name: str) -> Optional[transformers.PreTrainedTokenizer]:
|
515 |
+
if tokenizer_name in HN_FILTER_TOKENIZER_MAP:
|
516 |
+
return transformers.AutoTokenizer.from_pretrained(HN_FILTER_TOKENIZER_MAP[tokenizer_name])
|
517 |
+
else:
|
518 |
+
return None
|
model.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import transformers
|
7 |
+
|
8 |
+
from cde.lib.dist import print0
|
9 |
+
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
|
10 |
+
|
11 |
+
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
|
12 |
+
|
13 |
+
|
14 |
+
gpt_tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
15 |
+
|
16 |
+
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
|
17 |
+
if hasattr(model, 'transformer'):
|
18 |
+
if hasattr(model.transformer, 'h'):
|
19 |
+
# gpt2
|
20 |
+
model.transformer.h = model.transformer.h[:n_layers]
|
21 |
+
else:
|
22 |
+
model.transformer.layer = model.transformer.layer[:n_layers]
|
23 |
+
elif hasattr(model, 'encoder'):
|
24 |
+
if hasattr(model.encoder, 'layers'):
|
25 |
+
model.encoder.layers = model.encoder.layers[:n_layers]
|
26 |
+
else:
|
27 |
+
model.encoder.layer = model.encoder.layer[:n_layers]
|
28 |
+
else:
|
29 |
+
raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
|
30 |
+
|
31 |
+
|
32 |
+
def disable_dropout(model: torch.nn.Module):
|
33 |
+
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
|
34 |
+
for m in dropout_modules:
|
35 |
+
m.p = 0.0
|
36 |
+
print0(
|
37 |
+
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def disable_causality(model: torch.nn.Module):
|
42 |
+
disabled_modules = 0
|
43 |
+
for m in model.modules():
|
44 |
+
if hasattr(m, "is_causal"):
|
45 |
+
m.is_causal = False
|
46 |
+
disabled_modules += 1
|
47 |
+
print0(
|
48 |
+
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class ContextualModelMixin(nn.Module):
|
53 |
+
@property
|
54 |
+
def num_corpus_tokens(self) -> int:
|
55 |
+
return self.transductive_corpus_size * self.transductive_tokens_per_document
|
56 |
+
|
57 |
+
def contextual_init(self):
|
58 |
+
self.n_soft_prompt = 8
|
59 |
+
self.prompt_projection = torch.nn.Sequential(
|
60 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
61 |
+
torch.nn.ReLU(),
|
62 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
|
63 |
+
)
|
64 |
+
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
|
65 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
66 |
+
self.randomize_dataset_sequence_order = True
|
67 |
+
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
|
68 |
+
if self.sequence_dropout_prob > 0.0:
|
69 |
+
self.sequence_dropout_null_embedding = torch.nn.Parameter(
|
70 |
+
torch.randn(self.hidden_size) * 0.01,
|
71 |
+
requires_grad = True
|
72 |
+
)
|
73 |
+
self.output_projection = torch.nn.Sequential(
|
74 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
75 |
+
torch.nn.ReLU(),
|
76 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size)
|
77 |
+
)
|
78 |
+
|
79 |
+
def _prepare_dataset_embeddings(
|
80 |
+
self,
|
81 |
+
input_ids: torch.Tensor,
|
82 |
+
dataset_embeddings: torch.Tensor,
|
83 |
+
null_dataset_embedding: bool = False,
|
84 |
+
) -> torch.Tensor:
|
85 |
+
if not isinstance(dataset_embeddings, torch.Tensor):
|
86 |
+
dataset_embeddings = torch.tensor(dataset_embeddings)
|
87 |
+
|
88 |
+
if len(dataset_embeddings.shape) == 2:
|
89 |
+
# Auto-expand for a batch.
|
90 |
+
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
91 |
+
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
92 |
+
|
93 |
+
batch_size = input_ids.shape[0]
|
94 |
+
if (self.transductive_tokens_per_document > 1):
|
95 |
+
if self.training:
|
96 |
+
# Choose N random documents to fill our context window with.
|
97 |
+
# This logic is a little confusing but allows us to sample a
|
98 |
+
# different batch *per-document*
|
99 |
+
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
|
100 |
+
R = torch.randint(
|
101 |
+
low=0,
|
102 |
+
high=len(dataset_embeddings),
|
103 |
+
size=(batch_size, self.config.transductive_corpus_size),
|
104 |
+
device=dataset_embeddings.device
|
105 |
+
)
|
106 |
+
# TODO make this deterministic somehow for evaluation?
|
107 |
+
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
108 |
+
else:
|
109 |
+
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
110 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
111 |
+
|
112 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
113 |
+
# If too many dataset embeddings are passed in, just take the first N until
|
114 |
+
# we have the proper number.
|
115 |
+
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
116 |
+
|
117 |
+
_, corpus_size, _hidden_size = dataset_embeddings.shape
|
118 |
+
if _ == 1:
|
119 |
+
# Auto-expand for a batch.
|
120 |
+
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
|
121 |
+
|
122 |
+
if self.training and self.sequence_dropout_prob > 0.0:
|
123 |
+
sequence_dropout_mask = (
|
124 |
+
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
|
125 |
+
)
|
126 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
127 |
+
dataset_embeddings = torch.where(
|
128 |
+
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
|
129 |
+
)
|
130 |
+
elif null_dataset_embedding:
|
131 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
132 |
+
dataset_embeddings = null_embeddings
|
133 |
+
|
134 |
+
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
135 |
+
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
136 |
+
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
137 |
+
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
|
138 |
+
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
|
139 |
+
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
|
140 |
+
|
141 |
+
return soft_prompt
|
142 |
+
|
143 |
+
|
144 |
+
class BiEncoder(transformers.PreTrainedModel):
|
145 |
+
config_class = ContextualModelConfig
|
146 |
+
embedder: transformers.PreTrainedModel
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
config, #: transformers.PreTrainedConfig,
|
150 |
+
):
|
151 |
+
super().__init__(config=config)
|
152 |
+
embedder, _ = load_embedder_and_tokenizer(
|
153 |
+
config.embedder,
|
154 |
+
)
|
155 |
+
|
156 |
+
if config.limit_layers:
|
157 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
158 |
+
limit_layers(embedder, config.limit_layers)
|
159 |
+
|
160 |
+
self.embedder = embedder
|
161 |
+
# if ("t5" in embedder.config.model_type):
|
162 |
+
# print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
|
163 |
+
# self.embedder = torch.compile(self.embedder)
|
164 |
+
self.hidden_size = self.embedder.config.hidden_size
|
165 |
+
# Allow pooling to multiple tokens per document
|
166 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
167 |
+
self.mlp = torch.nn.Sequential(
|
168 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
169 |
+
torch.nn.GELU(),
|
170 |
+
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
|
171 |
+
)
|
172 |
+
self.temp = config.logit_scale
|
173 |
+
|
174 |
+
if config.disable_dropout:
|
175 |
+
disable_dropout(self)
|
176 |
+
self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
input_ids: torch.Tensor,
|
181 |
+
attention_mask: torch.Tensor,
|
182 |
+
dataset_input_ids: Optional[torch.Tensor] = None,
|
183 |
+
dataset_attention_mask: Optional[torch.Tensor] = None,
|
184 |
+
token_type_ids = None,
|
185 |
+
output_hidden_states: bool = False,
|
186 |
+
) -> torch.Tensor:
|
187 |
+
"""
|
188 |
+
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
|
189 |
+
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
|
190 |
+
where the corpus_size >= batch_size and is structured like this:
|
191 |
+
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
192 |
+
for a corpus with three documents and two hard negatives per document
|
193 |
+
"""
|
194 |
+
del token_type_ids
|
195 |
+
|
196 |
+
outputs = (
|
197 |
+
self.embedder(
|
198 |
+
input_ids=input_ids,
|
199 |
+
attention_mask=attention_mask,
|
200 |
+
).last_hidden_state
|
201 |
+
)
|
202 |
+
if self.transductive_tokens_per_document > 1:
|
203 |
+
document_embeddings = None
|
204 |
+
batch_size, seq_length, output_dim = outputs.shape
|
205 |
+
|
206 |
+
if seq_length % self.transductive_tokens_per_document != 0:
|
207 |
+
# Pad to nearest multiple
|
208 |
+
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
|
209 |
+
outputs = torch.cat(
|
210 |
+
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
|
211 |
+
dim=1
|
212 |
+
)
|
213 |
+
attention_mask = torch.cat(
|
214 |
+
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
|
215 |
+
dim=1
|
216 |
+
)
|
217 |
+
seq_length += n_extra_embeds
|
218 |
+
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
|
219 |
+
|
220 |
+
# print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
|
221 |
+
|
222 |
+
outputs = outputs.reshape(
|
223 |
+
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
|
224 |
+
)
|
225 |
+
|
226 |
+
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
|
227 |
+
document_embeddings = mean_pool_3d(outputs, attention_mask)
|
228 |
+
|
229 |
+
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
|
230 |
+
else:
|
231 |
+
if self.pooling_strategy == "mean":
|
232 |
+
document_embeddings = mean_pool(outputs, attention_mask)
|
233 |
+
else:
|
234 |
+
document_embeddings = document_embeddings.max(dim=1)
|
235 |
+
output = self.mlp(document_embeddings)
|
236 |
+
# breakpoint()
|
237 |
+
|
238 |
+
if output_hidden_states:
|
239 |
+
return {
|
240 |
+
"hidden_states": outputs,
|
241 |
+
"pooled": output,
|
242 |
+
}
|
243 |
+
else:
|
244 |
+
return output
|
245 |
+
|
246 |
+
|
247 |
+
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
config,
|
251 |
+
dataset_backbone: transformers.PreTrainedModel,
|
252 |
+
first_stage_hidden_size: int,
|
253 |
+
):
|
254 |
+
super().__init__(config=config)
|
255 |
+
self.backbone = dataset_backbone
|
256 |
+
self.backbone_hidden_size = self.backbone.config.hidden_size
|
257 |
+
self.hidden_size = first_stage_hidden_size # Input token size
|
258 |
+
self.contextual_init()
|
259 |
+
disable_causality(self.backbone)
|
260 |
+
|
261 |
+
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", False)
|
262 |
+
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
263 |
+
self.pool_instruction_end_id = self.backbone.config.bos_token_id
|
264 |
+
|
265 |
+
# Override contextual init
|
266 |
+
self.output_projection = torch.nn.Sequential(
|
267 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
268 |
+
torch.nn.ReLU(),
|
269 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
|
270 |
+
)
|
271 |
+
self._shift_rotary_embedding()
|
272 |
+
|
273 |
+
@property
|
274 |
+
def num_corpus_tokens(self) -> int:
|
275 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
276 |
+
|
277 |
+
@property
|
278 |
+
def corpus_token_ratio(self) -> float:
|
279 |
+
# How many tokens from the first stage make one token in the second
|
280 |
+
# stage?
|
281 |
+
return self.backbone_hidden_size / self.hidden_size
|
282 |
+
|
283 |
+
def corpus_token_pad_size(self, n_tokens: int) -> int:
|
284 |
+
return self.hidden_size % self.backbone_hidden_size
|
285 |
+
|
286 |
+
def _shift_rotary_embedding(self) -> None:
|
287 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
288 |
+
# TODO: Can we do this for LLAMA?
|
289 |
+
print0("Warning: Positional embedding disabling not implemented for LLAMA.")
|
290 |
+
|
291 |
+
def forward(
|
292 |
+
self,
|
293 |
+
input_ids: torch.Tensor,
|
294 |
+
attention_mask: torch.Tensor,
|
295 |
+
dataset_embeddings: torch.Tensor,
|
296 |
+
output_hidden_states: bool = False,
|
297 |
+
null_dataset_embedding: bool = False,
|
298 |
+
) -> torch.Tensor:
|
299 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
300 |
+
input_ids=input_ids,
|
301 |
+
dataset_embeddings=dataset_embeddings,
|
302 |
+
null_dataset_embedding=null_dataset_embedding,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Reshape for this model.
|
306 |
+
# print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
|
307 |
+
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
|
308 |
+
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
|
309 |
+
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
|
310 |
+
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
|
311 |
+
soft_prompt = torch.cat((soft_prompt, padding), dim=1)
|
312 |
+
soft_prompt = soft_prompt.reshape(
|
313 |
+
(soft_prompt.shape[0], -1, self.backbone_hidden_size)
|
314 |
+
)
|
315 |
+
# print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
|
316 |
+
|
317 |
+
backbone_attention_mask = torch.ones(
|
318 |
+
soft_prompt.shape[0:2],
|
319 |
+
dtype=torch.long,
|
320 |
+
device=soft_prompt.device,
|
321 |
+
)
|
322 |
+
token_embeddings = self.backbone.get_input_embeddings()
|
323 |
+
inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
|
324 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
325 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
326 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
327 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
328 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
329 |
+
|
330 |
+
output = self.backbone(
|
331 |
+
inputs_embeds=inputs_embeds,
|
332 |
+
attention_mask=input_attention_mask,
|
333 |
+
output_hidden_states=True,
|
334 |
+
) # (1, 4 + b + s, d)
|
335 |
+
# trim soft prompt
|
336 |
+
output_vectors = output.hidden_states[-1]
|
337 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
338 |
+
|
339 |
+
if self.pool_ignore_instruction_tokens:
|
340 |
+
# Denote the end of an instruction with an extra BOS token.
|
341 |
+
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
342 |
+
# instruction, but also there may or may not be a BOS token at the beginning.
|
343 |
+
instruction_end_idx = (
|
344 |
+
(input_ids == self.pool_instruction_end_id) &
|
345 |
+
attention_mask &
|
346 |
+
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
347 |
+
).int().argmax(1)
|
348 |
+
is_instruction_token_mask = (
|
349 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
350 |
+
)
|
351 |
+
# catch edge case where there is no instruction
|
352 |
+
is_instruction_token_mask = is_instruction_token_mask.where(
|
353 |
+
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
354 |
+
)
|
355 |
+
input_attention_mask = torch.cat((
|
356 |
+
backbone_attention_mask,
|
357 |
+
attention_mask & ~is_instruction_token_mask), dim=1
|
358 |
+
)
|
359 |
+
|
360 |
+
output_attention_mask = input_attention_mask
|
361 |
+
if self.pool_ignore_contextual_tokens:
|
362 |
+
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
363 |
+
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
364 |
+
|
365 |
+
# Take last token position
|
366 |
+
if vars(self.config).get("pooling_strategy") == "last_token":
|
367 |
+
output_pooled = last_token_pool(output_vectors, output_attention_mask)
|
368 |
+
elif vars(self.config).get("pooling_strategy") == "mean":
|
369 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
370 |
+
else:
|
371 |
+
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
|
372 |
+
|
373 |
+
# average with original vectors
|
374 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
375 |
+
|
376 |
+
if output_hidden_states:
|
377 |
+
return {
|
378 |
+
"hidden_states": output_vectors,
|
379 |
+
"pooled": output,
|
380 |
+
}
|
381 |
+
else:
|
382 |
+
return output
|
383 |
+
|
384 |
+
|
385 |
+
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
config,
|
389 |
+
dataset_backbone: transformers.PreTrainedModel,
|
390 |
+
):
|
391 |
+
super().__init__(config=config)
|
392 |
+
self.backbone = dataset_backbone
|
393 |
+
self.hidden_size = self.backbone.config.hidden_size
|
394 |
+
self.hidden_size = dataset_backbone.config.hidden_size
|
395 |
+
self.contextual_init()
|
396 |
+
self._shift_rotary_embedding()
|
397 |
+
|
398 |
+
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", True)
|
399 |
+
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
400 |
+
|
401 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.embedder)
|
402 |
+
self.pool_instruction_end_id = tokenizer.encode(": ", add_special_tokens=False)[0] # Hardcoded for colon-ending prefixes.
|
403 |
+
|
404 |
+
@property
|
405 |
+
def num_corpus_tokens(self) -> int:
|
406 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
407 |
+
|
408 |
+
def _shift_rotary_embedding(self) -> None:
|
409 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
410 |
+
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
|
411 |
+
# We only want to apply positional embeddings to the
|
412 |
+
# *text* portion of the backbone network.
|
413 |
+
self.backbone.config.rotary_start_pos = 0.0
|
414 |
+
rotary_disabled = 0
|
415 |
+
|
416 |
+
rotary_start_pos = self.num_corpus_tokens
|
417 |
+
for module in self.backbone.modules():
|
418 |
+
if hasattr(module, "rotary_emb_dim"):
|
419 |
+
module.rotary_start_pos = rotary_start_pos
|
420 |
+
rotary_disabled += 1
|
421 |
+
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
|
422 |
+
|
423 |
+
def forward(
|
424 |
+
self,
|
425 |
+
input_ids: torch.Tensor,
|
426 |
+
attention_mask: torch.Tensor,
|
427 |
+
dataset_embeddings: torch.Tensor,
|
428 |
+
output_hidden_states: bool = False,
|
429 |
+
null_dataset_embedding: bool = False,
|
430 |
+
) -> torch.Tensor:
|
431 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
432 |
+
input_ids=input_ids,
|
433 |
+
dataset_embeddings=dataset_embeddings,
|
434 |
+
null_dataset_embedding=null_dataset_embedding,
|
435 |
+
)
|
436 |
+
backbone_attention_mask = torch.ones(
|
437 |
+
soft_prompt.shape[0:2],
|
438 |
+
dtype=torch.long,
|
439 |
+
device=soft_prompt.device,
|
440 |
+
)
|
441 |
+
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
442 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
443 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
444 |
+
output = self.backbone(
|
445 |
+
inputs_embeds=inputs_embeds,
|
446 |
+
attention_mask=input_attention_mask,
|
447 |
+
) # (1, 4 + b + s, d)
|
448 |
+
# trim soft prompt
|
449 |
+
output_vectors = output.last_hidden_state
|
450 |
+
|
451 |
+
# use only these tokens
|
452 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
453 |
+
|
454 |
+
if self.pool_ignore_instruction_tokens:
|
455 |
+
# Denote the end of an instruction with an extra BOS token.
|
456 |
+
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
457 |
+
# instruction, but also there may or may not be a BOS token at the beginning.
|
458 |
+
instruction_end_idx = (
|
459 |
+
(input_ids == self.pool_instruction_end_id) &
|
460 |
+
attention_mask &
|
461 |
+
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
462 |
+
).int().argmax(1)
|
463 |
+
is_instruction_token_mask = (
|
464 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
465 |
+
)
|
466 |
+
# catch edge case where there is no instruction
|
467 |
+
is_instruction_token_mask = is_instruction_token_mask.where(
|
468 |
+
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
469 |
+
)
|
470 |
+
output_attention_mask = torch.cat((backbone_attention_mask, attention_mask & ~is_instruction_token_mask), dim=1)
|
471 |
+
else:
|
472 |
+
output_attention_mask = input_attention_mask
|
473 |
+
|
474 |
+
if self.pool_ignore_contextual_tokens:
|
475 |
+
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
476 |
+
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
477 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
478 |
+
# average with original vectors
|
479 |
+
output = self.output_projection(output_pooled) + output_pooled # (b, d) -> (b, d) / with residual connection
|
480 |
+
|
481 |
+
if output_hidden_states:
|
482 |
+
return {
|
483 |
+
"hidden_states": output_vectors,
|
484 |
+
"pooled": output,
|
485 |
+
}
|
486 |
+
else:
|
487 |
+
return output
|
488 |
+
|
489 |
+
|
490 |
+
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
config, #: transformers.PreTrainedConfig,
|
494 |
+
embedder: transformers.PreTrainedModel,
|
495 |
+
):
|
496 |
+
super().__init__(config=config)
|
497 |
+
self.embedder = embedder
|
498 |
+
self.hidden_size = self.embedder.config.hidden_size
|
499 |
+
self.contextual_init()
|
500 |
+
|
501 |
+
def forward(
|
502 |
+
self,
|
503 |
+
input_ids: torch.Tensor,
|
504 |
+
attention_mask: torch.Tensor,
|
505 |
+
dataset_input_ids: torch.Tensor,
|
506 |
+
dataset_attention_mask: torch.Tensor,
|
507 |
+
output_hidden_states: bool = False,
|
508 |
+
) -> torch.Tensor:
|
509 |
+
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
|
510 |
+
|
511 |
+
dataset_input_ids = dataset_input_ids[R]
|
512 |
+
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
|
513 |
+
|
514 |
+
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
|
515 |
+
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
|
516 |
+
output_attention_mask = torch.cat(
|
517 |
+
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1
|
518 |
+
)
|
519 |
+
|
520 |
+
output = self.embedder(
|
521 |
+
input_ids=input_ids,
|
522 |
+
attention_mask=input_attention_mask,
|
523 |
+
)
|
524 |
+
|
525 |
+
output_vectors = output.last_hidden_state
|
526 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
527 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
528 |
+
|
529 |
+
if output_hidden_states:
|
530 |
+
S_d = dataset_attention_mask.shape[1]
|
531 |
+
output_vectors = output_vectors[:, S_d:, :]
|
532 |
+
return {
|
533 |
+
"hidden_states": output_vectors,
|
534 |
+
"pooled": output,
|
535 |
+
}
|
536 |
+
else:
|
537 |
+
return output
|
538 |
+
|
539 |
+
|
540 |
+
class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
541 |
+
config_class = ContextualModelConfig
|
542 |
+
embedder: transformers.PreTrainedModel
|
543 |
+
dataset_backbone: transformers.PreTrainedModel
|
544 |
+
def __init__(
|
545 |
+
self,
|
546 |
+
config,
|
547 |
+
):
|
548 |
+
super().__init__(config=config)
|
549 |
+
dataset_backbone, _ = load_embedder_and_tokenizer(
|
550 |
+
vars(config).get("dataset_backbone") or config.embedder
|
551 |
+
)
|
552 |
+
|
553 |
+
if config.limit_layers:
|
554 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
555 |
+
limit_layers(dataset_backbone, config.limit_layers)
|
556 |
+
|
557 |
+
biencoder_config = copy.deepcopy(config)
|
558 |
+
biencoder_config.embedding_output_dim = None
|
559 |
+
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
|
560 |
+
self.first_stage_model = BiEncoder(
|
561 |
+
config=biencoder_config,
|
562 |
+
)
|
563 |
+
|
564 |
+
if vars(config).get("autoregressive_backbone", False):
|
565 |
+
self.second_stage_model = DatasetConditionedAutoregressive(
|
566 |
+
config=config,
|
567 |
+
dataset_backbone=dataset_backbone,
|
568 |
+
first_stage_hidden_size=self.first_stage_model.hidden_size,
|
569 |
+
)
|
570 |
+
else:
|
571 |
+
self.second_stage_model = DatasetConditionedBiencoder(
|
572 |
+
config=config,
|
573 |
+
dataset_backbone=dataset_backbone
|
574 |
+
)
|
575 |
+
|
576 |
+
self.temp = config.logit_scale
|
577 |
+
if config.disable_dropout:
|
578 |
+
disable_dropout(self)
|
579 |
+
|
580 |
+
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
|
581 |
+
if transductive_tie_token_embeddings:
|
582 |
+
self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
|
583 |
+
self.first_stage_model.embedder.embeddings.word_embeddings.weight
|
584 |
+
)
|
585 |
+
|
586 |
+
def forward(
|
587 |
+
self,
|
588 |
+
input_ids: torch.Tensor,
|
589 |
+
attention_mask: torch.Tensor,
|
590 |
+
dataset_input_ids: Optional[torch.Tensor],
|
591 |
+
dataset_attention_mask: Optional[torch.Tensor],
|
592 |
+
output_hidden_states: bool = False,
|
593 |
+
) -> torch.Tensor:
|
594 |
+
"""
|
595 |
+
input_ids (long torch.Tensor) – ids of input tokens
|
596 |
+
attention_mask (bool torch.Tensor)
|
597 |
+
"""
|
598 |
+
dataset_embeddings = self.first_stage_model(
|
599 |
+
input_ids=dataset_input_ids,
|
600 |
+
attention_mask=dataset_attention_mask
|
601 |
+
)
|
602 |
+
return self.second_stage_model(
|
603 |
+
input_ids=input_ids,
|
604 |
+
attention_mask=attention_mask,
|
605 |
+
dataset_embeddings=dataset_embeddings,
|
606 |
+
output_hidden_states=output_hidden_states,
|
607 |
+
)
|
608 |
+
|
609 |
+
|
610 |
+
|
611 |
+
def get_model_class(name: str):
|
612 |
+
if name in 'transductive':
|
613 |
+
return ContextualDocumentEmbeddingTransformer
|
614 |
+
elif name == 'biencoder':
|
615 |
+
return BiEncoder
|
616 |
+
elif name == "biencoder_plus_plus":
|
617 |
+
from cde.model_extra import BiEncoderPlusPlus
|
618 |
+
return BiEncoderPlusPlus
|
619 |
+
elif name == "dataset_prefix_biencoder":
|
620 |
+
return DatasetPrefixBiencoder
|
621 |
+
else:
|
622 |
+
raise ValueError(f'unknown model cls {name}')
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7cca261c510de07c012f3019366f1b6c5720761b6966b0388faea6e70398983
|
3 |
+
size 1124594680
|