farzadab commited on
Commit
bb86f58
·
verified ·
1 Parent(s): 7928241

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +4 -1
ultravox_model.py CHANGED
@@ -76,6 +76,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
76
  return model
77
 
78
  def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
 
 
 
79
  if (
80
  self.config.text_model_id is not None
81
  and self.language_model.device.type == "meta"
@@ -850,4 +853,4 @@ UltravoxModel.register_for_auto_class()
850
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
851
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
852
 
853
- transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
76
  return model
77
 
78
  def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
+ if "torch_dtype" in kwargs:
80
+ self.config.torch_dtype = kwargs.pop("torch_dtype")
81
+
82
  if (
83
  self.config.text_model_id is not None
84
  and self.language_model.device.type == "meta"
 
853
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
854
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
855
 
856
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU