quant12 commited on
Commit
40f6c31
·
1 Parent(s): 3b3c2ff

update to pelmo2

Browse files
Files changed (4) hide show
  1. config.json +47 -47
  2. modeling_plamo.py +103 -95
  3. tokenization_plamo.py +1 -1
  4. tokenizer_config.json +52 -52
config.json CHANGED
@@ -1,49 +1,49 @@
1
  {
2
- "architectures": [
3
- "PlamoForCausalLM"
4
- ],
5
- "attention_window_size": 2048,
6
- "auto_map": {
7
- "AutoConfig": "modeling_plamo.PlamoConfig",
8
- "AutoModelForCausalLM": "modeling_plamo.PlamoForCausalLM"
9
- },
10
- "bos_token_id": 1,
11
- "capacity_factor": 1.0,
12
- "eos_token_id": 2,
13
- "eval_attention_n_bit": null,
14
- "eval_mlp_n_bit": null,
15
- "expert_dropout": 0.0,
16
- "fp8_accum_dtype": "bfloat16",
17
- "group_size": 1024,
18
- "hidden_size": 2048,
19
- "hidden_size_per_head": 128,
20
- "image_feature_size": null,
21
- "image_proj_type": "linear",
22
- "image_token_id": null,
23
- "intermediate_size": 8192,
24
- "k_expert": null,
25
- "linear_type": "fp8",
26
- "mamba_chunk_size": 256,
27
- "mamba_d_conv": 4,
28
- "mamba_d_state": 64,
29
- "mamba_enabled": true,
30
- "mamba_num_heads": 32,
31
- "mamba_step": 2,
32
- "max_position_embeddings": 10485760,
33
- "model_type": "plamo2",
34
- "n_expert": null,
35
- "num_attention_heads": 16,
36
- "num_hidden_layers": 16,
37
- "num_key_value_heads": 1,
38
- "rms_norm_eps": 1e-06,
39
- "shared_intermediate_size": null,
40
- "sliding_window": 2048,
41
- "sparse_intermediate_size": null,
42
- "sparse_step": null,
43
- "tokenizer_class": "PlamoTokenizer",
44
- "torch_dtype": "float32",
45
- "transformers_version": "4.44.2",
46
- "use_cache": true,
47
- "use_predefined_initial_state": false,
48
- "vocab_size": 100000
49
  }
 
1
  {
2
+ "architectures": [
3
+ "Plamo2ForCausalLM"
4
+ ],
5
+ "attention_window_size": 2048,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.Plamo2Config",
8
+ "AutoModelForCausalLM": "modeling_plamo.Plamo2ForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "capacity_factor": 1.0,
12
+ "eos_token_id": 2,
13
+ "eval_attention_n_bit": null,
14
+ "eval_mlp_n_bit": null,
15
+ "expert_dropout": 0.0,
16
+ "fp8_accum_dtype": "bfloat16",
17
+ "group_size": 1024,
18
+ "hidden_size": 2048,
19
+ "hidden_size_per_head": 128,
20
+ "image_feature_size": null,
21
+ "image_proj_type": "linear",
22
+ "image_token_id": null,
23
+ "intermediate_size": 8192,
24
+ "k_expert": null,
25
+ "linear_type": "fp8",
26
+ "mamba_chunk_size": 256,
27
+ "mamba_d_conv": 4,
28
+ "mamba_d_state": 64,
29
+ "mamba_enabled": true,
30
+ "mamba_num_heads": 32,
31
+ "mamba_step": 2,
32
+ "max_position_embeddings": 10485760,
33
+ "model_type": "plamo2",
34
+ "n_expert": null,
35
+ "num_attention_heads": 16,
36
+ "num_hidden_layers": 16,
37
+ "num_key_value_heads": 1,
38
+ "rms_norm_eps": 1e-06,
39
+ "shared_intermediate_size": null,
40
+ "sliding_window": 2048,
41
+ "sparse_intermediate_size": null,
42
+ "sparse_step": null,
43
+ "tokenizer_class": "Plamo2Tokenizer",
44
+ "torch_dtype": "float32",
45
+ "transformers_version": "4.44.2",
46
+ "use_cache": true,
47
+ "use_predefined_initial_state": false,
48
+ "vocab_size": 100000
49
  }
modeling_plamo.py CHANGED
@@ -105,8 +105,8 @@ class LinearType(str, enum.Enum):
105
  Fp8Retain = "fp8-retain"
106
 
107
 
108
- class PlamoConfig(PretrainedConfig): # type: ignore
109
- model_type: str = "plamo"
110
 
111
  def __init__(
112
  self,
@@ -121,6 +121,8 @@ class PlamoConfig(PretrainedConfig): # type: ignore
121
  max_position_embeddings: int = 2048,
122
  attention_window_size: int = 2048,
123
  full_attention_idx: list[int] | None = None,
 
 
124
  # Mamba
125
  mamba_d_state: int = 64,
126
  mamba_d_conv: int = 4,
@@ -132,7 +134,7 @@ class PlamoConfig(PretrainedConfig): # type: ignore
132
  intermediate_size: int = 13312,
133
  # Tokenizer
134
  vocab_size: int = 32000,
135
- tokenizer_class: str = "PlamoTokenizer",
136
  pad_token_id: Optional[int] = None,
137
  bos_token_id: int = 1,
138
  eos_token_id: int = 2,
@@ -161,6 +163,8 @@ class PlamoConfig(PretrainedConfig): # type: ignore
161
  self.num_key_value_heads = num_key_value_heads
162
  self.attention_window_size = attention_window_size
163
  self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
 
 
164
 
165
  self.mamba_d_state = mamba_d_state
166
  self.mamba_d_conv = mamba_d_conv
@@ -196,8 +200,16 @@ class PlamoConfig(PretrainedConfig): # type: ignore
196
  **kwargs,
197
  )
