manueldeprada HF Staff RaushanTurganbay HF Staff commited on
Commit
245ec5a
·
verified ·
1 Parent(s): 1e37df0

Fix a bug (#1)

Browse files

- Fix a bug (8f1c855ae874bdc25b5bd121cdf592b3280ffd26)


Co-authored-by: Raushan Turganbay <[email protected]>

Files changed (1) hide show
  1. custom_generate/generate.py +3 -3
custom_generate/generate.py CHANGED
@@ -119,8 +119,8 @@ def _contrastive_search(
119
  logits_processor: LogitsProcessorList,
120
  stopping_criteria: StoppingCriteriaList,
121
  generation_config: GenerationConfig,
122
- synced_gpus: bool,
123
- streamer: Optional["BaseStreamer"],
124
  **model_kwargs,
125
  ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
126
  r"""
@@ -138,7 +138,7 @@ def _contrastive_search(
138
  used to tell if the generation loop should stop.
139
  generation_config ([`~generation.GenerationConfig`]):
140
  The generation configuration to be used as parametrization of the decoding method.
141
- synced_gpus (`bool`):
142
  Whether to continue running the while loop until max_length (needed to avoid deadlocking with
143
  `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
144
  streamer (`BaseStreamer`, *optional*):
 
119
  logits_processor: LogitsProcessorList,
120
  stopping_criteria: StoppingCriteriaList,
121
  generation_config: GenerationConfig,
122
+ synced_gpus: bool = False,
123
+ streamer: Optional["BaseStreamer"] = None,
124
  **model_kwargs,
125
  ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
126
  r"""
 
138
  used to tell if the generation loop should stop.
139
  generation_config ([`~generation.GenerationConfig`]):
140
  The generation configuration to be used as parametrization of the decoding method.
141
+ synced_gpus (`bool`, *optional*, defaults to `False`):
142
  Whether to continue running the while loop until max_length (needed to avoid deadlocking with
143
  `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
144
  streamer (`BaseStreamer`, *optional*):