Update ultravox_model.py
Browse files- ultravox_model.py +4 -1
ultravox_model.py
CHANGED
@@ -412,7 +412,10 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
412 |
cls, config: UltravoxConfig
|
413 |
) -> "UltravoxProjector":
|
414 |
projector = UltravoxProjector(config)
|
415 |
-
|
|
|
|
|
|
|
416 |
return projector
|
417 |
|
418 |
@classmethod
|
|
|
412 |
cls, config: UltravoxConfig
|
413 |
) -> "UltravoxProjector":
|
414 |
projector = UltravoxProjector(config)
|
415 |
+
dtype = config.torch_dtype
|
416 |
+
if isinstance(dtype, str):
|
417 |
+
dtype = getattr(torch, dtype)
|
418 |
+
projector.to(dtype)
|
419 |
return projector
|
420 |
|
421 |
@classmethod
|