ancv commited on
Commit
97682b0
·
verified ·
1 Parent(s): 2143f77

Update modeling_spark_tts.py

Browse files
Files changed (1) hide show
  1. modeling_spark_tts.py +64 -74
modeling_spark_tts.py CHANGED
@@ -3011,7 +3011,6 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3011
  return outputs # Should be CausalLMOutputWithPast or tuple
3012
 
3013
  @classmethod
3014
- @torch.no_grad() # Decorator often used for loading, though internal ops might need grads later
3015
  def from_pretrained(
3016
  cls,
3017
  pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
@@ -3021,46 +3020,39 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3021
  ignore_mismatched_sizes: bool = False,
3022
  force_download: bool = False,
3023
  local_files_only: bool = False,
3024
- token: Optional[Union[str, bool]] = None,
3025
  revision: str = "main",
3026
  use_safetensors: Optional[bool] = None,
3027
  # New args from base class signature to pass down if relevant
3028
- state_dict = None, # Pass state_dict explicitly is usually avoided with component loading
3029
- device_map = None, # Simplified handling
3030
- low_cpu_mem_usage = None, # Simplified handling
3031
- torch_dtype = "auto", # Keep "auto" as default
3032
- quantization_config = None, # Pass down if needed by components
3033
- trust_remote_code = None, # Default to None, will be set below
3034
  # Add other relevant args from base class if needed: subfolder, variant, etc.
3035
- subfolder: str = "",
3036
  variant: Optional[str] = None,
3037
  **kwargs,
3038
  ):
3039
  # --- Argument Handling & Initial Setup ---
3040
- # Pop device map and dtype early - handle placement later
3041
  if device_map:
3042
  logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
3043
  if low_cpu_mem_usage:
3044
  logger.info("`low_cpu_mem_usage` is set, but simplified loading is used. Memory usage might not be optimized.")
3045
-
3046
- # Handle trust_remote_code explicitly for custom code loading
3047
  if trust_remote_code is None:
3048
- logger.warning(
3049
- "Loading SparkTTSModel requires custom code. Setting `trust_remote_code=True`. "
3050
- "Make sure you trust the source of the code you are loading."
3051
- )
3052
  trust_remote_code = True
3053
  elif not trust_remote_code:
3054
  raise ValueError("Loading SparkTTSModel requires `trust_remote_code=True`.")
3055
 
3056
- # Pop unused kwargs specific to base class loading logic if not handled here
3057
  kwargs.pop("output_loading_info", None)
3058
  kwargs.pop("_from_auto", None)
3059
- kwargs.pop("attn_implementation", None) # LLM loader might handle this
3060
 
3061
  # --- 1. Resolve the main model directory ---
3062
  if state_dict is not None:
3063
- raise ValueError("Explicitly passing `state_dict` is not supported for this composite model. Load components individually if needed.")
3064
  if pretrained_model_name_or_path is None:
3065
  raise ValueError("`pretrained_model_name_or_path` must be provided.")
3066
 
@@ -3075,6 +3067,7 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3075
  logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
3076
  try:
3077
  # Use snapshot_download to get all necessary files
 
3078
  resolved_model_path_str = snapshot_download(
3079
  repo_id=str(pretrained_model_name_or_path),
3080
  cache_dir=cache_dir,
@@ -3082,82 +3075,85 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3082
  local_files_only=local_files_only,
3083
  token=token,
3084
  revision=revision,
3085
- allow_patterns=[ # Be more specific if possible
3086
  "*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt",
3087
- "README.md", ".gitattributes", # Common files
3088
- "LLM/*", "BiCodec/*", "wav2vec2-large-xlsr-53/*" # Component folders
3089
  ],
3090
- ignore_patterns=["*.git*", "*.h5", "*.ot", "*.msgpack"], # Ignore unnecessary files
3091
- subfolder=subfolder, # Pass subfolder to snapshot_download
3092
- repo_type="model", # Specify repo type
 
3093
  )
3094
  resolved_model_path = Path(resolved_model_path_str)
3095
  logger.info(f"Model files downloaded to cache: {resolved_model_path}")
3096
  except Exception as e:
 
 
 
3097
  raise OSError(
3098
- f"Failed to download model '{pretrained_model_name_or_path}' (subfolder: '{subfolder}') from Hugging Face Hub. "
3099
  f"Error: {e}"
3100
  )
3101
 
3102
  if not resolved_model_path.is_dir():
3103
  raise EnvironmentError(f"Resolved model path is not a directory: {resolved_model_path}")
3104
 
