Spaces:
Running
Running
import numpy as np | |
from matplotlib import pyplot as plt | |
from scipy.fft import fft | |
from scipy.signal import savgol_filter | |
from tools import rms_normalize | |
colors = [ | |
# (0, 0, 0), # Black | |
# (86, 180, 233), # Sky blue | |
# (240, 228, 66), # Yellow | |
# (204, 121, 167), # Reddish purple | |
(213, 94, 0), # Vermilion | |
(0, 114, 178), # Blue | |
(230, 159, 0), # Orange | |
(0, 158, 115), # Bluish green | |
] | |
def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500, | |
figsize=(10, 6), save_path=None, normalize=False): | |
""" | |
在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。 | |
参数: | |
signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array | |
labels_list: 每组音频信号对应的标签字符串列表 | |
sample_rate: 音频的采样率 | |
""" | |
# 确保传入的signals_list和labels_list长度相同 | |
assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。" | |
signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list] | |
# 绘图准备 | |
plt.figure(figsize=figsize) | |
# 遍历所有的音频信号 | |
i = 0 | |
for signal, label in zip(signals_list, labels_list): | |
# 计算FFT | |
fft_signal = fft(signal, axis=1) | |
# 计算平均功率谱密度 | |
psd_signal = np.mean(np.abs(fft_signal)**2, axis=0) | |
# 计算频率轴 | |
freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate) | |
# 应用Savitzky-Golay滤波器进行平滑 | |
psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3 | |
# Normalize each curve if normalize is True | |
if normalize: | |
psd_smoothed /= np.mean(psd_smoothed) | |
# 绘制每组信号的功率谱密度 | |
plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) | |
i += 1 | |
# 设置图表元素 | |
plt.xlabel('Frequency (Hz)') | |
plt.ylabel('Mean Log-Amplitude') | |
plt.legend() | |
# 根据save_path参数决定保存图像还是直接显示 | |
if save_path: | |
plt.savefig(save_path) | |
else: | |
plt.show() | |
def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500, | |
figsize=(10, 6), save_path=None, normalize=False, start_time=0): | |
""" | |
Plot the loudness of multiple sets of audio signals over time on the same graph, | |
using a logarithmic scale for the loudness axis (base 2), with smoothing applied. | |
Parameters: | |
signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length] | |
labels_list: List of labels corresponding to each set of audio signals | |
sample_rate: Sampling rate of the audio | |
window_size: Window size for the Savitzky-Golay filter | |
figsize: Figure size | |
save_path: Path to save the figure, if None, the figure will be displayed | |
normalize: Whether to normalize each curve so that the sum of each curve is the same | |
start_time: Time (in seconds) to start plotting, only data after this time will be retained | |
""" | |
assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \ | |
f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}" | |
# Compute starting sample index | |
start_sample = int(start_time * sample_rate) | |
# Normalize signals and truncate data | |
signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list] | |
time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate | |
plt.figure(figsize=figsize) | |
i = 0 | |
for signal, label in zip(signals_list, labels_list): | |
amplitude_mean = np.mean(np.abs(signal), axis=0) | |
amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3) | |
# Normalize each curve if normalize is True | |
if normalize: | |
amplitude_smoothed /= np.mean(amplitude_smoothed) | |
plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) | |
i += 1 | |
plt.xlabel('Time (seconds)') | |
plt.ylabel('Mean Log-Amplitude') | |
plt.legend() | |
# Save or show the figure based on save_path parameter | |
if save_path: | |
plt.savefig(save_path) | |
else: | |
plt.show() | |