Spaces:
Running
on
Zero
Running
on
Zero
刘虹雨
commited on
Commit
·
8f481d2
1
Parent(s):
909e7c5
update
Browse files- .gitattributes +1 -0
- .gitignore +174 -0
- LICENSE +201 -0
- app.py +905 -0
- inference.py +476 -0
- requirements.txt +36 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
ckpts*
|
3 |
+
gradio_tmp*
|
4 |
+
output*
|
5 |
+
logs*
|
6 |
+
taming*
|
7 |
+
samples*
|
8 |
+
datasets*
|
9 |
+
asset*
|
10 |
+
temp_samples*
|
11 |
+
wandb*
|
12 |
+
output_dir*
|
13 |
+
temp_output*
|
14 |
+
temp_to_be_delete*
|
15 |
+
__pycache__/
|
16 |
+
error_log.txt
|
17 |
+
.deepspeed_env
|
18 |
+
*.py[cod]
|
19 |
+
*$py.class
|
20 |
+
|
21 |
+
# C extensions
|
22 |
+
*.so
|
23 |
+
|
24 |
+
# Distribution / packaging
|
25 |
+
.Python
|
26 |
+
build/
|
27 |
+
develop-eggs/
|
28 |
+
dist/
|
29 |
+
downloads/
|
30 |
+
eggs/
|
31 |
+
.eggs/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
share/python-wheels/
|
38 |
+
*.egg-info/
|
39 |
+
.installed.cfg
|
40 |
+
*.egg
|
41 |
+
MANIFEST
|
42 |
+
|
43 |
+
# PyInstaller
|
44 |
+
# Usually these files are written by a python script from a template
|
45 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
46 |
+
*.manifest
|
47 |
+
*.spec
|
48 |
+
|
49 |
+
# Installer logs
|
50 |
+
pip-log.txt
|
51 |
+
pip-delete-this-directory.txt
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.nox/
|
57 |
+
.coverage
|
58 |
+
.coverage.*
|
59 |
+
.cache
|
60 |
+
nosetests.xml
|
61 |
+
coverage.xml
|
62 |
+
*.cover
|
63 |
+
*.py,cover
|
64 |
+
.hypothesis/
|
65 |
+
.pytest_cache/
|
66 |
+
cover/
|
67 |
+
|
68 |
+
# Translations
|
69 |
+
*.mo
|
70 |
+
*.pot
|
71 |
+
|
72 |
+
# Django stuff:
|
73 |
+
*.log
|
74 |
+
local_settings.py
|
75 |
+
db.sqlite3
|
76 |
+
db.sqlite3-journal
|
77 |
+
|
78 |
+
# Flask stuff:
|
79 |
+
instance/
|
80 |
+
.webassets-cache
|
81 |
+
|
82 |
+
# Scrapy stuff:
|
83 |
+
.scrapy
|
84 |
+
|
85 |
+
# Sphinx documentation
|
86 |
+
docs/_build/
|
87 |
+
|
88 |
+
# PyBuilder
|
89 |
+
.pybuilder/
|
90 |
+
target/
|
91 |
+
|
92 |
+
# Jupyter Notebook
|
93 |
+
.ipynb_checkpoints
|
94 |
+
|
95 |
+
# IPython
|
96 |
+
profile_default/
|
97 |
+
ipython_config.py
|
98 |
+
|
99 |
+
# pyenv
|
100 |
+
# For a library or package, you might want to ignore these files since the code is
|
101 |
+
# intended to run in multiple environments; otherwise, check them in:
|
102 |
+
# .python-version
|
103 |
+
|
104 |
+
# pipenv
|
105 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
106 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
107 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
108 |
+
# install all needed dependencies.
|
109 |
+
#Pipfile.lock
|
110 |
+
|
111 |
+
# poetry
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
113 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
114 |
+
# commonly ignored for libraries.
|
115 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
116 |
+
#poetry.lock
|
117 |
+
|
118 |
+
# pdm
|
119 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
120 |
+
#pdm.lock
|
121 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
122 |
+
# in version control.
|
123 |
+
# https://pdm.fming.dev/#use-with-ide
|
124 |
+
.pdm.toml
|
125 |
+
|
126 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
127 |
+
__pypackages__/
|
128 |
+
|
129 |
+
# Celery stuff
|
130 |
+
celerybeat-schedule
|
131 |
+
celerybeat.pid
|
132 |
+
|
133 |
+
# SageMath parsed files
|
134 |
+
*.sage.py
|
135 |
+
|
136 |
+
# Environments
|
137 |
+
.env
|
138 |
+
.venv
|
139 |
+
env/
|
140 |
+
venv/
|
141 |
+
ENV/
|
142 |
+
env.bak/
|
143 |
+
venv.bak/
|
144 |
+
|
145 |
+
# Spyder project settings
|
146 |
+
.spyderproject
|
147 |
+
.spyproject
|
148 |
+
|
149 |
+
# Rope project settings
|
150 |
+
.ropeproject
|
151 |
+
|
152 |
+
# mkdocs documentation
|
153 |
+
/site
|
154 |
+
|
155 |
+
# mypy
|
156 |
+
.mypy_cache/
|
157 |
+
.dmypy.json
|
158 |
+
dmypy.json
|
159 |
+
|
160 |
+
# Pyre type checker
|
161 |
+
.pyre/
|
162 |
+
|
163 |
+
# pytype static type analyzer
|
164 |
+
.pytype/
|
165 |
+
|
166 |
+
# Cython debug symbols
|
167 |
+
cython_debug/
|
168 |
+
|
169 |
+
# PyCharm
|
170 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
171 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
172 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
173 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
174 |
+
.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
import logging
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from PIL import Image
|
14 |
+
from tqdm import tqdm
|
15 |
+
from natsort import natsorted, ns
|
16 |
+
from einops import rearrange
|
17 |
+
from omegaconf import OmegaConf
|
18 |
+
from huggingface_hub import snapshot_download
|
19 |
+
import spaces
|
20 |
+
import gradio as gr
|
21 |
+
import base64
|
22 |
+
import imageio_ffmpeg as ffmpeg
|
23 |
+
import subprocess
|
24 |
+
from different_domain_imge_gen.landmark_generation import generate_annotation
|
25 |
+
|
26 |
+
from transformers import (
|
27 |
+
Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor
|
28 |
+
)
|
29 |
+
from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics
|
30 |
+
|
31 |
+
import recon.dnnlib as dnnlib
|
32 |
+
import recon.legacy as legacy
|
33 |
+
|
34 |
+
from DiT_VAE.diffusion.utils.misc import read_config
|
35 |
+
from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane
|
36 |
+
from DiT_VAE.diffusion import IDDPM, DPMS
|
37 |
+
from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2
|
38 |
+
from DiT_VAE.diffusion.data.datasets import get_chunks
|
39 |
+
|
40 |
+
# Get the directory of the current script
|
41 |
+
father_path = os.path.dirname(os.path.abspath(__file__))
|
42 |
+
|
43 |
+
# Add necessary paths dynamically
|
44 |
+
sys.path.extend([
|
45 |
+
os.path.join(father_path, 'recon'),
|
46 |
+
os.path.join(father_path, 'Next3d'),
|
47 |
+
os.path.join(father_path, 'data_process'),
|
48 |
+
os.path.join(father_path, 'data_process/lib')
|
49 |
+
|
50 |
+
])
|
51 |
+
|
52 |
+
from lib.FaceVerse.renderer import Faceverse_manager
|
53 |
+
from data_process.input_img_align_extract_ldm_demo import Process
|
54 |
+
from lib.config.config_demo import cfg
|
55 |
+
import shutil
|
56 |
+
|
57 |
+
# Suppress warnings (especially for PyTorch)
|
58 |
+
warnings.filterwarnings("ignore")
|
59 |
+
|
60 |
+
# Configure logging settings
|
61 |
+
logging.basicConfig(
|
62 |
+
level=logging.INFO,
|
63 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
64 |
+
)
|
65 |
+
from diffusers import (
|
66 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
67 |
+
ControlNetModel,
|
68 |
+
DPMSolverMultistepScheduler,
|
69 |
+
AutoencoderKL,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def get_args():
|
74 |
+
"""Parse and return command-line arguments."""
|
75 |
+
parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments")
|
76 |
+
|
77 |
+
# Configuration and model checkpoints
|
78 |
+
parser.add_argument("--config", type=str, default="./configs/infer_config.py",
|
79 |
+
help="Path to the configuration file.")
|
80 |
+
|
81 |
+
# Generation parameters
|
82 |
+
parser.add_argument("--bs", type=int, default=1,
|
83 |
+
help="Batch size for processing.")
|
84 |
+
parser.add_argument("--cfg_scale", type=float, default=4.5,
|
85 |
+
help="CFG scale parameter.")
|
86 |
+
parser.add_argument("--sampling_algo", type=str, default="dpm-solver",
|
87 |
+
choices=["iddpm", "dpm-solver"],
|
88 |
+
help="Sampling algorithm to be used.")
|
89 |
+
parser.add_argument("--seed", type=int, default=0,
|
90 |
+
help="Random seed for reproducibility.")
|
91 |
+
# parser.add_argument("--select_img", type=str, default=None,
|
92 |
+
# help="Optional: Select a specific image.")
|
93 |
+
parser.add_argument('--step', default=-1, type=int)
|
94 |
+
# parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters")
|
95 |
+
return parser.parse_args()
|
96 |
+
|
97 |
+
|
98 |
+
def set_env(seed=0):
|
99 |
+
"""Set random seed for reproducibility across multiple frameworks."""
|
100 |
+
torch.manual_seed(seed) # Set PyTorch seed
|
101 |
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
102 |
+
np.random.seed(seed) # Set NumPy seed
|
103 |
+
random.seed(seed) # Set Python built-in random module seed
|
104 |
+
torch.set_grad_enabled(False) # Disable gradients for inference
|
105 |
+
|
106 |
+
|
107 |
+
def to_rgb_image(image: Image.Image):
|
108 |
+
"""Convert an image to RGB format if necessary."""
|
109 |
+
if image.mode == 'RGB':
|
110 |
+
return image
|
111 |
+
elif image.mode == 'RGBA':
|
112 |
+
img = Image.new("RGB", image.size, (127, 127, 127))
|
113 |
+
img.paste(image, mask=image.getchannel('A'))
|
114 |
+
return img
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Unsupported image type: {image.mode}")
|
117 |
+
|
118 |
+
|
119 |
+
def image_process(image_path, clip_image_processor, dino_img_processor, device):
|
120 |
+
"""Preprocess an image for CLIP and DINO models."""
|
121 |
+
image = to_rgb_image(Image.open(image_path))
|
122 |
+
clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device)
|
123 |
+
dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device)
|
124 |
+
return dino_image, clip_image
|
125 |
+
|
126 |
+
|
127 |
+
# def video_gen(frames_dir, output_path, fps=30):
|
128 |
+
# """Generate a video from image frames."""
|
129 |
+
# frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH)
|
130 |
+
# frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files]
|
131 |
+
# H, W = frames[0].shape[:2]
|
132 |
+
# video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H))
|
133 |
+
# for frame in frames:
|
134 |
+
# video_writer.write(frame)
|
135 |
+
# video_writer.release()
|
136 |
+
|
137 |
+
|
138 |
+
def trans(tensor_img):
|
139 |
+
img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255.
|
140 |
+
img = img.to(torch.uint8)
|
141 |
+
img = img[0].detach().cpu().numpy()
|
142 |
+
|
143 |
+
return img
|
144 |
+
|
145 |
+
|
146 |
+
def get_vert(vert_dir):
|
147 |
+
uvcoords_image = np.load(os.path.join(vert_dir))[..., :3]
|
148 |
+
uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0
|
149 |
+
uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1
|
150 |
+
return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0)
|
151 |
+
|
152 |
+
|
153 |
+
def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature,
|
154 |
+
uncond_dino_feature, device, latent_size, sampling_algo):
|
155 |
+
"""
|
156 |
+
Generate latent samples using the specified diffusion model.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
DiT_model (torch.nn.Module): The diffusion model.
|
160 |
+
cfg_scale (float): The classifier-free guidance scale.
|
161 |
+
sample_steps (int): Number of sampling steps.
|
162 |
+
clip_feature (torch.Tensor): CLIP feature tensor.
|
163 |
+
dino_feature (torch.Tensor): DINO feature tensor.
|
164 |
+
uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor.
|
165 |
+
uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor.
|
166 |
+
device (str): Device for computation.
|
167 |
+
latent_size (tuple): The latent space size.
|
168 |
+
sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver').
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
torch.Tensor: The generated samples.
|
172 |
+
"""
|
173 |
+
n = 1 # Batch size
|
174 |
+
z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device)
|
175 |
+
|
176 |
+
if sampling_algo == 'iddpm':
|
177 |
+
z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance
|
178 |
+
model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]),
|
179 |
+
img_feature=torch.cat([dino_feature, dino_feature]),
|
180 |
+
cfg_scale=cfg_scale)
|
181 |
+
diffusion = IDDPM(str(sample_steps))
|
182 |
+
samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False,
|
183 |
+
model_kwargs=model_kwargs, progress=True, device=device)
|
184 |
+
samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples
|
185 |
+
|
186 |
+
elif sampling_algo == 'dpm-solver':
|
187 |
+
dpm_solver = DPMS(DiT_model.forward_with_dpmsolver,
|
188 |
+
condition=[clip_feature, dino_feature],
|
189 |
+
uncondition=[uncond_clip_feature, dino_feature],
|
190 |
+
cfg_scale=cfg_scale)
|
191 |
+
samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep")
|
192 |
+
else:
|
193 |
+
raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.")
|
194 |
+
|
195 |
+
return samples
|
196 |
+
|
197 |
+
|
198 |
+
def load_motion_aware_render_model(ckpt_path, device):
|
199 |
+
"""Load the motion-aware render model from a checkpoint."""
|
200 |
+
logging.info("Loading motion-aware render model...")
|
201 |
+
with dnnlib.util.open_url(ckpt_path, 'rb') as f:
|
202 |
+
network = legacy.load_network_pkl(f) # type: ignore
|
203 |
+
logging.info("Motion-aware render model loaded.")
|
204 |
+
return network['G_ema'].to(device)
|
205 |
+
|
206 |
+
|
207 |
+
def load_diffusion_model(ckpt_path, latent_size, device):
|
208 |
+
"""Load the diffusion model (DiT)."""
|
209 |
+
logging.info("Loading diffusion model (DiT)...")
|
210 |
+
|
211 |
+
DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device)
|
212 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
213 |
+
|
214 |
+
# Remove keys that can cause mismatches
|
215 |
+
for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']:
|
216 |
+
ckpt['state_dict'].pop(key, None)
|
217 |
+
ckpt.get('state_dict_ema', {}).pop(key, None)
|
218 |
+
|
219 |
+
state_dict = ckpt.get('state_dict_ema', ckpt)
|
220 |
+
DiT_model.load_state_dict(state_dict, strict=False)
|
221 |
+
DiT_model.eval()
|
222 |
+
logging.info("Diffusion model (DiT) loaded.")
|
223 |
+
return DiT_model
|
224 |
+
|
225 |
+
|
226 |
+
def load_vae_clip_dino(config, device):
|
227 |
+
"""Load VAE, CLIP, and DINO models."""
|
228 |
+
logging.info("Loading VAE, CLIP, and DINO models...")
|
229 |
+
|
230 |
+
# Load CLIP image encoder
|
231 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
232 |
+
config.image_encoder_path)
|
233 |
+
image_encoder.requires_grad_(False)
|
234 |
+
image_encoder.to(device)
|
235 |
+
|
236 |
+
# Load VAE
|
237 |
+
config_vae = OmegaConf.load(config.vae_triplane_config_path)
|
238 |
+
vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8)
|
239 |
+
vae_triplane.to(device)
|
240 |
+
|
241 |
+
vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin')
|
242 |
+
if not os.path.isfile(vae_ckpt_path):
|
243 |
+
raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}")
|
244 |
+
|
245 |
+
vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"))
|
246 |
+
vae_triplane.requires_grad_(False)
|
247 |
+
|
248 |
+
# Load DINO model
|
249 |
+
dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained)
|
250 |
+
dinov2.requires_grad_(False)
|
251 |
+
dinov2.to(device)
|
252 |
+
|
253 |
+
# Load image processors
|
254 |
+
dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained)
|
255 |
+
clip_image_processor = CLIPImageProcessor()
|
256 |
+
|
257 |
+
logging.info("VAE, CLIP, and DINO models loaded.")
|
258 |
+
return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor
|
259 |
+
|
260 |
+
|
261 |
+
def prepare_working_dir(dir, style):
|
262 |
+
print('stylestylestylestylestylestylestyle',style)
|
263 |
+
if style:
|
264 |
+
return dir
|
265 |
+
else:
|
266 |
+
import tempfile
|
267 |
+
working_dir = tempfile.TemporaryDirectory()
|
268 |
+
return working_dir.name
|
269 |
+
|
270 |
+
|
271 |
+
def launch_pretrained():
|
272 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
273 |
+
hf_hub_download(repo_id="KumaPower/AvatarArtist", repo_type='model', local_dir="./pretrained_model")
|
274 |
+
|
275 |
+
|
276 |
+
def prepare_image_list(img_dir, selected_img):
|
277 |
+
"""Prepare the list of image paths for processing."""
|
278 |
+
if selected_img and selected_img in os.listdir(img_dir):
|
279 |
+
return [os.path.join(img_dir, selected_img)]
|
280 |
+
|
281 |
+
return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)])
|
282 |
+
|
283 |
+
|
284 |
+
def images_to_video(image_folder, output_video, fps=30):
|
285 |
+
# Get all image files and ensure correct order
|
286 |
+
images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]
|
287 |
+
images = natsorted(images) # Sort filenames naturally to preserve frame order
|
288 |
+
|
289 |
+
if not images:
|
290 |
+
print("❌ No images found in the directory!")
|
291 |
+
return
|
292 |
+
|
293 |
+
# Get the path to the FFmpeg executable
|
294 |
+
ffmpeg_exe = ffmpeg.get_ffmpeg_exe()
|
295 |
+
print(f"Using FFmpeg from: {ffmpeg_exe}")
|
296 |
+
|
297 |
+
# Define input image pattern (expects images named like "%04d.png")
|
298 |
+
image_pattern = os.path.join(image_folder, "%04d.png")
|
299 |
+
|
300 |
+
# FFmpeg command to encode video
|
301 |
+
command = [
|
302 |
+
ffmpeg_exe, '-framerate', str(fps), '-i', image_pattern,
|
303 |
+
'-c:v', 'libx264', '-preset', 'slow', '-crf', '18', # High-quality H.264 encoding
|
304 |
+
'-pix_fmt', 'yuv420p', '-b:v', '5000k', # Ensure compatibility & increase bitrate
|
305 |
+
output_video
|
306 |
+
]
|
307 |
+
|
308 |
+
# Run FFmpeg command
|
309 |
+
subprocess.run(command, check=True)
|
310 |
+
|
311 |
+
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
312 |
+
|
313 |
+
|
314 |
+
def model_define():
|
315 |
+
args = get_args()
|
316 |
+
set_env(args.seed)
|
317 |
+
input_process_model = Process(cfg)
|
318 |
+
|
319 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
320 |
+
weight_dtype = torch.float32
|
321 |
+
logging.info(f"Running inference with {weight_dtype}")
|
322 |
+
|
323 |
+
# Load configuration
|
324 |
+
default_config = read_config(args.config)
|
325 |
+
|
326 |
+
# Ensure valid sampling algorithm
|
327 |
+
assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver']
|
328 |
+
# Load motion-aware render model
|
329 |
+
motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt, device)
|
330 |
+
|
331 |
+
# Load diffusion model (DiT)
|
332 |
+
triplane_size = (256 * 4, 256)
|
333 |
+
latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
|
334 |
+
sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[
|
335 |
+
args.sampling_algo]
|
336 |
+
DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size, device)
|
337 |
+
|
338 |
+
# Load VAE, CLIP, and DINO
|
339 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config,
|
340 |
+
device)
|
341 |
+
|
342 |
+
# Load normalization parameters
|
343 |
+
triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1)
|
344 |
+
triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1)
|
345 |
+
|
346 |
+
# Load average latent vector
|
347 |
+
ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0]
|
348 |
+
|
349 |
+
# Set up face verse for amimation
|
350 |
+
base_coff = np.load(
|
351 |
+
'pretrained_model/temp.npy').astype(
|
352 |
+
np.float32)
|
353 |
+
base_coff = torch.from_numpy(base_coff).float()
|
354 |
+
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
355 |
+
|
356 |
+
return motion_aware_render_model, sample_steps, DiT_model, \
|
357 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
|
358 |
+
|
359 |
+
|
360 |
+
def duplicate_batch(tensor, batch_size=2):
|
361 |
+
if tensor is None:
|
362 |
+
return None # 如果是 None,则直接返回
|
363 |
+
return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
|
364 |
+
|
365 |
+
|
366 |
+
@torch.inference_mode()
|
367 |
+
@spaces.GPU(duration=200)
|
368 |
+
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
|
369 |
+
"""
|
370 |
+
Generate avatars from input images.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
items (list): List of image paths.
|
374 |
+
bs (int): Batch size.
|
375 |
+
sample_steps (int): Number of sampling steps.
|
376 |
+
cfg_scale (float): Classifier-free guidance scale.
|
377 |
+
save_path_base (str): Base directory for saving results.
|
378 |
+
DiT_model (torch.nn.Module): The diffusion model.
|
379 |
+
render_model (torch.nn.Module): The rendering model.
|
380 |
+
std (torch.Tensor): Standard deviation normalization tensor.
|
381 |
+
mean (torch.Tensor): Mean normalization tensor.
|
382 |
+
ws_avg (torch.Tensor): Latent average tensor.
|
383 |
+
"""
|
384 |
+
if is_styled:
|
385 |
+
items = [styled_img]
|
386 |
+
else:
|
387 |
+
items = [items]
|
388 |
+
video_folder = "./demo_data/target_video"
|
389 |
+
video_name = os.path.basename(video_path_input).split(".")[0]
|
390 |
+
target_path = os.path.join(video_folder, 'data_' + video_name)
|
391 |
+
exp_base_dir = os.path.join(target_path, 'coeffs')
|
392 |
+
exp_img_base_dir = os.path.join(target_path, 'images512x512')
|
393 |
+
motion_base_dir = os.path.join(target_path, 'motions')
|
394 |
+
label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
|
395 |
+
|
396 |
+
if source_type == 'example':
|
397 |
+
input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/trained_input_imgs'
|
398 |
+
input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/trained_input_imgs'
|
399 |
+
elif source_type == 'custom':
|
400 |
+
input_img_fvid = os.path.join(save_path_base, 'processed_img/dataset/coeffs/input_image')
|
401 |
+
input_img_motion = os.path.join(save_path_base, 'processed_img/dataset/motions/input_image')
|
402 |
+
else:
|
403 |
+
raise ValueError("Wrong type")
|
404 |
+
bs = 1
|
405 |
+
sample_steps = 20
|
406 |
+
cfg_scale = 4.5
|
407 |
+
pitch_range = 0.25
|
408 |
+
yaw_range = 0.35
|
409 |
+
triplane_size = (256 * 4, 256)
|
410 |
+
latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
|
411 |
+
for chunk in tqdm(list(get_chunks(items, 1)), unit='batch'):
|
412 |
+
if bs != 1:
|
413 |
+
raise ValueError("Batch size > 1 not implemented")
|
414 |
+
|
415 |
+
image_dir = chunk[0]
|
416 |
+
|
417 |
+
image_name = os.path.splitext(os.path.basename(image_dir))[0]
|
418 |
+
dino_img, clip_image = image_process(image_dir, clip_image_processor, dino_img_processor, device)
|
419 |
+
|
420 |
+
clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
421 |
+
uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
422 |
+
-2]
|
423 |
+
dino_feature = dinov2(dino_img).last_hidden_state
|
424 |
+
uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
|
425 |
+
|
426 |
+
samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
|
427 |
+
uncond_clip_feature, uncond_dino_feature, device, latent_size,
|
428 |
+
'dpm-solver')
|
429 |
+
|
430 |
+
samples = (samples / 0.3994218)
|
431 |
+
samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
|
432 |
+
samples = vae_triplane.decode(samples)
|
433 |
+
samples = rearrange(samples, "b c f h w -> b f c h w")
|
434 |
+
samples = samples * std + mean
|
435 |
+
torch.cuda.empty_cache()
|
436 |
+
|
437 |
+
save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
|
438 |
+
save_frames_path_outshow = os.path.join(save_path_base, image_name, 'out_show')
|
439 |
+
save_frames_path_depth = os.path.join(save_path_base, image_name, 'depth')
|
440 |
+
|
441 |
+
os.makedirs(save_frames_path_out, exist_ok=True)
|
442 |
+
os.makedirs(save_frames_path_outshow, exist_ok=True)
|
443 |
+
os.makedirs(save_frames_path_depth, exist_ok=True)
|
444 |
+
|
445 |
+
img_ref = np.array(Image.open(image_dir))
|
446 |
+
img_ref_out = img_ref.copy()
|
447 |
+
img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
|
448 |
+
|
449 |
+
motion_app_dir = os.path.join(input_img_motion, image_name + '.npy')
|
450 |
+
motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
|
451 |
+
|
452 |
+
id_motions = os.path.join(input_img_fvid, image_name + '.npy')
|
453 |
+
|
454 |
+
all_pose = json.loads(open(label_file_test).read())['labels']
|
455 |
+
all_pose = dict(all_pose)
|
456 |
+
if os.path.exists(id_motions):
|
457 |
+
coeff = np.load(id_motions).astype(np.float32)
|
458 |
+
coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
|
459 |
+
Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
|
460 |
+
motion_dir = os.path.join(motion_base_dir, video_name)
|
461 |
+
exp_dir = os.path.join(exp_base_dir, video_name)
|
462 |
+
for frame_index, motion_name in enumerate(
|
463 |
+
tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
|
464 |
+
exp_each_dir_img = os.path.join(exp_img_base_dir, video_name, motion_name.replace('.npy', '.png'))
|
465 |
+
exp_each_dir = os.path.join(exp_dir, motion_name)
|
466 |
+
motion_each_dir = os.path.join(motion_dir, motion_name)
|
467 |
+
|
468 |
+
# Load pose data
|
469 |
+
pose_key = os.path.join(video_name, motion_name.replace('.npy', '.png'))
|
470 |
+
|
471 |
+
cam2world_pose = LookAtPoseSampler.sample(
|
472 |
+
3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
|
473 |
+
3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
|
474 |
+
torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
|
475 |
+
pose_show = torch.cat([cam2world_pose.reshape(-1, 16),
|
476 |
+
FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
|
477 |
+
|
478 |
+
pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
|
479 |
+
|
480 |
+
# Load and resize expression image
|
481 |
+
exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
|
482 |
+
|
483 |
+
# Load expression coefficients
|
484 |
+
exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
|
485 |
+
exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
|
486 |
+
|
487 |
+
# Load motion data
|
488 |
+
motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
|
489 |
+
|
490 |
+
img_ref_double = duplicate_batch(img_ref, batch_size=2)
|
491 |
+
motion_app_double = duplicate_batch(motion_app, batch_size=2)
|
492 |
+
motion_double = duplicate_batch(motion, batch_size=2)
|
493 |
+
pose_double = torch.cat([pose_show, pose], dim=0)
|
494 |
+
exp_target_double = duplicate_batch(exp_target, batch_size=2)
|
495 |
+
samples_double = duplicate_batch(samples, batch_size=2)
|
496 |
+
# Select refine_net processing method
|
497 |
+
final_out = render_model(
|
498 |
+
img_ref_double, None, motion_app_double, motion_double, c=pose_double, mesh=exp_target_double,
|
499 |
+
triplane_recon=samples_double,
|
500 |
+
ws_avg=ws_avg, motion_scale=1.
|
501 |
+
)
|
502 |
+
|
503 |
+
# Process output image
|
504 |
+
final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
|
505 |
+
final_out_notshow = trans(final_out['image_sr'][1].unsqueeze(0))
|
506 |
+
depth = final_out['image_depth'][0].unsqueeze(0)
|
507 |
+
depth = -depth
|
508 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1
|
509 |
+
depth = trans(depth)
|
510 |
+
|
511 |
+
depth = np.repeat(depth[:, :, :], 3, axis=2)
|
512 |
+
# Save output images
|
513 |
+
frame_name = f'{str(frame_index).zfill(4)}.png'
|
514 |
+
Image.fromarray(depth, 'RGB').save(os.path.join(save_frames_path_depth, frame_name))
|
515 |
+
Image.fromarray(final_out_notshow, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
|
516 |
+
|
517 |
+
Image.fromarray(final_out_show, 'RGB').save(os.path.join(save_frames_path_outshow, frame_name))
|
518 |
+
|
519 |
+
# Generate videos
|
520 |
+
images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4'))
|
521 |
+
images_to_video(save_frames_path_outshow, os.path.join(save_path_base, image_name + '_outshow.mp4'))
|
522 |
+
images_to_video(save_frames_path_depth, os.path.join(save_path_base, image_name + '_depth.mp4'))
|
523 |
+
|
524 |
+
logging.info(f"✅ Video generation completed successfully!")
|
525 |
+
return os.path.join(save_path_base, image_name + '_out.mp4'), os.path.join(save_path_base,
|
526 |
+
image_name + '_outshow.mp4'), os.path.join(save_path_base, image_name + '_depth.mp4')
|
527 |
+
|
528 |
+
|
529 |
+
def get_image_base64(path):
|
530 |
+
with open(path, "rb") as image_file:
|
531 |
+
encoded_string = base64.b64encode(image_file.read()).decode()
|
532 |
+
return f"data:image/png;base64,{encoded_string}"
|
533 |
+
|
534 |
+
|
535 |
+
def assert_input_image(input_image):
|
536 |
+
if input_image is None:
|
537 |
+
raise gr.Error("No image selected or uploaded!")
|
538 |
+
|
539 |
+
|
540 |
+
def process_image(input_image, source_type, is_style, save_dir):
|
541 |
+
""" 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
|
542 |
+
process_img_input_dir = os.path.join(save_dir, 'input_image')
|
543 |
+
process_img_save_dir = os.path.join(save_dir, 'processed_img')
|
544 |
+
os.makedirs(process_img_save_dir, exist_ok=True)
|
545 |
+
os.makedirs(process_img_input_dir, exist_ok=True)
|
546 |
+
if source_type == "example":
|
547 |
+
return input_image, source_type
|
548 |
+
else:
|
549 |
+
# input_process_model.inference(input_image, process_img_save_dir)
|
550 |
+
shutil.copy(input_image, process_img_input_dir)
|
551 |
+
input_process_model.inference(process_img_input_dir, process_img_save_dir, is_img=True, is_video=False)
|
552 |
+
img_name = os.path.basename(input_image)
|
553 |
+
imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
|
554 |
+
return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
|
555 |
+
|
556 |
+
|
557 |
+
def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
|
558 |
+
"""
|
559 |
+
🎭 这个函数用于风格转换
|
560 |
+
✅ 你可以在这里填入你的风格化代码
|
561 |
+
"""
|
562 |
+
src_img_pil = Image.open(processed_image)
|
563 |
+
img_name = os.path.basename(processed_image)
|
564 |
+
save_dir = os.path.join(save_base, 'style_img')
|
565 |
+
os.makedirs(save_dir, exist_ok=True)
|
566 |
+
control_image = generate_annotation(src_img_pil, max_faces=1)
|
567 |
+
trg_img_pil = pipeline_sd(
|
568 |
+
prompt=style_prompt,
|
569 |
+
image=src_img_pil,
|
570 |
+
strength=strength,
|
571 |
+
control_image=Image.fromarray(control_image),
|
572 |
+
guidance_scale=cfg,
|
573 |
+
negative_prompt='worst quality, normal quality, low quality, low res, blurry',
|
574 |
+
num_inference_steps=30,
|
575 |
+
controlnet_conditioning_scale=1.5
|
576 |
+
)['images'][0]
|
577 |
+
trg_img_pil.save(os.path.join(save_dir, img_name))
|
578 |
+
return os.path.join(save_dir, img_name) # 🚨 这里需要替换成你的风格转换逻辑
|
579 |
+
|
580 |
+
|
581 |
+
def reset_flag():
|
582 |
+
return False
|
583 |
+
css = """
|
584 |
+
/* ✅ 让所有 Image 居中 + 自适应宽度 */
|
585 |
+
.gr-image img {
|
586 |
+
display: block;
|
587 |
+
margin-left: auto;
|
588 |
+
margin-right: auto;
|
589 |
+
max-width: 100%;
|
590 |
+
height: auto;
|
591 |
+
}
|
592 |
+
|
593 |
+
/* ✅ 让所有 Video 居中 + 自适应宽度 */
|
594 |
+
.gr-video video {
|
595 |
+
display: block;
|
596 |
+
margin-left: auto;
|
597 |
+
margin-right: auto;
|
598 |
+
max-width: 100%;
|
599 |
+
height: auto;
|
600 |
+
}
|
601 |
+
|
602 |
+
/* ✅ 可选:让按钮和 markdown 居中 */
|
603 |
+
#generate_block {
|
604 |
+
display: flex;
|
605 |
+
flex-direction: column;
|
606 |
+
align-items: center;
|
607 |
+
justify-content: center;
|
608 |
+
margin-top: 1rem;
|
609 |
+
}
|
610 |
+
|
611 |
+
|
612 |
+
/* 可选:让整个容器宽一点 */
|
613 |
+
#main_container {
|
614 |
+
max-width: 1280px; /* ✅ 例如限制在 1280px 内 */
|
615 |
+
margin-left: auto; /* ✅ 水平居中 */
|
616 |
+
margin-right: auto;
|
617 |
+
padding-left: 1rem;
|
618 |
+
padding-right: 1rem;
|
619 |
+
}
|
620 |
+
|
621 |
+
"""
|
622 |
+
|
623 |
+
def launch_gradio_app():
|
624 |
+
styles = {
|
625 |
+
"Ghibli": "Ghibli style avatar, anime style",
|
626 |
+
"Pixar": "a 3D render of a face in Pixar style",
|
627 |
+
"Lego": "a 3D render of a head of a lego man 3D model",
|
628 |
+
"Greek Statue": "a FHD photo of a white Greek statue",
|
629 |
+
"Elf": "a FHD photo of a face of a beautiful elf with silver hair in live action movie",
|
630 |
+
"Zombie": "a FHD photo of a face of a zombie",
|
631 |
+
"Tekken": "a 3D render of a Tekken game character",
|
632 |
+
"Devil": "a FHD photo of a face of a devil in fantasy movie",
|
633 |
+
"Steampunk": "Steampunk style portrait, mechanical, brass and copper tones",
|
634 |
+
"Mario": "a 3D render of a face of Super Mario",
|
635 |
+
"Orc": "a FHD photo of a face of an orc in fantasy movie",
|
636 |
+
"Masque": "a FHD photo of a face of a person in masquerade",
|
637 |
+
"Skeleton": "a FHD photo of a face of a skeleton in fantasy movie",
|
638 |
+
"Peking Opera": "a FHD photo of face of character in Peking opera with heavy make-up",
|
639 |
+
"Yoda": "a FHD photo of a face of Yoda in Star Wars",
|
640 |
+
"Hobbit": "a FHD photo of a face of Hobbit in Lord of the Rings",
|
641 |
+
"Stained Glass": "Stained glass style, portrait, beautiful, translucent",
|
642 |
+
"Graffiti": "Graffiti style portrait, street art, vibrant, urban, detailed, tag",
|
643 |
+
"Pixel-art": "pixel art style portrait, low res, blocky, pixel art style",
|
644 |
+
"Retro": "Retro game art style portrait, vibrant colors",
|
645 |
+
"Ink": "a portrait in ink style, black and white image",
|
646 |
+
}
|
647 |
+
|
648 |
+
with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600], css=css, elem_id="main_container") as demo:
|
649 |
+
logo_url = "./docs/AvatarArtist.png"
|
650 |
+
logo_base64 = get_image_base64(logo_url)
|
651 |
+
# 🚀 让 Logo 居中 & 标题对齐
|
652 |
+
gr.HTML(
|
653 |
+
f"""
|
654 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-bottom: 20px;">
|
655 |
+
<img src="{logo_base64}" style="height:50px; margin-right: 15px; display: block;" onerror="this.style.display='none'"/>
|
656 |
+
<h1 style="font-size: 32px; font-weight: bold;">AvatarArtist: Open-Domain 4D Avatarization</h1>
|
657 |
+
</div>
|
658 |
+
"""
|
659 |
+
)
|
660 |
+
|
661 |
+
# 🚀 让按钮在一行对齐
|
662 |
+
gr.HTML(
|
663 |
+
"""
|
664 |
+
<div style="display: flex; justify-content: center; gap: 10px; margin-top: 10px;">
|
665 |
+
<a title="Website" href="https://kumapowerliu.github.io/AvatarArtist/" target="_blank" rel="noopener noreferrer">
|
666 |
+
<img src="https://img.shields.io/badge/Website-Visit-blue?style=for-the-badge&logo=GoogleChrome">
|
667 |
+
</a>
|
668 |
+
<a title="arXiv" href="https://arxiv.org/abs/2503.19906" target="_blank" rel="noopener noreferrer">
|
669 |
+
<img src="https://img.shields.io/badge/arXiv-Paper-red?style=for-the-badge&logo=arXiv">
|
670 |
+
</a>
|
671 |
+
<a title="Github" href="https://github.com/ant-research/AvatarArtist" target="_blank" rel="noopener noreferrer">
|
672 |
+
<img src="https://img.shields.io/github/stars/ant-research/AvatarArtist?style=for-the-badge&logo=github&logoColor=white&color=orange">
|
673 |
+
</a>
|
674 |
+
</div>
|
675 |
+
"""
|
676 |
+
)
|
677 |
+
gr.HTML(
|
678 |
+
"""
|
679 |
+
<div style="text-align: left; font-size: 16px; line-height: 1.6; margin-top: 20px; padding: 10px; border: 1px solid #ddd; border-radius: 8px; background-color: #f9f9f9;">
|
680 |
+
<strong>🧑🎨 How to use this demo:</strong>
|
681 |
+
<ol style="margin-top: 10px; padding-left: 20px;">
|
682 |
+
<li><strong>Select or upload a source image</strong> – this will be the avatar's face.</li>
|
683 |
+
<li><strong>Select or upload a target video</strong> – the avatar will mimic this motion.</li>
|
684 |
+
<li><strong>Click the <em>Process Image</em> button</strong> – this prepares the source image to meet our model's input requirements.</li>
|
685 |
+
<li><strong>(Optional)</strong> Click <em>Apply Style</em> to change the appearance of the processed image – we offer a variety of fun styles to choose from!</li>
|
686 |
+
<li><strong>Click <em>Generate Avatar</em></strong> to create the final animated result driven by the target video.</li>
|
687 |
+
</ol>
|
688 |
+
<p style="margin-top: 10px;"><strong>🎨 Tip:</strong> Try different styles to get various artistic effects for your avatar!</p>
|
689 |
+
</div>
|
690 |
+
"""
|
691 |
+
)
|
692 |
+
# 🚀 添加重要提示框
|
693 |
+
gr.HTML(
|
694 |
+
"""
|
695 |
+
<div style="background-color: #FFDDDD; padding: 15px; border-radius: 10px; border: 2px solid red; text-align: center; margin-top: 20px;">
|
696 |
+
<h4 style="color: red; font-size: 18px;">
|
697 |
+
🚨 <strong>Important Notes:</strong> Please try to provide a <u>front-facing</u> or <u>full-face</u> image without obstructions.
|
698 |
+
</h4>
|
699 |
+
<p style="color: black; font-size: 16px;">
|
700 |
+
❌ Our demo does <strong>not</strong> support uploading videos with specific motions because processing requires time.<br>
|
701 |
+
✅ Feel free to check out our <a href="https://github.com/ant-research/AvatarArtist" target="_blank" style="color: red; font-weight: bold;">GitHub repository</a> to drive portraits using your desired motions.
|
702 |
+
</p>
|
703 |
+
</div>
|
704 |
+
"""
|
705 |
+
)
|
706 |
+
# DISPLAY
|
707 |
+
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/trained_input_imgs"
|
708 |
+
video_folder = "./demo_data/target_video"
|
709 |
+
|
710 |
+
examples_images = sorted(
|
711 |
+
[os.path.join(image_folder, f) for f in os.listdir(image_folder) if
|
712 |
+
f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
713 |
+
)
|
714 |
+
examples_videos = sorted(
|
715 |
+
[os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.lower().endswith('.mp4')]
|
716 |
+
)
|
717 |
+
print(examples_videos)
|
718 |
+
source_type = gr.State("example")
|
719 |
+
is_from_example = gr.State(value=True)
|
720 |
+
is_styled = gr.State(value=False)
|
721 |
+
working_dir = gr.State()
|
722 |
+
|
723 |
+
with gr.Row():
|
724 |
+
with gr.Column(variant='panel'):
|
725 |
+
with gr.Tabs(elem_id="input_image"):
|
726 |
+
with gr.TabItem('🎨 Upload Image'):
|
727 |
+
input_image = gr.Image(
|
728 |
+
label="Upload Source Image",
|
729 |
+
value=os.path.join(image_folder, '02057_(2).png'),
|
730 |
+
image_mode="RGB", height=512, container=True,
|
731 |
+
sources="upload", type="filepath"
|
732 |
+
)
|
733 |
+
|
734 |
+
def mark_as_example(example_image):
|
735 |
+
print("✅ mark_as_example called")
|
736 |
+
return "example", True, False
|
737 |
+
|
738 |
+
def mark_as_custom(user_image, is_from_example_flag):
|
739 |
+
print("✅ mark_as_custom called")
|
740 |
+
if is_from_example_flag:
|
741 |
+
print("⚠️ Ignored mark_as_custom triggered by example")
|
742 |
+
return "example", False, False
|
743 |
+
return "custom", False, False
|
744 |
+
|
745 |
+
input_image.change(
|
746 |
+
mark_as_custom,
|
747 |
+
inputs=[input_image, is_from_example],
|
748 |
+
outputs=[source_type, is_from_example, is_styled] # ✅ 只返回 source_type,不要输出 input_image
|
749 |
+
)
|
750 |
+
|
751 |
+
# ✅ 让 `Examples` 组件单独占一行,并绑定点击事件
|
752 |
+
with gr.Row():
|
753 |
+
example_component = gr.Examples(
|
754 |
+
examples=examples_images,
|
755 |
+
inputs=[input_image],
|
756 |
+
examples_per_page=10,
|
757 |
+
)
|
758 |
+
# ✅ 监听 `Examples` 的 `click` 事件
|
759 |
+
example_component.dataset.click(
|
760 |
+
fn=mark_as_example,
|
761 |
+
inputs=[input_image],
|
762 |
+
outputs=[source_type, is_from_example, is_styled]
|
763 |
+
)
|
764 |
+
|
765 |
+
with gr.Column(variant='panel' ):
|
766 |
+
with gr.Tabs(elem_id="input_video"):
|
767 |
+
with gr.TabItem('🎬 Target Video'):
|
768 |
+
video_input = gr.Video(
|
769 |
+
label="Select Target Motion",
|
770 |
+
height=512, container=True,interactive=False, format="mp4",
|
771 |
+
value=examples_videos[0]
|
772 |
+
)
|
773 |
+
|
774 |
+
with gr.Row():
|
775 |
+
gr.Examples(
|
776 |
+
examples=examples_videos,
|
777 |
+
inputs=[video_input],
|
778 |
+
examples_per_page=10,
|
779 |
+
)
|
780 |
+
with gr.Column(variant='panel' ):
|
781 |
+
with gr.Tabs(elem_id="processed_image"):
|
782 |
+
with gr.TabItem('🖼️ Processed Image'):
|
783 |
+
processed_image = gr.Image(
|
784 |
+
label="Processed Image",
|
785 |
+
image_mode="RGB", type="filepath",
|
786 |
+
elem_id="processed_image",
|
787 |
+
height=512, container=True,
|
788 |
+
interactive=False
|
789 |
+
)
|
790 |
+
processed_image_button = gr.Button("🔧 Process Image", variant="primary")
|
791 |
+
with gr.Column(variant='panel' ):
|
792 |
+
with gr.Tabs(elem_id="style_transfer"):
|
793 |
+
with gr.TabItem('🎭 Style Transfer'):
|
794 |
+
style_image = gr.Image(
|
795 |
+
label="Style Image",
|
796 |
+
image_mode="RGB", type="filepath",
|
797 |
+
elem_id="style_image",
|
798 |
+
height=512, container=True,
|
799 |
+
interactive=False
|
800 |
+
)
|
801 |
+
style_choice = gr.Dropdown(
|
802 |
+
choices=list(styles.keys()),
|
803 |
+
label="Choose Style",
|
804 |
+
value="Pixar"
|
805 |
+
)
|
806 |
+
cfg_slider = gr.Slider(
|
807 |
+
minimum=3.0, maximum=10.0, value=7.5, step=0.1,
|
808 |
+
label="CFG Scale"
|
809 |
+
)
|
810 |
+
strength_slider = gr.Slider(
|
811 |
+
minimum=0.4, maximum=0.85, value=0.65, step=0.05,
|
812 |
+
label="SDEdit Strength"
|
813 |
+
)
|
814 |
+
style_button = gr.Button("🎨 Apply Style", interactive=False)
|
815 |
+
gr.Markdown(
|
816 |
+
"⬅️ Please click **Process Image** first. "
|
817 |
+
"**Apply Style** will transform the image in the **Processed Image** panel "
|
818 |
+
"according to the selected style."
|
819 |
+
)
|
820 |
+
|
821 |
+
|
822 |
+
with gr.Row():
|
823 |
+
with gr.Tabs(elem_id="render_output"):
|
824 |
+
with gr.TabItem('🎥 Animation Results'):
|
825 |
+
# ✅ 让 `Generate Avatar` 按钮单独占一行
|
826 |
+
with gr.Row():
|
827 |
+
with gr.Column(scale=1, elem_id="generate_block", min_width=200):
|
828 |
+
submit = gr.Button('🚀 Generate Avatar', elem_id="avatarartist_generate", variant='primary',
|
829 |
+
interactive=False)
|
830 |
+
gr.Markdown("⬇️ Please click **Process Image** first before generating.",
|
831 |
+
elem_id="generate_tip")
|
832 |
+
|
833 |
+
# ✅ 让两个 `Animation Results` 窗口并排
|
834 |
+
with gr.Row():
|
835 |
+
output_video = gr.Video(
|
836 |
+
label="Generated Animation Input Video View",
|
837 |
+
format="mp4", height=512, width=512,
|
838 |
+
autoplay=True
|
839 |
+
)
|
840 |
+
|
841 |
+
output_video_2 = gr.Video(
|
842 |
+
label="Generated Animation Rotate View",
|
843 |
+
format="mp4", height=512, width=512,
|
844 |
+
autoplay=True
|
845 |
+
)
|
846 |
+
|
847 |
+
output_video_3 = gr.Video(
|
848 |
+
label="Generated Animation Rotate View Depth",
|
849 |
+
format="mp4", height=512, width=512,
|
850 |
+
autoplay=True
|
851 |
+
)
|
852 |
+
def apply_style_and_mark(processed_image, style_choice, cfg, strength, working_dir):
|
853 |
+
styled = style_transfer(processed_image, styles[style_choice], cfg, strength, working_dir)
|
854 |
+
return styled, True
|
855 |
+
|
856 |
+
def process_image_and_enable_style(input_image, source_type, is_styled, wd):
|
857 |
+
processed_result, updated_source_type = process_image(input_image, source_type, is_styled, wd)
|
858 |
+
return processed_result, updated_source_type, gr.update(interactive=True), gr.update(interactive=True)
|
859 |
+
processed_image_button.click(
|
860 |
+
fn=prepare_working_dir,
|
861 |
+
inputs=[working_dir, is_styled],
|
862 |
+
outputs=[working_dir],
|
863 |
+
queue=False,
|
864 |
+
).success(
|
865 |
+
fn=process_image_and_enable_style,
|
866 |
+
inputs=[input_image, source_type, is_styled, working_dir],
|
867 |
+
outputs=[processed_image, source_type, style_button, submit],
|
868 |
+
queue=True
|
869 |
+
)
|
870 |
+
style_button.click(
|
871 |
+
fn=apply_style_and_mark,
|
872 |
+
inputs=[processed_image, style_choice, cfg_slider, strength_slider, working_dir],
|
873 |
+
outputs=[style_image, is_styled]
|
874 |
+
)
|
875 |
+
submit.click(
|
876 |
+
fn=avatar_generation,
|
877 |
+
inputs=[processed_image, working_dir, video_input, source_type, is_styled, style_image],
|
878 |
+
outputs=[output_video, output_video_2, output_video_3], # ⏳ 稍后展示视频
|
879 |
+
queue=True
|
880 |
+
)
|
881 |
+
|
882 |
+
|
883 |
+
demo.queue()
|
884 |
+
demo.launch(server_name="0.0.0.0")
|
885 |
+
|
886 |
+
|
887 |
+
if __name__ == '__main__':
|
888 |
+
import torch.multiprocessing as mp
|
889 |
+
|
890 |
+
mp.set_start_method('spawn', force=True)
|
891 |
+
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/trained_input_imgs"
|
892 |
+
example_img_names = os.listdir(image_folder)
|
893 |
+
render_model, sample_steps, DiT_model, \
|
894 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
|
895 |
+
controlnet_path = '/nas8/liuhongyu/model/ControlNetMediaPipeFaceold'
|
896 |
+
controlnet = ControlNetModel.from_pretrained(
|
897 |
+
controlnet_path, torch_dtype=torch.float16
|
898 |
+
)
|
899 |
+
sd_path = '/nas8/liuhongyu/model/stable-diffusion-2-1-base'
|
900 |
+
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
901 |
+
sd_path, torch_dtype=torch.float16,
|
902 |
+
use_safetensors=True, controlnet=controlnet, variant="fp16"
|
903 |
+
).to(device)
|
904 |
+
demo_cam = False
|
905 |
+
launch_gradio_app()
|
inference.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
import logging
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from PIL import Image
|
14 |
+
from tqdm import tqdm
|
15 |
+
from natsort import natsorted, ns
|
16 |
+
from einops import rearrange
|
17 |
+
from omegaconf import OmegaConf
|
18 |
+
from huggingface_hub import snapshot_download
|
19 |
+
|
20 |
+
from transformers import (
|
21 |
+
Dinov2Model, CLIPImageProcessor, CLIPVisionModelWithProjection, AutoImageProcessor
|
22 |
+
)
|
23 |
+
from Next3d.training_avatar_texture.camera_utils import LookAtPoseSampler, FOV_to_intrinsics
|
24 |
+
|
25 |
+
from data_process.lib.FaceVerse.renderer import Faceverse_manager
|
26 |
+
import recon.dnnlib as dnnlib
|
27 |
+
import recon.legacy as legacy
|
28 |
+
|
29 |
+
from DiT_VAE.diffusion.utils.misc import read_config
|
30 |
+
from DiT_VAE.vae.triplane_vae import AutoencoderKL as AutoencoderKLTriplane
|
31 |
+
from DiT_VAE.diffusion import IDDPM, DPMS
|
32 |
+
from DiT_VAE.diffusion.model.nets import TriDitCLIPDINO_XL_2
|
33 |
+
from DiT_VAE.diffusion.data.datasets import get_chunks
|
34 |
+
|
35 |
+
# Get the directory of the current script
|
36 |
+
father_path = os.path.dirname(os.path.abspath(__file__))
|
37 |
+
|
38 |
+
# Add necessary paths dynamically
|
39 |
+
sys.path.extend([
|
40 |
+
os.path.join(father_path, 'recon'),
|
41 |
+
os.path.join(father_path, 'Next3d')
|
42 |
+
])
|
43 |
+
|
44 |
+
# Suppress warnings (especially for PyTorch)
|
45 |
+
warnings.filterwarnings("ignore")
|
46 |
+
|
47 |
+
# Configure logging settings
|
48 |
+
logging.basicConfig(
|
49 |
+
level=logging.INFO,
|
50 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def get_args():
|
55 |
+
"""Parse and return command-line arguments."""
|
56 |
+
parser = argparse.ArgumentParser(description="4D Triplane Generation Arguments")
|
57 |
+
|
58 |
+
# Configuration and model checkpoints
|
59 |
+
parser.add_argument("--config", type=str, default="./configs/infer_config.py",
|
60 |
+
help="Path to the configuration file.")
|
61 |
+
|
62 |
+
# Input data paths
|
63 |
+
parser.add_argument("--target_path", type=str, required=True, default='./demo_data/target_video/data_obama',
|
64 |
+
help="Base path of the dataset.")
|
65 |
+
parser.add_argument("--img_file", type=str, required=True, default='./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs',
|
66 |
+
help="Directory containing input images.")
|
67 |
+
parser.add_argument("--input_img_motion", type=str,
|
68 |
+
default="./demo_data/source_img/img_generate_different_domain/motions/demo_imgs",
|
69 |
+
help="Directory containing motion features.")
|
70 |
+
parser.add_argument("--video_name", type=str, required=True, default='Obama',
|
71 |
+
help="Name of the video.")
|
72 |
+
parser.add_argument("--input_img_fvid", type=str,
|
73 |
+
default="./demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs",
|
74 |
+
help="Path to input image coefficients.")
|
75 |
+
|
76 |
+
# Output settings
|
77 |
+
parser.add_argument("--output_basedir", type=str, default="./output",
|
78 |
+
help="Base directory for saving output results.")
|
79 |
+
|
80 |
+
# Generation parameters
|
81 |
+
parser.add_argument("--bs", type=int, default=1,
|
82 |
+
help="Batch size for processing.")
|
83 |
+
parser.add_argument("--cfg_scale", type=float, default=4.5,
|
84 |
+
help="CFG scale parameter.")
|
85 |
+
parser.add_argument("--sampling_algo", type=str, default="dpm-solver",
|
86 |
+
choices=["iddpm", "dpm-solver"],
|
87 |
+
help="Sampling algorithm to be used.")
|
88 |
+
parser.add_argument("--seed", type=int, default=0,
|
89 |
+
help="Random seed for reproducibility.")
|
90 |
+
parser.add_argument("--select_img", type=str, default=None,
|
91 |
+
help="Optional: Select a specific image.")
|
92 |
+
parser.add_argument('--step', default=-1, type=int)
|
93 |
+
parser.add_argument('--use_demo_cam', action='store_true', help="Enable predefined camera parameters")
|
94 |
+
return parser.parse_args()
|
95 |
+
|
96 |
+
|
97 |
+
def set_env(seed=0):
|
98 |
+
"""Set random seed for reproducibility across multiple frameworks."""
|
99 |
+
torch.manual_seed(seed) # Set PyTorch seed
|
100 |
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
101 |
+
np.random.seed(seed) # Set NumPy seed
|
102 |
+
random.seed(seed) # Set Python built-in random module seed
|
103 |
+
torch.set_grad_enabled(False) # Disable gradients for inference
|
104 |
+
|
105 |
+
|
106 |
+
def to_rgb_image(image: Image.Image):
|
107 |
+
"""Convert an image to RGB format if necessary."""
|
108 |
+
if image.mode == 'RGB':
|
109 |
+
return image
|
110 |
+
elif image.mode == 'RGBA':
|
111 |
+
img = Image.new("RGB", image.size, (127, 127, 127))
|
112 |
+
img.paste(image, mask=image.getchannel('A'))
|
113 |
+
return img
|
114 |
+
else:
|
115 |
+
raise ValueError(f"Unsupported image type: {image.mode}")
|
116 |
+
|
117 |
+
|
118 |
+
def image_process(image_path):
|
119 |
+
"""Preprocess an image for CLIP and DINO models."""
|
120 |
+
image = to_rgb_image(Image.open(image_path))
|
121 |
+
clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values.to(device)
|
122 |
+
dino_image = dino_img_processor(images=image, return_tensors="pt").pixel_values.to(device)
|
123 |
+
return dino_image, clip_image
|
124 |
+
|
125 |
+
|
126 |
+
def video_gen(frames_dir, output_path, fps=30):
|
127 |
+
"""Generate a video from image frames."""
|
128 |
+
frame_files = natsorted(os.listdir(frames_dir), alg=ns.PATH)
|
129 |
+
frames = [cv2.imread(os.path.join(frames_dir, f)) for f in frame_files]
|
130 |
+
H, W = frames[0].shape[:2]
|
131 |
+
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (W, H))
|
132 |
+
for frame in frames:
|
133 |
+
video_writer.write(frame)
|
134 |
+
video_writer.release()
|
135 |
+
|
136 |
+
|
137 |
+
def trans(tensor_img):
|
138 |
+
img = (tensor_img.permute(0, 2, 3, 1) * 0.5 + 0.5).clamp(0, 1) * 255.
|
139 |
+
img = img.to(torch.uint8)
|
140 |
+
img = img[0].detach().cpu().numpy()
|
141 |
+
|
142 |
+
return img
|
143 |
+
|
144 |
+
|
145 |
+
def get_vert(vert_dir):
|
146 |
+
uvcoords_image = np.load(os.path.join(vert_dir))[..., :3]
|
147 |
+
uvcoords_image[..., -1][uvcoords_image[..., -1] < 0.5] = 0
|
148 |
+
uvcoords_image[..., -1][uvcoords_image[..., -1] >= 0.5] = 1
|
149 |
+
return torch.tensor(uvcoords_image.copy()).float().unsqueeze(0)
|
150 |
+
|
151 |
+
|
152 |
+
def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature, uncond_clip_feature,
|
153 |
+
uncond_dino_feature, device, latent_size, sampling_algo):
|
154 |
+
"""
|
155 |
+
Generate latent samples using the specified diffusion model.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
DiT_model (torch.nn.Module): The diffusion model.
|
159 |
+
cfg_scale (float): The classifier-free guidance scale.
|
160 |
+
sample_steps (int): Number of sampling steps.
|
161 |
+
clip_feature (torch.Tensor): CLIP feature tensor.
|
162 |
+
dino_feature (torch.Tensor): DINO feature tensor.
|
163 |
+
uncond_clip_feature (torch.Tensor): Unconditional CLIP feature tensor.
|
164 |
+
uncond_dino_feature (torch.Tensor): Unconditional DINO feature tensor.
|
165 |
+
device (str): Device for computation.
|
166 |
+
latent_size (tuple): The latent space size.
|
167 |
+
sampling_algo (str): The sampling algorithm ('iddpm' or 'dpm-solver').
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
torch.Tensor: The generated samples.
|
171 |
+
"""
|
172 |
+
n = 1 # Batch size
|
173 |
+
z = torch.randn(n, 8, latent_size[0], latent_size[1], device=device)
|
174 |
+
|
175 |
+
if sampling_algo == 'iddpm':
|
176 |
+
z = z.repeat(2, 1, 1, 1) # Duplicate for classifier-free guidance
|
177 |
+
model_kwargs = dict(y=torch.cat([clip_feature, uncond_clip_feature]),
|
178 |
+
img_feature=torch.cat([dino_feature, dino_feature]),
|
179 |
+
cfg_scale=cfg_scale)
|
180 |
+
diffusion = IDDPM(str(sample_steps))
|
181 |
+
samples = diffusion.p_sample_loop(DiT_model.forward_with_cfg, z.shape, z, clip_denoised=False,
|
182 |
+
model_kwargs=model_kwargs, progress=True, device=device)
|
183 |
+
samples, _ = samples.chunk(2, dim=0) # Remove unconditional samples
|
184 |
+
|
185 |
+
elif sampling_algo == 'dpm-solver':
|
186 |
+
dpm_solver = DPMS(DiT_model.forward_with_dpmsolver,
|
187 |
+
condition=[clip_feature, dino_feature],
|
188 |
+
uncondition=[uncond_clip_feature, dino_feature],
|
189 |
+
cfg_scale=cfg_scale)
|
190 |
+
samples = dpm_solver.sample(z, steps=sample_steps, order=2, skip_type="time_uniform", method="multistep")
|
191 |
+
else:
|
192 |
+
raise ValueError(f"Invalid sampling_algo '{sampling_algo}'. Choose either 'iddpm' or 'dpm-solver'.")
|
193 |
+
|
194 |
+
return samples
|
195 |
+
|
196 |
+
def images_to_video(image_folder, output_video, fps=30):
|
197 |
+
# Get all image files and ensure correct order
|
198 |
+
images = [img for img in os.listdir(image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]
|
199 |
+
images = natsorted(images) # Sort filenames naturally to preserve frame order
|
200 |
+
|
201 |
+
if not images:
|
202 |
+
print("❌ No images found in the directory!")
|
203 |
+
return
|
204 |
+
|
205 |
+
# Get the path to the FFmpeg executable
|
206 |
+
ffmpeg_exe = ffmpeg.get_ffmpeg_exe()
|
207 |
+
print(f"Using FFmpeg from: {ffmpeg_exe}")
|
208 |
+
|
209 |
+
# Define input image pattern (expects images named like "%04d.png")
|
210 |
+
image_pattern = os.path.join(image_folder, "%04d.png")
|
211 |
+
|
212 |
+
# FFmpeg command to encode video
|
213 |
+
command = [
|
214 |
+
ffmpeg_exe, '-framerate', str(fps), '-i', image_pattern,
|
215 |
+
'-c:v', 'libx264', '-preset', 'slow', '-crf', '18', # High-quality H.264 encoding
|
216 |
+
'-pix_fmt', 'yuv420p', '-b:v', '5000k', # Ensure compatibility & increase bitrate
|
217 |
+
output_video
|
218 |
+
]
|
219 |
+
|
220 |
+
# Run FFmpeg command
|
221 |
+
subprocess.run(command, check=True)
|
222 |
+
|
223 |
+
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
224 |
+
@torch.inference_mode()
|
225 |
+
def avatar_generation(items, bs, sample_steps, cfg_scale, save_path_base, DiT_model, render_model, std, mean, ws_avg,
|
226 |
+
Faceverse, pitch_range=0.25, yaw_range=0.35, demo_cam=False):
|
227 |
+
"""
|
228 |
+
Generate avatars from input images.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
items (list): List of image paths.
|
232 |
+
bs (int): Batch size.
|
233 |
+
sample_steps (int): Number of sampling steps.
|
234 |
+
cfg_scale (float): Classifier-free guidance scale.
|
235 |
+
save_path_base (str): Base directory for saving results.
|
236 |
+
DiT_model (torch.nn.Module): The diffusion model.
|
237 |
+
render_model (torch.nn.Module): The rendering model.
|
238 |
+
std (torch.Tensor): Standard deviation normalization tensor.
|
239 |
+
mean (torch.Tensor): Mean normalization tensor.
|
240 |
+
ws_avg (torch.Tensor): Latent average tensor.
|
241 |
+
"""
|
242 |
+
for chunk in tqdm(list(get_chunks(items, bs)), unit='batch'):
|
243 |
+
if bs != 1:
|
244 |
+
raise ValueError("Batch size > 1 not implemented")
|
245 |
+
|
246 |
+
image_dir = chunk[0]
|
247 |
+
image_name = os.path.splitext(os.path.basename(image_dir))[0]
|
248 |
+
dino_img, clip_image = image_process(image_dir)
|
249 |
+
|
250 |
+
clip_feature = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
251 |
+
uncond_clip_feature = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
252 |
+
-2]
|
253 |
+
dino_feature = dinov2(dino_img).last_hidden_state
|
254 |
+
uncond_dino_feature = dinov2(torch.zeros_like(dino_img)).last_hidden_state
|
255 |
+
|
256 |
+
samples = generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feature,
|
257 |
+
uncond_clip_feature, uncond_dino_feature, device, latent_size,
|
258 |
+
args.sampling_algo)
|
259 |
+
|
260 |
+
samples = (samples / default_config.scale_factor)
|
261 |
+
samples = rearrange(samples, "b c (f h) w -> b c f h w", f=4)
|
262 |
+
samples = vae_triplane.decode(samples)
|
263 |
+
samples = rearrange(samples, "b c f h w -> b f c h w")
|
264 |
+
samples = samples * std + mean
|
265 |
+
torch.cuda.empty_cache()
|
266 |
+
|
267 |
+
save_frames_path_combine = os.path.join(save_path_base, image_name, 'combine')
|
268 |
+
save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
|
269 |
+
os.makedirs(save_frames_path_combine, exist_ok=True)
|
270 |
+
os.makedirs(save_frames_path_out, exist_ok=True)
|
271 |
+
|
272 |
+
img_ref = np.array(Image.open(image_dir))
|
273 |
+
img_ref_out = img_ref.copy()
|
274 |
+
img_ref = torch.from_numpy(img_ref.astype(np.float32) / 127.5 - 1).permute(2, 0, 1).unsqueeze(0).to(device)
|
275 |
+
|
276 |
+
motion_app_dir = os.path.join(args.input_img_motion, image_name + '.npy')
|
277 |
+
motion_app = torch.tensor(np.load(motion_app_dir), dtype=torch.float32).unsqueeze(0).to(device)
|
278 |
+
|
279 |
+
id_motions = os.path.join(args.input_img_fvid, image_name + '.npy')
|
280 |
+
|
281 |
+
all_pose = json.loads(open(label_file_test).read())['labels']
|
282 |
+
all_pose = dict(all_pose)
|
283 |
+
if os.path.exists(id_motions):
|
284 |
+
coeff = np.load(id_motions).astype(np.float32)
|
285 |
+
coeff = torch.from_numpy(coeff).to(device).float().unsqueeze(0)
|
286 |
+
Faceverse.id_coeff = Faceverse.recon_model.split_coeffs(coeff)[0]
|
287 |
+
motion_dir = os.path.join(motion_base_dir, args.video_name)
|
288 |
+
exp_dir = os.path.join(exp_base_dir, args.video_name)
|
289 |
+
for frame_index, motion_name in enumerate(
|
290 |
+
tqdm(natsorted(os.listdir(motion_dir), alg=ns.PATH), desc="Processing Frames")):
|
291 |
+
exp_each_dir_img = os.path.join(exp_img_base_dir, args.video_name, motion_name.replace('.npy', '.png'))
|
292 |
+
exp_each_dir = os.path.join(exp_dir, motion_name)
|
293 |
+
motion_each_dir = os.path.join(motion_dir, motion_name)
|
294 |
+
|
295 |
+
# Load pose data
|
296 |
+
pose_key = os.path.join(args.video_name, motion_name.replace('.npy', '.png'))
|
297 |
+
if demo_cam:
|
298 |
+
cam2world_pose = LookAtPoseSampler.sample(
|
299 |
+
3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
|
300 |
+
3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_index / len(os.listdir(motion_dir))),
|
301 |
+
torch.tensor([0, 0, 0], device=device), radius=2.7, device=device)
|
302 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16),
|
303 |
+
FOV_to_intrinsics(fov_degrees=18.837, device=device).reshape(-1, 9)], 1).to(device)
|
304 |
+
else:
|
305 |
+
pose = torch.tensor(np.array(all_pose[pose_key]).astype(np.float32)).float().unsqueeze(0).to(device)
|
306 |
+
|
307 |
+
# Load and resize expression image
|
308 |
+
exp_img = np.array(Image.open(exp_each_dir_img).resize((512, 512)))
|
309 |
+
|
310 |
+
# Load expression coefficients
|
311 |
+
exp_coeff = torch.from_numpy(np.load(exp_each_dir).astype(np.float32)).to(device).float().unsqueeze(0)
|
312 |
+
exp_target = Faceverse.make_driven_rendering(exp_coeff, res=256)
|
313 |
+
|
314 |
+
# Load motion data
|
315 |
+
motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
|
316 |
+
|
317 |
+
# Select refine_net processing method
|
318 |
+
final_out = render_model(
|
319 |
+
img_ref, None, motion_app, motion, c=pose, mesh=exp_target, triplane_recon=samples,
|
320 |
+
ws_avg=ws_avg, motion_scale=1.
|
321 |
+
)
|
322 |
+
|
323 |
+
# Process output image
|
324 |
+
final_out = trans(final_out['image_sr'])
|
325 |
+
output_img_combine = np.hstack((img_ref_out, exp_img, final_out))
|
326 |
+
|
327 |
+
# Save output images
|
328 |
+
frame_name = f'{str(frame_index).zfill(4)}.png'
|
329 |
+
Image.fromarray(output_img_combine, 'RGB').save(os.path.join(save_frames_path_combine, frame_name))
|
330 |
+
Image.fromarray(final_out, 'RGB').save(os.path.join(save_frames_path_out, frame_name))
|
331 |
+
|
332 |
+
# Generate videos
|
333 |
+
images_to_video(save_frames_path_combine, os.path.join(save_path_base, image_name + '_combine.mp4'))
|
334 |
+
images_to_video(save_frames_path_out, os.path.join(save_path_base, image_name + '_out.mp4'))
|
335 |
+
logging.info(f"✅ Video generation completed successfully!")
|
336 |
+
logging.info(f"📂 Combined video saved at: {os.path.join(save_path_base, image_name + '_combine.mp4')}")
|
337 |
+
logging.info(f"📂 Output video saved at: {os.path.join(save_path_base, image_name + '_out.mp4')}")
|
338 |
+
|
339 |
+
|
340 |
+
def load_motion_aware_render_model(ckpt_path):
|
341 |
+
"""Load the motion-aware render model from a checkpoint."""
|
342 |
+
logging.info("Loading motion-aware render model...")
|
343 |
+
with dnnlib.util.open_url(ckpt_path, 'rb') as f:
|
344 |
+
network = legacy.load_network_pkl(f) # type: ignore
|
345 |
+
logging.info("Motion-aware render model loaded.")
|
346 |
+
return network['G_ema'].to(device)
|
347 |
+
|
348 |
+
|
349 |
+
def load_diffusion_model(ckpt_path, latent_size):
|
350 |
+
"""Load the diffusion model (DiT)."""
|
351 |
+
logging.info("Loading diffusion model (DiT)...")
|
352 |
+
|
353 |
+
DiT_model = TriDitCLIPDINO_XL_2(input_size=latent_size).to(device)
|
354 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
355 |
+
|
356 |
+
# Remove keys that can cause mismatches
|
357 |
+
for key in ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']:
|
358 |
+
ckpt['state_dict'].pop(key, None)
|
359 |
+
ckpt.get('state_dict_ema', {}).pop(key, None)
|
360 |
+
|
361 |
+
state_dict = ckpt.get('state_dict_ema', ckpt)
|
362 |
+
DiT_model.load_state_dict(state_dict, strict=False)
|
363 |
+
DiT_model.eval()
|
364 |
+
logging.info("Diffusion model (DiT) loaded.")
|
365 |
+
return DiT_model
|
366 |
+
|
367 |
+
|
368 |
+
def load_vae_clip_dino(config, device):
|
369 |
+
"""Load VAE, CLIP, and DINO models."""
|
370 |
+
logging.info("Loading VAE, CLIP, and DINO models...")
|
371 |
+
|
372 |
+
# Load CLIP image encoder
|
373 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
374 |
+
config.image_encoder_path)
|
375 |
+
image_encoder.requires_grad_(False)
|
376 |
+
image_encoder.to(device)
|
377 |
+
|
378 |
+
# Load VAE
|
379 |
+
config_vae = OmegaConf.load(config.vae_triplane_config_path)
|
380 |
+
vae_triplane = AutoencoderKLTriplane(ddconfig=config_vae['ddconfig'], lossconfig=None, embed_dim=8)
|
381 |
+
vae_triplane.to(device)
|
382 |
+
|
383 |
+
vae_ckpt_path = os.path.join(config.vae_pretrained, 'pytorch_model.bin')
|
384 |
+
if not os.path.isfile(vae_ckpt_path):
|
385 |
+
raise RuntimeError(f"VAE checkpoint not found at {vae_ckpt_path}")
|
386 |
+
|
387 |
+
vae_triplane.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"))
|
388 |
+
vae_triplane.requires_grad_(False)
|
389 |
+
|
390 |
+
# Load DINO model
|
391 |
+
dinov2 = Dinov2Model.from_pretrained(config.dino_pretrained)
|
392 |
+
dinov2.requires_grad_(False)
|
393 |
+
dinov2.to(device)
|
394 |
+
|
395 |
+
# Load image processors
|
396 |
+
dino_img_processor = AutoImageProcessor.from_pretrained(config.dino_pretrained)
|
397 |
+
clip_image_processor = CLIPImageProcessor()
|
398 |
+
|
399 |
+
logging.info("VAE, CLIP, and DINO models loaded.")
|
400 |
+
return vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor
|
401 |
+
|
402 |
+
|
403 |
+
def prepare_image_list(img_dir, selected_img):
|
404 |
+
"""Prepare the list of image paths for processing."""
|
405 |
+
if selected_img and selected_img in os.listdir(img_dir):
|
406 |
+
return [os.path.join(img_dir, selected_img)]
|
407 |
+
|
408 |
+
return sorted([os.path.join(img_dir, img) for img in os.listdir(img_dir)])
|
409 |
+
|
410 |
+
|
411 |
+
if __name__ == '__main__':
|
412 |
+
|
413 |
+
model_path = "./pretrained_model"
|
414 |
+
if not os.path.exists(model_path):
|
415 |
+
logging.info("📥 Model not found. Downloading from Hugging Face...")
|
416 |
+
snapshot_download(repo_id="KumaPower/AvatarArtist", local_dir=model_path)
|
417 |
+
logging.info("✅ Model downloaded successfully!")
|
418 |
+
else:
|
419 |
+
logging.info("🎉 Pretrained model already exists. Skipping download.")
|
420 |
+
|
421 |
+
args = get_args()
|
422 |
+
exp_base_dir = os.path.join(args.target_path, 'coeffs')
|
423 |
+
exp_img_base_dir = os.path.join(args.target_path, 'images512x512')
|
424 |
+
motion_base_dir = os.path.join(args.target_path, 'motions')
|
425 |
+
label_file_test = os.path.join(args.target_path, 'images512x512/dataset_realcam.json')
|
426 |
+
set_env(args.seed)
|
427 |
+
|
428 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
429 |
+
weight_dtype = torch.float32
|
430 |
+
logging.info(f"Running inference with {weight_dtype}")
|
431 |
+
|
432 |
+
# Load configuration
|
433 |
+
default_config = read_config(args.config)
|
434 |
+
|
435 |
+
# Ensure valid sampling algorithm
|
436 |
+
assert args.sampling_algo in ['iddpm', 'dpm-solver', 'sa-solver']
|
437 |
+
|
438 |
+
# Prepare image list
|
439 |
+
items = prepare_image_list(args.img_file, args.select_img)
|
440 |
+
logging.info(f"Input images: {items}")
|
441 |
+
|
442 |
+
# Load motion-aware render model
|
443 |
+
motion_aware_render_model = load_motion_aware_render_model(default_config.motion_aware_render_model_ckpt)
|
444 |
+
|
445 |
+
# Load diffusion model (DiT)
|
446 |
+
triplane_size = (256 * 4, 256)
|
447 |
+
latent_size = (triplane_size[0] // 8, triplane_size[1] // 8)
|
448 |
+
sample_steps = args.step if args.step != -1 else {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}[
|
449 |
+
args.sampling_algo]
|
450 |
+
DiT_model = load_diffusion_model(default_config.DiT_model_ckpt, latent_size)
|
451 |
+
|
452 |
+
# Load VAE, CLIP, and DINO
|
453 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor = load_vae_clip_dino(default_config,
|
454 |
+
device)
|
455 |
+
|
456 |
+
# Load normalization parameters
|
457 |
+
triplane_std = torch.load(default_config.std_dir).to(device).reshape(1, -1, 1, 1, 1)
|
458 |
+
triplane_mean = torch.load(default_config.mean_dir).to(device).reshape(1, -1, 1, 1, 1)
|
459 |
+
|
460 |
+
# Load average latent vector
|
461 |
+
ws_avg = torch.load(default_config.ws_avg_pkl).to(device)[0]
|
462 |
+
|
463 |
+
# Set up save directory
|
464 |
+
save_root = os.path.join(args.output_basedir, f'{datetime.now().date()}', args.video_name)
|
465 |
+
os.makedirs(save_root, exist_ok=True)
|
466 |
+
|
467 |
+
# Set up face verse for amimation
|
468 |
+
base_coff = np.load(
|
469 |
+
'pretrained_model/temp.npy').astype(
|
470 |
+
np.float32)
|
471 |
+
base_coff = torch.from_numpy(base_coff).float()
|
472 |
+
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
473 |
+
|
474 |
+
# Run avatar generation
|
475 |
+
avatar_generation(items, args.bs, sample_steps, args.cfg_scale, save_root, DiT_model, motion_aware_render_model,
|
476 |
+
triplane_std, triplane_mean, ws_avg, Faceverse, demo_cam=args.use_demo_cam)
|
requirements.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.34.1
|
2 |
+
scipy==1.13.1
|
3 |
+
scikit-image==0.21.0
|
4 |
+
scikit-learn==1.2.2
|
5 |
+
opencv-python==4.10.0.84
|
6 |
+
Pillow==10.4.0
|
7 |
+
termcolor==1.1.0
|
8 |
+
PyYAML==6.0.2
|
9 |
+
tqdm==4.67.1
|
10 |
+
absl-py==2.1.0
|
11 |
+
tensorboard==2.17.0
|
12 |
+
tensorboardX==2.6.1
|
13 |
+
PyOpenGL==3.1.0
|
14 |
+
pyrender==0.1.45
|
15 |
+
trimesh==3.22.0
|
16 |
+
click==8.1.7
|
17 |
+
omegaconf==2.2.3
|
18 |
+
segmentation_models_pytorch
|
19 |
+
timm
|
20 |
+
psutil==5.9.5
|
21 |
+
lmdb==1.4.1
|
22 |
+
einops==0.8.1
|
23 |
+
kornia==0.6.7
|
24 |
+
gdown==5.2.0
|
25 |
+
plyfile==1.0.3
|
26 |
+
natsort
|
27 |
+
mmcv==1.7.0
|
28 |
+
xformers==0.0.22.post3
|
29 |
+
https://download.pytorch.org/whl/cu121/torch-2.1.1%2Bcu121-cp39-cp39-linux_x86_64.whl
|
30 |
+
https://download.pytorch.org/whl/cu121/torchvision-0.16.1%2Bcu121-cp39-cp39-linux_x86_64
|
31 |
+
git+https://github.com/facebookresearch/[email protected]
|
32 |
+
matplotlib==3.9.1.post1
|
33 |
+
mediapipe
|
34 |
+
open_clip
|
35 |
+
imageio_ffmpeg
|
36 |
+
spaces
|