import os import torch import spaces import gradio as gr import numpy as np from PIL import Image import ml_collections from torchvision.utils import save_image, make_grid import torch.nn.functional as F import einops import random import torchvision.transforms as standard_transforms from huggingface_hub import hf_hub_download hf_hub_download(repo_id="thu-ml/unidiffuser-v1", filename="autoencoder_kl.pth", local_dir='./models') hf_hub_download(repo_id="mespinosami/COP-GEN-Beta", filename="nnet_ema_114000.pth", local_dir='./models') import sys sys.path.append('./src/COP-GEN-Beta') import libs from dpm_solver_pp import DPM_Solver, NoiseScheduleVP from sample_n_triffuser import set_seed, stable_diffusion_beta_schedule, unpreprocess import utils from diffusers import AutoencoderKL from .Triffuser import * # Function to load model def load_model(device='cuda'): nnet = Triffuser(num_modalities=4) checkpoint = torch.load('models/nnet_ema_114000.pth', map_location='cuda') nnet.load_state_dict(checkpoint) nnet.to(device) nnet.eval() autoencoder = libs.autoencoder.get_model(pretrained_path = "models/autoencoder_kl.pth") autoencoder.to(device) autoencoder.eval() return nnet, autoencoder print('Loading COP-GEN-Beta model...') nnet, autoencoder = load_model() to_PIL = standard_transforms.ToPILImage() print('[DONE]') def get_config(generate_modalities, condition_modalities, seed, num_inference_steps=50): config = ml_collections.ConfigDict() config.device = 'cuda' if torch.cuda.is_available() else 'cpu' config.seed = seed config.n_samples = 1 config.z_shape = (4, 32, 32) # Shape of the latent vectors config.sample = { 'sample_steps': num_inference_steps, 'algorithm': "dpm_solver", } # Model config config.num_modalities = 4 # 4 modalities: DEM, S1RTC, S2L1C, S2L2A config.modalities = ['dem', 's1_rtc', 's2_l1c', 's2_l2a'] # Network config config.nnet = { 'name': 'triffuser_multi_post_ln', 'img_size': 32, 'in_chans': 4, 'patch_size': 2, 'embed_dim': 1024, 'depth': 20, 'num_heads': 16, 'mlp_ratio': 4, 'qkv_bias': False, 'pos_drop_rate': 0., 'drop_rate': 0., 'attn_drop_rate': 0., 'mlp_time_embed': False, 'num_modalities': 4, 'use_checkpoint': True, } # Parse generate and condition modalities config.generate_modalities = generate_modalities config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x)) config.condition_modalities = condition_modalities if condition_modalities else [] config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x)) config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities] config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities] # Validate modalities valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'} for mod in config.generate_modalities + config.condition_modalities: if mod not in valid_modalities: raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}") # Check that generate and condition modalities don't overlap if set(config.generate_modalities) & set(config.condition_modalities): raise ValueError("Generate and condition modalities must be different") # Default data paths config.nnet_path = 'models/nnet_ema_114000.pth' #config.autoencoder = {"pretrained_path": "assets/stable-diffusion/autoencoder_kl_ema.pth"} return config # Function to prepare image for inference def prepare_images(images): transforms = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) img_tensors = [] for img in images: img_tensors.append(transforms(img)) # Add batch dimension return img_tensors def run_inference(config, nnet, autoencoder, img_tensors): set_seed(config.seed) img_tensors = [tensor.to(config.device) for tensor in img_tensors] # Create a context tensor for all modalities img_contexts = torch.randn(config.num_modalities, 1, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2], device=config.device) with torch.no_grad(): # Encode the input images with autoencoder z_conds = [autoencoder.encode_moments(tensor.unsqueeze(0)) for tensor in img_tensors] # Create mapping of conditional modalities indices to the encoded inputs cond_indices = [i for i, is_cond in enumerate(config.condition_modalities_mask) if is_cond] # Check if we have the right number of inputs if len(cond_indices) != len(z_conds): raise ValueError(f"Number of conditioning modalities ({len(cond_indices)}) must match number of input images ({len(z_conds)})") # Assign each encoded input to the corresponding modality for i, z_cond in zip(cond_indices, z_conds): img_contexts[i] = z_cond # Sample values from the distribution (mean and variance) z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts]) # Generate initial noise for the modalities being generated _z_init = torch.randn(len(config.generate_modalities), 1, *z_imgs[0].shape[1:], device=config.device) def combine_joint(z_list): """Combine individual modality tensors into a single concatenated tensor""" return torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z_list], dim=-1) def split_joint(x, z_imgs, config): """ Split the combined tensor back into individual modality tensors and arrange them according to the full set of modalities """ C, H, W = config.z_shape z_dim = C * H * W z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1) z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) for z_i, modality in zip(z_generated, config.generate_modalities)} z = [] for i, modality in enumerate(config.modalities): if modality in config.generate_modalities: # Modalities that are being denoised z.append(z_generated[modality]) elif modality in config.condition_modalities: # Modalities that are being conditioned on z.append(z_imgs[i]) else: # Modalities that are ignored z.append(torch.randn(x.shape[0], C, H, W, device=config.device)) return z _x_init = combine_joint(_z_init) # Initial tensor for the modalities being generated _betas = stable_diffusion_beta_schedule() N = len(_betas) def model_fn(x, t_continuous): t = t_continuous * N # Create timesteps for each modality based on the generate mask timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask] # Split the input into a list of tensors for all modalities z = split_joint(x, z_imgs, config) # Call the network with the right format z_out = nnet(z, t_imgs=timesteps) # Select only the generated modalities for the denoising process z_out_generated = [z_out[i] for i, modality in enumerate(config.modalities) if modality in config.generate_modalities] # Combine the outputs back into a single tensor return combine_joint(z_out_generated) # Sample using the DPM-Solver with exact parameters from sample_n_triffuser.py noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=config.device).float()) dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) # Generate samples with torch.no_grad(): with torch.autocast(device_type=config.device): x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) # Split the result back into individual modality tensors _zs = split_joint(x, z_imgs, config) # Replace conditional modalities with the original images for i, mask in enumerate(config.condition_modalities_mask): if mask: _zs[i] = z_imgs[i] # Decode and unprocess the generated samples generated_samples = [] for i, modality in enumerate(config.modalities): if modality in config.generate_modalities: sample = autoencoder.decode(_zs[i]) # Decode the latent representation sample = unpreprocess(sample) # Unpreprocess to [0, 1] range generated_samples.append((modality, sample)) return generated_samples def custom_inference(images, generate_modalities, condition_modalities, num_inference_steps, seed=None): """ Run custom inference with user-specified parameters Args: generate_modalities: List of modalities to generate condition_modalities: List of modalities to condition on image_paths: Path to conditioning image or list of paths (ordered to match condition_modalities) Returns: Dict mapping modality names to generated tensors """ if seed is None: seed = random.randint(0, int(1e8)) img_tensors = prepare_images(images) config = get_config(generate_modalities, condition_modalities, seed=seed) config.sample.sample_steps = num_inference_steps generated_samples = run_inference(config, nnet, autoencoder, img_tensors) results = {modality: tensor for modality, tensor in generated_samples} return results @spaces.GPU def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, seed_number, ignore_seed): seed = seed_number if not ignore_seed else None s2l2a_active = s2l2a_input is not None s2l1c_active = s2l1c_input is not None s1rtc_active = s1rtc_input is not None dem_active = dem_input is not None if s2l2a_active and s2l1c_active and s1rtc_active and dem_active: gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.") return s2l1c_input, s2l2a_input, s1rtc_input, dem_input # Instead of collecting in UI order, create ordered dictionaries input_images = {} if s2l1c_active: input_images['s2_l1c'] = s2l1c_input if s2l2a_active: input_images['s2_l2a'] = s2l2a_input if s1rtc_active: input_images['s1_rtc'] = s1rtc_input if dem_active: input_images['dem'] = dem_input condition_modalities = list(input_images.keys()) # Sort modalities and collect images in the same order sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x)) sorted_images = [input_images[mod] for mod in sorted_modalities] imgs_out = custom_inference( images=sorted_images, generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities], condition_modalities=sorted_modalities, num_inference_steps=num_inference_steps_slider, seed=seed ) output = [] # Collect outputs if s2l1c_active: output.append(s2l1c_input) else: output.append(to_PIL(imgs_out['s2_l1c'][0])) if s2l2a_active: output.append(s2l2a_input) else: output.append(to_PIL(imgs_out['s2_l2a'][0])) if s1rtc_active: output.append(s1rtc_input) else: output.append(to_PIL(imgs_out['s1_rtc'][0])) if dem_active: output.append(dem_input) else: output.append(to_PIL(imgs_out['dem'][0])) return output