LandyGuo commited on
Commit
e92022a
·
1 Parent(s): 736eafa

First model version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. __pycache__/bailingmm_utils.cpython-38.pyc +0 -0
  3. __pycache__/chat_format.cpython-38.pyc +0 -0
  4. __pycache__/configuration_audio.cpython-38.pyc +0 -0
  5. __pycache__/configuration_bailing_moe.cpython-38.pyc +0 -0
  6. __pycache__/configuration_bailing_talker.cpython-38.pyc +0 -0
  7. __pycache__/configuration_bailingmm.cpython-38.pyc +0 -0
  8. __pycache__/image_processing_bailingmm.cpython-38.pyc +0 -0
  9. __pycache__/modeling_bailing_moe.cpython-38.pyc +0 -0
  10. __pycache__/modeling_bailing_talker.cpython-38.pyc +0 -0
  11. __pycache__/modeling_bailingmm.cpython-38.pyc +0 -0
  12. __pycache__/modeling_utils.cpython-38.pyc +0 -0
  13. __pycache__/qwen2_5_vit.cpython-38.pyc +0 -0
  14. __pycache__/s3bpe_tokenizer.cpython-38.pyc +0 -0
  15. am.mvn +8 -0
  16. audio_detokenizer/__init__.py +0 -0
  17. audio_detokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
  18. audio_detokenizer/cli/__init__.py +0 -0
  19. audio_detokenizer/cli/__pycache__/__init__.cpython-38.pyc +0 -0
  20. audio_detokenizer/cli/__pycache__/model.cpython-38.pyc +0 -0
  21. audio_detokenizer/cli/model.py +62 -0
  22. audio_detokenizer/flow/__init__.py +0 -0
  23. audio_detokenizer/flow/__pycache__/__init__.cpython-38.pyc +0 -0
  24. audio_detokenizer/flow/__pycache__/decoder.cpython-38.pyc +0 -0
  25. audio_detokenizer/flow/__pycache__/flow.cpython-38.pyc +0 -0
  26. audio_detokenizer/flow/__pycache__/flow_matching.cpython-38.pyc +0 -0
  27. audio_detokenizer/flow/__pycache__/length_regulator.cpython-38.pyc +0 -0
  28. audio_detokenizer/flow/decoder.py +224 -0
  29. audio_detokenizer/flow/flow.py +148 -0
  30. audio_detokenizer/flow/flow_matching.py +242 -0
  31. audio_detokenizer/flow/length_regulator.py +49 -0
  32. audio_detokenizer/hifigan/__init__.py +0 -0
  33. audio_detokenizer/hifigan/__pycache__/__init__.cpython-38.pyc +0 -0
  34. audio_detokenizer/hifigan/__pycache__/f0_predictor.cpython-38.pyc +0 -0
  35. audio_detokenizer/hifigan/__pycache__/generator.cpython-38.pyc +0 -0
  36. audio_detokenizer/hifigan/f0_predictor.py +55 -0
  37. audio_detokenizer/hifigan/generator.py +392 -0
  38. audio_detokenizer/transformer/__init__.py +0 -0
  39. audio_detokenizer/transformer/__pycache__/__init__.cpython-38.pyc +0 -0
  40. audio_detokenizer/transformer/__pycache__/activation.cpython-38.pyc +0 -0
  41. audio_detokenizer/transformer/__pycache__/attention.cpython-38.pyc +0 -0
  42. audio_detokenizer/transformer/__pycache__/convolution.cpython-38.pyc +0 -0
  43. audio_detokenizer/transformer/__pycache__/embedding.cpython-38.pyc +0 -0
  44. audio_detokenizer/transformer/__pycache__/encoder.cpython-38.pyc +0 -0
  45. audio_detokenizer/transformer/__pycache__/encoder_layer.cpython-38.pyc +0 -0
  46. audio_detokenizer/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc +0 -0
  47. audio_detokenizer/transformer/__pycache__/subsampling.cpython-38.pyc +0 -0
  48. audio_detokenizer/transformer/activation.py +84 -0
  49. audio_detokenizer/transformer/attention.py +463 -0
  50. audio_detokenizer/transformer/convolution.py +145 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/matcha_tts-0.0.5.1-cp38-cp38-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ data/wavs/BAC009S0915W0292.wav filter=lfs diff=lfs merge=lfs -text
38
+ out.wav filter=lfs diff=lfs merge=lfs -text
39
+ talker/tokenizer.json filter=lfs diff=lfs merge=lfs -text
40
+
41
+
42
+
__pycache__/bailingmm_utils.cpython-38.pyc ADDED
Binary file (13.5 kB). View file
 
__pycache__/chat_format.cpython-38.pyc ADDED
Binary file (21.9 kB). View file
 
__pycache__/configuration_audio.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
__pycache__/configuration_bailing_moe.cpython-38.pyc ADDED
Binary file (1.99 kB). View file
 
__pycache__/configuration_bailing_talker.cpython-38.pyc ADDED
Binary file (1.14 kB). View file
 
__pycache__/configuration_bailingmm.cpython-38.pyc ADDED
Binary file (1.2 kB). View file
 
__pycache__/image_processing_bailingmm.cpython-38.pyc ADDED
Binary file (16.9 kB). View file
 
__pycache__/modeling_bailing_moe.cpython-38.pyc ADDED
Binary file (48.1 kB). View file
 
__pycache__/modeling_bailing_talker.cpython-38.pyc ADDED
Binary file (8.19 kB). View file
 
__pycache__/modeling_bailingmm.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (26.3 kB). View file
 
__pycache__/qwen2_5_vit.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
__pycache__/s3bpe_tokenizer.cpython-38.pyc ADDED
Binary file (2.24 kB). View file
 
am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 560 560
3
+ [ 0 ]
4
+ <AddShift> 560 560
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 560 560
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
audio_detokenizer/__init__.py ADDED
File without changes
audio_detokenizer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
audio_detokenizer/cli/__init__.py ADDED
File without changes
audio_detokenizer/cli/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (182 Bytes). View file
 
audio_detokenizer/cli/__pycache__/model.cpython-38.pyc ADDED
Binary file (2.02 kB). View file
 
audio_detokenizer/cli/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ import time
17
+
18
+ class AudioDetokenizerModel:
19
+
20
+ def __init__(self,
21
+ flow: torch.nn.Module,
22
+ hift: torch.nn.Module,
23
+ lora_config=None):
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ self.flow = flow
26
+ self.hift = hift
27
+ self.dtype = torch.float16
28
+ # self.dtype = torch.bfloat16
29
+ self.max_seq_short = 384
30
+ self.max_seq_long = 2048
31
+ self.max_batch = 1
32
+
33
+ def load(self, flow_model, hift_model):
34
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
35
+ self.flow.to(self.device).eval().to(self.dtype)
36
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
37
+ self.hift.to(self.device).eval()
38
+
39
+ def inference(self, flow_embedding, tts_speech_token,
40
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
41
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), is_en=False):
42
+
43
+ torch.cuda.synchronize()
44
+ t0 = time.time()
45
+
46
+ torch.cuda.synchronize()
47
+ t1 = time.time()
48
+
49
+ tts_mel = self.flow.inference(token=tts_speech_token.to(self.device),
50
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
51
+ prompt_token=flow_prompt_speech_token.to(self.device),
52
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
53
+ prompt_feat=prompt_speech_feat.to(self.device),
54
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
55
+ embedding=flow_embedding.to(self.device).to(self.dtype)).float()
56
+ torch.cuda.synchronize()
57
+
58
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
+ torch.cuda.synchronize()
60
+ dur = tts_speech.shape[-1]/22050
61
+ torch.cuda.empty_cache()
62
+ return {'tts_speech': tts_speech}
audio_detokenizer/flow/__init__.py ADDED
File without changes
audio_detokenizer/flow/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (183 Bytes). View file
 
