Spaces:
Sleeping
Sleeping
Upload marlenezw/audio-driven-animations/MakeItTalk with huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/.gitignore +8 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/LICENSE +201 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/README.md +82 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__init__.py +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/ckpt/.gitkeep +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__init__.py +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/coord_conv.py +157 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/dataloader.py +368 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/evaler.py +151 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/models.py +228 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/eval.py +77 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png +3 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png +3 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/requirements.txt +12 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/scripts/eval_wflw.sh +10 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__init__.py +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/utils.py +354 -0
- marlenezw/audio-driven-animations/MakeItTalk/__init__.py +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-37.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-39.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/CODEOWNERS +1 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/LICENCE.txt +21 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/README.md +98 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.py +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/__init__.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/data_loading_functions.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deep_heatmaps_model_fusion_net.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deformation_functions.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/logging_functions.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/menpo_functions.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/ops.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/pdm_clm_functions.cpython-36.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/crop_training_set.py +38 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.py +161 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.pyc +0 -0
- marlenezw/audio-driven-animations/MakeItTalk/face_of_art/deep_heatmaps_model_fusion_net.py +872 -0
.gitattributes
CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
marlenezw/audio-driven-animations/MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
|
36 |
MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
marlenezw/audio-driven-animations/MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
|
36 |
MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/old/teaser.png filter=lfs diff=lfs merge=lfs -text
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python generated files
|
2 |
+
*.pyc
|
3 |
+
|
4 |
+
# Project related files
|
5 |
+
ckpt/*.pth
|
6 |
+
dataset/*
|
7 |
+
!dataset/!.py
|
8 |
+
experiments/*
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AdaptiveWingLoss
|
2 |
+
## [arXiv](https://arxiv.org/abs/1904.07399)
|
3 |
+
Pytorch Implementation of Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression.
|
4 |
+
|
5 |
+
<img src='images/wflw.png' width="1000px">
|
6 |
+
|
7 |
+
## Update Logs:
|
8 |
+
### October 28, 2019
|
9 |
+
* Pretrained Model and evaluation code on WFLW dataset is released.
|
10 |
+
|
11 |
+
## Installation
|
12 |
+
#### Note: Code was originally developed under Python2.X and Pytorch 0.4. This released version was revisioned from original code and was tested on Python3.5.7 and Pytorch 1.3.0.
|
13 |
+
|
14 |
+
Install system requirements:
|
15 |
+
```
|
16 |
+
sudo apt-get install python3-dev python3-pip python3-tk libglib2.0-0
|
17 |
+
```
|
18 |
+
|
19 |
+
Install python dependencies:
|
20 |
+
```
|
21 |
+
pip3 install -r requirements.txt
|
22 |
+
```
|
23 |
+
|
24 |
+
## Run Evaluation on WFLW dataset
|
25 |
+
1. Download and process WFLW dataset
|
26 |
+
* Download WFLW dataset and annotation from [Here](https://wywu.github.io/projects/LAB/WFLW.html).
|
27 |
+
* Unzip WFLW dataset and annotations and move files into ```./dataset``` directory. Your directory should look like this:
|
28 |
+
```
|
29 |
+
AdaptiveWingLoss
|
30 |
+
└───dataset
|
31 |
+
│
|
32 |
+
└───WFLW_annotations
|
33 |
+
│ └───list_98pt_rect_attr_train_test
|
34 |
+
│ │
|
35 |
+
│ └───list_98pt_test
|
36 |
+
│
|
37 |
+
└───WFLW_images
|
38 |
+
└───0--Parade
|
39 |
+
│
|
40 |
+
└───...
|
41 |
+
```
|
42 |
+
* Inside ```./dataset``` directory, run:
|
43 |
+
```
|
44 |
+
python convert_WFLW.py
|
45 |
+
```
|
46 |
+
A new directory ```./dataset/WFLW_test``` should be generated with 2500 processed testing images and corresponding landmarks.
|
47 |
+
|
48 |
+
2. Download pretrained model from [Google Drive](https://drive.google.com/file/d/1HZaSjLoorQ4QCEx7PRTxOmg0bBPYSqhH/view?usp=sharing) and put it in ```./ckpt``` directory.
|
49 |
+
|
50 |
+
3. Within ```./Scripts``` directory, run following command:
|
51 |
+
```
|
52 |
+
sh eval_wflw.sh
|
53 |
+
```
|
54 |
+
|
55 |
+
<img src='images/wflw_table.png' width="800px">
|
56 |
+
*GTBbox indicates the ground truth landmarks are used as bounding box to crop faces.
|
57 |
+
|
58 |
+
## Future Plans
|
59 |
+
- [x] Release evaluation code and pretrained model on WFLW dataset.
|
60 |
+
|
61 |
+
- [ ] Release training code on WFLW dataset.
|
62 |
+
|
63 |
+
- [ ] Release pretrained model and code on 300W, AFLW and COFW dataset.
|
64 |
+
|
65 |
+
- [ ] Replease facial landmark detection API
|
66 |
+
|
67 |
+
|
68 |
+
## Citation
|
69 |
+
If you find this useful for your research, please cite the following paper.
|
70 |
+
|
71 |
+
```
|
72 |
+
@InProceedings{Wang_2019_ICCV,
|
73 |
+
author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li},
|
74 |
+
title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression},
|
75 |
+
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
|
76 |
+
month = {October},
|
77 |
+
year = {2019}
|
78 |
+
}
|
79 |
+
```
|
80 |
+
|
81 |
+
## Acknowledgments
|
82 |
+
This repository borrows or partially modifies hourglass model and data processing code from [face alignment](https://github.com/1adrianb/face-alignment) and [pose-hg-train](https://github.com/princeton-vl/pose-hg-train).
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__init__.py
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (164 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (179 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/ckpt/.gitkeep
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__init__.py
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (169 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (184 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-37.pyc
ADDED
Binary file (4.33 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-39.pyc
ADDED
Binary file (4.38 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-37.pyc
ADDED
Binary file (5.77 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-39.pyc
ADDED
Binary file (5.83 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/coord_conv.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
+
class AddCoordsTh(nn.Module):
|
9 |
+
def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
|
10 |
+
super(AddCoordsTh, self).__init__()
|
11 |
+
self.x_dim = x_dim
|
12 |
+
self.y_dim = y_dim
|
13 |
+
self.with_r = with_r
|
14 |
+
self.with_boundary = with_boundary
|
15 |
+
|
16 |
+
def forward(self, input_tensor, heatmap=None):
|
17 |
+
"""
|
18 |
+
input_tensor: (batch, c, x_dim, y_dim)
|
19 |
+
"""
|
20 |
+
batch_size_tensor = input_tensor.shape[0]
|
21 |
+
|
22 |
+
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(device)
|
23 |
+
xx_ones = xx_ones.unsqueeze(-1)
|
24 |
+
|
25 |
+
xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(device)
|
26 |
+
xx_range = xx_range.unsqueeze(1)
|
27 |
+
|
28 |
+
xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
|
29 |
+
xx_channel = xx_channel.unsqueeze(-1)
|
30 |
+
|
31 |
+
|
32 |
+
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(device)
|
33 |
+
yy_ones = yy_ones.unsqueeze(1)
|
34 |
+
|
35 |
+
yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(device)
|
36 |
+
yy_range = yy_range.unsqueeze(-1)
|
37 |
+
|
38 |
+
yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
|
39 |
+
yy_channel = yy_channel.unsqueeze(-1)
|
40 |
+
|
41 |
+
xx_channel = xx_channel.permute(0, 3, 2, 1)
|
42 |
+
yy_channel = yy_channel.permute(0, 3, 2, 1)
|
43 |
+
|
44 |
+
xx_channel = xx_channel / (self.x_dim - 1)
|
45 |
+
yy_channel = yy_channel / (self.y_dim - 1)
|
46 |
+
|
47 |
+
xx_channel = xx_channel * 2 - 1
|
48 |
+
yy_channel = yy_channel * 2 - 1
|
49 |
+
|
50 |
+
xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
|
51 |
+
yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
|
52 |
+
|
53 |
+
if self.with_boundary and type(heatmap) != type(None):
|
54 |
+
boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
|
55 |
+
0.0, 1.0)
|
56 |
+
|
57 |
+
zero_tensor = torch.zeros_like(xx_channel)
|
58 |
+
xx_boundary_channel = torch.where(boundary_channel>0.05,
|
59 |
+
xx_channel, zero_tensor)
|
60 |
+
yy_boundary_channel = torch.where(boundary_channel>0.05,
|
61 |
+
yy_channel, zero_tensor)
|
62 |
+
if self.with_boundary and type(heatmap) != type(None):
|
63 |
+
xx_boundary_channel = xx_boundary_channel.to(device)
|
64 |
+
yy_boundary_channel = yy_boundary_channel.to(device)
|
65 |
+
|
66 |
+
ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
|
67 |
+
|
68 |
+
|
69 |
+
if self.with_r:
|
70 |
+
rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
|
71 |
+
rr = rr / torch.max(rr)
|
72 |
+
ret = torch.cat([ret, rr], dim=1)
|
73 |
+
|
74 |
+
if self.with_boundary and type(heatmap) != type(None):
|
75 |
+
ret = torch.cat([ret, xx_boundary_channel,
|
76 |
+
yy_boundary_channel], dim=1)
|
77 |
+
return ret
|
78 |
+
|
79 |
+
|
80 |
+
class CoordConvTh(nn.Module):
|
81 |
+
"""CoordConv layer as in the paper."""
|
82 |
+
def __init__(self, x_dim, y_dim, with_r, with_boundary,
|
83 |
+
in_channels, first_one=False, *args, **kwargs):
|
84 |
+
super(CoordConvTh, self).__init__()
|
85 |
+
self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
|
86 |
+
with_boundary=with_boundary)
|
87 |
+
in_channels += 2
|
88 |
+
if with_r:
|
89 |
+
in_channels += 1
|
90 |
+
if with_boundary and not first_one:
|
91 |
+
in_channels += 2
|
92 |
+
self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
|
93 |
+
|
94 |
+
def forward(self, input_tensor, heatmap=None):
|
95 |
+
ret = self.addcoords(input_tensor, heatmap)
|
96 |
+
last_channel = ret[:, -2:, :, :]
|
97 |
+
ret = self.conv(ret)
|
98 |
+
return ret, last_channel
|
99 |
+
|
100 |
+
|
101 |
+
'''
|
102 |
+
An alternative implementation for PyTorch with auto-infering the x-y dimensions.
|
103 |
+
'''
|
104 |
+
class AddCoords(nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, with_r=False):
|
107 |
+
super().__init__()
|
108 |
+
self.with_r = with_r
|
109 |
+
|
110 |
+
def forward(self, input_tensor):
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
input_tensor: shape(batch, channel, x_dim, y_dim)
|
114 |
+
"""
|
115 |
+
batch_size, _, x_dim, y_dim = input_tensor.size()
|
116 |
+
|
117 |
+
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
|
118 |
+
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
|
119 |
+
|
120 |
+
xx_channel = xx_channel / (x_dim - 1)
|
121 |
+
yy_channel = yy_channel / (y_dim - 1)
|
122 |
+
|
123 |
+
xx_channel = xx_channel * 2 - 1
|
124 |
+
yy_channel = yy_channel * 2 - 1
|
125 |
+
|
126 |
+
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
|
127 |
+
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
|
128 |
+
|
129 |
+
if input_tensor.is_cuda:
|
130 |
+
xx_channel = xx_channel.to(device)
|
131 |
+
yy_channel = yy_channel.to(device)
|
132 |
+
|
133 |
+
ret = torch.cat([
|
134 |
+
input_tensor,
|
135 |
+
xx_channel.type_as(input_tensor),
|
136 |
+
yy_channel.type_as(input_tensor)], dim=1)
|
137 |
+
|
138 |
+
if self.with_r:
|
139 |
+
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
|
140 |
+
if input_tensor.is_cuda:
|
141 |
+
rr = rr.to(device)
|
142 |
+
ret = torch.cat([ret, rr], dim=1)
|
143 |
+
|
144 |
+
return ret
|
145 |
+
|
146 |
+
|
147 |
+
class CoordConv(nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
|
150 |
+
super().__init__()
|
151 |
+
self.addcoords = AddCoords(with_r=with_r)
|
152 |
+
self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
ret = self.addcoords(x)
|
156 |
+
ret = self.conv(ret)
|
157 |
+
return ret
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/dataloader.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
from skimage import io
|
7 |
+
from skimage import transform as ski_transform
|
8 |
+
from skimage.color import rgb2gray
|
9 |
+
import scipy.io as sio
|
10 |
+
from scipy import interpolate
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
from torchvision import transforms, utils
|
15 |
+
from torchvision.transforms import Lambda, Compose
|
16 |
+
from torchvision.transforms.functional import adjust_brightness, adjust_contrast, adjust_saturation, adjust_hue
|
17 |
+
from utils.utils import cv_crop, cv_rotate, draw_gaussian, transform, power_transform, shuffle_lr, fig2data, generate_weight_map
|
18 |
+
from PIL import Image
|
19 |
+
import cv2
|
20 |
+
import copy
|
21 |
+
import math
|
22 |
+
from imgaug import augmenters as iaa
|
23 |
+
|
24 |
+
|
25 |
+
class AddBoundary(object):
|
26 |
+
def __init__(self, num_landmarks=68):
|
27 |
+
self.num_landmarks = num_landmarks
|
28 |
+
|
29 |
+
def __call__(self, sample):
|
30 |
+
landmarks_64 = np.floor(sample['landmarks'] / 4.0)
|
31 |
+
if self.num_landmarks == 68:
|
32 |
+
boundaries = {}
|
33 |
+
boundaries['cheek'] = landmarks_64[0:17]
|
34 |
+
boundaries['left_eyebrow'] = landmarks_64[17:22]
|
35 |
+
boundaries['right_eyebrow'] = landmarks_64[22:27]
|
36 |
+
boundaries['uper_left_eyelid'] = landmarks_64[36:40]
|
37 |
+
boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]])
|
38 |
+
boundaries['upper_right_eyelid'] = landmarks_64[42:46]
|
39 |
+
boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]])
|
40 |
+
boundaries['noise'] = landmarks_64[27:31]
|
41 |
+
boundaries['noise_bot'] = landmarks_64[31:36]
|
42 |
+
boundaries['upper_outer_lip'] = landmarks_64[48:55]
|
43 |
+
boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]])
|
44 |
+
boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]])
|
45 |
+
boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
|
46 |
+
elif self.num_landmarks == 98:
|
47 |
+
boundaries = {}
|
48 |
+
boundaries['cheek'] = landmarks_64[0:33]
|
49 |
+
boundaries['left_eyebrow'] = landmarks_64[33:38]
|
50 |
+
boundaries['right_eyebrow'] = landmarks_64[42:47]
|
51 |
+
boundaries['uper_left_eyelid'] = landmarks_64[60:65]
|
52 |
+
boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
|
53 |
+
boundaries['upper_right_eyelid'] = landmarks_64[68:73]
|
54 |
+
boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]])
|
55 |
+
boundaries['noise'] = landmarks_64[51:55]
|
56 |
+
boundaries['noise_bot'] = landmarks_64[55:60]
|
57 |
+
boundaries['upper_outer_lip'] = landmarks_64[76:83]
|
58 |
+
boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]])
|
59 |
+
boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]])
|
60 |
+
boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]])
|
61 |
+
elif self.num_landmarks == 19:
|
62 |
+
boundaries = {}
|
63 |
+
boundaries['left_eyebrow'] = landmarks_64[0:3]
|
64 |
+
boundaries['right_eyebrow'] = landmarks_64[3:5]
|
65 |
+
boundaries['left_eye'] = landmarks_64[6:9]
|
66 |
+
boundaries['right_eye'] = landmarks_64[9:12]
|
67 |
+
boundaries['noise'] = landmarks_64[12:15]
|
68 |
+
|
69 |
+
elif self.num_landmarks == 29:
|
70 |
+
boundaries = {}
|
71 |
+
boundaries['upper_left_eyebrow'] = np.stack([
|
72 |
+
landmarks_64[0],
|
73 |
+
landmarks_64[4],
|
74 |
+
landmarks_64[2]
|
75 |
+
], axis=0)
|
76 |
+
boundaries['lower_left_eyebrow'] = np.stack([
|
77 |
+
landmarks_64[0],
|
78 |
+
landmarks_64[5],
|
79 |
+
landmarks_64[2]
|
80 |
+
], axis=0)
|
81 |
+
boundaries['upper_right_eyebrow'] = np.stack([
|
82 |
+
landmarks_64[1],
|
83 |
+
landmarks_64[6],
|
84 |
+
landmarks_64[3]
|
85 |
+
], axis=0)
|
86 |
+
boundaries['lower_right_eyebrow'] = np.stack([
|
87 |
+
landmarks_64[1],
|
88 |
+
landmarks_64[7],
|
89 |
+
landmarks_64[3]
|
90 |
+
], axis=0)
|
91 |
+
boundaries['upper_left_eye'] = np.stack([
|
92 |
+
landmarks_64[8],
|
93 |
+
landmarks_64[12],
|
94 |
+
landmarks_64[10]
|
95 |
+
], axis=0)
|
96 |
+
boundaries['lower_left_eye'] = np.stack([
|
97 |
+
landmarks_64[8],
|
98 |
+
landmarks_64[13],
|
99 |
+
landmarks_64[10]
|
100 |
+
], axis=0)
|
101 |
+
boundaries['upper_right_eye'] = np.stack([
|
102 |
+
landmarks_64[9],
|
103 |
+
landmarks_64[14],
|
104 |
+
landmarks_64[11]
|
105 |
+
], axis=0)
|
106 |
+
boundaries['lower_right_eye'] = np.stack([
|
107 |
+
landmarks_64[9],
|
108 |
+
landmarks_64[15],
|
109 |
+
landmarks_64[11]
|
110 |
+
], axis=0)
|
111 |
+
boundaries['noise'] = np.stack([
|
112 |
+
landmarks_64[18],
|
113 |
+
landmarks_64[21],
|
114 |
+
landmarks_64[19]
|
115 |
+
], axis=0)
|
116 |
+
boundaries['outer_upper_lip'] = np.stack([
|
117 |
+
landmarks_64[22],
|
118 |
+
landmarks_64[24],
|
119 |
+
landmarks_64[23]
|
120 |
+
], axis=0)
|
121 |
+
boundaries['inner_upper_lip'] = np.stack([
|
122 |
+
landmarks_64[22],
|
123 |
+
landmarks_64[25],
|
124 |
+
landmarks_64[23]
|
125 |
+
], axis=0)
|
126 |
+
boundaries['outer_lower_lip'] = np.stack([
|
127 |
+
landmarks_64[22],
|
128 |
+
landmarks_64[26],
|
129 |
+
landmarks_64[23]
|
130 |
+
], axis=0)
|
131 |
+
boundaries['inner_lower_lip'] = np.stack([
|
132 |
+
landmarks_64[22],
|
133 |
+
landmarks_64[27],
|
134 |
+
landmarks_64[23]
|
135 |
+
], axis=0)
|
136 |
+
functions = {}
|
137 |
+
|
138 |
+
for key, points in boundaries.items():
|
139 |
+
temp = points[0]
|
140 |
+
new_points = points[0:1, :]
|
141 |
+
for point in points[1:]:
|
142 |
+
if point[0] == temp[0] and point[1] == temp[1]:
|
143 |
+
continue
|
144 |
+
else:
|
145 |
+
new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
|
146 |
+
temp = point
|
147 |
+
points = new_points
|
148 |
+
if points.shape[0] == 1:
|
149 |
+
points = np.concatenate((points, points+0.001), axis=0)
|
150 |
+
k = min(4, points.shape[0])
|
151 |
+
functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k-1,s=0)
|
152 |
+
|
153 |
+
boundary_map = np.zeros((64, 64))
|
154 |
+
|
155 |
+
fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
|
156 |
+
|
157 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
158 |
+
|
159 |
+
ax.axis('off')
|
160 |
+
|
161 |
+
ax.imshow(boundary_map, interpolation='nearest', cmap='gray')
|
162 |
+
#ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w')
|
163 |
+
|
164 |
+
for key in functions.keys():
|
165 |
+
xnew = np.arange(0, 1, 0.01)
|
166 |
+
out = interpolate.splev(xnew, functions[key][0], der=0)
|
167 |
+
plt.plot(out[0], out[1], ',', linewidth=1, color='w')
|
168 |
+
|
169 |
+
img = fig2data(fig)
|
170 |
+
|
171 |
+
plt.close()
|
172 |
+
|
173 |
+
sigma = 1
|
174 |
+
temp = 255-img[:,:,1]
|
175 |
+
temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
|
176 |
+
temp = temp.astype(np.float32)
|
177 |
+
temp = np.where(temp < 3*sigma, np.exp(-(temp*temp)/(2*sigma*sigma)), 0 )
|
178 |
+
|
179 |
+
fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
|
180 |
+
|
181 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
182 |
+
|
183 |
+
ax.axis('off')
|
184 |
+
ax.imshow(temp, cmap='gray')
|
185 |
+
plt.close()
|
186 |
+
|
187 |
+
boundary_map = fig2data(fig)
|
188 |
+
|
189 |
+
sample['boundary'] = boundary_map[:, :, 0]
|
190 |
+
|
191 |
+
return sample
|
192 |
+
|
193 |
+
class AddWeightMap(object):
|
194 |
+
def __call__(self, sample):
|
195 |
+
heatmap= sample['heatmap']
|
196 |
+
boundary = sample['boundary']
|
197 |
+
heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0)
|
198 |
+
weight_map = np.zeros_like(heatmap)
|
199 |
+
for i in range(heatmap.shape[0]):
|
200 |
+
weight_map[i] = generate_weight_map(weight_map[i],
|
201 |
+
heatmap[i])
|
202 |
+
sample['weight_map'] = weight_map
|
203 |
+
return sample
|
204 |
+
|
205 |
+
class ToTensor(object):
|
206 |
+
"""Convert ndarrays in sample to Tensors."""
|
207 |
+
|
208 |
+
def __call__(self, sample):
|
209 |
+
image, heatmap, landmarks, boundary, weight_map= sample['image'], sample['heatmap'], sample['landmarks'], sample['boundary'], sample['weight_map']
|
210 |
+
|
211 |
+
# swap color axis because
|
212 |
+
# numpy image: H x W x C
|
213 |
+
# torch image: C X H X W
|
214 |
+
if len(image.shape) == 2:
|
215 |
+
image = np.expand_dims(image, axis=2)
|
216 |
+
image_small = np.expand_dims(image_small, axis=2)
|
217 |
+
image = image.transpose((2, 0, 1))
|
218 |
+
boundary = np.expand_dims(boundary, axis=2)
|
219 |
+
boundary = boundary.transpose((2, 0, 1))
|
220 |
+
return {'image': torch.from_numpy(image).float().div(255.0),
|
221 |
+
'heatmap': torch.from_numpy(heatmap).float(),
|
222 |
+
'landmarks': torch.from_numpy(landmarks).float(),
|
223 |
+
'boundary': torch.from_numpy(boundary).float().div(255.0),
|
224 |
+
'weight_map': torch.from_numpy(weight_map).float()}
|
225 |
+
|
226 |
+
class FaceLandmarksDataset(Dataset):
|
227 |
+
"""Face Landmarks dataset."""
|
228 |
+
|
229 |
+
def __init__(self, img_dir, landmarks_dir, num_landmarks=68, gray_scale=False,
|
230 |
+
detect_face=False, enhance=False, center_shift=0,
|
231 |
+
transform=None,):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
landmark_dir (string): Path to the mat file with landmarks saved.
|
235 |
+
img_dir (string): Directory with all the images.
|
236 |
+
transform (callable, optional): Optional transform to be applied
|
237 |
+
on a sample.
|
238 |
+
"""
|
239 |
+
self.img_dir = img_dir
|
240 |
+
self.landmarks_dir = landmarks_dir
|
241 |
+
self.num_lanmdkars = num_landmarks
|
242 |
+
self.transform = transform
|
243 |
+
self.img_names = glob.glob(self.img_dir+'*.jpg') + \
|
244 |
+
glob.glob(self.img_dir+'*.png')
|
245 |
+
self.gray_scale = gray_scale
|
246 |
+
self.detect_face = detect_face
|
247 |
+
self.enhance = enhance
|
248 |
+
self.center_shift = center_shift
|
249 |
+
if self.detect_face:
|
250 |
+
self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7])
|
251 |
+
def __len__(self):
|
252 |
+
return len(self.img_names)
|
253 |
+
|
254 |
+
def __getitem__(self, idx):
|
255 |
+
img_name = self.img_names[idx]
|
256 |
+
pil_image = Image.open(img_name)
|
257 |
+
if pil_image.mode != "RGB":
|
258 |
+
# if input is grayscale image, convert it to 3 channel image
|
259 |
+
if self.enhance:
|
260 |
+
pil_image = power_transform(pil_image, 0.5)
|
261 |
+
temp_image = Image.new('RGB', pil_image.size)
|
262 |
+
temp_image.paste(pil_image)
|
263 |
+
pil_image = temp_image
|
264 |
+
image = np.array(pil_image)
|
265 |
+
if self.gray_scale:
|
266 |
+
image = rgb2gray(image)
|
267 |
+
image = np.expand_dims(image, axis=2)
|
268 |
+
image = np.concatenate((image, image, image), axis=2)
|
269 |
+
image = image * 255.0
|
270 |
+
image = image.astype(np.uint8)
|
271 |
+
if not self.detect_face:
|
272 |
+
center = [450//2, 450//2+0]
|
273 |
+
if self.center_shift != 0:
|
274 |
+
center[0] += int(np.random.uniform(-self.center_shift,
|
275 |
+
self.center_shift))
|
276 |
+
center[1] += int(np.random.uniform(-self.center_shift,
|
277 |
+
self.center_shift))
|
278 |
+
scale = 1.8
|
279 |
+
else:
|
280 |
+
detected_faces = self.face_detector.detect_image(image)
|
281 |
+
if len(detected_faces) > 0:
|
282 |
+
box = detected_faces[0]
|
283 |
+
left, top, right, bottom, _ = box
|
284 |
+
center = [right - (right - left) / 2.0,
|
285 |
+
bottom - (bottom - top) / 2.0]
|
286 |
+
center[1] = center[1] - (bottom - top) * 0.12
|
287 |
+
scale = (right - left + bottom - top) / 195.0
|
288 |
+
else:
|
289 |
+
center = [450//2, 450//2+0]
|
290 |
+
scale = 1.8
|
291 |
+
if self.center_shift != 0:
|
292 |
+
shift = self.center * self.center_shift / 450
|
293 |
+
center[0] += int(np.random.uniform(-shift, shift))
|
294 |
+
center[1] += int(np.random.uniform(-shift, shift))
|
295 |
+
base_name = os.path.basename(img_name)
|
296 |
+
landmarks_base_name = base_name[:-4] + '_pts.mat'
|
297 |
+
landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name)
|
298 |
+
if os.path.isfile(landmarks_name):
|
299 |
+
mat_data = sio.loadmat(landmarks_name)
|
300 |
+
landmarks = mat_data['pts_2d']
|
301 |
+
elif os.path.isfile(landmarks_name[:-8] + '.pts.npy'):
|
302 |
+
landmarks = np.load(landmarks_name[:-8] + '.pts.npy')
|
303 |
+
else:
|
304 |
+
landmarks = []
|
305 |
+
heatmap = []
|
306 |
+
|
307 |
+
if landmarks != []:
|
308 |
+
new_image, new_landmarks = cv_crop(image, landmarks, center,
|
309 |
+
scale, 256, self.center_shift)
|
310 |
+
tries = 0
|
311 |
+
while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15):
|
312 |
+
center = [450//2, 450//2+0]
|
313 |
+
scale += 0.05
|
314 |
+
center[0] += int(np.random.uniform(-self.center_shift,
|
315 |
+
self.center_shift))
|
316 |
+
center[1] += int(np.random.uniform(-self.center_shift,
|
317 |
+
self.center_shift))
|
318 |
+
|
319 |
+
new_image, new_landmarks = cv_crop(image, landmarks,
|
320 |
+
center, scale, 256,
|
321 |
+
self.center_shift)
|
322 |
+
tries += 1
|
323 |
+
if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5:
|
324 |
+
center = [450//2, 450//2+0]
|
325 |
+
scale = 2.25
|
326 |
+
new_image, new_landmarks = cv_crop(image, landmarks,
|
327 |
+
center, scale, 256,
|
328 |
+
100)
|
329 |
+
assert (np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256), \
|
330 |
+
"Landmarks out of boundary!"
|
331 |
+
image = new_image
|
332 |
+
landmarks = new_landmarks
|
333 |
+
heatmap = np.zeros((self.num_lanmdkars, 64, 64))
|
334 |
+
for i in range(self.num_lanmdkars):
|
335 |
+
if landmarks[i][0] > 0:
|
336 |
+
heatmap[i] = draw_gaussian(heatmap[i], landmarks[i]/4.0+1, 1)
|
337 |
+
sample = {'image': image, 'heatmap': heatmap, 'landmarks': landmarks}
|
338 |
+
if self.transform:
|
339 |
+
sample = self.transform(sample)
|
340 |
+
|
341 |
+
return sample
|
342 |
+
|
343 |
+
def get_dataset(val_img_dir, val_landmarks_dir, batch_size,
|
344 |
+
num_landmarks=68, rotation=0, scale=0,
|
345 |
+
center_shift=0, random_flip=False,
|
346 |
+
brightness=0, contrast=0, saturation=0,
|
347 |
+
blur=False, noise=False, jpeg_effect=False,
|
348 |
+
random_occlusion=False, gray_scale=False,
|
349 |
+
detect_face=False, enhance=False):
|
350 |
+
val_transforms = transforms.Compose([AddBoundary(num_landmarks),
|
351 |
+
AddWeightMap(),
|
352 |
+
ToTensor()])
|
353 |
+
|
354 |
+
val_dataset = FaceLandmarksDataset(val_img_dir, val_landmarks_dir,
|
355 |
+
num_landmarks=num_landmarks,
|
356 |
+
gray_scale=gray_scale,
|
357 |
+
detect_face=detect_face,
|
358 |
+
enhance=enhance,
|
359 |
+
transform=val_transforms)
|
360 |
+
|
361 |
+
val_dataloader = torch.utils.data.DataLoader(val_dataset,
|
362 |
+
batch_size=batch_size,
|
363 |
+
shuffle=False,
|
364 |
+
num_workers=6)
|
365 |
+
data_loaders = {'val': val_dataloader}
|
366 |
+
dataset_sizes = {}
|
367 |
+
dataset_sizes['val'] = len(val_dataset)
|
368 |
+
return data_loaders, dataset_sizes
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/evaler.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use('Agg')
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import copy
|
6 |
+
import time
|
7 |
+
from torch.autograd import Variable
|
8 |
+
import shutil
|
9 |
+
from skimage import io
|
10 |
+
import numpy as np
|
11 |
+
from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
|
12 |
+
from PIL import Image, ImageDraw
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
import cv2
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
|
19 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
+
def eval_model(model, dataloaders, dataset_sizes,
|
22 |
+
writer, use_gpu=True, epoches=5, dataset='val',
|
23 |
+
save_path='./', num_landmarks=68):
|
24 |
+
global_nme = 0
|
25 |
+
model.eval()
|
26 |
+
for epoch in range(epoches):
|
27 |
+
running_loss = 0
|
28 |
+
step = 0
|
29 |
+
total_nme = 0
|
30 |
+
total_count = 0
|
31 |
+
fail_count = 0
|
32 |
+
nmes = []
|
33 |
+
# running_corrects = 0
|
34 |
+
|
35 |
+
# Iterate over data.
|
36 |
+
with torch.no_grad():
|
37 |
+
for data in dataloaders[dataset]:
|
38 |
+
total_runtime = 0
|
39 |
+
run_count = 0
|
40 |
+
step_start = time.time()
|
41 |
+
step += 1
|
42 |
+
# get the inputs
|
43 |
+
inputs = data['image'].type(torch.FloatTensor)
|
44 |
+
labels_heatmap = data['heatmap'].type(torch.FloatTensor)
|
45 |
+
labels_boundary = data['boundary'].type(torch.FloatTensor)
|
46 |
+
landmarks = data['landmarks'].type(torch.FloatTensor)
|
47 |
+
loss_weight_map = data['weight_map'].type(torch.FloatTensor)
|
48 |
+
# wrap them in Variable
|
49 |
+
if use_gpu:
|
50 |
+
inputs = inputs.to(device)
|
51 |
+
labels_heatmap = labels_heatmap.to(device)
|
52 |
+
labels_boundary = labels_boundary.to(device)
|
53 |
+
loss_weight_map = loss_weight_map.to(device)
|
54 |
+
else:
|
55 |
+
inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap)
|
56 |
+
labels_boundary = Variable(labels_boundary)
|
57 |
+
labels = torch.cat((labels_heatmap, labels_boundary), 1)
|
58 |
+
single_start = time.time()
|
59 |
+
outputs, boundary_channels = model(inputs)
|
60 |
+
single_end = time.time()
|
61 |
+
total_runtime += time.time() - single_start
|
62 |
+
run_count += 1
|
63 |
+
step_end = time.time()
|
64 |
+
for i in range(inputs.shape[0]):
|
65 |
+
print(inputs.shape)
|
66 |
+
img = inputs[i]
|
67 |
+
img = img.cpu().numpy()
|
68 |
+
img = img.transpose((1, 2, 0)) #*255.0
|
69 |
+
# img = img.astype(np.uint8)
|
70 |
+
# img = Image.fromarray(img)
|
71 |
+
# pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :]
|
72 |
+
pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu()
|
73 |
+
pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
|
74 |
+
pred_landmarks = pred_landmarks.squeeze().numpy()
|
75 |
+
|
76 |
+
gt_landmarks = data['landmarks'][i].numpy()
|
77 |
+
print(pred_landmarks, gt_landmarks)
|
78 |
+
import cv2
|
79 |
+
while(True):
|
80 |
+
imgshow = vis_landmark_on_img(cv2.UMat(img), pred_landmarks*4)
|
81 |
+
cv2.imshow('img', imgshow)
|
82 |
+
|
83 |
+
if(cv2.waitKey(10) == ord('q')):
|
84 |
+
break
|
85 |
+
|
86 |
+
|
87 |
+
if num_landmarks == 68:
|
88 |
+
left_eye = np.average(gt_landmarks[36:42], axis=0)
|
89 |
+
right_eye = np.average(gt_landmarks[42:48], axis=0)
|
90 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
91 |
+
# norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45])
|
92 |
+
|
93 |
+
elif num_landmarks == 98:
|
94 |
+
norm_factor = np.linalg.norm(gt_landmarks[60]- gt_landmarks[72])
|
95 |
+
elif num_landmarks == 19:
|
96 |
+
left, top = gt_landmarks[-2, :]
|
97 |
+
right, bottom = gt_landmarks[-1, :]
|
98 |
+
norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
|
99 |
+
gt_landmarks = gt_landmarks[:-2, :]
|
100 |
+
elif num_landmarks == 29:
|
101 |
+
# norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9])
|
102 |
+
norm_factor = np.linalg.norm(gt_landmarks[16]- gt_landmarks[17])
|
103 |
+
single_nme = (np.sum(np.linalg.norm(pred_landmarks*4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]) / norm_factor
|
104 |
+
|
105 |
+
nmes.append(single_nme)
|
106 |
+
total_count += 1
|
107 |
+
if single_nme > 0.1:
|
108 |
+
fail_count += 1
|
109 |
+
if step % 10 == 0:
|
110 |
+
print('Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}'.format(
|
111 |
+
step, step_end - step_start,
|
112 |
+
torch.mean(labels),
|
113 |
+
torch.mean(outputs[0])))
|
114 |
+
# gt_landmarks = landmarks.numpy()
|
115 |
+
# pred_heatmap = outputs[-1].to('cpu').numpy()
|
116 |
+
gt_landmarks = landmarks
|
117 |
+
batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
|
118 |
+
# batch_nme = 0
|
119 |
+
total_nme += batch_nme
|
120 |
+
epoch_nme = total_nme / dataset_sizes['val']
|
121 |
+
global_nme += epoch_nme
|
122 |
+
nme_save_path = os.path.join(save_path, 'nme_log.npy')
|
123 |
+
np.save(nme_save_path, np.array(nmes))
|
124 |
+
print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))
|
125 |
+
print('Evaluation done! Average NME: {:.6f}'.format(global_nme/epoches))
|
126 |
+
print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
|
127 |
+
return model
|
128 |
+
|
129 |
+
|
130 |
+
def vis_landmark_on_img(img, shape, linewidth=2):
|
131 |
+
'''
|
132 |
+
Visualize landmark on images.
|
133 |
+
'''
|
134 |
+
|
135 |
+
def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth):
|
136 |
+
for i in idx_list:
|
137 |
+
cv2.line(img, (shape[i, 0], shape[i, 1]), (shape[i + 1, 0], shape[i + 1, 1]), color, lineWidth)
|
138 |
+
if (loop):
|
139 |
+
cv2.line(img, (shape[idx_list[0], 0], shape[idx_list[0], 1]),
|
140 |
+
(shape[idx_list[-1] + 1, 0], shape[idx_list[-1] + 1, 1]), color, lineWidth)
|
141 |
+
|
142 |
+
draw_curve(list(range(0, 32))) # jaw
|
143 |
+
draw_curve(list(range(33, 41)), color=(0, 0, 255), loop=True) # eye brow
|
144 |
+
draw_curve(list(range(42, 50)), color=(0, 0, 255), loop=True)
|
145 |
+
draw_curve(list(range(51, 59))) # nose
|
146 |
+
draw_curve(list(range(60, 67)), loop=True) # eyes
|
147 |
+
draw_curve(list(range(68, 75)), loop=True)
|
148 |
+
draw_curve(list(range(76, 87)), loop=True, color=(0, 255, 255)) # mouth
|
149 |
+
draw_curve(list(range(88, 95)), loop=True, color=(255, 255, 0))
|
150 |
+
|
151 |
+
return img
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/models.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from core.coord_conv import CoordConvTh
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1,
|
9 |
+
bias=False,dilation=1):
|
10 |
+
"3x3 convolution with padding"
|
11 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
12 |
+
stride=strd, padding=padding, bias=bias,
|
13 |
+
dilation=dilation)
|
14 |
+
|
15 |
+
class BasicBlock(nn.Module):
|
16 |
+
expansion = 1
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
19 |
+
super(BasicBlock, self).__init__()
|
20 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
21 |
+
# self.bn1 = nn.BatchNorm2d(planes)
|
22 |
+
self.relu = nn.ReLU(inplace=True)
|
23 |
+
self.conv2 = conv3x3(planes, planes)
|
24 |
+
# self.bn2 = nn.BatchNorm2d(planes)
|
25 |
+
self.downsample = downsample
|
26 |
+
self.stride = stride
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
residual = x
|
30 |
+
|
31 |
+
out = self.conv1(x)
|
32 |
+
# out = self.bn1(out)
|
33 |
+
out = self.relu(out)
|
34 |
+
|
35 |
+
out = self.conv2(out)
|
36 |
+
# out = self.bn2(out)
|
37 |
+
|
38 |
+
if self.downsample is not None:
|
39 |
+
residual = self.downsample(x)
|
40 |
+
|
41 |
+
out += residual
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
return out
|
45 |
+
|
46 |
+
class ConvBlock(nn.Module):
|
47 |
+
def __init__(self, in_planes, out_planes):
|
48 |
+
super(ConvBlock, self).__init__()
|
49 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
50 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
51 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
52 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4),
|
53 |
+
padding=1, dilation=1)
|
54 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
55 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4),
|
56 |
+
padding=1, dilation=1)
|
57 |
+
|
58 |
+
if in_planes != out_planes:
|
59 |
+
self.downsample = nn.Sequential(
|
60 |
+
nn.BatchNorm2d(in_planes),
|
61 |
+
nn.ReLU(True),
|
62 |
+
nn.Conv2d(in_planes, out_planes,
|
63 |
+
kernel_size=1, stride=1, bias=False),
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
self.downsample = None
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
residual = x
|
70 |
+
|
71 |
+
out1 = self.bn1(x)
|
72 |
+
out1 = F.relu(out1, True)
|
73 |
+
out1 = self.conv1(out1)
|
74 |
+
|
75 |
+
out2 = self.bn2(out1)
|
76 |
+
out2 = F.relu(out2, True)
|
77 |
+
out2 = self.conv2(out2)
|
78 |
+
|
79 |
+
out3 = self.bn3(out2)
|
80 |
+
out3 = F.relu(out3, True)
|
81 |
+
out3 = self.conv3(out3)
|
82 |
+
|
83 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
residual = self.downsample(residual)
|
87 |
+
|
88 |
+
out3 += residual
|
89 |
+
|
90 |
+
return out3
|
91 |
+
|
92 |
+
class HourGlass(nn.Module):
|
93 |
+
def __init__(self, num_modules, depth, num_features, first_one=False):
|
94 |
+
super(HourGlass, self).__init__()
|
95 |
+
self.num_modules = num_modules
|
96 |
+
self.depth = depth
|
97 |
+
self.features = num_features
|
98 |
+
self.coordconv = CoordConvTh(x_dim=64, y_dim=64,
|
99 |
+
with_r=True, with_boundary=True,
|
100 |
+
in_channels=256, first_one=first_one,
|
101 |
+
out_channels=256,
|
102 |
+
kernel_size=1,
|
103 |
+
stride=1, padding=0)
|
104 |
+
self._generate_network(self.depth)
|
105 |
+
|
106 |
+
def _generate_network(self, level):
|
107 |
+
self.add_module('b1_' + str(level), ConvBlock(256, 256))
|
108 |
+
|
109 |
+
self.add_module('b2_' + str(level), ConvBlock(256, 256))
|
110 |
+
|
111 |
+
if level > 1:
|
112 |
+
self._generate_network(level - 1)
|
113 |
+
else:
|
114 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
|
115 |
+
|
116 |
+
self.add_module('b3_' + str(level), ConvBlock(256, 256))
|
117 |
+
|
118 |
+
def _forward(self, level, inp):
|
119 |
+
# Upper branch
|
120 |
+
up1 = inp
|
121 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
122 |
+
|
123 |
+
# Lower branch
|
124 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
125 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
126 |
+
|
127 |
+
if level > 1:
|
128 |
+
low2 = self._forward(level - 1, low1)
|
129 |
+
else:
|
130 |
+
low2 = low1
|
131 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
132 |
+
|
133 |
+
low3 = low2
|
134 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
135 |
+
|
136 |
+
up2 = F.upsample(low3, scale_factor=2, mode='nearest')
|
137 |
+
|
138 |
+
return up1 + up2
|
139 |
+
|
140 |
+
def forward(self, x, heatmap):
|
141 |
+
x, last_channel = self.coordconv(x, heatmap)
|
142 |
+
return self._forward(self.depth, x), last_channel
|
143 |
+
|
144 |
+
class FAN(nn.Module):
|
145 |
+
|
146 |
+
def __init__(self, num_modules=1, end_relu=False, gray_scale=False,
|
147 |
+
num_landmarks=68):
|
148 |
+
super(FAN, self).__init__()
|
149 |
+
self.num_modules = num_modules
|
150 |
+
self.gray_scale = gray_scale
|
151 |
+
self.end_relu = end_relu
|
152 |
+
self.num_landmarks = num_landmarks
|
153 |
+
|
154 |
+
# Base part
|
155 |
+
if self.gray_scale:
|
156 |
+
self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
|
157 |
+
with_r=True, with_boundary=False,
|
158 |
+
in_channels=3, out_channels=64,
|
159 |
+
kernel_size=7,
|
160 |
+
stride=2, padding=3)
|
161 |
+
else:
|
162 |
+
self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
|
163 |
+
with_r=True, with_boundary=False,
|
164 |
+
in_channels=3, out_channels=64,
|
165 |
+
kernel_size=7,
|
166 |
+
stride=2, padding=3)
|
167 |
+
self.bn1 = nn.BatchNorm2d(64)
|
168 |
+
self.conv2 = ConvBlock(64, 128)
|
169 |
+
self.conv3 = ConvBlock(128, 128)
|
170 |
+
self.conv4 = ConvBlock(128, 256)
|
171 |
+
|
172 |
+
# Stacking part
|
173 |
+
for hg_module in range(self.num_modules):
|
174 |
+
if hg_module == 0:
|
175 |
+
first_one = True
|
176 |
+
else:
|
177 |
+
first_one = False
|
178 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256,
|
179 |
+
first_one))
|
180 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
181 |
+
self.add_module('conv_last' + str(hg_module),
|
182 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
183 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
184 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
185 |
+
num_landmarks+1, kernel_size=1, stride=1, padding=0))
|
186 |
+
|
187 |
+
if hg_module < self.num_modules - 1:
|
188 |
+
self.add_module(
|
189 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
190 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks+1,
|
191 |
+
256, kernel_size=1, stride=1, padding=0))
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
x, _ = self.conv1(x)
|
195 |
+
x = F.relu(self.bn1(x), True)
|
196 |
+
# x = F.relu(self.bn1(self.conv1(x)), True)
|
197 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
198 |
+
x = self.conv3(x)
|
199 |
+
x = self.conv4(x)
|
200 |
+
|
201 |
+
previous = x
|
202 |
+
|
203 |
+
outputs = []
|
204 |
+
boundary_channels = []
|
205 |
+
tmp_out = None
|
206 |
+
for i in range(self.num_modules):
|
207 |
+
hg, boundary_channel = self._modules['m' + str(i)](previous,
|
208 |
+
tmp_out)
|
209 |
+
|
210 |
+
ll = hg
|
211 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
212 |
+
|
213 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
214 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
215 |
+
|
216 |
+
# Predict heatmaps
|
217 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
218 |
+
if self.end_relu:
|
219 |
+
tmp_out = F.relu(tmp_out) # HACK: Added relu
|
220 |
+
outputs.append(tmp_out)
|
221 |
+
boundary_channels.append(boundary_channel)
|
222 |
+
|
223 |
+
if i < self.num_modules - 1:
|
224 |
+
ll = self._modules['bl' + str(i)](ll)
|
225 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
226 |
+
previous = previous + ll + tmp_out_
|
227 |
+
|
228 |
+
return outputs, boundary_channels
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/eval.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
import time
|
7 |
+
import os
|
8 |
+
from core.evaler import eval_model
|
9 |
+
from core.dataloader import get_dataset
|
10 |
+
from core import models
|
11 |
+
from tensorboardX import SummaryWriter
|
12 |
+
|
13 |
+
# Parse arguments
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
# Dataset paths
|
16 |
+
parser.add_argument('--val_img_dir', type=str,
|
17 |
+
help='Validation image directory')
|
18 |
+
parser.add_argument('--val_landmarks_dir', type=str,
|
19 |
+
help='Validation landmarks directory')
|
20 |
+
parser.add_argument('--num_landmarks', type=int, default=68,
|
21 |
+
help='Number of landmarks')
|
22 |
+
|
23 |
+
# Checkpoint and pretrained weights
|
24 |
+
parser.add_argument('--ckpt_save_path', type=str,
|
25 |
+
help='a directory to save checkpoint file')
|
26 |
+
parser.add_argument('--pretrained_weights', type=str,
|
27 |
+
help='a directory to save pretrained_weights')
|
28 |
+
|
29 |
+
# Eval options
|
30 |
+
parser.add_argument('--batch_size', type=int, default=25,
|
31 |
+
help='learning rate decay after each epoch')
|
32 |
+
|
33 |
+
# Network parameters
|
34 |
+
parser.add_argument('--hg_blocks', type=int, default=4,
|
35 |
+
help='Number of HG blocks to stack')
|
36 |
+
parser.add_argument('--gray_scale', type=str, default="False",
|
37 |
+
help='Whether to convert RGB image into gray scale during training')
|
38 |
+
parser.add_argument('--end_relu', type=str, default="False",
|
39 |
+
help='Whether to add relu at the end of each HG module')
|
40 |
+
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
VAL_IMG_DIR = args.val_img_dir
|
44 |
+
VAL_LANDMARKS_DIR = args.val_landmarks_dir
|
45 |
+
CKPT_SAVE_PATH = args.ckpt_save_path
|
46 |
+
BATCH_SIZE = args.batch_size
|
47 |
+
PRETRAINED_WEIGHTS = args.pretrained_weights
|
48 |
+
GRAY_SCALE = False if args.gray_scale == 'False' else True
|
49 |
+
HG_BLOCKS = args.hg_blocks
|
50 |
+
END_RELU = False if args.end_relu == 'False' else True
|
51 |
+
NUM_LANDMARKS = args.num_landmarks
|
52 |
+
|
53 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
54 |
+
|
55 |
+
writer = SummaryWriter(CKPT_SAVE_PATH)
|
56 |
+
|
57 |
+
dataloaders, dataset_sizes = get_dataset(VAL_IMG_DIR, VAL_LANDMARKS_DIR,
|
58 |
+
BATCH_SIZE, NUM_LANDMARKS)
|
59 |
+
use_gpu = torch.cuda.is_available()
|
60 |
+
model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS)
|
61 |
+
|
62 |
+
if PRETRAINED_WEIGHTS != "None":
|
63 |
+
checkpoint = torch.load(PRETRAINED_WEIGHTS)
|
64 |
+
if 'state_dict' not in checkpoint:
|
65 |
+
model_ft.load_state_dict(checkpoint)
|
66 |
+
else:
|
67 |
+
pretrained_weights = checkpoint['state_dict']
|
68 |
+
model_weights = model_ft.state_dict()
|
69 |
+
pretrained_weights = {k: v for k, v in pretrained_weights.items() \
|
70 |
+
if k in model_weights}
|
71 |
+
model_weights.update(pretrained_weights)
|
72 |
+
model_ft.load_state_dict(model_weights)
|
73 |
+
|
74 |
+
model_ft = model_ft.to(device)
|
75 |
+
|
76 |
+
model_ft = eval_model(model_ft, dataloaders, dataset_sizes, writer, use_gpu, 1, 'val', CKPT_SAVE_PATH, NUM_LANDMARKS)
|
77 |
+
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png
ADDED
![]() |
Git LFS Details
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png
ADDED
![]() |
Git LFS Details
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
scipy>=0.17.0
|
3 |
+
scikit-image
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
Pillow>=4.3.0
|
7 |
+
imgaug
|
8 |
+
tensorflow
|
9 |
+
git+https://github.com/lanpa/tensorboardX
|
10 |
+
joblib
|
11 |
+
torch==1.3.0
|
12 |
+
torchvision==0.4.1
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/scripts/eval_wflw.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=1 python ../eval.py \
|
2 |
+
--val_img_dir='../dataset/WFLW_test/images/' \
|
3 |
+
--val_landmarks_dir='../dataset/WFLW_test/landmarks/' \
|
4 |
+
--ckpt_save_path='../experiments/eval_iccv_0620' \
|
5 |
+
--hg_blocks=4 \
|
6 |
+
--pretrained_weights='../ckpt/WFLW_4HG.pth' \
|
7 |
+
--num_landmarks=98 \
|
8 |
+
--end_relu='False' \
|
9 |
+
--batch_size=20 \
|
10 |
+
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__init__.py
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (170 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (185 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (11.8 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (11.6 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/utils.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
from PIL import Image
|
8 |
+
from skimage import io
|
9 |
+
from skimage import transform as ski_transform
|
10 |
+
from scipy import ndimage
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
from torchvision import transforms, utils
|
16 |
+
|
17 |
+
def _gaussian(
|
18 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
19 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
20 |
+
mean_vert=0.5):
|
21 |
+
# handle some defaults
|
22 |
+
if width is None:
|
23 |
+
width = size
|
24 |
+
if height is None:
|
25 |
+
height = size
|
26 |
+
if sigma_horz is None:
|
27 |
+
sigma_horz = sigma
|
28 |
+
if sigma_vert is None:
|
29 |
+
sigma_vert = sigma
|
30 |
+
center_x = mean_horz * width + 0.5
|
31 |
+
center_y = mean_vert * height + 0.5
|
32 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
33 |
+
# generate kernel
|
34 |
+
for i in range(height):
|
35 |
+
for j in range(width):
|
36 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
37 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
38 |
+
if normalize:
|
39 |
+
gauss = gauss / np.sum(gauss)
|
40 |
+
return gauss
|
41 |
+
|
42 |
+
def draw_gaussian(image, point, sigma):
|
43 |
+
# Check if the gaussian is inside
|
44 |
+
ul = [np.floor(np.floor(point[0]) - 3 * sigma),
|
45 |
+
np.floor(np.floor(point[1]) - 3 * sigma)]
|
46 |
+
br = [np.floor(np.floor(point[0]) + 3 * sigma),
|
47 |
+
np.floor(np.floor(point[1]) + 3 * sigma)]
|
48 |
+
if (ul[0] > image.shape[1] or ul[1] >
|
49 |
+
image.shape[0] or br[0] < 1 or br[1] < 1):
|
50 |
+
return image
|
51 |
+
size = 6 * sigma + 1
|
52 |
+
g = _gaussian(size)
|
53 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
|
54 |
+
int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
55 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
|
56 |
+
int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
57 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
58 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
59 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
60 |
+
correct = False
|
61 |
+
while not correct:
|
62 |
+
try:
|
63 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
64 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
65 |
+
correct = True
|
66 |
+
except:
|
67 |
+
print('img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}'.format(img_x, img_y, g_x, g_y, point, g.shape, ul, br))
|
68 |
+
ul = [np.floor(np.floor(point[0]) - 3 * sigma),
|
69 |
+
np.floor(np.floor(point[1]) - 3 * sigma)]
|
70 |
+
br = [np.floor(np.floor(point[0]) + 3 * sigma),
|
71 |
+
np.floor(np.floor(point[1]) + 3 * sigma)]
|
72 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
|
73 |
+
int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
74 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
|
75 |
+
int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
76 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
77 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
78 |
+
pass
|
79 |
+
image[image > 1] = 1
|
80 |
+
return image
|
81 |
+
|
82 |
+
def transform(point, center, scale, resolution, rotation=0, invert=False):
|
83 |
+
_pt = np.ones(3)
|
84 |
+
_pt[0] = point[0]
|
85 |
+
_pt[1] = point[1]
|
86 |
+
|
87 |
+
h = 200.0 * scale
|
88 |
+
t = np.eye(3)
|
89 |
+
t[0, 0] = resolution / h
|
90 |
+
t[1, 1] = resolution / h
|
91 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
92 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
93 |
+
|
94 |
+
if rotation != 0:
|
95 |
+
rotation = -rotation
|
96 |
+
r = np.eye(3)
|
97 |
+
ang = rotation * math.pi / 180.0
|
98 |
+
s = math.sin(ang)
|
99 |
+
c = math.cos(ang)
|
100 |
+
r[0][0] = c
|
101 |
+
r[0][1] = -s
|
102 |
+
r[1][0] = s
|
103 |
+
r[1][1] = c
|
104 |
+
|
105 |
+
t_ = np.eye(3)
|
106 |
+
t_[0][2] = -resolution / 2.0
|
107 |
+
t_[1][2] = -resolution / 2.0
|
108 |
+
t_inv = torch.eye(3)
|
109 |
+
t_inv[0][2] = resolution / 2.0
|
110 |
+
t_inv[1][2] = resolution / 2.0
|
111 |
+
t = reduce(np.matmul, [t_inv, r, t_, t])
|
112 |
+
|
113 |
+
if invert:
|
114 |
+
t = np.linalg.inv(t)
|
115 |
+
new_point = (np.matmul(t, _pt))[0:2]
|
116 |
+
|
117 |
+
return new_point.astype(int)
|
118 |
+
|
119 |
+
def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0):
|
120 |
+
new_image = cv2.copyMakeBorder(image, center_shift,
|
121 |
+
center_shift,
|
122 |
+
center_shift,
|
123 |
+
center_shift,
|
124 |
+
cv2.BORDER_CONSTANT, value=[0,0,0])
|
125 |
+
new_landmarks = landmarks.copy()
|
126 |
+
if center_shift != 0:
|
127 |
+
center[0] += center_shift
|
128 |
+
center[1] += center_shift
|
129 |
+
new_landmarks = new_landmarks + center_shift
|
130 |
+
length = 200 * scale
|
131 |
+
top = int(center[1] - length // 2)
|
132 |
+
bottom = int(center[1] + length // 2)
|
133 |
+
left = int(center[0] - length // 2)
|
134 |
+
right = int(center[0] + length // 2)
|
135 |
+
y_pad = abs(min(top, new_image.shape[0] - bottom, 0))
|
136 |
+
x_pad = abs(min(left, new_image.shape[1] - right, 0))
|
137 |
+
top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad
|
138 |
+
new_image = cv2.copyMakeBorder(new_image, y_pad,
|
139 |
+
y_pad,
|
140 |
+
x_pad,
|
141 |
+
x_pad,
|
142 |
+
cv2.BORDER_CONSTANT, value=[0,0,0])
|
143 |
+
new_image = new_image[top:bottom, left:right]
|
144 |
+
new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)),
|
145 |
+
interpolation=cv2.INTER_LINEAR)
|
146 |
+
new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length
|
147 |
+
new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length
|
148 |
+
return new_image, new_landmarks
|
149 |
+
|
150 |
+
def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256):
|
151 |
+
img_mat = cv2.getRotationMatrix2D((resolution//2, resolution//2), rot, scale)
|
152 |
+
ones = np.ones(shape=(landmarks.shape[0], 1))
|
153 |
+
stacked_landmarks = np.hstack([landmarks, ones])
|
154 |
+
new_landmarks = img_mat.dot(stacked_landmarks.T).T
|
155 |
+
if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0:
|
156 |
+
return image, landmarks, heatmap
|
157 |
+
else:
|
158 |
+
new_image = cv2.warpAffine(image, img_mat, (resolution, resolution))
|
159 |
+
if heatmap is not None:
|
160 |
+
new_heatmap = np.zeros((heatmap.shape[0], 64, 64))
|
161 |
+
for i in range(heatmap.shape[0]):
|
162 |
+
if new_landmarks[i][0] > 0:
|
163 |
+
new_heatmap[i] = draw_gaussian(new_heatmap[i],
|
164 |
+
new_landmarks[i]/4.0+1, 1)
|
165 |
+
return new_image, new_landmarks, new_heatmap
|
166 |
+
|
167 |
+
def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap):
|
168 |
+
"""Show image with pred_landmarks"""
|
169 |
+
pred_landmarks = []
|
170 |
+
pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
|
171 |
+
pred_landmarks = pred_landmarks.squeeze()*4
|
172 |
+
|
173 |
+
# pred_landmarks2 = get_preds_fromhm2(heatmap)
|
174 |
+
heatmap = np.max(gt_heatmap, axis=0)
|
175 |
+
heatmap = heatmap / np.max(heatmap)
|
176 |
+
# image = ski_transform.resize(image, (64, 64))*255
|
177 |
+
image = image.astype(np.uint8)
|
178 |
+
heatmap = np.max(gt_heatmap, axis=0)
|
179 |
+
heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
|
180 |
+
heatmap *= 255
|
181 |
+
heatmap = heatmap.astype(np.uint8)
|
182 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
183 |
+
plt.imshow(image)
|
184 |
+
plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g')
|
185 |
+
plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r')
|
186 |
+
plt.pause(0.001) # pause a bit so that plots are updated
|
187 |
+
|
188 |
+
def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68):
|
189 |
+
'''
|
190 |
+
Calculate total NME for a batch of data
|
191 |
+
|
192 |
+
Args:
|
193 |
+
pred_heatmaps: torch tensor of size [batch, points, height, width]
|
194 |
+
gt_landmarks: torch tesnsor of size [batch, points, x, y]
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
nme: sum of nme for this batch
|
198 |
+
'''
|
199 |
+
nme = 0
|
200 |
+
pred_landmarks, _ = get_preds_fromhm(pred_heatmaps)
|
201 |
+
pred_landmarks = pred_landmarks.numpy()
|
202 |
+
gt_landmarks = gt_landmarks.numpy()
|
203 |
+
for i in range(pred_landmarks.shape[0]):
|
204 |
+
pred_landmark = pred_landmarks[i] * 4.0
|
205 |
+
gt_landmark = gt_landmarks[i]
|
206 |
+
|
207 |
+
if num_landmarks == 68:
|
208 |
+
left_eye = np.average(gt_landmark[36:42], axis=0)
|
209 |
+
right_eye = np.average(gt_landmark[42:48], axis=0)
|
210 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
211 |
+
# norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45])
|
212 |
+
elif num_landmarks == 98:
|
213 |
+
norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
|
214 |
+
elif num_landmarks == 19:
|
215 |
+
left, top = gt_landmark[-2, :]
|
216 |
+
right, bottom = gt_landmark[-1, :]
|
217 |
+
norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
|
218 |
+
gt_landmark = gt_landmark[:-2, :]
|
219 |
+
elif num_landmarks == 29:
|
220 |
+
# norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9])
|
221 |
+
norm_factor = np.linalg.norm(gt_landmark[16]- gt_landmark[17])
|
222 |
+
nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
|
223 |
+
return nme
|
224 |
+
|
225 |
+
def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68):
|
226 |
+
'''
|
227 |
+
Calculate total NME for a batch of data
|
228 |
+
|
229 |
+
Args:
|
230 |
+
pred_heatmaps: torch tensor of size [batch, points, height, width]
|
231 |
+
gt_landmarks: torch tesnsor of size [batch, points, x, y]
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
nme: sum of nme for this batch
|
235 |
+
'''
|
236 |
+
nme = 0
|
237 |
+
pred_landmarks, _ = get_index_fromhm(pred_heatmaps)
|
238 |
+
pred_landmarks = pred_landmarks.numpy()
|
239 |
+
gt_landmarks = gt_landmarks.numpy()
|
240 |
+
for i in range(pred_landmarks.shape[0]):
|
241 |
+
pred_landmark = pred_landmarks[i] * 4.0
|
242 |
+
gt_landmark = gt_landmarks[i]
|
243 |
+
if num_landmarks == 68:
|
244 |
+
left_eye = np.average(gt_landmark[36:42], axis=0)
|
245 |
+
right_eye = np.average(gt_landmark[42:48], axis=0)
|
246 |
+
norm_factor = np.linalg.norm(left_eye - right_eye)
|
247 |
+
else:
|
248 |
+
norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
|
249 |
+
nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
|
250 |
+
return nme
|
251 |
+
|
252 |
+
def power_transform(img, power):
|
253 |
+
img = np.array(img)
|
254 |
+
img_new = np.power((img/255.0), power) * 255.0
|
255 |
+
img_new = img_new.astype(np.uint8)
|
256 |
+
img_new = Image.fromarray(img_new)
|
257 |
+
return img_new
|
258 |
+
|
259 |
+
def get_preds_fromhm(hm, center=None, scale=None, rot=None):
|
260 |
+
max, idx = torch.max(
|
261 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
262 |
+
idx += 1
|
263 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
264 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
265 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
266 |
+
|
267 |
+
for i in range(preds.size(0)):
|
268 |
+
for j in range(preds.size(1)):
|
269 |
+
hm_ = hm[i, j, :]
|
270 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
271 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
272 |
+
diff = torch.FloatTensor(
|
273 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
274 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
275 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
276 |
+
|
277 |
+
preds.add_(-0.5)
|
278 |
+
|
279 |
+
preds_orig = torch.zeros(preds.size())
|
280 |
+
if center is not None and scale is not None:
|
281 |
+
for i in range(hm.size(0)):
|
282 |
+
for j in range(hm.size(1)):
|
283 |
+
preds_orig[i, j] = transform(
|
284 |
+
preds[i, j], center, scale, hm.size(2), rot, True)
|
285 |
+
|
286 |
+
return preds, preds_orig
|
287 |
+
|
288 |
+
def get_index_fromhm(hm):
|
289 |
+
max, idx = torch.max(
|
290 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
291 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
292 |
+
preds[..., 0].remainder_(hm.size(3))
|
293 |
+
preds[..., 1].div_(hm.size(2)).floor_()
|
294 |
+
|
295 |
+
for i in range(preds.size(0)):
|
296 |
+
for j in range(preds.size(1)):
|
297 |
+
hm_ = hm[i, j, :]
|
298 |
+
pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1])
|
299 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
300 |
+
diff = torch.FloatTensor(
|
301 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
302 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
303 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
304 |
+
|
305 |
+
return preds
|
306 |
+
|
307 |
+
def shuffle_lr(parts, num_landmarks=68, pairs=None):
|
308 |
+
if num_landmarks == 68:
|
309 |
+
if pairs is None:
|
310 |
+
pairs = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10],
|
311 |
+
[7, 9], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [36, 45],
|
312 |
+
[37, 44], [38, 43], [39, 42], [41, 46], [40, 47], [31, 35], [32, 34],
|
313 |
+
[50, 52], [49, 53], [48, 54], [61, 63], [60, 64], [67, 65], [59, 55], [58, 56]]
|
314 |
+
elif num_landmarks == 98:
|
315 |
+
if pairs is None:
|
316 |
+
pairs = [[0, 32], [1,31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73], [96, 97], [55, 59], [56, 58], [76, 82], [77, 81], [78, 80], [88, 92], [89, 91], [95, 93], [87, 83], [86, 84]]
|
317 |
+
elif num_landmarks == 19:
|
318 |
+
if pairs is None:
|
319 |
+
pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]]
|
320 |
+
elif num_landmarks == 29:
|
321 |
+
if pairs is None:
|
322 |
+
pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]]
|
323 |
+
for matched_p in pairs:
|
324 |
+
idx1, idx2 = matched_p[0], matched_p[1]
|
325 |
+
tmp = np.copy(parts[idx1])
|
326 |
+
np.copyto(parts[idx1], parts[idx2])
|
327 |
+
np.copyto(parts[idx2], tmp)
|
328 |
+
return parts
|
329 |
+
|
330 |
+
|
331 |
+
def generate_weight_map(weight_map,heatmap):
|
332 |
+
|
333 |
+
k_size = 3
|
334 |
+
dilate = ndimage.grey_dilation(heatmap ,size=(k_size,k_size))
|
335 |
+
weight_map[np.where(dilate>0.2)] = 1
|
336 |
+
return weight_map
|
337 |
+
|
338 |
+
def fig2data(fig):
|
339 |
+
"""
|
340 |
+
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
|
341 |
+
@param fig a matplotlib figure
|
342 |
+
@return a numpy 3D array of RGBA values
|
343 |
+
"""
|
344 |
+
# draw the renderer
|
345 |
+
fig.canvas.draw ( )
|
346 |
+
|
347 |
+
# Get the RGB buffer from the figure
|
348 |
+
w,h = fig.canvas.get_width_height()
|
349 |
+
buf = np.fromstring (fig.canvas.tostring_rgb(), dtype=np.uint8)
|
350 |
+
buf.shape = (w, h, 3)
|
351 |
+
|
352 |
+
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
|
353 |
+
buf = np.roll (buf, 3, axis=2)
|
354 |
+
return buf
|
marlenezw/audio-driven-animations/MakeItTalk/__init__.py
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (147 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (162 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/CODEOWNERS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
* @papulke
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/LICENCE.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Jordan Yaniv
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
16 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
17 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
18 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
19 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
20 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
21 |
+
OR OTHER DEALINGS IN THE SOFTWARE.
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/README.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The Face of Art: Landmark Detection and Geometric Style in Portraits
|
2 |
+
|
3 |
+
Code for the landmark detection framework described in [The Face of Art: Landmark Detection and Geometric Style in Portraits](http://www.faculty.idc.ac.il/arik/site/foa/face-of-art.asp) (SIGGRAPH 2019)
|
4 |
+
|
5 |
+

|
6 |
+
<sub><sup>Top: landmark detection results on artistic portraits with different styles allows to define the geometric style of an artist. Bottom: results of the style transfer of portraits using various artists' geometric style, including Amedeo Modigliani, Pablo Picasso, Margaret Keane, Fernand Léger, and Tsuguharu Foujita. Top right portrait is from 'Woman with Peanuts,' ©1962, Estate of Roy Lichtenstein.</sup></sub>
|
7 |
+
|
8 |
+
## Getting Started
|
9 |
+
|
10 |
+
### Requirements
|
11 |
+
|
12 |
+
* python
|
13 |
+
* anaconda
|
14 |
+
|
15 |
+
### Download
|
16 |
+
|
17 |
+
#### Model
|
18 |
+
download model weights from [here](https://www.dropbox.com/sh/hrxcyug1bmbj6cs/AAAxq_zI5eawcLjM8zvUwaXha?dl=0).
|
19 |
+
|
20 |
+
#### Datasets
|
21 |
+
* The datasets used for training and evaluating our model can be found [here](https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/).
|
22 |
+
|
23 |
+
* The Artistic-Faces dataset can be found [here](http://www.faculty.idc.ac.il/arik/site/foa/artistic-faces-dataset.asp).
|
24 |
+
|
25 |
+
* Training images with texture augmentation can be found [here](https://www.dropbox.com/sh/av2k1i1082z0nie/AAC5qV1E2UkqpDLVsv7TazMta?dl=0).
|
26 |
+
before applying texture style transfer, the training images were cropped to the ground-truth face bounding-box with 25% margin. To crop training images, run the script `crop_training_set.py`.
|
27 |
+
|
28 |
+
* our model expects the following directory structure of landmark detection datasets:
|
29 |
+
```
|
30 |
+
landmark_detection_datasets
|
31 |
+
├── training
|
32 |
+
├── test
|
33 |
+
├── challenging
|
34 |
+
├── common
|
35 |
+
├── full
|
36 |
+
├── crop_gt_margin_0.25 (cropped images of training set)
|
37 |
+
└── crop_gt_margin_0.25_ns (cropped images of training set + texture style transfer)
|
38 |
+
```
|
39 |
+
### Install
|
40 |
+
|
41 |
+
Create a virtual environment and install the following:
|
42 |
+
* opencv
|
43 |
+
* menpo
|
44 |
+
* menpofit
|
45 |
+
* tensorflow-gpu
|
46 |
+
|
47 |
+
for python 2:
|
48 |
+
```
|
49 |
+
conda create -n foa_env python=2.7 anaconda
|
50 |
+
source activate foa_env
|
51 |
+
conda install -c menpo opencv
|
52 |
+
conda install -c menpo menpo
|
53 |
+
conda install -c menpo menpofit
|
54 |
+
pip install tensorflow-gpu
|
55 |
+
|
56 |
+
```
|
57 |
+
|
58 |
+
for python 3:
|
59 |
+
```
|
60 |
+
conda create -n foa_env python=3.5 anaconda
|
61 |
+
source activate foa_env
|
62 |
+
conda install -c menpo opencv
|
63 |
+
conda install -c menpo menpo
|
64 |
+
conda install -c menpo menpofit
|
65 |
+
pip3 install tensorflow-gpu
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
Clone repository:
|
70 |
+
|
71 |
+
```
|
72 |
+
git clone https://github.com/papulke/deep_face_heatmaps
|
73 |
+
```
|
74 |
+
|
75 |
+
## Instructions
|
76 |
+
|
77 |
+
### Training
|
78 |
+
|
79 |
+
To train the network you need to run `train_heatmaps_network.py`
|
80 |
+
|
81 |
+
example for training a model with texture augmentation (100% of images) and geometric augmentation (~70% of images):
|
82 |
+
```
|
83 |
+
python train_heatmaps_network.py --output_dir='test_artistic_aug' --augment_geom=True \
|
84 |
+
--augment_texture=True --p_texture=1. --p_geom=0.7
|
85 |
+
```
|
86 |
+
|
87 |
+
### Testing
|
88 |
+
|
89 |
+
For using the detection framework to predict landmarks, run the script `predict_landmarks.py`
|
90 |
+
|
91 |
+
## Acknowledgments
|
92 |
+
|
93 |
+
* [ect](https://github.com/HongwenZhang/ECT-FaceAlignment)
|
94 |
+
* [menpo](https://github.com/menpo/menpo)
|
95 |
+
* [menpofit](https://github.com/menpo/menpofit)
|
96 |
+
* [mdm](https://github.com/trigeorgis/mdm)
|
97 |
+
* [style transfer implementation](https://github.com/woodrush/neural-art-tf)
|
98 |
+
* [painter-by-numbers dataset](https://www.kaggle.com/c/painter-by-numbers/data)
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.py
ADDED
File without changes
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.pyc
ADDED
Binary file (161 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (157 Bytes). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/data_loading_functions.cpython-36.pyc
ADDED
Binary file (4.56 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deep_heatmaps_model_fusion_net.cpython-36.pyc
ADDED
Binary file (21.6 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deformation_functions.cpython-36.pyc
ADDED
Binary file (9 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/logging_functions.cpython-36.pyc
ADDED
Binary file (5.81 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/menpo_functions.cpython-36.pyc
ADDED
Binary file (9.22 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/ops.cpython-36.pyc
ADDED
Binary file (3.6 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/pdm_clm_functions.cpython-36.pyc
ADDED
Binary file (6.34 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/crop_training_set.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.misc import imsave
|
2 |
+
from menpo_functions import *
|
3 |
+
from data_loading_functions import *
|
4 |
+
|
5 |
+
|
6 |
+
# define paths & parameters for cropping dataset
|
7 |
+
img_dir = '~/landmark_detection_datasets/'
|
8 |
+
dataset = 'training'
|
9 |
+
bb_type = 'gt'
|
10 |
+
margin = 0.25
|
11 |
+
image_size = 256
|
12 |
+
|
13 |
+
# load bounding boxes
|
14 |
+
bb_dir = os.path.join(img_dir, 'Bounding_Boxes')
|
15 |
+
bb_dictionary = load_bb_dictionary(bb_dir, mode='TRAIN', test_data=dataset)
|
16 |
+
|
17 |
+
# directory for saving face crops
|
18 |
+
outdir = os.path.join(img_dir, 'crop_'+bb_type+'_margin_'+str(margin))
|
19 |
+
if not os.path.exists(outdir):
|
20 |
+
os.mkdir(outdir)
|
21 |
+
|
22 |
+
# load images
|
23 |
+
imgs_to_crop = load_menpo_image_list(
|
24 |
+
img_dir=img_dir, train_crop_dir=None, img_dir_ns=None, mode='TRAIN', bb_dictionary=bb_dictionary,
|
25 |
+
image_size=image_size, margin=margin, bb_type=bb_type, augment_basic=False)
|
26 |
+
|
27 |
+
# save cropped images with matching landmarks
|
28 |
+
print ("\ncropping dataset from: "+os.path.join(img_dir, dataset))
|
29 |
+
print ("\nsaving cropped dataset to: "+outdir)
|
30 |
+
for im in imgs_to_crop:
|
31 |
+
if im.pixels.shape[0] == 1:
|
32 |
+
im_pixels = gray2rgb(np.squeeze(im.pixels))
|
33 |
+
else:
|
34 |
+
im_pixels = np.rollaxis(im.pixels, 0, 3)
|
35 |
+
imsave(os.path.join(outdir, im.path.name.split('.')[0]+'.png'), im_pixels)
|
36 |
+
mio.export_landmark_file(im.landmarks['PTS'], os.path.join(outdir, im.path.name.split('.')[0]+'.pts'))
|
37 |
+
|
38 |
+
print ("\ncropping dataset completed!")
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from skimage.color import gray2rgb
|
4 |
+
|
5 |
+
|
6 |
+
def train_val_shuffle_inds_per_epoch(valid_inds, train_inds, train_iter, batch_size, log_path, save_log=True):
|
7 |
+
"""shuffle image indices for each training epoch and save to log"""
|
8 |
+
|
9 |
+
np.random.seed(0)
|
10 |
+
num_train_images = len(train_inds)
|
11 |
+
num_epochs = int(np.ceil((1. * train_iter) / (1. * num_train_images / batch_size)))+1
|
12 |
+
epoch_inds_shuffle = np.zeros((num_epochs, num_train_images)).astype(int)
|
13 |
+
img_inds = np.arange(num_train_images)
|
14 |
+
for i in range(num_epochs):
|
15 |
+
np.random.shuffle(img_inds)
|
16 |
+
epoch_inds_shuffle[i, :] = img_inds
|
17 |
+
|
18 |
+
if save_log:
|
19 |
+
with open(os.path.join(log_path, "train_val_shuffle_inds.csv"), "wb") as f:
|
20 |
+
if valid_inds is not None:
|
21 |
+
f.write(b'valid inds\n')
|
22 |
+
np.savetxt(f, valid_inds.reshape(1, -1), fmt='%i', delimiter=",")
|
23 |
+
f.write(b'train inds\n')
|
24 |
+
np.savetxt(f, train_inds.reshape(1, -1), fmt='%i', delimiter=",")
|
25 |
+
f.write(b'shuffle inds\n')
|
26 |
+
np.savetxt(f, epoch_inds_shuffle, fmt='%i', delimiter=",")
|
27 |
+
|
28 |
+
return epoch_inds_shuffle
|
29 |
+
|
30 |
+
|
31 |
+
def gaussian(x, y, x0, y0, sigma=6):
|
32 |
+
return 1./(np.sqrt(2*np.pi)*sigma) * np.exp(-0.5 * ((x-x0)**2 + (y-y0)**2) / sigma**2)
|
33 |
+
|
34 |
+
|
35 |
+
def create_gaussian_filter(sigma=6, win_mult=3.5):
|
36 |
+
win_size = int(win_mult * sigma)
|
37 |
+
x, y = np.mgrid[0:2*win_size+1, 0:2*win_size+1]
|
38 |
+
gauss_filt = (8./3)*sigma*gaussian(x, y, win_size, win_size, sigma=sigma) # same as in ECT
|
39 |
+
return gauss_filt
|
40 |
+
|
41 |
+
|
42 |
+
def load_images(img_list, batch_inds, image_size=256, c_dim=3, scale=255):
|
43 |
+
|
44 |
+
""" load images as a numpy array from menpo image list """
|
45 |
+
|
46 |
+
num_inputs = len(batch_inds)
|
47 |
+
batch_menpo_images = img_list[batch_inds]
|
48 |
+
|
49 |
+
images = np.zeros([num_inputs, image_size, image_size, c_dim]).astype('float32')
|
50 |
+
|
51 |
+
for ind, img in enumerate(batch_menpo_images):
|
52 |
+
if img.n_channels < 3 and c_dim == 3:
|
53 |
+
images[ind, :, :, :] = gray2rgb(img.pixels_with_channels_at_back())
|
54 |
+
else:
|
55 |
+
images[ind, :, :, :] = img.pixels_with_channels_at_back()
|
56 |
+
|
57 |
+
if scale is 255:
|
58 |
+
images *= 255
|
59 |
+
elif scale is 0:
|
60 |
+
images = 2 * images - 1
|
61 |
+
|
62 |
+
return images
|
63 |
+
|
64 |
+
|
65 |
+
# loading functions with pre-allocation and approx heat-map generation
|
66 |
+
|
67 |
+
|
68 |
+
def create_approx_heat_maps_alloc_once(landmarks, maps, gauss_filt=None, win_mult=3.5, num_landmarks=68, image_size=256,
|
69 |
+
sigma=6):
|
70 |
+
""" create heatmaps from input landmarks"""
|
71 |
+
maps.fill(0.)
|
72 |
+
|
73 |
+
win_size = int(win_mult * sigma)
|
74 |
+
filt_size = 2 * win_size + 1
|
75 |
+
landmarks = landmarks.astype(int)
|
76 |
+
|
77 |
+
if gauss_filt is None:
|
78 |
+
x_small, y_small = np.mgrid[0:2 * win_size + 1, 0:2 * win_size + 1]
|
79 |
+
gauss_filt = (8. / 3) * sigma * gaussian(x_small, y_small, win_size, win_size, sigma=sigma) # same as in ECT
|
80 |
+
|
81 |
+
for i in range(num_landmarks):
|
82 |
+
|
83 |
+
min_row = landmarks[i, 0] - win_size
|
84 |
+
max_row = landmarks[i, 0] + win_size + 1
|
85 |
+
min_col = landmarks[i, 1] - win_size
|
86 |
+
max_col = landmarks[i, 1] + win_size + 1
|
87 |
+
|
88 |
+
if min_row < 0:
|
89 |
+
min_row_gap = -1 * min_row
|
90 |
+
min_row = 0
|
91 |
+
else:
|
92 |
+
min_row_gap = 0
|
93 |
+
|
94 |
+
if min_col < 0:
|
95 |
+
min_col_gap = -1 * min_col
|
96 |
+
min_col = 0
|
97 |
+
else:
|
98 |
+
min_col_gap = 0
|
99 |
+
|
100 |
+
if max_row > image_size:
|
101 |
+
max_row_gap = max_row - image_size
|
102 |
+
max_row = image_size
|
103 |
+
else:
|
104 |
+
max_row_gap = 0
|
105 |
+
|
106 |
+
if max_col > image_size:
|
107 |
+
max_col_gap = max_col - image_size
|
108 |
+
max_col = image_size
|
109 |
+
else:
|
110 |
+
max_col_gap = 0
|
111 |
+
|
112 |
+
maps[min_row:max_row, min_col:max_col, i] =\
|
113 |
+
gauss_filt[min_row_gap:filt_size - 1 * max_row_gap, min_col_gap:filt_size - 1 * max_col_gap]
|
114 |
+
|
115 |
+
|
116 |
+
def load_images_landmarks_approx_maps_alloc_once(
|
117 |
+
img_list, batch_inds, images, maps_small, maps, landmarks, image_size=256, num_landmarks=68,
|
118 |
+
scale=255, gauss_filt_large=None, gauss_filt_small=None, win_mult=3.5, sigma=6, save_landmarks=False):
|
119 |
+
|
120 |
+
""" load images and gt landmarks from menpo image list, and create matching heatmaps """
|
121 |
+
|
122 |
+
batch_menpo_images = img_list[batch_inds]
|
123 |
+
c_dim = images.shape[-1]
|
124 |
+
grp_name = batch_menpo_images[0].landmarks.group_labels[0]
|
125 |
+
|
126 |
+
win_size_large = int(win_mult * sigma)
|
127 |
+
win_size_small = int(win_mult * (1.*sigma/4))
|
128 |
+
|
129 |
+
if gauss_filt_small is None:
|
130 |
+
x_small, y_small = np.mgrid[0:2 * win_size_small + 1, 0:2 * win_size_small + 1]
|
131 |
+
gauss_filt_small = (8. / 3) * (1.*sigma/4) * gaussian(
|
132 |
+
x_small, y_small, win_size_small, win_size_small, sigma=1.*sigma/4) # same as in ECT
|
133 |
+
if gauss_filt_large is None:
|
134 |
+
x_large, y_large = np.mgrid[0:2 * win_size_large + 1, 0:2 * win_size_large + 1]
|
135 |
+
gauss_filt_large = (8. / 3) * sigma * gaussian(x_large, y_large, win_size_large, win_size_large, sigma=sigma) # same as in ECT
|
136 |
+
|
137 |
+
for ind, img in enumerate(batch_menpo_images):
|
138 |
+
if img.n_channels < 3 and c_dim == 3:
|
139 |
+
images[ind, :, :, :] = gray2rgb(img.pixels_with_channels_at_back())
|
140 |
+
else:
|
141 |
+
images[ind, :, :, :] = img.pixels_with_channels_at_back()
|
142 |
+
|
143 |
+
lms = img.landmarks[grp_name].points
|
144 |
+
lms = np.minimum(lms, image_size - 1)
|
145 |
+
create_approx_heat_maps_alloc_once(
|
146 |
+
landmarks=lms, maps=maps[ind, :, :, :], gauss_filt=gauss_filt_large, win_mult=win_mult,
|
147 |
+
num_landmarks=num_landmarks, image_size=image_size, sigma=sigma)
|
148 |
+
|
149 |
+
lms_small = img.resize([image_size / 4, image_size / 4]).landmarks[grp_name].points
|
150 |
+
lms_small = np.minimum(lms_small, image_size / 4 - 1)
|
151 |
+
create_approx_heat_maps_alloc_once(
|
152 |
+
landmarks=lms_small, maps=maps_small[ind, :, :, :], gauss_filt=gauss_filt_small, win_mult=win_mult,
|
153 |
+
num_landmarks=num_landmarks, image_size=image_size / 4, sigma=1. * sigma / 4)
|
154 |
+
|
155 |
+
if save_landmarks:
|
156 |
+
landmarks[ind, :, :] = lms
|
157 |
+
|
158 |
+
if scale is 255:
|
159 |
+
images *= 255
|
160 |
+
elif scale is 0:
|
161 |
+
images = 2 * images - 1
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.pyc
ADDED
Binary file (5.95 kB). View file
|
|
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/deep_heatmaps_model_fusion_net.py
ADDED
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import scipy.io
|
2 |
+
import scipy.misc
|
3 |
+
from glob import glob
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from thirdparty.face_of_art.ops import *
|
7 |
+
import tensorflow as tf
|
8 |
+
from tensorflow import contrib
|
9 |
+
from thirdparty.face_of_art.menpo_functions import *
|
10 |
+
from thirdparty.face_of_art.logging_functions import *
|
11 |
+
from thirdparty.face_of_art.data_loading_functions import *
|
12 |
+
|
13 |
+
|
14 |
+
class DeepHeatmapsModel(object):
|
15 |
+
|
16 |
+
"""facial landmark localization Network"""
|
17 |
+
|
18 |
+
def __init__(self, mode='TRAIN', train_iter=100000, batch_size=10, learning_rate=1e-3, l_weight_primary=1.,
|
19 |
+
l_weight_fusion=1.,l_weight_upsample=3.,adam_optimizer=True,momentum=0.95,step=100000, gamma=0.1,reg=0,
|
20 |
+
weight_initializer='xavier', weight_initializer_std=0.01, bias_initializer=0.0, image_size=256,c_dim=3,
|
21 |
+
num_landmarks=68, sigma=1.5, scale=1, margin=0.25, bb_type='gt', win_mult=3.33335,
|
22 |
+
augment_basic=True,augment_texture=False, p_texture=0., augment_geom=False, p_geom=0.,
|
23 |
+
output_dir='output', save_model_path='model',
|
24 |
+
save_sample_path='sample', save_log_path='logs', test_model_path='model/deep_heatmaps-50000',
|
25 |
+
pre_train_path='model/deep_heatmaps-50000', load_pretrain=False, load_primary_only=False,
|
26 |
+
img_path='data', test_data='full', valid_data='full', valid_size=0, log_valid_every=5,
|
27 |
+
train_crop_dir='crop_gt_margin_0.25', img_dir_ns='crop_gt_margin_0.25_ns',
|
28 |
+
print_every=100, save_every=5000, sample_every=5000, sample_grid=9, sample_to_log=True,
|
29 |
+
debug_data_size=20, debug=False, epoch_data_dir='epoch_data', use_epoch_data=False, menpo_verbose=True):
|
30 |
+
|
31 |
+
# define some extra parameters
|
32 |
+
|
33 |
+
self.log_histograms = False # save weight + gradient histogram to log
|
34 |
+
self.save_valid_images = True # sample heat maps of validation images
|
35 |
+
self.sample_per_channel = False # sample heatmaps separately for each landmark
|
36 |
+
|
37 |
+
# for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
|
38 |
+
self.reset_training_op = False
|
39 |
+
|
40 |
+
self.fast_img_gen = True
|
41 |
+
|
42 |
+
self.compute_nme = True # compute normalized mean error
|
43 |
+
|
44 |
+
self.config = tf.ConfigProto()
|
45 |
+
self.config.gpu_options.allow_growth = True
|
46 |
+
|
47 |
+
# sampling and logging parameters
|
48 |
+
self.print_every = print_every # print losses to screen + log
|
49 |
+
self.save_every = save_every # save model
|
50 |
+
self.sample_every = sample_every # save images of gen heat maps compared to GT
|
51 |
+
self.sample_grid = sample_grid # number of training images in sample
|
52 |
+
self.sample_to_log = sample_to_log # sample images to log instead of disk
|
53 |
+
self.log_valid_every = log_valid_every # log validation loss (in epochs)
|
54 |
+
|
55 |
+
self.debug = debug
|
56 |
+
self.debug_data_size = debug_data_size
|
57 |
+
self.use_epoch_data = use_epoch_data
|
58 |
+
self.epoch_data_dir = epoch_data_dir
|
59 |
+
|
60 |
+
self.load_pretrain = load_pretrain
|
61 |
+
self.load_primary_only = load_primary_only
|
62 |
+
self.pre_train_path = pre_train_path
|
63 |
+
|
64 |
+
self.mode = mode
|
65 |
+
self.train_iter = train_iter
|
66 |
+
self.learning_rate = learning_rate
|
67 |
+
|
68 |
+
self.image_size = image_size
|
69 |
+
self.c_dim = c_dim
|
70 |
+
self.batch_size = batch_size
|
71 |
+
|
72 |
+
self.num_landmarks = num_landmarks
|
73 |
+
|
74 |
+
self.save_log_path = save_log_path
|
75 |
+
self.save_sample_path = save_sample_path
|
76 |
+
self.save_model_path = save_model_path
|
77 |
+
self.test_model_path = test_model_path
|
78 |
+
self.img_path=img_path
|
79 |
+
|
80 |
+
self.momentum = momentum
|
81 |
+
self.step = step # for lr decay
|
82 |
+
self.gamma = gamma # for lr decay
|
83 |
+
self.reg = reg # weight decay scale
|
84 |
+
self.l_weight_primary = l_weight_primary # primary loss weight
|
85 |
+
self.l_weight_fusion = l_weight_fusion # fusion loss weight
|
86 |
+
self.l_weight_upsample = l_weight_upsample # upsample loss weight
|
87 |
+
|
88 |
+
self.weight_initializer = weight_initializer # random_normal or xavier
|
89 |
+
self.weight_initializer_std = weight_initializer_std
|
90 |
+
self.bias_initializer = bias_initializer
|
91 |
+
self.adam_optimizer = adam_optimizer
|
92 |
+
|
93 |
+
self.sigma = sigma # sigma for heatmap generation
|
94 |
+
self.scale = scale # scale for image normalization 255 / 1 / 0
|
95 |
+
self.win_mult = win_mult # gaussian filter size for cpu/gpu approximation: 2 * sigma * win_mult + 1
|
96 |
+
|
97 |
+
self.test_data = test_data # if mode is TEST, this choose the set to use full/common/challenging/test/art
|
98 |
+
self.train_crop_dir = train_crop_dir
|
99 |
+
self.img_dir_ns = os.path.join(img_path,img_dir_ns)
|
100 |
+
self.augment_basic = augment_basic # perform basic augmentation (rotation,flip,crop)
|
101 |
+
self.augment_texture = augment_texture # perform artistic texture augmentation (NS)
|
102 |
+
self.p_texture = p_texture # initial probability of artistic texture augmentation
|
103 |
+
self.augment_geom = augment_geom # perform artistic geometric augmentation
|
104 |
+
self.p_geom = p_geom # initial probability of artistic geometric augmentation
|
105 |
+
|
106 |
+
self.valid_size = valid_size
|
107 |
+
self.valid_data = valid_data
|
108 |
+
|
109 |
+
# load image, bb and landmark data using menpo
|
110 |
+
self.bb_dir = os.path.join(img_path, 'Bounding_Boxes')
|
111 |
+
self.bb_dictionary = load_bb_dictionary(self.bb_dir, mode, test_data=self.test_data)
|
112 |
+
|
113 |
+
# use pre-augmented data, to save time during training
|
114 |
+
if self.use_epoch_data:
|
115 |
+
epoch_0 = os.path.join(self.epoch_data_dir, '0')
|
116 |
+
self.img_menpo_list = load_menpo_image_list(
|
117 |
+
img_path, train_crop_dir=epoch_0, img_dir_ns=None, mode=mode, bb_dictionary=self.bb_dictionary,
|
118 |
+
image_size=self.image_size, test_data=self.test_data, augment_basic=False, augment_texture=False,
|
119 |
+
augment_geom=False, verbose=menpo_verbose)
|
120 |
+
else:
|
121 |
+
self.img_menpo_list = load_menpo_image_list(
|
122 |
+
img_path, train_crop_dir, self.img_dir_ns, mode, bb_dictionary=self.bb_dictionary,
|
123 |
+
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.test_data,
|
124 |
+
augment_basic=augment_basic, augment_texture=augment_texture, p_texture=p_texture,
|
125 |
+
augment_geom=augment_geom, p_geom=p_geom, verbose=menpo_verbose)
|
126 |
+
|
127 |
+
if mode == 'TRAIN':
|
128 |
+
|
129 |
+
train_params = locals()
|
130 |
+
print_training_params_to_file(train_params) # save init parameters
|
131 |
+
|
132 |
+
self.train_inds = np.arange(len(self.img_menpo_list))
|
133 |
+
|
134 |
+
if self.debug:
|
135 |
+
self.train_inds = self.train_inds[:self.debug_data_size]
|
136 |
+
self.img_menpo_list = self.img_menpo_list[self.train_inds]
|
137 |
+
|
138 |
+
if valid_size > 0:
|
139 |
+
|
140 |
+
self.valid_bb_dictionary = load_bb_dictionary(self.bb_dir, 'TEST', test_data=self.valid_data)
|
141 |
+
self.valid_img_menpo_list = load_menpo_image_list(
|
142 |
+
img_path, train_crop_dir, self.img_dir_ns, 'TEST', bb_dictionary=self.valid_bb_dictionary,
|
143 |
+
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.valid_data,
|
144 |
+
verbose=menpo_verbose)
|
145 |
+
|
146 |
+
np.random.seed(0)
|
147 |
+
self.val_inds = np.arange(len(self.valid_img_menpo_list))
|
148 |
+
np.random.shuffle(self.val_inds)
|
149 |
+
self.val_inds = self.val_inds[:self.valid_size]
|
150 |
+
|
151 |
+
self.valid_img_menpo_list = self.valid_img_menpo_list[self.val_inds]
|
152 |
+
|
153 |
+
self.valid_images_loaded =\
|
154 |
+
np.zeros([self.valid_size, self.image_size, self.image_size, self.c_dim]).astype('float32')
|
155 |
+
self.valid_gt_maps_small_loaded =\
|
156 |
+
np.zeros([self.valid_size, self.image_size / 4, self.image_size / 4,
|
157 |
+
self.num_landmarks]).astype('float32')
|
158 |
+
self.valid_gt_maps_loaded =\
|
159 |
+
np.zeros([self.valid_size, self.image_size, self.image_size, self.num_landmarks]
|
160 |
+
).astype('float32')
|
161 |
+
self.valid_landmarks_loaded = np.zeros([self.valid_size, num_landmarks, 2]).astype('float32')
|
162 |
+
self.valid_landmarks_pred = np.zeros([self.valid_size, self.num_landmarks, 2]).astype('float32')
|
163 |
+
|
164 |
+
load_images_landmarks_approx_maps_alloc_once(
|
165 |
+
self.valid_img_menpo_list, np.arange(self.valid_size), images=self.valid_images_loaded,
|
166 |
+
maps_small=self.valid_gt_maps_small_loaded, maps=self.valid_gt_maps_loaded,
|
167 |
+
landmarks=self.valid_landmarks_loaded, image_size=self.image_size,
|
168 |
+
num_landmarks=self.num_landmarks, scale=self.scale, win_mult=self.win_mult, sigma=self.sigma,
|
169 |
+
save_landmarks=self.compute_nme)
|
170 |
+
|
171 |
+
if self.valid_size > self.sample_grid:
|
172 |
+
self.valid_gt_maps_loaded = self.valid_gt_maps_loaded[:self.sample_grid]
|
173 |
+
self.valid_gt_maps_small_loaded = self.valid_gt_maps_small_loaded[:self.sample_grid]
|
174 |
+
else:
|
175 |
+
self.val_inds = None
|
176 |
+
|
177 |
+
self.epoch_inds_shuffle = train_val_shuffle_inds_per_epoch(
|
178 |
+
self.val_inds, self.train_inds, train_iter, batch_size, save_log_path)
|
179 |
+
|
180 |
+
def add_placeholders(self):
|
181 |
+
|
182 |
+
if self.mode == 'TEST':
|
183 |
+
self.images = tf.placeholder(
|
184 |
+
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'images')
|
185 |
+
|
186 |
+
self.heatmaps = tf.placeholder(
|
187 |
+
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'heatmaps')
|
188 |
+
|
189 |
+
self.heatmaps_small = tf.placeholder(
|
190 |
+
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'heatmaps_small')
|
191 |
+
self.lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'lms')
|
192 |
+
self.pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'pred_lms')
|
193 |
+
|
194 |
+
elif self.mode == 'TRAIN':
|
195 |
+
self.images = tf.placeholder(
|
196 |
+
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'train_images')
|
197 |
+
|
198 |
+
self.heatmaps = tf.placeholder(
|
199 |
+
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'train_heatmaps')
|
200 |
+
|
201 |
+
self.heatmaps_small = tf.placeholder(
|
202 |
+
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'train_heatmaps_small')
|
203 |
+
|
204 |
+
self.train_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_lms')
|
205 |
+
self.train_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_pred_lms')
|
206 |
+
|
207 |
+
self.valid_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_lms')
|
208 |
+
self.valid_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_pred_lms')
|
209 |
+
|
210 |
+
# self.p_texture_log = tf.placeholder(tf.float32, [])
|
211 |
+
# self.p_geom_log = tf.placeholder(tf.float32, [])
|
212 |
+
|
213 |
+
# self.sparse_hm_small = tf.placeholder(tf.float32, [None, int(self.image_size/4), int(self.image_size/4), 1])
|
214 |
+
# self.sparse_hm = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1])
|
215 |
+
|
216 |
+
if self.sample_to_log:
|
217 |
+
row = int(np.sqrt(self.sample_grid))
|
218 |
+
self.log_image_map_small = tf.placeholder(
|
219 |
+
tf.uint8, [None, row * int(self.image_size/4), 3 * row * int(self.image_size/4), self.c_dim],
|
220 |
+
'sample_img_map_small')
|
221 |
+
self.log_image_map = tf.placeholder(
|
222 |
+
tf.uint8, [None, row * self.image_size, 3 * row * self.image_size, self.c_dim],
|
223 |
+
'sample_img_map')
|
224 |
+
if self.sample_per_channel:
|
225 |
+
row = np.ceil(np.sqrt(self.num_landmarks)).astype(np.int64)
|
226 |
+
self.log_map_channels_small = tf.placeholder(
|
227 |
+
tf.uint8, [None, row * int(self.image_size/4), 2 * row * int(self.image_size/4), self.c_dim],
|
228 |
+
'sample_map_channels_small')
|
229 |
+
self.log_map_channels = tf.placeholder(
|
230 |
+
tf.uint8, [None, row * self.image_size, 2 * row * self.image_size, self.c_dim],
|
231 |
+
'sample_map_channels')
|
232 |
+
|
233 |
+
def heatmaps_network(self, input_images, reuse=None, name='pred_heatmaps'):
|
234 |
+
|
235 |
+
with tf.name_scope(name):
|
236 |
+
|
237 |
+
if self.weight_initializer == 'xavier':
|
238 |
+
weight_initializer = contrib.layers.xavier_initializer()
|
239 |
+
else:
|
240 |
+
weight_initializer = tf.random_normal_initializer(stddev=self.weight_initializer_std)
|
241 |
+
|
242 |
+
bias_init = tf.constant_initializer(self.bias_initializer)
|
243 |
+
|
244 |
+
with tf.variable_scope('heatmaps_network'):
|
245 |
+
with tf.name_scope('primary_net'):
|
246 |
+
|
247 |
+
l1 = conv_relu_pool(input_images, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
|
248 |
+
reuse=reuse, var_scope='conv_1')
|
249 |
+
l2 = conv_relu_pool(l1, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
|
250 |
+
reuse=reuse, var_scope='conv_2')
|
251 |
+
l3 = conv_relu(l2, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
|
252 |
+
reuse=reuse, var_scope='conv_3')
|
253 |
+
|
254 |
+
l4_1 = conv_relu(l3, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
|
255 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_1')
|
256 |
+
l4_2 = conv_relu(l3, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
|
257 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_2')
|
258 |
+
l4_3 = conv_relu(l3, 3, 128, conv_dilation=3, conv_ker_init=weight_initializer,
|
259 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_3')
|
260 |
+
l4_4 = conv_relu(l3, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
|
261 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_4')
|
262 |
+
|
263 |
+
l4 = tf.concat([l4_1, l4_2, l4_3, l4_4], 3, name='conv_4')
|
264 |
+
|
265 |
+
l5_1 = conv_relu(l4, 3, 256, conv_dilation=1, conv_ker_init=weight_initializer,
|
266 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_1')
|
267 |
+
l5_2 = conv_relu(l4, 3, 256, conv_dilation=2, conv_ker_init=weight_initializer,
|
268 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_2')
|
269 |
+
l5_3 = conv_relu(l4, 3, 256, conv_dilation=3, conv_ker_init=weight_initializer,
|
270 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_3')
|
271 |
+
l5_4 = conv_relu(l4, 3, 256, conv_dilation=4, conv_ker_init=weight_initializer,
|
272 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_4')
|
273 |
+
|
274 |
+
l5 = tf.concat([l5_1, l5_2, l5_3, l5_4], 3, name='conv_5')
|
275 |
+
|
276 |
+
l6 = conv_relu(l5, 1, 512, conv_ker_init=weight_initializer,
|
277 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_6')
|
278 |
+
l7 = conv_relu(l6, 1, 256, conv_ker_init=weight_initializer,
|
279 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_7')
|
280 |
+
primary_out = conv(l7, 1, self.num_landmarks, conv_ker_init=weight_initializer,
|
281 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_8')
|
282 |
+
|
283 |
+
with tf.name_scope('fusion_net'):
|
284 |
+
|
285 |
+
l_fsn_0 = tf.concat([l3, l7], 3, name='conv_3_7_fsn')
|
286 |
+
|
287 |
+
l_fsn_1_1 = conv_relu(l_fsn_0, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
|
288 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_1')
|
289 |
+
l_fsn_1_2 = conv_relu(l_fsn_0, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
|
290 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_2')
|
291 |
+
l_fsn_1_3 = conv_relu(l_fsn_0, 3, 64, conv_dilation=3, conv_ker_init=weight_initializer,
|
292 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_3')
|
293 |
+
|
294 |
+
l_fsn_1 = tf.concat([l_fsn_1_1, l_fsn_1_2, l_fsn_1_3], 3, name='conv_fsn_1')
|
295 |
+
|
296 |
+
l_fsn_2_1 = conv_relu(l_fsn_1, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
|
297 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_1')
|
298 |
+
l_fsn_2_2 = conv_relu(l_fsn_1, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
|
299 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_2')
|
300 |
+
l_fsn_2_3 = conv_relu(l_fsn_1, 3, 64, conv_dilation=4, conv_ker_init=weight_initializer,
|
301 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_3')
|
302 |
+
l_fsn_2_4 = conv_relu(l_fsn_1, 5, 64, conv_dilation=3, conv_ker_init=weight_initializer,
|
303 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_4')
|
304 |
+
|
305 |
+
l_fsn_2 = tf.concat([l_fsn_2_1, l_fsn_2_2, l_fsn_2_3, l_fsn_2_4], 3, name='conv_fsn_2')
|
306 |
+
|
307 |
+
l_fsn_3_1 = conv_relu(l_fsn_2, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
|
308 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_1')
|
309 |
+
l_fsn_3_2 = conv_relu(l_fsn_2, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
|
310 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_2')
|
311 |
+
l_fsn_3_3 = conv_relu(l_fsn_2, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
|
312 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_3')
|
313 |
+
l_fsn_3_4 = conv_relu(l_fsn_2, 5, 128, conv_dilation=3, conv_ker_init=weight_initializer,
|
314 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_4')
|
315 |
+
|
316 |
+
l_fsn_3 = tf.concat([l_fsn_3_1, l_fsn_3_2, l_fsn_3_3, l_fsn_3_4], 3, name='conv_fsn_3')
|
317 |
+
|
318 |
+
l_fsn_4 = conv_relu(l_fsn_3, 1, 256, conv_ker_init=weight_initializer,
|
319 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_4')
|
320 |
+
fusion_out = conv(l_fsn_4, 1, self.num_landmarks, conv_ker_init=weight_initializer,
|
321 |
+
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_5')
|
322 |
+
|
323 |
+
with tf.name_scope('upsample_net'):
|
324 |
+
|
325 |
+
out = deconv(fusion_out, 8, self.num_landmarks, conv_stride=4,
|
326 |
+
conv_ker_init=deconv2d_bilinear_upsampling_initializer(
|
327 |
+
[8, 8, self.num_landmarks, self.num_landmarks]), conv_bias_init=bias_init,
|
328 |
+
reuse=reuse, var_scope='deconv_1')
|
329 |
+
|
330 |
+
self.all_layers = [l1, l2, l3, l4, l5, l6, l7, primary_out, l_fsn_1, l_fsn_2, l_fsn_3, l_fsn_4,
|
331 |
+
fusion_out, out]
|
332 |
+
|
333 |
+
return primary_out, fusion_out, out
|
334 |
+
|
335 |
+
def build_model(self):
|
336 |
+
self.pred_hm_p, self.pred_hm_f, self.pred_hm_u = self.heatmaps_network(self.images,name='heatmaps_prediction')
|
337 |
+
|
338 |
+
def create_loss_ops(self):
|
339 |
+
|
340 |
+
def nme_norm_eyes(pred_landmarks, real_landmarks, normalize=True, name='NME'):
|
341 |
+
"""calculate normalized mean error on landmarks - normalize with inter pupil distance"""
|
342 |
+
|
343 |
+
with tf.name_scope(name):
|
344 |
+
with tf.name_scope('real_pred_landmarks_rmse'):
|
345 |
+
# calculate RMS ERROR between GT and predicted lms
|
346 |
+
landmarks_rms_err = tf.reduce_mean(
|
347 |
+
tf.sqrt(tf.reduce_sum(tf.square(pred_landmarks - real_landmarks), axis=2)), axis=1)
|
348 |
+
if normalize:
|
349 |
+
# normalize RMS ERROR with inter-pupil distance of GT lms
|
350 |
+
with tf.name_scope('inter_pupil_dist'):
|
351 |
+
with tf.name_scope('left_eye_center'):
|
352 |
+
p1 = tf.reduce_mean(tf.slice(real_landmarks, [0, 42, 0], [-1, 6, 2]), axis=1)
|
353 |
+
with tf.name_scope('right_eye_center'):
|
354 |
+
p2 = tf.reduce_mean(tf.slice(real_landmarks, [0, 36, 0], [-1, 6, 2]), axis=1)
|
355 |
+
|
356 |
+
eye_dist = tf.sqrt(tf.reduce_sum(tf.square(p1 - p2), axis=1))
|
357 |
+
|
358 |
+
return landmarks_rms_err / eye_dist
|
359 |
+
else:
|
360 |
+
return landmarks_rms_err
|
361 |
+
|
362 |
+
if self.mode is 'TRAIN':
|
363 |
+
|
364 |
+
# calculate L2 loss between ideal and predicted heatmaps
|
365 |
+
primary_maps_diff = self.pred_hm_p - self.heatmaps_small
|
366 |
+
fusion_maps_diff = self.pred_hm_f - self.heatmaps_small
|
367 |
+
upsample_maps_diff = self.pred_hm_u - self.heatmaps
|
368 |
+
|
369 |
+
self.l2_primary = tf.reduce_mean(tf.square(primary_maps_diff))
|
370 |
+
self.l2_fusion = tf.reduce_mean(tf.square(fusion_maps_diff))
|
371 |
+
self.l2_upsample = tf.reduce_mean(tf.square(upsample_maps_diff))
|
372 |
+
|
373 |
+
self.total_loss = 1000.*(self.l_weight_primary * self.l2_primary + self.l_weight_fusion * self.l2_fusion +
|
374 |
+
self.l_weight_upsample * self.l2_upsample)
|
375 |
+
|
376 |
+
# add weight decay
|
377 |
+
self.total_loss += self.reg * tf.add_n(
|
378 |
+
[tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name])
|
379 |
+
|
380 |
+
# compute normalized mean error on gt vs. predicted landmarks (for validation)
|
381 |
+
if self.compute_nme:
|
382 |
+
self.nme_loss = tf.reduce_mean(nme_norm_eyes(self.train_pred_lms, self.train_lms))
|
383 |
+
|
384 |
+
if self.valid_size > 0 and self.compute_nme:
|
385 |
+
self.valid_nme_loss = tf.reduce_mean(nme_norm_eyes(self.valid_pred_lms, self.valid_lms))
|
386 |
+
|
387 |
+
elif self.mode == 'TEST' and self.compute_nme:
|
388 |
+
self.nme_per_image = nme_norm_eyes(self.pred_lms, self.lms)
|
389 |
+
self.nme_loss = tf.reduce_mean(self.nme_per_image)
|
390 |
+
|
391 |
+
def predict_valid_landmarks_in_batches(self, images, session):
|
392 |
+
|
393 |
+
num_images=int(images.shape[0])
|
394 |
+
num_batches = int(1.*num_images/self.batch_size)
|
395 |
+
if num_batches == 0:
|
396 |
+
batch_size = num_images
|
397 |
+
num_batches = 1
|
398 |
+
else:
|
399 |
+
batch_size = self.batch_size
|
400 |
+
|
401 |
+
for j in range(num_batches):
|
402 |
+
|
403 |
+
batch_images = images[j * batch_size:(j + 1) * batch_size,:,:,:]
|
404 |
+
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
|
405 |
+
batch_heat_maps_to_landmarks_alloc_once(
|
406 |
+
batch_maps=batch_maps_pred, batch_landmarks=self.valid_landmarks_pred[j * batch_size:(j + 1) * batch_size, :, :],
|
407 |
+
batch_size=batch_size,image_size=self.image_size,num_landmarks=self.num_landmarks)
|
408 |
+
|
409 |
+
reminder = num_images-num_batches*batch_size
|
410 |
+
if reminder > 0:
|
411 |
+
batch_images = images[-reminder:, :, :, :]
|
412 |
+
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
|
413 |
+
|
414 |
+
batch_heat_maps_to_landmarks_alloc_once(
|
415 |
+
batch_maps=batch_maps_pred,
|
416 |
+
batch_landmarks=self.valid_landmarks_pred[-reminder:, :, :],
|
417 |
+
batch_size=reminder, image_size=self.image_size, num_landmarks=self.num_landmarks)
|
418 |
+
|
419 |
+
def create_summary_ops(self):
|
420 |
+
"""create summary ops for logging"""
|
421 |
+
|
422 |
+
# loss summary
|
423 |
+
l2_primary = tf.summary.scalar('l2_primary', self.l2_primary)
|
424 |
+
l2_fusion = tf.summary.scalar('l2_fusion', self.l2_fusion)
|
425 |
+
l2_upsample = tf.summary.scalar('l2_upsample', self.l2_upsample)
|
426 |
+
|
427 |
+
l_total = tf.summary.scalar('l_total', self.total_loss)
|
428 |
+
self.batch_summary_op = tf.summary.merge([l2_primary,l2_fusion,l2_upsample,l_total])
|
429 |
+
|
430 |
+
if self.compute_nme:
|
431 |
+
nme = tf.summary.scalar('nme', self.nme_loss)
|
432 |
+
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, nme])
|
433 |
+
|
434 |
+
if self.log_histograms:
|
435 |
+
var_summary = [tf.summary.histogram(var.name,var) for var in tf.trainable_variables()]
|
436 |
+
grads = tf.gradients(self.total_loss, tf.trainable_variables())
|
437 |
+
grads = list(zip(grads, tf.trainable_variables()))
|
438 |
+
grad_summary = [tf.summary.histogram(var.name+'/grads',grad) for grad,var in grads]
|
439 |
+
activ_summary = [tf.summary.histogram(layer.name, layer) for layer in self.all_layers]
|
440 |
+
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, var_summary, grad_summary, activ_summary])
|
441 |
+
|
442 |
+
if self.valid_size > 0 and self.compute_nme:
|
443 |
+
self.valid_summary = tf.summary.scalar('valid_nme', self.valid_nme_loss)
|
444 |
+
|
445 |
+
if self.sample_to_log:
|
446 |
+
img_map_summary_small = tf.summary.image('compare_map_to_gt_small', self.log_image_map_small)
|
447 |
+
img_map_summary = tf.summary.image('compare_map_to_gt', self.log_image_map)
|
448 |
+
|
449 |
+
if self.sample_per_channel:
|
450 |
+
map_channels_summary = tf.summary.image('compare_map_channels_to_gt', self.log_map_channels)
|
451 |
+
map_channels_summary_small = tf.summary.image('compare_map_channels_to_gt_small',
|
452 |
+
self.log_map_channels_small)
|
453 |
+
self.img_summary = tf.summary.merge(
|
454 |
+
[img_map_summary, img_map_summary_small,map_channels_summary,map_channels_summary_small])
|
455 |
+
else:
|
456 |
+
self.img_summary = tf.summary.merge([img_map_summary, img_map_summary_small])
|
457 |
+
|
458 |
+
if self.valid_size >= self.sample_grid:
|
459 |
+
img_map_summary_valid_small = tf.summary.image('compare_map_to_gt_small_valid', self.log_image_map_small)
|
460 |
+
img_map_summary_valid = tf.summary.image('compare_map_to_gt_valid', self.log_image_map)
|
461 |
+
|
462 |
+
if self.sample_per_channel:
|
463 |
+
map_channels_summary_valid_small = tf.summary.image('compare_map_channels_to_gt_small_valid',
|
464 |
+
self.log_map_channels_small)
|
465 |
+
map_channels_summary_valid = tf.summary.image('compare_map_channels_to_gt_valid',
|
466 |
+
self.log_map_channels)
|
467 |
+
self.img_summary_valid = tf.summary.merge(
|
468 |
+
[img_map_summary_valid,img_map_summary_valid_small,map_channels_summary_valid,
|
469 |
+
map_channels_summary_valid_small])
|
470 |
+
else:
|
471 |
+
self.img_summary_valid = tf.summary.merge([img_map_summary_valid, img_map_summary_valid_small])
|
472 |
+
|
473 |
+
def train(self):
|
474 |
+
# set random seed
|
475 |
+
tf.set_random_seed(1234)
|
476 |
+
np.random.seed(1234)
|
477 |
+
# build a graph
|
478 |
+
# add placeholders
|
479 |
+
self.add_placeholders()
|
480 |
+
# build model
|
481 |
+
self.build_model()
|
482 |
+
# create loss ops
|
483 |
+
self.create_loss_ops()
|
484 |
+
# create summary ops
|
485 |
+
self.create_summary_ops()
|
486 |
+
|
487 |
+
# create optimizer and training op
|
488 |
+
global_step = tf.Variable(0, trainable=False)
|
489 |
+
lr = tf.train.exponential_decay(self.learning_rate,global_step, self.step, self.gamma, staircase=True)
|
490 |
+
if self.adam_optimizer:
|
491 |
+
optimizer = tf.train.AdamOptimizer(lr)
|
492 |
+
else:
|
493 |
+
optimizer = tf.train.MomentumOptimizer(lr, self.momentum)
|
494 |
+
|
495 |
+
train_op = optimizer.minimize(self.total_loss,global_step=global_step)
|
496 |
+
|
497 |
+
with tf.Session(config=self.config) as sess:
|
498 |
+
|
499 |
+
tf.global_variables_initializer().run()
|
500 |
+
|
501 |
+
# load pre trained weights if load_pretrain==True
|
502 |
+
if self.load_pretrain:
|
503 |
+
print
|
504 |
+
print('*** loading pre-trained weights from: '+self.pre_train_path+' ***')
|
505 |
+
if self.load_primary_only:
|
506 |
+
print('*** loading primary-net only ***')
|
507 |
+
primary_var = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if
|
508 |
+
('deconv_' not in v.name) and ('_fsn_' not in v.name)]
|
509 |
+
loader = tf.train.Saver(var_list=primary_var)
|
510 |
+
else:
|
511 |
+
loader = tf.train.Saver()
|
512 |
+
loader.restore(sess, self.pre_train_path)
|
513 |
+
print("*** Model restore finished, current global step: %d" % global_step.eval())
|
514 |
+
|
515 |
+
# for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
|
516 |
+
if self.reset_training_op:
|
517 |
+
print ("resetting optimizer and global step")
|
518 |
+
opt_var_list = [optimizer.get_slot(var, name) for name in optimizer.get_slot_names()
|
519 |
+
for var in tf.global_variables() if optimizer.get_slot(var, name) is not None]
|
520 |
+
opt_var_list_init = tf.variables_initializer(opt_var_list)
|
521 |
+
opt_var_list_init.run()
|
522 |
+
sess.run(global_step.initializer)
|
523 |
+
|
524 |
+
# create model saver and file writer
|
525 |
+
summary_writer = tf.summary.FileWriter(logdir=self.save_log_path, graph=tf.get_default_graph())
|
526 |
+
saver = tf.train.Saver()
|
527 |
+
|
528 |
+
print('\n*** Start Training ***')
|
529 |
+
|
530 |
+
# initialize some variables before training loop
|
531 |
+
resume_step = global_step.eval()
|
532 |
+
num_train_images = len(self.img_menpo_list)
|
533 |
+
batches_in_epoch = int(float(num_train_images) / float(self.batch_size))
|
534 |
+
epoch = int(resume_step / batches_in_epoch)
|
535 |
+
img_inds = self.epoch_inds_shuffle[epoch, :]
|
536 |
+
log_valid = True
|
537 |
+
log_valid_images = True
|
538 |
+
|
539 |
+
# allocate space for batch images, maps and landmarks
|
540 |
+
batch_images = np.zeros([self.batch_size, self.image_size, self.image_size, self.c_dim]).astype(
|
541 |
+
'float32')
|
542 |
+
batch_lms = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
|
543 |
+
batch_lms_pred = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
|
544 |
+
|
545 |
+
batch_maps_small = np.zeros((self.batch_size, int(self.image_size/4),
|
546 |
+
int(self.image_size/4), self.num_landmarks)).astype('float32')
|
547 |
+
batch_maps = np.zeros((self.batch_size, self.image_size, self.image_size,
|
548 |
+
self.num_landmarks)).astype('float32')
|
549 |
+
|
550 |
+
# create gaussians for heatmap generation
|
551 |
+
gaussian_filt_large = create_gaussian_filter(sigma=self.sigma, win_mult=self.win_mult)
|
552 |
+
gaussian_filt_small = create_gaussian_filter(sigma=1.*self.sigma/4, win_mult=self.win_mult)
|
553 |
+
|
554 |
+
# training loop
|
555 |
+
for step in range(resume_step, self.train_iter):
|
556 |
+
|
557 |
+
j = step % batches_in_epoch # j==0 if we finished an epoch
|
558 |
+
|
559 |
+
# if we finished an epoch and this isn't the first step
|
560 |
+
if step > resume_step and j == 0:
|
561 |
+
epoch += 1
|
562 |
+
img_inds = self.epoch_inds_shuffle[epoch, :] # get next shuffled image inds
|
563 |
+
log_valid = True
|
564 |
+
log_valid_images = True
|
565 |
+
if self.use_epoch_data: # if using pre-augmented data, load epoch directory
|
566 |
+
epoch_dir = os.path.join(self.epoch_data_dir, str(epoch))
|
567 |
+
self.img_menpo_list = load_menpo_image_list(
|
568 |
+
self.img_path, train_crop_dir=epoch_dir, img_dir_ns=None, mode=self.mode,
|
569 |
+
bb_dictionary=self.bb_dictionary, image_size=self.image_size, test_data=self.test_data,
|
570 |
+
augment_basic=False, augment_texture=False, augment_geom=False)
|
571 |
+
|
572 |
+
# get batch indices
|
573 |
+
batch_inds = img_inds[j * self.batch_size:(j + 1) * self.batch_size]
|
574 |
+
|
575 |
+
# load batch images, gt maps and landmarks
|
576 |
+
load_images_landmarks_approx_maps_alloc_once(
|
577 |
+
self.img_menpo_list, batch_inds, images=batch_images, maps_small=batch_maps_small,
|
578 |
+
maps=batch_maps, landmarks=batch_lms, image_size=self.image_size,
|
579 |
+
num_landmarks=self.num_landmarks, scale=self.scale, gauss_filt_large=gaussian_filt_large,
|
580 |
+
gauss_filt_small=gaussian_filt_small, win_mult=self.win_mult, sigma=self.sigma,
|
581 |
+
save_landmarks=self.compute_nme)
|
582 |
+
|
583 |
+
feed_dict_train = {self.images: batch_images, self.heatmaps: batch_maps,
|
584 |
+
self.heatmaps_small: batch_maps_small}
|
585 |
+
|
586 |
+
# train on batch
|
587 |
+
sess.run(train_op, feed_dict_train)
|
588 |
+
|
589 |
+
# save to log and print status
|
590 |
+
if step == resume_step or (step + 1) % self.print_every == 0:
|
591 |
+
|
592 |
+
# train data log
|
593 |
+
if self.compute_nme:
|
594 |
+
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
|
595 |
+
|
596 |
+
batch_heat_maps_to_landmarks_alloc_once(
|
597 |
+
batch_maps=batch_maps_pred,batch_landmarks=batch_lms_pred,
|
598 |
+
batch_size=self.batch_size, image_size=self.image_size,
|
599 |
+
num_landmarks=self.num_landmarks)
|
600 |
+
|
601 |
+
train_feed_dict_log = {
|
602 |
+
self.images: batch_images, self.heatmaps: batch_maps,
|
603 |
+
self.heatmaps_small: batch_maps_small, self.train_lms: batch_lms,
|
604 |
+
self.train_pred_lms: batch_lms_pred}
|
605 |
+
|
606 |
+
summary, l_p, l_f, l_t, nme = sess.run(
|
607 |
+
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss,
|
608 |
+
self.nme_loss],
|
609 |
+
train_feed_dict_log)
|
610 |
+
|
611 |
+
print (
|
612 |
+
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f]'
|
613 |
+
' total loss: [%.6f] NME: [%.6f]' % (
|
614 |
+
epoch, step + 1, self.train_iter, l_p, l_f, l_t, nme))
|
615 |
+
else:
|
616 |
+
train_feed_dict_log = {self.images: batch_images, self.heatmaps: batch_maps,
|
617 |
+
self.heatmaps_small: batch_maps_small}
|
618 |
+
|
619 |
+
summary, l_p, l_f, l_t = sess.run(
|
620 |
+
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss],
|
621 |
+
train_feed_dict_log)
|
622 |
+
print (
|
623 |
+
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f] total loss: [%.6f]'
|
624 |
+
% (epoch, step + 1, self.train_iter, l_p, l_f, l_t))
|
625 |
+
|
626 |
+
summary_writer.add_summary(summary, step)
|
627 |
+
|
628 |
+
# valid data log
|
629 |
+
if self.valid_size > 0 and (log_valid and epoch % self.log_valid_every == 0) \
|
630 |
+
and self.compute_nme:
|
631 |
+
log_valid = False
|
632 |
+
|
633 |
+
self.predict_valid_landmarks_in_batches(self.valid_images_loaded, sess)
|
634 |
+
valid_feed_dict_log = {
|
635 |
+
self.valid_lms: self.valid_landmarks_loaded,
|
636 |
+
self.valid_pred_lms: self.valid_landmarks_pred}
|
637 |
+
|
638 |
+
v_summary, v_nme = sess.run([self.valid_summary, self.valid_nme_loss],
|
639 |
+
valid_feed_dict_log)
|
640 |
+
summary_writer.add_summary(v_summary, step)
|
641 |
+
print (
|
642 |
+
'epoch: [%d] step: [%d/%d] valid NME: [%.6f]' % (
|
643 |
+
epoch, step + 1, self.train_iter, v_nme))
|
644 |
+
|
645 |
+
# save model
|
646 |
+
if (step + 1) % self.save_every == 0:
|
647 |
+
saver.save(sess, os.path.join(self.save_model_path, 'deep_heatmaps'), global_step=step + 1)
|
648 |
+
print ('model/deep-heatmaps-%d saved' % (step + 1))
|
649 |
+
|
650 |
+
# save images
|
651 |
+
if step == resume_step or (step + 1) % self.sample_every == 0:
|
652 |
+
|
653 |
+
batch_maps_small_pred = sess.run(self.pred_hm_p, {self.images: batch_images})
|
654 |
+
if not self.compute_nme:
|
655 |
+
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
|
656 |
+
batch_lms_pred = None
|
657 |
+
|
658 |
+
merged_img = merge_images_landmarks_maps_gt(
|
659 |
+
batch_images.copy(), batch_maps_pred, batch_maps, landmarks=batch_lms_pred,
|
660 |
+
image_size=self.image_size, num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
|
661 |
+
scale=self.scale, circle_size=2, fast=self.fast_img_gen)
|
662 |
+
|
663 |
+
merged_img_small = merge_images_landmarks_maps_gt(
|
664 |
+
batch_images.copy(), batch_maps_small_pred, batch_maps_small,
|
665 |
+
image_size=self.image_size,
|
666 |
+
num_landmarks=self.num_landmarks, num_samples=self.sample_grid, scale=self.scale,
|
667 |
+
circle_size=0, fast=self.fast_img_gen)
|
668 |
+
|
669 |
+
if self.sample_per_channel:
|
670 |
+
map_per_channel = map_comapre_channels(
|
671 |
+
batch_images.copy(), batch_maps_pred, batch_maps, image_size=self.image_size,
|
672 |
+
num_landmarks=self.num_landmarks, scale=self.scale)
|
673 |
+
|
674 |
+
map_per_channel_small = map_comapre_channels(
|
675 |
+
batch_images.copy(), batch_maps_small_pred, batch_maps_small, image_size=int(self.image_size/4),
|
676 |
+
num_landmarks=self.num_landmarks, scale=self.scale)
|
677 |
+
|
678 |
+
if self.sample_to_log: # save heatmap images to log
|
679 |
+
if self.sample_per_channel:
|
680 |
+
summary_img = sess.run(
|
681 |
+
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
|
682 |
+
self.log_map_channels: np.expand_dims(map_per_channel, 0),
|
683 |
+
self.log_image_map_small: np.expand_dims(merged_img_small, 0),
|
684 |
+
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
|
685 |
+
else:
|
686 |
+
summary_img = sess.run(
|
687 |
+
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
|
688 |
+
self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
|
689 |
+
summary_writer.add_summary(summary_img, step)
|
690 |
+
|
691 |
+
if (self.valid_size >= self.sample_grid) and self.save_valid_images and\
|
692 |
+
(log_valid_images and epoch % self.log_valid_every == 0):
|
693 |
+
log_valid_images = False
|
694 |
+
|
695 |
+
batch_maps_small_pred_val,batch_maps_pred_val =\
|
696 |
+
sess.run([self.pred_hm_p,self.pred_hm_u],
|
697 |
+
{self.images: self.valid_images_loaded[:self.sample_grid]})
|
698 |
+
|
699 |
+
merged_img_small = merge_images_landmarks_maps_gt(
|
700 |
+
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
|
701 |
+
self.valid_gt_maps_small_loaded, image_size=self.image_size,
|
702 |
+
num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
|
703 |
+
scale=self.scale, circle_size=0, fast=self.fast_img_gen)
|
704 |
+
|
705 |
+
merged_img = merge_images_landmarks_maps_gt(
|
706 |
+
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred_val,
|
707 |
+
self.valid_gt_maps_loaded, image_size=self.image_size,
|
708 |
+
num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
|
709 |
+
scale=self.scale, circle_size=2, fast=self.fast_img_gen)
|
710 |
+
|
711 |
+
if self.sample_per_channel:
|
712 |
+
map_per_channel_small = map_comapre_channels(
|
713 |
+
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
|
714 |
+
self.valid_gt_maps_small_loaded, image_size=int(self.image_size / 4),
|
715 |
+
num_landmarks=self.num_landmarks, scale=self.scale)
|
716 |
+
|
717 |
+
map_per_channel = map_comapre_channels(
|
718 |
+
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred,
|
719 |
+
self.valid_gt_maps_loaded, image_size=self.image_size,
|
720 |
+
num_landmarks=self.num_landmarks, scale=self.scale)
|
721 |
+
|
722 |
+
summary_img = sess.run(
|
723 |
+
self.img_summary_valid,
|
724 |
+
{self.log_image_map: np.expand_dims(merged_img, 0),
|
725 |
+
self.log_map_channels: np.expand_dims(map_per_channel, 0),
|
726 |
+
self.log_image_map_small: np.expand_dims(merged_img_small, 0),
|
727 |
+
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
|
728 |
+
else:
|
729 |
+
summary_img = sess.run(
|
730 |
+
self.img_summary_valid,
|
731 |
+
{self.log_image_map: np.expand_dims(merged_img, 0),
|
732 |
+
self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
|
733 |
+
|
734 |
+
summary_writer.add_summary(summary_img, step)
|
735 |
+
else: # save heatmap images to directory
|
736 |
+
sample_path_imgs = os.path.join(
|
737 |
+
self.save_sample_path, 'epoch-%d-train-iter-%d-1.png' % (epoch, step + 1))
|
738 |
+
sample_path_imgs_small = os.path.join(
|
739 |
+
self.save_sample_path, 'epoch-%d-train-iter-%d-1-s.png' % (epoch, step + 1))
|
740 |
+
scipy.misc.imsave(sample_path_imgs, merged_img)
|
741 |
+
scipy.misc.imsave(sample_path_imgs_small, merged_img_small)
|
742 |
+
|
743 |
+
if self.sample_per_channel:
|
744 |
+
sample_path_ch_maps = os.path.join(
|
745 |
+
self.save_sample_path, 'epoch-%d-train-iter-%d-3.png' % (epoch, step + 1))
|
746 |
+
sample_path_ch_maps_small = os.path.join(
|
747 |
+
self.save_sample_path, 'epoch-%d-train-iter-%d-3-s.png' % (epoch, step + 1))
|
748 |
+
scipy.misc.imsave(sample_path_ch_maps, map_per_channel)
|
749 |
+
scipy.misc.imsave(sample_path_ch_maps_small, map_per_channel_small)
|
750 |
+
|
751 |
+
print('*** Finished Training ***')
|
752 |
+
|
753 |
+
def get_image_maps(self, test_image, reuse=None, norm=False):
|
754 |
+
""" returns heatmaps of input image (menpo image object)"""
|
755 |
+
|
756 |
+
self.add_placeholders()
|
757 |
+
# build model
|
758 |
+
pred_hm_p, pred_hm_f, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
|
759 |
+
|
760 |
+
with tf.Session(config=self.config) as sess:
|
761 |
+
# load trained parameters
|
762 |
+
saver = tf.train.Saver()
|
763 |
+
saver.restore(sess, self.test_model_path)
|
764 |
+
_, model_name = os.path.split(self.test_model_path)
|
765 |
+
|
766 |
+
test_image = test_image.pixels_with_channels_at_back().astype('float32')
|
767 |
+
if norm:
|
768 |
+
if self.scale is '255':
|
769 |
+
test_image *= 255
|
770 |
+
elif self.scale is '0':
|
771 |
+
test_image = 2 * test_image - 1
|
772 |
+
|
773 |
+
map_primary, map_fusion, map_upsample = sess.run(
|
774 |
+
[pred_hm_p, pred_hm_f, pred_hm_u], {self.images: np.expand_dims(test_image, 0)})
|
775 |
+
|
776 |
+
return map_primary, map_fusion, map_upsample
|
777 |
+
|
778 |
+
def get_landmark_predictions(self, img_list, pdm_models_dir, clm_model_path, reuse=None, map_to_input_size=False):
|
779 |
+
|
780 |
+
"""returns dictionary with landmark predictions of each step of the ECpTp algorithm and ECT"""
|
781 |
+
|
782 |
+
from thirdparty.face_of_art.pdm_clm_functions import feature_based_pdm_corr, clm_correct
|
783 |
+
|
784 |
+
jaw_line_inds = np.arange(0, 17)
|
785 |
+
left_brow_inds = np.arange(17, 22)
|
786 |
+
right_brow_inds = np.arange(22, 27)
|
787 |
+
|
788 |
+
self.add_placeholders()
|
789 |
+
# build model
|
790 |
+
_, _, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
|
791 |
+
|
792 |
+
with tf.Session(config=self.config) as sess:
|
793 |
+
# load trained parameters
|
794 |
+
saver = tf.train.Saver()
|
795 |
+
saver.restore(sess, self.test_model_path)
|
796 |
+
_, model_name = os.path.split(self.test_model_path)
|
797 |
+
e_list = []
|
798 |
+
ect_list = []
|
799 |
+
ecp_list = []
|
800 |
+
ecpt_list = []
|
801 |
+
ecptp_jaw_list = []
|
802 |
+
ecptp_out_list = []
|
803 |
+
|
804 |
+
for test_image in img_list:
|
805 |
+
|
806 |
+
if map_to_input_size:
|
807 |
+
test_image_transform = test_image[1]
|
808 |
+
test_image=test_image[0]
|
809 |
+
|
810 |
+
# get landmarks for estimation stage
|
811 |
+
if test_image.n_channels < 3:
|
812 |
+
test_image_map = sess.run(
|
813 |
+
pred_hm_u, {self.images: np.expand_dims(
|
814 |
+
gray2rgb(test_image.pixels_with_channels_at_back()).astype('float32'), 0)})
|
815 |
+
else:
|
816 |
+
test_image_map = sess.run(
|
817 |
+
pred_hm_u, {self.images: np.expand_dims(
|
818 |
+
test_image.pixels_with_channels_at_back().astype('float32'), 0)})
|
819 |
+
init_lms = heat_maps_to_landmarks(np.squeeze(test_image_map))
|
820 |
+
|
821 |
+
# get landmarks for part-based correction stage
|
822 |
+
p_pdm_lms = feature_based_pdm_corr(lms_init=init_lms, models_dir=pdm_models_dir, train_type='basic')
|
823 |
+
|
824 |
+
# get landmarks for part-based tuning stage
|
825 |
+
try: # clm may not converge
|
826 |
+
pdm_clm_lms = clm_correct(
|
827 |
+
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=p_pdm_lms)
|
828 |
+
except:
|
829 |
+
pdm_clm_lms = p_pdm_lms.copy()
|
830 |
+
|
831 |
+
# get landmarks ECT
|
832 |
+
try: # clm may not converge
|
833 |
+
ect_lms = clm_correct(
|
834 |
+
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=init_lms)
|
835 |
+
except:
|
836 |
+
ect_lms = p_pdm_lms.copy()
|
837 |
+
|
838 |
+
# get landmarks for ECpTp_out (tune jaw and eyebrows)
|
839 |
+
ecptp_out = p_pdm_lms.copy()
|
840 |
+
ecptp_out[left_brow_inds] = pdm_clm_lms[left_brow_inds]
|
841 |
+
ecptp_out[right_brow_inds] = pdm_clm_lms[right_brow_inds]
|
842 |
+
ecptp_out[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
|
843 |
+
|
844 |
+
# get landmarks for ECpTp_jaw (tune jaw)
|
845 |
+
ecptp_jaw = p_pdm_lms.copy()
|
846 |
+
ecptp_jaw[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
|
847 |
+
|
848 |
+
if map_to_input_size:
|
849 |
+
ecptp_jaw = test_image_transform.apply(ecptp_jaw)
|
850 |
+
ecptp_out = test_image_transform.apply(ecptp_out)
|
851 |
+
ect_lms = test_image_transform.apply(ect_lms)
|
852 |
+
init_lms = test_image_transform.apply(init_lms)
|
853 |
+
p_pdm_lms = test_image_transform.apply(p_pdm_lms)
|
854 |
+
pdm_clm_lms = test_image_transform.apply(pdm_clm_lms)
|
855 |
+
|
856 |
+
ecptp_jaw_list.append(ecptp_jaw) # E + p-correction + p-tuning (ECpTp_jaw)
|
857 |
+
ecptp_out_list.append(ecptp_out) # E + p-correction + p-tuning (ECpTp_out)
|
858 |
+
ect_list.append(ect_lms) # ECT prediction
|
859 |
+
e_list.append(init_lms) # init prediction from heatmap network (E)
|
860 |
+
ecp_list.append(p_pdm_lms) # init prediction + part pdm correction (ECp)
|
861 |
+
ecpt_list.append(pdm_clm_lms) # init prediction + part pdm correction + global tuning (ECpT)
|
862 |
+
|
863 |
+
pred_dict = {
|
864 |
+
'E': e_list,
|
865 |
+
'ECp': ecp_list,
|
866 |
+
'ECpT': ecpt_list,
|
867 |
+
'ECT': ect_list,
|
868 |
+
'ECpTp_jaw': ecptp_jaw_list,
|
869 |
+
'ECpTp_out': ecptp_out_list
|
870 |
+
}
|
871 |
+
|
872 |
+
return pred_dict
|