File size: 12,014 Bytes
886e812
 
4d22630
a1c8d5a
886e812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d787e0c
e8d0048
886e812
 
 
e8d0048
 
 
 
 
a1c8d5a
 
40c3bb1
a1c8d5a
b3851ae
 
886e812
b3851ae
40c3bb1
b3851ae
886e812
b3851ae
886e812
b3851ae
 
 
 
 
 
 
 
886e812
b3851ae
40c3bb1
b3851ae
886e812
 
 
 
 
 
 
e350327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886e812
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os
import torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import ml_collections
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import einops
import random
import torchvision.transforms as standard_transforms

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="thu-ml/unidiffuser-v1", filename="autoencoder_kl.pth", local_dir='./models')
hf_hub_download(repo_id="mespinosami/COP-GEN-Beta", filename="nnet_ema_114000.pth", local_dir='./models')

import sys
sys.path.append('./src/COP-GEN-Beta')

import libs
from dpm_solver_pp import DPM_Solver, NoiseScheduleVP
from sample_n_triffuser import set_seed, stable_diffusion_beta_schedule, unpreprocess
import utils

from diffusers import AutoencoderKL
from .Triffuser import *

# Function to load model
def load_model(device='cuda'):
    nnet = Triffuser(num_modalities=4)
    checkpoint = torch.load('models/nnet_ema_114000.pth', map_location='cuda')
    nnet.load_state_dict(checkpoint)
    nnet.to(device)
    nnet.eval()
    
    autoencoder = libs.autoencoder.get_model(pretrained_path = "models/autoencoder_kl.pth")
    autoencoder.to(device)
    autoencoder.eval()

    return nnet, autoencoder

print('Loading COP-GEN-Beta model...')
nnet, autoencoder = load_model()
to_PIL = standard_transforms.ToPILImage()
print('[DONE]')

def get_config(generate_modalities, condition_modalities, seed, num_inference_steps=50):
    config = ml_collections.ConfigDict()
    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    config.seed = seed
    config.n_samples = 1
    config.z_shape = (4, 32, 32)  # Shape of the latent vectors
    config.sample = {
        'sample_steps': num_inference_steps,
        'algorithm': "dpm_solver",
    }
    # Model config
    config.num_modalities = 4  # 4 modalities: DEM, S1RTC, S2L1C, S2L2A
    config.modalities = ['dem', 's1_rtc', 's2_l1c', 's2_l2a']
    # Network config
    config.nnet = {
        'name': 'triffuser_multi_post_ln',
        'img_size': 32,
        'in_chans': 4,
        'patch_size': 2,
        'embed_dim': 1024,
        'depth': 20,
        'num_heads': 16,
        'mlp_ratio': 4,
        'qkv_bias': False,
        'pos_drop_rate': 0.,
        'drop_rate': 0.,
        'attn_drop_rate': 0.,
        'mlp_time_embed': False,
        'num_modalities': 4,
        'use_checkpoint': True,
    }

    # Parse generate and condition modalities
    config.generate_modalities = generate_modalities
    config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x))
    config.condition_modalities = condition_modalities if condition_modalities else []
    config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x))
    config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities]
    config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities]
    # Validate modalities
    valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'}
    for mod in config.generate_modalities + config.condition_modalities:
        if mod not in valid_modalities:
            raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}")
    # Check that generate and condition modalities don't overlap
    if set(config.generate_modalities) & set(config.condition_modalities):
        raise ValueError("Generate and condition modalities must be different")
    # Default data paths
    config.nnet_path = 'models/nnet_ema_114000.pth'
    #config.autoencoder = {"pretrained_path": "assets/stable-diffusion/autoencoder_kl_ema.pth"}

    return config

# Function to prepare image for inference
def prepare_images(images):
    transforms = standard_transforms.Compose([
                   standard_transforms.ToTensor(),
                   standard_transforms.Normalize(mean=(0.5,), std=(0.5,))
                 ])
    img_tensors = []
    for img in images:
        img_tensors.append(transforms(img))  # Add batch dimension
    return img_tensors