198
 
 
 
 
 
 
 
 
199
 
200
- class PlamoAttentionCache(torch.nn.Module):
 
201
  def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
202
  super().__init__()
203
  B, nh, L, c = key.shape
@@ -208,7 +220,7 @@ class PlamoAttentionCache(torch.nn.Module):
208
  self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
209
 
210
 
211
- class PlamoMambaCache(torch.nn.Module):
212
  def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
213
  super().__init__()
214
  # conv_state: [B, C, d_conv]
@@ -220,10 +232,10 @@ class PlamoMambaCache(torch.nn.Module):
220
  self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
221
 
222
 
223
- PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
224
 
225
 
226
- class PlamoCache(torch.nn.Module):
227
  """
228
  stores states of the model for fast decoding.
229
  `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
@@ -233,7 +245,7 @@ class PlamoCache(torch.nn.Module):
233
  the state of Mamba properly.
234
  """
235
 
236
- def __init__(self, config: PlamoConfig) -> None:
237
  super().__init__()
238
  self.config = config
239
  self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
@@ -242,7 +254,7 @@ class PlamoCache(torch.nn.Module):
242
  c = self.cache[layer_idx]
243
  if c is None:
244
  return key, value
245
- assert isinstance(c, PlamoAttentionCache)
246
 
247
  def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
248
  assert len(cache.shape) == 4
@@ -258,20 +270,20 @@ class PlamoCache(torch.nn.Module):
258
 
259
  def update_attention(
260
  self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
261
- ) -> PlamoAttentionCache:
262
  full_attn = layer_idx in self.config.full_attention_idx
263
  window_size = self.config.attention_window_size
264
 
265
  if self.cache[layer_idx] is None:
266
  if full_attn:
267
- self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
268
  else:
269
- self.cache[layer_idx] = PlamoAttentionCache(
270
  key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
271
  )
272
  else:
273
  c = self.cache[layer_idx]
274
- assert isinstance(c, PlamoAttentionCache)
275
  k, v = self.append_kv(key_states, value_states, layer_idx)
276
  if full_attn:
277
  c.key.data = k
@@ -281,19 +293,19 @@ class PlamoCache(torch.nn.Module):
281
  c.value.data = v[:, :, -window_size:, :]
282
  return self.cache[layer_idx] # type: ignore
283
 
284
- def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> PlamoMambaCache:
285
  if self.cache[layer_idx] is None:
286
- self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
287
  else:
288
  c = self.cache[layer_idx]
289
- assert isinstance(c, PlamoMambaCache)
290
  assert c.conv_state.shape == conv_state.shape
291
  assert c.ssm_state.shape == ssm_state.shape
292
  c.conv_state.data = conv_state
293
  c.ssm_state.data = ssm_state
294
  return self.cache[layer_idx] # type: ignore
295
 
296
- def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
297
  assert layer_idx < len(self.cache)
298
  layer_cache = self.cache[layer_idx]
299
  return layer_cache # type: ignore
@@ -304,12 +316,12 @@ class PlamoCache(torch.nn.Module):
304
  def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
305
  if layer_idx is not None:
306
  c = self.cache[layer_idx]
307
- assert isinstance(c, PlamoAttentionCache)
308
  return c.key.shape[2] # type: ignore
309
 
310
  sequence_length: int | None = None
311
  for layer_cache in self.cache:
312
- if isinstance(layer_cache, PlamoAttentionCache):
313
  sequence_length = (
314
  max(layer_cache.key.shape[2], sequence_length)
315
  if sequence_length is not None
@@ -333,14 +345,14 @@ class PlamoCache(torch.nn.Module):
333
  return previous_seq_length
334
 
335
  def reorder_cache(self, beam_idx: torch.Tensor) -> None:
336
- def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
337
- return PlamoMambaCache(
338
  conv_state=cache.conv_state.index_select(0, beam_idx),
339
  ssm_state=cache.ssm_state.index_select(0, beam_idx),
340
  )
341
 
342
- def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
343
- return PlamoAttentionCache(
344
  key=cache.key.index_select(0, beam_idx),
345
  value=cache.value.index_select(0, beam_idx),
346
  )
@@ -349,10 +361,10 @@ class PlamoCache(torch.nn.Module):
349
  if self.cache[i] is None:
350
  continue
351
  layer_cache = self.cache[i]
352
- if isinstance(layer_cache, PlamoMambaCache):
353
  self.cache[i] = _mamba(layer_cache)
354
  else:
355
- assert isinstance(layer_cache, PlamoAttentionCache)
356
  self.cache[i] = _attention(layer_cache)
357
 
358
  @property
@@ -363,7 +375,7 @@ class PlamoCache(torch.nn.Module):
363
  class DecoderInput(NamedTuple):
364
  hidden_states: torch.Tensor
365
  attention_mask: Optional[torch.Tensor] = None
366
- past_states: Optional[PlamoCache] = None
367
  output_hidden_states: Optional[bool] = False
368
  output_attentions: Optional[bool] = False
369
  gradient_checkpointing: bool = False
@@ -810,7 +822,7 @@ def _causal_conv1d(
810
 
811
 
812
  class Mamba(torch.nn.Module):
813
- def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
814
  super().__init__()
815
  self.config = config
816
  self.layer_idx = layer_idx
@@ -862,8 +874,8 @@ class Mamba(torch.nn.Module):
862
  self,
863
  hidden_states: torch.Tensor,
864
  attention_mask: Optional[torch.Tensor] = None,
865
- past_states: Optional[PlamoCache] = None,
866
- ) -> Tuple[torch.Tensor, Optional[PlamoCache]]:
867
  bsize, length, _ = hidden_states.shape
868
  is_update = length == 1 and past_states is not None
869
 
@@ -905,7 +917,7 @@ class Mamba(torch.nn.Module):
905
  )
