Update modeling_xlm_roberta.py
#13
by
bwang0911
- opened
- modeling_xlm_roberta.py +3 -0
modeling_xlm_roberta.py
CHANGED
@@ -1168,6 +1168,9 @@ class XLMRobertaClassificationHead(nn.Module):
|
|
1168 |
|
1169 |
def __init__(self, config):
|
1170 |
super().__init__()
|
|
|
|
|
|
|
1171 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
1172 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
1173 |
classifier_dropout = (
|
|
|
1168 |
|
1169 |
def __init__(self, config):
|
1170 |
super().__init__()
|
1171 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
1172 |
+
if fused_bias_fc and FusedDense is None:
|
1173 |
+
raise ImportError("fused_dense is not installed")
|
1174 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
1175 |
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
1176 |
classifier_dropout = (
|