def run_inference(config, nnet, autoencoder, img_tensors):
    set_seed(config.seed)
    img_tensors = [tensor.to(config.device) for tensor in img_tensors]
    # Create a context tensor for all modalities
    img_contexts = torch.randn(config.num_modalities, 1, 2 * config.z_shape[0],
                                config.z_shape[1], config.z_shape[2], device=config.device)
    with torch.no_grad():
        # Encode the input images with autoencoder
        z_conds = [autoencoder.encode_moments(tensor.unsqueeze(0)) for tensor in img_tensors]
        # Create mapping of conditional modalities indices to the encoded inputs
        cond_indices = [i for i, is_cond in enumerate(config.condition_modalities_mask) if is_cond]
        # Check if we have the right number of inputs
        if len(cond_indices) != len(z_conds):
            raise ValueError(f"Number of conditioning modalities ({len(cond_indices)}) must match number of input images ({len(z_conds)})")
        # Assign each encoded input to the corresponding modality
        for i, z_cond in zip(cond_indices, z_conds):
            img_contexts[i] = z_cond
    # Sample values from the distribution (mean and variance)
    z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts])
    # Generate initial noise for the modalities being generated
    _z_init = torch.randn(len(config.generate_modalities), 1, *z_imgs[0].shape[1:], device=config.device)

    def combine_joint(z_list):
        """Combine individual modality tensors into a single concatenated tensor"""
        return torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z_list], dim=-1)

    def split_joint(x, z_imgs, config):
        """
        Split the combined tensor back into individual modality tensors
        and arrange them according to the full set of modalities
        """
        C, H, W = config.z_shape
        z_dim = C * H * W
        z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1)
        z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W)
                    for z_i, modality in zip(z_generated, config.generate_modalities)}
        z = []
        for i, modality in enumerate(config.modalities):
            if modality in config.generate_modalities: # Modalities that are being denoised
                z.append(z_generated[modality])
            elif modality in config.condition_modalities: # Modalities that are being conditioned on
                z.append(z_imgs[i])
            else: # Modalities that are ignored
                z.append(torch.randn(x.shape[0], C, H, W, device=config.device))

        return z

    _x_init = combine_joint(_z_init) # Initial tensor for the modalities being generated
    _betas = stable_diffusion_beta_schedule()
    N = len(_betas)

    def model_fn(x, t_continuous):
        t = t_continuous * N

        # Create timesteps for each modality based on the generate mask
        timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask]
        # Split the input into a list of tensors for all modalities
        z = split_joint(x, z_imgs, config)
        # Call the network with the right format
        z_out = nnet(z, t_imgs=timesteps)
        # Select only the generated modalities for the denoising process
        z_out_generated = [z_out[i]
                        for i, modality in enumerate(config.modalities)
                        if modality in config.generate_modalities]
        # Combine the outputs back into a single tensor
        return combine_joint(z_out_generated)

    # Sample using the DPM-Solver with exact parameters from sample_n_triffuser.py
    noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=config.device).float())
    dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)

    # Generate samples
    with torch.no_grad():
        with torch.autocast(device_type=config.device):
            x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)

    # Split the result back into individual modality tensors
    _zs = split_joint(x, z_imgs, config)

    # Replace conditional modalities with the original images
    for i, mask in enumerate(config.condition_modalities_mask):
        if mask:
            _zs[i] = z_imgs[i]

    # Decode and unprocess the generated samples
    generated_samples = []
    for i, modality in enumerate(config.modalities):
        if modality in config.generate_modalities:
            sample = autoencoder.decode(_zs[i]) # Decode the latent representation
            sample = unpreprocess(sample) # Unpreprocess to [0, 1] range
            generated_samples.append((modality, sample))

    return generated_samples

def custom_inference(images, generate_modalities, condition_modalities, num_inference_steps, seed=None):
    """
    Run custom inference with user-specified parameters

    Args:
        generate_modalities: List of modalities to generate
        condition_modalities: List of modalities to condition on
        image_paths: Path to conditioning image or list of paths (ordered to match condition_modalities)

    Returns:
        Dict mapping modality names to generated tensors
    """
    if seed is None:
        seed = random.randint(0, int(1e8))

    img_tensors = prepare_images(images)

    config = get_config(generate_modalities, condition_modalities, seed=seed)
    config.sample.sample_steps = num_inference_steps
    generated_samples = run_inference(config, nnet, autoencoder, img_tensors)
    results = {modality: tensor for modality, tensor in generated_samples}
    
    return results

@spaces.GPU
def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, seed_number, ignore_seed):

    seed = seed_number if not ignore_seed else None

    s2l2a_active = s2l2a_input is not None
    s2l1c_active = s2l1c_input is not None
    s1rtc_active = s1rtc_input is not None
    dem_active = dem_input is not None

    if s2l2a_active and s2l1c_active and s1rtc_active and dem_active:
        gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
        return s2l1c_input, s2l2a_input, s1rtc_input, dem_input

    # Instead of collecting in UI order, create ordered dictionaries
    input_images = {}
    if s2l1c_active:
        input_images['s2_l1c'] = s2l1c_input
    if s2l2a_active:
        input_images['s2_l2a'] = s2l2a_input
    if s1rtc_active:
        input_images['s1_rtc'] = s1rtc_input
    if dem_active:
        input_images['dem'] = dem_input
    
    condition_modalities = list(input_images.keys())
    
    # Sort modalities and collect images in the same order
    sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x))
    sorted_images = [input_images[mod] for mod in sorted_modalities]
    
    imgs_out = custom_inference(
        images=sorted_images,
        generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
        condition_modalities=sorted_modalities,
        num_inference_steps=num_inference_steps_slider,
        seed=seed
    )

    output = []

    # Collect outputs
    if s2l1c_active:
        output.append(s2l1c_input)
    else:
        output.append(to_PIL(imgs_out['s2_l1c'][0]))
    if s2l2a_active:
        output.append(s2l2a_input)
    else:
        output.append(to_PIL(imgs_out['s2_l2a'][0]))
    if s1rtc_active:
        output.append(s1rtc_input)
    else:
        output.append(to_PIL(imgs_out['s1_rtc'][0]))
    if dem_active:
        output.append(dem_input)
    else:
        output.append(to_PIL(imgs_out['dem'][0]))
        
    return output