import torch from datasets import load_dataset from transformers import AutoFeatureExtractor import os from tqdm import tqdm import matplotlib.pyplot as plt from pipeline_utils import compute_speaker_stats, plot_reconstruction def main(): dataset = load_dataset( "sanchit-gandhi/voxpopuli_dummy", # "train", split="validation" ) #dataset = load_dataset( # "mythicinfinity/libritts", # "clean", # split="test.clean", # #trust_remote_code=True #) # dataset = load_dataset( # "facebook/voxpopuli", # "en", # split="test" # ) preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor', #trust_remote_code=True ) processed_dataset = dataset.map( lambda x: preprocessor.extract_features(x['audio']['array']), load_from_cache_file=False, # num_proc=4 ) processed_dataset.save_to_disk("processed_dataset") speaker_stats = compute_speaker_stats(processed_dataset) torch.save(speaker_stats, "speaker_stats.pt") from transformers import pipeline embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings", f0_interp=False, f0_normalize=True, speaker_stats=speaker_stats, #trust_remote_code=True ) results = processed_dataset.map( lambda x: embedding_pipeline(x), remove_columns=processed_dataset.column_names, load_from_cache_file=False # num_proc=4 ) results.save_to_disk("embeddings_dataset") print(f"Processed {len(results)} samples") embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k print(f"embedding_codebook.shape", embedding_codebook.shape) embeddings_example = results[0]['codes'][0][0] print("Embeddings example:", embeddings_example) # inspect the embeddings in the codebook as follows # code_point = embeddings_example[0] # print(f"code_point", code_point) # code_point_embedding = embedding_codebook[code_point] # print(f"code_point_embedding", code_point_embedding) # print(f"code_point_embedding.shape", code_point_embedding.shape) # check that they are the same as the hidden states used in the model # hidden_states = np.array(results[0]['hidden_states']) # hidden_state = hidden_states[0, 0, :, 0] # print(f"hidden_state", hidden_state) metrics_list = [result['metrics'] for result in results] avg_metrics = {} for metric in results[0]['metrics'].keys(): values = [m[metric] for m in metrics_list] avg_metrics[metric] = sum(values) / len(values) # print(f"metric", metric) # print(f"len(values)", len(values)) print("\nAverage metrics across dataset:") print(avg_metrics) print(f"Plotting reconstruction curves...") for i in tqdm(range(len(results))): fig = plot_reconstruction(results[i], i) os.makedirs('plots', exist_ok=True) plt.savefig(f'plots/reconstruction_sample{i}.png') plt.close() print(f"Done.") if __name__ == '__main__': main()