Update modeling_spark_tts.py
Browse files- modeling_spark_tts.py +64 -74
modeling_spark_tts.py
CHANGED
@@ -3011,7 +3011,6 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3011 |
return outputs # Should be CausalLMOutputWithPast or tuple
|
3012 |
|
3013 |
@classmethod
|
3014 |
-
@torch.no_grad() # Decorator often used for loading, though internal ops might need grads later
|
3015 |
def from_pretrained(
|
3016 |
cls,
|
3017 |
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
@@ -3021,46 +3020,39 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3021 |
ignore_mismatched_sizes: bool = False,
|
3022 |
force_download: bool = False,
|
3023 |
local_files_only: bool = False,
|
3024 |
-
token: Optional[Union[
|
3025 |
revision: str = "main",
|
3026 |
use_safetensors: Optional[bool] = None,
|
3027 |
# New args from base class signature to pass down if relevant
|
3028 |
-
state_dict = None,
|
3029 |
-
device_map = None,
|
3030 |
-
low_cpu_mem_usage = None,
|
3031 |
-
torch_dtype = "auto",
|
3032 |
-
quantization_config = None,
|
3033 |
-
trust_remote_code = None,
|
3034 |
# Add other relevant args from base class if needed: subfolder, variant, etc.
|
3035 |
-
subfolder: str = "",
|
3036 |
variant: Optional[str] = None,
|
3037 |
**kwargs,
|
3038 |
):
|
3039 |
# --- Argument Handling & Initial Setup ---
|
3040 |
-
# Pop device map and dtype early - handle placement later
|
3041 |
if device_map:
|
3042 |
logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
|
3043 |
if low_cpu_mem_usage:
|
3044 |
logger.info("`low_cpu_mem_usage` is set, but simplified loading is used. Memory usage might not be optimized.")
|
3045 |
-
|
3046 |
-
# Handle trust_remote_code explicitly for custom code loading
|
3047 |
if trust_remote_code is None:
|
3048 |
-
logger.warning(
|
3049 |
-
"Loading SparkTTSModel requires custom code. Setting `trust_remote_code=True`. "
|
3050 |
-
"Make sure you trust the source of the code you are loading."
|
3051 |
-
)
|
3052 |
trust_remote_code = True
|
3053 |
elif not trust_remote_code:
|
3054 |
raise ValueError("Loading SparkTTSModel requires `trust_remote_code=True`.")
|
3055 |
|
3056 |
-
# Pop unused kwargs specific to base class loading logic if not handled here
|
3057 |
kwargs.pop("output_loading_info", None)
|
3058 |
kwargs.pop("_from_auto", None)
|
3059 |
-
kwargs.pop("attn_implementation", None)
|
3060 |
|
3061 |
# --- 1. Resolve the main model directory ---
|
3062 |
if state_dict is not None:
|
3063 |
-
raise ValueError("Explicitly passing `state_dict` is not supported for this composite model.
|
3064 |
if pretrained_model_name_or_path is None:
|
3065 |
raise ValueError("`pretrained_model_name_or_path` must be provided.")
|
3066 |
|
@@ -3075,6 +3067,7 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3075 |
logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
|
3076 |
try:
|
3077 |
# Use snapshot_download to get all necessary files
|
|
|
3078 |
resolved_model_path_str = snapshot_download(
|
3079 |
repo_id=str(pretrained_model_name_or_path),
|
3080 |
cache_dir=cache_dir,
|
@@ -3082,82 +3075,85 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3082 |
local_files_only=local_files_only,
|
3083 |
token=token,
|
3084 |
revision=revision,
|
3085 |
-
allow_patterns=[
|
3086 |
"*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt",
|
3087 |
-
"README.md", ".gitattributes",
|
3088 |
-
"LLM/*", "BiCodec/*", "wav2vec2-large-xlsr-53/*"
|
3089 |
],
|
3090 |
-
ignore_patterns=["*.git*", "*.h5", "*.ot", "*.msgpack"],
|
3091 |
-
|
3092 |
-
|
|
|
3093 |
)
|
3094 |
resolved_model_path = Path(resolved_model_path_str)
|
3095 |
logger.info(f"Model files downloaded to cache: {resolved_model_path}")
|
3096 |
except Exception as e:
|
|
|
|
|
|
|
3097 |
raise OSError(
|
3098 |
-
f"Failed to download model '{pretrained_model_name_or_path}' (
|
3099 |
f"Error: {e}"
|
3100 |
)
|
3101 |
|
3102 |
if not resolved_model_path.is_dir():
|
3103 |
raise EnvironmentError(f"Resolved model path is not a directory: {resolved_model_path}")
|
3104 |
|
3105 |
-
# If subfolder
|
3106 |
if subfolder:
|
3107 |
-
|
3108 |
-
if not
|
3109 |
-
raise EnvironmentError(f"Subfolder '{subfolder}' not found within the resolved path: {resolved_model_path
|
|
|
|
|
3110 |
|
3111 |
|
3112 |
# --- 2. Load the main configuration ---
|
3113 |
if not isinstance(config, PretrainedConfig):
|
|
|
3114 |
config_path = config if config is not None else resolved_model_path
|
3115 |
try:
|
3116 |
loaded_config, model_kwargs = SparkTTSConfig.from_pretrained(
|
3117 |
-
config_path,
|
3118 |
-
*model_args,
|
3119 |
cache_dir=cache_dir,
|
3120 |
force_download=force_download if not is_local else False,
|
3121 |
local_files_only=local_files_only or is_local,
|
3122 |
token=token,
|
3123 |
-
revision=revision,
|
3124 |
-
trust_remote_code=trust_remote_code,
|
3125 |
-
|
3126 |
return_unused_kwargs=True,
|
3127 |
-
**kwargs,
|
3128 |
)
|
3129 |
config = loaded_config
|
3130 |
-
kwargs = model_kwargs
|
3131 |
except OSError as e:
|
3132 |
-
raise OSError(f"Cannot load config
|
3133 |
-
# else: config object was passed directly
|
3134 |
|
3135 |
# --- Determine final torch_dtype ---
|
3136 |
-
final_torch_dtype = torch_dtype
|
3137 |
if final_torch_dtype == "auto":
|
3138 |
-
final_torch_dtype = getattr(config, "torch_dtype", None)
|
3139 |
-
# Convert string to torch.dtype object if needed
|
3140 |
if isinstance(final_torch_dtype, str) and final_torch_dtype != "auto":
|
3141 |
try:
|
3142 |
final_torch_dtype = getattr(torch, final_torch_dtype)
|
3143 |
except AttributeError:
|
3144 |
logger.warning(f"Invalid torch_dtype string: {final_torch_dtype}. Falling back to default.")
|
3145 |
-
final_torch_dtype = None
|
3146 |
elif final_torch_dtype == "auto":
|
3147 |
-
final_torch_dtype = None
|
3148 |
|
3149 |
-
# --- Helper function to resolve paths relative to the
|
3150 |
-
# (This handles components potentially being in subfolders specified in config)
|
3151 |
def _resolve_sub_path(sub_path_str):
|
3152 |
p = Path(sub_path_str)
|
3153 |
if p.is_absolute():
|
3154 |
if not p.exists(): logger.warning(f"Absolute path specified for sub-component does not exist: {p}")
|
3155 |
return str(p)
|
3156 |
else:
|
3157 |
-
# Resolve relative to the main model path
|
3158 |
resolved = resolved_model_path / p
|
3159 |
if not resolved.exists():
|
3160 |
-
# Check if the path exists without the leading './' often found in configs
|
3161 |
resolved_alt = resolved_model_path / sub_path_str.lstrip('./')
|
3162 |
if resolved_alt.exists():
|
3163 |
resolved = resolved_alt
|
@@ -3171,26 +3167,24 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3171 |
"force_download": force_download,
|
3172 |
"local_files_only": local_files_only,
|
3173 |
"token": token,
|
3174 |
-
"revision": revision,
|
3175 |
-
"trust_remote_code": trust_remote_code,
|
3176 |
-
"torch_dtype": final_torch_dtype,
|
3177 |
"use_safetensors": use_safetensors,
|
3178 |
-
# Pass quantization config if provided and relevant to component
|
3179 |
"quantization_config": quantization_config if quantization_config else None,
|
3180 |
-
# Pass variant if needed for specific component checkpoints
|
3181 |
"variant": variant,
|
3182 |
-
|
3183 |
-
**kwargs,
|
3184 |
}
|
3185 |
|
3186 |
# --- 3. Load Sub-components ---
|
3187 |
-
|
3188 |
# --- Load LLM ---
|
3189 |
llm_path = _resolve_sub_path(config.llm_model_name_or_path)
|
3190 |
logger.info(f"Loading LLM from resolved path: {llm_path}")
|
3191 |
try:
|
|
|
3192 |
llm = AutoModelForCausalLM.from_pretrained(
|
3193 |
-
llm_path, **component_loading_kwargs
|
3194 |
)
|
3195 |
except Exception as e:
|
3196 |
raise OSError(f"Failed to load LLM from {llm_path}: {e}")
|
@@ -3199,47 +3193,46 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3199 |
w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
|
3200 |
logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
|
3201 |
try:
|
3202 |
-
#
|
3203 |
wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
3204 |
w2v_path,
|
3205 |
-
cache_dir=cache_dir,
|
3206 |
force_download=force_download,
|
3207 |
local_files_only=local_files_only,
|
3208 |
token=token,
|
3209 |
revision=revision,
|
3210 |
-
#
|
3211 |
)
|
|
|
3212 |
wav2vec2_model = Wav2Vec2Model.from_pretrained(
|
3213 |
-
w2v_path, **component_loading_kwargs
|
3214 |
)
|
3215 |
-
wav2vec2_model.config.output_hidden_states = True
|
3216 |
except Exception as e:
|
3217 |
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
|
3218 |
|
3219 |
# --- Load BiCodec ---
|
3220 |
bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
|
3221 |
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
|
3222 |
-
if not config.bicodec_config:
|
3223 |
-
raise ValueError("BiCodec configuration (`bicodec_config`) not found
|
3224 |
try:
|
3225 |
-
# Pass the SparkTTSBiCodecConfig *object* directly
|
3226 |
bicodec = BiCodec.load_from_config_and_checkpoint(
|
3227 |
model_dir=Path(bicodec_path),
|
3228 |
-
bicodec_config_object=config.bicodec_config
|
3229 |
)
|
3230 |
if not isinstance(bicodec, torch.nn.Module):
|
3231 |
logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module.")
|
3232 |
-
# Apply torch_dtype to BiCodec if it's an nn.Module and dtype is set
|
3233 |
if isinstance(bicodec, torch.nn.Module) and final_torch_dtype:
|
3234 |
bicodec = bicodec.to(dtype=final_torch_dtype)
|
3235 |
-
|
3236 |
except FileNotFoundError as e:
|
3237 |
-
raise OSError(f"Failed to load BiCodec:
|
3238 |
except Exception as e:
|
3239 |
logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
|
3240 |
import traceback
|
3241 |
traceback.print_exc()
|
3242 |
-
raise OSError(f"Failed to load BiCodec from {bicodec_path}.
|
|
|
3243 |
|
3244 |
# --- 4. Instantiate the main model wrapper ---
|
3245 |
model = cls(
|
@@ -3251,20 +3244,17 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
3251 |
)
|
3252 |
|
3253 |
# --- 5. Handle device placement (Simplified) ---
|
3254 |
-
# Determine target device (simple logic: CUDA > MPS > CPU)
|
3255 |
if torch.cuda.is_available():
|
3256 |
final_device = torch.device("cuda")
|
3257 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
3258 |
final_device = torch.device("mps")
|
3259 |
else:
|
3260 |
final_device = torch.device("cpu")
|
3261 |
-
|
3262 |
logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
|
3263 |
try:
|
3264 |
model.to(final_device)
|
3265 |
except Exception as e:
|
3266 |
logger.error(f"Failed to move model to device {final_device}. Error: {e}")
|
3267 |
-
logger.warning("Device placement might be incomplete. Check component types and implementations.")
|
3268 |
|
3269 |
# --- 6. Return the loaded and prepared model ---
|
3270 |
return model
|
|
|
3011 |
return outputs # Should be CausalLMOutputWithPast or tuple
|
3012 |
|
3013 |
@classmethod
|
|
|
3014 |
def from_pretrained(
|
3015 |
cls,
|
3016 |
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
|
|
3020 |
ignore_mismatched_sizes: bool = False,
|
3021 |
force_download: bool = False,
|
3022 |
local_files_only: bool = False,
|
3023 |
+
token: Optional[Union[bool, str]] = None,
|
3024 |
revision: str = "main",
|
3025 |
use_safetensors: Optional[bool] = None,
|
3026 |
# New args from base class signature to pass down if relevant
|
3027 |
+
state_dict = None,
|
3028 |
+
device_map = None,
|
3029 |
+
low_cpu_mem_usage = None,
|
3030 |
+
torch_dtype = "auto",
|
3031 |
+
quantization_config = None,
|
3032 |
+
trust_remote_code = None,
|
3033 |
# Add other relevant args from base class if needed: subfolder, variant, etc.
|
3034 |
+
subfolder: str = "", # Keep subfolder arg for overall loading logic
|
3035 |
variant: Optional[str] = None,
|
3036 |
**kwargs,
|
3037 |
):
|
3038 |
# --- Argument Handling & Initial Setup ---
|
|
|
3039 |
if device_map:
|
3040 |
logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
|
3041 |
if low_cpu_mem_usage:
|
3042 |
logger.info("`low_cpu_mem_usage` is set, but simplified loading is used. Memory usage might not be optimized.")
|
|
|
|
|
3043 |
if trust_remote_code is None:
|
3044 |
+
logger.warning("Loading SparkTTSModel requires custom code. Setting `trust_remote_code=True`.")
|
|
|
|
|
|
|
3045 |
trust_remote_code = True
|
3046 |
elif not trust_remote_code:
|
3047 |
raise ValueError("Loading SparkTTSModel requires `trust_remote_code=True`.")
|
3048 |
|
|
|
3049 |
kwargs.pop("output_loading_info", None)
|
3050 |
kwargs.pop("_from_auto", None)
|
3051 |
+
kwargs.pop("attn_implementation", None)
|
3052 |
|
3053 |
# --- 1. Resolve the main model directory ---
|
3054 |
if state_dict is not None:
|
3055 |
+
raise ValueError("Explicitly passing `state_dict` is not supported for this composite model.")
|
3056 |
if pretrained_model_name_or_path is None:
|
3057 |
raise ValueError("`pretrained_model_name_or_path` must be provided.")
|
3058 |
|
|
|
3067 |
logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
|
3068 |
try:
|
3069 |
# Use snapshot_download to get all necessary files
|
3070 |
+
# REMOVED subfolder=subfolder from this call
|
3071 |
resolved_model_path_str = snapshot_download(
|
3072 |
repo_id=str(pretrained_model_name_or_path),
|
3073 |
cache_dir=cache_dir,
|
|
|
3075 |
local_files_only=local_files_only,
|
3076 |
token=token,
|
3077 |
revision=revision,
|
3078 |
+
allow_patterns=[
|
3079 |
"*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt",
|
3080 |
+
"README.md", ".gitattributes",
|
3081 |
+
"LLM/*", "BiCodec/*", "wav2vec2-large-xlsr-53/*"
|
3082 |
],
|
3083 |
+
ignore_patterns=["*.git*", "*.h5", "*.ot", "*.msgpack"],
|
3084 |
+
repo_type="model", # Explicitly set repo_type
|
3085 |
+
# max_workers=..., # Can adjust workers if needed
|
3086 |
+
# user_agent=..., # Can add user agent
|
3087 |
)
|
3088 |
resolved_model_path = Path(resolved_model_path_str)
|
3089 |
logger.info(f"Model files downloaded to cache: {resolved_model_path}")
|
3090 |
except Exception as e:
|
3091 |
+
# Catch potential TypeErrors from snapshot_download if args change again
|
3092 |
+
if isinstance(e, TypeError) and 'unexpected keyword argument' in str(e):
|
3093 |
+
logger.error(f"snapshot_download() received an unexpected keyword argument. Check huggingface_hub version compatibility. Error: {e}")
|
3094 |
raise OSError(
|
3095 |
+
f"Failed to download model '{pretrained_model_name_or_path}' (revision: '{revision}') from Hugging Face Hub. "
|
3096 |
f"Error: {e}"
|
3097 |
)
|
3098 |
|
3099 |
if not resolved_model_path.is_dir():
|
3100 |
raise EnvironmentError(f"Resolved model path is not a directory: {resolved_model_path}")
|
3101 |
|
3102 |
+
# If subfolder was specified for from_pretrained, adjust the path *after* download
|
3103 |
if subfolder:
|
3104 |
+
resolved_model_path_with_subfolder = resolved_model_path / subfolder
|
3105 |
+
if not resolved_model_path_with_subfolder.is_dir():
|
3106 |
+
raise EnvironmentError(f"Subfolder '{subfolder}' not found within the resolved path: {resolved_model_path}")
|
3107 |
+
resolved_model_path = resolved_model_path_with_subfolder # Update path to include subfolder
|
3108 |
+
logger.info(f"Using subfolder within resolved path: {resolved_model_path}")
|
3109 |
|
3110 |
|
3111 |
# --- 2. Load the main configuration ---
|
3112 |
if not isinstance(config, PretrainedConfig):
|
3113 |
+
# Load config from the potentially subfolder-adjusted path
|
3114 |
config_path = config if config is not None else resolved_model_path
|
3115 |
try:
|
3116 |
loaded_config, model_kwargs = SparkTTSConfig.from_pretrained(
|
3117 |
+
config_path, # Load from the final resolved path
|
3118 |
+
*model_args,
|
3119 |
cache_dir=cache_dir,
|
3120 |
force_download=force_download if not is_local else False,
|
3121 |
local_files_only=local_files_only or is_local,
|
3122 |
token=token,
|
3123 |
+
revision=revision, # Pass revision for config loading too
|
3124 |
+
trust_remote_code=trust_remote_code,
|
3125 |
+
subfolder="", # Config is expected at the root of resolved_model_path
|
3126 |
return_unused_kwargs=True,
|
3127 |
+
**kwargs,
|
3128 |
)
|
3129 |
config = loaded_config
|
3130 |
+
kwargs = model_kwargs
|
3131 |
except OSError as e:
|
3132 |
+
raise OSError(f"Cannot load config from {config_path}. Check `config.json` exists and is correctly formatted. Error: {e}")
|
|
|
3133 |
|
3134 |
# --- Determine final torch_dtype ---
|
3135 |
+
final_torch_dtype = torch_dtype
|
3136 |
if final_torch_dtype == "auto":
|
3137 |
+
final_torch_dtype = getattr(config, "torch_dtype", None)
|
|
|
3138 |
if isinstance(final_torch_dtype, str) and final_torch_dtype != "auto":
|
3139 |
try:
|
3140 |
final_torch_dtype = getattr(torch, final_torch_dtype)
|
3141 |
except AttributeError:
|
3142 |
logger.warning(f"Invalid torch_dtype string: {final_torch_dtype}. Falling back to default.")
|
3143 |
+
final_torch_dtype = None
|
3144 |
elif final_torch_dtype == "auto":
|
3145 |
+
final_torch_dtype = None
|
3146 |
|
3147 |
+
# --- Helper function to resolve component paths relative to the final resolved_model_path ---
|
|
|
3148 |
def _resolve_sub_path(sub_path_str):
|
3149 |
p = Path(sub_path_str)
|
3150 |
if p.is_absolute():
|
3151 |
if not p.exists(): logger.warning(f"Absolute path specified for sub-component does not exist: {p}")
|
3152 |
return str(p)
|
3153 |
else:
|
3154 |
+
# Resolve relative to the potentially subfolder-adjusted main model path
|
3155 |
resolved = resolved_model_path / p
|
3156 |
if not resolved.exists():
|
|
|
3157 |
resolved_alt = resolved_model_path / sub_path_str.lstrip('./')
|
3158 |
if resolved_alt.exists():
|
3159 |
resolved = resolved_alt
|
|
|
3167 |
"force_download": force_download,
|
3168 |
"local_files_only": local_files_only,
|
3169 |
"token": token,
|
3170 |
+
"revision": revision, # Pass revision to component loaders
|
3171 |
+
"trust_remote_code": trust_remote_code,
|
3172 |
+
"torch_dtype": final_torch_dtype,
|
3173 |
"use_safetensors": use_safetensors,
|
|
|
3174 |
"quantization_config": quantization_config if quantization_config else None,
|
|
|
3175 |
"variant": variant,
|
3176 |
+
**kwargs, # Pass remaining kwargs
|
|
|
3177 |
}
|
3178 |
|
3179 |
# --- 3. Load Sub-components ---
|
3180 |
+
# (LLM, Wav2Vec2, BiCodec loading logic remains the same as previous version)
|
3181 |
# --- Load LLM ---
|
3182 |
llm_path = _resolve_sub_path(config.llm_model_name_or_path)
|
3183 |
logger.info(f"Loading LLM from resolved path: {llm_path}")
|
3184 |
try:
|
3185 |
+
# Pass subfolder="" because llm_path is now absolute or correctly relative
|
3186 |
llm = AutoModelForCausalLM.from_pretrained(
|
3187 |
+
llm_path, subfolder="", **component_loading_kwargs
|
3188 |
)
|
3189 |
except Exception as e:
|
3190 |
raise OSError(f"Failed to load LLM from {llm_path}: {e}")
|
|
|
3193 |
w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
|
3194 |
logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
|
3195 |
try:
|
3196 |
+
# Load extractor without full component_loading_kwargs if they cause issues
|
3197 |
wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
3198 |
w2v_path,
|
3199 |
+
cache_dir=cache_dir,
|
3200 |
force_download=force_download,
|
3201 |
local_files_only=local_files_only,
|
3202 |
token=token,
|
3203 |
revision=revision,
|
3204 |
+
subfolder="", # Path is resolved
|
3205 |
)
|
3206 |
+
# Load model with full kwargs
|
3207 |
wav2vec2_model = Wav2Vec2Model.from_pretrained(
|
3208 |
+
w2v_path, subfolder="", **component_loading_kwargs
|
3209 |
)
|
3210 |
+
wav2vec2_model.config.output_hidden_states = True
|
3211 |
except Exception as e:
|
3212 |
raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
|
3213 |
|
3214 |
# --- Load BiCodec ---
|
3215 |
bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
|
3216 |
logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
|
3217 |
+
if not config.bicodec_config:
|
3218 |
+
raise ValueError("BiCodec configuration (`bicodec_config`) not found in SparkTTSConfig.")
|
3219 |
try:
|
|
|
3220 |
bicodec = BiCodec.load_from_config_and_checkpoint(
|
3221 |
model_dir=Path(bicodec_path),
|
3222 |
+
bicodec_config_object=config.bicodec_config
|
3223 |
)
|
3224 |
if not isinstance(bicodec, torch.nn.Module):
|
3225 |
logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module.")
|
|
|
3226 |
if isinstance(bicodec, torch.nn.Module) and final_torch_dtype:
|
3227 |
bicodec = bicodec.to(dtype=final_torch_dtype)
|
|
|
3228 |
except FileNotFoundError as e:
|
3229 |
+
raise OSError(f"Failed to load BiCodec: Required file not found in {bicodec_path}. Error: {e}")
|
3230 |
except Exception as e:
|
3231 |
logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
|
3232 |
import traceback
|
3233 |
traceback.print_exc()
|
3234 |
+
raise OSError(f"Failed to load BiCodec from {bicodec_path}. Error: {e}")
|
3235 |
+
|
3236 |
|
3237 |
# --- 4. Instantiate the main model wrapper ---
|
3238 |
model = cls(
|
|
|
3244 |
)
|
3245 |
|
3246 |
# --- 5. Handle device placement (Simplified) ---
|
|
|
3247 |
if torch.cuda.is_available():
|
3248 |
final_device = torch.device("cuda")
|
3249 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
3250 |
final_device = torch.device("mps")
|
3251 |
else:
|
3252 |
final_device = torch.device("cpu")
|
|
|
3253 |
logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
|
3254 |
try:
|
3255 |
model.to(final_device)
|
3256 |
except Exception as e:
|
3257 |
logger.error(f"Failed to move model to device {final_device}. Error: {e}")
|
|
|
3258 |
|
3259 |
# --- 6. Return the loaded and prepared model ---
|
3260 |
return model
|