Spaces:
svjack
/
Runtime error

yuandong513 commited on
Commit
17cd746
·
1 Parent(s): 13fa4fd

feat: init

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +7 -5
  2. app.py +568 -0
  3. app_lam.py +433 -0
  4. app_preprocess.py +387 -0
  5. configs/inference/lam-20k-8gpu.yaml +130 -0
  6. configs/stylematte_config.json +2311 -0
  7. external/human_matting/__init__.py +1 -0
  8. external/human_matting/matting_engine.py +66 -0
  9. external/human_matting/stylematte.py +272 -0
  10. external/landmark_detection/FaceBoxesV2/__init__.py +2 -0
  11. external/landmark_detection/FaceBoxesV2/detector.py +39 -0
  12. external/landmark_detection/FaceBoxesV2/faceboxes_detector.py +97 -0
  13. external/landmark_detection/FaceBoxesV2/utils/__init__.py +0 -0
  14. external/landmark_detection/FaceBoxesV2/utils/box_utils.py +276 -0
  15. external/landmark_detection/FaceBoxesV2/utils/build.py +57 -0
  16. external/landmark_detection/FaceBoxesV2/utils/config.py +14 -0
  17. external/landmark_detection/FaceBoxesV2/utils/faceboxes.py +239 -0
  18. external/landmark_detection/FaceBoxesV2/utils/make.sh +3 -0
  19. external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py +0 -0
  20. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c +0 -0
  21. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py +0 -0
  22. external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx +163 -0
  23. external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp +2 -0
  24. external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx +31 -0
  25. external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu +144 -0
  26. external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py +38 -0
  27. external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py +15 -0
  28. external/landmark_detection/FaceBoxesV2/utils/prior_box.py +43 -0
  29. external/landmark_detection/FaceBoxesV2/utils/timer.py +40 -0
  30. external/landmark_detection/README.md +110 -0
  31. external/landmark_detection/conf/__init__.py +1 -0
  32. external/landmark_detection/conf/alignment.py +239 -0
  33. external/landmark_detection/conf/base.py +94 -0
  34. external/landmark_detection/config.json +15 -0
  35. external/landmark_detection/data_processor/CheckFaceKeyPoint.py +147 -0
  36. external/landmark_detection/data_processor/align.py +193 -0
  37. external/landmark_detection/data_processor/process_pcd.py +250 -0
  38. external/landmark_detection/evaluate.py +258 -0
  39. external/landmark_detection/infer_folder.py +253 -0
  40. external/landmark_detection/infer_image.py +251 -0
  41. external/landmark_detection/infer_video.py +287 -0
  42. external/landmark_detection/lib/__init__.py +9 -0
  43. external/landmark_detection/lib/backbone/__init__.py +5 -0
  44. external/landmark_detection/lib/backbone/core/coord_conv.py +157 -0
  45. external/landmark_detection/lib/backbone/stackedHGNetV1.py +307 -0
  46. external/landmark_detection/lib/dataset/__init__.py +11 -0
  47. external/landmark_detection/lib/dataset/alignmentDataset.py +316 -0
  48. external/landmark_detection/lib/dataset/augmentation.py +355 -0
  49. external/landmark_detection/lib/dataset/decoder/__init__.py +8 -0
  50. external/landmark_detection/lib/dataset/decoder/decoder_default.py +38 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: LAM
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.23.3
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LAM_test
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: Large Avatar Model for One-shot Animatable Gaussian Head
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024-2025, Yisheng He, Yuan Dong
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ os.system("rm -rf /data-nvme/zerogpu-offload/")
18
+ os.system("pip install chumpy")
19
+ # os.system("pip uninstall -y basicsr")
20
+ os.system("pip install Cython")
21
+ os.system("pip install ./wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")
22
+ os.system("pip install ./wheels/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")
23
+ os.system("pip install ./wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl --force-reinstall")
24
+ os.system(
25
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html")
26
+ os.system("pip install numpy==1.23.0")
27
+
28
+ import cv2
29
+ import sys
30
+ import base64
31
+ import subprocess
32
+
33
+ import argparse
34
+ from glob import glob
35
+ import gradio as gr
36
+ import numpy as np
37
+ from PIL import Image
38
+ from omegaconf import OmegaConf
39
+
40
+ import torch
41
+ import moviepy.editor as mpy
42
+ from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
43
+ from lam.utils.ffmpeg_utils import images_to_video
44
+
45
+ import spaces
46
+
47
+
48
+ def compile_module(subfolder, script):
49
+ try:
50
+ # Save the current working directory
51
+ current_dir = os.getcwd()
52
+ # Change directory to the subfolder
53
+ os.chdir(os.path.join(current_dir, subfolder))
54
+ # Run the compilation command
55
+ result = subprocess.run(
56
+ ["sh", script],
57
+ capture_output=True,
58
+ text=True,
59
+ check=True
60
+ )
61
+ # Print the compilation output
62
+ print("Compilation output:", result.stdout)
63
+
64
+ except Exception as e:
65
+ # Print any error that occurred
66
+ print(f"An error occurred: {e}")
67
+ finally:
68
+ # Ensure returning to the original directory
69
+ os.chdir(current_dir)
70
+ print("Returned to the original directory.")
71
+
72
+
73
+ # compile flame_tracking dependence submodule
74
+ compile_module("external/landmark_detection/FaceBoxesV2/utils/", "make.sh")
75
+ from flame_tracking_single_image import FlameTrackingSingleImage
76
+
77
+
78
+ def launch_pretrained():
79
+ from huggingface_hub import snapshot_download, hf_hub_download
80
+ # launch pretrained for flame tracking.
81
+ hf_hub_download(repo_id='yuandong513/flametracking_model',
82
+ repo_type='model',
83
+ filename='pretrain_model.tar',
84
+ local_dir='./')
85
+ os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar')
86
+ # launch human model files
87
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
88
+ repo_type='model',
89
+ filename='LAM_human_model.tar',
90
+ local_dir='./')
91
+ os.system('tar -xf LAM_human_model.tar && rm LAM_human_model.tar')
92
+ # launch pretrained for LAM
93
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="config.json")
94
+ print(model_dir)
95
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="model.safetensors")
96
+ print(model_dir)
97
+ model_dir = hf_hub_download(repo_id="3DAIGC/LAM-20K", repo_type="model", local_dir="./exps/releases/lam/lam-20k/step_045500/", filename="README.md")
98
+ print(model_dir)
99
+ # launch example for LAM
100
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
101
+ repo_type='model',
102
+ filename='LAM_assets.tar',
103
+ local_dir='./')
104
+ os.system('tar -xf LAM_assets.tar && rm LAM_assets.tar')
105
+ hf_hub_download(repo_id='3DAIGC/LAM-assets',
106
+ repo_type='model',
107
+ filename='config.json',
108
+ local_dir='./tmp/')
109
+
110
+
111
+ def launch_env_not_compile_with_cuda():
112
+ os.system('pip install chumpy')
113
+ os.system('pip install numpy==1.23.0')
114
+ os.system(
115
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
116
+ )
117
+
118
+
119
+ def assert_input_image(input_image):
120
+ if input_image is None:
121
+ raise gr.Error('No image selected or uploaded!')
122
+
123
+
124
+ def prepare_working_dir():
125
+ import tempfile
126
+ working_dir = tempfile.TemporaryDirectory()
127
+ return working_dir
128
+
129
+
130
+ def init_preprocessor():
131
+ from lam.utils.preprocess import Preprocessor
132
+ global preprocessor
133
+ preprocessor = Preprocessor()
134
+
135
+
136
+ def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
137
+ working_dir):
138
+ image_raw = os.path.join(working_dir.name, 'raw.png')
139
+ with Image.fromarray(image_in) as img:
140
+ img.save(image_raw)
141
+ image_out = os.path.join(working_dir.name, 'rembg.png')
142
+ success = preprocessor.preprocess(image_path=image_raw,
143
+ save_path=image_out,
144
+ rmbg=remove_bg,
145
+ recenter=recenter)
146
+ assert success, f'Failed under preprocess_fn!'
147
+ return image_out
148
+
149
+
150
+ def get_image_base64(path):
151
+ with open(path, 'rb') as image_file:
152
+ encoded_string = base64.b64encode(image_file.read()).decode()
153
+ return f'data:image/png;base64,{encoded_string}'
154
+
155
+
156
+ def save_imgs_2_video(imgs, v_pth, fps=30):
157
+ # moviepy example
158
+ from moviepy.editor import ImageSequenceClip, VideoFileClip
159
+ images = [image.astype(np.uint8) for image in imgs]
160
+ clip = ImageSequenceClip(images, fps=fps)
161
+ # final_duration = len(images) / fps
162
+ # clip = clip.subclip(0, final_duration)
163
+ clip = clip.subclip(0, len(images) / fps)
164
+ clip.write_videofile(v_pth, codec='libx264')
165
+
166
+ import cv2
167
+ cap = cv2.VideoCapture(v_pth)
168
+ nf = cap.get(cv2.CAP_PROP_FRAME_COUNT)
169
+ if nf != len(images):
170
+ print("="*100+f"\n{v_pth} moviepy saved video frame error."+"\n"+"="*100)
171
+ print(f"Video saved successfully at {v_pth}")
172
+
173
+
174
+ def add_audio_to_video(video_path, out_path, audio_path, fps=30):
175
+ # Import necessary modules from moviepy
176
+ from moviepy.editor import VideoFileClip, AudioFileClip
177
+
178
+ # Load video file into VideoFileClip object
179
+ video_clip = VideoFileClip(video_path)
180
+
181
+ # Load audio file into AudioFileClip object
182
+ audio_clip = AudioFileClip(audio_path)
183
+
184
+ # Hard code clip audio
185
+ if audio_clip.duration > 10:
186
+ audio_clip = audio_clip.subclip(0, 10)
187
+
188
+ # Attach audio clip to video clip (replaces existing audio)
189
+ video_clip_with_audio = video_clip.set_audio(audio_clip)
190
+
191
+ # Export final video with audio using standard codecs
192
+ video_clip_with_audio.write_videofile(out_path, codec='libx264', audio_codec='aac', fps=fps)
193
+
194
+ print(f"Audio added successfully at {out_path}")
195
+
196
+
197
+ def parse_configs():
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument("--config", type=str)
200
+ parser.add_argument("--infer", type=str)
201
+ args, unknown = parser.parse_known_args()
202
+
203
+ cfg = OmegaConf.create()
204
+ cli_cfg = OmegaConf.from_cli(unknown)
205
+
206
+ # parse from ENV
207
+ if os.environ.get("APP_INFER") is not None:
208
+ args.infer = os.environ.get("APP_INFER")
209
+ if os.environ.get("APP_MODEL_NAME") is not None:
210
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
211
+
212
+ args.config = args.infer if args.config is None else args.config
213
+
214
+ if args.config is not None:
215
+ cfg_train = OmegaConf.load(args.config)
216
+ cfg.source_size = cfg_train.dataset.source_image_res
217
+ try:
218
+ cfg.src_head_size = cfg_train.dataset.src_head_size
219
+ except:
220
+ cfg.src_head_size = 112
221
+ cfg.render_size = cfg_train.dataset.render_image.high
222
+ _relative_path = os.path.join(
223
+ cfg_train.experiment.parent,
224
+ cfg_train.experiment.child,
225
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
226
+ )
227
+
228
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
229
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
230
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
231
+
232
+ if args.infer is not None:
233
+ cfg_infer = OmegaConf.load(args.infer)
234
+ cfg.merge_with(cfg_infer)
235
+ cfg.setdefault(
236
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
237
+ )
238
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
239
+ cfg.setdefault(
240
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
241
+ )
242
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
243
+
244
+ cfg.motion_video_read_fps = 30
245
+ cfg.merge_with(cli_cfg)
246
+
247
+ cfg.setdefault("logger", "INFO")
248
+
249
+ assert cfg.model_name is not None, "model_name is required"
250
+
251
+ return cfg, cfg_train
252
+
253
+
254
+ def demo_lam(flametracking, lam, cfg):
255
+ @spaces.GPU(duration=80)
256
+ def core_fn(image_path: str, video_params, working_dir):
257
+ image_raw = os.path.join(working_dir.name, "raw.png")
258
+ with Image.open(image_path).convert('RGB') as img:
259
+ img.save(image_raw)
260
+
261
+ base_vid = os.path.basename(video_params).split(".")[0]
262
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
263
+ base_iid = os.path.basename(image_path).split('.')[0]
264
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
265
+
266
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
267
+ dump_image_path = os.path.join(working_dir.name, "output.png")
268
+
269
+ # prepare dump paths
270
+ omit_prefix = os.path.dirname(image_raw)
271
+ image_name = os.path.basename(image_raw)
272
+ uid = image_name.split(".")[0]
273
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
274
+ subdir_path = (
275
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
276
+ )
277
+ print("subdir_path and uid:", subdir_path, uid)
278
+
279
+ motion_seqs_dir = flame_params_dir
280
+
281
+ dump_image_dir = os.path.dirname(dump_image_path)
282
+ os.makedirs(dump_image_dir, exist_ok=True)
283
+
284
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
285
+
286
+ dump_tmp_dir = dump_image_dir
287
+
288
+ if os.path.exists(dump_video_path):
289
+ return dump_image_path, dump_video_path
290
+
291
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
292
+ vis_motion = cfg.get("vis_motion", False) # False
293
+
294
+ # preprocess input image: segmentation, flame params estimation
295
+ # """
296
+ return_code = flametracking.preprocess(image_raw)
297
+ assert (return_code == 0), "flametracking preprocess failed!"
298
+ return_code = flametracking.optimize()
299
+ assert (return_code == 0), "flametracking optimize failed!"
300
+ return_code, output_dir = flametracking.export()
301
+ assert (return_code == 0), "flametracking export failed!"
302
+ image_path = os.path.join(output_dir, "images/00000_00.png")
303
+ # """
304
+
305
+ mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png")
306
+ print(image_path, mask_path)
307
+
308
+ aspect_standard = 1.0 / 1.0
309
+ source_size = cfg.source_size
310
+ render_size = cfg.render_size
311
+ render_fps = 30
312
+ # prepare reference image
313
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0,
314
+ bg_color=1.,
315
+ max_tgt_size=None, aspect_standard=aspect_standard,
316
+ enlarge_ratio=[1.0, 1.0],
317
+ render_tgt_size=source_size, multiply=14, need_mask=True,
318
+ get_shape_param=True)
319
+
320
+ # save masked image for vis
321
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
322
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
323
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
324
+
325
+ # prepare motion seq
326
+ src = image_path.split('/')[-3]
327
+ driven = motion_seqs_dir.split('/')[-2]
328
+ src_driven = [src, driven]
329
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
330
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1, 0],
331
+ render_image_res=render_size, multiply=16,
332
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
333
+ shape_param=shape_param, test_sample=False, cross_id=False,
334
+ src_driven=src_driven, max_squen_length=300)
335
+
336
+ # start inference
337
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
338
+ device, dtype = "cuda", torch.float32
339
+ print("start to inference...................")
340
+ with torch.no_grad():
341
+ # TODO check device and dtype
342
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
343
+ render_c2ws=motion_seq["render_c2ws"].to(device),
344
+ render_intrs=motion_seq["render_intrs"].to(device),
345
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
346
+ flame_params={k: v.to(device) for k, v in motion_seq["flame_params"].items()})
347
+
348
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
349
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
350
+ mask[mask < 0.5] = 0.0
351
+ rgb = rgb * mask + (1 - mask) * 1
352
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
353
+ if vis_motion:
354
+ vis_ref_img = np.tile(
355
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :,
356
+ :],
357
+ (rgb.shape[0], 1, 1, 1),
358
+ )
359
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
360
+
361
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
362
+
363
+ print("==="*36, "\nrgb length:", rgb.shape, render_fps, "==="*36)
364
+ save_imgs_2_video(rgb, dump_video_path, render_fps)
365
+ # images_to_video(rgb, output_path=dump_video_path, fps=30, gradio_codec=False, verbose=True)
366
+ audio_path = os.path.join("./assets/sample_motion/export", base_vid, base_vid + ".wav")
367
+ dump_video_path_wa = dump_video_path.replace(".mp4", "_audio.mp4")
368
+ add_audio_to_video(dump_video_path, dump_video_path_wa, audio_path)
369
+
370
+ return dump_image_path, dump_video_path_wa
371
+
372
+ def core_fn_space(image_path: str, video_params, working_dir):
373
+ return core_fn(image_path, video_params, working_dir)
374
+
375
+ with gr.Blocks(analytics_enabled=False) as demo:
376
+
377
+ logo_url = './assets/images/logo.jpeg'
378
+ logo_base64 = get_image_base64(logo_url)
379
+ gr.HTML(f"""
380
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
381
+ <div>
382
+ <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Avatar Model for One-shot Animatable Gaussian Head</h1>
383
+ </div>
384
+ </div>
385
+ """)
386
+
387
+ gr.HTML(
388
+ """
389
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
390
+ <a class="flex-item" href="https://arxiv.org/abs/2502.17796" target="_blank">
391
+ <img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper">
392
+ </a>
393
+ <a class="flex-item" href="https://aigc3d.github.io/projects/LAM/" target="_blank">
394
+ <img src="https://img.shields.io/badge/Project-LAM-blue" alt="Project Page">
395
+ </a>
396
+ <a class="flex-item" href="https://github.com/aigc3d/LAM" target="_blank">
397
+ <img src="https://img.shields.io/github/stars/aigc3d/LAM?label=Github%20★&logo=github&color=C8C" alt="badge-github-stars">
398
+ </a>
399
+ <a class="flex-item" href="https://youtu.be/FrfE3RYSKhk" target="_blank">
400
+ <img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video">
401
+ </a>
402
+ </div>
403
+ """
404
+ )
405
+
406
+
407
+ gr.HTML("""<div style="margin-top: -10px">
408
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes1: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p>
409
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes2: Due to computational constraints with Hugging Face's ZeroGPU infrastructure, video generation requires ~1 minute per instance.</h4></p>
410
+ <p style="margin: 4px 0; line-height: 1.2"><h4 style="color: red; margin: 2px 0">Notes3: Using LAM-20K model (lower quality than premium LAM-80K) to mitigate processing latency.</h4></p>
411
+ </div>""")
412
+
413
+
414
+
415
+
416
+ # DISPLAY
417
+ with gr.Row():
418
+ with gr.Column(variant='panel', scale=1):
419
+ with gr.Tabs(elem_id='lam_input_image'):
420
+ with gr.TabItem('Input Image'):
421
+ with gr.Row():
422
+ input_image = gr.Image(label='Input Image',
423
+ image_mode='RGB',
424
+ height=480,
425
+ width=270,
426
+ sources='upload',
427
+ type='filepath',
428
+ elem_id='content_image')
429
+ # EXAMPLES
430
+ with gr.Row():
431
+ examples = [
432
+ ['assets/sample_input/messi.png'],
433
+ ['assets/sample_input/status.png'],
434
+ ['assets/sample_input/james.png'],
435
+ ['assets/sample_input/cluo.jpg'],
436
+ ['assets/sample_input/dufu.jpg'],
437
+ ['assets/sample_input/libai.jpg'],
438
+ ['assets/sample_input/barbara.jpg'],
439
+ ['assets/sample_input/pop.png'],
440
+ ['assets/sample_input/musk.jpg'],
441
+ ['assets/sample_input/speed.jpg'],
442
+ ['assets/sample_input/zhouxingchi.jpg'],
443
+ ]
444
+ gr.Examples(
445
+ examples=examples,
446
+ inputs=[input_image],
447
+ examples_per_page=20
448
+ )
449
+
450
+
451
+ with gr.Column():
452
+ with gr.Tabs(elem_id='lam_input_video'):
453
+ with gr.TabItem('Input Video'):
454
+ with gr.Row():
455
+ video_input = gr.Video(label='Input Video',
456
+ height=480,
457
+ width=270,
458
+ interactive=False)
459
+
460
+ examples = ['./assets/sample_motion/export/Speeding_Scandal/Speeding_Scandal.mp4',
461
+ './assets/sample_motion/export/Look_In_My_Eyes/Look_In_My_Eyes.mp4',
462
+ './assets/sample_motion/export/D_ANgelo_Dinero/D_ANgelo_Dinero.mp4',
463
+ './assets/sample_motion/export/Michael_Wayne_Rosen/Michael_Wayne_Rosen.mp4',
464
+ './assets/sample_motion/export/I_Am_Iron_Man/I_Am_Iron_Man.mp4',
465
+ './assets/sample_motion/export/Anti_Drugs/Anti_Drugs.mp4',
466
+ './assets/sample_motion/export/Pen_Pineapple_Apple_Pen/Pen_Pineapple_Apple_Pen.mp4',
467
+ './assets/sample_motion/export/Joe_Biden/Joe_Biden.mp4',
468
+ './assets/sample_motion/export/Donald_Trump/Donald_Trump.mp4',
469
+ './assets/sample_motion/export/Taylor_Swift/Taylor_Swift.mp4',
470
+ './assets/sample_motion/export/GEM/GEM.mp4',
471
+ './assets/sample_motion/export/The_Shawshank_Redemption/The_Shawshank_Redemption.mp4'
472
+ ]
473
+ print("Video example list {}".format(examples))
474
+
475
+ gr.Examples(
476
+ examples=examples,
477
+ inputs=[video_input],
478
+ examples_per_page=20,
479
+ )
480
+ with gr.Column(variant='panel', scale=1):
481
+ with gr.Tabs(elem_id='lam_processed_image'):
482
+ with gr.TabItem('Processed Image'):
483
+ with gr.Row():
484
+ processed_image = gr.Image(
485
+ label='Processed Image',
486
+ image_mode='RGBA',
487
+ type='filepath',
488
+ elem_id='processed_image',
489
+ height=480,
490
+ width=270,
491
+ interactive=False)
492
+
493
+ with gr.Column(variant='panel', scale=1):
494
+ with gr.Tabs(elem_id='lam_render_video'):
495
+ with gr.TabItem('Rendered Video'):
496
+ with gr.Row():
497
+ output_video = gr.Video(label='Rendered Video',
498
+ format='mp4',
499
+ height=480,
500
+ width=270,
501
+ autoplay=True)
502
+
503
+ # SETTING
504
+ with gr.Row():
505
+ with gr.Column(variant='panel', scale=1):
506
+ submit = gr.Button('Generate',
507
+ elem_id='lam_generate',
508
+ variant='primary')
509
+
510
+ main_fn = core_fn
511
+
512
+ working_dir = gr.State()
513
+ submit.click(
514
+ fn=assert_input_image,
515
+ inputs=[input_image],
516
+ queue=False,
517
+ ).success(
518
+ fn=prepare_working_dir,
519
+ outputs=[working_dir],
520
+ queue=False,
521
+ ).success(
522
+ fn=main_fn,
523
+ inputs=[input_image, video_input,
524
+ working_dir], # video_params refer to smpl dir
525
+ outputs=[processed_image, output_video],
526
+ )
527
+
528
+ demo.queue()
529
+ demo.launch()
530
+
531
+
532
+ def _build_model(cfg):
533
+ from lam.models import model_dict
534
+ from lam.utils.hf_hub import wrap_model_hub
535
+
536
+ hf_model_cls = wrap_model_hub(model_dict["lam"])
537
+ model = hf_model_cls.from_pretrained(cfg.model_name)
538
+
539
+ return model
540
+
541
+
542
+ def launch_gradio_app():
543
+ os.environ.update({
544
+ 'APP_ENABLED': '1',
545
+ 'APP_MODEL_NAME':
546
+ './exps/releases/lam/lam-20k/step_045500/',
547
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
548
+ 'APP_TYPE': 'infer.lam',
549
+ 'NUMBA_THREADING_LAYER': 'omp',
550
+ })
551
+
552
+ cfg, _ = parse_configs()
553
+ lam = _build_model(cfg)
554
+ lam.to('cuda')
555
+
556
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
557
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
558
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
559
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
560
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
561
+ detect_iris_landmarks=False)
562
+
563
+ demo_lam(flametracking, lam, cfg)
564
+
565
+
566
+ if __name__ == '__main__':
567
+ launch_pretrained()
568
+ launch_gradio_app()
app_lam.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024-2025, Yisheng He, Yuan Dong
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import cv2
17
+ import base64
18
+ import subprocess
19
+
20
+ import gradio as gr
21
+ import numpy as np
22
+ from PIL import Image
23
+ import argparse
24
+ from omegaconf import OmegaConf
25
+
26
+ import torch
27
+ from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image
28
+ import moviepy.editor as mpy
29
+ from lam.utils.ffmpeg_utils import images_to_video
30
+ import sys
31
+ from flame_tracking_single_image import FlameTrackingSingleImage
32
+
33
+ try:
34
+ import spaces
35
+ except:
36
+ pass
37
+
38
+
39
+ def launch_pretrained():
40
+ from huggingface_hub import snapshot_download, hf_hub_download
41
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
42
+ repo_type='model',
43
+ filename='assets.tar',
44
+ local_dir='./')
45
+ os.system('tar -xvf assets.tar && rm assets.tar')
46
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
47
+ repo_type='model',
48
+ filename='LHM-0.5B.tar',
49
+ local_dir='./')
50
+ os.system('tar -xvf LHM-0.5B.tar && rm LHM-0.5B.tar')
51
+ hf_hub_download(repo_id='DyrusQZ/LHM_Runtime',
52
+ repo_type='model',
53
+ filename='LHM_prior_model.tar',
54
+ local_dir='./')
55
+ os.system('tar -xvf LHM_prior_model.tar && rm LHM_prior_model.tar')
56
+
57
+
58
+ def launch_env_not_compile_with_cuda():
59
+ os.system('pip install chumpy')
60
+ os.system('pip uninstall -y basicsr')
61
+ os.system('pip install git+https://github.com/hitsz-zuoqi/BasicSR/')
62
+ os.system('pip install numpy==1.23.0')
63
+ os.system(
64
+ 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html'
65
+ )
66
+
67
+
68
+ def assert_input_image(input_image):
69
+ if input_image is None:
70
+ raise gr.Error('No image selected or uploaded!')
71
+
72
+
73
+ def prepare_working_dir():
74
+ import tempfile
75
+ working_dir = tempfile.TemporaryDirectory()
76
+ return working_dir
77
+
78
+
79
+ def init_preprocessor():
80
+ from lam.utils.preprocess import Preprocessor
81
+ global preprocessor
82
+ preprocessor = Preprocessor()
83
+
84
+
85
+ def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool,
86
+ working_dir):
87
+ image_raw = os.path.join(working_dir.name, 'raw.png')
88
+ with Image.fromarray(image_in) as img:
89
+ img.save(image_raw)
90
+ image_out = os.path.join(working_dir.name, 'rembg.png')
91
+ success = preprocessor.preprocess(image_path=image_raw,
92
+ save_path=image_out,
93
+ rmbg=remove_bg,
94
+ recenter=recenter)
95
+ assert success, f'Failed under preprocess_fn!'
96
+ return image_out
97
+
98
+
99
+ def get_image_base64(path):
100
+ with open(path, 'rb') as image_file:
101
+ encoded_string = base64.b64encode(image_file.read()).decode()
102
+ return f'data:image/png;base64,{encoded_string}'
103
+
104
+
105
+ def save_imgs_2_video(imgs, v_pth, fps):
106
+ img_lst = [imgs[i] for i in range(imgs.shape[0])]
107
+ # Convert the list of NumPy arrays to a list of ImageClip objects
108
+ clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] # 0.1 seconds per frame
109
+
110
+ # Concatenate the ImageClips into a single VideoClip
111
+ video = mpy.concatenate_videoclips(clips, method="compose")
112
+
113
+ # Write the VideoClip to a file
114
+ video.write_videofile(v_pth, fps=fps) # setting fps to 10 as example
115
+
116
+
117
+ def parse_configs():
118
+
119
+ parser = argparse.ArgumentParser()
120
+ parser.add_argument("--config", type=str)
121
+ parser.add_argument("--infer", type=str)
122
+ args, unknown = parser.parse_known_args()
123
+
124
+ cfg = OmegaConf.create()
125
+ cli_cfg = OmegaConf.from_cli(unknown)
126
+
127
+ # parse from ENV
128
+ if os.environ.get("APP_INFER") is not None:
129
+ args.infer = os.environ.get("APP_INFER")
130
+ if os.environ.get("APP_MODEL_NAME") is not None:
131
+ cli_cfg.model_name = os.environ.get("APP_MODEL_NAME")
132
+
133
+ args.config = args.infer if args.config is None else args.config
134
+
135
+ if args.config is not None:
136
+ cfg_train = OmegaConf.load(args.config)
137
+ cfg.source_size = cfg_train.dataset.source_image_res
138
+ try:
139
+ cfg.src_head_size = cfg_train.dataset.src_head_size
140
+ except:
141
+ cfg.src_head_size = 112
142
+ cfg.render_size = cfg_train.dataset.render_image.high
143
+ _relative_path = os.path.join(
144
+ cfg_train.experiment.parent,
145
+ cfg_train.experiment.child,
146
+ os.path.basename(cli_cfg.model_name).split("_")[-1],
147
+ )
148
+
149
+ cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path)
150
+ cfg.image_dump = os.path.join("exps", "images", _relative_path)
151
+ cfg.video_dump = os.path.join("exps", "videos", _relative_path) # output path
152
+
153
+ if args.infer is not None:
154
+ cfg_infer = OmegaConf.load(args.infer)
155
+ cfg.merge_with(cfg_infer)
156
+ cfg.setdefault(
157
+ "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp")
158
+ )
159
+ cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images"))
160
+ cfg.setdefault(
161
+ "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos")
162
+ )
163
+ cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes"))
164
+
165
+ cfg.motion_video_read_fps = 6
166
+ cfg.merge_with(cli_cfg)
167
+
168
+ cfg.setdefault("logger", "INFO")
169
+
170
+ assert cfg.model_name is not None, "model_name is required"
171
+
172
+ return cfg, cfg_train
173
+
174
+
175
+ def demo_lam(flametracking, lam, cfg):
176
+
177
+ # @spaces.GPU(duration=80)
178
+ def core_fn(image_path: str, video_params, working_dir):
179
+ image_raw = os.path.join(working_dir.name, "raw.png")
180
+ with Image.open(image_path).convert('RGB') as img:
181
+ img.save(image_raw)
182
+
183
+ base_vid = os.path.basename(video_params).split(".")[0]
184
+ flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param")
185
+ base_iid = os.path.basename(image_path).split('.')[0]
186
+ image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png")
187
+
188
+ dump_video_path = os.path.join(working_dir.name, "output.mp4")
189
+ dump_image_path = os.path.join(working_dir.name, "output.png")
190
+
191
+ # prepare dump paths
192
+ omit_prefix = os.path.dirname(image_raw)
193
+ image_name = os.path.basename(image_raw)
194
+ uid = image_name.split(".")[0]
195
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "")
196
+ subdir_path = (
197
+ subdir_path[1:] if subdir_path.startswith("/") else subdir_path
198
+ )
199
+ print("subdir_path and uid:", subdir_path, uid)
200
+
201
+ motion_seqs_dir = flame_params_dir
202
+
203
+ dump_image_dir = os.path.dirname(dump_image_path)
204
+ os.makedirs(dump_image_dir, exist_ok=True)
205
+
206
+ print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path)
207
+
208
+ dump_tmp_dir = dump_image_dir
209
+
210
+ if os.path.exists(dump_video_path):
211
+ return dump_image_path, dump_video_path
212
+
213
+ motion_img_need_mask = cfg.get("motion_img_need_mask", False) # False
214
+ vis_motion = cfg.get("vis_motion", False) # False
215
+
216
+ # preprocess input image: segmentation, flame params estimation
217
+ return_code = flametracking.preprocess(image_raw)
218
+ assert (return_code == 0), "flametracking preprocess failed!"
219
+ return_code = flametracking.optimize()
220
+ assert (return_code == 0), "flametracking optimize failed!"
221
+ return_code, output_dir = flametracking.export()
222
+ assert (return_code == 0), "flametracking export failed!"
223
+
224
+ image_path = os.path.join(output_dir, "images/00000_00.png")
225
+ mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png")
226
+ print(image_path, mask_path)
227
+
228
+ aspect_standard = 1.0/1.0
229
+ source_size = cfg.source_size
230
+ render_size = cfg.render_size
231
+ render_fps = 30
232
+ # prepare reference image
233
+ image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1.,
234
+ max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0],
235
+ render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True)
236
+
237
+ # save masked image for vis
238
+ save_ref_img_path = os.path.join(dump_tmp_dir, "output.png")
239
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
240
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
241
+
242
+ # prepare motion seq
243
+ src = image_path.split('/')[-3]
244
+ driven = motion_seqs_dir.split('/')[-2]
245
+ src_driven = [src, driven]
246
+ motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps,
247
+ bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0],
248
+ render_image_res=render_size, multiply=16,
249
+ need_mask=motion_img_need_mask, vis_motion=vis_motion,
250
+ shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven)
251
+
252
+ # start inference
253
+ motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0)
254
+ device, dtype = "cuda", torch.float32
255
+ print("start to inference...................")
256
+ with torch.no_grad():
257
+ # TODO check device and dtype
258
+ res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None,
259
+ render_c2ws=motion_seq["render_c2ws"].to(device),
260
+ render_intrs=motion_seq["render_intrs"].to(device),
261
+ render_bg_colors=motion_seq["render_bg_colors"].to(device),
262
+ flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()})
263
+
264
+ rgb = res["comp_rgb"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
265
+ mask = res["comp_mask"].detach().cpu().numpy() # [Nv, H, W, 3], 0-1
266
+ mask[mask < 0.5] = 0.0
267
+ rgb = rgb * mask + (1 - mask) * 1
268
+ rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8)
269
+ if vis_motion:
270
+ vis_ref_img = np.tile(
271
+ cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :],
272
+ (rgb.shape[0], 1, 1, 1),
273
+ )
274
+ rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2)
275
+
276
+ os.makedirs(os.path.dirname(dump_video_path), exist_ok=True)
277
+
278
+ save_imgs_2_video(rgb, dump_video_path, render_fps)
279
+ # images_to_video(rgb, output_path=dump_video_path, fps=30, gradio_codec=False, verbose=True)
280
+
281
+ return dump_image_path, dump_video_path
282
+
283
+ with gr.Blocks(analytics_enabled=False) as demo:
284
+
285
+ logo_url = './assets/images/logo.png'
286
+ logo_base64 = get_image_base64(logo_url)
287
+ gr.HTML(f"""
288
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
289
+ <div>
290
+ <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> LAM: Large Avatar Model for One-shot Animatable Gaussian Head</h1>
291
+ </div>
292
+ </div>
293
+ """)
294
+ gr.HTML(
295
+ """<p><h4 style="color: red;"> Notes: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p>"""
296
+ )
297
+
298
+ # DISPLAY
299
+ with gr.Row():
300
+
301
+ with gr.Column(variant='panel', scale=1):
302
+ with gr.Tabs(elem_id='lam_input_image'):
303
+ with gr.TabItem('Input Image'):
304
+ with gr.Row():
305
+ input_image = gr.Image(label='Input Image',
306
+ image_mode='RGB',
307
+ height=480,
308
+ width=270,
309
+ sources='upload',
310
+ type='filepath', # 'numpy',
311
+ elem_id='content_image')
312
+ # EXAMPLES
313
+ with gr.Row():
314
+ examples = [
315
+ ['assets/sample_input/2w01/images/2w01.png'],
316
+ ['assets/sample_input/2w02/images/2w02.png'],
317
+ ['assets/sample_input/2w03/images/2w03.png'],
318
+ ['assets/sample_input/2w04/images/2w04.png'],
319
+ ]
320
+ gr.Examples(
321
+ examples=examples,
322
+ inputs=[input_image],
323
+ examples_per_page=20,
324
+ )
325
+
326
+ with gr.Column():
327
+ with gr.Tabs(elem_id='lam_input_video'):
328
+ with gr.TabItem('Input Video'):
329
+ with gr.Row():
330
+ video_input = gr.Video(label='Input Video',
331
+ height=480,
332
+ width=270,
333
+ interactive=False)
334
+
335
+ examples = [
336
+ './assets/sample_motion/export/clip1/clip1.mp4',
337
+ './assets/sample_motion/export/clip2/clip2.mp4',
338
+ './assets/sample_motion/export/clip3/clip3.mp4',
339
+ ]
340
+
341
+ gr.Examples(
342
+ examples=examples,
343
+ inputs=[video_input],
344
+ examples_per_page=20,
345
+ )
346
+ with gr.Column(variant='panel', scale=1):
347
+ with gr.Tabs(elem_id='lam_processed_image'):
348
+ with gr.TabItem('Processed Image'):
349
+ with gr.Row():
350
+ processed_image = gr.Image(
351
+ label='Processed Image',
352
+ image_mode='RGBA',
353
+ type='filepath',
354
+ elem_id='processed_image',
355
+ height=480,
356
+ width=270,
357
+ interactive=False)
358
+
359
+ with gr.Column(variant='panel', scale=1):
360
+ with gr.Tabs(elem_id='lam_render_video'):
361
+ with gr.TabItem('Rendered Video'):
362
+ with gr.Row():
363
+ output_video = gr.Video(label='Rendered Video',
364
+ format='mp4',
365
+ height=480,
366
+ width=270,
367
+ autoplay=True)
368
+
369
+ # SETTING
370
+ with gr.Row():
371
+ with gr.Column(variant='panel', scale=1):
372
+ submit = gr.Button('Generate',
373
+ elem_id='lam_generate',
374
+ variant='primary')
375
+
376
+ working_dir = gr.State()
377
+ submit.click(
378
+ fn=assert_input_image,
379
+ inputs=[input_image],
380
+ queue=False,
381
+ ).success(
382
+ fn=prepare_working_dir,
383
+ outputs=[working_dir],
384
+ queue=False,
385
+ ).success(
386
+ fn=core_fn,
387
+ inputs=[input_image, video_input,
388
+ working_dir], # video_params refer to smpl dir
389
+ outputs=[processed_image, output_video],
390
+ )
391
+
392
+ demo.queue()
393
+ demo.launch()
394
+
395
+
396
+ def _build_model(cfg):
397
+ from lam.models import model_dict
398
+ from lam.utils.hf_hub import wrap_model_hub
399
+
400
+ hf_model_cls = wrap_model_hub(model_dict["lam"])
401
+ model = hf_model_cls.from_pretrained(cfg.model_name)
402
+
403
+ return model
404
+
405
+ def launch_gradio_app():
406
+
407
+ os.environ.update({
408
+ 'APP_ENABLED': '1',
409
+ 'APP_MODEL_NAME':
410
+ './exps/releases/lam/lam-20k/step_045500/',
411
+ 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml',
412
+ 'APP_TYPE': 'infer.lam',
413
+ 'NUMBA_THREADING_LAYER': 'omp',
414
+ })
415
+
416
+ cfg, _ = parse_configs()
417
+ lam = _build_model(cfg)
418
+ lam.to('cuda')
419
+
420
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
421
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
422
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
423
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
424
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
425
+ detect_iris_landmarks=True)
426
+
427
+ demo_lam(flametracking, lam, cfg)
428
+
429
+
430
+ if __name__ == '__main__':
431
+ # launch_pretrained()
432
+ # launch_env_not_compile_with_cuda()
433
+ launch_gradio_app()
app_preprocess.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, Qi Zuo
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ os.system('rm -rf /data-nvme/zerogpu-offload/')
17
+ os.system('pip install numpy==1.23.0')
18
+ os.system('pip install ./wheels/pytorch3d-0.7.3-cp310-cp310-linux_x86_64.whl')
19
+
20
+ import argparse
21
+ import base64
22
+ import time
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ from omegaconf import OmegaConf
28
+ from PIL import Image
29
+
30
+ import gradio as gr
31
+ import spaces
32
+ from flame_tracking_single_image import FlameTrackingSingleImage
33
+ from ffmpeg_utils import images_to_video
34
+
35
+ # torch._dynamo.config.disable = True
36
+
37
+
38
+ def parse_configs():
39
+
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument('--config', type=str)
42
+ parser.add_argument('--infer', type=str)
43
+ args, unknown = parser.parse_known_args()
44
+
45
+ cfg = OmegaConf.create()
46
+ cli_cfg = OmegaConf.from_cli(unknown)
47
+
48
+ # parse from ENV
49
+ if os.environ.get('APP_INFER') is not None:
50
+ args.infer = os.environ.get('APP_INFER')
51
+ if os.environ.get('APP_MODEL_NAME') is not None:
52
+ cli_cfg.model_name = os.environ.get('APP_MODEL_NAME')
53
+
54
+ args.config = args.infer if args.config is None else args.config
55
+
56
+ if args.config is not None:
57
+ cfg_train = OmegaConf.load(args.config)
58
+ cfg.source_size = cfg_train.dataset.source_image_res
59
+ try:
60
+ cfg.src_head_size = cfg_train.dataset.src_head_size
61
+ except:
62
+ cfg.src_head_size = 112
63
+ cfg.render_size = cfg_train.dataset.render_image.high
64
+ _relative_path = os.path.join(
65
+ cfg_train.experiment.parent,
66
+ cfg_train.experiment.child,
67
+ os.path.basename(cli_cfg.model_name).split('_')[-1],
68
+ )
69
+
70
+ cfg.save_tmp_dump = os.path.join('exps', 'save_tmp', _relative_path)
71
+ cfg.image_dump = os.path.join('exps', 'images', _relative_path)
72
+ cfg.video_dump = os.path.join('exps', 'videos',
73
+ _relative_path) # output path
74
+
75
+ if args.infer is not None:
76
+ cfg_infer = OmegaConf.load(args.infer)
77
+ cfg.merge_with(cfg_infer)
78
+ cfg.setdefault('save_tmp_dump',
79
+ os.path.join('exps', cli_cfg.model_name, 'save_tmp'))
80
+ cfg.setdefault('image_dump',
81
+ os.path.join('exps', cli_cfg.model_name, 'images'))
82
+ cfg.setdefault('video_dump',
83
+ os.path.join('dumps', cli_cfg.model_name, 'videos'))
84
+ cfg.setdefault('mesh_dump',
85
+ os.path.join('dumps', cli_cfg.model_name, 'meshes'))
86
+
87
+ cfg.motion_video_read_fps = 6
88
+ cfg.merge_with(cli_cfg)
89
+
90
+ cfg.setdefault('logger', 'INFO')
91
+
92
+ assert cfg.model_name is not None, 'model_name is required'
93
+
94
+ return cfg, cfg_train
95
+
96
+
97
+
98
+ def launch_pretrained():
99
+ from huggingface_hub import snapshot_download, hf_hub_download
100
+ hf_hub_download(repo_id='yuandong513/flametracking_model',
101
+ repo_type='model',
102
+ filename='pretrain_model.tar',
103
+ local_dir='./')
104
+ os.system('tar -xf pretrain_model.tar && rm pretrain_model.tar')
105
+
106
+ def animation_infer(renderer, gs_model_list, query_points, smplx_params,
107
+ render_c2ws, render_intrs, render_bg_colors):
108
+ '''Inference code avoid repeat forward.
109
+ '''
110
+ render_h, render_w = int(render_intrs[0, 0, 1, 2] * 2), int(
111
+ render_intrs[0, 0, 0, 2] * 2)
112
+ # render target views
113
+ render_res_list = []
114
+ num_views = render_c2ws.shape[1]
115
+ start_time = time.time()
116
+
117
+ # render target views
118
+ render_res_list = []
119
+
120
+ for view_idx in range(num_views):
121
+ render_res = renderer.forward_animate_gs(
122
+ gs_model_list,
123
+ query_points,
124
+ renderer.get_single_view_smpl_data(smplx_params, view_idx),
125
+ render_c2ws[:, view_idx:view_idx + 1],
126
+ render_intrs[:, view_idx:view_idx + 1],
127
+ render_h,
128
+ render_w,
129
+ render_bg_colors[:, view_idx:view_idx + 1],
130
+ )
131
+ render_res_list.append(render_res)
132
+ print(
133
+ f'time elpased(animate gs model per frame):{(time.time() - start_time)/num_views}'
134
+ )
135
+
136
+ out = defaultdict(list)
137
+ for res in render_res_list:
138
+ for k, v in res.items():
139
+ if isinstance(v[0], torch.Tensor):
140
+ out[k].append(v.detach().cpu())
141
+ else:
142
+ out[k].append(v)
143
+ for k, v in out.items():
144
+ # print(f"out key:{k}")
145
+ if isinstance(v[0], torch.Tensor):
146
+ out[k] = torch.concat(v, dim=1)
147
+ if k in ['comp_rgb', 'comp_mask', 'comp_depth']:
148
+ out[k] = out[k][0].permute(
149
+ 0, 2, 3,
150
+ 1) # [1, Nv, 3, H, W] -> [Nv, 3, H, W] - > [Nv, H, W, 3]
151
+ else:
152
+ out[k] = v
153
+ return out
154
+
155
+
156
+ def assert_input_image(input_image):
157
+ if input_image is None:
158
+ raise gr.Error('No image selected or uploaded!')
159
+
160
+
161
+ def prepare_working_dir():
162
+ import tempfile
163
+ working_dir = tempfile.TemporaryDirectory()
164
+ return working_dir
165
+
166
+ def get_image_base64(path):
167
+ with open(path, 'rb') as image_file:
168
+ encoded_string = base64.b64encode(image_file.read()).decode()
169
+ return f'data:image/png;base64,{encoded_string}'
170
+
171
+
172
+ def demo_lhm(flametracking):
173
+ @spaces.GPU(duration=80)
174
+ def core_fn(image: str, video_params, working_dir):
175
+ image_raw = os.path.join(working_dir.name, 'raw.png')
176
+ with Image.fromarray(image) as img:
177
+ img.save(image_raw)
178
+
179
+ base_vid = os.path.basename(video_params).split('_')[0]
180
+
181
+ dump_video_path = os.path.join(working_dir.name, 'output.mp4')
182
+ dump_image_path = os.path.join(working_dir.name, 'output.png')
183
+
184
+ # prepare dump paths
185
+ omit_prefix = os.path.dirname(image_raw)
186
+ image_name = os.path.basename(image_raw)
187
+ uid = image_name.split('.')[0]
188
+ subdir_path = os.path.dirname(image_raw).replace(omit_prefix, '')
189
+ subdir_path = (subdir_path[1:]
190
+ if subdir_path.startswith('/') else subdir_path)
191
+ print('==> subdir_path and uid:', subdir_path, uid)
192
+
193
+ dump_image_dir = os.path.dirname(dump_image_path)
194
+ os.makedirs(dump_image_dir, exist_ok=True)
195
+
196
+ print('==> path:', image_raw, dump_image_dir, dump_video_path)
197
+
198
+ dump_tmp_dir = dump_image_dir
199
+
200
+ return_code = flametracking.preprocess(image_raw)
201
+ return_code = flametracking.optimize()
202
+ return_code, output_dir = flametracking.export()
203
+
204
+ print("==> output_dir:", output_dir)
205
+
206
+
207
+ save_ref_img_path = os.path.join(dump_tmp_dir, 'output.png')
208
+ vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() *
209
+ 255).astype(np.uint8)
210
+ Image.fromarray(vis_ref_img).save(save_ref_img_path)
211
+
212
+ # rendering !!!!
213
+ start_time = time.time()
214
+ batch_dict = dict()
215
+
216
+ rgb = cv2.imread(os.path.join(output_dir,'images/00000_00.png'))
217
+
218
+ for i in range(30):
219
+ images_to_video(
220
+ rgb,
221
+ output_path=dump_video_path,
222
+ fps=30,
223
+ gradio_codec=False,
224
+ verbose=True,
225
+ )
226
+
227
+ return dump_image_path, dump_video_path
228
+
229
+ _TITLE = '''LHM: Large Animatable Human Model'''
230
+
231
+ _DESCRIPTION = '''
232
+ <strong>Reconstruct a human avatar in 0.2 seconds with A100!</strong>
233
+ '''
234
+
235
+ with gr.Blocks(analytics_enabled=False, delete_cache=[3600, 3600]) as demo:
236
+
237
+ # </div>
238
+ logo_url = './asset/logo.jpeg'
239
+ logo_base64 = get_image_base64(logo_url)
240
+ gr.HTML(f"""
241
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
242
+ <div>
243
+ <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> Large Animatable Human Model </h1>
244
+ </div>
245
+ </div>
246
+ """)
247
+
248
+ gr.HTML("""
249
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin: 20px; gap: 10px;">
250
+ <a class="flex-item" href="https://arxiv.org/abs/2503.10625" target="_blank">
251
+ <img src="https://img.shields.io/badge/Paper-arXiv-darkred.svg" alt="arXiv Paper">
252
+ </a>
253
+ <a class="flex-item" href="https://lingtengqiu.github.io/LHM/" target="_blank">
254
+ <img src="https://img.shields.io/badge/Project-LHM-blue" alt="Project Page">
255
+ </a>
256
+ <a class="flex-item" href="https://github.com/aigc3d/LHM" target="_blank">
257
+ <img src="https://img.shields.io/github/stars/aigc3d/LHM?label=Github%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
258
+ </a>
259
+ <a class="flex-item" href="https://www.youtube.com/watch?v=tivEpz_yiEo" target="_blank">
260
+ <img src="https://img.shields.io/badge/Youtube-Video-red.svg" alt="Video">
261
+ </a>
262
+ </div>
263
+ """)
264
+
265
+ gr.HTML(
266
+ """<p><h4 style="color: red;"> Notes: Please input full-body image in case of detection errors. We simplify the pipeline in spaces: 1) using Rembg instead of SAM2; 2) limit the output video length to 10s; For best visual quality, try the inference code on Github instead.</h4></p>"""
267
+ )
268
+
269
+ # DISPLAY
270
+ with gr.Row():
271
+
272
+ with gr.Column(variant='panel', scale=1):
273
+ with gr.Tabs(elem_id='openlrm_input_image'):
274
+ with gr.TabItem('Input Image'):
275
+ with gr.Row():
276
+ input_image = gr.Image(label='Input Image',
277
+ image_mode='RGB',
278
+ height=480,
279
+ width=270,
280
+ sources='upload',
281
+ type='numpy',
282
+ elem_id='content_image')
283
+ # EXAMPLES
284
+ with gr.Row():
285
+ examples = [
286
+ ['asset/sample_input/00000.png'],
287
+ ]
288
+ gr.Examples(
289
+ examples=examples,
290
+ inputs=[input_image],
291
+ examples_per_page=10,
292
+ )
293
+
294
+ with gr.Column():
295
+ with gr.Tabs(elem_id='openlrm_input_video'):
296
+ with gr.TabItem('Input Video'):
297
+ with gr.Row():
298
+ video_input = gr.Video(label='Input Video',
299
+ height=480,
300
+ width=270,
301
+ interactive=False)
302
+
303
+ examples = [
304
+ './asset/sample_input/demo.mp4',
305
+ ]
306
+
307
+ gr.Examples(
308
+ examples=examples,
309
+ inputs=[video_input],
310
+ examples_per_page=20,
311
+ )
312
+ with gr.Column(variant='panel', scale=1):
313
+ with gr.Tabs(elem_id='openlrm_processed_image'):
314
+ with gr.TabItem('Processed Image'):
315
+ with gr.Row():
316
+ processed_image = gr.Image(
317
+ label='Processed Image',
318
+ image_mode='RGB',
319
+ type='filepath',
320
+ elem_id='processed_image',
321
+ height=480,
322
+ width=270,
323
+ interactive=False)
324
+
325
+ with gr.Column(variant='panel', scale=1):
326
+ with gr.Tabs(elem_id='openlrm_render_video'):
327
+ with gr.TabItem('Rendered Video'):
328
+ with gr.Row():
329
+ output_video = gr.Video(label='Rendered Video',
330
+ format='mp4',
331
+ height=480,
332
+ width=270,
333
+ autoplay=True)
334
+
335
+ # SETTING
336
+ with gr.Row():
337
+ with gr.Column(variant='panel', scale=1):
338
+ submit = gr.Button('Generate',
339
+ elem_id='openlrm_generate',
340
+ variant='primary')
341
+
342
+ working_dir = gr.State()
343
+ submit.click(
344
+ fn=assert_input_image,
345
+ inputs=[input_image],
346
+ queue=False,
347
+ ).success(
348
+ fn=prepare_working_dir,
349
+ outputs=[working_dir],
350
+ queue=False,
351
+ ).success(
352
+ fn=core_fn,
353
+ inputs=[input_image, video_input,
354
+ working_dir], # video_params refer to smpl dir
355
+ outputs=[processed_image, output_video],
356
+ )
357
+
358
+ demo.queue(max_size=1)
359
+ demo.launch()
360
+
361
+
362
+ def launch_gradio_app():
363
+
364
+ os.environ.update({
365
+ 'APP_ENABLED': '1',
366
+ 'APP_MODEL_NAME':
367
+ './exps/releases/video_human_benchmark/human-lrm-500M/step_060000/',
368
+ 'APP_INFER': './configs/inference/human-lrm-500M.yaml',
369
+ 'APP_TYPE': 'infer.human_lrm',
370
+ 'NUMBA_THREADING_LAYER': 'omp',
371
+ })
372
+
373
+ flametracking = FlameTrackingSingleImage(output_dir='tracking_output',
374
+ alignment_model_path='./pretrain_model/68_keypoints_model.pkl',
375
+ vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd',
376
+ human_matting_path='./pretrain_model/matting/stylematte_synth.pt',
377
+ facebox_model_path='./pretrain_model/FaceBoxesV2.pth',
378
+ detect_iris_landmarks=True)
379
+
380
+
381
+ demo_lhm(flametracking)
382
+
383
+
384
+ if __name__ == '__main__':
385
+ launch_pretrained()
386
+ launch_gradio_app()
387
+
configs/inference/lam-20k-8gpu.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment:
3
+ type: lam
4
+ seed: 42
5
+ parent: lam
6
+ child: lam_20k
7
+ model:
8
+ # image encoder
9
+ encoder_type: "dinov2_fusion"
10
+ encoder_model_name: "dinov2_vitl14_reg"
11
+ encoder_feat_dim: 1024
12
+ encoder_freeze: false
13
+
14
+ # points embeddings
15
+ latent_query_points_type: "e2e_flame"
16
+ pcl_dim: 1024
17
+
18
+ # transformer
19
+ transformer_type: "sd3_cond"
20
+ transformer_heads: 16
21
+ transformer_dim: 1024
22
+ transformer_layers: 10
23
+ tf_grad_ckpt: true
24
+ encoder_grad_ckpt: true
25
+
26
+ # for gs renderer
27
+ human_model_path: "./pretrained_models/human_model_files"
28
+ flame_subdivide_num: 1
29
+ flame_type: "flame"
30
+ gs_query_dim: 1024
31
+ gs_use_rgb: True
32
+ gs_sh: 3
33
+ gs_mlp_network_config:
34
+ n_neurons: 512
35
+ n_hidden_layers: 2
36
+ activation: silu
37
+ gs_xyz_offset_max_step: 0.2
38
+ gs_clip_scaling: 0.01
39
+ scale_sphere: false
40
+
41
+ expr_param_dim: 10
42
+ shape_param_dim: 10
43
+ add_teeth: false
44
+
45
+ fix_opacity: false
46
+ fix_rotation: false
47
+
48
+ has_disc: false
49
+
50
+ teeth_bs_flag: false
51
+ oral_mesh_flag: false
52
+
53
+ dataset:
54
+ subsets:
55
+ - name: video_head
56
+ root_dirs: "./train_data/vfhq_vhap_nooffset/export"
57
+ meta_path:
58
+ train: "./train_data/vfhq_vhap_nooffset/label/valid_id_train_list.json"
59
+ val: "./train_data/vfhq_vhap_nooffset/label/valid_id_val_list.json"
60
+ sample_rate: 1.0
61
+ sample_side_views: 7
62
+ sample_aug_views: 0
63
+ source_image_res: 512
64
+ render_image:
65
+ low: 512
66
+ high: 512
67
+ region: null
68
+ num_train_workers: 4
69
+ num_val_workers: 2
70
+ pin_mem: true
71
+ repeat_num: 1
72
+ gaga_track_type: "vfhq"
73
+
74
+ train:
75
+ mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE
76
+ find_unused_parameters: false
77
+ loss:
78
+ pixel_weight: 0.0
79
+ pixel_loss_fn: "mse"
80
+ crop_face_weight: 0.
81
+ crop_mouth_weight: 0.
82
+ crop_eye_weight: 0.
83
+ masked_pixel_weight: 1.0
84
+ perceptual_weight: 1.0
85
+ tv_weight: -1
86
+ mask_weight: 0:1.0:0.5:10000
87
+ offset_reg_weight: 0.1
88
+ optim:
89
+ lr: 4e-4
90
+ weight_decay: 0.05
91
+ beta1: 0.9
92
+ beta2: 0.95
93
+ clip_grad_norm: 1.0
94
+ scheduler:
95
+ type: cosine
96
+ warmup_real_iters: 3000
97
+ batch_size: 4 # REPLACE THIS (PER GPU)
98
+ accum_steps: 1 # REPLACE THIS
99
+ epochs: 100 # REPLACE THIS
100
+ debug_global_steps: null
101
+ resume: ""
102
+
103
+ val:
104
+ batch_size: 2
105
+ global_step_period: 500
106
+ debug_batches: 10
107
+
108
+ saver:
109
+ auto_resume: true
110
+ load_model: null
111
+ checkpoint_root: ./exps/checkpoints
112
+ checkpoint_global_steps: 500
113
+ checkpoint_keep_level: 5
114
+
115
+ logger:
116
+ stream_level: WARNING
117
+ log_level: INFO
118
+ log_root: ./exps/logs
119
+ tracker_root: ./exps/trackers
120
+ enable_profiler: false
121
+ trackers:
122
+ - tensorboard
123
+ image_monitor:
124
+ train_global_steps: 500
125
+ samples_per_log: 4
126
+
127
+ compile:
128
+ suppress_errors: true
129
+ print_specializations: true
130
+ disable: true
configs/stylematte_config.json ADDED
@@ -0,0 +1,2311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "activation_function": "relu",
4
+ "architectures": [
5
+ "Mask2FormerForUniversalSegmentation"
6
+ ],
7
+ "backbone_config": {
8
+ "_name_or_path": "",
9
+ "add_cross_attention": false,
10
+ "architectures": [
11
+ "SwinForImageClassification"
12
+ ],
13
+ "attention_probs_dropout_prob": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "depths": [
21
+ 2,
22
+ 2,
23
+ 6,
24
+ 2
25
+ ],
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "drop_path_rate": 0.3,
29
+ "early_stopping": false,
30
+ "embed_dim": 96,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "encoder_stride": 32,
33
+ "eos_token_id": null,
34
+ "exponential_decay_length_penalty": null,
35
+ "finetuning_task": null,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "hidden_act": "gelu",
39
+ "hidden_dropout_prob": 0.0,
40
+ "hidden_size": 768,
41
+ "id2label": {
42
+ "0": "tench, Tinca tinca",
43
+ "1": "goldfish, Carassius auratus",
44
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
45
+ "3": "tiger shark, Galeocerdo cuvieri",
46
+ "4": "hammerhead, hammerhead shark",
47
+ "5": "electric ray, crampfish, numbfish, torpedo",
48
+ "6": "stingray",
49
+ "7": "cock",
50
+ "8": "hen",
51
+ "9": "ostrich, Struthio camelus",
52
+ "10": "brambling, Fringilla montifringilla",
53
+ "11": "goldfinch, Carduelis carduelis",
54
+ "12": "house finch, linnet, Carpodacus mexicanus",
55
+ "13": "junco, snowbird",
56
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
57
+ "15": "robin, American robin, Turdus migratorius",
58
+ "16": "bulbul",
59
+ "17": "jay",
60
+ "18": "magpie",
61
+ "19": "chickadee",
62
+ "20": "water ouzel, dipper",
63
+ "21": "kite",
64
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
65
+ "23": "vulture",
66
+ "24": "great grey owl, great gray owl, Strix nebulosa",
67
+ "25": "European fire salamander, Salamandra salamandra",
68
+ "26": "common newt, Triturus vulgaris",
69
+ "27": "eft",
70
+ "28": "spotted salamander, Ambystoma maculatum",
71
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
72
+ "30": "bullfrog, Rana catesbeiana",
73
+ "31": "tree frog, tree-frog",
74
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
75
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
76
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
77
+ "35": "mud turtle",
78
+ "36": "terrapin",
79
+ "37": "box turtle, box tortoise",
80
+ "38": "banded gecko",
81
+ "39": "common iguana, iguana, Iguana iguana",
82
+ "40": "American chameleon, anole, Anolis carolinensis",
83
+ "41": "whiptail, whiptail lizard",
84
+ "42": "agama",
85
+ "43": "frilled lizard, Chlamydosaurus kingi",
86
+ "44": "alligator lizard",
87
+ "45": "Gila monster, Heloderma suspectum",
88
+ "46": "green lizard, Lacerta viridis",
89
+ "47": "African chameleon, Chamaeleo chamaeleon",
90
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
91
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
92
+ "50": "American alligator, Alligator mississipiensis",
93
+ "51": "triceratops",
94
+ "52": "thunder snake, worm snake, Carphophis amoenus",
95
+ "53": "ringneck snake, ring-necked snake, ring snake",
96
+ "54": "hognose snake, puff adder, sand viper",
97
+ "55": "green snake, grass snake",
98
+ "56": "king snake, kingsnake",
99
+ "57": "garter snake, grass snake",
100
+ "58": "water snake",
101
+ "59": "vine snake",
102
+ "60": "night snake, Hypsiglena torquata",
103
+ "61": "boa constrictor, Constrictor constrictor",
104
+ "62": "rock python, rock snake, Python sebae",
105
+ "63": "Indian cobra, Naja naja",
106
+ "64": "green mamba",
107
+ "65": "sea snake",
108
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
109
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
110
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
111
+ "69": "trilobite",
112
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
113
+ "71": "scorpion",
114
+ "72": "black and gold garden spider, Argiope aurantia",
115
+ "73": "barn spider, Araneus cavaticus",
116
+ "74": "garden spider, Aranea diademata",
117
+ "75": "black widow, Latrodectus mactans",
118
+ "76": "tarantula",
119
+ "77": "wolf spider, hunting spider",
120
+ "78": "tick",
121
+ "79": "centipede",
122
+ "80": "black grouse",
123
+ "81": "ptarmigan",
124
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
125
+ "83": "prairie chicken, prairie grouse, prairie fowl",
126
+ "84": "peacock",
127
+ "85": "quail",
128
+ "86": "partridge",
129
+ "87": "African grey, African gray, Psittacus erithacus",
130
+ "88": "macaw",
131
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
132
+ "90": "lorikeet",
133
+ "91": "coucal",
134
+ "92": "bee eater",
135
+ "93": "hornbill",
136
+ "94": "hummingbird",
137
+ "95": "jacamar",
138
+ "96": "toucan",
139
+ "97": "drake",
140
+ "98": "red-breasted merganser, Mergus serrator",
141
+ "99": "goose",
142
+ "100": "black swan, Cygnus atratus",
143
+ "101": "tusker",
144
+ "102": "echidna, spiny anteater, anteater",
145
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
146
+ "104": "wallaby, brush kangaroo",
147
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
148
+ "106": "wombat",
149
+ "107": "jellyfish",
150
+ "108": "sea anemone, anemone",
151
+ "109": "brain coral",
152
+ "110": "flatworm, platyhelminth",
153
+ "111": "nematode, nematode worm, roundworm",
154
+ "112": "conch",
155
+ "113": "snail",
156
+ "114": "slug",
157
+ "115": "sea slug, nudibranch",
158
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
159
+ "117": "chambered nautilus, pearly nautilus, nautilus",
160
+ "118": "Dungeness crab, Cancer magister",
161
+ "119": "rock crab, Cancer irroratus",
162
+ "120": "fiddler crab",
163
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
164
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
165
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
166
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
167
+ "125": "hermit crab",
168
+ "126": "isopod",
169
+ "127": "white stork, Ciconia ciconia",
170
+ "128": "black stork, Ciconia nigra",
171
+ "129": "spoonbill",
172
+ "130": "flamingo",
173
+ "131": "little blue heron, Egretta caerulea",
174
+ "132": "American egret, great white heron, Egretta albus",
175
+ "133": "bittern",
176
+ "134": "crane",
177
+ "135": "limpkin, Aramus pictus",
178
+ "136": "European gallinule, Porphyrio porphyrio",
179
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
180
+ "138": "bustard",
181
+ "139": "ruddy turnstone, Arenaria interpres",
182
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
183
+ "141": "redshank, Tringa totanus",
184
+ "142": "dowitcher",
185
+ "143": "oystercatcher, oyster catcher",
186
+ "144": "pelican",
187
+ "145": "king penguin, Aptenodytes patagonica",
188
+ "146": "albatross, mollymawk",
189
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
190
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
191
+ "149": "dugong, Dugong dugon",
192
+ "150": "sea lion",
193
+ "151": "Chihuahua",
194
+ "152": "Japanese spaniel",
195
+ "153": "Maltese dog, Maltese terrier, Maltese",
196
+ "154": "Pekinese, Pekingese, Peke",
197
+ "155": "Shih-Tzu",
198
+ "156": "Blenheim spaniel",
199
+ "157": "papillon",
200
+ "158": "toy terrier",
201
+ "159": "Rhodesian ridgeback",
202
+ "160": "Afghan hound, Afghan",
203
+ "161": "basset, basset hound",
204
+ "162": "beagle",
205
+ "163": "bloodhound, sleuthhound",
206
+ "164": "bluetick",
207
+ "165": "black-and-tan coonhound",
208
+ "166": "Walker hound, Walker foxhound",
209
+ "167": "English foxhound",
210
+ "168": "redbone",
211
+ "169": "borzoi, Russian wolfhound",
212
+ "170": "Irish wolfhound",
213
+ "171": "Italian greyhound",
214
+ "172": "whippet",
215
+ "173": "Ibizan hound, Ibizan Podenco",
216
+ "174": "Norwegian elkhound, elkhound",
217
+ "175": "otterhound, otter hound",
218
+ "176": "Saluki, gazelle hound",
219
+ "177": "Scottish deerhound, deerhound",
220
+ "178": "Weimaraner",
221
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
222
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
223
+ "181": "Bedlington terrier",
224
+ "182": "Border terrier",
225
+ "183": "Kerry blue terrier",
226
+ "184": "Irish terrier",
227
+ "185": "Norfolk terrier",
228
+ "186": "Norwich terrier",
229
+ "187": "Yorkshire terrier",
230
+ "188": "wire-haired fox terrier",
231
+ "189": "Lakeland terrier",
232
+ "190": "Sealyham terrier, Sealyham",
233
+ "191": "Airedale, Airedale terrier",
234
+ "192": "cairn, cairn terrier",
235
+ "193": "Australian terrier",
236
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
237
+ "195": "Boston bull, Boston terrier",
238
+ "196": "miniature schnauzer",
239
+ "197": "giant schnauzer",
240
+ "198": "standard schnauzer",
241
+ "199": "Scotch terrier, Scottish terrier, Scottie",
242
+ "200": "Tibetan terrier, chrysanthemum dog",
243
+ "201": "silky terrier, Sydney silky",
244
+ "202": "soft-coated wheaten terrier",
245
+ "203": "West Highland white terrier",
246
+ "204": "Lhasa, Lhasa apso",
247
+ "205": "flat-coated retriever",
248
+ "206": "curly-coated retriever",
249
+ "207": "golden retriever",
250
+ "208": "Labrador retriever",
251
+ "209": "Chesapeake Bay retriever",
252
+ "210": "German short-haired pointer",
253
+ "211": "vizsla, Hungarian pointer",
254
+ "212": "English setter",
255
+ "213": "Irish setter, red setter",
256
+ "214": "Gordon setter",
257
+ "215": "Brittany spaniel",
258
+ "216": "clumber, clumber spaniel",
259
+ "217": "English springer, English springer spaniel",
260
+ "218": "Welsh springer spaniel",
261
+ "219": "cocker spaniel, English cocker spaniel, cocker",
262
+ "220": "Sussex spaniel",
263
+ "221": "Irish water spaniel",
264
+ "222": "kuvasz",
265
+ "223": "schipperke",
266
+ "224": "groenendael",
267
+ "225": "malinois",
268
+ "226": "briard",
269
+ "227": "kelpie",
270
+ "228": "komondor",
271
+ "229": "Old English sheepdog, bobtail",
272
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
273
+ "231": "collie",
274
+ "232": "Border collie",
275
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
276
+ "234": "Rottweiler",
277
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
278
+ "236": "Doberman, Doberman pinscher",
279
+ "237": "miniature pinscher",
280
+ "238": "Greater Swiss Mountain dog",
281
+ "239": "Bernese mountain dog",
282
+ "240": "Appenzeller",
283
+ "241": "EntleBucher",
284
+ "242": "boxer",
285
+ "243": "bull mastiff",
286
+ "244": "Tibetan mastiff",
287
+ "245": "French bulldog",
288
+ "246": "Great Dane",
289
+ "247": "Saint Bernard, St Bernard",
290
+ "248": "Eskimo dog, husky",
291
+ "249": "malamute, malemute, Alaskan malamute",
292
+ "250": "Siberian husky",
293
+ "251": "dalmatian, coach dog, carriage dog",
294
+ "252": "affenpinscher, monkey pinscher, monkey dog",
295
+ "253": "basenji",
296
+ "254": "pug, pug-dog",
297
+ "255": "Leonberg",
298
+ "256": "Newfoundland, Newfoundland dog",
299
+ "257": "Great Pyrenees",
300
+ "258": "Samoyed, Samoyede",
301
+ "259": "Pomeranian",
302
+ "260": "chow, chow chow",
303
+ "261": "keeshond",
304
+ "262": "Brabancon griffon",
305
+ "263": "Pembroke, Pembroke Welsh corgi",
306
+ "264": "Cardigan, Cardigan Welsh corgi",
307
+ "265": "toy poodle",
308
+ "266": "miniature poodle",
309
+ "267": "standard poodle",
310
+ "268": "Mexican hairless",
311
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
312
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
313
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
314
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
315
+ "273": "dingo, warrigal, warragal, Canis dingo",
316
+ "274": "dhole, Cuon alpinus",
317
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
318
+ "276": "hyena, hyaena",
319
+ "277": "red fox, Vulpes vulpes",
320
+ "278": "kit fox, Vulpes macrotis",
321
+ "279": "Arctic fox, white fox, Alopex lagopus",
322
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
323
+ "281": "tabby, tabby cat",
324
+ "282": "tiger cat",
325
+ "283": "Persian cat",
326
+ "284": "Siamese cat, Siamese",
327
+ "285": "Egyptian cat",
328
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
329
+ "287": "lynx, catamount",
330
+ "288": "leopard, Panthera pardus",
331
+ "289": "snow leopard, ounce, Panthera uncia",
332
+ "290": "jaguar, panther, Panthera onca, Felis onca",
333
+ "291": "lion, king of beasts, Panthera leo",
334
+ "292": "tiger, Panthera tigris",
335
+ "293": "cheetah, chetah, Acinonyx jubatus",
336
+ "294": "brown bear, bruin, Ursus arctos",
337
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
338
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
339
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
340
+ "298": "mongoose",
341
+ "299": "meerkat, mierkat",
342
+ "300": "tiger beetle",
343
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
344
+ "302": "ground beetle, carabid beetle",
345
+ "303": "long-horned beetle, longicorn, longicorn beetle",
346
+ "304": "leaf beetle, chrysomelid",
347
+ "305": "dung beetle",
348
+ "306": "rhinoceros beetle",
349
+ "307": "weevil",
350
+ "308": "fly",
351
+ "309": "bee",
352
+ "310": "ant, emmet, pismire",
353
+ "311": "grasshopper, hopper",
354
+ "312": "cricket",
355
+ "313": "walking stick, walkingstick, stick insect",
356
+ "314": "cockroach, roach",
357
+ "315": "mantis, mantid",
358
+ "316": "cicada, cicala",
359
+ "317": "leafhopper",
360
+ "318": "lacewing, lacewing fly",
361
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
362
+ "320": "damselfly",
363
+ "321": "admiral",
364
+ "322": "ringlet, ringlet butterfly",
365
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
366
+ "324": "cabbage butterfly",
367
+ "325": "sulphur butterfly, sulfur butterfly",
368
+ "326": "lycaenid, lycaenid butterfly",
369
+ "327": "starfish, sea star",
370
+ "328": "sea urchin",
371
+ "329": "sea cucumber, holothurian",
372
+ "330": "wood rabbit, cottontail, cottontail rabbit",
373
+ "331": "hare",
374
+ "332": "Angora, Angora rabbit",
375
+ "333": "hamster",
376
+ "334": "porcupine, hedgehog",
377
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
378
+ "336": "marmot",
379
+ "337": "beaver",
380
+ "338": "guinea pig, Cavia cobaya",
381
+ "339": "sorrel",
382
+ "340": "zebra",
383
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
384
+ "342": "wild boar, boar, Sus scrofa",
385
+ "343": "warthog",
386
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
387
+ "345": "ox",
388
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
389
+ "347": "bison",
390
+ "348": "ram, tup",
391
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
392
+ "350": "ibex, Capra ibex",
393
+ "351": "hartebeest",
394
+ "352": "impala, Aepyceros melampus",
395
+ "353": "gazelle",
396
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
397
+ "355": "llama",
398
+ "356": "weasel",
399
+ "357": "mink",
400
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
401
+ "359": "black-footed ferret, ferret, Mustela nigripes",
402
+ "360": "otter",
403
+ "361": "skunk, polecat, wood pussy",
404
+ "362": "badger",
405
+ "363": "armadillo",
406
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
407
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
408
+ "366": "gorilla, Gorilla gorilla",
409
+ "367": "chimpanzee, chimp, Pan troglodytes",
410
+ "368": "gibbon, Hylobates lar",
411
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
412
+ "370": "guenon, guenon monkey",
413
+ "371": "patas, hussar monkey, Erythrocebus patas",
414
+ "372": "baboon",
415
+ "373": "macaque",
416
+ "374": "langur",
417
+ "375": "colobus, colobus monkey",
418
+ "376": "proboscis monkey, Nasalis larvatus",
419
+ "377": "marmoset",
420
+ "378": "capuchin, ringtail, Cebus capucinus",
421
+ "379": "howler monkey, howler",
422
+ "380": "titi, titi monkey",
423
+ "381": "spider monkey, Ateles geoffroyi",
424
+ "382": "squirrel monkey, Saimiri sciureus",
425
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
426
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
427
+ "385": "Indian elephant, Elephas maximus",
428
+ "386": "African elephant, Loxodonta africana",
429
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
430
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
431
+ "389": "barracouta, snoek",
432
+ "390": "eel",
433
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
434
+ "392": "rock beauty, Holocanthus tricolor",
435
+ "393": "anemone fish",
436
+ "394": "sturgeon",
437
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
438
+ "396": "lionfish",
439
+ "397": "puffer, pufferfish, blowfish, globefish",
440
+ "398": "abacus",
441
+ "399": "abaya",
442
+ "400": "academic gown, academic robe, judge's robe",
443
+ "401": "accordion, piano accordion, squeeze box",
444
+ "402": "acoustic guitar",
445
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
446
+ "404": "airliner",
447
+ "405": "airship, dirigible",
448
+ "406": "altar",
449
+ "407": "ambulance",
450
+ "408": "amphibian, amphibious vehicle",
451
+ "409": "analog clock",
452
+ "410": "apiary, bee house",
453
+ "411": "apron",
454
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
455
+ "413": "assault rifle, assault gun",
456
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
457
+ "415": "bakery, bakeshop, bakehouse",
458
+ "416": "balance beam, beam",
459
+ "417": "balloon",
460
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
461
+ "419": "Band Aid",
462
+ "420": "banjo",
463
+ "421": "bannister, banister, balustrade, balusters, handrail",
464
+ "422": "barbell",
465
+ "423": "barber chair",
466
+ "424": "barbershop",
467
+ "425": "barn",
468
+ "426": "barometer",
469
+ "427": "barrel, cask",
470
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
471
+ "429": "baseball",
472
+ "430": "basketball",
473
+ "431": "bassinet",
474
+ "432": "bassoon",
475
+ "433": "bathing cap, swimming cap",
476
+ "434": "bath towel",
477
+ "435": "bathtub, bathing tub, bath, tub",
478
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
479
+ "437": "beacon, lighthouse, beacon light, pharos",
480
+ "438": "beaker",
481
+ "439": "bearskin, busby, shako",
482
+ "440": "beer bottle",
483
+ "441": "beer glass",
484
+ "442": "bell cote, bell cot",
485
+ "443": "bib",
486
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
487
+ "445": "bikini, two-piece",
488
+ "446": "binder, ring-binder",
489
+ "447": "binoculars, field glasses, opera glasses",
490
+ "448": "birdhouse",
491
+ "449": "boathouse",
492
+ "450": "bobsled, bobsleigh, bob",
493
+ "451": "bolo tie, bolo, bola tie, bola",
494
+ "452": "bonnet, poke bonnet",
495
+ "453": "bookcase",
496
+ "454": "bookshop, bookstore, bookstall",
497
+ "455": "bottlecap",
498
+ "456": "bow",
499
+ "457": "bow tie, bow-tie, bowtie",
500
+ "458": "brass, memorial tablet, plaque",
501
+ "459": "brassiere, bra, bandeau",
502
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
503
+ "461": "breastplate, aegis, egis",
504
+ "462": "broom",
505
+ "463": "bucket, pail",
506
+ "464": "buckle",
507
+ "465": "bulletproof vest",
508
+ "466": "bullet train, bullet",
509
+ "467": "butcher shop, meat market",
510
+ "468": "cab, hack, taxi, taxicab",
511
+ "469": "caldron, cauldron",
512
+ "470": "candle, taper, wax light",
513
+ "471": "cannon",
514
+ "472": "canoe",
515
+ "473": "can opener, tin opener",
516
+ "474": "cardigan",
517
+ "475": "car mirror",
518
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
519
+ "477": "carpenter's kit, tool kit",
520
+ "478": "carton",
521
+ "479": "car wheel",
522
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
523
+ "481": "cassette",
524
+ "482": "cassette player",
525
+ "483": "castle",
526
+ "484": "catamaran",
527
+ "485": "CD player",
528
+ "486": "cello, violoncello",
529
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
530
+ "488": "chain",
531
+ "489": "chainlink fence",
532
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
533
+ "491": "chain saw, chainsaw",
534
+ "492": "chest",
535
+ "493": "chiffonier, commode",
536
+ "494": "chime, bell, gong",
537
+ "495": "china cabinet, china closet",
538
+ "496": "Christmas stocking",
539
+ "497": "church, church building",
540
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
541
+ "499": "cleaver, meat cleaver, chopper",
542
+ "500": "cliff dwelling",
543
+ "501": "cloak",
544
+ "502": "clog, geta, patten, sabot",
545
+ "503": "cocktail shaker",
546
+ "504": "coffee mug",
547
+ "505": "coffeepot",
548
+ "506": "coil, spiral, volute, whorl, helix",
549
+ "507": "combination lock",
550
+ "508": "computer keyboard, keypad",
551
+ "509": "confectionery, confectionary, candy store",
552
+ "510": "container ship, containership, container vessel",
553
+ "511": "convertible",
554
+ "512": "corkscrew, bottle screw",
555
+ "513": "cornet, horn, trumpet, trump",
556
+ "514": "cowboy boot",
557
+ "515": "cowboy hat, ten-gallon hat",
558
+ "516": "cradle",
559
+ "517": "crane",
560
+ "518": "crash helmet",
561
+ "519": "crate",
562
+ "520": "crib, cot",
563
+ "521": "Crock Pot",
564
+ "522": "croquet ball",
565
+ "523": "crutch",
566
+ "524": "cuirass",
567
+ "525": "dam, dike, dyke",
568
+ "526": "desk",
569
+ "527": "desktop computer",
570
+ "528": "dial telephone, dial phone",
571
+ "529": "diaper, nappy, napkin",
572
+ "530": "digital clock",
573
+ "531": "digital watch",
574
+ "532": "dining table, board",
575
+ "533": "dishrag, dishcloth",
576
+ "534": "dishwasher, dish washer, dishwashing machine",
577
+ "535": "disk brake, disc brake",
578
+ "536": "dock, dockage, docking facility",
579
+ "537": "dogsled, dog sled, dog sleigh",
580
+ "538": "dome",
581
+ "539": "doormat, welcome mat",
582
+ "540": "drilling platform, offshore rig",
583
+ "541": "drum, membranophone, tympan",
584
+ "542": "drumstick",
585
+ "543": "dumbbell",
586
+ "544": "Dutch oven",
587
+ "545": "electric fan, blower",
588
+ "546": "electric guitar",
589
+ "547": "electric locomotive",
590
+ "548": "entertainment center",
591
+ "549": "envelope",
592
+ "550": "espresso maker",
593
+ "551": "face powder",
594
+ "552": "feather boa, boa",
595
+ "553": "file, file cabinet, filing cabinet",
596
+ "554": "fireboat",
597
+ "555": "fire engine, fire truck",
598
+ "556": "fire screen, fireguard",
599
+ "557": "flagpole, flagstaff",
600
+ "558": "flute, transverse flute",
601
+ "559": "folding chair",
602
+ "560": "football helmet",
603
+ "561": "forklift",
604
+ "562": "fountain",
605
+ "563": "fountain pen",
606
+ "564": "four-poster",
607
+ "565": "freight car",
608
+ "566": "French horn, horn",
609
+ "567": "frying pan, frypan, skillet",
610
+ "568": "fur coat",
611
+ "569": "garbage truck, dustcart",
612
+ "570": "gasmask, respirator, gas helmet",
613
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
614
+ "572": "goblet",
615
+ "573": "go-kart",
616
+ "574": "golf ball",
617
+ "575": "golfcart, golf cart",
618
+ "576": "gondola",
619
+ "577": "gong, tam-tam",
620
+ "578": "gown",
621
+ "579": "grand piano, grand",
622
+ "580": "greenhouse, nursery, glasshouse",
623
+ "581": "grille, radiator grille",
624
+ "582": "grocery store, grocery, food market, market",
625
+ "583": "guillotine",
626
+ "584": "hair slide",
627
+ "585": "hair spray",
628
+ "586": "half track",
629
+ "587": "hammer",
630
+ "588": "hamper",
631
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
632
+ "590": "hand-held computer, hand-held microcomputer",
633
+ "591": "handkerchief, hankie, hanky, hankey",
634
+ "592": "hard disc, hard disk, fixed disk",
635
+ "593": "harmonica, mouth organ, harp, mouth harp",
636
+ "594": "harp",
637
+ "595": "harvester, reaper",
638
+ "596": "hatchet",
639
+ "597": "holster",
640
+ "598": "home theater, home theatre",
641
+ "599": "honeycomb",
642
+ "600": "hook, claw",
643
+ "601": "hoopskirt, crinoline",
644
+ "602": "horizontal bar, high bar",
645
+ "603": "horse cart, horse-cart",
646
+ "604": "hourglass",
647
+ "605": "iPod",
648
+ "606": "iron, smoothing iron",
649
+ "607": "jack-o'-lantern",
650
+ "608": "jean, blue jean, denim",
651
+ "609": "jeep, landrover",
652
+ "610": "jersey, T-shirt, tee shirt",
653
+ "611": "jigsaw puzzle",
654
+ "612": "jinrikisha, ricksha, rickshaw",
655
+ "613": "joystick",
656
+ "614": "kimono",
657
+ "615": "knee pad",
658
+ "616": "knot",
659
+ "617": "lab coat, laboratory coat",
660
+ "618": "ladle",
661
+ "619": "lampshade, lamp shade",
662
+ "620": "laptop, laptop computer",
663
+ "621": "lawn mower, mower",
664
+ "622": "lens cap, lens cover",
665
+ "623": "letter opener, paper knife, paperknife",
666
+ "624": "library",
667
+ "625": "lifeboat",
668
+ "626": "lighter, light, igniter, ignitor",
669
+ "627": "limousine, limo",
670
+ "628": "liner, ocean liner",
671
+ "629": "lipstick, lip rouge",
672
+ "630": "Loafer",
673
+ "631": "lotion",
674
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
675
+ "633": "loupe, jeweler's loupe",
676
+ "634": "lumbermill, sawmill",
677
+ "635": "magnetic compass",
678
+ "636": "mailbag, postbag",
679
+ "637": "mailbox, letter box",
680
+ "638": "maillot",
681
+ "639": "maillot, tank suit",
682
+ "640": "manhole cover",
683
+ "641": "maraca",
684
+ "642": "marimba, xylophone",
685
+ "643": "mask",
686
+ "644": "matchstick",
687
+ "645": "maypole",
688
+ "646": "maze, labyrinth",
689
+ "647": "measuring cup",
690
+ "648": "medicine chest, medicine cabinet",
691
+ "649": "megalith, megalithic structure",
692
+ "650": "microphone, mike",
693
+ "651": "microwave, microwave oven",
694
+ "652": "military uniform",
695
+ "653": "milk can",
696
+ "654": "minibus",
697
+ "655": "miniskirt, mini",
698
+ "656": "minivan",
699
+ "657": "missile",
700
+ "658": "mitten",
701
+ "659": "mixing bowl",
702
+ "660": "mobile home, manufactured home",
703
+ "661": "Model T",
704
+ "662": "modem",
705
+ "663": "monastery",
706
+ "664": "monitor",
707
+ "665": "moped",
708
+ "666": "mortar",
709
+ "667": "mortarboard",
710
+ "668": "mosque",
711
+ "669": "mosquito net",
712
+ "670": "motor scooter, scooter",
713
+ "671": "mountain bike, all-terrain bike, off-roader",
714
+ "672": "mountain tent",
715
+ "673": "mouse, computer mouse",
716
+ "674": "mousetrap",
717
+ "675": "moving van",
718
+ "676": "muzzle",
719
+ "677": "nail",
720
+ "678": "neck brace",
721
+ "679": "necklace",
722
+ "680": "nipple",
723
+ "681": "notebook, notebook computer",
724
+ "682": "obelisk",
725
+ "683": "oboe, hautboy, hautbois",
726
+ "684": "ocarina, sweet potato",
727
+ "685": "odometer, hodometer, mileometer, milometer",
728
+ "686": "oil filter",
729
+ "687": "organ, pipe organ",
730
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
731
+ "689": "overskirt",
732
+ "690": "oxcart",
733
+ "691": "oxygen mask",
734
+ "692": "packet",
735
+ "693": "paddle, boat paddle",
736
+ "694": "paddlewheel, paddle wheel",
737
+ "695": "padlock",
738
+ "696": "paintbrush",
739
+ "697": "pajama, pyjama, pj's, jammies",
740
+ "698": "palace",
741
+ "699": "panpipe, pandean pipe, syrinx",
742
+ "700": "paper towel",
743
+ "701": "parachute, chute",
744
+ "702": "parallel bars, bars",
745
+ "703": "park bench",
746
+ "704": "parking meter",
747
+ "705": "passenger car, coach, carriage",
748
+ "706": "patio, terrace",
749
+ "707": "pay-phone, pay-station",
750
+ "708": "pedestal, plinth, footstall",
751
+ "709": "pencil box, pencil case",
752
+ "710": "pencil sharpener",
753
+ "711": "perfume, essence",
754
+ "712": "Petri dish",
755
+ "713": "photocopier",
756
+ "714": "pick, plectrum, plectron",
757
+ "715": "pickelhaube",
758
+ "716": "picket fence, paling",
759
+ "717": "pickup, pickup truck",
760
+ "718": "pier",
761
+ "719": "piggy bank, penny bank",
762
+ "720": "pill bottle",
763
+ "721": "pillow",
764
+ "722": "ping-pong ball",
765
+ "723": "pinwheel",
766
+ "724": "pirate, pirate ship",
767
+ "725": "pitcher, ewer",
768
+ "726": "plane, carpenter's plane, woodworking plane",
769
+ "727": "planetarium",
770
+ "728": "plastic bag",
771
+ "729": "plate rack",
772
+ "730": "plow, plough",
773
+ "731": "plunger, plumber's helper",
774
+ "732": "Polaroid camera, Polaroid Land camera",
775
+ "733": "pole",
776
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
777
+ "735": "poncho",
778
+ "736": "pool table, billiard table, snooker table",
779
+ "737": "pop bottle, soda bottle",
780
+ "738": "pot, flowerpot",
781
+ "739": "potter's wheel",
782
+ "740": "power drill",
783
+ "741": "prayer rug, prayer mat",
784
+ "742": "printer",
785
+ "743": "prison, prison house",
786
+ "744": "projectile, missile",
787
+ "745": "projector",
788
+ "746": "puck, hockey puck",
789
+ "747": "punching bag, punch bag, punching ball, punchball",
790
+ "748": "purse",
791
+ "749": "quill, quill pen",
792
+ "750": "quilt, comforter, comfort, puff",
793
+ "751": "racer, race car, racing car",
794
+ "752": "racket, racquet",
795
+ "753": "radiator",
796
+ "754": "radio, wireless",
797
+ "755": "radio telescope, radio reflector",
798
+ "756": "rain barrel",
799
+ "757": "recreational vehicle, RV, R.V.",
800
+ "758": "reel",
801
+ "759": "reflex camera",
802
+ "760": "refrigerator, icebox",
803
+ "761": "remote control, remote",
804
+ "762": "restaurant, eating house, eating place, eatery",
805
+ "763": "revolver, six-gun, six-shooter",
806
+ "764": "rifle",
807
+ "765": "rocking chair, rocker",
808
+ "766": "rotisserie",
809
+ "767": "rubber eraser, rubber, pencil eraser",
810
+ "768": "rugby ball",
811
+ "769": "rule, ruler",
812
+ "770": "running shoe",
813
+ "771": "safe",
814
+ "772": "safety pin",
815
+ "773": "saltshaker, salt shaker",
816
+ "774": "sandal",
817
+ "775": "sarong",
818
+ "776": "sax, saxophone",
819
+ "777": "scabbard",
820
+ "778": "scale, weighing machine",
821
+ "779": "school bus",
822
+ "780": "schooner",
823
+ "781": "scoreboard",
824
+ "782": "screen, CRT screen",
825
+ "783": "screw",
826
+ "784": "screwdriver",
827
+ "785": "seat belt, seatbelt",
828
+ "786": "sewing machine",
829
+ "787": "shield, buckler",
830
+ "788": "shoe shop, shoe-shop, shoe store",
831
+ "789": "shoji",
832
+ "790": "shopping basket",
833
+ "791": "shopping cart",
834
+ "792": "shovel",
835
+ "793": "shower cap",
836
+ "794": "shower curtain",
837
+ "795": "ski",
838
+ "796": "ski mask",
839
+ "797": "sleeping bag",
840
+ "798": "slide rule, slipstick",
841
+ "799": "sliding door",
842
+ "800": "slot, one-armed bandit",
843
+ "801": "snorkel",
844
+ "802": "snowmobile",
845
+ "803": "snowplow, snowplough",
846
+ "804": "soap dispenser",
847
+ "805": "soccer ball",
848
+ "806": "sock",
849
+ "807": "solar dish, solar collector, solar furnace",
850
+ "808": "sombrero",
851
+ "809": "soup bowl",
852
+ "810": "space bar",
853
+ "811": "space heater",
854
+ "812": "space shuttle",
855
+ "813": "spatula",
856
+ "814": "speedboat",
857
+ "815": "spider web, spider's web",
858
+ "816": "spindle",
859
+ "817": "sports car, sport car",
860
+ "818": "spotlight, spot",
861
+ "819": "stage",
862
+ "820": "steam locomotive",
863
+ "821": "steel arch bridge",
864
+ "822": "steel drum",
865
+ "823": "stethoscope",
866
+ "824": "stole",
867
+ "825": "stone wall",
868
+ "826": "stopwatch, stop watch",
869
+ "827": "stove",
870
+ "828": "strainer",
871
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
872
+ "830": "stretcher",
873
+ "831": "studio couch, day bed",
874
+ "832": "stupa, tope",
875
+ "833": "submarine, pigboat, sub, U-boat",
876
+ "834": "suit, suit of clothes",
877
+ "835": "sundial",
878
+ "836": "sunglass",
879
+ "837": "sunglasses, dark glasses, shades",
880
+ "838": "sunscreen, sunblock, sun blocker",
881
+ "839": "suspension bridge",
882
+ "840": "swab, swob, mop",
883
+ "841": "sweatshirt",
884
+ "842": "swimming trunks, bathing trunks",
885
+ "843": "swing",
886
+ "844": "switch, electric switch, electrical switch",
887
+ "845": "syringe",
888
+ "846": "table lamp",
889
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
890
+ "848": "tape player",
891
+ "849": "teapot",
892
+ "850": "teddy, teddy bear",
893
+ "851": "television, television system",
894
+ "852": "tennis ball",
895
+ "853": "thatch, thatched roof",
896
+ "854": "theater curtain, theatre curtain",
897
+ "855": "thimble",
898
+ "856": "thresher, thrasher, threshing machine",
899
+ "857": "throne",
900
+ "858": "tile roof",
901
+ "859": "toaster",
902
+ "860": "tobacco shop, tobacconist shop, tobacconist",
903
+ "861": "toilet seat",
904
+ "862": "torch",
905
+ "863": "totem pole",
906
+ "864": "tow truck, tow car, wrecker",
907
+ "865": "toyshop",
908
+ "866": "tractor",
909
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
910
+ "868": "tray",
911
+ "869": "trench coat",
912
+ "870": "tricycle, trike, velocipede",
913
+ "871": "trimaran",
914
+ "872": "tripod",
915
+ "873": "triumphal arch",
916
+ "874": "trolleybus, trolley coach, trackless trolley",
917
+ "875": "trombone",
918
+ "876": "tub, vat",
919
+ "877": "turnstile",
920
+ "878": "typewriter keyboard",
921
+ "879": "umbrella",
922
+ "880": "unicycle, monocycle",
923
+ "881": "upright, upright piano",
924
+ "882": "vacuum, vacuum cleaner",
925
+ "883": "vase",
926
+ "884": "vault",
927
+ "885": "velvet",
928
+ "886": "vending machine",
929
+ "887": "vestment",
930
+ "888": "viaduct",
931
+ "889": "violin, fiddle",
932
+ "890": "volleyball",
933
+ "891": "waffle iron",
934
+ "892": "wall clock",
935
+ "893": "wallet, billfold, notecase, pocketbook",
936
+ "894": "wardrobe, closet, press",
937
+ "895": "warplane, military plane",
938
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
939
+ "897": "washer, automatic washer, washing machine",
940
+ "898": "water bottle",
941
+ "899": "water jug",
942
+ "900": "water tower",
943
+ "901": "whiskey jug",
944
+ "902": "whistle",
945
+ "903": "wig",
946
+ "904": "window screen",
947
+ "905": "window shade",
948
+ "906": "Windsor tie",
949
+ "907": "wine bottle",
950
+ "908": "wing",
951
+ "909": "wok",
952
+ "910": "wooden spoon",
953
+ "911": "wool, woolen, woollen",
954
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
955
+ "913": "wreck",
956
+ "914": "yawl",
957
+ "915": "yurt",
958
+ "916": "web site, website, internet site, site",
959
+ "917": "comic book",
960
+ "918": "crossword puzzle, crossword",
961
+ "919": "street sign",
962
+ "920": "traffic light, traffic signal, stoplight",
963
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
964
+ "922": "menu",
965
+ "923": "plate",
966
+ "924": "guacamole",
967
+ "925": "consomme",
968
+ "926": "hot pot, hotpot",
969
+ "927": "trifle",
970
+ "928": "ice cream, icecream",
971
+ "929": "ice lolly, lolly, lollipop, popsicle",
972
+ "930": "French loaf",
973
+ "931": "bagel, beigel",
974
+ "932": "pretzel",
975
+ "933": "cheeseburger",
976
+ "934": "hotdog, hot dog, red hot",
977
+ "935": "mashed potato",
978
+ "936": "head cabbage",
979
+ "937": "broccoli",
980
+ "938": "cauliflower",
981
+ "939": "zucchini, courgette",
982
+ "940": "spaghetti squash",
983
+ "941": "acorn squash",
984
+ "942": "butternut squash",
985
+ "943": "cucumber, cuke",
986
+ "944": "artichoke, globe artichoke",
987
+ "945": "bell pepper",
988
+ "946": "cardoon",
989
+ "947": "mushroom",
990
+ "948": "Granny Smith",
991
+ "949": "strawberry",
992
+ "950": "orange",
993
+ "951": "lemon",
994
+ "952": "fig",
995
+ "953": "pineapple, ananas",
996
+ "954": "banana",
997
+ "955": "jackfruit, jak, jack",
998
+ "956": "custard apple",
999
+ "957": "pomegranate",
1000
+ "958": "hay",
1001
+ "959": "carbonara",
1002
+ "960": "chocolate sauce, chocolate syrup",
1003
+ "961": "dough",
1004
+ "962": "meat loaf, meatloaf",
1005
+ "963": "pizza, pizza pie",
1006
+ "964": "potpie",
1007
+ "965": "burrito",
1008
+ "966": "red wine",
1009
+ "967": "espresso",
1010
+ "968": "cup",
1011
+ "969": "eggnog",
1012
+ "970": "alp",
1013
+ "971": "bubble",
1014
+ "972": "cliff, drop, drop-off",
1015
+ "973": "coral reef",
1016
+ "974": "geyser",
1017
+ "975": "lakeside, lakeshore",
1018
+ "976": "promontory, headland, head, foreland",
1019
+ "977": "sandbar, sand bar",
1020
+ "978": "seashore, coast, seacoast, sea-coast",
1021
+ "979": "valley, vale",
1022
+ "980": "volcano",
1023
+ "981": "ballplayer, baseball player",
1024
+ "982": "groom, bridegroom",
1025
+ "983": "scuba diver",
1026
+ "984": "rapeseed",
1027
+ "985": "daisy",
1028
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1029
+ "987": "corn",
1030
+ "988": "acorn",
1031
+ "989": "hip, rose hip, rosehip",
1032
+ "990": "buckeye, horse chestnut, conker",
1033
+ "991": "coral fungus",
1034
+ "992": "agaric",
1035
+ "993": "gyromitra",
1036
+ "994": "stinkhorn, carrion fungus",
1037
+ "995": "earthstar",
1038
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1039
+ "997": "bolete",
1040
+ "998": "ear, spike, capitulum",
1041
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1042
+ },
1043
+ "image_size": 224,
1044
+ "initializer_range": 0.02,
1045
+ "is_decoder": false,
1046
+ "is_encoder_decoder": false,
1047
+ "label2id": {
1048
+ "Afghan hound, Afghan": 160,
1049
+ "African chameleon, Chamaeleo chamaeleon": 47,
1050
+ "African crocodile, Nile crocodile, Crocodylus niloticus": 49,
1051
+ "African elephant, Loxodonta africana": 386,
1052
+ "African grey, African gray, Psittacus erithacus": 87,
1053
+ "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
1054
+ "Airedale, Airedale terrier": 191,
1055
+ "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
1056
+ "American alligator, Alligator mississipiensis": 50,
1057
+ "American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
1058
+ "American chameleon, anole, Anolis carolinensis": 40,
1059
+ "American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
1060
+ "American egret, great white heron, Egretta albus": 132,
1061
+ "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
1062
+ "Angora, Angora rabbit": 332,
1063
+ "Appenzeller": 240,
1064
+ "Arabian camel, dromedary, Camelus dromedarius": 354,
1065
+ "Arctic fox, white fox, Alopex lagopus": 279,
1066
+ "Australian terrier": 193,
1067
+ "Band Aid": 419,
1068
+ "Bedlington terrier": 181,
1069
+ "Bernese mountain dog": 239,
1070
+ "Blenheim spaniel": 156,
1071
+ "Border collie": 232,
1072
+ "Border terrier": 182,
1073
+ "Boston bull, Boston terrier": 195,
1074
+ "Bouvier des Flandres, Bouviers des Flandres": 233,
1075
+ "Brabancon griffon": 262,
1076
+ "Brittany spaniel": 215,
1077
+ "CD player": 485,
1078
+ "Cardigan, Cardigan Welsh corgi": 264,
1079
+ "Chesapeake Bay retriever": 209,
1080
+ "Chihuahua": 151,
1081
+ "Christmas stocking": 496,
1082
+ "Crock Pot": 521,
1083
+ "Dandie Dinmont, Dandie Dinmont terrier": 194,
1084
+ "Doberman, Doberman pinscher": 236,
1085
+ "Dungeness crab, Cancer magister": 118,
1086
+ "Dutch oven": 544,
1087
+ "Egyptian cat": 285,
1088
+ "English foxhound": 167,
1089
+ "English setter": 212,
1090
+ "English springer, English springer spaniel": 217,
1091
+ "EntleBucher": 241,
1092
+ "Eskimo dog, husky": 248,
1093
+ "European fire salamander, Salamandra salamandra": 25,
1094
+ "European gallinule, Porphyrio porphyrio": 136,
1095
+ "French bulldog": 245,
1096
+ "French horn, horn": 566,
1097
+ "French loaf": 930,
1098
+ "German shepherd, German shepherd dog, German police dog, alsatian": 235,
1099
+ "German short-haired pointer": 210,
1100
+ "Gila monster, Heloderma suspectum": 45,
1101
+ "Gordon setter": 214,
1102
+ "Granny Smith": 948,
1103
+ "Great Dane": 246,
1104
+ "Great Pyrenees": 257,
1105
+ "Greater Swiss Mountain dog": 238,
1106
+ "Ibizan hound, Ibizan Podenco": 173,
1107
+ "Indian cobra, Naja naja": 63,
1108
+ "Indian elephant, Elephas maximus": 385,
1109
+ "Irish setter, red setter": 213,
1110
+ "Irish terrier": 184,
1111
+ "Irish water spaniel": 221,
1112
+ "Irish wolfhound": 170,
1113
+ "Italian greyhound": 171,
1114
+ "Japanese spaniel": 152,
1115
+ "Kerry blue terrier": 183,
1116
+ "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
1117
+ "Labrador retriever": 208,
1118
+ "Lakeland terrier": 189,
1119
+ "Leonberg": 255,
1120
+ "Lhasa, Lhasa apso": 204,
1121
+ "Loafer": 630,
1122
+ "Madagascar cat, ring-tailed lemur, Lemur catta": 383,
1123
+ "Maltese dog, Maltese terrier, Maltese": 153,
1124
+ "Mexican hairless": 268,
1125
+ "Model T": 661,
1126
+ "Newfoundland, Newfoundland dog": 256,
1127
+ "Norfolk terrier": 185,
1128
+ "Norwegian elkhound, elkhound": 174,
1129
+ "Norwich terrier": 186,
1130
+ "Old English sheepdog, bobtail": 229,
1131
+ "Pekinese, Pekingese, Peke": 154,
1132
+ "Pembroke, Pembroke Welsh corgi": 263,
1133
+ "Persian cat": 283,
1134
+ "Petri dish": 712,
1135
+ "Polaroid camera, Polaroid Land camera": 732,
1136
+ "Pomeranian": 259,
1137
+ "Rhodesian ridgeback": 159,
1138
+ "Rottweiler": 234,
1139
+ "Saint Bernard, St Bernard": 247,
1140
+ "Saluki, gazelle hound": 176,
1141
+ "Samoyed, Samoyede": 258,
1142
+ "Scotch terrier, Scottish terrier, Scottie": 199,
1143
+ "Scottish deerhound, deerhound": 177,
1144
+ "Sealyham terrier, Sealyham": 190,
1145
+ "Shetland sheepdog, Shetland sheep dog, Shetland": 230,
1146
+ "Shih-Tzu": 155,
1147
+ "Siamese cat, Siamese": 284,
1148
+ "Siberian husky": 250,
1149
+ "Staffordshire bullterrier, Staffordshire bull terrier": 179,
1150
+ "Sussex spaniel": 220,
1151
+ "Tibetan mastiff": 244,
1152
+ "Tibetan terrier, chrysanthemum dog": 200,
1153
+ "Walker hound, Walker foxhound": 166,
1154
+ "Weimaraner": 178,
1155
+ "Welsh springer spaniel": 218,
1156
+ "West Highland white terrier": 203,
1157
+ "Windsor tie": 906,
1158
+ "Yorkshire terrier": 187,
1159
+ "abacus": 398,
1160
+ "abaya": 399,
1161
+ "academic gown, academic robe, judge's robe": 400,
1162
+ "accordion, piano accordion, squeeze box": 401,
1163
+ "acorn": 988,
1164
+ "acorn squash": 941,
1165
+ "acoustic guitar": 402,
1166
+ "admiral": 321,
1167
+ "affenpinscher, monkey pinscher, monkey dog": 252,
1168
+ "agama": 42,
1169
+ "agaric": 992,
1170
+ "aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
1171
+ "airliner": 404,
1172
+ "airship, dirigible": 405,
1173
+ "albatross, mollymawk": 146,
1174
+ "alligator lizard": 44,
1175
+ "alp": 970,
1176
+ "altar": 406,
1177
+ "ambulance": 407,
1178
+ "amphibian, amphibious vehicle": 408,
1179
+ "analog clock": 409,
1180
+ "anemone fish": 393,
1181
+ "ant, emmet, pismire": 310,
1182
+ "apiary, bee house": 410,
1183
+ "apron": 411,
1184
+ "armadillo": 363,
1185
+ "artichoke, globe artichoke": 944,
1186
+ "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
1187
+ "assault rifle, assault gun": 413,
1188
+ "axolotl, mud puppy, Ambystoma mexicanum": 29,
1189
+ "baboon": 372,
1190
+ "backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
1191
+ "badger": 362,
1192
+ "bagel, beigel": 931,
1193
+ "bakery, bakeshop, bakehouse": 415,
1194
+ "balance beam, beam": 416,
1195
+ "bald eagle, American eagle, Haliaeetus leucocephalus": 22,
1196
+ "balloon": 417,
1197
+ "ballplayer, baseball player": 981,
1198
+ "ballpoint, ballpoint pen, ballpen, Biro": 418,
1199
+ "banana": 954,
1200
+ "banded gecko": 38,
1201
+ "banjo": 420,
1202
+ "bannister, banister, balustrade, balusters, handrail": 421,
1203
+ "barbell": 422,
1204
+ "barber chair": 423,
1205
+ "barbershop": 424,
1206
+ "barn": 425,
1207
+ "barn spider, Araneus cavaticus": 73,
1208
+ "barometer": 426,
1209
+ "barracouta, snoek": 389,
1210
+ "barrel, cask": 427,
1211
+ "barrow, garden cart, lawn cart, wheelbarrow": 428,
1212
+ "baseball": 429,
1213
+ "basenji": 253,
1214
+ "basketball": 430,
1215
+ "basset, basset hound": 161,
1216
+ "bassinet": 431,
1217
+ "bassoon": 432,
1218
+ "bath towel": 434,
1219
+ "bathing cap, swimming cap": 433,
1220
+ "bathtub, bathing tub, bath, tub": 435,
1221
+ "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
1222
+ "beacon, lighthouse, beacon light, pharos": 437,
1223
+ "beagle": 162,
1224
+ "beaker": 438,
1225
+ "bearskin, busby, shako": 439,
1226
+ "beaver": 337,
1227
+ "bee": 309,
1228
+ "bee eater": 92,
1229
+ "beer bottle": 440,
1230
+ "beer glass": 441,
1231
+ "bell cote, bell cot": 442,
1232
+ "bell pepper": 945,
1233
+ "bib": 443,
1234
+ "bicycle-built-for-two, tandem bicycle, tandem": 444,
1235
+ "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
1236
+ "bikini, two-piece": 445,
1237
+ "binder, ring-binder": 446,
1238
+ "binoculars, field glasses, opera glasses": 447,
1239
+ "birdhouse": 448,
1240
+ "bison": 347,
1241
+ "bittern": 133,
1242
+ "black and gold garden spider, Argiope aurantia": 72,
1243
+ "black grouse": 80,
1244
+ "black stork, Ciconia nigra": 128,
1245
+ "black swan, Cygnus atratus": 100,
1246
+ "black widow, Latrodectus mactans": 75,
1247
+ "black-and-tan coonhound": 165,
1248
+ "black-footed ferret, ferret, Mustela nigripes": 359,
1249
+ "bloodhound, sleuthhound": 163,
1250
+ "bluetick": 164,
1251
+ "boa constrictor, Constrictor constrictor": 61,
1252
+ "boathouse": 449,
1253
+ "bobsled, bobsleigh, bob": 450,
1254
+ "bolete": 997,
1255
+ "bolo tie, bolo, bola tie, bola": 451,
1256
+ "bonnet, poke bonnet": 452,
1257
+ "book jacket, dust cover, dust jacket, dust wrapper": 921,
1258
+ "bookcase": 453,
1259
+ "bookshop, bookstore, bookstall": 454,
1260
+ "borzoi, Russian wolfhound": 169,
1261
+ "bottlecap": 455,
1262
+ "bow": 456,
1263
+ "bow tie, bow-tie, bowtie": 457,
1264
+ "box turtle, box tortoise": 37,
1265
+ "boxer": 242,
1266
+ "brain coral": 109,
1267
+ "brambling, Fringilla montifringilla": 10,
1268
+ "brass, memorial tablet, plaque": 458,
1269
+ "brassiere, bra, bandeau": 459,
1270
+ "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
1271
+ "breastplate, aegis, egis": 461,
1272
+ "briard": 226,
1273
+ "broccoli": 937,
1274
+ "broom": 462,
1275
+ "brown bear, bruin, Ursus arctos": 294,
1276
+ "bubble": 971,
1277
+ "bucket, pail": 463,
1278
+ "buckeye, horse chestnut, conker": 990,
1279
+ "buckle": 464,
1280
+ "bulbul": 16,
1281
+ "bull mastiff": 243,
1282
+ "bullet train, bullet": 466,
1283
+ "bulletproof vest": 465,
1284
+ "bullfrog, Rana catesbeiana": 30,
1285
+ "burrito": 965,
1286
+ "bustard": 138,
1287
+ "butcher shop, meat market": 467,
1288
+ "butternut squash": 942,
1289
+ "cab, hack, taxi, taxicab": 468,
1290
+ "cabbage butterfly": 324,
1291
+ "cairn, cairn terrier": 192,
1292
+ "caldron, cauldron": 469,
1293
+ "can opener, tin opener": 473,
1294
+ "candle, taper, wax light": 470,
1295
+ "cannon": 471,
1296
+ "canoe": 472,
1297
+ "capuchin, ringtail, Cebus capucinus": 378,
1298
+ "car mirror": 475,
1299
+ "car wheel": 479,
1300
+ "carbonara": 959,
1301
+ "cardigan": 474,
1302
+ "cardoon": 946,
1303
+ "carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
1304
+ "carpenter's kit, tool kit": 477,
1305
+ "carton": 478,
1306
+ "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
1307
+ "cassette": 481,
1308
+ "cassette player": 482,
1309
+ "castle": 483,
1310
+ "catamaran": 484,
1311
+ "cauliflower": 938,
1312
+ "cello, violoncello": 486,
1313
+ "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
1314
+ "centipede": 79,
1315
+ "chain": 488,
1316
+ "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
1317
+ "chain saw, chainsaw": 491,
1318
+ "chainlink fence": 489,
1319
+ "chambered nautilus, pearly nautilus, nautilus": 117,
1320
+ "cheeseburger": 933,
1321
+ "cheetah, chetah, Acinonyx jubatus": 293,
1322
+ "chest": 492,
1323
+ "chickadee": 19,
1324
+ "chiffonier, commode": 493,
1325
+ "chime, bell, gong": 494,
1326
+ "chimpanzee, chimp, Pan troglodytes": 367,
1327
+ "china cabinet, china closet": 495,
1328
+ "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
1329
+ "chocolate sauce, chocolate syrup": 960,
1330
+ "chow, chow chow": 260,
1331
+ "church, church building": 497,
1332
+ "cicada, cicala": 316,
1333
+ "cinema, movie theater, movie theatre, movie house, picture palace": 498,
1334
+ "cleaver, meat cleaver, chopper": 499,
1335
+ "cliff dwelling": 500,
1336
+ "cliff, drop, drop-off": 972,
1337
+ "cloak": 501,
1338
+ "clog, geta, patten, sabot": 502,
1339
+ "clumber, clumber spaniel": 216,
1340
+ "cock": 7,
1341
+ "cocker spaniel, English cocker spaniel, cocker": 219,
1342
+ "cockroach, roach": 314,
1343
+ "cocktail shaker": 503,
1344
+ "coffee mug": 504,
1345
+ "coffeepot": 505,
1346
+ "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
1347
+ "coil, spiral, volute, whorl, helix": 506,
1348
+ "collie": 231,
1349
+ "colobus, colobus monkey": 375,
1350
+ "combination lock": 507,
1351
+ "comic book": 917,
1352
+ "common iguana, iguana, Iguana iguana": 39,
1353
+ "common newt, Triturus vulgaris": 26,
1354
+ "computer keyboard, keypad": 508,
1355
+ "conch": 112,
1356
+ "confectionery, confectionary, candy store": 509,
1357
+ "consomme": 925,
1358
+ "container ship, containership, container vessel": 510,
1359
+ "convertible": 511,
1360
+ "coral fungus": 991,
1361
+ "coral reef": 973,
1362
+ "corkscrew, bottle screw": 512,
1363
+ "corn": 987,
1364
+ "cornet, horn, trumpet, trump": 513,
1365
+ "coucal": 91,
1366
+ "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
1367
+ "cowboy boot": 514,
1368
+ "cowboy hat, ten-gallon hat": 515,
1369
+ "coyote, prairie wolf, brush wolf, Canis latrans": 272,
1370
+ "cradle": 516,
1371
+ "crane": 517,
1372
+ "crash helmet": 518,
1373
+ "crate": 519,
1374
+ "crayfish, crawfish, crawdad, crawdaddy": 124,
1375
+ "crib, cot": 520,
1376
+ "cricket": 312,
1377
+ "croquet ball": 522,
1378
+ "crossword puzzle, crossword": 918,
1379
+ "crutch": 523,
1380
+ "cucumber, cuke": 943,
1381
+ "cuirass": 524,
1382
+ "cup": 968,
1383
+ "curly-coated retriever": 206,
1384
+ "custard apple": 956,
1385
+ "daisy": 985,
1386
+ "dalmatian, coach dog, carriage dog": 251,
1387
+ "dam, dike, dyke": 525,
1388
+ "damselfly": 320,
1389
+ "desk": 526,
1390
+ "desktop computer": 527,
1391
+ "dhole, Cuon alpinus": 274,
1392
+ "dial telephone, dial phone": 528,
1393
+ "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
1394
+ "diaper, nappy, napkin": 529,
1395
+ "digital clock": 530,
1396
+ "digital watch": 531,
1397
+ "dingo, warrigal, warragal, Canis dingo": 273,
1398
+ "dining table, board": 532,
1399
+ "dishrag, dishcloth": 533,
1400
+ "dishwasher, dish washer, dishwashing machine": 534,
1401
+ "disk brake, disc brake": 535,
1402
+ "dock, dockage, docking facility": 536,
1403
+ "dogsled, dog sled, dog sleigh": 537,
1404
+ "dome": 538,
1405
+ "doormat, welcome mat": 539,
1406
+ "dough": 961,
1407
+ "dowitcher": 142,
1408
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
1409
+ "drake": 97,
1410
+ "drilling platform, offshore rig": 540,
1411
+ "drum, membranophone, tympan": 541,
1412
+ "drumstick": 542,
1413
+ "dugong, Dugong dugon": 149,
1414
+ "dumbbell": 543,
1415
+ "dung beetle": 305,
1416
+ "ear, spike, capitulum": 998,
1417
+ "earthstar": 995,
1418
+ "echidna, spiny anteater, anteater": 102,
1419
+ "eel": 390,
1420
+ "eft": 27,
1421
+ "eggnog": 969,
1422
+ "electric fan, blower": 545,
1423
+ "electric guitar": 546,
1424
+ "electric locomotive": 547,
1425
+ "electric ray, crampfish, numbfish, torpedo": 5,
1426
+ "entertainment center": 548,
1427
+ "envelope": 549,
1428
+ "espresso": 967,
1429
+ "espresso maker": 550,
1430
+ "face powder": 551,
1431
+ "feather boa, boa": 552,
1432
+ "fiddler crab": 120,
1433
+ "fig": 952,
1434
+ "file, file cabinet, filing cabinet": 553,
1435
+ "fire engine, fire truck": 555,
1436
+ "fire screen, fireguard": 556,
1437
+ "fireboat": 554,
1438
+ "flagpole, flagstaff": 557,
1439
+ "flamingo": 130,
1440
+ "flat-coated retriever": 205,
1441
+ "flatworm, platyhelminth": 110,
1442
+ "flute, transverse flute": 558,
1443
+ "fly": 308,
1444
+ "folding chair": 559,
1445
+ "football helmet": 560,
1446
+ "forklift": 561,
1447
+ "fountain": 562,
1448
+ "fountain pen": 563,
1449
+ "four-poster": 564,
1450
+ "fox squirrel, eastern fox squirrel, Sciurus niger": 335,
1451
+ "freight car": 565,
1452
+ "frilled lizard, Chlamydosaurus kingi": 43,
1453
+ "frying pan, frypan, skillet": 567,
1454
+ "fur coat": 568,
1455
+ "gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
1456
+ "garbage truck, dustcart": 569,
1457
+ "garden spider, Aranea diademata": 74,
1458
+ "garter snake, grass snake": 57,
1459
+ "gas pump, gasoline pump, petrol pump, island dispenser": 571,
1460
+ "gasmask, respirator, gas helmet": 570,
1461
+ "gazelle": 353,
1462
+ "geyser": 974,
1463
+ "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
1464
+ "giant schnauzer": 197,
1465
+ "gibbon, Hylobates lar": 368,
1466
+ "go-kart": 573,
1467
+ "goblet": 572,
1468
+ "golden retriever": 207,
1469
+ "goldfinch, Carduelis carduelis": 11,
1470
+ "goldfish, Carassius auratus": 1,
1471
+ "golf ball": 574,
1472
+ "golfcart, golf cart": 575,
1473
+ "gondola": 576,
1474
+ "gong, tam-tam": 577,
1475
+ "goose": 99,
1476
+ "gorilla, Gorilla gorilla": 366,
1477
+ "gown": 578,
1478
+ "grand piano, grand": 579,
1479
+ "grasshopper, hopper": 311,
1480
+ "great grey owl, great gray owl, Strix nebulosa": 24,
1481
+ "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
1482
+ "green lizard, Lacerta viridis": 46,
1483
+ "green mamba": 64,
1484
+ "green snake, grass snake": 55,
1485
+ "greenhouse, nursery, glasshouse": 580,
1486
+ "grey fox, gray fox, Urocyon cinereoargenteus": 280,
1487
+ "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
1488
+ "grille, radiator grille": 581,
1489
+ "grocery store, grocery, food market, market": 582,
1490
+ "groenendael": 224,
1491
+ "groom, bridegroom": 982,
1492
+ "ground beetle, carabid beetle": 302,
1493
+ "guacamole": 924,
1494
+ "guenon, guenon monkey": 370,
1495
+ "guillotine": 583,
1496
+ "guinea pig, Cavia cobaya": 338,
1497
+ "gyromitra": 993,
1498
+ "hair slide": 584,
1499
+ "hair spray": 585,
1500
+ "half track": 586,
1501
+ "hammer": 587,
1502
+ "hammerhead, hammerhead shark": 4,
1503
+ "hamper": 588,
1504
+ "hamster": 333,
1505
+ "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
1506
+ "hand-held computer, hand-held microcomputer": 590,
1507
+ "handkerchief, hankie, hanky, hankey": 591,
1508
+ "hard disc, hard disk, fixed disk": 592,
1509
+ "hare": 331,
1510
+ "harmonica, mouth organ, harp, mouth harp": 593,
1511
+ "harp": 594,
1512
+ "hartebeest": 351,
1513
+ "harvester, reaper": 595,
1514
+ "harvestman, daddy longlegs, Phalangium opilio": 70,
1515
+ "hatchet": 596,
1516
+ "hay": 958,
1517
+ "head cabbage": 936,
1518
+ "hen": 8,
1519
+ "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
1520
+ "hermit crab": 125,
1521
+ "hip, rose hip, rosehip": 989,
1522
+ "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
1523
+ "hog, pig, grunter, squealer, Sus scrofa": 341,
1524
+ "hognose snake, puff adder, sand viper": 54,
1525
+ "holster": 597,
1526
+ "home theater, home theatre": 598,
1527
+ "honeycomb": 599,
1528
+ "hook, claw": 600,
1529
+ "hoopskirt, crinoline": 601,
1530
+ "horizontal bar, high bar": 602,
1531
+ "hornbill": 93,
1532
+ "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
1533
+ "horse cart, horse-cart": 603,
1534
+ "hot pot, hotpot": 926,
1535
+ "hotdog, hot dog, red hot": 934,
1536
+ "hourglass": 604,
1537
+ "house finch, linnet, Carpodacus mexicanus": 12,
1538
+ "howler monkey, howler": 379,
1539
+ "hummingbird": 94,
1540
+ "hyena, hyaena": 276,
1541
+ "iPod": 605,
1542
+ "ibex, Capra ibex": 350,
1543
+ "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
1544
+ "ice cream, icecream": 928,
1545
+ "ice lolly, lolly, lollipop, popsicle": 929,
1546
+ "impala, Aepyceros melampus": 352,
1547
+ "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
1548
+ "indri, indris, Indri indri, Indri brevicaudatus": 384,
1549
+ "iron, smoothing iron": 606,
1550
+ "isopod": 126,
1551
+ "jacamar": 95,
1552
+ "jack-o'-lantern": 607,
1553
+ "jackfruit, jak, jack": 955,
1554
+ "jaguar, panther, Panthera onca, Felis onca": 290,
1555
+ "jay": 17,
1556
+ "jean, blue jean, denim": 608,
1557
+ "jeep, landrover": 609,
1558
+ "jellyfish": 107,
1559
+ "jersey, T-shirt, tee shirt": 610,
1560
+ "jigsaw puzzle": 611,
1561
+ "jinrikisha, ricksha, rickshaw": 612,
1562
+ "joystick": 613,
1563
+ "junco, snowbird": 13,
1564
+ "keeshond": 261,
1565
+ "kelpie": 227,
1566
+ "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
1567
+ "kimono": 614,
1568
+ "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
1569
+ "king penguin, Aptenodytes patagonica": 145,
1570
+ "king snake, kingsnake": 56,
1571
+ "kit fox, Vulpes macrotis": 278,
1572
+ "kite": 21,
1573
+ "knee pad": 615,
1574
+ "knot": 616,
1575
+ "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
1576
+ "komondor": 228,
1577
+ "kuvasz": 222,
1578
+ "lab coat, laboratory coat": 617,
1579
+ "lacewing, lacewing fly": 318,
1580
+ "ladle": 618,
1581
+ "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
1582
+ "lakeside, lakeshore": 975,
1583
+ "lampshade, lamp shade": 619,
1584
+ "langur": 374,
1585
+ "laptop, laptop computer": 620,
1586
+ "lawn mower, mower": 621,
1587
+ "leaf beetle, chrysomelid": 304,
1588
+ "leafhopper": 317,
1589
+ "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
1590
+ "lemon": 951,
1591
+ "lens cap, lens cover": 622,
1592
+ "leopard, Panthera pardus": 288,
1593
+ "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
1594
+ "letter opener, paper knife, paperknife": 623,
1595
+ "library": 624,
1596
+ "lifeboat": 625,
1597
+ "lighter, light, igniter, ignitor": 626,
1598
+ "limousine, limo": 627,
1599
+ "limpkin, Aramus pictus": 135,
1600
+ "liner, ocean liner": 628,
1601
+ "lion, king of beasts, Panthera leo": 291,
1602
+ "lionfish": 396,
1603
+ "lipstick, lip rouge": 629,
1604
+ "little blue heron, Egretta caerulea": 131,
1605
+ "llama": 355,
1606
+ "loggerhead, loggerhead turtle, Caretta caretta": 33,
1607
+ "long-horned beetle, longicorn, longicorn beetle": 303,
1608
+ "lorikeet": 90,
1609
+ "lotion": 631,
1610
+ "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
1611
+ "loupe, jeweler's loupe": 633,
1612
+ "lumbermill, sawmill": 634,
1613
+ "lycaenid, lycaenid butterfly": 326,
1614
+ "lynx, catamount": 287,
1615
+ "macaque": 373,
1616
+ "macaw": 88,
1617
+ "magnetic compass": 635,
1618
+ "magpie": 18,
1619
+ "mailbag, postbag": 636,
1620
+ "mailbox, letter box": 637,
1621
+ "maillot": 638,
1622
+ "maillot, tank suit": 639,
1623
+ "malamute, malemute, Alaskan malamute": 249,
1624
+ "malinois": 225,
1625
+ "manhole cover": 640,
1626
+ "mantis, mantid": 315,
1627
+ "maraca": 641,
1628
+ "marimba, xylophone": 642,
1629
+ "marmoset": 377,
1630
+ "marmot": 336,
1631
+ "mashed potato": 935,
1632
+ "mask": 643,
1633
+ "matchstick": 644,
1634
+ "maypole": 645,
1635
+ "maze, labyrinth": 646,
1636
+ "measuring cup": 647,
1637
+ "meat loaf, meatloaf": 962,
1638
+ "medicine chest, medicine cabinet": 648,
1639
+ "meerkat, mierkat": 299,
1640
+ "megalith, megalithic structure": 649,
1641
+ "menu": 922,
1642
+ "microphone, mike": 650,
1643
+ "microwave, microwave oven": 651,
1644
+ "military uniform": 652,
1645
+ "milk can": 653,
1646
+ "miniature pinscher": 237,
1647
+ "miniature poodle": 266,
1648
+ "miniature schnauzer": 196,
1649
+ "minibus": 654,
1650
+ "miniskirt, mini": 655,
1651
+ "minivan": 656,
1652
+ "mink": 357,
1653
+ "missile": 657,
1654
+ "mitten": 658,
1655
+ "mixing bowl": 659,
1656
+ "mobile home, manufactured home": 660,
1657
+ "modem": 662,
1658
+ "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
1659
+ "monastery": 663,
1660
+ "mongoose": 298,
1661
+ "monitor": 664,
1662
+ "moped": 665,
1663
+ "mortar": 666,
1664
+ "mortarboard": 667,
1665
+ "mosque": 668,
1666
+ "mosquito net": 669,
1667
+ "motor scooter, scooter": 670,
1668
+ "mountain bike, all-terrain bike, off-roader": 671,
1669
+ "mountain tent": 672,
1670
+ "mouse, computer mouse": 673,
1671
+ "mousetrap": 674,
1672
+ "moving van": 675,
1673
+ "mud turtle": 35,
1674
+ "mushroom": 947,
1675
+ "muzzle": 676,
1676
+ "nail": 677,
1677
+ "neck brace": 678,
1678
+ "necklace": 679,
1679
+ "nematode, nematode worm, roundworm": 111,
1680
+ "night snake, Hypsiglena torquata": 60,
1681
+ "nipple": 680,
1682
+ "notebook, notebook computer": 681,
1683
+ "obelisk": 682,
1684
+ "oboe, hautboy, hautbois": 683,
1685
+ "ocarina, sweet potato": 684,
1686
+ "odometer, hodometer, mileometer, milometer": 685,
1687
+ "oil filter": 686,
1688
+ "orange": 950,
1689
+ "orangutan, orang, orangutang, Pongo pygmaeus": 365,
1690
+ "organ, pipe organ": 687,
1691
+ "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
1692
+ "ostrich, Struthio camelus": 9,
1693
+ "otter": 360,
1694
+ "otterhound, otter hound": 175,
1695
+ "overskirt": 689,
1696
+ "ox": 345,
1697
+ "oxcart": 690,
1698
+ "oxygen mask": 691,
1699
+ "oystercatcher, oyster catcher": 143,
1700
+ "packet": 692,
1701
+ "paddle, boat paddle": 693,
1702
+ "paddlewheel, paddle wheel": 694,
1703
+ "padlock": 695,
1704
+ "paintbrush": 696,
1705
+ "pajama, pyjama, pj's, jammies": 697,
1706
+ "palace": 698,
1707
+ "panpipe, pandean pipe, syrinx": 699,
1708
+ "paper towel": 700,
1709
+ "papillon": 157,
1710
+ "parachute, chute": 701,
1711
+ "parallel bars, bars": 702,
1712
+ "park bench": 703,
1713
+ "parking meter": 704,
1714
+ "partridge": 86,
1715
+ "passenger car, coach, carriage": 705,
1716
+ "patas, hussar monkey, Erythrocebus patas": 371,
1717
+ "patio, terrace": 706,
1718
+ "pay-phone, pay-station": 707,
1719
+ "peacock": 84,
1720
+ "pedestal, plinth, footstall": 708,
1721
+ "pelican": 144,
1722
+ "pencil box, pencil case": 709,
1723
+ "pencil sharpener": 710,
1724
+ "perfume, essence": 711,
1725
+ "photocopier": 713,
1726
+ "pick, plectrum, plectron": 714,
1727
+ "pickelhaube": 715,
1728
+ "picket fence, paling": 716,
1729
+ "pickup, pickup truck": 717,
1730
+ "pier": 718,
1731
+ "piggy bank, penny bank": 719,
1732
+ "pill bottle": 720,
1733
+ "pillow": 721,
1734
+ "pineapple, ananas": 953,
1735
+ "ping-pong ball": 722,
1736
+ "pinwheel": 723,
1737
+ "pirate, pirate ship": 724,
1738
+ "pitcher, ewer": 725,
1739
+ "pizza, pizza pie": 963,
1740
+ "plane, carpenter's plane, woodworking plane": 726,
1741
+ "planetarium": 727,
1742
+ "plastic bag": 728,
1743
+ "plate": 923,
1744
+ "plate rack": 729,
1745
+ "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
1746
+ "plow, plough": 730,
1747
+ "plunger, plumber's helper": 731,
1748
+ "pole": 733,
1749
+ "polecat, fitch, foulmart, foumart, Mustela putorius": 358,
1750
+ "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
1751
+ "pomegranate": 957,
1752
+ "poncho": 735,
1753
+ "pool table, billiard table, snooker table": 736,
1754
+ "pop bottle, soda bottle": 737,
1755
+ "porcupine, hedgehog": 334,
1756
+ "pot, flowerpot": 738,
1757
+ "potpie": 964,
1758
+ "potter's wheel": 739,
1759
+ "power drill": 740,
1760
+ "prairie chicken, prairie grouse, prairie fowl": 83,
1761
+ "prayer rug, prayer mat": 741,
1762
+ "pretzel": 932,
1763
+ "printer": 742,
1764
+ "prison, prison house": 743,
1765
+ "proboscis monkey, Nasalis larvatus": 376,
1766
+ "projectile, missile": 744,
1767
+ "projector": 745,
1768
+ "promontory, headland, head, foreland": 976,
1769
+ "ptarmigan": 81,
1770
+ "puck, hockey puck": 746,
1771
+ "puffer, pufferfish, blowfish, globefish": 397,
1772
+ "pug, pug-dog": 254,
1773
+ "punching bag, punch bag, punching ball, punchball": 747,
1774
+ "purse": 748,
1775
+ "quail": 85,
1776
+ "quill, quill pen": 749,
1777
+ "quilt, comforter, comfort, puff": 750,
1778
+ "racer, race car, racing car": 751,
1779
+ "racket, racquet": 752,
1780
+ "radiator": 753,
1781
+ "radio telescope, radio reflector": 755,
1782
+ "radio, wireless": 754,
1783
+ "rain barrel": 756,
1784
+ "ram, tup": 348,
1785
+ "rapeseed": 984,
1786
+ "recreational vehicle, RV, R.V.": 757,
1787
+ "red fox, Vulpes vulpes": 277,
1788
+ "red wine": 966,
1789
+ "red wolf, maned wolf, Canis rufus, Canis niger": 271,
1790
+ "red-backed sandpiper, dunlin, Erolia alpina": 140,
1791
+ "red-breasted merganser, Mergus serrator": 98,
1792
+ "redbone": 168,
1793
+ "redshank, Tringa totanus": 141,
1794
+ "reel": 758,
1795
+ "reflex camera": 759,
1796
+ "refrigerator, icebox": 760,
1797
+ "remote control, remote": 761,
1798
+ "restaurant, eating house, eating place, eatery": 762,
1799
+ "revolver, six-gun, six-shooter": 763,
1800
+ "rhinoceros beetle": 306,
1801
+ "rifle": 764,
1802
+ "ringlet, ringlet butterfly": 322,
1803
+ "ringneck snake, ring-necked snake, ring snake": 53,
1804
+ "robin, American robin, Turdus migratorius": 15,
1805
+ "rock beauty, Holocanthus tricolor": 392,
1806
+ "rock crab, Cancer irroratus": 119,
1807
+ "rock python, rock snake, Python sebae": 62,
1808
+ "rocking chair, rocker": 765,
1809
+ "rotisserie": 766,
1810
+ "rubber eraser, rubber, pencil eraser": 767,
1811
+ "ruddy turnstone, Arenaria interpres": 139,
1812
+ "ruffed grouse, partridge, Bonasa umbellus": 82,
1813
+ "rugby ball": 768,
1814
+ "rule, ruler": 769,
1815
+ "running shoe": 770,
1816
+ "safe": 771,
1817
+ "safety pin": 772,
1818
+ "saltshaker, salt shaker": 773,
1819
+ "sandal": 774,
1820
+ "sandbar, sand bar": 977,
1821
+ "sarong": 775,
1822
+ "sax, saxophone": 776,
1823
+ "scabbard": 777,
1824
+ "scale, weighing machine": 778,
1825
+ "schipperke": 223,
1826
+ "school bus": 779,
1827
+ "schooner": 780,
1828
+ "scoreboard": 781,
1829
+ "scorpion": 71,
1830
+ "screen, CRT screen": 782,
1831
+ "screw": 783,
1832
+ "screwdriver": 784,
1833
+ "scuba diver": 983,
1834
+ "sea anemone, anemone": 108,
1835
+ "sea cucumber, holothurian": 329,
1836
+ "sea lion": 150,
1837
+ "sea slug, nudibranch": 115,
1838
+ "sea snake": 65,
1839
+ "sea urchin": 328,
1840
+ "seashore, coast, seacoast, sea-coast": 978,
1841
+ "seat belt, seatbelt": 785,
1842
+ "sewing machine": 786,
1843
+ "shield, buckler": 787,
1844
+ "shoe shop, shoe-shop, shoe store": 788,
1845
+ "shoji": 789,
1846
+ "shopping basket": 790,
1847
+ "shopping cart": 791,
1848
+ "shovel": 792,
1849
+ "shower cap": 793,
1850
+ "shower curtain": 794,
1851
+ "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
1852
+ "sidewinder, horned rattlesnake, Crotalus cerastes": 68,
1853
+ "silky terrier, Sydney silky": 201,
1854
+ "ski": 795,
1855
+ "ski mask": 796,
1856
+ "skunk, polecat, wood pussy": 361,
1857
+ "sleeping bag": 797,
1858
+ "slide rule, slipstick": 798,
1859
+ "sliding door": 799,
1860
+ "slot, one-armed bandit": 800,
1861
+ "sloth bear, Melursus ursinus, Ursus ursinus": 297,
1862
+ "slug": 114,
1863
+ "snail": 113,
1864
+ "snorkel": 801,
1865
+ "snow leopard, ounce, Panthera uncia": 289,
1866
+ "snowmobile": 802,
1867
+ "snowplow, snowplough": 803,
1868
+ "soap dispenser": 804,
1869
+ "soccer ball": 805,
1870
+ "sock": 806,
1871
+ "soft-coated wheaten terrier": 202,
1872
+ "solar dish, solar collector, solar furnace": 807,
1873
+ "sombrero": 808,
1874
+ "sorrel": 339,
1875
+ "soup bowl": 809,
1876
+ "space bar": 810,
1877
+ "space heater": 811,
1878
+ "space shuttle": 812,
1879
+ "spaghetti squash": 940,
1880
+ "spatula": 813,
1881
+ "speedboat": 814,
1882
+ "spider monkey, Ateles geoffroyi": 381,
1883
+ "spider web, spider's web": 815,
1884
+ "spindle": 816,
1885
+ "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
1886
+ "spoonbill": 129,
1887
+ "sports car, sport car": 817,
1888
+ "spotlight, spot": 818,
1889
+ "spotted salamander, Ambystoma maculatum": 28,
1890
+ "squirrel monkey, Saimiri sciureus": 382,
1891
+ "stage": 819,
1892
+ "standard poodle": 267,
1893
+ "standard schnauzer": 198,
1894
+ "starfish, sea star": 327,
1895
+ "steam locomotive": 820,
1896
+ "steel arch bridge": 821,
1897
+ "steel drum": 822,
1898
+ "stethoscope": 823,
1899
+ "stingray": 6,
1900
+ "stinkhorn, carrion fungus": 994,
1901
+ "stole": 824,
1902
+ "stone wall": 825,
1903
+ "stopwatch, stop watch": 826,
1904
+ "stove": 827,
1905
+ "strainer": 828,
1906
+ "strawberry": 949,
1907
+ "street sign": 919,
1908
+ "streetcar, tram, tramcar, trolley, trolley car": 829,
1909
+ "stretcher": 830,
1910
+ "studio couch, day bed": 831,
1911
+ "stupa, tope": 832,
1912
+ "sturgeon": 394,
1913
+ "submarine, pigboat, sub, U-boat": 833,
1914
+ "suit, suit of clothes": 834,
1915
+ "sulphur butterfly, sulfur butterfly": 325,
1916
+ "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
1917
+ "sundial": 835,
1918
+ "sunglass": 836,
1919
+ "sunglasses, dark glasses, shades": 837,
1920
+ "sunscreen, sunblock, sun blocker": 838,
1921
+ "suspension bridge": 839,
1922
+ "swab, swob, mop": 840,
1923
+ "sweatshirt": 841,
1924
+ "swimming trunks, bathing trunks": 842,
1925
+ "swing": 843,
1926
+ "switch, electric switch, electrical switch": 844,
1927
+ "syringe": 845,
1928
+ "tabby, tabby cat": 281,
1929
+ "table lamp": 846,
1930
+ "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
1931
+ "tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
1932
+ "tape player": 848,
1933
+ "tarantula": 76,
1934
+ "teapot": 849,
1935
+ "teddy, teddy bear": 850,
1936
+ "television, television system": 851,
1937
+ "tench, Tinca tinca": 0,
1938
+ "tennis ball": 852,
1939
+ "terrapin": 36,
1940
+ "thatch, thatched roof": 853,
1941
+ "theater curtain, theatre curtain": 854,
1942
+ "thimble": 855,
1943
+ "three-toed sloth, ai, Bradypus tridactylus": 364,
1944
+ "thresher, thrasher, threshing machine": 856,
1945
+ "throne": 857,
1946
+ "thunder snake, worm snake, Carphophis amoenus": 52,
1947
+ "tick": 78,
1948
+ "tiger beetle": 300,
1949
+ "tiger cat": 282,
1950
+ "tiger shark, Galeocerdo cuvieri": 3,
1951
+ "tiger, Panthera tigris": 292,
1952
+ "tile roof": 858,
1953
+ "timber wolf, grey wolf, gray wolf, Canis lupus": 269,
1954
+ "titi, titi monkey": 380,
1955
+ "toaster": 859,
1956
+ "tobacco shop, tobacconist shop, tobacconist": 860,
1957
+ "toilet seat": 861,
1958
+ "toilet tissue, toilet paper, bathroom tissue": 999,
1959
+ "torch": 862,
1960
+ "totem pole": 863,
1961
+ "toucan": 96,
1962
+ "tow truck, tow car, wrecker": 864,
1963
+ "toy poodle": 265,
1964
+ "toy terrier": 158,
1965
+ "toyshop": 865,
1966
+ "tractor": 866,
1967
+ "traffic light, traffic signal, stoplight": 920,
1968
+ "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
1969
+ "tray": 868,
1970
+ "tree frog, tree-frog": 31,
1971
+ "trench coat": 869,
1972
+ "triceratops": 51,
1973
+ "tricycle, trike, velocipede": 870,
1974
+ "trifle": 927,
1975
+ "trilobite": 69,
1976
+ "trimaran": 871,
1977
+ "tripod": 872,
1978
+ "triumphal arch": 873,
1979
+ "trolleybus, trolley coach, trackless trolley": 874,
1980
+ "trombone": 875,
1981
+ "tub, vat": 876,
1982
+ "turnstile": 877,
1983
+ "tusker": 101,
1984
+ "typewriter keyboard": 878,
1985
+ "umbrella": 879,
1986
+ "unicycle, monocycle": 880,
1987
+ "upright, upright piano": 881,
1988
+ "vacuum, vacuum cleaner": 882,
1989
+ "valley, vale": 979,
1990
+ "vase": 883,
1991
+ "vault": 884,
1992
+ "velvet": 885,
1993
+ "vending machine": 886,
1994
+ "vestment": 887,
1995
+ "viaduct": 888,
1996
+ "vine snake": 59,
1997
+ "violin, fiddle": 889,
1998
+ "vizsla, Hungarian pointer": 211,
1999
+ "volcano": 980,
2000
+ "volleyball": 890,
2001
+ "vulture": 23,
2002
+ "waffle iron": 891,
2003
+ "walking stick, walkingstick, stick insect": 313,
2004
+ "wall clock": 892,
2005
+ "wallaby, brush kangaroo": 104,
2006
+ "wallet, billfold, notecase, pocketbook": 893,
2007
+ "wardrobe, closet, press": 894,
2008
+ "warplane, military plane": 895,
2009
+ "warthog": 343,
2010
+ "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
2011
+ "washer, automatic washer, washing machine": 897,
2012
+ "water bottle": 898,
2013
+ "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
2014
+ "water jug": 899,
2015
+ "water ouzel, dipper": 20,
2016
+ "water snake": 58,
2017
+ "water tower": 900,
2018
+ "weasel": 356,
2019
+ "web site, website, internet site, site": 916,
2020
+ "weevil": 307,
2021
+ "whippet": 172,
2022
+ "whiptail, whiptail lizard": 41,
2023
+ "whiskey jug": 901,
2024
+ "whistle": 902,
2025
+ "white stork, Ciconia ciconia": 127,
2026
+ "white wolf, Arctic wolf, Canis lupus tundrarum": 270,
2027
+ "wig": 903,
2028
+ "wild boar, boar, Sus scrofa": 342,
2029
+ "window screen": 904,
2030
+ "window shade": 905,
2031
+ "wine bottle": 907,
2032
+ "wing": 908,
2033
+ "wire-haired fox terrier": 188,
2034
+ "wok": 909,
2035
+ "wolf spider, hunting spider": 77,
2036
+ "wombat": 106,
2037
+ "wood rabbit, cottontail, cottontail rabbit": 330,
2038
+ "wooden spoon": 910,
2039
+ "wool, woolen, woollen": 911,
2040
+ "worm fence, snake fence, snake-rail fence, Virginia fence": 912,
2041
+ "wreck": 913,
2042
+ "yawl": 914,
2043
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
2044
+ "yurt": 915,
2045
+ "zebra": 340,
2046
+ "zucchini, courgette": 939
2047
+ },
2048
+ "layer_norm_eps": 1e-05,
2049
+ "length_penalty": 1.0,
2050
+ "max_length": 20,
2051
+ "min_length": 0,
2052
+ "mlp_ratio": 4.0,
2053
+ "model_type": "swin",
2054
+ "no_repeat_ngram_size": 0,
2055
+ "num_beam_groups": 1,
2056
+ "num_beams": 1,
2057
+ "num_channels": 3,
2058
+ "num_heads": [
2059
+ 3,
2060
+ 6,
2061
+ 12,
2062
+ 24
2063
+ ],
2064
+ "num_layers": 4,
2065
+ "num_return_sequences": 1,
2066
+ "out_features": [
2067
+ "stage1",
2068
+ "stage2",
2069
+ "stage3",
2070
+ "stage4"
2071
+ ],
2072
+ "output_attentions": false,
2073
+ "output_hidden_states": false,
2074
+ "output_scores": false,
2075
+ "pad_token_id": null,
2076
+ "patch_size": 4,
2077
+ "path_norm": true,
2078
+ "prefix": null,
2079
+ "problem_type": null,
2080
+ "pruned_heads": {},
2081
+ "qkv_bias": true,
2082
+ "remove_invalid_values": false,
2083
+ "repetition_penalty": 1.0,
2084
+ "return_dict": true,
2085
+ "return_dict_in_generate": false,
2086
+ "sep_token_id": null,
2087
+ "stage_names": [
2088
+ "stem",
2089
+ "stage1",
2090
+ "stage2",
2091
+ "stage3",
2092
+ "stage4"
2093
+ ],
2094
+ "suppress_tokens": null,
2095
+ "task_specific_params": null,
2096
+ "temperature": 1.0,
2097
+ "tf_legacy_loss": false,
2098
+ "tie_encoder_decoder": false,
2099
+ "tie_word_embeddings": true,
2100
+ "tokenizer_class": null,
2101
+ "top_k": 50,
2102
+ "top_p": 1.0,
2103
+ "torch_dtype": "float32",
2104
+ "torchscript": false,
2105
+ "transformers_version": "4.26.0.dev0",
2106
+ "typical_p": 1.0,
2107
+ "use_absolute_embeddings": false,
2108
+ "use_bfloat16": false,
2109
+ "window_size": 7
2110
+ },
2111
+ "class_weight": 2.0,
2112
+ "common_stride": 4,
2113
+ "decoder_layers": 10,
2114
+ "dice_weight": 5.0,
2115
+ "dim_feedforward": 2048,
2116
+ "dropout": 0.0,
2117
+ "encoder_feedforward_dim": 1024,
2118
+ "encoder_layers": 6,
2119
+ "enforce_input_proj": false,
2120
+ "enforce_input_projection": false,
2121
+ "feature_size": 256,
2122
+ "feature_strides": [
2123
+ 4,
2124
+ 8,
2125
+ 16,
2126
+ 32
2127
+ ],
2128
+ "hidden_dim": 256,
2129
+ "id2label": {
2130
+ "0": "person",
2131
+ "1": "bicycle",
2132
+ "2": "car",
2133
+ "3": "motorbike",
2134
+ "4": "aeroplane",
2135
+ "5": "bus",
2136
+ "6": "train",
2137
+ "7": "truck",
2138
+ "8": "boat",
2139
+ "9": "traffic light",
2140
+ "10": "fire hydrant",
2141
+ "11": "stop sign",
2142
+ "12": "parking meter",
2143
+ "13": "bench",
2144
+ "14": "bird",
2145
+ "15": "cat",
2146
+ "16": "dog",
2147
+ "17": "horse",
2148
+ "18": "sheep",
2149
+ "19": "cow",
2150
+ "20": "elephant",
2151
+ "21": "bear",
2152
+ "22": "zebra",
2153
+ "23": "giraffe",
2154
+ "24": "backpack",
2155
+ "25": "umbrella",
2156
+ "26": "handbag",
2157
+ "27": "tie",
2158
+ "28": "suitcase",
2159
+ "29": "frisbee",
2160
+ "30": "skis",
2161
+ "31": "snowboard",
2162
+ "32": "sports ball",
2163
+ "33": "kite",
2164
+ "34": "baseball bat",
2165
+ "35": "baseball glove",
2166
+ "36": "skateboard",
2167
+ "37": "surfboard",
2168
+ "38": "tennis racket",
2169
+ "39": "bottle",
2170
+ "40": "wine glass",
2171
+ "41": "cup",
2172
+ "42": "fork",
2173
+ "43": "knife",
2174
+ "44": "spoon",
2175
+ "45": "bowl",
2176
+ "46": "banana",
2177
+ "47": "apple",
2178
+ "48": "sandwich",
2179
+ "49": "orange",
2180
+ "50": "broccoli",
2181
+ "51": "carrot",
2182
+ "52": "hot dog",
2183
+ "53": "pizza",
2184
+ "54": "donut",
2185
+ "55": "cake",
2186
+ "56": "chair",
2187
+ "57": "sofa",
2188
+ "58": "pottedplant",
2189
+ "59": "bed",
2190
+ "60": "diningtable",
2191
+ "61": "toilet",
2192
+ "62": "tvmonitor",
2193
+ "63": "laptop",
2194
+ "64": "mouse",
2195
+ "65": "remote",
2196
+ "66": "keyboard",
2197
+ "67": "cell phone",
2198
+ "68": "microwave",
2199
+ "69": "oven",
2200
+ "70": "toaster",
2201
+ "71": "sink",
2202
+ "72": "refrigerator",
2203
+ "73": "book",
2204
+ "74": "clock",
2205
+ "75": "vase",
2206
+ "76": "scissors",
2207
+ "77": "teddy bear",
2208
+ "78": "hair drier",
2209
+ "79": "toothbrush"
2210
+ },
2211
+ "ignore_value": 255,
2212
+ "importance_sample_ratio": 0.75,
2213
+ "init_std": 0.02,
2214
+ "init_xavier_std": 1.0,
2215
+ "label2id": {
2216
+ "aeroplane": 4,
2217
+ "apple": 47,
2218
+ "backpack": 24,
2219
+ "banana": 46,
2220
+ "baseball bat": 34,
2221
+ "baseball glove": 35,
2222
+ "bear": 21,
2223
+ "bed": 59,
2224
+ "bench": 13,
2225
+ "bicycle": 1,
2226
+ "bird": 14,
2227
+ "boat": 8,
2228
+ "book": 73,
2229
+ "bottle": 39,
2230
+ "bowl": 45,
2231
+ "broccoli": 50,
2232
+ "bus": 5,
2233
+ "cake": 55,
2234
+ "car": 2,
2235
+ "carrot": 51,
2236
+ "cat": 15,
2237
+ "cell phone": 67,
2238
+ "chair": 56,
2239
+ "clock": 74,
2240
+ "cow": 19,
2241
+ "cup": 41,
2242
+ "diningtable": 60,
2243
+ "dog": 16,
2244
+ "donut": 54,
2245
+ "elephant": 20,
2246
+ "fire hydrant": 10,
2247
+ "fork": 42,
2248
+ "frisbee": 29,
2249
+ "giraffe": 23,
2250
+ "hair drier": 78,
2251
+ "handbag": 26,
2252
+ "horse": 17,
2253
+ "hot dog": 52,
2254
+ "keyboard": 66,
2255
+ "kite": 33,
2256
+ "knife": 43,
2257
+ "laptop": 63,
2258
+ "microwave": 68,
2259
+ "motorbike": 3,
2260
+ "mouse": 64,
2261
+ "orange": 49,
2262
+ "oven": 69,
2263
+ "parking meter": 12,
2264
+ "person": 0,
2265
+ "pizza": 53,
2266
+ "pottedplant": 58,
2267
+ "refrigerator": 72,
2268
+ "remote": 65,
2269
+ "sandwich": 48,
2270
+ "scissors": 76,
2271
+ "sheep": 18,
2272
+ "sink": 71,
2273
+ "skateboard": 36,
2274
+ "skis": 30,
2275
+ "snowboard": 31,
2276
+ "sofa": 57,
2277
+ "spoon": 44,
2278
+ "sports ball": 32,
2279
+ "stop sign": 11,
2280
+ "suitcase": 28,
2281
+ "surfboard": 37,
2282
+ "teddy bear": 77,
2283
+ "tennis racket": 38,
2284
+ "tie": 27,
2285
+ "toaster": 70,
2286
+ "toilet": 61,
2287
+ "toothbrush": 79,
2288
+ "traffic light": 9,
2289
+ "train": 6,
2290
+ "truck": 7,
2291
+ "tvmonitor": 62,
2292
+ "umbrella": 25,
2293
+ "vase": 75,
2294
+ "wine glass": 40,
2295
+ "zebra": 22
2296
+ },
2297
+ "mask_feature_size": 256,
2298
+ "mask_weight": 5.0,
2299
+ "model_type": "mask2former",
2300
+ "no_object_weight": 0.1,
2301
+ "num_attention_heads": 8,
2302
+ "num_hidden_layers": 10,
2303
+ "num_queries": 100,
2304
+ "output_auxiliary_logits": null,
2305
+ "oversample_ratio": 3.0,
2306
+ "pre_norm": false,
2307
+ "torch_dtype": "float32",
2308
+ "train_num_points": 12544,
2309
+ "transformers_version": null,
2310
+ "use_auxiliary_loss": true
2311
+ }
external/human_matting/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .matting_engine import StyleMatteEngine
external/human_matting/matting_engine.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import inspect
4
+ import warnings
5
+ import torchvision
6
+ from .stylematte import StyleMatte
7
+
8
+ class StyleMatteEngine(torch.nn.Module):
9
+ def __init__(self, device='cpu',human_matting_path='./pretrain_model/matting/stylematte_synth.pt'):
10
+ super().__init__()
11
+ self._device = device
12
+ self.normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13
+ self._init_models(human_matting_path)
14
+
15
+ def _init_models(self,_ckpt_path):
16
+ # load dict
17
+ state_dict = torch.load(_ckpt_path, map_location='cpu')
18
+ # build model
19
+ model = StyleMatte()
20
+ model.load_state_dict(state_dict)
21
+ self.model = model.to(self._device).eval()
22
+
23
+ @torch.no_grad()
24
+ def forward(self, input_image, return_type='matting', background_rgb=1.0):
25
+ if not hasattr(self, 'model'):
26
+ self._init_models()
27
+ if input_image.max() > 2.0:
28
+ warnings.warn('Image should be normalized to [0, 1].')
29
+ _, ori_h, ori_w = input_image.shape
30
+ input_image = input_image.to(self._device).float()
31
+ image = input_image.clone()
32
+ # resize
33
+ if max(ori_h, ori_w) > 1024:
34
+ scale = 1024.0 / max(ori_h, ori_w)
35
+ resized_h, resized_w = int(ori_h * scale), int(ori_w * scale)
36
+ image = torchvision.transforms.functional.resize(image, (resized_h, resized_w), antialias=True)
37
+ else:
38
+ resized_h, resized_w = ori_h, ori_w
39
+ # padding
40
+ if resized_h % 8 != 0 or resized_w % 8 != 0:
41
+ image = torchvision.transforms.functional.pad(image, ((8-resized_w % 8)%8, (8-resized_h % 8)%8, 0, 0, ), padding_mode='reflect')
42
+ # normalize and forwarding
43
+ image = self.normalize(image)[None]
44
+ predict = self.model(image)[0]
45
+ # undo padding
46
+ predict = predict[:, -resized_h:, -resized_w:]
47
+ # undo resize
48
+ if resized_h != ori_h or resized_w != ori_w:
49
+ predict = torchvision.transforms.functional.resize(predict, (ori_h, ori_w), antialias=True)
50
+
51
+ if return_type == 'alpha':
52
+ return predict[0]
53
+ elif return_type == 'matting':
54
+ predict = predict.expand(3, -1, -1)
55
+ matting_image = input_image.clone()
56
+ background_rgb = matting_image.new_ones(matting_image.shape) * background_rgb
57
+ matting_image = matting_image * predict + (1-predict) * background_rgb
58
+ return matting_image, predict[0]
59
+ elif return_type == 'all':
60
+ predict = predict.expand(3, -1, -1)
61
+ background_rgb = input_image.new_ones(input_image.shape) * background_rgb
62
+ foreground_image = input_image * predict + (1-predict) * background_rgb
63
+ background_image = input_image * (1-predict) + predict * background_rgb
64
+ return foreground_image, background_image
65
+ else:
66
+ raise NotImplementedError
external/human_matting/stylematte.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import Mask2FormerForUniversalSegmentation
6
+ from transformers.models.mask2former.configuration_mask2former import Mask2FormerConfig
7
+
8
+ class StyleMatte(nn.Module):
9
+ def __init__(self):
10
+ super(StyleMatte, self).__init__()
11
+ self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256], fpn_out=256)
12
+ config = Mask2FormerConfig.from_json_file('./configs/stylematte_config.json')
13
+ self.pixel_decoder = Mask2FormerForUniversalSegmentation(config).base_model.pixel_level_module
14
+ self.fgf = FastGuidedFilter(eps=1e-4)
15
+ self.conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)
16
+
17
+ def forward(self, image, normalize=False):
18
+ decoder_out = self.pixel_decoder(image)
19
+ decoder_states = list(decoder_out.decoder_hidden_states)
20
+ decoder_states.append(decoder_out.decoder_last_hidden_state)
21
+ out_pure = self.fpn(decoder_states)
22
+
23
+ image_lr = nn.functional.interpolate(image.mean(1, keepdim=True),
24
+ scale_factor=0.25,
25
+ mode='bicubic',
26
+ align_corners=True
27
+ )
28
+ out = self.conv(out_pure)
29
+ out = self.fgf(image_lr, out, image.mean(1, keepdim=True))
30
+
31
+ return torch.sigmoid(out)
32
+
33
+ def get_training_params(self):
34
+ return list(self.fpn.parameters())+list(self.conv.parameters())
35
+
36
+
37
+ def conv2d_relu(input_filters, output_filters, kernel_size=3, bias=True):
38
+ return nn.Sequential(
39
+ nn.Conv2d(input_filters, output_filters,
40
+ kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
41
+ nn.LeakyReLU(0.2, inplace=True),
42
+ nn.BatchNorm2d(output_filters)
43
+ )
44
+
45
+
46
+ def up_and_add(x, y):
47
+ return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y
48
+
49
+
50
+ class FPN_fuse(nn.Module):
51
+ def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
52
+ super(FPN_fuse, self).__init__()
53
+ assert feature_channels[0] == fpn_out
54
+ self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
55
+ for ft_size in feature_channels[1:]])
56
+ self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)]
57
+ * (len(feature_channels)-1))
58
+ self.conv_fusion = nn.Sequential(
59
+ nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3,
60
+ padding=1, bias=False),
61
+ nn.BatchNorm2d(fpn_out),
62
+ nn.ReLU(inplace=True),
63
+ )
64
+
65
+ def forward(self, features):
66
+
67
+ features[:-1] = [conv1x1(feature) for feature,
68
+ conv1x1 in zip(features[:-1], self.conv1x1)]
69
+ feature = up_and_add(self.smooth_conv[0](features[0]), features[1])
70
+ feature = up_and_add(self.smooth_conv[1](feature), features[2])
71
+ feature = up_and_add(self.smooth_conv[2](feature), features[3])
72
+
73
+ H, W = features[-1].size(2), features[-1].size(3)
74
+ x = [feature, features[-1]]
75
+ x = [F.interpolate(x_el, size=(H, W), mode='bilinear',
76
+ align_corners=True) for x_el in x]
77
+
78
+ x = self.conv_fusion(torch.cat(x, dim=1))
79
+
80
+ return x
81
+
82
+
83
+ class PSPModule(nn.Module):
84
+ # In the original inmplementation they use precise RoI pooling
85
+ # Instead of using adaptative average pooling
86
+ def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
87
+ super(PSPModule, self).__init__()
88
+ out_channels = in_channels // len(bin_sizes)
89
+ self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s)
90
+ for b_s in bin_sizes])
91
+ self.bottleneck = nn.Sequential(
92
+ nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels,
93
+ kernel_size=3, padding=1, bias=False),
94
+ nn.BatchNorm2d(in_channels),
95
+ nn.ReLU(inplace=True),
96
+ nn.Dropout2d(0.1)
97
+ )
98
+
99
+ def _make_stages(self, in_channels, out_channels, bin_sz):
100
+ prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
101
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
102
+ bn = nn.BatchNorm2d(out_channels)
103
+ relu = nn.ReLU(inplace=True)
104
+ return nn.Sequential(prior, conv, bn, relu)
105
+
106
+ def forward(self, features):
107
+ h, w = features.size()[2], features.size()[3]
108
+ pyramids = [features]
109
+ pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
110
+ align_corners=True) for stage in self.stages])
111
+ output = self.bottleneck(torch.cat(pyramids, dim=1))
112
+ return output
113
+
114
+
115
+ class GuidedFilter(nn.Module):
116
+ def __init__(self, r, eps=1e-8):
117
+ super(GuidedFilter, self).__init__()
118
+
119
+ self.r = r
120
+ self.eps = eps
121
+ self.boxfilter = BoxFilter(r)
122
+
123
+ def forward(self, x, y):
124
+ n_x, c_x, h_x, w_x = x.size()
125
+ n_y, c_y, h_y, w_y = y.size()
126
+
127
+ assert n_x == n_y
128
+ assert c_x == 1 or c_x == c_y
129
+ assert h_x == h_y and w_x == w_y
130
+ assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
131
+
132
+ # N
133
+ N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
134
+
135
+ # mean_x
136
+ mean_x = self.boxfilter(x) / N
137
+ # mean_y
138
+ mean_y = self.boxfilter(y) / N
139
+ # cov_xy
140
+ cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
141
+ # var_x
142
+ var_x = self.boxfilter(x * x) / N - mean_x * mean_x
143
+
144
+ # A
145
+ A = cov_xy / (var_x + self.eps)
146
+ # b
147
+ b = mean_y - A * mean_x
148
+
149
+ # mean_A; mean_b
150
+ mean_A = self.boxfilter(A) / N
151
+ mean_b = self.boxfilter(b) / N
152
+
153
+ return mean_A * x + mean_b
154
+
155
+
156
+ class FastGuidedFilter(nn.Module):
157
+ def __init__(self, r=1, eps=1e-8):
158
+ super(FastGuidedFilter, self).__init__()
159
+
160
+ self.r = r
161
+ self.eps = eps
162
+ self.boxfilter = BoxFilter(r)
163
+
164
+ def forward(self, lr_x, lr_y, hr_x):
165
+ n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
166
+ n_lry, c_lry, h_lry, w_lry = lr_y.size()
167
+ n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
168
+
169
+ assert n_lrx == n_lry and n_lry == n_hrx
170
+ assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
171
+ assert h_lrx == h_lry and w_lrx == w_lry
172
+ assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
173
+
174
+ # N
175
+ N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))
176
+
177
+ # mean_x
178
+ mean_x = self.boxfilter(lr_x) / N
179
+ # mean_y
180
+ mean_y = self.boxfilter(lr_y) / N
181
+ # cov_xy
182
+ cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
183
+ # var_x
184
+ var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
185
+
186
+ # A
187
+ A = cov_xy / (var_x + self.eps)
188
+ # b
189
+ b = mean_y - A * mean_x
190
+
191
+ # mean_A; mean_b
192
+ mean_A = F.interpolate(
193
+ A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
194
+ mean_b = F.interpolate(
195
+ b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
196
+
197
+ return mean_A*hr_x+mean_b
198
+
199
+
200
+ class DeepGuidedFilterRefiner(nn.Module):
201
+ def __init__(self, hid_channels=16):
202
+ super().__init__()
203
+ self.box_filter = nn.Conv2d(
204
+ 4, 4, kernel_size=3, padding=1, bias=False, groups=4)
205
+ self.box_filter.weight.data[...] = 1 / 9
206
+ self.conv = nn.Sequential(
207
+ nn.Conv2d(4 * 2 + hid_channels, hid_channels,
208
+ kernel_size=1, bias=False),
209
+ nn.BatchNorm2d(hid_channels),
210
+ nn.ReLU(True),
211
+ nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
212
+ nn.BatchNorm2d(hid_channels),
213
+ nn.ReLU(True),
214
+ nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
215
+ )
216
+
217
+ def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
218
+ fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
219
+ base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
220
+ base_y = torch.cat([base_fgr, base_pha], dim=1)
221
+
222
+ mean_x = self.box_filter(base_x)
223
+ mean_y = self.box_filter(base_y)
224
+ cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
225
+ var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
226
+
227
+ A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
228
+ b = mean_y - A * mean_x
229
+
230
+ H, W = fine_src.shape[2:]
231
+ A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
232
+ b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
233
+
234
+ out = A * fine_x + b
235
+ fgr, pha = out.split([3, 1], dim=1)
236
+ return fgr, pha
237
+
238
+
239
+ def diff_x(input, r):
240
+ assert input.dim() == 4
241
+
242
+ left = input[:, :, r:2 * r + 1]
243
+ middle = input[:, :, 2 * r + 1:] - input[:, :, :-2 * r - 1]
244
+ right = input[:, :, -1:] - input[:, :, -2 * r - 1: -r - 1]
245
+
246
+ output = torch.cat([left, middle, right], dim=2)
247
+
248
+ return output
249
+
250
+
251
+ def diff_y(input, r):
252
+ assert input.dim() == 4
253
+
254
+ left = input[:, :, :, r:2 * r + 1]
255
+ middle = input[:, :, :, 2 * r + 1:] - input[:, :, :, :-2 * r - 1]
256
+ right = input[:, :, :, -1:] - input[:, :, :, -2 * r - 1: -r - 1]
257
+
258
+ output = torch.cat([left, middle, right], dim=3)
259
+
260
+ return output
261
+
262
+
263
+ class BoxFilter(nn.Module):
264
+ def __init__(self, r):
265
+ super(BoxFilter, self).__init__()
266
+
267
+ self.r = r
268
+
269
+ def forward(self, x):
270
+ assert x.dim() == 4
271
+
272
+ return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
external/landmark_detection/FaceBoxesV2/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import detector
2
+ from . import faceboxes_detector
external/landmark_detection/FaceBoxesV2/detector.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ class Detector(object):
4
+ def __init__(self, model_arch, model_weights):
5
+ self.model_arch = model_arch
6
+ self.model_weights = model_weights
7
+
8
+ def detect(self, image, thresh):
9
+ raise NotImplementedError
10
+
11
+ def crop(self, image, detections):
12
+ crops = []
13
+ for det in detections:
14
+ xmin = max(det[2], 0)
15
+ ymin = max(det[3], 0)
16
+ width = det[4]
17
+ height = det[5]
18
+ xmax = min(xmin+width, image.shape[1])
19
+ ymax = min(ymin+height, image.shape[0])
20
+ cut = image[ymin:ymax, xmin:xmax,:]
21
+ crops.append(cut)
22
+
23
+ return crops
24
+
25
+ def draw(self, image, detections, im_scale=None):
26
+ if im_scale is not None:
27
+ image = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
28
+ detections = [[det[0],det[1],int(det[2]*im_scale),int(det[3]*im_scale),int(det[4]*im_scale),int(det[5]*im_scale)] for det in detections]
29
+
30
+ for det in detections:
31
+ xmin = det[2]
32
+ ymin = det[3]
33
+ width = det[4]
34
+ height = det[5]
35
+ xmax = xmin + width
36
+ ymax = ymin + height
37
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
38
+
39
+ return image
external/landmark_detection/FaceBoxesV2/faceboxes_detector.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .detector import Detector
2
+ import cv2, os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from .utils.config import cfg
7
+ from .utils.prior_box import PriorBox
8
+ from .utils.nms_wrapper import nms
9
+ from .utils.faceboxes import FaceBoxesV2
10
+ from .utils.box_utils import decode
11
+ import time
12
+
13
+ class FaceBoxesDetector(Detector):
14
+ def __init__(self, model_arch, model_weights, use_gpu, device):
15
+ super().__init__(model_arch, model_weights)
16
+ self.name = 'FaceBoxesDetector'
17
+ self.net = FaceBoxesV2(phase='test', size=None, num_classes=2) # initialize detector
18
+ self.use_gpu = use_gpu
19
+ self.device = device
20
+
21
+ state_dict = torch.load(self.model_weights, map_location=self.device)
22
+ # create new OrderedDict that does not contain `module.`
23
+ from collections import OrderedDict
24
+ new_state_dict = OrderedDict()
25
+ for k, v in state_dict.items():
26
+ name = k[7:] # remove `module.`
27
+ new_state_dict[name] = v
28
+ # load params
29
+ self.net.load_state_dict(new_state_dict)
30
+ self.net = self.net.to(self.device)
31
+ self.net.eval()
32
+
33
+
34
+ def detect(self, image, thresh=0.6, im_scale=None):
35
+ # auto resize for large images
36
+ if im_scale is None:
37
+ height, width, _ = image.shape
38
+ if min(height, width) > 600:
39
+ im_scale = 600. / min(height, width)
40
+ else:
41
+ im_scale = 1
42
+ image_scale = cv2.resize(image, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
43
+
44
+ scale = torch.Tensor([image_scale.shape[1], image_scale.shape[0], image_scale.shape[1], image_scale.shape[0]])
45
+ image_scale = torch.from_numpy(image_scale.transpose(2,0,1)).to(self.device).int()
46
+ mean_tmp = torch.IntTensor([104, 117, 123]).to(self.device)
47
+ mean_tmp = mean_tmp.unsqueeze(1).unsqueeze(2)
48
+ image_scale -= mean_tmp
49
+ image_scale = image_scale.float().unsqueeze(0)
50
+ scale = scale.to(self.device)
51
+
52
+ with torch.no_grad():
53
+ out = self.net(image_scale)
54
+ #priorbox = PriorBox(cfg, out[2], (image_scale.size()[2], image_scale.size()[3]), phase='test')
55
+ priorbox = PriorBox(cfg, image_size=(image_scale.size()[2], image_scale.size()[3]))
56
+ priors = priorbox.forward()
57
+ priors = priors.to(self.device)
58
+ loc, conf = out
59
+ prior_data = priors.data
60
+ boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
61
+ boxes = boxes * scale
62
+ boxes = boxes.cpu().numpy()
63
+ scores = conf.data.cpu().numpy()[:, 1]
64
+
65
+ # ignore low scores
66
+ inds = np.where(scores > thresh)[0]
67
+ boxes = boxes[inds]
68
+ scores = scores[inds]
69
+
70
+ # keep top-K before NMS
71
+ order = scores.argsort()[::-1][:5000]
72
+ boxes = boxes[order]
73
+ scores = scores[order]
74
+
75
+ # do NMS
76
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
77
+ keep = nms(dets, 0.3)
78
+ dets = dets[keep, :]
79
+
80
+ dets = dets[:750, :]
81
+ detections_scale = []
82
+ for i in range(dets.shape[0]):
83
+ xmin = int(dets[i][0])
84
+ ymin = int(dets[i][1])
85
+ xmax = int(dets[i][2])
86
+ ymax = int(dets[i][3])
87
+ score = dets[i][4]
88
+ width = xmax - xmin
89
+ height = ymax - ymin
90
+ detections_scale.append(['face', score, xmin, ymin, width, height])
91
+
92
+ # adapt bboxes to the original image size
93
+ if len(detections_scale) > 0:
94
+ detections_scale = [[det[0],det[1],int(det[2]/im_scale),int(det[3]/im_scale),int(det[4]/im_scale),int(det[5]/im_scale)] for det in detections_scale]
95
+
96
+ return detections_scale, im_scale
97
+
external/landmark_detection/FaceBoxesV2/utils/__init__.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/box_utils.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def point_form(boxes):
6
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7
+ representation for comparison to point form ground truth data.
8
+ Args:
9
+ boxes: (tensor) center-size default boxes from priorbox layers.
10
+ Return:
11
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12
+ """
13
+ return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14
+ boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15
+
16
+
17
+ def center_size(boxes):
18
+ """ Convert prior_boxes to (cx, cy, w, h)
19
+ representation for comparison to center-size form ground truth data.
20
+ Args:
21
+ boxes: (tensor) point_form boxes
22
+ Return:
23
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24
+ """
25
+ return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26
+ boxes[:, 2:] - boxes[:, :2], 1) # w, h
27
+
28
+
29
+ def intersect(box_a, box_b):
30
+ """ We resize both tensors to [A,B,2] without new malloc:
31
+ [A,2] -> [A,1,2] -> [A,B,2]
32
+ [B,2] -> [1,B,2] -> [A,B,2]
33
+ Then we compute the area of intersect between box_a and box_b.
34
+ Args:
35
+ box_a: (tensor) bounding boxes, Shape: [A,4].
36
+ box_b: (tensor) bounding boxes, Shape: [B,4].
37
+ Return:
38
+ (tensor) intersection area, Shape: [A,B].
39
+ """
40
+ A = box_a.size(0)
41
+ B = box_b.size(0)
42
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43
+ box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45
+ box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46
+ inter = torch.clamp((max_xy - min_xy), min=0)
47
+ return inter[:, :, 0] * inter[:, :, 1]
48
+
49
+
50
+ def jaccard(box_a, box_b):
51
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52
+ is simply the intersection over union of two boxes. Here we operate on
53
+ ground truth boxes and default boxes.
54
+ E.g.:
55
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56
+ Args:
57
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59
+ Return:
60
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61
+ """
62
+ inter = intersect(box_a, box_b)
63
+ area_a = ((box_a[:, 2]-box_a[:, 0]) *
64
+ (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65
+ area_b = ((box_b[:, 2]-box_b[:, 0]) *
66
+ (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67
+ union = area_a + area_b - inter
68
+ return inter / union # [A,B]
69
+
70
+
71
+ def matrix_iou(a, b):
72
+ """
73
+ return iou of a and b, numpy version for data augenmentation
74
+ """
75
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
76
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
77
+
78
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
79
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
80
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
81
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
82
+
83
+
84
+ def matrix_iof(a, b):
85
+ """
86
+ return iof of a and b, numpy version for data augenmentation
87
+ """
88
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
89
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
90
+
91
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
92
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
93
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
94
+
95
+
96
+ def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
97
+ """Match each prior box with the ground truth box of the highest jaccard
98
+ overlap, encode the bounding boxes, then return the matched indices
99
+ corresponding to both confidence and location preds.
100
+ Args:
101
+ threshold: (float) The overlap threshold used when mathing boxes.
102
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
103
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
104
+ variances: (tensor) Variances corresponding to each prior coord,
105
+ Shape: [num_priors, 4].
106
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
107
+ loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
108
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
109
+ idx: (int) current batch index
110
+ Return:
111
+ The matched indices corresponding to 1)location and 2)confidence preds.
112
+ """
113
+ # jaccard index
114
+ overlaps = jaccard(
115
+ truths,
116
+ point_form(priors)
117
+ )
118
+ # (Bipartite Matching)
119
+ # [1,num_objects] best prior for each ground truth
120
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
121
+
122
+ # ignore hard gt
123
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
124
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
125
+ if best_prior_idx_filter.shape[0] <= 0:
126
+ loc_t[idx] = 0
127
+ conf_t[idx] = 0
128
+ return
129
+
130
+ # [1,num_priors] best ground truth for each prior
131
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
132
+ best_truth_idx.squeeze_(0)
133
+ best_truth_overlap.squeeze_(0)
134
+ best_prior_idx.squeeze_(1)
135
+ best_prior_idx_filter.squeeze_(1)
136
+ best_prior_overlap.squeeze_(1)
137
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
138
+ # TODO refactor: index best_prior_idx with long tensor
139
+ # ensure every gt matches with its prior of max overlap
140
+ for j in range(best_prior_idx.size(0)):
141
+ best_truth_idx[best_prior_idx[j]] = j
142
+ matches = truths[best_truth_idx] # Shape: [num_priors,4]
143
+ conf = labels[best_truth_idx] # Shape: [num_priors]
144
+ conf[best_truth_overlap < threshold] = 0 # label as background
145
+ loc = encode(matches, priors, variances)
146
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
147
+ conf_t[idx] = conf # [num_priors] top class label for each prior
148
+
149
+
150
+ def encode(matched, priors, variances):
151
+ """Encode the variances from the priorbox layers into the ground truth boxes
152
+ we have matched (based on jaccard overlap) with the prior boxes.
153
+ Args:
154
+ matched: (tensor) Coords of ground truth for each prior in point-form
155
+ Shape: [num_priors, 4].
156
+ priors: (tensor) Prior boxes in center-offset form
157
+ Shape: [num_priors,4].
158
+ variances: (list[float]) Variances of priorboxes
159
+ Return:
160
+ encoded boxes (tensor), Shape: [num_priors, 4]
161
+ """
162
+
163
+ # dist b/t match center and prior's center
164
+ g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
165
+ # encode variance
166
+ g_cxcy /= (variances[0] * priors[:, 2:])
167
+ # match wh / prior wh
168
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
169
+ g_wh = torch.log(g_wh) / variances[1]
170
+ # return target for smooth_l1_loss
171
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
172
+
173
+
174
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
175
+ def decode(loc, priors, variances):
176
+ """Decode locations from predictions using priors to undo
177
+ the encoding we did for offset regression at train time.
178
+ Args:
179
+ loc (tensor): location predictions for loc layers,
180
+ Shape: [num_priors,4]
181
+ priors (tensor): Prior boxes in center-offset form.
182
+ Shape: [num_priors,4].
183
+ variances: (list[float]) Variances of priorboxes
184
+ Return:
185
+ decoded bounding box predictions
186
+ """
187
+
188
+ boxes = torch.cat((
189
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
190
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
191
+ boxes[:, :2] -= boxes[:, 2:] / 2
192
+ boxes[:, 2:] += boxes[:, :2]
193
+ return boxes
194
+
195
+
196
+ def log_sum_exp(x):
197
+ """Utility function for computing log_sum_exp while determining
198
+ This will be used to determine unaveraged confidence loss across
199
+ all examples in a batch.
200
+ Args:
201
+ x (Variable(tensor)): conf_preds from conf layers
202
+ """
203
+ x_max = x.data.max()
204
+ return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
205
+
206
+
207
+ # Original author: Francisco Massa:
208
+ # https://github.com/fmassa/object-detection.torch
209
+ # Ported to PyTorch by Max deGroot (02/01/2017)
210
+ def nms(boxes, scores, overlap=0.5, top_k=200):
211
+ """Apply non-maximum suppression at test time to avoid detecting too many
212
+ overlapping bounding boxes for a given object.
213
+ Args:
214
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
215
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
216
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
217
+ top_k: (int) The Maximum number of box preds to consider.
218
+ Return:
219
+ The indices of the kept boxes with respect to num_priors.
220
+ """
221
+
222
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
223
+ if boxes.numel() == 0:
224
+ return keep
225
+ x1 = boxes[:, 0]
226
+ y1 = boxes[:, 1]
227
+ x2 = boxes[:, 2]
228
+ y2 = boxes[:, 3]
229
+ area = torch.mul(x2 - x1, y2 - y1)
230
+ v, idx = scores.sort(0) # sort in ascending order
231
+ # I = I[v >= 0.01]
232
+ idx = idx[-top_k:] # indices of the top-k largest vals
233
+ xx1 = boxes.new()
234
+ yy1 = boxes.new()
235
+ xx2 = boxes.new()
236
+ yy2 = boxes.new()
237
+ w = boxes.new()
238
+ h = boxes.new()
239
+
240
+ # keep = torch.Tensor()
241
+ count = 0
242
+ while idx.numel() > 0:
243
+ i = idx[-1] # index of current largest val
244
+ # keep.append(i)
245
+ keep[count] = i
246
+ count += 1
247
+ if idx.size(0) == 1:
248
+ break
249
+ idx = idx[:-1] # remove kept element from view
250
+ # load bboxes of next highest vals
251
+ torch.index_select(x1, 0, idx, out=xx1)
252
+ torch.index_select(y1, 0, idx, out=yy1)
253
+ torch.index_select(x2, 0, idx, out=xx2)
254
+ torch.index_select(y2, 0, idx, out=yy2)
255
+ # store element-wise max with next highest score
256
+ xx1 = torch.clamp(xx1, min=x1[i])
257
+ yy1 = torch.clamp(yy1, min=y1[i])
258
+ xx2 = torch.clamp(xx2, max=x2[i])
259
+ yy2 = torch.clamp(yy2, max=y2[i])
260
+ w.resize_as_(xx2)
261
+ h.resize_as_(yy2)
262
+ w = xx2 - xx1
263
+ h = yy2 - yy1
264
+ # check sizes of xx1 and xx2.. after each iteration
265
+ w = torch.clamp(w, min=0.0)
266
+ h = torch.clamp(h, min=0.0)
267
+ inter = w*h
268
+ # IoU = i / (area(a) + area(b) - i)
269
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
270
+ union = (rem_areas - inter) + area[i]
271
+ IoU = inter/union # store result in iou
272
+ # keep only elements with an IoU <= overlap
273
+ idx = idx[IoU.le(overlap)]
274
+ return keep, count
275
+
276
+
external/landmark_detection/FaceBoxesV2/utils/build.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ # --------------------------------------------------------
4
+ # Fast R-CNN
5
+ # Copyright (c) 2015 Microsoft
6
+ # Licensed under The MIT License [see LICENSE for details]
7
+ # Written by Ross Girshick
8
+ # --------------------------------------------------------
9
+
10
+ import os
11
+ from os.path import join as pjoin
12
+ import numpy as np
13
+ from distutils.core import setup
14
+ from distutils.extension import Extension
15
+ from Cython.Distutils import build_ext
16
+
17
+
18
+ def find_in_path(name, path):
19
+ "Find a file in a search path"
20
+ # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
21
+ for dir in path.split(os.pathsep):
22
+ binpath = pjoin(dir, name)
23
+ if os.path.exists(binpath):
24
+ return os.path.abspath(binpath)
25
+ return None
26
+
27
+
28
+ # Obtain the numpy include directory. This logic works across numpy versions.
29
+ try:
30
+ numpy_include = np.get_include()
31
+ except AttributeError:
32
+ numpy_include = np.get_numpy_include()
33
+
34
+
35
+ # run the customize_compiler
36
+ class custom_build_ext(build_ext):
37
+ def build_extensions(self):
38
+ # customize_compiler_for_nvcc(self.compiler)
39
+ build_ext.build_extensions(self)
40
+
41
+
42
+ ext_modules = [
43
+ Extension(
44
+ "nms.cpu_nms",
45
+ ["nms/cpu_nms.pyx"],
46
+ # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
47
+ extra_compile_args=["-Wno-cpp", "-Wno-unused-function"],
48
+ include_dirs=[numpy_include]
49
+ )
50
+ ]
51
+
52
+ setup(
53
+ name='mot_utils',
54
+ ext_modules=ext_modules,
55
+ # inject our custom trigger
56
+ cmdclass={'build_ext': custom_build_ext},
57
+ )
external/landmark_detection/FaceBoxesV2/utils/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ cfg = {
4
+ 'name': 'FaceBoxes',
5
+ #'min_dim': 1024,
6
+ #'feature_maps': [[32, 32], [16, 16], [8, 8]],
7
+ # 'aspect_ratios': [[1], [1], [1]],
8
+ 'min_sizes': [[32, 64, 128], [256], [512]],
9
+ 'steps': [32, 64, 128],
10
+ 'variance': [0.1, 0.2],
11
+ 'clip': False,
12
+ 'loc_weight': 2.0,
13
+ 'gpu_train': True
14
+ }
external/landmark_detection/FaceBoxesV2/utils/faceboxes.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BasicConv2d(nn.Module):
7
+
8
+ def __init__(self, in_channels, out_channels, **kwargs):
9
+ super(BasicConv2d, self).__init__()
10
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
11
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
12
+
13
+ def forward(self, x):
14
+ x = self.conv(x)
15
+ x = self.bn(x)
16
+ return F.relu(x, inplace=True)
17
+
18
+
19
+ class Inception(nn.Module):
20
+
21
+ def __init__(self):
22
+ super(Inception, self).__init__()
23
+ self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)
24
+ self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)
25
+ self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)
26
+ self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)
27
+ self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)
28
+ self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)
29
+ self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
30
+
31
+ def forward(self, x):
32
+ branch1x1 = self.branch1x1(x)
33
+
34
+ branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
35
+ branch1x1_2 = self.branch1x1_2(branch1x1_pool)
36
+
37
+ branch3x3_reduce = self.branch3x3_reduce(x)
38
+ branch3x3 = self.branch3x3(branch3x3_reduce)
39
+
40
+ branch3x3_reduce_2 = self.branch3x3_reduce_2(x)
41
+ branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)
42
+ branch3x3_3 = self.branch3x3_3(branch3x3_2)
43
+
44
+ outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3]
45
+ return torch.cat(outputs, 1)
46
+
47
+
48
+ class CRelu(nn.Module):
49
+
50
+ def __init__(self, in_channels, out_channels, **kwargs):
51
+ super(CRelu, self).__init__()
52
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
53
+ self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
54
+
55
+ def forward(self, x):
56
+ x = self.conv(x)
57
+ x = self.bn(x)
58
+ x = torch.cat([x, -x], 1)
59
+ x = F.relu(x, inplace=True)
60
+ return x
61
+
62
+
63
+ class FaceBoxes(nn.Module):
64
+
65
+ def __init__(self, phase, size, num_classes):
66
+ super(FaceBoxes, self).__init__()
67
+ self.phase = phase
68
+ self.num_classes = num_classes
69
+ self.size = size
70
+
71
+ self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)
72
+ self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)
73
+
74
+ self.inception1 = Inception()
75
+ self.inception2 = Inception()
76
+ self.inception3 = Inception()
77
+
78
+ self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
79
+ self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
80
+
81
+ self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
82
+ self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
83
+
84
+ self.loc, self.conf = self.multibox(self.num_classes)
85
+
86
+ if self.phase == 'test':
87
+ self.softmax = nn.Softmax(dim=-1)
88
+
89
+ if self.phase == 'train':
90
+ for m in self.modules():
91
+ if isinstance(m, nn.Conv2d):
92
+ if m.bias is not None:
93
+ nn.init.xavier_normal_(m.weight.data)
94
+ m.bias.data.fill_(0.02)
95
+ else:
96
+ m.weight.data.normal_(0, 0.01)
97
+ elif isinstance(m, nn.BatchNorm2d):
98
+ m.weight.data.fill_(1)
99
+ m.bias.data.zero_()
100
+
101
+ def multibox(self, num_classes):
102
+ loc_layers = []
103
+ conf_layers = []
104
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
105
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
106
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
107
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
108
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
109
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
110
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
111
+
112
+ def forward(self, x):
113
+
114
+ detection_sources = list()
115
+ loc = list()
116
+ conf = list()
117
+
118
+ x = self.conv1(x)
119
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
120
+ x = self.conv2(x)
121
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
122
+ x = self.inception1(x)
123
+ x = self.inception2(x)
124
+ x = self.inception3(x)
125
+ detection_sources.append(x)
126
+
127
+ x = self.conv3_1(x)
128
+ x = self.conv3_2(x)
129
+ detection_sources.append(x)
130
+
131
+ x = self.conv4_1(x)
132
+ x = self.conv4_2(x)
133
+ detection_sources.append(x)
134
+
135
+ for (x, l, c) in zip(detection_sources, self.loc, self.conf):
136
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
137
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
138
+
139
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
140
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
141
+
142
+ if self.phase == "test":
143
+ output = (loc.view(loc.size(0), -1, 4),
144
+ self.softmax(conf.view(conf.size(0), -1, self.num_classes)))
145
+ else:
146
+ output = (loc.view(loc.size(0), -1, 4),
147
+ conf.view(conf.size(0), -1, self.num_classes))
148
+
149
+ return output
150
+
151
+ class FaceBoxesV2(nn.Module):
152
+
153
+ def __init__(self, phase, size, num_classes):
154
+ super(FaceBoxesV2, self).__init__()
155
+ self.phase = phase
156
+ self.num_classes = num_classes
157
+ self.size = size
158
+
159
+ self.conv1 = BasicConv2d(3, 8, kernel_size=3, stride=2, padding=1)
160
+ self.conv2 = BasicConv2d(8, 16, kernel_size=3, stride=2, padding=1)
161
+ self.conv3 = BasicConv2d(16, 32, kernel_size=3, stride=2, padding=1)
162
+ self.conv4 = BasicConv2d(32, 64, kernel_size=3, stride=2, padding=1)
163
+ self.conv5 = BasicConv2d(64, 128, kernel_size=3, stride=2, padding=1)
164
+
165
+ self.inception1 = Inception()
166
+ self.inception2 = Inception()
167
+ self.inception3 = Inception()
168
+
169
+ self.conv6_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
170
+ self.conv6_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
171
+
172
+ self.conv7_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
173
+ self.conv7_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
174
+
175
+ self.loc, self.conf = self.multibox(self.num_classes)
176
+
177
+ if self.phase == 'test':
178
+ self.softmax = nn.Softmax(dim=-1)
179
+
180
+ if self.phase == 'train':
181
+ for m in self.modules():
182
+ if isinstance(m, nn.Conv2d):
183
+ if m.bias is not None:
184
+ nn.init.xavier_normal_(m.weight.data)
185
+ m.bias.data.fill_(0.02)
186
+ else:
187
+ m.weight.data.normal_(0, 0.01)
188
+ elif isinstance(m, nn.BatchNorm2d):
189
+ m.weight.data.fill_(1)
190
+ m.bias.data.zero_()
191
+
192
+ def multibox(self, num_classes):
193
+ loc_layers = []
194
+ conf_layers = []
195
+ loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
196
+ conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
197
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
198
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
199
+ loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
200
+ conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
201
+ return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
202
+
203
+ def forward(self, x):
204
+
205
+ sources = list()
206
+ loc = list()
207
+ conf = list()
208
+
209
+ x = self.conv1(x)
210
+ x = self.conv2(x)
211
+ x = self.conv3(x)
212
+ x = self.conv4(x)
213
+ x = self.conv5(x)
214
+ x = self.inception1(x)
215
+ x = self.inception2(x)
216
+ x = self.inception3(x)
217
+ sources.append(x)
218
+ x = self.conv6_1(x)
219
+ x = self.conv6_2(x)
220
+ sources.append(x)
221
+ x = self.conv7_1(x)
222
+ x = self.conv7_2(x)
223
+ sources.append(x)
224
+
225
+ for (x, l, c) in zip(sources, self.loc, self.conf):
226
+ loc.append(l(x).permute(0, 2, 3, 1).contiguous())
227
+ conf.append(c(x).permute(0, 2, 3, 1).contiguous())
228
+
229
+ loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
230
+ conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
231
+
232
+ if self.phase == "test":
233
+ output = (loc.view(loc.size(0), -1, 4),
234
+ self.softmax(conf.view(-1, self.num_classes)))
235
+ else:
236
+ output = (loc.view(loc.size(0), -1, 4),
237
+ conf.view(conf.size(0), -1, self.num_classes))
238
+
239
+ return output
external/landmark_detection/FaceBoxesV2/utils/make.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python3 build.py build_ext --inplace
3
+
external/landmark_detection/FaceBoxesV2/utils/nms/__init__.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.c ADDED
The diff for this file is too large to render. See raw diff
 
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.py ADDED
File without changes
external/landmark_detection/FaceBoxesV2/utils/nms/cpu_nms.pyx ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ cimport numpy as np
10
+
11
+ cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
12
+ return a if a >= b else b
13
+
14
+ cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
15
+ return a if a <= b else b
16
+
17
+ def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
18
+ cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
19
+ cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
20
+ cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
21
+ cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
22
+ cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
23
+
24
+ cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
25
+ cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1]
26
+
27
+ cdef int ndets = dets.shape[0]
28
+ cdef np.ndarray[np.int_t, ndim=1] suppressed = \
29
+ np.zeros((ndets), dtype=np.int)
30
+
31
+ # nominal indices
32
+ cdef int _i, _j
33
+ # sorted indices
34
+ cdef int i, j
35
+ # temp variables for box i's (the box currently under consideration)
36
+ cdef np.float32_t ix1, iy1, ix2, iy2, iarea
37
+ # variables for computing overlap with box j (lower scoring box)
38
+ cdef np.float32_t xx1, yy1, xx2, yy2
39
+ cdef np.float32_t w, h
40
+ cdef np.float32_t inter, ovr
41
+
42
+ keep = []
43
+ for _i in range(ndets):
44
+ i = order[_i]
45
+ if suppressed[i] == 1:
46
+ continue
47
+ keep.append(i)
48
+ ix1 = x1[i]
49
+ iy1 = y1[i]
50
+ ix2 = x2[i]
51
+ iy2 = y2[i]
52
+ iarea = areas[i]
53
+ for _j in range(_i + 1, ndets):
54
+ j = order[_j]
55
+ if suppressed[j] == 1:
56
+ continue
57
+ xx1 = max(ix1, x1[j])
58
+ yy1 = max(iy1, y1[j])
59
+ xx2 = min(ix2, x2[j])
60
+ yy2 = min(iy2, y2[j])
61
+ w = max(0.0, xx2 - xx1 + 1)
62
+ h = max(0.0, yy2 - yy1 + 1)
63
+ inter = w * h
64
+ ovr = inter / (iarea + areas[j] - inter)
65
+ if ovr >= thresh:
66
+ suppressed[j] = 1
67
+
68
+ return keep
69
+
70
+ def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0):
71
+ cdef unsigned int N = boxes.shape[0]
72
+ cdef float iw, ih, box_area
73
+ cdef float ua
74
+ cdef int pos = 0
75
+ cdef float maxscore = 0
76
+ cdef int maxpos = 0
77
+ cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov
78
+
79
+ for i in range(N):
80
+ maxscore = boxes[i, 4]
81
+ maxpos = i
82
+
83
+ tx1 = boxes[i,0]
84
+ ty1 = boxes[i,1]
85
+ tx2 = boxes[i,2]
86
+ ty2 = boxes[i,3]
87
+ ts = boxes[i,4]
88
+
89
+ pos = i + 1
90
+ # get max box
91
+ while pos < N:
92
+ if maxscore < boxes[pos, 4]:
93
+ maxscore = boxes[pos, 4]
94
+ maxpos = pos
95
+ pos = pos + 1
96
+
97
+ # add max box as a detection
98
+ boxes[i,0] = boxes[maxpos,0]
99
+ boxes[i,1] = boxes[maxpos,1]
100
+ boxes[i,2] = boxes[maxpos,2]
101
+ boxes[i,3] = boxes[maxpos,3]
102
+ boxes[i,4] = boxes[maxpos,4]
103
+
104
+ # swap ith box with position of max box
105
+ boxes[maxpos,0] = tx1
106
+ boxes[maxpos,1] = ty1
107
+ boxes[maxpos,2] = tx2
108
+ boxes[maxpos,3] = ty2
109
+ boxes[maxpos,4] = ts
110
+
111
+ tx1 = boxes[i,0]
112
+ ty1 = boxes[i,1]
113
+ tx2 = boxes[i,2]
114
+ ty2 = boxes[i,3]
115
+ ts = boxes[i,4]
116
+
117
+ pos = i + 1
118
+ # NMS iterations, note that N changes if detection boxes fall below threshold
119
+ while pos < N:
120
+ x1 = boxes[pos, 0]
121
+ y1 = boxes[pos, 1]
122
+ x2 = boxes[pos, 2]
123
+ y2 = boxes[pos, 3]
124
+ s = boxes[pos, 4]
125
+
126
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
127
+ iw = (min(tx2, x2) - max(tx1, x1) + 1)
128
+ if iw > 0:
129
+ ih = (min(ty2, y2) - max(ty1, y1) + 1)
130
+ if ih > 0:
131
+ ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
132
+ ov = iw * ih / ua #iou between max box and detection box
133
+
134
+ if method == 1: # linear
135
+ if ov > Nt:
136
+ weight = 1 - ov
137
+ else:
138
+ weight = 1
139
+ elif method == 2: # gaussian
140
+ weight = np.exp(-(ov * ov)/sigma)
141
+ else: # original NMS
142
+ if ov > Nt:
143
+ weight = 0
144
+ else:
145
+ weight = 1
146
+
147
+ boxes[pos, 4] = weight*boxes[pos, 4]
148
+
149
+ # if box score falls below threshold, discard the box by swapping with last box
150
+ # update N
151
+ if boxes[pos, 4] < threshold:
152
+ boxes[pos,0] = boxes[N-1, 0]
153
+ boxes[pos,1] = boxes[N-1, 1]
154
+ boxes[pos,2] = boxes[N-1, 2]
155
+ boxes[pos,3] = boxes[N-1, 3]
156
+ boxes[pos,4] = boxes[N-1, 4]
157
+ N = N - 1
158
+ pos = pos - 1
159
+
160
+ pos = pos + 1
161
+
162
+ keep = [i for i in range(N)]
163
+ return keep
external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.hpp ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
2
+ int boxes_dim, float nms_overlap_thresh, int device_id);
external/landmark_detection/FaceBoxesV2/utils/nms/gpu_nms.pyx ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Faster R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ cimport numpy as np
10
+
11
+ assert sizeof(int) == sizeof(np.int32_t)
12
+
13
+ cdef extern from "gpu_nms.hpp":
14
+ void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
15
+
16
+ def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
17
+ np.int32_t device_id=0):
18
+ cdef int boxes_num = dets.shape[0]
19
+ cdef int boxes_dim = dets.shape[1]
20
+ cdef int num_out
21
+ cdef np.ndarray[np.int32_t, ndim=1] \
22
+ keep = np.zeros(boxes_num, dtype=np.int32)
23
+ cdef np.ndarray[np.float32_t, ndim=1] \
24
+ scores = dets[:, 4]
25
+ cdef np.ndarray[np.int_t, ndim=1] \
26
+ order = scores.argsort()[::-1]
27
+ cdef np.ndarray[np.float32_t, ndim=2] \
28
+ sorted_dets = dets[order, :]
29
+ _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
30
+ keep = keep[:num_out]
31
+ return list(order[keep])
external/landmark_detection/FaceBoxesV2/utils/nms/nms_kernel.cu ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ------------------------------------------------------------------
2
+ // Faster R-CNN
3
+ // Copyright (c) 2015 Microsoft
4
+ // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
5
+ // Written by Shaoqing Ren
6
+ // ------------------------------------------------------------------
7
+
8
+ #include "gpu_nms.hpp"
9
+ #include <vector>
10
+ #include <iostream>
11
+
12
+ #define CUDA_CHECK(condition) \
13
+ /* Code block avoids redefinition of cudaError_t error */ \
14
+ do { \
15
+ cudaError_t error = condition; \
16
+ if (error != cudaSuccess) { \
17
+ std::cout << cudaGetErrorString(error) << std::endl; \
18
+ } \
19
+ } while (0)
20
+
21
+ #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
22
+ int const threadsPerBlock = sizeof(unsigned long long) * 8;
23
+
24
+ __device__ inline float devIoU(float const * const a, float const * const b) {
25
+ float left = max(a[0], b[0]), right = min(a[2], b[2]);
26
+ float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
27
+ float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
28
+ float interS = width * height;
29
+ float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
30
+ float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
31
+ return interS / (Sa + Sb - interS);
32
+ }
33
+
34
+ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
35
+ const float *dev_boxes, unsigned long long *dev_mask) {
36
+ const int row_start = blockIdx.y;
37
+ const int col_start = blockIdx.x;
38
+
39
+ // if (row_start > col_start) return;
40
+
41
+ const int row_size =
42
+ min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
43
+ const int col_size =
44
+ min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
45
+
46
+ __shared__ float block_boxes[threadsPerBlock * 5];
47
+ if (threadIdx.x < col_size) {
48
+ block_boxes[threadIdx.x * 5 + 0] =
49
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
50
+ block_boxes[threadIdx.x * 5 + 1] =
51
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
52
+ block_boxes[threadIdx.x * 5 + 2] =
53
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
54
+ block_boxes[threadIdx.x * 5 + 3] =
55
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
56
+ block_boxes[threadIdx.x * 5 + 4] =
57
+ dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
58
+ }
59
+ __syncthreads();
60
+
61
+ if (threadIdx.x < row_size) {
62
+ const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
63
+ const float *cur_box = dev_boxes + cur_box_idx * 5;
64
+ int i = 0;
65
+ unsigned long long t = 0;
66
+ int start = 0;
67
+ if (row_start == col_start) {
68
+ start = threadIdx.x + 1;
69
+ }
70
+ for (i = start; i < col_size; i++) {
71
+ if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
72
+ t |= 1ULL << i;
73
+ }
74
+ }
75
+ const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
76
+ dev_mask[cur_box_idx * col_blocks + col_start] = t;
77
+ }
78
+ }
79
+
80
+ void _set_device(int device_id) {
81
+ int current_device;
82
+ CUDA_CHECK(cudaGetDevice(&current_device));
83
+ if (current_device == device_id) {
84
+ return;
85
+ }
86
+ // The call to cudaSetDevice must come before any calls to Get, which
87
+ // may perform initialization using the GPU.
88
+ CUDA_CHECK(cudaSetDevice(device_id));
89
+ }
90
+
91
+ void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
92
+ int boxes_dim, float nms_overlap_thresh, int device_id) {
93
+ _set_device(device_id);
94
+
95
+ float* boxes_dev = NULL;
96
+ unsigned long long* mask_dev = NULL;
97
+
98
+ const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
99
+
100
+ CUDA_CHECK(cudaMalloc(&boxes_dev,
101
+ boxes_num * boxes_dim * sizeof(float)));
102
+ CUDA_CHECK(cudaMemcpy(boxes_dev,
103
+ boxes_host,
104
+ boxes_num * boxes_dim * sizeof(float),
105
+ cudaMemcpyHostToDevice));
106
+
107
+ CUDA_CHECK(cudaMalloc(&mask_dev,
108
+ boxes_num * col_blocks * sizeof(unsigned long long)));
109
+
110
+ dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
111
+ DIVUP(boxes_num, threadsPerBlock));
112
+ dim3 threads(threadsPerBlock);
113
+ nms_kernel<<<blocks, threads>>>(boxes_num,
114
+ nms_overlap_thresh,
115
+ boxes_dev,
116
+ mask_dev);
117
+
118
+ std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
119
+ CUDA_CHECK(cudaMemcpy(&mask_host[0],
120
+ mask_dev,
121
+ sizeof(unsigned long long) * boxes_num * col_blocks,
122
+ cudaMemcpyDeviceToHost));
123
+
124
+ std::vector<unsigned long long> remv(col_blocks);
125
+ memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
126
+
127
+ int num_to_keep = 0;
128
+ for (int i = 0; i < boxes_num; i++) {
129
+ int nblock = i / threadsPerBlock;
130
+ int inblock = i % threadsPerBlock;
131
+
132
+ if (!(remv[nblock] & (1ULL << inblock))) {
133
+ keep_out[num_to_keep++] = i;
134
+ unsigned long long *p = &mask_host[0] + i * col_blocks;
135
+ for (int j = nblock; j < col_blocks; j++) {
136
+ remv[j] |= p[j];
137
+ }
138
+ }
139
+ }
140
+ *num_out = num_to_keep;
141
+
142
+ CUDA_CHECK(cudaFree(boxes_dev));
143
+ CUDA_CHECK(cudaFree(mask_dev));
144
+ }
external/landmark_detection/FaceBoxesV2/utils/nms/py_cpu_nms.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+
10
+ def py_cpu_nms(dets, thresh):
11
+ """Pure Python NMS baseline."""
12
+ x1 = dets[:, 0]
13
+ y1 = dets[:, 1]
14
+ x2 = dets[:, 2]
15
+ y2 = dets[:, 3]
16
+ scores = dets[:, 4]
17
+
18
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
19
+ order = scores.argsort()[::-1]
20
+
21
+ keep = []
22
+ while order.size > 0:
23
+ i = order[0]
24
+ keep.append(i)
25
+ xx1 = np.maximum(x1[i], x1[order[1:]])
26
+ yy1 = np.maximum(y1[i], y1[order[1:]])
27
+ xx2 = np.minimum(x2[i], x2[order[1:]])
28
+ yy2 = np.minimum(y2[i], y2[order[1:]])
29
+
30
+ w = np.maximum(0.0, xx2 - xx1 + 1)
31
+ h = np.maximum(0.0, yy2 - yy1 + 1)
32
+ inter = w * h
33
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
34
+
35
+ inds = np.where(ovr <= thresh)[0]
36
+ order = order[inds + 1]
37
+
38
+ return keep
external/landmark_detection/FaceBoxesV2/utils/nms_wrapper.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ from .nms.cpu_nms import cpu_nms, cpu_soft_nms
9
+
10
+ def nms(dets, thresh):
11
+ """Dispatch to either CPU or GPU NMS implementations."""
12
+
13
+ if dets.shape[0] == 0:
14
+ return []
15
+ return cpu_nms(dets, thresh)
external/landmark_detection/FaceBoxesV2/utils/prior_box.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from itertools import product as product
3
+ import numpy as np
4
+ from math import ceil
5
+
6
+
7
+ class PriorBox(object):
8
+ def __init__(self, cfg, image_size=None, phase='train'):
9
+ super(PriorBox, self).__init__()
10
+ #self.aspect_ratios = cfg['aspect_ratios']
11
+ self.min_sizes = cfg['min_sizes']
12
+ self.steps = cfg['steps']
13
+ self.clip = cfg['clip']
14
+ self.image_size = image_size
15
+ self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
16
+
17
+ def forward(self):
18
+ anchors = []
19
+ for k, f in enumerate(self.feature_maps):
20
+ min_sizes = self.min_sizes[k]
21
+ for i, j in product(range(f[0]), range(f[1])):
22
+ for min_size in min_sizes:
23
+ s_kx = min_size / self.image_size[1]
24
+ s_ky = min_size / self.image_size[0]
25
+ if min_size == 32:
26
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]]
27
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]]
28
+ for cy, cx in product(dense_cy, dense_cx):
29
+ anchors += [cx, cy, s_kx, s_ky]
30
+ elif min_size == 64:
31
+ dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]]
32
+ dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]]
33
+ for cy, cx in product(dense_cy, dense_cx):
34
+ anchors += [cx, cy, s_kx, s_ky]
35
+ else:
36
+ cx = (j + 0.5) * self.steps[k] / self.image_size[1]
37
+ cy = (i + 0.5) * self.steps[k] / self.image_size[0]
38
+ anchors += [cx, cy, s_kx, s_ky]
39
+ # back to torch land
40
+ output = torch.Tensor(anchors).view(-1, 4)
41
+ if self.clip:
42
+ output.clamp_(max=1, min=0)
43
+ return output
external/landmark_detection/FaceBoxesV2/utils/timer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Fast R-CNN
3
+ # Copyright (c) 2015 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ross Girshick
6
+ # --------------------------------------------------------
7
+
8
+ import time
9
+
10
+
11
+ class Timer(object):
12
+ """A simple timer."""
13
+ def __init__(self):
14
+ self.total_time = 0.
15
+ self.calls = 0
16
+ self.start_time = 0.
17
+ self.diff = 0.
18
+ self.average_time = 0.
19
+
20
+ def tic(self):
21
+ # using time.time instead of time.clock because time time.clock
22
+ # does not normalize for multithreading
23
+ self.start_time = time.time()
24
+
25
+ def toc(self, average=True):
26
+ self.diff = time.time() - self.start_time
27
+ self.total_time += self.diff
28
+ self.calls += 1
29
+ self.average_time = self.total_time / self.calls
30
+ if average:
31
+ return self.average_time
32
+ else:
33
+ return self.diff
34
+
35
+ def clear(self):
36
+ self.total_time = 0.
37
+ self.calls = 0
38
+ self.start_time = 0.
39
+ self.diff = 0.
40
+ self.average_time = 0.
external/landmark_detection/README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection.
2
+
3
+ Paper Link: [arxiv](https://arxiv.org/abs/2306.02763) | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhou_STAR_Loss_Reducing_Semantic_Ambiguity_in_Facial_Landmark_Detection_CVPR_2023_paper.pdf)
4
+
5
+
6
+ - Pytorch implementation of **S**elf-adap**T**ive **A**mbiguity **R**eduction (**STAR**) loss.
7
+ - STAR loss is a self-adaptive anisotropic direction loss, which can be used in heatmap regression-based methods for facial landmark detection.
8
+ - Specifically, we find that semantic ambiguity results in the anisotropic predicted distribution, which inspires us to use predicted distribution to represent semantic ambiguity. So, we use PCA to indicate the character of the predicted distribution and indirectly formulate the direction and intensity of semantic ambiguity. Based on this, STAR loss adaptively suppresses the prediction error in the ambiguity direction to mitigate the impact of ambiguity annotation in training. More details can be found in our paper.
9
+ <p align="center">
10
+ <img src="./images/framework.png" width="80%">
11
+ </p>
12
+
13
+
14
+
15
+
16
+ ## Dependencies
17
+
18
+ * python==3.7.3
19
+ * PyTorch=1.6.0
20
+ * requirements.txt
21
+
22
+ ## Dataset Preparation
23
+
24
+ - Step1: Download the raw images from [COFW](http://www.vision.caltech.edu/xpburgos/ICCV13/#dataset), [300W](https://ibug.doc.ic.ac.uk/resources/300-W/), and [WFLW](https://wywu.github.io/projects/LAB/WFLW.html).
25
+ - Step2: We follow the data preprocess in [ADNet](https://openaccess.thecvf.com/content/ICCV2021/papers/Huang_ADNet_Leveraging_Error-Bias_Towards_Normal_Direction_in_Face_Alignment_ICCV_2021_paper.pdf), and the metadata can be download from [the corresponding repository](https://github.com/huangyangyu/ADNet).
26
+ - Step3: Make them look like this:
27
+ ```script
28
+ # the dataset directory:
29
+ |-- ${image_dir}
30
+ |-- WFLW
31
+ | -- WFLW_images
32
+ |-- 300W
33
+ | -- afw
34
+ | -- helen
35
+ | -- ibug
36
+ | -- lfpw
37
+ |-- COFW
38
+ | -- train
39
+ | -- test
40
+ |-- ${annot_dir}
41
+ |-- WFLW
42
+ |-- train.tsv, test.tsv
43
+ |-- 300W
44
+ |-- train.tsv, test.tsv
45
+ |--COFW
46
+ |-- train.tsv, test.tsv
47
+ ```
48
+
49
+ ## Usage
50
+ * Work directory: set the ${ckpt_dir} in ./conf/alignment.py.
51
+ * Pretrained model:
52
+
53
+ | Dataset | Model |
54
+ |:-----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
55
+ | WFLW | [google](https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view?usp=sharing) / [baidu](https://pan.baidu.com/s/10vvI-ovs3x9NrdmpnXK6sg?pwd=u0yu) |
56
+ | 300W | [google](https://drive.google.com/file/d/1Fiu3hjjkQRdKsWE9IgyNPdiJSz9_MzA5/view?usp=sharing) / [baidu](https://pan.baidu.com/s/1bjUhLq1zS1XSl1nX78fU7A?pwd=yb2s) |
57
+ | COFW | [google](https://drive.google.com/file/d/1NFcZ9jzql_jnn3ulaSzUlyhS05HWB9n_/view?usp=drive_link) / [baidu](https://pan.baidu.com/s/1XO6hDZ8siJLTgFcpyu1Tzw?pwd=m57n) |
58
+
59
+
60
+ ### Training
61
+ ```shell
62
+ python main.py --mode=train --device_ids=0,1,2,3 \
63
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
64
+ --data_definition={WFLW, 300W, COFW}
65
+ ```
66
+
67
+ ### Testing
68
+ ```shell
69
+ python main.py --mode=test --device_ids=0 \
70
+ --image_dir=${image_dir} --annot_dir=${annot_dir} \
71
+ --data_definition={WFLW, 300W, COFW} \
72
+ --pretrained_weight=${model_path} \
73
+ ```
74
+
75
+ ### Evaluation
76
+ ```shell
77
+ python evaluate.py --device_ids=0 \
78
+ --model_path=${model_path} --metadata_path=${metadata_path} \
79
+ --image_dir=${image_dir} --data_definition={WFLW, 300W, COFW} \
80
+ ```
81
+
82
+ To test on your own image, the following code could be considered:
83
+ ```shell
84
+ python demo.py
85
+ ```
86
+
87
+
88
+ ## Results
89
+ The models trained by STAR Loss achieved **SOTA** performance in all of COFW, 300W and WFLW datasets.
90
+
91
+ <p align="center">
92
+ <img src="./images/results.png" width="80%">
93
+ </p>
94
+
95
+ ## BibTeX Citation
96
+ Please consider citing our papers in your publications if the project helps your research. BibTeX reference is as follows.
97
+ ```
98
+ @inproceedings{Zhou_2023_CVPR,
99
+ author = {Zhou, Zhenglin and Li, Huaxia and Liu, Hong and Wang, Nanyang and Yu, Gang and Ji, Rongrong},
100
+ title = {STAR Loss: Reducing Semantic Ambiguity in Facial Landmark Detection},
101
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
102
+ month = {June},
103
+ year = {2023},
104
+ pages = {15475-15484}
105
+ }
106
+ ```
107
+
108
+ ## Acknowledgments
109
+ This repository is built on top of [ADNet](https://github.com/huangyangyu/ADNet).
110
+ Thanks for this strong baseline.
external/landmark_detection/conf/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .alignment import Alignment
external/landmark_detection/conf/alignment.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from .base import Base
3
+
4
+
5
+ class Alignment(Base):
6
+ """
7
+ Alignment configure file, which contains training parameters of alignment.
8
+ """
9
+
10
+ def __init__(self, args):
11
+ super(Alignment, self).__init__('alignment')
12
+ self.ckpt_dir = '/mnt/workspace/humanAIGC/project/STAR/weights'
13
+ self.net = "stackedHGnet_v1"
14
+ self.nstack = 4
15
+ self.loader_type = "alignment"
16
+ self.data_definition = "300W" # COFW, 300W, WFLW
17
+ self.test_file = "test.tsv"
18
+
19
+ # image
20
+ self.channels = 3
21
+ self.width = 256
22
+ self.height = 256
23
+ self.means = (127.5, 127.5, 127.5)
24
+ self.scale = 1 / 127.5
25
+ self.aug_prob = 1.0
26
+
27
+ self.display_iteration = 10
28
+ self.val_epoch = 1
29
+ self.valset = "test.tsv"
30
+ self.norm_type = 'default'
31
+ self.encoder_type = 'default'
32
+ self.decoder_type = 'default'
33
+
34
+ # scheduler & optimizer
35
+ self.milestones = [200, 350, 450]
36
+ self.max_epoch = 260
37
+ self.optimizer = "adam"
38
+ self.learn_rate = 0.001
39
+ self.weight_decay = 0.00001
40
+ self.betas = [0.9, 0.999]
41
+ self.gamma = 0.1
42
+
43
+ # batch_size & workers
44
+ self.batch_size = 32
45
+ self.train_num_workers = 16
46
+ self.val_batch_size = 32
47
+ self.val_num_workers = 16
48
+ self.test_batch_size = 16
49
+ self.test_num_workers = 0
50
+
51
+ # tricks
52
+ self.ema = True
53
+ self.add_coord = True
54
+ self.use_AAM = True
55
+
56
+ # loss
57
+ self.loss_func = "STARLoss_v2"
58
+
59
+ # STAR Loss paras
60
+ self.star_w = 1
61
+ self.star_dist = 'smoothl1'
62
+
63
+ self.init_from_args(args)
64
+
65
+ # COFW
66
+ if self.data_definition == "COFW":
67
+ self.edge_info = (
68
+ (True, (0, 4, 2, 5)), # RightEyebrow
69
+ (True, (1, 6, 3, 7)), # LeftEyebrow
70
+ (True, (8, 12, 10, 13)), # RightEye
71
+ (False, (9, 14, 11, 15)), # LeftEye
72
+ (True, (18, 20, 19, 21)), # Nose
73
+ (True, (22, 26, 23, 27)), # LowerLip
74
+ (True, (22, 24, 23, 25)), # UpperLip
75
+ )
76
+ if self.norm_type == 'ocular':
77
+ self.nme_left_index = 8 # ocular
78
+ self.nme_right_index = 9 # ocular
79
+ elif self.norm_type in ['pupil', 'default']:
80
+ self.nme_left_index = 16 # pupil
81
+ self.nme_right_index = 17 # pupil
82
+ else:
83
+ raise NotImplementedError
84
+ self.classes_num = [29, 7, 29]
85
+ self.crop_op = True
86
+ self.flip_mapping = (
87
+ [0, 1], [4, 6], [2, 3], [5, 7], [8, 9], [10, 11], [12, 14], [16, 17], [13, 15], [18, 19], [22, 23],
88
+ )
89
+ self.image_dir = osp.join(self.image_dir, 'COFW')
90
+ # 300W
91
+ elif self.data_definition == "300W":
92
+ self.edge_info = (
93
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
94
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
95
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
96
+ (False, (27, 28, 29, 30)), # NoseLine
97
+ (False, (31, 32, 33, 34, 35)), # Nose
98
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
99
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
100
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
101
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
102
+ )
103
+ if self.norm_type in ['ocular', 'default']:
104
+ self.nme_left_index = 36 # ocular
105
+ self.nme_right_index = 45 # ocular
106
+ elif self.norm_type == 'pupil':
107
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
108
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
109
+ else:
110
+ raise NotImplementedError
111
+ self.classes_num = [68, 9, 68]
112
+ self.crop_op = True
113
+ self.flip_mapping = (
114
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
115
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
116
+ [31, 35], [32, 34],
117
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
118
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
119
+ )
120
+ self.image_dir = osp.join(self.image_dir, '300W')
121
+ # self.image_dir = osp.join(self.image_dir, '300VW_images')
122
+ # 300VW
123
+ elif self.data_definition == "300VW":
124
+ self.edge_info = (
125
+ (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), # FaceContour
126
+ (False, (17, 18, 19, 20, 21)), # RightEyebrow
127
+ (False, (22, 23, 24, 25, 26)), # LeftEyebrow
128
+ (False, (27, 28, 29, 30)), # NoseLine
129
+ (False, (31, 32, 33, 34, 35)), # Nose
130
+ (True, (36, 37, 38, 39, 40, 41)), # RightEye
131
+ (True, (42, 43, 44, 45, 46, 47)), # LeftEye
132
+ (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)), # OuterLip
133
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # InnerLip
134
+ )
135
+ if self.norm_type in ['ocular', 'default']:
136
+ self.nme_left_index = 36 # ocular
137
+ self.nme_right_index = 45 # ocular
138
+ elif self.norm_type == 'pupil':
139
+ self.nme_left_index = [36, 37, 38, 39, 40, 41] # pupil
140
+ self.nme_right_index = [42, 43, 44, 45, 46, 47] # pupil
141
+ else:
142
+ raise NotImplementedError
143
+ self.classes_num = [68, 9, 68]
144
+ self.crop_op = True
145
+ self.flip_mapping = (
146
+ [0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10], [7, 9],
147
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
148
+ [31, 35], [32, 34],
149
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
150
+ [48, 54], [49, 53], [50, 52], [61, 63], [60, 64], [67, 65], [58, 56], [59, 55],
151
+ )
152
+ self.image_dir = osp.join(self.image_dir, '300VW_Dataset_2015_12_14')
153
+ # WFLW
154
+ elif self.data_definition == "WFLW":
155
+ self.edge_info = (
156
+ (False, (
157
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
158
+ 27,
159
+ 28, 29, 30, 31, 32)), # FaceContour
160
+ (True, (33, 34, 35, 36, 37, 38, 39, 40, 41)), # RightEyebrow
161
+ (True, (42, 43, 44, 45, 46, 47, 48, 49, 50)), # LeftEyebrow
162
+ (False, (51, 52, 53, 54)), # NoseLine
163
+ (False, (55, 56, 57, 58, 59)), # Nose
164
+ (True, (60, 61, 62, 63, 64, 65, 66, 67)), # RightEye
165
+ (True, (68, 69, 70, 71, 72, 73, 74, 75)), # LeftEye
166
+ (True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)), # OuterLip
167
+ (True, (88, 89, 90, 91, 92, 93, 94, 95)), # InnerLip
168
+ )
169
+ if self.norm_type in ['ocular', 'default']:
170
+ self.nme_left_index = 60 # ocular
171
+ self.nme_right_index = 72 # ocular
172
+ elif self.norm_type == 'pupil':
173
+ self.nme_left_index = 96 # pupils
174
+ self.nme_right_index = 97 # pupils
175
+ else:
176
+ raise NotImplementedError
177
+ self.classes_num = [98, 9, 98]
178
+ self.crop_op = True
179
+ self.flip_mapping = (
180
+ [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
181
+ [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], # cheek
182
+ [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], # elbrow
183
+ [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73],
184
+ [55, 59], [56, 58],
185
+ [76, 82], [77, 81], [78, 80], [87, 83], [86, 84],
186
+ [88, 92], [89, 91], [95, 93], [96, 97]
187
+ )
188
+ self.image_dir = osp.join(self.image_dir, 'WFLW', 'WFLW_images')
189
+
190
+ self.label_num = self.nstack * 3 if self.use_AAM else self.nstack
191
+ self.loss_weights, self.criterions, self.metrics = [], [], []
192
+ for i in range(self.nstack):
193
+ factor = (2 ** i) / (2 ** (self.nstack - 1))
194
+ if self.use_AAM:
195
+ self.loss_weights += [factor * weight for weight in [1.0, 10.0, 10.0]]
196
+ self.criterions += [self.loss_func, "AWingLoss", "AWingLoss"]
197
+ self.metrics += ["NME", None, None]
198
+ else:
199
+ self.loss_weights += [factor * weight for weight in [1.0]]
200
+ self.criterions += [self.loss_func, ]
201
+ self.metrics += ["NME", ]
202
+
203
+ self.key_metric_index = (self.nstack - 1) * 3 if self.use_AAM else (self.nstack - 1)
204
+
205
+ # data
206
+ self.folder = self.get_foldername()
207
+ self.work_dir = osp.join(self.ckpt_dir, self.data_definition, self.folder)
208
+ self.model_dir = osp.join(self.work_dir, 'model')
209
+ self.log_dir = osp.join(self.work_dir, 'log')
210
+
211
+ self.train_tsv_file = osp.join(self.annot_dir, self.data_definition, "train.tsv")
212
+ self.train_pic_dir = self.image_dir
213
+
214
+ self.val_tsv_file = osp.join(self.annot_dir, self.data_definition, self.valset)
215
+ self.val_pic_dir = self.image_dir
216
+
217
+ self.test_tsv_file = osp.join(self.annot_dir, self.data_definition, self.test_file)
218
+ self.test_pic_dir = self.image_dir
219
+
220
+ # self.train_tsv_file = osp.join(self.annot_dir, '300VW', "train.tsv")
221
+ # self.train_pic_dir = self.image_dir
222
+
223
+ # self.val_tsv_file = osp.join(self.annot_dir, '300VW', self.valset)
224
+ # self.val_pic_dir = self.image_dir
225
+
226
+ # self.test_tsv_file = osp.join(self.annot_dir, '300VW', self.test_file)
227
+ # self.test_pic_dir = self.image_dir
228
+
229
+
230
+ def get_foldername(self):
231
+ str = ''
232
+ str += '{}_{}x{}_{}_ep{}_lr{}_bs{}'.format(self.data_definition, self.height, self.width,
233
+ self.optimizer, self.max_epoch, self.learn_rate, self.batch_size)
234
+ str += '_{}'.format(self.loss_func)
235
+ str += '_{}_{}'.format(self.star_dist, self.star_w) if self.loss_func == 'STARLoss' else ''
236
+ str += '_AAM' if self.use_AAM else ''
237
+ str += '_{}'.format(self.valset[:-4]) if self.valset != 'test.tsv' else ''
238
+ str += '_{}'.format(self.id)
239
+ return str
external/landmark_detection/conf/base.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import logging
3
+ import os.path as osp
4
+ from argparse import Namespace
5
+ # from tensorboardX import SummaryWriter
6
+
7
+ class Base:
8
+ """
9
+ Base configure file, which contains the basic training parameters and should be inherited by other attribute configure file.
10
+ """
11
+
12
+ def __init__(self, config_name, ckpt_dir='./', image_dir='./', annot_dir='./'):
13
+ self.type = config_name
14
+ self.id = str(uuid.uuid4())
15
+ self.note = ""
16
+
17
+ self.ckpt_dir = ckpt_dir
18
+ self.image_dir = image_dir
19
+ self.annot_dir = annot_dir
20
+
21
+ self.loader_type = "alignment"
22
+ self.loss_func = "STARLoss"
23
+
24
+ # train
25
+ self.batch_size = 128
26
+ self.val_batch_size = 1
27
+ self.test_batch_size = 32
28
+ self.channels = 3
29
+ self.width = 256
30
+ self.height = 256
31
+
32
+ # mean values in r, g, b channel.
33
+ self.means = (127, 127, 127)
34
+ self.scale = 0.0078125
35
+
36
+ self.display_iteration = 100
37
+ self.milestones = [50, 80]
38
+ self.max_epoch = 100
39
+
40
+ self.net = "stackedHGnet_v1"
41
+ self.nstack = 4
42
+
43
+ # ["adam", "sgd"]
44
+ self.optimizer = "adam"
45
+ self.learn_rate = 0.1
46
+ self.momentum = 0.01 # caffe: 0.99
47
+ self.weight_decay = 0.0
48
+ self.nesterov = False
49
+ self.scheduler = "MultiStepLR"
50
+ self.gamma = 0.1
51
+
52
+ self.loss_weights = [1.0]
53
+ self.criterions = ["SoftmaxWithLoss"]
54
+ self.metrics = ["Accuracy"]
55
+ self.key_metric_index = 0
56
+ self.classes_num = [1000]
57
+ self.label_num = len(self.classes_num)
58
+
59
+ # model
60
+ self.ema = False
61
+ self.use_AAM = True
62
+
63
+ # visualization
64
+ self.writer = None
65
+
66
+ # log file
67
+ self.logger = None
68
+
69
+ def init_instance(self):
70
+ # self.writer = SummaryWriter(logdir=self.log_dir, comment=self.type)
71
+ log_formatter = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s")
72
+ root_logger = logging.getLogger()
73
+ file_handler = logging.FileHandler(osp.join(self.log_dir, "log.txt"))
74
+ file_handler.setFormatter(log_formatter)
75
+ file_handler.setLevel(logging.NOTSET)
76
+ root_logger.addHandler(file_handler)
77
+ console_handler = logging.StreamHandler()
78
+ console_handler.setFormatter(log_formatter)
79
+ console_handler.setLevel(logging.NOTSET)
80
+ root_logger.addHandler(console_handler)
81
+ root_logger.setLevel(logging.NOTSET)
82
+ self.logger = root_logger
83
+
84
+ def __del__(self):
85
+ # tensorboard --logdir self.log_dir
86
+ if self.writer is not None:
87
+ # self.writer.export_scalars_to_json(self.log_dir + "visual.json")
88
+ self.writer.close()
89
+
90
+ def init_from_args(self, args: Namespace):
91
+ args_vars = vars(args)
92
+ for key, value in args_vars.items():
93
+ if hasattr(self, key) and value is not None:
94
+ setattr(self, key, value)
external/landmark_detection/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Token":"bpt4JPotFA6bpdknR9ZDCw",
3
+ "business_flag": "shadow_cv_face",
4
+ "model_local_file_path": "/apdcephfs_cq3/share_1134483/charlinzhou/Documents/awesome-tools/jizhi/",
5
+ "host_num": 1,
6
+ "host_gpu_num": 1,
7
+ "GPUName": "V100",
8
+ "is_elasticity": true,
9
+ "enable_evicted_pulled_up": true,
10
+ "task_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
11
+ "task_flag": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
12
+ "model_name": "20230312_slpt_star_bb_init_eigen_box_align_smoothl1-1",
13
+ "image_full_name": "mirrors.tencent.com/haroldzcli/py36-pytorch1.7.1-torchvision0.8.2-cuda10.1-cudnn7.6",
14
+ "start_cmd": "./start_slpt.sh /apdcephfs_cq3/share_1134483/charlinzhou/Documents/SLPT_Training train.py --loss_func=star --bb_init --eigen_box --dist_func=align_smoothl1"
15
+ }
external/landmark_detection/data_processor/CheckFaceKeyPoint.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ selected_indices_old = [
8
+ 2311,
9
+ 2416,
10
+ 2437,
11
+ 2460,
12
+ 2495,
13
+ 2518,
14
+ 2520,
15
+ 2627,
16
+ 4285,
17
+ 4315,
18
+ 6223,
19
+ 6457,
20
+ 6597,
21
+ 6642,
22
+ 6974,
23
+ 7054,
24
+ 7064,
25
+ 7182,
26
+ 7303,
27
+ 7334,
28
+ 7351,
29
+ 7368,
30
+ 7374,
31
+ 7493,
32
+ 7503,
33
+ 7626,
34
+ 8443,
35
+ 8562,
36
+ 8597,
37
+ 8701,
38
+ 8817,
39
+ 8953,
40
+ 11213,
41
+ 11261,
42
+ 11317,
43
+ 11384,
44
+ 11600,
45
+ 11755,
46
+ 11852,
47
+ 11891,
48
+ 11945,
49
+ 12010,
50
+ 12354,
51
+ 12534,
52
+ 12736,
53
+ 12880,
54
+ 12892,
55
+ 13004,
56
+ 13323,
57
+ 13371,
58
+ 13534,
59
+ 13575,
60
+ 14874,
61
+ 14949,
62
+ 14977,
63
+ 15052,
64
+ 15076,
65
+ 15291,
66
+ 15620,
67
+ 15758,
68
+ 16309,
69
+ 16325,
70
+ 16348,
71
+ 16390,
72
+ 16489,
73
+ 16665,
74
+ 16891,
75
+ 17147,
76
+ 17183,
77
+ 17488,
78
+ 17549,
79
+ 17657,
80
+ 17932,
81
+ 19661,
82
+ 20162,
83
+ 20200,
84
+ 20238,
85
+ 20286,
86
+ 20432,
87
+ 20834,
88
+ 20954,
89
+ 21015,
90
+ 21036,
91
+ 21117,
92
+ 21299,
93
+ 21611,
94
+ 21632,
95
+ 21649,
96
+ 22722,
97
+ 22759,
98
+ 22873,
99
+ 23028,
100
+ 23033,
101
+ 23082,
102
+ 23187,
103
+ 23232,
104
+ 23302,
105
+ 23413,
106
+ 23430,
107
+ 23446,
108
+ 23457,
109
+ 23548,
110
+ 23636,
111
+ 32060,
112
+ 32245,
113
+ ]
114
+
115
+ selected_indices = list()
116
+ with open('/home/gyalex/Desktop/face_anno.txt', 'r') as f:
117
+ lines = f.readlines()
118
+ for line in lines:
119
+ hh = line.strip().split()
120
+ if len(hh) > 0:
121
+ pid = hh[0].find('.')
122
+ if pid != -1:
123
+ s = hh[0][pid+1:len(hh[0])]
124
+ print(s)
125
+ selected_indices.append(int(s))
126
+
127
+ f.close()
128
+
129
+ dir = '/media/gyalex/Data/face_ldk_dataset/MHC_LightingPreset_Portrait_RT_0_19/MHC_LightingPreset_Portrait_RT_seq_000015'
130
+
131
+ for idx in range(500):
132
+ img = os.path.join(dir, "view_1/MHC_LightingPreset_Portrait_RT_seq_000015_FinalImage_" + str(idx).zfill(4) + ".jpeg")
133
+ lmd = os.path.join(dir, "mesh/mesh_screen" + str(idx+5).zfill(7) + ".npy")
134
+
135
+ img = cv2.imread(img)
136
+ # c = 511 / 2
137
+ # lmd = np.load(lmd) * c + c
138
+ # lmd[:, 1] = 511 - lmd[:, 1]
139
+ lmd = np.load(lmd)[selected_indices]
140
+ for i in range(lmd.shape[0]):
141
+ p = lmd[i]
142
+ x, y = round(float(p[0])), round(float(p[1]))
143
+ print(p)
144
+ cv2.circle(img, (x, y), 2, (0, 0, 255), -1)
145
+
146
+ cv2.imshow('win', img)
147
+ cv2.waitKey(0)
external/landmark_detection/data_processor/align.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import open3d as o3d
3
+ from scipy.spatial.transform import Rotation
4
+ from scipy.linalg import orthogonal_procrustes
5
+
6
+ from open3d.pipelines.registration import registration_ransac_based_on_correspondence
7
+
8
+
9
+ def rigid_transform_3D(A, B):
10
+ assert A.shape == B.shape, "Input arrays must have the same shape"
11
+ assert A.shape[1] == 3, "Input arrays must be Nx3"
12
+
13
+ N = A.shape[0] # Number of points
14
+
15
+ # Compute centroids of A and B
16
+ centroid_A = np.mean(A, axis=0)
17
+ centroid_B = np.mean(B, axis=0)
18
+
19
+ # Center the points around the centroids
20
+ AA = A - centroid_A
21
+ BB = B - centroid_B
22
+
23
+ # H = AA^T * BB
24
+ H = np.dot(AA.T, BB)
25
+
26
+ # Singular Value Decomposition
27
+ U, S, Vt = np.linalg.svd(H)
28
+
29
+ # Compute rotation
30
+ R = np.dot(Vt.T, U.T)
31
+
32
+ # Ensure a proper rotation (det(R) should be +1)
33
+ if np.linalg.det(R) < 0:
34
+ Vt[2, :] *= -1
35
+ R = np.dot(Vt.T, U.T)
36
+
37
+ # Compute translation
38
+ t = centroid_B - np.dot(R, centroid_A)
39
+
40
+ # Construct the transform matrix (4x4)
41
+ transform_matrix = np.eye(4)
42
+ transform_matrix[:3, :3] = R
43
+ transform_matrix[:3, 3] = t
44
+
45
+ return transform_matrix
46
+
47
+
48
+ def compute_rigid_transform(points1, points2):
49
+ """
50
+ 计算从points1到points2的刚体变换(包括尺度、旋转和平移)。
51
+
52
+ 参数:
53
+ points1, points2: np.ndarray, 形状为(68, 3)的数组,分别为两组3D对应点。
54
+
55
+ 返回:
56
+ scale: float, 尺度因子
57
+ R: np.ndarray, 3x3的旋转矩阵
58
+ t: np.ndarray, 3维的平移向量
59
+ """
60
+ # 中心化
61
+ mean1 = np.mean(points1, axis=0)
62
+ centered_points1 = points1 - mean1
63
+ mean2 = np.mean(points2, axis=0)
64
+ centered_points2 = points2 - mean2
65
+
66
+ # 使用orthogonal_procrustes计算旋转和平移
67
+ R, _ = orthogonal_procrustes(centered_points1, centered_points2)
68
+ t = mean2 - R @ mean1 # 计算平移向量
69
+
70
+ # 计算尺度因子
71
+ scale = np.mean(np.linalg.norm(centered_points2, axis=1) /
72
+ np.linalg.norm(centered_points1, axis=1))
73
+
74
+ return scale, R, t
75
+
76
+
77
+ def compute_rigid_transform_new(points_A, points_B):
78
+ # 中心化
79
+ center_A = np.mean(points_A, axis=0)
80
+ center_B = np.mean(points_B, axis=0)
81
+ points_A_centered = points_A - center_A
82
+ points_B_centered = points_B - center_B
83
+
84
+ # 计算协方差矩阵
85
+ cov_matrix = np.dot(points_A_centered.T, points_B_centered)
86
+
87
+ # SVD分解
88
+ U, S, Vt = np.linalg.svd(cov_matrix)
89
+
90
+ # 确保旋转矩阵为正交且右手系,这里我们取Vt的转置作为旋转矩阵
91
+ rotation_matrix = np.dot(Vt.T, U.T)
92
+
93
+ # 检查行列式是否为-1(表示反射,不满足旋转矩阵要求),如果是,则调整一个列的符号
94
+ if np.linalg.det(rotation_matrix) < 0:
95
+ Vt[2,:] *= -1
96
+ rotation_matrix = np.dot(Vt.T, U.T)
97
+
98
+ # 计算尺度因子
99
+ scale = np.trace(np.dot(points_A_centered.T, points_B_centered)) / np.trace(np.dot(points_A_centered.T, points_A_centered))
100
+
101
+ # 计算平移向量
102
+ translation_vector = center_B - scale * np.dot(rotation_matrix, center_A)
103
+
104
+ return scale, rotation_matrix, translation_vector
105
+
106
+
107
+
108
+
109
+ # 示范用法
110
+ obj_A = '/home/gyalex/Desktop/our_face.obj'
111
+ obj_B = '/home/gyalex/Desktop/Neutral.obj'
112
+
113
+ mesh_A = o3d.io.read_triangle_mesh(obj_A)
114
+ mesh_B = o3d.io.read_triangle_mesh(obj_B)
115
+
116
+ vertices_A = np.asarray(mesh_A.vertices)
117
+ vertices_B = np.asarray(mesh_B.vertices)
118
+
119
+ list_A = list()
120
+ list_B = list()
121
+ with open('/home/gyalex/Desktop/our_marker.txt', 'r') as f:
122
+ lines_A = f.readlines()
123
+ for line in lines_A:
124
+ hh = line.strip().split()
125
+ list_A.append(int(hh[0]))
126
+
127
+ with open('/home/gyalex/Desktop/ARKit_landmarks.txt', 'r') as f:
128
+ lines_B = f.readlines()
129
+ for line in lines_B:
130
+ hh = line.strip().split()
131
+ list_B.append(int(hh[0]))
132
+
133
+ A = vertices_A[list_A,:] # 第一组3D点
134
+ B = vertices_B[list_B,:] # 第二组3D点
135
+
136
+ # scale, R, t = compute_rigid_transform(A, B)
137
+
138
+ # # 定义尺度变换矩阵
139
+ # scale_matrix = np.eye(4)
140
+ # scale_matrix[0, 0] = scale # x轴方向放大2倍
141
+ # scale_matrix[1, 1] = scale # y轴方向放大2倍
142
+ # scale_matrix[2, 2] = scale # z轴方向放大2倍
143
+
144
+ # transform_matrix = np.eye(4)
145
+ # transform_matrix[:3, :3] = scale
146
+ # transform_matrix[:3, 3] = R*t
147
+
148
+ # mesh_A.transform(transform_matrix)
149
+ # # mesh_A.transform(scale_matrix)
150
+
151
+ # o3d.io.write_triangle_mesh('/home/gyalex/Desktop/our_face_new.obj', mesh_A)
152
+
153
+ pcd_source = o3d.utility.Vector3dVector(A) # 示例源点云数据
154
+ pcd_target = o3d.utility.Vector3dVector(B) # 示例目标点云数据 + 1偏移,仅作示例
155
+
156
+ corres_source = list()
157
+ for idx in range(68): corres_source.append(idx)
158
+ corres_target = list()
159
+ for idx in range(68): corres_target.append(idx)
160
+
161
+ # 根据对应点索引获取实际的对应点坐标
162
+ corres_source_points = pcd_source
163
+ corres_target_points = pcd_target
164
+
165
+ corres = o3d.utility.Vector2iVector([[src, tgt] for src, tgt in zip(corres_source, corres_target)])
166
+
167
+ # 应用RANSAC进行基于对应点的配准
168
+ reg_result = registration_ransac_based_on_correspondence(
169
+ pcd_source,
170
+ pcd_target,
171
+ corres,
172
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(),
173
+ ransac_n=3,
174
+ criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=100000, epsilon=1e-6)
175
+ )
176
+
177
+ # # 使用RANSAC进行配准
178
+ # convergence_criteria = o3d.pipelines.registration.RANSACConvergenceCriteria(max_iteration=50000, max_validation=500)
179
+ # ransac_result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
180
+ # pcd_source,
181
+ # pcd_target,
182
+ # corres,
183
+ # o3d.pipelines.registration.TransformationEstimationPointToPoint(),
184
+ # 3, # RANSAC阈值,根据实际情况调整
185
+ # convergence_criteria,
186
+ # [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
187
+ # o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.05)],
188
+ # o3d.pipelines.registration.RANSACLoss())
189
+
190
+ # 应用变换到源mesh
191
+ # mesh_source_aligned = mesh_source.transform(reg_result.transformation)
192
+
193
+ a = 0
external/landmark_detection/data_processor/process_pcd.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import open3d as o3d
5
+ # import pyrender
6
+ # from pyrender import mesh, DirectionalLight, Material, PerspectiveCamera
7
+
8
+ os.environ['__GL_THREADED_OPTIMIZATIONS'] = '1'
9
+
10
+ cord_list = []
11
+ with open('./cord.txt', 'r') as f:
12
+ lines = f.readlines()
13
+ for line in lines:
14
+ m = line.split()
15
+ x = int(m[0])
16
+ y = int(m[1])
17
+
18
+ x = 1000 - x
19
+ y = 1000 - y
20
+
21
+ cord_list.append([x, y])
22
+
23
+
24
+ # 假设TXT文件的路径
25
+ output_folder = '/media/gyalex/Data/face_det_dataset/rgbd_data/rgbd'
26
+ if not os.path.exists(output_folder):
27
+ os.mkdir(output_folder)
28
+
29
+ for idx in range(32, 33):
30
+ txt_file_path = '/media/gyalex/Data/face_det_dataset/rgbd_data/PointImage'+ str(idx) + '.txt'
31
+ _, name = os.path.split(txt_file_path)
32
+ print(txt_file_path)
33
+
34
+ with open(txt_file_path, 'r') as file:
35
+ points = []
36
+ rgb_list = []
37
+ ori_rgb_list = []
38
+ normal_list = []
39
+
40
+ # 逐行读取数据
41
+ for line in file:
42
+ # 去除行尾的换行符并分割字符串
43
+ x, y, z, r, g, b, nx, ny, nz, w = line.split()
44
+ # 将字符串转换为浮点数
45
+ x = float(x)
46
+ y = float(y)
47
+ z = float(z)
48
+ r = float(r)
49
+ g = float(g)
50
+ b = float(b)
51
+ nx = float(nx)
52
+ ny = float(ny)
53
+ nz = float(nz)
54
+ # 将点添加到列表中
55
+ points.append((x, y, z))
56
+ rgb_list.append((r/255.0, g/255.0 , b/255.0))
57
+ normal_list.append((nx, ny, nz))
58
+
59
+ ori_r = int(r)
60
+ ori_g = int(g)
61
+ ori_b = int(b)
62
+ ori_rgb_list.append((ori_r, ori_g , ori_b))
63
+
64
+ np_points = np.asarray(points)
65
+
66
+ np_points_a = np_points
67
+
68
+ np_colors = np.asarray(rgb_list)
69
+ np_normals = np.asarray(normal_list)
70
+
71
+ np_colors_ori = np.asarray(ori_rgb_list)
72
+
73
+ pcd = o3d.geometry.PointCloud()
74
+ pcd.points = o3d.utility.Vector3dVector(np_points)
75
+ pcd.colors = o3d.utility.Vector3dVector(np_colors)
76
+ pcd.normals = o3d.utility.Vector3dVector(np_normals)
77
+
78
+ map_dict = {}
79
+
80
+ image = np.ones((1000, 1000, 3),dtype=np.uint8)*255
81
+ for i in range(np.array(pcd.points).shape[0]):
82
+ x = np.array(pcd.points)[i,0]+400
83
+ y = np.array(pcd.points)[i,1]+400
84
+
85
+ image[int(x),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
86
+ image[int(x+1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
87
+ image[int(x),int(y+1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
88
+ image[int(x-1),int(y),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
89
+ image[int(x),int(y-1),:] = (np.array(pcd.colors)[i,:]*255).astype(np.uint8)
90
+
91
+ map_dict[str(int(x)) + '_' + str(int(y))] = i
92
+ map_dict[str(int(x+1)) + '_' + str(int(y))] = i
93
+ map_dict[str(int(x)) + '_' + str(int(y+1))] = i
94
+ map_dict[str(int(x-1)) + '_' + str(int(y))] = i
95
+ map_dict[str(int(x)) + '_' + str(int(y-1))] = i
96
+
97
+ # if [int(y), int(x)] in cord_list:
98
+ # image[int(x),int(y),:] = np.array([0, 255, 0])
99
+
100
+ # if [int(y), int(x+1)] in cord_list:
101
+ # image[int(x+1),int(y),:] = np.array([0, 255, 0])
102
+
103
+ # if [int(y+1), int(x)] in cord_list:
104
+ # image[int(x),int(y+1),:] = np.array([0, 255, 0])
105
+
106
+ # if [int(y), int(x-1)] in cord_list:
107
+ # image[int(x-1),int(y),:] = np.array([0, 255, 0])
108
+
109
+ # if [int(y-1), int(x)] in cord_list:
110
+ # image[int(x),int(y-1),:] = np.array([0, 255, 0])
111
+
112
+ # if [int(y-1), int(x-1)] in cord_list:
113
+ # image[int(x-1),int(y-1),:] = np.array([0, 255, 0])
114
+
115
+ # if [int(y+1), int(x+1)] in cord_list:
116
+ # image[int(x+1),int(y+1),:] = np.array([0, 255, 0])
117
+
118
+ h_list = []
119
+ for m in cord_list:
120
+ a, b = m[0], m[1]
121
+ c = image[int(b),int(a),:][0]
122
+
123
+ flag = False
124
+
125
+ if image[int(b),int(a),:][1] != 255:
126
+ h_list.append(str(int(b))+'_'+str(int(a)))
127
+ flag = True
128
+ else:
129
+ if image[int(b)-2,int(a)-2,:][1] != 255:
130
+ h_list.append(str(int(b)-2)+'_'+str(int(a)-2))
131
+ flag = True
132
+ elif image[int(b)+2,int(a)+2,:][1] != 255:
133
+ h_list.append(str(int(b)+2)+'_'+str(int(a)+2))
134
+ flag = True
135
+ elif image[int(b),int(a)-3,:][1] != 255:
136
+ h_list.append(str(int(b))+'_'+str(int(a)-3))
137
+ flag = True
138
+
139
+ # if flag == False:
140
+ # cc = image[int(b),int(a),:][1]
141
+
142
+ # cv2.circle(image, (465,505), 2, (0, 255, 0), -1)
143
+
144
+ # cv2.imshow('win', image)
145
+ # cv2.waitKey(0)
146
+
147
+ with open('pid.txt', 'w') as f:
148
+ for h in h_list:
149
+ pid = map_dict[h]
150
+ s = str(pid) + '\n'
151
+ f.write(s)
152
+
153
+ np_colors[pid,:] = np.array([0, 255, 0])
154
+
155
+ f.close()
156
+
157
+ pcd0 = o3d.geometry.PointCloud()
158
+ pcd0.points = o3d.utility.Vector3dVector(np_points)
159
+ pcd0.colors = o3d.utility.Vector3dVector(np_colors)
160
+ pcd0.normals = o3d.utility.Vector3dVector(np_normals)
161
+
162
+ o3d.io.write_point_cloud('aa.ply', pcd0)
163
+
164
+
165
+ mm = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
166
+ image3 = cv2.flip(mm, -1)
167
+
168
+ # cv2.imwrite('./rgb.png', image3)
169
+
170
+ with open('./cord.txt', 'r') as f:
171
+ lines = f.readlines()
172
+ for line in lines:
173
+ m = line.split()
174
+ x = int(m[0])
175
+ y = int(m[1])
176
+
177
+ x = 1000 - x
178
+ y = 1000 - y
179
+
180
+ cv2.circle(image, (x,y), 2, (0, 255, 0), -1)
181
+
182
+ idx = map_dict[str(x)+'_'+str(y)]
183
+
184
+ a = 0
185
+
186
+ # cv2.imshow("win", image)
187
+ # cv2.waitKey(0)
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+ # import matplotlib.pyplot as plt
203
+ # plt.imshow(image)
204
+ # plt.show()
205
+
206
+ # save_pcd_path = os.path.join(output_folder, name[:-3]+'ply')
207
+ # # o3d.io.write_point_cloud(save_pcd_path, pcd)
208
+
209
+ # # render
210
+ # import trimesh
211
+ # # fuze_trimesh = trimesh.load('/home/gyalex/Desktop/PointImage32.obj')
212
+ # # mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)
213
+ # mesh = pyrender.Mesh.from_points(np_points, np_colors_ori, np_normals)
214
+
215
+ # import math
216
+ # camera = PerspectiveCamera(yfov=math.pi / 3, aspectRatio=1.0)
217
+ # camera_pose = np.array([[-1.0, 0.0, 0.0, 0], \
218
+ # [0.0, 1.0, 0.0, 0], \
219
+ # [0.0, 0.0, -1.0, 0], \
220
+ # [0.0, 0.0, 0.0, 1.0]])
221
+
222
+ # # 创建场景
223
+ # scene = pyrender.Scene()
224
+ # scene.add(mesh)
225
+ # scene.add(camera, pose=camera_pose)
226
+
227
+ # # light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, innerConeAngle=np.pi/16.0, outerConeAngle=np.pi/6.0)
228
+ # # scene.add(light, pose=camera_pose)
229
+
230
+ # # 渲染场景
231
+ # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
232
+ # color, depth = renderer.render(scene)
233
+
234
+ # # # 设置场景和光源
235
+ # # scene = pyrender.Scene()
236
+ # # scene.add(point_cloud_mesh, 'point_cloud')
237
+ # # camera = PerspectiveCamera(yfov=45.0, aspectRatio=1.0)
238
+ # # scene.add(camera)
239
+
240
+ # # # 渲染场景
241
+ # # renderer = pyrender.OffscreenRenderer(viewport_width=1280, viewport_height=1024)
242
+ # # color, depth = renderer.render(scene)
243
+
244
+ # # 保存渲染结果为图片
245
+ # import cv2
246
+ # cv2.imshow('win', color)
247
+
248
+ # rgb_img = cv2.imread('/media/gyalex/Data/face_det_dataset/rgbd_data/color_32.bmp')
249
+ # cv2.imshow('win0', rgb_img)
250
+ # cv2.waitKey(0)
external/landmark_detection/evaluate.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+
10
+ # private package
11
+ from lib import utility
12
+
13
+
14
+
15
+ class GetCropMatrix():
16
+ """
17
+ from_shape -> transform_matrix
18
+ """
19
+
20
+ def __init__(self, image_size, target_face_scale, align_corners=False):
21
+ self.image_size = image_size
22
+ self.target_face_scale = target_face_scale
23
+ self.align_corners = align_corners
24
+
25
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
26
+ cosv = math.cos(angle)
27
+ sinv = math.sin(angle)
28
+
29
+ fx, fy = from_center
30
+ tx, ty = to_center
31
+
32
+ acos = scale * cosv
33
+ asin = scale * sinv
34
+
35
+ a0 = acos
36
+ a1 = -asin
37
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
38
+
39
+ b0 = asin
40
+ b1 = acos
41
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
42
+
43
+ rot_scale_m = np.array([
44
+ [a0, a1, a2],
45
+ [b0, b1, b2],
46
+ [0.0, 0.0, 1.0]
47
+ ], np.float32)
48
+ return rot_scale_m
49
+
50
+ def process(self, scale, center_w, center_h):
51
+ if self.align_corners:
52
+ to_w, to_h = self.image_size - 1, self.image_size - 1
53
+ else:
54
+ to_w, to_h = self.image_size, self.image_size
55
+
56
+ rot_mu = 0
57
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
58
+ shift_xy_mu = (0, 0)
59
+ matrix = self._compose_rotate_and_scale(
60
+ rot_mu, scale_mu, shift_xy_mu,
61
+ from_center=[center_w, center_h],
62
+ to_center=[to_w / 2.0, to_h / 2.0])
63
+ return matrix
64
+
65
+
66
+ class TransformPerspective():
67
+ """
68
+ image, matrix3x3 -> transformed_image
69
+ """
70
+
71
+ def __init__(self, image_size):
72
+ self.image_size = image_size
73
+
74
+ def process(self, image, matrix):
75
+ return cv2.warpPerspective(
76
+ image, matrix, dsize=(self.image_size, self.image_size),
77
+ flags=cv2.INTER_LINEAR, borderValue=0)
78
+
79
+
80
+ class TransformPoints2D():
81
+ """
82
+ points (nx2), matrix (3x3) -> points (nx2)
83
+ """
84
+
85
+ def process(self, srcPoints, matrix):
86
+ # nx3
87
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
88
+ desPoints = desPoints @ np.transpose(matrix) # nx3
89
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
90
+ return desPoints.astype(srcPoints.dtype)
91
+
92
+
93
+ class Alignment:
94
+ def __init__(self, args, model_path, dl_framework, device_ids):
95
+ self.input_size = 256
96
+ self.target_face_scale = 1.0
97
+ self.dl_framework = dl_framework
98
+
99
+ # model
100
+ if self.dl_framework == "pytorch":
101
+ # conf
102
+ self.config = utility.get_config(args)
103
+ self.config.device_id = device_ids[0]
104
+ # set environment
105
+ utility.set_environment(self.config)
106
+ self.config.init_instance()
107
+ if self.config.logger is not None:
108
+ self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
109
+ self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
110
+
111
+ net = utility.get_net(self.config)
112
+ if device_ids == [-1]:
113
+ checkpoint = torch.load(model_path, map_location="cpu")
114
+ else:
115
+ checkpoint = torch.load(model_path)
116
+ net.load_state_dict(checkpoint["net"])
117
+ net = net.to(self.config.device_id)
118
+ net.eval()
119
+ self.alignment = net
120
+ else:
121
+ assert False
122
+
123
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
124
+ align_corners=True)
125
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
126
+ self.transformPoints2D = TransformPoints2D()
127
+
128
+ def norm_points(self, points, align_corners=False):
129
+ if align_corners:
130
+ # [0, SIZE-1] -> [-1, +1]
131
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
132
+ else:
133
+ # [-0.5, SIZE-0.5] -> [-1, +1]
134
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
135
+
136
+ def denorm_points(self, points, align_corners=False):
137
+ if align_corners:
138
+ # [-1, +1] -> [0, SIZE-1]
139
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
140
+ else:
141
+ # [-1, +1] -> [-0.5, SIZE-0.5]
142
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
143
+
144
+ def preprocess(self, image, scale, center_w, center_h):
145
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
146
+ input_tensor = self.transformPerspective.process(image, matrix)
147
+ input_tensor = input_tensor[np.newaxis, :]
148
+
149
+ input_tensor = torch.from_numpy(input_tensor)
150
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
151
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
152
+ input_tensor = input_tensor.to(self.config.device_id)
153
+ return input_tensor, matrix
154
+
155
+ def postprocess(self, srcPoints, coeff):
156
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
157
+ # matrix^(-1) * src = dst
158
+ # src = matrix * dst
159
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
160
+ for i in range(srcPoints.shape[0]):
161
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
162
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
163
+ return dstPoints
164
+
165
+ def analyze(self, image, scale, center_w, center_h):
166
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
167
+
168
+ if self.dl_framework == "pytorch":
169
+ with torch.no_grad():
170
+ output = self.alignment(input_tensor)
171
+ landmarks = output[-1][0]
172
+ else:
173
+ assert False
174
+
175
+ landmarks = self.denorm_points(landmarks)
176
+ landmarks = landmarks.data.cpu().numpy()[0]
177
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
178
+
179
+ return landmarks
180
+
181
+
182
+ def L2(p1, p2):
183
+ return np.linalg.norm(p1 - p2)
184
+
185
+
186
+ def NME(landmarks_gt, landmarks_pv):
187
+ pts_num = landmarks_gt.shape[0]
188
+ if pts_num == 29:
189
+ left_index = 16
190
+ right_index = 17
191
+ elif pts_num == 68:
192
+ left_index = 36
193
+ right_index = 45
194
+ elif pts_num == 98:
195
+ left_index = 60
196
+ right_index = 72
197
+
198
+ nme = 0
199
+ eye_span = L2(landmarks_gt[left_index], landmarks_gt[right_index])
200
+ for i in range(pts_num):
201
+ error = L2(landmarks_pv[i], landmarks_gt[i])
202
+ nme += error / eye_span
203
+ nme /= pts_num
204
+ return nme
205
+
206
+
207
+ def evaluate(args, model_path, metadata_path, device_ids, mode):
208
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
209
+ config = alignment.config
210
+ nme_sum = 0
211
+ with open(metadata_path, 'r') as f:
212
+ lines = f.readlines()
213
+ for k, line in enumerate(tqdm(lines)):
214
+ item = line.strip().split("\t")
215
+ image_name, landmarks_5pts, landmarks_gt, scale, center_w, center_h = item[:6]
216
+ # image & keypoints alignment
217
+ image_name = image_name.replace('\\', '/')
218
+ image_name = image_name.replace('//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
219
+ image_name = image_name.replace('./rawImages/', '')
220
+ image_path = os.path.join(config.image_dir, image_name)
221
+ landmarks_gt = np.array(list(map(float, landmarks_gt.split(","))), dtype=np.float32).reshape(-1, 2)
222
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
223
+
224
+ image = cv2.imread(image_path)
225
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
226
+
227
+ # NME
228
+ if mode == "nme":
229
+ nme = NME(landmarks_gt, landmarks_pv)
230
+ nme_sum += nme
231
+ # print("Current NME(%d): %f" % (k + 1, (nme_sum / (k + 1))))
232
+ else:
233
+ pass
234
+
235
+ if mode == "nme":
236
+ print("Final NME: %f" % (100*nme_sum / (k + 1)))
237
+ else:
238
+ pass
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = argparse.ArgumentParser(description="Evaluation script")
243
+ parser.add_argument("--config_name", type=str, default="alignment", help="set configure file name")
244
+ parser.add_argument("--model_path", type=str, default="./train.pkl", help="the path of model")
245
+ parser.add_argument("--data_definition", type=str, default='WFLW', help="COFW/300W/WFLW")
246
+ parser.add_argument("--metadata_path", type=str, default="", help="the path of metadata")
247
+ parser.add_argument("--image_dir", type=str, default="", help="the path of image")
248
+ parser.add_argument("--device_ids", type=str, default="0", help="set device ids, -1 means use cpu device, >= 0 means use gpu device")
249
+ parser.add_argument("--mode", type=str, default="nme", help="set the evaluate mode: nme")
250
+ args = parser.parse_args()
251
+
252
+ device_ids = list(map(int, args.device_ids.split(",")))
253
+ evaluate(
254
+ args,
255
+ model_path=args.model_path,
256
+ metadata_path=args.metadata_path,
257
+ device_ids=device_ids,
258
+ mode=args.mode)
external/landmark_detection/infer_folder.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+ import json
8
+
9
+ # private package
10
+ from lib import utility
11
+ from FaceBoxesV2.faceboxes_detector import *
12
+
13
+ class GetCropMatrix():
14
+ """
15
+ from_shape -> transform_matrix
16
+ """
17
+
18
+ def __init__(self, image_size, target_face_scale, align_corners=False):
19
+ self.image_size = image_size
20
+ self.target_face_scale = target_face_scale
21
+ self.align_corners = align_corners
22
+
23
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
24
+ cosv = math.cos(angle)
25
+ sinv = math.sin(angle)
26
+
27
+ fx, fy = from_center
28
+ tx, ty = to_center
29
+
30
+ acos = scale * cosv
31
+ asin = scale * sinv
32
+
33
+ a0 = acos
34
+ a1 = -asin
35
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
36
+
37
+ b0 = asin
38
+ b1 = acos
39
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
40
+
41
+ rot_scale_m = np.array([
42
+ [a0, a1, a2],
43
+ [b0, b1, b2],
44
+ [0.0, 0.0, 1.0]
45
+ ], np.float32)
46
+ return rot_scale_m
47
+
48
+ def process(self, scale, center_w, center_h):
49
+ if self.align_corners:
50
+ to_w, to_h = self.image_size - 1, self.image_size - 1
51
+ else:
52
+ to_w, to_h = self.image_size, self.image_size
53
+
54
+ rot_mu = 0
55
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
56
+ shift_xy_mu = (0, 0)
57
+ matrix = self._compose_rotate_and_scale(
58
+ rot_mu, scale_mu, shift_xy_mu,
59
+ from_center=[center_w, center_h],
60
+ to_center=[to_w / 2.0, to_h / 2.0])
61
+ return matrix
62
+
63
+
64
+ class TransformPerspective():
65
+ """
66
+ image, matrix3x3 -> transformed_image
67
+ """
68
+
69
+ def __init__(self, image_size):
70
+ self.image_size = image_size
71
+
72
+ def process(self, image, matrix):
73
+ return cv2.warpPerspective(
74
+ image, matrix, dsize=(self.image_size, self.image_size),
75
+ flags=cv2.INTER_LINEAR, borderValue=0)
76
+
77
+
78
+ class TransformPoints2D():
79
+ """
80
+ points (nx2), matrix (3x3) -> points (nx2)
81
+ """
82
+
83
+ def process(self, srcPoints, matrix):
84
+ # nx3
85
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
86
+ desPoints = desPoints @ np.transpose(matrix) # nx3
87
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
88
+ return desPoints.astype(srcPoints.dtype)
89
+
90
+ class Alignment:
91
+ def __init__(self, args, model_path, dl_framework, device_ids):
92
+ self.input_size = 256
93
+ self.target_face_scale = 1.0
94
+ self.dl_framework = dl_framework
95
+
96
+ # model
97
+ if self.dl_framework == "pytorch":
98
+ # conf
99
+ self.config = utility.get_config(args)
100
+ self.config.device_id = device_ids[0]
101
+ # set environment
102
+ utility.set_environment(self.config)
103
+ # self.config.init_instance()
104
+ # if self.config.logger is not None:
105
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
106
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
107
+
108
+ net = utility.get_net(self.config)
109
+ if device_ids == [-1]:
110
+ checkpoint = torch.load(model_path, map_location="cpu")
111
+ else:
112
+ checkpoint = torch.load(model_path)
113
+ net.load_state_dict(checkpoint["net"])
114
+
115
+ if self.config.device_id == -1:
116
+ net = net.cpu()
117
+ else:
118
+ net = net.to(self.config.device_id)
119
+
120
+ net.eval()
121
+ self.alignment = net
122
+ else:
123
+ assert False
124
+
125
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
126
+ align_corners=True)
127
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
128
+ self.transformPoints2D = TransformPoints2D()
129
+
130
+ def norm_points(self, points, align_corners=False):
131
+ if align_corners:
132
+ # [0, SIZE-1] -> [-1, +1]
133
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
134
+ else:
135
+ # [-0.5, SIZE-0.5] -> [-1, +1]
136
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
137
+
138
+ def denorm_points(self, points, align_corners=False):
139
+ if align_corners:
140
+ # [-1, +1] -> [0, SIZE-1]
141
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
142
+ else:
143
+ # [-1, +1] -> [-0.5, SIZE-0.5]
144
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
145
+
146
+ def preprocess(self, image, scale, center_w, center_h):
147
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
148
+ input_tensor = self.transformPerspective.process(image, matrix)
149
+ input_tensor = input_tensor[np.newaxis, :]
150
+
151
+ input_tensor = torch.from_numpy(input_tensor)
152
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
153
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
154
+
155
+ if self.config.device_id == -1:
156
+ input_tensor = input_tensor.cpu()
157
+ else:
158
+ input_tensor = input_tensor.to(self.config.device_id)
159
+
160
+ return input_tensor, matrix
161
+
162
+ def postprocess(self, srcPoints, coeff):
163
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
164
+ # matrix^(-1) * src = dst
165
+ # src = matrix * dst
166
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
167
+ for i in range(srcPoints.shape[0]):
168
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
169
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
170
+ return dstPoints
171
+
172
+ def analyze(self, image, scale, center_w, center_h):
173
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
174
+
175
+ if self.dl_framework == "pytorch":
176
+ with torch.no_grad():
177
+ output = self.alignment(input_tensor)
178
+ landmarks = output[-1][0]
179
+ else:
180
+ assert False
181
+
182
+ landmarks = self.denorm_points(landmarks)
183
+ landmarks = landmarks.data.cpu().numpy()[0]
184
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
185
+
186
+ return landmarks
187
+
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser(description="inference script")
190
+ parser.add_argument('--folder_path', type=str, help='Path to image folder')
191
+ args = parser.parse_args()
192
+
193
+ # args.folder_path = '/media/gyalex/Data/flame/ph_test/head_images/flame/image'
194
+
195
+ current_path = os.getcwd()
196
+
197
+ use_gpu = True
198
+ ########### face detection ############
199
+ if use_gpu:
200
+ device = torch.device("cuda:0")
201
+ else:
202
+ device = torch.device("cpu")
203
+
204
+ current_path = os.getcwd()
205
+ det_model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'FaceBoxesV2/weights/FaceBoxesV2.pth')
206
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
207
+
208
+ ########### facial alignment ############
209
+ model_path = os.path.join(current_path, 'preprocess', 'submodules', 'Landmark_detection', 'weights/68_keypoints_model.pkl')
210
+
211
+ if use_gpu:
212
+ device_ids = [0]
213
+ else:
214
+ device_ids = [-1]
215
+
216
+ args.config_name = 'alignment'
217
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
218
+
219
+ img_path_list = os.listdir(args.folder_path)
220
+ kpts_code = dict()
221
+
222
+ ########### inference ############
223
+ for file_name in img_path_list:
224
+ abs_path = os.path.join(args.folder_path, file_name)
225
+
226
+ image = cv2.imread(abs_path)
227
+ image_draw = copy.deepcopy(image)
228
+
229
+ detections, _ = detector.detect(image, 0.6, 1)
230
+ for idx in range(len(detections)):
231
+ x1_ori = detections[idx][2]
232
+ y1_ori = detections[idx][3]
233
+ x2_ori = x1_ori + detections[idx][4]
234
+ y2_ori = y1_ori + detections[idx][5]
235
+
236
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
237
+ center_w = (x1_ori + x2_ori) / 2
238
+ center_h = (y1_ori + y2_ori) / 2
239
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
240
+
241
+ landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
242
+ landmarks_pv_list = landmarks_pv.tolist()
243
+
244
+ for num in range(landmarks_pv.shape[0]):
245
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
246
+ 2, (0, 255, 0), -1)
247
+
248
+ kpts_code[file_name] = landmarks_pv_list
249
+ save_path = args.folder_path[:-5] + 'landmark'
250
+ cv2.imwrite(os.path.join(save_path, file_name), image_draw)
251
+
252
+ path = args.folder_path[:-5]
253
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
external/landmark_detection/infer_image.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+
8
+ # private package
9
+ from external.landmark_detection.lib import utility
10
+ from external.landmark_detection.FaceBoxesV2.faceboxes_detector import *
11
+
12
+ class GetCropMatrix():
13
+ """
14
+ from_shape -> transform_matrix
15
+ """
16
+
17
+ def __init__(self, image_size, target_face_scale, align_corners=False):
18
+ self.image_size = image_size
19
+ self.target_face_scale = target_face_scale
20
+ self.align_corners = align_corners
21
+
22
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
23
+ cosv = math.cos(angle)
24
+ sinv = math.sin(angle)
25
+
26
+ fx, fy = from_center
27
+ tx, ty = to_center
28
+
29
+ acos = scale * cosv
30
+ asin = scale * sinv
31
+
32
+ a0 = acos
33
+ a1 = -asin
34
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
35
+
36
+ b0 = asin
37
+ b1 = acos
38
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
39
+
40
+ rot_scale_m = np.array([
41
+ [a0, a1, a2],
42
+ [b0, b1, b2],
43
+ [0.0, 0.0, 1.0]
44
+ ], np.float32)
45
+ return rot_scale_m
46
+
47
+ def process(self, scale, center_w, center_h):
48
+ if self.align_corners:
49
+ to_w, to_h = self.image_size - 1, self.image_size - 1
50
+ else:
51
+ to_w, to_h = self.image_size, self.image_size
52
+
53
+ rot_mu = 0
54
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
55
+ shift_xy_mu = (0, 0)
56
+ matrix = self._compose_rotate_and_scale(
57
+ rot_mu, scale_mu, shift_xy_mu,
58
+ from_center=[center_w, center_h],
59
+ to_center=[to_w / 2.0, to_h / 2.0])
60
+ return matrix
61
+
62
+
63
+ class TransformPerspective():
64
+ """
65
+ image, matrix3x3 -> transformed_image
66
+ """
67
+
68
+ def __init__(self, image_size):
69
+ self.image_size = image_size
70
+
71
+ def process(self, image, matrix):
72
+ return cv2.warpPerspective(
73
+ image, matrix, dsize=(self.image_size, self.image_size),
74
+ flags=cv2.INTER_LINEAR, borderValue=0)
75
+
76
+
77
+ class TransformPoints2D():
78
+ """
79
+ points (nx2), matrix (3x3) -> points (nx2)
80
+ """
81
+
82
+ def process(self, srcPoints, matrix):
83
+ # nx3
84
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
85
+ desPoints = desPoints @ np.transpose(matrix) # nx3
86
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
87
+ return desPoints.astype(srcPoints.dtype)
88
+
89
+ class Alignment:
90
+ def __init__(self, args, model_path, dl_framework, device_ids):
91
+ self.input_size = 256
92
+ self.target_face_scale = 1.0
93
+ self.dl_framework = dl_framework
94
+
95
+ # model
96
+ if self.dl_framework == "pytorch":
97
+ # conf
98
+ self.config = utility.get_config(args)
99
+ self.config.device_id = device_ids[0]
100
+ # set environment
101
+ # utility.set_environment(self.config)
102
+ # self.config.init_instance()
103
+ # if self.config.logger is not None:
104
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
105
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
106
+
107
+ net = utility.get_net(self.config)
108
+ if device_ids == [-1]:
109
+ checkpoint = torch.load(model_path, map_location="cpu")
110
+ else:
111
+ checkpoint = torch.load(model_path)
112
+ net.load_state_dict(checkpoint["net"])
113
+
114
+ if self.config.device_id == -1:
115
+ net = net.cpu()
116
+ else:
117
+ net = net.to(self.config.device_id)
118
+
119
+ net.eval()
120
+ self.alignment = net
121
+ else:
122
+ assert False
123
+
124
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
125
+ align_corners=True)
126
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
127
+ self.transformPoints2D = TransformPoints2D()
128
+
129
+ def norm_points(self, points, align_corners=False):
130
+ if align_corners:
131
+ # [0, SIZE-1] -> [-1, +1]
132
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
133
+ else:
134
+ # [-0.5, SIZE-0.5] -> [-1, +1]
135
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
136
+
137
+ def denorm_points(self, points, align_corners=False):
138
+ if align_corners:
139
+ # [-1, +1] -> [0, SIZE-1]
140
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
141
+ else:
142
+ # [-1, +1] -> [-0.5, SIZE-0.5]
143
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
144
+
145
+ def preprocess(self, image, scale, center_w, center_h):
146
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
147
+ input_tensor = self.transformPerspective.process(image, matrix)
148
+ input_tensor = input_tensor[np.newaxis, :]
149
+
150
+ input_tensor = torch.from_numpy(input_tensor)
151
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
152
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
153
+
154
+ if self.config.device_id == -1:
155
+ input_tensor = input_tensor.cpu()
156
+ else:
157
+ input_tensor = input_tensor.to(self.config.device_id)
158
+
159
+ return input_tensor, matrix
160
+
161
+ def postprocess(self, srcPoints, coeff):
162
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
163
+ # matrix^(-1) * src = dst
164
+ # src = matrix * dst
165
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
166
+ for i in range(srcPoints.shape[0]):
167
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
168
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
169
+ return dstPoints
170
+
171
+ def analyze(self, image, scale, center_w, center_h):
172
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
173
+
174
+ if self.dl_framework == "pytorch":
175
+ with torch.no_grad():
176
+ output = self.alignment(input_tensor)
177
+ landmarks = output[-1][0]
178
+ else:
179
+ assert False
180
+
181
+ landmarks = self.denorm_points(landmarks)
182
+ landmarks = landmarks.data.cpu().numpy()[0]
183
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
184
+
185
+ return landmarks
186
+
187
+ # parser = argparse.ArgumentParser(description="Evaluation script")
188
+ # args = parser.parse_args()
189
+ # image_path = './rgb.png'
190
+ # image = cv2.imread(image_path)
191
+ #
192
+ # use_gpu = False
193
+ # ########### face detection ############
194
+ # if use_gpu:
195
+ # device = torch.device("cuda:0")
196
+ # else:
197
+ # device = torch.device("cpu")
198
+ #
199
+ # detector = FaceBoxesDetector('FaceBoxes', 'FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
200
+ #
201
+ # ########### facial alignment ############
202
+ # model_path = './weights/68_keypoints_model.pkl'
203
+ #
204
+ # if use_gpu:
205
+ # device_ids = [0]
206
+ # else:
207
+ # device_ids = [-1]
208
+ #
209
+ # args.config_name = 'alignment'
210
+ # alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
211
+ # image_draw = copy.deepcopy(image)
212
+ #
213
+ # ########### inference ############
214
+ # ldk_list = []
215
+ #
216
+ # detections, _ = detector.detect(image, 0.9, 1)
217
+ # for idx in range(len(detections)):
218
+ # x1_ori = detections[idx][2]
219
+ # y1_ori = detections[idx][3]
220
+ # x2_ori = x1_ori + detections[idx][4]
221
+ # y2_ori = y1_ori + detections[idx][5]
222
+ #
223
+ # scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
224
+ # center_w = (x1_ori + x2_ori) / 2
225
+ # center_h = (y1_ori + y2_ori) / 2
226
+ # scale, center_w, center_h = float(scale), float(center_w), float(center_h)
227
+ #
228
+ # landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
229
+ #
230
+ # for num in range(landmarks_pv.shape[0]):
231
+ # cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
232
+ # 2, (0, 255, 0), -1)
233
+ #
234
+ # ldk_list.append([round(landmarks_pv[num][0]), round(landmarks_pv[num][1])])
235
+ #
236
+ # cv2.imshow("win", image_draw)
237
+ #
238
+ # # ldk_img = cv2.imread('/home/gyalex/Desktop/image_landmark_149/all.jpg')
239
+ # # cv2.imshow("win1", ldk_img)
240
+ #
241
+ # cv2.waitKey(0)
242
+ #
243
+ # with open('./cord.txt', 'w') as f:
244
+ # for num in range(len(ldk_list)):
245
+ # s = str(ldk_list[num][0]) + ' ' + str(ldk_list[num][1]) + '\n'
246
+ # f.write(s)
247
+ #
248
+ # f.close()
249
+
250
+
251
+
external/landmark_detection/infer_video.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import copy
4
+ import numpy as np
5
+ import argparse
6
+ import torch
7
+ import json
8
+
9
+ # private package
10
+ from lib import utility
11
+ from FaceBoxesV2.faceboxes_detector import *
12
+
13
+ class GetCropMatrix():
14
+ """
15
+ from_shape -> transform_matrix
16
+ """
17
+
18
+ def __init__(self, image_size, target_face_scale, align_corners=False):
19
+ self.image_size = image_size
20
+ self.target_face_scale = target_face_scale
21
+ self.align_corners = align_corners
22
+
23
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
24
+ cosv = math.cos(angle)
25
+ sinv = math.sin(angle)
26
+
27
+ fx, fy = from_center
28
+ tx, ty = to_center
29
+
30
+ acos = scale * cosv
31
+ asin = scale * sinv
32
+
33
+ a0 = acos
34
+ a1 = -asin
35
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
36
+
37
+ b0 = asin
38
+ b1 = acos
39
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
40
+
41
+ rot_scale_m = np.array([
42
+ [a0, a1, a2],
43
+ [b0, b1, b2],
44
+ [0.0, 0.0, 1.0]
45
+ ], np.float32)
46
+ return rot_scale_m
47
+
48
+ def process(self, scale, center_w, center_h):
49
+ if self.align_corners:
50
+ to_w, to_h = self.image_size - 1, self.image_size - 1
51
+ else:
52
+ to_w, to_h = self.image_size, self.image_size
53
+
54
+ rot_mu = 0
55
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
56
+ shift_xy_mu = (0, 0)
57
+ matrix = self._compose_rotate_and_scale(
58
+ rot_mu, scale_mu, shift_xy_mu,
59
+ from_center=[center_w, center_h],
60
+ to_center=[to_w / 2.0, to_h / 2.0])
61
+ return matrix
62
+
63
+
64
+ class TransformPerspective():
65
+ """
66
+ image, matrix3x3 -> transformed_image
67
+ """
68
+
69
+ def __init__(self, image_size):
70
+ self.image_size = image_size
71
+
72
+ def process(self, image, matrix):
73
+ return cv2.warpPerspective(
74
+ image, matrix, dsize=(self.image_size, self.image_size),
75
+ flags=cv2.INTER_LINEAR, borderValue=0)
76
+
77
+
78
+ class TransformPoints2D():
79
+ """
80
+ points (nx2), matrix (3x3) -> points (nx2)
81
+ """
82
+
83
+ def process(self, srcPoints, matrix):
84
+ # nx3
85
+ desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
86
+ desPoints = desPoints @ np.transpose(matrix) # nx3
87
+ desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
88
+ return desPoints.astype(srcPoints.dtype)
89
+
90
+ class Alignment:
91
+ def __init__(self, args, model_path, dl_framework, device_ids):
92
+ self.input_size = 256
93
+ self.target_face_scale = 1.0
94
+ self.dl_framework = dl_framework
95
+
96
+ # model
97
+ if self.dl_framework == "pytorch":
98
+ # conf
99
+ self.config = utility.get_config(args)
100
+ self.config.device_id = device_ids[0]
101
+ # set environment
102
+ utility.set_environment(self.config)
103
+ # self.config.init_instance()
104
+ # if self.config.logger is not None:
105
+ # self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
106
+ # self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
107
+
108
+ net = utility.get_net(self.config)
109
+ if device_ids == [-1]:
110
+ checkpoint = torch.load(model_path, map_location="cpu")
111
+ else:
112
+ checkpoint = torch.load(model_path)
113
+ net.load_state_dict(checkpoint["net"])
114
+
115
+ if self.config.device_id == -1:
116
+ net = net.cpu()
117
+ else:
118
+ net = net.to(self.config.device_id)
119
+
120
+ net.eval()
121
+ self.alignment = net
122
+ else:
123
+ assert False
124
+
125
+ self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
126
+ align_corners=True)
127
+ self.transformPerspective = TransformPerspective(image_size=self.input_size)
128
+ self.transformPoints2D = TransformPoints2D()
129
+
130
+ def norm_points(self, points, align_corners=False):
131
+ if align_corners:
132
+ # [0, SIZE-1] -> [-1, +1]
133
+ return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
134
+ else:
135
+ # [-0.5, SIZE-0.5] -> [-1, +1]
136
+ return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
137
+
138
+ def denorm_points(self, points, align_corners=False):
139
+ if align_corners:
140
+ # [-1, +1] -> [0, SIZE-1]
141
+ return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
142
+ else:
143
+ # [-1, +1] -> [-0.5, SIZE-0.5]
144
+ return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
145
+
146
+ def preprocess(self, image, scale, center_w, center_h):
147
+ matrix = self.getCropMatrix.process(scale, center_w, center_h)
148
+ input_tensor = self.transformPerspective.process(image, matrix)
149
+ input_tensor = input_tensor[np.newaxis, :]
150
+
151
+ input_tensor = torch.from_numpy(input_tensor)
152
+ input_tensor = input_tensor.float().permute(0, 3, 1, 2)
153
+ input_tensor = input_tensor / 255.0 * 2.0 - 1.0
154
+
155
+ if self.config.device_id == -1:
156
+ input_tensor = input_tensor.cpu()
157
+ else:
158
+ input_tensor = input_tensor.to(self.config.device_id)
159
+
160
+ return input_tensor, matrix
161
+
162
+ def postprocess(self, srcPoints, coeff):
163
+ # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
164
+ # matrix^(-1) * src = dst
165
+ # src = matrix * dst
166
+ dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
167
+ for i in range(srcPoints.shape[0]):
168
+ dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
169
+ dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
170
+ return dstPoints
171
+
172
+ def analyze(self, image, scale, center_w, center_h):
173
+ input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
174
+
175
+ if self.dl_framework == "pytorch":
176
+ with torch.no_grad():
177
+ output = self.alignment(input_tensor)
178
+ landmarks = output[-1][0]
179
+ else:
180
+ assert False
181
+
182
+ landmarks = self.denorm_points(landmarks)
183
+ landmarks = landmarks.data.cpu().numpy()[0]
184
+ landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
185
+
186
+ return landmarks
187
+
188
+ if __name__ == '__main__':
189
+ parser = argparse.ArgumentParser(description="inference script")
190
+ parser.add_argument('--video_path', type=str, help='Path to videos',default='/media/yuanzhen/HH/DATASET/VFTH/TESTVIDEO/Clip+7CzHzeeVRlE+P0+C0+F101007-101139.mp4')
191
+ args = parser.parse_args()
192
+
193
+ # args.video_path = '/media/gyalex/Data/flame/ph_test/test.mp4'
194
+
195
+ current_path = os.getcwd()
196
+
197
+ use_gpu = True
198
+ ########### face detection ############
199
+ if use_gpu:
200
+ device = torch.device("cuda:0")
201
+ else:
202
+ device = torch.device("cpu")
203
+
204
+ current_path = os.getcwd()
205
+ det_model_path = '/home/yuanzhen/code/landmark_detection/FaceBoxesV2/weights/FaceBoxesV2.pth'
206
+ detector = FaceBoxesDetector('FaceBoxes', det_model_path, use_gpu, device)
207
+
208
+ ########### facial alignment ############
209
+ model_path = '/home/yuanzhen/code/landmark_detection/weights/68_keypoints_model.pkl'
210
+
211
+ if use_gpu:
212
+ device_ids = [0]
213
+ else:
214
+ device_ids = [-1]
215
+
216
+ args.config_name = 'alignment'
217
+ alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
218
+
219
+ video_file = args.video_path
220
+ cap = cv2.VideoCapture(video_file)
221
+ frame_width = int(cap.get(3))
222
+ frame_height = int(cap.get(4))
223
+
224
+ # out_video_file = './output_video.mp4'
225
+ # fps = 30
226
+ # size = (frame_width, frame_height)
227
+ # out = cv2.VideoWriter(out_video_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
228
+
229
+ count = 0
230
+ kpts_code = dict()
231
+
232
+ keypoint_data_path = args.video_path.replace('.mp4','.json')
233
+ with open(keypoint_data_path,'r') as f:
234
+ keypoint_data = json.load(f)
235
+
236
+ ########### inference ############
237
+ path = video_file[:-4]
238
+ while(cap.isOpened()):
239
+ ret, image = cap.read()
240
+
241
+ if ret:
242
+ detections, _ = detector.detect(image, 0.8, 1)
243
+ image_draw = copy.deepcopy(image)
244
+
245
+ cv2.imwrite(os.path.join(path, 'image', str(count+1)+'.png'), image_draw)
246
+
247
+ for idx in range(len(detections)):
248
+ x1_ori = detections[idx][2]
249
+ y1_ori = detections[idx][3]
250
+ x2_ori = x1_ori + detections[idx][4]
251
+ y2_ori = y1_ori + detections[idx][5]
252
+
253
+ scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180
254
+ center_w = (x1_ori + x2_ori) / 2
255
+ center_h = (y1_ori + y2_ori) / 2
256
+ scale, center_w, center_h = float(scale), float(center_w), float(center_h)
257
+
258
+ # landmarks_pv = alignment.analyze(image, scale, center_w, center_h)
259
+ landmarks_pv = np.array(keypoint_data[str(count+1)+'.png'])
260
+
261
+ landmarks_pv_list = landmarks_pv.tolist()
262
+
263
+ for num in range(landmarks_pv.shape[0]):
264
+ cv2.circle(image_draw, (round(landmarks_pv[num][0]), round(landmarks_pv[num][1])),
265
+ 2, (0, 255, 0), -1)
266
+ cv2.putText(image_draw, str(num),
267
+ (round(landmarks_pv[num][0]) + 5, round(landmarks_pv[num][1]) + 5), # 文本位置
268
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)
269
+
270
+ kpts_code[str(count+1)+'.png'] = landmarks_pv_list
271
+ cv2.imwrite(os.path.join(path, 'landmark', str(count+1)+'.png'), image_draw)
272
+ else:
273
+ break
274
+
275
+ count += 1
276
+
277
+ cap.release()
278
+ # out.release()
279
+ # cv2.destroyAllWindows()
280
+
281
+ path = video_file[:-4]
282
+ json.dump(kpts_code, open(os.path.join(path, 'keypoint.json'), 'w'))
283
+
284
+ print(path)
285
+
286
+
287
+
external/landmark_detection/lib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import get_encoder, get_decoder
2
+ from .dataset import AlignmentDataset, Augmentation
3
+ from .backbone import StackedHGNetV1
4
+ from .metric import NME, Accuracy
5
+ from .utils import time_print, time_string, time_for_file, time_string_short
6
+ from .utils import convert_secs2time, convert_size2str
7
+
8
+ from .utility import get_dataloader, get_config, get_net, get_criterions
9
+ from .utility import get_optimizer, get_scheduler
external/landmark_detection/lib/backbone/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .stackedHGNetV1 import StackedHGNetV1
2
+
3
+ __all__ = [
4
+ "StackedHGNetV1",
5
+ ]
external/landmark_detection/lib/backbone/core/coord_conv.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AddCoordsTh(nn.Module):
6
+ def __init__(self, x_dim, y_dim, with_r=False, with_boundary=False):
7
+ super(AddCoordsTh, self).__init__()
8
+ self.x_dim = x_dim
9
+ self.y_dim = y_dim
10
+ self.with_r = with_r
11
+ self.with_boundary = with_boundary
12
+
13
+ def forward(self, input_tensor, heatmap=None):
14
+ """
15
+ input_tensor: (batch, c, x_dim, y_dim)
16
+ """
17
+ batch_size_tensor = input_tensor.shape[0]
18
+
19
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor)
20
+ xx_ones = xx_ones.unsqueeze(-1)
21
+
22
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
23
+ xx_range = xx_range.unsqueeze(1)
24
+
25
+ xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
26
+ xx_channel = xx_channel.unsqueeze(-1)
27
+
28
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor)
29
+ yy_ones = yy_ones.unsqueeze(1)
30
+
31
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
32
+ yy_range = yy_range.unsqueeze(-1)
33
+
34
+ yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
35
+ yy_channel = yy_channel.unsqueeze(-1)
36
+
37
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
38
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
39
+
40
+ xx_channel = xx_channel / (self.x_dim - 1)
41
+ yy_channel = yy_channel / (self.y_dim - 1)
42
+
43
+ xx_channel = xx_channel * 2 - 1
44
+ yy_channel = yy_channel * 2 - 1
45
+
46
+ xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
47
+ yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
48
+
49
+ if self.with_boundary and type(heatmap) != type(None):
50
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
51
+ 0.0, 1.0)
52
+
53
+ zero_tensor = torch.zeros_like(xx_channel).to(xx_channel)
54
+ xx_boundary_channel = torch.where(boundary_channel>0.05,
55
+ xx_channel, zero_tensor)
56
+ yy_boundary_channel = torch.where(boundary_channel>0.05,
57
+ yy_channel, zero_tensor)
58
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
59
+
60
+
61
+ if self.with_r:
62
+ rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
63
+ rr = rr / torch.max(rr)
64
+ ret = torch.cat([ret, rr], dim=1)
65
+
66
+ if self.with_boundary and type(heatmap) != type(None):
67
+ ret = torch.cat([ret, xx_boundary_channel,
68
+ yy_boundary_channel], dim=1)
69
+ return ret
70
+
71
+
72
+ class CoordConvTh(nn.Module):
73
+ """CoordConv layer as in the paper."""
74
+ def __init__(self, x_dim, y_dim, with_r, with_boundary,
75
+ in_channels, out_channels, first_one=False, relu=False, bn=False, *args, **kwargs):
76
+ super(CoordConvTh, self).__init__()
77
+ self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
78
+ with_boundary=with_boundary)
79
+ in_channels += 2
80
+ if with_r:
81
+ in_channels += 1
82
+ if with_boundary and not first_one:
83
+ in_channels += 2
84
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, *args, **kwargs)
85
+ self.relu = nn.ReLU() if relu else None
86
+ self.bn = nn.BatchNorm2d(out_channels) if bn else None
87
+
88
+ self.with_boundary = with_boundary
89
+ self.first_one = first_one
90
+
91
+
92
+ def forward(self, input_tensor, heatmap=None):
93
+ assert (self.with_boundary and not self.first_one) == (heatmap is not None)
94
+ ret = self.addcoords(input_tensor, heatmap)
95
+ ret = self.conv(ret)
96
+ if self.bn is not None:
97
+ ret = self.bn(ret)
98
+ if self.relu is not None:
99
+ ret = self.relu(ret)
100
+
101
+ return ret
102
+
103
+
104
+ '''
105
+ An alternative implementation for PyTorch with auto-infering the x-y dimensions.
106
+ '''
107
+ class AddCoords(nn.Module):
108
+
109
+ def __init__(self, with_r=False):
110
+ super().__init__()
111
+ self.with_r = with_r
112
+
113
+ def forward(self, input_tensor):
114
+ """
115
+ Args:
116
+ input_tensor: shape(batch, channel, x_dim, y_dim)
117
+ """
118
+ batch_size, _, x_dim, y_dim = input_tensor.size()
119
+
120
+ xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1).to(input_tensor)
121
+ yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2).to(input_tensor)
122
+
123
+ xx_channel = xx_channel / (x_dim - 1)
124
+ yy_channel = yy_channel / (y_dim - 1)
125
+
126
+ xx_channel = xx_channel * 2 - 1
127
+ yy_channel = yy_channel * 2 - 1
128
+
129
+ xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
130
+ yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
131
+
132
+ ret = torch.cat([
133
+ input_tensor,
134
+ xx_channel.type_as(input_tensor),
135
+ yy_channel.type_as(input_tensor)], dim=1)
136
+
137
+ if self.with_r:
138
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
139
+ ret = torch.cat([ret, rr], dim=1)
140
+
141
+ return ret
142
+
143
+
144
+ class CoordConv(nn.Module):
145
+
146
+ def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
147
+ super().__init__()
148
+ self.addcoords = AddCoords(with_r=with_r)
149
+ in_channels += 2
150
+ if with_r:
151
+ in_channels += 1
152
+ self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
153
+
154
+ def forward(self, x):
155
+ ret = self.addcoords(x)
156
+ ret = self.conv(ret)
157
+ return ret
external/landmark_detection/lib/backbone/stackedHGNetV1.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .core.coord_conv import CoordConvTh
8
+ from external.landmark_detection.lib.dataset import get_decoder
9
+
10
+
11
+
12
+ class Activation(nn.Module):
13
+ def __init__(self, kind: str = 'relu', channel=None):
14
+ super().__init__()
15
+ self.kind = kind
16
+
17
+ if '+' in kind:
18
+ norm_str, act_str = kind.split('+')
19
+ else:
20
+ norm_str, act_str = 'none', kind
21
+
22
+ self.norm_fn = {
23
+ 'in': F.instance_norm,
24
+ 'bn': nn.BatchNorm2d(channel),
25
+ 'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True),
26
+ 'none': None
27
+ }[norm_str]
28
+
29
+ self.act_fn = {
30
+ 'relu': F.relu,
31
+ 'softplus': nn.Softplus(),
32
+ 'exp': torch.exp,
33
+ 'sigmoid': torch.sigmoid,
34
+ 'tanh': torch.tanh,
35
+ 'none': None
36
+ }[act_str]
37
+
38
+ self.channel = channel
39
+
40
+ def forward(self, x):
41
+ if self.norm_fn is not None:
42
+ x = self.norm_fn(x)
43
+ if self.act_fn is not None:
44
+ x = self.act_fn(x)
45
+ return x
46
+
47
+ def extra_repr(self):
48
+ return f'kind={self.kind}, channel={self.channel}'
49
+
50
+
51
+ class ConvBlock(nn.Module):
52
+ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1):
53
+ super(ConvBlock, self).__init__()
54
+ self.inp_dim = inp_dim
55
+ self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size,
56
+ stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True)
57
+ self.relu = None
58
+ self.bn = None
59
+ if relu:
60
+ self.relu = nn.ReLU()
61
+ if bn:
62
+ self.bn = nn.BatchNorm2d(out_dim)
63
+
64
+ def forward(self, x):
65
+ x = self.conv(x)
66
+ if self.bn is not None:
67
+ x = self.bn(x)
68
+ if self.relu is not None:
69
+ x = self.relu(x)
70
+ return x
71
+
72
+
73
+ class ResBlock(nn.Module):
74
+ def __init__(self, inp_dim, out_dim, mid_dim=None):
75
+ super(ResBlock, self).__init__()
76
+ if mid_dim is None:
77
+ mid_dim = out_dim // 2
78
+ self.relu = nn.ReLU()
79
+ self.bn1 = nn.BatchNorm2d(inp_dim)
80
+ self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False)
81
+ self.bn2 = nn.BatchNorm2d(mid_dim)
82
+ self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False)
83
+ self.bn3 = nn.BatchNorm2d(mid_dim)
84
+ self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False)
85
+ self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False)
86
+ if inp_dim == out_dim:
87
+ self.need_skip = False
88
+ else:
89
+ self.need_skip = True
90
+
91
+ def forward(self, x):
92
+ if self.need_skip:
93
+ residual = self.skip_layer(x)
94
+ else:
95
+ residual = x
96
+ out = self.bn1(x)
97
+ out = self.relu(out)
98
+ out = self.conv1(out)
99
+ out = self.bn2(out)
100
+ out = self.relu(out)
101
+ out = self.conv2(out)
102
+ out = self.bn3(out)
103
+ out = self.relu(out)
104
+ out = self.conv3(out)
105
+ out += residual
106
+ return out
107
+
108
+
109
+ class Hourglass(nn.Module):
110
+ def __init__(self, n, f, increase=0, up_mode='nearest',
111
+ add_coord=False, first_one=False, x_dim=64, y_dim=64):
112
+ super(Hourglass, self).__init__()
113
+ nf = f + increase
114
+
115
+ Block = ResBlock
116
+
117
+ if add_coord:
118
+ self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim,
119
+ with_r=True, with_boundary=True,
120
+ relu=False, bn=False,
121
+ in_channels=f, out_channels=f,
122
+ first_one=first_one,
123
+ kernel_size=1,
124
+ stride=1, padding=0)
125
+ else:
126
+ self.coordconv = None
127
+ self.up1 = Block(f, f)
128
+
129
+ # Lower branch
130
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
131
+
132
+ self.low1 = Block(f, nf)
133
+ self.n = n
134
+ # Recursive hourglass
135
+ if self.n > 1:
136
+ self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False)
137
+ else:
138
+ self.low2 = Block(nf, nf)
139
+ self.low3 = Block(nf, f)
140
+ self.up2 = nn.Upsample(scale_factor=2, mode=up_mode)
141
+
142
+ def forward(self, x, heatmap=None):
143
+ if self.coordconv is not None:
144
+ x = self.coordconv(x, heatmap)
145
+ up1 = self.up1(x)
146
+ pool1 = self.pool1(x)
147
+ low1 = self.low1(pool1)
148
+ low2 = self.low2(low1)
149
+ low3 = self.low3(low2)
150
+ up2 = self.up2(low3)
151
+ return up1 + up2
152
+
153
+
154
+ class E2HTransform(nn.Module):
155
+ def __init__(self, edge_info, num_points, num_edges):
156
+ super().__init__()
157
+
158
+ e2h_matrix = np.zeros([num_points, num_edges])
159
+ for edge_id, isclosed_indices in enumerate(edge_info):
160
+ is_closed, indices = isclosed_indices
161
+ for point_id in indices:
162
+ e2h_matrix[point_id, edge_id] = 1
163
+ e2h_matrix = torch.from_numpy(e2h_matrix).float()
164
+
165
+ # pn x en x 1 x 1.
166
+ self.register_buffer('weight', e2h_matrix.view(
167
+ e2h_matrix.size(0), e2h_matrix.size(1), 1, 1))
168
+
169
+ # some keypoints are not coverred by any edges,
170
+ # in these cases, we must add a constant bias to their heatmap weights.
171
+ bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to(
172
+ e2h_matrix)) < 0.5).to(e2h_matrix)
173
+ # pn x 1.
174
+ self.register_buffer('bias', bias)
175
+
176
+ def forward(self, edgemaps):
177
+ # input: batch_size x en x hw x hh.
178
+ # output: batch_size x pn x hw x hh.
179
+ return F.conv2d(edgemaps, weight=self.weight, bias=self.bias)
180
+
181
+
182
+ class StackedHGNetV1(nn.Module):
183
+ def __init__(self, config, classes_num, edge_info,
184
+ nstack=4, nlevels=4, in_channel=256, increase=0,
185
+ add_coord=True, decoder_type='default'):
186
+ super(StackedHGNetV1, self).__init__()
187
+
188
+ self.cfg = config
189
+ self.coder_type = decoder_type
190
+ self.decoder = get_decoder(decoder_type=decoder_type)
191
+ self.nstack = nstack
192
+ self.add_coord = add_coord
193
+
194
+ self.num_heats = classes_num[0]
195
+
196
+ if self.add_coord:
197
+ convBlock = CoordConvTh(x_dim=self.cfg.width, y_dim=self.cfg.height,
198
+ with_r=True, with_boundary=False,
199
+ relu=True, bn=True,
200
+ in_channels=3, out_channels=64,
201
+ kernel_size=7,
202
+ stride=2, padding=3)
203
+ else:
204
+ convBlock = ConvBlock(3, 64, 7, 2, bn=True, relu=True)
205
+
206
+ pool = nn.MaxPool2d(kernel_size=2, stride=2)
207
+
208
+ Block = ResBlock
209
+
210
+ self.pre = nn.Sequential(
211
+ convBlock,
212
+ Block(64, 128),
213
+ pool,
214
+ Block(128, 128),
215
+ Block(128, in_channel)
216
+ )
217
+
218
+ self.hgs = nn.ModuleList(
219
+ [Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0),
220
+ x_dim=int(self.cfg.width / self.nstack), y_dim=int(self.cfg.height / self.nstack))
221
+ for _ in range(nstack)])
222
+
223
+ self.features = nn.ModuleList([
224
+ nn.Sequential(
225
+ Block(in_channel, in_channel),
226
+ ConvBlock(in_channel, in_channel, 1, bn=True, relu=True)
227
+ ) for _ in range(nstack)])
228
+
229
+ self.out_heatmaps = nn.ModuleList(
230
+ [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
231
+ for _ in range(nstack)])
232
+
233
+ if self.cfg.use_AAM:
234
+ self.num_edges = classes_num[1]
235
+ self.num_points = classes_num[2]
236
+
237
+ self.e2h_transform = E2HTransform(edge_info, self.num_points, self.num_edges)
238
+ self.out_edgemaps = nn.ModuleList(
239
+ [ConvBlock(in_channel, self.num_edges, 1, relu=False, bn=False)
240
+ for _ in range(nstack)])
241
+ self.out_pointmaps = nn.ModuleList(
242
+ [ConvBlock(in_channel, self.num_points, 1, relu=False, bn=False)
243
+ for _ in range(nstack)])
244
+ self.merge_edgemaps = nn.ModuleList(
245
+ [ConvBlock(self.num_edges, in_channel, 1, relu=False, bn=False)
246
+ for _ in range(nstack - 1)])
247
+ self.merge_pointmaps = nn.ModuleList(
248
+ [ConvBlock(self.num_points, in_channel, 1, relu=False, bn=False)
249
+ for _ in range(nstack - 1)])
250
+ self.edgemap_act = Activation("sigmoid", self.num_edges)
251
+ self.pointmap_act = Activation("sigmoid", self.num_points)
252
+
253
+ self.merge_features = nn.ModuleList(
254
+ [ConvBlock(in_channel, in_channel, 1, relu=False, bn=False)
255
+ for _ in range(nstack - 1)])
256
+ self.merge_heatmaps = nn.ModuleList(
257
+ [ConvBlock(self.num_heats, in_channel, 1, relu=False, bn=False)
258
+ for _ in range(nstack - 1)])
259
+
260
+ self.nstack = nstack
261
+
262
+ self.heatmap_act = Activation("in+relu", self.num_heats)
263
+
264
+ self.inference = False
265
+
266
+ def set_inference(self, inference):
267
+ self.inference = inference
268
+
269
+ def forward(self, x):
270
+ x = self.pre(x)
271
+
272
+ y, fusionmaps = [], []
273
+ heatmaps = None
274
+ for i in range(self.nstack):
275
+ hg = self.hgs[i](x, heatmap=heatmaps)
276
+ feature = self.features[i](hg)
277
+
278
+ heatmaps0 = self.out_heatmaps[i](feature)
279
+ heatmaps = self.heatmap_act(heatmaps0)
280
+
281
+ if self.cfg.use_AAM:
282
+ pointmaps0 = self.out_pointmaps[i](feature)
283
+ pointmaps = self.pointmap_act(pointmaps0)
284
+ edgemaps0 = self.out_edgemaps[i](feature)
285
+ edgemaps = self.edgemap_act(edgemaps0)
286
+ mask = self.e2h_transform(edgemaps) * pointmaps
287
+ fusion_heatmaps = mask * heatmaps
288
+ else:
289
+ fusion_heatmaps = heatmaps
290
+
291
+ landmarks = self.decoder.get_coords_from_heatmap(fusion_heatmaps)
292
+
293
+ if i < self.nstack - 1:
294
+ x = x + self.merge_features[i](feature) + \
295
+ self.merge_heatmaps[i](heatmaps)
296
+ if self.cfg.use_AAM:
297
+ x += self.merge_pointmaps[i](pointmaps)
298
+ x += self.merge_edgemaps[i](edgemaps)
299
+
300
+ y.append(landmarks)
301
+ if self.cfg.use_AAM:
302
+ y.append(pointmaps)
303
+ y.append(edgemaps)
304
+
305
+ fusionmaps.append(fusion_heatmaps)
306
+
307
+ return y, fusionmaps, landmarks
external/landmark_detection/lib/dataset/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .encoder import get_encoder
2
+ from .decoder import get_decoder
3
+ from .augmentation import Augmentation
4
+ from .alignmentDataset import AlignmentDataset
5
+
6
+ __all__ = [
7
+ "Augmentation",
8
+ "AlignmentDataset",
9
+ "get_encoder",
10
+ "get_decoder"
11
+ ]
external/landmark_detection/lib/dataset/alignmentDataset.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import math
5
+ import copy
6
+ import hashlib
7
+ import imageio
8
+ import numpy as np
9
+ import pandas as pd
10
+ from scipy import interpolate
11
+ from PIL import Image, ImageEnhance, ImageFile
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import Dataset
16
+
17
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
18
+
19
+ sys.path.append("./")
20
+ from external.landmark_detection.lib.dataset.augmentation import Augmentation
21
+ from external.landmark_detection.lib.dataset.encoder import get_encoder
22
+
23
+
24
+ class AlignmentDataset(Dataset):
25
+
26
+ def __init__(self, tsv_flie, image_dir="", transform=None,
27
+ width=256, height=256, channels=3,
28
+ means=(127.5, 127.5, 127.5), scale=1 / 127.5,
29
+ classes_num=None, crop_op=True, aug_prob=0.0, edge_info=None, flip_mapping=None, is_train=True,
30
+ encoder_type='default',
31
+ ):
32
+ super(AlignmentDataset, self).__init__()
33
+ self.use_AAM = True
34
+ self.encoder_type = encoder_type
35
+ self.encoder = get_encoder(height, width, encoder_type=encoder_type)
36
+ self.items = pd.read_csv(tsv_flie, sep="\t")
37
+ self.image_dir = image_dir
38
+ self.landmark_num = classes_num[0]
39
+ self.transform = transform
40
+
41
+ self.image_width = width
42
+ self.image_height = height
43
+ self.channels = channels
44
+ assert self.image_width == self.image_height
45
+
46
+ self.means = means
47
+ self.scale = scale
48
+
49
+ self.aug_prob = aug_prob
50
+ self.edge_info = edge_info
51
+ self.is_train = is_train
52
+ std_lmk_5pts = np.array([
53
+ 196.0, 226.0,
54
+ 316.0, 226.0,
55
+ 256.0, 286.0,
56
+ 220.0, 360.4,
57
+ 292.0, 360.4], np.float32) / 256.0 - 1.0
58
+ std_lmk_5pts = np.reshape(std_lmk_5pts, (5, 2)) # [-1 1]
59
+ target_face_scale = 1.0 if crop_op else 1.25
60
+
61
+ self.augmentation = Augmentation(
62
+ is_train=self.is_train,
63
+ aug_prob=self.aug_prob,
64
+ image_size=self.image_width,
65
+ crop_op=crop_op,
66
+ std_lmk_5pts=std_lmk_5pts,
67
+ target_face_scale=target_face_scale,
68
+ flip_rate=0.5,
69
+ flip_mapping=flip_mapping,
70
+ random_shift_sigma=0.05,
71
+ random_rot_sigma=math.pi / 180 * 18,
72
+ random_scale_sigma=0.1,
73
+ random_gray_rate=0.2,
74
+ random_occ_rate=0.4,
75
+ random_blur_rate=0.3,
76
+ random_gamma_rate=0.2,
77
+ random_nose_fusion_rate=0.2)
78
+
79
+ def _circle(self, img, pt, sigma=1.0, label_type='Gaussian'):
80
+ # Check that any part of the gaussian is in-bounds
81
+ tmp_size = sigma * 3
82
+ ul = [int(pt[0] - tmp_size), int(pt[1] - tmp_size)]
83
+ br = [int(pt[0] + tmp_size + 1), int(pt[1] + tmp_size + 1)]
84
+ if (ul[0] > img.shape[1] - 1 or ul[1] > img.shape[0] - 1 or
85
+ br[0] - 1 < 0 or br[1] - 1 < 0):
86
+ # If not, just return the image as is
87
+ return img
88
+
89
+ # Generate gaussian
90
+ size = 2 * tmp_size + 1
91
+ x = np.arange(0, size, 1, np.float32)
92
+ y = x[:, np.newaxis]
93
+ x0 = y0 = size // 2
94
+ # The gaussian is not normalized, we want the center value to equal 1
95
+ if label_type == 'Gaussian':
96
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
97
+ else:
98
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
99
+
100
+ # Usable gaussian range
101
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
102
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
103
+ # Image range
104
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
105
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
106
+
107
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = 255 * g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
108
+ return img
109
+
110
+ def _polylines(self, img, lmks, is_closed, color=255, thickness=1, draw_mode=cv2.LINE_AA,
111
+ interpolate_mode=cv2.INTER_AREA, scale=4):
112
+ h, w = img.shape
113
+ img_scale = cv2.resize(img, (w * scale, h * scale), interpolation=interpolate_mode)
114
+ lmks_scale = (lmks * scale + 0.5).astype(np.int32)
115
+ cv2.polylines(img_scale, [lmks_scale], is_closed, color, thickness * scale, draw_mode)
116
+ img = cv2.resize(img_scale, (w, h), interpolation=interpolate_mode)
117
+ return img
118
+
119
+ def _generate_edgemap(self, points, scale=0.25, thickness=1):
120
+ h, w = self.image_height, self.image_width
121
+ edgemaps = []
122
+ for is_closed, indices in self.edge_info:
123
+ edgemap = np.zeros([h, w], dtype=np.float32)
124
+ # align_corners: False.
125
+ part = copy.deepcopy(points[np.array(indices)])
126
+
127
+ part = self._fit_curve(part, is_closed)
128
+ part[:, 0] = np.clip(part[:, 0], 0, w - 1)
129
+ part[:, 1] = np.clip(part[:, 1], 0, h - 1)
130
+ edgemap = self._polylines(edgemap, part, is_closed, 255, thickness)
131
+
132
+ edgemaps.append(edgemap)
133
+ edgemaps = np.stack(edgemaps, axis=0) / 255.0
134
+ edgemaps = torch.from_numpy(edgemaps).float().unsqueeze(0)
135
+ edgemaps = F.interpolate(edgemaps, size=(int(w * scale), int(h * scale)), mode='bilinear',
136
+ align_corners=False).squeeze()
137
+ return edgemaps
138
+
139
+ def _fit_curve(self, lmks, is_closed=False, density=5):
140
+ try:
141
+ x = lmks[:, 0].copy()
142
+ y = lmks[:, 1].copy()
143
+ if is_closed:
144
+ x = np.append(x, x[0])
145
+ y = np.append(y, y[0])
146
+ tck, u = interpolate.splprep([x, y], s=0, per=is_closed, k=3)
147
+ # bins = (x.shape[0] - 1) * density + 1
148
+ # lmk_x, lmk_y = interpolate.splev(np.linspace(0, 1, bins), f)
149
+ intervals = np.array([])
150
+ for i in range(len(u) - 1):
151
+ intervals = np.concatenate((intervals, np.linspace(u[i], u[i + 1], density, endpoint=False)))
152
+ if not is_closed:
153
+ intervals = np.concatenate((intervals, [u[-1]]))
154
+ lmk_x, lmk_y = interpolate.splev(intervals, tck, der=0)
155
+ # der_x, der_y = interpolate.splev(intervals, tck, der=1)
156
+ curve_lmks = np.stack([lmk_x, lmk_y], axis=-1)
157
+ # curve_ders = np.stack([der_x, der_y], axis=-1)
158
+ # origin_indices = np.arange(0, curve_lmks.shape[0], density)
159
+
160
+ return curve_lmks
161
+ except:
162
+ return lmks
163
+
164
+ def _image_id(self, image_path):
165
+ if not os.path.exists(image_path):
166
+ image_path = os.path.join(self.image_dir, image_path)
167
+ return hashlib.md5(open(image_path, "rb").read()).hexdigest()
168
+
169
+ def _load_image(self, image_path):
170
+ if not os.path.exists(image_path):
171
+ image_path = os.path.join(self.image_dir, image_path)
172
+
173
+ try:
174
+ # img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)#HWC, BGR, [0-255]
175
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR) # HWC, BGR, [0-255]
176
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
177
+ except:
178
+ try:
179
+ img = imageio.imread(image_path) # HWC, RGB, [0-255]
180
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # HWC, BGR, [0-255]
181
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
182
+ except:
183
+ try:
184
+ gifImg = imageio.mimread(image_path) # BHWC, RGB, [0-255]
185
+ img = gifImg[0] # HWC, RGB, [0-255]
186
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # HWC, BGR, [0-255]
187
+ assert img is not None and len(img.shape) == 3 and img.shape[2] == 3
188
+ except:
189
+ img = None
190
+ return img
191
+
192
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
193
+ cosv = math.cos(angle)
194
+ sinv = math.sin(angle)
195
+
196
+ fx, fy = from_center
197
+ tx, ty = to_center
198
+
199
+ acos = scale * cosv
200
+ asin = scale * sinv
201
+
202
+ a0 = acos
203
+ a1 = -asin
204
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
205
+
206
+ b0 = asin
207
+ b1 = acos
208
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
209
+
210
+ rot_scale_m = np.array([
211
+ [a0, a1, a2],
212
+ [b0, b1, b2],
213
+ [0.0, 0.0, 1.0]
214
+ ], np.float32)
215
+ return rot_scale_m
216
+
217
+ def _transformPoints2D(self, points, matrix):
218
+ """
219
+ points (nx2), matrix (3x3) -> points (nx2)
220
+ """
221
+ dtype = points.dtype
222
+
223
+ # nx3
224
+ points = np.concatenate([points, np.ones_like(points[:, [0]])], axis=1)
225
+ points = points @ np.transpose(matrix) # nx3
226
+ points = points[:, :2] / points[:, [2, 2]]
227
+ return points.astype(dtype)
228
+
229
+ def _transformPerspective(self, image, matrix, target_shape):
230
+ """
231
+ image, matrix3x3 -> transformed_image
232
+ """
233
+ return cv2.warpPerspective(
234
+ image, matrix,
235
+ dsize=(target_shape[1], target_shape[0]),
236
+ flags=cv2.INTER_LINEAR, borderValue=0)
237
+
238
+ def _norm_points(self, points, h, w, align_corners=False):
239
+ if align_corners:
240
+ # [0, SIZE-1] -> [-1, +1]
241
+ des_points = points / torch.tensor([w - 1, h - 1]).to(points).view(1, 2) * 2 - 1
242
+ else:
243
+ # [-0.5, SIZE-0.5] -> [-1, +1]
244
+ des_points = (points * 2 + 1) / torch.tensor([w, h]).to(points).view(1, 2) - 1
245
+ des_points = torch.clamp(des_points, -1, 1)
246
+ return des_points
247
+
248
+ def _denorm_points(self, points, h, w, align_corners=False):
249
+ if align_corners:
250
+ # [-1, +1] -> [0, SIZE-1]
251
+ des_points = (points + 1) / 2 * torch.tensor([w - 1, h - 1]).to(points).view(1, 1, 2)
252
+ else:
253
+ # [-1, +1] -> [-0.5, SIZE-0.5]
254
+ des_points = ((points + 1) * torch.tensor([w, h]).to(points).view(1, 1, 2) - 1) / 2
255
+ return des_points
256
+
257
+ def __len__(self):
258
+ return len(self.items)
259
+
260
+ def __getitem__(self, index):
261
+ sample = dict()
262
+
263
+ image_path = self.items.iloc[index, 0]
264
+ landmarks_5pts = self.items.iloc[index, 1]
265
+ landmarks_5pts = np.array(list(map(float, landmarks_5pts.split(","))), dtype=np.float32).reshape(5, 2)
266
+ landmarks_target = self.items.iloc[index, 2]
267
+ landmarks_target = np.array(list(map(float, landmarks_target.split(","))), dtype=np.float32).reshape(
268
+ self.landmark_num, 2)
269
+ scale = float(self.items.iloc[index, 3])
270
+ center_w, center_h = float(self.items.iloc[index, 4]), float(self.items.iloc[index, 5])
271
+ if len(self.items.iloc[index]) > 6:
272
+ tags = np.array(list(map(lambda x: int(float(x)), self.items.iloc[index, 6].split(","))))
273
+ else:
274
+ tags = np.array([])
275
+
276
+ # image & keypoints alignment
277
+ image_path = image_path.replace('\\', '/')
278
+ # wflw testset
279
+ image_path = image_path.replace(
280
+ '//msr-facestore/Workspace/MSRA_EP_Allergan/users/yanghuan/training_data/wflw/rawImages/', '')
281
+ # trainset
282
+ image_path = image_path.replace('./rawImages/', '')
283
+ image_path = os.path.join(self.image_dir, image_path)
284
+
285
+ # image path
286
+ sample["image_path"] = image_path
287
+
288
+ img = self._load_image(image_path) # HWC, BGR, [0, 255]
289
+ assert img is not None
290
+
291
+ # augmentation
292
+ # landmarks_target = [-0.5, edge-0.5]
293
+ img, landmarks_target, matrix = \
294
+ self.augmentation.process(img, landmarks_target, landmarks_5pts, scale, center_w, center_h)
295
+
296
+ landmarks = self._norm_points(torch.from_numpy(landmarks_target), self.image_height, self.image_width)
297
+
298
+ sample["label"] = [landmarks, ]
299
+
300
+ if self.use_AAM:
301
+ pointmap = self.encoder.generate_heatmap(landmarks_target)
302
+ edgemap = self._generate_edgemap(landmarks_target)
303
+ sample["label"] += [pointmap, edgemap]
304
+
305
+ sample['matrix'] = matrix
306
+
307
+ # image normalization
308
+ img = img.transpose(2, 0, 1).astype(np.float32) # CHW, BGR, [0, 255]
309
+ img[0, :, :] = (img[0, :, :] - self.means[0]) * self.scale
310
+ img[1, :, :] = (img[1, :, :] - self.means[1]) * self.scale
311
+ img[2, :, :] = (img[2, :, :] - self.means[2]) * self.scale
312
+ sample["data"] = torch.from_numpy(img) # CHW, BGR, [-1, 1]
313
+
314
+ sample["tags"] = tags
315
+
316
+ return sample
external/landmark_detection/lib/dataset/augmentation.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ from skimage import transform
7
+
8
+
9
+ class Augmentation:
10
+ def __init__(self,
11
+ is_train=True,
12
+ aug_prob=1.0,
13
+ image_size=256,
14
+ crop_op=True,
15
+ std_lmk_5pts=None,
16
+ target_face_scale=1.0,
17
+ flip_rate=0.5,
18
+ flip_mapping=None,
19
+ random_shift_sigma=0.05,
20
+ random_rot_sigma=math.pi/180*18,
21
+ random_scale_sigma=0.1,
22
+ random_gray_rate=0.2,
23
+ random_occ_rate=0.4,
24
+ random_blur_rate=0.3,
25
+ random_gamma_rate=0.2,
26
+ random_nose_fusion_rate=0.2):
27
+ self.is_train = is_train
28
+ self.aug_prob = aug_prob
29
+ self.crop_op = crop_op
30
+ self._flip = Flip(flip_mapping, flip_rate)
31
+ if self.crop_op:
32
+ self._cropMatrix = GetCropMatrix(
33
+ image_size=image_size,
34
+ target_face_scale=target_face_scale,
35
+ align_corners=True)
36
+ else:
37
+ self._alignMatrix = GetAlignMatrix(
38
+ image_size=image_size,
39
+ target_face_scale=target_face_scale,
40
+ std_lmk_5pts=std_lmk_5pts)
41
+ self._randomGeometryMatrix = GetRandomGeometryMatrix(
42
+ target_shape=(image_size, image_size),
43
+ from_shape=(image_size, image_size),
44
+ shift_sigma=random_shift_sigma,
45
+ rot_sigma=random_rot_sigma,
46
+ scale_sigma=random_scale_sigma,
47
+ align_corners=True)
48
+ self._transform = Transform(image_size=image_size)
49
+ self._randomTexture = RandomTexture(
50
+ random_gray_rate=random_gray_rate,
51
+ random_occ_rate=random_occ_rate,
52
+ random_blur_rate=random_blur_rate,
53
+ random_gamma_rate=random_gamma_rate,
54
+ random_nose_fusion_rate=random_nose_fusion_rate)
55
+
56
+ def process(self, img, lmk, lmk_5pts=None, scale=1.0, center_w=0, center_h=0, is_train=True):
57
+ if self.is_train and random.random() < self.aug_prob:
58
+ img, lmk, lmk_5pts, center_w, center_h = self._flip.process(img, lmk, lmk_5pts, center_w, center_h)
59
+ matrix_geoaug = self._randomGeometryMatrix.process()
60
+ if self.crop_op:
61
+ matrix_pre = self._cropMatrix.process(scale, center_w, center_h)
62
+ else:
63
+ matrix_pre = self._alignMatrix.process(lmk_5pts)
64
+ matrix = matrix_geoaug @ matrix_pre
65
+ aug_img, aug_lmk = self._transform.process(img, lmk, matrix)
66
+ aug_img = self._randomTexture.process(aug_img)
67
+ else:
68
+ if self.crop_op:
69
+ matrix = self._cropMatrix.process(scale, center_w, center_h)
70
+ else:
71
+ matrix = self._alignMatrix.process(lmk_5pts)
72
+ aug_img, aug_lmk = self._transform.process(img, lmk, matrix)
73
+ return aug_img, aug_lmk, matrix
74
+
75
+
76
+ class GetCropMatrix:
77
+ def __init__(self, image_size, target_face_scale, align_corners=False):
78
+ self.image_size = image_size
79
+ self.target_face_scale = target_face_scale
80
+ self.align_corners = align_corners
81
+
82
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
83
+ cosv = math.cos(angle)
84
+ sinv = math.sin(angle)
85
+
86
+ fx, fy = from_center
87
+ tx, ty = to_center
88
+
89
+ acos = scale * cosv
90
+ asin = scale * sinv
91
+
92
+ a0 = acos
93
+ a1 = -asin
94
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
95
+
96
+ b0 = asin
97
+ b1 = acos
98
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
99
+
100
+ rot_scale_m = np.array([
101
+ [a0, a1, a2],
102
+ [b0, b1, b2],
103
+ [0.0, 0.0, 1.0]
104
+ ], np.float32)
105
+ return rot_scale_m
106
+
107
+ def process(self, scale, center_w, center_h):
108
+ if self.align_corners:
109
+ to_w, to_h = self.image_size-1, self.image_size-1
110
+ else:
111
+ to_w, to_h = self.image_size, self.image_size
112
+
113
+ rot_mu = 0
114
+ scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
115
+ shift_xy_mu = (0, 0)
116
+ matrix = self._compose_rotate_and_scale(
117
+ rot_mu, scale_mu, shift_xy_mu,
118
+ from_center=[center_w, center_h],
119
+ to_center=[to_w/2.0, to_h/2.0])
120
+ return matrix
121
+
122
+
123
+ class GetAlignMatrix:
124
+ def __init__(self, image_size, target_face_scale, std_lmk_5pts):
125
+ """
126
+ points in std_lmk_5pts range from -1 to 1.
127
+ """
128
+ self.std_lmk_5pts = (std_lmk_5pts * target_face_scale + 1) * \
129
+ np.array([image_size, image_size], np.float32) / 2.0
130
+
131
+ def process(self, lmk_5pts):
132
+ assert lmk_5pts.shape[-2:] == (5, 2)
133
+ tform = transform.SimilarityTransform()
134
+ tform.estimate(lmk_5pts, self.std_lmk_5pts)
135
+ return tform.params
136
+
137
+
138
+ class GetRandomGeometryMatrix:
139
+ def __init__(self, target_shape, from_shape,
140
+ shift_sigma=0.1, rot_sigma=18*math.pi/180, scale_sigma=0.1,
141
+ shift_mu=0.0, rot_mu=0.0, scale_mu=1.0,
142
+ shift_normal=True, rot_normal=True, scale_normal=True,
143
+ align_corners=False):
144
+ self.target_shape = target_shape
145
+ self.from_shape = from_shape
146
+ self.shift_config = (shift_mu, shift_sigma, shift_normal)
147
+ self.rot_config = (rot_mu, rot_sigma, rot_normal)
148
+ self.scale_config = (scale_mu, scale_sigma, scale_normal)
149
+ self.align_corners = align_corners
150
+
151
+ def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
152
+ cosv = math.cos(angle)
153
+ sinv = math.sin(angle)
154
+
155
+ fx, fy = from_center
156
+ tx, ty = to_center
157
+
158
+ acos = scale * cosv
159
+ asin = scale * sinv
160
+
161
+ a0 = acos
162
+ a1 = -asin
163
+ a2 = tx - acos * fx + asin * fy + shift_xy[0]
164
+
165
+ b0 = asin
166
+ b1 = acos
167
+ b2 = ty - asin * fx - acos * fy + shift_xy[1]
168
+
169
+ rot_scale_m = np.array([
170
+ [a0, a1, a2],
171
+ [b0, b1, b2],
172
+ [0.0, 0.0, 1.0]
173
+ ], np.float32)
174
+ return rot_scale_m
175
+
176
+ def _random(self, mu_sigma_normal, size=None):
177
+ mu, sigma, is_normal = mu_sigma_normal
178
+ if is_normal:
179
+ return np.random.normal(mu, sigma, size=size)
180
+ else:
181
+ return np.random.uniform(low=mu-sigma, high=mu+sigma, size=size)
182
+
183
+ def process(self):
184
+ if self.align_corners:
185
+ from_w, from_h = self.from_shape[1]-1, self.from_shape[0]-1
186
+ to_w, to_h = self.target_shape[1]-1, self.target_shape[0]-1
187
+ else:
188
+ from_w, from_h = self.from_shape[1], self.from_shape[0]
189
+ to_w, to_h = self.target_shape[1], self.target_shape[0]
190
+
191
+ if self.shift_config[:2] != (0.0, 0.0) or \
192
+ self.rot_config[:2] != (0.0, 0.0) or \
193
+ self.scale_config[:2] != (1.0, 0.0):
194
+ shift_xy = self._random(self.shift_config, size=[2]) * \
195
+ min(to_h, to_w)
196
+ rot_angle = self._random(self.rot_config)
197
+ scale = self._random(self.scale_config)
198
+ matrix_geoaug = self._compose_rotate_and_scale(
199
+ rot_angle, scale, shift_xy,
200
+ from_center=[from_w/2.0, from_h/2.0],
201
+ to_center=[to_w/2.0, to_h/2.0])
202
+
203
+ return matrix_geoaug
204
+
205
+
206
+ class Transform:
207
+ def __init__(self, image_size):
208
+ self.image_size = image_size
209
+
210
+ def _transformPoints2D(self, points, matrix):
211
+ """
212
+ points (nx2), matrix (3x3) -> points (nx2)
213
+ """
214
+ dtype = points.dtype
215
+
216
+ # nx3
217
+ points = np.concatenate([points, np.ones_like(points[:, [0]])], axis=1)
218
+ points = points @ np.transpose(matrix)
219
+ points = points[:, :2] / points[:, [2, 2]]
220
+ return points.astype(dtype)
221
+
222
+ def _transformPerspective(self, image, matrix):
223
+ """
224
+ image, matrix3x3 -> transformed_image
225
+ """
226
+ return cv2.warpPerspective(
227
+ image, matrix,
228
+ dsize=(self.image_size, self.image_size),
229
+ flags=cv2.INTER_LINEAR, borderValue=0)
230
+
231
+ def process(self, image, landmarks, matrix):
232
+ t_landmarks = self._transformPoints2D(landmarks, matrix)
233
+ t_image = self._transformPerspective(image, matrix)
234
+ return t_image, t_landmarks
235
+
236
+
237
+ class RandomTexture:
238
+ def __init__(self, random_gray_rate=0, random_occ_rate=0, random_blur_rate=0, random_gamma_rate=0, random_nose_fusion_rate=0):
239
+ self.random_gray_rate = random_gray_rate
240
+ self.random_occ_rate = random_occ_rate
241
+ self.random_blur_rate = random_blur_rate
242
+ self.random_gamma_rate = random_gamma_rate
243
+ self.random_nose_fusion_rate = random_nose_fusion_rate
244
+ self.texture_augs = (
245
+ (self.add_occ, self.random_occ_rate),
246
+ (self.add_blur, self.random_blur_rate),
247
+ (self.add_gamma, self.random_gamma_rate),
248
+ (self.add_nose_fusion, self.random_nose_fusion_rate)
249
+ )
250
+
251
+ def add_gray(self, image):
252
+ assert image.ndim == 3 and image.shape[-1] == 3
253
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
254
+ image = np.tile(np.expand_dims(image, -1), [1, 1, 3])
255
+ return image
256
+
257
+ def add_occ(self, image):
258
+ h, w, c = image.shape
259
+ rh = 0.2 + 0.6 * random.random() # [0.2, 0.8]
260
+ rw = rh - 0.2 + 0.4 * random.random()
261
+ cx = int((h - 1) * random.random())
262
+ cy = int((w - 1) * random.random())
263
+ dh = int(h / 2 * rh)
264
+ dw = int(w / 2 * rw)
265
+ x0 = max(0, cx - dw // 2)
266
+ y0 = max(0, cy - dh // 2)
267
+ x1 = min(w - 1, cx + dw // 2)
268
+ y1 = min(h - 1, cy + dh // 2)
269
+ image[y0:y1+1, x0:x1+1] = 0
270
+ return image
271
+
272
+ def add_blur(self, image):
273
+ blur_kratio = 0.05 * random.random()
274
+ blur_ksize = int((image.shape[0] + image.shape[1]) / 2 * blur_kratio)
275
+ if blur_ksize > 1:
276
+ image = cv2.blur(image, (blur_ksize, blur_ksize))
277
+ return image
278
+
279
+ def add_gamma(self, image):
280
+ if random.random() < 0.5:
281
+ gamma = 0.25 + 0.75 * random.random()
282
+ else:
283
+ gamma = 1.0 + 3.0 * random.random()
284
+ image = (((image / 255.0) ** gamma) * 255).astype("uint8")
285
+ return image
286
+
287
+ def add_nose_fusion(self, image):
288
+ h, w, c = image.shape
289
+ nose = np.array(bytearray(os.urandom(h * w * c)), dtype=image.dtype).reshape(h, w, c)
290
+ alpha = 0.5 * random.random()
291
+ image = (1 - alpha) * image + alpha * nose
292
+ return image.astype(np.uint8)
293
+
294
+ def process(self, image):
295
+ image = image.copy()
296
+ if random.random() < self.random_occ_rate:
297
+ image = self.add_occ(image)
298
+ if random.random() < self.random_blur_rate:
299
+ image = self.add_blur(image)
300
+ if random.random() < self.random_gamma_rate:
301
+ image = self.add_gamma(image)
302
+ if random.random() < self.random_nose_fusion_rate:
303
+ image = self.add_nose_fusion(image)
304
+ """
305
+ orders = list(range(len(self.texture_augs)))
306
+ random.shuffle(orders)
307
+ for order in orders:
308
+ if random.random() < self.texture_augs[order][1]:
309
+ image = self.texture_augs[order][0](image)
310
+ """
311
+
312
+ if random.random() < self.random_gray_rate:
313
+ image = self.add_gray(image)
314
+
315
+ return image
316
+
317
+
318
+ class Flip:
319
+ def __init__(self, flip_mapping, random_rate):
320
+ self.flip_mapping = flip_mapping
321
+ self.random_rate = random_rate
322
+
323
+ def process(self, image, landmarks, landmarks_5pts, center_w, center_h):
324
+ if random.random() >= self.random_rate or self.flip_mapping is None:
325
+ return image, landmarks, landmarks_5pts, center_w, center_h
326
+
327
+ # COFW
328
+ if landmarks.shape[0] == 29:
329
+ flip_offset = 0
330
+ # 300W, WFLW
331
+ elif landmarks.shape[0] in (68, 98):
332
+ flip_offset = -1
333
+ else:
334
+ flip_offset = -1
335
+
336
+ h, w, _ = image.shape
337
+ #image_flip = cv2.flip(image, 1)
338
+ image_flip = np.fliplr(image).copy()
339
+ landmarks_flip = landmarks.copy()
340
+ for i, j in self.flip_mapping:
341
+ landmarks_flip[i] = landmarks[j]
342
+ landmarks_flip[j] = landmarks[i]
343
+ landmarks_flip[:, 0] = w + flip_offset - landmarks_flip[:, 0]
344
+ if landmarks_5pts is not None:
345
+ flip_mapping = ([0, 1], [3, 4])
346
+ landmarks_5pts_flip = landmarks_5pts.copy()
347
+ for i, j in flip_mapping:
348
+ landmarks_5pts_flip[i] = landmarks_5pts[j]
349
+ landmarks_5pts_flip[j] = landmarks_5pts[i]
350
+ landmarks_5pts_flip[:, 0] = w + flip_offset - landmarks_5pts_flip[:, 0]
351
+ else:
352
+ landmarks_5pts_flip = None
353
+
354
+ center_w = w + flip_offset - center_w
355
+ return image_flip, landmarks_flip, landmarks_5pts_flip, center_w, center_h
external/landmark_detection/lib/dataset/decoder/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .decoder_default import decoder_default
2
+
3
+ def get_decoder(decoder_type='default'):
4
+ if decoder_type == 'default':
5
+ decoder = decoder_default()
6
+ else:
7
+ raise NotImplementedError
8
+ return decoder
external/landmark_detection/lib/dataset/decoder/decoder_default.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class decoder_default:
5
+ def __init__(self, weight=1, use_weight_map=False):
6
+ self.weight = weight
7
+ self.use_weight_map = use_weight_map
8
+
9
+ def _make_grid(self, h, w):
10
+ yy, xx = torch.meshgrid(
11
+ torch.arange(h).float() / (h - 1) * 2 - 1,
12
+ torch.arange(w).float() / (w - 1) * 2 - 1)
13
+ return yy, xx
14
+
15
+ def get_coords_from_heatmap(self, heatmap):
16
+ """
17
+ inputs:
18
+ - heatmap: batch x npoints x h x w
19
+
20
+ outputs:
21
+ - coords: batch x npoints x 2 (x,y), [-1, +1]
22
+ - radius_sq: batch x npoints
23
+ """
24
+ batch, npoints, h, w = heatmap.shape
25
+ if self.use_weight_map:
26
+ heatmap = heatmap * self.weight
27
+
28
+ yy, xx = self._make_grid(h, w)
29
+ yy = yy.view(1, 1, h, w).to(heatmap)
30
+ xx = xx.view(1, 1, h, w).to(heatmap)
31
+
32
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
33
+
34
+ yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
35
+ xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # batch x npoints
36
+ coords = torch.stack([xx_coord, yy_coord], dim=-1)
37
+
38
+ return coords