from transformers import Pipeline import numpy as np import torch from typing import Dict, Union, List, Optional from pathlib import Path import logging from datasets import Dataset logger = logging.getLogger(__name__) class ProsodyEmbeddingPipeline(Pipeline): def __init__( self, speaker_stats, f0_interp, f0_normalize, stats_dir: Optional[str] = None, **kwargs ): super().__init__(**kwargs) self.stats_dir = Path(stats_dir) if stats_dir else None self.speaker_stats = speaker_stats self.f0_interp = f0_interp self.f0_normalize = f0_normalize def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} forward_kwargs = {} postprocess_kwargs = { "return_tensors": kwargs.pop("return_tensors", "pt") } return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess(self, inputs: Union[str, Dict, Dataset]) -> Dict: """Preprocess inputs""" spkr_id = inputs['speaker_id'] stats = self.speaker_stats[spkr_id] if self.f0_interp: f0 = torch.Tensor(inputs['f0_interp']) else: f0 = torch.Tensor(inputs['f0']) f0_orig = f0.clone() # Save original f0 before normalization intensity = torch.Tensor(inputs['intensity']) intensity_orig = intensity.clone() # Save original intensity before normalization if self.f0_normalize: ii = f0 != 0 if stats.f0_std != 0: f0[ii] = (f0[ii] - stats.f0_mean) / stats.f0_std intensity_ii = intensity != 0 if stats.intensity_std != 0: intensity[intensity_ii] = (intensity[intensity_ii] - stats.intensity_mean) / stats.intensity_std if not self.f0_interp: zero_indices = f0 == 0 zero_mask = zero_indices * 1.0 inputs = { 'f0': f0, 'intensity': intensity, 'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None, 'f0_mean': stats.f0_mean, 'f0_std': stats.f0_std, 'intensity_mean': stats.intensity_mean, 'intensity_std': stats.intensity_std, 'f0_orig': f0_orig, # original features before normalization 'intensity_orig': intensity_orig, # original features 'speaker_id': spkr_id, } return inputs def _forward(self, features: Dict) -> Dict: """Run the model on the preprocessed features""" self.model.eval() f0 = torch.Tensor(features['f0']) intensity = torch.Tensor(features['intensity']) if self.f0_interp: stacked_features = torch.stack([f0, intensity], dim=0).to(self.model.device) else: zero_mask = torch.Tensor(features['zero_mask']) stacked_features = torch.stack([f0, intensity, zero_mask], dim=0).to(self.model.device) stacked_features = stacked_features.unsqueeze(0) with torch.no_grad(): model_outputs = self.model(features=stacked_features) outputs = { **model_outputs, 'input_features': { 'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None, 'f0_orig': features['f0_orig'], 'f0_mean': features['f0_mean'], 'f0_std': features['f0_std'], 'intensity_mean': features['intensity_mean'], 'intensity_std': features['intensity_std'], 'intensity_orig': features['intensity_orig'] } } return outputs def postprocess(self, outputs: Dict, return_tensors: str = "pt") -> Dict: """Convert outputs to the desired format and calculate metrics""" input_f0 = outputs['input_features']['f0_orig'] output_f0 = outputs['f0'][0,:,:] f0_recon = output_f0 # revert normalization if self.f0_normalize: f0_recon[0] = f0_recon[0] * outputs['input_features']["f0_std"] + outputs['input_features']["f0_mean"] if not self.f0_interp: mask = torch.where(f0_recon[2] < 0.5, torch.tensor([1.0]), torch.tensor([0.0])) f0_recon[0] = (f0_recon[0] * mask) f0_recon[1] = f0_recon[1] * outputs['input_features']["intensity_std"] + outputs['input_features']["intensity_mean"] epsilon = 1e-10 DIFF_THRESHOLD = 0.2 # F0 metrics calculation input_f0_np = input_f0.cpu().numpy() output_f0_np = f0_recon[0].cpu().numpy() # Use f0_recon[0] instead of output_f0 # Truncate both arrays to multiple of 16 length = len(input_f0_np) truncated_length = (length // 16) * 16 input_f0_np = input_f0_np[:truncated_length] output_f0_np = output_f0_np[:truncated_length] input_f0_safe = np.where(input_f0_np == 0, epsilon, input_f0_np) rel_diff = np.abs(input_f0_np - output_f0_np) / np.abs(input_f0_safe) diff_points = rel_diff > DIFF_THRESHOLD diff_count = np.sum(diff_points) total_points = len(input_f0_np) f0_large_diff_percent = (diff_count / total_points) * 100 # intensity metrics calculation input_intensity_np = outputs['input_features']['intensity_orig'].cpu().numpy() output_intensity_np = f0_recon[1].cpu().numpy() # Truncate intensity arrays to multiple of 16 length = len(input_intensity_np) truncated_length = (length // 16) * 16 input_intensity_np = input_intensity_np[:truncated_length] output_intensity_np = output_intensity_np[:truncated_length] intensity_rmse = np.sqrt(np.mean((input_intensity_np - output_intensity_np) ** 2)) outputs['f0_recon'] = output_f0_np outputs['intensity_recon'] = output_intensity_np # Add metrics to outputs outputs['metrics'] = { 'f0_large_diff_percent': f0_large_diff_percent.item(), 'intensity_rmse': float(intensity_rmse) } print(f"outputs['metrics']", outputs['metrics']) if return_tensors == "np": outputs = { k: v.cpu().numpy() if torch.is_tensor(v) else v for k, v in outputs.items() } return outputs