Update modeling_motif.py
Browse files- modeling_motif.py +1 -20
modeling_motif.py
CHANGED
@@ -51,27 +51,8 @@ class PolyNorm(torch.nn.Module):
|
|
51 |
return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
|
52 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
53 |
|
54 |
-
class PolyNorm_Test(torch.nn.Module):
|
55 |
-
"""
|
56 |
-
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
57 |
-
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
|
58 |
-
with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
|
59 |
-
"""
|
60 |
-
|
61 |
-
def __init__(self, eps=1e-6):
|
62 |
-
super(PolyNorm_Test, self).__init__()
|
63 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
64 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
65 |
-
self.eps = eps
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
|
69 |
-
#return torch.nn.SiLU(x)
|
70 |
-
return moreh_ops.poly_norm(x, self.weight, self.bias)
|
71 |
-
|
72 |
-
|
73 |
-
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm, "poly_norm_test": PolyNorm_Test}
|
74 |
|
|
|
75 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
76 |
ACT2FN = ClassInstantier(ACT2CLS)
|
77 |
|
|
|
51 |
return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
|
52 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
|
56 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
57 |
ACT2FN = ClassInstantier(ACT2CLS)
|
58 |
|