import activation import torch def norm(x, eps: float) -> torch.Tensor: return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) def poly_norm( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: x = x.float() return ( weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + weight[2] * norm(x, eps) + bias ).to(weight.dtype) dtype = torch.bfloat16 torch.set_default_device("cuda:0") a = torch.randn(3, 3, dtype=dtype, requires_grad=True) w = torch.randn(3, dtype=dtype, requires_grad=True) b = torch.randn(1, dtype=dtype, requires_grad=True) a.retain_grad() w.retain_grad() b.retain_grad() out = activation.poly_norm(a, w, b, 1e-6) # out = poly_norm(a, w, b, 1e-6) out.backward(torch.ones_like(out)) print(a.grad) print(w.grad) print(b.grad)