刘虹雨 commited on
Commit
8f481d2
·
1 Parent(s): 909e7c5
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. .gitignore +174 -0
  3. LICENSE +201 -0
  4. app.py +905 -0
  5. inference.py +476 -0
  6. requirements.txt +36 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ ckpts*
3
+ gradio_tmp*
4
+ output*
5
+ logs*
6
+ taming*
7
+ samples*
8
+ datasets*
9
+ asset*
10
+ temp_samples*
11
+ wandb*
12
+ output_dir*
13
+ temp_output*
14
+ temp_to_be_delete*
15
+ __pycache__/
16
+ error_log.txt
17
+ .deepspeed_env
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .nox/
57
+ .coverage
58
+ .coverage.*
59
+ .cache
60
+ nosetests.xml
61
+ coverage.xml
62
+ *.cover
63
+ *.py,cover
64
+ .hypothesis/
65
+ .pytest_cache/
66
+ cover/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ .pybuilder/
90
+ target/
91
+
92
+ # Jupyter Notebook
93
+ .ipynb_checkpoints
94
+
95
+ # IPython
96
+ profile_default/
97
+ ipython_config.py
98
+
99
+ # pyenv
100
+ # For a library or package, you might want to ignore these files since the code is
101
+ # intended to run in multiple environments; otherwise, check them in:
102
+ # .python-version
103
+
104
+ # pipenv
105
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
107
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
108
+ # install all needed dependencies.
109
+ #Pipfile.lock
110
+
111
+ # poetry
112
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
113
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
114
+ # commonly ignored for libraries.
115
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
116
+ #poetry.lock
117
+
118
+ # pdm
119
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
120
+ #pdm.lock
121
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
122
+ # in version control.
123
+ # https://pdm.fming.dev/#use-with-ide
124
+ .pdm.toml
125
+
126
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
127
+ __pypackages__/
128
+
129
+ # Celery stuff
130
+ celerybeat-schedule
131
+ celerybeat.pid
132
+
133
+ # SageMath parsed files
134
+ *.sage.py
135
+
136
+ # Environments
137
+ .env
138
+ .venv
139
+ env/
140
+ venv/
141
+ ENV/
142
+ env.bak/
143
+ venv.bak/
144
+
145
+ # Spyder project settings
146
+ .spyderproject
147
+ .spyproject
148
+
149
+ # Rope project settings
150
+ .ropeproject
151
+
152
+ # mkdocs documentation
153
+ /site
154
+
155
+ # mypy
156
+ .mypy_cache/
157
+ .dmypy.json
158
+ dmypy.json
159
+
160
+ # Pyre type checker
161
+ .pyre/
162
+
163
+ # pytype static type analyzer
164
+ .pytype/
165
+
166
+ # Cython debug symbols
167
+ cython_debug/
168
+
169
+ # PyCharm
170
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
171
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
172
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
173
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
174
+ .idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import logging
5
+ import argparse
6
+ import json
7
+ import random
8
+ from datetime import datetime
9
+
10
+ import torch
11
+ import numpy as np
12
+ import cv2
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from natsort import natsorted, ns
16
+ from einops import rearrange
17
+ from omegaconf import OmegaConf
18
+ from huggingface_hub import snapshot_download
19
+ import spaces
20
+ import gradio as gr
21
+ import base64
22
+ import imageio_ffmpeg as ffmpeg
23
+ import subprocess
24
+ from different_domain_imge_gen.landmark_generation import generate_annotation
25
+
26
+ from transformers import (
27
+ Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor
28
+ )
29
+ from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics
30
+
31
+ import recon.dnnlib as dnnlib
32
+ import recon.legacy as legacy
33
+
34
+ from DiT_VAE.diffusion.utils.misc import read_config
35
+ from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane
36
+ from DiT_VAE.diffusion import IDDPM, DPMS
37
+ from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2
38
+ from DiT_VAE.diffusion.data.datasets import get_chunks
39
+
40
+ # Get the directory of the current script
41
+ father_path = os.path.dirname(os.path.abspath(__file__))
42
+
43
+ # Add necessary paths dynamically
44
+ sys.path.extend([
45
+ os.path.join(father_path, 'recon'),
46
+ os.path.join(father_path, 'Next3d'),
47
+ os.path.join(father_path, 'data_process'),
48
+ os.path.join(father_path, 'data_process/lib')
49
+
50
+ ])
51
+
52
+ from lib.FaceVerse.renderer import Faceverse_manager
53
+ from data_process.input_img_align_extract_ldm_demo import Process
54
+ from lib.config.config_demo import cfg
55
+ import shutil
56
+
57
+ # Suppress warnings (especially for PyTorch)
58
+ warnings.filterwarnings("ignore")
59
+
60
+ # Configure logging settings
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ format="%(asctime)s - %(levelname)s - %(message)s"
64
+ )
65
+ from diffusers import (
66
+ StableDiffusionControlNetImg2ImgPipeline,
67
+ ControlNetModel,
68
+ DPMSolverMultistepScheduler,
69
+ AutoencoderKL,
70
+ )
71
+
72
+
73
+ def get_args():
74
+ """Parse and return command-line arguments."""
75
+ parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments")
76
+
77
+ # Configuration and model checkpoints
78
+ parser.add_argument("--config", type=str, default="./configs/infer_config.py",
79
+ help="Path to the configuration file.")
80
+
81
+ # Generation parameters
82
+ parser.add_argument("--bs", type=int, default=1,
83
+ help="Batch size for processing.")
84
+ parser.add_argument("--cfg_scale", type=float, default=4.5,
85
+ help="CFG scale parameter.")
86
+ parser.add_argument("--sampling_algo", type=str, default="dpm-solver",
87
+ choices=["iddpm", "dpm-solver"],
88
+ help="Sampling algorithm to be used.")
89
+ parser.add_argument("--seed", type=int, default=0,
90
+ help="Random seed for reproducibility.")
91
+ # parser.add_argument("--select_img", type=str, default=None,
92
+ # help="Optional: Select a specific image.")
93
+ parser.add_argument('--step', default=-1, type=int)
94
+ # parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters")
95
+ return parser.parse_args()
96
+
97
+
98
+ def set_env(seed=0):
99
+ """Set random seed for reproducibility across multiple frameworks."""
100
+ torch.manual_seed(seed) # Set PyTorch seed
101
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU
102
+ np.random.seed(seed) # Set NumPy seed
103
+ random.seed(seed) # Set Python built-in random module seed
104
+ torch.set_grad_enabled(False) # Disable gradients for inference
105
+
106
+
107
+ def to_rgb_image(image: Image.Image):
108
+ """Convert an image to RGB format if necessary."""
109
+ if image.mode == 'RGB':
110
+ return image
111
+ elif image.mode == 'RGBA':
112
+ img = Image.new("RGB", image.size, (127, 127, 127))
113
+ img.paste(image, mask=image.getchannel('A'))
114
+ return img
115
+ else:
116
+ raise ValueError(f"Unsupported image type: {image.mode}")
117
+
118
+
119
+ def image_process(image_path, clip_image_processor, dino_img_processor, device):
120
+ """Preprocess an image for CLIP and DINO models."""
121
+ image = to_rgb_image(Image.open(image_path))
122
+ clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device)
123
+ dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device)
124
+ return dino_image, clip_image
125
+
126
+
127
+ # def video_gen(frames_dir, output_path, fps=30):
128
+ # """Generate a video from image frames."""
129
+ # frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH)
130
+ # frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files]
131
+ # H, W = frames[0].shape[:2]
132
+ # video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H))
133
+ # for frame in frames:
134
+ # video_writer.write(frame)
135
+ # video_writer.release()
136
+
137
+
138
+ def trans(tensor_img):
139
+ img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255.
140
+ img = img.to(torch.uint8)
141
+ img = img[0].detach().cpu().numpy()
142
+
143
+ return img
144
+
145
+
146
+ def get_vert(vert_dir):
147
+ uvcoords_image = np.load(os.path.join(vert_dir))[..., :3]
148
+ uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0
149
+ uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1
150
+ return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0)
151
+
152
+
153
+ def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature,
154
+ uncond_dino_feature, device, latent_size, sampling_algo):
155
+ """
156
+ Generate latent samples using the specified diffusion model.
157
+
158
+ Args:
159
+ DiT_model (torch.nn.Module): The diffusion model.
160
+ cfg_scale (float): The classifier-free guidance scale.
161
+ sample_steps (int): Number of sampling steps.
162
+ clip_feature (torch.Tensor): CLIP feature tensor.
163
+ dino_feature (torch.Tensor): DINO feature tensor.
164
+ uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor.
165
+ uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor.
166
+ device (str): Device for computation.
167
+ latent_size (tuple): The latent space size.
168
+ sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver').
169
+
170
+ Returns:
171
+ torch.Tensor: The generated samples.
172
+ """
173
+ n = 1 # Batch size
174
+ z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device)
175
+
176
+ if sampling_algo == 'iddpm':
177
+ z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance
178
+ model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]),
179
+ img_feature=torch.cat([dino_feature, dino_feature]),
180
+ cfg_scale=cfg_scale)
181
+ diffusion = IDDPM(str(sample_steps))
182
+ samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False,
183
+ model_kwargs=model_kwargs, progress=True, device=device)
184
+ samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples
185
+
186
+ elif sampling_algo == 'dpm-solver':
187
+ dpm_solver = DPMS(DiT_model.forward_with_dpmsolver,
188
+ condition=[clip_feature, dino_feature],
189
+ uncondition=[uncond_clip_feature, dino_feature],
190
+ cfg_scale=cfg_scale)
191
+ samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep")
192
+ else:
193
+ raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.")
194
+
195
+ return samples
196
+
197
+
198
+ def load_motion_aware_render_model(ckpt_path, device):
199
+ """Load the motion-aware render model from a checkpoint."""
200
+ logging.info("Loading motion-aware render model...")
201
+ with dnnlib.util.open_url(ckpt_path, 'rb') as f:
202
+ network = legacy.load_network_pkl(f) # type: ignore
203
+ logging.info("Motion-aware render model loaded.")
204
+ return network['G_ema'].to(device)
205
+
206
+
207
+ def load_diffusion_model(ckpt_path, latent_size, device):
208
+ """Load the diffusion model (DiT)."""
209
+ logging.info("Loading diffusion model (DiT)...")
210
+
211
+ DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device)
212
+ ckpt = torch.load(ckpt_path, map_location="cpu")
213
+
214
+ # Remove keys that can cause mismatches
215
+ for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']:
216
+ ckpt['state_dict'].pop(key, None)
217
+ ckpt.get('state_dict_ema', {}).pop(key, None)
218
+
219
+ state_dict = ckpt.get('state_dict_ema', ckpt)
220
+ DiT_model.load_state_dict(state_dict, strict=False)
221
+ DiT_model.eval()
222
+ logging.info("Diffusion model (DiT) loaded.")
223
+ return DiT_model
224
+
225
+
226
+ def load_vae_clip_dino(config, device):
227
+ """Load VAE, CLIP, and DINO models."""
228
+ logging.info("Loading VAE, CLIP, and DINO models...")
229
+
230
+ # Load CLIP image encoder
231
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
232
+ config.image_encoder_path)
233
+ image_encoder.requires_grad_(False)
234
+ image_encoder.to(device)
235
+
236
+ # Load VAE
237
+ config_vae = OmegaConf.load(config.vae_triplane_config_path)
238
+ vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8)
239
+ vae_triplane.to(device)
240
+
241
+ vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin')
242
+ if not os.path.isfile(vae_ckpt_path):
243
+ raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}")
244
+
245
+ vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"))
246
+ vae_triplane.requires_grad_(False)
247
+
248
+ # Load DINO model
249
+ dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained)
250
+ dinov2.requires_grad_(False)
251
+ dinov2.to(device)
252
+
253
+ # Load image processors
254
+ dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained)
255
+ clip_image_processor = CLIPImageProcessor()
256
+
257
+ logging.info("VAE, CLIP, and DINO models loaded.")
258
+ return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor
259
+
260
+
261
+ def prepare_working_dir(dir, style):
262
+ print('stylestylestylestylestylestylestyle',style)
263
+ if style:
264
+ return dir
265
+ else:
266
+ import tempfile
267
+ working_dir = tempfile.TemporaryDirectory()
268
+ return working_dir.name
269
+
270
+
271
+ def launch_pretrained():
272
+ from huggingface_hub import hf_hub_download, snapshot_download
273
+ hf_hub_download(repo_id="KumaPower/AvatarArtist", repo_type='model', local_dir="./pretrained_model")
274
+
275
+
276
+ def prepare_image_list(img_dir, selected_img):
277
+ """Prepare the list of image paths for processing."""
278
+ if selected_img and selected_img in os.listdir(img_dir):
279
+ return [os.path.join(img_dir, selected_img)]
280
+
281
+ return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)])
282
+
283
+
284
+ def images_to_video(image_folder, output_video, fps=30):
285
+ # Get all image files and ensure correct order
286
+ images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]
287
+ images = natsorted(images) # Sort filenames naturally to preserve frame order
288
+
289
+ if not images:
290
+ print("❌ No images found in the directory!")
291
+ return
292
+
293
+ # Get the path to the FFmpeg executable
294
+ ffmpeg_exe = ffmpeg.get_ffmpeg_exe()
295
+ print(f"Using FFmpeg from: {ffmpeg_exe}")
296
+
297
+ # Define input image pattern (expects images named like "%04d.png")
298
+ image_pattern = os.path.join(image_folder, "%04d.png")
299
+
300
+ # FFmpeg command to encode video
301
+ command = [
302
+ ffmpeg_exe, '-framerate', str(fps), '-i', image_pattern,
303
+ '-c:v', 'libx264', '-preset', 'slow', '-crf', '18', # High-quality H.264 encoding
304
+ '-pix_fmt', 'yuv420p', '-b:v', '5000k', # Ensure compatibility & increase bitrate
305
+ output_video
306
+ ]
307
+
308
+ # Run FFmpeg command
309
+ subprocess.run(command, check=True)
310
+
311
+ print(f"✅ High-quality MP4 video has been generated: {output_video}")
312
+
313
+
314
+ def model_define():
315
+ args = get_args()
316
+ set_env(args.seed)
317
+ input_process_model = Process(cfg)
318
+
319
+ device = "cuda" if torch.cuda.is_available() else "cpu"
320
+ weight_dtype = torch.float32
321
+ logging.info(f"Running inference with {weight_dtype}")
322
+
323
+ # Load configuration
324
+ default_config = read_config(args.config)
325
+
326
+ # Ensure valid sampling algorithm
327
+ assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver']
328
+ # Load motion-aware render model
329
+ motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt, device)
330
+
331
+ # Load diffusion model (DiT)
332
+ triplane_size = (256 * 4, 256)
333
+ latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
334
+ sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[
335
+ args.sampling_algo]
336
+ DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size, device)
337
+
338
+ # Load VAE, CLIP, and DINO
339
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config,
340
+ device)
341
+
342
+ # Load normalization parameters
343
+ triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1)
344
+ triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1)
345
+
346
+ # Load average latent vector
347
+ ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0]
348
+
349
+ # Set up face verse for amimation
350
+ base_coff = np.load(
351
+ 'pretrained_model/temp.npy').astype(
352
+ np.float32)
353
+ base_coff = torch.from_numpy(base_coff).float()
354
+ Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
355
+
356
+ return motion_aware_render_model, sample_steps, DiT_model, \
357
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
358
+
359
+
360
+ def duplicate_batch(tensor, batch_size=2):
361
+ if tensor is None:
362
+ return None # 如果是 None,则直接返回
363
+ return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
364
+
365
+
366
+ @torch.inference_mode()
367
+ @spaces.GPU(duration=200)
368
+ def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
369
+ """
370
+ Generate avatars from input images.
371
+
372
+ Args:
373
+ items (list): List of image paths.
374
+ bs (int): Batch size.
375
+ sample_steps (int): Number of sampling steps.
376
+ cfg_scale (float): Classifier-free guidance scale.
377
+ save_path_base (str): Base directory for saving results.
378
+ DiT_model (torch.nn.Module): The diffusion model.
379
+ render_model (torch.nn.Module): The rendering model.
380
+ std (torch.Tensor): Standard deviation normalization tensor.
381
+ mean (torch.Tensor): Mean normalization tensor.
382
+ ws_avg (torch.Tensor): Latent average tensor.
383
+ """
384
+ if is_styled:
385
+ items = [styled_img]
386
+ else:
387
+ items = [items]
388
+ video_folder = "./demo_data/target_video"
389
+ video_name = os.path.basename(video_path_input).split(".")[0]
390
+ target_path = os.path.join(video_folder, 'data_' + video_name)
391
+ exp_base_dir = os.path.join(target_path, 'coeffs')
392
+ exp_img_base_dir = os.path.join(target_path, 'images512x512')
393
+ motion_base_dir = os.path.join(target_path, 'motions')
394
+ label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
395
+
396
+ if source_type == 'example':
397
+ input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/trained_input_imgs'
398
+ input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/trained_input_imgs'
399
+ elif source_type == 'custom':
400
+ input_img_fvid = os.path.join(save_path_base, 'processed_img/dataset/coeffs/input_image')
401
+ input_img_motion = os.path.join(save_path_base, 'processed_img/dataset/motions/input_image')
402
+ else:
403
+ raise ValueError("Wrong type")
404
+ bs = 1
405
+ sample_steps = 20
406
+ cfg_scale = 4.5
407
+ pitch_range = 0.25
408
+ yaw_range = 0.35
409
+ triplane_size = (256 * 4, 256)
410
+ latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
411
+ for chunk in tqdm(list(get_chunks(items, 1)), unit='batch'):
412
+ if bs != 1:
413
+ raise ValueError("Batch size > 1 not implemented")
414
+
415
+ image_dir = chunk[0]
416
+
417
+ image_name = os.path.splitext(os.path.basename(image_dir))[0]
418
+ dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
419
+
420
+ clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
421
+ uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
422
+ -2]
423
+ dino_feature = dinov2(dino_img).last_hidden_state
424
+ uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
425
+
426
+ samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
427
+ uncond_clip_feature, uncond_dino_feature, device, latent_size,
428
+ 'dpm-solver')
429
+
430
+ samples = (samples / 0.3994218)
431
+ samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
432
+ samples = vae_triplane.decode(samples)
433
+ samples = rearrange(samples, "b c f h w -> b f c h w")
434
+ samples = samples * std + mean
435
+ torch.cuda.empty_cache()
436
+
437
+ save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
438
+ save_frames_path_outshow = os.path.join(save_path_base, image_name, 'out_show')
439
+ save_frames_path_depth = os.path.join(save_path_base, image_name, 'depth')
440
+
441
+ os.makedirs(save_frames_path_out, exist_ok=True)
442
+ os.makedirs(save_frames_path_outshow, exist_ok=True)
443
+ os.makedirs(save_frames_path_depth, exist_ok=True)
444
+
445
+ img_ref = np.array(Image.open(image_dir))
446
+ img_ref_out = img_ref.copy()
447
+ img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
448
+
449
+ motion_app_dir = os.path.join(input_img_motion, image_name + '.npy')
450
+ motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
451
+
452
+ id_motions = os.path.join(input_img_fvid, image_name + '.npy')
453
+
454
+ all_pose = json.loads(open(label_file_test).read())['labels']
455
+ all_pose = dict(all_pose)
456
+ if os.path.exists(id_motions):
457
+ coeff = np.load(id_motions).astype(np.float32)
458
+ coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
459
+ Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
460
+ motion_dir = os.path.join(motion_base_dir, video_name)
461
+ exp_dir = os.path.join(exp_base_dir, video_name)
462
+ for frame_index, motion_name in enumerate(
463
+ tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
464
+ exp_each_dir_img = os.path.join(exp_img_base_dir, video_name, motion_name.replace('.npy', '.png'))
465
+ exp_each_dir = os.path.join(exp_dir, motion_name)
466
+ motion_each_dir = os.path.join(motion_dir, motion_name)
467
+
468
+ # Load pose data
469
+ pose_key = os.path.join(video_name, motion_name.replace('.npy', '.png'))
470
+
471
+ cam2world_pose = LookAtPoseSampler.sample(
472
+ 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
473
+ 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
474
+ torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
475
+ pose_show = torch.cat([cam2world_pose.reshape(-1, 16),
476
+ FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
477
+
478
+ pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
479
+
480
+ # Load and resize expression image
481
+ exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
482
+
483
+ # Load expression coefficients
484
+ exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
485
+ exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
486
+
487
+ # Load motion data
488
+ motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
489
+
490
+ img_ref_double = duplicate_batch(img_ref, batch_size=2)
491
+ motion_app_double = duplicate_batch(motion_app, batch_size=2)
492
+ motion_double = duplicate_batch(motion, batch_size=2)
493
+ pose_double = torch.cat([pose_show, pose], dim=0)
494
+ exp_target_double = duplicate_batch(exp_target, batch_size=2)
495
+ samples_double = duplicate_batch(samples, batch_size=2)
496
+ # Select refine_net processing method
497
+ final_out = render_model(
498
+ img_ref_double, None, motion_app_double, motion_double, c=pose_double, mesh=exp_target_double,
499
+ triplane_recon=samples_double,
500
+ ws_avg=ws_avg, motion_scale=1.
501
+ )
502
+
503
+ # Process output image
504
+ final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
505
+ final_out_notshow = trans(final_out['image_sr'][1].unsqueeze(0))
506
+ depth = final_out['image_depth'][0].unsqueeze(0)
507
+ depth = -depth
508
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1
509
+ depth = trans(depth)
510
+
511
+ depth = np.repeat(depth[:, :, :], 3, axis=2)
512
+ # Save output images
513
+ frame_name = f'{str(frame_index).zfill(4)}.png'
514
+ Image.fromarray(depth, 'RGB').save(os.path.join(save_frames_path_depth, frame_name))
515
+ Image.fromarray(final_out_notshow, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
516
+
517
+ Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
518
+
519
+ # Generate videos
520
+ images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4'))
521
+ images_to_video(save_frames_path_outshow, os.path.join(save_path_base, image_name + '_outshow.mp4'))
522
+ images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + '_depth.mp4'))
523
+
524
+ logging.info(f"✅ Video generation completed successfully!")
525
+ return os.path.join(save_path_base, image_name + '_out.mp4'), os.path.join(save_path_base,
526
+ image_name + '_outshow.mp4'), os.path.join(save_path_base, image_name + '_depth.mp4')
527
+
528
+
529
+ def get_image_base64(path):
530
+ with open(path, "rb") as image_file:
531
+ encoded_string = base64.b64encode(image_file.read()).decode()
532
+ return f"data:image/png;base64,{encoded_string}"
533
+
534
+
535
+ def assert_input_image(input_image):
536
+ if input_image is None:
537
+ raise gr.Error("No image selected or uploaded!")
538
+
539
+
540
+ def process_image(input_image, source_type, is_style, save_dir):
541
+ """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
542
+ process_img_input_dir = os.path.join(save_dir, 'input_image')
543
+ process_img_save_dir = os.path.join(save_dir, 'processed_img')
544
+ os.makedirs(process_img_save_dir, exist_ok=True)
545
+ os.makedirs(process_img_input_dir, exist_ok=True)
546
+ if source_type == "example":
547
+ return input_image, source_type
548
+ else:
549
+ # input_process_model.inference(input_image, process_img_save_dir)
550
+ shutil.copy(input_image, process_img_input_dir)
551
+ input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
552
+ img_name = os.path.basename(input_image)
553
+ imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
554
+ return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
555
+
556
+
557
+ def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
558
+ """
559
+ 🎭 这个函数用于风格转换
560
+ ✅ 你可以在这里填入你的风格化代码
561
+ """
562
+ src_img_pil = Image.open(processed_image)
563
+ img_name = os.path.basename(processed_image)
564
+ save_dir = os.path.join(save_base, 'style_img')
565
+ os.makedirs(save_dir, exist_ok=True)
566
+ control_image = generate_annotation(src_img_pil, max_faces=1)
567
+ trg_img_pil = pipeline_sd(
568
+ prompt=style_prompt,
569
+ image=src_img_pil,
570
+ strength=strength,
571
+ control_image=Image.fromarray(control_image),
572
+ guidance_scale=cfg,
573
+ negative_prompt='worst quality, normal quality, low quality, low res, blurry',
574
+ num_inference_steps=30,
575
+ controlnet_conditioning_scale=1.5
576
+ )['images'][0]
577
+ trg_img_pil.save(os.path.join(save_dir, img_name))
578
+ return os.path.join(save_dir, img_name) # 🚨 这里需要替换成你的风格转换逻辑
579
+
580
+
581
+ def reset_flag():
582
+ return False
583
+ css = """
584
+ /* ✅ 让所有 Image 居中 + 自适应宽度 */
585
+ .gr-image img {
586
+ display: block;
587
+ margin-left: auto;
588
+ margin-right: auto;
589
+ max-width: 100%;
590
+ height: auto;
591
+ }
592
+
593
+ /* ✅ 让所有 Video 居中 + 自适应宽度 */
594
+ .gr-video video {
595
+ display: block;
596
+ margin-left: auto;
597
+ margin-right: auto;
598
+ max-width: 100%;
599
+ height: auto;
600
+ }
601
+
602
+ /* ✅ 可选:让按钮和 markdown 居中 */
603
+ #generate_block {
604
+ display: flex;
605
+ flex-direction: column;
606
+ align-items: center;
607
+ justify-content: center;
608
+ margin-top: 1rem;
609
+ }
610
+
611
+
612
+ /* 可选:让整个容器宽一点 */
613
+ #main_container {
614
+ max-width: 1280px; /* ✅ 例如限制在 1280px 内 */
615
+ margin-left: auto; /* ✅ 水平居中 */
616
+ margin-right: auto;
617
+ padding-left: 1rem;
618
+ padding-right: 1rem;
619
+ }
620
+
621
+ """
622
+
623
+ def launch_gradio_app():
624
+ styles = {
625
+ "Ghibli": "Ghibli style avatar, anime style",
626
+ "Pixar": "a 3D render of a face in Pixar style",
627
+ "Lego": "a 3D render of a head of a lego man 3D model",
628
+ "Greek Statue": "a FHD photo of a white Greek statue",
629
+ "Elf": "a FHD photo of a face of a beautiful elf with silver hair in live action movie",
630
+ "Zombie": "a FHD photo of a face of a zombie",
631
+ "Tekken": "a 3D render of a Tekken game character",
632
+ "Devil": "a FHD photo of a face of a devil in fantasy movie",
633
+ "Steampunk": "Steampunk style portrait, mechanical, brass and copper tones",
634
+ "Mario": "a 3D render of a face of Super Mario",
635
+ "Orc": "a FHD photo of a face of an orc in fantasy movie",
636
+ "Masque": "a FHD photo of a face of a person in masquerade",
637
+ "Skeleton": "a FHD photo of a face of a skeleton in fantasy movie",
638
+ "Peking Opera": "a FHD photo of face of character in Peking opera with heavy make-up",
639
+ "Yoda": "a FHD photo of a face of Yoda in Star Wars",
640
+ "Hobbit": "a FHD photo of a face of Hobbit in Lord of the Rings",
641
+ "Stained Glass": "Stained glass style, portrait, beautiful, translucent",
642
+ "Graffiti": "Graffiti style portrait, street art, vibrant, urban, detailed, tag",
643
+ "Pixel-art": "pixel art style portrait, low res, blocky, pixel art style",
644
+ "Retro": "Retro game art style portrait, vibrant colors",
645
+ "Ink": "a portrait in ink style, black and white image",
646
+ }
647
+
648
+ with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600], css=css, elem_id="main_container") as demo:
649
+ logo_url = "./docs/AvatarArtist.png"
650
+ logo_base64 = get_image_base64(logo_url)
651
+ # 🚀 让 Logo 居中 & 标题对齐
652
+ gr.HTML(
653
+ f"""
654
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-bottom: 20px;">
655
+ <img src="{logo_base64}" style="height:50px; margin-right: 15px; display: block;" onerror="this.style.display='none'"/>
656
+ <h1 style="font-size: 32px; font-weight: bold;">AvatarArtist: Open-Domain 4D Avatarization</h1>
657
+ </div>
658
+ """
659
+ )
660
+
661
+ # 🚀 让按钮在一行对齐
662
+ gr.HTML(
663
+ """
664
+ <div style="display: flex; justify-content: center; gap: 10px; margin-top: 10px;">
665
+ <a title="Website" href="https://kumapowerliu.github.io/AvatarArtist/" target="_blank" rel="noopener noreferrer">
666
+ <img src="https://img.shields.io/badge/Website-Visit-blue?style=for-the-badge&logo=GoogleChrome">
667
+ </a>
668
+ <a title="arXiv" href="https://arxiv.org/abs/2503.19906" target="_blank" rel="noopener noreferrer">
669
+ <img src="https://img.shields.io/badge/arXiv-Paper-red?style=for-the-badge&logo=arXiv">
670
+ </a>
671
+ <a title="Github" href="https://github.com/ant-research/AvatarArtist" target="_blank" rel="noopener noreferrer">
672
+ <img src="https://img.shields.io/github/stars/ant-research/AvatarArtist?style=for-the-badge&logo=github&logoColor=white&color=orange">
673
+ </a>
674
+ </div>
675
+ """
676
+ )
677
+ gr.HTML(
678
+ """
679
+ <div style="text-align: left; font-size: 16px; line-height: 1.6; margin-top: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;">
680
+ <strong>🧑‍🎨 How to use this demo:</strong>
681
+ <ol style="margin-top: 10px; padding-left: 20px;">
682
+ <li><strong>Select or upload a source image</strong> – this will be the avatar's face.</li>
683
+ <li><strong>Select or upload a target video</strong> – the avatar will mimic this motion.</li>
684
+ <li><strong>Click the <em>Process Image</em> button</strong> – this prepares the source image to meet our model's input requirements.</li>
685
+ <li><strong>(Optional)</strong> Click <em>Apply Style</em> to change the appearance of the processed image – we offer a variety of fun styles to choose from!</li>
686
+ <li><strong>Click <em>Generate Avatar</em></strong> to create the final animated result driven by the target video.</li>
687
+ </ol>
688
+ <p style="margin-top: 10px;"><strong>🎨 Tip:</strong> Try different styles to get various artistic effects for your avatar!</p>
689
+ </div>
690
+ """
691
+ )
692
+ # 🚀 添加重要提示框
693
+ gr.HTML(
694
+ """
695
+ <div style="background-color: #FFDDDD; padding: 15px; border-radius: 10px; border: 2px solid red; text-align: center; margin-top: 20px;">
696
+ <h4 style="color: red; font-size: 18px;">
697
+ 🚨 <strong>Important Notes:</strong> Please try to provide a <u>front-facing</u> or <u>full-face</u> image without obstructions.
698
+ </h4>
699
+ <p style="color: black; font-size: 16px;">
700
+ ❌ Our demo does <strong>not</strong> support uploading videos with specific motions because processing requires time.<br>
701
+ ✅ Feel free to check out our <a href="https://github.com/ant-research/AvatarArtist" target="_blank" style="color: red; font-weight: bold;">GitHub repository</a> to drive portraits using your desired motions.
702
+ </p>
703
+ </div>
704
+ """
705
+ )
706
+ # DISPLAY
707
+ image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/trained_input_imgs"
708
+ video_folder = "./demo_data/target_video"
709
+
710
+ examples_images = sorted(
711
+ [os.path.join(image_folder, f) for f in os.listdir(image_folder) if
712
+ f.lower().endswith(('.png', '.jpg', '.jpeg'))]
713
+ )
714
+ examples_videos = sorted(
715
+ [os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.lower().endswith('.mp4')]
716
+ )
717
+ print(examples_videos)
718
+ source_type = gr.State("example")
719
+ is_from_example = gr.State(value=True)
720
+ is_styled = gr.State(value=False)
721
+ working_dir = gr.State()
722
+
723
+ with gr.Row():
724
+ with gr.Column(variant='panel'):
725
+ with gr.Tabs(elem_id="input_image"):
726
+ with gr.TabItem('🎨 Upload Image'):
727
+ input_image = gr.Image(
728
+ label="Upload Source Image",
729
+ value=os.path.join(image_folder, '02057_(2).png'),
730
+ image_mode="RGB", height=512, container=True,
731
+ sources="upload", type="filepath"
732
+ )
733
+
734
+ def mark_as_example(example_image):
735
+ print("✅ mark_as_example called")
736
+ return "example", True, False
737
+
738
+ def mark_as_custom(user_image, is_from_example_flag):
739
+ print("✅ mark_as_custom called")
740
+ if is_from_example_flag:
741
+ print("⚠️ Ignored mark_as_custom triggered by example")
742
+ return "example", False, False
743
+ return "custom", False, False
744
+
745
+ input_image.change(
746
+ mark_as_custom,
747
+ inputs=[input_image, is_from_example],
748
+ outputs=[source_type, is_from_example, is_styled] # ✅ 只返回 source_type,不要输出 input_image
749
+ )
750
+
751
+ # ✅ 让 `Examples` 组件单独占一行,并绑定点击事件
752
+ with gr.Row():
753
+ example_component = gr.Examples(
754
+ examples=examples_images,
755
+ inputs=[input_image],
756
+ examples_per_page=10,
757
+ )
758
+ # ✅ 监听 `Examples` 的 `click` 事件
759
+ example_component.dataset.click(
760
+ fn=mark_as_example,
761
+ inputs=[input_image],
762
+ outputs=[source_type, is_from_example, is_styled]
763
+ )
764
+
765
+ with gr.Column(variant='panel' ):
766
+ with gr.Tabs(elem_id="input_video"):
767
+ with gr.TabItem('🎬 Target Video'):
768
+ video_input = gr.Video(
769
+ label="Select Target Motion",
770
+ height=512, container=True,interactive=False, format="mp4",
771
+ value=examples_videos[0]
772
+ )
773
+
774
+ with gr.Row():
775
+ gr.Examples(
776
+ examples=examples_videos,
777
+ inputs=[video_input],
778
+ examples_per_page=10,
779
+ )
780
+ with gr.Column(variant='panel' ):
781
+ with gr.Tabs(elem_id="processed_image"):
782
+ with gr.TabItem('🖼️ Processed Image'):
783
+ processed_image = gr.Image(
784
+ label="Processed Image",
785
+ image_mode="RGB", type="filepath",
786
+ elem_id="processed_image",
787
+ height=512, container=True,
788
+ interactive=False
789
+ )
790
+ processed_image_button = gr.Button("🔧 Process Image", variant="primary")
791
+ with gr.Column(variant='panel' ):
792
+ with gr.Tabs(elem_id="style_transfer"):
793
+ with gr.TabItem('🎭 Style Transfer'):
794
+ style_image = gr.Image(
795
+ label="Style Image",
796
+ image_mode="RGB", type="filepath",
797
+ elem_id="style_image",
798
+ height=512, container=True,
799
+ interactive=False
800
+ )
801
+ style_choice = gr.Dropdown(
802
+ choices=list(styles.keys()),
803
+ label="Choose Style",
804
+ value="Pixar"
805
+ )
806
+ cfg_slider = gr.Slider(
807
+ minimum=3.0, maximum=10.0, value=7.5, step=0.1,
808
+ label="CFG Scale"
809
+ )
810
+ strength_slider = gr.Slider(
811
+ minimum=0.4, maximum=0.85, value=0.65, step=0.05,
812
+ label="SDEdit Strength"
813
+ )
814
+ style_button = gr.Button("🎨 Apply Style", interactive=False)
815
+ gr.Markdown(
816
+ "⬅️ Please click **Process Image** first. "
817
+ "**Apply Style** will transform the image in the **Processed Image** panel "
818
+ "according to the selected style."
819
+ )
820
+
821
+
822
+ with gr.Row():
823
+ with gr.Tabs(elem_id="render_output"):
824
+ with gr.TabItem('🎥 Animation Results'):
825
+ # ✅ 让 `Generate Avatar` 按钮单独占一行
826
+ with gr.Row():
827
+ with gr.Column(scale=1, elem_id="generate_block", min_width=200):
828
+ submit = gr.Button('🚀 Generate Avatar', elem_id="avatarartist_generate", variant='primary',
829
+ interactive=False)
830
+ gr.Markdown("⬇️ Please click **Process Image** first before generating.",
831
+ elem_id="generate_tip")
832
+
833
+ # ✅ 让两个 `Animation Results` 窗口并排
834
+ with gr.Row():
835
+ output_video = gr.Video(
836
+ label="Generated Animation Input Video View",
837
+ format="mp4", height=512, width=512,
838
+ autoplay=True
839
+ )
840
+
841
+ output_video_2 = gr.Video(
842
+ label="Generated Animation Rotate View",
843
+ format="mp4", height=512, width=512,
844
+ autoplay=True
845
+ )
846
+
847
+ output_video_3 = gr.Video(
848
+ label="Generated Animation Rotate View Depth",
849
+ format="mp4", height=512, width=512,
850
+ autoplay=True
851
+ )
852
+ def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir):
853
+ styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir)
854
+ return styled, True
855
+
856
+ def process_image_and_enable_style(input_image, source_type, is_styled, wd):
857
+ processed_result, updated_source_type = process_image(input_image, source_type, is_styled, wd)
858
+ return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True)
859
+ processed_image_button.click(
860
+ fn=prepare_working_dir,
861
+ inputs=[working_dir, is_styled],
862
+ outputs=[working_dir],
863
+ queue=False,
864
+ ).success(
865
+ fn=process_image_and_enable_style,
866
+ inputs=[input_image, source_type, is_styled, working_dir],
867
+ outputs=[processed_image, source_type, style_button, submit],
868
+ queue=True
869
+ )
870
+ style_button.click(
871
+ fn=apply_style_and_mark,
872
+ inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir],
873
+ outputs=[style_image, is_styled]
874
+ )
875
+ submit.click(
876
+ fn=avatar_generation,
877
+ inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image],
878
+ outputs=[output_video, output_video_2, output_video_3], # ⏳ 稍后展示视频
879
+ queue=True
880
+ )
881
+
882
+
883
+ demo.queue()
884
+ demo.launch(server_name="0.0.0.0")
885
+
886
+
887
+ if __name__ == '__main__':
888
+ import torch.multiprocessing as mp
889
+
890
+ mp.set_start_method('spawn', force=True)
891
+ image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/trained_input_imgs"
892
+ example_img_names = os.listdir(image_folder)
893
+ render_model, sample_steps, DiT_model, \
894
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
895
+ controlnet_path = '/nas8/liuhongyu/model/ControlNetMediaPipeFaceold'
896
+ controlnet = ControlNetModel.from_pretrained(
897
+ controlnet_path, torch_dtype=torch.float16
898
+ )
899
+ sd_path = '/nas8/liuhongyu/model/stable-diffusion-2-1-base'
900
+ pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
901
+ sd_path, torch_dtype=torch.float16,
902
+ use_safetensors=True, controlnet=controlnet, variant="fp16"
903
+ ).to(device)
904
+ demo_cam = False
905
+ launch_gradio_app()
inference.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import logging
5
+ import argparse
6
+ import json
7
+ import random
8
+ from datetime import datetime
9
+
10
+ import torch
11
+ import numpy as np
12
+ import cv2
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from natsort import natsorted, ns
16
+ from einops import rearrange
17
+ from omegaconf import OmegaConf
18
+ from huggingface_hub import snapshot_download
19
+
20
+ from transformers import (
21
+ Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor
22
+ )
23
+ from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics
24
+
25
+ from data_process.lib.FaceVerse.renderer import Faceverse_manager
26
+ import recon.dnnlib as dnnlib
27
+ import recon.legacy as legacy
28
+
29
+ from DiT_VAE.diffusion.utils.misc import read_config
30
+ from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane
31
+ from DiT_VAE.diffusion import IDDPM, DPMS
32
+ from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2
33
+ from DiT_VAE.diffusion.data.datasets import get_chunks
34
+
35
+ # Get the directory of the current script
36
+ father_path = os.path.dirname(os.path.abspath(__file__))
37
+
38
+ # Add necessary paths dynamically
39
+ sys.path.extend([
40
+ os.path.join(father_path, 'recon'),
41
+ os.path.join(father_path, 'Next3d')
42
+ ])
43
+
44
+ # Suppress warnings (especially for PyTorch)
45
+ warnings.filterwarnings("ignore")
46
+
47
+ # Configure logging settings
48
+ logging.basicConfig(
49
+ level=logging.INFO,
50
+ format="%(asctime)s - %(levelname)s - %(message)s"
51
+ )
52
+
53
+
54
+ def get_args():
55
+ """Parse and return command-line arguments."""
56
+ parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments")
57
+
58
+ # Configuration and model checkpoints
59
+ parser.add_argument("--config", type=str, default="./configs/infer_config.py",
60
+ help="Path to the configuration file.")
61
+
62
+ # Input data paths
63
+ parser.add_argument("--target_path", type=str, required=True, default='./demo_data/target_video/data_obama',
64
+ help="Base path of the dataset.")
65
+ parser.add_argument("--img_file", type=str, required=True, default='./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs',
66
+ help="Directory containing input images.")
67
+ parser.add_argument("--input_img_motion", type=str,
68
+ default="./demo_data/source_img/img_generate_different_domain/motions/demo_imgs",
69
+ help="Directory containing motion features.")
70
+ parser.add_argument("--video_name", type=str, required=True, default='Obama',
71
+ help="Name of the video.")
72
+ parser.add_argument("--input_img_fvid", type=str,
73
+ default="./demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs",
74
+ help="Path to input image coefficients.")
75
+
76
+ # Output settings
77
+ parser.add_argument("--output_basedir", type=str, default="./output",
78
+ help="Base directory for saving output results.")
79
+
80
+ # Generation parameters
81
+ parser.add_argument("--bs", type=int, default=1,
82
+ help="Batch size for processing.")
83
+ parser.add_argument("--cfg_scale", type=float, default=4.5,
84
+ help="CFG scale parameter.")
85
+ parser.add_argument("--sampling_algo", type=str, default="dpm-solver",
86
+ choices=["iddpm", "dpm-solver"],
87
+ help="Sampling algorithm to be used.")
88
+ parser.add_argument("--seed", type=int, default=0,
89
+ help="Random seed for reproducibility.")
90
+ parser.add_argument("--select_img", type=str, default=None,
91
+ help="Optional: Select a specific image.")
92
+ parser.add_argument('--step', default=-1, type=int)
93
+ parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters")
94
+ return parser.parse_args()
95
+
96
+
97
+ def set_env(seed=0):
98
+ """Set random seed for reproducibility across multiple frameworks."""
99
+ torch.manual_seed(seed) # Set PyTorch seed
100
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU
101
+ np.random.seed(seed) # Set NumPy seed
102
+ random.seed(seed) # Set Python built-in random module seed
103
+ torch.set_grad_enabled(False) # Disable gradients for inference
104
+
105
+
106
+ def to_rgb_image(image: Image.Image):
107
+ """Convert an image to RGB format if necessary."""
108
+ if image.mode == 'RGB':
109
+ return image
110
+ elif image.mode == 'RGBA':
111
+ img = Image.new("RGB", image.size, (127, 127, 127))
112
+ img.paste(image, mask=image.getchannel('A'))
113
+ return img
114
+ else:
115
+ raise ValueError(f"Unsupported image type: {image.mode}")
116
+
117
+
118
+ def image_process(image_path):
119
+ """Preprocess an image for CLIP and DINO models."""
120
+ image = to_rgb_image(Image.open(image_path))
121
+ clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device)
122
+ dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device)
123
+ return dino_image, clip_image
124
+
125
+
126
+ def video_gen(frames_dir, output_path, fps=30):
127
+ """Generate a video from image frames."""
128
+ frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH)
129
+ frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files]
130
+ H, W = frames[0].shape[:2]
131
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H))
132
+ for frame in frames:
133
+ video_writer.write(frame)
134
+ video_writer.release()
135
+
136
+
137
+ def trans(tensor_img):
138
+ img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255.
139
+ img = img.to(torch.uint8)
140
+ img = img[0].detach().cpu().numpy()
141
+
142
+ return img
143
+
144
+
145
+ def get_vert(vert_dir):
146
+ uvcoords_image = np.load(os.path.join(vert_dir))[..., :3]
147
+ uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0
148
+ uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1
149
+ return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0)
150
+
151
+
152
+ def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature,
153
+ uncond_dino_feature, device, latent_size, sampling_algo):
154
+ """
155
+ Generate latent samples using the specified diffusion model.
156
+
157
+ Args:
158
+ DiT_model (torch.nn.Module): The diffusion model.
159
+ cfg_scale (float): The classifier-free guidance scale.
160
+ sample_steps (int): Number of sampling steps.
161
+ clip_feature (torch.Tensor): CLIP feature tensor.
162
+ dino_feature (torch.Tensor): DINO feature tensor.
163
+ uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor.
164
+ uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor.
165
+ device (str): Device for computation.
166
+ latent_size (tuple): The latent space size.
167
+ sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver').
168
+
169
+ Returns:
170
+ torch.Tensor: The generated samples.
171
+ """
172
+ n = 1 # Batch size
173
+ z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device)
174
+
175
+ if sampling_algo == 'iddpm':
176
+ z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance
177
+ model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]),
178
+ img_feature=torch.cat([dino_feature, dino_feature]),
179
+ cfg_scale=cfg_scale)
180
+ diffusion = IDDPM(str(sample_steps))
181
+ samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False,
182
+ model_kwargs=model_kwargs, progress=True, device=device)
183
+ samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples
184
+
185
+ elif sampling_algo == 'dpm-solver':
186
+ dpm_solver = DPMS(DiT_model.forward_with_dpmsolver,
187
+ condition=[clip_feature, dino_feature],
188
+ uncondition=[uncond_clip_feature, dino_feature],
189
+ cfg_scale=cfg_scale)
190
+ samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep")
191
+ else:
192
+ raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.")
193
+
194
+ return samples
195
+
196
+ def images_to_video(image_folder, output_video, fps=30):
197
+ # Get all image files and ensure correct order
198
+ images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]
199
+ images = natsorted(images) # Sort filenames naturally to preserve frame order
200
+
201
+ if not images:
202
+ print("❌ No images found in the directory!")
203
+ return
204
+
205
+ # Get the path to the FFmpeg executable
206
+ ffmpeg_exe = ffmpeg.get_ffmpeg_exe()
207
+ print(f"Using FFmpeg from: {ffmpeg_exe}")
208
+
209
+ # Define input image pattern (expects images named like "%04d.png")
210
+ image_pattern = os.path.join(image_folder, "%04d.png")
211
+
212
+ # FFmpeg command to encode video
213
+ command = [
214
+ ffmpeg_exe, '-framerate', str(fps), '-i', image_pattern,
215
+ '-c:v', 'libx264', '-preset', 'slow', '-crf', '18', # High-quality H.264 encoding
216
+ '-pix_fmt', 'yuv420p', '-b:v', '5000k', # Ensure compatibility & increase bitrate
217
+ output_video
218
+ ]
219
+
220
+ # Run FFmpeg command
221
+ subprocess.run(command, check=True)
222
+
223
+ print(f"✅ High-quality MP4 video has been generated: {output_video}")
224
+ @torch.inference_mode()
225
+ def avatar_generation(items, bs, sample_steps, cfg_scale, save_path_base, DiT_model, render_model, std, mean, ws_avg,
226
+ Faceverse, pitch_range=0.25, yaw_range=0.35, demo_cam=False):
227
+ """
228
+ Generate avatars from input images.
229
+
230
+ Args:
231
+ items (list): List of image paths.
232
+ bs (int): Batch size.
233
+ sample_steps (int): Number of sampling steps.
234
+ cfg_scale (float): Classifier-free guidance scale.
235
+ save_path_base (str): Base directory for saving results.
236
+ DiT_model (torch.nn.Module): The diffusion model.
237
+ render_model (torch.nn.Module): The rendering model.
238
+ std (torch.Tensor): Standard deviation normalization tensor.
239
+ mean (torch.Tensor): Mean normalization tensor.
240
+ ws_avg (torch.Tensor): Latent average tensor.
241
+ """
242
+ for chunk in tqdm(list(get_chunks(items, bs)), unit='batch'):
243
+ if bs != 1:
244
+ raise ValueError("Batch size > 1 not implemented")
245
+
246
+ image_dir = chunk[0]
247
+ image_name = os.path.splitext(os.path.basename(image_dir))[0]
248
+ dino_img, clip_image = image_process(image_dir)
249
+
250
+ clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
251
+ uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
252
+ -2]
253
+ dino_feature = dinov2(dino_img).last_hidden_state
254
+ uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
255
+
256
+ samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
257
+ uncond_clip_feature, uncond_dino_feature, device, latent_size,
258
+ args.sampling_algo)
259
+
260
+ samples = (samples / default_config.scale_factor)
261
+ samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
262
+ samples = vae_triplane.decode(samples)
263
+ samples = rearrange(samples, "b c f h w -> b f c h w")
264
+ samples = samples * std + mean
265
+ torch.cuda.empty_cache()
266
+
267
+ save_frames_path_combine = os.path.join(save_path_base, image_name, 'combine')
268
+ save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
269
+ os.makedirs(save_frames_path_combine, exist_ok=True)
270
+ os.makedirs(save_frames_path_out, exist_ok=True)
271
+
272
+ img_ref = np.array(Image.open(image_dir))
273
+ img_ref_out = img_ref.copy()
274
+ img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
275
+
276
+ motion_app_dir = os.path.join(args.input_img_motion, image_name + '.npy')
277
+ motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
278
+
279
+ id_motions = os.path.join(args.input_img_fvid, image_name + '.npy')
280
+
281
+ all_pose = json.loads(open(label_file_test).read())['labels']
282
+ all_pose = dict(all_pose)
283
+ if os.path.exists(id_motions):
284
+ coeff = np.load(id_motions).astype(np.float32)
285
+ coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
286
+ Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
287
+ motion_dir = os.path.join(motion_base_dir, args.video_name)
288
+ exp_dir = os.path.join(exp_base_dir, args.video_name)
289
+ for frame_index, motion_name in enumerate(
290
+ tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
291
+ exp_each_dir_img = os.path.join(exp_img_base_dir, args.video_name, motion_name.replace('.npy', '.png'))
292
+ exp_each_dir = os.path.join(exp_dir, motion_name)
293
+ motion_each_dir = os.path.join(motion_dir, motion_name)
294
+
295
+ # Load pose data
296
+ pose_key = os.path.join(args.video_name, motion_name.replace('.npy', '.png'))
297
+ if demo_cam:
298
+ cam2world_pose = LookAtPoseSampler.sample(
299
+ 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
300
+ 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
301
+ torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
302
+ pose = torch.cat([cam2world_pose.reshape(-1, 16),
303
+ FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
304
+ else:
305
+ pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
306
+
307
+ # Load and resize expression image
308
+ exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
309
+
310
+ # Load expression coefficients
311
+ exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
312
+ exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
313
+
314
+ # Load motion data
315
+ motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
316
+
317
+ # Select refine_net processing method
318
+ final_out = render_model(
319
+ img_ref, None, motion_app, motion, c=pose, mesh=exp_target, triplane_recon=samples,
320
+ ws_avg=ws_avg, motion_scale=1.
321
+ )
322
+
323
+ # Process output image
324
+ final_out = trans(final_out['image_sr'])
325
+ output_img_combine = np.hstack((img_ref_out, exp_img, final_out))
326
+
327
+ # Save output images
328
+ frame_name = f'{str(frame_index).zfill(4)}.png'
329
+ Image.fromarray(output_img_combine, 'RGB').save(os.path.join(save_frames_path_combine, frame_name))
330
+ Image.fromarray(final_out, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
331
+
332
+ # Generate videos
333
+ images_to_video(save_frames_path_combine, os.path.join(save_path_base, image_name + '_combine.mp4'))
334
+ images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4'))
335
+ logging.info(f"✅ Video generation completed successfully!")
336
+ logging.info(f"📂 Combined video saved at: {os.path.join(save_path_base, image_name + '_combine.mp4')}")
337
+ logging.info(f"📂 Output video saved at: {os.path.join(save_path_base, image_name + '_out.mp4')}")
338
+
339
+
340
+ def load_motion_aware_render_model(ckpt_path):
341
+ """Load the motion-aware render model from a checkpoint."""
342
+ logging.info("Loading motion-aware render model...")
343
+ with dnnlib.util.open_url(ckpt_path, 'rb') as f:
344
+ network = legacy.load_network_pkl(f) # type: ignore
345
+ logging.info("Motion-aware render model loaded.")
346
+ return network['G_ema'].to(device)
347
+
348
+
349
+ def load_diffusion_model(ckpt_path, latent_size):
350
+ """Load the diffusion model (DiT)."""
351
+ logging.info("Loading diffusion model (DiT)...")
352
+
353
+ DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device)
354
+ ckpt = torch.load(ckpt_path, map_location="cpu")
355
+
356
+ # Remove keys that can cause mismatches
357
+ for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']:
358
+ ckpt['state_dict'].pop(key, None)
359
+ ckpt.get('state_dict_ema', {}).pop(key, None)
360
+
361
+ state_dict = ckpt.get('state_dict_ema', ckpt)
362
+ DiT_model.load_state_dict(state_dict, strict=False)
363
+ DiT_model.eval()
364
+ logging.info("Diffusion model (DiT) loaded.")
365
+ return DiT_model
366
+
367
+
368
+ def load_vae_clip_dino(config, device):
369
+ """Load VAE, CLIP, and DINO models."""
370
+ logging.info("Loading VAE, CLIP, and DINO models...")
371
+
372
+ # Load CLIP image encoder
373
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
374
+ config.image_encoder_path)
375
+ image_encoder.requires_grad_(False)
376
+ image_encoder.to(device)
377
+
378
+ # Load VAE
379
+ config_vae = OmegaConf.load(config.vae_triplane_config_path)
380
+ vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8)
381
+ vae_triplane.to(device)
382
+
383
+ vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin')
384
+ if not os.path.isfile(vae_ckpt_path):
385
+ raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}")
386
+
387
+ vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"))
388
+ vae_triplane.requires_grad_(False)
389
+
390
+ # Load DINO model
391
+ dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained)
392
+ dinov2.requires_grad_(False)
393
+ dinov2.to(device)
394
+
395
+ # Load image processors
396
+ dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained)
397
+ clip_image_processor = CLIPImageProcessor()
398
+
399
+ logging.info("VAE, CLIP, and DINO models loaded.")
400
+ return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor
401
+
402
+
403
+ def prepare_image_list(img_dir, selected_img):
404
+ """Prepare the list of image paths for processing."""
405
+ if selected_img and selected_img in os.listdir(img_dir):
406
+ return [os.path.join(img_dir, selected_img)]
407
+
408
+ return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)])
409
+
410
+
411
+ if __name__ == '__main__':
412
+
413
+ model_path = "./pretrained_model"
414
+ if not os.path.exists(model_path):
415
+ logging.info("📥 Model not found. Downloading from Hugging Face...")
416
+ snapshot_download(repo_id="KumaPower/AvatarArtist", local_dir=model_path)
417
+ logging.info("✅ Model downloaded successfully!")
418
+ else:
419
+ logging.info("🎉 Pretrained model already exists. Skipping download.")
420
+
421
+ args = get_args()
422
+ exp_base_dir = os.path.join(args.target_path, 'coeffs')
423
+ exp_img_base_dir = os.path.join(args.target_path, 'images512x512')
424
+ motion_base_dir = os.path.join(args.target_path, 'motions')
425
+ label_file_test = os.path.join(args.target_path, 'images512x512/dataset_realcam.json')
426
+ set_env(args.seed)
427
+
428
+ device = "cuda" if torch.cuda.is_available() else "cpu"
429
+ weight_dtype = torch.float32
430
+ logging.info(f"Running inference with {weight_dtype}")
431
+
432
+ # Load configuration
433
+ default_config = read_config(args.config)
434
+
435
+ # Ensure valid sampling algorithm
436
+ assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver']
437
+
438
+ # Prepare image list
439
+ items = prepare_image_list(args.img_file, args.select_img)
440
+ logging.info(f"Input images: {items}")
441
+
442
+ # Load motion-aware render model
443
+ motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt)
444
+
445
+ # Load diffusion model (DiT)
446
+ triplane_size = (256 * 4, 256)
447
+ latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
448
+ sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[
449
+ args.sampling_algo]
450
+ DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size)
451
+
452
+ # Load VAE, CLIP, and DINO
453
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config,
454
+ device)
455
+
456
+ # Load normalization parameters
457
+ triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1)
458
+ triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1)
459
+
460
+ # Load average latent vector
461
+ ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0]
462
+
463
+ # Set up save directory
464
+ save_root = os.path.join(args.output_basedir, f'{datetime.now().date()}', args.video_name)
465
+ os.makedirs(save_root, exist_ok=True)
466
+
467
+ # Set up face verse for amimation
468
+ base_coff = np.load(
469
+ 'pretrained_model/temp.npy').astype(
470
+ np.float32)
471
+ base_coff = torch.from_numpy(base_coff).float()
472
+ Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
473
+
474
+ # Run avatar generation
475
+ avatar_generation(items, args.bs, sample_steps, args.cfg_scale, save_root, DiT_model, motion_aware_render_model,
476
+ triplane_std, triplane_mean, ws_avg, Faceverse, demo_cam=args.use_demo_cam)
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.34.1
2
+ scipy==1.13.1
3
+ scikit-image==0.21.0
4
+ scikit-learn==1.2.2
5
+ opencv-python==4.10.0.84
6
+ Pillow==10.4.0
7
+ termcolor==1.1.0
8
+ PyYAML==6.0.2
9
+ tqdm==4.67.1
10
+ absl-py==2.1.0
11
+ tensorboard==2.17.0
12
+ tensorboardX==2.6.1
13
+ PyOpenGL==3.1.0
14
+ pyrender==0.1.45
15
+ trimesh==3.22.0
16
+ click==8.1.7
17
+ omegaconf==2.2.3
18
+ segmentation_models_pytorch
19
+ timm
20
+ psutil==5.9.5
21
+ lmdb==1.4.1
22
+ einops==0.8.1
23
+ kornia==0.6.7
24
+ gdown==5.2.0
25
+ plyfile==1.0.3
26
+ natsort
27
+ mmcv==1.7.0
28
+ xformers==0.0.22.post3
29
+ https://download.pytorch.org/whl/cu121/torch-2.1.1%2Bcu121-cp39-cp39-linux_x86_64.whl
30
+ https://download.pytorch.org/whl/cu121/torchvision-0.16.1%2Bcu121-cp39-cp39-linux_x86_64
31
+ git+https://github.com/facebookresearch/[email protected]
32
+ matplotlib==3.9.1.post1
33
+ mediapipe
34
+ open_clip
35
+ imageio_ffmpeg
36
+ spaces