3105
- # If subfolder is used, update resolved_model_path to point inside it
3106
  if subfolder:
3107
- resolved_model_path = resolved_model_path / subfolder
3108
- if not resolved_model_path.is_dir():
3109
- raise EnvironmentError(f"Subfolder '{subfolder}' not found within the resolved path: {resolved_model_path.parent}")
 
 
3110
 
3111
 
3112
  # --- 2. Load the main configuration ---
3113
  if not isinstance(config, PretrainedConfig):
 
3114
  config_path = config if config is not None else resolved_model_path
3115
  try:
3116
  loaded_config, model_kwargs = SparkTTSConfig.from_pretrained(
3117
- config_path,
3118
- *model_args, # Pass model_args here
3119
  cache_dir=cache_dir,
3120
  force_download=force_download if not is_local else False,
3121
  local_files_only=local_files_only or is_local,
3122
  token=token,
3123
- revision=revision,
3124
- trust_remote_code=trust_remote_code, # Crucial if config class is remote
3125
- #subfolder="", # Config is usually at the root, not subfolder
3126
  return_unused_kwargs=True,
3127
- **kwargs, # Pass remaining kwargs for config loading
3128
  )
3129
  config = loaded_config
3130
- kwargs = model_kwargs # Update kwargs with unused ones
3131
  except OSError as e:
3132
- raise OSError(f"Cannot load config for '{pretrained_model_name_or_path}'. Check `config.json` exists and is correctly formatted. Error: {e}")
3133
- # else: config object was passed directly
3134
 
3135
  # --- Determine final torch_dtype ---
3136
- final_torch_dtype = torch_dtype # Explicit arg has highest prio
3137
  if final_torch_dtype == "auto":
3138
- final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present
3139
- # Convert string to torch.dtype object if needed
3140
  if isinstance(final_torch_dtype, str) and final_torch_dtype != "auto":
3141
  try:
3142
  final_torch_dtype = getattr(torch, final_torch_dtype)
3143
  except AttributeError:
3144
  logger.warning(f"Invalid torch_dtype string: {final_torch_dtype}. Falling back to default.")
3145
- final_torch_dtype = None # Fallback to None (which means float32 usually)
3146
  elif final_torch_dtype == "auto":
3147
- final_torch_dtype = None # Treat "auto" as None for component loading
3148
 
3149
- # --- Helper function to resolve paths relative to the main model directory ---
3150
- # (This handles components potentially being in subfolders specified in config)
3151
  def _resolve_sub_path(sub_path_str):
3152
  p = Path(sub_path_str)
3153
  if p.is_absolute():
3154
  if not p.exists(): logger.warning(f"Absolute path specified for sub-component does not exist: {p}")
3155
  return str(p)
3156
  else:
3157
- # Resolve relative to the main model path (which might be in cache or local)
3158
  resolved = resolved_model_path / p
3159
  if not resolved.exists():
3160
- # Check if the path exists without the leading './' often found in configs
3161
  resolved_alt = resolved_model_path / sub_path_str.lstrip('./')
3162
  if resolved_alt.exists():
3163
  resolved = resolved_alt
@@ -3171,26 +3167,24 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3171
  "force_download": force_download,
3172
  "local_files_only": local_files_only,
3173
  "token": token,
3174
- "revision": revision,
3175
- "trust_remote_code": trust_remote_code, # Pass this down
3176
- "torch_dtype": final_torch_dtype, # Pass resolved dtype
3177
  "use_safetensors": use_safetensors,
3178
- # Pass quantization config if provided and relevant to component
3179
  "quantization_config": quantization_config if quantization_config else None,
3180
- # Pass variant if needed for specific component checkpoints
3181
  "variant": variant,
3182
- # Filter kwargs? For now, pass all remaining, component loaders should ignore unused ones.
3183
- **kwargs,
3184
  }
3185
 
3186
  # --- 3. Load Sub-components ---
3187
-
3188
  # --- Load LLM ---
3189
  llm_path = _resolve_sub_path(config.llm_model_name_or_path)
3190
  logger.info(f"Loading LLM from resolved path: {llm_path}")
3191
  try:
 
3192
  llm = AutoModelForCausalLM.from_pretrained(
3193
- llm_path, **component_loading_kwargs
3194
  )
3195
  except Exception as e:
3196
  raise OSError(f"Failed to load LLM from {llm_path}: {e}")
@@ -3199,47 +3193,46 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3199
  w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
3200
  logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
3201
  try:
3202
- # Use specific class for extractor, Auto* might not work if only config is present
3203
  wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
3204
  w2v_path,
