F0_Energy_joint_VQVAE_embeddings / prosody_embedding_pipeline.py
Daporte's picture
Update prosody_embedding_pipeline.py
e92e60e verified
raw
history blame
6.72 kB
from transformers import Pipeline
import numpy as np
import torch
from typing import Dict, Union, List, Optional
from pathlib import Path
import logging
from datasets import Dataset
logger = logging.getLogger(__name__)
class ProsodyEmbeddingPipeline(Pipeline):
def __init__(
self,
speaker_stats,
f0_interp,
f0_normalize,
stats_dir: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.stats_dir = Path(stats_dir) if stats_dir else None
self.speaker_stats = speaker_stats
self.f0_interp = f0_interp
self.f0_normalize = f0_normalize
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {
"return_tensors": kwargs.pop("return_tensors", "pt")
}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs: Union[str, Dict, Dataset]) -> Dict:
"""Preprocess inputs"""
spkr_id = inputs['speaker_id']
stats = self.speaker_stats[spkr_id]
if self.f0_interp:
f0 = torch.Tensor(inputs['f0_interp'])
else:
f0 = torch.Tensor(inputs['f0'])
f0_orig = f0.clone() # Save original f0 before normalization
intensity = torch.Tensor(inputs['intensity'])
intensity_orig = intensity.clone() # Save original intensity before normalization
if self.f0_normalize:
ii = f0 != 0
if stats.f0_std != 0:
f0[ii] = (f0[ii] - stats.f0_mean) / stats.f0_std
intensity_ii = intensity != 0
if stats.intensity_std != 0:
intensity[intensity_ii] = (intensity[intensity_ii] - stats.intensity_mean) / stats.intensity_std
if not self.f0_interp:
zero_indices = f0 == 0
zero_mask = zero_indices * 1.0
inputs = {
'f0': f0,
'intensity': intensity,
'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None,
'f0_mean': stats.f0_mean,
'f0_std': stats.f0_std,
'intensity_mean': stats.intensity_mean,
'intensity_std': stats.intensity_std,
'f0_orig': f0_orig, # original features before normalization
'intensity_orig': intensity_orig, # original features
'speaker_id': spkr_id,
}
return inputs
def _forward(self, features: Dict) -> Dict:
"""Run the model on the preprocessed features"""
self.model.eval()
f0 = torch.Tensor(features['f0'])
intensity = torch.Tensor(features['intensity'])
if self.f0_interp:
stacked_features = torch.stack([f0, intensity], dim=0).to(self.model.device)
else:
zero_mask = torch.Tensor(features['zero_mask'])
stacked_features = torch.stack([f0, intensity, zero_mask], dim=0).to(self.model.device)
stacked_features = stacked_features.unsqueeze(0)
with torch.no_grad():
model_outputs = self.model(features=stacked_features)
outputs = {
**model_outputs,
'input_features': {
'zero_mask': zero_mask if self.f0_normalize and not self.f0_interp else None,
'f0_orig': features['f0_orig'],
'f0_mean': features['f0_mean'],
'f0_std': features['f0_std'],
'intensity_mean': features['intensity_mean'],
'intensity_std': features['intensity_std'],
'intensity_orig': features['intensity_orig']
}
}
return outputs
def postprocess(self, outputs: Dict, return_tensors: str = "pt") -> Dict:
"""Convert outputs to the desired format and calculate metrics"""
input_f0 = outputs['input_features']['f0_orig']
output_f0 = outputs['f0'][0,:,:]
f0_recon = output_f0
# revert normalization
if self.f0_normalize:
f0_recon[0] = f0_recon[0] * outputs['input_features']["f0_std"] + outputs['input_features']["f0_mean"]
if not self.f0_interp:
mask = torch.where(f0_recon[2] < 0.5, torch.tensor([1.0]), torch.tensor([0.0]))
f0_recon[0] = (f0_recon[0] * mask)
f0_recon[1] = f0_recon[1] * outputs['input_features']["intensity_std"] + outputs['input_features']["intensity_mean"]
epsilon = 1e-10
DIFF_THRESHOLD = 0.2
# F0 metrics calculation
input_f0_np = input_f0.cpu().numpy()
output_f0_np = f0_recon[0].cpu().numpy() # Use f0_recon[0] instead of output_f0
# Truncate both arrays to multiple of 16
length = len(input_f0_np)
truncated_length = (length // 16) * 16
input_f0_np = input_f0_np[:truncated_length]
output_f0_np = output_f0_np[:truncated_length]
input_f0_safe = np.where(input_f0_np == 0, epsilon, input_f0_np)
rel_diff = np.abs(input_f0_np - output_f0_np) / np.abs(input_f0_safe)
diff_points = rel_diff > DIFF_THRESHOLD
diff_count = np.sum(diff_points)
total_points = len(input_f0_np)
f0_large_diff_percent = (diff_count / total_points) * 100
# intensity metrics calculation
input_intensity_np = outputs['input_features']['intensity_orig'].cpu().numpy()
output_intensity_np = f0_recon[1].cpu().numpy()
# Truncate intensity arrays to multiple of 16
length = len(input_intensity_np)
truncated_length = (length // 16) * 16
input_intensity_np = input_intensity_np[:truncated_length]
output_intensity_np = output_intensity_np[:truncated_length]
intensity_rmse = np.sqrt(np.mean((input_intensity_np - output_intensity_np) ** 2))
outputs['f0_recon'] = output_f0_np
outputs['intensity_recon'] = output_intensity_np
# Add metrics to outputs
outputs['metrics'] = {
'f0_large_diff_percent': f0_large_diff_percent.item(),
'intensity_rmse': float(intensity_rmse)
}
print(f"outputs['metrics']", outputs['metrics'])
if return_tensors == "np":
outputs = {
k: v.cpu().numpy() if torch.is_tensor(v) else v
for k, v in outputs.items()
}
return outputs