LandyGuo
commited on
Commit
·
e92022a
1
Parent(s):
736eafa
First model version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- __pycache__/bailingmm_utils.cpython-38.pyc +0 -0
- __pycache__/chat_format.cpython-38.pyc +0 -0
- __pycache__/configuration_audio.cpython-38.pyc +0 -0
- __pycache__/configuration_bailing_moe.cpython-38.pyc +0 -0
- __pycache__/configuration_bailing_talker.cpython-38.pyc +0 -0
- __pycache__/configuration_bailingmm.cpython-38.pyc +0 -0
- __pycache__/image_processing_bailingmm.cpython-38.pyc +0 -0
- __pycache__/modeling_bailing_moe.cpython-38.pyc +0 -0
- __pycache__/modeling_bailing_talker.cpython-38.pyc +0 -0
- __pycache__/modeling_bailingmm.cpython-38.pyc +0 -0
- __pycache__/modeling_utils.cpython-38.pyc +0 -0
- __pycache__/qwen2_5_vit.cpython-38.pyc +0 -0
- __pycache__/s3bpe_tokenizer.cpython-38.pyc +0 -0
- am.mvn +8 -0
- audio_detokenizer/__init__.py +0 -0
- audio_detokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_detokenizer/cli/__init__.py +0 -0
- audio_detokenizer/cli/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_detokenizer/cli/__pycache__/model.cpython-38.pyc +0 -0
- audio_detokenizer/cli/model.py +62 -0
- audio_detokenizer/flow/__init__.py +0 -0
- audio_detokenizer/flow/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_detokenizer/flow/__pycache__/decoder.cpython-38.pyc +0 -0
- audio_detokenizer/flow/__pycache__/flow.cpython-38.pyc +0 -0
- audio_detokenizer/flow/__pycache__/flow_matching.cpython-38.pyc +0 -0
- audio_detokenizer/flow/__pycache__/length_regulator.cpython-38.pyc +0 -0
- audio_detokenizer/flow/decoder.py +224 -0
- audio_detokenizer/flow/flow.py +148 -0
- audio_detokenizer/flow/flow_matching.py +242 -0
- audio_detokenizer/flow/length_regulator.py +49 -0
- audio_detokenizer/hifigan/__init__.py +0 -0
- audio_detokenizer/hifigan/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_detokenizer/hifigan/__pycache__/f0_predictor.cpython-38.pyc +0 -0
- audio_detokenizer/hifigan/__pycache__/generator.cpython-38.pyc +0 -0
- audio_detokenizer/hifigan/f0_predictor.py +55 -0
- audio_detokenizer/hifigan/generator.py +392 -0
- audio_detokenizer/transformer/__init__.py +0 -0
- audio_detokenizer/transformer/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/activation.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/attention.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/convolution.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/embedding.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/encoder.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/encoder_layer.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/__pycache__/subsampling.cpython-38.pyc +0 -0
- audio_detokenizer/transformer/activation.py +84 -0
- audio_detokenizer/transformer/attention.py +463 -0
- audio_detokenizer/transformer/convolution.py +145 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/matcha_tts-0.0.5.1-cp38-cp38-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/wavs/BAC009S0915W0292.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
out.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
talker/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
|
41 |
+
|
42 |
+
|
__pycache__/bailingmm_utils.cpython-38.pyc
ADDED
Binary file (13.5 kB). View file
|
|
__pycache__/chat_format.cpython-38.pyc
ADDED
Binary file (21.9 kB). View file
|
|
__pycache__/configuration_audio.cpython-38.pyc
ADDED
Binary file (1.03 kB). View file
|
|
__pycache__/configuration_bailing_moe.cpython-38.pyc
ADDED
Binary file (1.99 kB). View file
|
|
__pycache__/configuration_bailing_talker.cpython-38.pyc
ADDED
Binary file (1.14 kB). View file
|
|
__pycache__/configuration_bailingmm.cpython-38.pyc
ADDED
Binary file (1.2 kB). View file
|
|
__pycache__/image_processing_bailingmm.cpython-38.pyc
ADDED
Binary file (16.9 kB). View file
|
|
__pycache__/modeling_bailing_moe.cpython-38.pyc
ADDED
Binary file (48.1 kB). View file
|
|
__pycache__/modeling_bailing_talker.cpython-38.pyc
ADDED
Binary file (8.19 kB). View file
|
|
__pycache__/modeling_bailingmm.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
__pycache__/modeling_utils.cpython-38.pyc
ADDED
Binary file (26.3 kB). View file
|
|
__pycache__/qwen2_5_vit.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
__pycache__/s3bpe_tokenizer.cpython-38.pyc
ADDED
Binary file (2.24 kB). View file
|
|
am.mvn
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<Nnet>
|
2 |
+
<Splice> 560 560
|
3 |
+
[ 0 ]
|
4 |
+
<AddShift> 560 560
|
5 |
+
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
|
6 |
+
<Rescale> 560 560
|
7 |
+
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
|
8 |
+
</Nnet>
|
audio_detokenizer/__init__.py
ADDED
File without changes
|
audio_detokenizer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (178 Bytes). View file
|
|
audio_detokenizer/cli/__init__.py
ADDED
File without changes
|
audio_detokenizer/cli/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (182 Bytes). View file
|
|
audio_detokenizer/cli/__pycache__/model.cpython-38.pyc
ADDED
Binary file (2.02 kB). View file
|
|
audio_detokenizer/cli/model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
import time
|
17 |
+
|
18 |
+
class AudioDetokenizerModel:
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
flow: torch.nn.Module,
|
22 |
+
hift: torch.nn.Module,
|
23 |
+
lora_config=None):
|
24 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
self.flow = flow
|
26 |
+
self.hift = hift
|
27 |
+
self.dtype = torch.float16
|
28 |
+
# self.dtype = torch.bfloat16
|
29 |
+
self.max_seq_short = 384
|
30 |
+
self.max_seq_long = 2048
|
31 |
+
self.max_batch = 1
|
32 |
+
|
33 |
+
def load(self, flow_model, hift_model):
|
34 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
35 |
+
self.flow.to(self.device).eval().to(self.dtype)
|
36 |
+
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
37 |
+
self.hift.to(self.device).eval()
|
38 |
+
|
39 |
+
def inference(self, flow_embedding, tts_speech_token,
|
40 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
41 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), is_en=False):
|
42 |
+
|
43 |
+
torch.cuda.synchronize()
|
44 |
+
t0 = time.time()
|
45 |
+
|
46 |
+
torch.cuda.synchronize()
|
47 |
+
t1 = time.time()
|
48 |
+
|
49 |
+
tts_mel = self.flow.inference(token=tts_speech_token.to(self.device),
|
50 |
+
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
51 |
+
prompt_token=flow_prompt_speech_token.to(self.device),
|
52 |
+
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
53 |
+
prompt_feat=prompt_speech_feat.to(self.device),
|
54 |
+
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
55 |
+
embedding=flow_embedding.to(self.device).to(self.dtype)).float()
|
56 |
+
torch.cuda.synchronize()
|
57 |
+
|
58 |
+
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
59 |
+
torch.cuda.synchronize()
|
60 |
+
dur = tts_speech.shape[-1]/22050
|
61 |
+
torch.cuda.empty_cache()
|
62 |
+
return {'tts_speech': tts_speech}
|
audio_detokenizer/flow/__init__.py
ADDED
File without changes
|
audio_detokenizer/flow/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (183 Bytes). View file
|
|
audio_detokenizer/flow/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (5.28 kB). View file
|
|
audio_detokenizer/flow/__pycache__/flow.cpython-38.pyc
ADDED
Binary file (4.15 kB). View file
|
|
audio_detokenizer/flow/__pycache__/flow_matching.cpython-38.pyc
ADDED
Binary file (6.1 kB). View file
|
|
audio_detokenizer/flow/__pycache__/length_regulator.cpython-38.pyc
ADDED
Binary file (1.48 kB). View file
|
|
audio_detokenizer/flow/decoder.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# antflake8: noqa
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from einops import pack, rearrange, repeat
|
18 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
19 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
20 |
+
|
21 |
+
|
22 |
+
class ConditionalDecoder(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
in_channels,
|
26 |
+
out_channels,
|
27 |
+
channels=(256, 256),
|
28 |
+
dropout=0.05,
|
29 |
+
attention_head_dim=64,
|
30 |
+
n_blocks=1,
|
31 |
+
num_mid_blocks=2,
|
32 |
+
num_heads=4,
|
33 |
+
act_fn="snake",
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
37 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
channels = tuple(channels)
|
41 |
+
self.in_channels = in_channels
|
42 |
+
self.out_channels = out_channels
|
43 |
+
|
44 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
45 |
+
time_embed_dim = channels[0] * 4
|
46 |
+
self.time_mlp = TimestepEmbedding(
|
47 |
+
in_channels=in_channels,
|
48 |
+
time_embed_dim=time_embed_dim,
|
49 |
+
act_fn="silu",
|
50 |
+
)
|
51 |
+
self.down_blocks = nn.ModuleList([])
|
52 |
+
self.mid_blocks = nn.ModuleList([])
|
53 |
+
self.up_blocks = nn.ModuleList([])
|
54 |
+
self.compiled_infer = None
|
55 |
+
|
56 |
+
output_channel = in_channels
|
57 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
58 |
+
input_channel = output_channel
|
59 |
+
output_channel = channels[i]
|
60 |
+
is_last = i == len(channels) - 1
|
61 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
62 |
+
transformer_blocks = nn.ModuleList(
|
63 |
+
[
|
64 |
+
BasicTransformerBlock(
|
65 |
+
dim=output_channel,
|
66 |
+
num_attention_heads=num_heads,
|
67 |
+
attention_head_dim=attention_head_dim,
|
68 |
+
dropout=dropout,
|
69 |
+
activation_fn=act_fn,
|
70 |
+
)
|
71 |
+
for _ in range(n_blocks)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
downsample = (
|
75 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
76 |
+
)
|
77 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
78 |
+
|
79 |
+
for i in range(num_mid_blocks):
|
80 |
+
input_channel = channels[-1]
|
81 |
+
out_channels = channels[-1]
|
82 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
83 |
+
|
84 |
+
transformer_blocks = nn.ModuleList(
|
85 |
+
[
|
86 |
+
BasicTransformerBlock(
|
87 |
+
dim=output_channel,
|
88 |
+
num_attention_heads=num_heads,
|
89 |
+
attention_head_dim=attention_head_dim,
|
90 |
+
dropout=dropout,
|
91 |
+
activation_fn=act_fn,
|
92 |
+
)
|
93 |
+
for _ in range(n_blocks)
|
94 |
+
]
|
95 |
+
)
|
96 |
+
|
97 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
98 |
+
|
99 |
+
channels = channels[::-1] + (channels[0],)
|
100 |
+
for i in range(len(channels) - 1):
|
101 |
+
input_channel = channels[i] * 2
|
102 |
+
output_channel = channels[i + 1]
|
103 |
+
is_last = i == len(channels) - 2
|
104 |
+
resnet = ResnetBlock1D(
|
105 |
+
dim=input_channel,
|
106 |
+
dim_out=output_channel,
|
107 |
+
time_emb_dim=time_embed_dim,
|
108 |
+
)
|
109 |
+
transformer_blocks = nn.ModuleList(
|
110 |
+
[
|
111 |
+
BasicTransformerBlock(
|
112 |
+
dim=output_channel,
|
113 |
+
num_attention_heads=num_heads,
|
114 |
+
attention_head_dim=attention_head_dim,
|
115 |
+
dropout=dropout,
|
116 |
+
activation_fn=act_fn,
|
117 |
+
)
|
118 |
+
for _ in range(n_blocks)
|
119 |
+
]
|
120 |
+
)
|
121 |
+
upsample = (
|
122 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
123 |
+
if not is_last
|
124 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
125 |
+
)
|
126 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
127 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
128 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
129 |
+
self.initialize_weights()
|
130 |
+
|
131 |
+
|
132 |
+
def initialize_weights(self):
|
133 |
+
for m in self.modules():
|
134 |
+
if isinstance(m, nn.Conv1d):
|
135 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
136 |
+
if m.bias is not None:
|
137 |
+
nn.init.constant_(m.bias, 0)
|
138 |
+
elif isinstance(m, nn.GroupNorm):
|
139 |
+
nn.init.constant_(m.weight, 1)
|
140 |
+
nn.init.constant_(m.bias, 0)
|
141 |
+
elif isinstance(m, nn.Linear):
|
142 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
143 |
+
if m.bias is not None:
|
144 |
+
nn.init.constant_(m.bias, 0)
|
145 |
+
|
146 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
147 |
+
"""Forward pass of the UNet1DConditional model.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
151 |
+
mask (_type_): shape (batch_size, 1, time)
|
152 |
+
t (_type_): shape (batch_size)
|
153 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
154 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
155 |
+
|
156 |
+
Raises:
|
157 |
+
ValueError: _description_
|
158 |
+
ValueError: _description_
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
_type_: _description_
|
162 |
+
"""
|
163 |
+
|
164 |
+
t = self.time_embeddings(t).to(t.dtype)
|
165 |
+
t = self.time_mlp(t)
|
166 |
+
|
167 |
+
x = pack([x, mu], "b * t")[0]
|
168 |
+
|
169 |
+
if spks is not None:
|
170 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
171 |
+
x = pack([x, spks], "b * t")[0]
|
172 |
+
if cond is not None:
|
173 |
+
x = pack([x, cond], "b * t")[0]
|
174 |
+
|
175 |
+
hiddens = []
|
176 |
+
masks = [mask]
|
177 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
178 |
+
mask_down = masks[-1]
|
179 |
+
x = resnet(x, mask_down, t)
|
180 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
181 |
+
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
182 |
+
for transformer_block in transformer_blocks:
|
183 |
+
x = transformer_block(
|
184 |
+
hidden_states=x,
|
185 |
+
attention_mask=attn_mask,
|
186 |
+
timestep=t,
|
187 |
+
)
|
188 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
189 |
+
hiddens.append(x) # Save hidden states for skip connections
|
190 |
+
x = downsample(x * mask_down)
|
191 |
+
masks.append(mask_down[:, :, ::2])
|
192 |
+
masks = masks[:-1]
|
193 |
+
mask_mid = masks[-1]
|
194 |
+
|
195 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
196 |
+
x = resnet(x, mask_mid, t)
|
197 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
198 |
+
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
199 |
+
for transformer_block in transformer_blocks:
|
200 |
+
x = transformer_block(
|
201 |
+
hidden_states=x,
|
202 |
+
attention_mask=attn_mask,
|
203 |
+
timestep=t,
|
204 |
+
)
|
205 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
206 |
+
|
207 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
208 |
+
mask_up = masks.pop()
|
209 |
+
skip = hiddens.pop()
|
210 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
211 |
+
x = resnet(x, mask_up, t)
|
212 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
213 |
+
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
214 |
+
for transformer_block in transformer_blocks:
|
215 |
+
x = transformer_block(
|
216 |
+
hidden_states=x,
|
217 |
+
attention_mask=attn_mask,
|
218 |
+
timestep=t,
|
219 |
+
)
|
220 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
221 |
+
x = upsample(x * mask_up)
|
222 |
+
x = self.final_block(x, mask_up)
|
223 |
+
output = self.final_proj(x * mask_up)
|
224 |
+
return output * mask
|
audio_detokenizer/flow/flow.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
from typing import Dict, Optional
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from ..utils.mask import make_pad_mask
|
22 |
+
|
23 |
+
|
24 |
+
class MaskedDiffWithXvec(torch.nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
input_size: int = 512,
|
27 |
+
output_size: int = 80,
|
28 |
+
spk_embed_dim: int = 192,
|
29 |
+
output_type: str = "mel",
|
30 |
+
vocab_size: int = 4096,
|
31 |
+
input_frame_rate: int = 50,
|
32 |
+
only_mask_loss: bool = True,
|
33 |
+
encoder: torch.nn.Module = None,
|
34 |
+
length_regulator: torch.nn.Module = None,
|
35 |
+
decoder: torch.nn.Module = None,
|
36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
37 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
38 |
+
super().__init__()
|
39 |
+
self.input_size = input_size
|
40 |
+
self.output_size = output_size
|
41 |
+
self.decoder_conf = decoder_conf
|
42 |
+
self.mel_feat_conf = mel_feat_conf
|
43 |
+
self.vocab_size = vocab_size
|
44 |
+
self.output_type = output_type
|
45 |
+
self.input_frame_rate = input_frame_rate
|
46 |
+
logging.info(f"input frame rate={self.input_frame_rate}")
|
47 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
48 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
49 |
+
self.encoder = encoder
|
50 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
51 |
+
self.decoder = decoder
|
52 |
+
self.length_regulator = length_regulator
|
53 |
+
self.only_mask_loss = only_mask_loss
|
54 |
+
self.max_seq_long = 2048 * 2
|
55 |
+
self.max_seq_short = 256 * 2
|
56 |
+
|
57 |
+
def forward(
|
58 |
+
self,
|
59 |
+
batch: dict,
|
60 |
+
device: torch.device,
|
61 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
62 |
+
token = batch['speech_token'].to(device)
|
63 |
+
token_len = batch['speech_token_len'].to(device)
|
64 |
+
feat = batch['speech_feat'].to(device)
|
65 |
+
feat_len = batch['speech_feat_len'].to(device)
|
66 |
+
embedding = batch['embedding'].to(device)
|
67 |
+
|
68 |
+
# xvec projection
|
69 |
+
embedding = F.normalize(embedding, dim=1)
|
70 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
71 |
+
|
72 |
+
# concat text and prompt_text
|
73 |
+
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
74 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
75 |
+
|
76 |
+
# text encode
|
77 |
+
h, h_lengths = self.encoder(token, token_len)
|
78 |
+
h = self.encoder_proj(h)
|
79 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
80 |
+
|
81 |
+
# get conditions
|
82 |
+
conds = torch.zeros(feat.shape, device=token.device)
|
83 |
+
for i, j in enumerate(feat_len):
|
84 |
+
if random.random() < 0.5:
|
85 |
+
continue
|
86 |
+
index = random.randint(0, int(0.3 * j))
|
87 |
+
conds[i, :index] = feat[i, :index]
|
88 |
+
conds = conds.transpose(1, 2)
|
89 |
+
|
90 |
+
mask = (~make_pad_mask(feat_len)).to(h)
|
91 |
+
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
92 |
+
loss, _ = self.decoder.compute_loss(
|
93 |
+
feat.transpose(1, 2).contiguous(),
|
94 |
+
mask.unsqueeze(1),
|
95 |
+
h.transpose(1, 2).contiguous(),
|
96 |
+
embedding,
|
97 |
+
cond=conds
|
98 |
+
)
|
99 |
+
return {'loss': loss}
|
100 |
+
|
101 |
+
@torch.inference_mode()
|
102 |
+
def inference(self,
|
103 |
+
token,
|
104 |
+
token_len,
|
105 |
+
prompt_token,
|
106 |
+
prompt_token_len,
|
107 |
+
prompt_feat,
|
108 |
+
prompt_feat_len,
|
109 |
+
embedding):
|
110 |
+
assert token.shape[0] == 1
|
111 |
+
# xvec projection
|
112 |
+
embedding = F.normalize(embedding, dim=1)
|
113 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
114 |
+
|
115 |
+
# concat text and prompt_text
|
116 |
+
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
117 |
+
mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding)
|
118 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
119 |
+
|
120 |
+
# text encode
|
121 |
+
h, h_lengths = self.encoder(token, token_len)
|
122 |
+
h = self.encoder_proj(h)
|
123 |
+
feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
|
124 |
+
h, h_lengths = self.length_regulator(h, feat_len)
|
125 |
+
|
126 |
+
fix_max_len = feat_len.max().item()
|
127 |
+
|
128 |
+
# get conditions
|
129 |
+
conds = torch.zeros([1, fix_max_len, self.output_size], device=token.device, dtype=embedding.dtype)
|
130 |
+
# if prompt_feat.shape[1] != 0:
|
131 |
+
# for i, j in enumerate(prompt_feat_len):
|
132 |
+
# conds[i, :j] = prompt_feat[i]
|
133 |
+
conds = conds.transpose(1, 2)
|
134 |
+
|
135 |
+
mask = (~make_pad_mask(feat_len, fix_max_len)).to(h)
|
136 |
+
|
137 |
+
feat = self.decoder.forward(
|
138 |
+
mu=h.transpose(1, 2).contiguous(),
|
139 |
+
mask=mask.unsqueeze(1),
|
140 |
+
spks=embedding,
|
141 |
+
cond=conds,
|
142 |
+
n_timesteps=8,
|
143 |
+
# temperature=0.7,
|
144 |
+
)
|
145 |
+
|
146 |
+
if prompt_feat.shape[1] != 0:
|
147 |
+
feat = feat[:, :, prompt_feat.shape[1]:]
|
148 |
+
return feat
|
audio_detokenizer/flow/flow_matching.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# antflake8: noqa
|
15 |
+
import os
|
16 |
+
import torch
|
17 |
+
|
18 |
+
try:
|
19 |
+
import tensorrt as trt
|
20 |
+
except ImportError:
|
21 |
+
import warnings
|
22 |
+
warnings.warn("Failed to import TensorRT. Make sure TensorRT is installed and available in your environment.", ImportWarning)
|
23 |
+
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from matcha.models.components.flow_matching import BASECFM
|
26 |
+
|
27 |
+
class ConditionalCFM(BASECFM):
|
28 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, tensorrt_model_path="estimator_fp16.plan", estimator: torch.nn.Module = None):
|
29 |
+
super().__init__(
|
30 |
+
n_feats=in_channels,
|
31 |
+
cfm_params=cfm_params,
|
32 |
+
n_spks=n_spks,
|
33 |
+
spk_emb_dim=spk_emb_dim,
|
34 |
+
)
|
35 |
+
self.t_scheduler = cfm_params.t_scheduler
|
36 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
37 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
38 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
39 |
+
# Just change the architecture of the estimator here
|
40 |
+
self.estimator = estimator
|
41 |
+
self.compiled_estimator = None
|
42 |
+
|
43 |
+
self.export_onnx = False
|
44 |
+
self.use_tensorrt = False
|
45 |
+
|
46 |
+
if os.path.isfile(tensorrt_model_path):
|
47 |
+
trt.init_libnvinfer_plugins(None, "")
|
48 |
+
logger = trt.Logger(trt.Logger.WARNING)
|
49 |
+
runtime = trt.Runtime(logger)
|
50 |
+
with open(tensorrt_model_path, 'rb') as f:
|
51 |
+
serialized_engine = f.read()
|
52 |
+
self.engine = runtime.deserialize_cuda_engine(serialized_engine)
|
53 |
+
self._context = self.engine.create_execution_context()
|
54 |
+
self.use_tensorrt = True
|
55 |
+
|
56 |
+
@torch.inference_mode()
|
57 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
58 |
+
"""Forward diffusion
|
59 |
+
|
60 |
+
Args:
|
61 |
+
mu (torch.Tensor): output of encoder
|
62 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
63 |
+
mask (torch.Tensor): output_mask
|
64 |
+
shape: (batch_size, 1, mel_timesteps)
|
65 |
+
n_timesteps (int): number of diffusion steps
|
66 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
67 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
68 |
+
shape: (batch_size, spk_emb_dim)
|
69 |
+
cond: Not used but kept for future purposes
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
sample: generated mel-spectrogram
|
73 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
74 |
+
"""
|
75 |
+
z = torch.randn_like(mu) * temperature
|
76 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
77 |
+
if self.t_scheduler == 'cosine':
|
78 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
79 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
80 |
+
|
81 |
+
def estimator_infer(self, x, mask, mu, t, spks, cond):
|
82 |
+
if self.use_tensorrt:
|
83 |
+
# print("Using tensorrt now !!!!")
|
84 |
+
bs = x.shape[0]
|
85 |
+
hs = x.shape[1]
|
86 |
+
seq_len = x.shape[2]
|
87 |
+
|
88 |
+
assert bs == 1 and hs == 80
|
89 |
+
|
90 |
+
ret = torch.empty_like(x)
|
91 |
+
self._context.set_input_shape("x", x.shape)
|
92 |
+
self._context.set_input_shape("mask", mask.shape)
|
93 |
+
self._context.set_input_shape("mu", mu.shape)
|
94 |
+
self._context.set_input_shape("t", t.shape)
|
95 |
+
self._context.set_input_shape("spks", spks.shape)
|
96 |
+
self._context.set_input_shape("cond", cond.shape)
|
97 |
+
|
98 |
+
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
|
99 |
+
|
100 |
+
for i in range(len(bindings)):
|
101 |
+
self._context.set_tensor_address(self.engine.get_tensor_name(i), bindings[i])
|
102 |
+
|
103 |
+
handle = torch.cuda.current_stream().cuda_stream
|
104 |
+
self._context.execute_async_v3(stream_handle=handle)
|
105 |
+
return ret
|
106 |
+
else:
|
107 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
108 |
+
|
109 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
110 |
+
"""
|
111 |
+
Fixed euler solver for ODEs.
|
112 |
+
Args:
|
113 |
+
x (torch.Tensor): random noise
|
114 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
115 |
+
shape: (n_timesteps + 1,)
|
116 |
+
mu (torch.Tensor): output of encoder
|
117 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
118 |
+
mask (torch.Tensor): output_mask
|
119 |
+
shape: (batch_size, 1, mel_timesteps)
|
120 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
121 |
+
shape: (batch_size, spk_emb_dim)
|
122 |
+
cond: Not used but kept for future purposes
|
123 |
+
"""
|
124 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
125 |
+
t = t.unsqueeze(dim=0)
|
126 |
+
|
127 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
128 |
+
# Or in future might add like a return_all_steps flag
|
129 |
+
sol = []
|
130 |
+
|
131 |
+
# self.export_onnx= True
|
132 |
+
# if self.export_onnx == True:
|
133 |
+
# dummy_input = (x, mask, mu, t, spks, cond)
|
134 |
+
# torch.onnx.export(
|
135 |
+
# self.estimator,
|
136 |
+
# dummy_input,
|
137 |
+
# "estimator_bf16.onnx",
|
138 |
+
# export_params=True,
|
139 |
+
# opset_version=18,
|
140 |
+
# do_constant_folding=True,
|
141 |
+
# input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
142 |
+
# output_names=['output'],
|
143 |
+
# dynamic_axes={
|
144 |
+
# 'x': {2: 'seq_len'},
|
145 |
+
# 'mask': {2: 'seq_len'},
|
146 |
+
# 'mu': {2: 'seq_len'},
|
147 |
+
# 'cond': {2: 'seq_len'},
|
148 |
+
# 'output': {2: 'seq_len'},
|
149 |
+
# }
|
150 |
+
# )
|
151 |
+
# onnx_file_path = "estimator_bf16.onnx"
|
152 |
+
# tensorrt_path = "/root/TensorRT-10.2.0.19"
|
153 |
+
# if not tensorrt_path:
|
154 |
+
# raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
|
155 |
+
|
156 |
+
# if not os.path.isdir(tensorrt_path):
|
157 |
+
# raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
|
158 |
+
|
159 |
+
# trt_lib_path = os.path.join(tensorrt_path, "lib")
|
160 |
+
# if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
|
161 |
+
# print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
|
162 |
+
# os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
|
163 |
+
|
164 |
+
# trt_file_name = 'estimator_bf16.plan'
|
165 |
+
# flow_model_dir ='.'
|
166 |
+
# # trt_file_path = os.path.join(flow_model_dir, trt_file_name)
|
167 |
+
|
168 |
+
# trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec')
|
169 |
+
# trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_name} " \
|
170 |
+
# "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
|
171 |
+
# "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 " + \
|
172 |
+
# "--fp16"
|
173 |
+
|
174 |
+
# print("execute tensorrt", trtexec_cmd)
|
175 |
+
# os.system(trtexec_cmd)
|
176 |
+
# # """
|
177 |
+
# # ${TensorRT-10.2.0.19}/bin/trtexec --onnx=estimator_fp16.onnx --saveEngine=estimator_fp16.plan \
|
178 |
+
# # --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 \
|
179 |
+
# # --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 \
|
180 |
+
# # --fp16 --verbose
|
181 |
+
# # """
|
182 |
+
|
183 |
+
|
184 |
+
for step in range(1, len(t_span)):
|
185 |
+
dphi_dt = self.estimator_infer(x, mask, mu, t, spks, cond).clone()
|
186 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
187 |
+
if self.inference_cfg_rate > 0:
|
188 |
+
cfg_dphi_dt = self.estimator_infer(
|
189 |
+
x, mask,
|
190 |
+
torch.zeros_like(mu), t,
|
191 |
+
torch.zeros_like(spks) if spks is not None else None,
|
192 |
+
torch.zeros_like(cond)
|
193 |
+
).clone()
|
194 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
195 |
+
x = x + dt * dphi_dt
|
196 |
+
t = t + dt
|
197 |
+
sol.append(x)
|
198 |
+
if step < len(t_span) - 1:
|
199 |
+
dt = t_span[step + 1] - t
|
200 |
+
|
201 |
+
return sol[-1]
|
202 |
+
|
203 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
204 |
+
"""Computes diffusion loss
|
205 |
+
|
206 |
+
Args:
|
207 |
+
x1 (torch.Tensor): Target
|
208 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
209 |
+
mask (torch.Tensor): target mask
|
210 |
+
shape: (batch_size, 1, mel_timesteps)
|
211 |
+
mu (torch.Tensor): output of encoder
|
212 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
213 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
214 |
+
shape: (batch_size, spk_emb_dim)
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
loss: conditional flow matching loss
|
218 |
+
y: conditional flow
|
219 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
220 |
+
"""
|
221 |
+
b, _, t = mu.shape
|
222 |
+
|
223 |
+
# random timestep
|
224 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
225 |
+
if self.t_scheduler == 'cosine':
|
226 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
227 |
+
# sample noise p(x_0)
|
228 |
+
z = torch.randn_like(x1)
|
229 |
+
|
230 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
231 |
+
u = x1 - (1 - self.sigma_min) * z
|
232 |
+
|
233 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
234 |
+
if self.training_cfg_rate > 0:
|
235 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
236 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
237 |
+
spks = spks * cfg_mask.view(-1, 1)
|
238 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
239 |
+
|
240 |
+
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
241 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
242 |
+
return loss, y
|
audio_detokenizer/flow/length_regulator.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Tuple
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from ..utils.mask import make_pad_mask
|
18 |
+
|
19 |
+
|
20 |
+
class InterpolateRegulator(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
channels: int,
|
24 |
+
sampling_ratios: Tuple,
|
25 |
+
out_channels: int = None,
|
26 |
+
groups: int = 1,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.sampling_ratios = sampling_ratios
|
30 |
+
out_channels = out_channels or channels
|
31 |
+
model = nn.ModuleList([])
|
32 |
+
if len(sampling_ratios) > 0:
|
33 |
+
for _ in sampling_ratios:
|
34 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
35 |
+
norm = nn.GroupNorm(groups, channels)
|
36 |
+
act = nn.Mish()
|
37 |
+
model.extend([module, norm, act])
|
38 |
+
model.append(
|
39 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
40 |
+
)
|
41 |
+
self.model = nn.Sequential(*model)
|
42 |
+
|
43 |
+
def forward(self, x, ylens=None):
|
44 |
+
# x in (B, T, D)
|
45 |
+
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
46 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
47 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
48 |
+
olens = ylens
|
49 |
+
return out * mask, olens
|
audio_detokenizer/hifigan/__init__.py
ADDED
File without changes
|
audio_detokenizer/hifigan/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (186 Bytes). View file
|
|
audio_detokenizer/hifigan/__pycache__/f0_predictor.cpython-38.pyc
ADDED
Binary file (1.37 kB). View file
|
|
audio_detokenizer/hifigan/__pycache__/generator.cpython-38.pyc
ADDED
Binary file (11.3 kB). View file
|
|
audio_detokenizer/hifigan/f0_predictor.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn.utils import weight_norm
|
17 |
+
|
18 |
+
|
19 |
+
class ConvRNNF0Predictor(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
num_class: int = 1,
|
22 |
+
in_channels: int = 80,
|
23 |
+
cond_channels: int = 512
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.num_class = num_class
|
28 |
+
self.condnet = nn.Sequential(
|
29 |
+
weight_norm(
|
30 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
31 |
+
),
|
32 |
+
nn.ELU(),
|
33 |
+
weight_norm(
|
34 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
35 |
+
),
|
36 |
+
nn.ELU(),
|
37 |
+
weight_norm(
|
38 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
39 |
+
),
|
40 |
+
nn.ELU(),
|
41 |
+
weight_norm(
|
42 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
43 |
+
),
|
44 |
+
nn.ELU(),
|
45 |
+
weight_norm(
|
46 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
47 |
+
),
|
48 |
+
nn.ELU(),
|
49 |
+
)
|
50 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = self.condnet(x)
|
54 |
+
x = x.transpose(1, 2)
|
55 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
audio_detokenizer/hifigan/generator.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""HIFI-GAN"""
|
16 |
+
# antflake8: noqa
|
17 |
+
|
18 |
+
import typing as tp
|
19 |
+
import numpy as np
|
20 |
+
from scipy.signal import get_window
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch.nn import Conv1d
|
25 |
+
from torch.nn import ConvTranspose1d
|
26 |
+
from torch.nn.utils import remove_weight_norm
|
27 |
+
from torch.nn.utils import weight_norm
|
28 |
+
from torch.distributions.uniform import Uniform
|
29 |
+
|
30 |
+
from ..transformer.activation import Snake
|
31 |
+
from ..utils.common import get_padding
|
32 |
+
from ..utils.common import init_weights
|
33 |
+
|
34 |
+
|
35 |
+
"""hifigan based generator implementation.
|
36 |
+
|
37 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
38 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
39 |
+
https://github.com/NVIDIA/BigVGAN
|
40 |
+
|
41 |
+
"""
|
42 |
+
class ResBlock(torch.nn.Module):
|
43 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
channels: int = 512,
|
47 |
+
kernel_size: int = 3,
|
48 |
+
dilations: tp.List[int] = [1, 3, 5],
|
49 |
+
):
|
50 |
+
super(ResBlock, self).__init__()
|
51 |
+
self.convs1 = nn.ModuleList()
|
52 |
+
self.convs2 = nn.ModuleList()
|
53 |
+
|
54 |
+
for dilation in dilations:
|
55 |
+
self.convs1.append(
|
56 |
+
weight_norm(
|
57 |
+
Conv1d(
|
58 |
+
channels,
|
59 |
+
channels,
|
60 |
+
kernel_size,
|
61 |
+
1,
|
62 |
+
dilation=dilation,
|
63 |
+
padding=get_padding(kernel_size, dilation)
|
64 |
+
)
|
65 |
+
)
|
66 |
+
)
|
67 |
+
self.convs2.append(
|
68 |
+
weight_norm(
|
69 |
+
Conv1d(
|
70 |
+
channels,
|
71 |
+
channels,
|
72 |
+
kernel_size,
|
73 |
+
1,
|
74 |
+
dilation=1,
|
75 |
+
padding=get_padding(kernel_size, 1)
|
76 |
+
)
|
77 |
+
)
|
78 |
+
)
|
79 |
+
self.convs1.apply(init_weights)
|
80 |
+
self.convs2.apply(init_weights)
|
81 |
+
self.activations1 = nn.ModuleList([
|
82 |
+
Snake(channels, alpha_logscale=False)
|
83 |
+
for _ in range(len(self.convs1))
|
84 |
+
])
|
85 |
+
self.activations2 = nn.ModuleList([
|
86 |
+
Snake(channels, alpha_logscale=False)
|
87 |
+
for _ in range(len(self.convs2))
|
88 |
+
])
|
89 |
+
|
90 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
91 |
+
for idx in range(len(self.convs1)):
|
92 |
+
xt = self.activations1[idx](x)
|
93 |
+
xt = self.convs1[idx](xt)
|
94 |
+
xt = self.activations2[idx](xt)
|
95 |
+
xt = self.convs2[idx](xt)
|
96 |
+
x = xt + x
|
97 |
+
return x
|
98 |
+
|
99 |
+
def remove_weight_norm(self):
|
100 |
+
for idx in range(len(self.convs1)):
|
101 |
+
remove_weight_norm(self.convs1[idx])
|
102 |
+
remove_weight_norm(self.convs2[idx])
|
103 |
+
|
104 |
+
class SineGen(torch.nn.Module):
|
105 |
+
""" Definition of sine generator
|
106 |
+
SineGen(samp_rate, harmonic_num = 0,
|
107 |
+
sine_amp = 0.1, noise_std = 0.003,
|
108 |
+
voiced_threshold = 0,
|
109 |
+
flag_for_pulse=False)
|
110 |
+
samp_rate: sampling rate in Hz
|
111 |
+
harmonic_num: number of harmonic overtones (default 0)
|
112 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
113 |
+
noise_std: std of Gaussian noise (default 0.003)
|
114 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
115 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
116 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
117 |
+
segment is always sin(np.pi) or cos(0)
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
121 |
+
sine_amp=0.1, noise_std=0.003,
|
122 |
+
voiced_threshold=0):
|
123 |
+
super(SineGen, self).__init__()
|
124 |
+
self.sine_amp = sine_amp
|
125 |
+
self.noise_std = noise_std
|
126 |
+
self.harmonic_num = harmonic_num
|
127 |
+
self.sampling_rate = samp_rate
|
128 |
+
self.voiced_threshold = voiced_threshold
|
129 |
+
|
130 |
+
def _f02uv(self, f0):
|
131 |
+
# generate uv signal
|
132 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
133 |
+
return uv
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def forward(self, f0):
|
137 |
+
"""
|
138 |
+
:param f0: [B, 1, sample_len], Hz
|
139 |
+
:return: [B, 1, sample_len]
|
140 |
+
"""
|
141 |
+
|
142 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
143 |
+
for i in range(self.harmonic_num + 1):
|
144 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
145 |
+
|
146 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
147 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
148 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
149 |
+
phase_vec[:, 0, :] = 0
|
150 |
+
|
151 |
+
# generate sine waveforms
|
152 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
153 |
+
|
154 |
+
# generate uv signal
|
155 |
+
uv = self._f02uv(f0)
|
156 |
+
|
157 |
+
# noise: for unvoiced should be similar to sine_amp
|
158 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
159 |
+
# . for voiced regions is self.noise_std
|
160 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
161 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
162 |
+
|
163 |
+
# first: set the unvoiced part to 0 by uv
|
164 |
+
# then: additive noise
|
165 |
+
sine_waves = sine_waves * uv + noise
|
166 |
+
return sine_waves, uv, noise
|
167 |
+
|
168 |
+
|
169 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
170 |
+
""" SourceModule for hn-nsf
|
171 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
172 |
+
add_noise_std=0.003, voiced_threshod=0)
|
173 |
+
sampling_rate: sampling_rate in Hz
|
174 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
175 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
176 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
177 |
+
note that amplitude of noise in unvoiced is decided
|
178 |
+
by sine_amp
|
179 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
180 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
181 |
+
F0_sampled (batchsize, length, 1)
|
182 |
+
Sine_source (batchsize, length, 1)
|
183 |
+
noise_source (batchsize, length 1)
|
184 |
+
uv (batchsize, length, 1)
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
188 |
+
add_noise_std=0.003, voiced_threshod=0):
|
189 |
+
super(SourceModuleHnNSF, self).__init__()
|
190 |
+
|
191 |
+
self.sine_amp = sine_amp
|
192 |
+
self.noise_std = add_noise_std
|
193 |
+
|
194 |
+
# to produce sine waveforms
|
195 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
196 |
+
sine_amp, add_noise_std, voiced_threshod)
|
197 |
+
|
198 |
+
# to merge source harmonics into a single excitation
|
199 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
200 |
+
self.l_tanh = torch.nn.Tanh()
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
"""
|
204 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
205 |
+
F0_sampled (batchsize, length, 1)
|
206 |
+
Sine_source (batchsize, length, 1)
|
207 |
+
noise_source (batchsize, length 1)
|
208 |
+
"""
|
209 |
+
# source for harmonic branch
|
210 |
+
with torch.no_grad():
|
211 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
212 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
213 |
+
uv = uv.transpose(1, 2)
|
214 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
215 |
+
|
216 |
+
# source for noise branch, in the same shape as uv
|
217 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
218 |
+
return sine_merge, noise, uv
|
219 |
+
|
220 |
+
|
221 |
+
class HiFTGenerator(nn.Module):
|
222 |
+
"""
|
223 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
224 |
+
https://arxiv.org/abs/2309.09493
|
225 |
+
"""
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
in_channels: int = 80,
|
229 |
+
base_channels: int = 512,
|
230 |
+
nb_harmonics: int = 8,
|
231 |
+
sampling_rate: int = 22050,
|
232 |
+
nsf_alpha: float = 0.1,
|
233 |
+
nsf_sigma: float = 0.003,
|
234 |
+
nsf_voiced_threshold: float = 10,
|
235 |
+
upsample_rates: tp.List[int] = [8, 8],
|
236 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
237 |
+
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
238 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
239 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
240 |
+
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
241 |
+
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
242 |
+
lrelu_slope: float = 0.1,
|
243 |
+
audio_limit: float = 0.99,
|
244 |
+
f0_predictor: torch.nn.Module = None,
|
245 |
+
):
|
246 |
+
super(HiFTGenerator, self).__init__()
|
247 |
+
|
248 |
+
self.out_channels = 1
|
249 |
+
self.nb_harmonics = nb_harmonics
|
250 |
+
self.sampling_rate = sampling_rate
|
251 |
+
self.istft_params = istft_params
|
252 |
+
self.lrelu_slope = lrelu_slope
|
253 |
+
self.audio_limit = audio_limit
|
254 |
+
|
255 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
256 |
+
self.num_upsamples = len(upsample_rates)
|
257 |
+
self.m_source = SourceModuleHnNSF(
|
258 |
+
sampling_rate=sampling_rate,
|
259 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
260 |
+
harmonic_num=nb_harmonics,
|
261 |
+
sine_amp=nsf_alpha,
|
262 |
+
add_noise_std=nsf_sigma,
|
263 |
+
voiced_threshod=nsf_voiced_threshold)
|
264 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
265 |
+
|
266 |
+
self.conv_pre = weight_norm(
|
267 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
268 |
+
)
|
269 |
+
|
270 |
+
# Up
|
271 |
+
self.ups = nn.ModuleList()
|
272 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
273 |
+
self.ups.append(
|
274 |
+
weight_norm(
|
275 |
+
ConvTranspose1d(
|
276 |
+
base_channels // (2**i),
|
277 |
+
base_channels // (2**(i + 1)),
|
278 |
+
k,
|
279 |
+
u,
|
280 |
+
padding=(k - u) // 2,
|
281 |
+
)
|
282 |
+
)
|
283 |
+
)
|
284 |
+
|
285 |
+
# Down
|
286 |
+
self.source_downs = nn.ModuleList()
|
287 |
+
self.source_resblocks = nn.ModuleList()
|
288 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
289 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
290 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
291 |
+
source_resblock_dilation_sizes)):
|
292 |
+
if u == 1:
|
293 |
+
self.source_downs.append(
|
294 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
self.source_downs.append(
|
298 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
299 |
+
)
|
300 |
+
|
301 |
+
self.source_resblocks.append(
|
302 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
303 |
+
)
|
304 |
+
|
305 |
+
self.resblocks = nn.ModuleList()
|
306 |
+
for i in range(len(self.ups)):
|
307 |
+
ch = base_channels // (2**(i + 1))
|
308 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
309 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
310 |
+
|
311 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
312 |
+
self.ups.apply(init_weights)
|
313 |
+
self.conv_post.apply(init_weights)
|
314 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
315 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
316 |
+
self.f0_predictor = f0_predictor
|
317 |
+
|
318 |
+
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
319 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
320 |
+
|
321 |
+
har_source, _, _ = self.m_source(f0)
|
322 |
+
return har_source.transpose(1, 2)
|
323 |
+
|
324 |
+
def _stft(self, x):
|
325 |
+
spec = torch.stft(
|
326 |
+
x,
|
327 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
328 |
+
return_complex=True)
|
329 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
330 |
+
return spec[..., 0], spec[..., 1]
|
331 |
+
|
332 |
+
def _istft(self, magnitude, phase):
|
333 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
334 |
+
real = magnitude * torch.cos(phase)
|
335 |
+
img = magnitude * torch.sin(phase)
|
336 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
337 |
+
return inverse_transform
|
338 |
+
|
339 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
340 |
+
f0 = self.f0_predictor(x)
|
341 |
+
s = self._f02source(f0)
|
342 |
+
|
343 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
344 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
345 |
+
|
346 |
+
x = self.conv_pre(x)
|
347 |
+
for i in range(self.num_upsamples):
|
348 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
349 |
+
x = self.ups[i](x)
|
350 |
+
|
351 |
+
if i == self.num_upsamples - 1:
|
352 |
+
x = self.reflection_pad(x)
|
353 |
+
|
354 |
+
# fusion
|
355 |
+
si = self.source_downs[i](s_stft)
|
356 |
+
si = self.source_resblocks[i](si)
|
357 |
+
x = x + si
|
358 |
+
|
359 |
+
xs = None
|
360 |
+
for j in range(self.num_kernels):
|
361 |
+
if xs is None:
|
362 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
363 |
+
else:
|
364 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
365 |
+
x = xs / self.num_kernels
|
366 |
+
|
367 |
+
x = F.leaky_relu(x)
|
368 |
+
x = self.conv_post(x)
|
369 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
370 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
371 |
+
|
372 |
+
x = self._istft(magnitude, phase)
|
373 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
374 |
+
return x
|
375 |
+
|
376 |
+
def remove_weight_norm(self):
|
377 |
+
print('Removing weight norm...')
|
378 |
+
for l in self.ups:
|
379 |
+
remove_weight_norm(l)
|
380 |
+
for l in self.resblocks:
|
381 |
+
l.remove_weight_norm()
|
382 |
+
remove_weight_norm(self.conv_pre)
|
383 |
+
remove_weight_norm(self.conv_post)
|
384 |
+
self.source_module.remove_weight_norm()
|
385 |
+
for l in self.source_downs:
|
386 |
+
remove_weight_norm(l)
|
387 |
+
for l in self.source_resblocks:
|
388 |
+
l.remove_weight_norm()
|
389 |
+
|
390 |
+
@torch.inference_mode()
|
391 |
+
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
392 |
+
return self.forward(x=mel)
|
audio_detokenizer/transformer/__init__.py
ADDED
File without changes
|
audio_detokenizer/transformer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (190 Bytes). View file
|
|
audio_detokenizer/transformer/__pycache__/activation.cpython-38.pyc
ADDED
Binary file (2.51 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (10.9 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/convolution.cpython-38.pyc
ADDED
Binary file (3.1 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/embedding.cpython-38.pyc
ADDED
Binary file (9.78 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (19.6 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/encoder_layer.cpython-38.pyc
ADDED
Binary file (8.64 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc
ADDED
Binary file (3.79 kB). View file
|
|
audio_detokenizer/transformer/__pycache__/subsampling.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
audio_detokenizer/transformer/activation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Swish() activation function for Conformer."""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn, sin, pow
|
21 |
+
from torch.nn import Parameter
|
22 |
+
|
23 |
+
|
24 |
+
class Swish(torch.nn.Module):
|
25 |
+
"""Construct an Swish object."""
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""Return Swish activation function."""
|
29 |
+
return x * torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
33 |
+
# LICENSE is in incl_licenses directory.
|
34 |
+
class Snake(nn.Module):
|
35 |
+
'''
|
36 |
+
Implementation of a sine-based periodic activation function
|
37 |
+
Shape:
|
38 |
+
- Input: (B, C, T)
|
39 |
+
- Output: (B, C, T), same shape as the input
|
40 |
+
Parameters:
|
41 |
+
- alpha - trainable parameter
|
42 |
+
References:
|
43 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
44 |
+
https://arxiv.org/abs/2006.08195
|
45 |
+
Examples:
|
46 |
+
>>> a1 = snake(256)
|
47 |
+
>>> x = torch.randn(256)
|
48 |
+
>>> x = a1(x)
|
49 |
+
'''
|
50 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
51 |
+
'''
|
52 |
+
Initialization.
|
53 |
+
INPUT:
|
54 |
+
- in_features: shape of the input
|
55 |
+
- alpha: trainable parameter
|
56 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
57 |
+
alpha will be trained along with the rest of your model.
|
58 |
+
'''
|
59 |
+
super(Snake, self).__init__()
|
60 |
+
self.in_features = in_features
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
68 |
+
|
69 |
+
self.alpha.requires_grad = alpha_trainable
|
70 |
+
|
71 |
+
self.no_div_by_zero = 0.000000001
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
'''
|
75 |
+
Forward pass of the function.
|
76 |
+
Applies the function to the input elementwise.
|
77 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
78 |
+
'''
|
79 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
80 |
+
if self.alpha_logscale:
|
81 |
+
alpha = torch.exp(alpha)
|
82 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
83 |
+
|
84 |
+
return x
|
audio_detokenizer/transformer/attention.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2022 Xingchen Song ([email protected])
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Multi-Head Attention layer definition."""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class MultiHeadedAttention(nn.Module):
|
27 |
+
"""Multi-Head Attention layer.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
n_head (int): The number of heads.
|
31 |
+
n_feat (int): The number of features.
|
32 |
+
dropout_rate (float): Dropout rate.
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
n_head: int,
|
38 |
+
n_feat: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
key_bias: bool = True):
|
41 |
+
"""Construct an MultiHeadedAttention object."""
|
42 |
+
super().__init__()
|
43 |
+
assert n_feat % n_head == 0
|
44 |
+
# We assume d_v always equals d_k
|
45 |
+
self.d_k = n_feat // n_head
|
46 |
+
self.h = n_head
|
47 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
48 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
49 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
50 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
52 |
+
self.dropout_rate = dropout_rate
|
53 |
+
self.kv_cache = None
|
54 |
+
|
55 |
+
def forward_qkv(
|
56 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
57 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
58 |
+
"""Transform query, key and value.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
62 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
63 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
torch.Tensor: Transformed query tensor, size
|
67 |
+
(#batch, n_head, time1, d_k).
|
68 |
+
torch.Tensor: Transformed key tensor, size
|
69 |
+
(#batch, n_head, time2, d_k).
|
70 |
+
torch.Tensor: Transformed value tensor, size
|
71 |
+
(#batch, n_head, time2, d_k).
|
72 |
+
|
73 |
+
"""
|
74 |
+
n_batch = query.size(0)
|
75 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
76 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
77 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
78 |
+
|
79 |
+
return q, k, v
|
80 |
+
|
81 |
+
def forward_attention(
|
82 |
+
self,
|
83 |
+
value: torch.Tensor,
|
84 |
+
scores: torch.Tensor,
|
85 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""Compute attention context vector.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
value (torch.Tensor): Transformed value, size
|
91 |
+
(#batch, n_head, time2, d_k).
|
92 |
+
scores (torch.Tensor): Attention score, size
|
93 |
+
(#batch, n_head, time1, time2).
|
94 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
95 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
99 |
+
weighted by the attention score (#batch, time1, time2).
|
100 |
+
|
101 |
+
"""
|
102 |
+
n_batch = value.size(0)
|
103 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
104 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
105 |
+
# 1st chunk to ease the onnx export.]
|
106 |
+
# 2. pytorch training
|
107 |
+
if mask.size(2) > 0: # time2 > 0
|
108 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
109 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
110 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
111 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
112 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
113 |
+
mask, 0.0) # (batch, head, time1, time2)
|
114 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
115 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
116 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
117 |
+
else:
|
118 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
119 |
+
|
120 |
+
p_attn = self.dropout(attn)
|
121 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
122 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
123 |
+
self.h * self.d_k)
|
124 |
+
) # (batch, time1, d_model)
|
125 |
+
|
126 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
query: torch.Tensor,
|
131 |
+
key: torch.Tensor,
|
132 |
+
value: torch.Tensor,
|
133 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
134 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
135 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
136 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
"""Compute scaled dot product attention.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
141 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
142 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
143 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
144 |
+
(#batch, time1, time2).
|
145 |
+
1.When applying cross attention between decoder and encoder,
|
146 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
147 |
+
2.When applying self attention of encoder,
|
148 |
+
the mask is in (#batch, T, T) shape.
|
149 |
+
3.When applying self attention of decoder,
|
150 |
+
the mask is in (#batch, L, L) shape.
|
151 |
+
4.If the different position in decoder see different block
|
152 |
+
of the encoder, such as Mocha, the passed in mask could be
|
153 |
+
in (#batch, L, T) shape.
|
154 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
155 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
156 |
+
and `head * d_k == size`
|
157 |
+
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
161 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
162 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
163 |
+
and `head * d_k == size`
|
164 |
+
|
165 |
+
"""
|
166 |
+
q, k, v = self.forward_qkv(query, key, value)
|
167 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
168 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
169 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
170 |
+
|
171 |
+
# NOTE(xcsong):
|
172 |
+
# when export onnx model, for 1st chunk, we feed
|
173 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
174 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
175 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
176 |
+
# and we will always do splitting and
|
177 |
+
# concatnation(this will simplify onnx export). Note that
|
178 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
179 |
+
# when export jit model, for 1st chunk, we always feed
|
180 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
181 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
182 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
183 |
+
# >>> c = torch.cat((a, b), dim=2)
|
184 |
+
# >>> torch.equal(b, c) # True
|
185 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
186 |
+
# >>> torch.equal(d[0], d[1]) # True
|
187 |
+
if cache.size(0) > 0:
|
188 |
+
key_cache, value_cache = torch.split(cache,
|
189 |
+
cache.size(-1) // 2,
|
190 |
+
dim=-1)
|
191 |
+
k = torch.cat([key_cache, k], dim=2)
|
192 |
+
v = torch.cat([value_cache, v], dim=2)
|
193 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
194 |
+
# non-trivial to calculate `next_cache_start` here.
|
195 |
+
new_cache = torch.cat((k, v), dim=-1)
|
196 |
+
|
197 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
198 |
+
return self.forward_attention(v, scores, mask), new_cache
|
199 |
+
|
200 |
+
def inference(
|
201 |
+
self,
|
202 |
+
query: torch.Tensor,
|
203 |
+
key: torch.Tensor,
|
204 |
+
value: torch.Tensor,
|
205 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
206 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
207 |
+
cache_offset: torch.Tensor = None,
|
208 |
+
is_infer_short: bool = False,
|
209 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
210 |
+
"""Compute scaled dot product attention.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
214 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
215 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
216 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
217 |
+
(#batch, time1, time2).
|
218 |
+
1.When applying cross attention between decoder and encoder,
|
219 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
220 |
+
2.When applying self attention of encoder,
|
221 |
+
the mask is in (#batch, T, T) shape.
|
222 |
+
3.When applying self attention of decoder,
|
223 |
+
the mask is in (#batch, L, L) shape.
|
224 |
+
4.If the different position in decoder see different block
|
225 |
+
of the encoder, such as Mocha, the passed in mask could be
|
226 |
+
in (#batch, L, T) shape.
|
227 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
228 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
229 |
+
and `head * d_k == size`
|
230 |
+
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
234 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
235 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
236 |
+
and `head * d_k == size`
|
237 |
+
|
238 |
+
"""
|
239 |
+
q, k, v = self.forward_qkv(query, key, value)
|
240 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
241 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
242 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
243 |
+
|
244 |
+
if self.kv_cache is not None:
|
245 |
+
k, v = self.kv_cache.update(cache_offset, k, v, is_infer_short)
|
246 |
+
|
247 |
+
assert mask.dtype == torch.bool
|
248 |
+
mask = mask.unsqueeze(1).eq(False) * torch.finfo(q.dtype).min
|
249 |
+
|
250 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
251 |
+
q,
|
252 |
+
k,
|
253 |
+
v,
|
254 |
+
attn_mask=mask,
|
255 |
+
dropout_p=self.dropout_rate,
|
256 |
+
scale=1 / math.sqrt(self.d_k),
|
257 |
+
)
|
258 |
+
output = (output.transpose(1, 2).contiguous().view(
|
259 |
+
query.size(0), -1,
|
260 |
+
self.h * self.d_k)) # (batch, time1, d_model)
|
261 |
+
return self.linear_out(output)
|
262 |
+
|
263 |
+
|
264 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
265 |
+
"""Multi-Head Attention layer with relative position encoding.
|
266 |
+
Paper: https://arxiv.org/abs/1901.02860
|
267 |
+
Args:
|
268 |
+
n_head (int): The number of heads.
|
269 |
+
n_feat (int): The number of features.
|
270 |
+
dropout_rate (float): Dropout rate.
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(self,
|
274 |
+
n_head: int,
|
275 |
+
n_feat: int,
|
276 |
+
dropout_rate: float,
|
277 |
+
key_bias: bool = True):
|
278 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
279 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
280 |
+
# linear transformation for positional encoding
|
281 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
282 |
+
# these two learnable bias are used in matrix c and matrix d
|
283 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
284 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
285 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
286 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
287 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
288 |
+
|
289 |
+
def rel_shift(self, x):
|
290 |
+
"""Compute relative positional encoding.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
294 |
+
time1 means the length of query vector.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
torch.Tensor: Output tensor.
|
298 |
+
|
299 |
+
"""
|
300 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
301 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
302 |
+
|
303 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
304 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
305 |
+
:, :, :, : x.size(-1) // 2 + 1
|
306 |
+
] # only keep the positions from 0 to time2
|
307 |
+
return x
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
query: torch.Tensor,
|
312 |
+
key: torch.Tensor,
|
313 |
+
value: torch.Tensor,
|
314 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
315 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
316 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
317 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
318 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
319 |
+
Args:
|
320 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
321 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
322 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
323 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
324 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
325 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
326 |
+
(#batch, time2, size).
|
327 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
328 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
329 |
+
and `head * d_k == size`
|
330 |
+
Returns:
|
331 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
332 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
333 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
334 |
+
and `head * d_k == size`
|
335 |
+
"""
|
336 |
+
q, k, v = self.forward_qkv(query, key, value)
|
337 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
338 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
339 |
+
|
340 |
+
# NOTE(xcsong):
|
341 |
+
# when export onnx model, for 1st chunk, we feed
|
342 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
343 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
344 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
345 |
+
# and we will always do splitting and
|
346 |
+
# concatnation(this will simplify onnx export). Note that
|
347 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
348 |
+
# when export jit model, for 1st chunk, we always feed
|
349 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
350 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
351 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
352 |
+
# >>> c = torch.cat((a, b), dim=2)
|
353 |
+
# >>> torch.equal(b, c) # True
|
354 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
355 |
+
# >>> torch.equal(d[0], d[1]) # True
|
356 |
+
if cache.size(0) > 0:
|
357 |
+
key_cache, value_cache = torch.split(cache,
|
358 |
+
cache.size(-1) // 2,
|
359 |
+
dim=-1)
|
360 |
+
k = torch.cat([key_cache, k], dim=2)
|
361 |
+
v = torch.cat([value_cache, v], dim=2)
|
362 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
363 |
+
# non-trivial to calculate `next_cache_start` here.
|
364 |
+
new_cache = torch.cat((k, v), dim=-1)
|
365 |
+
|
366 |
+
n_batch_pos = pos_emb.size(0)
|
367 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
368 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
369 |
+
|
370 |
+
# (batch, head, time1, d_k)
|
371 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
372 |
+
# (batch, head, time1, d_k)
|
373 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
374 |
+
|
375 |
+
# compute attention score
|
376 |
+
# first compute matrix a and matrix c
|
377 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
378 |
+
# (batch, head, time1, time2)
|
379 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
380 |
+
|
381 |
+
# compute matrix b and matrix d
|
382 |
+
# (batch, head, time1, time2)
|
383 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
384 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
385 |
+
if matrix_ac.shape != matrix_bd.shape:
|
386 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
387 |
+
|
388 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
389 |
+
self.d_k) # (batch, head, time1, time2)
|
390 |
+
|
391 |
+
return self.forward_attention(v, scores, mask), new_cache
|
392 |
+
|
393 |
+
def inference(
|
394 |
+
self,
|
395 |
+
query: torch.Tensor,
|
396 |
+
key: torch.Tensor,
|
397 |
+
value: torch.Tensor,
|
398 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
399 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
400 |
+
cache_offset: torch.Tensor = None,
|
401 |
+
is_infer_short: bool = False,
|
402 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
403 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
404 |
+
Args:
|
405 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
406 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
407 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
408 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
409 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
410 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
411 |
+
(#batch, time2, size).
|
412 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
413 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
414 |
+
and `head * d_k == size`
|
415 |
+
Returns:
|
416 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
417 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
418 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
419 |
+
and `head * d_k == size`
|
420 |
+
"""
|
421 |
+
q, k, v = self.forward_qkv(query, key, value)
|
422 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
423 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
424 |
+
|
425 |
+
if self.kv_cache is not None:
|
426 |
+
k, v = self.kv_cache.update(cache_offset, k, v, is_infer_short)
|
427 |
+
|
428 |
+
n_batch_pos = pos_emb.size(0)
|
429 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
430 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
431 |
+
|
432 |
+
# (batch, head, time1, d_k)
|
433 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
434 |
+
# (batch, head, time1, d_k)
|
435 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
436 |
+
|
437 |
+
# compute matrix b and matrix d
|
438 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
439 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
440 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
441 |
+
|
442 |
+
assert mask.dtype == torch.bool
|
443 |
+
# mask = (mask.unsqueeze(1).eq(False) * torch.finfo(k.dtype).min).to(matrix_bd.dtype)
|
444 |
+
mask = mask.unsqueeze(1).eq(False)
|
445 |
+
mask = (matrix_bd / math.sqrt(self.d_k)).masked_fill(mask, torch.tensor(-float('inf'), dtype=matrix_bd.dtype))
|
446 |
+
# import pdb; pdb.set_trace()
|
447 |
+
# print("q_with_bias_u.shape", q_with_bias_u.shape)
|
448 |
+
# print("k.shape", k.shape)
|
449 |
+
# print("v.shape", v.shape)
|
450 |
+
# print("mask.shape", mask.shape)
|
451 |
+
# import pdb; pdb.set_trace()
|
452 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
453 |
+
q_with_bias_u,
|
454 |
+
k,
|
455 |
+
v,
|
456 |
+
attn_mask=mask,
|
457 |
+
dropout_p=self.dropout_rate,
|
458 |
+
scale=1 / math.sqrt(self.d_k),
|
459 |
+
)
|
460 |
+
|
461 |
+
output = (output.transpose(1, 2).contiguous().view(
|
462 |
+
query.size(0), -1, self.h * self.d_k)) # (batch, time1, d_model)
|
463 |
+
return self.linear_out(output)
|
audio_detokenizer/transformer/convolution.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""ConvolutionModule definition."""
|
17 |
+
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class ConvolutionModule(nn.Module):
|
25 |
+
"""ConvolutionModule in Conformer model."""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
channels: int,
|
29 |
+
kernel_size: int = 15,
|
30 |
+
activation: nn.Module = nn.ReLU(),
|
31 |
+
norm: str = "batch_norm",
|
32 |
+
causal: bool = False,
|
33 |
+
bias: bool = True):
|
34 |
+
"""Construct an ConvolutionModule object.
|
35 |
+
Args:
|
36 |
+
channels (int): The number of channels of conv layers.
|
37 |
+
kernel_size (int): Kernel size of conv layers.
|
38 |
+
causal (int): Whether use causal convolution or not
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.pointwise_conv1 = nn.Conv1d(
|
43 |
+
channels,
|
44 |
+
2 * channels,
|
45 |
+
kernel_size=1,
|
46 |
+
stride=1,
|
47 |
+
padding=0,
|
48 |
+
bias=bias,
|
49 |
+
)
|
50 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
51 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
52 |
+
# padded with self.lorder frames on the left in forward.
|
53 |
+
# else: it's a symmetrical convolution
|
54 |
+
if causal:
|
55 |
+
padding = 0
|
56 |
+
self.lorder = kernel_size - 1
|
57 |
+
else:
|
58 |
+
# kernel_size should be an odd number for none causal convolution
|
59 |
+
assert (kernel_size - 1) % 2 == 0
|
60 |
+
padding = (kernel_size - 1) // 2
|
61 |
+
self.lorder = 0
|
62 |
+
self.depthwise_conv = nn.Conv1d(
|
63 |
+
channels,
|
64 |
+
channels,
|
65 |
+
kernel_size,
|
66 |
+
stride=1,
|
67 |
+
padding=padding,
|
68 |
+
groups=channels,
|
69 |
+
bias=bias,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert norm in ['batch_norm', 'layer_norm']
|
73 |
+
if norm == "batch_norm":
|
74 |
+
self.use_layer_norm = False
|
75 |
+
self.norm = nn.BatchNorm1d(channels)
|
76 |
+
else:
|
77 |
+
self.use_layer_norm = True
|
78 |
+
self.norm = nn.LayerNorm(channels)
|
79 |
+
|
80 |
+
self.pointwise_conv2 = nn.Conv1d(
|
81 |
+
channels,
|
82 |
+
channels,
|
83 |
+
kernel_size=1,
|
84 |
+
stride=1,
|
85 |
+
padding=0,
|
86 |
+
bias=bias,
|
87 |
+
)
|
88 |
+
self.activation = activation
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
x: torch.Tensor,
|
93 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
94 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
95 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
96 |
+
"""Compute convolution module.
|
97 |
+
Args:
|
98 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
99 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
100 |
+
(0, 0, 0) means fake mask.
|
101 |
+
cache (torch.Tensor): left context cache, it is only
|
102 |
+
used in causal convolution (#batch, channels, cache_t),
|
103 |
+
(0, 0, 0) meas fake cache.
|
104 |
+
Returns:
|
105 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
106 |
+
"""
|
107 |
+
# exchange the temporal dimension and the feature dimension
|
108 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
109 |
+
|
110 |
+
# mask batch padding
|
111 |
+
if mask_pad.size(2) > 0: # time > 0
|
112 |
+
x.masked_fill_(~mask_pad, 0.0)
|
113 |
+
|
114 |
+
if self.lorder > 0:
|
115 |
+
if cache.size(2) == 0: # cache_t == 0
|
116 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
117 |
+
else:
|
118 |
+
assert cache.size(0) == x.size(0) # equal batch
|
119 |
+
assert cache.size(1) == x.size(1) # equal channel
|
120 |
+
x = torch.cat((cache, x), dim=2)
|
121 |
+
assert (x.size(2) > self.lorder)
|
122 |
+
new_cache = x[:, :, -self.lorder:]
|
123 |
+
else:
|
124 |
+
# It's better we just return None if no cache is required,
|
125 |
+
# However, for JIT export, here we just fake one tensor instead of
|
126 |
+
# None.
|
127 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
128 |
+
|
129 |
+
# GLU mechanism
|
130 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
131 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
132 |
+
|
133 |
+
# 1D Depthwise Conv
|
134 |
+
x = self.depthwise_conv(x)
|
135 |
+
if self.use_layer_norm:
|
136 |
+
x = x.transpose(1, 2)
|
137 |
+
x = self.activation(self.norm(x))
|
138 |
+
if self.use_layer_norm:
|
139 |
+
x = x.transpose(1, 2)
|
140 |
+
x = self.pointwise_conv2(x)
|
141 |
+
# mask batch padding
|
142 |
+
if mask_pad.size(2) > 0: # time > 0
|
143 |
+
x.masked_fill_(~mask_pad, 0.0)
|
144 |
+
|
145 |
+
return x.transpose(1, 2), new_cache
|