import os, csv, argparse import sys import torch, torchaudio, timm import numpy as np from torch.cuda.amp import autocast import IPython current_directory = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_directory) from src.models import ASTModel # Create a new class that inherits the original ASTModel class class ASTModelVis(ASTModel): def get_att_map(self, block, x): qkv = block.attn.qkv num_heads = block.attn.num_heads scale = block.attn.scale B, N, C = x.shape qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) return attn def forward_visualization(self, x): # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) x = x.unsqueeze(1) x = x.transpose(2, 3) B = x.shape[0] x = self.v.patch_embed(x) cls_tokens = self.v.cls_token.expand(B, -1, -1) dist_token = self.v.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) x = x + self.v.pos_embed x = self.v.pos_drop(x) # save the attention map of each of 12 Transformer layer att_list = [] for blk in self.v.blocks: cur_att = self.get_att_map(blk, x) att_list.append(cur_att) x = blk(x) return att_list def make_features(wav_name, mel_bins, target_length=1024): waveform, sr = torchaudio.load(wav_name) # assert sr == 16000, 'input audio sampling rate must be 16kHz' fbank = torchaudio.compliance.kaldi.fbank( waveform, htk_compat=True, sample_frequency=sr, use_energy=False, window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10) n_frames = fbank.shape[0] p = target_length - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[0:target_length, :] fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) return fbank def load_label(label_csv): with open(label_csv, 'r') as f: reader = csv.reader(f, delimiter=',') lines = list(reader) labels = [] ids = [] # Each label has a unique id such as "/m/068hy" for i1 in range(1, len(lines)): id = lines[i1][1] label = lines[i1][2] ids.append(id) labels.append(label) return labels def ASTpredict(): # Assume each input spectrogram has 1024 time frames input_tdim = 1024 checkpoint_path = './ast_master/pretrained_models/audio_mdl.pth' # now load the visualization model ast_mdl = ASTModelVis(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False) print(f'[*INFO] load checkpoint: {checkpoint_path}') checkpoint = torch.load(checkpoint_path, map_location='cuda') audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0]) audio_model.load_state_dict(checkpoint) audio_model = audio_model.to(torch.device("cuda:0")) audio_model.eval() # Load the AudioSet label set label_csv = './ast_master/egs/audioset/data/class_labels_indices.csv' # label and indices for audioset data labels = load_label(label_csv) feats = make_features("./audio.flac", mel_bins=128) # shape(1024, 128) feats_data = feats.expand(1, input_tdim, 128) # reshape the feature feats_data = feats_data.to(torch.device("cuda:0")) # do some masking of the input #feats_data[:, :512, :] = 0. # Make the prediction with torch.no_grad(): with autocast(): output = audio_model.forward(feats_data) output = torch.sigmoid(output) result_output = output.data.cpu().numpy()[0] sorted_indexes = np.argsort(result_output)[::-1] # Print audio tagging top probabilities print('Predice results:') for k in range(10): print('- {}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]], result_output[sorted_indexes[k]])) #return the top 10 labels and their probabilities top_labels_probs = {} top_labels = {} for k in range(10): label = np.array(labels)[sorted_indexes[k]] prob = result_output[sorted_indexes[k]] top_labels[k]= label top_labels_probs[k]= prob return top_labels, top_labels_probs