Upload 3 files
Browse filesImplemented processor.
- config.json +7 -5
- modeling_qwen2.py +175 -107
- processing_qwen2_ts.py +171 -0
config.json
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "
|
| 3 |
"architectures": [
|
| 4 |
"Qwen2TSForCausalLM"
|
| 5 |
],
|
|
@@ -7,7 +7,8 @@
|
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "configuration_qwen2.Qwen2TSConfig",
|
| 9 |
"AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
|
| 10 |
-
"AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM"
|
|
|
|
| 11 |
},
|
| 12 |
"bos_token_id": 151643,
|
| 13 |
"eos_token_id": 151645,
|
|
@@ -33,10 +34,11 @@
|
|
| 33 |
"hidden_size": 5120,
|
| 34 |
"num_features": 2,
|
| 35 |
"num_layers": 5,
|
| 36 |
-
"patch_size": 16
|
|
|
|
| 37 |
},
|
| 38 |
-
"ts_token_end_index":
|
| 39 |
-
"ts_token_start_index":
|
| 40 |
"use_cache": false,
|
| 41 |
"use_sliding_window": false,
|
| 42 |
"vocab_size": 152064
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "chatts_release",
|
| 3 |
"architectures": [
|
| 4 |
"Qwen2TSForCausalLM"
|
| 5 |
],
|
|
|
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "configuration_qwen2.Qwen2TSConfig",
|
| 9 |
"AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM",
|
| 11 |
+
"AutoProcessor": "processing_qwen2_ts.Qwen2TSProcessor"
|
| 12 |
},
|
| 13 |
"bos_token_id": 151643,
|
| 14 |
"eos_token_id": 151645,
|
|
|
|
| 34 |
"hidden_size": 5120,
|
| 35 |
"num_features": 2,
|
| 36 |
"num_layers": 5,
|
| 37 |
+
"patch_size": 16,
|
| 38 |
+
"max_length": 2048
|
| 39 |
},
|
| 40 |
+
"ts_token_end_index": 151666,
|
| 41 |
+
"ts_token_start_index": 151665,
|
| 42 |
"use_cache": false,
|
| 43 |
"use_sliding_window": false,
|
| 44 |
"vocab_size": 152064
|
modeling_qwen2.py
CHANGED
|
@@ -26,7 +26,7 @@
|
|
| 26 |
import inspect
|
| 27 |
import math
|
| 28 |
import copy
|
| 29 |
-
from typing import List, Optional, Tuple, Union
|
| 30 |
from dataclasses import dataclass
|
| 31 |
|
| 32 |
import torch
|
|
@@ -68,6 +68,44 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
|
| 68 |
_CONFIG_FOR_DOC = "Qwen2TSConfig"
|
| 69 |
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
########################Naive TS Embedding#####################
|
| 72 |
class TimeSeriesEmbedding(nn.Module):
|
| 73 |
def __init__(self, config):
|
|
@@ -1187,147 +1225,127 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1187 |
|
| 1188 |
def get_decoder(self):
|
| 1189 |
return self.model
|
| 1190 |
-
|
| 1191 |
-
def _get_real_length(self, timeseries, input_ids):
|
| 1192 |
-
# Return the embed length after inserting timeseries features
|
| 1193 |
-
if timeseries is None:
|
| 1194 |
-
return input_ids.size(1)
|
| 1195 |
-
|
| 1196 |
-
num_time_steps = timeseries.size(1) * timeseries.size(2) // self.config.ts['num_features']
|
| 1197 |
-
num_patches = num_time_steps // self.config.ts['patch_size']
|
| 1198 |
-
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
| 1199 |
-
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
| 1200 |
-
return num_special_ts_tokens * (num_patches - 2) + input_ids.size(1)
|
| 1201 |
-
|
| 1202 |
-
def _get_original_length(self, timeseries, input_ids, past_length):
|
| 1203 |
-
if timeseries is None:
|
| 1204 |
-
if isinstance(past_length, int):
|
| 1205 |
-
original_length = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
|
| 1206 |
-
else:
|
| 1207 |
-
original_length = past_length
|
| 1208 |
-
num_special_ts_tokens_within_past = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
|
| 1209 |
-
return original_length, num_special_ts_tokens_within_past
|
| 1210 |
-
|
| 1211 |
-
patch_size = self.config.ts['patch_size']
|
| 1212 |
-
num_patches = timeseries.size(1) * timeseries.size(2) // patch_size // self.config.ts['num_features']
|
| 1213 |
-
ts_token_start_index = self.config.ts_token_start_index
|
| 1214 |
-
|
| 1215 |
-
ts_mask = (input_ids == ts_token_start_index).long() # (batch_size, seq_length)
|
| 1216 |
-
|
| 1217 |
-
cumsum_ts = torch.cumsum(ts_mask, dim=1) # (batch_size, seq_length)
|
| 1218 |
-
|
| 1219 |
-
seq_length = input_ids.size(1)
|
| 1220 |
-
positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
|
| 1221 |
-
|
| 1222 |
-
transformed_length = positions + cumsum_ts * (num_patches - 2) # (batch_size, seq_length)
|
| 1223 |
-
|
| 1224 |
-
if isinstance(past_length, int):
|
| 1225 |
-
past_length_tensor = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
|
| 1226 |
-
else:
|
| 1227 |
-
past_length_tensor = past_length.to(input_ids.device)
|
| 1228 |
-
|
| 1229 |
-
mask = transformed_length <= past_length_tensor.unsqueeze(1) # (batch_size, seq_length)
|
| 1230 |
-
|
| 1231 |
-
original_length = torch.sum(mask, dim=1) # (batch_size,)
|
| 1232 |
-
original_positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
|
| 1233 |
-
original_mask = original_positions <= original_length.unsqueeze(1) # (batch_size, seq_length)
|
| 1234 |
-
ts_within_original_mask = ts_mask.bool() & original_mask.bool() # (batch_size, seq_length)
|
| 1235 |
-
num_special_ts_tokens_within_past = torch.sum(ts_within_original_mask, dim=1) # (batch_size,)
|
| 1236 |
-
|
| 1237 |
-
original_length = torch.clamp(original_length, min=0)
|
| 1238 |
-
|
| 1239 |
-
return original_length, num_special_ts_tokens_within_past
|
| 1240 |
-
|
| 1241 |
def _merge_input_ids_with_time_series_features(
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
total_time_steps, embed_dim = time_series_features.shape
|
| 1245 |
batch_size, sequence_length = input_ids.shape
|
|
|
|
|
|
|
| 1246 |
left_padding = False
|
| 1247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1248 |
# 1. Create a mask to know where special time series tokens are
|
| 1249 |
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
| 1250 |
special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
|
| 1251 |
special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
|
|
|
|
|
|
|
| 1252 |
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
|
|
|
|
|
|
| 1253 |
# Correctly calculate the total number of patches per batch
|
|
|
|
| 1254 |
num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
|
| 1255 |
special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
|
| 1256 |
special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
|
| 1257 |
-
|
|
|
|
| 1258 |
for i in range(batch_size):
|
| 1259 |
num_ts_in_batch = num_special_ts_tokens[i]
|
| 1260 |
-
num_total_patches[i] = patch_cnt[patch_index:patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
|
| 1261 |
for idx in range(patch_index, patch_index + num_ts_in_batch):
|
| 1262 |
-
|
| 1263 |
-
special_ts_token_mask_start_with_size[
|
| 1264 |
patch_index += num_ts_in_batch
|
| 1265 |
-
|
| 1266 |
-
|
|
|
|
| 1267 |
max_embed_dim = sequence_length + num_total_patches.max()
|
| 1268 |
-
|
| 1269 |
-
#
|
| 1270 |
batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
|
| 1271 |
|
| 1272 |
-
#
|
| 1273 |
new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
|
|
|
|
|
|
|
| 1274 |
nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
| 1275 |
if left_padding:
|
| 1276 |
-
new_token_positions += nb_ts_pad[:, None]
|
|
|
|
| 1277 |
text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
|
| 1278 |
-
|
| 1279 |
-
#
|
| 1280 |
final_embedding = torch.zeros(
|
| 1281 |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 1282 |
)
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1286 |
if labels is not None:
|
| 1287 |
final_labels = torch.full(
|
| 1288 |
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
| 1289 |
)
|
|
|
|
| 1290 |
target_device = inputs_embeds.device
|
| 1291 |
batch_indices, non_ts_indices, text_to_overwrite = (
|
| 1292 |
batch_indices.to(target_device),
|
| 1293 |
non_ts_indices.to(target_device),
|
| 1294 |
text_to_overwrite.to(target_device),
|
| 1295 |
)
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
# 4. Fill the embeddings based on the mask
|
| 1299 |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
|
| 1300 |
-
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_ts_indices]
|
| 1301 |
if labels is not None:
|
| 1302 |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
|
| 1303 |
-
|
| 1304 |
-
#
|
| 1305 |
ts_to_overwrite = torch.full(
|
| 1306 |
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
| 1307 |
)
|
| 1308 |
ts_to_overwrite[batch_indices, text_to_overwrite] = False
|
|
|
|
| 1309 |
reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
|
| 1310 |
ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
|
| 1311 |
-
|
|
|
|
| 1312 |
if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
|
| 1313 |
raise ValueError(
|
| 1314 |
f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
|
| 1315 |
f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
|
| 1316 |
)
|
| 1317 |
-
|
| 1318 |
final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
| 1319 |
-
|
|
|
|
| 1320 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
|
|
|
|
|
|
| 1327 |
|
| 1328 |
-
if labels is None:
|
| 1329 |
-
final_labels = None
|
| 1330 |
-
|
| 1331 |
return final_embedding, final_attention_mask, position_ids, final_labels
|
| 1332 |
|
| 1333 |
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
@@ -1382,10 +1400,8 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1382 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1383 |
|
| 1384 |
if timeseries is not None and timeseries.shape[0] > 0:
|
| 1385 |
-
#
|
| 1386 |
-
use_cache = False
|
| 1387 |
ts_features, patch_cnt = self.ts_encoder(timeseries)
|
| 1388 |
-
|
| 1389 |
inputs_embeds = inputs_embeds.to(ts_features.dtype)
|
| 1390 |
|
| 1391 |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
|
|
@@ -1424,14 +1440,63 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1424 |
output = (logits,) + outputs[1:]
|
| 1425 |
return (loss,) + output if loss is not None else output
|
| 1426 |
|
| 1427 |
-
|
|
|
|
| 1428 |
loss=loss,
|
| 1429 |
logits=logits,
|
| 1430 |
past_key_values=outputs.past_key_values,
|
| 1431 |
hidden_states=outputs.hidden_states,
|
| 1432 |
attentions=outputs.attentions,
|
|
|
|
| 1433 |
)
|
| 1434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1435 |
|
| 1436 |
def prepare_inputs_for_generation(
|
| 1437 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
|
|
@@ -1446,20 +1511,23 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1446 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1447 |
max_cache_length = None
|
| 1448 |
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1456 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
elif past_length < real_len:
|
| 1460 |
-
input_ids = input_ids[:, origin_past_len:]
|
| 1461 |
-
if timeseries is not None:
|
| 1462 |
-
timeseries = timeseries[past_num_ts:]
|
| 1463 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1464 |
|
| 1465 |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
@@ -1476,7 +1544,7 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
| 1476 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1477 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1478 |
if past_key_values:
|
| 1479 |
-
position_ids = position_ids[:, -input_ids.
|
| 1480 |
|
| 1481 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1482 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
|
| 26 |
import inspect
|
| 27 |
import math
|
| 28 |
import copy
|
| 29 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
| 30 |
from dataclasses import dataclass
|
| 31 |
|
| 32 |
import torch
|
|
|
|
| 68 |
_CONFIG_FOR_DOC = "Qwen2TSConfig"
|
| 69 |
|
| 70 |
|
| 71 |
+
@dataclass
|
| 72 |
+
class Qwen2TSCausalLMOutputWithPast(ModelOutput):
|
| 73 |
+
"""
|
| 74 |
+
Base class for Qwen2TS causal language model (or autoregressive) outputs.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 78 |
+
Language modeling loss (for next-token prediction).
|
| 79 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 80 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 81 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 82 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 83 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 84 |
+
|
| 85 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 86 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 87 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 88 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 89 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 90 |
+
|
| 91 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 92 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 93 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 94 |
+
sequence_length)`.
|
| 95 |
+
|
| 96 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 97 |
+
heads.
|
| 98 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 99 |
+
Attentions mask, used to update attention mask and position_ids.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
loss: Optional[torch.FloatTensor] = None
|
| 103 |
+
logits: torch.FloatTensor = None
|
| 104 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 105 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 106 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 107 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
| 108 |
+
|
| 109 |
########################Naive TS Embedding#####################
|
| 110 |
class TimeSeriesEmbedding(nn.Module):
|
| 111 |
def __init__(self, config):
|
|
|
|
| 1225 |
|
| 1226 |
def get_decoder(self):
|
| 1227 |
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1228 |
def _merge_input_ids_with_time_series_features(
|
| 1229 |
+
self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
|
| 1230 |
+
):
|
|
|
|
| 1231 |
batch_size, sequence_length = input_ids.shape
|
| 1232 |
+
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
| 1233 |
+
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
| 1234 |
left_padding = False
|
| 1235 |
+
if batch_size > 1:
|
| 1236 |
+
if _left_padding and not _right_padding:
|
| 1237 |
+
left_padding = True
|
| 1238 |
+
elif not _left_padding and _right_padding:
|
| 1239 |
+
left_padding = False
|
| 1240 |
+
elif not _left_padding and not _right_padding:
|
| 1241 |
+
left_padding = False
|
| 1242 |
+
else:
|
| 1243 |
+
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
| 1244 |
+
else:
|
| 1245 |
+
if _left_padding and not _right_padding:
|
| 1246 |
+
left_padding = True
|
| 1247 |
+
else:
|
| 1248 |
+
left_padding = False
|
| 1249 |
+
|
| 1250 |
# 1. Create a mask to know where special time series tokens are
|
| 1251 |
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
| 1252 |
special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
|
| 1253 |
special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
|
| 1254 |
+
|
| 1255 |
+
# 2. Calculate patch count
|
| 1256 |
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
| 1257 |
+
total_time_steps, embed_dim = time_series_features.shape
|
| 1258 |
+
|
| 1259 |
# Correctly calculate the total number of patches per batch
|
| 1260 |
+
patch_index = 0
|
| 1261 |
num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
|
| 1262 |
special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
|
| 1263 |
special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
|
| 1264 |
+
|
| 1265 |
+
attn_mask_cnt = attention_mask.sum(dim=-1)
|
| 1266 |
for i in range(batch_size):
|
| 1267 |
num_ts_in_batch = num_special_ts_tokens[i]
|
| 1268 |
+
num_total_patches[i] = patch_cnt[patch_index : patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
|
| 1269 |
for idx in range(patch_index, patch_index + num_ts_in_batch):
|
| 1270 |
+
b_idx, pos = special_ts_token_mask_start_nonzero[idx]
|
| 1271 |
+
special_ts_token_mask_start_with_size[b_idx, pos] *= (patch_cnt[idx].item() - 2)
|
| 1272 |
patch_index += num_ts_in_batch
|
| 1273 |
+
attn_mask_cnt[i] += num_total_patches[i].item()
|
| 1274 |
+
|
| 1275 |
+
# 3. Embeding length
|
| 1276 |
max_embed_dim = sequence_length + num_total_patches.max()
|
| 1277 |
+
|
| 1278 |
+
# 4. Non ts tokens
|
| 1279 |
batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
|
| 1280 |
|
| 1281 |
+
# 5. Text token in final text positions
|
| 1282 |
new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
|
| 1283 |
+
|
| 1284 |
+
# nb_ts_pad
|
| 1285 |
nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
| 1286 |
if left_padding:
|
| 1287 |
+
new_token_positions += nb_ts_pad[:, None]
|
| 1288 |
+
|
| 1289 |
text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
|
| 1290 |
+
|
| 1291 |
+
# 6. Final embedding and attention masks
|
| 1292 |
final_embedding = torch.zeros(
|
| 1293 |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 1294 |
)
|
| 1295 |
+
|
| 1296 |
+
final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device)
|
| 1297 |
+
for i in range(attention_mask.size(0)):
|
| 1298 |
+
if left_padding:
|
| 1299 |
+
final_attention_mask[i, max_embed_dim - attn_mask_cnt[i] :] = 1
|
| 1300 |
+
else:
|
| 1301 |
+
final_attention_mask[i, : attn_mask_cnt[i]] = 1
|
| 1302 |
+
|
| 1303 |
+
final_labels = None
|
| 1304 |
if labels is not None:
|
| 1305 |
final_labels = torch.full(
|
| 1306 |
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
| 1307 |
)
|
| 1308 |
+
|
| 1309 |
target_device = inputs_embeds.device
|
| 1310 |
batch_indices, non_ts_indices, text_to_overwrite = (
|
| 1311 |
batch_indices.to(target_device),
|
| 1312 |
non_ts_indices.to(target_device),
|
| 1313 |
text_to_overwrite.to(target_device),
|
| 1314 |
)
|
| 1315 |
+
|
| 1316 |
+
# 7. Move embedding and labels to final positions
|
|
|
|
| 1317 |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
|
|
|
|
| 1318 |
if labels is not None:
|
| 1319 |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
|
| 1320 |
+
|
| 1321 |
+
# 8. Move time series to final positions
|
| 1322 |
ts_to_overwrite = torch.full(
|
| 1323 |
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
| 1324 |
)
|
| 1325 |
ts_to_overwrite[batch_indices, text_to_overwrite] = False
|
| 1326 |
+
|
| 1327 |
reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
|
| 1328 |
ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
|
| 1329 |
+
|
| 1330 |
+
# Check that the number of time series tokens is correct
|
| 1331 |
if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
|
| 1332 |
raise ValueError(
|
| 1333 |
f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
|
| 1334 |
f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
|
| 1335 |
)
|
|
|
|
| 1336 |
final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
| 1337 |
+
|
| 1338 |
+
# 9. Calculate position ids
|
| 1339 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
| 1340 |
+
if position_ids.size(-1) < input_ids.size(-1):
|
| 1341 |
+
position_ids = position_ids[:, -input_ids.size(-1) :]
|
| 1342 |
+
|
| 1343 |
+
# 10. Move attention mask to final positions
|
| 1344 |
+
pad_batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id)
|
| 1345 |
+
if len(pad_batch_indices) > 0:
|
| 1346 |
+
indices_to_mask = new_token_positions[pad_batch_indices, pad_indices]
|
| 1347 |
+
final_embedding[pad_batch_indices, indices_to_mask] = 0
|
| 1348 |
|
|
|
|
|
|
|
|
|
|
| 1349 |
return final_embedding, final_attention_mask, position_ids, final_labels
|
| 1350 |
|
| 1351 |
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
|
|
| 1400 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1401 |
|
| 1402 |
if timeseries is not None and timeseries.shape[0] > 0:
|
| 1403 |
+
# use_cache = False
|
|
|
|
| 1404 |
ts_features, patch_cnt = self.ts_encoder(timeseries)
|
|
|
|
| 1405 |
inputs_embeds = inputs_embeds.to(ts_features.dtype)
|
| 1406 |
|
| 1407 |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
|
|
|
|
| 1440 |
output = (logits,) + outputs[1:]
|
| 1441 |
return (loss,) + output if loss is not None else output
|
| 1442 |
|
| 1443 |
+
|
| 1444 |
+
return Qwen2TSCausalLMOutputWithPast(
|
| 1445 |
loss=loss,
|
| 1446 |
logits=logits,
|
| 1447 |
past_key_values=outputs.past_key_values,
|
| 1448 |
hidden_states=outputs.hidden_states,
|
| 1449 |
attentions=outputs.attentions,
|
| 1450 |
+
attention_mask=attention_mask
|
| 1451 |
)
|
| 1452 |
|
| 1453 |
+
def _update_model_kwargs_for_generation(
|
| 1454 |
+
self,
|
| 1455 |
+
outputs: ModelOutput,
|
| 1456 |
+
model_kwargs: Dict[str, Any],
|
| 1457 |
+
is_encoder_decoder: bool = False,
|
| 1458 |
+
num_new_tokens: int = 1,
|
| 1459 |
+
) -> Dict[str, Any]:
|
| 1460 |
+
# update past_key_values keeping its naming used in model code
|
| 1461 |
+
cache_name, cache = self._extract_past_from_model_output(outputs)
|
| 1462 |
+
model_kwargs[cache_name] = cache
|
| 1463 |
+
if getattr(outputs, "state", None) is not None:
|
| 1464 |
+
model_kwargs["state"] = outputs.state
|
| 1465 |
+
|
| 1466 |
+
# update attention_mask
|
| 1467 |
+
if getattr(outputs, "attention_mask", None) is not None:
|
| 1468 |
+
model_kwargs["attention_mask"] = outputs.attention_mask
|
| 1469 |
+
|
| 1470 |
+
# update token_type_ids with last value
|
| 1471 |
+
if "token_type_ids" in model_kwargs:
|
| 1472 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
| 1473 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
| 1474 |
+
|
| 1475 |
+
if not is_encoder_decoder:
|
| 1476 |
+
# update attention mask
|
| 1477 |
+
if "attention_mask" in model_kwargs:
|
| 1478 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 1479 |
+
model_kwargs["attention_mask"] = torch.cat(
|
| 1480 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 1481 |
+
)
|
| 1482 |
+
else:
|
| 1483 |
+
# update decoder attention mask
|
| 1484 |
+
if "decoder_attention_mask" in model_kwargs:
|
| 1485 |
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
| 1486 |
+
model_kwargs["decoder_attention_mask"] = torch.cat(
|
| 1487 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
| 1488 |
+
dim=-1,
|
| 1489 |
+
)
|
| 1490 |
+
|
| 1491 |
+
if model_kwargs.get("use_cache", True):
|
| 1492 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
| 1493 |
+
else:
|
| 1494 |
+
past_positions = model_kwargs.pop("cache_position")
|
| 1495 |
+
new_positions = torch.arange(
|
| 1496 |
+
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
| 1497 |
+
).to(past_positions.device)
|
| 1498 |
+
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
| 1499 |
+
return model_kwargs
|
| 1500 |
|
| 1501 |
def prepare_inputs_for_generation(
|
| 1502 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
|
|
|
|
| 1511 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1512 |
max_cache_length = None
|
| 1513 |
|
| 1514 |
+
has_ts = timeseries is not None and len(timeseries) > 0
|
| 1515 |
+
|
| 1516 |
+
if has_ts and kwargs.get("attention_mask") is not None:
|
| 1517 |
+
attention_mask = kwargs["attention_mask"]
|
| 1518 |
+
attention_mask = torch.cat(
|
| 1519 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
# Set attention mask and input_ids
|
| 1523 |
+
if has_ts and past_length > 0:
|
| 1524 |
+
# We have only one token added and timeseries are already inferenced
|
| 1525 |
+
input_ids = input_ids[:, -1:]
|
| 1526 |
+
timeseries = None
|
| 1527 |
+
elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1528 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1529 |
+
elif past_length < input_ids.shape[1]:
|
| 1530 |
+
input_ids = input_ids[:, past_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1531 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1532 |
|
| 1533 |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
|
|
| 1544 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1545 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1546 |
if past_key_values:
|
| 1547 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1548 |
|
| 1549 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1550 |
if inputs_embeds is not None and past_key_values is None:
|
processing_qwen2_ts.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Tsinghua University and ByteDance.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the MIT License (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://opensource.org/license/mit
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
from typing import List, Union, Tuple, Optional
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 21 |
+
from transformers.processing_utils import ProcessorMixin
|
| 22 |
+
from transformers.tokenization_utils_base import (
|
| 23 |
+
PreTokenizedInput,
|
| 24 |
+
TextInput,
|
| 25 |
+
PaddingStrategy,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def sp_encoding(timeseries: np.ndarray, eots_token: bool = True) -> Tuple[np.ndarray, str, dict]:
|
| 29 |
+
"""
|
| 30 |
+
Encodes a time series with scalar normalization.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
timeseries (np.ndarray): The raw time series data (1D or 2D).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
result_timeseries (np.ndarray): The encoded time series, shape [seq_len, 1].
|
| 37 |
+
prompt (str): The placeholder string with offset and scaling info.
|
| 38 |
+
metadata (dict): Metadata containing the offset and scaling factor.
|
| 39 |
+
"""
|
| 40 |
+
mean = np.mean(timeseries)
|
| 41 |
+
scaled_timeseries = timeseries - mean
|
| 42 |
+
scale_factor = 1.0
|
| 43 |
+
if np.any(np.abs(scaled_timeseries) >= 3.0):
|
| 44 |
+
scale_factor = np.max(np.abs(scaled_timeseries)) / 3.0
|
| 45 |
+
scaled_timeseries /= scale_factor
|
| 46 |
+
|
| 47 |
+
prompt = f"[Value Offset: {-mean:.4f}|Value Scaling: {scale_factor:.4f}]<ts>"
|
| 48 |
+
if eots_token:
|
| 49 |
+
prompt += '<ts/>'
|
| 50 |
+
|
| 51 |
+
result_timeseries = np.stack([scaled_timeseries, np.ones_like(scaled_timeseries)], axis=-1).reshape(-1, 1)
|
| 52 |
+
|
| 53 |
+
return result_timeseries, prompt, {"offset": float(-mean), "scale_factor": float(scale_factor)}
|
| 54 |
+
|
| 55 |
+
class Qwen2TSProcessor(ProcessorMixin):
|
| 56 |
+
"""
|
| 57 |
+
A processor for ChatTS that integrates text prompt processing and time series encoding.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
attributes = ["tokenizer"]
|
| 61 |
+
feature_extractor_class = None # You can add a feature extractor if needed
|
| 62 |
+
tokenizer_class = "AutoTokenizer"
|
| 63 |
+
|
| 64 |
+
def __init__(self, tokenizer=None):
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
tokenizer: An optional tokenizer to process text prompts.
|
| 68 |
+
"""
|
| 69 |
+
super().__init__(tokenizer=tokenizer)
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self,
|
| 73 |
+
text: List[str],
|
| 74 |
+
timeseries: List[List[np.ndarray]],
|
| 75 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 76 |
+
padding_side: str = 'left',
|
| 77 |
+
vllm_flag: bool = False,
|
| 78 |
+
**kwargs,
|
| 79 |
+
) -> BatchFeature:
|
| 80 |
+
"""
|
| 81 |
+
Encodes a prompt and its associated time series.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
prompt (List[str]): The input prompt containing <ts><ts/> placeholders.
|
| 85 |
+
timeseries (List[np.ndarray]): A list of time series matched to placeholders in the prompt.
|
| 86 |
+
padding (bool or str or PaddingStrategy, optional): Passed to the tokenizer for text padding.
|
| 87 |
+
return_tensors (str, optional): "pt" to return PyTorch tensors; None to return NumPy arrays.
|
| 88 |
+
**kwargs: Additional tokenizer parameters.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
BatchFeature: Contains processed prompt, encoded time series, and tokenizer outputs.
|
| 92 |
+
"""
|
| 93 |
+
if type(text) == str:
|
| 94 |
+
text = [text]
|
| 95 |
+
|
| 96 |
+
encoded_ts_arrays = []
|
| 97 |
+
reconstructed_prompts = []
|
| 98 |
+
total_ts_cnt = 0
|
| 99 |
+
for idx, prompt in enumerate(text):
|
| 100 |
+
# Split prompt by <ts><ts/> placeholders
|
| 101 |
+
last_ts_cnt = total_ts_cnt
|
| 102 |
+
prompt_segments = prompt.split("<ts><ts/>")
|
| 103 |
+
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
| 104 |
+
|
| 105 |
+
# Encode each time series and rebuild the prompt
|
| 106 |
+
reconstructed_prompt = prompt_segments[0]
|
| 107 |
+
|
| 108 |
+
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
| 109 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
| 110 |
+
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
| 111 |
+
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
| 112 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
| 113 |
+
|
| 114 |
+
reconstructed_prompts.append(reconstructed_prompt)
|
| 115 |
+
|
| 116 |
+
if len(timeseries) != len(encoded_ts_arrays):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
| 119 |
+
f"and time series ({len(encoded_ts_arrays)})."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if len(encoded_ts_arrays) > 0:
|
| 123 |
+
# Pad time series to the same length
|
| 124 |
+
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
| 125 |
+
padded_ts_arrays = [
|
| 126 |
+
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
| 127 |
+
for ts in encoded_ts_arrays
|
| 128 |
+
]
|
| 129 |
+
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
| 130 |
+
|
| 131 |
+
# Convert to torch
|
| 132 |
+
concatenated_ts = torch.from_numpy(concatenated_ts).half()
|
| 133 |
+
else:
|
| 134 |
+
concatenated_ts = None
|
| 135 |
+
|
| 136 |
+
# Tokenize the processed prompt
|
| 137 |
+
tokenizer_outputs = {}
|
| 138 |
+
if self.tokenizer is not None:
|
| 139 |
+
tokenizer_outputs = self.tokenizer(reconstructed_prompts, padding=padding, padding_side=padding_side, **kwargs)
|
| 140 |
+
|
| 141 |
+
# Create the final output
|
| 142 |
+
outputs = {
|
| 143 |
+
"timeseries": concatenated_ts
|
| 144 |
+
}
|
| 145 |
+
outputs.update(tokenizer_outputs)
|
| 146 |
+
|
| 147 |
+
return BatchFeature(data=outputs)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def model_input_names(self):
|
| 151 |
+
"""
|
| 152 |
+
Define the input names expected by the model.
|
| 153 |
+
"""
|
| 154 |
+
tokenizer_input_names = []
|
| 155 |
+
if self.tokenizer and hasattr(self.tokenizer, "model_input_names"):
|
| 156 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 157 |
+
return list(dict.fromkeys(["processed_prompt", "time_series"] + tokenizer_input_names))
|
| 158 |
+
|
| 159 |
+
def batch_decode(self, *args, **kwargs):
|
| 160 |
+
"""
|
| 161 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 162 |
+
refer to the docstring of this method for more information.
|
| 163 |
+
"""
|
| 164 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 165 |
+
|
| 166 |
+
def decode(self, *args, **kwargs):
|
| 167 |
+
"""
|
| 168 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 169 |
+
the docstring of this method for more information.
|
| 170 |
+
"""
|
| 171 |
+
return self.tokenizer.decode(*args, **kwargs)
|