Spaces:
Paused
Paused
| import train | |
| import os | |
| import time | |
| import csv | |
| import sys | |
| import warnings | |
| import random | |
| import numpy as np | |
| import time | |
| import pprint | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from loguru import logger | |
| import smplx | |
| import librosa | |
| from utils import config, logger_tools, other_tools, metric | |
| from utils import rotation_conversions as rc | |
| from dataloaders import data_tools | |
| from optimizers.optim_factory import create_optimizer | |
| from optimizers.scheduler_factory import create_scheduler | |
| from optimizers.loss_factory import get_loss_func | |
| from scipy.spatial.transform import Rotation | |
| class CustomTrainer(train.BaseTrainer): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.joints = self.train_data.joints | |
| self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'div_reg', "kl"], [False,True,True, False, False, False, False, False, False, False, False, False, False]) | |
| if not self.args.rot6d: #"rot6d" not in args.pose_rep: | |
| logger.error(f"this script is for rot6d, your pose rep. is {args.pose_rep}") | |
| self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) | |
| self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) | |
| def _load_data(self, dict_data): | |
| tar_pose = dict_data["pose"].to(self.rank) | |
| tar_trans = dict_data["trans"].to(self.rank) | |
| tar_exps = dict_data["facial"].to(self.rank) | |
| tar_beta = dict_data["beta"].to(self.rank) | |
| tar_id = dict_data["id"].to(self.rank).long() | |
| tar_word = dict_data["word"].to(self.rank) | |
| in_audio = dict_data["audio"].to(self.rank) | |
| in_emo = dict_data["emo"].to(self.rank) | |
| #in_sem = dict_data["sem"].to(self.rank) | |
| bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints | |
| tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, j, 3)) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| in_pre_pose_cat = torch.cat([tar_pose[:, 0:self.args.pre_frames], tar_trans[:, :self.args.pre_frames]], dim=2).to(self.rank) | |
| in_pre_pose = tar_pose.new_zeros((bs, n, j*6+1+3)).to(self.rank) | |
| in_pre_pose[:, 0:self.args.pre_frames, :-1] = in_pre_pose_cat[:, 0:self.args.pre_frames] | |
| in_pre_pose[:, 0:self.args.pre_frames, -1] = 1 | |
| return { | |
| "tar_pose": tar_pose, | |
| "in_audio": in_audio, | |
| "in_motion": in_pre_pose, | |
| "tar_trans": tar_trans, | |
| "tar_exps": tar_exps, | |
| "tar_beta": tar_beta, | |
| "tar_word": tar_word, | |
| 'tar_id': tar_id, | |
| 'in_emo': in_emo, | |
| #'in_sem': in_sem, | |
| } | |
| def _d_training(self, loaded_data): | |
| bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints | |
| net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) | |
| rec_pose = net_out["rec_pose"][:, :, :j*6] | |
| # rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] | |
| rec_pose = rec_pose.reshape(bs, n, j, 6) | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose) | |
| rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
| tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| out_d_fake = self.d_model(rec_pose) | |
| out_d_real = self.d_model(tar_pose) | |
| d_loss_adv = torch.sum(-torch.mean(torch.log(out_d_real + 1e-8) + torch.log(1 - out_d_fake + 1e-8))) | |
| self.tracker.update_meter("dis", "train", d_loss_adv.item()) | |
| return d_loss_adv | |
| def _g_training(self, loaded_data, use_adv, mode="train"): | |
| bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints | |
| net_out = self.model(in_audio = loaded_data['in_audio'], pre_seq = loaded_data["in_motion"], in_text=loaded_data["tar_word"], in_id=loaded_data["tar_id"], in_emo=loaded_data["in_emo"], in_facial = loaded_data["tar_exps"]) | |
| rec_pose = net_out["rec_pose"][:, :, :j*6] | |
| rec_trans = net_out["rec_pose"][:, :, j*6:j*6+3] | |
| # print(rec_pose.shape, bs, n, j, loaded_data['in_audio'].shape, loaded_data["in_motion"].shape) | |
| rec_pose = rec_pose.reshape(bs, n, j, 6) | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose) | |
| tar_pose = rc.rotation_6d_to_matrix(loaded_data["tar_pose"].reshape(bs, n, j, 6)) | |
| rec_loss = self.rec_loss(tar_pose, rec_pose) | |
| rec_loss *= self.args.rec_weight | |
| self.tracker.update_meter("rec", mode, rec_loss.item()) | |
| # rec_loss_vel = self.vel_loss(rec_pose[:, 1:] - rec_pose[:, :-1], tar_pose[:, 1:] - tar_pose[:, :-1]) | |
| # self.tracker.update_meter("vel", mode, rec_loss_vel.item()) | |
| # rec_loss_acc = self.vel_loss(rec_pose[:, 2:] - 2*rec_pose[:, 1:-1] + rec_pose[:, :-2], tar_pose[:, 2:] - 2*tar_pose[:, 1:-1] + tar_pose[:, :-2]) | |
| # self.tracker.update_meter("acc", mode, rec_loss_acc.item()) | |
| rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| if self.args.pose_dims < 330 and mode != "train": | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs, n, j, 6)) | |
| rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs, n, j*3) | |
| rec_pose = self.inverse_selection_tensor(rec_pose, self.train_data.joint_mask, rec_pose.shape[0]) | |
| rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs, n, 55, 3)) | |
| rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, 55*6) | |
| tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs, n, j, 6)) | |
| tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs, n, j*3) | |
| tar_pose = self.inverse_selection_tensor(tar_pose, self.train_data.joint_mask, tar_pose.shape[0]) | |
| tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) | |
| tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, 55*6) | |
| if use_adv and mode == 'train': | |
| out_d_fake = self.d_model(rec_pose) | |
| d_loss_adv = -torch.mean(torch.log(out_d_fake + 1e-8)) | |
| self.tracker.update_meter("gen", mode, d_loss_adv.item()) | |
| else: | |
| d_loss_adv = 0 | |
| if self.args.train_trans: | |
| trans_loss = self.vel_loss(rec_trans, loaded_data["tar_trans"]) | |
| trans_loss *= self.args.rec_weight | |
| self.tracker.update_meter("trans", mode, trans_loss.item()) | |
| else: | |
| trans_loss = 0 | |
| # trans_loss_vel = self.vel_loss(rec_trans[:, 1:] - rec_trans[:, :-1], loaded_data["tar_trans"][:, 1:] - loaded_data["tar_trans"][:, :-1]) | |
| # self.tracker.update_meter("transv", mode, trans_loss_vel.item()) | |
| # trans_loss_acc = self.vel_loss(rec_trans[:, 2:] - 2*rec_trans[:, 1:-1] + rec_trans[:, :-2], loaded_data["tar_trans"][:, 2:] - 2*loaded_data["tar_trans"][:, 1:-1] + loaded_data["tar_trans"][:, :-2]) | |
| # self.tracker.update_meter("transa", mode, trans_loss_acc.item()) | |
| if mode == 'train': | |
| return d_loss_adv + rec_loss + trans_loss # + rec_loss_vel + rec_loss_acc + trans_loss_vel + trans_loss_acc | |
| elif mode == 'val': | |
| return { | |
| 'rec_pose': rec_pose, | |
| 'rec_trans': rec_trans, | |
| 'tar_pose': tar_pose, | |
| } | |
| else: | |
| return { | |
| 'rec_pose': rec_pose, | |
| 'rec_trans': rec_trans, | |
| 'tar_pose': tar_pose, | |
| 'tar_exps': loaded_data["tar_exps"], | |
| 'tar_beta': loaded_data["tar_beta"], | |
| 'tar_trans': loaded_data["tar_trans"], | |
| } | |
| def train(self, epoch): | |
| use_adv = bool(epoch>=self.args.no_adv_epoch) | |
| self.model.train() | |
| self.d_model.train() | |
| self.tracker.reset() | |
| t_start = time.time() | |
| for its, batch_data in enumerate(self.train_loader): | |
| loaded_data = self._load_data(batch_data) | |
| t_data = time.time() - t_start | |
| if use_adv: | |
| d_loss_final = 0 | |
| self.opt_d.zero_grad() | |
| d_loss_adv = self._d_training(loaded_data) | |
| d_loss_final += d_loss_adv | |
| d_loss_final.backward() | |
| self.opt_d.step() | |
| self.opt.zero_grad() | |
| g_loss_final = 0 | |
| g_loss_final += self._g_training(loaded_data, use_adv, 'train') | |
| g_loss_final.backward() | |
| self.opt.step() | |
| mem_cost = torch.cuda.memory_cached() / 1E9 | |
| lr_g = self.opt.param_groups[0]['lr'] | |
| lr_d = self.opt_d.param_groups[0]['lr'] | |
| t_train = time.time() - t_start - t_data | |
| t_start = time.time() | |
| if its % self.args.log_period == 0: | |
| self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g, lr_d=lr_d) | |
| if self.args.debug: | |
| if its == 1: break | |
| self.opt_s.step(epoch) | |
| self.opt_d_s.step(epoch) | |
| def val(self, epoch): | |
| self.model.eval() | |
| self.d_model.eval() | |
| with torch.no_grad(): | |
| for its, batch_data in enumerate(self.train_loader): | |
| loaded_data = self._load_data(batch_data) | |
| net_out = self._g_training(loaded_data, False, 'val') | |
| tar_pose = net_out['tar_pose'] | |
| rec_pose = net_out['rec_pose'] | |
| n = tar_pose.shape[1] | |
| if (30/self.args.pose_fps) != 1: | |
| assert 30%self.args.pose_fps == 0 | |
| n *= int(30/self.args.pose_fps) | |
| tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| n = tar_pose.shape[1] | |
| remain = n%self.args.vae_test_len | |
| tar_pose = tar_pose[:, :n-remain, :] | |
| rec_pose = rec_pose[:, :n-remain, :] | |
| latent_out = self.eval_copy.map2latent(rec_pose).reshape(-1, self.args.vae_length).cpu().numpy() | |
| latent_ori = self.eval_copy.map2latent(tar_pose).reshape(-1, self.args.vae_length).cpu().numpy() | |
| if its == 0: | |
| latent_out_motion_all = latent_out | |
| latent_ori_all = latent_ori | |
| else: | |
| latent_out_motion_all = np.concatenate([latent_out_motion_all, latent_out], axis=0) | |
| latent_ori_all = np.concatenate([latent_ori_all, latent_ori], axis=0) | |
| if self.args.debug: | |
| if its == 1: break | |
| fid_motion = data_tools.FIDCalculator.frechet_distance(latent_out_motion_all, latent_ori_all) | |
| self.tracker.update_meter("fid", "val", fid_motion) | |
| self.val_recording(epoch) | |
| def test(self, epoch): | |
| results_save_path = self.checkpoint_path + f"/{epoch}/" | |
| if os.path.exists(results_save_path): | |
| return 0 | |
| os.makedirs(results_save_path) | |
| start_time = time.time() | |
| total_length = 0 | |
| test_seq_list = self.test_data.selected_file | |
| align = 0 | |
| latent_out = [] | |
| latent_ori = [] | |
| self.model.eval() | |
| self.smplx.eval() | |
| self.eval_copy.eval() | |
| with torch.no_grad(): | |
| for its, batch_data in enumerate(self.test_loader): | |
| loaded_data = self._load_data(batch_data) | |
| net_out = self._g_training(loaded_data, False, 'test') | |
| tar_pose = net_out['tar_pose'] | |
| rec_pose = net_out['rec_pose'] | |
| tar_exps = net_out['tar_exps'] | |
| tar_beta = net_out['tar_beta'] | |
| rec_trans = net_out['rec_trans'] | |
| tar_trans = net_out['tar_trans'] | |
| bs, n, j = tar_pose.shape[0], tar_pose.shape[1], 55 | |
| if (30/self.args.pose_fps) != 1: | |
| assert 30%self.args.pose_fps == 0 | |
| n *= int(30/self.args.pose_fps) | |
| tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| tar_beta = torch.nn.functional.interpolate(tar_beta.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| tar_exps = torch.nn.functional.interpolate(tar_exps.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| tar_trans = torch.nn.functional.interpolate(tar_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| rec_trans = torch.nn.functional.interpolate(rec_trans.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) | |
| # print(rec_pose.shape, tar_pose.shape) | |
| # rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
| # rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) | |
| # tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
| # tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) | |
| remain = n%self.args.vae_test_len | |
| latent_out.append(self.eval_copy.map2latent(rec_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) # bs * n/8 * 240 | |
| latent_ori.append(self.eval_copy.map2latent(tar_pose[:, :n-remain]).reshape(-1, self.args.vae_length).detach().cpu().numpy()) | |
| rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) | |
| rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) | |
| tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) | |
| tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) | |
| vertices_rec = self.smplx( | |
| betas=tar_beta.reshape(bs*n, 300), | |
| transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), | |
| expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), | |
| jaw_pose=rec_pose[:, 66:69], | |
| global_orient=rec_pose[:,:3], | |
| body_pose=rec_pose[:,3:21*3+3], | |
| left_hand_pose=rec_pose[:,25*3:40*3], | |
| right_hand_pose=rec_pose[:,40*3:55*3], | |
| return_joints=True, | |
| leye_pose=rec_pose[:, 69:72], | |
| reye_pose=rec_pose[:, 72:75], | |
| ) | |
| # vertices_tar = self.smplx( | |
| # betas=tar_beta.reshape(bs*n, 300), | |
| # transl=rec_trans.reshape(bs*n, 3)-rec_trans.reshape(bs*n, 3), | |
| # expression=tar_exps.reshape(bs*n, 100)-tar_exps.reshape(bs*n, 100), | |
| # jaw_pose=tar_pose[:, 66:69], | |
| # global_orient=tar_pose[:,:3], | |
| # body_pose=tar_pose[:,3:21*3+3], | |
| # left_hand_pose=tar_pose[:,25*3:40*3], | |
| # right_hand_pose=tar_pose[:,40*3:55*3], | |
| # return_joints=True, | |
| # leye_pose=tar_pose[:, 69:72], | |
| # reye_pose=tar_pose[:, 72:75], | |
| # ) | |
| joints_rec = vertices_rec["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] | |
| # joints_tar = vertices_tar["joints"].detach().cpu().numpy().reshape(1, n, 127*3)[0, :n, :55*3] | |
| _ = self.l1_calculator.run(joints_rec) | |
| if self.alignmenter is not None: | |
| in_audio_eval, sr = librosa.load(self.args.data_path+"wave16k/"+test_seq_list.iloc[its]['id']+".wav") | |
| in_audio_eval = librosa.resample(in_audio_eval, orig_sr=sr, target_sr=self.args.audio_sr) | |
| a_offset = int(self.align_mask * (self.args.audio_sr / self.args.pose_fps)) | |
| onset_bt = self.alignmenter.load_audio(in_audio_eval[:int(self.args.audio_sr / self.args.pose_fps*n)], a_offset, len(in_audio_eval)-a_offset, True) | |
| beat_vel = self.alignmenter.load_pose(joints_rec, self.align_mask, n-self.align_mask, 30, True) | |
| # print(beat_vel) | |
| align += (self.alignmenter.calculate_align(onset_bt, beat_vel, 30) * (n-2*self.align_mask)) | |
| tar_pose_axis_np = tar_pose.detach().cpu().numpy() | |
| rec_pose_axis_np = rec_pose.detach().cpu().numpy() | |
| rec_trans_np = rec_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
| rec_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
| tar_exp_np = tar_exps.detach().cpu().numpy().reshape(bs*n, 100) - tar_exps.detach().cpu().numpy().reshape(bs*n, 100) | |
| tar_trans_np = tar_trans.detach().cpu().numpy().reshape(bs*n, 3) | |
| gt_npz = np.load(self.args.data_path+self.args.pose_rep +"/"+test_seq_list.iloc[its]['id']+".npz", allow_pickle=True) | |
| if not self.args.train_trans: | |
| tar_trans_np = tar_trans_np - tar_trans_np | |
| rec_trans_np = rec_trans_np - rec_trans_np | |
| np.savez(results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', | |
| betas=gt_npz["betas"], | |
| poses=tar_pose_axis_np, | |
| expressions=tar_exp_np, | |
| trans=tar_trans_np, | |
| model='smplx2020', | |
| gender='neutral', | |
| mocap_frame_rate = 30 , | |
| ) | |
| np.savez(results_save_path+"res_"+test_seq_list.iloc[its]['id']+'.npz', | |
| betas=gt_npz["betas"], | |
| poses=rec_pose_axis_np, | |
| expressions=rec_exp_np, | |
| trans=rec_trans_np, | |
| model='smplx2020', | |
| gender='neutral', | |
| mocap_frame_rate = 30, | |
| ) | |
| total_length += n | |
| latent_out_all = np.concatenate(latent_out, axis=0) | |
| latent_ori_all = np.concatenate(latent_ori, axis=0) | |
| fid = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all) | |
| logger.info(f"fid score: {fid}") | |
| self.test_recording("fid", fid, epoch) | |
| align_avg = align/(total_length-2*len(self.test_loader)*self.align_mask) | |
| logger.info(f"align score: {align_avg}") | |
| self.test_recording("bc", align_avg, epoch) | |
| l1div = self.l1_calculator.avg() | |
| logger.info(f"l1div score: {l1div}") | |
| self.test_recording("l1div", l1div, epoch) | |
| # data_tools.result2target_vis(self.args.pose_version, results_save_path, results_save_path, self.test_demo, False) | |
| end_time = time.time() - start_time | |
| logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") |