Upload 94 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- UserGuide.md +160 -0
- app.py +278 -0
- assets/images/test.jpg +3 -0
- assets/images/test2.jpg +0 -0
- assets/images/test3.jpg +3 -0
- assets/masks/test.png +0 -0
- assets/masks/test2.png +0 -0
- assets/materials/gr_infer_demo.jpg +3 -0
- assets/materials/gr_pre_demo.jpg +3 -0
- assets/materials/tasks.png +3 -0
- assets/materials/teaser.jpg +3 -0
- assets/videos/test.mp4 +3 -0
- assets/videos/test2.mp4 +0 -0
- benchmarks/.gitkeep +0 -0
- models/.gitkeep +0 -0
- pyproject.toml +75 -0
- requirements.txt +1 -0
- requirements/annotator.txt +6 -0
- requirements/framework.txt +26 -0
- tests/test_annotators.py +568 -0
- vace/__init__.py +6 -0
- vace/annotators/__init__.py +24 -0
- vace/annotators/canvas.py +60 -0
- vace/annotators/common.py +62 -0
- vace/annotators/composition.py +155 -0
- vace/annotators/depth.py +51 -0
- vace/annotators/dwpose/__init__.py +2 -0
- vace/annotators/dwpose/onnxdet.py +127 -0
- vace/annotators/dwpose/onnxpose.py +362 -0
- vace/annotators/dwpose/util.py +299 -0
- vace/annotators/dwpose/wholebody.py +80 -0
- vace/annotators/face.py +55 -0
- vace/annotators/flow.py +53 -0
- vace/annotators/frameref.py +118 -0
- vace/annotators/gdino.py +88 -0
- vace/annotators/gray.py +24 -0
- vace/annotators/inpainting.py +283 -0
- vace/annotators/layout.py +161 -0
- vace/annotators/mask.py +79 -0
- vace/annotators/maskaug.py +181 -0
- vace/annotators/midas/__init__.py +2 -0
- vace/annotators/midas/api.py +166 -0
- vace/annotators/midas/base_model.py +18 -0
- vace/annotators/midas/blocks.py +391 -0
- vace/annotators/midas/dpt_depth.py +107 -0
- vace/annotators/midas/midas_net.py +80 -0
- vace/annotators/midas/midas_net_custom.py +167 -0
- vace/annotators/midas/transforms.py +231 -0
- vace/annotators/midas/utils.py +193 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/images/test.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/images/test3.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/materials/gr_infer_demo.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/materials/gr_pre_demo.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/materials/tasks.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/materials/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/videos/test.mp4 filter=lfs diff=lfs merge=lfs -text
|
UserGuide.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VACE User Guide
|
2 |
+
|
3 |
+
## 1. Overall Steps
|
4 |
+
|
5 |
+
- Preparation: Be aware of the task type ([single task](#32-single-task) or [multi-task composition](#33-composition-task)) of your creative idea, and prepare all the required materials (images, videos, prompt, etc.)
|
6 |
+
- Preprocessing: Select the appropriate preprocessing method based task name, then preprocess your materials to meet the model's input requirements.
|
7 |
+
- Inference: Based on the preprocessed materials, perform VACE inference to obtain results.
|
8 |
+
|
9 |
+
## 2. Preparations
|
10 |
+
|
11 |
+
### 2.1 Task Definition
|
12 |
+
|
13 |
+
VACE, as a unified video generation solution, simultaneously supports Video Generation, Video Editing, and complex composition task. Specifically:
|
14 |
+
|
15 |
+
- Video Generation: No video input. Injecting concepts into the model through semantic understanding based on text and reference materials, including **T2V** (Text-to-Video Generation) and **R2V** (Reference-to-Video Generation) tasks.
|
16 |
+
- Video Editing: With video input. Modifying input video at the pixel level globally or locally,including **V2V** (Video-to-Video Editing) and **MV2V** (Masked Video-to-Video Editing).
|
17 |
+
- Composition Task: Compose two or more single task above into a complex composition task, such as **Reference Anything** (Face R2V + Object R2V), **Move Anything**(Frame R2V + Layout V2V), **Animate Anything**(R2V + Pose V2V), **Swap Anything**(R2V + Inpainting MV2V), and **Expand Anything**(Object R2V + Frame R2V + Outpainting MV2V), etc.
|
18 |
+
|
19 |
+
Single tasks and compositional tasks are illustrated in the diagram below:
|
20 |
+
|
21 |
+

|
22 |
+
|
23 |
+
|
24 |
+
### 2.2 Limitations
|
25 |
+
|
26 |
+
- Super high resolution video will be resized to proper spatial size.
|
27 |
+
- Super long video will be trimmed or uniformly sampled into around 5 seconds.
|
28 |
+
- For users who are demanding of long video generation, we recommend to generate 5s video clips one by one, while using `firstclip` video extension task to keep the temporal consistency.
|
29 |
+
|
30 |
+
## 3. Preprocessing
|
31 |
+
### 3.1 VACE-Recognizable Inputs
|
32 |
+
|
33 |
+
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including **`src_video`**, **`src_mask`**, **`src_ref_images`**, and **`prompt`**.
|
34 |
+
Specific descriptions are as follows:
|
35 |
+
|
36 |
+
- `src_video`: The video to be edited for input into the model, such as condition videos (Depth, Pose, etc.) or in/outpainting input video. **Gray areas**(values equal to 127) represent missing video part. In first-frame R2V task, the first frame are reference frame while subsequent frames are left gray. The missing parts of in/outpainting `src_video` are also set gray.
|
37 |
+
- `src_mask`: A 3D mask in the same shape of `src_video`. **White areas** represent the parts to be generated, while **black areas** represent the parts to be retained.
|
38 |
+
- `src_ref_images`: Reference images of R2V. Salient object segmentation can be performed to keep the background white.
|
39 |
+
- `prompt`: A text describing the content of the output video. Prompt expansion can be used to achieve better generation effects for LTX-Video and English user of Wan2.1. Use descriptive prompt instead of instructions.
|
40 |
+
|
41 |
+
Among them, `prompt` is required while `src_video`, `src_mask`, and `src_ref_images` are optional. For instance, MV2V task requires `src_video`, `src_mask`, and `prompt`; R2V task only requires `src_ref_images` and `prompt`.
|
42 |
+
|
43 |
+
### 3.2 Preprocessing Tools
|
44 |
+
Both command line and Gradio demo are supported.
|
45 |
+
|
46 |
+
1) Command Line: You can refer to the `run_vace_preproccess.sh` script and invoke it based on the different task types. An example command is as follows:
|
47 |
+
```bash
|
48 |
+
python vace/vace_preproccess.py --task depth --video assets/videos/test.mp4
|
49 |
+
```
|
50 |
+
|
51 |
+
2) Gradio Interactive: Launch the graphical interface for data preprocessing and perform preprocessing on the interface. The specific command is as follows:
|
52 |
+
```bash
|
53 |
+
python vace/gradios/preprocess_demo.py
|
54 |
+
```
|
55 |
+
|
56 |
+

|
57 |
+
|
58 |
+
|
59 |
+
### 3.2 Single Tasks
|
60 |
+
|
61 |
+
VACE is an all-in-one model supporting various task types. However, different preprocessing is required for these task types. The specific task types and descriptions are as follows:
|
62 |
+
|
63 |
+
| Task | Subtask | Annotator | Input modal | Params | Note |
|
64 |
+
|------------|----------------------|----------------------------|------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------|
|
65 |
+
| txt2vid | txt2vid | / | / | / | |
|
66 |
+
| control | depth | DepthVideoAnnotator | video | / | |
|
67 |
+
| control | flow | FlowVisAnnotator | video | / | |
|
68 |
+
| control | gray | GrayVideoAnnotator | video | / | |
|
69 |
+
| control | pose | PoseBodyFaceVideoAnnotator | video | / | |
|
70 |
+
| control | scribble | ScribbleVideoAnnotator | video | / | |
|
71 |
+
| control | layout_bbox | LayoutBboxAnnotator | two bboxes <br>'x1,y1,x2,y2 x1,y1,x2,y2' | / | Move linearly from the first box to the second box |
|
72 |
+
| control | layout_track | LayoutTrackAnnotator | video | mode='masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Mode represents different methods of subject tracking. |
|
73 |
+
| extension | frameref | FrameRefExpandAnnotator | image | mode='firstframe'<br>expand_num=80 (default) | |
|
74 |
+
| extension | frameref | FrameRefExpandAnnotator | image | mode='lastframe'<br>expand_num=80 (default) | |
|
75 |
+
| extension | frameref | FrameRefExpandAnnotator | two images<br>a.jpg,b.jpg | mode='firstlastframe'<br>expand_num=80 (default) | Images are separated by commas. |
|
76 |
+
| extension | clipref | FrameRefExpandAnnotator | video | mode='firstclip'<br>expand_num=80 (default) | |
|
77 |
+
| extension | clipref | FrameRefExpandAnnotator | video | mode='lastclip'<br>expand_num=80 (default) | |
|
78 |
+
| extension | clipref | FrameRefExpandAnnotator | two videos<br>a.mp4,b.mp4 | mode='firstlastclip'<br>expand_num=80 (default) | Videos are separated by commas. |
|
79 |
+
| repainting | inpainting_mask | InpaintingAnnotator | video | mode='salient' | Use salient as a fixed mask. |
|
80 |
+
| repainting | inpainting_mask | InpaintingAnnotator | video + mask | mode='mask' | Use mask as a fixed mask. |
|
81 |
+
| repainting | inpainting_bbox | InpaintingAnnotator | video + bbox<br>'x1, y1, x2, y2' | mode='bbox' | Use bbox as a fixed mask. |
|
82 |
+
| repainting | inpainting_masktrack | InpaintingAnnotator | video | mode='salientmasktrack' | Use salient mask for dynamic tracking. |
|
83 |
+
| repainting | inpainting_masktrack | InpaintingAnnotator | video | mode='salientbboxtrack' | Use salient bbox for dynamic tracking. |
|
84 |
+
| repainting | inpainting_masktrack | InpaintingAnnotator | video + mask | mode='masktrack' | Use mask for dynamic tracking. |
|
85 |
+
| repainting | inpainting_bboxtrack | InpaintingAnnotator | video + bbox<br>'x1, y1, x2, y2' | mode='bboxtrack' | Use bbox for dynamic tracking. |
|
86 |
+
| repainting | inpainting_label | InpaintingAnnotator | video + label | mode='label' | Use label for dynamic tracking. |
|
87 |
+
| repainting | inpainting_caption | InpaintingAnnotator | video + caption | mode='caption' | Use caption for dynamic tracking. |
|
88 |
+
| repainting | outpainting | OutpaintingVideoAnnotator | video | direction=left/right/up/down<br>expand_ratio=0~1.0 | Combine outpainting directions arbitrarily. |
|
89 |
+
| reference | image_reference | SubjectAnnotator | image | mode='salient/mask/bbox/salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Use different methods to obtain the subject region. |
|
90 |
+
|
91 |
+
### 3.3 Composition Task
|
92 |
+
|
93 |
+
Moreover, VACE supports combining tasks to accomplish more complex objectives. The following examples illustrate how tasks can be combined, but these combinations are not limited to the examples provided:
|
94 |
+
|
95 |
+
| Task | Subtask | Annotator | Input modal | Params | Note |
|
96 |
+
|-------------|--------------------|----------------------------|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------|
|
97 |
+
| composition | reference_anything | ReferenceAnythingAnnotator | image_list | mode='salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption' | Input no more than three images. |
|
98 |
+
| composition | animate_anything | AnimateAnythingAnnotator | image + video | mode='salientmasktrack/salientbboxtrack/masktrack/bboxtrack/label/caption' | Video for conditional redrawing; images for reference generation. |
|
99 |
+
| composition | swap_anything | SwapAnythingAnnotator | image + video | mode='masktrack/bboxtrack/label/caption'<br>maskaug_mode(optional)='original/original_expand/hull/hull_expand/bbox/bbox_expand'<br>maskaug_ratio(optional)=0~1.0 | Video for conditional redrawing; images for reference generation.<br>Comma-separated mode: first for video, second for images. |
|
100 |
+
| composition | expand_anything | ExpandAnythingAnnotator | image + image_list | mode='masktrack/bboxtrack/label/caption'<br>direction=left/right/up/down<br>expand_ratio=0~1.0<br>expand_num=80 (default) | First image for extension edit; others for reference.<br>Comma-separated mode: first for video, second for images. |
|
101 |
+
| composition | move_anything | MoveAnythingAnnotator | image + two bboxes | expand_num=80 (default) | First image for initial frame reference; others represented by linear bbox changes. |
|
102 |
+
| composition | more_anything | ... | ... | ... | ... |
|
103 |
+
|
104 |
+
|
105 |
+
## 4. Model Inference
|
106 |
+
|
107 |
+
### 4.1 Execution Methods
|
108 |
+
|
109 |
+
Both command line and Gradio demo are supported.
|
110 |
+
|
111 |
+
1) Command Line: Refer to the `run_vace_ltx.sh` and `run_vace_wan.sh` scripts and invoke them based on the different task types. The input data needs to be preprocessed to obtain parameters such as `src_video`, `src_mask`, `src_ref_images` and `prompt`. An example command is as follows:
|
112 |
+
```bash
|
113 |
+
python vace/vace_wan_inference.py --src_video <path-to-src-video> --src_mask <path-to-src-mask> --src_ref_images <paths-to-src-ref-images> --prompt <prompt> # wan
|
114 |
+
python vace/vace_ltx_inference.py --src_video <path-to-src-video> --src_mask <path-to-src-mask> --src_ref_images <paths-to-src-ref-images> --prompt <prompt> # ltx
|
115 |
+
```
|
116 |
+
|
117 |
+
2) Gradio Interactive: Launch the graphical interface for model inference and perform inference through interactions on the interface. The specific command is as follows:
|
118 |
+
```bash
|
119 |
+
python vace/gradios/vace_wan_demo.py # wan
|
120 |
+
python vace/gradios/vace_ltx_demo.py # ltx
|
121 |
+
```
|
122 |
+
|
123 |
+

