danieldk HF Staff commited on
Commit
e99cc09
·
1 Parent(s): b0e5c39

Mark as torch.compile being supported

Browse files
Files changed (1) hide show
  1. torch-ext/activation/layers.py +14 -0
torch-ext/activation/layers.py CHANGED
@@ -5,6 +5,8 @@ from ._ops import ops
5
 
6
 
7
  class SiluAndMul(nn.Module):
 
 
8
  def forward(self, x: torch.Tensor):
9
  d = x.shape[-1] // 2
10
  output_shape = x.shape[:-1] + (d,)
@@ -14,6 +16,8 @@ class SiluAndMul(nn.Module):
14
 
15
 
16
  class GeluAndMul(nn.Module):
 
 
17
  def forward(self, x: torch.Tensor):
18
  d = x.shape[-1] // 2
19
  output_shape = x.shape[:-1] + (d,)
@@ -23,6 +27,8 @@ class GeluAndMul(nn.Module):
23
 
24
 
25
  class GeluTanhAndMul(nn.Module):
 
 
26
  def forward(self, x: torch.Tensor):
27
  d = x.shape[-1] // 2
28
  output_shape = x.shape[:-1] + (d,)
@@ -32,6 +38,8 @@ class GeluTanhAndMul(nn.Module):
32
 
33
 
34
  class FatreluAndMul(nn.Module):
 
 
35
  def __init__(self, threshold: float = 0.0):
36
  super().__init__()
37
  self.threshold = threshold
@@ -45,6 +53,8 @@ class FatreluAndMul(nn.Module):
45
 
46
 
47
  class FastGELU(nn.Module):
 
 
48
  def forward(self, x: torch.Tensor) -> torch.Tensor:
49
  out = torch.empty_like(x)
50
  ops.gelu_fast(out, x)
@@ -52,6 +62,8 @@ class FastGELU(nn.Module):
52
 
53
 
54
  class NewGELU(nn.Module):
 
 
55
  def forward(self, x: torch.Tensor) -> torch.Tensor:
56
  out = torch.empty_like(x)
57
  ops.gelu_new(out, x)
@@ -59,6 +71,8 @@ class NewGELU(nn.Module):
59
 
60
 
61
  class QuickGELU(nn.Module):
 
 
62
  def forward(self, x: torch.Tensor) -> torch.Tensor:
63
  out = torch.empty_like(x)
64
  ops.gelu_quick(out, x)
 
5
 
6
 
7
  class SiluAndMul(nn.Module):
8
+ can_torch_compile: bool = True
9
+
10
  def forward(self, x: torch.Tensor):
11
  d = x.shape[-1] // 2
12
  output_shape = x.shape[:-1] + (d,)
 
16
 
17
 
18
  class GeluAndMul(nn.Module):
19
+ can_torch_compile: bool = True
20
+
21
  def forward(self, x: torch.Tensor):
22
  d = x.shape[-1] // 2
23
  output_shape = x.shape[:-1] + (d,)
 
27
 
28
 
29
  class GeluTanhAndMul(nn.Module):
30
+ can_torch_compile: bool = True
31
+
32
  def forward(self, x: torch.Tensor):
33
  d = x.shape[-1] // 2
34
  output_shape = x.shape[:-1] + (d,)
 
38
 
39
 
40
  class FatreluAndMul(nn.Module):
41
+ can_torch_compile: bool = True
42
+
43
  def __init__(self, threshold: float = 0.0):
44
  super().__init__()
45
  self.threshold = threshold
 
53
 
54
 
55
  class FastGELU(nn.Module):
56
+ can_torch_compile: bool = True
57
+
58
  def forward(self, x: torch.Tensor) -> torch.Tensor:
59
  out = torch.empty_like(x)
60
  ops.gelu_fast(out, x)
 
62
 
63
 
64
  class NewGELU(nn.Module):
65
+ can_torch_compile: bool = True
66
+
67
  def forward(self, x: torch.Tensor) -> torch.Tensor:
68
  out = torch.empty_like(x)
69
  ops.gelu_new(out, x)
 
71
 
72
 
73
  class QuickGELU(nn.Module):
74
+ can_torch_compile: bool = True
75
+
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
  out = torch.empty_like(x)
78
  ops.gelu_quick(out, x)