File size: 8,785 Bytes
57876e1
d58b8ba
 
 
 
 
 
 
57876e1
 
 
 
 
 
 
d58b8ba
57876e1
 
 
 
1942098
57876e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1942098
57876e1
 
 
 
 
 
 
3f4b0ee
 
 
 
 
57876e1
 
 
 
 
 
3f4b0ee
57876e1
 
 
 
 
 
 
 
 
 
 
3128011
57876e1
 
 
 
 
3222587
57876e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad7bb0b
57876e1
ad7bb0b
 
57876e1
 
 
 
 
 
 
 
 
 
 
8cd5ef4
57876e1
 
 
 
 
 
 
 
 
 
 
8cd5ef4
 
e58a491
 
 
 
 
 
9b1b6c3
 
8cd5ef4
 
 
57876e1
 
 
 
 
3222587
 
 
 
57876e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import spaces
import gradio as gr
import random
import numpy as np
import torch
from torchvision.transforms.functional import center_crop

try:
    # Try to install detectron2 from source. Needed for semseg plotting functionality.
    os.system("python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
except Exception as e:
    print('detectron2 cannot be installed. Falling back to simple semseg visualization.')
    print(e)


# We recommend running this demo on an A100 GPU
if torch.cuda.is_available():
    device = "cuda"
    gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
    power_device = f"{gpu_type}"
    torch.cuda.max_memory_allocated(device=device)
else:
    device = "cpu"
    power_device = "CPU"
    os.system("pip uninstall -y xformers") # Only use xformers on GPU

from fourm.demo_4M_sampler import Demo4MSampler
from fourm.data.modality_transforms import RGBTransform


# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

MAX_SEED = np.iinfo(np.int32).max

FM_MODEL_ID = 'EPFL-VILAB/4M-21_XL'
MODEL_NAME = FM_MODEL_ID.split('/')[1].replace('_', ' ')

# Human poses visualization is disabled, since it needs SMPL weights. To enable human pose prediction and rendering:
# 1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4`
#    You may need to follow the pyrender install instructions: https://pyrender.readthedocs.io/en/latest/install/index.html
# 2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example
# 3) Copy the required SMPL files (smpl_mean_params.npz, SMPL_to_J19.pkl, smpl/SMPL_NEUTRAL.pkl) to fourm/utils/hmr2_utils/data .
MANUAL_MODS_OVERRIDE = [
    'color_palette', 'tok_depth@224', 'tok_imagebind@224', 'sam_instance', 'tok_dinov2_global', 
    'tok_normal@224', 'tok_sam_edge@224', 'det', 'tok_canny_edge@224', 'tok_semseg@224', 'rgb@224', 
    'caption', 't5_caption', 'tok_imagebind_global', 'tok_rgb@224', 'tok_clip@224', 'metadata', 'tok_dinov2@224'
]

sampler = Demo4MSampler(
    fm=FM_MODEL_ID, 
    fm_sr=None, 
    tok_human_poses=None, 
    tok_text='./text_tokenizer_4m_wordpiece_30k.json',
    mods=MANUAL_MODS_OVERRIDE,
).to(device)


def img_from_path(img_path: str):
    rgb_transform = RGBTransform(imagenet_default_mean_and_std=True)
    img_pil = rgb_transform.load(img_path)
    img_pil = rgb_transform.preprocess(img_pil)
    img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224))
    img = rgb_transform.postprocess(img_pil).unsqueeze(0)
    return img

@spaces.GPU(duration=100)
def infer(img_path, seed=0, randomize_seed=False, target_modalities=None, top_p=0.8, top_k=0.0):
    if randomize_seed:
        seed = None
    img = img_from_path(img_path).to(device)
    preds = sampler({'rgb@224': img}, seed=seed, target_modalities=target_modalities, top_p=top_p, top_k=top_k) 
    return sampler.modalities_to_pil(preds, use_fixed_plotting_order=True, resize=512)