audio_detokenizer/flow/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (5.28 kB). View file
 
audio_detokenizer/flow/__pycache__/flow.cpython-38.pyc ADDED
Binary file (4.15 kB). View file
 
audio_detokenizer/flow/__pycache__/flow_matching.cpython-38.pyc ADDED
Binary file (6.1 kB). View file
 
audio_detokenizer/flow/__pycache__/length_regulator.cpython-38.pyc ADDED
Binary file (1.48 kB). View file
 
audio_detokenizer/flow/decoder.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # antflake8: noqa
15
+ import torch
16
+ import torch.nn as nn
17
+ from einops import pack, rearrange, repeat
18
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
19
+ from matcha.models.components.transformer import BasicTransformerBlock
20
+
21
+
22
+ class ConditionalDecoder(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_channels,
26
+ out_channels,
27
+ channels=(256, 256),
28
+ dropout=0.05,
29
+ attention_head_dim=64,
30
+ n_blocks=1,
31
+ num_mid_blocks=2,
32
+ num_heads=4,
33
+ act_fn="snake",
34
+ ):
35
+ """
36
+ This decoder requires an input with the same shape of the target. So, if your text content
37
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
38
+ """
39
+ super().__init__()
40
+ channels = tuple(channels)
41
+ self.in_channels = in_channels
42
+ self.out_channels = out_channels
43
+
44
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
45
+ time_embed_dim = channels[0] * 4
46
+ self.time_mlp = TimestepEmbedding(
47
+ in_channels=in_channels,
48
+ time_embed_dim=time_embed_dim,
49
+ act_fn="silu",
50
+ )
51
+ self.down_blocks = nn.ModuleList([])
52
+ self.mid_blocks = nn.ModuleList([])
53
+ self.up_blocks = nn.ModuleList([])
54
+ self.compiled_infer = None
55
+
56
+ output_channel = in_channels
57
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
58
+ input_channel = output_channel
59
+ output_channel = channels[i]
60
+ is_last = i == len(channels) - 1
61
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
62
+ transformer_blocks = nn.ModuleList(
63
+ [
64
+ BasicTransformerBlock(
65
+ dim=output_channel,
66
+ num_attention_heads=num_heads,
67
+ attention_head_dim=attention_head_dim,
68
+ dropout=dropout,
69
+ activation_fn=act_fn,
70
+ )
71
+ for _ in range(n_blocks)
72
+ ]
73
+ )
74
+ downsample = (
75
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
76
+ )
77
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
78
+
79
+ for i in range(num_mid_blocks):
80
+ input_channel = channels[-1]
81
+ out_channels = channels[-1]
82
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
83
+
84
+ transformer_blocks = nn.ModuleList(
85
+ [
86
+ BasicTransformerBlock(
87
+ dim=output_channel,
88
+ num_attention_heads=num_heads,
89
+ attention_head_dim=attention_head_dim,
90
+ dropout=dropout,
91
+ activation_fn=act_fn,
92
+ )
93
+ for _ in range(n_blocks)
94
+ ]
95
+ )
96
+
97
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
98
+
99
+ channels = channels[::-1] + (channels[0],)
100
+ for i in range(len(channels) - 1):
101
+ input_channel = channels[i] * 2
102
+ output_channel = channels[i + 1]
103
+ is_last = i == len(channels) - 2
104
+ resnet = ResnetBlock1D(
105
+ dim=input_channel,
106
+ dim_out=output_channel,
107
+ time_emb_dim=time_embed_dim,
108
+ )
109
+ transformer_blocks = nn.ModuleList(
110
+ [
111
+ BasicTransformerBlock(
112
+ dim=output_channel,
113
+ num_attention_heads=num_heads,
114
+ attention_head_dim=attention_head_dim,
115
+ dropout=dropout,
116
+ activation_fn=act_fn,
117
+ )
118
+ for _ in range(n_blocks)
119
+ ]
120
+ )
121
+ upsample = (
122
+ Upsample1D(output_channel, use_conv_transpose=True)
123
+ if not is_last
124
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
125
+ )
126
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
127
+ self.final_block = Block1D(channels[-1], channels[-1])
128
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
129
+ self.initialize_weights()
130
+
131
+
132
+ def initialize_weights(self):
133
+ for m in self.modules():
134
+ if isinstance(m, nn.Conv1d):
135
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
136
+ if m.bias is not None:
137
+ nn.init.constant_(m.bias, 0)
138
+ elif isinstance(m, nn.GroupNorm):
139
+ nn.init.constant_(m.weight, 1)
140
+ nn.init.constant_(m.bias, 0)
141
+ elif isinstance(m, nn.Linear):
142
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+
146
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
147
+ """Forward pass of the UNet1DConditional model.
148
+
149
+ Args:
150
+ x (torch.Tensor): shape (batch_size, in_channels, time)
151
+ mask (_type_): shape (batch_size, 1, time)
152
+ t (_type_): shape (batch_size)
153
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
154
+ cond (_type_, optional): placeholder for future use. Defaults to None.
155
+
156
+ Raises:
157
+ ValueError: _description_
158
+ ValueError: _description_
159
+
160
+ Returns:
161
+ _type_: _description_
162
+ """
163
+
164
+ t = self.time_embeddings(t).to(t.dtype)
165
+ t = self.time_mlp(t)
166
+
167
+ x = pack([x, mu], "b * t")[0]
168
+
169
+ if spks is not None:
170
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
171
+ x = pack([x, spks], "b * t")[0]
172
+ if cond is not None:
173
+ x = pack([x, cond], "b * t")[0]
174
+
175
+ hiddens = []
176
+ masks = [mask]
177
+ for resnet, transformer_blocks, downsample in self.down_blocks:
178
+ mask_down = masks[-1]
179
+ x = resnet(x, mask_down, t)
180
+ x = rearrange(x, "b c t -> b t c").contiguous()
181
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
182
+ for transformer_block in transformer_blocks:
183
+ x = transformer_block(
184
+ hidden_states=x,
185
+ attention_mask=attn_mask,
186
+ timestep=t,
187
+ )
188
+ x = rearrange(x, "b t c -> b c t").contiguous()
189
+ hiddens.append(x) # Save hidden states for skip connections
190
+ x = downsample(x * mask_down)
191
+ masks.append(mask_down[:, :, ::2])
192
+ masks = masks[:-1]
193
+ mask_mid = masks[-1]
194
+
195
+ for resnet, transformer_blocks in self.mid_blocks:
196
+ x = resnet(x, mask_mid, t)
197
+ x = rearrange(x, "b c t -> b t c").contiguous()
198
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
199
+ for transformer_block in transformer_blocks:
200
+ x = transformer_block(
201
+ hidden_states=x,
202
+ attention_mask=attn_mask,
203
+ timestep=t,
204
+ )
205
+ x = rearrange(x, "b t c -> b c t").contiguous()
206
+
207
+ for resnet, transformer_blocks, upsample in self.up_blocks:
208
+ mask_up = masks.pop()
209
+ skip = hiddens.pop()
210
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
211
+ x = resnet(x, mask_up, t)
212
+ x = rearrange(x, "b c t -> b t c").contiguous()
213
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
214
+ for transformer_block in transformer_blocks:
215
+ x = transformer_block(
216
+ hidden_states=x,
217
+ attention_mask=attn_mask,
218
+ timestep=t,
219
+ )
220
+ x = rearrange(x, "b t c -> b c t").contiguous()
221
+ x = upsample(x * mask_up)
222
+ x = self.final_block(x, mask_up)
223
+ output = self.final_proj(x * mask_up)
224
+ return output * mask
audio_detokenizer/flow/flow.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from ..utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
37
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
38
+ super().__init__()
39
+ self.input_size = input_size
40
+ self.output_size = output_size
41
+ self.decoder_conf = decoder_conf
42
+ self.mel_feat_conf = mel_feat_conf
43
+ self.vocab_size = vocab_size
44
+ self.output_type = output_type
45
+ self.input_frame_rate = input_frame_rate
46
+ logging.info(f"input frame rate={self.input_frame_rate}")
47
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
48
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
49
+ self.encoder = encoder
50
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
51
+ self.decoder = decoder
52
+ self.length_regulator = length_regulator
53
+ self.only_mask_loss = only_mask_loss
54
+ self.max_seq_long = 2048 * 2
55
+ self.max_seq_short = 256 * 2
56
+
57
+ def forward(
58
+ self,
59
+ batch: dict,
60
+ device: torch.device,
61
+ ) -> Dict[str, Optional[torch.Tensor]]:
62
+ token = batch['speech_token'].to(device)
63
+ token_len = batch['speech_token_len'].to(device)
64
+ feat = batch['speech_feat'].to(device)
65
+ feat_len = batch['speech_feat_len'].to(device)
66
+ embedding = batch['embedding'].to(device)
67
+
68
+ # xvec projection
69
+ embedding = F.normalize(embedding, dim=1)
70
+ embedding = self.spk_embed_affine_layer(embedding)
71
+
72
+ # concat text and prompt_text
73
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
74
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
75
+
76
+ # text encode
77
+ h, h_lengths = self.encoder(token, token_len)
78
+ h = self.encoder_proj(h)
79
+ h, h_lengths = self.length_regulator(h, feat_len)
80
+
81
+ # get conditions
82
+ conds = torch.zeros(feat.shape, device=token.device)
83
+ for i, j in enumerate(feat_len):
84
+ if random.random() < 0.5:
85
+ continue
86
+ index = random.randint(0, int(0.3 * j))
87
+ conds[i, :index] = feat[i, :index]
88
+ conds = conds.transpose(1, 2)
89
+
90
+ mask = (~make_pad_mask(feat_len)).to(h)
91
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
92
+ loss, _ = self.decoder.compute_loss(
93
+ feat.transpose(1, 2).contiguous(),
94
+ mask.unsqueeze(1),
95
+ h.transpose(1, 2).contiguous(),
96
+ embedding,
97
+ cond=conds
98
+ )
99
+ return {'loss': loss}
100
+
101
+ @torch.inference_mode()
102
+ def inference(self,
103
+ token,
104
+ token_len,
105
+ prompt_token,
106
+ prompt_token_len,
107
+ prompt_feat,
108
+ prompt_feat_len,
109
+ embedding):
110
+ assert token.shape[0] == 1
111
+ # xvec projection
112
+ embedding = F.normalize(embedding, dim=1)
113
+ embedding = self.spk_embed_affine_layer(embedding)
114
+
115
+ # concat text and prompt_text
116
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
117
+ mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding)
118
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
119
+
120
+ # text encode
121
+ h, h_lengths = self.encoder(token, token_len)
122
+ h = self.encoder_proj(h)
123
+ feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
124
+ h, h_lengths = self.length_regulator(h, feat_len)
125
+
126
+ fix_max_len = feat_len.max().item()
127
+
128
+ # get conditions
129
+ conds = torch.zeros([1, fix_max_len, self.output_size], device=token.device, dtype=embedding.dtype)
130
+ # if prompt_feat.shape[1] != 0:
131
+ # for i, j in enumerate(prompt_feat_len):
132
+ # conds[i, :j] = prompt_feat[i]
133
+ conds = conds.transpose(1, 2)
134
+
135
+ mask = (~make_pad_mask(feat_len, fix_max_len)).to(h)
136
+
137
+ feat = self.decoder.forward(
138
+ mu=h.transpose(1, 2).contiguous(),
139
+ mask=mask.unsqueeze(1),
140
+ spks=embedding,
141
+ cond=conds,
142
+ n_timesteps=8,
143
+ # temperature=0.7,
144
+ )
145
+
146
+ if prompt_feat.shape[1] != 0:
147
+ feat = feat[:, :, prompt_feat.shape[1]:]
148
+ return feat
audio_detokenizer/flow/flow_matching.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # antflake8: noqa
15
+ import os
16
+ import torch
17
+
18
+ try:
19
+ import tensorrt as trt
20
+ except ImportError:
21
+ import warnings
22
+ warnings.warn("Failed to import TensorRT. Make sure TensorRT is installed and available in your environment.", ImportWarning)
23
+
24
+ import torch.nn.functional as F
25
+ from matcha.models.components.flow_matching import BASECFM
26
+
27
+ class ConditionalCFM(BASECFM):
28
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, tensorrt_model_path="estimator_fp16.plan", estimator: torch.nn.Module = None):
29
+ super().__init__(
30
+ n_feats=in_channels,
31
+ cfm_params=cfm_params,
32
+ n_spks=n_spks,
33
+ spk_emb_dim=spk_emb_dim,
34
+ )
35
+ self.t_scheduler = cfm_params.t_scheduler
36
+ self.training_cfg_rate = cfm_params.training_cfg_rate
37
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
38
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
39
+ # Just change the architecture of the estimator here
40
+ self.estimator = estimator
41
+ self.compiled_estimator = None
42
+
43
+ self.export_onnx = False
44
+ self.use_tensorrt = False
45
+
46
+ if os.path.isfile(tensorrt_model_path):
47
+ trt.init_libnvinfer_plugins(None, "")
48
+ logger = trt.Logger(trt.Logger.WARNING)
49
+ runtime = trt.Runtime(logger)
50
+ with open(tensorrt_model_path, 'rb') as f:
51
+ serialized_engine = f.read()
52
+ self.engine = runtime.deserialize_cuda_engine(serialized_engine)
53
+ self._context = self.engine.create_execution_context()
54
+ self.use_tensorrt = True
55
+
56
+ @torch.inference_mode()
57
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
58
+ """Forward diffusion
59
+
60
+ Args:
61
+ mu (torch.Tensor): output of encoder
62
+ shape: (batch_size, n_feats, mel_timesteps)
63
+ mask (torch.Tensor): output_mask
64
+ shape: (batch_size, 1, mel_timesteps)
65
+ n_timesteps (int): number of diffusion steps
66
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
67
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
68
+ shape: (batch_size, spk_emb_dim)
69
+ cond: Not used but kept for future purposes
70
+
71
+ Returns:
72
+ sample: generated mel-spectrogram
73
+ shape: (batch_size, n_feats, mel_timesteps)
74
+ """
75
+ z = torch.randn_like(mu) * temperature
76
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
77
+ if self.t_scheduler == 'cosine':
78
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
79
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
80
+
81
+ def estimator_infer(self, x, mask, mu, t, spks, cond):
82
+ if self.use_tensorrt:
83
+ # print("Using tensorrt now !!!!")
84
+ bs = x.shape[0]
85
+ hs = x.shape[1]
86
+ seq_len = x.shape[2]
87
+
88
+ assert bs == 1 and hs == 80
89
+
90
+ ret = torch.empty_like(x)
91
+ self._context.set_input_shape("x", x.shape)
92
+ self._context.set_input_shape("mask", mask.shape)
93
+ self._context.set_input_shape("mu", mu.shape)
94
+ self._context.set_input_shape("t", t.shape)
95
+ self._context.set_input_shape("spks", spks.shape)
96
+ self._context.set_input_shape("cond", cond.shape)
97
+
98
+ bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
99
+
100
+ for i in range(len(bindings)):
101
+ self._context.set_tensor_address(self.engine.get_tensor_name(i), bindings[i])
102
+
103
+ handle = torch.cuda.current_stream().cuda_stream
104
+ self._context.execute_async_v3(stream_handle=handle)
105
+ return ret
106
+ else:
107
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
108
+
109
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
110
+ """
111
+ Fixed euler solver for ODEs.
112
+ Args:
113
+ x (torch.Tensor): random noise
114
+ t_span (torch.Tensor): n_timesteps interpolated
115
+ shape: (n_timesteps + 1,)
116
+ mu (torch.Tensor): output of encoder
117
+ shape: (batch_size, n_feats, mel_timesteps)
118
+ mask (torch.Tensor): output_mask
119
+ shape: (batch_size, 1, mel_timesteps)
120
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
121
+ shape: (batch_size, spk_emb_dim)
122
+ cond: Not used but kept for future purposes
123
+ """
124
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
125
+ t = t.unsqueeze(dim=0)
126
+
127
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
128
+ # Or in future might add like a return_all_steps flag
129
+ sol = []
130
+
131
+ # self.export_onnx= True
132
+ # if self.export_onnx == True:
133
+ # dummy_input = (x, mask, mu, t, spks, cond)
134
+ # torch.onnx.export(
135
+ # self.estimator,
136
+ # dummy_input,
137
+ # "estimator_bf16.onnx",
138
+ # export_params=True,
139
+ # opset_version=18,
140
+ # do_constant_folding=True,
141
+ # input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
142
+ # output_names=['output'],
143
+ # dynamic_axes={
144
+ # 'x': {2: 'seq_len'},
145
+ # 'mask': {2: 'seq_len'},
146
+ # 'mu': {2: 'seq_len'},
147
+ # 'cond': {2: 'seq_len'},
148
+ # 'output': {2: 'seq_len'},
149
+ # }
150
+ # )
151
+ # onnx_file_path = "estimator_bf16.onnx"
152
+ # tensorrt_path = "/root/TensorRT-10.2.0.19"
153
+ # if not tensorrt_path:
154
+ # raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
155
+
156
+ # if not os.path.isdir(tensorrt_path):
157
+ # raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
158
+
159
+ # trt_lib_path = os.path.join(tensorrt_path, "lib")
160
+ # if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
161
+ # print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
162
+ # os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
163
+
164
+ # trt_file_name = 'estimator_bf16.plan'
165
+ # flow_model_dir ='.'
166
+ # # trt_file_path = os.path.join(flow_model_dir, trt_file_name)
167
+
168
+ # trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec')
169
+ # trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_name} " \
170
+ # "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
171
+ # "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 " + \
172
+ # "--fp16"
173
+
174
+ # print("execute tensorrt", trtexec_cmd)
175
+ # os.system(trtexec_cmd)
176
+ # # """
177
+ # # ${TensorRT-10.2.0.19}/bin/trtexec --onnx=estimator_fp16.onnx --saveEngine=estimator_fp16.plan \
178
+ # # --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 \
179
+ # # --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 \
180
+ # # --fp16 --verbose
181
+ # # """
182
+
183
+
184
+ for step in range(1, len(t_span)):
185
+ dphi_dt = self.estimator_infer(x, mask, mu, t, spks, cond).clone()
186
+ # Classifier-Free Guidance inference introduced in VoiceBox
187
+ if self.inference_cfg_rate > 0:
188
+ cfg_dphi_dt = self.estimator_infer(
189
+ x, mask,
190
+ torch.zeros_like(mu), t,
191
+ torch.zeros_like(spks) if spks is not None else None,
192
+ torch.zeros_like(cond)
193
+ ).clone()
194
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
195
+ x = x + dt * dphi_dt
196
+ t = t + dt
197
+ sol.append(x)
198
+ if step < len(t_span) - 1:
199
+ dt = t_span[step + 1] - t
200
+
201
+ return sol[-1]
202
+
203
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
204
+ """Computes diffusion loss
205
+
206
+ Args:
207
+ x1 (torch.Tensor): Target
208
+ shape: (batch_size, n_feats, mel_timesteps)
209
+ mask (torch.Tensor): target mask
210
+ shape: (batch_size, 1, mel_timesteps)
211
+ mu (torch.Tensor): output of encoder
212
+ shape: (batch_size, n_feats, mel_timesteps)
213
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
214
+ shape: (batch_size, spk_emb_dim)
215
+
216
+ Returns:
217
+ loss: conditional flow matching loss
218
+ y: conditional flow
219
+ shape: (batch_size, n_feats, mel_timesteps)
220
+ """
221
+ b, _, t = mu.shape
222
+
223
+ # random timestep
224
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
225
+ if self.t_scheduler == 'cosine':
226
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
227
+ # sample noise p(x_0)
228
+ z = torch.randn_like(x1)
229
+
230
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
231
+ u = x1 - (1 - self.sigma_min) * z
232
+
233
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
234
+ if self.training_cfg_rate > 0:
235
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
236
+ mu = mu * cfg_mask.view(-1, 1, 1)
237
+ spks = spks * cfg_mask.view(-1, 1)
238
+ cond = cond * cfg_mask.view(-1, 1, 1)
239
+
240
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
241
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
242
+ return loss, y
audio_detokenizer/flow/length_regulator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from ..utils.mask import make_pad_mask
18
+
19
+
20
+ class InterpolateRegulator(nn.Module):
21
+ def __init__(
22
+ self,
23
+ channels: int,
24
+ sampling_ratios: Tuple,
25
+ out_channels: int = None,
26
+ groups: int = 1,
27
+ ):
28
+ super().__init__()
29
+ self.sampling_ratios = sampling_ratios
30
+ out_channels = out_channels or channels
31
+ model = nn.ModuleList([])
32
+ if len(sampling_ratios) > 0:
33
+ for _ in sampling_ratios:
34
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
35
+ norm = nn.GroupNorm(groups, channels)
36
+ act = nn.Mish()
37
+ model.extend([module, norm, act])
38
+ model.append(
39
+ nn.Conv1d(channels, out_channels, 1, 1)
40
+ )
41
+ self.model = nn.Sequential(*model)
42
+
43
+ def forward(self, x, ylens=None):
44
+ # x in (B, T, D)
45
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
46
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
47
+ out = self.model(x).transpose(1, 2).contiguous()
48
+ olens = ylens
49
+ return out * mask, olens
audio_detokenizer/hifigan/__init__.py ADDED
File without changes
audio_detokenizer/hifigan/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (186 Bytes). View file
 
