priyankad199 commited on
Commit
f2e14a3
·
verified ·
1 Parent(s): bf215a9

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +166 -12
  2. app.py +48 -0
  3. audio.py +136 -0
  4. color_syncnet_train.py +279 -0
  5. hparams.py +101 -0
  6. hq_wav2lip_train.py +443 -0
  7. inference.py +280 -0
  8. preprocess.py +113 -0
  9. requirements.txt +8 -0
  10. wav2lip_train.py +374 -0
README.md CHANGED
@@ -1,12 +1,166 @@
1
- ---
2
- title: Wav2lip1
3
- emoji: 💻
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Wav2Lip**: *Accurately Lip-syncing Videos In The Wild*
2
+
3
+ ### Wav2Lip is hosted for free at [Sync Labs](https://synclabs.so/)
4
+
5
+ Are you looking to integrate this into a product? We have a turn-key hosted API with new and improved lip-syncing models here: https://synclabs.so/
6
+
7
+ For any other commercial / enterprise requests, please contact us at pavan@synclabs.so and prady@synclabs.so
8
+
9
+ To reach out to the authors directly you can reach us at [email protected], [email protected].
10
+
11
+ This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild_ published at ACM Multimedia 2020.
12
+
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs2)](https://paperswithcode.com/sota/lip-sync-on-lrs2?p=a-lip-sync-expert-is-all-you-need-for-speech)
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs3)](https://paperswithcode.com/sota/lip-sync-on-lrs3?p=a-lip-sync-expert-is-all-you-need-for-speech)
15
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrw)](https://paperswithcode.com/sota/lip-sync-on-lrw?p=a-lip-sync-expert-is-all-you-need-for-speech)
16
+
17
+ |📑 Original Paper|📰 Project Page|🌀 Demo|⚡ Live Testing|📔 Colab Notebook
18
+ |:-:|:-:|:-:|:-:|:-:|
19
+ [Paper](http://arxiv.org/abs/2008.10010) | [Project Page](http://cvit.iiit.ac.in/research/projects/cvit-projects/a-lip-sync-expert-is-all-you-need-for-speech-to-lip-generation-in-the-wild/) | [Demo Video](https://youtu.be/0fXaDCZNOJc) | [Interactive Demo](https://synclabs.so/) | [Colab Notebook](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing) /[Updated Collab Notebook](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH)
20
+
21
+ ![Logo](https://drive.google.com/uc?export=view&id=1Wn0hPmpo4GRbCIJR8Tf20Akzdi1qjjG9)
22
+
23
+ ----------
24
+ **Highlights**
25
+ ----------
26
+ - Weights of the visual quality disc has been updated in readme!
27
+ - Lip-sync videos to any target speech with high accuracy :100:. Try our [interactive demo](https://synclabs.so/).
28
+ - :sparkles: Works for any identity, voice, and language. Also works for CGI faces and synthetic voices.
29
+ - Complete training code, inference code, and pretrained models are available :boom:
30
+ - Or, quick-start with the Google Colab Notebook: [Link](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing). Checkpoints and samples are available in a Google Drive [folder](https://drive.google.com/drive/folders/1I-0dNLfFOSFwrfqjNa-SXuwaURHE5K4k?usp=sharing) as well. There is also a [tutorial video](https://www.youtube.com/watch?v=Ic0TBhfuOrA) on this, courtesy of [What Make Art](https://www.youtube.com/channel/UCmGXH-jy0o2CuhqtpxbaQgA). Also, thanks to [Eyal Gruss](https://eyalgruss.com), there is a more accessible [Google Colab notebook](https://j.mp/wav2lip) with more useful features. A tutorial collab notebook is present at this [link](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH).
31
+ - :fire: :fire: Several new, reliable evaluation benchmarks and metrics [[`evaluation/` folder of this repo]](https://github.com/Rudrabha/Wav2Lip/tree/master/evaluation) released. Instructions to calculate the metrics reported in the paper are also present.
32
+
33
+ --------
34
+ **Disclaimer**
35
+ --------
36
+ All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the <a href="http://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs2.html">LRS2 dataset</a>, any form of commercial use is strictly prohibited. For commercial requests please contact us directly!
37
+
38
+ Prerequisites
39
+ -------------
40
+ - `Python 3.6`
41
+ - ffmpeg: `sudo apt-get install ffmpeg`
42
+ - Install necessary packages using `pip install -r requirements.txt`. Alternatively, instructions for using a docker image is provided [here](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668). Have a look at [this comment](https://github.com/Rudrabha/Wav2Lip/issues/131#issuecomment-725478562) and comment on [the gist](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668) if you encounter any issues.
43
+ - Face detection [pre-trained model](https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth) should be downloaded to `face_detection/detection/sfd/s3fd.pth`. Alternative [link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/prajwal_k_research_iiit_ac_in/EZsy6qWuivtDnANIG73iHjIBjMSoojcIV0NULXV-yiuiIg?e=qTasa8) if the above does not work.
44
+
45
+ Getting the weights
46
+ ----------
47
+ | Model | Description | Link to the model |
48
+ | :-------------: | :---------------: | :---------------: |
49
+ | Wav2Lip | Highly accurate lip-sync | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/Eb3LEzbfuKlJiR600lQWRxgBIY27JZg80f7V9jtMfbNDaQ?e=TBFBVW) |
50
+ | Wav2Lip + GAN | Slightly inferior lip-sync, but better visual quality | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA?e=n9ljGW) |
51
+ | Expert Discriminator | Weights of the expert discriminator | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQRvmiZg-HRAjvI6zqN9eTEBP74KefynCwPWVmF57l-AYA?e=ZRPHKP) |
52
+ | Visual Quality Discriminator | Weights of the visual disc trained in a GAN setup | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQVqH88dTm1HjlK11eNba5gBbn15WMS0B0EZbDBttqrqkg?e=ic0ljo) |
53
+
54
+ Lip-syncing videos using the pre-trained models (Inference)
55
+ -------
56
+ You can lip-sync any video to any audio:
57
+ ```bash
58
+ python inference.py --checkpoint_path <ckpt> --face <video.mp4> --audio <an-audio-source>
59
+ ```
60
+ The result is saved (by default) in `results/result_voice.mp4`. You can specify it as an argument, similar to several other available options. The audio source can be any file supported by `FFMPEG` containing audio data: `*.wav`, `*.mp3` or even a video file, from which the code will automatically extract the audio.
61
+
62
+ ##### Tips for better results:
63
+ - Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`.
64
+ - If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give it another try.
65
+ - Experiment with the `--resize_factor` argument, to get a lower-resolution video. Why? The models are trained on faces that were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too).
66
+ - The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well.
67
+
68
+ Preparing LRS2 for training
69
+ ----------
70
+ Our models are trained on LRS2. See [here](#training-on-datasets-other-than-lrs2) for a few suggestions regarding training on other datasets.
71
+ ##### LRS2 dataset folder structure
72
+
73
+ ```
74
+ data_root (mvlrs_v1)
75
+ ├── main, pretrain (we use only main folder in this work)
76
+ | ├── list of folders
77
+ | │ ├── five-digit numbered video IDs ending with (.mp4)
78
+ ```
79
+
80
+ Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` folder.
81
+
82
+ ##### Preprocess the dataset for fast training
83
+
84
+ ```bash
85
+ python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/
86
+ ```
87
+ Additional options like `batch_size` and the number of GPUs to use in parallel to use can also be set.
88
+
89
+ ##### Preprocessed LRS2 folder structure
90
+ ```
91
+ preprocessed_root (lrs2_preprocessed)
92
+ ├── list of folders
93
+ | ├── Folders with five-digit numbered video IDs
94
+ | │ ├── *.jpg
95
+ | │ ├── audio.wav
96
+ ```
97
+
98
+ Train!
99
+ ----------
100
+ There are two major steps: (i) Train the expert lip-sync discriminator, (ii) Train the Wav2Lip model(s).
101
+
102
+ ##### Training the expert discriminator
103
+ You can download [the pre-trained weights](#getting-the-weights) if you want to skip this step. To train it:
104
+ ```bash
105
+ python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints>
106
+ ```
107
+ ##### Training the Wav2Lip models
108
+ You can either train the model without the additional visual quality discriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run:
109
+ ```bash
110
+ python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir <folder_to_save_checkpoints> --syncnet_checkpoint_path <path_to_expert_disc_checkpoint>
111
+ ```
112
+
113
+ To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both files are similar. In both cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file.
114
+
115
+ Training on datasets other than LRS2
116
+ ------------------------------------
117
+ Training on other datasets might require modifications to the code. Please read the following before you raise an issue:
118
+
119
+ - You might not get good results by training/fine-tuning on a few minutes of a single speaker. This is a separate research problem, to which we do not have a solution yet. Thus, we would most likely not be able to resolve your issue.
120
+ - You must train the expert discriminator for your own dataset before training Wav2Lip.
121
+ - If it is your own dataset downloaded from the web, in most cases, needs to be sync-corrected.
122
+ - Be mindful of the FPS of the videos of your dataset. Changes to FPS would need significant code changes.
123
+ - The expert discriminator's eval loss should go down to ~0.25 and the Wav2Lip eval sync loss should go down to ~0.2 to get good results.
124
+
125
+ When raising an issue on this topic, please let us know that you are aware of all these points.
126
+
127
+ We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model.
128
+
129
+ Evaluation
130
+ ----------
131
+ Please check the `evaluation/` folder for the instructions.
132
+
133
+ License and Citation
134
+ ----------
135
+ This repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at [email protected] or [email protected]. We have a turn-key hosted API with new and improved lip-syncing models here: https://synclabs.so/
136
+ The size of the generated face will be 192 x 288 in our new models. Please cite the following paper if you use this repository:
137
+ ```
138
+ @inproceedings{10.1145/3394171.3413532,
139
+ author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.},
140
+ title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild},
141
+ year = {2020},
142
+ isbn = {9781450379885},
143
+ publisher = {Association for Computing Machinery},
144
+ address = {New York, NY, USA},
145
+ url = {https://doi.org/10.1145/3394171.3413532},
146
+ doi = {10.1145/3394171.3413532},
147
+ booktitle = {Proceedings of the 28th ACM International Conference on Multimedia},
148
+ pages = {484–492},
149
+ numpages = {9},
150
+ keywords = {lip sync, talking face generation, video generation},
151
+ location = {Seattle, WA, USA},
152
+ series = {MM '20}
153
+ }
154
+ ```
155
+
156
+
157
+ Acknowledgments
158
+ ----------
159
+ Parts of the code structure are inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook.
160
+
161
+ ## Acknowledgements
162
+
163
+ - [Awesome Readme Templates](https://awesomeopensource.com/project/elangosundar/awesome-README-templates)
164
+ - [Awesome README](https://github.com/matiassingers/awesome-readme)
165
+ - [How to write a Good readme](https://bulldogjob.com/news/449-how-to-write-a-good-readme-for-your-github-project)
166
+
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import subprocess
4
+ from subprocess import call
5
+
6
+ with gr.Blocks() as ui:
7
+ with gr.Row():
8
+ video = gr.File(label="Video or Image")
9
+ audio = gr.File(label="Audio")
10
+ with gr.Column():
11
+ checkpoint = gr.Radio(["wav2lip", "wav2lip_gan"], label="Checkpoint")
12
+ no_smooth = gr.Checkbox(label="No Smooth")
13
+ resize_factor = gr.Slider(minimum=1, maximum=4, step=1, label="Resize Factor")
14
+ with gr.Row():
15
+ with gr.Column():
16
+ pad_top = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Top")
17
+ pad_bottom = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Pad Bottom (Often increasing this to 20 allows chin to be included)")
18
+ pad_left = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Left")
19
+ pad_right = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Right")
20
+ generate_btn = gr.Button("Generate")
21
+ with gr.Column():
22
+ result = gr.Video()
23
+
24
+ def generate(video, audio, checkpoint, no_smooth, resize_factor, pad_top, pad_bottom, pad_left, pad_right):
25
+ if video is None or audio is None or checkpoint is None:
26
+ return
27
+
28
+ smooth = "--nosmooth" if no_smooth else ""
29
+
30
+
31
+ cmd = [
32
+ "python",
33
+ "inference.py",
34
+ "--checkpoint_path", f"checkpoints/{checkpoint}.pth",
35
+ "--face", video.name,
36
+ "--audio", audio.name,
37
+ "--outfile", "results/output.mp4",
38
+ ]
39
+
40
+ call(cmd)
41
+ return "results/output.mp4"
42
+
43
+ generate_btn.click(
44
+ generate,
45
+ [video, audio, checkpoint, pad_top, pad_bottom, pad_left, pad_right, resize_factor],
46
+ result)
47
+
48
+ ui.queue().launch(share=True,debug=True)
audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+ def _build_mel_basis():
99
+ assert hp.fmax <= hp.sample_rate // 2
100
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft= hp.n_fft, n_mels=hp.num_mels,
101
+ fmin=hp.fmin, fmax=hp.fmax)
102
+
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
color_syncnet_train.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ import audio
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.backends.cudnn as cudnn
11
+ from torch.utils import data as data_utils
12
+ import numpy as np
13
+
14
+ from glob import glob
15
+
16
+ import os, random, cv2, argparse
17
+ from hparams import hparams, get_image_list
18
+
19
+ parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')
20
+
21
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)
22
+
23
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
24
+ parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ global_step = 0
30
+ global_epoch = 0
31
+ use_cuda = torch.cuda.is_available()
32
+ print('use_cuda: {}'.format(use_cuda))
33
+
34
+ syncnet_T = 5
35
+ syncnet_mel_step_size = 16
36
+
37
+ class Dataset(object):
38
+ def __init__(self, split):
39
+ self.all_videos = get_image_list(args.data_root, split)
40
+
41
+ def get_frame_id(self, frame):
42
+ return int(basename(frame).split('.')[0])
43
+
44
+ def get_window(self, start_frame):
45
+ start_id = self.get_frame_id(start_frame)
46
+ vidname = dirname(start_frame)
47
+
48
+ window_fnames = []
49
+ for frame_id in range(start_id, start_id + syncnet_T):
50
+ frame = join(vidname, '{}.jpg'.format(frame_id))
51
+ if not isfile(frame):
52
+ return None
53
+ window_fnames.append(frame)
54
+ return window_fnames
55
+
56
+ def crop_audio_window(self, spec, start_frame):
57
+ # num_frames = (T x hop_size * fps) / sample_rate
58
+ start_frame_num = self.get_frame_id(start_frame)
59
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
60
+
61
+ end_idx = start_idx + syncnet_mel_step_size
62
+
63
+ return spec[start_idx : end_idx, :]
64
+
65
+
66
+ def __len__(self):
67
+ return len(self.all_videos)
68
+
69
+ def __getitem__(self, idx):
70
+ while 1:
71
+ idx = random.randint(0, len(self.all_videos) - 1)
72
+ vidname = self.all_videos[idx]
73
+
74
+ img_names = list(glob(join(vidname, '*.jpg')))
75
+ if len(img_names) <= 3 * syncnet_T:
76
+ continue
77
+ img_name = random.choice(img_names)
78
+ wrong_img_name = random.choice(img_names)
79
+ while wrong_img_name == img_name:
80
+ wrong_img_name = random.choice(img_names)
81
+
82
+ if random.choice([True, False]):
83
+ y = torch.ones(1).float()
84
+ chosen = img_name
85
+ else:
86
+ y = torch.zeros(1).float()
87
+ chosen = wrong_img_name
88
+
89
+ window_fnames = self.get_window(chosen)
90
+ if window_fnames is None:
91
+ continue
92
+
93
+ window = []
94
+ all_read = True
95
+ for fname in window_fnames:
96
+ img = cv2.imread(fname)
97
+ if img is None:
98
+ all_read = False
99
+ break
100
+ try:
101
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
102
+ except Exception as e:
103
+ all_read = False
104
+ break
105
+
106
+ window.append(img)
107
+
108
+ if not all_read: continue
109
+
110
+ try:
111
+ wavpath = join(vidname, "audio.wav")
112
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
113
+
114
+ orig_mel = audio.melspectrogram(wav).T
115
+ except Exception as e:
116
+ continue
117
+
118
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
119
+
120
+ if (mel.shape[0] != syncnet_mel_step_size):
121
+ continue
122
+
123
+ # H x W x 3 * T
124
+ x = np.concatenate(window, axis=2) / 255.
125
+ x = x.transpose(2, 0, 1)
126
+ x = x[:, x.shape[1]//2:]
127
+
128
+ x = torch.FloatTensor(x)
129
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
130
+
131
+ return x, mel, y
132
+
133
+ logloss = nn.BCELoss()
134
+ def cosine_loss(a, v, y):
135
+ d = nn.functional.cosine_similarity(a, v)
136
+ loss = logloss(d.unsqueeze(1), y)
137
+
138
+ return loss
139
+
140
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
141
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
142
+
143
+ global global_step, global_epoch
144
+ resumed_step = global_step
145
+
146
+ while global_epoch < nepochs:
147
+ running_loss = 0.
148
+ prog_bar = tqdm(enumerate(train_data_loader))
149
+ for step, (x, mel, y) in prog_bar:
150
+ model.train()
151
+ optimizer.zero_grad()
152
+
153
+ # Transform data to CUDA device
154
+ x = x.to(device)
155
+
156
+ mel = mel.to(device)
157
+
158
+ a, v = model(mel, x)
159
+ y = y.to(device)
160
+
161
+ loss = cosine_loss(a, v, y)
162
+ loss.backward()
163
+ optimizer.step()
164
+
165
+ global_step += 1
166
+ cur_session_steps = global_step - resumed_step
167
+ running_loss += loss.item()
168
+
169
+ if global_step == 1 or global_step % checkpoint_interval == 0:
170
+ save_checkpoint(
171
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
172
+
173
+ if global_step % hparams.syncnet_eval_interval == 0:
174
+ with torch.no_grad():
175
+ eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
176
+
177
+ prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))
178
+
179
+ global_epoch += 1
180
+
181
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
182
+ eval_steps = 1400
183
+ print('Evaluating for {} steps'.format(eval_steps))
184
+ losses = []
185
+ while 1:
186
+ for step, (x, mel, y) in enumerate(test_data_loader):
187
+
188
+ model.eval()
189
+
190
+ # Transform data to CUDA device
191
+ x = x.to(device)
192
+
193
+ mel = mel.to(device)
194
+
195
+ a, v = model(mel, x)
196
+ y = y.to(device)
197
+
198
+ loss = cosine_loss(a, v, y)
199
+ losses.append(loss.item())
200
+
201
+ if step > eval_steps: break
202
+
203
+ averaged_loss = sum(losses) / len(losses)
204
+ print(averaged_loss)
205
+
206
+ return
207
+
208
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
209
+
210
+ checkpoint_path = join(
211
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
212
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
213
+ torch.save({
214
+ "state_dict": model.state_dict(),
215
+ "optimizer": optimizer_state,
216
+ "global_step": step,
217
+ "global_epoch": epoch,
218
+ }, checkpoint_path)
219
+ print("Saved checkpoint:", checkpoint_path)
220
+
221
+ def _load(checkpoint_path):
222
+ if use_cuda:
223
+ checkpoint = torch.load(checkpoint_path)
224
+ else:
225
+ checkpoint = torch.load(checkpoint_path,
226
+ map_location=lambda storage, loc: storage)
227
+ return checkpoint
228
+
229
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False):
230
+ global global_step
231
+ global global_epoch
232
+
233
+ print("Load checkpoint from: {}".format(path))
234
+ checkpoint = _load(path)
235
+ model.load_state_dict(checkpoint["state_dict"])
236
+ if not reset_optimizer:
237
+ optimizer_state = checkpoint["optimizer"]
238
+ if optimizer_state is not None:
239
+ print("Load optimizer state from {}".format(path))
240
+ optimizer.load_state_dict(checkpoint["optimizer"])
241
+ global_step = checkpoint["global_step"]
242
+ global_epoch = checkpoint["global_epoch"]
243
+
244
+ return model
245
+
246
+ if __name__ == "__main__":
247
+ checkpoint_dir = args.checkpoint_dir
248
+ checkpoint_path = args.checkpoint_path
249
+
250
+ if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
251
+
252
+ # Dataset and Dataloader setup
253
+ train_dataset = Dataset('train')
254
+ test_dataset = Dataset('val')
255
+
256
+ train_data_loader = data_utils.DataLoader(
257
+ train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
258
+ num_workers=hparams.num_workers)
259
+
260
+ test_data_loader = data_utils.DataLoader(
261
+ test_dataset, batch_size=hparams.syncnet_batch_size,
262
+ num_workers=8)
263
+
264
+ device = torch.device("cuda" if use_cuda else "cpu")
265
+
266
+ # Model
267
+ model = SyncNet().to(device)
268
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
269
+
270
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
271
+ lr=hparams.syncnet_lr)
272
+
273
+ if checkpoint_path is not None:
274
+ load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)
275
+
276
+ train(device, model, train_data_loader, test_data_loader, optimizer,
277
+ checkpoint_dir=checkpoint_dir,
278
+ checkpoint_interval=hparams.syncnet_checkpoint_interval,
279
+ nepochs=hparams.nepochs)
hparams.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+
4
+ def get_image_list(data_root, split):
5
+ filelist = []
6
+
7
+ with open('filelists/{}.txt'.format(split)) as f:
8
+ for line in f:
9
+ line = line.strip()
10
+ if ' ' in line: line = line.split()[0]
11
+ filelist.append(os.path.join(data_root, line))
12
+
13
+ return filelist
14
+
15
+ class HParams:
16
+ def __init__(self, **kwargs):
17
+ self.data = {}
18
+
19
+ for key, value in kwargs.items():
20
+ self.data[key] = value
21
+
22
+ def __getattr__(self, key):
23
+ if key not in self.data:
24
+ raise AttributeError("'HParams' object has no attribute %s" % key)
25
+ return self.data[key]
26
+
27
+ def set_hparam(self, key, value):
28
+ self.data[key] = value
29
+
30
+
31
+ # Default hyperparameters
32
+ hparams = HParams(
33
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
34
+ # network
35
+ rescale=True, # Whether to rescale audio prior to preprocessing
36
+ rescaling_max=0.9, # Rescaling value
37
+
38
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
39
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
40
+ # Does not work if n_ffit is not multiple of hop_size!!
41
+ use_lws=False,
42
+
43
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
44
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
45
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
46
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
47
+
48
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
49
+
50
+ # Mel and Linear spectrograms normalization/scaling and clipping
51
+ signal_normalization=True,
52
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
53
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
54
+ symmetric_mels=True,
55
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
56
+ # faster and cleaner convergence)
57
+ max_abs_value=4.,
58
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
59
+ # be too big to avoid gradient explosion,
60
+ # not too small for fast convergence)
61
+ # Contribution by @begeekmyfriend
62
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
63
+ # levels. Also allows for better G&L phase reconstruction)
64
+ preemphasize=True, # whether to apply filter
65
+ preemphasis=0.97, # filter coefficient.
66
+
67
+ # Limits
68
+ min_level_db=-100,
69
+ ref_level_db=20,
70
+ fmin=55,
71
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
72
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
73
+ fmax=7600, # To be increased/reduced depending on data.
74
+
75
+ ###################### Our training parameters #################################
76
+ img_size=96,
77
+ fps=25,
78
+
79
+ batch_size=16,
80
+ initial_learning_rate=1e-4,
81
+ nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
82
+ num_workers=16,
83
+ checkpoint_interval=3000,
84
+ eval_interval=3000,
85
+ save_optimizer_state=True,
86
+
87
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
88
+ syncnet_batch_size=64,
89
+ syncnet_lr=1e-4,
90
+ syncnet_eval_interval=10000,
91
+ syncnet_checkpoint_interval=10000,
92
+
93
+ disc_wt=0.07,
94
+ disc_initial_learning_rate=1e-4,
95
+ )
96
+
97
+
98
+ def hparams_debug_string():
99
+ values = hparams.values()
100
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
101
+ return "Hyperparameters:\n" + "\n".join(hp)
hq_wav2lip_train.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip, Wav2Lip_disc_qual
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch import optim
12
+ import torch.backends.cudnn as cudnn
13
+ from torch.utils import data as data_utils
14
+ import numpy as np
15
+
16
+ from glob import glob
17
+
18
+ import os, random, cv2, argparse
19
+ from hparams import hparams, get_image_list
20
+
21
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator')
22
+
23
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
24
+
25
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
26
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
27
+
28
+ parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str)
29
+ parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str)
30
+
31
+ args = parser.parse_args()
32
+
33
+
34
+ global_step = 0
35
+ global_epoch = 0
36
+ use_cuda = torch.cuda.is_available()
37
+ print('use_cuda: {}'.format(use_cuda))
38
+
39
+ syncnet_T = 5
40
+ syncnet_mel_step_size = 16
41
+
42
+ class Dataset(object):
43
+ def __init__(self, split):
44
+ self.all_videos = get_image_list(args.data_root, split)
45
+
46
+ def get_frame_id(self, frame):
47
+ return int(basename(frame).split('.')[0])
48
+
49
+ def get_window(self, start_frame):
50
+ start_id = self.get_frame_id(start_frame)
51
+ vidname = dirname(start_frame)
52
+
53
+ window_fnames = []
54
+ for frame_id in range(start_id, start_id + syncnet_T):
55
+ frame = join(vidname, '{}.jpg'.format(frame_id))
56
+ if not isfile(frame):
57
+ return None
58
+ window_fnames.append(frame)
59
+ return window_fnames
60
+
61
+ def read_window(self, window_fnames):
62
+ if window_fnames is None: return None
63
+ window = []
64
+ for fname in window_fnames:
65
+ img = cv2.imread(fname)
66
+ if img is None:
67
+ return None
68
+ try:
69
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
70
+ except Exception as e:
71
+ return None
72
+
73
+ window.append(img)
74
+
75
+ return window
76
+
77
+ def crop_audio_window(self, spec, start_frame):
78
+ if type(start_frame) == int:
79
+ start_frame_num = start_frame
80
+ else:
81
+ start_frame_num = self.get_frame_id(start_frame)
82
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
83
+
84
+ end_idx = start_idx + syncnet_mel_step_size
85
+
86
+ return spec[start_idx : end_idx, :]
87
+
88
+ def get_segmented_mels(self, spec, start_frame):
89
+ mels = []
90
+ assert syncnet_T == 5
91
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
92
+ if start_frame_num - 2 < 0: return None
93
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
94
+ m = self.crop_audio_window(spec, i - 2)
95
+ if m.shape[0] != syncnet_mel_step_size:
96
+ return None
97
+ mels.append(m.T)
98
+
99
+ mels = np.asarray(mels)
100
+
101
+ return mels
102
+
103
+ def prepare_window(self, window):
104
+ # 3 x T x H x W
105
+ x = np.asarray(window) / 255.
106
+ x = np.transpose(x, (3, 0, 1, 2))
107
+
108
+ return x
109
+
110
+ def __len__(self):
111
+ return len(self.all_videos)
112
+
113
+ def __getitem__(self, idx):
114
+ while 1:
115
+ idx = random.randint(0, len(self.all_videos) - 1)
116
+ vidname = self.all_videos[idx]
117
+ img_names = list(glob(join(vidname, '*.jpg')))
118
+ if len(img_names) <= 3 * syncnet_T:
119
+ continue
120
+
121
+ img_name = random.choice(img_names)
122
+ wrong_img_name = random.choice(img_names)
123
+ while wrong_img_name == img_name:
124
+ wrong_img_name = random.choice(img_names)
125
+
126
+ window_fnames = self.get_window(img_name)
127
+ wrong_window_fnames = self.get_window(wrong_img_name)
128
+ if window_fnames is None or wrong_window_fnames is None:
129
+ continue
130
+
131
+ window = self.read_window(window_fnames)
132
+ if window is None:
133
+ continue
134
+
135
+ wrong_window = self.read_window(wrong_window_fnames)
136
+ if wrong_window is None:
137
+ continue
138
+
139
+ try:
140
+ wavpath = join(vidname, "audio.wav")
141
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
142
+
143
+ orig_mel = audio.melspectrogram(wav).T
144
+ except Exception as e:
145
+ continue
146
+
147
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
148
+
149
+ if (mel.shape[0] != syncnet_mel_step_size):
150
+ continue
151
+
152
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
153
+ if indiv_mels is None: continue
154
+
155
+ window = self.prepare_window(window)
156
+ y = window.copy()
157
+ window[:, :, window.shape[2]//2:] = 0.
158
+
159
+ wrong_window = self.prepare_window(wrong_window)
160
+ x = np.concatenate([window, wrong_window], axis=0)
161
+
162
+ x = torch.FloatTensor(x)
163
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
164
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
165
+ y = torch.FloatTensor(y)
166
+ return x, indiv_mels, mel, y
167
+
168
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
169
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
171
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
172
+
173
+ refs, inps = x[..., 3:], x[..., :3]
174
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
175
+ if not os.path.exists(folder): os.mkdir(folder)
176
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
177
+ for batch_idx, c in enumerate(collage):
178
+ for t in range(len(c)):
179
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
180
+
181
+ logloss = nn.BCELoss()
182
+ def cosine_loss(a, v, y):
183
+ d = nn.functional.cosine_similarity(a, v)
184
+ loss = logloss(d.unsqueeze(1), y)
185
+
186
+ return loss
187
+
188
+ device = torch.device("cuda" if use_cuda else "cpu")
189
+ syncnet = SyncNet().to(device)
190
+ for p in syncnet.parameters():
191
+ p.requires_grad = False
192
+
193
+ recon_loss = nn.L1Loss()
194
+ def get_sync_loss(mel, g):
195
+ g = g[:, :, :, g.size(3)//2:]
196
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
197
+ # B, 3 * T, H//2, W
198
+ a, v = syncnet(mel, g)
199
+ y = torch.ones(g.size(0), 1).float().to(device)
200
+ return cosine_loss(a, v, y)
201
+
202
+ def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
203
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
204
+ global global_step, global_epoch
205
+ resumed_step = global_step
206
+
207
+ while global_epoch < nepochs:
208
+ print('Starting Epoch: {}'.format(global_epoch))
209
+ running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
210
+ running_disc_real_loss, running_disc_fake_loss = 0., 0.
211
+ prog_bar = tqdm(enumerate(train_data_loader))
212
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
213
+ disc.train()
214
+ model.train()
215
+
216
+ x = x.to(device)
217
+ mel = mel.to(device)
218
+ indiv_mels = indiv_mels.to(device)
219
+ gt = gt.to(device)
220
+
221
+ ### Train generator now. Remove ALL grads.
222
+ optimizer.zero_grad()
223
+ disc_optimizer.zero_grad()
224
+
225
+ g = model(indiv_mels, x)
226
+
227
+ if hparams.syncnet_wt > 0.:
228
+ sync_loss = get_sync_loss(mel, g)
229
+ else:
230
+ sync_loss = 0.
231
+
232
+ if hparams.disc_wt > 0.:
233
+ perceptual_loss = disc.perceptual_forward(g)
234
+ else:
235
+ perceptual_loss = 0.
236
+
237
+ l1loss = recon_loss(g, gt)
238
+
239
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
240
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
241
+
242
+ loss.backward()
243
+ optimizer.step()
244
+
245
+ ### Remove all gradients before Training disc
246
+ disc_optimizer.zero_grad()
247
+
248
+ pred = disc(gt)
249
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
250
+ disc_real_loss.backward()
251
+
252
+ pred = disc(g.detach())
253
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
254
+ disc_fake_loss.backward()
255
+
256
+ disc_optimizer.step()
257
+
258
+ running_disc_real_loss += disc_real_loss.item()
259
+ running_disc_fake_loss += disc_fake_loss.item()
260
+
261
+ if global_step % checkpoint_interval == 0:
262
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
263
+
264
+ # Logs
265
+ global_step += 1
266
+ cur_session_steps = global_step - resumed_step
267
+
268
+ running_l1_loss += l1loss.item()
269
+ if hparams.syncnet_wt > 0.:
270
+ running_sync_loss += sync_loss.item()
271
+ else:
272
+ running_sync_loss += 0.
273
+
274
+ if hparams.disc_wt > 0.:
275
+ running_perceptual_loss += perceptual_loss.item()
276
+ else:
277
+ running_perceptual_loss += 0.
278
+
279
+ if global_step == 1 or global_step % checkpoint_interval == 0:
280
+ save_checkpoint(
281
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
282
+ save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')
283
+
284
+
285
+ if global_step % hparams.eval_interval == 0:
286
+ with torch.no_grad():
287
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)
288
+
289
+ if average_sync_loss < .75:
290
+ hparams.set_hparam('syncnet_wt', 0.03)
291
+
292
+ prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
293
+ running_sync_loss / (step + 1),
294
+ running_perceptual_loss / (step + 1),
295
+ running_disc_fake_loss / (step + 1),
296
+ running_disc_real_loss / (step + 1)))
297
+
298
+ global_epoch += 1
299
+
300
+ def eval_model(test_data_loader, global_step, device, model, disc):
301
+ eval_steps = 300
302
+ print('Evaluating for {} steps'.format(eval_steps))
303
+ running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], []
304
+ while 1:
305
+ for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)):
306
+ model.eval()
307
+ disc.eval()
308
+
309
+ x = x.to(device)
310
+ mel = mel.to(device)
311
+ indiv_mels = indiv_mels.to(device)
312
+ gt = gt.to(device)
313
+
314
+ pred = disc(gt)
315
+ disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
316
+
317
+ g = model(indiv_mels, x)
318
+ pred = disc(g)
319
+ disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
320
+
321
+ running_disc_real_loss.append(disc_real_loss.item())
322
+ running_disc_fake_loss.append(disc_fake_loss.item())
323
+
324
+ sync_loss = get_sync_loss(mel, g)
325
+
326
+ if hparams.disc_wt > 0.:
327
+ perceptual_loss = disc.perceptual_forward(g)
328
+ else:
329
+ perceptual_loss = 0.
330
+
331
+ l1loss = recon_loss(g, gt)
332
+
333
+ loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
334
+ (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss
335
+
336
+ running_l1_loss.append(l1loss.item())
337
+ running_sync_loss.append(sync_loss.item())
338
+
339
+ if hparams.disc_wt > 0.:
340
+ running_perceptual_loss.append(perceptual_loss.item())
341
+ else:
342
+ running_perceptual_loss.append(0.)
343
+
344
+ if step > eval_steps: break
345
+
346
+ print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss),
347
+ sum(running_sync_loss) / len(running_sync_loss),
348
+ sum(running_perceptual_loss) / len(running_perceptual_loss),
349
+ sum(running_disc_fake_loss) / len(running_disc_fake_loss),
350
+ sum(running_disc_real_loss) / len(running_disc_real_loss)))
351
+ return sum(running_sync_loss) / len(running_sync_loss)
352
+
353
+
354
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''):
355
+ checkpoint_path = join(
356
+ checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step))
357
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
358
+ torch.save({
359
+ "state_dict": model.state_dict(),
360
+ "optimizer": optimizer_state,
361
+ "global_step": step,
362
+ "global_epoch": epoch,
363
+ }, checkpoint_path)
364
+ print("Saved checkpoint:", checkpoint_path)
365
+
366
+ def _load(checkpoint_path):
367
+ if use_cuda:
368
+ checkpoint = torch.load(checkpoint_path)
369
+ else:
370
+ checkpoint = torch.load(checkpoint_path,
371
+ map_location=lambda storage, loc: storage)
372
+ return checkpoint
373
+
374
+
375
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
376
+ global global_step
377
+ global global_epoch
378
+
379
+ print("Load checkpoint from: {}".format(path))
380
+ checkpoint = _load(path)
381
+ s = checkpoint["state_dict"]
382
+ new_s = {}
383
+ for k, v in s.items():
384
+ new_s[k.replace('module.', '')] = v
385
+ model.load_state_dict(new_s)
386
+ if not reset_optimizer:
387
+ optimizer_state = checkpoint["optimizer"]
388
+ if optimizer_state is not None:
389
+ print("Load optimizer state from {}".format(path))
390
+ optimizer.load_state_dict(checkpoint["optimizer"])
391
+ if overwrite_global_states:
392
+ global_step = checkpoint["global_step"]
393
+ global_epoch = checkpoint["global_epoch"]
394
+
395
+ return model
396
+
397
+ if __name__ == "__main__":
398
+ checkpoint_dir = args.checkpoint_dir
399
+
400
+ # Dataset and Dataloader setup
401
+ train_dataset = Dataset('train')
402
+ test_dataset = Dataset('val')
403
+
404
+ train_data_loader = data_utils.DataLoader(
405
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
406
+ num_workers=hparams.num_workers)
407
+
408
+ test_data_loader = data_utils.DataLoader(
409
+ test_dataset, batch_size=hparams.batch_size,
410
+ num_workers=4)
411
+
412
+ device = torch.device("cuda" if use_cuda else "cpu")
413
+
414
+ # Model
415
+ model = Wav2Lip().to(device)
416
+ disc = Wav2Lip_disc_qual().to(device)
417
+
418
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
419
+ print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad)))
420
+
421
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
422
+ lr=hparams.initial_learning_rate, betas=(0.5, 0.999))
423
+ disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad],
424
+ lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999))
425
+
426
+ if args.checkpoint_path is not None:
427
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
428
+
429
+ if args.disc_checkpoint_path is not None:
430
+ load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer,
431
+ reset_optimizer=False, overwrite_global_states=False)
432
+
433
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True,
434
+ overwrite_global_states=False)
435
+
436
+ if not os.path.exists(checkpoint_dir):
437
+ os.mkdir(checkpoint_dir)
438
+
439
+ # Train!
440
+ train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
441
+ checkpoint_dir=checkpoint_dir,
442
+ checkpoint_interval=hparams.checkpoint_interval,
443
+ nepochs=hparams.nepochs)
inference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir, path
2
+ import numpy as np
3
+ import scipy, cv2, os, sys, argparse, audio
4
+ import json, subprocess, random, string
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import torch, face_detection
8
+ from models import Wav2Lip
9
+ import platform
10
+
11
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
+
13
+ parser.add_argument('--checkpoint_path', type=str,
14
+ help='Name of saved checkpoint to load weights from', required=True)
15
+
16
+ parser.add_argument('--face', type=str,
17
+ help='Filepath of video/image that contains faces to use', required=True)
18
+ parser.add_argument('--audio', type=str,
19
+ help='Filepath of video/audio file to use as raw audio source', required=True)
20
+ parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.',
21
+ default='results/result_voice.mp4')
22
+
23
+ parser.add_argument('--static', type=bool,
24
+ help='If True, then use only first video frame for inference', default=False)
25
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)',
26
+ default=25., required=False)
27
+
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0],
29
+ help='Padding (top, bottom, left, right). Please adjust to include chin at least')
30
+
31
+ parser.add_argument('--face_det_batch_size', type=int,
32
+ help='Batch size for face detection', default=16)
33
+ parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128)
34
+
35
+ parser.add_argument('--resize_factor', default=1, type=int,
36
+ help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')
37
+
38
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
39
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
40
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
41
+
42
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
43
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
44
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
45
+
46
+ parser.add_argument('--rotate', default=False, action='store_true',
47
+ help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.'
48
+ 'Use if you get a flipped result, despite feeding a normal looking video')
49
+
50
+ parser.add_argument('--nosmooth', default=False, action='store_true',
51
+ help='Prevent smoothing face detections over a short temporal window')
52
+
53
+ args = parser.parse_args()
54
+ args.img_size = 96
55
+
56
+ if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
57
+ args.static = True
58
+
59
+ def get_smoothened_boxes(boxes, T):
60
+ for i in range(len(boxes)):
61
+ if i + T > len(boxes):
62
+ window = boxes[len(boxes) - T:]
63
+ else:
64
+ window = boxes[i : i + T]
65
+ boxes[i] = np.mean(window, axis=0)
66
+ return boxes
67
+
68
+ def face_detect(images):
69
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
70
+ flip_input=False, device=device)
71
+
72
+ batch_size = args.face_det_batch_size
73
+
74
+ while 1:
75
+ predictions = []
76
+ try:
77
+ for i in tqdm(range(0, len(images), batch_size)):
78
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
79
+ except RuntimeError:
80
+ if batch_size == 1:
81
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
82
+ batch_size //= 2
83
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
84
+ continue
85
+ break
86
+
87
+ results = []
88
+ pady1, pady2, padx1, padx2 = args.pads
89
+ for rect, image in zip(predictions, images):
90
+ if rect is None:
91
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
92
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
93
+
94
+ y1 = max(0, rect[1] - pady1)
95
+ y2 = min(image.shape[0], rect[3] + pady2)
96
+ x1 = max(0, rect[0] - padx1)
97
+ x2 = min(image.shape[1], rect[2] + padx2)
98
+
99
+ results.append([x1, y1, x2, y2])
100
+
101
+ boxes = np.array(results)
102
+ if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
103
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
104
+
105
+ del detector
106
+ return results
107
+
108
+ def datagen(frames, mels):
109
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
110
+
111
+ if args.box[0] == -1:
112
+ if not args.static:
113
+ face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
114
+ else:
115
+ face_det_results = face_detect([frames[0]])
116
+ else:
117
+ print('Using the specified bounding box instead of face detection...')
118
+ y1, y2, x1, x2 = args.box
119
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
120
+
121
+ for i, m in enumerate(mels):
122
+ idx = 0 if args.static else i%len(frames)
123
+ frame_to_save = frames[idx].copy()
124
+ face, coords = face_det_results[idx].copy()
125
+
126
+ face = cv2.resize(face, (args.img_size, args.img_size))
127
+
128
+ img_batch.append(face)
129
+ mel_batch.append(m)
130
+ frame_batch.append(frame_to_save)
131
+ coords_batch.append(coords)
132
+
133
+ if len(img_batch) >= args.wav2lip_batch_size:
134
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
135
+
136
+ img_masked = img_batch.copy()
137
+ img_masked[:, args.img_size//2:] = 0
138
+
139
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
140
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
141
+
142
+ yield img_batch, mel_batch, frame_batch, coords_batch
143
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
144
+
145
+ if len(img_batch) > 0:
146
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
147
+
148
+ img_masked = img_batch.copy()
149
+ img_masked[:, args.img_size//2:] = 0
150
+
151
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
152
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
153
+
154
+ yield img_batch, mel_batch, frame_batch, coords_batch
155
+
156
+ mel_step_size = 16
157
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ print('Using {} for inference.'.format(device))
159
+
160
+ def _load(checkpoint_path):
161
+ if device == 'cuda':
162
+ checkpoint = torch.load(checkpoint_path)
163
+ else:
164
+ checkpoint = torch.load(checkpoint_path,
165
+ map_location=lambda storage, loc: storage)
166
+ return checkpoint
167
+
168
+ def load_model(path):
169
+ model = Wav2Lip()
170
+ print("Load checkpoint from: {}".format(path))
171
+ checkpoint = _load(path)
172
+ s = checkpoint["state_dict"]
173
+ new_s = {}
174
+ for k, v in s.items():
175
+ new_s[k.replace('module.', '')] = v
176
+ model.load_state_dict(new_s)
177
+
178
+ model = model.to(device)
179
+ return model.eval()
180
+
181
+ def main():
182
+ if not os.path.isfile(args.face):
183
+ raise ValueError('--face argument must be a valid path to video/image file')
184
+
185
+ elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
+ full_frames = [cv2.imread(args.face)]
187
+ fps = args.fps
188
+
189
+ else:
190
+ video_stream = cv2.VideoCapture(args.face)
191
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
192
+
193
+ print('Reading video frames...')
194
+
195
+ full_frames = []
196
+ while 1:
197
+ still_reading, frame = video_stream.read()
198
+ if not still_reading:
199
+ video_stream.release()
200
+ break
201
+ if args.resize_factor > 1:
202
+ frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
+
204
+ if args.rotate:
205
+ frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
+
207
+ y1, y2, x1, x2 = args.crop
208
+ if x2 == -1: x2 = frame.shape[1]
209
+ if y2 == -1: y2 = frame.shape[0]
210
+
211
+ frame = frame[y1:y2, x1:x2]
212
+
213
+ full_frames.append(frame)
214
+
215
+ print ("Number of frames available for inference: "+str(len(full_frames)))
216
+
217
+ if not args.audio.endswith('.wav'):
218
+ print('Extracting raw audio...')
219
+ command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')
220
+
221
+ subprocess.call(command, shell=True)
222
+ args.audio = 'temp/temp.wav'
223
+
224
+ wav = audio.load_wav(args.audio, 16000)
225
+ mel = audio.melspectrogram(wav)
226
+ print(mel.shape)
227
+
228
+ if np.isnan(mel.reshape(-1)).sum() > 0:
229
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
230
+
231
+ mel_chunks = []
232
+ mel_idx_multiplier = 80./fps
233
+ i = 0
234
+ while 1:
235
+ start_idx = int(i * mel_idx_multiplier)
236
+ if start_idx + mel_step_size > len(mel[0]):
237
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
238
+ break
239
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
240
+ i += 1
241
+
242
+ print("Length of mel chunks: {}".format(len(mel_chunks)))
243
+
244
+ full_frames = full_frames[:len(mel_chunks)]
245
+
246
+ batch_size = args.wav2lip_batch_size
247
+ gen = datagen(full_frames.copy(), mel_chunks)
248
+
249
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
250
+ total=int(np.ceil(float(len(mel_chunks))/batch_size)))):
251
+ if i == 0:
252
+ model = load_model(args.checkpoint_path)
253
+ print ("Model loaded")
254
+
255
+ frame_h, frame_w = full_frames[0].shape[:-1]
256
+ out = cv2.VideoWriter('temp/result.avi',
257
+ cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))
258
+
259
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
261
+
262
+ with torch.no_grad():
263
+ pred = model(mel_batch, img_batch)
264
+
265
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
266
+
267
+ for p, f, c in zip(pred, frames, coords):
268
+ y1, y2, x1, x2 = c
269
+ p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
270
+
271
+ f[y1:y2, x1:x2] = p
272
+ out.write(f)
273
+
274
+ out.release()
275
+
276
+ command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
277
+ subprocess.call(command, shell=platform.system() != 'Windows')
278
+
279
+ if __name__ == '__main__':
280
+ main()
preprocess.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ if sys.version_info[0] < 3 and sys.version_info[1] < 2:
4
+ raise Exception("Must be using >= Python 3.2")
5
+
6
+ from os import listdir, path
7
+
8
+ if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
9
+ raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
10
+ before running this script!')
11
+
12
+ import multiprocessing as mp
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ import numpy as np
15
+ import argparse, os, cv2, traceback, subprocess
16
+ from tqdm import tqdm
17
+ from glob import glob
18
+ import audio
19
+ from hparams import hparams as hp
20
+
21
+ import face_detection
22
+
23
+ parser = argparse.ArgumentParser()
24
+
25
+ parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
26
+ parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
27
+ parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
28
+ parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)
29
+
30
+ args = parser.parse_args()
31
+
32
+ fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
33
+ device='cuda:{}'.format(id)) for id in range(args.ngpu)]
34
+
35
+ template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
36
+ # template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'
37
+
38
+ def process_video_file(vfile, args, gpu_id):
39
+ video_stream = cv2.VideoCapture(vfile)
40
+
41
+ frames = []
42
+ while 1:
43
+ still_reading, frame = video_stream.read()
44
+ if not still_reading:
45
+ video_stream.release()
46
+ break
47
+ frames.append(frame)
48
+
49
+ vidname = os.path.basename(vfile).split('.')[0]
50
+ dirname = vfile.split('/')[-2]
51
+
52
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
53
+ os.makedirs(fulldir, exist_ok=True)
54
+
55
+ batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]
56
+
57
+ i = -1
58
+ for fb in batches:
59
+ preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))
60
+
61
+ for j, f in enumerate(preds):
62
+ i += 1
63
+ if f is None:
64
+ continue
65
+
66
+ x1, y1, x2, y2 = f
67
+ cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
68
+
69
+ def process_audio_file(vfile, args):
70
+ vidname = os.path.basename(vfile).split('.')[0]
71
+ dirname = vfile.split('/')[-2]
72
+
73
+ fulldir = path.join(args.preprocessed_root, dirname, vidname)
74
+ os.makedirs(fulldir, exist_ok=True)
75
+
76
+ wavpath = path.join(fulldir, 'audio.wav')
77
+
78
+ command = template.format(vfile, wavpath)
79
+ subprocess.call(command, shell=True)
80
+
81
+
82
+ def mp_handler(job):
83
+ vfile, args, gpu_id = job
84
+ try:
85
+ process_video_file(vfile, args, gpu_id)
86
+ except KeyboardInterrupt:
87
+ exit(0)
88
+ except:
89
+ traceback.print_exc()
90
+
91
+ def main(args):
92
+ print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
93
+
94
+ filelist = glob(path.join(args.data_root, '*/*.mp4'))
95
+
96
+ jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
97
+ p = ThreadPoolExecutor(args.ngpu)
98
+ futures = [p.submit(mp_handler, j) for j in jobs]
99
+ _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
100
+
101
+ print('Dumping audios...')
102
+
103
+ for vfile in tqdm(filelist):
104
+ try:
105
+ process_audio_file(vfile, args)
106
+ except KeyboardInterrupt:
107
+ exit(0)
108
+ except:
109
+ traceback.print_exc()
110
+ continue
111
+
112
+ if __name__ == '__main__':
113
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ librosa==0.7.0
2
+ numpy==1.17.1
3
+ opencv-contrib-python>=4.2.0.34
4
+ opencv-python==4.3.0.38
5
+ torch==1.11.0
6
+ torchvision==0.12.0
7
+ tqdm==4.45.0
8
+ numba==0.48
wav2lip_train.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip as Wav2Lip
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch import optim
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils import data as data_utils
13
+ import numpy as np
14
+
15
+ from glob import glob
16
+
17
+ import os, random, cv2, argparse
18
+ from hparams import hparams, get_image_list
19
+
20
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
21
+
22
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
23
+
24
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
25
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
26
+
27
+ parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
28
+
29
+ args = parser.parse_args()
30
+
31
+
32
+ global_step = 0
33
+ global_epoch = 0
34
+ use_cuda = torch.cuda.is_available()
35
+ print('use_cuda: {}'.format(use_cuda))
36
+
37
+ syncnet_T = 5
38
+ syncnet_mel_step_size = 16
39
+
40
+ class Dataset(object):
41
+ def __init__(self, split):
42
+ self.all_videos = get_image_list(args.data_root, split)
43
+
44
+ def get_frame_id(self, frame):
45
+ return int(basename(frame).split('.')[0])
46
+
47
+ def get_window(self, start_frame):
48
+ start_id = self.get_frame_id(start_frame)
49
+ vidname = dirname(start_frame)
50
+
51
+ window_fnames = []
52
+ for frame_id in range(start_id, start_id + syncnet_T):
53
+ frame = join(vidname, '{}.jpg'.format(frame_id))
54
+ if not isfile(frame):
55
+ return None
56
+ window_fnames.append(frame)
57
+ return window_fnames
58
+
59
+ def read_window(self, window_fnames):
60
+ if window_fnames is None: return None
61
+ window = []
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ if img is None:
65
+ return None
66
+ try:
67
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
68
+ except Exception as e:
69
+ return None
70
+
71
+ window.append(img)
72
+
73
+ return window
74
+
75
+ def crop_audio_window(self, spec, start_frame):
76
+ if type(start_frame) == int:
77
+ start_frame_num = start_frame
78
+ else:
79
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
80
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
81
+
82
+ end_idx = start_idx + syncnet_mel_step_size
83
+
84
+ return spec[start_idx : end_idx, :]
85
+
86
+ def get_segmented_mels(self, spec, start_frame):
87
+ mels = []
88
+ assert syncnet_T == 5
89
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
90
+ if start_frame_num - 2 < 0: return None
91
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
92
+ m = self.crop_audio_window(spec, i - 2)
93
+ if m.shape[0] != syncnet_mel_step_size:
94
+ return None
95
+ mels.append(m.T)
96
+
97
+ mels = np.asarray(mels)
98
+
99
+ return mels
100
+
101
+ def prepare_window(self, window):
102
+ # 3 x T x H x W
103
+ x = np.asarray(window) / 255.
104
+ x = np.transpose(x, (3, 0, 1, 2))
105
+
106
+ return x
107
+
108
+ def __len__(self):
109
+ return len(self.all_videos)
110
+
111
+ def __getitem__(self, idx):
112
+ while 1:
113
+ idx = random.randint(0, len(self.all_videos) - 1)
114
+ vidname = self.all_videos[idx]
115
+ img_names = list(glob(join(vidname, '*.jpg')))
116
+ if len(img_names) <= 3 * syncnet_T:
117
+ continue
118
+
119
+ img_name = random.choice(img_names)
120
+ wrong_img_name = random.choice(img_names)
121
+ while wrong_img_name == img_name:
122
+ wrong_img_name = random.choice(img_names)
123
+
124
+ window_fnames = self.get_window(img_name)
125
+ wrong_window_fnames = self.get_window(wrong_img_name)
126
+ if window_fnames is None or wrong_window_fnames is None:
127
+ continue
128
+
129
+ window = self.read_window(window_fnames)
130
+ if window is None:
131
+ continue
132
+
133
+ wrong_window = self.read_window(wrong_window_fnames)
134
+ if wrong_window is None:
135
+ continue
136
+
137
+ try:
138
+ wavpath = join(vidname, "audio.wav")
139
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
140
+
141
+ orig_mel = audio.melspectrogram(wav).T
142
+ except Exception as e:
143
+ continue
144
+
145
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
146
+
147
+ if (mel.shape[0] != syncnet_mel_step_size):
148
+ continue
149
+
150
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
151
+ if indiv_mels is None: continue
152
+
153
+ window = self.prepare_window(window)
154
+ y = window.copy()
155
+ window[:, :, window.shape[2]//2:] = 0.
156
+
157
+ wrong_window = self.prepare_window(wrong_window)
158
+ x = np.concatenate([window, wrong_window], axis=0)
159
+
160
+ x = torch.FloatTensor(x)
161
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
162
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
163
+ y = torch.FloatTensor(y)
164
+ return x, indiv_mels, mel, y
165
+
166
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
167
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
168
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
169
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+
171
+ refs, inps = x[..., 3:], x[..., :3]
172
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
173
+ if not os.path.exists(folder): os.mkdir(folder)
174
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
175
+ for batch_idx, c in enumerate(collage):
176
+ for t in range(len(c)):
177
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
178
+
179
+ logloss = nn.BCELoss()
180
+ def cosine_loss(a, v, y):
181
+ d = nn.functional.cosine_similarity(a, v)
182
+ loss = logloss(d.unsqueeze(1), y)
183
+
184
+ return loss
185
+
186
+ device = torch.device("cuda" if use_cuda else "cpu")
187
+ syncnet = SyncNet().to(device)
188
+ for p in syncnet.parameters():
189
+ p.requires_grad = False
190
+
191
+ recon_loss = nn.L1Loss()
192
+ def get_sync_loss(mel, g):
193
+ g = g[:, :, :, g.size(3)//2:]
194
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
195
+ # B, 3 * T, H//2, W
196
+ a, v = syncnet(mel, g)
197
+ y = torch.ones(g.size(0), 1).float().to(device)
198
+ return cosine_loss(a, v, y)
199
+
200
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
201
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
202
+
203
+ global global_step, global_epoch
204
+ resumed_step = global_step
205
+
206
+ while global_epoch < nepochs:
207
+ print('Starting Epoch: {}'.format(global_epoch))
208
+ running_sync_loss, running_l1_loss = 0., 0.
209
+ prog_bar = tqdm(enumerate(train_data_loader))
210
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
211
+ model.train()
212
+ optimizer.zero_grad()
213
+
214
+ # Move data to CUDA device
215
+ x = x.to(device)
216
+ mel = mel.to(device)
217
+ indiv_mels = indiv_mels.to(device)
218
+ gt = gt.to(device)
219
+
220
+ g = model(indiv_mels, x)
221
+
222
+ if hparams.syncnet_wt > 0.:
223
+ sync_loss = get_sync_loss(mel, g)
224
+ else:
225
+ sync_loss = 0.
226
+
227
+ l1loss = recon_loss(g, gt)
228
+
229
+ loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ if global_step % checkpoint_interval == 0:
234
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
235
+
236
+ global_step += 1
237
+ cur_session_steps = global_step - resumed_step
238
+
239
+ running_l1_loss += l1loss.item()
240
+ if hparams.syncnet_wt > 0.:
241
+ running_sync_loss += sync_loss.item()
242
+ else:
243
+ running_sync_loss += 0.
244
+
245
+ if global_step == 1 or global_step % checkpoint_interval == 0:
246
+ save_checkpoint(
247
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
248
+
249
+ if global_step == 1 or global_step % hparams.eval_interval == 0:
250
+ with torch.no_grad():
251
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
252
+
253
+ if average_sync_loss < .75:
254
+ hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
255
+
256
+ prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
257
+ running_sync_loss / (step + 1)))
258
+
259
+ global_epoch += 1
260
+
261
+
262
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
263
+ eval_steps = 700
264
+ print('Evaluating for {} steps'.format(eval_steps))
265
+ sync_losses, recon_losses = [], []
266
+ step = 0
267
+ while 1:
268
+ for x, indiv_mels, mel, gt in test_data_loader:
269
+ step += 1
270
+ model.eval()
271
+
272
+ # Move data to CUDA device
273
+ x = x.to(device)
274
+ gt = gt.to(device)
275
+ indiv_mels = indiv_mels.to(device)
276
+ mel = mel.to(device)
277
+
278
+ g = model(indiv_mels, x)
279
+
280
+ sync_loss = get_sync_loss(mel, g)
281
+ l1loss = recon_loss(g, gt)
282
+
283
+ sync_losses.append(sync_loss.item())
284
+ recon_losses.append(l1loss.item())
285
+
286
+ if step > eval_steps:
287
+ averaged_sync_loss = sum(sync_losses) / len(sync_losses)
288
+ averaged_recon_loss = sum(recon_losses) / len(recon_losses)
289
+
290
+ print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
291
+
292
+ return averaged_sync_loss
293
+
294
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
295
+
296
+ checkpoint_path = join(
297
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
298
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
299
+ torch.save({
300
+ "state_dict": model.state_dict(),
301
+ "optimizer": optimizer_state,
302
+ "global_step": step,
303
+ "global_epoch": epoch,
304
+ }, checkpoint_path)
305
+ print("Saved checkpoint:", checkpoint_path)
306
+
307
+
308
+ def _load(checkpoint_path):
309
+ if use_cuda:
310
+ checkpoint = torch.load(checkpoint_path)
311
+ else:
312
+ checkpoint = torch.load(checkpoint_path,
313
+ map_location=lambda storage, loc: storage)
314
+ return checkpoint
315
+
316
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
317
+ global global_step
318
+ global global_epoch
319
+
320
+ print("Load checkpoint from: {}".format(path))
321
+ checkpoint = _load(path)
322
+ s = checkpoint["state_dict"]
323
+ new_s = {}
324
+ for k, v in s.items():
325
+ new_s[k.replace('module.', '')] = v
326
+ model.load_state_dict(new_s)
327
+ if not reset_optimizer:
328
+ optimizer_state = checkpoint["optimizer"]
329
+ if optimizer_state is not None:
330
+ print("Load optimizer state from {}".format(path))
331
+ optimizer.load_state_dict(checkpoint["optimizer"])
332
+ if overwrite_global_states:
333
+ global_step = checkpoint["global_step"]
334
+ global_epoch = checkpoint["global_epoch"]
335
+
336
+ return model
337
+
338
+ if __name__ == "__main__":
339
+ checkpoint_dir = args.checkpoint_dir
340
+
341
+ # Dataset and Dataloader setup
342
+ train_dataset = Dataset('train')
343
+ test_dataset = Dataset('val')
344
+
345
+ train_data_loader = data_utils.DataLoader(
346
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
347
+ num_workers=hparams.num_workers)
348
+
349
+ test_data_loader = data_utils.DataLoader(
350
+ test_dataset, batch_size=hparams.batch_size,
351
+ num_workers=4)
352
+
353
+ device = torch.device("cuda" if use_cuda else "cpu")
354
+
355
+ # Model
356
+ model = Wav2Lip().to(device)
357
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
358
+
359
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
360
+ lr=hparams.initial_learning_rate)
361
+
362
+ if args.checkpoint_path is not None:
363
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
364
+
365
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
366
+
367
+ if not os.path.exists(checkpoint_dir):
368
+ os.mkdir(checkpoint_dir)
369
+
370
+ # Train!
371
+ train(device, model, train_data_loader, test_data_loader, optimizer,
372
+ checkpoint_dir=checkpoint_dir,
373
+ checkpoint_interval=hparams.checkpoint_interval,
374
+ nepochs=hparams.nepochs)