farzadab commited on
Commit
b4f3a7b
·
verified ·
1 Parent(s): e91d05d

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +4 -1
ultravox_model.py CHANGED
@@ -412,7 +412,10 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
412
  cls, config: UltravoxConfig
413
  ) -> "UltravoxProjector":
414
  projector = UltravoxProjector(config)
415
- projector.to(config.torch_dtype)
 
 
 
416
  return projector
417
 
418
  @classmethod
 
412
  cls, config: UltravoxConfig
413
  ) -> "UltravoxProjector":
414
  projector = UltravoxProjector(config)
415
+ dtype = config.torch_dtype
416
+ if isinstance(dtype, str):
417
+ dtype = getattr(torch, dtype)
418
+ projector.to(dtype)
419
  return projector
420
 
421
  @classmethod