Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs, **kwargs): | |
| return module(*inputs, **kwargs) | |
| return custom_forward | |
| def gradient_checkpoint_forward( | |
| model, | |
| use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload, | |
| *args, | |
| **kwargs, | |
| ): | |
| if use_gradient_checkpointing_offload: | |
| with torch.autograd.graph.save_on_cpu(): | |
| model_output = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(model), | |
| *args, | |
| **kwargs, | |
| use_reentrant=False, | |
| ) | |
| elif use_gradient_checkpointing: | |
| model_output = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(model), | |
| *args, | |
| **kwargs, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| model_output = model(*args, **kwargs) | |
| return model_output | |