3205
- cache_dir=cache_dir, # Pass relevant args
3206
  force_download=force_download,
3207
  local_files_only=local_files_only,
3208
  token=token,
3209
  revision=revision,
3210
- # No trust_remote_code needed usually for feature extractors
3211
  )
 
3212
  wav2vec2_model = Wav2Vec2Model.from_pretrained(
3213
- w2v_path, **component_loading_kwargs # Pass full kwargs here
3214
  )
3215
- wav2vec2_model.config.output_hidden_states = True # Ensure this is set
3216
  except Exception as e:
3217
  raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
3218
 
3219
  # --- Load BiCodec ---
3220
  bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
3221
  logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
3222
- if not config.bicodec_config: # Check if the nested config object exists
3223
- raise ValueError("BiCodec configuration (`bicodec_config`) not found or properly instantiated in SparkTTSConfig.")
3224
  try:
3225
- # Pass the SparkTTSBiCodecConfig *object* directly
3226
  bicodec = BiCodec.load_from_config_and_checkpoint(
3227
  model_dir=Path(bicodec_path),
3228
- bicodec_config_object=config.bicodec_config # Pass the object
3229
  )
3230
  if not isinstance(bicodec, torch.nn.Module):
3231
  logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module.")
3232
- # Apply torch_dtype to BiCodec if it's an nn.Module and dtype is set
3233
  if isinstance(bicodec, torch.nn.Module) and final_torch_dtype:
3234
  bicodec = bicodec.to(dtype=final_torch_dtype)
3235
-
3236
  except FileNotFoundError as e:
3237
- raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}")
3238
  except Exception as e:
3239
  logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
3240
  import traceback
3241
  traceback.print_exc()
3242
- raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation, config, and file paths. Error: {e}")
 
3243
 
3244
  # --- 4. Instantiate the main model wrapper ---
3245
  model = cls(
@@ -3251,20 +3244,17 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3251
  )
3252
 
3253
  # --- 5. Handle device placement (Simplified) ---
3254
- # Determine target device (simple logic: CUDA > MPS > CPU)
3255
  if torch.cuda.is_available():
3256
  final_device = torch.device("cuda")
3257
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # Check MPS availability
3258
  final_device = torch.device("mps")
3259
  else:
3260
  final_device = torch.device("cpu")
3261
-
3262
  logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
3263
  try:
3264
  model.to(final_device)
3265
  except Exception as e:
3266
  logger.error(f"Failed to move model to device {final_device}. Error: {e}")
3267
- logger.warning("Device placement might be incomplete. Check component types and implementations.")
3268
 
3269
  # --- 6. Return the loaded and prepared model ---
3270
  return model
 
3011
  return outputs # Should be CausalLMOutputWithPast or tuple
3012
 
3013
  @classmethod
 
3014
  def from_pretrained(
3015
  cls,
3016
  pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
 
3020
  ignore_mismatched_sizes: bool = False,
3021
  force_download: bool = False,
3022
  local_files_only: bool = False,
3023
+ token: Optional[Union[bool, str]] = None,
3024
  revision: str = "main",
3025
  use_safetensors: Optional[bool] = None,
3026
  # New args from base class signature to pass down if relevant
3027
+ state_dict = None,
3028
+ device_map = None,
3029
+ low_cpu_mem_usage = None,
3030
+ torch_dtype = "auto",
3031
+ quantization_config = None,
3032
+ trust_remote_code = None,
3033
  # Add other relevant args from base class if needed: subfolder, variant, etc.
3034
+ subfolder: str = "", # Keep subfolder arg for overall loading logic
3035
  variant: Optional[str] = None,
3036
  **kwargs,
3037
  ):
3038
  # --- Argument Handling & Initial Setup ---
 
3039
  if device_map:
3040
  logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.")
3041
  if low_cpu_mem_usage:
3042
  logger.info("`low_cpu_mem_usage` is set, but simplified loading is used. Memory usage might not be optimized.")
 
 
3043
  if trust_remote_code is None:
3044
+ logger.warning("Loading SparkTTSModel requires custom code. Setting `trust_remote_code=True`.")
 
 
 
3045
  trust_remote_code = True
3046
  elif not trust_remote_code:
3047
  raise ValueError("Loading SparkTTSModel requires `trust_remote_code=True`.")
3048
 
 
3049
  kwargs.pop("output_loading_info", None)
3050
  kwargs.pop("_from_auto", None)
3051
+ kwargs.pop("attn_implementation", None)
3052
 
3053
  # --- 1. Resolve the main model directory ---
