Update modeling_xlm_roberta.py

#13
Files changed (1) hide show
  1. modeling_xlm_roberta.py +3 -0
modeling_xlm_roberta.py CHANGED
@@ -1168,6 +1168,9 @@ class XLMRobertaClassificationHead(nn.Module):
1168
 
1169
  def __init__(self, config):
1170
  super().__init__()
 
 
 
1171
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1172
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
1173
  classifier_dropout = (
 
1168
 
1169
  def __init__(self, config):
1170
  super().__init__()
1171
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
1172
+ if fused_bias_fc and FusedDense is None:
1173
+ raise ImportError("fused_dense is not installed")
1174
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1175
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
1176
  classifier_dropout = (