Daporte commited on
Commit
9558cab
·
verified ·
1 Parent(s): 4d494d9

Create generate_embeddings.py

Browse files
Files changed (1) hide show
  1. generate_embeddings.py +133 -0
generate_embeddings.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+ from transformers import AutoFeatureExtractor
7
+ import os
8
+ from tqdm import tqdm
9
+ import matplotlib.pyplot as plt
10
+
11
+ from pipeline_utils import compute_speaker_stats, plot_reconstruction
12
+
13
+
14
+
15
+
16
+
17
+ def main():
18
+
19
+ preprocessing_strategy = {
20
+ "norm_mask": {"model_suffix": "", # norm_mask is the base configuration
21
+ "f0_interp": False,
22
+ "f0_normalize": True
23
+ },
24
+ "norm_interp": {"model_suffix": "-norm_interp",
25
+ "f0_interp": True,
26
+ "f0_normalize": True
27
+ },
28
+ "interp": {"model_suffix": "-interp",
29
+ "f0_interp": True,
30
+ "f0_normalize": False
31
+ },
32
+ }
33
+
34
+ selected_strategy = "norm_mask"
35
+
36
+ dataset = load_dataset(
37
+ "patrickvonplaten/librispeech_asr_dummy",
38
+ "clean",
39
+ split="validation"
40
+ )
41
+ # dataset = load_dataset(
42
+ # "mythicinfinity/libritts",
43
+ # "clean",
44
+ # split="test.clean"
45
+ # )
46
+ # dataset = load_dataset(
47
+ # "facebook/voxpopuli",
48
+ # "en",
49
+ # split="test"
50
+ # )
51
+
52
+
53
+ preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor', trust_remote_code=True)
54
+
55
+ processed_dataset = dataset.map(
56
+ lambda x: preprocessor.extract_features(x['audio']['array']),
57
+ load_from_cache_file=False,
58
+ # num_proc=4
59
+ )
60
+
61
+ processed_dataset.save_to_disk("processed_dataset")
62
+
63
+ speaker_stats = compute_speaker_stats(processed_dataset)
64
+ torch.save(speaker_stats, "speaker_stats.pt")
65
+
66
+
67
+ from transformers import pipeline
68
+ embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings_final"+preprocessing_strategy[selected_strategy]["model_suffix"],
69
+ f0_interp=preprocessing_strategy[selected_strategy]['f0_interp'],
70
+ f0_normalize=preprocessing_strategy[selected_strategy]['f0_normalize'],
71
+ speaker_stats=speaker_stats,
72
+ trust_remote_code=True)
73
+
74
+
75
+ results = processed_dataset.map(
76
+ lambda x: embedding_pipeline(x),
77
+ remove_columns=processed_dataset.column_names,
78
+ load_from_cache_file=False
79
+ # num_proc=4
80
+ )
81
+
82
+ results.save_to_disk("embeddings_dataset")
83
+
84
+ print(f"Processed {len(results)} samples")
85
+
86
+ embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k
87
+ print(f"embedding_codebook.shape", embedding_codebook.shape)
88
+
89
+ embeddings_example = results[0]['codes'][0][0]
90
+ print("Embeddings example:", embeddings_example)
91
+
92
+
93
+ # inspect the embeddings in the codebook as follows
94
+
95
+ # code_point = embeddings_example[0]
96
+ # print(f"code_point", code_point)
97
+ # code_point_embedding = embedding_codebook[code_point]
98
+ # print(f"code_point_embedding", code_point_embedding)
99
+ # print(f"code_point_embedding.shape", code_point_embedding.shape)
100
+
101
+
102
+ # check that they are the same as the hidden states used in the model
103
+
104
+ # hidden_states = np.array(results[0]['hidden_states'])
105
+ # hidden_state = hidden_states[0, 0, :, 0]
106
+ # print(f"hidden_state", hidden_state)
107
+
108
+ metrics_list = [result['metrics'] for result in results]
109
+ avg_metrics = {}
110
+
111
+ for metric in results[0]['metrics'].keys():
112
+ values = [m[metric] for m in metrics_list if not isinstance(m[metric], str) ]
113
+ avg_metrics[metric] = sum(values) / len(values)
114
+ # print(f"metric", metric)
115
+ # print(f"len(values)", len(values))
116
+
117
+ print("\nAverage metrics across dataset:")
118
+ print(avg_metrics)
119
+
120
+
121
+ print(f"Plotting reconstruction curves...")
122
+ for i in tqdm(range(len(results))):
123
+ fig = plot_reconstruction(results[i], i)
124
+ os.makedirs('plots', exist_ok=True)
125
+ plt.savefig(f'plots/reconstruction_sample{i}.png')
126
+ plt.close()
127
+ print(f"Done.")
128
+
129
+
130
+ if __name__ == '__main__':
131
+ main()
132
+
133
+