KeiKinn
commited on
Commit
·
d4be371
1
Parent(s):
74dd5ba
evaluation instruction
Browse files- .gitignore +142 -0
- README.md +78 -0
- eval.py +61 -0
- models_xin.py +68 -0
- requirements.txt +5 -0
- utils.py +100 -0
- wrapper.py +23 -0
.gitignore
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Users
|
2 |
+
*_.py
|
3 |
+
*.pth.tar
|
4 |
+
temp
|
5 |
+
slurm*
|
6 |
+
.envrc
|
7 |
+
__pycache__/*
|
8 |
+
outputs/*
|
9 |
+
templates/*
|
10 |
+
sample
|
11 |
+
.idea
|
12 |
+
.vscode
|
13 |
+
main.py
|
14 |
+
# Byte-compiled / optimized / DLL files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
|
19 |
+
# C extensions
|
20 |
+
*.so
|
21 |
+
|
22 |
+
# Distribution / packaging
|
23 |
+
.Python
|
24 |
+
build/
|
25 |
+
develop-eggs/
|
26 |
+
dist/
|
27 |
+
downloads/
|
28 |
+
eggs/
|
29 |
+
.eggs/
|
30 |
+
lib/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
pip-wheel-metadata/
|
37 |
+
share/python-wheels/
|
38 |
+
*.egg-info/
|
39 |
+
.installed.cfg
|
40 |
+
*.egg
|
41 |
+
MANIFEST
|
42 |
+
|
43 |
+
# PyInstaller
|
44 |
+
# Usually these files are written by a python script from a template
|
45 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
46 |
+
*.manifest
|
47 |
+
*.spec
|
48 |
+
|
49 |
+
# Installer logs
|
50 |
+
pip-log.txt
|
51 |
+
pip-delete-this-directory.txt
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.nox/
|
57 |
+
.coverage
|
58 |
+
.coverage.*
|
59 |
+
.cache
|
60 |
+
nosetests.xml
|
61 |
+
coverage.xml
|
62 |
+
*.cover
|
63 |
+
*.py,cover
|
64 |
+
.hypothesis/
|
65 |
+
.pytest_cache/
|
66 |
+
|
67 |
+
# Translations
|
68 |
+
*.mo
|
69 |
+
*.pot
|
70 |
+
|
71 |
+
# Django stuff:
|
72 |
+
*.log
|
73 |
+
local_settings.py
|
74 |
+
db.sqlite3
|
75 |
+
db.sqlite3-journal
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
target/
|
89 |
+
|
90 |
+
# Jupyter Notebook
|
91 |
+
.ipynb_checkpoints
|
92 |
+
|
93 |
+
# IPython
|
94 |
+
profile_default/
|
95 |
+
ipython_config.py
|
96 |
+
|
97 |
+
# pyenv
|
98 |
+
.python-version
|
99 |
+
|
100 |
+
# pipenv
|
101 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
102 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
103 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
104 |
+
# install all needed dependencies.
|
105 |
+
#Pipfile.lock
|
106 |
+
|
107 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
108 |
+
__pypackages__/
|
109 |
+
|
110 |
+
# Celery stuff
|
111 |
+
celerybeat-schedule
|
112 |
+
celerybeat.pid
|
113 |
+
|
114 |
+
# SageMath parsed files
|
115 |
+
*.sage.py
|
116 |
+
|
117 |
+
# Environments
|
118 |
+
.env
|
119 |
+
.venv
|
120 |
+
env/
|
121 |
+
venv/
|
122 |
+
ENV/
|
123 |
+
env.bak/
|
124 |
+
venv.bak/
|
125 |
+
|
126 |
+
# Spyder project settings
|
127 |
+
.spyderproject
|
128 |
+
.spyproject
|
129 |
+
|
130 |
+
# Rope project settings
|
131 |
+
.ropeproject
|
132 |
+
|
133 |
+
# mkdocs documentation
|
134 |
+
/site
|
135 |
+
|
136 |
+
# mypy
|
137 |
+
.mypy_cache/
|
138 |
+
.dmypy.json
|
139 |
+
dmypy.json
|
140 |
+
|
141 |
+
# Pyre type checker
|
142 |
+
.pyre/
|
README.md
CHANGED
@@ -2,6 +2,84 @@ This repo includes the official PyTorch checkpoint of *ParaCLAP – Towards a ge
|
|
2 |
|
3 |
## Abstract
|
4 |
Contrastive language-audio pretraining (CLAP) has recently emerged as a method for making audio analysis more generalisable. Specifically, CLAP-style models are able to ‘answer’ a diverse set of language queries, extending the capabilities of audio models beyond a closed set of labels. However, CLAP relies on a large set of (audio, query) pairs for pretraining. While such sets are available for general audio tasks, like captioning or sound event detection, there are no datasets with matched audio and text queries for computational paralinguistic (CP) tasks. As a result, the community relies on generic CLAP models trained for general audio with limited success. In the present study, we explore training considerations for ParaCLAP, a CLAP-style model suited to CP, including a novel process for creating audio-language queries. We demonstrate its effectiveness on a set of computational paralinguistic tasks, where it is shown to surpass the performance of open-source state-of-the-art models.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
---
|
6 |
license: cc-by-nc-nd-4.0
|
7 |
language:
|
|
|
2 |
|
3 |
## Abstract
|
4 |
Contrastive language-audio pretraining (CLAP) has recently emerged as a method for making audio analysis more generalisable. Specifically, CLAP-style models are able to ‘answer’ a diverse set of language queries, extending the capabilities of audio models beyond a closed set of labels. However, CLAP relies on a large set of (audio, query) pairs for pretraining. While such sets are available for general audio tasks, like captioning or sound event detection, there are no datasets with matched audio and text queries for computational paralinguistic (CP) tasks. As a result, the community relies on generic CLAP models trained for general audio with limited success. In the present study, we explore training considerations for ParaCLAP, a CLAP-style model suited to CP, including a novel process for creating audio-language queries. We demonstrate its effectiveness on a set of computational paralinguistic tasks, where it is shown to surpass the performance of open-source state-of-the-art models.
|
5 |
+
|
6 |
+
## Instruction
|
7 |
+
Before Evaluation, I would recommand to clone the repo from HuggingFace or [GitHub](https://github.com/KeiKinn/ParaCLAP)
|
8 |
+
### Evaluation
|
9 |
+
```python
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
import librosa
|
13 |
+
from transformers import logging
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
from models_xin import CLAP
|
16 |
+
from utils import compute_similarity
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
logging.set_verbosity_error()
|
21 |
+
ckpt = torch.hub.load_state_dict_from_url(
|
22 |
+
url="https://huggingface.co/KeiKinn/paraclap/resolve/main/best.pth.tar?download=true",
|
23 |
+
map_location="cpu",
|
24 |
+
check_hash=True,
|
25 |
+
)
|
26 |
+
|
27 |
+
text_model = 'bert-base-uncased'
|
28 |
+
audio_model = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
29 |
+
|
30 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
31 |
+
|
32 |
+
candidates = ['happy', 'sad', 'surprise', 'angry'] # free to adapt it to your need
|
33 |
+
wavpath = '[Waveform path]' # single channel wavform
|
34 |
+
|
35 |
+
waveform, sample_rate = librosa.load(wavpath, sr=16000)
|
36 |
+
x = torch.Tensor(waveform)
|
37 |
+
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(text_model)
|
39 |
+
|
40 |
+
candidate_tokens = tokenizer.batch_encode_plus(
|
41 |
+
candidates,
|
42 |
+
padding=True,
|
43 |
+
truncation=True,
|
44 |
+
return_tensors='pt'
|
45 |
+
)
|
46 |
+
|
47 |
+
model = CLAP(
|
48 |
+
speech_name=audio_model,
|
49 |
+
text_name=text_model,
|
50 |
+
embedding_dim=768,
|
51 |
+
)
|
52 |
+
|
53 |
+
model.load_state_dict(ckpt)
|
54 |
+
model.to(device)
|
55 |
+
print(f'Checkpoint is loaded')
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
z = model(
|
60 |
+
x.unsqueeze(0).to(device),
|
61 |
+
candidate_tokens
|
62 |
+
)
|
63 |
+
|
64 |
+
similarity = compute_similarity(z[2], z[0], z[1])
|
65 |
+
prediction = similarity.T.argmax(dim=1)
|
66 |
+
|
67 |
+
result = candidates[prediction]
|
68 |
+
```
|
69 |
+
|
70 |
+
## Citation Info
|
71 |
+
ParaCLAP has been accept at InterSpeech 2024 for presentation.
|
72 |
+
```bash
|
73 |
+
@inproceedings{Jing24_PTA,
|
74 |
+
title = {ParaCLAP – Towards a general language-audio model for computational paralinguistic tasks},
|
75 |
+
author = {Xin Jing and Andreas Triantafyllopoulos and Björn Schuller},
|
76 |
+
year = {2024},
|
77 |
+
booktitle = {Interspeech 2024},
|
78 |
+
pages = {1155--1159},
|
79 |
+
doi = {10.21437/Interspeech.2024-1315},
|
80 |
+
issn = {2958-1796},
|
81 |
+
}
|
82 |
+
```
|
83 |
---
|
84 |
license: cc-by-nc-nd-4.0
|
85 |
language:
|
eval.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from transformers import logging
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from wrapper import EvalWrapper
|
6 |
+
from models_xin import CLAP
|
7 |
+
from utils import compute_similarity
|
8 |
+
import librosa
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
logging.set_verbosity_error()
|
13 |
+
ckpt = torch.hub.load_state_dict_from_url(
|
14 |
+
url="https://huggingface.co/KeiKinn/paraclap/resolve/main/best.pth.tar?download=true",
|
15 |
+
map_location="cpu",
|
16 |
+
check_hash=True,
|
17 |
+
)
|
18 |
+
|
19 |
+
text_model = 'bert-base-uncased'
|
20 |
+
audio_model = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
21 |
+
|
22 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
23 |
+
|
24 |
+
candidates = ['happy', 'sad', 'surprise', 'angry'] # free to adapt it to your need
|
25 |
+
wavpath = '[Waveform path]' # single channel wavform
|
26 |
+
|
27 |
+
waveform, sample_rate = librosa.load(wavpath, sr=16000)
|
28 |
+
x = torch.Tensor(waveform)
|
29 |
+
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(text_model)
|
31 |
+
|
32 |
+
candidate_tokens = tokenizer.batch_encode_plus(
|
33 |
+
candidates,
|
34 |
+
padding=True,
|
35 |
+
truncation=True,
|
36 |
+
return_tensors='pt'
|
37 |
+
)
|
38 |
+
|
39 |
+
model = CLAP(
|
40 |
+
speech_name=audio_model,
|
41 |
+
text_name=text_model,
|
42 |
+
embedding_dim=768,
|
43 |
+
)
|
44 |
+
|
45 |
+
model.load_state_dict(ckpt)
|
46 |
+
model.to(device)
|
47 |
+
print(f'Checkpoint is loaded')
|
48 |
+
model.eval()
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
z = model(
|
52 |
+
x.unsqueeze(0).to(device),
|
53 |
+
candidate_tokens
|
54 |
+
)
|
55 |
+
|
56 |
+
similarity = compute_similarity(z[2], z[0], z[1])
|
57 |
+
prediction = similarity.T.argmax(dim=1)
|
58 |
+
|
59 |
+
result = candidates[prediction]
|
60 |
+
|
61 |
+
print(result)
|
models_xin.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import (
|
6 |
+
AutoModel,
|
7 |
+
Wav2Vec2Model,
|
8 |
+
)
|
9 |
+
|
10 |
+
class Projection(torch.nn.Module):
|
11 |
+
def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
|
12 |
+
super().__init__()
|
13 |
+
self.linear1 = torch.nn.Linear(d_in, d_out, bias=False)
|
14 |
+
self.linear2 = torch.nn.Linear(d_out, d_out, bias=False)
|
15 |
+
self.layer_norm = torch.nn.LayerNorm(d_out)
|
16 |
+
self.drop = torch.nn.Dropout(p)
|
17 |
+
|
18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
19 |
+
embed1 = self.linear1(x)
|
20 |
+
embed2 = self.drop(self.linear2(F.gelu(embed1)))
|
21 |
+
embeds = self.layer_norm(embed1 + embed2)
|
22 |
+
return embeds
|
23 |
+
|
24 |
+
|
25 |
+
class SpeechEncoder(torch.nn.Module):
|
26 |
+
def __init__(self, model_name):
|
27 |
+
super().__init__()
|
28 |
+
self.model_name = model_name
|
29 |
+
self.base = Wav2Vec2Model.from_pretrained(self.model_name)
|
30 |
+
self.hidden_size = self.base.config.hidden_size
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.base(x)['last_hidden_state']
|
34 |
+
x = x.mean(1)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
class TextEncoder(torch.nn.Module):
|
39 |
+
def __init__(self, model_name: str) -> None:
|
40 |
+
super().__init__()
|
41 |
+
self.base = AutoModel.from_pretrained(model_name)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
out = self.base(**x)[0]
|
45 |
+
out = out[:, 0, :].detach() # get CLS token output
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class CLAP(torch.nn.Module):
|
50 |
+
def __init__(self, speech_name: str, text_name: str, embedding_dim: int = 1024):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.audio_branch = SpeechEncoder(model_name=speech_name)
|
54 |
+
|
55 |
+
self.text_branch = TextEncoder(model_name=text_name)
|
56 |
+
self.audio_projection = Projection(self.audio_branch.hidden_size, embedding_dim)
|
57 |
+
self.text_projection = Projection(self.text_branch.base.config.hidden_size, embedding_dim)
|
58 |
+
|
59 |
+
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
60 |
+
|
61 |
+
def forward(self, audio, text):
|
62 |
+
speech_emb = self.audio_branch(audio)
|
63 |
+
text_emb = self.text_branch(text)
|
64 |
+
|
65 |
+
speech_emb = self.audio_projection(speech_emb)
|
66 |
+
text_emb = self.text_projection(text_emb)
|
67 |
+
|
68 |
+
return text_emb, speech_emb, self.logit_scale.exp()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audformat
|
2 |
+
audmetric
|
3 |
+
audtorch
|
4 |
+
torch
|
5 |
+
transformers==4.25.1
|
utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import collections
|
4 |
+
|
5 |
+
def compute_similarity(logit_scale, audio_embeddings, text_embeddings):
|
6 |
+
r"""Compute similarity between text and audio embeddings"""
|
7 |
+
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
8 |
+
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
9 |
+
|
10 |
+
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
11 |
+
return similarity.T
|
12 |
+
|
13 |
+
def compute_logit(logit_scale, audio_embeddings, text_embeddings):
|
14 |
+
logits_per_audio = logit_scale * audio_embeddings @ text_embeddings.T
|
15 |
+
logits_per_text = logit_scale * text_embeddings @ audio_embeddings.T
|
16 |
+
return logits_per_audio, logits_per_text
|
17 |
+
|
18 |
+
def laion_compute_similarity(logit_scale, audio_embeddings, text_embeddings):
|
19 |
+
r"""Compute similarity between text and audio embeddings"""
|
20 |
+
audio_embeddings = F.normalize(audio_embeddings, dim=-1)
|
21 |
+
text_embeddings = F.normalize(text_embeddings, dim=-1)
|
22 |
+
|
23 |
+
similarity = logit_scale*audio_embeddings @ text_embeddings.T
|
24 |
+
return similarity
|
25 |
+
|
26 |
+
def freeze_branch_parameters(named_parameters, branch_name, freeze_flag):
|
27 |
+
branch_parameters = [
|
28 |
+
p
|
29 |
+
for n, p in named_parameters
|
30 |
+
if branch_name in n
|
31 |
+
]
|
32 |
+
if freeze_flag:
|
33 |
+
print(f"Freezing {branch_name.capitalize()} parameters.")
|
34 |
+
for param in branch_parameters:
|
35 |
+
param.requires_grad = False
|
36 |
+
|
37 |
+
def format_emotion(emotion):
|
38 |
+
if emotion == 'no_agreement':
|
39 |
+
return 'there is no clear emotion.'
|
40 |
+
else:
|
41 |
+
return f'this person is feeling {emotion}.'
|
42 |
+
|
43 |
+
|
44 |
+
def preprocess_text(text_queries, tokenizer):
|
45 |
+
r"""Load list of class labels and return tokenized text"""
|
46 |
+
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
47 |
+
tokenized_texts = []
|
48 |
+
for ttext in text_queries:
|
49 |
+
tok = tokenizer.encode_plus(
|
50 |
+
text=ttext, add_special_tokens=True, max_length=77, padding='max_length', return_tensors="pt")
|
51 |
+
for key in token_keys:
|
52 |
+
tok[key] = tok[key].reshape(-1).cuda()
|
53 |
+
tokenized_texts.append(tok)
|
54 |
+
return default_collate(tokenized_texts)
|
55 |
+
|
56 |
+
def default_collate(batch):
|
57 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
58 |
+
elem = batch[0]
|
59 |
+
elem_type = type(elem)
|
60 |
+
if isinstance(elem, torch.Tensor):
|
61 |
+
out = None
|
62 |
+
if torch.utils.data.get_worker_info() is not None:
|
63 |
+
# If we're in a background process, concatenate directly into a
|
64 |
+
# shared memory tensor to avoid an extra copy
|
65 |
+
numel = sum([x.numel() for x in batch])
|
66 |
+
storage = elem.storage()._new_shared(numel)
|
67 |
+
out = elem.new(storage)
|
68 |
+
return torch.stack(batch, 0, out=out)
|
69 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
70 |
+
and elem_type.__name__ != 'string_':
|
71 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
72 |
+
# array of string classes and object
|
73 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
74 |
+
raise TypeError(
|
75 |
+
default_collate_err_msg_format.format(elem.dtype))
|
76 |
+
|
77 |
+
return default_collate([torch.as_tensor(b) for b in batch])
|
78 |
+
elif elem.shape == (): # scalars
|
79 |
+
return torch.as_tensor(batch)
|
80 |
+
elif isinstance(elem, float):
|
81 |
+
return torch.tensor(batch, dtype=torch.float64)
|
82 |
+
elif isinstance(elem, int):
|
83 |
+
return torch.tensor(batch)
|
84 |
+
elif isinstance(elem, str):
|
85 |
+
return batch
|
86 |
+
elif isinstance(elem, collections.abc.Mapping):
|
87 |
+
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
88 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
89 |
+
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
90 |
+
elif isinstance(elem, collections.abc.Sequence):
|
91 |
+
# check to make sure that the elements in batch have consistent size
|
92 |
+
it = iter(batch)
|
93 |
+
elem_size = len(next(it))
|
94 |
+
if not all(len(elem) == elem_size for elem in it):
|
95 |
+
raise RuntimeError(
|
96 |
+
'each element in list of batch should be of equal size')
|
97 |
+
transposed = zip(*batch)
|
98 |
+
return [default_collate(samples) for samples in transposed]
|
99 |
+
|
100 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
wrapper.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
class EvalWrapper:
|
4 |
+
def __init__(self, dataset_name):
|
5 |
+
self.name = dataset_name.lower()
|
6 |
+
self.evaluate_map = {
|
7 |
+
'iemocap': 'evaluation.evaluate_iemo',
|
8 |
+
'ravdess': 'evaluation.evaluate_ravdess',
|
9 |
+
'cremad-d': 'evaluation.evaluate_cremad',
|
10 |
+
'tess': 'evaluation.evaluate_tess',
|
11 |
+
'aibo': 'evaluation.evaluate_aibo'
|
12 |
+
}
|
13 |
+
|
14 |
+
def set_eval(self):
|
15 |
+
# Get the module path dynamically
|
16 |
+
module_path = self.evaluate_map.get(self.name)
|
17 |
+
if not module_path:
|
18 |
+
supported_datasets = ', '.join(self.evaluate_map.keys())
|
19 |
+
raise ValueError(f"Unsupported dataset name: {self.name}.\nSupported datasets are: {supported_datasets}")
|
20 |
+
|
21 |
+
# Import the evaluate function dynamically
|
22 |
+
evaluate = __import__(module_path, fromlist=['evaluate']).evaluate
|
23 |
+
return self.name, evaluate
|