Torch compile + dynamo error

#11
by ParadiseN - opened

Hi, trying to use your torch compile example as it is with
accelerate==0.34.2
torch== 2.4.1
transformers==4.45.1

im getting torch compile error related to dynamo backend:

   raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
  File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1765, in forward
   outputs = self.model(
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
   return forward_call(*args, **kwargs)
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1616, in forward
   encoder_outputs = self.encoder(
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
   return forward_call(*args, **kwargs)
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1121, in forward
   layer_outputs = encoder_layer(
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
   return forward_call(*args, **kwargs)
 File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 689, in forward
   if hidden_states.dtype == torch.float16 and (

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
   import torch._dynamo
   torch._dynamo.config.suppress_errors = True

adding

torch._dynamo.config.suppress_errors = True

is not helping
So, i'm wondering if it needs more manipulations than provided in your example to run it with torch compile?

Also confirming this issue.

torch 2.4.1
accelerate 0.34.2
transformers 4.45.1

Same here with torch 2.4.1 and accelerate 0.34.2

Same issue under:

Torch: 2.4.1+cu121
accelerate: 1.0.0
transformers: 4.45.2

You can get to work with model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=False) but the speed-up nowhere near impressive (18.76s vs 14.45s). I used the code from https://github.com/sanchit-gandhi/notebooks/blob/main/whisper_compile.ipynb

Any updates?

Sign up or log in to comment