examples = [
    'examples/example_0.png', 'examples/example_1.png', 'examples/example_2.png',
    'examples/example_3.png', 'examples/example_4.png', 'examples/example_5.png',
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 1500px;
}
#col-input-container {
    margin: 0 auto;
    max-width: 400px;
}
#run-button {
    margin: 0 auto;
}
"""

with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # 4M: Massively Multimodal Masked Modeling
        """)
        
        with gr.Row():
            with gr.Column(elem_id="col-input-container"):
                gr.Markdown(f"""
                *any-to-any に対応するマルチモーダル基盤モデルをトレーニングするためのフレームワーク。スケーラブル。オープンソース。数十のモダリティとタスクに対応。*
                [`Website`](https://4m.epfl.ch) | [`GitHub`](https://github.com/apple/ml-4m) <br>[`4M Paper (NeurIPS'23)`](https://arxiv.org/abs/2312.06647) | [`4M-21 Paper (arXiv'24)`](https://arxiv.org/abs/2406.09406)
                このデモは、[{FM_MODEL_ID}](https://huggingface.co/{FM_MODEL_ID}) を使用して、*{power_device}* 上で実行され、指定された RGB 入力からすべてのモダリティを予測します。
                その他の生成的な any-to-any の例については、[GitHub リポジトリ](https://github.com/apple/ml-4m#generation) をご覧ください。
                """)
                
                img_path = gr.Image(label='RGB input image', type='filepath')
                run_button = gr.Button(f"Predict with {MODEL_NAME}", scale=0, elem_id="run-button")

                with gr.Accordion("Advanced Settings", open=False):
                    target_modalities = gr.CheckboxGroup(
                        choices=[
                            ('CLIP-B/16', 'tok_clip@224'), ('DINOv2-B/14', 'tok_dinov2@224'), ('ImageBind-H/14', 'tok_imagebind@224'), 
                            ('Depth', 'tok_depth@224'), ('Surface normals', 'tok_normal@224'), ('Semantic segmentation', 'tok_semseg@224'), 
                            ('Canny edges', 'tok_canny_edge@224'), ('SAM edges', 'tok_sam_edge@224'), ('Caption', 'caption'), 
                            ('Bounding boxes', 'det'), ('SAM instances (single pass*)', 'sam_instance'), ('Color palette', 'color_palette'), 
                            ('Metadata', 'metadata'),
                        ],
                        value=[
                            'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 
                            'tok_depth@224', 'tok_normal@224', 'tok_semseg@224', 
                            'tok_canny_edge@224', 'tok_sam_edge@224', 'caption', 
                            'det', 'sam_instance', 'color_palette', 'metadata'
                        ],
                        label="Target modalities", 
                        info='Choose which modalities are predicted (in this order).'
                    )
                    gr.Markdown(f"""
                    **Information on modalities**:

                    \* このデモの *SAM インスタンス* は 1 回のパスで生成されるため、まばらに見える場合があります。密な SAM インスタンスをサンプリングするには、`fourm.models.generate.GenerationSampler` の便利な関数 [`generate_sam_dense`](https://github.com/apple/ml-4m/blob/e11539965e45aa6731143d742c4493c46b4ef620/fourm/models/generate.py#L1230-L1273) と、使用例については [4M-21 interactive notebook](https://github.com/apple/ml-4m/blob/main/notebooks/generation_4M-21.ipynb) を参照してください。

                    \*\* 4M-21 モデルは *4D 人間のポーズ* を予測できますが、それを視覚化するには配布できない SMPL モデルが必要です。
                    ポーズを視覚化するには、次の手順に従ってください:

                    1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4`.
                    You may need to follow the [pyrender install instructions](https://pyrender.readthedocs.io/en/latest/install/index.html).
                    2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example.
                    3) Copy the required SMPL files (`smpl_mean_params.npz`, `SMPL_to_J19.pkl`, `smpl/SMPL_NEUTRAL.pkl`) to `fourm/utils/hmr2_utils/data` .
                    """)
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
                    top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.8)
                    top_k = gr.Slider(label="Top-k", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
        
            result = gr.Gallery(
                label="Predictions", show_label=True, elem_id="gallery", type='pil',
                columns=[4], rows=None, object_fit="contain", height="auto"
            )

        gr.Examples(
            examples = examples,
            fn = infer,
            inputs = [img_path],
            outputs = [result],
            cache_examples='lazy',
        )

    run_button.click(
        fn = infer,
        inputs = [img_path, seed, randomize_seed, target_modalities, top_p, top_k],
        outputs = [result]
    )

demo.queue(max_size=10).launch()