Update processing_qwen2_ts.py to work with the latest vllm patch for ChatTS support. (#16)
Browse files- Update processing_qwen2_ts.py to work with the latest vllm patch for ChatTS support. (d9c80001adf853e6d9c275b6e981b8d352ee0e5f)
Co-authored-by: Alexander Chemeris <[email protected]>
- processing_qwen2_ts.py +55 -36
processing_qwen2_ts.py
CHANGED
|
@@ -91,45 +91,62 @@ class Qwen2TSProcessor(ProcessorMixin):
|
|
| 91 |
if timeseries is None:
|
| 92 |
timeseries = []
|
| 93 |
|
| 94 |
-
encoded_ts_arrays = []
|
| 95 |
reconstructed_prompts = []
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# Split prompt by <ts><ts/> placeholders
|
| 99 |
-
last_ts_cnt = total_ts_cnt
|
| 100 |
-
prompt_segments = prompt.split("<ts><ts/>")
|
| 101 |
-
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
| 102 |
-
|
| 103 |
-
# Encode each time series and rebuild the prompt
|
| 104 |
-
reconstructed_prompt = prompt_segments[0]
|
| 105 |
-
|
| 106 |
-
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
| 107 |
-
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
| 108 |
-
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
| 109 |
-
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
| 110 |
-
encoded_ts_arrays.append(encoded_ts[None, ...])
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
| 117 |
-
f"and time series ({len(encoded_ts_arrays)})."
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
if len(encoded_ts_arrays) > 0:
|
| 121 |
-
# Pad time series to the same length
|
| 122 |
-
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
| 123 |
-
padded_ts_arrays = [
|
| 124 |
-
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
| 125 |
-
for ts in encoded_ts_arrays
|
| 126 |
-
]
|
| 127 |
-
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
else:
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
# Tokenize the processed prompt
|
| 135 |
tokenizer_outputs = {}
|
|
@@ -138,7 +155,9 @@ class Qwen2TSProcessor(ProcessorMixin):
|
|
| 138 |
|
| 139 |
# Create the final output
|
| 140 |
outputs = tokenizer_outputs
|
| 141 |
-
if
|
|
|
|
|
|
|
| 142 |
outputs["timeseries"] = concatenated_ts
|
| 143 |
|
| 144 |
return BatchFeature(data=outputs)
|
|
|
|
| 91 |
if timeseries is None:
|
| 92 |
timeseries = []
|
| 93 |
|
|
|
|
| 94 |
reconstructed_prompts = []
|
| 95 |
+
concatenated_ts = None
|
| 96 |
+
ts_tokens = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
if vllm_flag:
|
| 99 |
+
# All prompt modifications have to be done inside of the vLLM
|
| 100 |
+
# to work correctly with its caching mechanism.
|
| 101 |
+
reconstructed_prompts = text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
# Process timeseries data
|
| 104 |
+
encoded_ts_arrays = []
|
| 105 |
+
for ts in timeseries:
|
| 106 |
+
# Get the normalized data and prompt text
|
| 107 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=False)
|
| 108 |
+
# Tokenize the ts_prompt and add to the tokens list
|
| 109 |
+
if self.tokenizer is not None:
|
| 110 |
+
tokens = self.tokenizer.encode(ts_prompt, add_special_tokens=False)
|
| 111 |
+
ts_tokens.append(tokens)
|
| 112 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
| 113 |
else:
|
| 114 |
+
encoded_ts_arrays = []
|
| 115 |
+
total_ts_cnt = 0
|
| 116 |
+
for idx, prompt in enumerate(text):
|
| 117 |
+
# Split prompt by <ts><ts/> placeholders
|
| 118 |
+
last_ts_cnt = total_ts_cnt
|
| 119 |
+
prompt_segments = prompt.split("<ts><ts/>")
|
| 120 |
+
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
| 121 |
+
|
| 122 |
+
# Encode each time series and rebuild the prompt
|
| 123 |
+
reconstructed_prompt = prompt_segments[0]
|
| 124 |
+
|
| 125 |
+
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
| 126 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
| 127 |
+
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
| 128 |
+
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
| 129 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
| 130 |
+
|
| 131 |
+
reconstructed_prompts.append(reconstructed_prompt)
|
| 132 |
+
|
| 133 |
+
if len(timeseries) != len(encoded_ts_arrays):
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
| 136 |
+
f"and time series ({len(encoded_ts_arrays)})."
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if len(encoded_ts_arrays) > 0:
|
| 140 |
+
# Pad time series to the same length
|
| 141 |
+
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
| 142 |
+
padded_ts_arrays = [
|
| 143 |
+
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
| 144 |
+
for ts in encoded_ts_arrays
|
| 145 |
+
]
|
| 146 |
+
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
| 147 |
+
|
| 148 |
+
# Convert to torch
|
| 149 |
+
concatenated_ts = torch.from_numpy(concatenated_ts).half()
|
| 150 |
|
| 151 |
# Tokenize the processed prompt
|
| 152 |
tokenizer_outputs = {}
|
|
|
|
| 155 |
|
| 156 |
# Create the final output
|
| 157 |
outputs = tokenizer_outputs
|
| 158 |
+
if vllm_flag:
|
| 159 |
+
outputs["timeseries"] = zip(ts_tokens, encoded_ts_arrays)
|
| 160 |
+
elif concatenated_ts is not None:
|
| 161 |
outputs["timeseries"] = concatenated_ts
|
| 162 |
|
| 163 |
return BatchFeature(data=outputs)
|