DiffuSynthV0.2 / metrics /visualizations.py
WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
raw
history blame
4.88 kB
import numpy as np
from matplotlib import pyplot as plt
from scipy.fft import fft
from scipy.signal import savgol_filter
from tools import rms_normalize
colors = [
# (0, 0, 0), # Black
# (86, 180, 233), # Sky blue
# (240, 228, 66), # Yellow
# (204, 121, 167), # Reddish purple
(213, 94, 0), # Vermilion
(0, 114, 178), # Blue
(230, 159, 0), # Orange
(0, 158, 115), # Bluish green
]
def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500,
figsize=(10, 6), save_path=None, normalize=False):
"""
在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。
参数:
signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array
labels_list: 每组音频信号对应的标签字符串列表
sample_rate: 音频的采样率
"""
# 确保传入的signals_list和labels_list长度相同
assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。"
signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list]
# 绘图准备
plt.figure(figsize=figsize)
# 遍历所有的音频信号
i = 0
for signal, label in zip(signals_list, labels_list):
# 计算FFT
fft_signal = fft(signal, axis=1)
# 计算平均功率谱密度
psd_signal = np.mean(np.abs(fft_signal)**2, axis=0)
# 计算频率轴
freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate)
# 应用Savitzky-Golay滤波器进行平滑
psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3
# Normalize each curve if normalize is True
if normalize:
psd_smoothed /= np.mean(psd_smoothed)
# 绘制每组信号的功率谱密度
plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
i += 1
# 设置图表元素
plt.xlabel('Frequency (Hz)')
plt.ylabel('Mean Log-Amplitude')
plt.legend()
# 根据save_path参数决定保存图像还是直接显示
if save_path:
plt.savefig(save_path)
else:
plt.show()
def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500,
figsize=(10, 6), save_path=None, normalize=False, start_time=0):
"""
Plot the loudness of multiple sets of audio signals over time on the same graph,
using a logarithmic scale for the loudness axis (base 2), with smoothing applied.
Parameters:
signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length]
labels_list: List of labels corresponding to each set of audio signals
sample_rate: Sampling rate of the audio
window_size: Window size for the Savitzky-Golay filter
figsize: Figure size
save_path: Path to save the figure, if None, the figure will be displayed
normalize: Whether to normalize each curve so that the sum of each curve is the same
start_time: Time (in seconds) to start plotting, only data after this time will be retained
"""
assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \
f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}"
# Compute starting sample index
start_sample = int(start_time * sample_rate)
# Normalize signals and truncate data
signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list]
time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate
plt.figure(figsize=figsize)
i = 0
for signal, label in zip(signals_list, labels_list):
amplitude_mean = np.mean(np.abs(signal), axis=0)
amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3)
# Normalize each curve if normalize is True
if normalize:
amplitude_smoothed /= np.mean(amplitude_smoothed)
plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1)
i += 1
plt.xlabel('Time (seconds)')
plt.ylabel('Mean Log-Amplitude')
plt.legend()
# Save or show the figure based on save_path parameter
if save_path:
plt.savefig(save_path)
else:
plt.show()