3054
  if state_dict is not None:
3055
+ raise ValueError("Explicitly passing `state_dict` is not supported for this composite model.")
3056
  if pretrained_model_name_or_path is None:
3057
  raise ValueError("`pretrained_model_name_or_path` must be provided.")
3058
 
 
3067
  logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.")
3068
  try:
3069
  # Use snapshot_download to get all necessary files
3070
+ # REMOVED subfolder=subfolder from this call
3071
  resolved_model_path_str = snapshot_download(
3072
  repo_id=str(pretrained_model_name_or_path),
3073
  cache_dir=cache_dir,
 
3075
  local_files_only=local_files_only,
3076
  token=token,
3077
  revision=revision,
3078
+ allow_patterns=[
3079
  "*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt",
3080
+ "README.md", ".gitattributes",
3081
+ "LLM/*", "BiCodec/*", "wav2vec2-large-xlsr-53/*"
3082
  ],
3083
+ ignore_patterns=["*.git*", "*.h5", "*.ot", "*.msgpack"],
3084
+ repo_type="model", # Explicitly set repo_type
3085
+ # max_workers=..., # Can adjust workers if needed
3086
+ # user_agent=..., # Can add user agent
3087
  )
3088
  resolved_model_path = Path(resolved_model_path_str)
3089
  logger.info(f"Model files downloaded to cache: {resolved_model_path}")
3090
  except Exception as e:
3091
+ # Catch potential TypeErrors from snapshot_download if args change again
3092
+ if isinstance(e, TypeError) and 'unexpected keyword argument' in str(e):
3093
+ logger.error(f"snapshot_download() received an unexpected keyword argument. Check huggingface_hub version compatibility. Error: {e}")
3094
  raise OSError(
3095
+ f"Failed to download model '{pretrained_model_name_or_path}' (revision: '{revision}') from Hugging Face Hub. "
3096
  f"Error: {e}"
3097
  )
3098
 
3099
  if not resolved_model_path.is_dir():
3100
  raise EnvironmentError(f"Resolved model path is not a directory: {resolved_model_path}")
3101
 
3102
+ # If subfolder was specified for from_pretrained, adjust the path *after* download
3103
  if subfolder:
3104
+ resolved_model_path_with_subfolder = resolved_model_path / subfolder
3105
+ if not resolved_model_path_with_subfolder.is_dir():
3106
+ raise EnvironmentError(f"Subfolder '{subfolder}' not found within the resolved path: {resolved_model_path}")
3107
+ resolved_model_path = resolved_model_path_with_subfolder # Update path to include subfolder
3108
+ logger.info(f"Using subfolder within resolved path: {resolved_model_path}")
3109
 
3110
 
3111
  # --- 2. Load the main configuration ---
3112
  if not isinstance(config, PretrainedConfig):
3113
+ # Load config from the potentially subfolder-adjusted path
3114
  config_path = config if config is not None else resolved_model_path
3115
  try:
3116
  loaded_config, model_kwargs = SparkTTSConfig.from_pretrained(
3117
+ config_path, # Load from the final resolved path
3118
+ *model_args,
3119
  cache_dir=cache_dir,
3120
  force_download=force_download if not is_local else False,
3121
  local_files_only=local_files_only or is_local,
3122
  token=token,
3123
+ revision=revision, # Pass revision for config loading too
3124
+ trust_remote_code=trust_remote_code,
3125
+ subfolder="", # Config is expected at the root of resolved_model_path
3126
  return_unused_kwargs=True,
3127
+ **kwargs,
3128
  )
3129
  config = loaded_config
3130
+ kwargs = model_kwargs
3131
  except OSError as e:
3132
+ raise OSError(f"Cannot load config from {config_path}. Check `config.json` exists and is correctly formatted. Error: {e}")
 
3133
 
3134
  # --- Determine final torch_dtype ---
3135
+ final_torch_dtype = torch_dtype
3136
  if final_torch_dtype == "auto":
3137
+ final_torch_dtype = getattr(config, "torch_dtype", None)
 
3138
  if isinstance(final_torch_dtype, str) and final_torch_dtype != "auto":
3139
  try:
3140
  final_torch_dtype = getattr(torch, final_torch_dtype)
3141
  except AttributeError:
3142
  logger.warning(f"Invalid torch_dtype string: {final_torch_dtype}. Falling back to default.")
3143
+ final_torch_dtype = None
3144
  elif final_torch_dtype == "auto":
3145
+ final_torch_dtype = None
3146
 
3147
+ # --- Helper function to resolve component paths relative to the final resolved_model_path ---
 
