from torch.autograd import Function class GradientReversal(Function): @staticmethod def forward(ctx, x, alpha): ctx.save_for_backward(x, alpha) return x @staticmethod def backward(ctx, grad_output): grad_input = None _, alpha = ctx.saved_tensors if ctx.needs_input_grad[0]: grad_input = - alpha*grad_output return grad_input, None revgrad = GradientReversal.apply