diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..c7d95e8e73173cc9703d7c062f03f6ea05ad1ac1
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,37 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat filter=lfs diff=lfs merge=lfs -text
+damo/dreamtalk/media/teaser.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/damo/dreamtalk/.mdl b/damo/dreamtalk/.mdl
new file mode 100644
index 0000000000000000000000000000000000000000..1f61457c284d943ae6f79ff6af46a55d01393f35
Binary files /dev/null and b/damo/dreamtalk/.mdl differ
diff --git a/damo/dreamtalk/.msc b/damo/dreamtalk/.msc
new file mode 100644
index 0000000000000000000000000000000000000000..9bd0da68a4543a71ccdaa8ba270d6b8f3fe2a6c9
Binary files /dev/null and b/damo/dreamtalk/.msc differ
diff --git a/damo/dreamtalk/README.md b/damo/dreamtalk/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf013752f083007c3a25fbbdc81d785ce1e962e2
--- /dev/null
+++ b/damo/dreamtalk/README.md
@@ -0,0 +1,131 @@
+# DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models
+
+
[](https://youtu.be/VF4vlE6ZqWQ)
+
+DreamTalk is a diffusion-based audio-driven expressive talking head generation framework that can produce high-quality talking head videos across diverse speaking styles. DreamTalk exhibits robust performance with a diverse array of inputs, including songs, speech in multiple languages, noisy audio, and out-of-domain portraits.
+
+
+
+## News
+- __[2023.12]__ Release inference code and pretrained checkpoint.
+
+## 安装依赖
+```
+pip install dlib
+```
+
+## Installation
+
+ 我在`output_video`文件夹下已经放入了一些生成好的文件, 可运行下面脚本, 对比下结果.
+
+```python
+from modelscope.utils.constant import Tasks
+from modelscope.pipelines import pipeline
+import os
+
+pipe = pipeline(task=Tasks.text_to_video_synthesis, model='damo/dreamtalk',
+style_clip_path="data/style_clip/3DMM/M030_front_surprised_level3_001.mat",
+pose_path="data/pose/RichardShelby_front_neutral_level1_001.mat",
+model_revision='master'
+)
+# ,model_revision='master')
+inputs={
+ "output_name": "songbie_yk_male",
+ "wav_path": "data/audio/acknowledgement_english.m4a",
+ "img_crop": True,
+ "image_path": "data/src_img/uncropped/male_face.png",
+ "max_gen_len": 20
+ }
+pipe(input=inputs)
+print("end")
+```
+
+ `wav_path` 为输入音频路径;
+
+ `style_clip_path` 为表情参考文件,从带情绪的视频中提取, 可用来控制生成视频的表情;
+
+ `pose_path` 为头部运动参考文件, 从视频中提取,可用来控制生成视频的头部运动;
+
+ `image_path` 为说话人肖像, 最好是正脸, 理论支持任意分辨率输入, 会被裁减成$256\times256$ 分辨率;
+
+ `max_gen_len` 为最长视频生成时长, 单位为秒, 如果输入音频长于这个时间则会被截断;
+
+ `output_name`为输出名称, 最终生成的视频会在 `output_video` 文件夹下, 中间结果会在 `tmp` 文件夹下.
+
+ 如果输入图片已经为$256\times256$ 而且大小合适无需裁剪, 则可使用`disable_img_crop`跳过裁剪步骤, 如下:
+
+## Download Checkpoints
+Download the checkpoint of the denoising network:
+* [ModelScope](tmp)
+
+
+Download the checkpoint of the renderer:
+* [ModelScope](tmp)
+
+Put the downloaded checkpoints into `checkpoints` folder.
+
+
+## Inference
+Run the script:
+
+```
+python inference_for_demo_video.py \
+--wav_path data/audio/acknowledgement_english.m4a \
+--style_clip_path data/style_clip/3DMM/M030_front_neutral_level1_001.mat \
+--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
+--image_path data/src_img/uncropped/male_face.png \
+--cfg_scale 1.0 \
+--max_gen_len 30 \
+--output_name acknowledgement_english@M030_front_neutral_level1_001@male_face
+```
+
+`wav_path` specifies the input audio. The input audio file extensions such as wav, mp3, m4a, and mp4 (video with sound) should all be compatible.
+
+`style_clip_path` specifies the reference speaking style and `pose_path` specifies head pose. They are 3DMM paramenter sequences extracted from reference videos. You can follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters from your own videos. Note that the video frame rate should be 25 FPS. Besides, videos used for head pose reference should be first cropped to $256\times256$ using scripts in [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
+
+`image_path` specifies the input portrait. Its resolution should be larger than $256\times256$. Frontal portraits, with the face directly facing forward and not tilted to one side, usually achieve satisfactory results. The input portrait will be cropped to $256\times256$. If your portrait is already cropped to $256\times256$ and you want to disable cropping, use option `--disable_img_crop` like this:
+
+```
+python inference_for_demo_video.py \
+--wav_path data/audio/acknowledgement_chinese.m4a \
+--style_clip_path data/style_clip/3DMM/M030_front_surprised_level3_001.mat \
+--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
+--image_path data/src_img/cropped/zp1.png \
+--disable_img_crop \
+--cfg_scale 1.0 \
+--max_gen_len 30 \
+--output_name acknowledgement_chinese@M030_front_surprised_level3_001@zp1
+```
+
+`cfg_scale` controls the scale of classifer-free guidance. It can adjust the intensity of speaking styles.
+
+`max_gen_len` is the maximum video generation duration, measured in seconds. If the input audio exceeds this length, it will be truncated.
+
+The generated video will be named `$(output_name).mp4` and put in the output_video folder. Intermediate results, including the cropped portrait, will be in the `tmp/$(output_name)` folder.
+
+Sample inputs are presented in `data` folder. Due to copyright issues, we are unable to include the songs we have used in this folder.
+
+
+## Acknowledgements
+
+We extend our heartfelt thanks for the invaluable contributions made by preceding works to the development of DreamTalk. This includes, but is not limited to:
+[PIRenderer](https://github.com/RenYurui/PIRender)
+,[AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
+,[StyleTalk](https://github.com/FuxiVirtualHuman/styletalk)
+,[Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
+,[Wav2vec2.0](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english)
+,[diffusion-point-cloud](https://github.com/luost26/diffusion-point-cloud)
+,[FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). We are dedicated to advancing upon these foundational works with the utmost respect for their original contributions.
+
+## Citation
+If you find this codebase useful for your research, please use the following entry.
+```BibTeX
+@article{ma2023dreamtalk,
+ title={DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models},
+ author={Ma, Yifeng and Zhang, Shiwei and Wang, Jiayu and Wang, Xiang and Zhang, Yingya and Deng, Zhidong},
+ journal={arXiv preprint arXiv:2312.09767},
+ year={2023}
+}
+```
+
+
diff --git a/damo/dreamtalk/checkpoints/denoising_network.pth b/damo/dreamtalk/checkpoints/denoising_network.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7a9ed4ba5a6cbbef64cdc86efd7222cb49dd5561
--- /dev/null
+++ b/damo/dreamtalk/checkpoints/denoising_network.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93864d1316f60e75b40bd820707bb2464f790b1636ae2b9275ee500d41c4e3ae
+size 47908943
diff --git a/damo/dreamtalk/checkpoints/renderer.pt b/damo/dreamtalk/checkpoints/renderer.pt
new file mode 100644
index 0000000000000000000000000000000000000000..618008af470692c0caac5f9ec59ee8286a0a2173
--- /dev/null
+++ b/damo/dreamtalk/checkpoints/renderer.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
+size 335281551
diff --git a/damo/dreamtalk/configs/default.py b/damo/dreamtalk/configs/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c552c827ca82dc74389c4e8b47502457d7a0abc
--- /dev/null
+++ b/damo/dreamtalk/configs/default.py
@@ -0,0 +1,91 @@
+from yacs.config import CfgNode as CN
+
+
+_C = CN()
+_C.TAG = "style_id_emotion"
+_C.DECODER_TYPE = "DisentangleDecoder"
+_C.CONTENT_ENCODER_TYPE = "ContentW2VEncoder"
+_C.STYLE_ENCODER_TYPE = "StyleEncoder"
+
+_C.DIFFNET_TYPE = "DiffusionNet"
+
+_C.WIN_SIZE = 5
+_C.D_MODEL = 256
+
+_C.DATASET = CN()
+_C.DATASET.FACE3D_DIM = 64
+_C.DATASET.NUM_FRAMES = 64
+_C.DATASET.STYLE_MAX_LEN = 256
+
+_C.TRAIN = CN()
+_C.TRAIN.FACE3D_LATENT = CN()
+_C.TRAIN.FACE3D_LATENT.TYPE = "face3d"
+
+_C.DIFFUSION = CN()
+_C.DIFFUSION.PREDICT_WHAT = "x0" # noise | x0
+_C.DIFFUSION.SCHEDULE = CN()
+_C.DIFFUSION.SCHEDULE.NUM_STEPS = 1000
+_C.DIFFUSION.SCHEDULE.BETA_1 = 1e-4
+_C.DIFFUSION.SCHEDULE.BETA_T = 0.02
+_C.DIFFUSION.SCHEDULE.MODE = "linear"
+
+_C.CONTENT_ENCODER = CN()
+_C.CONTENT_ENCODER.d_model = _C.D_MODEL
+_C.CONTENT_ENCODER.nhead = 8
+_C.CONTENT_ENCODER.num_encoder_layers = 3
+_C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
+_C.CONTENT_ENCODER.dropout = 0.1
+_C.CONTENT_ENCODER.activation = "relu"
+_C.CONTENT_ENCODER.normalize_before = False
+_C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
+
+_C.STYLE_ENCODER = CN()
+_C.STYLE_ENCODER.d_model = _C.D_MODEL
+_C.STYLE_ENCODER.nhead = 8
+_C.STYLE_ENCODER.num_encoder_layers = 3
+_C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
+_C.STYLE_ENCODER.dropout = 0.1
+_C.STYLE_ENCODER.activation = "relu"
+_C.STYLE_ENCODER.normalize_before = False
+_C.STYLE_ENCODER.pos_embed_len = _C.DATASET.STYLE_MAX_LEN
+_C.STYLE_ENCODER.aggregate_method = (
+ "self_attention_pooling" # average | self_attention_pooling
+)
+# _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
+
+_C.DECODER = CN()
+_C.DECODER.d_model = _C.D_MODEL
+_C.DECODER.nhead = 8
+_C.DECODER.num_decoder_layers = 3
+_C.DECODER.dim_feedforward = 4 * _C.D_MODEL
+_C.DECODER.dropout = 0.1
+_C.DECODER.activation = "relu"
+_C.DECODER.normalize_before = False
+_C.DECODER.return_intermediate_dec = False
+_C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
+_C.DECODER.network_type = "TransformerDecoder"
+_C.DECODER.dynamic_K = None
+_C.DECODER.dynamic_ratio = None
+# _C.DECODER.output_dim = _C.DATASET.FACE3D_DIM
+# LSFM basis:
+# _C.DECODER.upper_face3d_indices = tuple(list(range(19)) + list(range(46, 51)))
+# _C.DECODER.lower_face3d_indices = tuple(range(19, 46))
+# BFM basis:
+# fmt: off
+_C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
+# fmt: on
+_C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
+
+_C.CF_GUIDANCE = CN()
+_C.CF_GUIDANCE.TRAINING = True
+_C.CF_GUIDANCE.INFERENCE = True
+_C.CF_GUIDANCE.NULL_PROB = 0.1
+_C.CF_GUIDANCE.SCALE = 1.0
+
+_C.INFERENCE = CN()
+_C.INFERENCE.CHECKPOINT = "checkpoints/denoising_network.pth"
+
+
+def get_cfg_defaults():
+ """Get a yacs CfgNode object with default values for my_project."""
+ return _C.clone()
diff --git a/damo/dreamtalk/configuration.json b/damo/dreamtalk/configuration.json
new file mode 100644
index 0000000000000000000000000000000000000000..568d1b31a3347adb2fb9b8c37d9fed4030f4e84f
--- /dev/null
+++ b/damo/dreamtalk/configuration.json
@@ -0,0 +1,11 @@
+{
+ "framework": "pytorch",
+ "task": "text-to-video-synthesis",
+ "model": {
+ "type": "Dreamtalk-Generation"
+ },
+ "pipeline": {
+ "type": "Dreamtalk-generation-pipe"
+ },
+ "allow_remote": true
+}
\ No newline at end of file
diff --git a/damo/dreamtalk/core/networks/__init__.py b/damo/dreamtalk/core/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6947c4e303a82b6a3b0fb00517f35deaf65783fb
--- /dev/null
+++ b/damo/dreamtalk/core/networks/__init__.py
@@ -0,0 +1,14 @@
+from core.networks.generator import (
+ StyleEncoder,
+ Decoder,
+ ContentW2VEncoder,
+)
+from core.networks.disentangle_decoder import DisentangleDecoder
+
+
+def get_network(name: str):
+ obj = globals().get(name)
+ if obj is None:
+ raise KeyError("Unknown Network: %s" % name)
+ else:
+ return obj
diff --git a/damo/dreamtalk/core/networks/diffusion_net.py b/damo/dreamtalk/core/networks/diffusion_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..3545e790fc44ac35c17c46012c694d7ace3b5b62
--- /dev/null
+++ b/damo/dreamtalk/core/networks/diffusion_net.py
@@ -0,0 +1,340 @@
+import math
+import torch
+import torch.nn.functional as F
+from torch.nn import Module
+from core.networks.diffusion_util import VarianceSchedule
+import numpy as np
+
+
+def face3d_raw_to_norm(face3d_raw, exp_min, exp_max):
+ """
+
+ Args:
+ face3d_raw (_type_): (B, L, C_face3d)
+ exp_min (_type_): (C_face3d)
+ exp_max (_type_): (C_face3d)
+
+ Returns:
+ _type_: (B, L, C_face3d) in [-1, 1]
+ """
+ exp_min_expand = exp_min[None, None, :]
+ exp_max_expand = exp_max[None, None, :]
+ face3d_norm_01 = (face3d_raw - exp_min_expand) / (exp_max_expand - exp_min_expand)
+ face3d_norm = face3d_norm_01 * 2 - 1
+ return face3d_norm
+
+
+def face3d_norm_to_raw(face3d_norm, exp_min, exp_max):
+ """
+
+ Args:
+ face3d_norm (_type_): (B, L, C_face3d)
+ exp_min (_type_): (C_face3d)
+ exp_max (_type_): (C_face3d)
+
+ Returns:
+ _type_: (B, L, C_face3d)
+ """
+ exp_min_expand = exp_min[None, None, :]
+ exp_max_expand = exp_max[None, None, :]
+ face3d_norm_01 = (face3d_norm + 1) / 2
+ face3d_raw = face3d_norm_01 * (exp_max_expand - exp_min_expand) + exp_min_expand
+ return face3d_raw
+
+
+class DiffusionNet(Module):
+ def __init__(self, cfg, net, var_sched: VarianceSchedule):
+ super().__init__()
+ self.cfg = cfg
+ self.net = net
+ self.var_sched = var_sched
+ self.face3d_latent_type = self.cfg.TRAIN.FACE3D_LATENT.TYPE
+ self.predict_what = self.cfg.DIFFUSION.PREDICT_WHAT
+
+ if self.cfg.CF_GUIDANCE.TRAINING:
+ null_style_clip = torch.zeros(
+ self.cfg.DATASET.STYLE_MAX_LEN, self.cfg.DATASET.FACE3D_DIM
+ )
+ self.register_buffer("null_style_clip", null_style_clip)
+
+ null_pad_mask = torch.tensor([False] * self.cfg.DATASET.STYLE_MAX_LEN)
+ self.register_buffer("null_pad_mask", null_pad_mask)
+
+ def _face3d_to_latent(self, face3d):
+ latent = None
+ if self.face3d_latent_type == "face3d":
+ latent = face3d
+ elif self.face3d_latent_type == "normalized_face3d":
+ latent = face3d_raw_to_norm(
+ face3d, exp_min=self.exp_min, exp_max=self.exp_max
+ )
+ else:
+ raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
+ return latent
+
+ def _latent_to_face3d(self, latent):
+ face3d = None
+ if self.face3d_latent_type == "face3d":
+ face3d = latent
+ elif self.face3d_latent_type == "normalized_face3d":
+ latent = torch.clamp(latent, min=-1, max=1)
+ face3d = face3d_norm_to_raw(
+ latent, exp_min=self.exp_min, exp_max=self.exp_max
+ )
+ else:
+ raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
+ return face3d
+
+ def ddim_sample(
+ self,
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim,
+ flexibility=0.0,
+ ret_traj=False,
+ use_cf_guidance=False,
+ cfg_scale=2.0,
+ ddim_num_step=50,
+ ready_style_code=None,
+ ):
+ """
+
+ Args:
+ audio (_type_): (B, L, W) or (B, L, W, C)
+ style_clip (_type_): (B, L_clipmax, C_face3d)
+ style_pad_mask : (B, L_clipmax)
+ pose_dim (_type_): int
+ flexibility (float, optional): _description_. Defaults to 0.0.
+ ret_traj (bool, optional): _description_. Defaults to False.
+
+
+ Returns:
+ _type_: (B, L, C_face)
+ """
+ if self.predict_what != "x0":
+ raise NotImplementedError(self.predict_what)
+
+ if ready_style_code is not None and use_cf_guidance:
+ raise NotImplementedError("not implement cfg for ready style code")
+
+ c = self.var_sched.num_steps // ddim_num_step
+ time_steps = torch.tensor(
+ np.asarray(list(range(0, self.var_sched.num_steps, c))) + 1
+ )
+ assert len(time_steps) == ddim_num_step
+ prev_time_steps = torch.cat((torch.tensor([0]), time_steps[:-1]))
+
+ batch_size, output_len = audio.shape[:2]
+ # batch_size = context.size(0)
+ context = {
+ "audio": audio,
+ "style_clip": style_clip,
+ "style_pad_mask": style_pad_mask,
+ "ready_style_code": ready_style_code,
+ }
+ if use_cf_guidance:
+ uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
+ batch_size, 1, 1
+ )
+ uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
+
+ context_double = {
+ "audio": torch.cat([audio] * 2, dim=0),
+ "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
+ "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
+ "ready_style_code": None
+ if ready_style_code is None
+ else torch.cat(
+ [
+ ready_style_code,
+ self.net.style_encoder(uncond_style_clip, uncond_pad_mask),
+ ],
+ dim=0,
+ ),
+ }
+
+ x_t = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
+
+ for idx in list(range(ddim_num_step))[::-1]:
+ t = time_steps[idx]
+ t_prev = prev_time_steps[idx]
+ ddim_alpha = self.var_sched.alpha_bars[t]
+ ddim_alpha_prev = self.var_sched.alpha_bars[t_prev]
+
+ t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
+ if use_cf_guidance:
+ x_t_double = torch.cat([x_t] * 2, dim=0)
+ t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
+ cond_output, uncond_output = self.net(
+ x_t_double, t=t_tensor_double, **context_double
+ ).chunk(2)
+ diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
+ else:
+ diff_output = self.net(x_t, t=t_tensor, **context)
+
+ pred_x0 = diff_output
+ eps = (x_t - torch.sqrt(ddim_alpha) * pred_x0) / torch.sqrt(1 - ddim_alpha)
+ c1 = torch.sqrt(ddim_alpha_prev)
+ c2 = torch.sqrt(1 - ddim_alpha_prev)
+
+ x_t = c1 * pred_x0 + c2 * eps
+
+ latent_output = x_t
+ face3d_output = self._latent_to_face3d(latent_output)
+ return face3d_output
+
+ def sample(
+ self,
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim,
+ flexibility=0.0,
+ ret_traj=False,
+ use_cf_guidance=False,
+ cfg_scale=2.0,
+ sample_method="ddpm",
+ ddim_num_step=50,
+ ready_style_code=None,
+ ):
+ # sample_method = kwargs["sample_method"]
+ if sample_method == "ddpm":
+ if ready_style_code is not None:
+ raise NotImplementedError("ready style code in ddpm")
+ return self.ddpm_sample(
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim,
+ flexibility=flexibility,
+ ret_traj=ret_traj,
+ use_cf_guidance=use_cf_guidance,
+ cfg_scale=cfg_scale,
+ )
+ elif sample_method == "ddim":
+ return self.ddim_sample(
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim,
+ flexibility=flexibility,
+ ret_traj=ret_traj,
+ use_cf_guidance=use_cf_guidance,
+ cfg_scale=cfg_scale,
+ ddim_num_step=ddim_num_step,
+ ready_style_code=ready_style_code,
+ )
+
+ def ddpm_sample(
+ self,
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim,
+ flexibility=0.0,
+ ret_traj=False,
+ use_cf_guidance=False,
+ cfg_scale=2.0,
+ ):
+ """
+
+ Args:
+ audio (_type_): (B, L, W) or (B, L, W, C)
+ style_clip (_type_): (B, L_clipmax, C_face3d)
+ style_pad_mask : (B, L_clipmax)
+ pose_dim (_type_): int
+ flexibility (float, optional): _description_. Defaults to 0.0.
+ ret_traj (bool, optional): _description_. Defaults to False.
+
+
+ Returns:
+ _type_: (B, L, C_face)
+ """
+ batch_size, output_len = audio.shape[:2]
+ # batch_size = context.size(0)
+ context = {
+ "audio": audio,
+ "style_clip": style_clip,
+ "style_pad_mask": style_pad_mask,
+ }
+ if use_cf_guidance:
+ uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
+ batch_size, 1, 1
+ )
+ uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
+ context_double = {
+ "audio": torch.cat([audio] * 2, dim=0),
+ "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
+ "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
+ }
+
+ x_T = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
+ traj = {self.var_sched.num_steps: x_T}
+ for t in range(self.var_sched.num_steps, 0, -1):
+ alpha = self.var_sched.alphas[t]
+ alpha_bar = self.var_sched.alpha_bars[t]
+ alpha_bar_prev = self.var_sched.alpha_bars[t - 1]
+ sigma = self.var_sched.get_sigmas(t, flexibility)
+
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
+ x_t = traj[t]
+ t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
+ if use_cf_guidance:
+ x_t_double = torch.cat([x_t] * 2, dim=0)
+ t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
+ cond_output, uncond_output = self.net(
+ x_t_double, t=t_tensor_double, **context_double
+ ).chunk(2)
+ diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
+ else:
+ diff_output = self.net(x_t, t=t_tensor, **context)
+
+ if self.predict_what == "noise":
+ c0 = 1.0 / torch.sqrt(alpha)
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
+ x_next = c0 * (x_t - c1 * diff_output) + sigma * z
+ elif self.predict_what == "x0":
+ d0 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)
+ d1 = torch.sqrt(alpha_bar_prev) * (1 - alpha) / (1 - alpha_bar)
+ x_next = d0 * x_t + d1 * diff_output + sigma * z
+ traj[t - 1] = x_next.detach()
+ traj[t] = traj[t].cpu()
+ if not ret_traj:
+ del traj[t]
+
+ if ret_traj:
+ raise NotImplementedError
+ return traj
+ else:
+ latent_output = traj[0]
+ face3d_output = self._latent_to_face3d(latent_output)
+ return face3d_output
+
+
+if __name__ == "__main__":
+ from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
+
+ diffnet = DiffusionNet(
+ net=NoisePredictor(),
+ var_sched=VarianceSchedule(
+ num_steps=500, beta_1=1e-4, beta_T=0.02, mode="linear"
+ ),
+ )
+
+ import torch
+
+ gt_face3d = torch.randn(16, 64, 64)
+ audio = torch.randn(16, 64, 11)
+ style_clip = torch.randn(16, 256, 64)
+ style_pad_mask = torch.ones(16, 256)
+
+ context = {
+ "audio": audio,
+ "style_clip": style_clip,
+ "style_pad_mask": style_pad_mask,
+ }
+
+ loss = diffnet.get_loss(gt_face3d, context)
+
+ print("hello")
diff --git a/damo/dreamtalk/core/networks/diffusion_util.py b/damo/dreamtalk/core/networks/diffusion_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..584866cbd790960c3f7e6d67478c612ef938a3cc
--- /dev/null
+++ b/damo/dreamtalk/core/networks/diffusion_util.py
@@ -0,0 +1,131 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Module
+from core.networks import get_network
+from core.utils import sinusoidal_embedding
+
+
+class VarianceSchedule(Module):
+ def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
+ super().__init__()
+ assert mode in ("linear",)
+ self.num_steps = num_steps
+ self.beta_1 = beta_1
+ self.beta_T = beta_T
+ self.mode = mode
+
+ if mode == "linear":
+ betas = torch.linspace(beta_1, beta_T, steps=num_steps)
+
+ betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
+
+ alphas = 1 - betas
+ log_alphas = torch.log(alphas)
+ for i in range(1, log_alphas.size(0)): # 1 to T
+ log_alphas[i] += log_alphas[i - 1]
+ alpha_bars = log_alphas.exp()
+
+ sigmas_flex = torch.sqrt(betas)
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
+ for i in range(1, sigmas_flex.size(0)):
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
+ i
+ ]
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
+
+ self.register_buffer("betas", betas)
+ self.register_buffer("alphas", alphas)
+ self.register_buffer("alpha_bars", alpha_bars)
+ self.register_buffer("sigmas_flex", sigmas_flex)
+ self.register_buffer("sigmas_inflex", sigmas_inflex)
+
+ def uniform_sample_t(self, batch_size):
+ ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
+ return ts.tolist()
+
+ def get_sigmas(self, t, flexibility):
+ assert 0 <= flexibility and flexibility <= 1
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
+ 1 - flexibility
+ )
+ return sigmas
+
+
+class NoisePredictor(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
+ self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
+
+ style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
+ cfg.defrost()
+ cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
+ cfg.freeze()
+ self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
+
+ decoder_class = get_network(cfg.DECODER_TYPE)
+ cfg.defrost()
+ cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
+ cfg.freeze()
+ self.decoder = decoder_class(**cfg.DECODER)
+
+ self.content_xt_to_decoder_input_wo_time = nn.Sequential(
+ nn.Linear(cfg.D_MODEL + cfg.DATASET.FACE3D_DIM, cfg.D_MODEL),
+ nn.ReLU(),
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+ nn.ReLU(),
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+ )
+
+ self.time_sinusoidal_dim = cfg.D_MODEL
+ self.time_embed_net = nn.Sequential(
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+ nn.SiLU(),
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
+ )
+
+ def forward(self, x_t, t, audio, style_clip, style_pad_mask, ready_style_code=None):
+ """_summary_
+
+ Args:
+ x_t (_type_): (B, L, C_face)
+ t (_type_): (B,) dtype:float32
+ audio (_type_): (B, L, W)
+ style_clip (_type_): (B, L_clipmax, C_face3d)
+ style_pad_mask : (B, L_clipmax)
+ ready_style_code: (B, C_model)
+ Returns:
+ e_theta : (B, L, C_face)
+ """
+ W = audio.shape[2]
+ content = self.content_encoder(audio)
+ # (B, L, W, C_model)
+ x_t_expand = x_t.unsqueeze(2).repeat(1, 1, W, 1)
+ # (B, L, C_face) -> (B, L, W, C_face)
+ content_xt_concat = torch.cat((content, x_t_expand), dim=3)
+ # (B, L, W, C_model+C_face)
+ decoder_input_without_time = self.content_xt_to_decoder_input_wo_time(
+ content_xt_concat
+ )
+ # (B, L, W, C_model)
+
+ time_sinusoidal = sinusoidal_embedding(t, self.time_sinusoidal_dim)
+ # (B, C_embed)
+ time_embedding = self.time_embed_net(time_sinusoidal)
+ # (B, C_model)
+ B, C = time_embedding.shape
+ time_embed_expand = time_embedding.view(B, 1, 1, C)
+ decoder_input = decoder_input_without_time + time_embed_expand
+ # (B, L, W, C_model)
+
+ if ready_style_code is not None:
+ style_code = ready_style_code
+ else:
+ style_code = self.style_encoder(style_clip, style_pad_mask)
+ # (B, C_model)
+
+ e_theta = self.decoder(decoder_input, style_code)
+ # (B, L, C_face)
+ return e_theta
diff --git a/damo/dreamtalk/core/networks/disentangle_decoder.py b/damo/dreamtalk/core/networks/disentangle_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..dab626a2cedd28444951ec0ab421b5a2a744d4ed
--- /dev/null
+++ b/damo/dreamtalk/core/networks/disentangle_decoder.py
@@ -0,0 +1,240 @@
+import torch
+from torch import nn
+
+from .transformer import (
+ PositionalEncoding,
+ TransformerDecoderLayer,
+ TransformerDecoder,
+)
+from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
+from core.utils import _reset_parameters
+
+
+def get_decoder_network(
+ network_type,
+ d_model,
+ nhead,
+ dim_feedforward,
+ dropout,
+ activation,
+ normalize_before,
+ num_decoder_layers,
+ return_intermediate_dec,
+ dynamic_K,
+ dynamic_ratio,
+):
+ decoder = None
+ if network_type == "TransformerDecoder":
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ norm = nn.LayerNorm(d_model)
+ decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ norm,
+ return_intermediate_dec,
+ )
+ elif network_type == "DynamicFCDecoder":
+ d_style = d_model
+ decoder_layer = DynamicFCDecoderLayer(
+ d_model,
+ nhead,
+ d_style,
+ dynamic_K,
+ dynamic_ratio,
+ dim_feedforward,
+ dropout,
+ activation,
+ normalize_before,
+ )
+ norm = nn.LayerNorm(d_model)
+ decoder = DynamicFCDecoder(
+ decoder_layer, num_decoder_layers, norm, return_intermediate_dec
+ )
+ elif network_type == "DynamicFCEncoder":
+ d_style = d_model
+ decoder_layer = DynamicFCEncoderLayer(
+ d_model,
+ nhead,
+ d_style,
+ dynamic_K,
+ dynamic_ratio,
+ dim_feedforward,
+ dropout,
+ activation,
+ normalize_before,
+ )
+ norm = nn.LayerNorm(d_model)
+ decoder = DynamicFCEncoder(decoder_layer, num_decoder_layers, norm)
+
+ else:
+ raise ValueError(f"Invalid network_type {network_type}")
+
+ return decoder
+
+
+class DisentangleDecoder(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_decoder_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ pos_embed_len=80,
+ upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
+ lower_face3d_indices=tuple(range(19, 46)),
+ network_type="None",
+ dynamic_K=None,
+ dynamic_ratio=None,
+ **_,
+ ) -> None:
+ super().__init__()
+
+ self.upper_face3d_indices = upper_face3d_indices
+ self.lower_face3d_indices = lower_face3d_indices
+
+ # upper_decoder_layer = TransformerDecoderLayer(
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ # )
+ # upper_decoder_norm = nn.LayerNorm(d_model)
+ # self.upper_decoder = TransformerDecoder(
+ # upper_decoder_layer,
+ # num_decoder_layers,
+ # upper_decoder_norm,
+ # return_intermediate=return_intermediate_dec,
+ # )
+ self.upper_decoder = get_decoder_network(
+ network_type,
+ d_model,
+ nhead,
+ dim_feedforward,
+ dropout,
+ activation,
+ normalize_before,
+ num_decoder_layers,
+ return_intermediate_dec,
+ dynamic_K,
+ dynamic_ratio,
+ )
+ _reset_parameters(self.upper_decoder)
+
+ # lower_decoder_layer = TransformerDecoderLayer(
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ # )
+ # lower_decoder_norm = nn.LayerNorm(d_model)
+ # self.lower_decoder = TransformerDecoder(
+ # lower_decoder_layer,
+ # num_decoder_layers,
+ # lower_decoder_norm,
+ # return_intermediate=return_intermediate_dec,
+ # )
+ self.lower_decoder = get_decoder_network(
+ network_type,
+ d_model,
+ nhead,
+ dim_feedforward,
+ dropout,
+ activation,
+ normalize_before,
+ num_decoder_layers,
+ return_intermediate_dec,
+ dynamic_K,
+ dynamic_ratio,
+ )
+ _reset_parameters(self.lower_decoder)
+
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+ tail_hidden_dim = d_model // 2
+ self.upper_tail_fc = nn.Sequential(
+ nn.Linear(d_model, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
+ )
+ self.lower_tail_fc = nn.Sequential(
+ nn.Linear(d_model, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
+ )
+
+ def forward(self, content, style_code):
+ """
+
+ Args:
+ content (_type_): (B, num_frames, window, C_dmodel)
+ style_code (_type_): (B, C_dmodel)
+
+ Returns:
+ face3d: (B, L_clip, C_3dmm)
+ """
+ B, N, W, C = content.shape
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
+ # (W, B*N, C)
+
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
+ # (W, B*N, C)
+ tgt = torch.zeros_like(style)
+ pos_embed = self.pos_embed(W)
+ pos_embed = pos_embed.permute(1, 0, 2)
+
+ upper_face3d_feat = self.upper_decoder(
+ tgt, content, pos=pos_embed, query_pos=style
+ )[0]
+ # (W, B*N, C)
+ upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
+ :, :, W // 2, :
+ ]
+ # (B, N, C)
+ upper_face3d = self.upper_tail_fc(upper_face3d_feat)
+ # (B, N, C_exp)
+
+ lower_face3d_feat = self.lower_decoder(
+ tgt, content, pos=pos_embed, query_pos=style
+ )[0]
+ lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
+ :, :, W // 2, :
+ ]
+ lower_face3d = self.lower_tail_fc(lower_face3d_feat)
+ C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
+ face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
+ face3d[:, :, self.upper_face3d_indices] = upper_face3d
+ face3d[:, :, self.lower_face3d_indices] = lower_face3d
+ return face3d
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.path.append("/home/mayifeng/Research/styleTH")
+
+ from configs.default import get_cfg_defaults
+
+ cfg = get_cfg_defaults()
+ cfg.merge_from_file("configs/styleTH_unpair_lsfm_emotion.yaml")
+ cfg.freeze()
+
+ # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
+
+ # dummy_audio = torch.randint(0, 41, (5, 64, 11))
+ # dummy_content = content_encoder(dummy_audio)
+
+ # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
+ # dummy_face3d_seq = torch.randn(5, 64, 64)
+ # dummy_style_code = style_encoder(dummy_face3d_seq)
+
+ decoder = DisentangleDecoder(**cfg.DECODER)
+ dummy_content = torch.randn(5, 64, 11, 256)
+ dummy_style = torch.randn(5, 256)
+ dummy_output = decoder(dummy_content, dummy_style)
+
+ print("hello")
diff --git a/damo/dreamtalk/core/networks/dynamic_conv.py b/damo/dreamtalk/core/networks/dynamic_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1b836406e4041b12c21e5defdf74841178cbe00
--- /dev/null
+++ b/damo/dreamtalk/core/networks/dynamic_conv.py
@@ -0,0 +1,156 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Attention(nn.Module):
+ def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
+ super().__init__()
+ # self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.temprature = temperature
+ assert cond_planes > ratio
+ hidden_planes = cond_planes // ratio
+ self.net = nn.Sequential(
+ nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
+ nn.ReLU(),
+ nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
+ )
+
+ if init_weight:
+ self._initialize_weights()
+
+ def update_temprature(self):
+ if self.temprature > 1:
+ self.temprature -= 1
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, cond):
+ """
+
+ Args:
+ cond (_type_): (B, C_style)
+
+ Returns:
+ _type_: (B, K)
+ """
+
+ # att = self.avgpool(cond) # bs,dim,1,1
+ att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
+ att = self.net(att).view(cond.shape[0], -1) # bs,K
+ return F.softmax(att / self.temprature, -1)
+
+
+class DynamicConv(nn.Module):
+ def __init__(
+ self,
+ in_planes,
+ out_planes,
+ cond_planes,
+ kernel_size,
+ stride,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ K=4,
+ temperature=30,
+ ratio=4,
+ init_weight=True,
+ ):
+ super().__init__()
+ self.in_planes = in_planes
+ self.out_planes = out_planes
+ self.cond_planes = cond_planes
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.bias = bias
+ self.K = K
+ self.init_weight = init_weight
+ self.attention = Attention(
+ cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
+ )
+
+ self.weight = nn.Parameter(
+ torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
+ )
+ if bias:
+ self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
+ else:
+ self.bias = None
+
+ if self.init_weight:
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for i in range(self.K):
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
+ if fan_in != 0:
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, x, cond):
+ """
+
+ Args:
+ x (_type_): (B, C_in, L, 1)
+ cond (_type_): (B, C_style)
+
+ Returns:
+ _type_: (B, C_out, L, 1)
+ """
+ bs, in_planels, h, w = x.shape
+ softmax_att = self.attention(cond) # bs,K
+ x = x.view(1, -1, h, w)
+ weight = self.weight.view(self.K, -1) # K,-1
+ aggregate_weight = torch.mm(softmax_att, weight).view(
+ bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
+ ) # bs*out_p,in_p,k,k
+
+ if self.bias is not None:
+ bias = self.bias.view(self.K, -1) # K,out_p
+ aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
+ output = F.conv2d(
+ x, # 1, bs*in_p, L, 1
+ weight=aggregate_weight,
+ bias=aggregate_bias,
+ stride=self.stride,
+ padding=self.padding,
+ groups=self.groups * bs,
+ dilation=self.dilation,
+ )
+ else:
+ output = F.conv2d(
+ x,
+ weight=aggregate_weight,
+ bias=None,
+ stride=self.stride,
+ padding=self.padding,
+ groups=self.groups * bs,
+ dilation=self.dilation,
+ )
+
+ output = output.view(bs, self.out_planes, h, w)
+ return output
+
+
+if __name__ == "__main__":
+ input = torch.randn(3, 32, 64, 64)
+ m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=True)
+ out = m(input)
+ print(out.shape)
diff --git a/damo/dreamtalk/core/networks/dynamic_fc_decoder.py b/damo/dreamtalk/core/networks/dynamic_fc_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eee68bdc77bde90527540b7240c5403c9026fc0
--- /dev/null
+++ b/damo/dreamtalk/core/networks/dynamic_fc_decoder.py
@@ -0,0 +1,178 @@
+import torch.nn as nn
+import torch
+
+from core.networks.transformer import _get_activation_fn, _get_clones
+from core.networks.dynamic_linear import DynamicLinear
+
+
+class DynamicFCDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ d_style,
+ dynamic_K,
+ dynamic_ratio,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ # self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ style,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ pos=None,
+ query_pos=None,
+ ):
+ # q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ # def forward_pre(
+ # self,
+ # tgt,
+ # memory,
+ # tgt_mask=None,
+ # memory_mask=None,
+ # tgt_key_padding_mask=None,
+ # memory_key_padding_mask=None,
+ # pos=None,
+ # query_pos=None,
+ # ):
+ # tgt2 = self.norm1(tgt)
+ # # q = k = self.with_pos_embed(tgt2, query_pos)
+ # tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ # tgt = tgt + self.dropout1(tgt2)
+ # tgt2 = self.norm2(tgt)
+ # tgt2 = self.multihead_attn(
+ # query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
+ # )[0]
+ # tgt = tgt + self.dropout2(tgt2)
+ # tgt2 = self.norm3(tgt)
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ # tgt = tgt + self.dropout3(tgt2)
+ # return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ style,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ pos=None,
+ query_pos=None,
+ ):
+ if self.normalize_before:
+ raise NotImplementedError
+ # return self.forward_pre(
+ # tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ # )
+ return self.forward_post(
+ tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
+ )
+
+
+class DynamicFCDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ pos=None,
+ query_pos=None,
+ ):
+ style = query_pos[0]
+ # (B*N, C)
+ output = tgt + pos + query_pos
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ style,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+if __name__ == "__main__":
+ query = torch.randn(11, 1024, 256)
+ content = torch.randn(11, 1024, 256)
+ style = torch.randn(1024, 256)
+ pos = torch.randn(11, 1, 256)
+ m = DynamicFCDecoderLayer(256, 4, 256, 4, 4, 1024)
+
+ out = m(query, content, style, pos=pos)
+ print(out.shape)
diff --git a/damo/dreamtalk/core/networks/dynamic_linear.py b/damo/dreamtalk/core/networks/dynamic_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b35ac7e7845d4250b5d56a9c41affb61e7e4da
--- /dev/null
+++ b/damo/dreamtalk/core/networks/dynamic_linear.py
@@ -0,0 +1,50 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from core.networks.dynamic_conv import DynamicConv
+
+
+class DynamicLinear(nn.Module):
+ def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
+ super().__init__()
+
+ self.dynamic_conv = DynamicConv(
+ in_planes,
+ out_planes,
+ cond_planes,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ K=K,
+ ratio=ratio,
+ temperature=temperature,
+ init_weight=init_weight,
+ )
+
+ def forward(self, x, cond):
+ """
+
+ Args:
+ x (_type_): (L, B, C_in)
+ cond (_type_): (B, C_style)
+
+ Returns:
+ _type_: (L, B, C_out)
+ """
+ x = x.permute(1, 2, 0).unsqueeze(-1)
+ out = self.dynamic_conv(x, cond)
+ # (B, C_out, L, 1)
+ out = out.squeeze().permute(2, 0, 1)
+ return out
+
+
+if __name__ == "__main__":
+ input = torch.randn(11, 1024, 255)
+ cond = torch.randn(1024, 256)
+ m = DynamicLinear(255, 1000, 256, K=7, temperature=5, ratio=8)
+ out = m(input, cond)
+ print(out.shape)
diff --git a/damo/dreamtalk/core/networks/generator.py b/damo/dreamtalk/core/networks/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cd33c17a405906f4ae825a9015f6920a22d6c29
--- /dev/null
+++ b/damo/dreamtalk/core/networks/generator.py
@@ -0,0 +1,309 @@
+import torch
+from torch import nn
+
+from .transformer import (
+ TransformerEncoder,
+ TransformerEncoderLayer,
+ PositionalEncoding,
+ TransformerDecoderLayer,
+ TransformerDecoder,
+)
+from core.utils import _reset_parameters
+from core.networks.self_attention_pooling import SelfAttentionPooling
+
+
+# class ContentEncoder(nn.Module):
+# def __init__(
+# self,
+# d_model=512,
+# nhead=8,
+# num_encoder_layers=6,
+# dim_feedforward=2048,
+# dropout=0.1,
+# activation="relu",
+# normalize_before=False,
+# pos_embed_len=80,
+# ph_embed_dim=128,
+# ):
+# super().__init__()
+
+# encoder_layer = TransformerEncoderLayer(
+# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+# )
+# encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+# self.encoder = TransformerEncoder(
+# encoder_layer, num_encoder_layers, encoder_norm
+# )
+
+# _reset_parameters(self.encoder)
+
+# self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+# self.ph_embedding = nn.Embedding(41, ph_embed_dim)
+# self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
+
+# def forward(self, x):
+# """
+
+# Args:
+# x (_type_): (B, num_frames, window)
+
+# Returns:
+# content: (B, num_frames, window, C_dmodel)
+# """
+# x_embedding = self.ph_embedding(x)
+# x_embedding = self.increase_embed_dim(x_embedding)
+# # (B, N, W, C)
+# B, N, W, C = x_embedding.shape
+# x_embedding = x_embedding.reshape(B * N, W, C)
+# x_embedding = x_embedding.permute(1, 0, 2)
+# # (W, B*N, C)
+
+# pos = self.pos_embed(W)
+# pos = pos.permute(1, 0, 2)
+# # (W, 1, C)
+
+# content = self.encoder(x_embedding, pos=pos)
+# # (W, B*N, C)
+# content = content.permute(1, 0, 2).reshape(B, N, W, C)
+# # (B, N, W, C)
+
+# return content
+
+
+class ContentW2VEncoder(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ pos_embed_len=80,
+ ph_embed_dim=128,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ _reset_parameters(self.encoder)
+
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+ self.increase_embed_dim = nn.Linear(1024, d_model)
+
+ def forward(self, x):
+ """
+ Args:
+ x (_type_): (B, num_frames, window, C_wav2vec)
+
+ Returns:
+ content: (B, num_frames, window, C_dmodel)
+ """
+ x_embedding = self.increase_embed_dim(
+ x
+ ) # [16, 64, 11, 1024] -> [16, 64, 11, 256]
+ # (B, N, W, C)
+ B, N, W, C = x_embedding.shape
+ x_embedding = x_embedding.reshape(B * N, W, C)
+ x_embedding = x_embedding.permute(1, 0, 2) # [11, 1024, 256]
+ # (W, B*N, C)
+
+ pos = self.pos_embed(W)
+ pos = pos.permute(1, 0, 2) # [11, 1, 256]
+ # (W, 1, C)
+
+ content = self.encoder(x_embedding, pos=pos) # [11, 1024, 256]
+ # (W, B*N, C)
+ content = content.permute(1, 0, 2).reshape(B, N, W, C)
+ # (B, N, W, C)
+
+ return content
+
+
+class StyleEncoder(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ pos_embed_len=80,
+ input_dim=128,
+ aggregate_method="average",
+ ):
+ super().__init__()
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+ _reset_parameters(self.encoder)
+
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+ self.increase_embed_dim = nn.Linear(input_dim, d_model)
+
+ self.aggregate_method = None
+ if aggregate_method == "self_attention_pooling":
+ self.aggregate_method = SelfAttentionPooling(d_model)
+ elif aggregate_method == "average":
+ pass
+ else:
+ raise ValueError(f"Invalid aggregate method {aggregate_method}")
+
+ def forward(self, x, pad_mask=None):
+ """
+
+ Args:
+ x (_type_): (B, num_frames(L), C_exp)
+ pad_mask: (B, num_frames)
+
+ Returns:
+ style_code: (B, C_model)
+ """
+ x = self.increase_embed_dim(x)
+ # (B, L, C)
+ x = x.permute(1, 0, 2)
+ # (L, B, C)
+
+ pos = self.pos_embed(x.shape[0])
+ pos = pos.permute(1, 0, 2)
+ # (L, 1, C)
+
+ style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
+ # (L, B, C)
+
+ if self.aggregate_method is not None:
+ permute_style = style.permute(1, 0, 2)
+ # (B, L, C)
+ style_code = self.aggregate_method(permute_style, pad_mask)
+ return style_code
+
+ if pad_mask is None:
+ style = style.permute(1, 2, 0)
+ # (B, C, L)
+ style_code = style.mean(2)
+ # (B, C)
+ else:
+ permute_style = style.permute(1, 0, 2)
+ # (B, L, C)
+ permute_style[pad_mask] = 0
+ sum_style_code = permute_style.sum(dim=1)
+ # (B, C)
+ valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
+ # (B, 1)
+ style_code = sum_style_code / valid_token_num
+ # (B, C)
+
+ return style_code
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_decoder_layers=3,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ pos_embed_len=80,
+ output_dim=64,
+ **_,
+ ) -> None:
+ super().__init__()
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ _reset_parameters(self.decoder)
+
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
+
+ tail_hidden_dim = d_model // 2
+ self.tail_fc = nn.Sequential(
+ nn.Linear(d_model, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
+ nn.ReLU(),
+ nn.Linear(tail_hidden_dim, output_dim),
+ )
+
+ def forward(self, content, style_code):
+ """
+
+ Args:
+ content (_type_): (B, num_frames, window, C_dmodel)
+ style_code (_type_): (B, C_dmodel)
+
+ Returns:
+ face3d: (B, num_frames, C_3dmm)
+ """
+ B, N, W, C = content.shape
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
+ # (W, B*N, C)
+
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
+ # (W, B*N, C)
+ tgt = torch.zeros_like(style)
+ pos_embed = self.pos_embed(W)
+ pos_embed = pos_embed.permute(1, 0, 2)
+ face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
+ # (W, B*N, C)
+ face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
+ # (B, N, C)
+ face3d = self.tail_fc(face3d_feat)
+ # (B, N, C_exp)
+ return face3d
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.path.append("/home/mayifeng/Research/styleTH")
+
+ from configs.default import get_cfg_defaults
+
+ cfg = get_cfg_defaults()
+ cfg.merge_from_file("configs/styleTH_bp.yaml")
+ cfg.freeze()
+
+ # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
+
+ # dummy_audio = torch.randint(0, 41, (5, 64, 11))
+ # dummy_content = content_encoder(dummy_audio)
+
+ # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
+ # dummy_face3d_seq = torch.randn(5, 64, 64)
+ # dummy_style_code = style_encoder(dummy_face3d_seq)
+
+ decoder = Decoder(**cfg.DECODER)
+ dummy_content = torch.randn(5, 64, 11, 512)
+ dummy_style = torch.randn(5, 512)
+ dummy_output = decoder(dummy_content, dummy_style)
+
+ print("hello")
diff --git a/damo/dreamtalk/core/networks/mish.py b/damo/dreamtalk/core/networks/mish.py
new file mode 100644
index 0000000000000000000000000000000000000000..607b95d33edd40bb53f93682bdcd9e0ff31ffbe4
--- /dev/null
+++ b/damo/dreamtalk/core/networks/mish.py
@@ -0,0 +1,51 @@
+"""
+Applies the mish function element-wise:
+mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+"""
+
+# import pytorch
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+@torch.jit.script
+def mish(input):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+ See additional documentation for mish class.
+ """
+ return input * torch.tanh(F.softplus(input))
+
+class Mish(nn.Module):
+ """
+ Applies the mish function element-wise:
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
+
+ Shape:
+ - Input: (N, *) where * means, any number of additional
+ dimensions
+ - Output: (N, *), same shape as the input
+
+ Examples:
+ >>> m = Mish()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
+ """
+
+ def __init__(self):
+ """
+ Init method.
+ """
+ super().__init__()
+
+ def forward(self, input):
+ """
+ Forward pass of the function.
+ """
+ if torch.__version__ >= "1.9":
+ return F.mish(input)
+ else:
+ return mish(input)
\ No newline at end of file
diff --git a/damo/dreamtalk/core/networks/self_attention_pooling.py b/damo/dreamtalk/core/networks/self_attention_pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93f1791f57092b704d0547c0402a80cb579c7a3
--- /dev/null
+++ b/damo/dreamtalk/core/networks/self_attention_pooling.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+from core.networks.mish import Mish
+
+
+class SelfAttentionPooling(nn.Module):
+ """
+ Implementation of SelfAttentionPooling
+ Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
+ https://arxiv.org/pdf/2008.01077v1.pdf
+ """
+
+ def __init__(self, input_dim):
+ super(SelfAttentionPooling, self).__init__()
+ self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
+ self.softmax = nn.functional.softmax
+
+ def forward(self, batch_rep, att_mask=None):
+ """
+ N: batch size, T: sequence length, H: Hidden dimension
+ input:
+ batch_rep : size (N, T, H)
+ attention_weight:
+ att_w : size (N, T, 1)
+ att_mask:
+ att_mask: size (N, T): if True, mask this item.
+ return:
+ utter_rep: size (N, H)
+ """
+
+ att_logits = self.W(batch_rep).squeeze(-1)
+ # (N, T)
+ if att_mask is not None:
+ att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
+ # (N, T)
+ att_logits = att_mask_logits + att_logits
+
+ att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
+ utter_rep = torch.sum(batch_rep * att_w, dim=1)
+
+ return utter_rep
+
+
+if __name__ == "__main__":
+ batch = torch.randn(8, 64, 256)
+ self_attn_pool = SelfAttentionPooling(256)
+ att_mask = torch.zeros(8, 64)
+ att_mask[:, 60:] = 1
+ att_mask = att_mask.to(torch.bool)
+ output = self_attn_pool(batch, att_mask)
+ # (8, 256)
+
+ print("hello")
diff --git a/damo/dreamtalk/core/networks/transformer.py b/damo/dreamtalk/core/networks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..499cbcd65a6c211d7c8fde1d202e58ccbfb2c5a0
--- /dev/null
+++ b/damo/dreamtalk/core/networks/transformer.py
@@ -0,0 +1,293 @@
+import torch.nn as nn
+import torch
+import numpy as np
+import torch.nn.functional as F
+import copy
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self, d_hid, n_position=200):
+ super(PositionalEncoding, self).__init__()
+
+ # Not a parameter
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
+ ''' Sinusoid position encoding table '''
+ # TODO: make it with torch instead of numpy
+
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+ def forward(self, winsize):
+ return self.pos_table[:, :winsize].clone().detach()
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+class Transformer(nn.Module):
+
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False,
+ return_intermediate_dec=True):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+ return_intermediate=return_intermediate_dec)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self,opt, src, query_embed, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+
+ src = src.permute(1, 0, 2)
+ pos_embed = pos_embed.permute(1, 0, 2)
+ query_embed = query_embed.permute(1, 0, 2)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, pos=pos_embed)
+
+ hs = self.decoder(tgt, memory,
+ pos=pos_embed, query_pos=query_embed)
+ return hs
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
+ output = src+pos
+
+ for layer in self.layers:
+ output = layer(output, src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ output = tgt+pos+query_pos
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(output, memory, tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos, query_pos=query_pos)
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None):
+ # q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(self, src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None):
+ src2 = self.norm1(src)
+ # q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(self, src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ # q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(query=tgt,
+ key=memory,
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(self, tgt, memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ tgt2 = self.norm1(tgt)
+ # q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(query=tgt2,
+ key=memory,
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(self, tgt, memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
+
+
+
diff --git a/damo/dreamtalk/core/utils.py b/damo/dreamtalk/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb4e3934995b06e7c8c3ac8bea6e6d655d3e71f
--- /dev/null
+++ b/damo/dreamtalk/core/utils.py
@@ -0,0 +1,456 @@
+import os
+import argparse
+from collections import defaultdict
+import logging
+import pickle
+import json
+
+import numpy as np
+import torch
+from torch import nn
+from scipy.io import loadmat
+
+from configs.default import get_cfg_defaults
+import dlib
+import cv2
+
+
+def _reset_parameters(model):
+ for p in model.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+
+def get_video_style(video_name, style_type):
+ person_id, direction, emotion, level, *_ = video_name.split("_")
+ if style_type == "id_dir_emo_level":
+ style = "_".join([person_id, direction, emotion, level])
+ elif style_type == "emotion":
+ style = emotion
+ elif style_type == "id":
+ style = person_id
+ else:
+ raise ValueError("Unknown style type")
+
+ return style
+
+
+def get_style_video_lists(video_list, style_type):
+ style2video_list = defaultdict(list)
+ for video in video_list:
+ style = get_video_style(video, style_type)
+ style2video_list[style].append(video)
+
+ return style2video_list
+
+
+def get_face3d_clip(
+ video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
+):
+ """_summary_
+
+ Args:
+ video_name (_type_): _description_
+ video_root_dir (_type_): _description_
+ num_frames (_type_): _description_
+ start_idx (_type_): "random" , middle, int
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
+
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+ video_path = os.path.join(video_root_dir, video_name)
+ if video_path[-3:] == "mat":
+ face3d_all = loadmat(video_path)["coeff"]
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
+ elif video_path[-3:] == "txt":
+ face3d_exp = np.loadtxt(video_path)
+ else:
+ raise ValueError("Invalid 3DMM file extension")
+
+ length = face3d_exp.shape[0]
+ clip_num_frames = num_frames
+ if start_idx == "random":
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+ elif start_idx == "middle":
+ clip_start_idx = (length - clip_num_frames + 1) // 2
+ elif isinstance(start_idx, int):
+ clip_start_idx = start_idx
+ else:
+ raise ValueError(f"Invalid start_idx {start_idx}")
+
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+ face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
+
+ return face3d_clip
+
+
+def get_video_style_clip(
+ video_name,
+ video_root_dir,
+ style_max_len,
+ start_idx="random",
+ dtype=torch.float32,
+ return_start_idx=False,
+):
+ video_path = os.path.join(video_root_dir, video_name)
+ if video_path[-3:] == "mat":
+ face3d_all = loadmat(video_path)["coeff"]
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
+ elif video_path[-3:] == "txt":
+ face3d_exp = np.loadtxt(video_path)
+ else:
+ raise ValueError("Invalid 3DMM file extension")
+
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
+
+ length = face3d_exp.shape[0]
+ if length >= style_max_len:
+ clip_num_frames = style_max_len
+ if start_idx == "random":
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+ elif start_idx == "middle":
+ clip_start_idx = (length - clip_num_frames + 1) // 2
+ elif isinstance(start_idx, int):
+ clip_start_idx = start_idx
+ else:
+ raise ValueError(f"Invalid start_idx {start_idx}")
+
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+ pad_mask = torch.tensor([False] * style_max_len)
+ else:
+ clip_start_idx = None
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
+
+ if return_start_idx:
+ return face3d_clip, pad_mask, clip_start_idx
+ else:
+ return face3d_clip, pad_mask
+
+
+def get_video_style_clip_from_np(
+ face3d_exp,
+ style_max_len,
+ start_idx="random",
+ dtype=torch.float32,
+ return_start_idx=False,
+):
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
+
+ length = face3d_exp.shape[0]
+ if length >= style_max_len:
+ clip_num_frames = style_max_len
+ if start_idx == "random":
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
+ elif start_idx == "middle":
+ clip_start_idx = (length - clip_num_frames + 1) // 2
+ elif isinstance(start_idx, int):
+ clip_start_idx = start_idx
+ else:
+ raise ValueError(f"Invalid start_idx {start_idx}")
+
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
+ pad_mask = torch.tensor([False] * style_max_len)
+ else:
+ clip_start_idx = None
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
+
+ if return_start_idx:
+ return face3d_clip, pad_mask, clip_start_idx
+ else:
+ return face3d_clip, pad_mask
+
+
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
+ """
+
+ Args:
+ audio_feat (np.ndarray): (N, 1024)
+ start_idx (_type_): _description_
+ num_frames (_type_): _description_
+ """
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
+ audio_window_list = []
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
+ for center_idx in center_idx_list:
+ cur_audio_window = []
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
+ if i < 0:
+ cur_audio_window.append(padding)
+ elif i >= len(audio_feat):
+ cur_audio_window.append(padding)
+ else:
+ cur_audio_window.append(audio_feat[i])
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
+ audio_window_list.append(cur_audio_win_array)
+
+ audio_window_array = np.stack(audio_window_list, axis=0)
+ return audio_window_array
+
+
+def setup_config():
+ parser = argparse.ArgumentParser(description="voice2pose main program")
+ parser.add_argument(
+ "--config_file", default="", metavar="FILE", help="path to config file"
+ )
+ parser.add_argument(
+ "--resume_from", type=str, default=None, help="the checkpoint to resume from"
+ )
+ parser.add_argument(
+ "--test_only", action="store_true", help="perform testing and evaluation only"
+ )
+ parser.add_argument(
+ "--demo_input", type=str, default=None, help="path to input for demo"
+ )
+ parser.add_argument(
+ "--checkpoint", type=str, default=None, help="the checkpoint to test with"
+ )
+ parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
+ parser.add_argument(
+ "opts",
+ help="Modify config options using the command-line",
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ help="local rank for DistributedDataParallel",
+ )
+ parser.add_argument(
+ "--master_port",
+ type=str,
+ default="12345",
+ )
+ parser.add_argument(
+ "--max_audio_len",
+ type=int,
+ default=450,
+ help="max_audio_len for inference",
+ )
+ parser.add_argument(
+ "--ddim_num_step",
+ type=int,
+ default=10,
+ )
+ parser.add_argument(
+ "--inference_seed",
+ type=int,
+ default=1,
+ )
+ parser.add_argument(
+ "--inference_sample_method",
+ type=str,
+ default="ddim",
+ )
+ args = parser.parse_args()
+
+ cfg = get_cfg_defaults()
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ return args, cfg
+
+
+def setup_logger(base_path, exp_name):
+ rootLogger = logging.getLogger()
+ rootLogger.setLevel(logging.INFO)
+
+ logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
+
+ log_path = "{0}/{1}.log".format(base_path, exp_name)
+ fileHandler = logging.FileHandler(log_path)
+ fileHandler.setFormatter(logFormatter)
+ rootLogger.addHandler(fileHandler)
+
+ consoleHandler = logging.StreamHandler()
+ consoleHandler.setFormatter(logFormatter)
+ rootLogger.addHandler(consoleHandler)
+ rootLogger.handlers[0].setLevel(logging.INFO)
+
+ logging.info("log path: %s" % log_path)
+
+
+def cosine_loss(a, v, y, logloss=nn.BCELoss()):
+ d = nn.functional.cosine_similarity(a, v)
+ loss = logloss(d.unsqueeze(1), y)
+ return loss
+
+
+def get_pose_params(mat_path):
+ """Get pose parameters from mat file
+
+ Args:
+ mat_path (str): path of mat file
+
+ Returns:
+ pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
+ """
+ mat_dict = loadmat(mat_path)
+
+ np_3dmm = mat_dict["coeff"]
+ angles = np_3dmm[:, 224:227]
+ translations = np_3dmm[:, 254:257]
+
+ np_trans_params = mat_dict["transform_params"]
+ crop = np_trans_params[:, -3:]
+
+ pose_params = np.concatenate((angles, translations, crop), axis=1)
+
+ return pose_params
+
+
+def sinusoidal_embedding(timesteps, dim):
+ """
+
+ Args:
+ timesteps (_type_): (B,)
+ dim (_type_): (C_embed)
+
+ Returns:
+ _type_: (B, C_embed)
+ """
+ # check input
+ half = dim // 2
+ timesteps = timesteps.float()
+
+ # compute sinusoidal embedding
+ sinusoid = torch.outer(
+ timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
+ )
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ if dim % 2 != 0:
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
+ return x
+
+
+def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
+ """
+
+ Args:
+ audio_feat (np.ndarray): (250, 1024)
+ start_idx (_type_): _description_
+ num_frames (_type_): _description_
+ """
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
+ audio_window_list = []
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
+ for center_idx in center_idx_list:
+ cur_audio_window = []
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
+ if i < 0:
+ cur_audio_window.append(padding)
+ elif i >= len(audio_feat):
+ cur_audio_window.append(padding)
+ else:
+ cur_audio_window.append(audio_feat[i])
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
+ audio_window_list.append(cur_audio_win_array)
+
+ audio_window_array = np.stack(audio_window_list, axis=0)
+ return audio_window_array
+
+
+def reshape_audio_feat(style_audio_all_raw, stride):
+ """_summary_
+
+ Args:
+ style_audio_all_raw (_type_): (stride * L, C)
+ stride (_type_): int
+
+ Returns:
+ _type_: (L, C * stride)
+ """
+ style_audio_all_raw = style_audio_all_raw[
+ : style_audio_all_raw.shape[0] // stride * stride
+ ]
+ style_audio_all_raw = style_audio_all_raw.reshape(
+ style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
+ )
+ style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
+ return style_audio_all
+
+
+import random
+
+
+def get_derangement_tuple(n):
+ while True:
+ v = [i for i in range(n)]
+ for j in range(n - 1, -1, -1):
+ p = random.randint(0, j)
+ if v[p] == j:
+ break
+ else:
+ v[j], v[p] = v[p], v[j]
+ else:
+ if v[0] != 0:
+ return tuple(v)
+
+
+def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ width_increase = max(
+ increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
+ )
+ height_increase = max(
+ increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
+ )
+
+ left_t = int(left - width_increase * width)
+ top_t = int(top - height_increase * height)
+ right_t = int(right + width_increase * width)
+ bot_t = int(bot + height_increase * height)
+
+ left_oob = -min(0, left_t)
+ right_oob = right - min(right_t, w)
+ top_oob = -min(0, top_t)
+ bot_oob = bot - min(bot_t, h)
+
+ if max(left_oob, right_oob, top_oob, bot_oob) > 0:
+ max_w = max(left_oob, right_oob)
+ max_h = max(top_oob, bot_oob)
+ if max_w > max_h:
+ return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
+ else:
+ return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
+
+ else:
+ return (left_t, top_t, right_t, bot_t)
+
+
+def crop_src_image(src_img, save_img, increase_ratio, detector=None):
+ if detector is None:
+ detector = dlib.get_frontal_face_detector()
+
+ img = cv2.imread(src_img)
+ faces = detector(img, 0)
+ h, width, _ = img.shape
+ if len(faces) > 0:
+ bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
+ l = bbox[3] - bbox[1]
+ bbox[1] = bbox[1] - l * 0.1
+ bbox[3] = bbox[3] - l * 0.1
+ bbox[1] = max(0, bbox[1])
+ bbox[3] = min(h, bbox[3])
+ bbox = compute_aspect_preserved_bbox(
+ tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
+ )
+ img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
+ img = cv2.resize(img, (256, 256))
+ cv2.imwrite(save_img, img)
+ else:
+ raise ValueError("No face detected in the input image")
+ # img = cv2.resize(img, (256, 256))
+ # cv2.imwrite(save_img, img)
diff --git a/damo/dreamtalk/data/audio/German1.wav b/damo/dreamtalk/data/audio/German1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..0be224e38f3274559fb575151d4403a3250a0346
Binary files /dev/null and b/damo/dreamtalk/data/audio/German1.wav differ
diff --git a/damo/dreamtalk/data/audio/German2.wav b/damo/dreamtalk/data/audio/German2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..51112450a80c68dfd075bf0c5c8beeee7a996e4d
Binary files /dev/null and b/damo/dreamtalk/data/audio/German2.wav differ
diff --git a/damo/dreamtalk/data/audio/German3.wav b/damo/dreamtalk/data/audio/German3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..48ec6fe7db2a4e5bee628bccf15e1a8fb8b364fc
Binary files /dev/null and b/damo/dreamtalk/data/audio/German3.wav differ
diff --git a/damo/dreamtalk/data/audio/German4.wav b/damo/dreamtalk/data/audio/German4.wav
new file mode 100644
index 0000000000000000000000000000000000000000..49d0c304740e589971639bf0de823f5c9e68156a
Binary files /dev/null and b/damo/dreamtalk/data/audio/German4.wav differ
diff --git a/damo/dreamtalk/data/audio/acknowledgement_chinese.m4a b/damo/dreamtalk/data/audio/acknowledgement_chinese.m4a
new file mode 100644
index 0000000000000000000000000000000000000000..229cfe91ec64d6b875145e3c04756e64b02da9be
Binary files /dev/null and b/damo/dreamtalk/data/audio/acknowledgement_chinese.m4a differ
diff --git a/damo/dreamtalk/data/audio/acknowledgement_english.m4a b/damo/dreamtalk/data/audio/acknowledgement_english.m4a
new file mode 100644
index 0000000000000000000000000000000000000000..6bc865bfef5fd7d558f0225230ac047c151380a2
Binary files /dev/null and b/damo/dreamtalk/data/audio/acknowledgement_english.m4a differ
diff --git a/damo/dreamtalk/data/audio/chinese1_haierlizhi.wav b/damo/dreamtalk/data/audio/chinese1_haierlizhi.wav
new file mode 100644
index 0000000000000000000000000000000000000000..df69fd0897ab3278462176467b3acdd71b3e37f6
Binary files /dev/null and b/damo/dreamtalk/data/audio/chinese1_haierlizhi.wav differ
diff --git a/damo/dreamtalk/data/audio/chinese2_guanyu.wav b/damo/dreamtalk/data/audio/chinese2_guanyu.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8ec3ed9d7553e1bb9fbb16cb4b200deb945acc69
Binary files /dev/null and b/damo/dreamtalk/data/audio/chinese2_guanyu.wav differ
diff --git a/damo/dreamtalk/data/audio/french1.wav b/damo/dreamtalk/data/audio/french1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5cf06bdf707f781ff43a4dfba3a6cb841ae7dcd1
Binary files /dev/null and b/damo/dreamtalk/data/audio/french1.wav differ
diff --git a/damo/dreamtalk/data/audio/french2.wav b/damo/dreamtalk/data/audio/french2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e1298d1b088536d1e21d197ca86222b2a3a96371
Binary files /dev/null and b/damo/dreamtalk/data/audio/french2.wav differ
diff --git a/damo/dreamtalk/data/audio/french3.wav b/damo/dreamtalk/data/audio/french3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..02c96b1293e0cb51ed5ad4b08d5c581c61ad15d0
Binary files /dev/null and b/damo/dreamtalk/data/audio/french3.wav differ
diff --git a/damo/dreamtalk/data/audio/italian1.wav b/damo/dreamtalk/data/audio/italian1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2428b1912db7ddb33269c42379e2a52479096c26
Binary files /dev/null and b/damo/dreamtalk/data/audio/italian1.wav differ
diff --git a/damo/dreamtalk/data/audio/italian2.wav b/damo/dreamtalk/data/audio/italian2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8f535af7ff113232378ceea190b1bc55c30a39d3
Binary files /dev/null and b/damo/dreamtalk/data/audio/italian2.wav differ
diff --git a/damo/dreamtalk/data/audio/italian3.wav b/damo/dreamtalk/data/audio/italian3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2f275e7a144aeed17c26cfdff85ec5d2816bef4e
Binary files /dev/null and b/damo/dreamtalk/data/audio/italian3.wav differ
diff --git a/damo/dreamtalk/data/audio/japan1.wav b/damo/dreamtalk/data/audio/japan1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..779e7464683d0d89be84311bfdd19737893d2295
Binary files /dev/null and b/damo/dreamtalk/data/audio/japan1.wav differ
diff --git a/damo/dreamtalk/data/audio/japan2.wav b/damo/dreamtalk/data/audio/japan2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6af6a81762b591c13aec9748eef60b512d4e21c5
Binary files /dev/null and b/damo/dreamtalk/data/audio/japan2.wav differ
diff --git a/damo/dreamtalk/data/audio/japan3.wav b/damo/dreamtalk/data/audio/japan3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3279aa45bb03212b47046b214d81219cc32377da
Binary files /dev/null and b/damo/dreamtalk/data/audio/japan3.wav differ
diff --git a/damo/dreamtalk/data/audio/korean1.wav b/damo/dreamtalk/data/audio/korean1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..52c94f6773dd7f6b580f25f4593daa7b3fe79c84
Binary files /dev/null and b/damo/dreamtalk/data/audio/korean1.wav differ
diff --git a/damo/dreamtalk/data/audio/korean2.wav b/damo/dreamtalk/data/audio/korean2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..0340950ccd021290db94ee97c6ec476db646a11d
Binary files /dev/null and b/damo/dreamtalk/data/audio/korean2.wav differ
diff --git a/damo/dreamtalk/data/audio/korean3.wav b/damo/dreamtalk/data/audio/korean3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..376ab876815d8302773d31b23c136b170640c589
Binary files /dev/null and b/damo/dreamtalk/data/audio/korean3.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav b/damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ef037f924f0454dd2dc042d2c96eb11186282c90
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav
new file mode 100644
index 0000000000000000000000000000000000000000..271b3a8bd21b69f9a4c596464e4c6e53928f4ef2
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav
new file mode 100644
index 0000000000000000000000000000000000000000..5ce9cc48ec417476c9eef275941969688dc1bc1c
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav
new file mode 100644
index 0000000000000000000000000000000000000000..edca183b668b8e47dfd701cd9178da0caf2254b8
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_narrative.wav b/damo/dreamtalk/data/audio/noisy_audio_narrative.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ff996b007363160eaed06296f7c4cf5bcdf19bc9
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_narrative.wav differ
diff --git a/damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav b/damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav
new file mode 100644
index 0000000000000000000000000000000000000000..36b98841f52a34c36a77c16ad309cb349b90b12a
Binary files /dev/null and b/damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav differ
diff --git a/damo/dreamtalk/data/audio/out_of_domain_narrative.wav b/damo/dreamtalk/data/audio/out_of_domain_narrative.wav
new file mode 100644
index 0000000000000000000000000000000000000000..17f04c7384ec352b527077b737d0cfbf6782147f
Binary files /dev/null and b/damo/dreamtalk/data/audio/out_of_domain_narrative.wav differ
diff --git a/damo/dreamtalk/data/audio/spanish1.wav b/damo/dreamtalk/data/audio/spanish1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3513ead1cc2ddc0153150f2909e4d1431875a0f8
Binary files /dev/null and b/damo/dreamtalk/data/audio/spanish1.wav differ
diff --git a/damo/dreamtalk/data/audio/spanish2.wav b/damo/dreamtalk/data/audio/spanish2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8147b41983140cafef1800da589c5ebc241f3240
Binary files /dev/null and b/damo/dreamtalk/data/audio/spanish2.wav differ
diff --git a/damo/dreamtalk/data/audio/spanish3.wav b/damo/dreamtalk/data/audio/spanish3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..83e8663fef0e40a8edb5439d2c9f56e582a7f44f
Binary files /dev/null and b/damo/dreamtalk/data/audio/spanish3.wav differ
diff --git a/damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat b/damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e3852095787dfd4ce449c3193e5daae410eb9854
--- /dev/null
+++ b/damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68f323fb87e49174f911b05fe6e68131ae45e3189ac6d2289976e2585789cd06
+size 2176968
diff --git a/damo/dreamtalk/data/src_img/cropped/chpa5.png b/damo/dreamtalk/data/src_img/cropped/chpa5.png
new file mode 100644
index 0000000000000000000000000000000000000000..28d920861833df3281bb657da39991cd72ab7178
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/chpa5.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/cut_img.png b/damo/dreamtalk/data/src_img/cropped/cut_img.png
new file mode 100644
index 0000000000000000000000000000000000000000..327972e6cc3c431ac8774453c9d06ffc9dda6830
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/cut_img.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/f30.png b/damo/dreamtalk/data/src_img/cropped/f30.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c5118ef14e28d290a72ed4aaea25a7f19c6c16b
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/f30.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/menglu2.png b/damo/dreamtalk/data/src_img/cropped/menglu2.png
new file mode 100644
index 0000000000000000000000000000000000000000..7c869aaac7a7f5e5aeded9dd8c248c61bc1eb9b6
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/menglu2.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/nscu2.png b/damo/dreamtalk/data/src_img/cropped/nscu2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b16dc9c89286b7a5a116a25c696cd3280ad2187
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/nscu2.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/zp1.png b/damo/dreamtalk/data/src_img/cropped/zp1.png
new file mode 100644
index 0000000000000000000000000000000000000000..20997f7900ecc40d1c5833c491f4fbf7f3d06fda
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/zp1.png differ
diff --git a/damo/dreamtalk/data/src_img/cropped/zt12.png b/damo/dreamtalk/data/src_img/cropped/zt12.png
new file mode 100644
index 0000000000000000000000000000000000000000..998551d51e497862518d8b7bf2c6636a93cc98d8
Binary files /dev/null and b/damo/dreamtalk/data/src_img/cropped/zt12.png differ
diff --git a/damo/dreamtalk/data/src_img/uncropped/face3.png b/damo/dreamtalk/data/src_img/uncropped/face3.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9962172854227c47b7d572db0290eb9b3b5b974
Binary files /dev/null and b/damo/dreamtalk/data/src_img/uncropped/face3.png differ
diff --git a/damo/dreamtalk/data/src_img/uncropped/male_face.png b/damo/dreamtalk/data/src_img/uncropped/male_face.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f62eba092d45ca304534eb7cffc959c11c7978a
Binary files /dev/null and b/damo/dreamtalk/data/src_img/uncropped/male_face.png differ
diff --git a/damo/dreamtalk/data/src_img/uncropped/uncut_src_img.jpg b/damo/dreamtalk/data/src_img/uncropped/uncut_src_img.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..607873fbff5a91b61b14a64cada0f1b6a7de22f4
Binary files /dev/null and b/damo/dreamtalk/data/src_img/uncropped/uncut_src_img.jpg differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_angry_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_angry_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..111bb62c9a9fbf75d53b3f94193d92b6a20ec3eb
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_angry_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_contempt_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_contempt_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..99bc6b0f8512170aa4214ce3a5ef84a925898687
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_contempt_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_disgusted_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_disgusted_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..22defd9436ac192818d112f07a53d7102f533f34
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_disgusted_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_fear_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_fear_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..bca8ca65d1d3f6123afc57eebd38a6e6f821d7fa
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_fear_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_happy_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_happy_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..e5698aabc0f4bd0dfde84c2c7e34f02e0c2be41b
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_happy_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_neutral_level1_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_neutral_level1_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a89c415af4266ca566c358af927634689c2a481f
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_neutral_level1_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_sad_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_sad_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..575299be97a817ed368f615cf7c6cd7384f5dc96
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_sad_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/M030_front_surprised_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/M030_front_surprised_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..cc06d238d8a87fc7342e231993c41d5620b5e00f
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/M030_front_surprised_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_angry_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_angry_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..c00e7169f9c0d3b7407d732b3c858403ca2d52a0
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_angry_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_contempt_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_contempt_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..83643c977dbe3fa7173e110d1e8771aee2def3ab
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_contempt_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_disgusted_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_disgusted_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..8f60c904f106acf73d6ce496d83e2d3b1ba3f371
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_disgusted_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_fear_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_fear_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..6553751c3372b621f1e674df168f3cc2dc20dd45
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_fear_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_happy_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_happy_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..9ec3cc3098f29d7053eb9bfa645558c0c0b2157c
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_happy_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_neutral_level1_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_neutral_level1_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..91ec8d7524271becf96ac20f9e889549da9aa963
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_neutral_level1_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_sad_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_sad_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..75fb8f25f920b446fc5d4b5c0160c2b229f2fe63
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_sad_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W009_front_surprised_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W009_front_surprised_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..8abea0f5e4d70d255320c3bf1b6b585aac3bc9bf
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W009_front_surprised_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_angry_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_angry_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..2b7a3f7d059cb6dd336af5cbf9851e99098cf271
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_angry_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_contempt_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_contempt_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..b17954df381995545c78fc383fa4dfef8750a45a
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_contempt_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_disgusted_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_disgusted_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..ef80145417f226b4c73ed54bb2e7b0cd9964a7c0
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_disgusted_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_fear_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_fear_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..ec1938f6a564baa5c877d00d5bd7336f24bcfea2
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_fear_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_happy_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_happy_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a5d2adff0cf167098a6b16a904ea79daabc61098
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_happy_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_neutral_level1_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_neutral_level1_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..f4e5cfabcb37dbb5b7c726a52ed1a677c88bd8bb
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_neutral_level1_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_sad_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_sad_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..ec6ee03f7b5c3e0f5396e6b2d5c4210f8de70729
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_sad_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/3DMM/W011_front_surprised_level3_001.mat b/damo/dreamtalk/data/style_clip/3DMM/W011_front_surprised_level3_001.mat
new file mode 100644
index 0000000000000000000000000000000000000000..7df4523eed7503e751558ff651227001977834b4
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/3DMM/W011_front_surprised_level3_001.mat differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_angry_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_angry_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c760ca665da973dc7ec22b2ce0f282b6e94f21f0
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_angry_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_contempt_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_contempt_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7b829124878cf2d0372548a9c9d834665ae82011
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_contempt_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_disgusted_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_disgusted_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8fecdabe99430ca3af6298a41d472646665c6170
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_disgusted_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_fear_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_fear_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3f6a706146370a6ba8bc208ba6c8cb9d62d9b1c4
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_fear_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_happy_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_happy_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c2f7f56199830a1964fd629c5b7d2e94340d9614
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_happy_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_neutral_level1_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_neutral_level1_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..45711ed965a749c9d8bba9dc42428ec2ac367190
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_neutral_level1_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_sad_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_sad_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a81804683132ea2a27046a1d4f6de8c826f691fe
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_sad_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/M030_front_surprised_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/M030_front_surprised_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..883983428c516b04f9d3697f999f7f5882002ccc
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/M030_front_surprised_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_angry_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_angry_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..45f6df25621c55a84a95ac1767532325f6f4d78d
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_angry_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_contempt_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_contempt_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0668cd1dbee1613a047fadf40d7ef892734dbf7c
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_contempt_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_disgusted_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_disgusted_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9a78a57ca6520adadb8e69a9b4546d5eb89f979d
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_disgusted_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_fear_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_fear_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3a288bd80c09c94d07f31d5b28201a0acfad6c8d
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_fear_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_happy_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_happy_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e713c3a0b69a0300dbb35b06afb2ca3bfa97acca
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_happy_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_neutral_level1_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_neutral_level1_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..95627a6443e9b062033ca48b9004ea755978d7d9
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_neutral_level1_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_sad_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_sad_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d2fe6774d0960a244dd240e90098908e118682f1
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_sad_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W009_front_surprised_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W009_front_surprised_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..959b327eebd1453e095f60e77e2260da4969a0d4
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W009_front_surprised_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_angry_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_angry_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a3c115ce906eefa937a26d21f0bb248414f10e2e
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_angry_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_contempt_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_contempt_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..f72b5a69e5ad29f800de631cf82ad369772616ee
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_contempt_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_disgusted_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_disgusted_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e6d11603479371719fa6f8204942f1e2252b6214
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_disgusted_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_fear_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_fear_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..da8172c3e36c4edd15f9bd3716c1053651487a2c
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_fear_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_happy_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_happy_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c72089c182c87ab4b751ca2cbd264b207697eb41
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_happy_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_neutral_level1_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_neutral_level1_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0e89207480b08d94546fe4b76019ccc019523399
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_neutral_level1_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_sad_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_sad_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b67568a49ab1439dc22b5a41765a3387b226536e
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_sad_level3_001.mp4 differ
diff --git a/damo/dreamtalk/data/style_clip/video/W011_front_surprised_level3_001.mp4 b/damo/dreamtalk/data/style_clip/video/W011_front_surprised_level3_001.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c1e62ffe96401b6bf8326c0a082b9407e956b4a6
Binary files /dev/null and b/damo/dreamtalk/data/style_clip/video/W011_front_surprised_level3_001.mp4 differ
diff --git a/damo/dreamtalk/generators/base_function.py b/damo/dreamtalk/generators/base_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..49fe4cf3d07c4a22f7d7db4bf0a97ebddc87dd72
--- /dev/null
+++ b/damo/dreamtalk/generators/base_function.py
@@ -0,0 +1,368 @@
+import sys
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
+
+
+class LayerNorm2d(nn.Module):
+ def __init__(self, n_out, affine=True):
+ super(LayerNorm2d, self).__init__()
+ self.n_out = n_out
+ self.affine = affine
+
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
+
+ def forward(self, x):
+ normalized_shape = x.size()[1:]
+ if self.affine:
+ return F.layer_norm(x, normalized_shape, \
+ self.weight.expand(normalized_shape),
+ self.bias.expand(normalized_shape))
+
+ else:
+ return F.layer_norm(x, normalized_shape)
+
+class ADAINHourglass(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
+ super(ADAINHourglass, self).__init__()
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
+ self.output_nc = self.decoder.output_nc
+
+ def forward(self, x, z):
+ return self.decoder(self.encoder(x, z), z)
+
+
+
+class ADAINEncoder(nn.Module):
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoder, self).__init__()
+ self.layers = layers
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
+ for i in range(layers):
+ in_channels = min(ngf * (2**i), img_f)
+ out_channels = min(ngf *(2**(i+1)), img_f)
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
+ setattr(self, 'encoder' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = self.input_layer(x)
+ out_list = [out]
+ for i in range(self.layers):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out, z)
+ out_list.append(out)
+ return out_list
+
+class ADAINDecoder(nn.Module):
+ """docstring for ADAINDecoder"""
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
+
+ super(ADAINDecoder, self).__init__()
+ self.encoder_layers = encoder_layers
+ self.decoder_layers = decoder_layers
+ self.skip_connect = skip_connect
+ use_transpose = True
+
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
+ in_channels = min(ngf * (2**(i+1)), img_f)
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
+ out_channels = min(ngf * (2**i), img_f)
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
+ setattr(self, 'decoder' + str(i), model)
+
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
+
+ def forward(self, x, z):
+ out = x.pop() if self.skip_connect else x
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
+ model = getattr(self, 'decoder' + str(i))
+ out = model(out, z)
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
+ return out
+
+class ADAINEncoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINEncoderBlock, self).__init__()
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
+
+
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(output_nc, feature_nc)
+ self.actvn = nonlinearity
+
+ def forward(self, x, z):
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
+ return x
+
+class ADAINDecoderBlock(nn.Module):
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(ADAINDecoderBlock, self).__init__()
+ # Attributes
+ self.actvn = nonlinearity
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
+
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
+ if use_transpose:
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
+ else:
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
+
+ # create conv layers
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
+ if use_transpose:
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
+ else:
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
+ nn.Upsample(scale_factor=2))
+ # define normalization layers
+ self.norm_0 = ADAIN(input_nc, feature_nc)
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
+ self.norm_s = ADAIN(input_nc, feature_nc)
+
+ def forward(self, x, z):
+ x_s = self.shortcut(x, z)
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, z):
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
+ return x_s
+
+
+def spectral_norm(module, use_spect=True):
+ """use spectral normal layer to stable the training process"""
+ if use_spect:
+ return SpectralNorm(module)
+ else:
+ return module
+
+
+class ADAIN(nn.Module):
+ def __init__(self, norm_nc, feature_nc):
+ super().__init__()
+
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+
+ nhidden = 128
+ use_bias=True
+
+ self.mlp_shared = nn.Sequential(
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
+ nn.ReLU()
+ )
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
+
+ def forward(self, x, feature):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on feature
+ feature = feature.view(feature.size(0), -1)
+ actv = self.mlp_shared(feature)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
+ beta = beta.view(*beta.size()[:2], 1,1)
+ out = normalized * (1 + gamma) + beta
+ return out
+
+
+class FineEncoder(nn.Module):
+ """docstring for Encoder"""
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineEncoder, self).__init__()
+ self.layers = layers
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
+ for i in range(layers):
+ in_channels = min(ngf*(2**i), img_f)
+ out_channels = min(ngf*(2**(i+1)), img_f)
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'down' + str(i), model)
+ self.output_nc = out_channels
+
+ def forward(self, x):
+ x = self.first(x)
+ out=[x]
+ for i in range(self.layers):
+ model = getattr(self, 'down'+str(i))
+ x = model(x)
+ out.append(x)
+ return out
+
+class FineDecoder(nn.Module):
+ """docstring for FineDecoder"""
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineDecoder, self).__init__()
+ self.layers = layers
+ for i in range(layers)[::-1]:
+ in_channels = min(ngf*(2**(i+1)), img_f)
+ out_channels = min(ngf*(2**i), img_f)
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
+
+ setattr(self, 'up' + str(i), up)
+ setattr(self, 'res' + str(i), res)
+ setattr(self, 'jump' + str(i), jump)
+
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
+
+ self.output_nc = out_channels
+
+ def forward(self, x, z):
+ out = x.pop()
+ for i in range(self.layers)[::-1]:
+ res_model = getattr(self, 'res' + str(i))
+ up_model = getattr(self, 'up' + str(i))
+ jump_model = getattr(self, 'jump' + str(i))
+ out = res_model(out, z)
+ out = up_model(out)
+ out = jump_model(x.pop()) + out
+ out_image = self.final(out)
+ return out_image
+
+class FirstBlock2d(nn.Module):
+ """
+ Downsampling block for use in encoder.
+ """
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FirstBlock2d, self).__init__()
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(DownBlock2d, self).__init__()
+
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity, pool)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(UpBlock2d, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(F.interpolate(x, scale_factor=2))
+ return out
+
+class FineADAINResBlocks(nn.Module):
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlocks, self).__init__()
+ self.num_block = num_block
+ for i in range(num_block):
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
+ setattr(self, 'res'+str(i), model)
+
+ def forward(self, x, z):
+ for i in range(self.num_block):
+ model = getattr(self, 'res'+str(i))
+ x = model(x, z)
+ return x
+
+class Jump(nn.Module):
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(Jump, self).__init__()
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+
+ if type(norm_layer) == type(None):
+ self.model = nn.Sequential(conv, nonlinearity)
+ else:
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
+
+ def forward(self, x):
+ out = self.model(x)
+ return out
+
+class FineADAINResBlock2d(nn.Module):
+ """
+ Define an Residual block for different types
+ """
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
+ super(FineADAINResBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
+
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
+ self.norm1 = ADAIN(input_nc, feature_nc)
+ self.norm2 = ADAIN(input_nc, feature_nc)
+
+ self.actvn = nonlinearity
+
+
+ def forward(self, x, z):
+ dx = self.actvn(self.norm1(self.conv1(x), z))
+ dx = self.norm2(self.conv2(x), z)
+ out = dx + x
+ return out
+
+class FinalBlock2d(nn.Module):
+ """
+ Define the output layer
+ """
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
+ super(FinalBlock2d, self).__init__()
+
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
+
+ if tanh_or_sigmoid == 'sigmoid':
+ out_nonlinearity = nn.Sigmoid()
+ else:
+ out_nonlinearity = nn.Tanh()
+
+ self.model = nn.Sequential(conv, out_nonlinearity)
+ def forward(self, x):
+ out = self.model(x)
+ return out
\ No newline at end of file
diff --git a/damo/dreamtalk/generators/face_model.py b/damo/dreamtalk/generators/face_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..20392f0cb3bdbefb6ecdc20b43ffe0d7b87cbe6c
--- /dev/null
+++ b/damo/dreamtalk/generators/face_model.py
@@ -0,0 +1,127 @@
+import functools
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import generators.flow_util as flow_util
+from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
+
+class FaceGenerator(nn.Module):
+ def __init__(
+ self,
+ mapping_net,
+ warpping_net,
+ editing_net,
+ common
+ ):
+ super(FaceGenerator, self).__init__()
+ self.mapping_net = MappingNet(**mapping_net)
+ self.warpping_net = WarpingNet(**warpping_net, **common)
+ self.editing_net = EditingNet(**editing_net, **common)
+
+ def forward(
+ self,
+ input_image,
+ driving_source,
+ stage=None
+ ):
+ if stage == 'warp':
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ else:
+ descriptor = self.mapping_net(driving_source)
+ output = self.warpping_net(input_image, descriptor)
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
+ return output
+
+class MappingNet(nn.Module):
+ def __init__(self, coeff_nc, descriptor_nc, layer):
+ super( MappingNet, self).__init__()
+
+ self.layer = layer
+ nonlinearity = nn.LeakyReLU(0.1)
+
+ self.first = nn.Sequential(
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+ for i in range(layer):
+ net = nn.Sequential(nonlinearity,
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+ setattr(self, 'encoder' + str(i), net)
+
+ self.pooling = nn.AdaptiveAvgPool1d(1)
+ self.output_nc = descriptor_nc
+
+ def forward(self, input_3dmm):
+ out = self.first(input_3dmm)
+ for i in range(self.layer):
+ model = getattr(self, 'encoder' + str(i))
+ out = model(out) + out[:,:,3:-3]
+ out = self.pooling(out)
+ return out
+
+class WarpingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ base_nc,
+ max_nc,
+ encoder_layer,
+ decoder_layer,
+ use_spect
+ ):
+ super( WarpingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
+
+ self.descriptor_nc = descriptor_nc
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
+ max_nc, encoder_layer, decoder_layer, **kwargs)
+
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
+ nonlinearity,
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+
+ def forward(self, input_image, descriptor):
+ final_output={}
+ output = self.hourglass(input_image, descriptor)
+ final_output['flow_field'] = self.flow_out(output)
+
+ deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
+ final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
+ return final_output
+
+
+class EditingNet(nn.Module):
+ def __init__(
+ self,
+ image_nc,
+ descriptor_nc,
+ layer,
+ base_nc,
+ max_nc,
+ num_res_blocks,
+ use_spect):
+ super(EditingNet, self).__init__()
+
+ nonlinearity = nn.LeakyReLU(0.1)
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
+ self.descriptor_nc = descriptor_nc
+
+ # encoder part
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
+
+ def forward(self, input_image, warp_image, descriptor):
+ x = torch.cat([input_image, warp_image], 1)
+ x = self.encoder(x)
+ gen_image = self.decoder(x, descriptor)
+ return gen_image
diff --git a/damo/dreamtalk/generators/flow_util.py b/damo/dreamtalk/generators/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..376a6cbe222bfe3e1833b954e764e4e6c086c766
--- /dev/null
+++ b/damo/dreamtalk/generators/flow_util.py
@@ -0,0 +1,56 @@
+import torch
+
+def convert_flow_to_deformation(flow):
+ r"""convert flow fields to deformations.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ deformation (tensor): The deformation used for warpping
+ """
+ b,c,h,w = flow.shape
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
+ grid = make_coordinate_grid(flow)
+ deformation = grid + flow_norm.permute(0,2,3,1)
+ return deformation
+
+def make_coordinate_grid(flow):
+ r"""obtain coordinate grid with the same size as the flow filed.
+
+ Args:
+ flow (tensor): Flow field obtained by the model
+ Returns:
+ grid (tensor): The grid with the same size as the input flow
+ """
+ b,c,h,w = flow.shape
+
+ x = torch.arange(w).to(flow)
+ y = torch.arange(h).to(flow)
+
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ meshed = meshed.expand(b, -1, -1, -1)
+ return meshed
+
+
+def warp_image(source_image, deformation):
+ r"""warp the input image according to the deformation
+
+ Args:
+ source_image (tensor): source images to be warpped
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
+ Returns:
+ output (tensor): the warpped images
+ """
+ _, h_old, w_old, _ = deformation.shape
+ _, _, h, w = source_image.shape
+ if h_old != h or w_old != w:
+ deformation = deformation.permute(0, 3, 1, 2)
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
+ deformation = deformation.permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(source_image, deformation)
\ No newline at end of file
diff --git a/damo/dreamtalk/generators/renderer_conf.yaml b/damo/dreamtalk/generators/renderer_conf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd1a7973853f52338a13d0d80341c295190f3e75
--- /dev/null
+++ b/damo/dreamtalk/generators/renderer_conf.yaml
@@ -0,0 +1,17 @@
+common:
+ descriptor_nc: 256
+ image_nc: 3
+ max_nc: 256
+ use_spect: false
+editing_net:
+ base_nc: 64
+ layer: 3
+ num_res_blocks: 2
+mapping_net:
+ coeff_nc: 73
+ descriptor_nc: 256
+ layer: 3
+warpping_net:
+ base_nc: 32
+ decoder_layer: 3
+ encoder_layer: 5
diff --git a/damo/dreamtalk/generators/utils.py b/damo/dreamtalk/generators/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc3a62c6435add011c27e9a7fc941d18fc52559
--- /dev/null
+++ b/damo/dreamtalk/generators/utils.py
@@ -0,0 +1,114 @@
+import argparse
+import cv2
+import json
+import os
+
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+def obtain_seq_index(index, num_frames, radius):
+ seq = list(range(index - radius, index + radius + 1))
+ seq = [min(max(item, 0), num_frames - 1) for item in seq]
+ return seq
+
+
+@torch.no_grad()
+def get_netG(checkpoint_path):
+ from generators.face_model import FaceGenerator
+ import yaml
+
+ with open("generators/renderer_conf.yaml", "r") as f:
+ renderer_config = yaml.load(f, Loader=yaml.FullLoader)
+
+ renderer = FaceGenerator(**renderer_config).to(torch.cuda.current_device())
+
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
+ renderer.load_state_dict(checkpoint["net_G_ema"], strict=False)
+
+ renderer.eval()
+
+ return renderer
+
+
+@torch.no_grad()
+def render_video(
+ net_G,
+ src_img_path,
+ exp_path,
+ wav_path,
+ output_path,
+ silent=False,
+ semantic_radius=13,
+ fps=30,
+ split_size=16,
+ no_move=False,
+):
+ """
+ exp: (N, 73)
+ """
+ target_exp_seq = np.load(exp_path)
+ if target_exp_seq.shape[1] == 257:
+ exp_coeff = target_exp_seq[:, 80:144]
+ angle_trans_crop = np.array(
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9370641, 126.84911, 129.03864],
+ dtype=np.float32,
+ )
+ target_exp_seq = np.concatenate(
+ [exp_coeff, angle_trans_crop[None, ...].repeat(exp_coeff.shape[0], axis=0)],
+ axis=1,
+ )
+ # (L, 73)
+ elif target_exp_seq.shape[1] == 73:
+ if no_move:
+ target_exp_seq[:, 64:] = np.array(
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9370641, 126.84911, 129.03864],
+ dtype=np.float32,
+ )
+ else:
+ raise NotImplementedError
+
+ frame = cv2.imread(src_img_path)
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ src_img_raw = Image.fromarray(frame)
+ image_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
+ ]
+ )
+ src_img = image_transform(src_img_raw)
+
+ target_win_exps = []
+ for frame_idx in range(len(target_exp_seq)):
+ win_indices = obtain_seq_index(
+ frame_idx, target_exp_seq.shape[0], semantic_radius
+ )
+ win_exp = torch.tensor(target_exp_seq[win_indices]).permute(1, 0)
+ # (73, 27)
+ target_win_exps.append(win_exp)
+
+ target_exp_concat = torch.stack(target_win_exps, dim=0)
+ target_splited_exps = torch.split(target_exp_concat, split_size, dim=0)
+ output_imgs = []
+ for win_exp in target_splited_exps:
+ win_exp = win_exp.cuda()
+ cur_src_img = src_img.expand(win_exp.shape[0], -1, -1, -1).cuda()
+ output_dict = net_G(cur_src_img, win_exp)
+ output_imgs.append(output_dict["fake_image"].cpu().clamp_(-1, 1))
+
+ output_imgs = torch.cat(output_imgs, 0)
+ transformed_imgs = ((output_imgs + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1)
+
+ if silent:
+ torchvision.io.write_video(output_path, transformed_imgs.cpu(), fps)
+ else:
+ silent_video_path = f"{output_path}-silent.mp4"
+ torchvision.io.write_video(silent_video_path, transformed_imgs.cpu(), fps)
+ os.system(
+ f"ffmpeg -loglevel quiet -y -i {silent_video_path} -i {wav_path} -shortest {output_path}"
+ )
+ os.remove(silent_video_path)
diff --git a/damo/dreamtalk/inference_for_demo_video.py b/damo/dreamtalk/inference_for_demo_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f8a8520ccd6b6746a6420b749f33ba5882a833
--- /dev/null
+++ b/damo/dreamtalk/inference_for_demo_video.py
@@ -0,0 +1,234 @@
+import argparse
+import torch
+import json
+import os
+
+from scipy.io import loadmat
+import subprocess
+
+import numpy as np
+import torchaudio
+import shutil
+
+from core.utils import (
+ get_pose_params,
+ get_video_style_clip,
+ get_wav2vec_audio_window,
+ crop_src_image,
+)
+from configs.default import get_cfg_defaults
+from generators.utils import get_netG, render_video
+from core.networks.diffusion_net import DiffusionNet
+from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
+from transformers import Wav2Vec2Processor
+from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
+
+
+@torch.no_grad()
+def get_diff_net(cfg):
+ diff_net = DiffusionNet(
+ cfg=cfg,
+ net=NoisePredictor(cfg),
+ var_sched=VarianceSchedule(
+ num_steps=cfg.DIFFUSION.SCHEDULE.NUM_STEPS,
+ beta_1=cfg.DIFFUSION.SCHEDULE.BETA_1,
+ beta_T=cfg.DIFFUSION.SCHEDULE.BETA_T,
+ mode=cfg.DIFFUSION.SCHEDULE.MODE,
+ ),
+ )
+ checkpoint = torch.load(cfg.INFERENCE.CHECKPOINT)
+ model_state_dict = checkpoint["model_state_dict"]
+ diff_net_dict = {
+ k[9:]: v for k, v in model_state_dict.items() if k[:9] == "diff_net."
+ }
+ diff_net.load_state_dict(diff_net_dict, strict=True)
+ diff_net.eval()
+
+ return diff_net
+
+
+@torch.no_grad()
+def get_audio_feat(wav_path, output_name, wav2vec_model):
+ audio_feat_dir = os.path.dirname(audio_feat_path)
+
+ pass
+
+
+@torch.no_grad()
+def inference_one_video(
+ cfg,
+ audio_path,
+ style_clip_path,
+ pose_path,
+ output_path,
+ diff_net,
+ max_audio_len=None,
+ sample_method="ddim",
+ ddim_num_step=10,
+):
+ audio_raw = audio_data = np.load(audio_path)
+
+ if max_audio_len is not None:
+ audio_raw = audio_raw[: max_audio_len * 50]
+ gen_num_frames = len(audio_raw) // 2
+
+ audio_win_array = get_wav2vec_audio_window(
+ audio_raw,
+ start_idx=0,
+ num_frames=gen_num_frames,
+ win_size=cfg.WIN_SIZE,
+ )
+
+ audio_win = torch.tensor(audio_win_array).cuda()
+ audio = audio_win.unsqueeze(0)
+
+ # the second parameter is "" because of bad interface design...
+ style_clip_raw, style_pad_mask_raw = get_video_style_clip(
+ style_clip_path, "", style_max_len=256, start_idx=0
+ )
+
+ style_clip = style_clip_raw.unsqueeze(0).cuda()
+ style_pad_mask = (
+ style_pad_mask_raw.unsqueeze(0).cuda()
+ if style_pad_mask_raw is not None
+ else None
+ )
+
+ gen_exp_stack = diff_net.sample(
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim=cfg.DATASET.FACE3D_DIM,
+ use_cf_guidance=cfg.CF_GUIDANCE.INFERENCE,
+ cfg_scale=cfg.CF_GUIDANCE.SCALE,
+ sample_method=sample_method,
+ ddim_num_step=ddim_num_step,
+ )
+ gen_exp = gen_exp_stack[0].cpu().numpy()
+
+ pose_ext = pose_path[-3:]
+ pose = None
+ pose = get_pose_params(pose_path)
+ # (L, 9)
+
+ selected_pose = None
+ if len(pose) >= len(gen_exp):
+ selected_pose = pose[: len(gen_exp)]
+ else:
+ selected_pose = pose[-1].unsqueeze(0).repeat(len(gen_exp), 1)
+ selected_pose[: len(pose)] = pose
+
+ gen_exp_pose = np.concatenate((gen_exp, selected_pose), axis=1)
+ np.save(output_path, gen_exp_pose)
+ return output_path
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="inference for demo")
+ parser.add_argument("--wav_path", type=str, default="", help="path for wav")
+ parser.add_argument("--image_path", type=str, default="", help="path for image")
+ parser.add_argument("--disable_img_crop", dest="img_crop", action="store_false")
+ parser.set_defaults(img_crop=True)
+
+ parser.add_argument(
+ "--style_clip_path", type=str, default="", help="path for style_clip_mat"
+ )
+ parser.add_argument("--pose_path", type=str, default="", help="path for pose")
+ parser.add_argument(
+ "--max_gen_len",
+ type=int,
+ default=1000,
+ help="The maximum length (seconds) limitation for generating videos",
+ )
+ parser.add_argument(
+ "--cfg_scale",
+ type=float,
+ default=1.0,
+ help="The scale of classifier-free guidance",
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default="test",
+ )
+ args = parser.parse_args()
+
+ cfg = get_cfg_defaults()
+ cfg.CF_GUIDANCE.SCALE = args.cfg_scale
+ cfg.freeze()
+
+ tmp_dir = f"tmp/{args.output_name}"
+ os.makedirs(tmp_dir, exist_ok=True)
+
+ # get audio in 16000Hz
+ wav_16k_path = os.path.join(tmp_dir, f"{args.output_name}_16K.wav")
+ command = f"ffmpeg -y -i {args.wav_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {wav_16k_path}"
+ subprocess.run(command.split())
+
+ # get wav2vec feat from audio
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(
+ "jonatasgrosman/wav2vec2-large-xlsr-53-english"
+ )
+ wav2vec_model = (
+ Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
+ .eval()
+ .cuda()
+ )
+
+ speech_array, sampling_rate = torchaudio.load(wav_16k_path)
+ audio_data = speech_array.squeeze().numpy()
+ inputs = wav2vec_processor(
+ audio_data, sampling_rate=16_000, return_tensors="pt", padding=True
+ )
+
+ with torch.no_grad():
+ audio_embedding = wav2vec_model(inputs.input_values.cuda(), return_dict=False)[
+ 0
+ ]
+
+ audio_feat_path = os.path.join(tmp_dir, f"{args.output_name}_wav2vec.npy")
+ np.save(audio_feat_path, audio_embedding[0].cpu().numpy())
+
+ # get src image
+ src_img_path = os.path.join(tmp_dir, "src_img.png")
+ if args.img_crop:
+ crop_src_image(args.image_path, src_img_path, 0.4)
+ else:
+ shutil.copy(args.image_path, src_img_path)
+
+ with torch.no_grad():
+ # get diff model and load checkpoint
+ diff_net = get_diff_net(cfg).cuda()
+ # generate face motion
+ face_motion_path = os.path.join(tmp_dir, f"{args.output_name}_facemotion.npy")
+ inference_one_video(
+ cfg,
+ audio_feat_path,
+ args.style_clip_path,
+ args.pose_path,
+ face_motion_path,
+ diff_net,
+ max_audio_len=args.max_gen_len,
+ )
+ # get renderer
+ renderer = get_netG("checkpoints/renderer.pt")
+ # render video
+ output_video_path = f"output_video/{args.output_name}.mp4"
+ render_video(
+ renderer,
+ src_img_path,
+ face_motion_path,
+ wav_16k_path,
+ output_video_path,
+ fps=25,
+ no_move=False,
+ )
+
+ # add watermark
+ # if you want to generate videos with no watermark (for evaluation), remove this code block.
+ no_watermark_video_path = f"{output_video_path}-no_watermark.mp4"
+ shutil.move(output_video_path, no_watermark_video_path)
+ os.system(
+ f'ffmpeg -y -i {no_watermark_video_path} -vf "movie=media/watermark.png,scale= 120: 36[watermask]; [in] [watermask] overlay=140:220 [out]" {output_video_path}'
+ )
+ os.remove(no_watermark_video_path)
diff --git a/damo/dreamtalk/media/teaser.gif b/damo/dreamtalk/media/teaser.gif
new file mode 100644
index 0000000000000000000000000000000000000000..a58b41e0e9f84b5848a0a51ab10273a39223b73a
--- /dev/null
+++ b/damo/dreamtalk/media/teaser.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73f1dc7c8baae253b789e9bf7be33214e6b3e68107433802d2d0e62117e6cd29
+size 6599699
diff --git a/damo/dreamtalk/media/watermark.png b/damo/dreamtalk/media/watermark.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd0970845bf2032aa78ff6d3fad38af7df8963d3
Binary files /dev/null and b/damo/dreamtalk/media/watermark.png differ
diff --git a/damo/dreamtalk/ms_wrapper.py b/damo/dreamtalk/ms_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ca56d9c753006c3507e832c8b744f7f017ac5c
--- /dev/null
+++ b/damo/dreamtalk/ms_wrapper.py
@@ -0,0 +1,250 @@
+import argparse
+import torch
+import json
+import os
+from typing import Union, Dict, Any
+import sys
+
+from scipy.io import loadmat
+import subprocess
+
+import numpy as np
+import torchaudio
+import shutil
+
+from core.utils import (
+ get_pose_params,
+ get_video_style_clip,
+ get_wav2vec_audio_window,
+ crop_src_image,
+)
+
+from configs.default import get_cfg_defaults
+from generators.utils import get_netG, render_video
+from core.networks.diffusion_net import DiffusionNet
+from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
+from transformers import Wav2Vec2Processor
+from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
+
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.models.builder import MODELS
+from modelscope.utils.constant import Tasks
+from modelscope.pipelines.base import Pipeline
+from modelscope.models.base import Model, TorchModel
+from modelscope.utils.logger import get_logger
+from modelscope import snapshot_download
+
+
+@torch.no_grad()
+def get_diff_net(cfg, model_dir=None):
+ diff_net = DiffusionNet(
+ cfg=cfg,
+ net=NoisePredictor(cfg),
+ var_sched=VarianceSchedule(
+ num_steps=cfg.DIFFUSION.SCHEDULE.NUM_STEPS,
+ beta_1=cfg.DIFFUSION.SCHEDULE.BETA_1,
+ beta_T=cfg.DIFFUSION.SCHEDULE.BETA_T,
+ mode=cfg.DIFFUSION.SCHEDULE.MODE,
+ ),
+ )
+ checkpoint = torch.load(model_dir+"/"+cfg.INFERENCE.CHECKPOINT)
+ model_state_dict = checkpoint["model_state_dict"]
+ diff_net_dict = {
+ k[9:]: v for k, v in model_state_dict.items() if k[:9] == "diff_net."
+ }
+ diff_net.load_state_dict(diff_net_dict, strict=True)
+ diff_net.eval()
+
+ return diff_net
+
+@torch.no_grad()
+def get_audio_feat(wav_path, output_name, wav2vec_model):
+ audio_feat_dir = os.path.dirname(audio_feat_path)
+
+ pass
+
+@torch.no_grad()
+def inference_one_video(
+ cfg,
+ audio_path,
+ style_clip_path,
+ pose_path,
+ output_path,
+ diff_net,
+ max_audio_len=None,
+ sample_method="ddim",
+ ddim_num_step=10,
+):
+ audio_raw = audio_data = np.load(audio_path)
+
+ if max_audio_len is not None:
+ audio_raw = audio_raw[: max_audio_len * 50]
+ gen_num_frames = len(audio_raw) // 2
+
+ audio_win_array = get_wav2vec_audio_window(
+ audio_raw,
+ start_idx=0,
+ num_frames=gen_num_frames,
+ win_size=cfg.WIN_SIZE,
+ )
+
+ audio_win = torch.tensor(audio_win_array).cuda()
+ audio = audio_win.unsqueeze(0)
+
+ # the second parameter is "" because of bad interface design...
+ style_clip_raw, style_pad_mask_raw = get_video_style_clip(
+ style_clip_path, "", style_max_len=256, start_idx=0
+ )
+
+ style_clip = style_clip_raw.unsqueeze(0).cuda()
+ style_pad_mask = (
+ style_pad_mask_raw.unsqueeze(0).cuda()
+ if style_pad_mask_raw is not None
+ else None
+ )
+
+ gen_exp_stack = diff_net.sample(
+ audio,
+ style_clip,
+ style_pad_mask,
+ output_dim=cfg.DATASET.FACE3D_DIM,
+ use_cf_guidance=cfg.CF_GUIDANCE.INFERENCE,
+ cfg_scale=cfg.CF_GUIDANCE.SCALE,
+ sample_method=sample_method,
+ ddim_num_step=ddim_num_step,
+ )
+ gen_exp = gen_exp_stack[0].cpu().numpy()
+
+ pose_ext = pose_path[-3:]
+ pose = None
+ pose = get_pose_params(pose_path)
+ # (L, 9)
+
+ selected_pose = None
+ if len(pose) >= len(gen_exp):
+ selected_pose = pose[: len(gen_exp)]
+ else:
+ selected_pose = pose[-1].unsqueeze(0).repeat(len(gen_exp), 1)
+ selected_pose[: len(pose)] = pose
+
+ gen_exp_pose = np.concatenate((gen_exp, selected_pose), axis=1)
+ np.save(output_path, gen_exp_pose)
+ return output_path
+
+@PIPELINES.register_module(Tasks.text_to_video_synthesis, module_name='Dreamtalk-generation-pipe')
+class DreamTalkPipeline(Pipeline):
+ def __init__(
+ self,
+ model: Union[Model, str],
+ *args,
+ **kwargs):
+ model = DreamTalkMS(model, **kwargs) if isinstance(model, str) else model
+ super().__init__(model=model, **kwargs)
+
+ def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
+ return inputs
+
+ def _sanitize_parameters(self, **pipeline_parameters):
+ return {},pipeline_parameters,{}
+
+ # define the forward pass
+ def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
+ return self.model(inputs,**forward_params)
+
+ # format the outputs from pipeline
+ def postprocess(self, input, **kwargs) -> Dict[str, Any]:
+ return input
+
+
+@MODELS.register_module(Tasks.text_to_video_synthesis, module_name='Dreamtalk-Generation')
+class DreamTalkMS(TorchModel):
+ def __init__(self, model_dir=None, *args, **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+ self.logger = get_logger()
+ self.style_clip_path = kwargs.get("style_clip_path", "")
+ self.pose_path = kwargs.get("pose_path", "")
+ os.chdir(model_dir)
+
+ if not os.path.exists(self.style_clip_path):
+ self.style_clip_path = os.path.join(model_dir, self.style_clip_path)
+
+ if not os.path.exists(self.pose_path):
+ self.pose_path = os.path.join(model_dir, self.pose_path)
+
+ self.cfg = get_cfg_defaults()
+ self.cfg.freeze()
+
+ # get wav2vec feat from audio
+ wav2vec_local_dir = snapshot_download("AI-ModelScope/wav2vec2-large-xlsr-53-english",revision='master')
+ self.wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_local_dir)
+ self.wav2vec_model = (
+ Wav2Vec2Model.from_pretrained(wav2vec_local_dir)
+ .eval()
+ .cuda()
+ )
+ self.diff_net = get_diff_net(self.cfg, model_dir).cuda()
+ # get renderer
+ self.renderer = get_netG(model_dir+"/"+"checkpoints/renderer.pt")
+ self.model_dir = model_dir
+
+ def forward(self, input: Dict, *args, **kwargs) -> Dict[str, Any]:
+ output_name = input.get("output_name", "")
+ wav_path = input.get("wav_path", "")
+ img_crop = input.get("img_crop", True)
+ image_path = input.get("image_path", "")
+ max_gen_len = input.get("max_gen_len",1000)
+ sys.path.append(self.model_dir)
+
+ tmp_dir = f"tmp/{output_name}"
+ os.makedirs(tmp_dir, exist_ok=True)
+
+ # get audio in 16000Hz
+ wav_16k_path = os.path.join(tmp_dir, f"{output_name}_16K.wav")
+ if not os.path.exists(wav_path):
+ wav_path = os.path.join(self.model_dir, wav_path)
+ command = f"ffmpeg -y -i {wav_path} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {wav_16k_path}"
+ subprocess.run(command.split())
+
+ speech_array, sampling_rate = torchaudio.load(wav_16k_path)
+ audio_data = speech_array.squeeze().numpy()
+
+ inputs = self.wav2vec_processor(
+ audio_data, sampling_rate=16_000, return_tensors="pt", padding=True
+ )
+ with torch.no_grad():
+ audio_embedding = self.wav2vec_model(inputs.input_values.cuda(), return_dict=False)[0]
+
+ audio_feat_path = os.path.join(tmp_dir, f"{output_name}_wav2vec.npy")
+ np.save(audio_feat_path, audio_embedding[0].cpu().numpy())
+
+ # get src image
+ src_img_path = os.path.join(tmp_dir, "src_img.png")
+ if not os.path.exists(image_path):
+ image_path = os.path.join(self.model_dir, image_path)
+ if img_crop:
+ crop_src_image(image_path, src_img_path, 0.4)
+ else:
+ shutil.copy(image_path, src_img_path)
+
+ with torch.no_grad():
+ face_motion_path = os.path.join(tmp_dir, f"{output_name}_facemotion.npy")
+ inference_one_video(
+ self.cfg,
+ audio_feat_path,
+ self.style_clip_path,
+ self.pose_path,
+ face_motion_path,
+ self.diff_net,
+ max_audio_len=max_gen_len,
+ )
+ # render video
+ output_video_path = f"output_video/{output_name}.mp4"
+ render_video(
+ self.renderer,
+ src_img_path,
+ face_motion_path,
+ wav_16k_path,
+ output_video_path,
+ fps=25,
+ no_move=False,
+ )
\ No newline at end of file
diff --git a/damo/dreamtalk/output_video/.gitkeep b/damo/dreamtalk/output_video/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/damo/dreamtalk/output_video/acknowledgement_chinese@M030_front_surprised_level3_001@zp1.mp4 b/damo/dreamtalk/output_video/acknowledgement_chinese@M030_front_surprised_level3_001@zp1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3622b10a0a6cba0dcb14a30d876ad3db8a5cec20
Binary files /dev/null and b/damo/dreamtalk/output_video/acknowledgement_chinese@M030_front_surprised_level3_001@zp1.mp4 differ
diff --git a/damo/dreamtalk/output_video/acknowledgement_english@M030_front_neutral_level1_001@male_face.mp4 b/damo/dreamtalk/output_video/acknowledgement_english@M030_front_neutral_level1_001@male_face.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1ab55c59056c05bb16ee265ded23e80a6153f0c9
Binary files /dev/null and b/damo/dreamtalk/output_video/acknowledgement_english@M030_front_neutral_level1_001@male_face.mp4 differ
diff --git a/damo/dreamtalk/tmp/.gitkeep b/damo/dreamtalk/tmp/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_16K.wav b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_16K.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ac75156b793e236da85e9dff780c6f239f4b21fb
Binary files /dev/null and b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_16K.wav differ
diff --git a/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_facemotion.npy b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_facemotion.npy
new file mode 100644
index 0000000000000000000000000000000000000000..abacb6015fe728a5af85ad929c693ec666035509
--- /dev/null
+++ b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_facemotion.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:213f7819a3c04ca9612794c6703e1d304631e1c7bc40517ed56c8ac7a3ca57ab
+size 127148
diff --git a/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_wav2vec.npy b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_wav2vec.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9b7c40200b441cb62e1847958f548149f7ceddf7
--- /dev/null
+++ b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/acknowledgement_chinese@M030_front_surprised_level3_001@zp1_wav2vec.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82cd887f905120ffbb6b5232a0ccd9c28474b9eea1afc95cd5ab88cc22d8fb29
+size 3563648
diff --git a/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/src_img.png b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/src_img.png
new file mode 100644
index 0000000000000000000000000000000000000000..20997f7900ecc40d1c5833c491f4fbf7f3d06fda
Binary files /dev/null and b/damo/dreamtalk/tmp/acknowledgement_chinese@M030_front_surprised_level3_001@zp1/src_img.png differ
diff --git a/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_16K.wav b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_16K.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d599afcf69e8d79bc2b50b84bbb86baddbc0317f
Binary files /dev/null and b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_16K.wav differ
diff --git a/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_facemotion.npy b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_facemotion.npy
new file mode 100644
index 0000000000000000000000000000000000000000..f1b9b15990ea0d87903c0c8b05e2b975b2eb6dfe
--- /dev/null
+++ b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_facemotion.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f91c73ed2bd9bb16534f591d660288482010f505dc924d59c848497335d215a
+size 121016
diff --git a/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_wav2vec.npy b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_wav2vec.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9ea693de888ca66e39133631bcae8b82c8227ade
--- /dev/null
+++ b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/acknowledgement_english@M030_front_neutral_level1_001@male_face_wav2vec.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec4208ce8353cf48b09515c0a367dafe76ab666800e5cac6e71050b4770b76f2
+size 3391616
diff --git a/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/src_img.png b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/src_img.png
new file mode 100644
index 0000000000000000000000000000000000000000..58ca7bcd07d2a57b7f827bfdf47aa376c5b341c5
Binary files /dev/null and b/damo/dreamtalk/tmp/acknowledgement_english@M030_front_neutral_level1_001@male_face/src_img.png differ