906
  else:
907
  c = past_states[self.layer_idx]
908
- assert isinstance(c, PlamoMambaCache)
909
  conv_state = c.conv_state
910
  ssm_state = c.ssm_state
911
 
@@ -1022,7 +1034,7 @@ def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) ->
1022
 
1023
 
1024
  class Attention(torch.nn.Module):
1025
- def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
1026
  super().__init__()
1027
  self.config = config
1028
  self.layer_idx = layer_idx
@@ -1045,15 +1057,19 @@ class Attention(torch.nn.Module):
1045
  self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
1046
  self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
1047
 
1048
- self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
 
 
 
 
1049
 
1050
  def forward(
1051
  self,
1052
  hidden_states: torch.Tensor,
1053
  attention_mask: Optional[torch.Tensor] = None,
1054
- past_states: Optional[PlamoCache] = None,
1055
  output_attentions: bool = False,
1056
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoCache]]:
1057
  bsz, q_len, _ = hidden_states.size()
1058
 
1059
  qkv = self.qkv_proj(hidden_states)
@@ -1094,15 +1110,13 @@ class Attention(torch.nn.Module):
1094
  key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1095
  value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1096
 
1097
- full_attn = self.layer_idx in self.config.full_attention_idx
1098
-
1099
  query_states = query_states.to(attn_dtype)
1100
  key_states = key_states.to(attn_dtype)
1101
  value_states = value_states.to(attn_dtype)
1102
  if attention_mask is not None and attention_mask.dtype != torch.bool:
1103
  attention_mask = attention_mask.to(attn_dtype)
1104
  if attention_mask is None:
1105
- if not full_attn:
1106
  assert key_states.shape[2] <= self.config.attention_window_size + 1
1107
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1108
  else:
@@ -1112,7 +1126,7 @@ class Attention(torch.nn.Module):
1112
  attention_mask = attention_mask[None, None]
1113
  assert len(attention_mask.shape) == 4
1114
 
1115
- if not full_attn:
1116
  m_swa = swa_mask(
1117
  query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1118
  )
@@ -1142,7 +1156,7 @@ class Attention(torch.nn.Module):
1142
 
1143
 
1144
  class MLP(nn.Module):
1145
- def __init__(self, config: PlamoConfig) -> None:
1146
  super().__init__()
1147
  self.config = config
1148
  self.hidden_size = config.hidden_size
@@ -1156,14 +1170,14 @@ class MLP(nn.Module):
1156
  return self.down_proj(h) # type: ignore
1157
 
1158
 
1159
- class PlamoDecoderLayer(torch.nn.Module):
1160
- def __init__(self, config: PlamoConfig, is_mamba: bool, layer_idx: int) -> None:
1161
  super().__init__()
1162
  self.config = config
1163
  self.hidden_size = config.hidden_size
1164
- self.is_mamba = is_mamba
1165
  self.mixer: torch.nn.Module
1166
- if is_mamba:
1167
  self.mixer = Mamba(config, layer_idx)
1168
  else:
1169
  self.mixer = Attention(config, layer_idx)
@@ -1180,7 +1194,7 @@ class PlamoDecoderLayer(torch.nn.Module):
1180
  self,
1181
  hidden_states: torch.Tensor,
1182
  attention_mask: Optional[torch.Tensor] = None,
1183
- past_state: Optional[PlamoCache] = None,
1184
  output_attentions: Optional[bool] = False,
1185
  ) -> Tuple[Any, ...]:
1186
  # from LlamaDecoder
@@ -1224,7 +1238,7 @@ class PlamoDecoderLayer(torch.nn.Module):
1224
  return outputs # type: ignore
1225
 
1226
 
1227
- def is_mamba(config: PlamoConfig, i: int) -> bool:
1228
  if not config.mamba_enabled:
1229
  return False
1230
  assert config.mamba_step > 1