3148
  def _resolve_sub_path(sub_path_str):
3149
  p = Path(sub_path_str)
3150
  if p.is_absolute():
3151
  if not p.exists(): logger.warning(f"Absolute path specified for sub-component does not exist: {p}")
3152
  return str(p)
3153
  else:
3154
+ # Resolve relative to the potentially subfolder-adjusted main model path
3155
  resolved = resolved_model_path / p
3156
  if not resolved.exists():
 
3157
  resolved_alt = resolved_model_path / sub_path_str.lstrip('./')
3158
  if resolved_alt.exists():
3159
  resolved = resolved_alt
 
3167
  "force_download": force_download,
3168
  "local_files_only": local_files_only,
3169
  "token": token,
3170
+ "revision": revision, # Pass revision to component loaders
3171
+ "trust_remote_code": trust_remote_code,
3172
+ "torch_dtype": final_torch_dtype,
3173
  "use_safetensors": use_safetensors,
 
3174
  "quantization_config": quantization_config if quantization_config else None,
 
3175
  "variant": variant,
3176
+ **kwargs, # Pass remaining kwargs
 
3177
  }
3178
 
3179
  # --- 3. Load Sub-components ---
3180
+ # (LLM, Wav2Vec2, BiCodec loading logic remains the same as previous version)
3181
  # --- Load LLM ---
3182
  llm_path = _resolve_sub_path(config.llm_model_name_or_path)
3183
  logger.info(f"Loading LLM from resolved path: {llm_path}")
3184
  try:
3185
+ # Pass subfolder="" because llm_path is now absolute or correctly relative
3186
  llm = AutoModelForCausalLM.from_pretrained(
3187
+ llm_path, subfolder="", **component_loading_kwargs
3188
  )
3189
  except Exception as e:
3190
  raise OSError(f"Failed to load LLM from {llm_path}: {e}")
 
3193
  w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path)
3194
  logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}")
3195
  try:
3196
+ # Load extractor without full component_loading_kwargs if they cause issues
3197
  wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained(
3198
  w2v_path,
3199
+ cache_dir=cache_dir,
3200
  force_download=force_download,
3201
  local_files_only=local_files_only,
3202
  token=token,
3203
  revision=revision,
3204
+ subfolder="", # Path is resolved
3205
  )
3206
+ # Load model with full kwargs
3207
  wav2vec2_model = Wav2Vec2Model.from_pretrained(
3208
+ w2v_path, subfolder="", **component_loading_kwargs
3209
  )
3210
+ wav2vec2_model.config.output_hidden_states = True
3211
  except Exception as e:
3212
  raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}")
3213
 
3214
  # --- Load BiCodec ---
3215
  bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path)
3216
  logger.info(f"Loading BiCodec from resolved path: {bicodec_path}")
3217
+ if not config.bicodec_config:
3218
+ raise ValueError("BiCodec configuration (`bicodec_config`) not found in SparkTTSConfig.")
3219
  try:
 
3220
  bicodec = BiCodec.load_from_config_and_checkpoint(
3221
  model_dir=Path(bicodec_path),
3222
+ bicodec_config_object=config.bicodec_config
3223
  )
3224
  if not isinstance(bicodec, torch.nn.Module):
3225
  logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module.")
 
3226
  if isinstance(bicodec, torch.nn.Module) and final_torch_dtype:
3227
  bicodec = bicodec.to(dtype=final_torch_dtype)
 
3228
  except FileNotFoundError as e:
3229
+ raise OSError(f"Failed to load BiCodec: Required file not found in {bicodec_path}. Error: {e}")
3230
  except Exception as e:
3231
  logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}")
3232
  import traceback
3233
  traceback.print_exc()
3234
+ raise OSError(f"Failed to load BiCodec from {bicodec_path}. Error: {e}")
3235
+
3236
 
3237
  # --- 4. Instantiate the main model wrapper ---
3238
  model = cls(
 
3244
  )
3245
 
3246
  # --- 5. Handle device placement (Simplified) ---
 
3247
  if torch.cuda.is_available():
3248
  final_device = torch.device("cuda")
3249
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
3250
  final_device = torch.device("mps")
3251
  else:
3252
  final_device = torch.device("cpu")
 
3253
  logger.info(f"Placing SparkTTSModel and components on device: {final_device}")
3254
  try:
3255
  model.to(final_device)
3256
  except Exception as e:
3257
  logger.error(f"Failed to move model to device {final_device}. Error: {e}")
 
3258
 
3259
  # --- 6. Return the loaded and prepared model ---
3260
  return model