Fix a bug (#1)
Browse files- Fix a bug (8f1c855ae874bdc25b5bd121cdf592b3280ffd26)
Co-authored-by: Raushan Turganbay <[email protected]>
custom_generate/generate.py
CHANGED
|
@@ -119,8 +119,8 @@ def _contrastive_search(
|
|
| 119 |
logits_processor: LogitsProcessorList,
|
| 120 |
stopping_criteria: StoppingCriteriaList,
|
| 121 |
generation_config: GenerationConfig,
|
| 122 |
-
synced_gpus: bool,
|
| 123 |
-
streamer: Optional["BaseStreamer"],
|
| 124 |
**model_kwargs,
|
| 125 |
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 126 |
r"""
|
|
@@ -138,7 +138,7 @@ def _contrastive_search(
|
|
| 138 |
used to tell if the generation loop should stop.
|
| 139 |
generation_config ([`~generation.GenerationConfig`]):
|
| 140 |
The generation configuration to be used as parametrization of the decoding method.
|
| 141 |
-
synced_gpus (`bool`):
|
| 142 |
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 143 |
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 144 |
streamer (`BaseStreamer`, *optional*):
|
|
|
|
| 119 |
logits_processor: LogitsProcessorList,
|
| 120 |
stopping_criteria: StoppingCriteriaList,
|
| 121 |
generation_config: GenerationConfig,
|
| 122 |
+
synced_gpus: bool = False,
|
| 123 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 124 |
**model_kwargs,
|
| 125 |
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 126 |
r"""
|
|
|
|
| 138 |
used to tell if the generation loop should stop.
|
| 139 |
generation_config ([`~generation.GenerationConfig`]):
|
| 140 |
The generation configuration to be used as parametrization of the decoding method.
|
| 141 |
+
synced_gpus (`bool`, *optional*, defaults to `False`):
|
| 142 |
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 143 |
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 144 |
streamer (`BaseStreamer`, *optional*):
|