@@ -1236,15 +1250,12 @@ def is_mamba(config: PlamoConfig, i: int) -> bool:
1236
  return (i % config.mamba_step) != (config.mamba_step // 2)
1237
 
1238
 
1239
- class PlamoDecoder(torch.nn.Module):
1240
- def __init__(self, config: PlamoConfig) -> None:
1241
  super().__init__()
1242
 
1243
  self.layers = torch.nn.ModuleList(
1244
- [
1245
- PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
1246
- for i in range(config.num_hidden_layers)
1247
- ]
1248
  )
1249
  self.gradient_checkpointing = False
1250
 
@@ -1283,8 +1294,8 @@ class PlamoDecoder(torch.nn.Module):
1283
  return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1284
 
1285
 
1286
- class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1287
- config_class = PlamoConfig
1288
  _no_split_modules: List[str]
1289
  base_model_prefix = "model"
1290
  supports_gradient_checkpointing = True
@@ -1304,8 +1315,8 @@ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1304
  module.weight.data[module.padding_idx].zero_()
1305
 
1306
 
1307
- class PlamoModel(PlamoPreTrainedModel):
1308
- def __init__(self, config: PlamoConfig):
1309
  super().__init__(config)
1310
  assert config.eval_attention_n_bit is None
1311
  assert config.eval_mlp_n_bit is None
@@ -1321,7 +1332,7 @@ class PlamoModel(PlamoPreTrainedModel):
1321
  self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1322
  else:
1323
  raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1324
- self.layers = PlamoDecoder(config) # type: ignore
1325
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1326
 
1327
  self.gradient_checkpointing = False
@@ -1376,15 +1387,16 @@ class PlamoModel(PlamoPreTrainedModel):
1376
  input_ids: Optional[torch.LongTensor] = None,
1377
  attention_mask: Optional[torch.Tensor] = None,
1378
  position_ids: Optional[torch.Tensor] = None,
1379
- past_key_values: Optional[PlamoCache] = None,
1380
  inputs_embeds: Optional[torch.Tensor] = None,
1381
  image_features: Optional[torch.Tensor] = None,
1382
  use_cache: Optional[bool] = None,
1383
  output_attentions: Optional[bool] = None,
1384
  output_hidden_states: Optional[bool] = None,
1385
  return_dict: Optional[bool] = None,
 
 
1386
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1387
- assert input_ids is not None
1388
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1389
  output_hidden_states = (
1390
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1394,22 +1406,22 @@ class PlamoModel(PlamoPreTrainedModel):
1394
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1395
 
1396
  # retrieve input_ids and inputs_embeds
1397
- if input_ids is not None and inputs_embeds is not None:
1398
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1399
- elif input_ids is not None:
1400
- batch_size, seq_length = input_ids.shape
1401
- else:
1402
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
 
 
1403
 
1404
  seq_length_with_past = seq_length
1405
  past_key_values_length = 0
1406
-
1407
  if past_key_values is not None:
1408
  past_key_values_length = past_key_values.get_seq_length()
1409
  seq_length_with_past = seq_length_with_past + past_key_values_length
1410
-
1411
- if inputs_embeds is None:
1412
- inputs_embeds = self.embed_tokens(input_ids)
1413
 
1414
  if image_features is not None:
1415
  assert self.config.image_token_id is not None
@@ -1435,12 +1447,8 @@ class PlamoModel(PlamoPreTrainedModel):
1435
 
1436
  hidden_states = inputs_embeds
1437
 
1438
- if self.gradient_checkpointing and self.training:
1439
- if use_cache:
1440
- use_cache = False
1441
-
1442
  if use_cache and past_key_values is None:
1443
- past_key_values = PlamoCache(self.config)
1444
 
1445
  # decoder layers
1446
  out = self.layers(
@@ -1477,7 +1485,7 @@ class PlamoModel(PlamoPreTrainedModel):
1477
  )
1478
 
1479
 
1480
- class PlamoForCausalLM(PlamoPreTrainedModel):
1481
  _tied_weights_keys = ["lm_head.weight"]
1482
 
1483
  # Without this, the model cannot be loaded into a meta device.
@@ -1487,9 +1495,9 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1487
  # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1488
  _supports_param_buffer_assignment = False
1489
 
1490
- def __init__(self, config: PlamoConfig) -> None:
1491
  super().__init__(config)
1492
- self.model = PlamoModel(config)
1493
 
1494
  self.vocab_size = config.vocab_size
1495
  vocab_size = ((self.vocab_size + 15) // 16) * 16
@@ -1510,10 +1518,10 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1510
  def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1511
  self.lm_head = new_embeddings
1512
 
1513
- def set_decoder(self, decoder: PlamoModel) -> None:
1514
  self.model = decoder
1515
 
1516
- def get_decoder(self) -> PlamoModel:
1517
  return self.model
1518
 
1519
  def forward( # type: ignore
@@ -1521,7 +1529,7 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1521
  input_ids: Optional[torch.LongTensor] = None,
1522
  attention_mask: Optional[torch.Tensor] = None,
1523
  position_ids: Optional[torch.Tensor] = None,
1524
- past_key_values: Optional[PlamoCache] = None,
1525
  inputs_embeds: Optional[torch.FloatTensor] = None,
1526
  image_features: Optional[torch.Tensor] = None,
1527
  labels: Optional[torch.LongTensor] = None,
@@ -1529,6 +1537,9 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1529
  output_attentions: Optional[bool] = None,
1530
  output_hidden_states: Optional[bool] = None,
1531
  return_dict: Optional[bool] = None,
 
 
 
1532
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1533
  r"""
1534
  Args:
@@ -1555,8 +1566,6 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1555
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1556
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1557
  ```"""
1558
- assert input_ids is not None
1559
-
1560
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1561
  output_hidden_states = (
1562
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1575,24 +1584,23 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1575
  output_attentions=output_attentions,
1576
  output_hidden_states=output_hidden_states,
1577
  return_dict=return_dict,
 
 
1578
  )
1579
 
1580
  hidden_states = outputs[0]
1581
  logits = self.lm_head(hidden_states)
1582
- logits = logits[..., : self.vocab_size]
 
1583
 
1584
  loss = None
1585
  if labels is not None:
1586
- # Shift so that tokens < n predict n
1587
- shift_logits = logits[..., :-1, :].contiguous()
1588
- shift_labels = labels[..., 1:].contiguous()
1589
- # Flatten the tokens
1590
- loss_fct = nn.CrossEntropyLoss()
1591
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1592
- shift_labels = shift_labels.view(-1)
1593
- # Enable model parallelism
1594
- shift_labels = shift_labels.to(shift_logits.device)
1595
- loss = loss_fct(shift_logits, shift_labels)
1596
 
1597
  if not return_dict:
1598
  output = (logits,) + outputs[1:]
@@ -1609,7 +1617,7 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1609
  def prepare_inputs_for_generation(
1610
  self,
1611
  input_ids: torch.Tensor,
1612
- past_key_values: Optional[PlamoCache] = None,
1613
  attention_mask: Optional[torch.Tensor] = None,
1614
  inputs_embeds: Optional[torch.Tensor] = None,
1615
  image_features: Optional[torch.Tensor] = None,
@@ -1646,13 +1654,13 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
1646
  return model_inputs
1647
 
1648
  @staticmethod
1649
- def _reorder_cache(past_key_values: PlamoCache, beam_idx: torch.Tensor) -> PlamoCache:
1650
  past_key_values.reorder_cache(beam_idx)
1651
  return past_key_values
1652
 
1653
 
1654
  class MLPImageProjector(nn.Module):
1655
- def __init__(self, config: PlamoConfig) -> None:
1656
  super().__init__()
1657
  self.config = config
1658
 
 
105
  Fp8Retain = "fp8-retain"
106
 
107
 
108
+ class Plamo2Config(PretrainedConfig): # type: ignore
109
+ model_type: str = "plamo2"
110
 
111
  def __init__(
112
  self,
 
121
  max_position_embeddings: int = 2048,
122
  attention_window_size: int = 2048,
123
  full_attention_idx: list[int] | None = None,
124
+ rope_theta: int = 10000,
125
+ rope_local_theta: int = 10000,
126
  # Mamba
127
  mamba_d_state: int = 64,
128
  mamba_d_conv: int = 4,
 
134
  intermediate_size: int = 13312,
135
  # Tokenizer
136
  vocab_size: int = 32000,
137
+ tokenizer_class: str = "Plamo2Tokenizer",
138
  pad_token_id: Optional[int] = None,
139
  bos_token_id: int = 1,
140
  eos_token_id: int = 2,
 
163
  self.num_key_value_heads = num_key_value_heads
164
  self.attention_window_size = attention_window_size
165
  self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
166
+ self.rope_theta = rope_theta
167
+ self.rope_local_theta = rope_local_theta
168
 
169
  self.mamba_d_state = mamba_d_state
170
  self.mamba_d_conv = mamba_d_conv
 
200
  **kwargs,
201
  )
202
 
203
+ @property
204
+ def layers_block_type(self) -> list[str]:
205
+ return ["mamba" if is_mamba(self, i) else "attention" for i in range(self.num_hidden_layers)]
206
+
207
+ @property
208
+ def rope_local_base_freq(self) -> int:
209
+ return self.rope_local_theta
210
 
211
+
212
+ class Plamo2AttentionCache(torch.nn.Module):
213
  def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
214
  super().__init__()
215
  B, nh, L, c = key.shape
 
220
  self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
221
 
222
 
223
+ class Plamo2MambaCache(torch.nn.Module):
224
  def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
225
  super().__init__()
226
  # conv_state: [B, C, d_conv]
 
232
  self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
233
 
234
 
235
+ Plamo2LayerCache = Plamo2AttentionCache | Plamo2MambaCache
236
 
237
 
238
+ class Plamo2Cache(torch.nn.Module):
239
  """
240
  stores states of the model for fast decoding.
241
  `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
 
245
  the state of Mamba properly.
246
  """
247
 
248
+ def __init__(self, config: Plamo2Config) -> None:
249
  super().__init__()
250
  self.config = config
251
  self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
 
254
  c = self.cache[layer_idx]
255
  if c is None:
256
  return key, value
257
+ assert isinstance(c, Plamo2AttentionCache)
258
 
259
  def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
260
  assert len(cache.shape) == 4
 
270
 
271
  def update_attention(
272
  self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
273
+ ) -> Plamo2AttentionCache:
274
  full_attn = layer_idx in self.config.full_attention_idx
275
  window_size = self.config.attention_window_size
276
 
277
  if self.cache[layer_idx] is None:
278
  if full_attn:
279
+ self.cache[layer_idx] = Plamo2AttentionCache(key_states, value_states)
280
  else:
281
+ self.cache[layer_idx] = Plamo2AttentionCache(
282
  key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
283
  )
284
  else:
285
  c = self.cache[layer_idx]
286
+ assert isinstance(c, Plamo2AttentionCache)
287
  k, v = self.append_kv(key_states, value_states, layer_idx)
288
  if full_attn:
289
  c.key.data = k
 
293
  c.value.data = v[:, :, -window_size:, :]
294
  return self.cache[layer_idx] # type: ignore
295
 
296
+ def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> Plamo2MambaCache:
297
  if self.cache[layer_idx] is None:
298
+ self.cache[layer_idx] = Plamo2MambaCache(conv_state, ssm_state)
299
  else:
300
  c = self.cache[layer_idx]
301
+ assert isinstance(c, Plamo2MambaCache)
302
  assert c.conv_state.shape == conv_state.shape
303
  assert c.ssm_state.shape == ssm_state.shape
304
  c.conv_state.data = conv_state
305
  c.ssm_state.data = ssm_state
306
  return self.cache[layer_idx] # type: ignore
307
 
308
+ def __getitem__(self, layer_idx: int) -> Plamo2LayerCache | None:
309
  assert layer_idx < len(self.cache)
310
  layer_cache = self.cache[layer_idx]
311
  return layer_cache # type: ignore
 
316
  def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
317
  if layer_idx is not None:
318
  c = self.cache[layer_idx]
319
+ assert isinstance(c, Plamo2AttentionCache)
320
  return c.key.shape[2] # type: ignore
321
 
322
  sequence_length: int | None = None
323
  for layer_cache in self.cache:
324
+ if isinstance(layer_cache, Plamo2AttentionCache):
325
  sequence_length = (
326
  max(layer_cache.key.shape[2], sequence_length)
327
  if sequence_length is not None
 
345
  return previous_seq_length
346
 
347
  def reorder_cache(self, beam_idx: torch.Tensor) -> None:
348
+ def _mamba(cache: Plamo2MambaCache) -> Plamo2MambaCache:
349
+ return Plamo2MambaCache(
350
  conv_state=cache.conv_state.index_select(0, beam_idx),
351
  ssm_state=cache.ssm_state.index_select(0, beam_idx),
352
  )
353
 
354
+ def _attention(cache: Plamo2AttentionCache) -> Plamo2AttentionCache:
355
+ return Plamo2AttentionCache(
356
  key=cache.key.index_select(0, beam_idx),
357
  value=cache.value.index_select(0, beam_idx),
358
  )
 
361
  if self.cache[i] is None:
362
  continue
363
  layer_cache = self.cache[i]
364
+ if isinstance(layer_cache, Plamo2MambaCache):
365
  self.cache[i] = _mamba(layer_cache)
366
  else:
367
+ assert isinstance(layer_cache, Plamo2AttentionCache)
368
  self.cache[i] = _attention(layer_cache)
369
 
370
  @property
 
375
  class DecoderInput(NamedTuple):
376
  hidden_states: torch.Tensor
377
  attention_mask: Optional[torch.Tensor] = None
378
+ past_states: Optional[Plamo2Cache] = None
379
  output_hidden_states: Optional[bool] = False
380
  output_attentions: Optional[bool] = False
381
  gradient_checkpointing: bool = False
 
822
 
823
 
824
  class Mamba(torch.nn.Module):
825
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
826
  super().__init__()
827
  self.config = config
828
  self.layer_idx = layer_idx
 
874
  self,
875
  hidden_states: torch.Tensor,
876
  attention_mask: Optional[torch.Tensor] = None,
877
+ past_states: Optional[Plamo2Cache] = None,
878
+ ) -> Tuple[torch.Tensor, Optional[Plamo2Cache]]:
879
  bsize, length, _ = hidden_states.shape
880
  is_update = length == 1 and past_states is not None
881
 
 
917
  )
918
  else:
919
  c = past_states[self.layer_idx]
920
+ assert isinstance(c, Plamo2MambaCache)
921
  conv_state = c.conv_state
922
  ssm_state = c.ssm_state
923
 
 
1034
 
1035
 
1036
  class Attention(torch.nn.Module):
1037
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
1038
  super().__init__()
1039
  self.config = config
1040
  self.layer_idx = layer_idx
 
1057
  self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
1058
  self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
1059
 
1060
+ self.full_attn = self.layer_idx in self.config.full_attention_idx
1061
+ base = self.config.rope_theta if self.full_attn else self.config.rope_local_theta
1062
+ self.rotary_emb = RotaryEmbedding(
1063
+ self.qk_dim, max_position_embeddings=self.config.attention_window_size, base=base
1064
+ )
1065
 
1066
  def forward(
1067
  self,
1068
  hidden_states: torch.Tensor,
1069
  attention_mask: Optional[torch.Tensor] = None,
1070
+ past_states: Optional[Plamo2Cache] = None,
1071
  output_attentions: bool = False,
1072
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Plamo2Cache]]:
1073
  bsz, q_len, _ = hidden_states.size()
1074
 
1075
  qkv = self.qkv_proj(hidden_states)
 
1110
  key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1111
  value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1112
 
 
 
1113
  query_states = query_states.to(attn_dtype)
1114
  key_states = key_states.to(attn_dtype)
1115
  value_states = value_states.to(attn_dtype)
1116
  if attention_mask is not None and attention_mask.dtype != torch.bool:
1117
  attention_mask = attention_mask.to(attn_dtype)
1118
  if attention_mask is None:
1119
+ if not self.full_attn:
1120
  assert key_states.shape[2] <= self.config.attention_window_size + 1
1121
  attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1122
  else:
 
1126
  attention_mask = attention_mask[None, None]
1127
  assert len(attention_mask.shape) == 4
1128
 
1129
+ if not self.full_attn:
1130
  m_swa = swa_mask(
1131
  query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1132
  )
 
1156
 
1157
 
1158
  class MLP(nn.Module):
1159
+ def __init__(self, config: Plamo2Config) -> None:
1160
  super().__init__()
1161
  self.config = config
1162
  self.hidden_size = config.hidden_size
 
1170
  return self.down_proj(h) # type: ignore
1171
 
1172
 
1173
+ class Plamo2DecoderLayer(torch.nn.Module):
1174
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
1175
  super().__init__()
1176
  self.config = config
1177
  self.hidden_size = config.hidden_size
1178
+ self.is_mamba = config.layers_block_type[layer_idx] == "mamba"
1179
  self.mixer: torch.nn.Module
1180
+ if self.is_mamba:
1181
  self.mixer = Mamba(config, layer_idx)
1182
  else:
1183
  self.mixer = Attention(config, layer_idx)
 
1194
  self,
1195
  hidden_states: torch.Tensor,
1196
  attention_mask: Optional[torch.Tensor] = None,
1197
+ past_state: Optional[Plamo2Cache] = None,
1198
  output_attentions: Optional[bool] = False,
1199
  ) -> Tuple[Any, ...]:
1200
  # from LlamaDecoder
 
1238
  return outputs # type: ignore
1239
 
1240
 
1241
+ def is_mamba(config: Plamo2Config, i: int) -> bool:
1242
  if not config.mamba_enabled:
1243
  return False
1244
  assert config.mamba_step > 1
 
1250
  return (i % config.mamba_step) != (config.mamba_step // 2)
1251
 
1252
 
1253
+ class Plamo2Decoder(torch.nn.Module):
1254
+ def __init__(self, config: Plamo2Config) -> None:
1255
  super().__init__()
1256
 
1257
  self.layers = torch.nn.ModuleList(
1258
+ [Plamo2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
 
 
 
1259
  )
1260
  self.gradient_checkpointing = False
1261
 
 
1294
  return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1295
 
1296
 
1297
+ class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
1298
+ config_class = Plamo2Config
1299
  _no_split_modules: List[str]
1300
  base_model_prefix = "model"
1301
  supports_gradient_checkpointing = True
 
1315
  module.weight.data[module.padding_idx].zero_()
1316
 
1317
 
1318
+ class Plamo2Model(Plamo2PreTrainedModel):
1319
+ def __init__(self, config: Plamo2Config):
1320
  super().__init__(config)
1321
  assert config.eval_attention_n_bit is None
1322
  assert config.eval_mlp_n_bit is None
 
1332
  self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1333
  else:
1334
  raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1335
+ self.layers = Plamo2Decoder(config) # type: ignore
1336
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1337
 
1338
  self.gradient_checkpointing = False
 
1387
  input_ids: Optional[torch.LongTensor] = None,
1388
  attention_mask: Optional[torch.Tensor] = None,
1389
  position_ids: Optional[torch.Tensor] = None,
1390
+ past_key_values: Optional[Plamo2Cache] = None,
1391
  inputs_embeds: Optional[torch.Tensor] = None,
1392
  image_features: Optional[torch.Tensor] = None,
1393
  use_cache: Optional[bool] = None,
1394
  output_attentions: Optional[bool] = None,
1395
  output_hidden_states: Optional[bool] = None,
1396
  return_dict: Optional[bool] = None,
1397
+ cache_position: Optional[torch.LongTensor] = None,
1398
+ **kwargs: Any,
1399
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
1400
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1401
  output_hidden_states = (
1402
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1406
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1407
 
1408
  # retrieve input_ids and inputs_embeds
1409
+ if (input_ids is None) ^ (inputs_embeds is not None):
1410
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1411
+
1412
+ if self.gradient_checkpointing and self.training and use_cache:
1413
+ use_cache = False
1414
+
1415
+ if inputs_embeds is None:
1416
+ inputs_embeds = self.embed_tokens(input_ids)
1417
+ batch_size, seq_length, _ = inputs_embeds.shape
1418
 
1419
  seq_length_with_past = seq_length
1420
  past_key_values_length = 0
 
1421
  if past_key_values is not None:
1422
  past_key_values_length = past_key_values.get_seq_length()
1423
  seq_length_with_past = seq_length_with_past + past_key_values_length
1424
+ assert cache_position is None, "cache_position is not supported yet"
 
 
1425
 
1426
  if image_features is not None:
1427
  assert self.config.image_token_id is not None
 
1447
 
1448
  hidden_states = inputs_embeds
1449
 
 
 
 
 
1450
  if use_cache and past_key_values is None:
1451
+ past_key_values = Plamo2Cache(self.config)
1452
 
1453
  # decoder layers
1454
  out = self.layers(
 
1485
  )
1486
 
1487
 
1488
+ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1489
  _tied_weights_keys = ["lm_head.weight"]
1490
 
1491
  # Without this, the model cannot be loaded into a meta device.
 
1495
  # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1496
  _supports_param_buffer_assignment = False
1497
 
1498
+ def __init__(self, config: Plamo2Config) -> None:
1499
  super().__init__(config)
1500
+ self.model = Plamo2Model(config)
1501
 
1502
  self.vocab_size = config.vocab_size
1503
  vocab_size = ((self.vocab_size + 15) // 16) * 16
 
1518
  def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1519
  self.lm_head = new_embeddings
1520
 
1521
+ def set_decoder(self, decoder: Plamo2Model) -> None:
1522
  self.model = decoder
1523
 
1524
+ def get_decoder(self) -> Plamo2Model:
1525
  return self.model
1526
 
1527
  def forward( # type: ignore
 
1529
  input_ids: Optional[torch.LongTensor] = None,
1530
  attention_mask: Optional[torch.Tensor] = None,
1531
  position_ids: Optional[torch.Tensor] = None,
1532
+ past_key_values: Optional[Plamo2Cache] = None,
1533
  inputs_embeds: Optional[torch.FloatTensor] = None,
1534
  image_features: Optional[torch.Tensor] = None,
1535
  labels: Optional[torch.LongTensor] = None,
 
1537
  output_attentions: Optional[bool] = None,
1538
  output_hidden_states: Optional[bool] = None,
1539
  return_dict: Optional[bool] = None,
1540
+ cache_position: Optional[torch.LongTensor] = None,
1541
+ logits_to_keep: int | torch.Tensor = 0,
1542
+ **kwargs: Any,
1543
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1544
  r"""
1545
  Args:
 
1566
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1567
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1568
  ```"""
 
 
1569
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1570
  output_hidden_states = (
1571
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1584
  output_attentions=output_attentions,
1585
  output_hidden_states=output_hidden_states,
1586
  return_dict=return_dict,
1587
+ cache_position=cache_position,
1588
+ **kwargs,
1589
  )
1590
 
1591
  hidden_states = outputs[0]
1592
  logits = self.lm_head(hidden_states)
1593
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1594
+ logits = logits[:, slice_indices, : self.vocab_size]
1595
 
1596
  loss = None
1597
  if labels is not None:
1598
+ if len(kwargs) > 0 and set(kwargs.keys()) != set(["ignore_index"]):
1599
+ warnings.warn(
1600
+ f"The following kwargs may not be supported: {', '.join(kwargs.keys())}. ",
1601
+ stacklevel=2,
1602
+ )
1603
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
 
 
 
 
1604
 
1605
  if not return_dict:
1606
  output = (logits,) + outputs[1:]
 
1617
  def prepare_inputs_for_generation(
1618
  self,
1619
  input_ids: torch.Tensor,
1620
+ past_key_values: Optional[Plamo2Cache] = None,
1621
  attention_mask: Optional[torch.Tensor] = None,
1622
  inputs_embeds: Optional[torch.Tensor] = None,
1623
  image_features: Optional[torch.Tensor] = None,
 
1654
  return model_inputs
1655
 
1656
  @staticmethod
1657
+ def _reorder_cache(past_key_values: Plamo2Cache, beam_idx: torch.Tensor) -> Plamo2Cache:
1658
  past_key_values.reorder_cache(beam_idx)
1659
  return past_key_values
1660
 
1661
 
1662
  class MLPImageProjector(nn.Module):
1663
+ def __init__(self, config: Plamo2Config) -> None:
1664
  super().__init__()
1665
  self.config = config
1666
 
tokenization_plamo.py CHANGED
@@ -237,7 +237,7 @@ class AhoCorasick:
237
  return [self._tokens[token_id] for token_id in self.encode(data)]
238
 
239
 
240
- class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
241
  vocab_files_names = VOCAB_FILES_NAMES
242
  model_input_names = ["input_ids", "attention_mask"]
243
 
 
237
  return [self._tokens[token_id] for token_id in self.encode(data)]
238
 
239
 
240
+ class Plamo2Tokenizer(PreTrainedTokenizer): # type: ignore
241
  vocab_files_names = VOCAB_FILES_NAMES
242
  model_input_names = ["input_ids", "attention_mask"]
243
 
tokenizer_config.json CHANGED
@@ -1,55 +1,55 @@
1
  {
2
- "add_bos_token": true,
3
- "add_eos_token": false,
4
- "added_tokens_decoder": {
5
- "0": {
6
- "content": "<|plamo:unk|>",
7
- "lstrip": false,
8
- "normalized": false,
9
- "rstrip": false,
10
- "single_word": false,
11
- "special": true
12
- },
13
- "1": {
14
- "content": "<|plamo:bos|>",
15
- "lstrip": false,
16
- "normalized": false,
17
- "rstrip": false,
18
- "single_word": false,
19
- "special": true
20
- },
21
- "2": {
22
- "content": "<|plamo:eos|>",
23
- "lstrip": false,
24
- "normalized": false,
25
- "rstrip": false,
26
- "single_word": false,
27
- "special": true
28
- },
29
- "3": {
30
- "content": "<|plamo:pad|>",
31
- "lstrip": false,
32
- "normalized": false,
33
- "rstrip": false,
34
- "single_word": false,
35
- "special": true
36
- }
37
  },
38
- "auto_map": {
39
- "AutoTokenizer": [
40
- "tokenization_plamo.PlamoTokenizer",
41
- null
42
- ]
 
 
43
  },
44
- "bos_token": "<|plamo:bos|>",
45
- "clean_up_tokenization_spaces": false,
46
- "cls_token": null,
47
- "eos_token": "<|plamo:eos|>",
48
- "local_file_only": true,
49
- "mask_token": null,
50
- "model_max_length": 1000000000000000019884624838656,
51
- "pad_token": "<|plamo:pad|>",
52
- "sep_token": null,
53
- "tokenizer_class": "PlamoTokenizer",
54
- "unk_token": "<|plamo:unk|>"
55
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
  },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.Plamo2Tokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|plamo:eos|>",
48
+ "local_file_only": true,
49
+ "mask_token": null,
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "pad_token": "<|plamo:pad|>",
52
+ "sep_token": null,
53
+ "tokenizer_class": "Plamo2Tokenizer",
54
+ "unk_token": "<|plamo:unk|>"
55
+ }