audio_detokenizer/hifigan/__pycache__/f0_predictor.cpython-38.pyc ADDED
Binary file (1.37 kB). View file
 
audio_detokenizer/hifigan/__pycache__/generator.cpython-38.pyc ADDED
Binary file (11.3 kB). View file
 
audio_detokenizer/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
audio_detokenizer/hifigan/generator.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+ # antflake8: noqa
17
+
18
+ import typing as tp
19
+ import numpy as np
20
+ from scipy.signal import get_window
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.nn import Conv1d
25
+ from torch.nn import ConvTranspose1d
26
+ from torch.nn.utils import remove_weight_norm
27
+ from torch.nn.utils import weight_norm
28
+ from torch.distributions.uniform import Uniform
29
+
30
+ from ..transformer.activation import Snake
31
+ from ..utils.common import get_padding
32
+ from ..utils.common import init_weights
33
+
34
+
35
+ """hifigan based generator implementation.
36
+
37
+ This code is modified from https://github.com/jik876/hifi-gan
38
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
39
+ https://github.com/NVIDIA/BigVGAN
40
+
41
+ """
42
+ class ResBlock(torch.nn.Module):
43
+ """Residual block module in HiFiGAN/BigVGAN."""
44
+ def __init__(
45
+ self,
46
+ channels: int = 512,
47
+ kernel_size: int = 3,
48
+ dilations: tp.List[int] = [1, 3, 5],
49
+ ):
50
+ super(ResBlock, self).__init__()
51
+ self.convs1 = nn.ModuleList()
52
+ self.convs2 = nn.ModuleList()
53
+
54
+ for dilation in dilations:
55
+ self.convs1.append(
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=dilation,
63
+ padding=get_padding(kernel_size, dilation)
64
+ )
65
+ )
66
+ )
67
+ self.convs2.append(
68
+ weight_norm(
69
+ Conv1d(
70
+ channels,
71
+ channels,
72
+ kernel_size,
73
+ 1,
74
+ dilation=1,
75
+ padding=get_padding(kernel_size, 1)
76
+ )
77
+ )
78
+ )
79
+ self.convs1.apply(init_weights)
80
+ self.convs2.apply(init_weights)
81
+ self.activations1 = nn.ModuleList([
82
+ Snake(channels, alpha_logscale=False)
83
+ for _ in range(len(self.convs1))
84
+ ])
85
+ self.activations2 = nn.ModuleList([
86
+ Snake(channels, alpha_logscale=False)
87
+ for _ in range(len(self.convs2))
88
+ ])
89
+
90
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ for idx in range(len(self.convs1)):
92
+ xt = self.activations1[idx](x)
93
+ xt = self.convs1[idx](xt)
94
+ xt = self.activations2[idx](xt)
95
+ xt = self.convs2[idx](xt)
96
+ x = xt + x
97
+ return x
98
+
99
+ def remove_weight_norm(self):
100
+ for idx in range(len(self.convs1)):
101
+ remove_weight_norm(self.convs1[idx])
102
+ remove_weight_norm(self.convs2[idx])
103
+
104
+ class SineGen(torch.nn.Module):
105
+ """ Definition of sine generator
106
+ SineGen(samp_rate, harmonic_num = 0,
107
+ sine_amp = 0.1, noise_std = 0.003,
108
+ voiced_threshold = 0,
109
+ flag_for_pulse=False)
110
+ samp_rate: sampling rate in Hz
111
+ harmonic_num: number of harmonic overtones (default 0)
112
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
113
+ noise_std: std of Gaussian noise (default 0.003)
114
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
115
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
116
+ Note: when flag_for_pulse is True, the first time step of a voiced
117
+ segment is always sin(np.pi) or cos(0)
118
+ """
119
+
120
+ def __init__(self, samp_rate, harmonic_num=0,
121
+ sine_amp=0.1, noise_std=0.003,
122
+ voiced_threshold=0):
123
+ super(SineGen, self).__init__()
124
+ self.sine_amp = sine_amp
125
+ self.noise_std = noise_std
126
+ self.harmonic_num = harmonic_num
127
+ self.sampling_rate = samp_rate
128
+ self.voiced_threshold = voiced_threshold
129
+
130
+ def _f02uv(self, f0):
131
+ # generate uv signal
132
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
133
+ return uv
134
+
135
+ @torch.no_grad()
136
+ def forward(self, f0):
137
+ """
138
+ :param f0: [B, 1, sample_len], Hz
139
+ :return: [B, 1, sample_len]
140
+ """
141
+
142
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
143
+ for i in range(self.harmonic_num + 1):
144
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
145
+
146
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
147
+ u_dist = Uniform(low=-np.pi, high=np.pi)
148
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
149
+ phase_vec[:, 0, :] = 0
150
+
151
+ # generate sine waveforms
152
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
153
+
154
+ # generate uv signal
155
+ uv = self._f02uv(f0)
156
+
157
+ # noise: for unvoiced should be similar to sine_amp
158
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
159
+ # . for voiced regions is self.noise_std
160
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
161
+ noise = noise_amp * torch.randn_like(sine_waves)
162
+
163
+ # first: set the unvoiced part to 0 by uv
164
+ # then: additive noise
165
+ sine_waves = sine_waves * uv + noise
166
+ return sine_waves, uv, noise
167
+
168
+
169
+ class SourceModuleHnNSF(torch.nn.Module):
170
+ """ SourceModule for hn-nsf
171
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
172
+ add_noise_std=0.003, voiced_threshod=0)
173
+ sampling_rate: sampling_rate in Hz
174
+ harmonic_num: number of harmonic above F0 (default: 0)
175
+ sine_amp: amplitude of sine source signal (default: 0.1)
176
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
177
+ note that amplitude of noise in unvoiced is decided
178
+ by sine_amp
179
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
180
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
181
+ F0_sampled (batchsize, length, 1)
182
+ Sine_source (batchsize, length, 1)
183
+ noise_source (batchsize, length 1)
184
+ uv (batchsize, length, 1)
185
+ """
186
+
187
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
188
+ add_noise_std=0.003, voiced_threshod=0):
189
+ super(SourceModuleHnNSF, self).__init__()
190
+
191
+ self.sine_amp = sine_amp
192
+ self.noise_std = add_noise_std
193
+
194
+ # to produce sine waveforms
195
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
196
+ sine_amp, add_noise_std, voiced_threshod)
197
+
198
+ # to merge source harmonics into a single excitation
199
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
200
+ self.l_tanh = torch.nn.Tanh()
201
+
202
+ def forward(self, x):
203
+ """
204
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
205
+ F0_sampled (batchsize, length, 1)
206
+ Sine_source (batchsize, length, 1)
207
+ noise_source (batchsize, length 1)
208
+ """
209
+ # source for harmonic branch
210
+ with torch.no_grad():
211
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
212
+ sine_wavs = sine_wavs.transpose(1, 2)
213
+ uv = uv.transpose(1, 2)
214
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
215
+
216
+ # source for noise branch, in the same shape as uv
217
+ noise = torch.randn_like(uv) * self.sine_amp / 3
218
+ return sine_merge, noise, uv
219
+
220
+
221
+ class HiFTGenerator(nn.Module):
222
+ """
223
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
224
+ https://arxiv.org/abs/2309.09493
225
+ """
226
+ def __init__(
227
+ self,
228
+ in_channels: int = 80,
229
+ base_channels: int = 512,
230
+ nb_harmonics: int = 8,
231
+ sampling_rate: int = 22050,
232
+ nsf_alpha: float = 0.1,
233
+ nsf_sigma: float = 0.003,
234
+ nsf_voiced_threshold: float = 10,
235
+ upsample_rates: tp.List[int] = [8, 8],
236
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
237
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
238
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
239
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
240
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
241
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
242
+ lrelu_slope: float = 0.1,
243
+ audio_limit: float = 0.99,
244
+ f0_predictor: torch.nn.Module = None,
245
+ ):
246
+ super(HiFTGenerator, self).__init__()
247
+
248
+ self.out_channels = 1
249
+ self.nb_harmonics = nb_harmonics
250
+ self.sampling_rate = sampling_rate
251
+ self.istft_params = istft_params
252
+ self.lrelu_slope = lrelu_slope
253
+ self.audio_limit = audio_limit
254
+
255
+ self.num_kernels = len(resblock_kernel_sizes)
256
+ self.num_upsamples = len(upsample_rates)
257
+ self.m_source = SourceModuleHnNSF(
258
+ sampling_rate=sampling_rate,
259
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
260
+ harmonic_num=nb_harmonics,
261
+ sine_amp=nsf_alpha,
262
+ add_noise_std=nsf_sigma,
263
+ voiced_threshod=nsf_voiced_threshold)
264
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
265
+
266
+ self.conv_pre = weight_norm(
267
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
268
+ )
269
+
270
+ # Up
271
+ self.ups = nn.ModuleList()
272
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
273
+ self.ups.append(
274
+ weight_norm(
275
+ ConvTranspose1d(
276
+ base_channels // (2**i),
277
+ base_channels // (2**(i + 1)),
278
+ k,
279
+ u,
280
+ padding=(k - u) // 2,
281
+ )
282
+ )
283
+ )
284
+
285
+ # Down
286
+ self.source_downs = nn.ModuleList()
287
+ self.source_resblocks = nn.ModuleList()
288
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
289
+ downsample_cum_rates = np.cumprod(downsample_rates)
290
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
291
+ source_resblock_dilation_sizes)):
292
+ if u == 1:
293
+ self.source_downs.append(
294
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
295
+ )
296
+ else:
297
+ self.source_downs.append(
298
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
299
+ )
300
+
301
+ self.source_resblocks.append(
302
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
303
+ )
304
+
305
+ self.resblocks = nn.ModuleList()
306
+ for i in range(len(self.ups)):
307
+ ch = base_channels // (2**(i + 1))
308
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
309
+ self.resblocks.append(ResBlock(ch, k, d))
310
+
311
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
312
+ self.ups.apply(init_weights)
313
+ self.conv_post.apply(init_weights)
314
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
315
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
316
+ self.f0_predictor = f0_predictor
317
+
318
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
319
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
320
+
321
+ har_source, _, _ = self.m_source(f0)
322
+ return har_source.transpose(1, 2)
323
+
324
+ def _stft(self, x):
325
+ spec = torch.stft(
326
+ x,
327
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
328
+ return_complex=True)
329
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
330
+ return spec[..., 0], spec[..., 1]
331
+
332
+ def _istft(self, magnitude, phase):
333
+ magnitude = torch.clip(magnitude, max=1e2)
334
+ real = magnitude * torch.cos(phase)
335
+ img = magnitude * torch.sin(phase)
336
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
337
+ return inverse_transform
338
+
339
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
340
+ f0 = self.f0_predictor(x)
341
+ s = self._f02source(f0)
342
+
343
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
344
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
345
+
346
+ x = self.conv_pre(x)
347
+ for i in range(self.num_upsamples):
348
+ x = F.leaky_relu(x, self.lrelu_slope)
349
+ x = self.ups[i](x)
350
+
351
+ if i == self.num_upsamples - 1:
352
+ x = self.reflection_pad(x)
353
+
354
+ # fusion
355
+ si = self.source_downs[i](s_stft)
356
+ si = self.source_resblocks[i](si)
357
+ x = x + si
358
+
359
+ xs = None
360
+ for j in range(self.num_kernels):
361
+ if xs is None:
362
+ xs = self.resblocks[i * self.num_kernels + j](x)
363
+ else:
364
+ xs += self.resblocks[i * self.num_kernels + j](x)
365
+ x = xs / self.num_kernels
366
+
367
+ x = F.leaky_relu(x)
368
+ x = self.conv_post(x)
369
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
370
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
371
+
372
+ x = self._istft(magnitude, phase)
373
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
374
+ return x
375
+
376
+ def remove_weight_norm(self):
377
+ print('Removing weight norm...')
378
+ for l in self.ups:
379
+ remove_weight_norm(l)
380
+ for l in self.resblocks:
381
+ l.remove_weight_norm()
382
+ remove_weight_norm(self.conv_pre)
383
+ remove_weight_norm(self.conv_post)
384
+ self.source_module.remove_weight_norm()
385
+ for l in self.source_downs:
386
+ remove_weight_norm(l)
387
+ for l in self.source_resblocks:
388
+ l.remove_weight_norm()
389
+
390
+ @torch.inference_mode()
391
+ def inference(self, mel: torch.Tensor) -> torch.Tensor:
392
+ return self.forward(x=mel)
audio_detokenizer/transformer/__init__.py ADDED
File without changes
audio_detokenizer/transformer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (190 Bytes). View file
 
