Update ts_generation_mixin.py
Browse files- ts_generation_mixin.py +7 -4
ts_generation_mixin.py
CHANGED
|
@@ -13,7 +13,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 13 |
|
| 14 |
def _greedy_search(
|
| 15 |
self,
|
| 16 |
-
input_ids: torch.
|
| 17 |
logits_processor: Optional[LogitsProcessorList] = None,
|
| 18 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 19 |
max_length: Optional[int] = None,
|
|
@@ -27,7 +27,11 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 27 |
synced_gpus: bool = False,
|
| 28 |
streamer: Optional["BaseStreamer"] = None,
|
| 29 |
**model_kwargs,
|
| 30 |
-
) -> Union[GenerateNonBeamOutput, torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# init values
|
| 32 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 33 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
@@ -82,7 +86,6 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 82 |
)
|
| 83 |
|
| 84 |
# keep track of which sequences are already finished
|
| 85 |
-
batch_size, cur_len = input_ids.shape
|
| 86 |
if "inputs_embeds" in model_kwargs:
|
| 87 |
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
| 88 |
this_peer_finished = False
|
|
@@ -189,7 +192,7 @@ class TSGenerationMixin(GenerationMixin):
|
|
| 189 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 190 |
)
|
| 191 |
else:
|
| 192 |
-
return input_ids
|
| 193 |
|
| 194 |
def _update_model_kwargs_for_generation(
|
| 195 |
self,
|
|
|
|
| 13 |
|
| 14 |
def _greedy_search(
|
| 15 |
self,
|
| 16 |
+
input_ids: torch.Tensor,
|
| 17 |
logits_processor: Optional[LogitsProcessorList] = None,
|
| 18 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 19 |
max_length: Optional[int] = None,
|
|
|
|
| 27 |
synced_gpus: bool = False,
|
| 28 |
streamer: Optional["BaseStreamer"] = None,
|
| 29 |
**model_kwargs,
|
| 30 |
+
) -> Union[GenerateNonBeamOutput, torch.Tensor]:
|
| 31 |
+
if len(input_ids.shape) == 2:
|
| 32 |
+
batch_size, cur_len = input_ids.shape
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError('Input shape must be: [batch_size, seq_len]')
|
| 35 |
# init values
|
| 36 |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 37 |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
# keep track of which sequences are already finished
|
|
|
|
| 89 |
if "inputs_embeds" in model_kwargs:
|
| 90 |
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
| 91 |
this_peer_finished = False
|
|
|
|
| 192 |
past_key_values=model_kwargs.get("past_key_values"),
|
| 193 |
)
|
| 194 |
else:
|
| 195 |
+
return input_ids.squeeze(dim=-1)
|
| 196 |
|
| 197 |
def _update_model_kwargs_for_generation(
|
| 198 |
self,
|