maffia commited on
Commit
690f890
·
verified ·
1 Parent(s): 8519254

Upload 94 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. UserGuide.md +160 -0
  3. app.py +278 -0
  4. assets/images/test.jpg +3 -0
  5. assets/images/test2.jpg +0 -0
  6. assets/images/test3.jpg +3 -0
  7. assets/masks/test.png +0 -0
  8. assets/masks/test2.png +0 -0
  9. assets/materials/gr_infer_demo.jpg +3 -0
  10. assets/materials/gr_pre_demo.jpg +3 -0
  11. assets/materials/tasks.png +3 -0
  12. assets/materials/teaser.jpg +3 -0
  13. assets/videos/test.mp4 +3 -0
  14. assets/videos/test2.mp4 +0 -0
  15. benchmarks/.gitkeep +0 -0
  16. models/.gitkeep +0 -0
  17. pyproject.toml +75 -0
  18. requirements.txt +1 -0
  19. requirements/annotator.txt +6 -0
  20. requirements/framework.txt +26 -0
  21. tests/test_annotators.py +568 -0
  22. vace/__init__.py +6 -0
  23. vace/annotators/__init__.py +24 -0
  24. vace/annotators/canvas.py +60 -0
  25. vace/annotators/common.py +62 -0
  26. vace/annotators/composition.py +155 -0
  27. vace/annotators/depth.py +51 -0
  28. vace/annotators/dwpose/__init__.py +2 -0
  29. vace/annotators/dwpose/onnxdet.py +127 -0
  30. vace/annotators/dwpose/onnxpose.py +362 -0
  31. vace/annotators/dwpose/util.py +299 -0
  32. vace/annotators/dwpose/wholebody.py +80 -0
  33. vace/annotators/face.py +55 -0
  34. vace/annotators/flow.py +53 -0
  35. vace/annotators/frameref.py +118 -0
  36. vace/annotators/gdino.py +88 -0
  37. vace/annotators/gray.py +24 -0
  38. vace/annotators/inpainting.py +283 -0
  39. vace/annotators/layout.py +161 -0
  40. vace/annotators/mask.py +79 -0
  41. vace/annotators/maskaug.py +181 -0
  42. vace/annotators/midas/__init__.py +2 -0
  43. vace/annotators/midas/api.py +166 -0
  44. vace/annotators/midas/base_model.py +18 -0
  45. vace/annotators/midas/blocks.py +391 -0
  46. vace/annotators/midas/dpt_depth.py +107 -0
  47. vace/annotators/midas/midas_net.py +80 -0
  48. vace/annotators/midas/midas_net_custom.py +167 -0
  49. vace/annotators/midas/transforms.py +231 -0
  50. vace/annotators/midas/utils.py +193 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/images/test.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/images/test3.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/materials/gr_infer_demo.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/materials/gr_pre_demo.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/materials/tasks.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/materials/teaser.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/videos/test.mp4 filter=lfs diff=lfs merge=lfs -text
UserGuide.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VACE User Guide
2
+
3
+ ## 1. Overall Steps
4
+
5
+ - Preparation: Be aware of the task type ([single task](#32-single-task) or [multi-task composition](#33-composition-task)) of your creative idea, and prepare all the required materials (images, videos, prompt, etc.)
6
+ - Preprocessing: Select the appropriate preprocessing method based task name, then preprocess your materials to meet the model's input requirements.
7
+ - Inference: Based on the preprocessed materials, perform VACE inference to obtain results.
8
+
9
+ ## 2. Preparations
10
+
11
+ ### 2.1 Task Definition
12
+
13
+ VACE, as a unified video generation solution, simultaneously supports Video Generation, Video Editing, and complex composition task. Specifically:
14
+
15
+ - Video Generation: No video input. Injecting concepts into the model through semantic understanding based on text and reference materials, including **T2V** (Text-to-Video Generation) and **R2V** (Reference-to-Video Generation) tasks.
16
+ - Video Editing: With video input. Modifying input video at the pixel level globally or locally,including **V2V** (Video-to-Video Editing) and **MV2V** (Masked Video-to-Video Editing).
17
+ - Composition Task: Compose two or more single task above into a complex composition task, such as **Reference Anything** (Face R2V + Object R2V), **Move Anything**(Frame R2V + Layout V2V), **Animate Anything**(R2V + Pose V2V), **Swap Anything**(R2V + Inpainting MV2V), and **Expand Anything**(Object R2V + Frame R2V + Outpainting MV2V), etc.
18
+
19
+ Single tasks and compositional tasks are illustrated in the diagram below:
20
+
21
+ ![vace_task](assets/materials/tasks.png)
22
+
23
+
24
+ ### 2.2 Limitations
25
+
26
+ - Super high resolution video will be resized to proper spatial size.
27
+ - Super long video will be trimmed or uniformly sampled into around 5 seconds.
28
+ - For users who are demanding of long video generation, we recommend to generate 5s video clips one by one, while using `firstclip` video extension task to keep the temporal consistency.
29
+
30
+ ## 3. Preprocessing
31
+ ### 3.1 VACE-Recognizable Inputs
32
+
33
+ User-collected materials needs to be preprocessed into VACE-recognizable inputs, including **`src_video`**, **`src_mask`**, **`src_ref_images`**, and **`prompt`**.
34
+ Specific descriptions are as follows:
35
+
36
+ - `src_video`: The video to be edited for input into the model, such as condition videos (Depth, Pose, etc.) or in/outpainting input video. **Gray areas**(values equal to 127) represent missing video part. In first-frame R2V task, the first frame are reference frame while subsequent frames are left gray. The missing parts of in/outpainting `src_video` are also set gray.
37
+ - `src_mask`: A 3D mask in the same shape of `src_video`. **White areas** represent the parts to be generated, while **black areas** represent the parts to be retained.
38
+ - `src_ref_images`: Reference images of R2V. Salient object segmentation can be performed to keep the background white.
39
+ - `prompt`: A text describing the content of the output video. Prompt expansion can be used to achieve better generation effects for LTX-Video and English user of Wan2.1. Use descriptive prompt instead of instructions.
40
+
41
+ Among them, `prompt` is required while `src_video`, `src_mask`, and `src_ref_images` are optional. For instance, MV2V task requires `src_video`, `src_mask`, and `prompt`; R2V task only requires `src_ref_images` and `prompt`.
42
+
43
+ ### 3.2 Preprocessing Tools
44
+ Both command line and Gradio demo are supported.
45
+
46
+ 1) Command Line: You can refer to the `run_vace_preproccess.sh` script and invoke it based on the different task types. An example command is as follows:
47
+ ```bash
48
+ python vace/vace_preproccess.py --task depth --video assets/videos/test.mp4
49
+ ```
50
+
51
+ 2) Gradio Interactive: Launch the graphical interface for data preprocessing and perform preprocessing on the interface. The specific command is as follows:
52
+ ```bash
53
+ python vace/gradios/preprocess_demo.py
54
+ ```
55
+
56
+ ![gr_pre_demo](assets/materials/gr_pre_demo.jpg)
57
+
58
+
59
+ ### 3.2 Single Tasks
60
+
61
+ VACE is an all-in-one model supporting various task types. However, different preprocessing is required for these task types. The specific task types and descriptions are as follows:
62
+
63
+ | Task | Subtask | Annotator | Input modal | Params | Note |
64
+ |------------|----------------------|----------------------------|------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------|
65
+ | txt2vid | txt2vid | / | / | / | |
66
+ | control | depth | DepthVideoAnnotator | video | / | |
67
+ | control | flow | FlowVisAnnotator | video | / | |
68
+ | control | gray | GrayVideoAnnotator | video | / | |
69
+ | control | pose | PoseBodyFaceVideoAnnotator | video | / | |
70
+ | control | scribble | ScribbleVideoAnnotator | video | / | |
71
+ | control | layout_bbox | LayoutBboxAnnotator | two bboxes <br>'x1,y1,x2,y2 x1,y1,x2,y2' | / | Move linearly from the first box to the second box |
72
+ | control | layout_track | LayoutTrackAnnotator | video | mode='masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Mode represents different methods of subject tracking. |
73
+ | extension | frameref | FrameRefExpandAnnotator | image | mode='firstframe'<br>expand_num=80 (default) | |
74
+ | extension | frameref | FrameRefExpandAnnotator | image | mode='lastframe'<br>expand_num=80 (default) | |
75
+ | extension | frameref | FrameRefExpandAnnotator | two images<br>a.jpg,b.jpg | mode='firstlastframe'<br>expand_num=80 (default) | Images are separated by commas. |
76
+ | extension | clipref | FrameRefExpandAnnotator | video | mode='firstclip'<br>expand_num=80 (default) | |
77
+ | extension | clipref | FrameRefExpandAnnotator | video | mode='lastclip'<br>expand_num=80 (default) | |
78
+ | extension | clipref | FrameRefExpandAnnotator | two videos<br>a.mp4,b.mp4 | mode='firstlastclip'<br>expand_num=80 (default) | Videos are separated by commas. |
79
+ | repainting | inpainting_mask | InpaintingAnnotator | video | mode='salient' | Use salient as a fixed mask. |
80
+ | repainting | inpainting_mask | InpaintingAnnotator | video + mask | mode='mask' | Use mask as a fixed mask. |
81
+ | repainting | inpainting_bbox | InpaintingAnnotator | video + bbox<br>'x1, y1, x2, y2' | mode='bbox' | Use bbox as a fixed mask. |
82
+ | repainting | inpainting_masktrack | InpaintingAnnotator | video | mode='salientmasktrack' | Use salient mask for dynamic tracking. |
83
+ | repainting | inpainting_masktrack | InpaintingAnnotator | video | mode='salientbboxtrack' | Use salient bbox for dynamic tracking. |
84
+ | repainting | inpainting_masktrack | InpaintingAnnotator | video + mask | mode='masktrack' | Use mask for dynamic tracking. |
85
+ | repainting | inpainting_bboxtrack | InpaintingAnnotator | video + bbox<br>'x1, y1, x2, y2' | mode='bboxtrack' | Use bbox for dynamic tracking. |
86
+ | repainting | inpainting_label | InpaintingAnnotator | video + label | mode='label' | Use label for dynamic tracking. |
87
+ | repainting | inpainting_caption | InpaintingAnnotator | video + caption | mode='caption' | Use caption for dynamic tracking. |
88
+ | repainting | outpainting | OutpaintingVideoAnnotator | video | direction=left/right/up/down<br>expand_ratio=0~1.0 | Combine outpainting directions arbitrarily. |
89
+ | reference | image_reference | SubjectAnnotator | image | mode='salient/mask/bbox/salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Use different methods to obtain the subject region. |
90
+
91
+ ### 3.3 Composition Task
92
+
93
+ Moreover, VACE supports combining tasks to accomplish more complex objectives. The following examples illustrate how tasks can be combined, but these combinations are not limited to the examples provided:
94
+
95
+ | Task | Subtask | Annotator | Input modal | Params | Note |
96
+ |-------------|--------------------|----------------------------|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------|
97
+ | composition | reference_anything | ReferenceAnythingAnnotator | image_list | mode='salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption' | Input no more than three images. |
98
+ | composition | animate_anything | AnimateAnythingAnnotator | image + video | mode='salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption' | Video for conditional redrawing; images for reference generation. |
99
+ | composition | swap_anything | SwapAnythingAnnotator | image + video | mode='masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Video for conditional redrawing; images for reference generation.<br>Comma-separated mode: first for video, second for images. |
100
+ | composition | expand_anything | ExpandAnythingAnnotator | image + image_list | mode='masktrack/bboxtrack/label/caption'<br>direction=left/right/up/down<br>expand_ratio=0~1.0<br>expand_num=80 (default) | First image for extension edit; others for reference.<br>Comma-separated mode: first for video, second for images. |
101
+ | composition | move_anything | MoveAnythingAnnotator | image + two bboxes | expand_num=80 (default) | First image for initial frame reference; others represented by linear bbox changes. |
102
+ | composition | more_anything | ... | ... | ... | ... |
103
+
104
+
105
+ ## 4. Model Inference
106
+
107
+ ### 4.1 Execution Methods
108
+
109
+ Both command line and Gradio demo are supported.
110
+
111
+ 1) Command Line: Refer to the `run_vace_ltx.sh` and `run_vace_wan.sh` scripts and invoke them based on the different task types. The input data needs to be preprocessed to obtain parameters such as `src_video`, `src_mask`, `src_ref_images` and `prompt`. An example command is as follows:
112
+ ```bash
113
+ python vace/vace_wan_inference.py --src_video <path-to-src-video> --src_mask <path-to-src-mask> --src_ref_images <paths-to-src-ref-images> --prompt <prompt> # wan
114
+ python vace/vace_ltx_inference.py --src_video <path-to-src-video> --src_mask <path-to-src-mask> --src_ref_images <paths-to-src-ref-images> --prompt <prompt> # ltx
115
+ ```
116
+
117
+ 2) Gradio Interactive: Launch the graphical interface for model inference and perform inference through interactions on the interface. The specific command is as follows:
118
+ ```bash
119
+ python vace/gradios/vace_wan_demo.py # wan
120
+ python vace/gradios/vace_ltx_demo.py # ltx
121
+ ```
122
+
123
+ ![gr_infer_demo](assets/materials/gr_infer_demo.jpg)
124
+
125
+ 3) End-to-End Inference: Refer to the `run_vace_pipeline.sh` script and invoke it based on different task types and input data. This pipeline includes both preprocessing and model inference, thereby requiring only user-provided materials. However, it offers relatively less flexibility. An example command is as follows:
126
+ ```bash
127
+ python vace/vace_pipeline.py --base wan --task depth --video <path-to-video> --prompt <prompt> # wan
128
+ python vace/vace_pipeline.py --base lxt --task depth --video <path-to-video> --prompt <prompt> # ltx
129
+ ```
130
+
131
+ ### 4.2 Inference Examples
132
+
133
+ We provide test examples under different tasks, enabling users to validate according to their needs. These include **task**, **sub-tasks**, **original inputs** (ori_videos and ori_images), **model inputs** (src_video, src_mask, src_ref_images, prompt), and **model outputs**.
134
+
135
+ | task | subtask | src_video | src_mask | src_ref_images | out_video | prompt | ori_video | ori_images |
136
+ |-------------|--------------------|----------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
137
+ | txt2vid | txt2vid | | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/txt2vid/out_video.mp4"></video> | 狂风巨浪的大海,镜头缓缓推进,一艘渺小的帆船在汹涌的波涛中挣扎漂荡。海面上白沫翻滚,帆船时隐时现,仿佛随时可能被巨浪吞噬。天空乌云密布,雷声轰鸣,海鸥在空中盘旋尖叫。帆船上的人们紧紧抓住缆绳,努力保持平衡。画面风格写实,充满紧张和动感。近景特写,强调风浪的冲击力和帆船的摇晃 | | |
138
+ | extension | firstframe | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/out_video.mp4"></video> | 纪实摄影风格,前景是一位中国越野爱好者坐在越野车上,手持车载电台正在进行通联。他五官清晰,表情专注,眼神坚定地望向前方。越野车停在户外,车身略显脏污,显示出经历过的艰难路况。镜头从车外缓缓拉近,最后定格在人物的面部特写上,展现出他的坚定与热情。中景到近景,动态镜头运镜。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/ori_image_1.png"> |
139
+ | repainting | inpainting | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/out_video.mp4"></video> | 一只巨大的金色凤凰从繁华的城市上空展翅飞过,羽毛如火焰般璀璨,闪烁着温暖的光辉,翅膀雄伟地展开。凤凰高昂着头,目光炯炯,轻轻扇动翅膀,散发出淡淡的光芒。下方是熙熙攘攘的市中心,人群惊叹,车水马龙,红蓝两色的霓虹灯在夜空下闪烁。镜头俯视城市街道,捕捉这一壮丽的景象,营造出既神秘又辉煌的氛围。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/ori_video.mp4"></video> | |
140
+ | repainting | outpainting | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/out_video.mp4"></video> | 赛博朋克风格,无人机俯瞰视角下的现代西安城墙,镜头穿过永宁门时泛起金色涟漪,城墙砖块化作数据流重组为唐代长安城。周围的街道上流动的人群和飞驰的机械交通工具交织在一起,现代与古代的交融,城墙上的灯光闪烁,形成时空隧道的效果。全息投影技术展现历史变迁,粒子重组特效细腻逼真。大远景逐渐过渡到特写,聚焦于城门特效。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/ori_video.mp4"></video> | |
141
+ | control | depth | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/out_video.mp4"></video> | 一群年轻人在天空之城拍摄集体照。画面中,一对年轻情侣手牵手,轻声细语,相视而笑,周围是飞翔的彩色热气球和闪烁的星星,营造出浪漫的氛围。天空中,暖阳透过飘浮的云朵,洒下斑驳的光影。镜头以近景特写开始,随着情侣间的亲密互动,缓缓拉远。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/ori_video.mp4"></video> | |
142
+ | control | flow | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/out_video.mp4"></video> | 纪实摄影风格,一颗鲜红的小番茄缓缓落入盛着牛奶的玻璃杯中,溅起晶莹的水花。画面以慢镜头捕捉这一瞬间,水花在空中绽放,形成美丽的弧线。玻璃杯中的牛奶纯白,番茄的鲜红与之形成鲜明对比。背景简洁,突出主体。近景特写,垂直俯视视角,展现细节之美。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/ori_video.mp4"></video> | |
143
+ | control | gray | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/out_video.mp4"></video> | 镜头缓缓向右平移,身穿淡黄色坎肩长裙的长发女孩面对镜头露出灿烂的漏齿微笑。她的长发随风轻扬,眼神明亮而充满活力。背景是秋天红色和黄色的树叶,阳光透过树叶的缝隙洒下斑驳光影,营造出温馨自然的氛围。画面风格清新自然,仿佛夏日午后的一抹清凉。中景人像,强调自然光效和细腻的皮肤质感。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/ori_video.mp4"></video> | |
144
+ | control | pose | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/out_video.mp4"></video> | 在一个热带的庆祝派对上,一家人围坐在椰子树下的长桌旁。桌上摆满了异国风味的美食。长辈们愉悦地交谈,年轻人兴奋地举杯碰撞,孩子们在沙滩上欢乐奔跑。背景中是湛蓝的海洋和明亮的阳光,营造出轻松的气氛。镜头以动态中景捕捉每个开心的瞬间,温暖的阳光映照着他们幸福的面庞。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/ori_video.mp4"></video> | |
145
+ | control | scribble | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/out_video.mp4"></video> | 画面中荧光色彩的无人机从极低空高速掠过超现实主义风格的西安古城墙,尘埃反射着阳光。镜头快速切换至城墙上的砖石特写,阳光温暖地洒落,勾勒出每一块砖块的细腻纹理。整体画质清晰华丽,运镜流畅如水。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/ori_video.mp4"></video> | |
146
+ | control | layout | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/out_video.mp4"></video> | 视频展示了一只成鸟在树枝上的巢中喂养它的幼鸟。成鸟在喂食的过程中,幼鸟张开嘴巴等待食物。随后,成鸟飞走,幼鸟继续等待。成鸟再次飞回,带回食物喂养幼鸟。整个视频的拍摄角度固定,聚焦于巢穴和鸟类的互动,背景是模糊的绿色植被,强调了鸟类的自然行为和生态环境。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/ori_video.mp4"></video> | |
147
+ | reference | face | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/face/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/face/out_video.mp4"></video> | 视频展示了一位长着尖耳朵的老人,他有一头银白色的长发和小胡子,穿着一件色彩斑斓的长袍,内搭金色衬衫,散发出神秘与智慧的气息。背景为一个华丽宫殿的内部,金碧辉煌。灯光明亮,照亮他脸上的神采奕奕。摄像机旋转动态拍摄,捕捉老人轻松挥手的动作。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/face/ori_image_1.png"> |
148
+ | reference | object | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/object/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/object/out_video.mp4"></video> | 经典游戏角色马里奥在绿松石色水下世界中,四周环绕着珊瑚和各种各样的热带鱼。马里奥兴奋地向上跳起,摆出经典的欢快姿势,身穿鲜明的蓝色潜水服,红色的潜水面���上印有“M”标志,脚上是一双潜水靴。背景中,水泡随波逐流,浮现出一个巨大而友好的海星。摄像机从水底向上快速移动,捕捉他跃出水面的瞬间,灯光明亮而流动。该场景融合了动画与幻想元素,令人惊叹。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/object/ori_image_1.png"> |
149
+ | composition | reference_anything | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_1.png">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_2.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/out_video.mp4"></video> | 一名打扮成超人的男子自信地站着,面对镜头,肩头有一只充满活力的毛绒黄色鸭子。他留着整齐的短发和浅色胡须,鸭子有橙色的喙和脚,它的翅膀稍微展开,脚分开以保持稳定。他的表情严肃而坚定。他穿着标志性的蓝红超人服装,胸前有黄色“S”标志。斗篷在他身后飘逸。背景有行人。相机位于视线水平,捕捉角色的整个上半身。灯光均匀明亮。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/ori_image_1.png">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/ori_image_2.png"> |
150
+ | composition | swap_anything | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_mask.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/out_video.mp4"></video> | 视频展示了一个人在宽阔的草原上骑马。他有淡紫色长发,穿着传统服饰白上衣黑裤子,动画建模画风,看起来像是在进行某种户外活动或者是在进行某种表演。背景是壮观的山脉和多云的天空,给人一种宁静而广阔的感觉。整个视频的拍摄角度是固定的,重点展示了骑手和他的马。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/ori_video.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/ori_image_1.jpg"> |
151
+ | composition | expand_anything | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_mask.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/out_video.mp4"></video> | 古典油画风格,背景是一条河边,画面中央一位成熟优雅的女人,穿着长裙坐在椅子上。她双手从怀里取出打开的红色心形墨镜戴上。固定机位。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/ori_image_1.jpeg">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/ori_image_2.png"> |
152
+
153
+ ## 5. Limitations
154
+
155
+ - VACE-LTX-Video-0.9
156
+ - The prompt significantly impacts video generation quality on LTX-Video. It must be extended in accordance with the methods described in this [system prompt](https://huggingface.co/spaces/Lightricks/LTX-Video-Playground/blob/main/assets/system_prompt_i2v.txt). We also provide input parameters for using prompt extension (--use_prompt_extend).
157
+ - This model is intended for experimental research validation within the VACE paper and may not guarantee performance in real-world scenarios. However, its inference speed is very fast, capable of creating a video in 25 seconds with 40 steps on an A100 GPU, making it suitable for preliminary data and creative validation.
158
+ - VACE-Wan2.1-1.3B-Preview
159
+ - This model mainly keeps the original Wan2.1-T2V-1.3B's video quality while supporting various tasks.
160
+ - When you encounter failure cases with specific tasks, we recommend trying again with a different seed and adjusting the prompt.
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import datetime
8
+ import imageio
9
+ import numpy as np
10
+ import torch
11
+ import gradio as gr
12
+
13
+ sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-3]))
14
+ import wan
15
+ from vace.models.wan.wan_vace import WanVace
16
+ from vace.models.wan.configs import WAN_CONFIGS, SIZE_CONFIGS
17
+
18
+
19
+ class FixedSizeQueue:
20
+ def __init__(self, max_size):
21
+ self.max_size = max_size
22
+ self.queue = []
23
+ def add(self, item):
24
+ self.queue.insert(0, item)
25
+ if len(self.queue) > self.max_size:
26
+ self.queue.pop()
27
+ def get(self):
28
+ return self.queue
29
+ def __repr__(self):
30
+ return str(self.queue)
31
+
32
+
33
+ class VACEInference:
34
+ def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
35
+ self.cfg = cfg
36
+ self.save_dir = cfg.save_dir
37
+ self.gallery_share = gallery_share
38
+ self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
39
+ if not skip_load:
40
+ self.pipe = WanVace(
41
+ config=WAN_CONFIGS['vace-1.3B'],
42
+ checkpoint_dir=cfg.ckpt_dir,
43
+ device_id=0,
44
+ rank=0,
45
+ t5_fsdp=False,
46
+ dit_fsdp=False,
47
+ use_usp=False,
48
+ )
49
+
50
+ def create_ui(self, *args, **kwargs):
51
+ gr.Markdown("""
52
+ <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
53
+ <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
54
+ </div>
55
+ """)
56
+ with gr.Row(variant='panel', equal_height=True):
57
+ with gr.Column(scale=1, min_width=0):
58
+ self.src_video = gr.Video(
59
+ label="src_video",
60
+ sources=['upload'],
61
+ value=None,
62
+ interactive=True)
63
+ with gr.Column(scale=1, min_width=0):
64
+ self.src_mask = gr.Video(
65
+ label="src_mask",
66
+ sources=['upload'],
67
+ value=None,
68
+ interactive=True)
69
+ #
70
+ with gr.Row(variant='panel', equal_height=True):
71
+ with gr.Column(scale=1, min_width=0):
72
+ with gr.Row(equal_height=True):
73
+ self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
74
+ height=200,
75
+ interactive=True,
76
+ type='filepath',
77
+ image_mode='RGB',
78
+ sources=['upload'],
79
+ elem_id="src_ref_image_1",
80
+ format='png')
81
+ self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
82
+ height=200,
83
+ interactive=True,
84
+ type='filepath',
85
+ image_mode='RGB',
86
+ sources=['upload'],
87
+ elem_id="src_ref_image_2",
88
+ format='png')
89
+ self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
90
+ height=200,
91
+ interactive=True,
92
+ type='filepath',
93
+ image_mode='RGB',
94
+ sources=['upload'],
95
+ elem_id="src_ref_image_3",
96
+ format='png')
97
+ with gr.Row(variant='panel', equal_height=True):
98
+ with gr.Column(scale=1):
99
+ self.prompt = gr.Textbox(
100
+ show_label=False,
101
+ placeholder="positive_prompt_input",
102
+ elem_id='positive_prompt',
103
+ container=True,
104
+ autofocus=True,
105
+ elem_classes='type_row',
106
+ visible=True,
107
+ lines=2)
108
+ self.negative_prompt = gr.Textbox(
109
+ show_label=False,
110
+ value=self.pipe.config.sample_neg_prompt,
111
+ placeholder="negative_prompt_input",
112
+ elem_id='negative_prompt',
113
+ container=True,
114
+ autofocus=False,
115
+ elem_classes='type_row',
116
+ visible=True,
117
+ interactive=True,
118
+ lines=1)
119
+ #
120
+ with gr.Row(variant='panel', equal_height=True):
121
+ with gr.Column(scale=1, min_width=0):
122
+ with gr.Row(equal_height=True):
123
+ self.shift_scale = gr.Slider(
124
+ label='shift_scale',
125
+ minimum=0.0,
126
+ maximum=10.0,
127
+ step=1.0,
128
+ value=8.0,
129
+ interactive=True)
130
+ self.sample_steps = gr.Slider(
131
+ label='sample_steps',
132
+ minimum=1,
133
+ maximum=100,
134
+ step=1,
135
+ value=25,
136
+ interactive=True)
137
+ self.context_scale = gr.Slider(
138
+ label='context_scale',
139
+ minimum=0.0,
140
+ maximum=2.0,
141
+ step=0.1,
142
+ value=1.0,
143
+ interactive=True)
144
+ self.guide_scale = gr.Slider(
145
+ label='guide_scale',
146
+ minimum=1,
147
+ maximum=10,
148
+ step=0.5,
149
+ value=6.0,
150
+ interactive=True)
151
+ self.infer_seed = gr.Slider(minimum=-1,
152
+ maximum=10000000,
153
+ value=2025,
154
+ label="Seed")
155
+ #
156
+ with gr.Accordion(label="Usable without source video", open=False):
157
+ with gr.Row(equal_height=True):
158
+ self.output_height = gr.Textbox(
159
+ label='resolutions_height',
160
+ value=480,
161
+ interactive=True)
162
+ self.output_width = gr.Textbox(
163
+ label='resolutions_width',
164
+ value=832,
165
+ interactive=True)
166
+ self.frame_rate = gr.Textbox(
167
+ label='frame_rate',
168
+ value=16,
169
+ interactive=True)
170
+ self.num_frames = gr.Textbox(
171
+ label='num_frames',
172
+ value=81,
173
+ interactive=True)
174
+ #
175
+ with gr.Row(equal_height=True):
176
+ with gr.Column(scale=5):
177
+ self.generate_button = gr.Button(
178
+ value='Run',
179
+ elem_classes='type_row',
180
+ elem_id='generate_button',
181
+ visible=True)
182
+ with gr.Column(scale=1):
183
+ self.refresh_button = gr.Button(value='\U0001f504') # 🔄
184
+ #
185
+ self.output_gallery = gr.Gallery(
186
+ label="output_gallery",
187
+ value=[],
188
+ interactive=False,
189
+ allow_preview=True,
190
+ preview=True)
191
+
192
+
193
+ def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
194
+ output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
195
+ src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
196
+ x is not None]
197
+ src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
198
+ [src_mask],
199
+ [src_ref_images],
200
+ num_frames=num_frames,
201
+ image_size=SIZE_CONFIGS[f"{output_height}*{output_width}"],
202
+ device=self.pipe.device)
203
+ video = self.pipe.generate(
204
+ prompt,
205
+ src_video,
206
+ src_mask,
207
+ src_ref_images,
208
+ size=(output_width, output_height),
209
+ context_scale=context_scale,
210
+ shift=shift_scale,
211
+ sampling_steps=sample_steps,
212
+ guide_scale=guide_scale,
213
+ n_prompt=negative_prompt,
214
+ seed=infer_seed,
215
+ offload_model=True)
216
+
217
+ name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
218
+ video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
219
+ video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
220
+
221
+ try:
222
+ writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
223
+ for frame in video_frames:
224
+ writer.append_data(frame)
225
+ writer.close()
226
+ print(video_path)
227
+ except Exception as e:
228
+ raise gr.Error(f"Video save error: {e}")
229
+
230
+ if self.gallery_share:
231
+ self.gallery_share_data.add(video_path)
232
+ return self.gallery_share_data.get()
233
+ else:
234
+ return [video_path]
235
+
236
+ def set_callbacks(self, **kwargs):
237
+ self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
238
+ self.gen_outputs = [self.output_gallery]
239
+ self.generate_button.click(self.generate,
240
+ inputs=self.gen_inputs,
241
+ outputs=self.gen_outputs,
242
+ queue=True)
243
+ self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
244
+
245
+
246
+ if __name__ == '__main__':
247
+ parser = argparse.ArgumentParser(description='Argparser for VACE-LTXV Demo:\n')
248
+ parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
249
+ parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
250
+ parser.add_argument('--root_path', dest='root_path', help='', default=None)
251
+ parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
252
+ parser.add_argument(
253
+ "--ckpt_dir",
254
+ type=str,
255
+ default='models/VACE-Wan2.1-1.3B-Preview',
256
+ help="The path to the checkpoint directory.",
257
+ )
258
+ parser.add_argument(
259
+ "--offload_to_cpu",
260
+ action="store_true",
261
+ help="Offloading unnecessary computations to CPU.",
262
+ )
263
+
264
+ args = parser.parse_args()
265
+
266
+ if not os.path.exists(args.save_dir):
267
+ os.makedirs(args.save_dir, exist_ok=True)
268
+
269
+ with gr.Blocks() as demo:
270
+ infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
271
+ infer_gr.create_ui()
272
+ infer_gr.set_callbacks()
273
+ allowed_paths = [args.save_dir]
274
+ demo.queue(status_update_rate=1).launch(server_name=args.server_name,
275
+ server_port=args.server_port,
276
+ root_path=args.root_path,
277
+ allowed_paths=allowed_paths,
278
+ show_error=True, debug=True)
assets/images/test.jpg ADDED

Git LFS Details

  • SHA256: 71549d76843c4ee220f37f45e87f0dfc22079d1bc5fbe3f52fe2ded2b9454a3b
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
assets/images/test2.jpg ADDED
assets/images/test3.jpg ADDED

Git LFS Details

  • SHA256: bee71955dac07594b21937c2354ab5b7bd3f3321447202476178dab5ceead497
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
assets/masks/test.png ADDED
assets/masks/test2.png ADDED
assets/materials/gr_infer_demo.jpg ADDED

Git LFS Details

  • SHA256: 9b4f0df3c602da88e707262029d78284b3b5857e2bac413edef6f117e3ddb8be
  • Pointer size: 131 Bytes
  • Size of remote file: 320 kB
assets/materials/gr_pre_demo.jpg ADDED

Git LFS Details

  • SHA256: 6939180a97bd5abfc8d90bef6b31e949c591e2d75f5719e0eac150871d4aaae2
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
assets/materials/tasks.png ADDED

Git LFS Details

  • SHA256: 1f1c4b3f3e6ae927880fbe2f9a46939cc98824bb56c2753c975a2e3c4820830b
  • Pointer size: 131 Bytes
  • Size of remote file: 709 kB
assets/materials/teaser.jpg ADDED

Git LFS Details

  • SHA256: 87ce75e8dcbf1536674d3a951326727e0aff80192f52cf7388b34c03f13f711f
  • Pointer size: 131 Bytes
  • Size of remote file: 892 kB
assets/videos/test.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2195efbd92773f1ee262154577c700e9c3b7a4d7d04b1a2ac421db0879c696b0
3
+ size 737090
assets/videos/test2.mp4 ADDED
Binary file (79.6 kB). View file
 
benchmarks/.gitkeep ADDED
File without changes
models/.gitkeep ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "vace"
7
+ version = "1.0.0"
8
+ description = "VACE: All-in-One Video Creation and Editing"
9
+ authors = [
10
+ { name = "VACE Team", email = "[email protected]" }
11
+ ]
12
+ requires-python = ">=3.10,<4.0"
13
+ readme = "README.md"
14
+ dependencies = [
15
+ "torch>=2.5.1",
16
+ "torchvision>=0.20.1",
17
+ "opencv-python>=4.9.0.80",
18
+ "diffusers>=0.31.0",
19
+ "transformers>=4.49.0",
20
+ "tokenizers>=0.20.3",
21
+ "accelerate>=1.1.1",
22
+ "gradio>=5.0.0",
23
+ "numpy>=1.23.5,<2",
24
+ "tqdm",
25
+ "imageio",
26
+ "easydict",
27
+ "ftfy",
28
+ "dashscope",
29
+ "imageio-ffmpeg",
30
+ "flash_attn",
31
+ "decord",
32
+ "einops",
33
+ "scikit-image",
34
+ "scikit-learn",
35
+ "pycocotools",
36
+ "timm",
37
+ "onnxruntime-gpu",
38
+ "BeautifulSoup4"
39
+ ]
40
+
41
+ [project.optional-dependencies]
42
+ ltx = [
43
+ "ltx-video@git+https://github.com/Lightricks/[email protected]"
44
+ ]
45
+ wan = [
46
+ "wan@git+https://github.com/Wan-Video/Wan2.1"
47
+ ]
48
+ annotator = [
49
+ "insightface",
50
+ "sam-2@git+https://github.com/facebookresearch/sam2.git",
51
+ "segment-anything@git+https://github.com/facebookresearch/segment-anything.git",
52
+ "groundingdino@git+https://github.com/IDEA-Research/GroundingDINO.git",
53
+ "ram@git+https://github.com/xinyu1205/recognize-anything.git",
54
+ "raft@git+https://github.com/martin-chobanyan-sdc/RAFT.git"
55
+ ]
56
+
57
+ [project.urls]
58
+ homepage = "https://ali-vilab.github.io/VACE-Page/"
59
+ documentation = "https://ali-vilab.github.io/VACE-Page/"
60
+ repository = "https://github.com/ali-vilab/VACE"
61
+ hfmodel = "https://huggingface.co/collections/ali-vilab/vace-67eca186ff3e3564726aff38"
62
+ msmodel = "https://modelscope.cn/collections/VACE-8fa5fcfd386e43"
63
+ paper = "https://arxiv.org/abs/2503.07598"
64
+
65
+ [tool.setuptools]
66
+ packages = { find = {} }
67
+
68
+ [tool.black]
69
+ line-length = 88
70
+
71
+ [tool.isort]
72
+ profile = "black"
73
+
74
+ [tool.mypy]
75
+ strict = true
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -r requirements/framework.txt
requirements/annotator.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ insightface
2
+ git+https://github.com/facebookresearch/sam2.git
3
+ git+https://github.com/facebookresearch/segment-anything.git
4
+ git+https://github.com/IDEA-Research/GroundingDINO.git
5
+ git+https://github.com/xinyu1205/recognize-anything.git
6
+ git+https://github.com/martin-chobanyan-sdc/RAFT.git
requirements/framework.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.5.1
2
+ torchvision>=0.20.1
3
+ opencv-python>=4.9.0.80
4
+ diffusers>=0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ gradio>=5.0.0
9
+ numpy>=1.23.5,<2
10
+ tqdm
11
+ imageio
12
+ easydict
13
+ ftfy
14
+ dashscope
15
+ imageio-ffmpeg
16
+ flash_attn
17
+ decord
18
+ einops
19
+ scikit-image
20
+ scikit-learn
21
+ pycocotools
22
+ timm
23
+ onnxruntime-gpu
24
+ BeautifulSoup4
25
+ #ltx-video@git+https://github.com/Lightricks/[email protected]
26
+ #wan@git+https://github.com/Wan-Video/Wan2.1
tests/test_annotators.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import os
5
+ import unittest
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from vace.annotators.utils import read_video_frames
10
+ from vace.annotators.utils import save_one_video
11
+
12
+ class AnnotatorTest(unittest.TestCase):
13
+ def setUp(self):
14
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
15
+ self.save_dir = './cache/test_annotator'
16
+ if not os.path.exists(self.save_dir):
17
+ os.makedirs(self.save_dir)
18
+ # load test image
19
+ self.image_path = './assets/images/test.jpg'
20
+ self.image = Image.open(self.image_path).convert('RGB')
21
+ # load test video
22
+ self.video_path = './assets/videos/test.mp4'
23
+ self.frames = read_video_frames(self.video_path)
24
+
25
+ def tearDown(self):
26
+ super().tearDown()
27
+
28
+ @unittest.skip('')
29
+ def test_annotator_gray_image(self):
30
+ from vace.annotators.gray import GrayAnnotator
31
+ cfg_dict = {}
32
+ anno_ins = GrayAnnotator(cfg_dict)
33
+ anno_image = anno_ins.forward(np.array(self.image))
34
+ save_path = os.path.join(self.save_dir, 'test_gray_image.png')
35
+ Image.fromarray(anno_image).save(save_path)
36
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
37
+
38
+ @unittest.skip('')
39
+ def test_annotator_gray_video(self):
40
+ from vace.annotators.gray import GrayAnnotator
41
+ cfg_dict = {}
42
+ anno_ins = GrayAnnotator(cfg_dict)
43
+ ret_frames = []
44
+ for frame in self.frames:
45
+ anno_frame = anno_ins.forward(np.array(frame))
46
+ ret_frames.append(anno_frame)
47
+ save_path = os.path.join(self.save_dir, 'test_gray_video.mp4')
48
+ save_one_video(save_path, ret_frames, fps=16)
49
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
50
+
51
+ @unittest.skip('')
52
+ def test_annotator_gray_video_2(self):
53
+ from vace.annotators.gray import GrayVideoAnnotator
54
+ cfg_dict = {}
55
+ anno_ins = GrayVideoAnnotator(cfg_dict)
56
+ ret_frames = anno_ins.forward(self.frames)
57
+ save_path = os.path.join(self.save_dir, 'test_gray_video_2.mp4')
58
+ save_one_video(save_path, ret_frames, fps=16)
59
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
60
+
61
+
62
+ @unittest.skip('')
63
+ def test_annotator_pose_image(self):
64
+ from vace.annotators.pose import PoseBodyFaceAnnotator
65
+ cfg_dict = {
66
+ "DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
67
+ "POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
68
+ "RESIZE_SIZE": 1024
69
+ }
70
+ anno_ins = PoseBodyFaceAnnotator(cfg_dict)
71
+ anno_image = anno_ins.forward(np.array(self.image))
72
+ save_path = os.path.join(self.save_dir, 'test_pose_image.png')
73
+ Image.fromarray(anno_image).save(save_path)
74
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
75
+
76
+ @unittest.skip('')
77
+ def test_annotator_pose_video(self):
78
+ from vace.annotators.pose import PoseBodyFaceAnnotator
79
+ cfg_dict = {
80
+ "DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
81
+ "POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
82
+ "RESIZE_SIZE": 1024
83
+ }
84
+ anno_ins = PoseBodyFaceAnnotator(cfg_dict)
85
+ ret_frames = []
86
+ for frame in self.frames:
87
+ anno_frame = anno_ins.forward(np.array(frame))
88
+ ret_frames.append(anno_frame)
89
+ save_path = os.path.join(self.save_dir, 'test_pose_video.mp4')
90
+ save_one_video(save_path, ret_frames, fps=16)
91
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
92
+
93
+ @unittest.skip('')
94
+ def test_annotator_pose_video_2(self):
95
+ from vace.annotators.pose import PoseBodyFaceVideoAnnotator
96
+ cfg_dict = {
97
+ "DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
98
+ "POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
99
+ "RESIZE_SIZE": 1024
100
+ }
101
+ anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
102
+ ret_frames = anno_ins.forward(self.frames)
103
+ save_path = os.path.join(self.save_dir, 'test_pose_video_2.mp4')
104
+ save_one_video(save_path, ret_frames, fps=16)
105
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
106
+
107
+ @unittest.skip('')
108
+ def test_annotator_depth_image(self):
109
+ from vace.annotators.depth import DepthAnnotator
110
+ cfg_dict = {
111
+ "PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
112
+ }
113
+ anno_ins = DepthAnnotator(cfg_dict)
114
+ anno_image = anno_ins.forward(np.array(self.image))
115
+ save_path = os.path.join(self.save_dir, 'test_depth_image.png')
116
+ Image.fromarray(anno_image).save(save_path)
117
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
118
+
119
+ @unittest.skip('')
120
+ def test_annotator_depth_video(self):
121
+ from vace.annotators.depth import DepthAnnotator
122
+ cfg_dict = {
123
+ "PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
124
+ }
125
+ anno_ins = DepthAnnotator(cfg_dict)
126
+ ret_frames = []
127
+ for frame in self.frames:
128
+ anno_frame = anno_ins.forward(np.array(frame))
129
+ ret_frames.append(anno_frame)
130
+ save_path = os.path.join(self.save_dir, 'test_depth_video.mp4')
131
+ save_one_video(save_path, ret_frames, fps=16)
132
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
133
+
134
+ @unittest.skip('')
135
+ def test_annotator_depth_video_2(self):
136
+ from vace.annotators.depth import DepthVideoAnnotator
137
+ cfg_dict = {
138
+ "PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
139
+ }
140
+ anno_ins = DepthVideoAnnotator(cfg_dict)
141
+ ret_frames = anno_ins.forward(self.frames)
142
+ save_path = os.path.join(self.save_dir, 'test_depth_video_2.mp4')
143
+ save_one_video(save_path, ret_frames, fps=16)
144
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
145
+
146
+ @unittest.skip('')
147
+ def test_annotator_scribble_image(self):
148
+ from vace.annotators.scribble import ScribbleAnnotator
149
+ cfg_dict = {
150
+ "PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
151
+ }
152
+ anno_ins = ScribbleAnnotator(cfg_dict)
153
+ anno_image = anno_ins.forward(np.array(self.image))
154
+ save_path = os.path.join(self.save_dir, 'test_scribble_image.png')
155
+ Image.fromarray(anno_image).save(save_path)
156
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
157
+
158
+ @unittest.skip('')
159
+ def test_annotator_scribble_video(self):
160
+ from vace.annotators.scribble import ScribbleAnnotator
161
+ cfg_dict = {
162
+ "PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
163
+ }
164
+ anno_ins = ScribbleAnnotator(cfg_dict)
165
+ ret_frames = []
166
+ for frame in self.frames:
167
+ anno_frame = anno_ins.forward(np.array(frame))
168
+ ret_frames.append(anno_frame)
169
+ save_path = os.path.join(self.save_dir, 'test_scribble_video.mp4')
170
+ save_one_video(save_path, ret_frames, fps=16)
171
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
172
+
173
+ @unittest.skip('')
174
+ def test_annotator_scribble_video_2(self):
175
+ from vace.annotators.scribble import ScribbleVideoAnnotator
176
+ cfg_dict = {
177
+ "PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
178
+ }
179
+ anno_ins = ScribbleVideoAnnotator(cfg_dict)
180
+ ret_frames = anno_ins.forward(self.frames)
181
+ save_path = os.path.join(self.save_dir, 'test_scribble_video_2.mp4')
182
+ save_one_video(save_path, ret_frames, fps=16)
183
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
184
+
185
+ @unittest.skip('')
186
+ def test_annotator_flow_video(self):
187
+ from vace.annotators.flow import FlowVisAnnotator
188
+ cfg_dict = {
189
+ "PRETRAINED_MODEL": "models/VACE-Annotators/flow/raft-things.pth"
190
+ }
191
+ anno_ins = FlowVisAnnotator(cfg_dict)
192
+ ret_frames = anno_ins.forward(self.frames)
193
+ save_path = os.path.join(self.save_dir, 'test_flow_video.mp4')
194
+ save_one_video(save_path, ret_frames, fps=16)
195
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
196
+
197
+ @unittest.skip('')
198
+ def test_annotator_frameref_video_1(self):
199
+ from vace.annotators.frameref import FrameRefExtractAnnotator
200
+ cfg_dict = {
201
+ "REF_CFG": [{"mode": "first", "proba": 0.1},
202
+ {"mode": "last", "proba": 0.1},
203
+ {"mode": "firstlast", "proba": 0.1},
204
+ {"mode": "random", "proba": 0.1}],
205
+ }
206
+ anno_ins = FrameRefExtractAnnotator(cfg_dict)
207
+ ret_frames, ret_masks = anno_ins.forward(self.frames, ref_num=10)
208
+ save_path = os.path.join(self.save_dir, 'test_frameref_video_1.mp4')
209
+ save_one_video(save_path, ret_frames, fps=16)
210
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
211
+ save_path = os.path.join(self.save_dir, 'test_frameref_mask_1.mp4')
212
+ save_one_video(save_path, ret_masks, fps=16)
213
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
214
+
215
+ @unittest.skip('')
216
+ def test_annotator_frameref_video_2(self):
217
+ from vace.annotators.frameref import FrameRefExpandAnnotator
218
+ cfg_dict = {}
219
+ anno_ins = FrameRefExpandAnnotator(cfg_dict)
220
+ ret_frames, ret_masks = anno_ins.forward(frames=self.frames, mode='lastclip', expand_num=50)
221
+ save_path = os.path.join(self.save_dir, 'test_frameref_video_2.mp4')
222
+ save_one_video(save_path, ret_frames, fps=16)
223
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
224
+ save_path = os.path.join(self.save_dir, 'test_frameref_mask_2.mp4')
225
+ save_one_video(save_path, ret_masks, fps=16)
226
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
227
+
228
+
229
+ @unittest.skip('')
230
+ def test_annotator_outpainting_1(self):
231
+ from vace.annotators.outpainting import OutpaintingAnnotator
232
+ cfg_dict = {
233
+ "RETURN_MASK": True,
234
+ "KEEP_PADDING_RATIO": 1,
235
+ "MASK_COLOR": "gray"
236
+ }
237
+ anno_ins = OutpaintingAnnotator(cfg_dict)
238
+ ret_data = anno_ins.forward(self.image, direction=['right', 'up', 'down'], expand_ratio=0.5)
239
+ save_path = os.path.join(self.save_dir, 'test_outpainting_image.png')
240
+ Image.fromarray(ret_data['image']).save(save_path)
241
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
242
+ save_path = os.path.join(self.save_dir, 'test_outpainting_mask.png')
243
+ Image.fromarray(ret_data['mask']).save(save_path)
244
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
245
+
246
+ @unittest.skip('')
247
+ def test_annotator_outpainting_video_1(self):
248
+ from vace.annotators.outpainting import OutpaintingVideoAnnotator
249
+ cfg_dict = {
250
+ "RETURN_MASK": True,
251
+ "KEEP_PADDING_RATIO": 1,
252
+ "MASK_COLOR": "gray"
253
+ }
254
+ anno_ins = OutpaintingVideoAnnotator(cfg_dict)
255
+ ret_data = anno_ins.forward(frames=self.frames, direction=['right', 'up', 'down'], expand_ratio=0.5)
256
+ save_path = os.path.join(self.save_dir, 'test_outpainting_video_1.mp4')
257
+ save_one_video(save_path, ret_data['frames'], fps=16)
258
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
259
+ save_path = os.path.join(self.save_dir, 'test_outpainting_mask_1.mp4')
260
+ save_one_video(save_path, ret_data['masks'], fps=16)
261
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
262
+
263
+ @unittest.skip('')
264
+ def test_annotator_outpainting_inner_1(self):
265
+ from vace.annotators.outpainting import OutpaintingInnerAnnotator
266
+ cfg_dict = {
267
+ "RETURN_MASK": True,
268
+ "KEEP_PADDING_RATIO": 1,
269
+ "MASK_COLOR": "gray"
270
+ }
271
+ anno_ins = OutpaintingInnerAnnotator(cfg_dict)
272
+ ret_data = anno_ins.forward(self.image, direction=['right', 'up', 'down'], expand_ratio=0.15)
273
+ save_path = os.path.join(self.save_dir, 'test_outpainting_inner_image.png')
274
+ Image.fromarray(ret_data['image']).save(save_path)
275
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
276
+ save_path = os.path.join(self.save_dir, 'test_outpainting_inner_mask.png')
277
+ Image.fromarray(ret_data['mask']).save(save_path)
278
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
279
+
280
+ @unittest.skip('')
281
+ def test_annotator_outpainting_inner_video_1(self):
282
+ from vace.annotators.outpainting import OutpaintingInnerVideoAnnotator
283
+ cfg_dict = {
284
+ "RETURN_MASK": True,
285
+ "KEEP_PADDING_RATIO": 1,
286
+ "MASK_COLOR": "gray"
287
+ }
288
+ anno_ins = OutpaintingInnerVideoAnnotator(cfg_dict)
289
+ ret_data = anno_ins.forward(self.frames, direction=['right', 'up', 'down'], expand_ratio=0.15)
290
+ save_path = os.path.join(self.save_dir, 'test_outpainting_inner_video_1.mp4')
291
+ save_one_video(save_path, ret_data['frames'], fps=16)
292
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
293
+ save_path = os.path.join(self.save_dir, 'test_outpainting_inner_mask_1.mp4')
294
+ save_one_video(save_path, ret_data['masks'], fps=16)
295
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
296
+
297
+ @unittest.skip('')
298
+ def test_annotator_salient(self):
299
+ from vace.annotators.salient import SalientAnnotator
300
+ cfg_dict = {
301
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
302
+ }
303
+ anno_ins = SalientAnnotator(cfg_dict)
304
+ ret_data = anno_ins.forward(self.image)
305
+ save_path = os.path.join(self.save_dir, 'test_salient_image.png')
306
+ Image.fromarray(ret_data).save(save_path)
307
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
308
+
309
+ @unittest.skip('')
310
+ def test_annotator_salient_video(self):
311
+ from vace.annotators.salient import SalientVideoAnnotator
312
+ cfg_dict = {
313
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
314
+ }
315
+ anno_ins = SalientVideoAnnotator(cfg_dict)
316
+ ret_frames = anno_ins.forward(self.frames)
317
+ save_path = os.path.join(self.save_dir, 'test_salient_video.mp4')
318
+ save_one_video(save_path, ret_frames, fps=16)
319
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
320
+
321
+ @unittest.skip('')
322
+ def test_annotator_layout_video(self):
323
+ from vace.annotators.layout import LayoutBboxAnnotator
324
+ cfg_dict = {
325
+ "RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
326
+ }
327
+ anno_ins = LayoutBboxAnnotator(cfg_dict)
328
+ ret_frames = anno_ins.forward(bbox=[(544, 288, 744, 680), (1112, 240, 1280, 712)], frame_size=(720, 1280), num_frames=49, label='person')
329
+ save_path = os.path.join(self.save_dir, 'test_layout_video.mp4')
330
+ save_one_video(save_path, ret_frames, fps=16)
331
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
332
+
333
+ @unittest.skip('')
334
+ def test_annotator_layout_mask_video(self):
335
+ # salient
336
+ from vace.annotators.salient import SalientVideoAnnotator
337
+ cfg_dict = {
338
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
339
+ }
340
+ anno_ins = SalientVideoAnnotator(cfg_dict)
341
+ salient_frames = anno_ins.forward(self.frames)
342
+
343
+ # mask layout
344
+ from vace.annotators.layout import LayoutMaskAnnotator
345
+ cfg_dict = {
346
+ "RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
347
+ }
348
+ anno_ins = LayoutMaskAnnotator(cfg_dict)
349
+ ret_frames = anno_ins.forward(salient_frames, label='cat')
350
+ save_path = os.path.join(self.save_dir, 'test_mask_layout_video.mp4')
351
+ save_one_video(save_path, ret_frames, fps=16)
352
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
353
+
354
+ @unittest.skip('')
355
+ def test_annotator_layout_mask_video_2(self):
356
+ # salient
357
+ from vace.annotators.salient import SalientVideoAnnotator
358
+ cfg_dict = {
359
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
360
+ }
361
+ anno_ins = SalientVideoAnnotator(cfg_dict)
362
+ salient_frames = anno_ins.forward(self.frames)
363
+
364
+ # mask layout
365
+ from vace.annotators.layout import LayoutMaskAnnotator
366
+ cfg_dict = {
367
+ "RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
368
+ "USE_AUG": True
369
+ }
370
+ anno_ins = LayoutMaskAnnotator(cfg_dict)
371
+ ret_frames = anno_ins.forward(salient_frames, label='cat', mask_cfg={'mode': 'bbox_expand'})
372
+ save_path = os.path.join(self.save_dir, 'test_mask_layout_video_2.mp4')
373
+ save_one_video(save_path, ret_frames, fps=16)
374
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
375
+
376
+
377
+ @unittest.skip('')
378
+ def test_annotator_maskaug_video(self):
379
+ # salient
380
+ from vace.annotators.salient import SalientVideoAnnotator
381
+ cfg_dict = {
382
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
383
+ }
384
+ anno_ins = SalientVideoAnnotator(cfg_dict)
385
+ salient_frames = anno_ins.forward(self.frames)
386
+
387
+ # mask aug
388
+ from vace.annotators.maskaug import MaskAugAnnotator
389
+ cfg_dict = {}
390
+ anno_ins = MaskAugAnnotator(cfg_dict)
391
+ ret_frames = anno_ins.forward(salient_frames, mask_cfg={'mode': 'hull_expand'})
392
+ save_path = os.path.join(self.save_dir, 'test_maskaug_video.mp4')
393
+ save_one_video(save_path, ret_frames, fps=16)
394
+ print(('Testing %s: %s' % (type(self).__name__, save_path)))
395
+
396
+
397
+ @unittest.skip('')
398
+ def test_annotator_ram(self):
399
+ from vace.annotators.ram import RAMAnnotator
400
+ cfg_dict = {
401
+ "TOKENIZER_PATH": "models/VACE-Annotators/ram/bert-base-uncased",
402
+ "PRETRAINED_MODEL": "models/VACE-Annotators/ram/ram_plus_swin_large_14m.pth",
403
+ }
404
+ anno_ins = RAMAnnotator(cfg_dict)
405
+ ret_data = anno_ins.forward(self.image)
406
+ print(ret_data)
407
+
408
+ @unittest.skip('')
409
+ def test_annotator_gdino_v1(self):
410
+ from vace.annotators.gdino import GDINOAnnotator
411
+ cfg_dict = {
412
+ "TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
413
+ "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
414
+ "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
415
+ }
416
+ anno_ins = GDINOAnnotator(cfg_dict)
417
+ ret_data = anno_ins.forward(self.image, caption="a cat and a vase")
418
+ print(ret_data)
419
+
420
+ @unittest.skip('')
421
+ def test_annotator_gdino_v2(self):
422
+ from vace.annotators.gdino import GDINOAnnotator
423
+ cfg_dict = {
424
+ "TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
425
+ "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
426
+ "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
427
+ }
428
+ anno_ins = GDINOAnnotator(cfg_dict)
429
+ ret_data = anno_ins.forward(self.image, classes=["cat", "vase"])
430
+ print(ret_data)
431
+
432
+ @unittest.skip('')
433
+ def test_annotator_gdino_with_ram(self):
434
+ from vace.annotators.gdino import GDINORAMAnnotator
435
+ cfg_dict = {
436
+ "RAM": {
437
+ "TOKENIZER_PATH": "models/VACE-Annotators/ram/bert-base-uncased",
438
+ "PRETRAINED_MODEL": "models/VACE-Annotators/ram/ram_plus_swin_large_14m.pth",
439
+ },
440
+ "GDINO": {
441
+ "TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
442
+ "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
443
+ "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
444
+ }
445
+
446
+ }
447
+ anno_ins = GDINORAMAnnotator(cfg_dict)
448
+ ret_data = anno_ins.forward(self.image)
449
+ print(ret_data)
450
+
451
+ @unittest.skip('')
452
+ def test_annotator_sam2(self):
453
+ from vace.annotators.sam2 import SAM2VideoAnnotator
454
+ from vace.annotators.utils import save_sam2_video
455
+ cfg_dict = {
456
+ "CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
457
+ "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
458
+ }
459
+ anno_ins = SAM2VideoAnnotator(cfg_dict)
460
+ ret_data = anno_ins.forward(video=self.video_path, input_box=[0, 0, 640, 480])
461
+ video_segments = ret_data['annotations']
462
+ save_path = os.path.join(self.save_dir, 'test_sam2_video')
463
+ if not os.path.exists(save_path):
464
+ os.makedirs(save_path)
465
+ save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
466
+ print(save_path)
467
+
468
+
469
+ @unittest.skip('')
470
+ def test_annotator_sam2salient(self):
471
+ from vace.annotators.sam2 import SAM2SalientVideoAnnotator
472
+ from vace.annotators.utils import save_sam2_video
473
+ cfg_dict = {
474
+ "SALIENT": {
475
+ "PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
476
+ },
477
+ "SAM2": {
478
+ "CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
479
+ "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
480
+ }
481
+
482
+ }
483
+ anno_ins = SAM2SalientVideoAnnotator(cfg_dict)
484
+ ret_data = anno_ins.forward(video=self.video_path)
485
+ video_segments = ret_data['annotations']
486
+ save_path = os.path.join(self.save_dir, 'test_sam2salient_video')
487
+ if not os.path.exists(save_path):
488
+ os.makedirs(save_path)
489
+ save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
490
+ print(save_path)
491
+
492
+
493
+ @unittest.skip('')
494
+ def test_annotator_sam2gdinoram_video(self):
495
+ from vace.annotators.sam2 import SAM2GDINOVideoAnnotator
496
+ from vace.annotators.utils import save_sam2_video
497
+ cfg_dict = {
498
+ "GDINO": {
499
+ "TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
500
+ "CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
501
+ "PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
502
+ },
503
+ "SAM2": {
504
+ "CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
505
+ "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
506
+ }
507
+ }
508
+ anno_ins = SAM2GDINOVideoAnnotator(cfg_dict)
509
+ ret_data = anno_ins.forward(video=self.video_path, classes='cat')
510
+ video_segments = ret_data['annotations']
511
+ save_path = os.path.join(self.save_dir, 'test_sam2gdino_video')
512
+ if not os.path.exists(save_path):
513
+ os.makedirs(save_path)
514
+ save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
515
+ print(save_path)
516
+
517
+ @unittest.skip('')
518
+ def test_annotator_sam2_image(self):
519
+ from vace.annotators.sam2 import SAM2ImageAnnotator
520
+ cfg_dict = {
521
+ "CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
522
+ "PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
523
+ }
524
+ anno_ins = SAM2ImageAnnotator(cfg_dict)
525
+ ret_data = anno_ins.forward(image=self.image, input_box=[0, 0, 640, 480])
526
+ print(ret_data)
527
+
528
+ # @unittest.skip('')
529
+ def test_annotator_prompt_extend(self):
530
+ from vace.annotators.prompt_extend import PromptExtendAnnotator
531
+ from vace.configs.prompt_preprocess import WAN_LM_ZH_SYS_PROMPT, WAN_LM_EN_SYS_PROMPT, LTX_LM_EN_SYS_PROMPT
532
+ cfg_dict = {
533
+ "MODEL_NAME": "models/VACE-Annotators/llm/Qwen2.5-3B-Instruct" # "Qwen2.5_3B"
534
+ }
535
+ anno_ins = PromptExtendAnnotator(cfg_dict)
536
+ ret_data = anno_ins.forward('一位男孩', system_prompt=WAN_LM_ZH_SYS_PROMPT)
537
+ print('wan_zh:', ret_data)
538
+ ret_data = anno_ins.forward('a boy', system_prompt=WAN_LM_EN_SYS_PROMPT)
539
+ print('wan_en:', ret_data)
540
+ ret_data = anno_ins.forward('a boy', system_prompt=WAN_LM_ZH_SYS_PROMPT)
541
+ print('wan_zh en:', ret_data)
542
+ ret_data = anno_ins.forward('a boy', system_prompt=LTX_LM_EN_SYS_PROMPT)
543
+ print('ltx_en:', ret_data)
544
+
545
+ from vace.annotators.utils import get_annotator
546
+ anno_ins = get_annotator(config_type='prompt', config_task='ltx_en', return_dict=False)
547
+ ret_data = anno_ins.forward('a boy', seed=2025)
548
+ print('ltx_en:', ret_data)
549
+ ret_data = anno_ins.forward('a boy')
550
+ print('ltx_en:', ret_data)
551
+ ret_data = anno_ins.forward('a boy', seed=2025)
552
+ print('ltx_en:', ret_data)
553
+
554
+ @unittest.skip('')
555
+ def test_annotator_prompt_extend_ds(self):
556
+ from vace.annotators.utils import get_annotator
557
+ # export DASH_API_KEY=''
558
+ anno_ins = get_annotator(config_type='prompt', config_task='wan_zh_ds', return_dict=False)
559
+ ret_data = anno_ins.forward('一位男孩', seed=2025)
560
+ print('wan_zh_ds:', ret_data)
561
+ ret_data = anno_ins.forward('a boy', seed=2025)
562
+ print('wan_zh_ds:', ret_data)
563
+
564
+
565
+ # ln -s your/path/annotator_models annotator_models
566
+ # PYTHONPATH=. python tests/test_annotators.py
567
+ if __name__ == '__main__':
568
+ unittest.main()
vace/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from . import annotators
4
+ from . import configs
5
+ from . import models
6
+ from . import gradios
vace/annotators/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from .depth import DepthAnnotator, DepthVideoAnnotator
4
+ from .flow import FlowAnnotator, FlowVisAnnotator
5
+ from .frameref import FrameRefExtractAnnotator, FrameRefExpandAnnotator
6
+ from .gdino import GDINOAnnotator, GDINORAMAnnotator
7
+ from .gray import GrayAnnotator, GrayVideoAnnotator
8
+ from .inpainting import InpaintingAnnotator, InpaintingVideoAnnotator
9
+ from .layout import LayoutBboxAnnotator, LayoutMaskAnnotator, LayoutTrackAnnotator
10
+ from .maskaug import MaskAugAnnotator
11
+ from .outpainting import OutpaintingAnnotator, OutpaintingInnerAnnotator, OutpaintingVideoAnnotator, OutpaintingInnerVideoAnnotator
12
+ from .pose import PoseBodyFaceAnnotator, PoseBodyFaceVideoAnnotator, PoseAnnotator
13
+ from .ram import RAMAnnotator
14
+ from .salient import SalientAnnotator, SalientVideoAnnotator
15
+ from .sam import SAMImageAnnotator
16
+ from .sam2 import SAM2ImageAnnotator, SAM2VideoAnnotator, SAM2SalientVideoAnnotator, SAM2GDINOVideoAnnotator
17
+ from .scribble import ScribbleAnnotator, ScribbleVideoAnnotator
18
+ from .face import FaceAnnotator
19
+ from .subject import SubjectAnnotator
20
+ from .common import PlainImageAnnotator, PlainMaskAnnotator, PlainMaskAugAnnotator, PlainMaskVideoAnnotator, PlainVideoAnnotator, PlainMaskAugVideoAnnotator, PlainMaskAugInvertAnnotator, PlainMaskAugInvertVideoAnnotator, ExpandMaskVideoAnnotator
21
+ from .prompt_extend import PromptExtendAnnotator
22
+ from .composition import CompositionAnnotator, ReferenceAnythingAnnotator, AnimateAnythingAnnotator, SwapAnythingAnnotator, ExpandAnythingAnnotator, MoveAnythingAnnotator
23
+ from .mask import MaskDrawAnnotator
24
+ from .canvas import RegionCanvasAnnotator
vace/annotators/canvas.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from .utils import convert_to_numpy
9
+
10
+
11
+ class RegionCanvasAnnotator:
12
+ def __init__(self, cfg, device=None):
13
+ self.scale_range = cfg.get('SCALE_RANGE', [0.75, 1.0])
14
+ self.canvas_value = cfg.get('CANVAS_VALUE', 255)
15
+ self.use_resize = cfg.get('USE_RESIZE', True)
16
+ self.use_canvas = cfg.get('USE_CANVAS', True)
17
+ self.use_aug = cfg.get('USE_AUG', False)
18
+ if self.use_aug:
19
+ from .maskaug import MaskAugAnnotator
20
+ self.maskaug_anno = MaskAugAnnotator(cfg={})
21
+
22
+ def forward(self, image, mask, mask_cfg=None):
23
+
24
+ image = convert_to_numpy(image)
25
+ mask = convert_to_numpy(mask)
26
+ image_h, image_w = image.shape[:2]
27
+
28
+ if self.use_aug:
29
+ mask = self.maskaug_anno.forward(mask, mask_cfg)
30
+
31
+ # get region with white bg
32
+ image[np.array(mask) == 0] = self.canvas_value
33
+ x, y, w, h = cv2.boundingRect(mask)
34
+ region_crop = image[y:y + h, x:x + w]
35
+
36
+ if self.use_resize:
37
+ # resize region
38
+ scale_min, scale_max = self.scale_range
39
+ scale_factor = random.uniform(scale_min, scale_max)
40
+ new_w, new_h = int(image_w * scale_factor), int(image_h * scale_factor)
41
+ obj_scale_factor = min(new_w/w, new_h/h)
42
+
43
+ new_w = int(w * obj_scale_factor)
44
+ new_h = int(h * obj_scale_factor)
45
+ region_crop_resized = cv2.resize(region_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
46
+ else:
47
+ region_crop_resized = region_crop
48
+
49
+ if self.use_canvas:
50
+ # plot region into canvas
51
+ new_canvas = np.ones_like(image) * self.canvas_value
52
+ max_x = max(0, image_w - new_w)
53
+ max_y = max(0, image_h - new_h)
54
+ new_x = random.randint(0, max_x)
55
+ new_y = random.randint(0, max_y)
56
+
57
+ new_canvas[new_y:new_y + new_h, new_x:new_x + new_w] = region_crop_resized
58
+ else:
59
+ new_canvas = region_crop_resized
60
+ return new_canvas
vace/annotators/common.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ class PlainImageAnnotator:
5
+ def __init__(self, cfg):
6
+ pass
7
+ def forward(self, image):
8
+ return image
9
+
10
+ class PlainVideoAnnotator:
11
+ def __init__(self, cfg):
12
+ pass
13
+ def forward(self, frames):
14
+ return frames
15
+
16
+ class PlainMaskAnnotator:
17
+ def __init__(self, cfg):
18
+ pass
19
+ def forward(self, mask):
20
+ return mask
21
+
22
+ class PlainMaskAugInvertAnnotator:
23
+ def __init__(self, cfg):
24
+ pass
25
+ def forward(self, mask):
26
+ return 255 - mask
27
+
28
+ class PlainMaskAugAnnotator:
29
+ def __init__(self, cfg):
30
+ pass
31
+ def forward(self, mask):
32
+ return mask
33
+
34
+ class PlainMaskVideoAnnotator:
35
+ def __init__(self, cfg):
36
+ pass
37
+ def forward(self, mask):
38
+ return mask
39
+
40
+ class PlainMaskAugVideoAnnotator:
41
+ def __init__(self, cfg):
42
+ pass
43
+ def forward(self, masks):
44
+ return masks
45
+
46
+ class PlainMaskAugInvertVideoAnnotator:
47
+ def __init__(self, cfg):
48
+ pass
49
+ def forward(self, masks):
50
+ return [255 - mask for mask in masks]
51
+
52
+ class ExpandMaskVideoAnnotator:
53
+ def __init__(self, cfg):
54
+ pass
55
+ def forward(self, mask, expand_num):
56
+ return [mask] * expand_num
57
+
58
+ class PlainPromptAnnotator:
59
+ def __init__(self, cfg):
60
+ pass
61
+ def forward(self, prompt):
62
+ return prompt
vace/annotators/composition.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import numpy as np
4
+
5
+ class CompositionAnnotator:
6
+ def __init__(self, cfg):
7
+ self.process_types = ["repaint", "extension", "control"]
8
+ self.process_map = {
9
+ "repaint": "repaint",
10
+ "extension": "extension",
11
+ "control": "control",
12
+ "inpainting": "repaint",
13
+ "outpainting": "repaint",
14
+ "frameref": "extension",
15
+ "clipref": "extension",
16
+ "depth": "control",
17
+ "flow": "control",
18
+ "gray": "control",
19
+ "pose": "control",
20
+ "scribble": "control",
21
+ "layout": "control"
22
+ }
23
+
24
+ def forward(self, process_type_1, process_type_2, frames_1, frames_2, masks_1, masks_2):
25
+ total_frames = min(len(frames_1), len(frames_2), len(masks_1), len(masks_2))
26
+ combine_type = (self.process_map[process_type_1], self.process_map[process_type_2])
27
+ if combine_type in [("extension", "repaint"), ("extension", "control"), ("extension", "extension")]:
28
+ output_video = [frames_2[i] * masks_1[i] + frames_1[i] * (1 - masks_1[i]) for i in range(total_frames)]
29
+ output_mask = [masks_1[i] * masks_2[i] * 255 for i in range(total_frames)]
30
+ elif combine_type in [("repaint", "extension"), ("control", "extension"), ("repaint", "repaint")]:
31
+ output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)]
32
+ output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)]
33
+ elif combine_type in [("repaint", "control"), ("control", "repaint")]:
34
+ if combine_type in [("control", "repaint")]:
35
+ frames_1, frames_2, masks_1, masks_2 = frames_2, frames_1, masks_2, masks_1
36
+ output_video = [frames_1[i] * (1 - masks_1[i]) + frames_2[i] * masks_1[i] for i in range(total_frames)]
37
+ output_mask = [masks_1[i] * 255 for i in range(total_frames)]
38
+ elif combine_type in [("control", "control")]: # apply masks_2
39
+ output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)]
40
+ output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)]
41
+ else:
42
+ raise Exception("Unknown combine type")
43
+ return output_video, output_mask
44
+
45
+
46
+ class ReferenceAnythingAnnotator:
47
+ def __init__(self, cfg):
48
+ from .subject import SubjectAnnotator
49
+ self.sbjref_ins = SubjectAnnotator(cfg['SUBJECT'] if 'SUBJECT' in cfg else cfg)
50
+ self.key_map = {
51
+ "image": "images",
52
+ "mask": "masks"
53
+ }
54
+ def forward(self, images, mode=None, return_mask=None, mask_cfg=None):
55
+ ret_data = {}
56
+ for image in images:
57
+ ret_one_data = self.sbjref_ins.forward(image=image, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg)
58
+ if isinstance(ret_one_data, dict):
59
+ for key, val in ret_one_data.items():
60
+ if key in self.key_map:
61
+ new_key = self.key_map[key]
62
+ else:
63
+ continue
64
+ if new_key in ret_data:
65
+ ret_data[new_key].append(val)
66
+ else:
67
+ ret_data[new_key] = [val]
68
+ else:
69
+ if 'images' in ret_data:
70
+ ret_data['images'].append(ret_data)
71
+ else:
72
+ ret_data['images'] = [ret_data]
73
+ return ret_data
74
+
75
+
76
+ class AnimateAnythingAnnotator:
77
+ def __init__(self, cfg):
78
+ from .pose import PoseBodyFaceVideoAnnotator
79
+ self.pose_ins = PoseBodyFaceVideoAnnotator(cfg['POSE'])
80
+ self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
81
+
82
+ def forward(self, frames=None, images=None, mode=None, return_mask=None, mask_cfg=None):
83
+ ret_data = {}
84
+ ret_pose_data = self.pose_ins.forward(frames=frames)
85
+ ret_data.update({"frames": ret_pose_data})
86
+
87
+ ret_ref_data = self.ref_ins.forward(images=images, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg)
88
+ ret_data.update({"images": ret_ref_data['images']})
89
+
90
+ return ret_data
91
+
92
+
93
+ class SwapAnythingAnnotator:
94
+ def __init__(self, cfg):
95
+ from .inpainting import InpaintingVideoAnnotator
96
+ self.inp_ins = InpaintingVideoAnnotator(cfg['INPAINTING'])
97
+ self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
98
+
99
+ def forward(self, video=None, frames=None, images=None, mode=None, mask=None, bbox=None, label=None, caption=None, return_mask=None, mask_cfg=None):
100
+ ret_data = {}
101
+ mode = mode.split(',') if ',' in mode else [mode, mode]
102
+
103
+ ret_inp_data = self.inp_ins.forward(video=video, frames=frames, mode=mode[0], mask=mask, bbox=bbox, label=label, caption=caption, mask_cfg=mask_cfg)
104
+ ret_data.update(ret_inp_data)
105
+
106
+ ret_ref_data = self.ref_ins.forward(images=images, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg)
107
+ ret_data.update({"images": ret_ref_data['images']})
108
+
109
+ return ret_data
110
+
111
+
112
+ class ExpandAnythingAnnotator:
113
+ def __init__(self, cfg):
114
+ from .outpainting import OutpaintingAnnotator
115
+ from .frameref import FrameRefExpandAnnotator
116
+ self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
117
+ self.frameref_ins = FrameRefExpandAnnotator(cfg['FRAMEREF'])
118
+ self.outpainting_ins = OutpaintingAnnotator(cfg['OUTPAINTING'])
119
+
120
+ def forward(self, images=None, mode=None, return_mask=None, mask_cfg=None, direction=None, expand_ratio=None, expand_num=None):
121
+ ret_data = {}
122
+ expand_image, reference_image= images[0], images[1:]
123
+ mode = mode.split(',') if ',' in mode else ['firstframe', mode]
124
+
125
+ outpainting_data = self.outpainting_ins.forward(expand_image,expand_ratio=expand_ratio, direction=direction)
126
+ outpainting_image, outpainting_mask = outpainting_data['image'], outpainting_data['mask']
127
+
128
+ frameref_data = self.frameref_ins.forward(outpainting_image, mode=mode[0], expand_num=expand_num)
129
+ frames, masks = frameref_data['frames'], frameref_data['masks']
130
+ masks[0] = outpainting_mask
131
+ ret_data.update({"frames": frames, "masks": masks})
132
+
133
+ ret_ref_data = self.ref_ins.forward(images=reference_image, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg)
134
+ ret_data.update({"images": ret_ref_data['images']})
135
+
136
+ return ret_data
137
+
138
+
139
+ class MoveAnythingAnnotator:
140
+ def __init__(self, cfg):
141
+ from .layout import LayoutBboxAnnotator
142
+ self.layout_bbox_ins = LayoutBboxAnnotator(cfg['LAYOUTBBOX'])
143
+
144
+ def forward(self, image=None, bbox=None, label=None, expand_num=None):
145
+ frame_size = image.shape[:2] # [H, W]
146
+ ret_layout_data = self.layout_bbox_ins.forward(bbox, frame_size=frame_size, num_frames=expand_num, label=label)
147
+
148
+ out_frames = [image] + ret_layout_data
149
+ out_mask = [np.zeros(frame_size, dtype=np.uint8)] + [np.ones(frame_size, dtype=np.uint8) * 255] * len(ret_layout_data)
150
+
151
+ ret_data = {
152
+ "frames": out_frames,
153
+ "masks": out_mask
154
+ }
155
+ return ret_data
vace/annotators/depth.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+
8
+ from .utils import convert_to_numpy, resize_image, resize_image_ori
9
+
10
+
11
+ class DepthAnnotator:
12
+ def __init__(self, cfg, device=None):
13
+ from .midas.api import MiDaSInference
14
+ pretrained_model = cfg['PRETRAINED_MODEL']
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
16
+ self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
17
+ self.a = cfg.get('A', np.pi * 2.0)
18
+ self.bg_th = cfg.get('BG_TH', 0.1)
19
+
20
+ @torch.no_grad()
21
+ @torch.inference_mode()
22
+ @torch.autocast('cuda', enabled=False)
23
+ def forward(self, image):
24
+ image = convert_to_numpy(image)
25
+ image_depth = image
26
+ h, w, c = image.shape
27
+ image_depth, k = resize_image(image_depth,
28
+ 1024 if min(h, w) > 1024 else min(h, w))
29
+ image_depth = torch.from_numpy(image_depth).float().to(self.device)
30
+ image_depth = image_depth / 127.5 - 1.0
31
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
32
+ depth = self.model(image_depth)[0]
33
+
34
+ depth_pt = depth.clone()
35
+ depth_pt -= torch.min(depth_pt)
36
+ depth_pt /= torch.max(depth_pt)
37
+ depth_pt = depth_pt.cpu().numpy()
38
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
39
+ depth_image = depth_image[..., None].repeat(3, 2)
40
+
41
+ depth_image = resize_image_ori(h, w, depth_image, k)
42
+ return depth_image
43
+
44
+
45
+ class DepthVideoAnnotator(DepthAnnotator):
46
+ def forward(self, frames):
47
+ ret_frames = []
48
+ for frame in frames:
49
+ anno_frame = super().forward(np.array(frame))
50
+ ret_frames.append(anno_frame)
51
+ return ret_frames
vace/annotators/dwpose/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
vace/annotators/dwpose/onnxdet.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import numpy as np
5
+
6
+ import onnxruntime
7
+
8
+ def nms(boxes, scores, nms_thr):
9
+ """Single class NMS implemented in Numpy."""
10
+ x1 = boxes[:, 0]
11
+ y1 = boxes[:, 1]
12
+ x2 = boxes[:, 2]
13
+ y2 = boxes[:, 3]
14
+
15
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
16
+ order = scores.argsort()[::-1]
17
+
18
+ keep = []
19
+ while order.size > 0:
20
+ i = order[0]
21
+ keep.append(i)
22
+ xx1 = np.maximum(x1[i], x1[order[1:]])
23
+ yy1 = np.maximum(y1[i], y1[order[1:]])
24
+ xx2 = np.minimum(x2[i], x2[order[1:]])
25
+ yy2 = np.minimum(y2[i], y2[order[1:]])
26
+
27
+ w = np.maximum(0.0, xx2 - xx1 + 1)
28
+ h = np.maximum(0.0, yy2 - yy1 + 1)
29
+ inter = w * h
30
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
31
+
32
+ inds = np.where(ovr <= nms_thr)[0]
33
+ order = order[inds + 1]
34
+
35
+ return keep
36
+
37
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
38
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
39
+ final_dets = []
40
+ num_classes = scores.shape[1]
41
+ for cls_ind in range(num_classes):
42
+ cls_scores = scores[:, cls_ind]
43
+ valid_score_mask = cls_scores > score_thr
44
+ if valid_score_mask.sum() == 0:
45
+ continue
46
+ else:
47
+ valid_scores = cls_scores[valid_score_mask]
48
+ valid_boxes = boxes[valid_score_mask]
49
+ keep = nms(valid_boxes, valid_scores, nms_thr)
50
+ if len(keep) > 0:
51
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
52
+ dets = np.concatenate(
53
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54
+ )
55
+ final_dets.append(dets)
56
+ if len(final_dets) == 0:
57
+ return None
58
+ return np.concatenate(final_dets, 0)
59
+
60
+ def demo_postprocess(outputs, img_size, p6=False):
61
+ grids = []
62
+ expanded_strides = []
63
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
64
+
65
+ hsizes = [img_size[0] // stride for stride in strides]
66
+ wsizes = [img_size[1] // stride for stride in strides]
67
+
68
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
69
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
70
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
71
+ grids.append(grid)
72
+ shape = grid.shape[:2]
73
+ expanded_strides.append(np.full((*shape, 1), stride))
74
+
75
+ grids = np.concatenate(grids, 1)
76
+ expanded_strides = np.concatenate(expanded_strides, 1)
77
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
78
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
79
+
80
+ return outputs
81
+
82
+ def preprocess(img, input_size, swap=(2, 0, 1)):
83
+ if len(img.shape) == 3:
84
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
85
+ else:
86
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
87
+
88
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
89
+ resized_img = cv2.resize(
90
+ img,
91
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
92
+ interpolation=cv2.INTER_LINEAR,
93
+ ).astype(np.uint8)
94
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
95
+
96
+ padded_img = padded_img.transpose(swap)
97
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
98
+ return padded_img, r
99
+
100
+ def inference_detector(session, oriImg):
101
+ input_shape = (640,640)
102
+ img, ratio = preprocess(oriImg, input_shape)
103
+
104
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
105
+ output = session.run(None, ort_inputs)
106
+ predictions = demo_postprocess(output[0], input_shape)[0]
107
+
108
+ boxes = predictions[:, :4]
109
+ scores = predictions[:, 4:5] * predictions[:, 5:]
110
+
111
+ boxes_xyxy = np.ones_like(boxes)
112
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
113
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
114
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
115
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
116
+ boxes_xyxy /= ratio
117
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
118
+ if dets is not None:
119
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
120
+ isscore = final_scores>0.3
121
+ iscat = final_cls_inds == 0
122
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
123
+ final_boxes = final_boxes[isbbox]
124
+ else:
125
+ final_boxes = np.array([])
126
+
127
+ return final_boxes
vace/annotators/dwpose/onnxpose.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from typing import List, Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ def preprocess(
10
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12
+ """Do preprocessing for RTMPose model inference.
13
+
14
+ Args:
15
+ img (np.ndarray): Input image in shape.
16
+ input_size (tuple): Input image size in shape (w, h).
17
+
18
+ Returns:
19
+ tuple:
20
+ - resized_img (np.ndarray): Preprocessed image.
21
+ - center (np.ndarray): Center of image.
22
+ - scale (np.ndarray): Scale of image.
23
+ """
24
+ # get shape of image
25
+ img_shape = img.shape[:2]
26
+ out_img, out_center, out_scale = [], [], []
27
+ if len(out_bbox) == 0:
28
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29
+ for i in range(len(out_bbox)):
30
+ x0 = out_bbox[i][0]
31
+ y0 = out_bbox[i][1]
32
+ x1 = out_bbox[i][2]
33
+ y1 = out_bbox[i][3]
34
+ bbox = np.array([x0, y0, x1, y1])
35
+
36
+ # get center and scale
37
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38
+
39
+ # do affine transformation
40
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
41
+
42
+ # normalize image
43
+ mean = np.array([123.675, 116.28, 103.53])
44
+ std = np.array([58.395, 57.12, 57.375])
45
+ resized_img = (resized_img - mean) / std
46
+
47
+ out_img.append(resized_img)
48
+ out_center.append(center)
49
+ out_scale.append(scale)
50
+
51
+ return out_img, out_center, out_scale
52
+
53
+
54
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55
+ """Inference RTMPose model.
56
+
57
+ Args:
58
+ sess (ort.InferenceSession): ONNXRuntime session.
59
+ img (np.ndarray): Input image in shape.
60
+
61
+ Returns:
62
+ outputs (np.ndarray): Output of RTMPose model.
63
+ """
64
+ all_out = []
65
+ # build input
66
+ for i in range(len(img)):
67
+ input = [img[i].transpose(2, 0, 1)]
68
+
69
+ # build output
70
+ sess_input = {sess.get_inputs()[0].name: input}
71
+ sess_output = []
72
+ for out in sess.get_outputs():
73
+ sess_output.append(out.name)
74
+
75
+ # run model
76
+ outputs = sess.run(sess_output, sess_input)
77
+ all_out.append(outputs)
78
+
79
+ return all_out
80
+
81
+
82
+ def postprocess(outputs: List[np.ndarray],
83
+ model_input_size: Tuple[int, int],
84
+ center: Tuple[int, int],
85
+ scale: Tuple[int, int],
86
+ simcc_split_ratio: float = 2.0
87
+ ) -> Tuple[np.ndarray, np.ndarray]:
88
+ """Postprocess for RTMPose model output.
89
+
90
+ Args:
91
+ outputs (np.ndarray): Output of RTMPose model.
92
+ model_input_size (tuple): RTMPose model Input image size.
93
+ center (tuple): Center of bbox in shape (x, y).
94
+ scale (tuple): Scale of bbox in shape (w, h).
95
+ simcc_split_ratio (float): Split ratio of simcc.
96
+
97
+ Returns:
98
+ tuple:
99
+ - keypoints (np.ndarray): Rescaled keypoints.
100
+ - scores (np.ndarray): Model predict scores.
101
+ """
102
+ all_key = []
103
+ all_score = []
104
+ for i in range(len(outputs)):
105
+ # use simcc to decode
106
+ simcc_x, simcc_y = outputs[i]
107
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
108
+
109
+ # rescale keypoints
110
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
111
+ all_key.append(keypoints[0])
112
+ all_score.append(scores[0])
113
+
114
+ return np.array(all_key), np.array(all_score)
115
+
116
+
117
+ def bbox_xyxy2cs(bbox: np.ndarray,
118
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
119
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
120
+
121
+ Args:
122
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
123
+ as (left, top, right, bottom)
124
+ padding (float): BBox padding factor that will be multilied to scale.
125
+ Default: 1.0
126
+
127
+ Returns:
128
+ tuple: A tuple containing center and scale.
129
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
132
+ (n, 2)
133
+ """
134
+ # convert single bbox from (4, ) to (1, 4)
135
+ dim = bbox.ndim
136
+ if dim == 1:
137
+ bbox = bbox[None, :]
138
+
139
+ # get bbox center and scale
140
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
141
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
142
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
143
+
144
+ if dim == 1:
145
+ center = center[0]
146
+ scale = scale[0]
147
+
148
+ return center, scale
149
+
150
+
151
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
152
+ aspect_ratio: float) -> np.ndarray:
153
+ """Extend the scale to match the given aspect ratio.
154
+
155
+ Args:
156
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
157
+ aspect_ratio (float): The ratio of ``w/h``
158
+
159
+ Returns:
160
+ np.ndarray: The reshaped image scale in (2, )
161
+ """
162
+ w, h = np.hsplit(bbox_scale, [1])
163
+ bbox_scale = np.where(w > h * aspect_ratio,
164
+ np.hstack([w, w / aspect_ratio]),
165
+ np.hstack([h * aspect_ratio, h]))
166
+ return bbox_scale
167
+
168
+
169
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
170
+ """Rotate a point by an angle.
171
+
172
+ Args:
173
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
174
+ angle_rad (float): rotation angle in radian
175
+
176
+ Returns:
177
+ np.ndarray: Rotated point in shape (2, )
178
+ """
179
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
180
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
181
+ return rot_mat @ pt
182
+
183
+
184
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
185
+ """To calculate the affine matrix, three pairs of points are required. This
186
+ function is used to get the 3rd point, given 2D points a & b.
187
+
188
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
189
+ anticlockwise, using b as the rotation center.
190
+
191
+ Args:
192
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
193
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
194
+
195
+ Returns:
196
+ np.ndarray: The 3rd point.
197
+ """
198
+ direction = a - b
199
+ c = b + np.r_[-direction[1], direction[0]]
200
+ return c
201
+
202
+
203
+ def get_warp_matrix(center: np.ndarray,
204
+ scale: np.ndarray,
205
+ rot: float,
206
+ output_size: Tuple[int, int],
207
+ shift: Tuple[float, float] = (0., 0.),
208
+ inv: bool = False) -> np.ndarray:
209
+ """Calculate the affine transformation matrix that can warp the bbox area
210
+ in the input image to the output size.
211
+
212
+ Args:
213
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
214
+ scale (np.ndarray[2, ]): Scale of the bounding box
215
+ wrt [width, height].
216
+ rot (float): Rotation angle (degree).
217
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
218
+ destination heatmaps.
219
+ shift (0-100%): Shift translation ratio wrt the width/height.
220
+ Default (0., 0.).
221
+ inv (bool): Option to inverse the affine transform direction.
222
+ (inv=False: src->dst or inv=True: dst->src)
223
+
224
+ Returns:
225
+ np.ndarray: A 2x3 transformation matrix
226
+ """
227
+ shift = np.array(shift)
228
+ src_w = scale[0]
229
+ dst_w = output_size[0]
230
+ dst_h = output_size[1]
231
+
232
+ # compute transformation matrix
233
+ rot_rad = np.deg2rad(rot)
234
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
235
+ dst_dir = np.array([0., dst_w * -0.5])
236
+
237
+ # get four corners of the src rectangle in the original image
238
+ src = np.zeros((3, 2), dtype=np.float32)
239
+ src[0, :] = center + scale * shift
240
+ src[1, :] = center + src_dir + scale * shift
241
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
242
+
243
+ # get four corners of the dst rectangle in the input image
244
+ dst = np.zeros((3, 2), dtype=np.float32)
245
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
246
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
247
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
248
+
249
+ if inv:
250
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
251
+ else:
252
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
253
+
254
+ return warp_mat
255
+
256
+
257
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
258
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
259
+ """Get the bbox image as the model input by affine transform.
260
+
261
+ Args:
262
+ input_size (dict): The input size of the model.
263
+ bbox_scale (dict): The bbox scale of the img.
264
+ bbox_center (dict): The bbox center of the img.
265
+ img (np.ndarray): The original image.
266
+
267
+ Returns:
268
+ tuple: A tuple containing center and scale.
269
+ - np.ndarray[float32]: img after affine transform.
270
+ - np.ndarray[float32]: bbox scale after affine transform.
271
+ """
272
+ w, h = input_size
273
+ warp_size = (int(w), int(h))
274
+
275
+ # reshape bbox to fixed aspect ratio
276
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
277
+
278
+ # get the affine matrix
279
+ center = bbox_center
280
+ scale = bbox_scale
281
+ rot = 0
282
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
283
+
284
+ # do affine transform
285
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
286
+
287
+ return img, bbox_scale
288
+
289
+
290
+ def get_simcc_maximum(simcc_x: np.ndarray,
291
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
292
+ """Get maximum response location and value from simcc representations.
293
+
294
+ Note:
295
+ instance number: N
296
+ num_keypoints: K
297
+ heatmap height: H
298
+ heatmap width: W
299
+
300
+ Args:
301
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
302
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
303
+
304
+ Returns:
305
+ tuple:
306
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
307
+ (K, 2) or (N, K, 2)
308
+ - vals (np.ndarray): values of maximum heatmap responses in shape
309
+ (K,) or (N, K)
310
+ """
311
+ N, K, Wx = simcc_x.shape
312
+ simcc_x = simcc_x.reshape(N * K, -1)
313
+ simcc_y = simcc_y.reshape(N * K, -1)
314
+
315
+ # get maximum value locations
316
+ x_locs = np.argmax(simcc_x, axis=1)
317
+ y_locs = np.argmax(simcc_y, axis=1)
318
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
319
+ max_val_x = np.amax(simcc_x, axis=1)
320
+ max_val_y = np.amax(simcc_y, axis=1)
321
+
322
+ # get maximum value across x and y axis
323
+ mask = max_val_x > max_val_y
324
+ max_val_x[mask] = max_val_y[mask]
325
+ vals = max_val_x
326
+ locs[vals <= 0.] = -1
327
+
328
+ # reshape
329
+ locs = locs.reshape(N, K, 2)
330
+ vals = vals.reshape(N, K)
331
+
332
+ return locs, vals
333
+
334
+
335
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
336
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
337
+ """Modulate simcc distribution with Gaussian.
338
+
339
+ Args:
340
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
341
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
342
+ simcc_split_ratio (int): The split ratio of simcc.
343
+
344
+ Returns:
345
+ tuple: A tuple containing center and scale.
346
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
347
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
348
+ """
349
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
350
+ keypoints /= simcc_split_ratio
351
+
352
+ return keypoints, scores
353
+
354
+
355
+ def inference_pose(session, out_bbox, oriImg):
356
+ h, w = session.get_inputs()[0].shape[2:]
357
+ model_input_size = (w, h)
358
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
359
+ outputs = inference(session, resized_img)
360
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
361
+
362
+ return keypoints, scores
vace/annotators/dwpose/util.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import numpy as np
5
+ import matplotlib
6
+ import cv2
7
+
8
+
9
+ eps = 0.01
10
+
11
+
12
+ def smart_resize(x, s):
13
+ Ht, Wt = s
14
+ if x.ndim == 2:
15
+ Ho, Wo = x.shape
16
+ Co = 1
17
+ else:
18
+ Ho, Wo, Co = x.shape
19
+ if Co == 3 or Co == 1:
20
+ k = float(Ht + Wt) / float(Ho + Wo)
21
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
22
+ else:
23
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
24
+
25
+
26
+ def smart_resize_k(x, fx, fy):
27
+ if x.ndim == 2:
28
+ Ho, Wo = x.shape
29
+ Co = 1
30
+ else:
31
+ Ho, Wo, Co = x.shape
32
+ Ht, Wt = Ho * fy, Wo * fx
33
+ if Co == 3 or Co == 1:
34
+ k = float(Ht + Wt) / float(Ho + Wo)
35
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
36
+ else:
37
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
38
+
39
+
40
+ def padRightDownCorner(img, stride, padValue):
41
+ h = img.shape[0]
42
+ w = img.shape[1]
43
+
44
+ pad = 4 * [None]
45
+ pad[0] = 0 # up
46
+ pad[1] = 0 # left
47
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
48
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
49
+
50
+ img_padded = img
51
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
52
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
53
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
54
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
55
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
56
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
57
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
58
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
59
+
60
+ return img_padded, pad
61
+
62
+
63
+ def transfer(model, model_weights):
64
+ transfered_model_weights = {}
65
+ for weights_name in model.state_dict().keys():
66
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
67
+ return transfered_model_weights
68
+
69
+
70
+ def draw_bodypose(canvas, candidate, subset):
71
+ H, W, C = canvas.shape
72
+ candidate = np.array(candidate)
73
+ subset = np.array(subset)
74
+
75
+ stickwidth = 4
76
+
77
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
78
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
79
+ [1, 16], [16, 18], [3, 17], [6, 18]]
80
+
81
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
82
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
83
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
84
+
85
+ for i in range(17):
86
+ for n in range(len(subset)):
87
+ index = subset[n][np.array(limbSeq[i]) - 1]
88
+ if -1 in index:
89
+ continue
90
+ Y = candidate[index.astype(int), 0] * float(W)
91
+ X = candidate[index.astype(int), 1] * float(H)
92
+ mX = np.mean(X)
93
+ mY = np.mean(Y)
94
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
95
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
96
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
97
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
98
+
99
+ canvas = (canvas * 0.6).astype(np.uint8)
100
+
101
+ for i in range(18):
102
+ for n in range(len(subset)):
103
+ index = int(subset[n][i])
104
+ if index == -1:
105
+ continue
106
+ x, y = candidate[index][0:2]
107
+ x = int(x * W)
108
+ y = int(y * H)
109
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
110
+
111
+ return canvas
112
+
113
+
114
+ def draw_handpose(canvas, all_hand_peaks):
115
+ H, W, C = canvas.shape
116
+
117
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
118
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
119
+
120
+ for peaks in all_hand_peaks:
121
+ peaks = np.array(peaks)
122
+
123
+ for ie, e in enumerate(edges):
124
+ x1, y1 = peaks[e[0]]
125
+ x2, y2 = peaks[e[1]]
126
+ x1 = int(x1 * W)
127
+ y1 = int(y1 * H)
128
+ x2 = int(x2 * W)
129
+ y2 = int(y2 * H)
130
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
131
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
132
+
133
+ for i, keyponit in enumerate(peaks):
134
+ x, y = keyponit
135
+ x = int(x * W)
136
+ y = int(y * H)
137
+ if x > eps and y > eps:
138
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
139
+ return canvas
140
+
141
+
142
+ def draw_facepose(canvas, all_lmks):
143
+ H, W, C = canvas.shape
144
+ for lmks in all_lmks:
145
+ lmks = np.array(lmks)
146
+ for lmk in lmks:
147
+ x, y = lmk
148
+ x = int(x * W)
149
+ y = int(y * H)
150
+ if x > eps and y > eps:
151
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
152
+ return canvas
153
+
154
+
155
+ # detect hand according to body pose keypoints
156
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
157
+ def handDetect(candidate, subset, oriImg):
158
+ # right hand: wrist 4, elbow 3, shoulder 2
159
+ # left hand: wrist 7, elbow 6, shoulder 5
160
+ ratioWristElbow = 0.33
161
+ detect_result = []
162
+ image_height, image_width = oriImg.shape[0:2]
163
+ for person in subset.astype(int):
164
+ # if any of three not detected
165
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
166
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
167
+ if not (has_left or has_right):
168
+ continue
169
+ hands = []
170
+ #left hand
171
+ if has_left:
172
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
173
+ x1, y1 = candidate[left_shoulder_index][:2]
174
+ x2, y2 = candidate[left_elbow_index][:2]
175
+ x3, y3 = candidate[left_wrist_index][:2]
176
+ hands.append([x1, y1, x2, y2, x3, y3, True])
177
+ # right hand
178
+ if has_right:
179
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
180
+ x1, y1 = candidate[right_shoulder_index][:2]
181
+ x2, y2 = candidate[right_elbow_index][:2]
182
+ x3, y3 = candidate[right_wrist_index][:2]
183
+ hands.append([x1, y1, x2, y2, x3, y3, False])
184
+
185
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
186
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
187
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
188
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
189
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
190
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
191
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
192
+ x = x3 + ratioWristElbow * (x3 - x2)
193
+ y = y3 + ratioWristElbow * (y3 - y2)
194
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
195
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
196
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
197
+ # x-y refers to the center --> offset to topLeft point
198
+ # handRectangle.x -= handRectangle.width / 2.f;
199
+ # handRectangle.y -= handRectangle.height / 2.f;
200
+ x -= width / 2
201
+ y -= width / 2 # width = height
202
+ # overflow the image
203
+ if x < 0: x = 0
204
+ if y < 0: y = 0
205
+ width1 = width
206
+ width2 = width
207
+ if x + width > image_width: width1 = image_width - x
208
+ if y + width > image_height: width2 = image_height - y
209
+ width = min(width1, width2)
210
+ # the max hand box value is 20 pixels
211
+ if width >= 20:
212
+ detect_result.append([int(x), int(y), int(width), is_left])
213
+
214
+ '''
215
+ return value: [[x, y, w, True if left hand else False]].
216
+ width=height since the network require squared input.
217
+ x, y is the coordinate of top left
218
+ '''
219
+ return detect_result
220
+
221
+
222
+ # Written by Lvmin
223
+ def faceDetect(candidate, subset, oriImg):
224
+ # left right eye ear 14 15 16 17
225
+ detect_result = []
226
+ image_height, image_width = oriImg.shape[0:2]
227
+ for person in subset.astype(int):
228
+ has_head = person[0] > -1
229
+ if not has_head:
230
+ continue
231
+
232
+ has_left_eye = person[14] > -1
233
+ has_right_eye = person[15] > -1
234
+ has_left_ear = person[16] > -1
235
+ has_right_ear = person[17] > -1
236
+
237
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
238
+ continue
239
+
240
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
241
+
242
+ width = 0.0
243
+ x0, y0 = candidate[head][:2]
244
+
245
+ if has_left_eye:
246
+ x1, y1 = candidate[left_eye][:2]
247
+ d = max(abs(x0 - x1), abs(y0 - y1))
248
+ width = max(width, d * 3.0)
249
+
250
+ if has_right_eye:
251
+ x1, y1 = candidate[right_eye][:2]
252
+ d = max(abs(x0 - x1), abs(y0 - y1))
253
+ width = max(width, d * 3.0)
254
+
255
+ if has_left_ear:
256
+ x1, y1 = candidate[left_ear][:2]
257
+ d = max(abs(x0 - x1), abs(y0 - y1))
258
+ width = max(width, d * 1.5)
259
+
260
+ if has_right_ear:
261
+ x1, y1 = candidate[right_ear][:2]
262
+ d = max(abs(x0 - x1), abs(y0 - y1))
263
+ width = max(width, d * 1.5)
264
+
265
+ x, y = x0, y0
266
+
267
+ x -= width
268
+ y -= width
269
+
270
+ if x < 0:
271
+ x = 0
272
+
273
+ if y < 0:
274
+ y = 0
275
+
276
+ width1 = width * 2
277
+ width2 = width * 2
278
+
279
+ if x + width > image_width:
280
+ width1 = image_width - x
281
+
282
+ if y + width > image_height:
283
+ width2 = image_height - y
284
+
285
+ width = min(width1, width2)
286
+
287
+ if width >= 20:
288
+ detect_result.append([int(x), int(y), int(width)])
289
+
290
+ return detect_result
291
+
292
+
293
+ # get max index of 2d array
294
+ def npmax(array):
295
+ arrayindex = array.argmax(1)
296
+ arrayvalue = array.max(1)
297
+ i = arrayvalue.argmax()
298
+ j = arrayindex[i]
299
+ return i, j
vace/annotators/dwpose/wholebody.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from .onnxdet import inference_detector
7
+ from .onnxpose import inference_pose
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
39
+
40
+ class Wholebody:
41
+ def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
42
+
43
+ providers = ['CPUExecutionProvider'
44
+ ] if device == 'cpu' else ['CUDAExecutionProvider']
45
+ # onnx_det = 'annotator/ckpts/yolox_l.onnx'
46
+ # onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
47
+
48
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
49
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
50
+
51
+ def __call__(self, ori_img):
52
+ det_result = inference_detector(self.session_det, ori_img)
53
+ keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
54
+
55
+ keypoints_info = np.concatenate(
56
+ (keypoints, scores[..., None]), axis=-1)
57
+ # compute neck joint
58
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
59
+ # neck score when visualizing pred
60
+ neck[:, 2:4] = np.logical_and(
61
+ keypoints_info[:, 5, 2:4] > 0.3,
62
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
63
+ new_keypoints_info = np.insert(
64
+ keypoints_info, 17, neck, axis=1)
65
+ mmpose_idx = [
66
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
67
+ ]
68
+ openpose_idx = [
69
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
70
+ ]
71
+ new_keypoints_info[:, openpose_idx] = \
72
+ new_keypoints_info[:, mmpose_idx]
73
+ keypoints_info = new_keypoints_info
74
+
75
+ keypoints, scores = keypoints_info[
76
+ ..., :2], keypoints_info[..., 2]
77
+
78
+ return keypoints, scores, det_result
79
+
80
+
vace/annotators/face.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .utils import convert_to_numpy
8
+
9
+
10
+ class FaceAnnotator:
11
+ def __init__(self, cfg, device=None):
12
+ from insightface.app import FaceAnalysis
13
+ self.return_raw = cfg.get('RETURN_RAW', True)
14
+ self.return_mask = cfg.get('RETURN_MASK', False)
15
+ self.return_dict = cfg.get('RETURN_DICT', False)
16
+ self.multi_face = cfg.get('MULTI_FACE', True)
17
+ pretrained_model = cfg['PRETRAINED_MODEL']
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
19
+ self.device_id = self.device.index if self.device.type == 'cuda' else None
20
+ ctx_id = self.device_id if self.device_id is not None else 0
21
+ self.model = FaceAnalysis(name=cfg.MODEL_NAME, root=pretrained_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
22
+ self.model.prepare(ctx_id=ctx_id, det_size=(640, 640))
23
+
24
+ def forward(self, image=None, return_mask=None, return_dict=None):
25
+ return_mask = return_mask if return_mask is not None else self.return_mask
26
+ return_dict = return_dict if return_dict is not None else self.return_dict
27
+ image = convert_to_numpy(image)
28
+ # [dict_keys(['bbox', 'kps', 'det_score', 'landmark_3d_68', 'pose', 'landmark_2d_106', 'gender', 'age', 'embedding'])]
29
+ faces = self.model.get(image)
30
+ if self.return_raw:
31
+ return faces
32
+ else:
33
+ crop_face_list, mask_list = [], []
34
+ if len(faces) > 0:
35
+ if not self.multi_face:
36
+ faces = faces[:1]
37
+ for face in faces:
38
+ x_min, y_min, x_max, y_max = face['bbox'].tolist()
39
+ crop_face = image[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1]
40
+ crop_face_list.append(crop_face)
41
+ mask = np.zeros_like(image[:, :, 0])
42
+ mask[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] = 255
43
+ mask_list.append(mask)
44
+ if not self.multi_face:
45
+ crop_face_list = crop_face_list[0]
46
+ mask_list = mask_list[0]
47
+ if return_mask:
48
+ if return_dict:
49
+ return {'image': crop_face_list, 'mask': mask_list}
50
+ else:
51
+ return crop_face_list, mask_list
52
+ else:
53
+ return crop_face_list
54
+ else:
55
+ return None
vace/annotators/flow.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import numpy as np
5
+ import argparse
6
+
7
+ from .utils import convert_to_numpy
8
+
9
+ class FlowAnnotator:
10
+ def __init__(self, cfg, device=None):
11
+ try:
12
+ from raft import RAFT
13
+ from raft.utils.utils import InputPadder
14
+ from raft.utils import flow_viz
15
+ except:
16
+ import warnings
17
+ warnings.warn(
18
+ "ignore raft import, please pip install raft package. you can refer to models/VACE-Annotators/flow/raft-1.0.0-py3-none-any.whl")
19
+
20
+ params = {
21
+ "small": False,
22
+ "mixed_precision": False,
23
+ "alternate_corr": False
24
+ }
25
+ params = argparse.Namespace(**params)
26
+ pretrained_model = cfg['PRETRAINED_MODEL']
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
28
+ self.model = RAFT(params)
29
+ self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
30
+ self.model = self.model.to(self.device).eval()
31
+ self.InputPadder = InputPadder
32
+ self.flow_viz = flow_viz
33
+
34
+ def forward(self, frames):
35
+ # frames / RGB
36
+ frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
37
+ flow_up_list, flow_up_vis_list = [], []
38
+ with torch.no_grad():
39
+ for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
40
+ padder = self.InputPadder(image1.shape)
41
+ image1, image2 = padder.pad(image1, image2)
42
+ flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
43
+ flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
44
+ flow_up_vis = self.flow_viz.flow_to_image(flow_up)
45
+ flow_up_list.append(flow_up)
46
+ flow_up_vis_list.append(flow_up_vis)
47
+ return flow_up_list, flow_up_vis_list # RGB
48
+
49
+
50
+ class FlowVisAnnotator(FlowAnnotator):
51
+ def forward(self, frames):
52
+ flow_up_list, flow_up_vis_list = super().forward(frames)
53
+ return flow_up_vis_list[:1] + flow_up_vis_list
vace/annotators/frameref.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+ import numpy as np
5
+ from .utils import align_frames
6
+
7
+
8
+ class FrameRefExtractAnnotator:
9
+ para_dict = {}
10
+
11
+ def __init__(self, cfg, device=None):
12
+ # first / last / firstlast / random
13
+ self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1},
14
+ {"mode": "last", "proba": 0.1},
15
+ {"mode": "firstlast", "proba": 0.1},
16
+ {"mode": "random", "proba": 0.1}])
17
+ self.ref_num = cfg.get('REF_NUM', 1)
18
+ self.ref_color = cfg.get('REF_COLOR', 127.5)
19
+ self.return_dict = cfg.get('RETURN_DICT', True)
20
+ self.return_mask = cfg.get('RETURN_MASK', True)
21
+
22
+
23
+ def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None):
24
+ return_mask = return_mask if return_mask is not None else self.return_mask
25
+ return_dict = return_dict if return_dict is not None else self.return_dict
26
+ ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg
27
+ ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg
28
+ probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg]
29
+ sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0]
30
+ mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original'
31
+ ref_num = int(ref_num) if ref_num is not None else self.ref_num
32
+
33
+ frame_num = len(frames)
34
+ frame_num_range = list(range(frame_num))
35
+ if mode == "first":
36
+ sel_idx = frame_num_range[:ref_num]
37
+ elif mode == "last":
38
+ sel_idx = frame_num_range[-ref_num:]
39
+ elif mode == "firstlast":
40
+ sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:]
41
+ elif mode == "random":
42
+ sel_idx = random.sample(frame_num_range, ref_num)
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ out_frames, out_masks = [], []
47
+ for i in range(frame_num):
48
+ if i in sel_idx:
49
+ out_frame = frames[i]
50
+ out_mask = np.zeros_like(frames[i][:, :, 0])
51
+ else:
52
+ out_frame = np.ones_like(frames[i]) * self.ref_color
53
+ out_mask = np.ones_like(frames[i][:, :, 0]) * 255
54
+ out_frames.append(out_frame)
55
+ out_masks.append(out_mask)
56
+
57
+ if return_dict:
58
+ ret_data = {"frames": out_frames}
59
+ if return_mask:
60
+ ret_data['masks'] = out_masks
61
+ return ret_data
62
+ else:
63
+ if return_mask:
64
+ return out_frames, out_masks
65
+ else:
66
+ return out_frames
67
+
68
+
69
+
70
+ class FrameRefExpandAnnotator:
71
+ para_dict = {}
72
+
73
+ def __init__(self, cfg, device=None):
74
+ # first / last / firstlast
75
+ self.ref_color = cfg.get('REF_COLOR', 127.5)
76
+ self.return_mask = cfg.get('RETURN_MASK', True)
77
+ self.return_dict = cfg.get('RETURN_DICT', True)
78
+ self.mode = cfg.get('MODE', "firstframe")
79
+ assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"]
80
+
81
+ def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None):
82
+ mode = mode if mode is not None else self.mode
83
+ return_mask = return_mask if return_mask is not None else self.return_mask
84
+ return_dict = return_dict if return_dict is not None else self.return_dict
85
+
86
+ if 'frame' in mode:
87
+ frames = [image] if image is not None and not isinstance(frames, list) else image
88
+ frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2
89
+
90
+ expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num
91
+ expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num
92
+ source_frames = frames
93
+ source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames)
94
+
95
+ if mode in ["firstframe", "firstclip"]:
96
+ out_frames = source_frames + expand_frames
97
+ out_masks = source_masks + expand_masks
98
+ elif mode in ["lastframe", "lastclip"]:
99
+ out_frames = expand_frames + source_frames
100
+ out_masks = expand_masks + source_masks
101
+ elif mode in ["firstlastframe", "firstlastclip"]:
102
+ source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2]
103
+ source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2)
104
+ out_frames = source_frames + expand_frames + source_frames_2
105
+ out_masks = source_masks + expand_masks + source_masks_2
106
+ else:
107
+ raise NotImplementedError
108
+
109
+ if return_dict:
110
+ ret_data = {"frames": out_frames}
111
+ if return_mask:
112
+ ret_data['masks'] = out_masks
113
+ return ret_data
114
+ else:
115
+ if return_mask:
116
+ return out_frames, out_masks
117
+ else:
118
+ return out_frames
vace/annotators/gdino.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import torchvision
8
+ from .utils import convert_to_numpy
9
+
10
+
11
+ class GDINOAnnotator:
12
+ def __init__(self, cfg, device=None):
13
+ try:
14
+ from groundingdino.util.inference import Model, load_model, load_image, predict
15
+ except:
16
+ import warnings
17
+ warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl")
18
+
19
+ grounding_dino_config_path = cfg['CONFIG_PATH']
20
+ grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL']
21
+ grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO
22
+ self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25)
23
+ self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2)
24
+ self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5)
25
+ self.use_nms = cfg.get('USE_NMS', True)
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
27
+ self.model = Model(model_config_path=grounding_dino_config_path,
28
+ model_checkpoint_path=grounding_dino_checkpoint_path,
29
+ device=self.device)
30
+
31
+ def forward(self, image, classes=None, caption=None):
32
+ image_bgr = convert_to_numpy(image)[..., ::-1] # bgr
33
+
34
+ if classes is not None:
35
+ classes = [classes] if isinstance(classes, str) else classes
36
+ detections = self.model.predict_with_classes(
37
+ image=image_bgr,
38
+ classes=classes,
39
+ box_threshold=self.box_threshold,
40
+ text_threshold=self.text_threshold
41
+ )
42
+ elif caption is not None:
43
+ detections, phrases = self.model.predict_with_caption(
44
+ image=image_bgr,
45
+ caption=caption,
46
+ box_threshold=self.box_threshold,
47
+ text_threshold=self.text_threshold
48
+ )
49
+ else:
50
+ raise NotImplementedError()
51
+
52
+ if self.use_nms:
53
+ nms_idx = torchvision.ops.nms(
54
+ torch.from_numpy(detections.xyxy),
55
+ torch.from_numpy(detections.confidence),
56
+ self.iou_threshold
57
+ ).numpy().tolist()
58
+ detections.xyxy = detections.xyxy[nms_idx]
59
+ detections.confidence = detections.confidence[nms_idx]
60
+ detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None
61
+
62
+ boxes = detections.xyxy
63
+ confidences = detections.confidence
64
+ class_ids = detections.class_id
65
+ class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases
66
+
67
+ ret_data = {
68
+ "boxes": boxes.tolist() if boxes is not None else None,
69
+ "confidences": confidences.tolist() if confidences is not None else None,
70
+ "class_ids": class_ids.tolist() if class_ids is not None else None,
71
+ "class_names": class_names if class_names is not None else None,
72
+ }
73
+ return ret_data
74
+
75
+
76
+ class GDINORAMAnnotator:
77
+ def __init__(self, cfg, device=None):
78
+ from .ram import RAMAnnotator
79
+ from .gdino import GDINOAnnotator
80
+ self.ram_model = RAMAnnotator(cfg['RAM'], device=device)
81
+ self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
82
+
83
+ def forward(self, image):
84
+ ram_res = self.ram_model.forward(image)
85
+ classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res
86
+ gdino_res = self.gdino_model.forward(image, classes=classes)
87
+ return gdino_res
88
+
vace/annotators/gray.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from .utils import convert_to_numpy
7
+
8
+
9
+ class GrayAnnotator:
10
+ def __init__(self, cfg):
11
+ pass
12
+ def forward(self, image):
13
+ image = convert_to_numpy(image)
14
+ gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
15
+ return gray_map[..., None].repeat(3, axis=2)
16
+
17
+
18
+ class GrayVideoAnnotator(GrayAnnotator):
19
+ def forward(self, frames):
20
+ ret_frames = []
21
+ for frame in frames:
22
+ anno_frame = super().forward(np.array(frame))
23
+ ret_frames.append(anno_frame)
24
+ return ret_frames
vace/annotators/inpainting.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import math
5
+ import random
6
+ from abc import ABCMeta
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+ from .utils import convert_to_numpy, convert_to_pil, single_rle_to_mask, get_mask_box, read_video_one_frame
12
+
13
+ class InpaintingAnnotator:
14
+ def __init__(self, cfg, device=None):
15
+ self.use_aug = cfg.get('USE_AUG', True)
16
+ self.return_mask = cfg.get('RETURN_MASK', True)
17
+ self.return_source = cfg.get('RETURN_SOURCE', True)
18
+ self.mask_color = cfg.get('MASK_COLOR', 128)
19
+ self.mode = cfg.get('MODE', "mask")
20
+ assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
21
+ if self.mode in ["salient", "salienttrack"]:
22
+ from .salient import SalientAnnotator
23
+ self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
24
+ if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
25
+ from .sam2 import SAM2ImageAnnotator
26
+ self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
27
+ if self.mode in ['label', 'caption']:
28
+ from .gdino import GDINOAnnotator
29
+ from .sam2 import SAM2ImageAnnotator
30
+ self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
31
+ self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
32
+ if self.mode in ['all']:
33
+ from .salient import SalientAnnotator
34
+ from .gdino import GDINOAnnotator
35
+ from .sam2 import SAM2ImageAnnotator
36
+ self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
37
+ self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
38
+ self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
39
+ if self.use_aug:
40
+ from .maskaug import MaskAugAnnotator
41
+ self.maskaug_anno = MaskAugAnnotator(cfg={})
42
+
43
+ def apply_plain_mask(self, image, mask, mask_color):
44
+ bool_mask = mask > 0
45
+ out_image = image.copy()
46
+ out_image[bool_mask] = mask_color
47
+ out_mask = np.where(bool_mask, 255, 0).astype(np.uint8)
48
+ return out_image, out_mask
49
+
50
+ def apply_seg_mask(self, image, mask, mask_color, mask_cfg=None):
51
+ out_mask = (mask * 255).astype('uint8')
52
+ if self.use_aug and mask_cfg is not None:
53
+ out_mask = self.maskaug_anno.forward(out_mask, mask_cfg)
54
+ bool_mask = out_mask > 0
55
+ out_image = image.copy()
56
+ out_image[bool_mask] = mask_color
57
+ return out_image, out_mask
58
+
59
+ def forward(self, image=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
60
+ mode = mode if mode is not None else self.mode
61
+ return_mask = return_mask if return_mask is not None else self.return_mask
62
+ return_source = return_source if return_source is not None else self.return_source
63
+ mask_color = mask_color if mask_color is not None else self.mask_color
64
+
65
+ image = convert_to_numpy(image)
66
+ out_image, out_mask = None, None
67
+ if mode in ['salient']:
68
+ mask = self.salient_model.forward(image)
69
+ out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
70
+ elif mode in ['mask']:
71
+ mask_h, mask_w = mask.shape[:2]
72
+ h, w = image.shape[:2]
73
+ if (mask_h ==h) and (mask_w == w):
74
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
75
+ out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
76
+ elif mode in ['bbox']:
77
+ x1, y1, x2, y2 = bbox
78
+ h, w = image.shape[:2]
79
+ x1, y1 = int(max(0, x1)), int(max(0, y1))
80
+ x2, y2 = int(min(w, x2)), int(min(h, y2))
81
+ out_image = image.copy()
82
+ out_image[y1:y2, x1:x2] = mask_color
83
+ out_mask = np.zeros((h, w), dtype=np.uint8)
84
+ out_mask[y1:y2, x1:x2] = 255
85
+ elif mode in ['salientmasktrack']:
86
+ mask = self.salient_model.forward(image)
87
+ resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
88
+ out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
89
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
90
+ elif mode in ['salientbboxtrack']:
91
+ mask = self.salient_model.forward(image)
92
+ bbox = get_mask_box(np.array(mask), threshold=1)
93
+ out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
94
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
95
+ elif mode in ['maskpointtrack']:
96
+ out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_point', return_mask=True)
97
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
98
+ elif mode in ['maskbboxtrack']:
99
+ out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_box', return_mask=True)
100
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
101
+ elif mode in ['masktrack']:
102
+ resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
103
+ out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
104
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
105
+ elif mode in ['bboxtrack']:
106
+ out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
107
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
108
+ elif mode in ['label']:
109
+ gdino_res = self.gdino_model.forward(image, classes=label)
110
+ if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
111
+ bboxes = gdino_res['boxes'][0]
112
+ else:
113
+ raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
114
+ out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
115
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
116
+ elif mode in ['caption']:
117
+ gdino_res = self.gdino_model.forward(image, caption=caption)
118
+ if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
119
+ bboxes = gdino_res['boxes'][0]
120
+ else:
121
+ raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
122
+ out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
123
+ out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
124
+
125
+ ret_data = {"image": out_image}
126
+ if return_mask:
127
+ ret_data["mask"] = out_mask
128
+ if return_source:
129
+ ret_data["src_image"] = image
130
+ return ret_data
131
+
132
+
133
+
134
+
135
+ class InpaintingVideoAnnotator:
136
+ def __init__(self, cfg, device=None):
137
+ self.use_aug = cfg.get('USE_AUG', True)
138
+ self.return_frame = cfg.get('RETURN_FRAME', True)
139
+ self.return_mask = cfg.get('RETURN_MASK', True)
140
+ self.return_source = cfg.get('RETURN_SOURCE', True)
141
+ self.mask_color = cfg.get('MASK_COLOR', 128)
142
+ self.mode = cfg.get('MODE', "mask")
143
+ assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
144
+ if self.mode in ["salient", "salienttrack"]:
145
+ from .salient import SalientAnnotator
146
+ self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
147
+ if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
148
+ from .sam2 import SAM2VideoAnnotator
149
+ self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
150
+ if self.mode in ['label', 'caption']:
151
+ from .gdino import GDINOAnnotator
152
+ from .sam2 import SAM2VideoAnnotator
153
+ self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
154
+ self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
155
+ if self.mode in ['all']:
156
+ from .salient import SalientAnnotator
157
+ from .gdino import GDINOAnnotator
158
+ from .sam2 import SAM2VideoAnnotator
159
+ self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
160
+ self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
161
+ self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
162
+ if self.use_aug:
163
+ from .maskaug import MaskAugAnnotator
164
+ self.maskaug_anno = MaskAugAnnotator(cfg={})
165
+
166
+ def apply_plain_mask(self, frames, mask, mask_color, return_frame=True):
167
+ out_frames = []
168
+ num_frames = len(frames)
169
+ bool_mask = mask > 0
170
+ out_masks = [np.where(bool_mask, 255, 0).astype(np.uint8)] * num_frames
171
+ if not return_frame:
172
+ return None, out_masks
173
+ for i in range(num_frames):
174
+ masked_frame = frames[i].copy()
175
+ masked_frame[bool_mask] = mask_color
176
+ out_frames.append(masked_frame)
177
+ return out_frames, out_masks
178
+
179
+ def apply_seg_mask(self, mask_data, frames, mask_color, mask_cfg=None, return_frame=True):
180
+ out_frames = []
181
+ out_masks = [(single_rle_to_mask(val[0]["mask"]) * 255).astype('uint8') for key, val in mask_data['annotations'].items()]
182
+ if not return_frame:
183
+ return None, out_masks
184
+ num_frames = min(len(out_masks), len(frames))
185
+ for i in range(num_frames):
186
+ sub_mask = out_masks[i]
187
+ if self.use_aug and mask_cfg is not None:
188
+ sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
189
+ out_masks[i] = sub_mask
190
+ bool_mask = sub_mask > 0
191
+ masked_frame = frames[i].copy()
192
+ masked_frame[bool_mask] = mask_color
193
+ out_frames.append(masked_frame)
194
+ out_masks = out_masks[:num_frames]
195
+ return out_frames, out_masks
196
+
197
+ def forward(self, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_frame=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
198
+ mode = mode if mode is not None else self.mode
199
+ return_frame = return_frame if return_frame is not None else self.return_frame
200
+ return_mask = return_mask if return_mask is not None else self.return_mask
201
+ return_source = return_source if return_source is not None else self.return_source
202
+ mask_color = mask_color if mask_color is not None else self.mask_color
203
+
204
+ out_frames, out_masks = [], []
205
+ if mode in ['salient']:
206
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
207
+ mask = self.salient_model.forward(first_frame)
208
+ out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
209
+ elif mode in ['mask']:
210
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
211
+ mask_h, mask_w = mask.shape[:2]
212
+ h, w = first_frame.shape[:2]
213
+ if (mask_h ==h) and (mask_w == w):
214
+ mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
215
+ out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
216
+ elif mode in ['bbox']:
217
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
218
+ num_frames = len(frames)
219
+ x1, y1, x2, y2 = bbox
220
+ h, w = first_frame.shape[:2]
221
+ x1, y1 = int(max(0, x1)), int(max(0, y1))
222
+ x2, y2 = int(min(w, x2)), int(min(h, y2))
223
+ mask = np.zeros((h, w), dtype=np.uint8)
224
+ mask[y1:y2, x1:x2] = 255
225
+ out_masks = [mask] * num_frames
226
+ if not return_frame:
227
+ out_frames = None
228
+ else:
229
+ for i in range(num_frames):
230
+ masked_frame = frames[i].copy()
231
+ masked_frame[y1:y2, x1:x2] = mask_color
232
+ out_frames.append(masked_frame)
233
+ elif mode in ['salientmasktrack']:
234
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
235
+ salient_mask = self.salient_model.forward(first_frame)
236
+ mask_data = self.sam2_model.forward(video=video, mask=salient_mask, task_type='mask')
237
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
238
+ elif mode in ['salientbboxtrack']:
239
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
240
+ salient_mask = self.salient_model.forward(first_frame)
241
+ bbox = get_mask_box(np.array(salient_mask), threshold=1)
242
+ mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
243
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
244
+ elif mode in ['maskpointtrack']:
245
+ mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_point')
246
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
247
+ elif mode in ['maskbboxtrack']:
248
+ mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_box')
249
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
250
+ elif mode in ['masktrack']:
251
+ mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask')
252
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
253
+ elif mode in ['bboxtrack']:
254
+ mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
255
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
256
+ elif mode in ['label']:
257
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
258
+ gdino_res = self.gdino_model.forward(first_frame, classes=label)
259
+ if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
260
+ bboxes = gdino_res['boxes'][0]
261
+ else:
262
+ raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
263
+ mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
264
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
265
+ elif mode in ['caption']:
266
+ first_frame = frames[0] if frames is not None else read_video_one_frame(video)
267
+ gdino_res = self.gdino_model.forward(first_frame, caption=caption)
268
+ if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
269
+ bboxes = gdino_res['boxes'][0]
270
+ else:
271
+ raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
272
+ mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
273
+ out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
274
+
275
+ ret_data = {}
276
+ if return_frame:
277
+ ret_data["frames"] = out_frames
278
+ if return_mask:
279
+ ret_data["masks"] = out_masks
280
+ return ret_data
281
+
282
+
283
+
vace/annotators/layout.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from .utils import convert_to_numpy
8
+
9
+
10
+ class LayoutBboxAnnotator:
11
+ def __init__(self, cfg, device=None):
12
+ self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
13
+ self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
14
+ self.frame_size = cfg.get('FRAME_SIZE', [720, 1280]) # [H, W]
15
+ self.num_frames = cfg.get('NUM_FRAMES', 81)
16
+ ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
17
+ self.color_dict = {'default': tuple(self.box_color)}
18
+ if ram_tag_color_path is not None:
19
+ lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
20
+ self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
21
+
22
+ def forward(self, bbox, frame_size=None, num_frames=None, label=None, color=None):
23
+ frame_size = frame_size if frame_size is not None else self.frame_size
24
+ num_frames = num_frames if num_frames is not None else self.num_frames
25
+ assert len(bbox) == 2, 'bbox should be a list of two elements (start_bbox & end_bbox)'
26
+ # frame_size = [H, W]
27
+ # bbox = [x1, y1, x2, y2]
28
+ label = label[0] if label is not None and isinstance(label, list) else label
29
+ if label is not None and label in self.color_dict:
30
+ box_color = self.color_dict[label]
31
+ elif color is not None:
32
+ box_color = color
33
+ else:
34
+ box_color = self.color_dict['default']
35
+ start_bbox, end_bbox = bbox
36
+ start_bbox = [start_bbox[0], start_bbox[1], start_bbox[2] - start_bbox[0], start_bbox[3] - start_bbox[1]]
37
+ start_bbox = np.array(start_bbox, dtype=np.float32)
38
+ end_bbox = [end_bbox[0], end_bbox[1], end_bbox[2] - end_bbox[0], end_bbox[3] - end_bbox[1]]
39
+ end_bbox = np.array(end_bbox, dtype=np.float32)
40
+ bbox_increment = (end_bbox - start_bbox) / num_frames
41
+ ret_frames = []
42
+ for frame_idx in range(num_frames):
43
+ frame = np.zeros((frame_size[0], frame_size[1], 3), dtype=np.uint8)
44
+ frame[:] = self.bg_color
45
+ current_bbox = start_bbox + bbox_increment * frame_idx
46
+ current_bbox = current_bbox.astype(int)
47
+ x, y, w, h = current_bbox
48
+ cv2.rectangle(frame, (x, y), (x + w, y + h), box_color, 2)
49
+ ret_frames.append(frame[..., ::-1])
50
+ return ret_frames
51
+
52
+
53
+
54
+
55
+ class LayoutMaskAnnotator:
56
+ def __init__(self, cfg, device=None):
57
+ self.use_aug = cfg.get('USE_AUG', False)
58
+ self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
59
+ self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
60
+ ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
61
+ self.color_dict = {'default': tuple(self.box_color)}
62
+ if ram_tag_color_path is not None:
63
+ lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
64
+ self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
65
+ if self.use_aug:
66
+ from .maskaug import MaskAugAnnotator
67
+ self.maskaug_anno = MaskAugAnnotator(cfg={})
68
+
69
+
70
+ def find_contours(self, mask):
71
+ contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
72
+ return contours
73
+
74
+ def draw_contours(self, canvas, contour, color):
75
+ canvas = np.ascontiguousarray(canvas, dtype=np.uint8)
76
+ canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3)
77
+ return canvas
78
+
79
+ def forward(self, mask=None, color=None, label=None, mask_cfg=None):
80
+ if not isinstance(mask, list):
81
+ is_batch = False
82
+ mask = [mask]
83
+ else:
84
+ is_batch = True
85
+
86
+ if label is not None and label in self.color_dict:
87
+ color = self.color_dict[label]
88
+ elif color is not None:
89
+ color = color
90
+ else:
91
+ color = self.color_dict['default']
92
+
93
+ ret_data = []
94
+ for sub_mask in mask:
95
+ sub_mask = convert_to_numpy(sub_mask)
96
+ if self.use_aug:
97
+ sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
98
+ canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255
99
+ contour = self.find_contours(sub_mask)
100
+ frame = self.draw_contours(canvas, contour, color)
101
+ ret_data.append(frame)
102
+
103
+ if is_batch:
104
+ return ret_data
105
+ else:
106
+ return ret_data[0]
107
+
108
+
109
+
110
+
111
+ class LayoutTrackAnnotator:
112
+ def __init__(self, cfg, device=None):
113
+ self.use_aug = cfg.get('USE_AUG', False)
114
+ self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
115
+ self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
116
+ ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
117
+ self.color_dict = {'default': tuple(self.box_color)}
118
+ if ram_tag_color_path is not None:
119
+ lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
120
+ self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
121
+ if self.use_aug:
122
+ from .maskaug import MaskAugAnnotator
123
+ self.maskaug_anno = MaskAugAnnotator(cfg={})
124
+ from .inpainting import InpaintingVideoAnnotator
125
+ self.inpainting_anno = InpaintingVideoAnnotator(cfg=cfg['INPAINTING'])
126
+
127
+ def find_contours(self, mask):
128
+ contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
129
+ return contours
130
+
131
+ def draw_contours(self, canvas, contour, color):
132
+ canvas = np.ascontiguousarray(canvas, dtype=np.uint8)
133
+ canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3)
134
+ return canvas
135
+
136
+ def forward(self, color=None, mask_cfg=None, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None):
137
+ inp_data = self.inpainting_anno.forward(frames, video, mask, bbox, label, caption, mode)
138
+ inp_masks = inp_data['masks']
139
+
140
+ label = label[0] if label is not None and isinstance(label, list) else label
141
+ if label is not None and label in self.color_dict:
142
+ color = self.color_dict[label]
143
+ elif color is not None:
144
+ color = color
145
+ else:
146
+ color = self.color_dict['default']
147
+
148
+ num_frames = len(inp_masks)
149
+ ret_data = []
150
+ for i in range(num_frames):
151
+ sub_mask = inp_masks[i]
152
+ if self.use_aug and mask_cfg is not None:
153
+ sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
154
+ canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255
155
+ contour = self.find_contours(sub_mask)
156
+ frame = self.draw_contours(canvas, contour, color)
157
+ ret_data.append(frame)
158
+
159
+ return ret_data
160
+
161
+
vace/annotators/mask.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ from scipy.spatial import ConvexHull
6
+ from skimage.draw import polygon
7
+ from scipy import ndimage
8
+
9
+ from .utils import convert_to_numpy
10
+
11
+
12
+ class MaskDrawAnnotator:
13
+ def __init__(self, cfg, device=None):
14
+ self.mode = cfg.get('MODE', 'maskpoint')
15
+ self.return_dict = cfg.get('RETURN_DICT', True)
16
+ assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox']
17
+
18
+ def forward(self,
19
+ mask=None,
20
+ image=None,
21
+ bbox=None,
22
+ mode=None,
23
+ return_dict=None):
24
+ mode = mode if mode is not None else self.mode
25
+ return_dict = return_dict if return_dict is not None else self.return_dict
26
+
27
+ mask = convert_to_numpy(mask) if mask is not None else None
28
+ image = convert_to_numpy(image) if image is not None else None
29
+
30
+ mask_shape = mask.shape
31
+ if mode == 'maskpoint':
32
+ scribble = mask.transpose(1, 0)
33
+ labeled_array, num_features = ndimage.label(scribble >= 255)
34
+ centers = ndimage.center_of_mass(scribble, labeled_array,
35
+ range(1, num_features + 1))
36
+ centers = np.array(centers)
37
+ out_mask = np.zeros(mask_shape, dtype=np.uint8)
38
+ hull = ConvexHull(centers)
39
+ hull_vertices = centers[hull.vertices]
40
+ rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape)
41
+ out_mask[rr, cc] = 255
42
+ elif mode == 'maskbbox':
43
+ scribble = mask.transpose(1, 0)
44
+ labeled_array, num_features = ndimage.label(scribble >= 255)
45
+ centers = ndimage.center_of_mass(scribble, labeled_array,
46
+ range(1, num_features + 1))
47
+ centers = np.array(centers)
48
+ # (x1, y1, x2, y2)
49
+ x_min = centers[:, 0].min()
50
+ x_max = centers[:, 0].max()
51
+ y_min = centers[:, 1].min()
52
+ y_max = centers[:, 1].max()
53
+ out_mask = np.zeros(mask_shape, dtype=np.uint8)
54
+ out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
55
+ if image is not None:
56
+ out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
57
+ elif mode == 'bbox':
58
+ if isinstance(bbox, list):
59
+ bbox = np.array(bbox)
60
+ x_min, y_min, x_max, y_max = bbox
61
+ out_mask = np.zeros(mask_shape, dtype=np.uint8)
62
+ out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
63
+ if image is not None:
64
+ out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
65
+ elif mode == 'mask':
66
+ out_mask = mask
67
+ else:
68
+ raise NotImplementedError
69
+
70
+ if return_dict:
71
+ if image is not None:
72
+ return {"image": out_image, "mask": out_mask}
73
+ else:
74
+ return {"mask": out_mask}
75
+ else:
76
+ if image is not None:
77
+ return out_image, out_mask
78
+ else:
79
+ return out_mask
vace/annotators/maskaug.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+
5
+ import random
6
+ from functools import partial
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image, ImageDraw
11
+
12
+ from .utils import convert_to_numpy
13
+
14
+
15
+
16
+ class MaskAugAnnotator:
17
+ def __init__(self, cfg, device=None):
18
+ # original / original_expand / hull / hull_expand / bbox / bbox_expand
19
+ self.mask_cfg = cfg.get('MASK_CFG', [{"mode": "original", "proba": 0.1},
20
+ {"mode": "original_expand", "proba": 0.1},
21
+ {"mode": "hull", "proba": 0.1},
22
+ {"mode": "hull_expand", "proba":0.1, "kwargs": {"expand_ratio": 0.2}},
23
+ {"mode": "bbox", "proba": 0.1},
24
+ {"mode": "bbox_expand", "proba": 0.1, "kwargs": {"min_expand_ratio": 0.2, "max_expand_ratio": 0.5}}])
25
+
26
+ def forward(self, mask, mask_cfg=None):
27
+ mask_cfg = mask_cfg if mask_cfg is not None else self.mask_cfg
28
+ if not isinstance(mask, list):
29
+ is_batch = False
30
+ masks = [mask]
31
+ else:
32
+ is_batch = True
33
+ masks = mask
34
+
35
+ mask_func = self.get_mask_func(mask_cfg)
36
+ # print(mask_func)
37
+ aug_masks = []
38
+ for submask in masks:
39
+ mask = convert_to_numpy(submask)
40
+ valid, large, h, w, bbox = self.get_mask_info(mask)
41
+ # print(valid, large, h, w, bbox)
42
+ if valid:
43
+ mask = mask_func(mask, bbox, h, w)
44
+ else:
45
+ mask = mask.astype(np.uint8)
46
+ aug_masks.append(mask)
47
+ return aug_masks if is_batch else aug_masks[0]
48
+
49
+ def get_mask_info(self, mask):
50
+ h, w = mask.shape
51
+ locs = mask.nonzero()
52
+ valid = True
53
+ if len(locs) < 1 or locs[0].shape[0] < 1 or locs[1].shape[0] < 1:
54
+ valid = False
55
+ return valid, False, h, w, [0, 0, 0, 0]
56
+
57
+ left, right = np.min(locs[1]), np.max(locs[1])
58
+ top, bottom = np.min(locs[0]), np.max(locs[0])
59
+ bbox = [left, top, right, bottom]
60
+
61
+ large = False
62
+ if (right - left + 1) * (bottom - top + 1) > 0.9 * h * w:
63
+ large = True
64
+ return valid, large, h, w, bbox
65
+
66
+ def get_expand_params(self, mask_kwargs):
67
+ if 'expand_ratio' in mask_kwargs:
68
+ expand_ratio = mask_kwargs['expand_ratio']
69
+ elif 'min_expand_ratio' in mask_kwargs and 'max_expand_ratio' in mask_kwargs:
70
+ expand_ratio = random.uniform(mask_kwargs['min_expand_ratio'], mask_kwargs['max_expand_ratio'])
71
+ else:
72
+ expand_ratio = 0.3
73
+
74
+ if 'expand_iters' in mask_kwargs:
75
+ expand_iters = mask_kwargs['expand_iters']
76
+ else:
77
+ expand_iters = random.randint(1, 10)
78
+
79
+ if 'expand_lrtp' in mask_kwargs:
80
+ expand_lrtp = mask_kwargs['expand_lrtp']
81
+ else:
82
+ expand_lrtp = [random.random(), random.random(), random.random(), random.random()]
83
+
84
+ return expand_ratio, expand_iters, expand_lrtp
85
+
86
+ def get_mask_func(self, mask_cfg):
87
+ if not isinstance(mask_cfg, list):
88
+ mask_cfg = [mask_cfg]
89
+ probas = [item['proba'] if 'proba' in item else 1.0 / len(mask_cfg) for item in mask_cfg]
90
+ sel_mask_cfg = random.choices(mask_cfg, weights=probas, k=1)[0]
91
+ mode = sel_mask_cfg['mode'] if 'mode' in sel_mask_cfg else 'original'
92
+ mask_kwargs = sel_mask_cfg['kwargs'] if 'kwargs' in sel_mask_cfg else {}
93
+
94
+ if mode == 'random':
95
+ mode = random.choice(['original', 'original_expand', 'hull', 'hull_expand', 'bbox', 'bbox_expand'])
96
+ if mode == 'original':
97
+ mask_func = partial(self.generate_mask)
98
+ elif mode == 'original_expand':
99
+ expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
100
+ mask_func = partial(self.generate_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
101
+ elif mode == 'hull':
102
+ clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise']
103
+ mask_func = partial(self.generate_hull_mask, clockwise=clockwise)
104
+ elif mode == 'hull_expand':
105
+ expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
106
+ clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise']
107
+ mask_func = partial(self.generate_hull_mask, clockwise=clockwise, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
108
+ elif mode == 'bbox':
109
+ mask_func = partial(self.generate_bbox_mask)
110
+ elif mode == 'bbox_expand':
111
+ expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
112
+ mask_func = partial(self.generate_bbox_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
113
+ else:
114
+ raise NotImplementedError
115
+ return mask_func
116
+
117
+
118
+ def generate_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
119
+ bin_mask = mask.astype(np.uint8)
120
+ if expand_ratio:
121
+ bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
122
+ return bin_mask
123
+
124
+
125
+ @staticmethod
126
+ def rand_expand_mask(mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
127
+ expand_ratio = 0.3 if expand_ratio is None else expand_ratio
128
+ expand_iters = random.randint(1, 10) if expand_iters is None else expand_iters
129
+ expand_lrtp = [random.random(), random.random(), random.random(), random.random()] if expand_lrtp is None else expand_lrtp
130
+ # print('iters', expand_iters, 'expand_ratio', expand_ratio, 'expand_lrtp', expand_lrtp)
131
+ # mask = np.squeeze(mask)
132
+ left, top, right, bottom = bbox
133
+ # mask expansion
134
+ box_w = (right - left + 1) * expand_ratio
135
+ box_h = (bottom - top + 1) * expand_ratio
136
+ left_, right_ = int(expand_lrtp[0] * min(box_w, left / 2) / expand_iters), int(
137
+ expand_lrtp[1] * min(box_w, (w - right) / 2) / expand_iters)
138
+ top_, bottom_ = int(expand_lrtp[2] * min(box_h, top / 2) / expand_iters), int(
139
+ expand_lrtp[3] * min(box_h, (h - bottom) / 2) / expand_iters)
140
+ kernel_size = max(left_, right_, top_, bottom_)
141
+ if kernel_size > 0:
142
+ kernel = np.zeros((kernel_size * 2, kernel_size * 2), dtype=np.uint8)
143
+ new_left, new_right = kernel_size - right_, kernel_size + left_
144
+ new_top, new_bottom = kernel_size - bottom_, kernel_size + top_
145
+ kernel[new_top:new_bottom + 1, new_left:new_right + 1] = 1
146
+ mask = mask.astype(np.uint8)
147
+ mask = cv2.dilate(mask, kernel, iterations=expand_iters).astype(np.uint8)
148
+ # mask = new_mask - (mask / 2).astype(np.uint8)
149
+ # mask = np.expand_dims(mask, axis=-1)
150
+ return mask
151
+
152
+
153
+ @staticmethod
154
+ def _convexhull(image, clockwise):
155
+ contours, hierarchy = cv2.findContours(image, 2, 1)
156
+ cnt = np.concatenate(contours) # merge all regions
157
+ hull = cv2.convexHull(cnt, clockwise=clockwise)
158
+ hull = np.squeeze(hull, axis=1).astype(np.float32).tolist()
159
+ hull = [tuple(x) for x in hull]
160
+ return hull # b, 1, 2
161
+
162
+ def generate_hull_mask(self, mask, bbox, h, w, clockwise=None, expand_ratio=None, expand_iters=None, expand_lrtp=None):
163
+ clockwise = random.choice([True, False]) if clockwise is None else clockwise
164
+ hull = self._convexhull(mask, clockwise)
165
+ mask_img = Image.new('L', (w, h), 0)
166
+ pt_list = hull
167
+ mask_img_draw = ImageDraw.Draw(mask_img)
168
+ mask_img_draw.polygon(pt_list, fill=255)
169
+ bin_mask = np.array(mask_img).astype(np.uint8)
170
+ if expand_ratio:
171
+ bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
172
+ return bin_mask
173
+
174
+
175
+ def generate_bbox_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
176
+ left, top, right, bottom = bbox
177
+ bin_mask = np.zeros((h, w), dtype=np.uint8)
178
+ bin_mask[top:bottom + 1, left:right + 1] = 255
179
+ if expand_ratio:
180
+ bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
181
+ return bin_mask
vace/annotators/midas/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
vace/annotators/midas/api.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # based on https://github.com/isl-org/MiDaS
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision.transforms import Compose
9
+
10
+ from .dpt_depth import DPTDepthModel
11
+ from .midas_net import MidasNet
12
+ from .midas_net_custom import MidasNet_small
13
+ from .transforms import NormalizeImage, PrepareForNet, Resize
14
+
15
+ # ISL_PATHS = {
16
+ # "dpt_large": "dpt_large-midas-2f21e586.pt",
17
+ # "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
18
+ # "midas_v21": "",
19
+ # "midas_v21_small": "",
20
+ # }
21
+
22
+ # remote_model_path =
23
+ # "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == 'dpt_large': # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = 'minimal'
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
39
+ std=[0.5, 0.5, 0.5])
40
+
41
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
42
+ net_w, net_h = 384, 384
43
+ resize_mode = 'minimal'
44
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
45
+ std=[0.5, 0.5, 0.5])
46
+
47
+ elif model_type == 'midas_v21':
48
+ net_w, net_h = 384, 384
49
+ resize_mode = 'upper_bound'
50
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
51
+ std=[0.229, 0.224, 0.225])
52
+
53
+ elif model_type == 'midas_v21_small':
54
+ net_w, net_h = 256, 256
55
+ resize_mode = 'upper_bound'
56
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225])
58
+
59
+ else:
60
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
61
+
62
+ transform = Compose([
63
+ Resize(
64
+ net_w,
65
+ net_h,
66
+ resize_target=None,
67
+ keep_aspect_ratio=True,
68
+ ensure_multiple_of=32,
69
+ resize_method=resize_mode,
70
+ image_interpolation_method=cv2.INTER_CUBIC,
71
+ ),
72
+ normalization,
73
+ PrepareForNet(),
74
+ ])
75
+
76
+ return transform
77
+
78
+
79
+ def load_model(model_type, model_path):
80
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
81
+ # load network
82
+ # model_path = ISL_PATHS[model_type]
83
+ if model_type == 'dpt_large': # DPT-Large
84
+ model = DPTDepthModel(
85
+ path=model_path,
86
+ backbone='vitl16_384',
87
+ non_negative=True,
88
+ )
89
+ net_w, net_h = 384, 384
90
+ resize_mode = 'minimal'
91
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
92
+ std=[0.5, 0.5, 0.5])
93
+
94
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
95
+ model = DPTDepthModel(
96
+ path=model_path,
97
+ backbone='vitb_rn50_384',
98
+ non_negative=True,
99
+ )
100
+ net_w, net_h = 384, 384
101
+ resize_mode = 'minimal'
102
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
103
+ std=[0.5, 0.5, 0.5])
104
+
105
+ elif model_type == 'midas_v21':
106
+ model = MidasNet(model_path, non_negative=True)
107
+ net_w, net_h = 384, 384
108
+ resize_mode = 'upper_bound'
109
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
110
+ std=[0.229, 0.224, 0.225])
111
+
112
+ elif model_type == 'midas_v21_small':
113
+ model = MidasNet_small(model_path,
114
+ features=64,
115
+ backbone='efficientnet_lite3',
116
+ exportable=True,
117
+ non_negative=True,
118
+ blocks={'expand': True})
119
+ net_w, net_h = 256, 256
120
+ resize_mode = 'upper_bound'
121
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
122
+ std=[0.229, 0.224, 0.225])
123
+
124
+ else:
125
+ print(
126
+ f"model_type '{model_type}' not implemented, use: --model_type large"
127
+ )
128
+ assert False
129
+
130
+ transform = Compose([
131
+ Resize(
132
+ net_w,
133
+ net_h,
134
+ resize_target=None,
135
+ keep_aspect_ratio=True,
136
+ ensure_multiple_of=32,
137
+ resize_method=resize_mode,
138
+ image_interpolation_method=cv2.INTER_CUBIC,
139
+ ),
140
+ normalization,
141
+ PrepareForNet(),
142
+ ])
143
+
144
+ return model.eval(), transform
145
+
146
+
147
+ class MiDaSInference(nn.Module):
148
+ MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
149
+ MODEL_TYPES_ISL = [
150
+ 'dpt_large',
151
+ 'dpt_hybrid',
152
+ 'midas_v21',
153
+ 'midas_v21_small',
154
+ ]
155
+
156
+ def __init__(self, model_type, model_path):
157
+ super().__init__()
158
+ assert (model_type in self.MODEL_TYPES_ISL)
159
+ model, _ = load_model(model_type, model_path)
160
+ self.model = model
161
+ self.model.train = disabled_train
162
+
163
+ def forward(self, x):
164
+ with torch.no_grad():
165
+ prediction = self.model(x)
166
+ return prediction
vace/annotators/midas/base_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+
5
+
6
+ class BaseModel(torch.nn.Module):
7
+ def load(self, path):
8
+ """Load model from file.
9
+
10
+ Args:
11
+ path (str): file path
12
+ """
13
+ parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
14
+
15
+ if 'optimizer' in parameters:
16
+ parameters = parameters['model']
17
+
18
+ self.load_state_dict(parameters)
vace/annotators/midas/blocks.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
7
+ _make_pretrained_vitl16_384)
8
+
9
+
10
+ def _make_encoder(
11
+ backbone,
12
+ features,
13
+ use_pretrained,
14
+ groups=1,
15
+ expand=False,
16
+ exportable=True,
17
+ hooks=None,
18
+ use_vit_only=False,
19
+ use_readout='ignore',
20
+ ):
21
+ if backbone == 'vitl16_384':
22
+ pretrained = _make_pretrained_vitl16_384(use_pretrained,
23
+ hooks=hooks,
24
+ use_readout=use_readout)
25
+ scratch = _make_scratch(
26
+ [256, 512, 1024, 1024], features, groups=groups,
27
+ expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
28
+ elif backbone == 'vitb_rn50_384':
29
+ pretrained = _make_pretrained_vitb_rn50_384(
30
+ use_pretrained,
31
+ hooks=hooks,
32
+ use_vit_only=use_vit_only,
33
+ use_readout=use_readout,
34
+ )
35
+ scratch = _make_scratch(
36
+ [256, 512, 768, 768], features, groups=groups,
37
+ expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
38
+ elif backbone == 'vitb16_384':
39
+ pretrained = _make_pretrained_vitb16_384(use_pretrained,
40
+ hooks=hooks,
41
+ use_readout=use_readout)
42
+ scratch = _make_scratch(
43
+ [96, 192, 384, 768], features, groups=groups,
44
+ expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
45
+ elif backbone == 'resnext101_wsl':
46
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
47
+ scratch = _make_scratch([256, 512, 1024, 2048],
48
+ features,
49
+ groups=groups,
50
+ expand=expand) # efficientnet_lite3
51
+ elif backbone == 'efficientnet_lite3':
52
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
53
+ exportable=exportable)
54
+ scratch = _make_scratch([32, 48, 136, 384],
55
+ features,
56
+ groups=groups,
57
+ expand=expand) # efficientnet_lite3
58
+ else:
59
+ print(f"Backbone '{backbone}' not implemented")
60
+ assert False
61
+
62
+ return pretrained, scratch
63
+
64
+
65
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
66
+ scratch = nn.Module()
67
+
68
+ out_shape1 = out_shape
69
+ out_shape2 = out_shape
70
+ out_shape3 = out_shape
71
+ out_shape4 = out_shape
72
+ if expand is True:
73
+ out_shape1 = out_shape
74
+ out_shape2 = out_shape * 2
75
+ out_shape3 = out_shape * 4
76
+ out_shape4 = out_shape * 8
77
+
78
+ scratch.layer1_rn = nn.Conv2d(in_shape[0],
79
+ out_shape1,
80
+ kernel_size=3,
81
+ stride=1,
82
+ padding=1,
83
+ bias=False,
84
+ groups=groups)
85
+ scratch.layer2_rn = nn.Conv2d(in_shape[1],
86
+ out_shape2,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1,
90
+ bias=False,
91
+ groups=groups)
92
+ scratch.layer3_rn = nn.Conv2d(in_shape[2],
93
+ out_shape3,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ groups=groups)
99
+ scratch.layer4_rn = nn.Conv2d(in_shape[3],
100
+ out_shape4,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups)
106
+
107
+ return scratch
108
+
109
+
110
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
111
+ efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
112
+ 'tf_efficientnet_lite3',
113
+ pretrained=use_pretrained,
114
+ exportable=exportable)
115
+ return _make_efficientnet_backbone(efficientnet)
116
+
117
+
118
+ def _make_efficientnet_backbone(effnet):
119
+ pretrained = nn.Module()
120
+
121
+ pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
122
+ effnet.act1, *effnet.blocks[0:2])
123
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
124
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
125
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
126
+
127
+ return pretrained
128
+
129
+
130
+ def _make_resnet_backbone(resnet):
131
+ pretrained = nn.Module()
132
+ pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
133
+ resnet.maxpool, resnet.layer1)
134
+
135
+ pretrained.layer2 = resnet.layer2
136
+ pretrained.layer3 = resnet.layer3
137
+ pretrained.layer4 = resnet.layer4
138
+
139
+ return pretrained
140
+
141
+
142
+ def _make_pretrained_resnext101_wsl(use_pretrained):
143
+ resnet = torch.hub.load('facebookresearch/WSL-Images',
144
+ 'resnext101_32x8d_wsl')
145
+ return _make_resnet_backbone(resnet)
146
+
147
+
148
+ class Interpolate(nn.Module):
149
+ """Interpolation module.
150
+ """
151
+ def __init__(self, scale_factor, mode, align_corners=False):
152
+ """Init.
153
+
154
+ Args:
155
+ scale_factor (float): scaling
156
+ mode (str): interpolation mode
157
+ """
158
+ super(Interpolate, self).__init__()
159
+
160
+ self.interp = nn.functional.interpolate
161
+ self.scale_factor = scale_factor
162
+ self.mode = mode
163
+ self.align_corners = align_corners
164
+
165
+ def forward(self, x):
166
+ """Forward pass.
167
+
168
+ Args:
169
+ x (tensor): input
170
+
171
+ Returns:
172
+ tensor: interpolated data
173
+ """
174
+
175
+ x = self.interp(x,
176
+ scale_factor=self.scale_factor,
177
+ mode=self.mode,
178
+ align_corners=self.align_corners)
179
+
180
+ return x
181
+
182
+
183
+ class ResidualConvUnit(nn.Module):
184
+ """Residual convolution module.
185
+ """
186
+ def __init__(self, features):
187
+ """Init.
188
+
189
+ Args:
190
+ features (int): number of features
191
+ """
192
+ super().__init__()
193
+
194
+ self.conv1 = nn.Conv2d(features,
195
+ features,
196
+ kernel_size=3,
197
+ stride=1,
198
+ padding=1,
199
+ bias=True)
200
+
201
+ self.conv2 = nn.Conv2d(features,
202
+ features,
203
+ kernel_size=3,
204
+ stride=1,
205
+ padding=1,
206
+ bias=True)
207
+
208
+ self.relu = nn.ReLU(inplace=True)
209
+
210
+ def forward(self, x):
211
+ """Forward pass.
212
+
213
+ Args:
214
+ x (tensor): input
215
+
216
+ Returns:
217
+ tensor: output
218
+ """
219
+ out = self.relu(x)
220
+ out = self.conv1(out)
221
+ out = self.relu(out)
222
+ out = self.conv2(out)
223
+
224
+ return out + x
225
+
226
+
227
+ class FeatureFusionBlock(nn.Module):
228
+ """Feature fusion block.
229
+ """
230
+ def __init__(self, features):
231
+ """Init.
232
+
233
+ Args:
234
+ features (int): number of features
235
+ """
236
+ super(FeatureFusionBlock, self).__init__()
237
+
238
+ self.resConfUnit1 = ResidualConvUnit(features)
239
+ self.resConfUnit2 = ResidualConvUnit(features)
240
+
241
+ def forward(self, *xs):
242
+ """Forward pass.
243
+
244
+ Returns:
245
+ tensor: output
246
+ """
247
+ output = xs[0]
248
+
249
+ if len(xs) == 2:
250
+ output += self.resConfUnit1(xs[1])
251
+
252
+ output = self.resConfUnit2(output)
253
+
254
+ output = nn.functional.interpolate(output,
255
+ scale_factor=2,
256
+ mode='bilinear',
257
+ align_corners=True)
258
+
259
+ return output
260
+
261
+
262
+ class ResidualConvUnit_custom(nn.Module):
263
+ """Residual convolution module.
264
+ """
265
+ def __init__(self, features, activation, bn):
266
+ """Init.
267
+
268
+ Args:
269
+ features (int): number of features
270
+ """
271
+ super().__init__()
272
+
273
+ self.bn = bn
274
+
275
+ self.groups = 1
276
+
277
+ self.conv1 = nn.Conv2d(features,
278
+ features,
279
+ kernel_size=3,
280
+ stride=1,
281
+ padding=1,
282
+ bias=True,
283
+ groups=self.groups)
284
+
285
+ self.conv2 = nn.Conv2d(features,
286
+ features,
287
+ kernel_size=3,
288
+ stride=1,
289
+ padding=1,
290
+ bias=True,
291
+ groups=self.groups)
292
+
293
+ if self.bn is True:
294
+ self.bn1 = nn.BatchNorm2d(features)
295
+ self.bn2 = nn.BatchNorm2d(features)
296
+
297
+ self.activation = activation
298
+
299
+ self.skip_add = nn.quantized.FloatFunctional()
300
+
301
+ def forward(self, x):
302
+ """Forward pass.
303
+
304
+ Args:
305
+ x (tensor): input
306
+
307
+ Returns:
308
+ tensor: output
309
+ """
310
+
311
+ out = self.activation(x)
312
+ out = self.conv1(out)
313
+ if self.bn is True:
314
+ out = self.bn1(out)
315
+
316
+ out = self.activation(out)
317
+ out = self.conv2(out)
318
+ if self.bn is True:
319
+ out = self.bn2(out)
320
+
321
+ if self.groups > 1:
322
+ out = self.conv_merge(out)
323
+
324
+ return self.skip_add.add(out, x)
325
+
326
+ # return out + x
327
+
328
+
329
+ class FeatureFusionBlock_custom(nn.Module):
330
+ """Feature fusion block.
331
+ """
332
+ def __init__(self,
333
+ features,
334
+ activation,
335
+ deconv=False,
336
+ bn=False,
337
+ expand=False,
338
+ align_corners=True):
339
+ """Init.
340
+
341
+ Args:
342
+ features (int): number of features
343
+ """
344
+ super(FeatureFusionBlock_custom, self).__init__()
345
+
346
+ self.deconv = deconv
347
+ self.align_corners = align_corners
348
+
349
+ self.groups = 1
350
+
351
+ self.expand = expand
352
+ out_features = features
353
+ if self.expand is True:
354
+ out_features = features // 2
355
+
356
+ self.out_conv = nn.Conv2d(features,
357
+ out_features,
358
+ kernel_size=1,
359
+ stride=1,
360
+ padding=0,
361
+ bias=True,
362
+ groups=1)
363
+
364
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
365
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
366
+
367
+ self.skip_add = nn.quantized.FloatFunctional()
368
+
369
+ def forward(self, *xs):
370
+ """Forward pass.
371
+
372
+ Returns:
373
+ tensor: output
374
+ """
375
+ output = xs[0]
376
+
377
+ if len(xs) == 2:
378
+ res = self.resConfUnit1(xs[1])
379
+ output = self.skip_add.add(output, res)
380
+ # output += res
381
+
382
+ output = self.resConfUnit2(output)
383
+
384
+ output = nn.functional.interpolate(output,
385
+ scale_factor=2,
386
+ mode='bilinear',
387
+ align_corners=self.align_corners)
388
+
389
+ output = self.out_conv(output)
390
+
391
+ return output
vace/annotators/midas/dpt_depth.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .base_model import BaseModel
7
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
8
+ from .vit import forward_vit
9
+
10
+
11
+ def _make_fusion_block(features, use_bn):
12
+ return FeatureFusionBlock_custom(
13
+ features,
14
+ nn.ReLU(False),
15
+ deconv=False,
16
+ bn=use_bn,
17
+ expand=False,
18
+ align_corners=True,
19
+ )
20
+
21
+
22
+ class DPT(BaseModel):
23
+ def __init__(
24
+ self,
25
+ head,
26
+ features=256,
27
+ backbone='vitb_rn50_384',
28
+ readout='project',
29
+ channels_last=False,
30
+ use_bn=False,
31
+ ):
32
+
33
+ super(DPT, self).__init__()
34
+
35
+ self.channels_last = channels_last
36
+
37
+ hooks = {
38
+ 'vitb_rn50_384': [0, 1, 8, 11],
39
+ 'vitb16_384': [2, 5, 8, 11],
40
+ 'vitl16_384': [5, 11, 17, 23],
41
+ }
42
+
43
+ # Instantiate backbone and reassemble blocks
44
+ self.pretrained, self.scratch = _make_encoder(
45
+ backbone,
46
+ features,
47
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
48
+ groups=1,
49
+ expand=False,
50
+ exportable=False,
51
+ hooks=hooks[backbone],
52
+ use_readout=readout,
53
+ )
54
+
55
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
56
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
57
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
58
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
59
+
60
+ self.scratch.output_conv = head
61
+
62
+ def forward(self, x):
63
+ if self.channels_last is True:
64
+ x.contiguous(memory_format=torch.channels_last)
65
+
66
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
67
+
68
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
69
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
70
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
71
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
72
+
73
+ path_4 = self.scratch.refinenet4(layer_4_rn)
74
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
75
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
76
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
77
+
78
+ out = self.scratch.output_conv(path_1)
79
+
80
+ return out
81
+
82
+
83
+ class DPTDepthModel(DPT):
84
+ def __init__(self, path=None, non_negative=True, **kwargs):
85
+ features = kwargs['features'] if 'features' in kwargs else 256
86
+
87
+ head = nn.Sequential(
88
+ nn.Conv2d(features,
89
+ features // 2,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1),
93
+ Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
94
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
95
+ nn.ReLU(True),
96
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
97
+ nn.ReLU(True) if non_negative else nn.Identity(),
98
+ nn.Identity(),
99
+ )
100
+
101
+ super().__init__(head, **kwargs)
102
+
103
+ if path is not None:
104
+ self.load(path)
105
+
106
+ def forward(self, x):
107
+ return super().forward(x).squeeze(dim=1)
vace/annotators/midas/midas_net.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
4
+ This file contains code that is adapted from
5
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .base_model import BaseModel
11
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
12
+
13
+
14
+ class MidasNet(BaseModel):
15
+ """Network for monocular depth estimation.
16
+ """
17
+ def __init__(self, path=None, features=256, non_negative=True):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print('Loading weights: ', path)
26
+
27
+ super(MidasNet, self).__init__()
28
+
29
+ use_pretrained = False if path is None else True
30
+
31
+ self.pretrained, self.scratch = _make_encoder(
32
+ backbone='resnext101_wsl',
33
+ features=features,
34
+ use_pretrained=use_pretrained)
35
+
36
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
37
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
38
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
39
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
40
+
41
+ self.scratch.output_conv = nn.Sequential(
42
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
43
+ Interpolate(scale_factor=2, mode='bilinear'),
44
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(True),
46
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
47
+ nn.ReLU(True) if non_negative else nn.Identity(),
48
+ )
49
+
50
+ if path:
51
+ self.load(path)
52
+
53
+ def forward(self, x):
54
+ """Forward pass.
55
+
56
+ Args:
57
+ x (tensor): input data (image)
58
+
59
+ Returns:
60
+ tensor: depth
61
+ """
62
+
63
+ layer_1 = self.pretrained.layer1(x)
64
+ layer_2 = self.pretrained.layer2(layer_1)
65
+ layer_3 = self.pretrained.layer3(layer_2)
66
+ layer_4 = self.pretrained.layer4(layer_3)
67
+
68
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
69
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
70
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
71
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
72
+
73
+ path_4 = self.scratch.refinenet4(layer_4_rn)
74
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
75
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
76
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
77
+
78
+ out = self.scratch.output_conv(path_1)
79
+
80
+ return torch.squeeze(out, dim=1)
vace/annotators/midas/midas_net_custom.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
4
+ This file contains code that is adapted from
5
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .base_model import BaseModel
11
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
12
+
13
+
14
+ class MidasNet_small(BaseModel):
15
+ """Network for monocular depth estimation.
16
+ """
17
+ def __init__(self,
18
+ path=None,
19
+ features=64,
20
+ backbone='efficientnet_lite3',
21
+ non_negative=True,
22
+ exportable=True,
23
+ channels_last=False,
24
+ align_corners=True,
25
+ blocks={'expand': True}):
26
+ """Init.
27
+
28
+ Args:
29
+ path (str, optional): Path to saved model. Defaults to None.
30
+ features (int, optional): Number of features. Defaults to 256.
31
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
32
+ """
33
+ print('Loading weights: ', path)
34
+
35
+ super(MidasNet_small, self).__init__()
36
+
37
+ use_pretrained = False if path else True
38
+
39
+ self.channels_last = channels_last
40
+ self.blocks = blocks
41
+ self.backbone = backbone
42
+
43
+ self.groups = 1
44
+
45
+ features1 = features
46
+ features2 = features
47
+ features3 = features
48
+ features4 = features
49
+ self.expand = False
50
+ if 'expand' in self.blocks and self.blocks['expand'] is True:
51
+ self.expand = True
52
+ features1 = features
53
+ features2 = features * 2
54
+ features3 = features * 4
55
+ features4 = features * 8
56
+
57
+ self.pretrained, self.scratch = _make_encoder(self.backbone,
58
+ features,
59
+ use_pretrained,
60
+ groups=self.groups,
61
+ expand=self.expand,
62
+ exportable=exportable)
63
+
64
+ self.scratch.activation = nn.ReLU(False)
65
+
66
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(
67
+ features4,
68
+ self.scratch.activation,
69
+ deconv=False,
70
+ bn=False,
71
+ expand=self.expand,
72
+ align_corners=align_corners)
73
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(
74
+ features3,
75
+ self.scratch.activation,
76
+ deconv=False,
77
+ bn=False,
78
+ expand=self.expand,
79
+ align_corners=align_corners)
80
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(
81
+ features2,
82
+ self.scratch.activation,
83
+ deconv=False,
84
+ bn=False,
85
+ expand=self.expand,
86
+ align_corners=align_corners)
87
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(
88
+ features1,
89
+ self.scratch.activation,
90
+ deconv=False,
91
+ bn=False,
92
+ align_corners=align_corners)
93
+
94
+ self.scratch.output_conv = nn.Sequential(
95
+ nn.Conv2d(features,
96
+ features // 2,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1,
100
+ groups=self.groups),
101
+ Interpolate(scale_factor=2, mode='bilinear'),
102
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103
+ self.scratch.activation,
104
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105
+ nn.ReLU(True) if non_negative else nn.Identity(),
106
+ nn.Identity(),
107
+ )
108
+
109
+ if path:
110
+ self.load(path)
111
+
112
+ def forward(self, x):
113
+ """Forward pass.
114
+
115
+ Args:
116
+ x (tensor): input data (image)
117
+
118
+ Returns:
119
+ tensor: depth
120
+ """
121
+ if self.channels_last is True:
122
+ print('self.channels_last = ', self.channels_last)
123
+ x.contiguous(memory_format=torch.channels_last)
124
+
125
+ layer_1 = self.pretrained.layer1(x)
126
+ layer_2 = self.pretrained.layer2(layer_1)
127
+ layer_3 = self.pretrained.layer3(layer_2)
128
+ layer_4 = self.pretrained.layer4(layer_3)
129
+
130
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
131
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
132
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
133
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
134
+
135
+ path_4 = self.scratch.refinenet4(layer_4_rn)
136
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
137
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
138
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
139
+
140
+ out = self.scratch.output_conv(path_1)
141
+
142
+ return torch.squeeze(out, dim=1)
143
+
144
+
145
+ def fuse_model(m):
146
+ prev_previous_type = nn.Identity()
147
+ prev_previous_name = ''
148
+ previous_type = nn.Identity()
149
+ previous_name = ''
150
+ for name, module in m.named_modules():
151
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
152
+ module) == nn.ReLU:
153
+ # print("FUSED ", prev_previous_name, previous_name, name)
154
+ torch.quantization.fuse_modules(
155
+ m, [prev_previous_name, previous_name, name], inplace=True)
156
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
157
+ # print("FUSED ", prev_previous_name, previous_name)
158
+ torch.quantization.fuse_modules(
159
+ m, [prev_previous_name, previous_name], inplace=True)
160
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
161
+ # print("FUSED ", previous_name, name)
162
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
163
+
164
+ prev_previous_type = previous_type
165
+ prev_previous_name = previous_name
166
+ previous_type = type(module)
167
+ previous_name = name
vace/annotators/midas/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
10
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
11
+
12
+ Args:
13
+ sample (dict): sample
14
+ size (tuple): image size
15
+
16
+ Returns:
17
+ tuple: new size
18
+ """
19
+ shape = list(sample['disparity'].shape)
20
+
21
+ if shape[0] >= size[0] and shape[1] >= size[1]:
22
+ return sample
23
+
24
+ scale = [0, 0]
25
+ scale[0] = size[0] / shape[0]
26
+ scale[1] = size[1] / shape[1]
27
+
28
+ scale = max(scale)
29
+
30
+ shape[0] = math.ceil(scale * shape[0])
31
+ shape[1] = math.ceil(scale * shape[1])
32
+
33
+ # resize
34
+ sample['image'] = cv2.resize(sample['image'],
35
+ tuple(shape[::-1]),
36
+ interpolation=image_interpolation_method)
37
+
38
+ sample['disparity'] = cv2.resize(sample['disparity'],
39
+ tuple(shape[::-1]),
40
+ interpolation=cv2.INTER_NEAREST)
41
+ sample['mask'] = cv2.resize(
42
+ sample['mask'].astype(np.float32),
43
+ tuple(shape[::-1]),
44
+ interpolation=cv2.INTER_NEAREST,
45
+ )
46
+ sample['mask'] = sample['mask'].astype(bool)
47
+
48
+ return tuple(shape)
49
+
50
+
51
+ class Resize(object):
52
+ """Resize sample to given size (width, height).
53
+ """
54
+ def __init__(
55
+ self,
56
+ width,
57
+ height,
58
+ resize_target=True,
59
+ keep_aspect_ratio=False,
60
+ ensure_multiple_of=1,
61
+ resize_method='lower_bound',
62
+ image_interpolation_method=cv2.INTER_AREA,
63
+ ):
64
+ """Init.
65
+
66
+ Args:
67
+ width (int): desired output width
68
+ height (int): desired output height
69
+ resize_target (bool, optional):
70
+ True: Resize the full sample (image, mask, target).
71
+ False: Resize image only.
72
+ Defaults to True.
73
+ keep_aspect_ratio (bool, optional):
74
+ True: Keep the aspect ratio of the input sample.
75
+ Output sample might not have the given width and height, and
76
+ resize behaviour depends on the parameter 'resize_method'.
77
+ Defaults to False.
78
+ ensure_multiple_of (int, optional):
79
+ Output width and height is constrained to be multiple of this parameter.
80
+ Defaults to 1.
81
+ resize_method (str, optional):
82
+ "lower_bound": Output will be at least as large as the given size.
83
+ "upper_bound": Output will be at max as large as the given size. "
84
+ "(Output size might be smaller than given size.)"
85
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
86
+ Defaults to "lower_bound".
87
+ """
88
+ self.__width = width
89
+ self.__height = height
90
+
91
+ self.__resize_target = resize_target
92
+ self.__keep_aspect_ratio = keep_aspect_ratio
93
+ self.__multiple_of = ensure_multiple_of
94
+ self.__resize_method = resize_method
95
+ self.__image_interpolation_method = image_interpolation_method
96
+
97
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
98
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if max_val is not None and y > max_val:
101
+ y = (np.floor(x / self.__multiple_of) *
102
+ self.__multiple_of).astype(int)
103
+
104
+ if y < min_val:
105
+ y = (np.ceil(x / self.__multiple_of) *
106
+ self.__multiple_of).astype(int)
107
+
108
+ return y
109
+
110
+ def get_size(self, width, height):
111
+ # determine new height and width
112
+ scale_height = self.__height / height
113
+ scale_width = self.__width / width
114
+
115
+ if self.__keep_aspect_ratio:
116
+ if self.__resize_method == 'lower_bound':
117
+ # scale such that output size is lower bound
118
+ if scale_width > scale_height:
119
+ # fit width
120
+ scale_height = scale_width
121
+ else:
122
+ # fit height
123
+ scale_width = scale_height
124
+ elif self.__resize_method == 'upper_bound':
125
+ # scale such that output size is upper bound
126
+ if scale_width < scale_height:
127
+ # fit width
128
+ scale_height = scale_width
129
+ else:
130
+ # fit height
131
+ scale_width = scale_height
132
+ elif self.__resize_method == 'minimal':
133
+ # scale as least as possbile
134
+ if abs(1 - scale_width) < abs(1 - scale_height):
135
+ # fit width
136
+ scale_height = scale_width
137
+ else:
138
+ # fit height
139
+ scale_width = scale_height
140
+ else:
141
+ raise ValueError(
142
+ f'resize_method {self.__resize_method} not implemented')
143
+
144
+ if self.__resize_method == 'lower_bound':
145
+ new_height = self.constrain_to_multiple_of(scale_height * height,
146
+ min_val=self.__height)
147
+ new_width = self.constrain_to_multiple_of(scale_width * width,
148
+ min_val=self.__width)
149
+ elif self.__resize_method == 'upper_bound':
150
+ new_height = self.constrain_to_multiple_of(scale_height * height,
151
+ max_val=self.__height)
152
+ new_width = self.constrain_to_multiple_of(scale_width * width,
153
+ max_val=self.__width)
154
+ elif self.__resize_method == 'minimal':
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(
159
+ f'resize_method {self.__resize_method} not implemented')
160
+
161
+ return (new_width, new_height)
162
+
163
+ def __call__(self, sample):
164
+ width, height = self.get_size(sample['image'].shape[1],
165
+ sample['image'].shape[0])
166
+
167
+ # resize sample
168
+ sample['image'] = cv2.resize(
169
+ sample['image'],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if 'disparity' in sample:
176
+ sample['disparity'] = cv2.resize(
177
+ sample['disparity'],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if 'depth' in sample:
183
+ sample['depth'] = cv2.resize(sample['depth'], (width, height),
184
+ interpolation=cv2.INTER_NEAREST)
185
+
186
+ sample['mask'] = cv2.resize(
187
+ sample['mask'].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample['mask'] = sample['mask'].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std.
198
+ """
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample['image'] = (sample['image'] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input.
211
+ """
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample['image'], (2, 0, 1))
217
+ sample['image'] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if 'mask' in sample:
220
+ sample['mask'] = sample['mask'].astype(np.float32)
221
+ sample['mask'] = np.ascontiguousarray(sample['mask'])
222
+
223
+ if 'disparity' in sample:
224
+ disparity = sample['disparity'].astype(np.float32)
225
+ sample['disparity'] = np.ascontiguousarray(disparity)
226
+
227
+ if 'depth' in sample:
228
+ depth = sample['depth'].astype(np.float32)
229
+ sample['depth'] = np.ascontiguousarray(depth)
230
+
231
+ return sample
vace/annotators/midas/utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Utils for monoDepth."""
4
+ import re
5
+ import sys
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def read_pfm(path):
13
+ """Read pfm file.
14
+
15
+ Args:
16
+ path (str): path to file
17
+
18
+ Returns:
19
+ tuple: (data, scale)
20
+ """
21
+ with open(path, 'rb') as file:
22
+
23
+ color = None
24
+ width = None
25
+ height = None
26
+ scale = None
27
+ endian = None
28
+
29
+ header = file.readline().rstrip()
30
+ if header.decode('ascii') == 'PF':
31
+ color = True
32
+ elif header.decode('ascii') == 'Pf':
33
+ color = False
34
+ else:
35
+ raise Exception('Not a PFM file: ' + path)
36
+
37
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$',
38
+ file.readline().decode('ascii'))
39
+ if dim_match:
40
+ width, height = list(map(int, dim_match.groups()))
41
+ else:
42
+ raise Exception('Malformed PFM header.')
43
+
44
+ scale = float(file.readline().decode('ascii').rstrip())
45
+ if scale < 0:
46
+ # little-endian
47
+ endian = '<'
48
+ scale = -scale
49
+ else:
50
+ # big-endian
51
+ endian = '>'
52
+
53
+ data = np.fromfile(file, endian + 'f')
54
+ shape = (height, width, 3) if color else (height, width)
55
+
56
+ data = np.reshape(data, shape)
57
+ data = np.flipud(data)
58
+
59
+ return data, scale
60
+
61
+
62
+ def write_pfm(path, image, scale=1):
63
+ """Write pfm file.
64
+
65
+ Args:
66
+ path (str): pathto file
67
+ image (array): data
68
+ scale (int, optional): Scale. Defaults to 1.
69
+ """
70
+
71
+ with open(path, 'wb') as file:
72
+ color = None
73
+
74
+ if image.dtype.name != 'float32':
75
+ raise Exception('Image dtype must be float32.')
76
+
77
+ image = np.flipud(image)
78
+
79
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
80
+ color = True
81
+ elif (len(image.shape) == 2
82
+ or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
83
+ color = False
84
+ else:
85
+ raise Exception(
86
+ 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
87
+
88
+ file.write('PF\n' if color else 'Pf\n'.encode())
89
+ file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
90
+
91
+ endian = image.dtype.byteorder
92
+
93
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
94
+ scale = -scale
95
+
96
+ file.write('%f\n'.encode() % scale)
97
+
98
+ image.tofile(file)
99
+
100
+
101
+ def read_image(path):
102
+ """Read image and output RGB image (0-1).
103
+
104
+ Args:
105
+ path (str): path to file
106
+
107
+ Returns:
108
+ array: RGB image (0-1)
109
+ """
110
+ img = cv2.imread(path)
111
+
112
+ if img.ndim == 2:
113
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
114
+
115
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
116
+
117
+ return img
118
+
119
+
120
+ def resize_image(img):
121
+ """Resize image and make it fit for network.
122
+
123
+ Args:
124
+ img (array): image
125
+
126
+ Returns:
127
+ tensor: data ready for network
128
+ """
129
+ height_orig = img.shape[0]
130
+ width_orig = img.shape[1]
131
+
132
+ if width_orig > height_orig:
133
+ scale = width_orig / 384
134
+ else:
135
+ scale = height_orig / 384
136
+
137
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
138
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
139
+
140
+ img_resized = cv2.resize(img, (width, height),
141
+ interpolation=cv2.INTER_AREA)
142
+
143
+ img_resized = (torch.from_numpy(np.transpose(
144
+ img_resized, (2, 0, 1))).contiguous().float())
145
+ img_resized = img_resized.unsqueeze(0)
146
+
147
+ return img_resized
148
+
149
+
150
+ def resize_depth(depth, width, height):
151
+ """Resize depth map and bring to CPU (numpy).
152
+
153
+ Args:
154
+ depth (tensor): depth
155
+ width (int): image width
156
+ height (int): image height
157
+
158
+ Returns:
159
+ array: processed depth
160
+ """
161
+ depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
162
+
163
+ depth_resized = cv2.resize(depth.numpy(), (width, height),
164
+ interpolation=cv2.INTER_CUBIC)
165
+
166
+ return depth_resized
167
+
168
+
169
+ def write_depth(path, depth, bits=1):
170
+ """Write depth map to pfm and png file.
171
+
172
+ Args:
173
+ path (str): filepath without extension
174
+ depth (array): depth
175
+ """
176
+ write_pfm(path + '.pfm', depth.astype(np.float32))
177
+
178
+ depth_min = depth.min()
179
+ depth_max = depth.max()
180
+
181
+ max_val = (2**(8 * bits)) - 1
182
+
183
+ if depth_max - depth_min > np.finfo('float').eps:
184
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
185
+ else:
186
+ out = np.zeros(depth.shape, dtype=depth.type)
187
+
188
+ if bits == 1:
189
+ cv2.imwrite(path + '.png', out.astype('uint8'))
190
+ elif bits == 2:
191
+ cv2.imwrite(path + '.png', out.astype('uint16'))
192
+
193
+ return