chenyuhe commited on
Commit
03eedcf
·
verified ·
1 Parent(s): 56ac58a

Upload 11 files

Browse files
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/huang/chy/ESMC/esm plus",
3
+ "architectures": [
4
+ "ESMplusplusForSequenceClassification"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_esm_plusplus.ESMplusplusConfig",
8
+ "AutoModel": "modeling_esm_plusplus.ESMplusplusModel",
9
+ "AutoModelForMaskedLM": "modeling_esm_plusplus.ESMplusplusForMaskedLM",
10
+ "AutoModelForSequenceClassification": "modeling_esm_plusplus.ESMplusplusForSequenceClassification",
11
+ "AutoModelForTokenClassification": "modeling_esm_plusplus.ESMplusplusForTokenClassification"
12
+ },
13
+ "dropout": 0.0,
14
+ "hidden_size": 1152,
15
+ "initializer_range": 0.02,
16
+ "model_type": "ESMplusplus",
17
+ "num_attention_heads": 18,
18
+ "num_hidden_layers": 36,
19
+ "problem_type": "single_label_classification",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.46.3",
22
+ "vocab_size": 64
23
+ }
modeling_esm_plusplus.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from dataclasses import dataclass
17
+ from functools import cache, partial
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple, Union
20
+ from einops import rearrange, repeat
21
+ from huggingface_hub import snapshot_download
22
+ from tokenizers import Tokenizer
23
+ from tokenizers.models import BPE
24
+ from tokenizers.processors import TemplateProcessing
25
+ from torch.utils.data import Dataset, DataLoader
26
+ from tqdm.auto import tqdm
27
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
28
+ from transformers.modeling_outputs import ModelOutput
29
+
30
+
31
+ class ESMplusplusConfig(PretrainedConfig):
32
+ """Configuration class for ESM++ model.
33
+
34
+ Args:
35
+ vocab_size: Size of the vocabulary
36
+ hidden_size: Dimension of hidden layers
37
+ num_attention_heads: Number of attention heads
38
+ num_hidden_layers: Number of transformer layers
39
+ num_labels: Number of output labels for classification
40
+ problem_type: Type of problem - regression, single/multi label classification
41
+ """
42
+ model_type = "ESMplusplus"
43
+ def __init__(
44
+ self,
45
+ vocab_size: int = 64,
46
+ hidden_size: int = 960,
47
+ num_attention_heads: int = 15,
48
+ num_hidden_layers: int = 30,
49
+ num_labels: int = 2,
50
+ problem_type: str | None = None,
51
+ dropout: float = 0.0,
52
+ initializer_range: float = 0.02,
53
+ **kwargs,
54
+ ):
55
+ super().__init__(**kwargs)
56
+ self.vocab_size = vocab_size
57
+ self.hidden_size = hidden_size
58
+ self.num_attention_heads = num_attention_heads
59
+ self.num_hidden_layers = num_hidden_layers
60
+ self.num_labels = num_labels
61
+ self.problem_type = problem_type
62
+ self.dropout = dropout
63
+ self.initializer_range = initializer_range
64
+
65
+
66
+ ### Rotary Embeddings
67
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
68
+ """Rotates half the hidden dims of the input."""
69
+ if not interleaved:
70
+ x1, x2 = x.chunk(2, dim=-1)
71
+ return torch.cat((-x2, x1), dim=-1)
72
+ else:
73
+ x1, x2 = x[..., ::2], x[..., 1::2]
74
+ return rearrange(
75
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
76
+ )
77
+
78
+
79
+ def apply_rotary_emb_torch(
80
+ x: torch.Tensor,
81
+ cos: torch.Tensor,
82
+ sin: torch.Tensor,
83
+ interleaved: bool = False,
84
+ _inplace: bool = False,
85
+ ) -> torch.Tensor:
86
+ """Apply rotary embeddings to input based on cos and sin."""
87
+ ro_dim = cos.shape[-1] * 2
88
+ assert ro_dim <= x.shape[-1]
89
+ seqlen = x.size(1)
90
+ cos = cos[:seqlen]
91
+ sin = sin[:seqlen]
92
+ cos = repeat(cos, "s d -> s 1 (2 d)")
93
+ sin = repeat(sin, "s d -> s 1 (2 d)")
94
+ return torch.cat(
95
+ [
96
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
97
+ x[..., ro_dim:],
98
+ ],
99
+ dim=-1,
100
+ )
101
+
102
+
103
+ class RotaryEmbedding(torch.nn.Module):
104
+ """Rotary position embeddings.
105
+
106
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
107
+
108
+ Args:
109
+ dim: Dimension of the embedding
110
+ base: Base for computing angular frequencies
111
+ interleaved: Whether to use interleaved rotations
112
+ scale_base: Base for scaling
113
+ scaling_factor: Factor for scaling positions
114
+ pos_idx_in_fp32: Whether to compute position indices in fp32
115
+ device: Computation device
116
+ """
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ base: float = 10000.0,
121
+ interleaved: bool = False,
122
+ scale_base: Optional[float] = None,
123
+ scaling_factor: float = 1.0,
124
+ pos_idx_in_fp32: bool = True,
125
+ device: Optional[torch.device] = None,
126
+ ):
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.base = float(base)
130
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
131
+ self.interleaved = interleaved
132
+ self.scale_base = scale_base
133
+ self.scaling_factor = scaling_factor
134
+ self.device = device
135
+
136
+ self._seq_len_cached = 0
137
+ self._cos_cached = None
138
+ self._sin_cached = None
139
+ self._cos_k_cached = None
140
+ self._sin_k_cached = None
141
+ self.reset_parameters()
142
+
143
+ def reset_parameters(self):
144
+ """Reset the parameters of the embedding."""
145
+ inv_freq = self._compute_inv_freq(self.device)
146
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
147
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
148
+ scale = (
149
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
150
+ if self.scale_base is not None
151
+ else None
152
+ )
153
+ self.register_buffer("scale", scale)
154
+
155
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
156
+ """Compute inverse frequency bands."""
157
+ return 1 / (
158
+ self.base
159
+ ** (
160
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
161
+ / self.dim
162
+ )
163
+ )
164
+
165
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
166
+ """Update the cached cosine and sine values."""
167
+ if (
168
+ seqlen > self._seq_len_cached
169
+ or self._cos_cached is None
170
+ or self._cos_cached.device != device
171
+ or self._cos_cached.dtype != dtype
172
+ or (self.training and self._cos_cached.is_inference())
173
+ ):
174
+ self._seq_len_cached = seqlen
175
+ if self.pos_idx_in_fp32:
176
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
177
+ t /= self.scaling_factor
178
+ if self.inv_freq.dtype != torch.float32:
179
+ inv_freq = self.inv_freq.to(torch.float32)
180
+ else:
181
+ inv_freq = self.inv_freq
182
+ else:
183
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
184
+ t /= self.scaling_factor
185
+ inv_freq = self.inv_freq
186
+ freqs = torch.outer(t, inv_freq)
187
+
188
+ if self.scale is None:
189
+ self._cos_cached = torch.cos(freqs).to(dtype)
190
+ self._sin_cached = torch.sin(freqs).to(dtype)
191
+ else:
192
+ power = (
193
+ torch.arange(
194
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
195
+ )
196
+ - seqlen // 2
197
+ ) / self.scale_base
198
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
199
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
200
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
201
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
202
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
203
+
204
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ """Apply rotary embeddings to queries and keys.
206
+
207
+ Args:
208
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
209
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
210
+
211
+ Returns:
212
+ Tuple of rotated query and key tensors
213
+ """
214
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
215
+ assert self._cos_cached is not None
216
+ assert self._sin_cached is not None
217
+ if self.scale is None:
218
+ return (
219
+ apply_rotary_emb_torch(
220
+ q,
221
+ self._cos_cached,
222
+ self._sin_cached,
223
+ self.interleaved,
224
+ True, # inplace=True
225
+ ),
226
+ apply_rotary_emb_torch(
227
+ k,
228
+ self._cos_cached,
229
+ self._sin_cached,
230
+ self.interleaved,
231
+ True, # inplace=True
232
+ ),
233
+ ) # type: ignore
234
+ else:
235
+ assert False
236
+
237
+
238
+ ### Feedforward Network Components
239
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
240
+ """Compute corrected dimension for SwiGLU."""
241
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
242
+
243
+
244
+ class SwiGLU(nn.Module):
245
+ """SwiGLU activation function."""
246
+ def __init__(self):
247
+ super(SwiGLU, self).__init__()
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ x1, x2 = x.chunk(2, dim=-1)
251
+ return F.silu(x1) * x2
252
+
253
+
254
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
255
+ """Create SwiGLU feedforward network with layer normalization."""
256
+ return nn.Sequential(
257
+ nn.LayerNorm(d_model),
258
+ nn.Linear(
259
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
260
+ ),
261
+ SwiGLU(),
262
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
263
+ )
264
+
265
+
266
+ ### Attention
267
+ class MultiHeadAttention(nn.Module):
268
+ """Multi-head attention with rotary embeddings.
269
+
270
+ Args:
271
+ d_model: Model dimension
272
+ n_heads: Number of attention heads
273
+ """
274
+ def __init__(self, d_model: int, n_heads: int):
275
+ super().__init__()
276
+ self.d_model = d_model
277
+ self.n_heads = n_heads
278
+ self.d_head = self.d_model // self.n_heads
279
+ self.layernorm_qkv = nn.Sequential(
280
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
281
+ )
282
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
283
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
284
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
285
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
286
+ self.rotary = RotaryEmbedding(d_model // n_heads)
287
+
288
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """Apply rotary embeddings to query and key."""
290
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
291
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
292
+ q, k = self.rotary(q, k)
293
+ q = q.flatten(-2, -1)
294
+ k = k.flatten(-2, -1)
295
+ return q, k
296
+
297
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
298
+ """
299
+ Args:
300
+ x: Input tensor
301
+ attention_mask: Optional attention mask
302
+ output_attentions: Whether to return attention weights
303
+
304
+ Returns:
305
+ Output tensor after self attention, and optionally attention weights
306
+ """
307
+ attn_weights = None
308
+ qkv_BLD3 = self.layernorm_qkv(x)
309
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
310
+ query_BLD, key_BLD = (
311
+ self.q_ln(query_BLD).to(query_BLD.dtype),
312
+ self.k_ln(key_BLD).to(query_BLD.dtype),
313
+ )
314
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
315
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
316
+
317
+ if output_attentions: # Manual attention computation
318
+ L, S = query_BLD.size(-2), key_BLD.size(-2)
319
+ scale = 1 / math.sqrt(query_BLD.size(-1))
320
+ attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
321
+ if attention_mask is not None:
322
+ if attention_mask.dtype == torch.bool:
323
+ attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
324
+ else:
325
+ attn_bias += attention_mask
326
+
327
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
328
+ attn_weights += attn_bias
329
+ attn_weights = F.softmax(attn_weights, dim=-1)
330
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
331
+ else:
332
+ context_BHLD = F.scaled_dot_product_attention(
333
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
334
+ )
335
+
336
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
337
+ output = self.out_proj(context_BLD)
338
+ return output, attn_weights
339
+
340
+
341
+ ### Regression Head
342
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
343
+ """Create a regression head with optional hidden dimension.
344
+
345
+ Args:
346
+ d_model: Input dimension
347
+ output_dim: Output dimension
348
+ hidden_dim: Optional hidden dimension (defaults to d_model)
349
+ """
350
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
351
+ return nn.Sequential(
352
+ nn.Linear(d_model, hidden_dim),
353
+ nn.GELU(),
354
+ nn.LayerNorm(hidden_dim),
355
+ nn.Linear(hidden_dim, output_dim),
356
+ )
357
+
358
+
359
+ ### Transformer Block
360
+ class UnifiedTransformerBlock(nn.Module):
361
+ """Transformer block with attention and feedforward layers.
362
+
363
+ Args:
364
+ d_model: Model dimension
365
+ n_heads: Number of attention heads
366
+ residue_scaling_factor: Factor for scaling residual connections
367
+ expansion_ratio: Expansion ratio for feedforward network
368
+ """
369
+ def __init__(
370
+ self,
371
+ d_model: int,
372
+ n_heads: int,
373
+ residue_scaling_factor: float = 1,
374
+ expansion_ratio: float = 8 / 3,
375
+ dropout: float = 0.0,
376
+ ):
377
+ super().__init__()
378
+ self.attn = MultiHeadAttention(d_model, n_heads)
379
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
380
+ self.scaling_factor = residue_scaling_factor
381
+ self.dropout = nn.Dropout(dropout)
382
+
383
+ def forward(
384
+ self,
385
+ x: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ output_attentions: bool = False,
388
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
389
+ """
390
+ Args:
391
+ x: Input tensor
392
+ attention_mask: Optional attention mask
393
+ output_attentions: Whether to return attention weights
394
+
395
+ Returns:
396
+ Output tensor after transformer block, and optionally attention weights
397
+ """
398
+ attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
399
+ x = x + self.dropout(attn_output) / self.scaling_factor
400
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
401
+ return x, attn_weights
402
+
403
+
404
+ ### Model Outputs
405
+ @dataclass
406
+ class TransformerOutput(ModelOutput):
407
+ """Output type for transformer encoder."""
408
+ last_hidden_state: Optional[torch.Tensor] = None
409
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
410
+ attentions: Optional[Tuple[torch.Tensor]] = None
411
+
412
+
413
+ @dataclass
414
+ class ESMplusplusOutput(ModelOutput):
415
+ """Output type for ESM++ models."""
416
+ loss: Optional[torch.Tensor] = None
417
+ logits: Optional[torch.Tensor] = None
418
+ last_hidden_state: Optional[torch.Tensor] = None
419
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
420
+ attentions: Optional[Tuple[torch.Tensor]] = None
421
+
422
+
423
+ ### Transformer Stack
424
+ class TransformerStack(nn.Module):
425
+ """Stack of transformer blocks.
426
+
427
+ Args:
428
+ d_model: Model dimension
429
+ n_heads: Number of attention heads
430
+ n_layers: Number of transformer layers
431
+ dropout: Dropout rate
432
+ """
433
+ def __init__(
434
+ self,
435
+ d_model: int,
436
+ n_heads: int,
437
+ n_layers: int,
438
+ dropout: float = 0.0,
439
+ ):
440
+ super().__init__()
441
+ self.blocks = nn.ModuleList(
442
+ [
443
+ UnifiedTransformerBlock(
444
+ d_model,
445
+ n_heads,
446
+ residue_scaling_factor=math.sqrt(n_layers / 36),
447
+ dropout=dropout,
448
+ )
449
+ for i in range(n_layers)
450
+ ]
451
+ )
452
+ self.norm = nn.LayerNorm(d_model, bias=False)
453
+ self.gradient_checkpointing = False
454
+
455
+ def forward(
456
+ self,
457
+ x: torch.Tensor,
458
+ attention_mask: Optional[torch.Tensor] = None,
459
+ output_hidden_states: bool = False,
460
+ output_attentions: bool = False,
461
+ ) -> TransformerOutput:
462
+ """
463
+ Args:
464
+ x: Input tensor
465
+ attention_mask: Optional attention mask
466
+ output_hidden_states: Whether to return all hidden states
467
+ output_attentions: Whether to return attention weights
468
+
469
+ Returns:
470
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
471
+ """
472
+ batch_size, seq_len, _ = x.shape
473
+ hidden_states = () if output_hidden_states else None
474
+ attentions = () if output_attentions else None
475
+
476
+ if attention_mask is not None:
477
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
478
+
479
+ for block in self.blocks:
480
+ if self.gradient_checkpointing and self.training:
481
+ x, attn_weights = self._gradient_checkpointing_func(
482
+ block.__call__,
483
+ x,
484
+ attention_mask,
485
+ output_attentions,
486
+ )
487
+ else:
488
+ x, attn_weights = block(x, attention_mask, output_attentions)
489
+
490
+ if attentions is not None:
491
+ attentions += (attn_weights,)
492
+
493
+ if output_hidden_states:
494
+ assert hidden_states is not None
495
+ hidden_states += (x,)
496
+
497
+ return TransformerOutput(
498
+ last_hidden_state=self.norm(x),
499
+ hidden_states=hidden_states,
500
+ attentions=attentions
501
+ )
502
+
503
+
504
+ ### Dataset for Embedding
505
+ class ProteinDataset(Dataset):
506
+ """Simple dataset for protein sequences."""
507
+ def __init__(self, sequences: list[str]):
508
+ self.sequences = sequences
509
+
510
+ def __len__(self) -> int:
511
+ return len(self.sequences)
512
+
513
+ def __getitem__(self, idx: int) -> str:
514
+ return self.sequences[idx]
515
+
516
+
517
+ class PreTrainedESMplusplusModel(PreTrainedModel):
518
+ """
519
+ init weights for ESM++ models
520
+ """
521
+ config_class = ESMplusplusConfig
522
+ base_model_prefix = "esm++"
523
+ supports_gradient_checkpointing = True
524
+
525
+ def _init_weights(self, module):
526
+ """Initialize the weights"""
527
+ if isinstance(module, nn.Linear):
528
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
529
+ if module.bias is not None:
530
+ module.bias.data.zero_()
531
+ elif isinstance(module, nn.Embedding):
532
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
533
+ if module.padding_idx is not None:
534
+ module.weight.data[module.padding_idx].zero_()
535
+ elif isinstance(module, nn.LayerNorm):
536
+ if module.bias is not None:
537
+ module.bias.data.zero_()
538
+ module.weight.data.fill_(1.0)
539
+
540
+ @classmethod
541
+ def from_pretrained_esm(cls, model_name: str):
542
+ """Load a pretrained ESM++ model."""
543
+ if '300' in model_name:
544
+ return ESMplusplus_300M()
545
+ elif '600' in model_name:
546
+ return ESMplusplus_600M()
547
+ else:
548
+ raise ValueError(f"Invalid model name: {model_name}")
549
+
550
+ @property
551
+ def device(self) -> torch.device:
552
+ """Get the device of the model."""
553
+ return next(self.parameters()).device
554
+
555
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
556
+ """Apply mean pooling to sequence outputs."""
557
+ if attention_mask is None:
558
+ return x.mean(dim=1)
559
+ else:
560
+ attention_mask = attention_mask.unsqueeze(-1)
561
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
562
+
563
+ def max_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
564
+ """Apply max pooling to sequence outputs."""
565
+ if attention_mask is None:
566
+ return x.max(dim=1).values
567
+ else:
568
+ attention_mask = attention_mask.unsqueeze(-1)
569
+ return (x * attention_mask).max(dim=1).values
570
+
571
+ def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
572
+ """Apply cls pooling to sequence outputs."""
573
+ return x[:, 0, :]
574
+
575
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
576
+ """Collate function for batching sequences."""
577
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
578
+
579
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
580
+ """Read sequences from SQLite database."""
581
+ import sqlite3
582
+ sequences = []
583
+ with sqlite3.connect(db_path) as conn:
584
+ c = conn.cursor()
585
+ c.execute("SELECT sequence FROM embeddings")
586
+ while True:
587
+ row = c.fetchone()
588
+ if row is None:
589
+ break
590
+ sequences.append(row[0])
591
+ return set(sequences)
592
+
593
+ def embed_dataset(
594
+ self,
595
+ sequences: list[str],
596
+ batch_size: int = 2,
597
+ max_len: int = 512,
598
+ full_embeddings: bool = False,
599
+ full_precision: bool = False,
600
+ pooling_type: str = 'mean',
601
+ num_workers: int = 0,
602
+ sql: bool = False,
603
+ sql_db_path: str = 'embeddings.db',
604
+ ) -> Optional[dict[str, torch.Tensor]]:
605
+ """Embed a dataset of protein sequences.
606
+
607
+ Args:
608
+ sequences: List of protein sequences
609
+ batch_size: Batch size for processing
610
+ max_len: Maximum sequence length
611
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
612
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
613
+ pooling_type: Type of pooling ('mean' or 'cls')
614
+ num_workers: Number of workers for data loading, 0 for the main process
615
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
616
+ sql_db_path: Path to SQLite database
617
+
618
+ Returns:
619
+ Dictionary mapping sequences to embeddings, or None if sql=True
620
+ """
621
+ sequences = list(set([seq[:max_len] for seq in sequences]))
622
+ device = self.device
623
+
624
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
625
+ if full_embeddings:
626
+ return residue_embeddings
627
+ elif pooling_type == 'mean':
628
+ return self.mean_pooling(residue_embeddings, attention_mask)
629
+ elif pooling_type == 'max':
630
+ return self.max_pooling(residue_embeddings, attention_mask)
631
+ elif pooling_type == 'cls':
632
+ return self.cls_pooling(residue_embeddings, attention_mask)
633
+ else:
634
+ raise ValueError(f"Invalid pooling type: {pooling_type}")
635
+
636
+ sequences = list(set([seq[:max_len] for seq in sequences]))
637
+ if sql:
638
+ import sqlite3
639
+ conn = sqlite3.connect(sql_db_path)
640
+ c = conn.cursor()
641
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
642
+ already_embedded = self._read_sequences_from_db(sql_db_path)
643
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
644
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
645
+ print(f"Embedding {len(to_embed)} new sequences")
646
+ if len(to_embed) > 0:
647
+ to_embed = sorted(to_embed, key=len, reverse=True)
648
+ dataset = ProteinDataset(to_embed)
649
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
650
+ with torch.no_grad():
651
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
652
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
653
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
654
+ x = self.embed(input_ids)
655
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
656
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
657
+
658
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
659
+ if full_embeddings:
660
+ emb = emb[mask.bool()]
661
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
662
+ (seq, emb.cpu().numpy().tobytes()))
663
+
664
+ if (i + 1) % 100 == 0:
665
+ conn.commit()
666
+
667
+ conn.commit()
668
+ conn.close()
669
+ return None
670
+
671
+ embeddings_dict = {}
672
+ sequences = sorted(sequences, key=len, reverse=True)
673
+ dataset = ProteinDataset(sequences)
674
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
675
+ with torch.no_grad():
676
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
677
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
678
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
679
+ x = self.embed(input_ids)
680
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
681
+ if full_precision:
682
+ residue_embeddings = residue_embeddings.float()
683
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
684
+ for seq, emb in zip(seqs, embeddings):
685
+ embeddings_dict[seq] = emb
686
+
687
+ return embeddings_dict
688
+
689
+
690
+ ### ESM++ Models
691
+ class ESMplusplusModel(PreTrainedESMplusplusModel):
692
+ """
693
+ ESM++ model. transformer model with no heads
694
+ """
695
+ config_class = ESMplusplusConfig
696
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
697
+ super().__init__(config, **kwargs)
698
+ self.config = config
699
+ self.vocab_size = config.vocab_size
700
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
701
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
702
+ self.tokenizer = EsmSequenceTokenizer()
703
+ self.init_weights()
704
+
705
+ def get_input_embeddings(self):
706
+ return self.embed
707
+
708
+ def set_input_embeddings(self, value):
709
+ self.embed = value
710
+
711
+ def forward(
712
+ self,
713
+ input_ids: Optional[torch.Tensor] = None,
714
+ attention_mask: Optional[torch.Tensor] = None,
715
+ inputs_embeds: Optional[torch.Tensor] = None,
716
+ output_attentions: Optional[bool] = None,
717
+ output_hidden_states: Optional[bool] = None,
718
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
719
+ ) -> TransformerOutput:
720
+ """Forward pass for masked language modeling.
721
+
722
+ Args:
723
+ input_ids: Input token IDs
724
+ attention_mask: Attention mask
725
+ inputs_embeds: Optional precomputed embeddings
726
+ output_hidden_states: Whether to return all hidden states
727
+ output_attentions: Whether to return attention weights
728
+
729
+ Returns:
730
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
731
+ """
732
+ if inputs_embeds is None:
733
+ x = self.embed(input_ids)
734
+ else:
735
+ x = inputs_embeds
736
+ return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
737
+
738
+
739
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
740
+ """
741
+ ESM++ model for masked language modeling.
742
+ Implements the base ESM++ architecture with a masked language modeling head.
743
+ """
744
+ config_class = ESMplusplusConfig
745
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
746
+ super().__init__(config, **kwargs)
747
+ self.config = config
748
+ self.vocab_size = config.vocab_size
749
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
750
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
751
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
752
+ self.ce_loss = nn.CrossEntropyLoss()
753
+ self.tokenizer = EsmSequenceTokenizer()
754
+ self.init_weights()
755
+
756
+ def get_input_embeddings(self):
757
+ return self.embed
758
+
759
+ def set_input_embeddings(self, value):
760
+ self.embed = value
761
+
762
+ def get_output_embeddings(self):
763
+ return self.sequence_head[-1]
764
+
765
+ def set_output_embeddings(self, new_embeddings):
766
+ self.sequence_head[-1] = new_embeddings
767
+
768
+ def forward(
769
+ self,
770
+ input_ids: Optional[torch.Tensor] = None,
771
+ attention_mask: Optional[torch.Tensor] = None,
772
+ inputs_embeds: Optional[torch.Tensor] = None,
773
+ labels: Optional[torch.Tensor] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
777
+ ) -> ESMplusplusOutput:
778
+ """Forward pass for masked language modeling.
779
+
780
+ Args:
781
+ input_ids: Input token IDs
782
+ attention_mask: Attention mask
783
+ inputs_embeds: Optional precomputed embeddings
784
+ labels: Optional labels for masked tokens
785
+ output_hidden_states: Whether to return all hidden states
786
+ output_attentions: Whether to return attention weights
787
+
788
+ Returns:
789
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
790
+ """
791
+ if inputs_embeds is None:
792
+ x = self.embed(input_ids)
793
+ else:
794
+ x = inputs_embeds
795
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
796
+ x = output.last_hidden_state
797
+ logits = self.sequence_head(x)
798
+ loss = None
799
+ if labels is not None:
800
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
801
+ return ESMplusplusOutput(
802
+ loss=loss,
803
+ logits=logits,
804
+ last_hidden_state=x,
805
+ hidden_states=output.hidden_states,
806
+ attentions=output.attentions,
807
+ )
808
+
809
+
810
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
811
+ """
812
+ ESM++ model for sequence classification.
813
+ Extends the base ESM++ model with a classification head.
814
+ """
815
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
816
+ super().__init__(config, **kwargs)
817
+ self.config = config
818
+ self.num_labels = config.num_labels
819
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
820
+ # Large intermediate projections help with sequence classification tasks (*4)
821
+ self.mse = nn.MSELoss()
822
+ self.ce = nn.CrossEntropyLoss()
823
+ self.bce = nn.BCEWithLogitsLoss()
824
+ self.init_weights()
825
+
826
+ def forward(
827
+ self,
828
+ input_ids: Optional[torch.Tensor] = None,
829
+ attention_mask: Optional[torch.Tensor] = None,
830
+ inputs_embeds: Optional[torch.Tensor] = None,
831
+ labels: Optional[torch.Tensor] = None,
832
+ output_attentions: Optional[bool] = None,
833
+ output_hidden_states: Optional[bool] = None,
834
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
835
+ ) -> ESMplusplusOutput:
836
+ """Forward pass for sequence classification.
837
+
838
+ Args:
839
+ input_ids: Input token IDs
840
+ attention_mask: Attention mask
841
+ inputs_embeds: Optional precomputed embeddings
842
+ labels: Optional labels for classification
843
+ output_hidden_states: Whether to return all hidden states
844
+ output_attentions: Whether to return attention weights
845
+
846
+ Returns:
847
+ ESMplusplusOutput containing loss, logits, and hidden states
848
+ """
849
+ output = super().forward(
850
+ input_ids=input_ids,
851
+ attention_mask=attention_mask,
852
+ inputs_embeds=inputs_embeds,
853
+ labels=None,
854
+ output_attentions=output_attentions,
855
+ output_hidden_states=output_hidden_states
856
+ )
857
+ x = output.last_hidden_state
858
+ cls_features = x[:, 0, :]
859
+ mean_features = self.mean_pooling(x, attention_mask)
860
+ # we include mean pooling features to help with early convergence, the cost of this is basically zero
861
+ features = torch.cat([cls_features, mean_features], dim=-1)
862
+ logits = self.classifier(features)
863
+ loss = None
864
+ if labels is not None:
865
+ labels = labels.to(logits.device)
866
+ if self.config.problem_type is None:
867
+ if self.num_labels == 1:
868
+ self.config.problem_type = "regression"
869
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
870
+ self.config.problem_type = "single_label_classification"
871
+ else:
872
+ self.config.problem_type = "multi_label_classification"
873
+
874
+ if self.config.problem_type == "regression":
875
+ if self.num_labels == 1:
876
+ loss = self.mse(logits.flatten(), labels.flatten())
877
+ else:
878
+ loss = self.mse(logits, labels)
879
+ elif self.config.problem_type == "single_label_classification":
880
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
881
+ elif self.config.problem_type == "multi_label_classification":
882
+ loss = self.bce(logits, labels)
883
+ return ESMplusplusOutput(
884
+ loss=loss,
885
+ logits=logits,
886
+ last_hidden_state=x,
887
+ hidden_states=output.hidden_states,
888
+ )
889
+
890
+
891
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
892
+ """
893
+ ESM++ model for token classification.
894
+ Extends the base ESM++ model with a token classification head.
895
+ """
896
+ def __init__(self, config: ESMplusplusConfig):
897
+ super().__init__(config)
898
+ self.config = config
899
+ self.num_labels = config.num_labels
900
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
901
+ # Large intermediate projections help with sequence classification tasks (*4)
902
+ self.loss_fct = nn.CrossEntropyLoss()
903
+ self.init_weights()
904
+
905
+ def forward(
906
+ self,
907
+ input_ids: Optional[torch.Tensor] = None,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ inputs_embeds: Optional[torch.Tensor] = None,
910
+ labels: Optional[torch.Tensor] = None,
911
+ output_attentions: Optional[bool] = None,
912
+ output_hidden_states: Optional[bool] = None,
913
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
914
+ ) -> ESMplusplusOutput:
915
+ """Forward pass for token classification.
916
+
917
+ Args:
918
+ input_ids: Input token IDs
919
+ attention_mask: Attention mask
920
+ inputs_embeds: Optional precomputed embeddings
921
+ labels: Optional labels for token classification
922
+ output_hidden_states: Whether to return all hidden states
923
+ output_attentions: Whether to return attention weights
924
+
925
+ Returns:
926
+ ESMplusplusOutput containing loss, logits, and hidden states
927
+ """
928
+ output = super().forward(
929
+ input_ids=input_ids,
930
+ attention_mask=attention_mask,
931
+ inputs_embeds=inputs_embeds,
932
+ labels=None,
933
+ output_attentions=output_attentions,
934
+ output_hidden_states=output_hidden_states
935
+ )
936
+ x = output.last_hidden_state
937
+ logits = self.classifier(x)
938
+ loss = None
939
+ if labels is not None:
940
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
941
+ return ESMplusplusOutput(
942
+ loss=loss,
943
+ logits=logits,
944
+ last_hidden_state=x,
945
+ hidden_states=output.hidden_states,
946
+ )
947
+
948
+
949
+ ### Loading from EvolutionaryScale
950
+ @staticmethod
951
+ @cache
952
+ def data_root(model: str):
953
+ if "INFRA_PROVIDER" in os.environ:
954
+ return Path("")
955
+ # Try to download from hugginface if it doesn't exist
956
+ if model.startswith("esmc-300"):
957
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
958
+ elif model.startswith("esmc-600"):
959
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
960
+ else:
961
+ raise ValueError(f"{model=} is an invalid model name.")
962
+ return path
963
+
964
+
965
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
966
+ with torch.device(device):
967
+ config = ESMplusplusConfig(
968
+ hidden_size=960,
969
+ num_attention_heads=15,
970
+ num_hidden_layers=30,
971
+ )
972
+ model = ESMplusplusForMaskedLM(config)
973
+ state_dict = torch.load(
974
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
975
+ map_location=device,
976
+ )
977
+ model.load_state_dict(state_dict)
978
+ return model
979
+
980
+
981
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
982
+ with torch.device(device):
983
+ config = ESMplusplusConfig(
984
+ hidden_size=1152,
985
+ num_attention_heads=18,
986
+ num_hidden_layers=36,
987
+ )
988
+ model = ESMplusplusForMaskedLM(config)
989
+ state_dict = torch.load(
990
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
991
+ map_location=device,
992
+ )
993
+ model.load_state_dict(state_dict)
994
+ return model
995
+
996
+
997
+ ### Tokenization
998
+ SEQUENCE_VOCAB = [
999
+ "<cls>", "<pad>", "<eos>", "<unk>",
1000
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1001
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1002
+ "O", ".", "-", "|",
1003
+ "<mask>",
1004
+ ]
1005
+
1006
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1007
+ model_input_names = ["input_ids", "attention_mask"]
1008
+
1009
+ def __init__(
1010
+ self,
1011
+ unk_token="<unk>",
1012
+ cls_token="<cls>",
1013
+ pad_token="<pad>",
1014
+ mask_token="<mask>",
1015
+ eos_token="<eos>",
1016
+ chain_break_token="|",
1017
+ **kwargs,
1018
+ ):
1019
+ all_tokens = SEQUENCE_VOCAB
1020
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1021
+
1022
+ # a character-level tokenizer is the same as BPE with no token merges
1023
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1024
+ tokenizer = Tokenizer(bpe)
1025
+ special_tokens = [
1026
+ cls_token,
1027
+ pad_token,
1028
+ mask_token,
1029
+ eos_token,
1030
+ chain_break_token,
1031
+ ]
1032
+ self.cb_token = chain_break_token
1033
+ additional_special_tokens = [chain_break_token]
1034
+
1035
+ tokenizer.add_special_tokens(special_tokens)
1036
+
1037
+ # This is where we configure the automatic addition of special tokens when we call
1038
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1039
+ # sequences are merged if you want.
1040
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
1041
+ single="<cls> $A <eos>",
1042
+ special_tokens=[
1043
+ ("<cls>", tokenizer.token_to_id("<cls>")),
1044
+ ("<eos>", tokenizer.token_to_id("<eos>")),
1045
+ ],
1046
+ )
1047
+ super().__init__(
1048
+ tokenizer_object=tokenizer,
1049
+ unk_token=unk_token,
1050
+ cls_token=cls_token,
1051
+ pad_token=pad_token,
1052
+ mask_token=mask_token,
1053
+ eos_token=eos_token,
1054
+ additional_special_tokens=additional_special_tokens,
1055
+ **kwargs,
1056
+ )
1057
+
1058
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1059
+ @property
1060
+ def bos_token(self):
1061
+ return self.cls_token
1062
+
1063
+ @property
1064
+ def bos_token_id(self):
1065
+ return self.cls_token_id
1066
+
1067
+ @property
1068
+ def chain_break_token(self):
1069
+ return self.cb_token
1070
+
1071
+ @property
1072
+ def chain_break_token_id(self):
1073
+ return self.convert_tokens_to_ids(self.chain_break_token)
1074
+
1075
+ @property
1076
+ def all_token_ids(self):
1077
+ return list(range(self.vocab_size))
1078
+
1079
+ @property
1080
+ def special_token_ids(self):
1081
+ return self.all_special_ids
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65da6c60c1bce7631099ae5107c966b1a6a140c7ab5e6f84f6332b574580e612
3
+ size 212600235
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79246c8e03186bc8a2eecef837174b61b169a5f4a4f659bfa3b5a475f95bb0a0
3
+ size 2342537950
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb85e4337de2376f71e4e72e2c65a95701ab80e2018c4e3c0cdfcbdb0b0c6947
3
+ size 14244
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c41e07b12a2ac2fb1a259324594e38d5300825b5c8641ee7574777826112ae
3
+ size 1064
special_tokens_map.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "|"
4
+ ],
5
+ "cls_token": "<cls>",
6
+ "eos_token": "<eos>",
7
+ "mask_token": "<mask>",
8
+ "pad_token": "<pad>",
9
+ "unk_token": "<unk>"
10
+ }
tokenizer.json ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<cls>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<pad>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<eos>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<unk>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 31,
44
+ "content": "|",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 32,
53
+ "content": "<mask>",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ }
60
+ ],
61
+ "normalizer": null,
62
+ "pre_tokenizer": null,
63
+ "post_processor": {
64
+ "type": "TemplateProcessing",
65
+ "single": [
66
+ {
67
+ "SpecialToken": {
68
+ "id": "<cls>",
69
+ "type_id": 0
70
+ }
71
+ },
72
+ {
73
+ "Sequence": {
74
+ "id": "A",
75
+ "type_id": 0
76
+ }
77
+ },
78
+ {
79
+ "SpecialToken": {
80
+ "id": "<eos>",
81
+ "type_id": 0
82
+ }
83
+ }
84
+ ],
85
+ "pair": [
86
+ {
87
+ "Sequence": {
88
+ "id": "A",
89
+ "type_id": 0
90
+ }
91
+ },
92
+ {
93
+ "Sequence": {
94
+ "id": "B",
95
+ "type_id": 1
96
+ }
97
+ }
98
+ ],
99
+ "special_tokens": {
100
+ "<cls>": {
101
+ "id": "<cls>",
102
+ "ids": [
103
+ 0
104
+ ],
105
+ "tokens": [
106
+ "<cls>"
107
+ ]
108
+ },
109
+ "<eos>": {
110
+ "id": "<eos>",
111
+ "ids": [
112
+ 2
113
+ ],
114
+ "tokens": [
115
+ "<eos>"
116
+ ]
117
+ }
118
+ }
119
+ },
120
+ "decoder": null,
121
+ "model": {
122
+ "type": "BPE",
123
+ "dropout": null,
124
+ "unk_token": "<unk>",
125
+ "continuing_subword_prefix": null,
126
+ "end_of_word_suffix": null,
127
+ "fuse_unk": false,
128
+ "byte_fallback": false,
129
+ "ignore_merges": false,
130
+ "vocab": {
131
+ "<cls>": 0,
132
+ "<pad>": 1,
133
+ "<eos>": 2,
134
+ "<unk>": 3,
135
+ "L": 4,
136
+ "A": 5,
137
+ "G": 6,
138
+ "V": 7,
139
+ "S": 8,
140
+ "E": 9,
141
+ "R": 10,
142
+ "T": 11,
143
+ "I": 12,
144
+ "D": 13,
145
+ "P": 14,
146
+ "K": 15,
147
+ "Q": 16,
148
+ "N": 17,
149
+ "F": 18,
150
+ "Y": 19,
151
+ "M": 20,
152
+ "H": 21,
153
+ "W": 22,
154
+ "C": 23,
155
+ "X": 24,
156
+ "B": 25,
157
+ "U": 26,
158
+ "Z": 27,
159
+ "O": 28,
160
+ ".": 29,
161
+ "-": 30,
162
+ "|": 31,
163
+ "<mask>": 32
164
+ },
165
+ "merges": []
166
+ }
167
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "31": {
36
+ "content": "|",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "32": {
44
+ "content": "<mask>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ }
51
+ },
52
+ "additional_special_tokens": [
53
+ "|"
54
+ ],
55
+ "bos_token": "<cls>",
56
+ "clean_up_tokenization_spaces": false,
57
+ "cls_token": "<cls>",
58
+ "eos_token": "<eos>",
59
+ "mask_token": "<mask>",
60
+ "model_max_length": 1000000000000000019884624838656,
61
+ "pad_token": "<pad>",
62
+ "tokenizer_class": "EsmSequenceTokenizer",
63
+ "unk_token": "<unk>"
64
+ }
trainer_state.json ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.8709677419354839,
3
+ "best_model_checkpoint": "ESMC_plus-finetuned-TP53_201AA/checkpoint-800",
4
+ "epoch": 36.36363636363637,
5
+ "eval_steps": 50,
6
+ "global_step": 800,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 2.2727272727272725,
13
+ "grad_norm": 3.7824575901031494,
14
+ "learning_rate": 5.6999999999999996e-05,
15
+ "loss": 0.6625,
16
+ "step": 50
17
+ },
18
+ {
19
+ "epoch": 2.2727272727272725,
20
+ "eval_f1_score": 0.7755102040816326,
21
+ "eval_loss": 0.5577281713485718,
22
+ "eval_precision": 0.6855670103092784,
23
+ "eval_recall": 0.8926174496644296,
24
+ "eval_roc_auc": 0.7809697045725397,
25
+ "eval_runtime": 0.5094,
26
+ "eval_samples_per_second": 693.021,
27
+ "eval_steps_per_second": 5.89,
28
+ "step": 50
29
+ },
30
+ {
31
+ "epoch": 4.545454545454545,
32
+ "grad_norm": 10.261860847473145,
33
+ "learning_rate": 5.4000000000000005e-05,
34
+ "loss": 0.4772,
35
+ "step": 100
36
+ },
37
+ {
38
+ "epoch": 4.545454545454545,
39
+ "eval_f1_score": 0.7908496732026143,
40
+ "eval_loss": 0.5206712484359741,
41
+ "eval_precision": 0.7707006369426752,
42
+ "eval_recall": 0.8120805369127517,
43
+ "eval_roc_auc": 0.80794078906103,
44
+ "eval_runtime": 0.5068,
45
+ "eval_samples_per_second": 696.586,
46
+ "eval_steps_per_second": 5.92,
47
+ "step": 100
48
+ },
49
+ {
50
+ "epoch": 6.818181818181818,
51
+ "grad_norm": 11.708925247192383,
52
+ "learning_rate": 5.1e-05,
53
+ "loss": 0.4047,
54
+ "step": 150
55
+ },
56
+ {
57
+ "epoch": 6.818181818181818,
58
+ "eval_f1_score": 0.7753623188405797,
59
+ "eval_loss": 0.49258285760879517,
60
+ "eval_precision": 0.84251968503937,
61
+ "eval_recall": 0.7181208053691275,
62
+ "eval_roc_auc": 0.8690334316000753,
63
+ "eval_runtime": 0.5185,
64
+ "eval_samples_per_second": 680.81,
65
+ "eval_steps_per_second": 5.786,
66
+ "step": 150
67
+ },
68
+ {
69
+ "epoch": 9.090909090909092,
70
+ "grad_norm": 14.289894104003906,
71
+ "learning_rate": 4.8e-05,
72
+ "loss": 0.3611,
73
+ "step": 200
74
+ },
75
+ {
76
+ "epoch": 9.090909090909092,
77
+ "eval_f1_score": 0.7753623188405797,
78
+ "eval_loss": 0.46306541562080383,
79
+ "eval_precision": 0.84251968503937,
80
+ "eval_recall": 0.7181208053691275,
81
+ "eval_roc_auc": 0.8886658721696042,
82
+ "eval_runtime": 0.487,
83
+ "eval_samples_per_second": 724.818,
84
+ "eval_steps_per_second": 6.16,
85
+ "step": 200
86
+ },
87
+ {
88
+ "epoch": 11.363636363636363,
89
+ "grad_norm": 2.303349494934082,
90
+ "learning_rate": 4.5e-05,
91
+ "loss": 0.3402,
92
+ "step": 250
93
+ },
94
+ {
95
+ "epoch": 11.363636363636363,
96
+ "eval_f1_score": 0.8417508417508418,
97
+ "eval_loss": 0.39569616317749023,
98
+ "eval_precision": 0.8445945945945946,
99
+ "eval_recall": 0.8389261744966443,
100
+ "eval_roc_auc": 0.8988270714420122,
101
+ "eval_runtime": 0.4881,
102
+ "eval_samples_per_second": 723.205,
103
+ "eval_steps_per_second": 6.146,
104
+ "step": 250
105
+ },
106
+ {
107
+ "epoch": 13.636363636363637,
108
+ "grad_norm": 4.852221488952637,
109
+ "learning_rate": 4.2e-05,
110
+ "loss": 0.3342,
111
+ "step": 300
112
+ },
113
+ {
114
+ "epoch": 13.636363636363637,
115
+ "eval_f1_score": 0.8617363344051447,
116
+ "eval_loss": 0.42518067359924316,
117
+ "eval_precision": 0.8271604938271605,
118
+ "eval_recall": 0.8993288590604027,
119
+ "eval_roc_auc": 0.8945618766856928,
120
+ "eval_runtime": 0.4865,
121
+ "eval_samples_per_second": 725.541,
122
+ "eval_steps_per_second": 6.166,
123
+ "step": 300
124
+ },
125
+ {
126
+ "epoch": 15.909090909090908,
127
+ "grad_norm": 5.0347747802734375,
128
+ "learning_rate": 3.9e-05,
129
+ "loss": 0.2885,
130
+ "step": 350
131
+ },
132
+ {
133
+ "epoch": 15.909090909090908,
134
+ "eval_f1_score": 0.8571428571428571,
135
+ "eval_loss": 0.40493202209472656,
136
+ "eval_precision": 0.8486842105263158,
137
+ "eval_recall": 0.8657718120805369,
138
+ "eval_roc_auc": 0.9029668192937339,
139
+ "eval_runtime": 0.4878,
140
+ "eval_samples_per_second": 723.589,
141
+ "eval_steps_per_second": 6.149,
142
+ "step": 350
143
+ },
144
+ {
145
+ "epoch": 18.181818181818183,
146
+ "grad_norm": 1.1147843599319458,
147
+ "learning_rate": 3.6e-05,
148
+ "loss": 0.2654,
149
+ "step": 400
150
+ },
151
+ {
152
+ "epoch": 18.181818181818183,
153
+ "eval_f1_score": 0.8304498269896193,
154
+ "eval_loss": 0.4450823664665222,
155
+ "eval_precision": 0.8571428571428571,
156
+ "eval_recall": 0.8053691275167785,
157
+ "eval_roc_auc": 0.9034058834598256,
158
+ "eval_runtime": 0.4878,
159
+ "eval_samples_per_second": 723.718,
160
+ "eval_steps_per_second": 6.151,
161
+ "step": 400
162
+ },
163
+ {
164
+ "epoch": 20.454545454545453,
165
+ "grad_norm": 7.004922866821289,
166
+ "learning_rate": 3.3e-05,
167
+ "loss": 0.2596,
168
+ "step": 450
169
+ },
170
+ {
171
+ "epoch": 20.454545454545453,
172
+ "eval_f1_score": 0.8395904436860068,
173
+ "eval_loss": 0.4173749089241028,
174
+ "eval_precision": 0.8541666666666666,
175
+ "eval_recall": 0.825503355704698,
176
+ "eval_roc_auc": 0.9070438436931568,
177
+ "eval_runtime": 0.4865,
178
+ "eval_samples_per_second": 725.609,
179
+ "eval_steps_per_second": 6.167,
180
+ "step": 450
181
+ },
182
+ {
183
+ "epoch": 22.727272727272727,
184
+ "grad_norm": 2.1273648738861084,
185
+ "learning_rate": 3e-05,
186
+ "loss": 0.2366,
187
+ "step": 500
188
+ },
189
+ {
190
+ "epoch": 22.727272727272727,
191
+ "eval_f1_score": 0.8142857142857143,
192
+ "eval_loss": 0.4525018334388733,
193
+ "eval_precision": 0.8702290076335878,
194
+ "eval_recall": 0.7651006711409396,
195
+ "eval_roc_auc": 0.9105563570218905,
196
+ "eval_runtime": 0.4875,
197
+ "eval_samples_per_second": 724.141,
198
+ "eval_steps_per_second": 6.154,
199
+ "step": 500
200
+ },
201
+ {
202
+ "epoch": 25.0,
203
+ "grad_norm": 4.955398082733154,
204
+ "learning_rate": 2.7000000000000002e-05,
205
+ "loss": 0.2306,
206
+ "step": 550
207
+ },
208
+ {
209
+ "epoch": 25.0,
210
+ "eval_f1_score": 0.8350877192982457,
211
+ "eval_loss": 0.4550461769104004,
212
+ "eval_precision": 0.875,
213
+ "eval_recall": 0.7986577181208053,
214
+ "eval_roc_auc": 0.9146333814213135,
215
+ "eval_runtime": 0.487,
216
+ "eval_samples_per_second": 724.913,
217
+ "eval_steps_per_second": 6.161,
218
+ "step": 550
219
+ },
220
+ {
221
+ "epoch": 27.272727272727273,
222
+ "grad_norm": 1.5066548585891724,
223
+ "learning_rate": 2.4e-05,
224
+ "loss": 0.2214,
225
+ "step": 600
226
+ },
227
+ {
228
+ "epoch": 27.272727272727273,
229
+ "eval_f1_score": 0.8542372881355932,
230
+ "eval_loss": 0.4234406054019928,
231
+ "eval_precision": 0.863013698630137,
232
+ "eval_recall": 0.8456375838926175,
233
+ "eval_roc_auc": 0.9131280185661419,
234
+ "eval_runtime": 0.487,
235
+ "eval_samples_per_second": 724.849,
236
+ "eval_steps_per_second": 6.16,
237
+ "step": 600
238
+ },
239
+ {
240
+ "epoch": 29.545454545454547,
241
+ "grad_norm": 2.199934482574463,
242
+ "learning_rate": 2.1e-05,
243
+ "loss": 0.2007,
244
+ "step": 650
245
+ },
246
+ {
247
+ "epoch": 29.545454545454547,
248
+ "eval_f1_score": 0.8406779661016949,
249
+ "eval_loss": 0.4141705632209778,
250
+ "eval_precision": 0.8493150684931506,
251
+ "eval_recall": 0.8322147651006712,
252
+ "eval_roc_auc": 0.9175186602270589,
253
+ "eval_runtime": 0.4861,
254
+ "eval_samples_per_second": 726.128,
255
+ "eval_steps_per_second": 6.171,
256
+ "step": 650
257
+ },
258
+ {
259
+ "epoch": 31.818181818181817,
260
+ "grad_norm": 3.5876011848449707,
261
+ "learning_rate": 1.8e-05,
262
+ "loss": 0.1869,
263
+ "step": 700
264
+ },
265
+ {
266
+ "epoch": 31.818181818181817,
267
+ "eval_f1_score": 0.8533333333333334,
268
+ "eval_loss": 0.42080163955688477,
269
+ "eval_precision": 0.847682119205298,
270
+ "eval_recall": 0.8590604026845637,
271
+ "eval_roc_auc": 0.9163896380856803,
272
+ "eval_runtime": 0.486,
273
+ "eval_samples_per_second": 726.361,
274
+ "eval_steps_per_second": 6.173,
275
+ "step": 700
276
+ },
277
+ {
278
+ "epoch": 34.09090909090909,
279
+ "grad_norm": 2.6010570526123047,
280
+ "learning_rate": 1.5e-05,
281
+ "loss": 0.1791,
282
+ "step": 750
283
+ },
284
+ {
285
+ "epoch": 34.09090909090909,
286
+ "eval_f1_score": 0.8655737704918033,
287
+ "eval_loss": 0.405379056930542,
288
+ "eval_precision": 0.8461538461538461,
289
+ "eval_recall": 0.8859060402684564,
290
+ "eval_roc_auc": 0.9218465784356771,
291
+ "eval_runtime": 0.4871,
292
+ "eval_samples_per_second": 724.696,
293
+ "eval_steps_per_second": 6.159,
294
+ "step": 750
295
+ },
296
+ {
297
+ "epoch": 36.36363636363637,
298
+ "grad_norm": 4.041784286499023,
299
+ "learning_rate": 1.2e-05,
300
+ "loss": 0.1666,
301
+ "step": 800
302
+ },
303
+ {
304
+ "epoch": 36.36363636363637,
305
+ "eval_f1_score": 0.8709677419354839,
306
+ "eval_loss": 0.4173583388328552,
307
+ "eval_precision": 0.8385093167701864,
308
+ "eval_recall": 0.9060402684563759,
309
+ "eval_roc_auc": 0.9189612996299317,
310
+ "eval_runtime": 0.4862,
311
+ "eval_samples_per_second": 725.969,
312
+ "eval_steps_per_second": 6.17,
313
+ "step": 800
314
+ }
315
+ ],
316
+ "logging_steps": 50,
317
+ "max_steps": 1000,
318
+ "num_input_tokens_seen": 0,
319
+ "num_train_epochs": 46,
320
+ "save_steps": 50,
321
+ "stateful_callbacks": {
322
+ "TrainerControl": {
323
+ "args": {
324
+ "should_epoch_stop": false,
325
+ "should_evaluate": false,
326
+ "should_log": false,
327
+ "should_save": true,
328
+ "should_training_stop": false
329
+ },
330
+ "attributes": {}
331
+ }
332
+ },
333
+ "total_flos": 1.9066711283712e+16,
334
+ "train_batch_size": 128,
335
+ "trial_name": null,
336
+ "trial_params": null
337
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6cb8550a209622e86a866422a75efe9cb3f0d1718a7f1cd6d94caa3ce9ec57a
3
+ size 5304