|
124 |
+
|
125 |
+
3) End-to-End Inference: Refer to the `run_vace_pipeline.sh` script and invoke it based on different task types and input data. This pipeline includes both preprocessing and model inference, thereby requiring only user-provided materials. However, it offers relatively less flexibility. An example command is as follows:
|
126 |
+
```bash
|
127 |
+
python vace/vace_pipeline.py --base wan --task depth --video <path-to-video> --prompt <prompt> # wan
|
128 |
+
python vace/vace_pipeline.py --base lxt --task depth --video <path-to-video> --prompt <prompt> # ltx
|
129 |
+
```
|
130 |
+
|
131 |
+
### 4.2 Inference Examples
|
132 |
+
|
133 |
+
We provide test examples under different tasks, enabling users to validate according to their needs. These include **task**, **sub-tasks**, **original inputs** (ori_videos and ori_images), **model inputs** (src_video, src_mask, src_ref_images, prompt), and **model outputs**.
|
134 |
+
|
135 |
+
| task | subtask | src_video | src_mask | src_ref_images | out_video | prompt | ori_video | ori_images |
|
136 |
+
|-------------|--------------------|----------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
137 |
+
| txt2vid | txt2vid | | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/txt2vid/out_video.mp4"></video> | 狂风巨浪的大海,镜头缓缓推进,一艘渺小的帆船在汹涌的波涛中挣扎漂荡。海面上白沫翻滚,帆船时隐时现,仿佛随时可能被巨浪吞噬。天空乌云密布,雷声轰鸣,海鸥在空中盘旋尖叫。帆船上的人们紧紧抓住缆绳,努力保持平衡。画面风格写实,充满紧张和动感。近景特写,强调风浪的冲击力和帆船的摇晃 | | |
|
138 |
+
| extension | firstframe | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/out_video.mp4"></video> | 纪实摄影风格,前景是一位中国越野爱好者坐在越野车上,手持车载电台正在进行通联。他五官清晰,表情专注,眼神坚定地望向前方。越野车停在户外,车身略显脏污,显示出经历过的艰难路况。镜头从车外缓缓拉近,最后定格在人物的面部特写上,展现出他的坚定与热情。中景到近景,动态镜头运镜。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/firstframe/ori_image_1.png"> |
|
139 |
+
| repainting | inpainting | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/out_video.mp4"></video> | 一只巨大的金色凤凰从繁华的城市上空展翅飞过,羽毛如火焰般璀璨,闪烁着温暖的光辉,翅膀雄伟地展开。凤凰高昂着头,目光炯炯,轻轻扇动翅膀,散发出淡淡的光芒。下方是熙熙攘攘的市中心,人群惊叹,车水马龙,红蓝两色的霓虹灯在夜空下闪烁。镜头俯视城市街道,捕捉这一壮丽的景象,营造出既神秘又辉煌的氛围。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/inpainting/ori_video.mp4"></video> | |
|
140 |
+
| repainting | outpainting | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/src_mask.mp4"></video> | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/out_video.mp4"></video> | 赛博朋克风格,无人机俯瞰视角下的现代西安城墙,镜头穿过永宁门时泛起金色涟漪,城墙砖块化作数据流重组为唐代长安城。周围的街道上流动的人群和飞驰的机械交通工具交织在一起,现代与古代的交融,城墙上的灯光闪烁,形成时空隧道的效果。全息投影技术展现历史变迁,粒子重组特效细腻逼真。大远景逐渐过渡到特写,聚焦于城门特效。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/outpainting/ori_video.mp4"></video> | |
|
141 |
+
| control | depth | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/out_video.mp4"></video> | 一群年轻人在天空之城拍摄集体照。画面中,一对年轻情侣手牵手,轻声细语,相视而笑,周围是飞翔的彩色热气球和闪烁的星星,营造出浪漫的氛围。天空中,暖阳透过飘浮的云朵,洒下斑驳的光影。镜头以近景特写开始,随着情侣间的亲密互动,缓缓拉远。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/depth/ori_video.mp4"></video> | |
|
142 |
+
| control | flow | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/out_video.mp4"></video> | 纪实摄影风格,一颗鲜红的小番茄缓缓落入盛着牛奶的玻璃杯中,溅起晶莹的水花。画面以慢镜头捕捉这一瞬间,水花在空中绽放,形成美丽的弧线。玻璃杯中的牛奶纯白,番茄的鲜红与之形成鲜明对比。背景简洁,突出主体。近景特写,垂直俯视视角,展现细节之美。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/flow/ori_video.mp4"></video> | |
|
143 |
+
| control | gray | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/out_video.mp4"></video> | 镜头缓缓向右平移,身穿淡黄色坎肩长裙的长发女孩面对镜头露出灿烂的漏齿微笑。她的长发随风轻扬,眼神明亮而充满活力。背景是秋天红色和黄色的树叶,阳光透过树叶的缝隙洒下斑驳光影,营造出温馨自然的氛围。画面风格清新自然,仿佛夏日午后的一抹清凉。中景人像,强调自然光效和细腻的皮肤质感。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/gray/ori_video.mp4"></video> | |
|
144 |
+
| control | pose | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/out_video.mp4"></video> | 在一个热带的庆祝派对上,一家人围坐在椰子树下的长桌旁。桌上摆满了异国风味的美食。长辈们愉悦地交谈,年轻人兴奋地举杯碰撞,孩子们在沙滩上欢乐奔跑。背景中是湛蓝的海洋和明亮的阳光,营造出轻松的气氛。镜头以动态中景捕捉每个开心的瞬间,温暖的阳光映照着他们幸福的面庞。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/pose/ori_video.mp4"></video> | |
|
145 |
+
| control | scribble | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/out_video.mp4"></video> | 画面中荧光色彩的无人机从极低空高速掠过超现实主义风格的西安古城墙,尘埃反射着阳光。镜头快速切换至城墙上的砖石特写,阳光温暖地洒落,勾勒出每一块砖块的细腻纹理。整体画质清晰华丽,运镜流畅如水。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/scribble/ori_video.mp4"></video> | |
|
146 |
+
| control | layout | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/src_video.mp4"></video> | | | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/out_video.mp4"></video> | 视频展示了一只成鸟在树枝上的巢中喂养它的幼鸟。成鸟在喂食的过程中,幼鸟张开嘴巴等待食物。随后,成鸟飞走,幼鸟继续等待。成鸟再次飞回,带回食物喂养幼鸟。整个视频的拍摄角度固定,聚焦于巢穴和鸟类的互动,背景是模糊的绿色植被,强调了鸟类的自然行为和生态环境。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/layout/ori_video.mp4"></video> | |
|
147 |
+
| reference | face | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/face/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/face/out_video.mp4"></video> | 视频展示了一位长着尖耳朵的老人,他有一头银白色的长发和小胡子,穿着一件色彩斑斓的长袍,内搭金色衬衫,散发出神秘与智慧的气息。背景为一个华丽宫殿的内部,金碧辉煌。灯光明亮,照亮他脸上的神采奕奕。摄像机旋转动态拍摄,捕捉老人轻松挥手的动作。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/face/ori_image_1.png"> |
|
148 |
+
| reference | object | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/object/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/object/out_video.mp4"></video> | 经典游戏角色马里奥在绿松石色水下世界中,四周环绕着珊瑚和各种各样的热带鱼。马里奥兴奋地向上跳起,摆出经典的欢快姿势,身穿鲜明的蓝色潜水服,红色的潜水面���上印有“M”标志,脚上是一双潜水靴。背景中,水泡随波逐流,浮现出一个巨大而友好的海星。摄像机从水底向上快速移动,捕捉他跃出水面的瞬间,灯光明亮而流动。该场景融合了动画与幻想元素,令人惊叹。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/object/ori_image_1.png"> |
|
149 |
+
| composition | reference_anything | | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_1.png">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/src_ref_image_2.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/out_video.mp4"></video> | 一名打扮成超人的男子自信地站着,面对镜头,肩头有一只充满活力的毛绒黄色鸭子。他留着整齐的短发和浅色胡须,鸭子有橙色的喙和脚,它的翅膀稍微展开,脚分开以保持稳定。他的表情严肃而坚定。他穿着标志性的蓝红超人服装,胸前有黄色“S”标志。斗篷在他身后飘逸。背景有行人。相机位于视线水平,捕捉角色的整个上半身。灯光均匀明亮。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/ori_image_1.png">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/reference_anything/ori_image_2.png"> |
|
150 |
+
| composition | swap_anything | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_mask.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/out_video.mp4"></video> | 视频展示了一个人在宽阔的草原上骑马。他有淡紫色长发,穿着传统服饰白上衣黑裤子,动画建模画风,看起来像是在进行某种户外活动或者是在进行某种表演。背景是壮观的山脉和多云的天空,给人一种宁静而广阔的感觉。整个视频的拍摄角度是固定的,重点展示了骑手和他的马。 | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/ori_video.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/swap_anything/ori_image_1.jpg"> |
|
151 |
+
| composition | expand_anything | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_video.mp4"></video> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_mask.mp4"></video> | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/src_ref_image_1.png"> | <video controls height="200" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/out_video.mp4"></video> | 古典油画风格,背景是一条河边,画面中央一位成熟优雅的女人,穿着长裙坐在椅子上。她双手从怀里取出打开的红色心形墨镜戴上。固定机位。 | | <img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/ori_image_1.jpeg">,<img style="width: auto; height: 200px; object-fit: contain;" src="benchmarks/VACE-Benchmark/assets/examples/expand_anything/ori_image_2.png"> |
|
152 |
+
|
153 |
+
## 5. Limitations
|
154 |
+
|
155 |
+
- VACE-LTX-Video-0.9
|
156 |
+
- The prompt significantly impacts video generation quality on LTX-Video. It must be extended in accordance with the methods described in this [system prompt](https://huggingface.co/spaces/Lightricks/LTX-Video-Playground/blob/main/assets/system_prompt_i2v.txt). We also provide input parameters for using prompt extension (--use_prompt_extend).
|
157 |
+
- This model is intended for experimental research validation within the VACE paper and may not guarantee performance in real-world scenarios. However, its inference speed is very fast, capable of creating a video in 25 seconds with 40 steps on an A100 GPU, making it suitable for preliminary data and creative validation.
|
158 |
+
- VACE-Wan2.1-1.3B-Preview
|
159 |
+
- This model mainly keeps the original Wan2.1-T2V-1.3B's video quality while supporting various tasks.
|
160 |
+
- When you encounter failure cases with specific tasks, we recommend trying again with a different seed and adjusting the prompt.
|
app.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import datetime
|
8 |
+
import imageio
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-3]))
|
14 |
+
import wan
|
15 |
+
from vace.models.wan.wan_vace import WanVace
|
16 |
+
from vace.models.wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
17 |
+
|
18 |
+
|
19 |
+
class FixedSizeQueue:
|
20 |
+
def __init__(self, max_size):
|
21 |
+
self.max_size = max_size
|
22 |
+
self.queue = []
|
23 |
+
def add(self, item):
|
24 |
+
self.queue.insert(0, item)
|
25 |
+
if len(self.queue) > self.max_size:
|
26 |
+
self.queue.pop()
|
27 |
+
def get(self):
|
28 |
+
return self.queue
|
29 |
+
def __repr__(self):
|
30 |
+
return str(self.queue)
|
31 |
+
|
32 |
+
|
33 |
+
class VACEInference:
|
34 |
+
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
35 |
+
self.cfg = cfg
|
36 |
+
self.save_dir = cfg.save_dir
|
37 |
+
self.gallery_share = gallery_share
|
38 |
+
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
|
39 |
+
if not skip_load:
|
40 |
+
self.pipe = WanVace(
|
41 |
+
config=WAN_CONFIGS['vace-1.3B'],
|
42 |
+
checkpoint_dir=cfg.ckpt_dir,
|
43 |
+
device_id=0,
|
44 |
+
rank=0,
|
45 |
+
t5_fsdp=False,
|
46 |
+
dit_fsdp=False,
|
47 |
+
use_usp=False,
|
48 |
+
)
|
49 |
+
|
50 |
+
def create_ui(self, *args, **kwargs):
|
51 |
+
gr.Markdown("""
|
52 |
+
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
|
53 |
+
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
|
54 |
+
</div>
|
55 |
+
""")
|
56 |
+
with gr.Row(variant='panel', equal_height=True):
|
57 |
+
with gr.Column(scale=1, min_width=0):
|
58 |
+
self.src_video = gr.Video(
|
59 |
+
label="src_video",
|
60 |
+
sources=['upload'],
|
61 |
+
value=None,
|
62 |
+
interactive=True)
|
63 |
+
with gr.Column(scale=1, min_width=0):
|
64 |
+
self.src_mask = gr.Video(
|
65 |
+
label="src_mask",
|
66 |
+
sources=['upload'],
|
67 |
+
value=None,
|
68 |
+
interactive=True)
|
69 |
+
#
|
70 |
+
with gr.Row(variant='panel', equal_height=True):
|
71 |
+
with gr.Column(scale=1, min_width=0):
|
72 |
+
with gr.Row(equal_height=True):
|
73 |
+
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
74 |
+
height=200,
|
75 |
+
interactive=True,
|
76 |
+
type='filepath',
|
77 |
+
image_mode='RGB',
|
78 |
+
sources=['upload'],
|
79 |
+
elem_id="src_ref_image_1",
|
80 |
+
format='png')
|
81 |
+
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
82 |
+
height=200,
|
83 |
+
interactive=True,
|
84 |
+
type='filepath',
|
85 |
+
image_mode='RGB',
|
86 |
+
sources=['upload'],
|
87 |
+
elem_id="src_ref_image_2",
|
88 |
+
format='png')
|
89 |
+
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
90 |
+
height=200,
|
91 |
+
interactive=True,
|
92 |
+
type='filepath',
|
93 |
+
image_mode='RGB',
|
94 |
+
sources=['upload'],
|
95 |
+
elem_id="src_ref_image_3",
|
96 |
+
format='png')
|
97 |
+
with gr.Row(variant='panel', equal_height=True):
|
98 |
+
with gr.Column(scale=1):
|
99 |
+
self.prompt = gr.Textbox(
|
100 |
+
show_label=False,
|
101 |
+
placeholder="positive_prompt_input",
|
102 |
+
elem_id='positive_prompt',
|
103 |
+
container=True,
|
104 |
+
autofocus=True,
|
105 |
+
elem_classes='type_row',
|
106 |
+
visible=True,
|
107 |
+
lines=2)
|
108 |
+
self.negative_prompt = gr.Textbox(
|
109 |
+
show_label=False,
|
110 |
+
value=self.pipe.config.sample_neg_prompt,
|
111 |
+
placeholder="negative_prompt_input",
|
112 |
+
elem_id='negative_prompt',
|
113 |
+
container=True,
|
114 |
+
autofocus=False,
|
115 |
+
elem_classes='type_row',
|
116 |
+
visible=True,
|
117 |
+
interactive=True,
|
118 |
+
lines=1)
|
119 |
+
#
|
120 |
+
with gr.Row(variant='panel', equal_height=True):
|
121 |
+
with gr.Column(scale=1, min_width=0):
|
122 |
+
with gr.Row(equal_height=True):
|
123 |
+
self.shift_scale = gr.Slider(
|
124 |
+
label='shift_scale',
|
125 |
+
minimum=0.0,
|
126 |
+
maximum=10.0,
|
127 |
+
step=1.0,
|
128 |
+
value=8.0,
|
129 |
+
interactive=True)
|
130 |
+
self.sample_steps = gr.Slider(
|
131 |
+
label='sample_steps',
|
132 |
+
minimum=1,
|
133 |
+
maximum=100,
|
134 |
+
step=1,
|
135 |
+
value=25,
|
136 |
+
interactive=True)
|
137 |
+
self.context_scale = gr.Slider(
|
138 |
+
label='context_scale',
|
139 |
+
minimum=0.0,
|
140 |
+
maximum=2.0,
|
141 |
+
step=0.1,
|
142 |
+
value=1.0,
|
143 |
+
interactive=True)
|
144 |
+
self.guide_scale = gr.Slider(
|
145 |
+
label='guide_scale',
|
146 |
+
minimum=1,
|
147 |
+
maximum=10,
|
148 |
+
step=0.5,
|
149 |
+
value=6.0,
|
150 |
+
interactive=True)
|
151 |
+
self.infer_seed = gr.Slider(minimum=-1,
|
152 |
+
maximum=10000000,
|
153 |
+
value=2025,
|
154 |
+
label="Seed")
|
155 |
+
#
|
156 |
+
with gr.Accordion(label="Usable without source video", open=False):
|
157 |
+
with gr.Row(equal_height=True):
|
158 |
+
self.output_height = gr.Textbox(
|
159 |
+
label='resolutions_height',
|
160 |
+
value=480,
|
161 |
+
interactive=True)
|
162 |
+
self.output_width = gr.Textbox(
|
163 |
+
label='resolutions_width',
|
164 |
+
value=832,
|
165 |
+
interactive=True)
|
166 |
+
self.frame_rate = gr.Textbox(
|
167 |
+
label='frame_rate',
|
168 |
+
value=16,
|
169 |
+
interactive=True)
|
170 |
+
self.num_frames = gr.Textbox(
|
171 |
+
label='num_frames',
|
172 |
+
value=81,
|
173 |
+
interactive=True)
|
174 |
+
#
|
175 |
+
with gr.Row(equal_height=True):
|
176 |
+
with gr.Column(scale=5):
|
177 |
+
self.generate_button = gr.Button(
|
178 |
+
value='Run',
|
179 |
+
elem_classes='type_row',
|
180 |
+
elem_id='generate_button',
|
181 |
+
visible=True)
|
182 |
+
with gr.Column(scale=1):
|
183 |
+
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
|
184 |
+
#
|
185 |
+
self.output_gallery = gr.Gallery(
|
186 |
+
label="output_gallery",
|
187 |
+
value=[],
|
188 |
+
interactive=False,
|
189 |
+
allow_preview=True,
|
190 |
+
preview=True)
|
191 |
+
|
192 |
+
|
193 |
+
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
194 |
+
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
195 |
+
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
196 |
+
x is not None]
|
197 |
+
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
198 |
+
[src_mask],
|
199 |
+
[src_ref_images],
|
200 |
+
num_frames=num_frames,
|
201 |
+
image_size=SIZE_CONFIGS[f"{output_height}*{output_width}"],
|
202 |
+
device=self.pipe.device)
|
203 |
+
video = self.pipe.generate(
|
204 |
+
prompt,
|
205 |
+
src_video,
|
206 |
+
src_mask,
|
207 |
+
src_ref_images,
|
208 |
+
size=(output_width, output_height),
|
209 |
+
context_scale=context_scale,
|
210 |
+
shift=shift_scale,
|
211 |
+
sampling_steps=sample_steps,
|
212 |
+
guide_scale=guide_scale,
|
213 |
+
n_prompt=negative_prompt,
|
214 |
+
seed=infer_seed,
|
215 |
+
offload_model=True)
|
216 |
+
|
217 |
+
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
218 |
+
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
219 |
+
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
220 |
+
|
221 |
+
try:
|
222 |
+
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
223 |
+
for frame in video_frames:
|
224 |
+
writer.append_data(frame)
|
225 |
+
writer.close()
|
226 |
+
print(video_path)
|
227 |
+
except Exception as e:
|
228 |
+
raise gr.Error(f"Video save error: {e}")
|
229 |
+
|
230 |
+
if self.gallery_share:
|
231 |
+
self.gallery_share_data.add(video_path)
|
232 |
+
return self.gallery_share_data.get()
|
233 |
+
else:
|
234 |
+
return [video_path]
|
235 |
+
|
236 |
+
def set_callbacks(self, **kwargs):
|
237 |
+
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
238 |
+
self.gen_outputs = [self.output_gallery]
|
239 |
+
self.generate_button.click(self.generate,
|
240 |
+
inputs=self.gen_inputs,
|
241 |
+
outputs=self.gen_outputs,
|
242 |
+
queue=True)
|
243 |
+
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == '__main__':
|
247 |
+
parser = argparse.ArgumentParser(description='Argparser for VACE-LTXV Demo:\n')
|
248 |
+
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
249 |
+
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
250 |
+
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
251 |
+
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
252 |
+
parser.add_argument(
|
253 |
+
"--ckpt_dir",
|
254 |
+
type=str,
|
255 |
+
default='models/VACE-Wan2.1-1.3B-Preview',
|
256 |
+
help="The path to the checkpoint directory.",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--offload_to_cpu",
|
260 |
+
action="store_true",
|
261 |
+
help="Offloading unnecessary computations to CPU.",
|
262 |
+
)
|
263 |
+
|
264 |
+
args = parser.parse_args()
|
265 |
+
|
266 |
+
if not os.path.exists(args.save_dir):
|
267 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
268 |
+
|
269 |
+
with gr.Blocks() as demo:
|
270 |
+
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
271 |
+
infer_gr.create_ui()
|
272 |
+
infer_gr.set_callbacks()
|
273 |
+
allowed_paths = [args.save_dir]
|
274 |
+
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
275 |
+
server_port=args.server_port,
|
276 |
+
root_path=args.root_path,
|
277 |
+
allowed_paths=allowed_paths,
|
278 |
+
show_error=True, debug=True)
|
assets/images/test.jpg
ADDED
![]() |
Git LFS Details
|
assets/images/test2.jpg
ADDED
![]() |
assets/images/test3.jpg
ADDED
![]() |
Git LFS Details
|
assets/masks/test.png
ADDED
![]() |
assets/masks/test2.png
ADDED
![]() |
assets/materials/gr_infer_demo.jpg
ADDED
![]() |
Git LFS Details
|
assets/materials/gr_pre_demo.jpg
ADDED
![]() |
Git LFS Details
|
assets/materials/tasks.png
ADDED
![]() |
Git LFS Details
|
assets/materials/teaser.jpg
ADDED
![]() |
Git LFS Details
|
assets/videos/test.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2195efbd92773f1ee262154577c700e9c3b7a4d7d04b1a2ac421db0879c696b0
|
3 |
+
size 737090
|
assets/videos/test2.mp4
ADDED
Binary file (79.6 kB). View file
|
|
benchmarks/.gitkeep
ADDED
File without changes
|
models/.gitkeep
ADDED
File without changes
|
pyproject.toml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=42", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "vace"
|
7 |
+
version = "1.0.0"
|
8 |
+
description = "VACE: All-in-One Video Creation and Editing"
|
9 |
+
authors = [
|
10 |
+
{ name = "VACE Team", email = "[email protected]" }
|
11 |
+
]
|
12 |
+
requires-python = ">=3.10,<4.0"
|
13 |
+
readme = "README.md"
|
14 |
+
dependencies = [
|
15 |
+
"torch>=2.5.1",
|
16 |
+
"torchvision>=0.20.1",
|
17 |
+
"opencv-python>=4.9.0.80",
|
18 |
+
"diffusers>=0.31.0",
|
19 |
+
"transformers>=4.49.0",
|
20 |
+
"tokenizers>=0.20.3",
|
21 |
+
"accelerate>=1.1.1",
|
22 |
+
"gradio>=5.0.0",
|
23 |
+
"numpy>=1.23.5,<2",
|
24 |
+
"tqdm",
|
25 |
+
"imageio",
|
26 |
+
"easydict",
|
27 |
+
"ftfy",
|
28 |
+
"dashscope",
|
29 |
+
"imageio-ffmpeg",
|
30 |
+
"flash_attn",
|
31 |
+
"decord",
|
32 |
+
"einops",
|
33 |
+
"scikit-image",
|
34 |
+
"scikit-learn",
|
35 |
+
"pycocotools",
|
36 |
+
"timm",
|
37 |
+
"onnxruntime-gpu",
|
38 |
+
"BeautifulSoup4"
|
39 |
+
]
|
40 |
+
|
41 |
+
[project.optional-dependencies]
|
42 |
+
ltx = [
|
43 |
+
"ltx-video@git+https://github.com/Lightricks/[email protected]"
|
44 |
+
]
|
45 |
+
wan = [
|
46 |
+
"wan@git+https://github.com/Wan-Video/Wan2.1"
|
47 |
+
]
|
48 |
+
annotator = [
|
49 |
+
"insightface",
|
50 |
+
"sam-2@git+https://github.com/facebookresearch/sam2.git",
|
51 |
+
"segment-anything@git+https://github.com/facebookresearch/segment-anything.git",
|
52 |
+
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO.git",
|
53 |
+
"ram@git+https://github.com/xinyu1205/recognize-anything.git",
|
54 |
+
"raft@git+https://github.com/martin-chobanyan-sdc/RAFT.git"
|
55 |
+
]
|
56 |
+
|
57 |
+
[project.urls]
|
58 |
+
homepage = "https://ali-vilab.github.io/VACE-Page/"
|
59 |
+
documentation = "https://ali-vilab.github.io/VACE-Page/"
|
60 |
+
repository = "https://github.com/ali-vilab/VACE"
|
61 |
+
hfmodel = "https://huggingface.co/collections/ali-vilab/vace-67eca186ff3e3564726aff38"
|
62 |
+
msmodel = "https://modelscope.cn/collections/VACE-8fa5fcfd386e43"
|
63 |
+
paper = "https://arxiv.org/abs/2503.07598"
|
64 |
+
|
65 |
+
[tool.setuptools]
|
66 |
+
packages = { find = {} }
|
67 |
+
|
68 |
+
[tool.black]
|
69 |
+
line-length = 88
|
70 |
+
|
71 |
+
[tool.isort]
|
72 |
+
profile = "black"
|
73 |
+
|
74 |
+
[tool.mypy]
|
75 |
+
strict = true
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-r requirements/framework.txt
|
requirements/annotator.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
insightface
|
2 |
+
git+https://github.com/facebookresearch/sam2.git
|
3 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
4 |
+
git+https://github.com/IDEA-Research/GroundingDINO.git
|
5 |
+
git+https://github.com/xinyu1205/recognize-anything.git
|
6 |
+
git+https://github.com/martin-chobanyan-sdc/RAFT.git
|
requirements/framework.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.5.1
|
2 |
+
torchvision>=0.20.1
|
3 |
+
opencv-python>=4.9.0.80
|
4 |
+
diffusers>=0.31.0
|
5 |
+
transformers>=4.49.0
|
6 |
+
tokenizers>=0.20.3
|
7 |
+
accelerate>=1.1.1
|
8 |
+
gradio>=5.0.0
|
9 |
+
numpy>=1.23.5,<2
|
10 |
+
tqdm
|
11 |
+
imageio
|
12 |
+
easydict
|
13 |
+
ftfy
|
14 |
+
dashscope
|
15 |
+
imageio-ffmpeg
|
16 |
+
flash_attn
|
17 |
+
decord
|
18 |
+
einops
|
19 |
+
scikit-image
|
20 |
+
scikit-learn
|
21 |
+
pycocotools
|
22 |
+
timm
|
23 |
+
onnxruntime-gpu
|
24 |
+
BeautifulSoup4
|
25 |
+
#ltx-video@git+https://github.com/Lightricks/[email protected]
|
26 |
+
#wan@git+https://github.com/Wan-Video/Wan2.1
|
tests/test_annotators.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import unittest
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from vace.annotators.utils import read_video_frames
|
10 |
+
from vace.annotators.utils import save_one_video
|
11 |
+
|
12 |
+
class AnnotatorTest(unittest.TestCase):
|
13 |
+
def setUp(self):
|
14 |
+
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
15 |
+
self.save_dir = './cache/test_annotator'
|
16 |
+
if not os.path.exists(self.save_dir):
|
17 |
+
os.makedirs(self.save_dir)
|
18 |
+
# load test image
|
19 |
+
self.image_path = './assets/images/test.jpg'
|
20 |
+
self.image = Image.open(self.image_path).convert('RGB')
|
21 |
+
# load test video
|
22 |
+
self.video_path = './assets/videos/test.mp4'
|
23 |
+
self.frames = read_video_frames(self.video_path)
|
24 |
+
|
25 |
+
def tearDown(self):
|
26 |
+
super().tearDown()
|
27 |
+
|
28 |
+
@unittest.skip('')
|
29 |
+
def test_annotator_gray_image(self):
|
30 |
+
from vace.annotators.gray import GrayAnnotator
|
31 |
+
cfg_dict = {}
|
32 |
+
anno_ins = GrayAnnotator(cfg_dict)
|
33 |
+
anno_image = anno_ins.forward(np.array(self.image))
|
34 |
+
save_path = os.path.join(self.save_dir, 'test_gray_image.png')
|
35 |
+
Image.fromarray(anno_image).save(save_path)
|
36 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
37 |
+
|
38 |
+
@unittest.skip('')
|
39 |
+
def test_annotator_gray_video(self):
|
40 |
+
from vace.annotators.gray import GrayAnnotator
|
41 |
+
cfg_dict = {}
|
42 |
+
anno_ins = GrayAnnotator(cfg_dict)
|
43 |
+
ret_frames = []
|
44 |
+
for frame in self.frames:
|
45 |
+
anno_frame = anno_ins.forward(np.array(frame))
|
46 |
+
ret_frames.append(anno_frame)
|
47 |
+
save_path = os.path.join(self.save_dir, 'test_gray_video.mp4')
|
48 |
+
save_one_video(save_path, ret_frames, fps=16)
|
49 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
50 |
+
|
51 |
+
@unittest.skip('')
|
52 |
+
def test_annotator_gray_video_2(self):
|
53 |
+
from vace.annotators.gray import GrayVideoAnnotator
|
54 |
+
cfg_dict = {}
|
55 |
+
anno_ins = GrayVideoAnnotator(cfg_dict)
|
56 |
+
ret_frames = anno_ins.forward(self.frames)
|
57 |
+
save_path = os.path.join(self.save_dir, 'test_gray_video_2.mp4')
|
58 |
+
save_one_video(save_path, ret_frames, fps=16)
|
59 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
60 |
+
|
61 |
+
|
62 |
+
@unittest.skip('')
|
63 |
+
def test_annotator_pose_image(self):
|
64 |
+
from vace.annotators.pose import PoseBodyFaceAnnotator
|
65 |
+
cfg_dict = {
|
66 |
+
"DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
|
67 |
+
"POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
|
68 |
+
"RESIZE_SIZE": 1024
|
69 |
+
}
|
70 |
+
anno_ins = PoseBodyFaceAnnotator(cfg_dict)
|
71 |
+
anno_image = anno_ins.forward(np.array(self.image))
|
72 |
+
save_path = os.path.join(self.save_dir, 'test_pose_image.png')
|
73 |
+
Image.fromarray(anno_image).save(save_path)
|
74 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
75 |
+
|
76 |
+
@unittest.skip('')
|
77 |
+
def test_annotator_pose_video(self):
|
78 |
+
from vace.annotators.pose import PoseBodyFaceAnnotator
|
79 |
+
cfg_dict = {
|
80 |
+
"DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
|
81 |
+
"POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
|
82 |
+
"RESIZE_SIZE": 1024
|
83 |
+
}
|
84 |
+
anno_ins = PoseBodyFaceAnnotator(cfg_dict)
|
85 |
+
ret_frames = []
|
86 |
+
for frame in self.frames:
|
87 |
+
anno_frame = anno_ins.forward(np.array(frame))
|
88 |
+
ret_frames.append(anno_frame)
|
89 |
+
save_path = os.path.join(self.save_dir, 'test_pose_video.mp4')
|
90 |
+
save_one_video(save_path, ret_frames, fps=16)
|
91 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
92 |
+
|
93 |
+
@unittest.skip('')
|
94 |
+
def test_annotator_pose_video_2(self):
|
95 |
+
from vace.annotators.pose import PoseBodyFaceVideoAnnotator
|
96 |
+
cfg_dict = {
|
97 |
+
"DETECTION_MODEL": "models/VACE-Annotators/pose/yolox_l.onnx",
|
98 |
+
"POSE_MODEL": "models/VACE-Annotators/pose/dw-ll_ucoco_384.onnx",
|
99 |
+
"RESIZE_SIZE": 1024
|
100 |
+
}
|
101 |
+
anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
|
102 |
+
ret_frames = anno_ins.forward(self.frames)
|
103 |
+
save_path = os.path.join(self.save_dir, 'test_pose_video_2.mp4')
|
104 |
+
save_one_video(save_path, ret_frames, fps=16)
|
105 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
106 |
+
|
107 |
+
@unittest.skip('')
|
108 |
+
def test_annotator_depth_image(self):
|
109 |
+
from vace.annotators.depth import DepthAnnotator
|
110 |
+
cfg_dict = {
|
111 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
|
112 |
+
}
|
113 |
+
anno_ins = DepthAnnotator(cfg_dict)
|
114 |
+
anno_image = anno_ins.forward(np.array(self.image))
|
115 |
+
save_path = os.path.join(self.save_dir, 'test_depth_image.png')
|
116 |
+
Image.fromarray(anno_image).save(save_path)
|
117 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
118 |
+
|
119 |
+
@unittest.skip('')
|
120 |
+
def test_annotator_depth_video(self):
|
121 |
+
from vace.annotators.depth import DepthAnnotator
|
122 |
+
cfg_dict = {
|
123 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
|
124 |
+
}
|
125 |
+
anno_ins = DepthAnnotator(cfg_dict)
|
126 |
+
ret_frames = []
|
127 |
+
for frame in self.frames:
|
128 |
+
anno_frame = anno_ins.forward(np.array(frame))
|
129 |
+
ret_frames.append(anno_frame)
|
130 |
+
save_path = os.path.join(self.save_dir, 'test_depth_video.mp4')
|
131 |
+
save_one_video(save_path, ret_frames, fps=16)
|
132 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
133 |
+
|
134 |
+
@unittest.skip('')
|
135 |
+
def test_annotator_depth_video_2(self):
|
136 |
+
from vace.annotators.depth import DepthVideoAnnotator
|
137 |
+
cfg_dict = {
|
138 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/depth/dpt_hybrid-midas-501f0c75.pt"
|
139 |
+
}
|
140 |
+
anno_ins = DepthVideoAnnotator(cfg_dict)
|
141 |
+
ret_frames = anno_ins.forward(self.frames)
|
142 |
+
save_path = os.path.join(self.save_dir, 'test_depth_video_2.mp4')
|
143 |
+
save_one_video(save_path, ret_frames, fps=16)
|
144 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
145 |
+
|
146 |
+
@unittest.skip('')
|
147 |
+
def test_annotator_scribble_image(self):
|
148 |
+
from vace.annotators.scribble import ScribbleAnnotator
|
149 |
+
cfg_dict = {
|
150 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
|
151 |
+
}
|
152 |
+
anno_ins = ScribbleAnnotator(cfg_dict)
|
153 |
+
anno_image = anno_ins.forward(np.array(self.image))
|
154 |
+
save_path = os.path.join(self.save_dir, 'test_scribble_image.png')
|
155 |
+
Image.fromarray(anno_image).save(save_path)
|
156 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
157 |
+
|
158 |
+
@unittest.skip('')
|
159 |
+
def test_annotator_scribble_video(self):
|
160 |
+
from vace.annotators.scribble import ScribbleAnnotator
|
161 |
+
cfg_dict = {
|
162 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
|
163 |
+
}
|
164 |
+
anno_ins = ScribbleAnnotator(cfg_dict)
|
165 |
+
ret_frames = []
|
166 |
+
for frame in self.frames:
|
167 |
+
anno_frame = anno_ins.forward(np.array(frame))
|
168 |
+
ret_frames.append(anno_frame)
|
169 |
+
save_path = os.path.join(self.save_dir, 'test_scribble_video.mp4')
|
170 |
+
save_one_video(save_path, ret_frames, fps=16)
|
171 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
172 |
+
|
173 |
+
@unittest.skip('')
|
174 |
+
def test_annotator_scribble_video_2(self):
|
175 |
+
from vace.annotators.scribble import ScribbleVideoAnnotator
|
176 |
+
cfg_dict = {
|
177 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/scribble/anime_style/netG_A_latest.pth"
|
178 |
+
}
|
179 |
+
anno_ins = ScribbleVideoAnnotator(cfg_dict)
|
180 |
+
ret_frames = anno_ins.forward(self.frames)
|
181 |
+
save_path = os.path.join(self.save_dir, 'test_scribble_video_2.mp4')
|
182 |
+
save_one_video(save_path, ret_frames, fps=16)
|
183 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
184 |
+
|
185 |
+
@unittest.skip('')
|
186 |
+
def test_annotator_flow_video(self):
|
187 |
+
from vace.annotators.flow import FlowVisAnnotator
|
188 |
+
cfg_dict = {
|
189 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/flow/raft-things.pth"
|
190 |
+
}
|
191 |
+
anno_ins = FlowVisAnnotator(cfg_dict)
|
192 |
+
ret_frames = anno_ins.forward(self.frames)
|
193 |
+
save_path = os.path.join(self.save_dir, 'test_flow_video.mp4')
|
194 |
+
save_one_video(save_path, ret_frames, fps=16)
|
195 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
196 |
+
|
197 |
+
@unittest.skip('')
|
198 |
+
def test_annotator_frameref_video_1(self):
|
199 |
+
from vace.annotators.frameref import FrameRefExtractAnnotator
|
200 |
+
cfg_dict = {
|
201 |
+
"REF_CFG": [{"mode": "first", "proba": 0.1},
|
202 |
+
{"mode": "last", "proba": 0.1},
|
203 |
+
{"mode": "firstlast", "proba": 0.1},
|
204 |
+
{"mode": "random", "proba": 0.1}],
|
205 |
+
}
|
206 |
+
anno_ins = FrameRefExtractAnnotator(cfg_dict)
|
207 |
+
ret_frames, ret_masks = anno_ins.forward(self.frames, ref_num=10)
|
208 |
+
save_path = os.path.join(self.save_dir, 'test_frameref_video_1.mp4')
|
209 |
+
save_one_video(save_path, ret_frames, fps=16)
|
210 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
211 |
+
save_path = os.path.join(self.save_dir, 'test_frameref_mask_1.mp4')
|
212 |
+
save_one_video(save_path, ret_masks, fps=16)
|
213 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
214 |
+
|
215 |
+
@unittest.skip('')
|
216 |
+
def test_annotator_frameref_video_2(self):
|
217 |
+
from vace.annotators.frameref import FrameRefExpandAnnotator
|
218 |
+
cfg_dict = {}
|
219 |
+
anno_ins = FrameRefExpandAnnotator(cfg_dict)
|
220 |
+
ret_frames, ret_masks = anno_ins.forward(frames=self.frames, mode='lastclip', expand_num=50)
|
221 |
+
save_path = os.path.join(self.save_dir, 'test_frameref_video_2.mp4')
|
222 |
+
save_one_video(save_path, ret_frames, fps=16)
|
223 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
224 |
+
save_path = os.path.join(self.save_dir, 'test_frameref_mask_2.mp4')
|
225 |
+
save_one_video(save_path, ret_masks, fps=16)
|
226 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
227 |
+
|
228 |
+
|
229 |
+
@unittest.skip('')
|
230 |
+
def test_annotator_outpainting_1(self):
|
231 |
+
from vace.annotators.outpainting import OutpaintingAnnotator
|
232 |
+
cfg_dict = {
|
233 |
+
"RETURN_MASK": True,
|
234 |
+
"KEEP_PADDING_RATIO": 1,
|
235 |
+
"MASK_COLOR": "gray"
|
236 |
+
}
|
237 |
+
anno_ins = OutpaintingAnnotator(cfg_dict)
|
238 |
+
ret_data = anno_ins.forward(self.image, direction=['right', 'up', 'down'], expand_ratio=0.5)
|
239 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_image.png')
|
240 |
+
Image.fromarray(ret_data['image']).save(save_path)
|
241 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
242 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_mask.png')
|
243 |
+
Image.fromarray(ret_data['mask']).save(save_path)
|
244 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
245 |
+
|
246 |
+
@unittest.skip('')
|
247 |
+
def test_annotator_outpainting_video_1(self):
|
248 |
+
from vace.annotators.outpainting import OutpaintingVideoAnnotator
|
249 |
+
cfg_dict = {
|
250 |
+
"RETURN_MASK": True,
|
251 |
+
"KEEP_PADDING_RATIO": 1,
|
252 |
+
"MASK_COLOR": "gray"
|
253 |
+
}
|
254 |
+
anno_ins = OutpaintingVideoAnnotator(cfg_dict)
|
255 |
+
ret_data = anno_ins.forward(frames=self.frames, direction=['right', 'up', 'down'], expand_ratio=0.5)
|
256 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_video_1.mp4')
|
257 |
+
save_one_video(save_path, ret_data['frames'], fps=16)
|
258 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
259 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_mask_1.mp4')
|
260 |
+
save_one_video(save_path, ret_data['masks'], fps=16)
|
261 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
262 |
+
|
263 |
+
@unittest.skip('')
|
264 |
+
def test_annotator_outpainting_inner_1(self):
|
265 |
+
from vace.annotators.outpainting import OutpaintingInnerAnnotator
|
266 |
+
cfg_dict = {
|
267 |
+
"RETURN_MASK": True,
|
268 |
+
"KEEP_PADDING_RATIO": 1,
|
269 |
+
"MASK_COLOR": "gray"
|
270 |
+
}
|
271 |
+
anno_ins = OutpaintingInnerAnnotator(cfg_dict)
|
272 |
+
ret_data = anno_ins.forward(self.image, direction=['right', 'up', 'down'], expand_ratio=0.15)
|
273 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_inner_image.png')
|
274 |
+
Image.fromarray(ret_data['image']).save(save_path)
|
275 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
276 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_inner_mask.png')
|
277 |
+
Image.fromarray(ret_data['mask']).save(save_path)
|
278 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
279 |
+
|
280 |
+
@unittest.skip('')
|
281 |
+
def test_annotator_outpainting_inner_video_1(self):
|
282 |
+
from vace.annotators.outpainting import OutpaintingInnerVideoAnnotator
|
283 |
+
cfg_dict = {
|
284 |
+
"RETURN_MASK": True,
|
285 |
+
"KEEP_PADDING_RATIO": 1,
|
286 |
+
"MASK_COLOR": "gray"
|
287 |
+
}
|
288 |
+
anno_ins = OutpaintingInnerVideoAnnotator(cfg_dict)
|
289 |
+
ret_data = anno_ins.forward(self.frames, direction=['right', 'up', 'down'], expand_ratio=0.15)
|
290 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_inner_video_1.mp4')
|
291 |
+
save_one_video(save_path, ret_data['frames'], fps=16)
|
292 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
293 |
+
save_path = os.path.join(self.save_dir, 'test_outpainting_inner_mask_1.mp4')
|
294 |
+
save_one_video(save_path, ret_data['masks'], fps=16)
|
295 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
296 |
+
|
297 |
+
@unittest.skip('')
|
298 |
+
def test_annotator_salient(self):
|
299 |
+
from vace.annotators.salient import SalientAnnotator
|
300 |
+
cfg_dict = {
|
301 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
302 |
+
}
|
303 |
+
anno_ins = SalientAnnotator(cfg_dict)
|
304 |
+
ret_data = anno_ins.forward(self.image)
|
305 |
+
save_path = os.path.join(self.save_dir, 'test_salient_image.png')
|
306 |
+
Image.fromarray(ret_data).save(save_path)
|
307 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
308 |
+
|
309 |
+
@unittest.skip('')
|
310 |
+
def test_annotator_salient_video(self):
|
311 |
+
from vace.annotators.salient import SalientVideoAnnotator
|
312 |
+
cfg_dict = {
|
313 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
314 |
+
}
|
315 |
+
anno_ins = SalientVideoAnnotator(cfg_dict)
|
316 |
+
ret_frames = anno_ins.forward(self.frames)
|
317 |
+
save_path = os.path.join(self.save_dir, 'test_salient_video.mp4')
|
318 |
+
save_one_video(save_path, ret_frames, fps=16)
|
319 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
320 |
+
|
321 |
+
@unittest.skip('')
|
322 |
+
def test_annotator_layout_video(self):
|
323 |
+
from vace.annotators.layout import LayoutBboxAnnotator
|
324 |
+
cfg_dict = {
|
325 |
+
"RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
|
326 |
+
}
|
327 |
+
anno_ins = LayoutBboxAnnotator(cfg_dict)
|
328 |
+
ret_frames = anno_ins.forward(bbox=[(544, 288, 744, 680), (1112, 240, 1280, 712)], frame_size=(720, 1280), num_frames=49, label='person')
|
329 |
+
save_path = os.path.join(self.save_dir, 'test_layout_video.mp4')
|
330 |
+
save_one_video(save_path, ret_frames, fps=16)
|
331 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
332 |
+
|
333 |
+
@unittest.skip('')
|
334 |
+
def test_annotator_layout_mask_video(self):
|
335 |
+
# salient
|
336 |
+
from vace.annotators.salient import SalientVideoAnnotator
|
337 |
+
cfg_dict = {
|
338 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
339 |
+
}
|
340 |
+
anno_ins = SalientVideoAnnotator(cfg_dict)
|
341 |
+
salient_frames = anno_ins.forward(self.frames)
|
342 |
+
|
343 |
+
# mask layout
|
344 |
+
from vace.annotators.layout import LayoutMaskAnnotator
|
345 |
+
cfg_dict = {
|
346 |
+
"RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
|
347 |
+
}
|
348 |
+
anno_ins = LayoutMaskAnnotator(cfg_dict)
|
349 |
+
ret_frames = anno_ins.forward(salient_frames, label='cat')
|
350 |
+
save_path = os.path.join(self.save_dir, 'test_mask_layout_video.mp4')
|
351 |
+
save_one_video(save_path, ret_frames, fps=16)
|
352 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
353 |
+
|
354 |
+
@unittest.skip('')
|
355 |
+
def test_annotator_layout_mask_video_2(self):
|
356 |
+
# salient
|
357 |
+
from vace.annotators.salient import SalientVideoAnnotator
|
358 |
+
cfg_dict = {
|
359 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
360 |
+
}
|
361 |
+
anno_ins = SalientVideoAnnotator(cfg_dict)
|
362 |
+
salient_frames = anno_ins.forward(self.frames)
|
363 |
+
|
364 |
+
# mask layout
|
365 |
+
from vace.annotators.layout import LayoutMaskAnnotator
|
366 |
+
cfg_dict = {
|
367 |
+
"RAM_TAG_COLOR_PATH": "models/VACE-Annotators/layout/ram_tag_color_list.txt",
|
368 |
+
"USE_AUG": True
|
369 |
+
}
|
370 |
+
anno_ins = LayoutMaskAnnotator(cfg_dict)
|
371 |
+
ret_frames = anno_ins.forward(salient_frames, label='cat', mask_cfg={'mode': 'bbox_expand'})
|
372 |
+
save_path = os.path.join(self.save_dir, 'test_mask_layout_video_2.mp4')
|
373 |
+
save_one_video(save_path, ret_frames, fps=16)
|
374 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
375 |
+
|
376 |
+
|
377 |
+
@unittest.skip('')
|
378 |
+
def test_annotator_maskaug_video(self):
|
379 |
+
# salient
|
380 |
+
from vace.annotators.salient import SalientVideoAnnotator
|
381 |
+
cfg_dict = {
|
382 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
383 |
+
}
|
384 |
+
anno_ins = SalientVideoAnnotator(cfg_dict)
|
385 |
+
salient_frames = anno_ins.forward(self.frames)
|
386 |
+
|
387 |
+
# mask aug
|
388 |
+
from vace.annotators.maskaug import MaskAugAnnotator
|
389 |
+
cfg_dict = {}
|
390 |
+
anno_ins = MaskAugAnnotator(cfg_dict)
|
391 |
+
ret_frames = anno_ins.forward(salient_frames, mask_cfg={'mode': 'hull_expand'})
|
392 |
+
save_path = os.path.join(self.save_dir, 'test_maskaug_video.mp4')
|
393 |
+
save_one_video(save_path, ret_frames, fps=16)
|
394 |
+
print(('Testing %s: %s' % (type(self).__name__, save_path)))
|
395 |
+
|
396 |
+
|
397 |
+
@unittest.skip('')
|
398 |
+
def test_annotator_ram(self):
|
399 |
+
from vace.annotators.ram import RAMAnnotator
|
400 |
+
cfg_dict = {
|
401 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/ram/bert-base-uncased",
|
402 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/ram/ram_plus_swin_large_14m.pth",
|
403 |
+
}
|
404 |
+
anno_ins = RAMAnnotator(cfg_dict)
|
405 |
+
ret_data = anno_ins.forward(self.image)
|
406 |
+
print(ret_data)
|
407 |
+
|
408 |
+
@unittest.skip('')
|
409 |
+
def test_annotator_gdino_v1(self):
|
410 |
+
from vace.annotators.gdino import GDINOAnnotator
|
411 |
+
cfg_dict = {
|
412 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
|
413 |
+
"CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
|
414 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
|
415 |
+
}
|
416 |
+
anno_ins = GDINOAnnotator(cfg_dict)
|
417 |
+
ret_data = anno_ins.forward(self.image, caption="a cat and a vase")
|
418 |
+
print(ret_data)
|
419 |
+
|
420 |
+
@unittest.skip('')
|
421 |
+
def test_annotator_gdino_v2(self):
|
422 |
+
from vace.annotators.gdino import GDINOAnnotator
|
423 |
+
cfg_dict = {
|
424 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
|
425 |
+
"CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
|
426 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
|
427 |
+
}
|
428 |
+
anno_ins = GDINOAnnotator(cfg_dict)
|
429 |
+
ret_data = anno_ins.forward(self.image, classes=["cat", "vase"])
|
430 |
+
print(ret_data)
|
431 |
+
|
432 |
+
@unittest.skip('')
|
433 |
+
def test_annotator_gdino_with_ram(self):
|
434 |
+
from vace.annotators.gdino import GDINORAMAnnotator
|
435 |
+
cfg_dict = {
|
436 |
+
"RAM": {
|
437 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/ram/bert-base-uncased",
|
438 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/ram/ram_plus_swin_large_14m.pth",
|
439 |
+
},
|
440 |
+
"GDINO": {
|
441 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
|
442 |
+
"CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
|
443 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
|
444 |
+
}
|
445 |
+
|
446 |
+
}
|
447 |
+
anno_ins = GDINORAMAnnotator(cfg_dict)
|
448 |
+
ret_data = anno_ins.forward(self.image)
|
449 |
+
print(ret_data)
|
450 |
+
|
451 |
+
@unittest.skip('')
|
452 |
+
def test_annotator_sam2(self):
|
453 |
+
from vace.annotators.sam2 import SAM2VideoAnnotator
|
454 |
+
from vace.annotators.utils import save_sam2_video
|
455 |
+
cfg_dict = {
|
456 |
+
"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
|
457 |
+
"PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
|
458 |
+
}
|
459 |
+
anno_ins = SAM2VideoAnnotator(cfg_dict)
|
460 |
+
ret_data = anno_ins.forward(video=self.video_path, input_box=[0, 0, 640, 480])
|
461 |
+
video_segments = ret_data['annotations']
|
462 |
+
save_path = os.path.join(self.save_dir, 'test_sam2_video')
|
463 |
+
if not os.path.exists(save_path):
|
464 |
+
os.makedirs(save_path)
|
465 |
+
save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
|
466 |
+
print(save_path)
|
467 |
+
|
468 |
+
|
469 |
+
@unittest.skip('')
|
470 |
+
def test_annotator_sam2salient(self):
|
471 |
+
from vace.annotators.sam2 import SAM2SalientVideoAnnotator
|
472 |
+
from vace.annotators.utils import save_sam2_video
|
473 |
+
cfg_dict = {
|
474 |
+
"SALIENT": {
|
475 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/salient/u2net.pt",
|
476 |
+
},
|
477 |
+
"SAM2": {
|
478 |
+
"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
|
479 |
+
"PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
|
480 |
+
}
|
481 |
+
|
482 |
+
}
|
483 |
+
anno_ins = SAM2SalientVideoAnnotator(cfg_dict)
|
484 |
+
ret_data = anno_ins.forward(video=self.video_path)
|
485 |
+
video_segments = ret_data['annotations']
|
486 |
+
save_path = os.path.join(self.save_dir, 'test_sam2salient_video')
|
487 |
+
if not os.path.exists(save_path):
|
488 |
+
os.makedirs(save_path)
|
489 |
+
save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
|
490 |
+
print(save_path)
|
491 |
+
|
492 |
+
|
493 |
+
@unittest.skip('')
|
494 |
+
def test_annotator_sam2gdinoram_video(self):
|
495 |
+
from vace.annotators.sam2 import SAM2GDINOVideoAnnotator
|
496 |
+
from vace.annotators.utils import save_sam2_video
|
497 |
+
cfg_dict = {
|
498 |
+
"GDINO": {
|
499 |
+
"TOKENIZER_PATH": "models/VACE-Annotators/gdino/bert-base-uncased",
|
500 |
+
"CONFIG_PATH": "models/VACE-Annotators/gdino/GroundingDINO_SwinT_OGC_mod.py",
|
501 |
+
"PRETRAINED_MODEL": "models/VACE-Annotators/gdino/groundingdino_swint_ogc.pth",
|
502 |
+
},
|
503 |
+
"SAM2": {
|
504 |
+
"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
|
505 |
+
"PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
|
506 |
+
}
|
507 |
+
}
|
508 |
+
anno_ins = SAM2GDINOVideoAnnotator(cfg_dict)
|
509 |
+
ret_data = anno_ins.forward(video=self.video_path, classes='cat')
|
510 |
+
video_segments = ret_data['annotations']
|
511 |
+
save_path = os.path.join(self.save_dir, 'test_sam2gdino_video')
|
512 |
+
if not os.path.exists(save_path):
|
513 |
+
os.makedirs(save_path)
|
514 |
+
save_sam2_video(video_path=self.video_path, video_segments=video_segments, output_video_path=save_path)
|
515 |
+
print(save_path)
|
516 |
+
|
517 |
+
@unittest.skip('')
|
518 |
+
def test_annotator_sam2_image(self):
|
519 |
+
from vace.annotators.sam2 import SAM2ImageAnnotator
|
520 |
+
cfg_dict = {
|
521 |
+
"CONFIG_PATH": 'models/VACE-Annotators/sam2/configs/sam2.1/sam2.1_hiera_l.yaml',
|
522 |
+
"PRETRAINED_MODEL": 'models/VACE-Annotators/sam2/sam2.1_hiera_large.pt'
|
523 |
+
}
|
524 |
+
anno_ins = SAM2ImageAnnotator(cfg_dict)
|
525 |
+
ret_data = anno_ins.forward(image=self.image, input_box=[0, 0, 640, 480])
|
526 |
+
print(ret_data)
|
527 |
+
|
528 |
+
# @unittest.skip('')
|
529 |
+
def test_annotator_prompt_extend(self):
|
530 |
+
from vace.annotators.prompt_extend import PromptExtendAnnotator
|
531 |
+
from vace.configs.prompt_preprocess import WAN_LM_ZH_SYS_PROMPT, WAN_LM_EN_SYS_PROMPT, LTX_LM_EN_SYS_PROMPT
|
532 |
+
cfg_dict = {
|
533 |
+
"MODEL_NAME": "models/VACE-Annotators/llm/Qwen2.5-3B-Instruct" # "Qwen2.5_3B"
|
534 |
+
}
|
535 |
+
anno_ins = PromptExtendAnnotator(cfg_dict)
|
536 |
+
ret_data = anno_ins.forward('一位男孩', system_prompt=WAN_LM_ZH_SYS_PROMPT)
|
537 |
+
print('wan_zh:', ret_data)
|
538 |
+
ret_data = anno_ins.forward('a boy', system_prompt=WAN_LM_EN_SYS_PROMPT)
|
539 |
+
print('wan_en:', ret_data)
|
540 |
+
ret_data = anno_ins.forward('a boy', system_prompt=WAN_LM_ZH_SYS_PROMPT)
|
541 |
+
print('wan_zh en:', ret_data)
|
542 |
+
ret_data = anno_ins.forward('a boy', system_prompt=LTX_LM_EN_SYS_PROMPT)
|
543 |
+
print('ltx_en:', ret_data)
|
544 |
+
|
545 |
+
from vace.annotators.utils import get_annotator
|
546 |
+
anno_ins = get_annotator(config_type='prompt', config_task='ltx_en', return_dict=False)
|
547 |
+
ret_data = anno_ins.forward('a boy', seed=2025)
|
548 |
+
print('ltx_en:', ret_data)
|
549 |
+
ret_data = anno_ins.forward('a boy')
|
550 |
+
print('ltx_en:', ret_data)
|
551 |
+
ret_data = anno_ins.forward('a boy', seed=2025)
|
552 |
+
print('ltx_en:', ret_data)
|
553 |
+
|
554 |
+
@unittest.skip('')
|
555 |
+
def test_annotator_prompt_extend_ds(self):
|
556 |
+
from vace.annotators.utils import get_annotator
|
557 |
+
# export DASH_API_KEY=''
|
558 |
+
anno_ins = get_annotator(config_type='prompt', config_task='wan_zh_ds', return_dict=False)
|
559 |
+
ret_data = anno_ins.forward('一位男孩', seed=2025)
|
560 |
+
print('wan_zh_ds:', ret_data)
|
561 |
+
ret_data = anno_ins.forward('a boy', seed=2025)
|
562 |
+
print('wan_zh_ds:', ret_data)
|
563 |
+
|
564 |
+
|
565 |
+
# ln -s your/path/annotator_models annotator_models
|
566 |
+
# PYTHONPATH=. python tests/test_annotators.py
|
567 |
+
if __name__ == '__main__':
|
568 |
+
unittest.main()
|
vace/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from . import annotators
|
4 |
+
from . import configs
|
5 |
+
from . import models
|
6 |
+
from . import gradios
|
vace/annotators/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from .depth import DepthAnnotator, DepthVideoAnnotator
|
4 |
+
from .flow import FlowAnnotator, FlowVisAnnotator
|
5 |
+
from .frameref import FrameRefExtractAnnotator, FrameRefExpandAnnotator
|
6 |
+
from .gdino import GDINOAnnotator, GDINORAMAnnotator
|
7 |
+
from .gray import GrayAnnotator, GrayVideoAnnotator
|
8 |
+
from .inpainting import InpaintingAnnotator, InpaintingVideoAnnotator
|
9 |
+
from .layout import LayoutBboxAnnotator, LayoutMaskAnnotator, LayoutTrackAnnotator
|
10 |
+
from .maskaug import MaskAugAnnotator
|
11 |
+
from .outpainting import OutpaintingAnnotator, OutpaintingInnerAnnotator, OutpaintingVideoAnnotator, OutpaintingInnerVideoAnnotator
|
12 |
+
from .pose import PoseBodyFaceAnnotator, PoseBodyFaceVideoAnnotator, PoseAnnotator
|
13 |
+
from .ram import RAMAnnotator
|
14 |
+
from .salient import SalientAnnotator, SalientVideoAnnotator
|
15 |
+
from .sam import SAMImageAnnotator
|
16 |
+
from .sam2 import SAM2ImageAnnotator, SAM2VideoAnnotator, SAM2SalientVideoAnnotator, SAM2GDINOVideoAnnotator
|
17 |
+
from .scribble import ScribbleAnnotator, ScribbleVideoAnnotator
|
18 |
+
from .face import FaceAnnotator
|
19 |
+
from .subject import SubjectAnnotator
|
20 |
+
from .common import PlainImageAnnotator, PlainMaskAnnotator, PlainMaskAugAnnotator, PlainMaskVideoAnnotator, PlainVideoAnnotator, PlainMaskAugVideoAnnotator, PlainMaskAugInvertAnnotator, PlainMaskAugInvertVideoAnnotator, ExpandMaskVideoAnnotator
|
21 |
+
from .prompt_extend import PromptExtendAnnotator
|
22 |
+
from .composition import CompositionAnnotator, ReferenceAnythingAnnotator, AnimateAnythingAnnotator, SwapAnythingAnnotator, ExpandAnythingAnnotator, MoveAnythingAnnotator
|
23 |
+
from .mask import MaskDrawAnnotator
|
24 |
+
from .canvas import RegionCanvasAnnotator
|
vace/annotators/canvas.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .utils import convert_to_numpy
|
9 |
+
|
10 |
+
|
11 |
+
class RegionCanvasAnnotator:
|
12 |
+
def __init__(self, cfg, device=None):
|
13 |
+
self.scale_range = cfg.get('SCALE_RANGE', [0.75, 1.0])
|
14 |
+
self.canvas_value = cfg.get('CANVAS_VALUE', 255)
|
15 |
+
self.use_resize = cfg.get('USE_RESIZE', True)
|
16 |
+
self.use_canvas = cfg.get('USE_CANVAS', True)
|
17 |
+
self.use_aug = cfg.get('USE_AUG', False)
|
18 |
+
if self.use_aug:
|
19 |
+
from .maskaug import MaskAugAnnotator
|
20 |
+
self.maskaug_anno = MaskAugAnnotator(cfg={})
|
21 |
+
|
22 |
+
def forward(self, image, mask, mask_cfg=None):
|
23 |
+
|
24 |
+
image = convert_to_numpy(image)
|
25 |
+
mask = convert_to_numpy(mask)
|
26 |
+
image_h, image_w = image.shape[:2]
|
27 |
+
|
28 |
+
if self.use_aug:
|
29 |
+
mask = self.maskaug_anno.forward(mask, mask_cfg)
|
30 |
+
|
31 |
+
# get region with white bg
|
32 |
+
image[np.array(mask) == 0] = self.canvas_value
|
33 |
+
x, y, w, h = cv2.boundingRect(mask)
|
34 |
+
region_crop = image[y:y + h, x:x + w]
|
35 |
+
|
36 |
+
if self.use_resize:
|
37 |
+
# resize region
|
38 |
+
scale_min, scale_max = self.scale_range
|
39 |
+
scale_factor = random.uniform(scale_min, scale_max)
|
40 |
+
new_w, new_h = int(image_w * scale_factor), int(image_h * scale_factor)
|
41 |
+
obj_scale_factor = min(new_w/w, new_h/h)
|
42 |
+
|
43 |
+
new_w = int(w * obj_scale_factor)
|
44 |
+
new_h = int(h * obj_scale_factor)
|
45 |
+
region_crop_resized = cv2.resize(region_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
46 |
+
else:
|
47 |
+
region_crop_resized = region_crop
|
48 |
+
|
49 |
+
if self.use_canvas:
|
50 |
+
# plot region into canvas
|
51 |
+
new_canvas = np.ones_like(image) * self.canvas_value
|
52 |
+
max_x = max(0, image_w - new_w)
|
53 |
+
max_y = max(0, image_h - new_h)
|
54 |
+
new_x = random.randint(0, max_x)
|
55 |
+
new_y = random.randint(0, max_y)
|
56 |
+
|
57 |
+
new_canvas[new_y:new_y + new_h, new_x:new_x + new_w] = region_crop_resized
|
58 |
+
else:
|
59 |
+
new_canvas = region_crop_resized
|
60 |
+
return new_canvas
|
vace/annotators/common.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
class PlainImageAnnotator:
|
5 |
+
def __init__(self, cfg):
|
6 |
+
pass
|
7 |
+
def forward(self, image):
|
8 |
+
return image
|
9 |
+
|
10 |
+
class PlainVideoAnnotator:
|
11 |
+
def __init__(self, cfg):
|
12 |
+
pass
|
13 |
+
def forward(self, frames):
|
14 |
+
return frames
|
15 |
+
|
16 |
+
class PlainMaskAnnotator:
|
17 |
+
def __init__(self, cfg):
|
18 |
+
pass
|
19 |
+
def forward(self, mask):
|
20 |
+
return mask
|
21 |
+
|
22 |
+
class PlainMaskAugInvertAnnotator:
|
23 |
+
def __init__(self, cfg):
|
24 |
+
pass
|
25 |
+
def forward(self, mask):
|
26 |
+
return 255 - mask
|
27 |
+
|
28 |
+
class PlainMaskAugAnnotator:
|
29 |
+
def __init__(self, cfg):
|
30 |
+
pass
|
31 |
+
def forward(self, mask):
|
32 |
+
return mask
|
33 |
+
|
34 |
+
class PlainMaskVideoAnnotator:
|
35 |
+
def __init__(self, cfg):
|
36 |
+
pass
|
37 |
+
def forward(self, mask):
|
38 |
+
return mask
|
39 |
+
|
40 |
+
class PlainMaskAugVideoAnnotator:
|
41 |
+
def __init__(self, cfg):
|
42 |
+
pass
|
43 |
+
def forward(self, masks):
|
44 |
+
return masks
|
45 |
+
|
46 |
+
class PlainMaskAugInvertVideoAnnotator:
|
47 |
+
def __init__(self, cfg):
|
48 |
+
pass
|
49 |
+
def forward(self, masks):
|
50 |
+
return [255 - mask for mask in masks]
|
51 |
+
|
52 |
+
class ExpandMaskVideoAnnotator:
|
53 |
+
def __init__(self, cfg):
|
54 |
+
pass
|
55 |
+
def forward(self, mask, expand_num):
|
56 |
+
return [mask] * expand_num
|
57 |
+
|
58 |
+
class PlainPromptAnnotator:
|
59 |
+
def __init__(self, cfg):
|
60 |
+
pass
|
61 |
+
def forward(self, prompt):
|
62 |
+
return prompt
|
vace/annotators/composition.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class CompositionAnnotator:
|
6 |
+
def __init__(self, cfg):
|
7 |
+
self.process_types = ["repaint", "extension", "control"]
|
8 |
+
self.process_map = {
|
9 |
+
"repaint": "repaint",
|
10 |
+
"extension": "extension",
|
11 |
+
"control": "control",
|
12 |
+
"inpainting": "repaint",
|
13 |
+
"outpainting": "repaint",
|
14 |
+
"frameref": "extension",
|
15 |
+
"clipref": "extension",
|
16 |
+
"depth": "control",
|
17 |
+
"flow": "control",
|
18 |
+
"gray": "control",
|
19 |
+
"pose": "control",
|
20 |
+
"scribble": "control",
|
21 |
+
"layout": "control"
|
22 |
+
}
|
23 |
+
|
24 |
+
def forward(self, process_type_1, process_type_2, frames_1, frames_2, masks_1, masks_2):
|
25 |
+
total_frames = min(len(frames_1), len(frames_2), len(masks_1), len(masks_2))
|
26 |
+
combine_type = (self.process_map[process_type_1], self.process_map[process_type_2])
|
27 |
+
if combine_type in [("extension", "repaint"), ("extension", "control"), ("extension", "extension")]:
|
28 |
+
output_video = [frames_2[i] * masks_1[i] + frames_1[i] * (1 - masks_1[i]) for i in range(total_frames)]
|
29 |
+
output_mask = [masks_1[i] * masks_2[i] * 255 for i in range(total_frames)]
|
30 |
+
elif combine_type in [("repaint", "extension"), ("control", "extension"), ("repaint", "repaint")]:
|
31 |
+
output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)]
|
32 |
+
output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)]
|
33 |
+
elif combine_type in [("repaint", "control"), ("control", "repaint")]:
|
34 |
+
if combine_type in [("control", "repaint")]:
|
35 |
+
frames_1, frames_2, masks_1, masks_2 = frames_2, frames_1, masks_2, masks_1
|
36 |
+
output_video = [frames_1[i] * (1 - masks_1[i]) + frames_2[i] * masks_1[i] for i in range(total_frames)]
|
37 |
+
output_mask = [masks_1[i] * 255 for i in range(total_frames)]
|
38 |
+
elif combine_type in [("control", "control")]: # apply masks_2
|
39 |
+
output_video = [frames_1[i] * (1 - masks_2[i]) + frames_2[i] * masks_2[i] for i in range(total_frames)]
|
40 |
+
output_mask = [(masks_1[i] * (1 - masks_2[i]) + masks_2[i] * masks_2[i]) * 255 for i in range(total_frames)]
|
41 |
+
else:
|
42 |
+
raise Exception("Unknown combine type")
|
43 |
+
return output_video, output_mask
|
44 |
+
|
45 |
+
|
46 |
+
class ReferenceAnythingAnnotator:
|
47 |
+
def __init__(self, cfg):
|
48 |
+
from .subject import SubjectAnnotator
|
49 |
+
self.sbjref_ins = SubjectAnnotator(cfg['SUBJECT'] if 'SUBJECT' in cfg else cfg)
|
50 |
+
self.key_map = {
|
51 |
+
"image": "images",
|
52 |
+
"mask": "masks"
|
53 |
+
}
|
54 |
+
def forward(self, images, mode=None, return_mask=None, mask_cfg=None):
|
55 |
+
ret_data = {}
|
56 |
+
for image in images:
|
57 |
+
ret_one_data = self.sbjref_ins.forward(image=image, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg)
|
58 |
+
if isinstance(ret_one_data, dict):
|
59 |
+
for key, val in ret_one_data.items():
|
60 |
+
if key in self.key_map:
|
61 |
+
new_key = self.key_map[key]
|
62 |
+
else:
|
63 |
+
continue
|
64 |
+
if new_key in ret_data:
|
65 |
+
ret_data[new_key].append(val)
|
66 |
+
else:
|
67 |
+
ret_data[new_key] = [val]
|
68 |
+
else:
|
69 |
+
if 'images' in ret_data:
|
70 |
+
ret_data['images'].append(ret_data)
|
71 |
+
else:
|
72 |
+
ret_data['images'] = [ret_data]
|
73 |
+
return ret_data
|
74 |
+
|
75 |
+
|
76 |
+
class AnimateAnythingAnnotator:
|
77 |
+
def __init__(self, cfg):
|
78 |
+
from .pose import PoseBodyFaceVideoAnnotator
|
79 |
+
self.pose_ins = PoseBodyFaceVideoAnnotator(cfg['POSE'])
|
80 |
+
self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
|
81 |
+
|
82 |
+
def forward(self, frames=None, images=None, mode=None, return_mask=None, mask_cfg=None):
|
83 |
+
ret_data = {}
|
84 |
+
ret_pose_data = self.pose_ins.forward(frames=frames)
|
85 |
+
ret_data.update({"frames": ret_pose_data})
|
86 |
+
|
87 |
+
ret_ref_data = self.ref_ins.forward(images=images, mode=mode, return_mask=return_mask, mask_cfg=mask_cfg)
|
88 |
+
ret_data.update({"images": ret_ref_data['images']})
|
89 |
+
|
90 |
+
return ret_data
|
91 |
+
|
92 |
+
|
93 |
+
class SwapAnythingAnnotator:
|
94 |
+
def __init__(self, cfg):
|
95 |
+
from .inpainting import InpaintingVideoAnnotator
|
96 |
+
self.inp_ins = InpaintingVideoAnnotator(cfg['INPAINTING'])
|
97 |
+
self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
|
98 |
+
|
99 |
+
def forward(self, video=None, frames=None, images=None, mode=None, mask=None, bbox=None, label=None, caption=None, return_mask=None, mask_cfg=None):
|
100 |
+
ret_data = {}
|
101 |
+
mode = mode.split(',') if ',' in mode else [mode, mode]
|
102 |
+
|
103 |
+
ret_inp_data = self.inp_ins.forward(video=video, frames=frames, mode=mode[0], mask=mask, bbox=bbox, label=label, caption=caption, mask_cfg=mask_cfg)
|
104 |
+
ret_data.update(ret_inp_data)
|
105 |
+
|
106 |
+
ret_ref_data = self.ref_ins.forward(images=images, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg)
|
107 |
+
ret_data.update({"images": ret_ref_data['images']})
|
108 |
+
|
109 |
+
return ret_data
|
110 |
+
|
111 |
+
|
112 |
+
class ExpandAnythingAnnotator:
|
113 |
+
def __init__(self, cfg):
|
114 |
+
from .outpainting import OutpaintingAnnotator
|
115 |
+
from .frameref import FrameRefExpandAnnotator
|
116 |
+
self.ref_ins = ReferenceAnythingAnnotator(cfg['REFERENCE'])
|
117 |
+
self.frameref_ins = FrameRefExpandAnnotator(cfg['FRAMEREF'])
|
118 |
+
self.outpainting_ins = OutpaintingAnnotator(cfg['OUTPAINTING'])
|
119 |
+
|
120 |
+
def forward(self, images=None, mode=None, return_mask=None, mask_cfg=None, direction=None, expand_ratio=None, expand_num=None):
|
121 |
+
ret_data = {}
|
122 |
+
expand_image, reference_image= images[0], images[1:]
|
123 |
+
mode = mode.split(',') if ',' in mode else ['firstframe', mode]
|
124 |
+
|
125 |
+
outpainting_data = self.outpainting_ins.forward(expand_image,expand_ratio=expand_ratio, direction=direction)
|
126 |
+
outpainting_image, outpainting_mask = outpainting_data['image'], outpainting_data['mask']
|
127 |
+
|
128 |
+
frameref_data = self.frameref_ins.forward(outpainting_image, mode=mode[0], expand_num=expand_num)
|
129 |
+
frames, masks = frameref_data['frames'], frameref_data['masks']
|
130 |
+
masks[0] = outpainting_mask
|
131 |
+
ret_data.update({"frames": frames, "masks": masks})
|
132 |
+
|
133 |
+
ret_ref_data = self.ref_ins.forward(images=reference_image, mode=mode[1], return_mask=return_mask, mask_cfg=mask_cfg)
|
134 |
+
ret_data.update({"images": ret_ref_data['images']})
|
135 |
+
|
136 |
+
return ret_data
|
137 |
+
|
138 |
+
|
139 |
+
class MoveAnythingAnnotator:
|
140 |
+
def __init__(self, cfg):
|
141 |
+
from .layout import LayoutBboxAnnotator
|
142 |
+
self.layout_bbox_ins = LayoutBboxAnnotator(cfg['LAYOUTBBOX'])
|
143 |
+
|
144 |
+
def forward(self, image=None, bbox=None, label=None, expand_num=None):
|
145 |
+
frame_size = image.shape[:2] # [H, W]
|
146 |
+
ret_layout_data = self.layout_bbox_ins.forward(bbox, frame_size=frame_size, num_frames=expand_num, label=label)
|
147 |
+
|
148 |
+
out_frames = [image] + ret_layout_data
|
149 |
+
out_mask = [np.zeros(frame_size, dtype=np.uint8)] + [np.ones(frame_size, dtype=np.uint8) * 255] * len(ret_layout_data)
|
150 |
+
|
151 |
+
ret_data = {
|
152 |
+
"frames": out_frames,
|
153 |
+
"masks": out_mask
|
154 |
+
}
|
155 |
+
return ret_data
|
vace/annotators/depth.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from .utils import convert_to_numpy, resize_image, resize_image_ori
|
9 |
+
|
10 |
+
|
11 |
+
class DepthAnnotator:
|
12 |
+
def __init__(self, cfg, device=None):
|
13 |
+
from .midas.api import MiDaSInference
|
14 |
+
pretrained_model = cfg['PRETRAINED_MODEL']
|
15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
16 |
+
self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
|
17 |
+
self.a = cfg.get('A', np.pi * 2.0)
|
18 |
+
self.bg_th = cfg.get('BG_TH', 0.1)
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
@torch.inference_mode()
|
22 |
+
@torch.autocast('cuda', enabled=False)
|
23 |
+
def forward(self, image):
|
24 |
+
image = convert_to_numpy(image)
|
25 |
+
image_depth = image
|
26 |
+
h, w, c = image.shape
|
27 |
+
image_depth, k = resize_image(image_depth,
|
28 |
+
1024 if min(h, w) > 1024 else min(h, w))
|
29 |
+
image_depth = torch.from_numpy(image_depth).float().to(self.device)
|
30 |
+
image_depth = image_depth / 127.5 - 1.0
|
31 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
32 |
+
depth = self.model(image_depth)[0]
|
33 |
+
|
34 |
+
depth_pt = depth.clone()
|
35 |
+
depth_pt -= torch.min(depth_pt)
|
36 |
+
depth_pt /= torch.max(depth_pt)
|
37 |
+
depth_pt = depth_pt.cpu().numpy()
|
38 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
39 |
+
depth_image = depth_image[..., None].repeat(3, 2)
|
40 |
+
|
41 |
+
depth_image = resize_image_ori(h, w, depth_image, k)
|
42 |
+
return depth_image
|
43 |
+
|
44 |
+
|
45 |
+
class DepthVideoAnnotator(DepthAnnotator):
|
46 |
+
def forward(self, frames):
|
47 |
+
ret_frames = []
|
48 |
+
for frame in frames:
|
49 |
+
anno_frame = super().forward(np.array(frame))
|
50 |
+
ret_frames.append(anno_frame)
|
51 |
+
return ret_frames
|
vace/annotators/dwpose/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
vace/annotators/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import onnxruntime
|
7 |
+
|
8 |
+
def nms(boxes, scores, nms_thr):
|
9 |
+
"""Single class NMS implemented in Numpy."""
|
10 |
+
x1 = boxes[:, 0]
|
11 |
+
y1 = boxes[:, 1]
|
12 |
+
x2 = boxes[:, 2]
|
13 |
+
y2 = boxes[:, 3]
|
14 |
+
|
15 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
16 |
+
order = scores.argsort()[::-1]
|
17 |
+
|
18 |
+
keep = []
|
19 |
+
while order.size > 0:
|
20 |
+
i = order[0]
|
21 |
+
keep.append(i)
|
22 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
23 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
24 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
25 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
26 |
+
|
27 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
28 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
29 |
+
inter = w * h
|
30 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
31 |
+
|
32 |
+
inds = np.where(ovr <= nms_thr)[0]
|
33 |
+
order = order[inds + 1]
|
34 |
+
|
35 |
+
return keep
|
36 |
+
|
37 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
38 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
39 |
+
final_dets = []
|
40 |
+
num_classes = scores.shape[1]
|
41 |
+
for cls_ind in range(num_classes):
|
42 |
+
cls_scores = scores[:, cls_ind]
|
43 |
+
valid_score_mask = cls_scores > score_thr
|
44 |
+
if valid_score_mask.sum() == 0:
|
45 |
+
continue
|
46 |
+
else:
|
47 |
+
valid_scores = cls_scores[valid_score_mask]
|
48 |
+
valid_boxes = boxes[valid_score_mask]
|
49 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
50 |
+
if len(keep) > 0:
|
51 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
52 |
+
dets = np.concatenate(
|
53 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
54 |
+
)
|
55 |
+
final_dets.append(dets)
|
56 |
+
if len(final_dets) == 0:
|
57 |
+
return None
|
58 |
+
return np.concatenate(final_dets, 0)
|
59 |
+
|
60 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
61 |
+
grids = []
|
62 |
+
expanded_strides = []
|
63 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
64 |
+
|
65 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
66 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
67 |
+
|
68 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
69 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
70 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
71 |
+
grids.append(grid)
|
72 |
+
shape = grid.shape[:2]
|
73 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
74 |
+
|
75 |
+
grids = np.concatenate(grids, 1)
|
76 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
77 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
78 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
83 |
+
if len(img.shape) == 3:
|
84 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
85 |
+
else:
|
86 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
87 |
+
|
88 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
89 |
+
resized_img = cv2.resize(
|
90 |
+
img,
|
91 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
92 |
+
interpolation=cv2.INTER_LINEAR,
|
93 |
+
).astype(np.uint8)
|
94 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
95 |
+
|
96 |
+
padded_img = padded_img.transpose(swap)
|
97 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
98 |
+
return padded_img, r
|
99 |
+
|
100 |
+
def inference_detector(session, oriImg):
|
101 |
+
input_shape = (640,640)
|
102 |
+
img, ratio = preprocess(oriImg, input_shape)
|
103 |
+
|
104 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
105 |
+
output = session.run(None, ort_inputs)
|
106 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
107 |
+
|
108 |
+
boxes = predictions[:, :4]
|
109 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
110 |
+
|
111 |
+
boxes_xyxy = np.ones_like(boxes)
|
112 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
114 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
115 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
116 |
+
boxes_xyxy /= ratio
|
117 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
118 |
+
if dets is not None:
|
119 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
120 |
+
isscore = final_scores>0.3
|
121 |
+
iscat = final_cls_inds == 0
|
122 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
123 |
+
final_boxes = final_boxes[isbbox]
|
124 |
+
else:
|
125 |
+
final_boxes = np.array([])
|
126 |
+
|
127 |
+
return final_boxes
|
vace/annotators/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime as ort
|
8 |
+
|
9 |
+
def preprocess(
|
10 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
11 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
12 |
+
"""Do preprocessing for RTMPose model inference.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
img (np.ndarray): Input image in shape.
|
16 |
+
input_size (tuple): Input image size in shape (w, h).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple:
|
20 |
+
- resized_img (np.ndarray): Preprocessed image.
|
21 |
+
- center (np.ndarray): Center of image.
|
22 |
+
- scale (np.ndarray): Scale of image.
|
23 |
+
"""
|
24 |
+
# get shape of image
|
25 |
+
img_shape = img.shape[:2]
|
26 |
+
out_img, out_center, out_scale = [], [], []
|
27 |
+
if len(out_bbox) == 0:
|
28 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
29 |
+
for i in range(len(out_bbox)):
|
30 |
+
x0 = out_bbox[i][0]
|
31 |
+
y0 = out_bbox[i][1]
|
32 |
+
x1 = out_bbox[i][2]
|
33 |
+
y1 = out_bbox[i][3]
|
34 |
+
bbox = np.array([x0, y0, x1, y1])
|
35 |
+
|
36 |
+
# get center and scale
|
37 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
38 |
+
|
39 |
+
# do affine transformation
|
40 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
41 |
+
|
42 |
+
# normalize image
|
43 |
+
mean = np.array([123.675, 116.28, 103.53])
|
44 |
+
std = np.array([58.395, 57.12, 57.375])
|
45 |
+
resized_img = (resized_img - mean) / std
|
46 |
+
|
47 |
+
out_img.append(resized_img)
|
48 |
+
out_center.append(center)
|
49 |
+
out_scale.append(scale)
|
50 |
+
|
51 |
+
return out_img, out_center, out_scale
|
52 |
+
|
53 |
+
|
54 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
55 |
+
"""Inference RTMPose model.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
59 |
+
img (np.ndarray): Input image in shape.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
outputs (np.ndarray): Output of RTMPose model.
|
63 |
+
"""
|
64 |
+
all_out = []
|
65 |
+
# build input
|
66 |
+
for i in range(len(img)):
|
67 |
+
input = [img[i].transpose(2, 0, 1)]
|
68 |
+
|
69 |
+
# build output
|
70 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
71 |
+
sess_output = []
|
72 |
+
for out in sess.get_outputs():
|
73 |
+
sess_output.append(out.name)
|
74 |
+
|
75 |
+
# run model
|
76 |
+
outputs = sess.run(sess_output, sess_input)
|
77 |
+
all_out.append(outputs)
|
78 |
+
|
79 |
+
return all_out
|
80 |
+
|
81 |
+
|
82 |
+
def postprocess(outputs: List[np.ndarray],
|
83 |
+
model_input_size: Tuple[int, int],
|
84 |
+
center: Tuple[int, int],
|
85 |
+
scale: Tuple[int, int],
|
86 |
+
simcc_split_ratio: float = 2.0
|
87 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
88 |
+
"""Postprocess for RTMPose model output.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
outputs (np.ndarray): Output of RTMPose model.
|
92 |
+
model_input_size (tuple): RTMPose model Input image size.
|
93 |
+
center (tuple): Center of bbox in shape (x, y).
|
94 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
95 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
tuple:
|
99 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
100 |
+
- scores (np.ndarray): Model predict scores.
|
101 |
+
"""
|
102 |
+
all_key = []
|
103 |
+
all_score = []
|
104 |
+
for i in range(len(outputs)):
|
105 |
+
# use simcc to decode
|
106 |
+
simcc_x, simcc_y = outputs[i]
|
107 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
108 |
+
|
109 |
+
# rescale keypoints
|
110 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
111 |
+
all_key.append(keypoints[0])
|
112 |
+
all_score.append(scores[0])
|
113 |
+
|
114 |
+
return np.array(all_key), np.array(all_score)
|
115 |
+
|
116 |
+
|
117 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
118 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
119 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
120 |
+
|
121 |
+
Args:
|
122 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
123 |
+
as (left, top, right, bottom)
|
124 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
125 |
+
Default: 1.0
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
tuple: A tuple containing center and scale.
|
129 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
132 |
+
(n, 2)
|
133 |
+
"""
|
134 |
+
# convert single bbox from (4, ) to (1, 4)
|
135 |
+
dim = bbox.ndim
|
136 |
+
if dim == 1:
|
137 |
+
bbox = bbox[None, :]
|
138 |
+
|
139 |
+
# get bbox center and scale
|
140 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
141 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
142 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
143 |
+
|
144 |
+
if dim == 1:
|
145 |
+
center = center[0]
|
146 |
+
scale = scale[0]
|
147 |
+
|
148 |
+
return center, scale
|
149 |
+
|
150 |
+
|
151 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
152 |
+
aspect_ratio: float) -> np.ndarray:
|
153 |
+
"""Extend the scale to match the given aspect ratio.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
157 |
+
aspect_ratio (float): The ratio of ``w/h``
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
np.ndarray: The reshaped image scale in (2, )
|
161 |
+
"""
|
162 |
+
w, h = np.hsplit(bbox_scale, [1])
|
163 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
164 |
+
np.hstack([w, w / aspect_ratio]),
|
165 |
+
np.hstack([h * aspect_ratio, h]))
|
166 |
+
return bbox_scale
|
167 |
+
|
168 |
+
|
169 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
170 |
+
"""Rotate a point by an angle.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
174 |
+
angle_rad (float): rotation angle in radian
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
np.ndarray: Rotated point in shape (2, )
|
178 |
+
"""
|
179 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
180 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
181 |
+
return rot_mat @ pt
|
182 |
+
|
183 |
+
|
184 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
185 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
186 |
+
function is used to get the 3rd point, given 2D points a & b.
|
187 |
+
|
188 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
189 |
+
anticlockwise, using b as the rotation center.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
193 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
np.ndarray: The 3rd point.
|
197 |
+
"""
|
198 |
+
direction = a - b
|
199 |
+
c = b + np.r_[-direction[1], direction[0]]
|
200 |
+
return c
|
201 |
+
|
202 |
+
|
203 |
+
def get_warp_matrix(center: np.ndarray,
|
204 |
+
scale: np.ndarray,
|
205 |
+
rot: float,
|
206 |
+
output_size: Tuple[int, int],
|
207 |
+
shift: Tuple[float, float] = (0., 0.),
|
208 |
+
inv: bool = False) -> np.ndarray:
|
209 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
210 |
+
in the input image to the output size.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
214 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
215 |
+
wrt [width, height].
|
216 |
+
rot (float): Rotation angle (degree).
|
217 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
218 |
+
destination heatmaps.
|
219 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
220 |
+
Default (0., 0.).
|
221 |
+
inv (bool): Option to inverse the affine transform direction.
|
222 |
+
(inv=False: src->dst or inv=True: dst->src)
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
np.ndarray: A 2x3 transformation matrix
|
226 |
+
"""
|
227 |
+
shift = np.array(shift)
|
228 |
+
src_w = scale[0]
|
229 |
+
dst_w = output_size[0]
|
230 |
+
dst_h = output_size[1]
|
231 |
+
|
232 |
+
# compute transformation matrix
|
233 |
+
rot_rad = np.deg2rad(rot)
|
234 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
235 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
236 |
+
|
237 |
+
# get four corners of the src rectangle in the original image
|
238 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
239 |
+
src[0, :] = center + scale * shift
|
240 |
+
src[1, :] = center + src_dir + scale * shift
|
241 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
242 |
+
|
243 |
+
# get four corners of the dst rectangle in the input image
|
244 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
245 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
246 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
247 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
248 |
+
|
249 |
+
if inv:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
251 |
+
else:
|
252 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
253 |
+
|
254 |
+
return warp_mat
|
255 |
+
|
256 |
+
|
257 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
258 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
259 |
+
"""Get the bbox image as the model input by affine transform.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
input_size (dict): The input size of the model.
|
263 |
+
bbox_scale (dict): The bbox scale of the img.
|
264 |
+
bbox_center (dict): The bbox center of the img.
|
265 |
+
img (np.ndarray): The original image.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
tuple: A tuple containing center and scale.
|
269 |
+
- np.ndarray[float32]: img after affine transform.
|
270 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
271 |
+
"""
|
272 |
+
w, h = input_size
|
273 |
+
warp_size = (int(w), int(h))
|
274 |
+
|
275 |
+
# reshape bbox to fixed aspect ratio
|
276 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
277 |
+
|
278 |
+
# get the affine matrix
|
279 |
+
center = bbox_center
|
280 |
+
scale = bbox_scale
|
281 |
+
rot = 0
|
282 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
283 |
+
|
284 |
+
# do affine transform
|
285 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
286 |
+
|
287 |
+
return img, bbox_scale
|
288 |
+
|
289 |
+
|
290 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
291 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
292 |
+
"""Get maximum response location and value from simcc representations.
|
293 |
+
|
294 |
+
Note:
|
295 |
+
instance number: N
|
296 |
+
num_keypoints: K
|
297 |
+
heatmap height: H
|
298 |
+
heatmap width: W
|
299 |
+
|
300 |
+
Args:
|
301 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
302 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
tuple:
|
306 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
307 |
+
(K, 2) or (N, K, 2)
|
308 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
309 |
+
(K,) or (N, K)
|
310 |
+
"""
|
311 |
+
N, K, Wx = simcc_x.shape
|
312 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
313 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
314 |
+
|
315 |
+
# get maximum value locations
|
316 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
317 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
318 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
319 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
320 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
321 |
+
|
322 |
+
# get maximum value across x and y axis
|
323 |
+
mask = max_val_x > max_val_y
|
324 |
+
max_val_x[mask] = max_val_y[mask]
|
325 |
+
vals = max_val_x
|
326 |
+
locs[vals <= 0.] = -1
|
327 |
+
|
328 |
+
# reshape
|
329 |
+
locs = locs.reshape(N, K, 2)
|
330 |
+
vals = vals.reshape(N, K)
|
331 |
+
|
332 |
+
return locs, vals
|
333 |
+
|
334 |
+
|
335 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
336 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
337 |
+
"""Modulate simcc distribution with Gaussian.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
341 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
342 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
tuple: A tuple containing center and scale.
|
346 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
347 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
348 |
+
"""
|
349 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
350 |
+
keypoints /= simcc_split_ratio
|
351 |
+
|
352 |
+
return keypoints, scores
|
353 |
+
|
354 |
+
|
355 |
+
def inference_pose(session, out_bbox, oriImg):
|
356 |
+
h, w = session.get_inputs()[0].shape[2:]
|
357 |
+
model_input_size = (w, h)
|
358 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
359 |
+
outputs = inference(session, resized_img)
|
360 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
361 |
+
|
362 |
+
return keypoints, scores
|
vace/annotators/dwpose/util.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
eps = 0.01
|
10 |
+
|
11 |
+
|
12 |
+
def smart_resize(x, s):
|
13 |
+
Ht, Wt = s
|
14 |
+
if x.ndim == 2:
|
15 |
+
Ho, Wo = x.shape
|
16 |
+
Co = 1
|
17 |
+
else:
|
18 |
+
Ho, Wo, Co = x.shape
|
19 |
+
if Co == 3 or Co == 1:
|
20 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
21 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
22 |
+
else:
|
23 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
24 |
+
|
25 |
+
|
26 |
+
def smart_resize_k(x, fx, fy):
|
27 |
+
if x.ndim == 2:
|
28 |
+
Ho, Wo = x.shape
|
29 |
+
Co = 1
|
30 |
+
else:
|
31 |
+
Ho, Wo, Co = x.shape
|
32 |
+
Ht, Wt = Ho * fy, Wo * fx
|
33 |
+
if Co == 3 or Co == 1:
|
34 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
35 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
36 |
+
else:
|
37 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
38 |
+
|
39 |
+
|
40 |
+
def padRightDownCorner(img, stride, padValue):
|
41 |
+
h = img.shape[0]
|
42 |
+
w = img.shape[1]
|
43 |
+
|
44 |
+
pad = 4 * [None]
|
45 |
+
pad[0] = 0 # up
|
46 |
+
pad[1] = 0 # left
|
47 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
48 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
49 |
+
|
50 |
+
img_padded = img
|
51 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
52 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
53 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
54 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
55 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
56 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
57 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
58 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
59 |
+
|
60 |
+
return img_padded, pad
|
61 |
+
|
62 |
+
|
63 |
+
def transfer(model, model_weights):
|
64 |
+
transfered_model_weights = {}
|
65 |
+
for weights_name in model.state_dict().keys():
|
66 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
67 |
+
return transfered_model_weights
|
68 |
+
|
69 |
+
|
70 |
+
def draw_bodypose(canvas, candidate, subset):
|
71 |
+
H, W, C = canvas.shape
|
72 |
+
candidate = np.array(candidate)
|
73 |
+
subset = np.array(subset)
|
74 |
+
|
75 |
+
stickwidth = 4
|
76 |
+
|
77 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
78 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
79 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
80 |
+
|
81 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
82 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
83 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
84 |
+
|
85 |
+
for i in range(17):
|
86 |
+
for n in range(len(subset)):
|
87 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
88 |
+
if -1 in index:
|
89 |
+
continue
|
90 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
91 |
+
X = candidate[index.astype(int), 1] * float(H)
|
92 |
+
mX = np.mean(X)
|
93 |
+
mY = np.mean(Y)
|
94 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
95 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
96 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
97 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
98 |
+
|
99 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
100 |
+
|
101 |
+
for i in range(18):
|
102 |
+
for n in range(len(subset)):
|
103 |
+
index = int(subset[n][i])
|
104 |
+
if index == -1:
|
105 |
+
continue
|
106 |
+
x, y = candidate[index][0:2]
|
107 |
+
x = int(x * W)
|
108 |
+
y = int(y * H)
|
109 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
110 |
+
|
111 |
+
return canvas
|
112 |
+
|
113 |
+
|
114 |
+
def draw_handpose(canvas, all_hand_peaks):
|
115 |
+
H, W, C = canvas.shape
|
116 |
+
|
117 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
118 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
119 |
+
|
120 |
+
for peaks in all_hand_peaks:
|
121 |
+
peaks = np.array(peaks)
|
122 |
+
|
123 |
+
for ie, e in enumerate(edges):
|
124 |
+
x1, y1 = peaks[e[0]]
|
125 |
+
x2, y2 = peaks[e[1]]
|
126 |
+
x1 = int(x1 * W)
|
127 |
+
y1 = int(y1 * H)
|
128 |
+
x2 = int(x2 * W)
|
129 |
+
y2 = int(y2 * H)
|
130 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
131 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
132 |
+
|
133 |
+
for i, keyponit in enumerate(peaks):
|
134 |
+
x, y = keyponit
|
135 |
+
x = int(x * W)
|
136 |
+
y = int(y * H)
|
137 |
+
if x > eps and y > eps:
|
138 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
139 |
+
return canvas
|
140 |
+
|
141 |
+
|
142 |
+
def draw_facepose(canvas, all_lmks):
|
143 |
+
H, W, C = canvas.shape
|
144 |
+
for lmks in all_lmks:
|
145 |
+
lmks = np.array(lmks)
|
146 |
+
for lmk in lmks:
|
147 |
+
x, y = lmk
|
148 |
+
x = int(x * W)
|
149 |
+
y = int(y * H)
|
150 |
+
if x > eps and y > eps:
|
151 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
152 |
+
return canvas
|
153 |
+
|
154 |
+
|
155 |
+
# detect hand according to body pose keypoints
|
156 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
157 |
+
def handDetect(candidate, subset, oriImg):
|
158 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
159 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
160 |
+
ratioWristElbow = 0.33
|
161 |
+
detect_result = []
|
162 |
+
image_height, image_width = oriImg.shape[0:2]
|
163 |
+
for person in subset.astype(int):
|
164 |
+
# if any of three not detected
|
165 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
166 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
167 |
+
if not (has_left or has_right):
|
168 |
+
continue
|
169 |
+
hands = []
|
170 |
+
#left hand
|
171 |
+
if has_left:
|
172 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
173 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
174 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
175 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
176 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
177 |
+
# right hand
|
178 |
+
if has_right:
|
179 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
180 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
181 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
182 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
183 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
184 |
+
|
185 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
186 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
187 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
188 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
189 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
190 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
191 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
192 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
193 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
194 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
195 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
196 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
197 |
+
# x-y refers to the center --> offset to topLeft point
|
198 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
199 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
200 |
+
x -= width / 2
|
201 |
+
y -= width / 2 # width = height
|
202 |
+
# overflow the image
|
203 |
+
if x < 0: x = 0
|
204 |
+
if y < 0: y = 0
|
205 |
+
width1 = width
|
206 |
+
width2 = width
|
207 |
+
if x + width > image_width: width1 = image_width - x
|
208 |
+
if y + width > image_height: width2 = image_height - y
|
209 |
+
width = min(width1, width2)
|
210 |
+
# the max hand box value is 20 pixels
|
211 |
+
if width >= 20:
|
212 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
213 |
+
|
214 |
+
'''
|
215 |
+
return value: [[x, y, w, True if left hand else False]].
|
216 |
+
width=height since the network require squared input.
|
217 |
+
x, y is the coordinate of top left
|
218 |
+
'''
|
219 |
+
return detect_result
|
220 |
+
|
221 |
+
|
222 |
+
# Written by Lvmin
|
223 |
+
def faceDetect(candidate, subset, oriImg):
|
224 |
+
# left right eye ear 14 15 16 17
|
225 |
+
detect_result = []
|
226 |
+
image_height, image_width = oriImg.shape[0:2]
|
227 |
+
for person in subset.astype(int):
|
228 |
+
has_head = person[0] > -1
|
229 |
+
if not has_head:
|
230 |
+
continue
|
231 |
+
|
232 |
+
has_left_eye = person[14] > -1
|
233 |
+
has_right_eye = person[15] > -1
|
234 |
+
has_left_ear = person[16] > -1
|
235 |
+
has_right_ear = person[17] > -1
|
236 |
+
|
237 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
238 |
+
continue
|
239 |
+
|
240 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
241 |
+
|
242 |
+
width = 0.0
|
243 |
+
x0, y0 = candidate[head][:2]
|
244 |
+
|
245 |
+
if has_left_eye:
|
246 |
+
x1, y1 = candidate[left_eye][:2]
|
247 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
248 |
+
width = max(width, d * 3.0)
|
249 |
+
|
250 |
+
if has_right_eye:
|
251 |
+
x1, y1 = candidate[right_eye][:2]
|
252 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
253 |
+
width = max(width, d * 3.0)
|
254 |
+
|
255 |
+
if has_left_ear:
|
256 |
+
x1, y1 = candidate[left_ear][:2]
|
257 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
258 |
+
width = max(width, d * 1.5)
|
259 |
+
|
260 |
+
if has_right_ear:
|
261 |
+
x1, y1 = candidate[right_ear][:2]
|
262 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
263 |
+
width = max(width, d * 1.5)
|
264 |
+
|
265 |
+
x, y = x0, y0
|
266 |
+
|
267 |
+
x -= width
|
268 |
+
y -= width
|
269 |
+
|
270 |
+
if x < 0:
|
271 |
+
x = 0
|
272 |
+
|
273 |
+
if y < 0:
|
274 |
+
y = 0
|
275 |
+
|
276 |
+
width1 = width * 2
|
277 |
+
width2 = width * 2
|
278 |
+
|
279 |
+
if x + width > image_width:
|
280 |
+
width1 = image_width - x
|
281 |
+
|
282 |
+
if y + width > image_height:
|
283 |
+
width2 = image_height - y
|
284 |
+
|
285 |
+
width = min(width1, width2)
|
286 |
+
|
287 |
+
if width >= 20:
|
288 |
+
detect_result.append([int(x), int(y), int(width)])
|
289 |
+
|
290 |
+
return detect_result
|
291 |
+
|
292 |
+
|
293 |
+
# get max index of 2d array
|
294 |
+
def npmax(array):
|
295 |
+
arrayindex = array.argmax(1)
|
296 |
+
arrayvalue = array.max(1)
|
297 |
+
i = arrayvalue.argmax()
|
298 |
+
j = arrayindex[i]
|
299 |
+
return i, j
|
vace/annotators/dwpose/wholebody.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
from .onnxdet import inference_detector
|
7 |
+
from .onnxpose import inference_pose
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
39 |
+
|
40 |
+
class Wholebody:
|
41 |
+
def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
|
42 |
+
|
43 |
+
providers = ['CPUExecutionProvider'
|
44 |
+
] if device == 'cpu' else ['CUDAExecutionProvider']
|
45 |
+
# onnx_det = 'annotator/ckpts/yolox_l.onnx'
|
46 |
+
# onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
|
47 |
+
|
48 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
49 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
50 |
+
|
51 |
+
def __call__(self, ori_img):
|
52 |
+
det_result = inference_detector(self.session_det, ori_img)
|
53 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
|
54 |
+
|
55 |
+
keypoints_info = np.concatenate(
|
56 |
+
(keypoints, scores[..., None]), axis=-1)
|
57 |
+
# compute neck joint
|
58 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
59 |
+
# neck score when visualizing pred
|
60 |
+
neck[:, 2:4] = np.logical_and(
|
61 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
62 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
63 |
+
new_keypoints_info = np.insert(
|
64 |
+
keypoints_info, 17, neck, axis=1)
|
65 |
+
mmpose_idx = [
|
66 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
67 |
+
]
|
68 |
+
openpose_idx = [
|
69 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
70 |
+
]
|
71 |
+
new_keypoints_info[:, openpose_idx] = \
|
72 |
+
new_keypoints_info[:, mmpose_idx]
|
73 |
+
keypoints_info = new_keypoints_info
|
74 |
+
|
75 |
+
keypoints, scores = keypoints_info[
|
76 |
+
..., :2], keypoints_info[..., 2]
|
77 |
+
|
78 |
+
return keypoints, scores, det_result
|
79 |
+
|
80 |
+
|
vace/annotators/face.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .utils import convert_to_numpy
|
8 |
+
|
9 |
+
|
10 |
+
class FaceAnnotator:
|
11 |
+
def __init__(self, cfg, device=None):
|
12 |
+
from insightface.app import FaceAnalysis
|
13 |
+
self.return_raw = cfg.get('RETURN_RAW', True)
|
14 |
+
self.return_mask = cfg.get('RETURN_MASK', False)
|
15 |
+
self.return_dict = cfg.get('RETURN_DICT', False)
|
16 |
+
self.multi_face = cfg.get('MULTI_FACE', True)
|
17 |
+
pretrained_model = cfg['PRETRAINED_MODEL']
|
18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
19 |
+
self.device_id = self.device.index if self.device.type == 'cuda' else None
|
20 |
+
ctx_id = self.device_id if self.device_id is not None else 0
|
21 |
+
self.model = FaceAnalysis(name=cfg.MODEL_NAME, root=pretrained_model, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
22 |
+
self.model.prepare(ctx_id=ctx_id, det_size=(640, 640))
|
23 |
+
|
24 |
+
def forward(self, image=None, return_mask=None, return_dict=None):
|
25 |
+
return_mask = return_mask if return_mask is not None else self.return_mask
|
26 |
+
return_dict = return_dict if return_dict is not None else self.return_dict
|
27 |
+
image = convert_to_numpy(image)
|
28 |
+
# [dict_keys(['bbox', 'kps', 'det_score', 'landmark_3d_68', 'pose', 'landmark_2d_106', 'gender', 'age', 'embedding'])]
|
29 |
+
faces = self.model.get(image)
|
30 |
+
if self.return_raw:
|
31 |
+
return faces
|
32 |
+
else:
|
33 |
+
crop_face_list, mask_list = [], []
|
34 |
+
if len(faces) > 0:
|
35 |
+
if not self.multi_face:
|
36 |
+
faces = faces[:1]
|
37 |
+
for face in faces:
|
38 |
+
x_min, y_min, x_max, y_max = face['bbox'].tolist()
|
39 |
+
crop_face = image[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1]
|
40 |
+
crop_face_list.append(crop_face)
|
41 |
+
mask = np.zeros_like(image[:, :, 0])
|
42 |
+
mask[int(y_min): int(y_max) + 1, int(x_min): int(x_max) + 1] = 255
|
43 |
+
mask_list.append(mask)
|
44 |
+
if not self.multi_face:
|
45 |
+
crop_face_list = crop_face_list[0]
|
46 |
+
mask_list = mask_list[0]
|
47 |
+
if return_mask:
|
48 |
+
if return_dict:
|
49 |
+
return {'image': crop_face_list, 'mask': mask_list}
|
50 |
+
else:
|
51 |
+
return crop_face_list, mask_list
|
52 |
+
else:
|
53 |
+
return crop_face_list
|
54 |
+
else:
|
55 |
+
return None
|
vace/annotators/flow.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from .utils import convert_to_numpy
|
8 |
+
|
9 |
+
class FlowAnnotator:
|
10 |
+
def __init__(self, cfg, device=None):
|
11 |
+
try:
|
12 |
+
from raft import RAFT
|
13 |
+
from raft.utils.utils import InputPadder
|
14 |
+
from raft.utils import flow_viz
|
15 |
+
except:
|
16 |
+
import warnings
|
17 |
+
warnings.warn(
|
18 |
+
"ignore raft import, please pip install raft package. you can refer to models/VACE-Annotators/flow/raft-1.0.0-py3-none-any.whl")
|
19 |
+
|
20 |
+
params = {
|
21 |
+
"small": False,
|
22 |
+
"mixed_precision": False,
|
23 |
+
"alternate_corr": False
|
24 |
+
}
|
25 |
+
params = argparse.Namespace(**params)
|
26 |
+
pretrained_model = cfg['PRETRAINED_MODEL']
|
27 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
28 |
+
self.model = RAFT(params)
|
29 |
+
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
|
30 |
+
self.model = self.model.to(self.device).eval()
|
31 |
+
self.InputPadder = InputPadder
|
32 |
+
self.flow_viz = flow_viz
|
33 |
+
|
34 |
+
def forward(self, frames):
|
35 |
+
# frames / RGB
|
36 |
+
frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
|
37 |
+
flow_up_list, flow_up_vis_list = [], []
|
38 |
+
with torch.no_grad():
|
39 |
+
for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
|
40 |
+
padder = self.InputPadder(image1.shape)
|
41 |
+
image1, image2 = padder.pad(image1, image2)
|
42 |
+
flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
|
43 |
+
flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
|
44 |
+
flow_up_vis = self.flow_viz.flow_to_image(flow_up)
|
45 |
+
flow_up_list.append(flow_up)
|
46 |
+
flow_up_vis_list.append(flow_up_vis)
|
47 |
+
return flow_up_list, flow_up_vis_list # RGB
|
48 |
+
|
49 |
+
|
50 |
+
class FlowVisAnnotator(FlowAnnotator):
|
51 |
+
def forward(self, frames):
|
52 |
+
flow_up_list, flow_up_vis_list = super().forward(frames)
|
53 |
+
return flow_up_vis_list[:1] + flow_up_vis_list
|
vace/annotators/frameref.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from .utils import align_frames
|
6 |
+
|
7 |
+
|
8 |
+
class FrameRefExtractAnnotator:
|
9 |
+
para_dict = {}
|
10 |
+
|
11 |
+
def __init__(self, cfg, device=None):
|
12 |
+
# first / last / firstlast / random
|
13 |
+
self.ref_cfg = cfg.get('REF_CFG', [{"mode": "first", "proba": 0.1},
|
14 |
+
{"mode": "last", "proba": 0.1},
|
15 |
+
{"mode": "firstlast", "proba": 0.1},
|
16 |
+
{"mode": "random", "proba": 0.1}])
|
17 |
+
self.ref_num = cfg.get('REF_NUM', 1)
|
18 |
+
self.ref_color = cfg.get('REF_COLOR', 127.5)
|
19 |
+
self.return_dict = cfg.get('RETURN_DICT', True)
|
20 |
+
self.return_mask = cfg.get('RETURN_MASK', True)
|
21 |
+
|
22 |
+
|
23 |
+
def forward(self, frames, ref_cfg=None, ref_num=None, return_mask=None, return_dict=None):
|
24 |
+
return_mask = return_mask if return_mask is not None else self.return_mask
|
25 |
+
return_dict = return_dict if return_dict is not None else self.return_dict
|
26 |
+
ref_cfg = ref_cfg if ref_cfg is not None else self.ref_cfg
|
27 |
+
ref_cfg = [ref_cfg] if not isinstance(ref_cfg, list) else ref_cfg
|
28 |
+
probas = [item['proba'] if 'proba' in item else 1.0 / len(ref_cfg) for item in ref_cfg]
|
29 |
+
sel_ref_cfg = random.choices(ref_cfg, weights=probas, k=1)[0]
|
30 |
+
mode = sel_ref_cfg['mode'] if 'mode' in sel_ref_cfg else 'original'
|
31 |
+
ref_num = int(ref_num) if ref_num is not None else self.ref_num
|
32 |
+
|
33 |
+
frame_num = len(frames)
|
34 |
+
frame_num_range = list(range(frame_num))
|
35 |
+
if mode == "first":
|
36 |
+
sel_idx = frame_num_range[:ref_num]
|
37 |
+
elif mode == "last":
|
38 |
+
sel_idx = frame_num_range[-ref_num:]
|
39 |
+
elif mode == "firstlast":
|
40 |
+
sel_idx = frame_num_range[:ref_num] + frame_num_range[-ref_num:]
|
41 |
+
elif mode == "random":
|
42 |
+
sel_idx = random.sample(frame_num_range, ref_num)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
out_frames, out_masks = [], []
|
47 |
+
for i in range(frame_num):
|
48 |
+
if i in sel_idx:
|
49 |
+
out_frame = frames[i]
|
50 |
+
out_mask = np.zeros_like(frames[i][:, :, 0])
|
51 |
+
else:
|
52 |
+
out_frame = np.ones_like(frames[i]) * self.ref_color
|
53 |
+
out_mask = np.ones_like(frames[i][:, :, 0]) * 255
|
54 |
+
out_frames.append(out_frame)
|
55 |
+
out_masks.append(out_mask)
|
56 |
+
|
57 |
+
if return_dict:
|
58 |
+
ret_data = {"frames": out_frames}
|
59 |
+
if return_mask:
|
60 |
+
ret_data['masks'] = out_masks
|
61 |
+
return ret_data
|
62 |
+
else:
|
63 |
+
if return_mask:
|
64 |
+
return out_frames, out_masks
|
65 |
+
else:
|
66 |
+
return out_frames
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
class FrameRefExpandAnnotator:
|
71 |
+
para_dict = {}
|
72 |
+
|
73 |
+
def __init__(self, cfg, device=None):
|
74 |
+
# first / last / firstlast
|
75 |
+
self.ref_color = cfg.get('REF_COLOR', 127.5)
|
76 |
+
self.return_mask = cfg.get('RETURN_MASK', True)
|
77 |
+
self.return_dict = cfg.get('RETURN_DICT', True)
|
78 |
+
self.mode = cfg.get('MODE', "firstframe")
|
79 |
+
assert self.mode in ["firstframe", "lastframe", "firstlastframe", "firstclip", "lastclip", "firstlastclip", "all"]
|
80 |
+
|
81 |
+
def forward(self, image=None, image_2=None, frames=None, frames_2=None, mode=None, expand_num=None, return_mask=None, return_dict=None):
|
82 |
+
mode = mode if mode is not None else self.mode
|
83 |
+
return_mask = return_mask if return_mask is not None else self.return_mask
|
84 |
+
return_dict = return_dict if return_dict is not None else self.return_dict
|
85 |
+
|
86 |
+
if 'frame' in mode:
|
87 |
+
frames = [image] if image is not None and not isinstance(frames, list) else image
|
88 |
+
frames_2 = [image_2] if image_2 is not None and not isinstance(image_2, list) else image_2
|
89 |
+
|
90 |
+
expand_frames = [np.ones_like(frames[0]) * self.ref_color] * expand_num
|
91 |
+
expand_masks = [np.ones_like(frames[0][:, :, 0]) * 255] * expand_num
|
92 |
+
source_frames = frames
|
93 |
+
source_masks = [np.zeros_like(frames[0][:, :, 0])] * len(frames)
|
94 |
+
|
95 |
+
if mode in ["firstframe", "firstclip"]:
|
96 |
+
out_frames = source_frames + expand_frames
|
97 |
+
out_masks = source_masks + expand_masks
|
98 |
+
elif mode in ["lastframe", "lastclip"]:
|
99 |
+
out_frames = expand_frames + source_frames
|
100 |
+
out_masks = expand_masks + source_masks
|
101 |
+
elif mode in ["firstlastframe", "firstlastclip"]:
|
102 |
+
source_frames_2 = [align_frames(source_frames[0], f2) for f2 in frames_2]
|
103 |
+
source_masks_2 = [np.zeros_like(source_frames_2[0][:, :, 0])] * len(frames_2)
|
104 |
+
out_frames = source_frames + expand_frames + source_frames_2
|
105 |
+
out_masks = source_masks + expand_masks + source_masks_2
|
106 |
+
else:
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
if return_dict:
|
110 |
+
ret_data = {"frames": out_frames}
|
111 |
+
if return_mask:
|
112 |
+
ret_data['masks'] = out_masks
|
113 |
+
return ret_data
|
114 |
+
else:
|
115 |
+
if return_mask:
|
116 |
+
return out_frames, out_masks
|
117 |
+
else:
|
118 |
+
return out_frames
|
vace/annotators/gdino.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torchvision
|
8 |
+
from .utils import convert_to_numpy
|
9 |
+
|
10 |
+
|
11 |
+
class GDINOAnnotator:
|
12 |
+
def __init__(self, cfg, device=None):
|
13 |
+
try:
|
14 |
+
from groundingdino.util.inference import Model, load_model, load_image, predict
|
15 |
+
except:
|
16 |
+
import warnings
|
17 |
+
warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl")
|
18 |
+
|
19 |
+
grounding_dino_config_path = cfg['CONFIG_PATH']
|
20 |
+
grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL']
|
21 |
+
grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO
|
22 |
+
self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25)
|
23 |
+
self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2)
|
24 |
+
self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5)
|
25 |
+
self.use_nms = cfg.get('USE_NMS', True)
|
26 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
27 |
+
self.model = Model(model_config_path=grounding_dino_config_path,
|
28 |
+
model_checkpoint_path=grounding_dino_checkpoint_path,
|
29 |
+
device=self.device)
|
30 |
+
|
31 |
+
def forward(self, image, classes=None, caption=None):
|
32 |
+
image_bgr = convert_to_numpy(image)[..., ::-1] # bgr
|
33 |
+
|
34 |
+
if classes is not None:
|
35 |
+
classes = [classes] if isinstance(classes, str) else classes
|
36 |
+
detections = self.model.predict_with_classes(
|
37 |
+
image=image_bgr,
|
38 |
+
classes=classes,
|
39 |
+
box_threshold=self.box_threshold,
|
40 |
+
text_threshold=self.text_threshold
|
41 |
+
)
|
42 |
+
elif caption is not None:
|
43 |
+
detections, phrases = self.model.predict_with_caption(
|
44 |
+
image=image_bgr,
|
45 |
+
caption=caption,
|
46 |
+
box_threshold=self.box_threshold,
|
47 |
+
text_threshold=self.text_threshold
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError()
|
51 |
+
|
52 |
+
if self.use_nms:
|
53 |
+
nms_idx = torchvision.ops.nms(
|
54 |
+
torch.from_numpy(detections.xyxy),
|
55 |
+
torch.from_numpy(detections.confidence),
|
56 |
+
self.iou_threshold
|
57 |
+
).numpy().tolist()
|
58 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
59 |
+
detections.confidence = detections.confidence[nms_idx]
|
60 |
+
detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None
|
61 |
+
|
62 |
+
boxes = detections.xyxy
|
63 |
+
confidences = detections.confidence
|
64 |
+
class_ids = detections.class_id
|
65 |
+
class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases
|
66 |
+
|
67 |
+
ret_data = {
|
68 |
+
"boxes": boxes.tolist() if boxes is not None else None,
|
69 |
+
"confidences": confidences.tolist() if confidences is not None else None,
|
70 |
+
"class_ids": class_ids.tolist() if class_ids is not None else None,
|
71 |
+
"class_names": class_names if class_names is not None else None,
|
72 |
+
}
|
73 |
+
return ret_data
|
74 |
+
|
75 |
+
|
76 |
+
class GDINORAMAnnotator:
|
77 |
+
def __init__(self, cfg, device=None):
|
78 |
+
from .ram import RAMAnnotator
|
79 |
+
from .gdino import GDINOAnnotator
|
80 |
+
self.ram_model = RAMAnnotator(cfg['RAM'], device=device)
|
81 |
+
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
|
82 |
+
|
83 |
+
def forward(self, image):
|
84 |
+
ram_res = self.ram_model.forward(image)
|
85 |
+
classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res
|
86 |
+
gdino_res = self.gdino_model.forward(image, classes=classes)
|
87 |
+
return gdino_res
|
88 |
+
|
vace/annotators/gray.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from .utils import convert_to_numpy
|
7 |
+
|
8 |
+
|
9 |
+
class GrayAnnotator:
|
10 |
+
def __init__(self, cfg):
|
11 |
+
pass
|
12 |
+
def forward(self, image):
|
13 |
+
image = convert_to_numpy(image)
|
14 |
+
gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
15 |
+
return gray_map[..., None].repeat(3, axis=2)
|
16 |
+
|
17 |
+
|
18 |
+
class GrayVideoAnnotator(GrayAnnotator):
|
19 |
+
def forward(self, frames):
|
20 |
+
ret_frames = []
|
21 |
+
for frame in frames:
|
22 |
+
anno_frame = super().forward(np.array(frame))
|
23 |
+
ret_frames.append(anno_frame)
|
24 |
+
return ret_frames
|
vace/annotators/inpainting.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
from abc import ABCMeta
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from PIL import Image, ImageDraw
|
11 |
+
from .utils import convert_to_numpy, convert_to_pil, single_rle_to_mask, get_mask_box, read_video_one_frame
|
12 |
+
|
13 |
+
class InpaintingAnnotator:
|
14 |
+
def __init__(self, cfg, device=None):
|
15 |
+
self.use_aug = cfg.get('USE_AUG', True)
|
16 |
+
self.return_mask = cfg.get('RETURN_MASK', True)
|
17 |
+
self.return_source = cfg.get('RETURN_SOURCE', True)
|
18 |
+
self.mask_color = cfg.get('MASK_COLOR', 128)
|
19 |
+
self.mode = cfg.get('MODE', "mask")
|
20 |
+
assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
|
21 |
+
if self.mode in ["salient", "salienttrack"]:
|
22 |
+
from .salient import SalientAnnotator
|
23 |
+
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
|
24 |
+
if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
|
25 |
+
from .sam2 import SAM2ImageAnnotator
|
26 |
+
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
|
27 |
+
if self.mode in ['label', 'caption']:
|
28 |
+
from .gdino import GDINOAnnotator
|
29 |
+
from .sam2 import SAM2ImageAnnotator
|
30 |
+
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
|
31 |
+
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
|
32 |
+
if self.mode in ['all']:
|
33 |
+
from .salient import SalientAnnotator
|
34 |
+
from .gdino import GDINOAnnotator
|
35 |
+
from .sam2 import SAM2ImageAnnotator
|
36 |
+
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
|
37 |
+
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
|
38 |
+
self.sam2_model = SAM2ImageAnnotator(cfg['SAM2'], device=device)
|
39 |
+
if self.use_aug:
|
40 |
+
from .maskaug import MaskAugAnnotator
|
41 |
+
self.maskaug_anno = MaskAugAnnotator(cfg={})
|
42 |
+
|
43 |
+
def apply_plain_mask(self, image, mask, mask_color):
|
44 |
+
bool_mask = mask > 0
|
45 |
+
out_image = image.copy()
|
46 |
+
out_image[bool_mask] = mask_color
|
47 |
+
out_mask = np.where(bool_mask, 255, 0).astype(np.uint8)
|
48 |
+
return out_image, out_mask
|
49 |
+
|
50 |
+
def apply_seg_mask(self, image, mask, mask_color, mask_cfg=None):
|
51 |
+
out_mask = (mask * 255).astype('uint8')
|
52 |
+
if self.use_aug and mask_cfg is not None:
|
53 |
+
out_mask = self.maskaug_anno.forward(out_mask, mask_cfg)
|
54 |
+
bool_mask = out_mask > 0
|
55 |
+
out_image = image.copy()
|
56 |
+
out_image[bool_mask] = mask_color
|
57 |
+
return out_image, out_mask
|
58 |
+
|
59 |
+
def forward(self, image=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
|
60 |
+
mode = mode if mode is not None else self.mode
|
61 |
+
return_mask = return_mask if return_mask is not None else self.return_mask
|
62 |
+
return_source = return_source if return_source is not None else self.return_source
|
63 |
+
mask_color = mask_color if mask_color is not None else self.mask_color
|
64 |
+
|
65 |
+
image = convert_to_numpy(image)
|
66 |
+
out_image, out_mask = None, None
|
67 |
+
if mode in ['salient']:
|
68 |
+
mask = self.salient_model.forward(image)
|
69 |
+
out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
|
70 |
+
elif mode in ['mask']:
|
71 |
+
mask_h, mask_w = mask.shape[:2]
|
72 |
+
h, w = image.shape[:2]
|
73 |
+
if (mask_h ==h) and (mask_w == w):
|
74 |
+
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
75 |
+
out_image, out_mask = self.apply_plain_mask(image, mask, mask_color)
|
76 |
+
elif mode in ['bbox']:
|
77 |
+
x1, y1, x2, y2 = bbox
|
78 |
+
h, w = image.shape[:2]
|
79 |
+
x1, y1 = int(max(0, x1)), int(max(0, y1))
|
80 |
+
x2, y2 = int(min(w, x2)), int(min(h, y2))
|
81 |
+
out_image = image.copy()
|
82 |
+
out_image[y1:y2, x1:x2] = mask_color
|
83 |
+
out_mask = np.zeros((h, w), dtype=np.uint8)
|
84 |
+
out_mask[y1:y2, x1:x2] = 255
|
85 |
+
elif mode in ['salientmasktrack']:
|
86 |
+
mask = self.salient_model.forward(image)
|
87 |
+
resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
|
88 |
+
out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
|
89 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
90 |
+
elif mode in ['salientbboxtrack']:
|
91 |
+
mask = self.salient_model.forward(image)
|
92 |
+
bbox = get_mask_box(np.array(mask), threshold=1)
|
93 |
+
out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
|
94 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
95 |
+
elif mode in ['maskpointtrack']:
|
96 |
+
out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_point', return_mask=True)
|
97 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
98 |
+
elif mode in ['maskbboxtrack']:
|
99 |
+
out_mask = self.sam2_model.forward(image=image, mask=mask, task_type='mask_box', return_mask=True)
|
100 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
101 |
+
elif mode in ['masktrack']:
|
102 |
+
resize_mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)
|
103 |
+
out_mask = self.sam2_model.forward(image=image, mask=resize_mask, task_type='mask', return_mask=True)
|
104 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
105 |
+
elif mode in ['bboxtrack']:
|
106 |
+
out_mask = self.sam2_model.forward(image=image, input_box=bbox, task_type='input_box', return_mask=True)
|
107 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
108 |
+
elif mode in ['label']:
|
109 |
+
gdino_res = self.gdino_model.forward(image, classes=label)
|
110 |
+
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
|
111 |
+
bboxes = gdino_res['boxes'][0]
|
112 |
+
else:
|
113 |
+
raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
|
114 |
+
out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
|
115 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
116 |
+
elif mode in ['caption']:
|
117 |
+
gdino_res = self.gdino_model.forward(image, caption=caption)
|
118 |
+
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
|
119 |
+
bboxes = gdino_res['boxes'][0]
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
|
122 |
+
out_mask = self.sam2_model.forward(image=image, input_box=bboxes, task_type='input_box', return_mask=True)
|
123 |
+
out_image, out_mask = self.apply_seg_mask(image, out_mask, mask_color, mask_cfg)
|
124 |
+
|
125 |
+
ret_data = {"image": out_image}
|
126 |
+
if return_mask:
|
127 |
+
ret_data["mask"] = out_mask
|
128 |
+
if return_source:
|
129 |
+
ret_data["src_image"] = image
|
130 |
+
return ret_data
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
class InpaintingVideoAnnotator:
|
136 |
+
def __init__(self, cfg, device=None):
|
137 |
+
self.use_aug = cfg.get('USE_AUG', True)
|
138 |
+
self.return_frame = cfg.get('RETURN_FRAME', True)
|
139 |
+
self.return_mask = cfg.get('RETURN_MASK', True)
|
140 |
+
self.return_source = cfg.get('RETURN_SOURCE', True)
|
141 |
+
self.mask_color = cfg.get('MASK_COLOR', 128)
|
142 |
+
self.mode = cfg.get('MODE', "mask")
|
143 |
+
assert self.mode in ["salient", "mask", "bbox", "salientmasktrack", "salientbboxtrack", "maskpointtrack", "maskbboxtrack", "masktrack", "bboxtrack", "label", "caption", "all"]
|
144 |
+
if self.mode in ["salient", "salienttrack"]:
|
145 |
+
from .salient import SalientAnnotator
|
146 |
+
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
|
147 |
+
if self.mode in ['masktrack', 'bboxtrack', 'salienttrack']:
|
148 |
+
from .sam2 import SAM2VideoAnnotator
|
149 |
+
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
|
150 |
+
if self.mode in ['label', 'caption']:
|
151 |
+
from .gdino import GDINOAnnotator
|
152 |
+
from .sam2 import SAM2VideoAnnotator
|
153 |
+
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
|
154 |
+
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
|
155 |
+
if self.mode in ['all']:
|
156 |
+
from .salient import SalientAnnotator
|
157 |
+
from .gdino import GDINOAnnotator
|
158 |
+
from .sam2 import SAM2VideoAnnotator
|
159 |
+
self.salient_model = SalientAnnotator(cfg['SALIENT'], device=device)
|
160 |
+
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
|
161 |
+
self.sam2_model = SAM2VideoAnnotator(cfg['SAM2'], device=device)
|
162 |
+
if self.use_aug:
|
163 |
+
from .maskaug import MaskAugAnnotator
|
164 |
+
self.maskaug_anno = MaskAugAnnotator(cfg={})
|
165 |
+
|
166 |
+
def apply_plain_mask(self, frames, mask, mask_color, return_frame=True):
|
167 |
+
out_frames = []
|
168 |
+
num_frames = len(frames)
|
169 |
+
bool_mask = mask > 0
|
170 |
+
out_masks = [np.where(bool_mask, 255, 0).astype(np.uint8)] * num_frames
|
171 |
+
if not return_frame:
|
172 |
+
return None, out_masks
|
173 |
+
for i in range(num_frames):
|
174 |
+
masked_frame = frames[i].copy()
|
175 |
+
masked_frame[bool_mask] = mask_color
|
176 |
+
out_frames.append(masked_frame)
|
177 |
+
return out_frames, out_masks
|
178 |
+
|
179 |
+
def apply_seg_mask(self, mask_data, frames, mask_color, mask_cfg=None, return_frame=True):
|
180 |
+
out_frames = []
|
181 |
+
out_masks = [(single_rle_to_mask(val[0]["mask"]) * 255).astype('uint8') for key, val in mask_data['annotations'].items()]
|
182 |
+
if not return_frame:
|
183 |
+
return None, out_masks
|
184 |
+
num_frames = min(len(out_masks), len(frames))
|
185 |
+
for i in range(num_frames):
|
186 |
+
sub_mask = out_masks[i]
|
187 |
+
if self.use_aug and mask_cfg is not None:
|
188 |
+
sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
|
189 |
+
out_masks[i] = sub_mask
|
190 |
+
bool_mask = sub_mask > 0
|
191 |
+
masked_frame = frames[i].copy()
|
192 |
+
masked_frame[bool_mask] = mask_color
|
193 |
+
out_frames.append(masked_frame)
|
194 |
+
out_masks = out_masks[:num_frames]
|
195 |
+
return out_frames, out_masks
|
196 |
+
|
197 |
+
def forward(self, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None, return_frame=None, return_mask=None, return_source=None, mask_color=None, mask_cfg=None):
|
198 |
+
mode = mode if mode is not None else self.mode
|
199 |
+
return_frame = return_frame if return_frame is not None else self.return_frame
|
200 |
+
return_mask = return_mask if return_mask is not None else self.return_mask
|
201 |
+
return_source = return_source if return_source is not None else self.return_source
|
202 |
+
mask_color = mask_color if mask_color is not None else self.mask_color
|
203 |
+
|
204 |
+
out_frames, out_masks = [], []
|
205 |
+
if mode in ['salient']:
|
206 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
207 |
+
mask = self.salient_model.forward(first_frame)
|
208 |
+
out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
|
209 |
+
elif mode in ['mask']:
|
210 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
211 |
+
mask_h, mask_w = mask.shape[:2]
|
212 |
+
h, w = first_frame.shape[:2]
|
213 |
+
if (mask_h ==h) and (mask_w == w):
|
214 |
+
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
215 |
+
out_frames, out_masks = self.apply_plain_mask(frames, mask, mask_color, return_frame)
|
216 |
+
elif mode in ['bbox']:
|
217 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
218 |
+
num_frames = len(frames)
|
219 |
+
x1, y1, x2, y2 = bbox
|
220 |
+
h, w = first_frame.shape[:2]
|
221 |
+
x1, y1 = int(max(0, x1)), int(max(0, y1))
|
222 |
+
x2, y2 = int(min(w, x2)), int(min(h, y2))
|
223 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
224 |
+
mask[y1:y2, x1:x2] = 255
|
225 |
+
out_masks = [mask] * num_frames
|
226 |
+
if not return_frame:
|
227 |
+
out_frames = None
|
228 |
+
else:
|
229 |
+
for i in range(num_frames):
|
230 |
+
masked_frame = frames[i].copy()
|
231 |
+
masked_frame[y1:y2, x1:x2] = mask_color
|
232 |
+
out_frames.append(masked_frame)
|
233 |
+
elif mode in ['salientmasktrack']:
|
234 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
235 |
+
salient_mask = self.salient_model.forward(first_frame)
|
236 |
+
mask_data = self.sam2_model.forward(video=video, mask=salient_mask, task_type='mask')
|
237 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
238 |
+
elif mode in ['salientbboxtrack']:
|
239 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
240 |
+
salient_mask = self.salient_model.forward(first_frame)
|
241 |
+
bbox = get_mask_box(np.array(salient_mask), threshold=1)
|
242 |
+
mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
|
243 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
244 |
+
elif mode in ['maskpointtrack']:
|
245 |
+
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_point')
|
246 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
247 |
+
elif mode in ['maskbboxtrack']:
|
248 |
+
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask_box')
|
249 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
250 |
+
elif mode in ['masktrack']:
|
251 |
+
mask_data = self.sam2_model.forward(video=video, mask=mask, task_type='mask')
|
252 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
253 |
+
elif mode in ['bboxtrack']:
|
254 |
+
mask_data = self.sam2_model.forward(video=video, input_box=bbox, task_type='input_box')
|
255 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
256 |
+
elif mode in ['label']:
|
257 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
258 |
+
gdino_res = self.gdino_model.forward(first_frame, classes=label)
|
259 |
+
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
|
260 |
+
bboxes = gdino_res['boxes'][0]
|
261 |
+
else:
|
262 |
+
raise ValueError(f"Unable to find the corresponding boxes of label: {label}")
|
263 |
+
mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
|
264 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
265 |
+
elif mode in ['caption']:
|
266 |
+
first_frame = frames[0] if frames is not None else read_video_one_frame(video)
|
267 |
+
gdino_res = self.gdino_model.forward(first_frame, caption=caption)
|
268 |
+
if 'boxes' in gdino_res and len(gdino_res['boxes']) > 0:
|
269 |
+
bboxes = gdino_res['boxes'][0]
|
270 |
+
else:
|
271 |
+
raise ValueError(f"Unable to find the corresponding boxes of caption: {caption}")
|
272 |
+
mask_data = self.sam2_model.forward(video=video, input_box=bboxes, task_type='input_box')
|
273 |
+
out_frames, out_masks = self.apply_seg_mask(mask_data, frames, mask_color, mask_cfg, return_frame)
|
274 |
+
|
275 |
+
ret_data = {}
|
276 |
+
if return_frame:
|
277 |
+
ret_data["frames"] = out_frames
|
278 |
+
if return_mask:
|
279 |
+
ret_data["masks"] = out_masks
|
280 |
+
return ret_data
|
281 |
+
|
282 |
+
|
283 |
+
|
vace/annotators/layout.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .utils import convert_to_numpy
|
8 |
+
|
9 |
+
|
10 |
+
class LayoutBboxAnnotator:
|
11 |
+
def __init__(self, cfg, device=None):
|
12 |
+
self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
|
13 |
+
self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
|
14 |
+
self.frame_size = cfg.get('FRAME_SIZE', [720, 1280]) # [H, W]
|
15 |
+
self.num_frames = cfg.get('NUM_FRAMES', 81)
|
16 |
+
ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
|
17 |
+
self.color_dict = {'default': tuple(self.box_color)}
|
18 |
+
if ram_tag_color_path is not None:
|
19 |
+
lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
|
20 |
+
self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
|
21 |
+
|
22 |
+
def forward(self, bbox, frame_size=None, num_frames=None, label=None, color=None):
|
23 |
+
frame_size = frame_size if frame_size is not None else self.frame_size
|
24 |
+
num_frames = num_frames if num_frames is not None else self.num_frames
|
25 |
+
assert len(bbox) == 2, 'bbox should be a list of two elements (start_bbox & end_bbox)'
|
26 |
+
# frame_size = [H, W]
|
27 |
+
# bbox = [x1, y1, x2, y2]
|
28 |
+
label = label[0] if label is not None and isinstance(label, list) else label
|
29 |
+
if label is not None and label in self.color_dict:
|
30 |
+
box_color = self.color_dict[label]
|
31 |
+
elif color is not None:
|
32 |
+
box_color = color
|
33 |
+
else:
|
34 |
+
box_color = self.color_dict['default']
|
35 |
+
start_bbox, end_bbox = bbox
|
36 |
+
start_bbox = [start_bbox[0], start_bbox[1], start_bbox[2] - start_bbox[0], start_bbox[3] - start_bbox[1]]
|
37 |
+
start_bbox = np.array(start_bbox, dtype=np.float32)
|
38 |
+
end_bbox = [end_bbox[0], end_bbox[1], end_bbox[2] - end_bbox[0], end_bbox[3] - end_bbox[1]]
|
39 |
+
end_bbox = np.array(end_bbox, dtype=np.float32)
|
40 |
+
bbox_increment = (end_bbox - start_bbox) / num_frames
|
41 |
+
ret_frames = []
|
42 |
+
for frame_idx in range(num_frames):
|
43 |
+
frame = np.zeros((frame_size[0], frame_size[1], 3), dtype=np.uint8)
|
44 |
+
frame[:] = self.bg_color
|
45 |
+
current_bbox = start_bbox + bbox_increment * frame_idx
|
46 |
+
current_bbox = current_bbox.astype(int)
|
47 |
+
x, y, w, h = current_bbox
|
48 |
+
cv2.rectangle(frame, (x, y), (x + w, y + h), box_color, 2)
|
49 |
+
ret_frames.append(frame[..., ::-1])
|
50 |
+
return ret_frames
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
class LayoutMaskAnnotator:
|
56 |
+
def __init__(self, cfg, device=None):
|
57 |
+
self.use_aug = cfg.get('USE_AUG', False)
|
58 |
+
self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
|
59 |
+
self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
|
60 |
+
ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
|
61 |
+
self.color_dict = {'default': tuple(self.box_color)}
|
62 |
+
if ram_tag_color_path is not None:
|
63 |
+
lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
|
64 |
+
self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
|
65 |
+
if self.use_aug:
|
66 |
+
from .maskaug import MaskAugAnnotator
|
67 |
+
self.maskaug_anno = MaskAugAnnotator(cfg={})
|
68 |
+
|
69 |
+
|
70 |
+
def find_contours(self, mask):
|
71 |
+
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
72 |
+
return contours
|
73 |
+
|
74 |
+
def draw_contours(self, canvas, contour, color):
|
75 |
+
canvas = np.ascontiguousarray(canvas, dtype=np.uint8)
|
76 |
+
canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3)
|
77 |
+
return canvas
|
78 |
+
|
79 |
+
def forward(self, mask=None, color=None, label=None, mask_cfg=None):
|
80 |
+
if not isinstance(mask, list):
|
81 |
+
is_batch = False
|
82 |
+
mask = [mask]
|
83 |
+
else:
|
84 |
+
is_batch = True
|
85 |
+
|
86 |
+
if label is not None and label in self.color_dict:
|
87 |
+
color = self.color_dict[label]
|
88 |
+
elif color is not None:
|
89 |
+
color = color
|
90 |
+
else:
|
91 |
+
color = self.color_dict['default']
|
92 |
+
|
93 |
+
ret_data = []
|
94 |
+
for sub_mask in mask:
|
95 |
+
sub_mask = convert_to_numpy(sub_mask)
|
96 |
+
if self.use_aug:
|
97 |
+
sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
|
98 |
+
canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255
|
99 |
+
contour = self.find_contours(sub_mask)
|
100 |
+
frame = self.draw_contours(canvas, contour, color)
|
101 |
+
ret_data.append(frame)
|
102 |
+
|
103 |
+
if is_batch:
|
104 |
+
return ret_data
|
105 |
+
else:
|
106 |
+
return ret_data[0]
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
class LayoutTrackAnnotator:
|
112 |
+
def __init__(self, cfg, device=None):
|
113 |
+
self.use_aug = cfg.get('USE_AUG', False)
|
114 |
+
self.bg_color = cfg.get('BG_COLOR', [255, 255, 255])
|
115 |
+
self.box_color = cfg.get('BOX_COLOR', [0, 0, 0])
|
116 |
+
ram_tag_color_path = cfg.get('RAM_TAG_COLOR_PATH', None)
|
117 |
+
self.color_dict = {'default': tuple(self.box_color)}
|
118 |
+
if ram_tag_color_path is not None:
|
119 |
+
lines = [id_name_color.strip().split('#;#') for id_name_color in open(ram_tag_color_path).readlines()]
|
120 |
+
self.color_dict.update({id_name_color[1]: tuple(eval(id_name_color[2])) for id_name_color in lines})
|
121 |
+
if self.use_aug:
|
122 |
+
from .maskaug import MaskAugAnnotator
|
123 |
+
self.maskaug_anno = MaskAugAnnotator(cfg={})
|
124 |
+
from .inpainting import InpaintingVideoAnnotator
|
125 |
+
self.inpainting_anno = InpaintingVideoAnnotator(cfg=cfg['INPAINTING'])
|
126 |
+
|
127 |
+
def find_contours(self, mask):
|
128 |
+
contours, hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
129 |
+
return contours
|
130 |
+
|
131 |
+
def draw_contours(self, canvas, contour, color):
|
132 |
+
canvas = np.ascontiguousarray(canvas, dtype=np.uint8)
|
133 |
+
canvas = cv2.drawContours(canvas, contour, -1, color, thickness=3)
|
134 |
+
return canvas
|
135 |
+
|
136 |
+
def forward(self, color=None, mask_cfg=None, frames=None, video=None, mask=None, bbox=None, label=None, caption=None, mode=None):
|
137 |
+
inp_data = self.inpainting_anno.forward(frames, video, mask, bbox, label, caption, mode)
|
138 |
+
inp_masks = inp_data['masks']
|
139 |
+
|
140 |
+
label = label[0] if label is not None and isinstance(label, list) else label
|
141 |
+
if label is not None and label in self.color_dict:
|
142 |
+
color = self.color_dict[label]
|
143 |
+
elif color is not None:
|
144 |
+
color = color
|
145 |
+
else:
|
146 |
+
color = self.color_dict['default']
|
147 |
+
|
148 |
+
num_frames = len(inp_masks)
|
149 |
+
ret_data = []
|
150 |
+
for i in range(num_frames):
|
151 |
+
sub_mask = inp_masks[i]
|
152 |
+
if self.use_aug and mask_cfg is not None:
|
153 |
+
sub_mask = self.maskaug_anno.forward(sub_mask, mask_cfg)
|
154 |
+
canvas = np.ones((sub_mask.shape[0], sub_mask.shape[1], 3)) * 255
|
155 |
+
contour = self.find_contours(sub_mask)
|
156 |
+
frame = self.draw_contours(canvas, contour, color)
|
157 |
+
ret_data.append(frame)
|
158 |
+
|
159 |
+
return ret_data
|
160 |
+
|
161 |
+
|
vace/annotators/mask.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from scipy.spatial import ConvexHull
|
6 |
+
from skimage.draw import polygon
|
7 |
+
from scipy import ndimage
|
8 |
+
|
9 |
+
from .utils import convert_to_numpy
|
10 |
+
|
11 |
+
|
12 |
+
class MaskDrawAnnotator:
|
13 |
+
def __init__(self, cfg, device=None):
|
14 |
+
self.mode = cfg.get('MODE', 'maskpoint')
|
15 |
+
self.return_dict = cfg.get('RETURN_DICT', True)
|
16 |
+
assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox']
|
17 |
+
|
18 |
+
def forward(self,
|
19 |
+
mask=None,
|
20 |
+
image=None,
|
21 |
+
bbox=None,
|
22 |
+
mode=None,
|
23 |
+
return_dict=None):
|
24 |
+
mode = mode if mode is not None else self.mode
|
25 |
+
return_dict = return_dict if return_dict is not None else self.return_dict
|
26 |
+
|
27 |
+
mask = convert_to_numpy(mask) if mask is not None else None
|
28 |
+
image = convert_to_numpy(image) if image is not None else None
|
29 |
+
|
30 |
+
mask_shape = mask.shape
|
31 |
+
if mode == 'maskpoint':
|
32 |
+
scribble = mask.transpose(1, 0)
|
33 |
+
labeled_array, num_features = ndimage.label(scribble >= 255)
|
34 |
+
centers = ndimage.center_of_mass(scribble, labeled_array,
|
35 |
+
range(1, num_features + 1))
|
36 |
+
centers = np.array(centers)
|
37 |
+
out_mask = np.zeros(mask_shape, dtype=np.uint8)
|
38 |
+
hull = ConvexHull(centers)
|
39 |
+
hull_vertices = centers[hull.vertices]
|
40 |
+
rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape)
|
41 |
+
out_mask[rr, cc] = 255
|
42 |
+
elif mode == 'maskbbox':
|
43 |
+
scribble = mask.transpose(1, 0)
|
44 |
+
labeled_array, num_features = ndimage.label(scribble >= 255)
|
45 |
+
centers = ndimage.center_of_mass(scribble, labeled_array,
|
46 |
+
range(1, num_features + 1))
|
47 |
+
centers = np.array(centers)
|
48 |
+
# (x1, y1, x2, y2)
|
49 |
+
x_min = centers[:, 0].min()
|
50 |
+
x_max = centers[:, 0].max()
|
51 |
+
y_min = centers[:, 1].min()
|
52 |
+
y_max = centers[:, 1].max()
|
53 |
+
out_mask = np.zeros(mask_shape, dtype=np.uint8)
|
54 |
+
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
|
55 |
+
if image is not None:
|
56 |
+
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
|
57 |
+
elif mode == 'bbox':
|
58 |
+
if isinstance(bbox, list):
|
59 |
+
bbox = np.array(bbox)
|
60 |
+
x_min, y_min, x_max, y_max = bbox
|
61 |
+
out_mask = np.zeros(mask_shape, dtype=np.uint8)
|
62 |
+
out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255
|
63 |
+
if image is not None:
|
64 |
+
out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1]
|
65 |
+
elif mode == 'mask':
|
66 |
+
out_mask = mask
|
67 |
+
else:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
if return_dict:
|
71 |
+
if image is not None:
|
72 |
+
return {"image": out_image, "mask": out_mask}
|
73 |
+
else:
|
74 |
+
return {"mask": out_mask}
|
75 |
+
else:
|
76 |
+
if image is not None:
|
77 |
+
return out_image, out_mask
|
78 |
+
else:
|
79 |
+
return out_mask
|
vace/annotators/maskaug.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
|
5 |
+
import random
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image, ImageDraw
|
11 |
+
|
12 |
+
from .utils import convert_to_numpy
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
class MaskAugAnnotator:
|
17 |
+
def __init__(self, cfg, device=None):
|
18 |
+
# original / original_expand / hull / hull_expand / bbox / bbox_expand
|
19 |
+
self.mask_cfg = cfg.get('MASK_CFG', [{"mode": "original", "proba": 0.1},
|
20 |
+
{"mode": "original_expand", "proba": 0.1},
|
21 |
+
{"mode": "hull", "proba": 0.1},
|
22 |
+
{"mode": "hull_expand", "proba":0.1, "kwargs": {"expand_ratio": 0.2}},
|
23 |
+
{"mode": "bbox", "proba": 0.1},
|
24 |
+
{"mode": "bbox_expand", "proba": 0.1, "kwargs": {"min_expand_ratio": 0.2, "max_expand_ratio": 0.5}}])
|
25 |
+
|
26 |
+
def forward(self, mask, mask_cfg=None):
|
27 |
+
mask_cfg = mask_cfg if mask_cfg is not None else self.mask_cfg
|
28 |
+
if not isinstance(mask, list):
|
29 |
+
is_batch = False
|
30 |
+
masks = [mask]
|
31 |
+
else:
|
32 |
+
is_batch = True
|
33 |
+
masks = mask
|
34 |
+
|
35 |
+
mask_func = self.get_mask_func(mask_cfg)
|
36 |
+
# print(mask_func)
|
37 |
+
aug_masks = []
|
38 |
+
for submask in masks:
|
39 |
+
mask = convert_to_numpy(submask)
|
40 |
+
valid, large, h, w, bbox = self.get_mask_info(mask)
|
41 |
+
# print(valid, large, h, w, bbox)
|
42 |
+
if valid:
|
43 |
+
mask = mask_func(mask, bbox, h, w)
|
44 |
+
else:
|
45 |
+
mask = mask.astype(np.uint8)
|
46 |
+
aug_masks.append(mask)
|
47 |
+
return aug_masks if is_batch else aug_masks[0]
|
48 |
+
|
49 |
+
def get_mask_info(self, mask):
|
50 |
+
h, w = mask.shape
|
51 |
+
locs = mask.nonzero()
|
52 |
+
valid = True
|
53 |
+
if len(locs) < 1 or locs[0].shape[0] < 1 or locs[1].shape[0] < 1:
|
54 |
+
valid = False
|
55 |
+
return valid, False, h, w, [0, 0, 0, 0]
|
56 |
+
|
57 |
+
left, right = np.min(locs[1]), np.max(locs[1])
|
58 |
+
top, bottom = np.min(locs[0]), np.max(locs[0])
|
59 |
+
bbox = [left, top, right, bottom]
|
60 |
+
|
61 |
+
large = False
|
62 |
+
if (right - left + 1) * (bottom - top + 1) > 0.9 * h * w:
|
63 |
+
large = True
|
64 |
+
return valid, large, h, w, bbox
|
65 |
+
|
66 |
+
def get_expand_params(self, mask_kwargs):
|
67 |
+
if 'expand_ratio' in mask_kwargs:
|
68 |
+
expand_ratio = mask_kwargs['expand_ratio']
|
69 |
+
elif 'min_expand_ratio' in mask_kwargs and 'max_expand_ratio' in mask_kwargs:
|
70 |
+
expand_ratio = random.uniform(mask_kwargs['min_expand_ratio'], mask_kwargs['max_expand_ratio'])
|
71 |
+
else:
|
72 |
+
expand_ratio = 0.3
|
73 |
+
|
74 |
+
if 'expand_iters' in mask_kwargs:
|
75 |
+
expand_iters = mask_kwargs['expand_iters']
|
76 |
+
else:
|
77 |
+
expand_iters = random.randint(1, 10)
|
78 |
+
|
79 |
+
if 'expand_lrtp' in mask_kwargs:
|
80 |
+
expand_lrtp = mask_kwargs['expand_lrtp']
|
81 |
+
else:
|
82 |
+
expand_lrtp = [random.random(), random.random(), random.random(), random.random()]
|
83 |
+
|
84 |
+
return expand_ratio, expand_iters, expand_lrtp
|
85 |
+
|
86 |
+
def get_mask_func(self, mask_cfg):
|
87 |
+
if not isinstance(mask_cfg, list):
|
88 |
+
mask_cfg = [mask_cfg]
|
89 |
+
probas = [item['proba'] if 'proba' in item else 1.0 / len(mask_cfg) for item in mask_cfg]
|
90 |
+
sel_mask_cfg = random.choices(mask_cfg, weights=probas, k=1)[0]
|
91 |
+
mode = sel_mask_cfg['mode'] if 'mode' in sel_mask_cfg else 'original'
|
92 |
+
mask_kwargs = sel_mask_cfg['kwargs'] if 'kwargs' in sel_mask_cfg else {}
|
93 |
+
|
94 |
+
if mode == 'random':
|
95 |
+
mode = random.choice(['original', 'original_expand', 'hull', 'hull_expand', 'bbox', 'bbox_expand'])
|
96 |
+
if mode == 'original':
|
97 |
+
mask_func = partial(self.generate_mask)
|
98 |
+
elif mode == 'original_expand':
|
99 |
+
expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
|
100 |
+
mask_func = partial(self.generate_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
|
101 |
+
elif mode == 'hull':
|
102 |
+
clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise']
|
103 |
+
mask_func = partial(self.generate_hull_mask, clockwise=clockwise)
|
104 |
+
elif mode == 'hull_expand':
|
105 |
+
expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
|
106 |
+
clockwise = random.choice([True, False]) if 'clockwise' not in mask_kwargs else mask_kwargs['clockwise']
|
107 |
+
mask_func = partial(self.generate_hull_mask, clockwise=clockwise, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
|
108 |
+
elif mode == 'bbox':
|
109 |
+
mask_func = partial(self.generate_bbox_mask)
|
110 |
+
elif mode == 'bbox_expand':
|
111 |
+
expand_ratio, expand_iters, expand_lrtp = self.get_expand_params(mask_kwargs)
|
112 |
+
mask_func = partial(self.generate_bbox_mask, expand_ratio=expand_ratio, expand_iters=expand_iters, expand_lrtp=expand_lrtp)
|
113 |
+
else:
|
114 |
+
raise NotImplementedError
|
115 |
+
return mask_func
|
116 |
+
|
117 |
+
|
118 |
+
def generate_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
|
119 |
+
bin_mask = mask.astype(np.uint8)
|
120 |
+
if expand_ratio:
|
121 |
+
bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
|
122 |
+
return bin_mask
|
123 |
+
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def rand_expand_mask(mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
|
127 |
+
expand_ratio = 0.3 if expand_ratio is None else expand_ratio
|
128 |
+
expand_iters = random.randint(1, 10) if expand_iters is None else expand_iters
|
129 |
+
expand_lrtp = [random.random(), random.random(), random.random(), random.random()] if expand_lrtp is None else expand_lrtp
|
130 |
+
# print('iters', expand_iters, 'expand_ratio', expand_ratio, 'expand_lrtp', expand_lrtp)
|
131 |
+
# mask = np.squeeze(mask)
|
132 |
+
left, top, right, bottom = bbox
|
133 |
+
# mask expansion
|
134 |
+
box_w = (right - left + 1) * expand_ratio
|
135 |
+
box_h = (bottom - top + 1) * expand_ratio
|
136 |
+
left_, right_ = int(expand_lrtp[0] * min(box_w, left / 2) / expand_iters), int(
|
137 |
+
expand_lrtp[1] * min(box_w, (w - right) / 2) / expand_iters)
|
138 |
+
top_, bottom_ = int(expand_lrtp[2] * min(box_h, top / 2) / expand_iters), int(
|
139 |
+
expand_lrtp[3] * min(box_h, (h - bottom) / 2) / expand_iters)
|
140 |
+
kernel_size = max(left_, right_, top_, bottom_)
|
141 |
+
if kernel_size > 0:
|
142 |
+
kernel = np.zeros((kernel_size * 2, kernel_size * 2), dtype=np.uint8)
|
143 |
+
new_left, new_right = kernel_size - right_, kernel_size + left_
|
144 |
+
new_top, new_bottom = kernel_size - bottom_, kernel_size + top_
|
145 |
+
kernel[new_top:new_bottom + 1, new_left:new_right + 1] = 1
|
146 |
+
mask = mask.astype(np.uint8)
|
147 |
+
mask = cv2.dilate(mask, kernel, iterations=expand_iters).astype(np.uint8)
|
148 |
+
# mask = new_mask - (mask / 2).astype(np.uint8)
|
149 |
+
# mask = np.expand_dims(mask, axis=-1)
|
150 |
+
return mask
|
151 |
+
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def _convexhull(image, clockwise):
|
155 |
+
contours, hierarchy = cv2.findContours(image, 2, 1)
|
156 |
+
cnt = np.concatenate(contours) # merge all regions
|
157 |
+
hull = cv2.convexHull(cnt, clockwise=clockwise)
|
158 |
+
hull = np.squeeze(hull, axis=1).astype(np.float32).tolist()
|
159 |
+
hull = [tuple(x) for x in hull]
|
160 |
+
return hull # b, 1, 2
|
161 |
+
|
162 |
+
def generate_hull_mask(self, mask, bbox, h, w, clockwise=None, expand_ratio=None, expand_iters=None, expand_lrtp=None):
|
163 |
+
clockwise = random.choice([True, False]) if clockwise is None else clockwise
|
164 |
+
hull = self._convexhull(mask, clockwise)
|
165 |
+
mask_img = Image.new('L', (w, h), 0)
|
166 |
+
pt_list = hull
|
167 |
+
mask_img_draw = ImageDraw.Draw(mask_img)
|
168 |
+
mask_img_draw.polygon(pt_list, fill=255)
|
169 |
+
bin_mask = np.array(mask_img).astype(np.uint8)
|
170 |
+
if expand_ratio:
|
171 |
+
bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
|
172 |
+
return bin_mask
|
173 |
+
|
174 |
+
|
175 |
+
def generate_bbox_mask(self, mask, bbox, h, w, expand_ratio=None, expand_iters=None, expand_lrtp=None):
|
176 |
+
left, top, right, bottom = bbox
|
177 |
+
bin_mask = np.zeros((h, w), dtype=np.uint8)
|
178 |
+
bin_mask[top:bottom + 1, left:right + 1] = 255
|
179 |
+
if expand_ratio:
|
180 |
+
bin_mask = self.rand_expand_mask(bin_mask, bbox, h, w, expand_ratio, expand_iters, expand_lrtp)
|
181 |
+
return bin_mask
|
vace/annotators/midas/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
vace/annotators/midas/api.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# based on https://github.com/isl-org/MiDaS
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torchvision.transforms import Compose
|
9 |
+
|
10 |
+
from .dpt_depth import DPTDepthModel
|
11 |
+
from .midas_net import MidasNet
|
12 |
+
from .midas_net_custom import MidasNet_small
|
13 |
+
from .transforms import NormalizeImage, PrepareForNet, Resize
|
14 |
+
|
15 |
+
# ISL_PATHS = {
|
16 |
+
# "dpt_large": "dpt_large-midas-2f21e586.pt",
|
17 |
+
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
18 |
+
# "midas_v21": "",
|
19 |
+
# "midas_v21_small": "",
|
20 |
+
# }
|
21 |
+
|
22 |
+
# remote_model_path =
|
23 |
+
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == 'dpt_large': # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = 'minimal'
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
39 |
+
std=[0.5, 0.5, 0.5])
|
40 |
+
|
41 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
42 |
+
net_w, net_h = 384, 384
|
43 |
+
resize_mode = 'minimal'
|
44 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
45 |
+
std=[0.5, 0.5, 0.5])
|
46 |
+
|
47 |
+
elif model_type == 'midas_v21':
|
48 |
+
net_w, net_h = 384, 384
|
49 |
+
resize_mode = 'upper_bound'
|
50 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
51 |
+
std=[0.229, 0.224, 0.225])
|
52 |
+
|
53 |
+
elif model_type == 'midas_v21_small':
|
54 |
+
net_w, net_h = 256, 256
|
55 |
+
resize_mode = 'upper_bound'
|
56 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
57 |
+
std=[0.229, 0.224, 0.225])
|
58 |
+
|
59 |
+
else:
|
60 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
61 |
+
|
62 |
+
transform = Compose([
|
63 |
+
Resize(
|
64 |
+
net_w,
|
65 |
+
net_h,
|
66 |
+
resize_target=None,
|
67 |
+
keep_aspect_ratio=True,
|
68 |
+
ensure_multiple_of=32,
|
69 |
+
resize_method=resize_mode,
|
70 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
71 |
+
),
|
72 |
+
normalization,
|
73 |
+
PrepareForNet(),
|
74 |
+
])
|
75 |
+
|
76 |
+
return transform
|
77 |
+
|
78 |
+
|
79 |
+
def load_model(model_type, model_path):
|
80 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
81 |
+
# load network
|
82 |
+
# model_path = ISL_PATHS[model_type]
|
83 |
+
if model_type == 'dpt_large': # DPT-Large
|
84 |
+
model = DPTDepthModel(
|
85 |
+
path=model_path,
|
86 |
+
backbone='vitl16_384',
|
87 |
+
non_negative=True,
|
88 |
+
)
|
89 |
+
net_w, net_h = 384, 384
|
90 |
+
resize_mode = 'minimal'
|
91 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
92 |
+
std=[0.5, 0.5, 0.5])
|
93 |
+
|
94 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
95 |
+
model = DPTDepthModel(
|
96 |
+
path=model_path,
|
97 |
+
backbone='vitb_rn50_384',
|
98 |
+
non_negative=True,
|
99 |
+
)
|
100 |
+
net_w, net_h = 384, 384
|
101 |
+
resize_mode = 'minimal'
|
102 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
103 |
+
std=[0.5, 0.5, 0.5])
|
104 |
+
|
105 |
+
elif model_type == 'midas_v21':
|
106 |
+
model = MidasNet(model_path, non_negative=True)
|
107 |
+
net_w, net_h = 384, 384
|
108 |
+
resize_mode = 'upper_bound'
|
109 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
110 |
+
std=[0.229, 0.224, 0.225])
|
111 |
+
|
112 |
+
elif model_type == 'midas_v21_small':
|
113 |
+
model = MidasNet_small(model_path,
|
114 |
+
features=64,
|
115 |
+
backbone='efficientnet_lite3',
|
116 |
+
exportable=True,
|
117 |
+
non_negative=True,
|
118 |
+
blocks={'expand': True})
|
119 |
+
net_w, net_h = 256, 256
|
120 |
+
resize_mode = 'upper_bound'
|
121 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
122 |
+
std=[0.229, 0.224, 0.225])
|
123 |
+
|
124 |
+
else:
|
125 |
+
print(
|
126 |
+
f"model_type '{model_type}' not implemented, use: --model_type large"
|
127 |
+
)
|
128 |
+
assert False
|
129 |
+
|
130 |
+
transform = Compose([
|
131 |
+
Resize(
|
132 |
+
net_w,
|
133 |
+
net_h,
|
134 |
+
resize_target=None,
|
135 |
+
keep_aspect_ratio=True,
|
136 |
+
ensure_multiple_of=32,
|
137 |
+
resize_method=resize_mode,
|
138 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
139 |
+
),
|
140 |
+
normalization,
|
141 |
+
PrepareForNet(),
|
142 |
+
])
|
143 |
+
|
144 |
+
return model.eval(), transform
|
145 |
+
|
146 |
+
|
147 |
+
class MiDaSInference(nn.Module):
|
148 |
+
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
149 |
+
MODEL_TYPES_ISL = [
|
150 |
+
'dpt_large',
|
151 |
+
'dpt_hybrid',
|
152 |
+
'midas_v21',
|
153 |
+
'midas_v21_small',
|
154 |
+
]
|
155 |
+
|
156 |
+
def __init__(self, model_type, model_path):
|
157 |
+
super().__init__()
|
158 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
159 |
+
model, _ = load_model(model_type, model_path)
|
160 |
+
self.model = model
|
161 |
+
self.model.train = disabled_train
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
with torch.no_grad():
|
165 |
+
prediction = self.model(x)
|
166 |
+
return prediction
|
vace/annotators/midas/base_model.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(torch.nn.Module):
|
7 |
+
def load(self, path):
|
8 |
+
"""Load model from file.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
path (str): file path
|
12 |
+
"""
|
13 |
+
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
|
14 |
+
|
15 |
+
if 'optimizer' in parameters:
|
16 |
+
parameters = parameters['model']
|
17 |
+
|
18 |
+
self.load_state_dict(parameters)
|
vace/annotators/midas/blocks.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
|
7 |
+
_make_pretrained_vitl16_384)
|
8 |
+
|
9 |
+
|
10 |
+
def _make_encoder(
|
11 |
+
backbone,
|
12 |
+
features,
|
13 |
+
use_pretrained,
|
14 |
+
groups=1,
|
15 |
+
expand=False,
|
16 |
+
exportable=True,
|
17 |
+
hooks=None,
|
18 |
+
use_vit_only=False,
|
19 |
+
use_readout='ignore',
|
20 |
+
):
|
21 |
+
if backbone == 'vitl16_384':
|
22 |
+
pretrained = _make_pretrained_vitl16_384(use_pretrained,
|
23 |
+
hooks=hooks,
|
24 |
+
use_readout=use_readout)
|
25 |
+
scratch = _make_scratch(
|
26 |
+
[256, 512, 1024, 1024], features, groups=groups,
|
27 |
+
expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
|
28 |
+
elif backbone == 'vitb_rn50_384':
|
29 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
30 |
+
use_pretrained,
|
31 |
+
hooks=hooks,
|
32 |
+
use_vit_only=use_vit_only,
|
33 |
+
use_readout=use_readout,
|
34 |
+
)
|
35 |
+
scratch = _make_scratch(
|
36 |
+
[256, 512, 768, 768], features, groups=groups,
|
37 |
+
expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
|
38 |
+
elif backbone == 'vitb16_384':
|
39 |
+
pretrained = _make_pretrained_vitb16_384(use_pretrained,
|
40 |
+
hooks=hooks,
|
41 |
+
use_readout=use_readout)
|
42 |
+
scratch = _make_scratch(
|
43 |
+
[96, 192, 384, 768], features, groups=groups,
|
44 |
+
expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
|
45 |
+
elif backbone == 'resnext101_wsl':
|
46 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
47 |
+
scratch = _make_scratch([256, 512, 1024, 2048],
|
48 |
+
features,
|
49 |
+
groups=groups,
|
50 |
+
expand=expand) # efficientnet_lite3
|
51 |
+
elif backbone == 'efficientnet_lite3':
|
52 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
|
53 |
+
exportable=exportable)
|
54 |
+
scratch = _make_scratch([32, 48, 136, 384],
|
55 |
+
features,
|
56 |
+
groups=groups,
|
57 |
+
expand=expand) # efficientnet_lite3
|
58 |
+
else:
|
59 |
+
print(f"Backbone '{backbone}' not implemented")
|
60 |
+
assert False
|
61 |
+
|
62 |
+
return pretrained, scratch
|
63 |
+
|
64 |
+
|
65 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
66 |
+
scratch = nn.Module()
|
67 |
+
|
68 |
+
out_shape1 = out_shape
|
69 |
+
out_shape2 = out_shape
|
70 |
+
out_shape3 = out_shape
|
71 |
+
out_shape4 = out_shape
|
72 |
+
if expand is True:
|
73 |
+
out_shape1 = out_shape
|
74 |
+
out_shape2 = out_shape * 2
|
75 |
+
out_shape3 = out_shape * 4
|
76 |
+
out_shape4 = out_shape * 8
|
77 |
+
|
78 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0],
|
79 |
+
out_shape1,
|
80 |
+
kernel_size=3,
|
81 |
+
stride=1,
|
82 |
+
padding=1,
|
83 |
+
bias=False,
|
84 |
+
groups=groups)
|
85 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1],
|
86 |
+
out_shape2,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
padding=1,
|
90 |
+
bias=False,
|
91 |
+
groups=groups)
|
92 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2],
|
93 |
+
out_shape3,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=1,
|
96 |
+
padding=1,
|
97 |
+
bias=False,
|
98 |
+
groups=groups)
|
99 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3],
|
100 |
+
out_shape4,
|
101 |
+
kernel_size=3,
|
102 |
+
stride=1,
|
103 |
+
padding=1,
|
104 |
+
bias=False,
|
105 |
+
groups=groups)
|
106 |
+
|
107 |
+
return scratch
|
108 |
+
|
109 |
+
|
110 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
111 |
+
efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
|
112 |
+
'tf_efficientnet_lite3',
|
113 |
+
pretrained=use_pretrained,
|
114 |
+
exportable=exportable)
|
115 |
+
return _make_efficientnet_backbone(efficientnet)
|
116 |
+
|
117 |
+
|
118 |
+
def _make_efficientnet_backbone(effnet):
|
119 |
+
pretrained = nn.Module()
|
120 |
+
|
121 |
+
pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
|
122 |
+
effnet.act1, *effnet.blocks[0:2])
|
123 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
124 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
125 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
126 |
+
|
127 |
+
return pretrained
|
128 |
+
|
129 |
+
|
130 |
+
def _make_resnet_backbone(resnet):
|
131 |
+
pretrained = nn.Module()
|
132 |
+
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
|
133 |
+
resnet.maxpool, resnet.layer1)
|
134 |
+
|
135 |
+
pretrained.layer2 = resnet.layer2
|
136 |
+
pretrained.layer3 = resnet.layer3
|
137 |
+
pretrained.layer4 = resnet.layer4
|
138 |
+
|
139 |
+
return pretrained
|
140 |
+
|
141 |
+
|
142 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
143 |
+
resnet = torch.hub.load('facebookresearch/WSL-Images',
|
144 |
+
'resnext101_32x8d_wsl')
|
145 |
+
return _make_resnet_backbone(resnet)
|
146 |
+
|
147 |
+
|
148 |
+
class Interpolate(nn.Module):
|
149 |
+
"""Interpolation module.
|
150 |
+
"""
|
151 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
152 |
+
"""Init.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
scale_factor (float): scaling
|
156 |
+
mode (str): interpolation mode
|
157 |
+
"""
|
158 |
+
super(Interpolate, self).__init__()
|
159 |
+
|
160 |
+
self.interp = nn.functional.interpolate
|
161 |
+
self.scale_factor = scale_factor
|
162 |
+
self.mode = mode
|
163 |
+
self.align_corners = align_corners
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
"""Forward pass.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
x (tensor): input
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
tensor: interpolated data
|
173 |
+
"""
|
174 |
+
|
175 |
+
x = self.interp(x,
|
176 |
+
scale_factor=self.scale_factor,
|
177 |
+
mode=self.mode,
|
178 |
+
align_corners=self.align_corners)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class ResidualConvUnit(nn.Module):
|
184 |
+
"""Residual convolution module.
|
185 |
+
"""
|
186 |
+
def __init__(self, features):
|
187 |
+
"""Init.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
features (int): number of features
|
191 |
+
"""
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.conv1 = nn.Conv2d(features,
|
195 |
+
features,
|
196 |
+
kernel_size=3,
|
197 |
+
stride=1,
|
198 |
+
padding=1,
|
199 |
+
bias=True)
|
200 |
+
|
201 |
+
self.conv2 = nn.Conv2d(features,
|
202 |
+
features,
|
203 |
+
kernel_size=3,
|
204 |
+
stride=1,
|
205 |
+
padding=1,
|
206 |
+
bias=True)
|
207 |
+
|
208 |
+
self.relu = nn.ReLU(inplace=True)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
"""Forward pass.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x (tensor): input
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
tensor: output
|
218 |
+
"""
|
219 |
+
out = self.relu(x)
|
220 |
+
out = self.conv1(out)
|
221 |
+
out = self.relu(out)
|
222 |
+
out = self.conv2(out)
|
223 |
+
|
224 |
+
return out + x
|
225 |
+
|
226 |
+
|
227 |
+
class FeatureFusionBlock(nn.Module):
|
228 |
+
"""Feature fusion block.
|
229 |
+
"""
|
230 |
+
def __init__(self, features):
|
231 |
+
"""Init.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
features (int): number of features
|
235 |
+
"""
|
236 |
+
super(FeatureFusionBlock, self).__init__()
|
237 |
+
|
238 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
239 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
240 |
+
|
241 |
+
def forward(self, *xs):
|
242 |
+
"""Forward pass.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
tensor: output
|
246 |
+
"""
|
247 |
+
output = xs[0]
|
248 |
+
|
249 |
+
if len(xs) == 2:
|
250 |
+
output += self.resConfUnit1(xs[1])
|
251 |
+
|
252 |
+
output = self.resConfUnit2(output)
|
253 |
+
|
254 |
+
output = nn.functional.interpolate(output,
|
255 |
+
scale_factor=2,
|
256 |
+
mode='bilinear',
|
257 |
+
align_corners=True)
|
258 |
+
|
259 |
+
return output
|
260 |
+
|
261 |
+
|
262 |
+
class ResidualConvUnit_custom(nn.Module):
|
263 |
+
"""Residual convolution module.
|
264 |
+
"""
|
265 |
+
def __init__(self, features, activation, bn):
|
266 |
+
"""Init.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
features (int): number of features
|
270 |
+
"""
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
self.bn = bn
|
274 |
+
|
275 |
+
self.groups = 1
|
276 |
+
|
277 |
+
self.conv1 = nn.Conv2d(features,
|
278 |
+
features,
|
279 |
+
kernel_size=3,
|
280 |
+
stride=1,
|
281 |
+
padding=1,
|
282 |
+
bias=True,
|
283 |
+
groups=self.groups)
|
284 |
+
|
285 |
+
self.conv2 = nn.Conv2d(features,
|
286 |
+
features,
|
287 |
+
kernel_size=3,
|
288 |
+
stride=1,
|
289 |
+
padding=1,
|
290 |
+
bias=True,
|
291 |
+
groups=self.groups)
|
292 |
+
|
293 |
+
if self.bn is True:
|
294 |
+
self.bn1 = nn.BatchNorm2d(features)
|
295 |
+
self.bn2 = nn.BatchNorm2d(features)
|
296 |
+
|
297 |
+
self.activation = activation
|
298 |
+
|
299 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
"""Forward pass.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
x (tensor): input
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
tensor: output
|
309 |
+
"""
|
310 |
+
|
311 |
+
out = self.activation(x)
|
312 |
+
out = self.conv1(out)
|
313 |
+
if self.bn is True:
|
314 |
+
out = self.bn1(out)
|
315 |
+
|
316 |
+
out = self.activation(out)
|
317 |
+
out = self.conv2(out)
|
318 |
+
if self.bn is True:
|
319 |
+
out = self.bn2(out)
|
320 |
+
|
321 |
+
if self.groups > 1:
|
322 |
+
out = self.conv_merge(out)
|
323 |
+
|
324 |
+
return self.skip_add.add(out, x)
|
325 |
+
|
326 |
+
# return out + x
|
327 |
+
|
328 |
+
|
329 |
+
class FeatureFusionBlock_custom(nn.Module):
|
330 |
+
"""Feature fusion block.
|
331 |
+
"""
|
332 |
+
def __init__(self,
|
333 |
+
features,
|
334 |
+
activation,
|
335 |
+
deconv=False,
|
336 |
+
bn=False,
|
337 |
+
expand=False,
|
338 |
+
align_corners=True):
|
339 |
+
"""Init.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
features (int): number of features
|
343 |
+
"""
|
344 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
345 |
+
|
346 |
+
self.deconv = deconv
|
347 |
+
self.align_corners = align_corners
|
348 |
+
|
349 |
+
self.groups = 1
|
350 |
+
|
351 |
+
self.expand = expand
|
352 |
+
out_features = features
|
353 |
+
if self.expand is True:
|
354 |
+
out_features = features // 2
|
355 |
+
|
356 |
+
self.out_conv = nn.Conv2d(features,
|
357 |
+
out_features,
|
358 |
+
kernel_size=1,
|
359 |
+
stride=1,
|
360 |
+
padding=0,
|
361 |
+
bias=True,
|
362 |
+
groups=1)
|
363 |
+
|
364 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
365 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
366 |
+
|
367 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
368 |
+
|
369 |
+
def forward(self, *xs):
|
370 |
+
"""Forward pass.
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
tensor: output
|
374 |
+
"""
|
375 |
+
output = xs[0]
|
376 |
+
|
377 |
+
if len(xs) == 2:
|
378 |
+
res = self.resConfUnit1(xs[1])
|
379 |
+
output = self.skip_add.add(output, res)
|
380 |
+
# output += res
|
381 |
+
|
382 |
+
output = self.resConfUnit2(output)
|
383 |
+
|
384 |
+
output = nn.functional.interpolate(output,
|
385 |
+
scale_factor=2,
|
386 |
+
mode='bilinear',
|
387 |
+
align_corners=self.align_corners)
|
388 |
+
|
389 |
+
output = self.out_conv(output)
|
390 |
+
|
391 |
+
return output
|
vace/annotators/midas/dpt_depth.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .base_model import BaseModel
|
7 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
8 |
+
from .vit import forward_vit
|
9 |
+
|
10 |
+
|
11 |
+
def _make_fusion_block(features, use_bn):
|
12 |
+
return FeatureFusionBlock_custom(
|
13 |
+
features,
|
14 |
+
nn.ReLU(False),
|
15 |
+
deconv=False,
|
16 |
+
bn=use_bn,
|
17 |
+
expand=False,
|
18 |
+
align_corners=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DPT(BaseModel):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
head,
|
26 |
+
features=256,
|
27 |
+
backbone='vitb_rn50_384',
|
28 |
+
readout='project',
|
29 |
+
channels_last=False,
|
30 |
+
use_bn=False,
|
31 |
+
):
|
32 |
+
|
33 |
+
super(DPT, self).__init__()
|
34 |
+
|
35 |
+
self.channels_last = channels_last
|
36 |
+
|
37 |
+
hooks = {
|
38 |
+
'vitb_rn50_384': [0, 1, 8, 11],
|
39 |
+
'vitb16_384': [2, 5, 8, 11],
|
40 |
+
'vitl16_384': [5, 11, 17, 23],
|
41 |
+
}
|
42 |
+
|
43 |
+
# Instantiate backbone and reassemble blocks
|
44 |
+
self.pretrained, self.scratch = _make_encoder(
|
45 |
+
backbone,
|
46 |
+
features,
|
47 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
48 |
+
groups=1,
|
49 |
+
expand=False,
|
50 |
+
exportable=False,
|
51 |
+
hooks=hooks[backbone],
|
52 |
+
use_readout=readout,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
56 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
57 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
58 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
59 |
+
|
60 |
+
self.scratch.output_conv = head
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
if self.channels_last is True:
|
64 |
+
x.contiguous(memory_format=torch.channels_last)
|
65 |
+
|
66 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
67 |
+
|
68 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
69 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
70 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
71 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
72 |
+
|
73 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
74 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
75 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
76 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
77 |
+
|
78 |
+
out = self.scratch.output_conv(path_1)
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class DPTDepthModel(DPT):
|
84 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
85 |
+
features = kwargs['features'] if 'features' in kwargs else 256
|
86 |
+
|
87 |
+
head = nn.Sequential(
|
88 |
+
nn.Conv2d(features,
|
89 |
+
features // 2,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1),
|
93 |
+
Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
|
94 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
95 |
+
nn.ReLU(True),
|
96 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
97 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
98 |
+
nn.Identity(),
|
99 |
+
)
|
100 |
+
|
101 |
+
super().__init__(head, **kwargs)
|
102 |
+
|
103 |
+
if path is not None:
|
104 |
+
self.load(path)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
return super().forward(x).squeeze(dim=1)
|
vace/annotators/midas/midas_net.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
4 |
+
This file contains code that is adapted from
|
5 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .base_model import BaseModel
|
11 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class MidasNet(BaseModel):
|
15 |
+
"""Network for monocular depth estimation.
|
16 |
+
"""
|
17 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print('Loading weights: ', path)
|
26 |
+
|
27 |
+
super(MidasNet, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path is None else True
|
30 |
+
|
31 |
+
self.pretrained, self.scratch = _make_encoder(
|
32 |
+
backbone='resnext101_wsl',
|
33 |
+
features=features,
|
34 |
+
use_pretrained=use_pretrained)
|
35 |
+
|
36 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
37 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
38 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
39 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
40 |
+
|
41 |
+
self.scratch.output_conv = nn.Sequential(
|
42 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
43 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
44 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
45 |
+
nn.ReLU(True),
|
46 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
47 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
48 |
+
)
|
49 |
+
|
50 |
+
if path:
|
51 |
+
self.load(path)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""Forward pass.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (tensor): input data (image)
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tensor: depth
|
61 |
+
"""
|
62 |
+
|
63 |
+
layer_1 = self.pretrained.layer1(x)
|
64 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
65 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
66 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
67 |
+
|
68 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
69 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
70 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
71 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
72 |
+
|
73 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
74 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
75 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
76 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
77 |
+
|
78 |
+
out = self.scratch.output_conv(path_1)
|
79 |
+
|
80 |
+
return torch.squeeze(out, dim=1)
|
vace/annotators/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
4 |
+
This file contains code that is adapted from
|
5 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .base_model import BaseModel
|
11 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class MidasNet_small(BaseModel):
|
15 |
+
"""Network for monocular depth estimation.
|
16 |
+
"""
|
17 |
+
def __init__(self,
|
18 |
+
path=None,
|
19 |
+
features=64,
|
20 |
+
backbone='efficientnet_lite3',
|
21 |
+
non_negative=True,
|
22 |
+
exportable=True,
|
23 |
+
channels_last=False,
|
24 |
+
align_corners=True,
|
25 |
+
blocks={'expand': True}):
|
26 |
+
"""Init.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
path (str, optional): Path to saved model. Defaults to None.
|
30 |
+
features (int, optional): Number of features. Defaults to 256.
|
31 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
32 |
+
"""
|
33 |
+
print('Loading weights: ', path)
|
34 |
+
|
35 |
+
super(MidasNet_small, self).__init__()
|
36 |
+
|
37 |
+
use_pretrained = False if path else True
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
self.blocks = blocks
|
41 |
+
self.backbone = backbone
|
42 |
+
|
43 |
+
self.groups = 1
|
44 |
+
|
45 |
+
features1 = features
|
46 |
+
features2 = features
|
47 |
+
features3 = features
|
48 |
+
features4 = features
|
49 |
+
self.expand = False
|
50 |
+
if 'expand' in self.blocks and self.blocks['expand'] is True:
|
51 |
+
self.expand = True
|
52 |
+
features1 = features
|
53 |
+
features2 = features * 2
|
54 |
+
features3 = features * 4
|
55 |
+
features4 = features * 8
|
56 |
+
|
57 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone,
|
58 |
+
features,
|
59 |
+
use_pretrained,
|
60 |
+
groups=self.groups,
|
61 |
+
expand=self.expand,
|
62 |
+
exportable=exportable)
|
63 |
+
|
64 |
+
self.scratch.activation = nn.ReLU(False)
|
65 |
+
|
66 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(
|
67 |
+
features4,
|
68 |
+
self.scratch.activation,
|
69 |
+
deconv=False,
|
70 |
+
bn=False,
|
71 |
+
expand=self.expand,
|
72 |
+
align_corners=align_corners)
|
73 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(
|
74 |
+
features3,
|
75 |
+
self.scratch.activation,
|
76 |
+
deconv=False,
|
77 |
+
bn=False,
|
78 |
+
expand=self.expand,
|
79 |
+
align_corners=align_corners)
|
80 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(
|
81 |
+
features2,
|
82 |
+
self.scratch.activation,
|
83 |
+
deconv=False,
|
84 |
+
bn=False,
|
85 |
+
expand=self.expand,
|
86 |
+
align_corners=align_corners)
|
87 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(
|
88 |
+
features1,
|
89 |
+
self.scratch.activation,
|
90 |
+
deconv=False,
|
91 |
+
bn=False,
|
92 |
+
align_corners=align_corners)
|
93 |
+
|
94 |
+
self.scratch.output_conv = nn.Sequential(
|
95 |
+
nn.Conv2d(features,
|
96 |
+
features // 2,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1,
|
100 |
+
groups=self.groups),
|
101 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
102 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
103 |
+
self.scratch.activation,
|
104 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
105 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
106 |
+
nn.Identity(),
|
107 |
+
)
|
108 |
+
|
109 |
+
if path:
|
110 |
+
self.load(path)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
"""Forward pass.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x (tensor): input data (image)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
tensor: depth
|
120 |
+
"""
|
121 |
+
if self.channels_last is True:
|
122 |
+
print('self.channels_last = ', self.channels_last)
|
123 |
+
x.contiguous(memory_format=torch.channels_last)
|
124 |
+
|
125 |
+
layer_1 = self.pretrained.layer1(x)
|
126 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
127 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
128 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
129 |
+
|
130 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
131 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
132 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
133 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
134 |
+
|
135 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
136 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
137 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
138 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
139 |
+
|
140 |
+
out = self.scratch.output_conv(path_1)
|
141 |
+
|
142 |
+
return torch.squeeze(out, dim=1)
|
143 |
+
|
144 |
+
|
145 |
+
def fuse_model(m):
|
146 |
+
prev_previous_type = nn.Identity()
|
147 |
+
prev_previous_name = ''
|
148 |
+
previous_type = nn.Identity()
|
149 |
+
previous_name = ''
|
150 |
+
for name, module in m.named_modules():
|
151 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
|
152 |
+
module) == nn.ReLU:
|
153 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
154 |
+
torch.quantization.fuse_modules(
|
155 |
+
m, [prev_previous_name, previous_name, name], inplace=True)
|
156 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
157 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
158 |
+
torch.quantization.fuse_modules(
|
159 |
+
m, [prev_previous_name, previous_name], inplace=True)
|
160 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
161 |
+
# print("FUSED ", previous_name, name)
|
162 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
163 |
+
|
164 |
+
prev_previous_type = previous_type
|
165 |
+
prev_previous_name = previous_name
|
166 |
+
previous_type = type(module)
|
167 |
+
previous_name = name
|
vace/annotators/midas/transforms.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
10 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
sample (dict): sample
|
14 |
+
size (tuple): image size
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple: new size
|
18 |
+
"""
|
19 |
+
shape = list(sample['disparity'].shape)
|
20 |
+
|
21 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
22 |
+
return sample
|
23 |
+
|
24 |
+
scale = [0, 0]
|
25 |
+
scale[0] = size[0] / shape[0]
|
26 |
+
scale[1] = size[1] / shape[1]
|
27 |
+
|
28 |
+
scale = max(scale)
|
29 |
+
|
30 |
+
shape[0] = math.ceil(scale * shape[0])
|
31 |
+
shape[1] = math.ceil(scale * shape[1])
|
32 |
+
|
33 |
+
# resize
|
34 |
+
sample['image'] = cv2.resize(sample['image'],
|
35 |
+
tuple(shape[::-1]),
|
36 |
+
interpolation=image_interpolation_method)
|
37 |
+
|
38 |
+
sample['disparity'] = cv2.resize(sample['disparity'],
|
39 |
+
tuple(shape[::-1]),
|
40 |
+
interpolation=cv2.INTER_NEAREST)
|
41 |
+
sample['mask'] = cv2.resize(
|
42 |
+
sample['mask'].astype(np.float32),
|
43 |
+
tuple(shape[::-1]),
|
44 |
+
interpolation=cv2.INTER_NEAREST,
|
45 |
+
)
|
46 |
+
sample['mask'] = sample['mask'].astype(bool)
|
47 |
+
|
48 |
+
return tuple(shape)
|
49 |
+
|
50 |
+
|
51 |
+
class Resize(object):
|
52 |
+
"""Resize sample to given size (width, height).
|
53 |
+
"""
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
width,
|
57 |
+
height,
|
58 |
+
resize_target=True,
|
59 |
+
keep_aspect_ratio=False,
|
60 |
+
ensure_multiple_of=1,
|
61 |
+
resize_method='lower_bound',
|
62 |
+
image_interpolation_method=cv2.INTER_AREA,
|
63 |
+
):
|
64 |
+
"""Init.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
width (int): desired output width
|
68 |
+
height (int): desired output height
|
69 |
+
resize_target (bool, optional):
|
70 |
+
True: Resize the full sample (image, mask, target).
|
71 |
+
False: Resize image only.
|
72 |
+
Defaults to True.
|
73 |
+
keep_aspect_ratio (bool, optional):
|
74 |
+
True: Keep the aspect ratio of the input sample.
|
75 |
+
Output sample might not have the given width and height, and
|
76 |
+
resize behaviour depends on the parameter 'resize_method'.
|
77 |
+
Defaults to False.
|
78 |
+
ensure_multiple_of (int, optional):
|
79 |
+
Output width and height is constrained to be multiple of this parameter.
|
80 |
+
Defaults to 1.
|
81 |
+
resize_method (str, optional):
|
82 |
+
"lower_bound": Output will be at least as large as the given size.
|
83 |
+
"upper_bound": Output will be at max as large as the given size. "
|
84 |
+
"(Output size might be smaller than given size.)"
|
85 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
86 |
+
Defaults to "lower_bound".
|
87 |
+
"""
|
88 |
+
self.__width = width
|
89 |
+
self.__height = height
|
90 |
+
|
91 |
+
self.__resize_target = resize_target
|
92 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
93 |
+
self.__multiple_of = ensure_multiple_of
|
94 |
+
self.__resize_method = resize_method
|
95 |
+
self.__image_interpolation_method = image_interpolation_method
|
96 |
+
|
97 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
98 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if max_val is not None and y > max_val:
|
101 |
+
y = (np.floor(x / self.__multiple_of) *
|
102 |
+
self.__multiple_of).astype(int)
|
103 |
+
|
104 |
+
if y < min_val:
|
105 |
+
y = (np.ceil(x / self.__multiple_of) *
|
106 |
+
self.__multiple_of).astype(int)
|
107 |
+
|
108 |
+
return y
|
109 |
+
|
110 |
+
def get_size(self, width, height):
|
111 |
+
# determine new height and width
|
112 |
+
scale_height = self.__height / height
|
113 |
+
scale_width = self.__width / width
|
114 |
+
|
115 |
+
if self.__keep_aspect_ratio:
|
116 |
+
if self.__resize_method == 'lower_bound':
|
117 |
+
# scale such that output size is lower bound
|
118 |
+
if scale_width > scale_height:
|
119 |
+
# fit width
|
120 |
+
scale_height = scale_width
|
121 |
+
else:
|
122 |
+
# fit height
|
123 |
+
scale_width = scale_height
|
124 |
+
elif self.__resize_method == 'upper_bound':
|
125 |
+
# scale such that output size is upper bound
|
126 |
+
if scale_width < scale_height:
|
127 |
+
# fit width
|
128 |
+
scale_height = scale_width
|
129 |
+
else:
|
130 |
+
# fit height
|
131 |
+
scale_width = scale_height
|
132 |
+
elif self.__resize_method == 'minimal':
|
133 |
+
# scale as least as possbile
|
134 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
135 |
+
# fit width
|
136 |
+
scale_height = scale_width
|
137 |
+
else:
|
138 |
+
# fit height
|
139 |
+
scale_width = scale_height
|
140 |
+
else:
|
141 |
+
raise ValueError(
|
142 |
+
f'resize_method {self.__resize_method} not implemented')
|
143 |
+
|
144 |
+
if self.__resize_method == 'lower_bound':
|
145 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
146 |
+
min_val=self.__height)
|
147 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
148 |
+
min_val=self.__width)
|
149 |
+
elif self.__resize_method == 'upper_bound':
|
150 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
151 |
+
max_val=self.__height)
|
152 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
153 |
+
max_val=self.__width)
|
154 |
+
elif self.__resize_method == 'minimal':
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(
|
159 |
+
f'resize_method {self.__resize_method} not implemented')
|
160 |
+
|
161 |
+
return (new_width, new_height)
|
162 |
+
|
163 |
+
def __call__(self, sample):
|
164 |
+
width, height = self.get_size(sample['image'].shape[1],
|
165 |
+
sample['image'].shape[0])
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample['image'] = cv2.resize(
|
169 |
+
sample['image'],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if 'disparity' in sample:
|
176 |
+
sample['disparity'] = cv2.resize(
|
177 |
+
sample['disparity'],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if 'depth' in sample:
|
183 |
+
sample['depth'] = cv2.resize(sample['depth'], (width, height),
|
184 |
+
interpolation=cv2.INTER_NEAREST)
|
185 |
+
|
186 |
+
sample['mask'] = cv2.resize(
|
187 |
+
sample['mask'].astype(np.float32),
|
188 |
+
(width, height),
|
189 |
+
interpolation=cv2.INTER_NEAREST,
|
190 |
+
)
|
191 |
+
sample['mask'] = sample['mask'].astype(bool)
|
192 |
+
|
193 |
+
return sample
|
194 |
+
|
195 |
+
|
196 |
+
class NormalizeImage(object):
|
197 |
+
"""Normlize image by given mean and std.
|
198 |
+
"""
|
199 |
+
def __init__(self, mean, std):
|
200 |
+
self.__mean = mean
|
201 |
+
self.__std = std
|
202 |
+
|
203 |
+
def __call__(self, sample):
|
204 |
+
sample['image'] = (sample['image'] - self.__mean) / self.__std
|
205 |
+
|
206 |
+
return sample
|
207 |
+
|
208 |
+
|
209 |
+
class PrepareForNet(object):
|
210 |
+
"""Prepare sample for usage as network input.
|
211 |
+
"""
|
212 |
+
def __init__(self):
|
213 |
+
pass
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
image = np.transpose(sample['image'], (2, 0, 1))
|
217 |
+
sample['image'] = np.ascontiguousarray(image).astype(np.float32)
|
218 |
+
|
219 |
+
if 'mask' in sample:
|
220 |
+
sample['mask'] = sample['mask'].astype(np.float32)
|
221 |
+
sample['mask'] = np.ascontiguousarray(sample['mask'])
|
222 |
+
|
223 |
+
if 'disparity' in sample:
|
224 |
+
disparity = sample['disparity'].astype(np.float32)
|
225 |
+
sample['disparity'] = np.ascontiguousarray(disparity)
|
226 |
+
|
227 |
+
if 'depth' in sample:
|
228 |
+
depth = sample['depth'].astype(np.float32)
|
229 |
+
sample['depth'] = np.ascontiguousarray(depth)
|
230 |
+
|
231 |
+
return sample
|
vace/annotators/midas/utils.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""Utils for monoDepth."""
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def read_pfm(path):
|
13 |
+
"""Read pfm file.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
path (str): path to file
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple: (data, scale)
|
20 |
+
"""
|
21 |
+
with open(path, 'rb') as file:
|
22 |
+
|
23 |
+
color = None
|
24 |
+
width = None
|
25 |
+
height = None
|
26 |
+
scale = None
|
27 |
+
endian = None
|
28 |
+
|
29 |
+
header = file.readline().rstrip()
|
30 |
+
if header.decode('ascii') == 'PF':
|
31 |
+
color = True
|
32 |
+
elif header.decode('ascii') == 'Pf':
|
33 |
+
color = False
|
34 |
+
else:
|
35 |
+
raise Exception('Not a PFM file: ' + path)
|
36 |
+
|
37 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
|
38 |
+
file.readline().decode('ascii'))
|
39 |
+
if dim_match:
|
40 |
+
width, height = list(map(int, dim_match.groups()))
|
41 |
+
else:
|
42 |
+
raise Exception('Malformed PFM header.')
|
43 |
+
|
44 |
+
scale = float(file.readline().decode('ascii').rstrip())
|
45 |
+
if scale < 0:
|
46 |
+
# little-endian
|
47 |
+
endian = '<'
|
48 |
+
scale = -scale
|
49 |
+
else:
|
50 |
+
# big-endian
|
51 |
+
endian = '>'
|
52 |
+
|
53 |
+
data = np.fromfile(file, endian + 'f')
|
54 |
+
shape = (height, width, 3) if color else (height, width)
|
55 |
+
|
56 |
+
data = np.reshape(data, shape)
|
57 |
+
data = np.flipud(data)
|
58 |
+
|
59 |
+
return data, scale
|
60 |
+
|
61 |
+
|
62 |
+
def write_pfm(path, image, scale=1):
|
63 |
+
"""Write pfm file.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
path (str): pathto file
|
67 |
+
image (array): data
|
68 |
+
scale (int, optional): Scale. Defaults to 1.
|
69 |
+
"""
|
70 |
+
|
71 |
+
with open(path, 'wb') as file:
|
72 |
+
color = None
|
73 |
+
|
74 |
+
if image.dtype.name != 'float32':
|
75 |
+
raise Exception('Image dtype must be float32.')
|
76 |
+
|
77 |
+
image = np.flipud(image)
|
78 |
+
|
79 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
80 |
+
color = True
|
81 |
+
elif (len(image.shape) == 2
|
82 |
+
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
|
83 |
+
color = False
|
84 |
+
else:
|
85 |
+
raise Exception(
|
86 |
+
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
87 |
+
|
88 |
+
file.write('PF\n' if color else 'Pf\n'.encode())
|
89 |
+
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
|
90 |
+
|
91 |
+
endian = image.dtype.byteorder
|
92 |
+
|
93 |
+
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
94 |
+
scale = -scale
|
95 |
+
|
96 |
+
file.write('%f\n'.encode() % scale)
|
97 |
+
|
98 |
+
image.tofile(file)
|
99 |
+
|
100 |
+
|
101 |
+
def read_image(path):
|
102 |
+
"""Read image and output RGB image (0-1).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
path (str): path to file
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
array: RGB image (0-1)
|
109 |
+
"""
|
110 |
+
img = cv2.imread(path)
|
111 |
+
|
112 |
+
if img.ndim == 2:
|
113 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
114 |
+
|
115 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
116 |
+
|
117 |
+
return img
|
118 |
+
|
119 |
+
|
120 |
+
def resize_image(img):
|
121 |
+
"""Resize image and make it fit for network.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
img (array): image
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tensor: data ready for network
|
128 |
+
"""
|
129 |
+
height_orig = img.shape[0]
|
130 |
+
width_orig = img.shape[1]
|
131 |
+
|
132 |
+
if width_orig > height_orig:
|
133 |
+
scale = width_orig / 384
|
134 |
+
else:
|
135 |
+
scale = height_orig / 384
|
136 |
+
|
137 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
138 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
139 |
+
|
140 |
+
img_resized = cv2.resize(img, (width, height),
|
141 |
+
interpolation=cv2.INTER_AREA)
|
142 |
+
|
143 |
+
img_resized = (torch.from_numpy(np.transpose(
|
144 |
+
img_resized, (2, 0, 1))).contiguous().float())
|
145 |
+
img_resized = img_resized.unsqueeze(0)
|
146 |
+
|
147 |
+
return img_resized
|
148 |
+
|
149 |
+
|
150 |
+
def resize_depth(depth, width, height):
|
151 |
+
"""Resize depth map and bring to CPU (numpy).
|
152 |
+
|
153 |
+
Args:
|
154 |
+
depth (tensor): depth
|
155 |
+
width (int): image width
|
156 |
+
height (int): image height
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
array: processed depth
|
160 |
+
"""
|
161 |
+
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
|
162 |
+
|
163 |
+
depth_resized = cv2.resize(depth.numpy(), (width, height),
|
164 |
+
interpolation=cv2.INTER_CUBIC)
|
165 |
+
|
166 |
+
return depth_resized
|
167 |
+
|
168 |
+
|
169 |
+
def write_depth(path, depth, bits=1):
|
170 |
+
"""Write depth map to pfm and png file.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
path (str): filepath without extension
|
174 |
+
depth (array): depth
|
175 |
+
"""
|
176 |
+
write_pfm(path + '.pfm', depth.astype(np.float32))
|
177 |
+
|
178 |
+
depth_min = depth.min()
|
179 |
+
depth_max = depth.max()
|
180 |
+
|
181 |
+
max_val = (2**(8 * bits)) - 1
|
182 |
+
|
183 |
+
if depth_max - depth_min > np.finfo('float').eps:
|
184 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
185 |
+
else:
|
186 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
187 |
+
|
188 |
+
if bits == 1:
|
189 |
+
cv2.imwrite(path + '.png', out.astype('uint8'))
|
190 |
+
elif bits == 2:
|
191 |
+
cv2.imwrite(path + '.png', out.astype('uint16'))
|
192 |
+
|
193 |
+
return
|