audio_detokenizer/transformer/__pycache__/activation.cpython-38.pyc ADDED
Binary file (2.51 kB). View file
 
audio_detokenizer/transformer/__pycache__/attention.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
audio_detokenizer/transformer/__pycache__/convolution.cpython-38.pyc ADDED
Binary file (3.1 kB). View file
 
audio_detokenizer/transformer/__pycache__/embedding.cpython-38.pyc ADDED
Binary file (9.78 kB). View file
 
audio_detokenizer/transformer/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (19.6 kB). View file
 
audio_detokenizer/transformer/__pycache__/encoder_layer.cpython-38.pyc ADDED
Binary file (8.64 kB). View file
 
audio_detokenizer/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc ADDED
Binary file (3.79 kB). View file
 
audio_detokenizer/transformer/__pycache__/subsampling.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
audio_detokenizer/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
audio_detokenizer/transformer/attention.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+ self.dropout_rate = dropout_rate
53
+ self.kv_cache = None
54
+
55
+ def forward_qkv(
56
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
57
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
58
+ """Transform query, key and value.
59
+
60
+ Args:
61
+ query (torch.Tensor): Query tensor (#batch, time1, size).
62
+ key (torch.Tensor): Key tensor (#batch, time2, size).
63
+ value (torch.Tensor): Value tensor (#batch, time2, size).
64
+
65
+ Returns:
66
+ torch.Tensor: Transformed query tensor, size
67
+ (#batch, n_head, time1, d_k).
68
+ torch.Tensor: Transformed key tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+ torch.Tensor: Transformed value tensor, size
71
+ (#batch, n_head, time2, d_k).
72
+
73
+ """
74
+ n_batch = query.size(0)
75
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
76
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
77
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
78
+
79
+ return q, k, v
80
+
81
+ def forward_attention(
82
+ self,
83
+ value: torch.Tensor,
84
+ scores: torch.Tensor,
85
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
86
+ ) -> torch.Tensor:
87
+ """Compute attention context vector.
88
+
89
+ Args:
90
+ value (torch.Tensor): Transformed value, size
91
+ (#batch, n_head, time2, d_k).
92
+ scores (torch.Tensor): Attention score, size
93
+ (#batch, n_head, time1, time2).
94
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
95
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
96
+
97
+ Returns:
98
+ torch.Tensor: Transformed value (#batch, time1, d_model)
99
+ weighted by the attention score (#batch, time1, time2).
100
+
101
+ """
102
+ n_batch = value.size(0)
103
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
104
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
105
+ # 1st chunk to ease the onnx export.]
106
+ # 2. pytorch training
107
+ if mask.size(2) > 0: # time2 > 0
108
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
109
+ # For last chunk, time2 might be larger than scores.size(-1)
110
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
111
+ scores = scores.masked_fill(mask, -float('inf'))
112
+ attn = torch.softmax(scores, dim=-1).masked_fill(
113
+ mask, 0.0) # (batch, head, time1, time2)
114
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
115
+ # 1. onnx(16/-1, -1/-1, 16/0)
116
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
117
+ else:
118
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
119
+
120
+ p_attn = self.dropout(attn)
121
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
122
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
123
+ self.h * self.d_k)
124
+ ) # (batch, time1, d_model)
125
+
126
+ return self.linear_out(x) # (batch, time1, d_model)
127
+
128
+ def forward(
129
+ self,
130
+ query: torch.Tensor,
131
+ key: torch.Tensor,
132
+ value: torch.Tensor,
133
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
134
+ pos_emb: torch.Tensor = torch.empty(0),
135
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute scaled dot product attention.
138
+
139
+ Args:
140
+ query (torch.Tensor): Query tensor (#batch, time1, size).
141
+ key (torch.Tensor): Key tensor (#batch, time2, size).
142
+ value (torch.Tensor): Value tensor (#batch, time2, size).
143
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
144
+ (#batch, time1, time2).
145
+ 1.When applying cross attention between decoder and encoder,
146
+ the batch padding mask for input is in (#batch, 1, T) shape.
147
+ 2.When applying self attention of encoder,
148
+ the mask is in (#batch, T, T) shape.
149
+ 3.When applying self attention of decoder,
150
+ the mask is in (#batch, L, L) shape.
151
+ 4.If the different position in decoder see different block
152
+ of the encoder, such as Mocha, the passed in mask could be
153
+ in (#batch, L, T) shape.
154
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
155
+ where `cache_t == chunk_size * num_decoding_left_chunks`
156
+ and `head * d_k == size`
157
+
158
+
159
+ Returns:
160
+ torch.Tensor: Output tensor (#batch, time1, d_model).
161
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
162
+ where `cache_t == chunk_size * num_decoding_left_chunks`
163
+ and `head * d_k == size`
164
+
165
+ """
166
+ q, k, v = self.forward_qkv(query, key, value)
167
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
168
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
169
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
170
+
171
+ # NOTE(xcsong):
172
+ # when export onnx model, for 1st chunk, we feed
173
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
174
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
175
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
176
+ # and we will always do splitting and
177
+ # concatnation(this will simplify onnx export). Note that
178
+ # it's OK to concat & split zero-shaped tensors(see code below).
179
+ # when export jit model, for 1st chunk, we always feed
180
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
181
+ # >>> a = torch.ones((1, 2, 0, 4))
182
+ # >>> b = torch.ones((1, 2, 3, 4))
183
+ # >>> c = torch.cat((a, b), dim=2)
184
+ # >>> torch.equal(b, c) # True
185
+ # >>> d = torch.split(a, 2, dim=-1)
186
+ # >>> torch.equal(d[0], d[1]) # True
187
+ if cache.size(0) > 0:
188
+ key_cache, value_cache = torch.split(cache,
189
+ cache.size(-1) // 2,
190
+ dim=-1)
191
+ k = torch.cat([key_cache, k], dim=2)
192
+ v = torch.cat([value_cache, v], dim=2)
193
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
194
+ # non-trivial to calculate `next_cache_start` here.
195
+ new_cache = torch.cat((k, v), dim=-1)
196
+
197
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
198
+ return self.forward_attention(v, scores, mask), new_cache
199
+
200
+ def inference(
201
+ self,
202
+ query: torch.Tensor,
203
+ key: torch.Tensor,
204
+ value: torch.Tensor,
205
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
206
+ pos_emb: torch.Tensor = torch.empty(0),
207
+ cache_offset: torch.Tensor = None,
208
+ is_infer_short: bool = False,
209
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ """Compute scaled dot product attention.
211
+
212
+ Args:
213
+ query (torch.Tensor): Query tensor (#batch, time1, size).
214
+ key (torch.Tensor): Key tensor (#batch, time2, size).
215
+ value (torch.Tensor): Value tensor (#batch, time2, size).
216
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
217
+ (#batch, time1, time2).
218
+ 1.When applying cross attention between decoder and encoder,
219
+ the batch padding mask for input is in (#batch, 1, T) shape.
220
+ 2.When applying self attention of encoder,
221
+ the mask is in (#batch, T, T) shape.
222
+ 3.When applying self attention of decoder,
223
+ the mask is in (#batch, L, L) shape.
224
+ 4.If the different position in decoder see different block
225
+ of the encoder, such as Mocha, the passed in mask could be
226
+ in (#batch, L, T) shape.
227
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
228
+ where `cache_t == chunk_size * num_decoding_left_chunks`
229
+ and `head * d_k == size`
230
+
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor (#batch, time1, d_model).
234
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
235
+ where `cache_t == chunk_size * num_decoding_left_chunks`
236
+ and `head * d_k == size`
237
+
238
+ """
239
+ q, k, v = self.forward_qkv(query, key, value)
240
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
241
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
242
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
243
+
244
+ if self.kv_cache is not None:
245
+ k, v = self.kv_cache.update(cache_offset, k, v, is_infer_short)
246
+
247
+ assert mask.dtype == torch.bool
248
+ mask = mask.unsqueeze(1).eq(False) * torch.finfo(q.dtype).min
249
+
250
+ output = torch.nn.functional.scaled_dot_product_attention(
251
+ q,
252
+ k,
253
+ v,
254
+ attn_mask=mask,
255
+ dropout_p=self.dropout_rate,
256
+ scale=1 / math.sqrt(self.d_k),
257
+ )
258
+ output = (output.transpose(1, 2).contiguous().view(
259
+ query.size(0), -1,
260
+ self.h * self.d_k)) # (batch, time1, d_model)
261
+ return self.linear_out(output)
262
+
263
+
264
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
265
+ """Multi-Head Attention layer with relative position encoding.
266
+ Paper: https://arxiv.org/abs/1901.02860
267
+ Args:
268
+ n_head (int): The number of heads.
269
+ n_feat (int): The number of features.
270
+ dropout_rate (float): Dropout rate.
271
+ """
272
+
273
+ def __init__(self,
274
+ n_head: int,
275
+ n_feat: int,
276
+ dropout_rate: float,
277
+ key_bias: bool = True):
278
+ """Construct an RelPositionMultiHeadedAttention object."""
279
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
280
+ # linear transformation for positional encoding
281
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
282
+ # these two learnable bias are used in matrix c and matrix d
283
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
284
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
285
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
286
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
287
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
288
+
289
+ def rel_shift(self, x):
290
+ """Compute relative positional encoding.
291
+
292
+ Args:
293
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
294
+ time1 means the length of query vector.
295
+
296
+ Returns:
297
+ torch.Tensor: Output tensor.
298
+
299
+ """
300
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
301
+ x_padded = torch.cat([zero_pad, x], dim=-1)
302
+
303
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
304
+ x = x_padded[:, :, 1:].view_as(x)[
305
+ :, :, :, : x.size(-1) // 2 + 1
306
+ ] # only keep the positions from 0 to time2
307
+ return x
308
+
309
+ def forward(
310
+ self,
311
+ query: torch.Tensor,
312
+ key: torch.Tensor,
313
+ value: torch.Tensor,
314
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
315
+ pos_emb: torch.Tensor = torch.empty(0),
316
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
317
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
318
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
319
+ Args:
320
+ query (torch.Tensor): Query tensor (#batch, time1, size).
321
+ key (torch.Tensor): Key tensor (#batch, time2, size).
322
+ value (torch.Tensor): Value tensor (#batch, time2, size).
323
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
324
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
325
+ pos_emb (torch.Tensor): Positional embedding tensor
326
+ (#batch, time2, size).
327
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
328
+ where `cache_t == chunk_size * num_decoding_left_chunks`
329
+ and `head * d_k == size`
330
+ Returns:
331
+ torch.Tensor: Output tensor (#batch, time1, d_model).
332
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
333
+ where `cache_t == chunk_size * num_decoding_left_chunks`
334
+ and `head * d_k == size`
335
+ """
336
+ q, k, v = self.forward_qkv(query, key, value)
337
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
338
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
339
+
340
+ # NOTE(xcsong):
341
+ # when export onnx model, for 1st chunk, we feed
342
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
343
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
344
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
345
+ # and we will always do splitting and
346
+ # concatnation(this will simplify onnx export). Note that
347
+ # it's OK to concat & split zero-shaped tensors(see code below).
348
+ # when export jit model, for 1st chunk, we always feed
349
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
350
+ # >>> a = torch.ones((1, 2, 0, 4))
351
+ # >>> b = torch.ones((1, 2, 3, 4))
352
+ # >>> c = torch.cat((a, b), dim=2)
353
+ # >>> torch.equal(b, c) # True
354
+ # >>> d = torch.split(a, 2, dim=-1)
355
+ # >>> torch.equal(d[0], d[1]) # True
356
+ if cache.size(0) > 0:
357
+ key_cache, value_cache = torch.split(cache,
358
+ cache.size(-1) // 2,
359
+ dim=-1)
360
+ k = torch.cat([key_cache, k], dim=2)
361
+ v = torch.cat([value_cache, v], dim=2)
362
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
363
+ # non-trivial to calculate `next_cache_start` here.
364
+ new_cache = torch.cat((k, v), dim=-1)
365
+
366
+ n_batch_pos = pos_emb.size(0)
367
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
368
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
369
+
370
+ # (batch, head, time1, d_k)
371
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
372
+ # (batch, head, time1, d_k)
373
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
374
+
375
+ # compute attention score
376
+ # first compute matrix a and matrix c
377
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
378
+ # (batch, head, time1, time2)
379
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
380
+
381
+ # compute matrix b and matrix d
382
+ # (batch, head, time1, time2)
383
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
384
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
385
+ if matrix_ac.shape != matrix_bd.shape:
386
+ matrix_bd = self.rel_shift(matrix_bd)
387
+
388
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
389
+ self.d_k) # (batch, head, time1, time2)
390
+
391
+ return self.forward_attention(v, scores, mask), new_cache
392
+
393
+ def inference(
394
+ self,
395
+ query: torch.Tensor,
396
+ key: torch.Tensor,
397
+ value: torch.Tensor,
398
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
399
+ pos_emb: torch.Tensor = torch.empty(0),
400
+ cache_offset: torch.Tensor = None,
401
+ is_infer_short: bool = False,
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
404
+ Args:
405
+ query (torch.Tensor): Query tensor (#batch, time1, size).
406
+ key (torch.Tensor): Key tensor (#batch, time2, size).
407
+ value (torch.Tensor): Value tensor (#batch, time2, size).
408
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
409
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
410
+ pos_emb (torch.Tensor): Positional embedding tensor
411
+ (#batch, time2, size).
412
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
413
+ where `cache_t == chunk_size * num_decoding_left_chunks`
414
+ and `head * d_k == size`
415
+ Returns:
416
+ torch.Tensor: Output tensor (#batch, time1, d_model).
417
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
418
+ where `cache_t == chunk_size * num_decoding_left_chunks`
419
+ and `head * d_k == size`
420
+ """
421
+ q, k, v = self.forward_qkv(query, key, value)
422
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
423
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
424
+
425
+ if self.kv_cache is not None:
426
+ k, v = self.kv_cache.update(cache_offset, k, v, is_infer_short)
427
+
428
+ n_batch_pos = pos_emb.size(0)
429
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
430
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
431
+
432
+ # (batch, head, time1, d_k)
433
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
434
+ # (batch, head, time1, d_k)
435
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
436
+
437
+ # compute matrix b and matrix d
438
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
439
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
440
+ matrix_bd = self.rel_shift(matrix_bd)
441
+
442
+ assert mask.dtype == torch.bool
443
+ # mask = (mask.unsqueeze(1).eq(False) * torch.finfo(k.dtype).min).to(matrix_bd.dtype)
444
+ mask = mask.unsqueeze(1).eq(False)
445
+ mask = (matrix_bd / math.sqrt(self.d_k)).masked_fill(mask, torch.tensor(-float('inf'), dtype=matrix_bd.dtype))
446
+ # import pdb; pdb.set_trace()
447
+ # print("q_with_bias_u.shape", q_with_bias_u.shape)
448
+ # print("k.shape", k.shape)
449
+ # print("v.shape", v.shape)
450
+ # print("mask.shape", mask.shape)
451
+ # import pdb; pdb.set_trace()
452
+ output = torch.nn.functional.scaled_dot_product_attention(
453
+ q_with_bias_u,
454
+ k,
455
+ v,
456
+ attn_mask=mask,
457
+ dropout_p=self.dropout_rate,
458
+ scale=1 / math.sqrt(self.d_k),
459
+ )
460
+
461
+ output = (output.transpose(1, 2).contiguous().view(
462
+ query.size(0), -1, self.h * self.d_k)) # (batch, time1, d_model)
463
+ return self.linear_out(output)
audio_detokenizer/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache