mikonvergence commited on
Commit
886e812
·
1 Parent(s): ee8d564

First test of backend

Browse files
Files changed (4) hide show
  1. app.py +11 -5
  2. requirements.txt +1 -0
  3. src/backend.py +278 -0
  4. src/utils.py +2 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import spaces
3
  from src.utils import *
 
4
 
5
  theme = gr.themes.Soft(primary_hue="cyan", secondary_hue="zinc", font=[gr.themes.GoogleFont("Source Sans 3", weights=(400, 600)),'arial'])
6
 
@@ -9,7 +10,9 @@ with gr.Blocks(theme=theme) as demo:
9
  gr.Markdown("# 🗾 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
10
  gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
11
  gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
12
-
 
 
13
  with gr.Column(elem_classes="abstract"):
14
 
15
  with gr.Accordion("Abstract", open=False) as abstract:
@@ -48,8 +51,7 @@ with gr.Blocks(theme=theme) as demo:
48
  dem_output = gr.Image(label="DEM (Elevation)", interactive=False)
49
 
50
  with gr.Accordion("Advanced Options", open=False) as advanced_options:
51
- num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
52
- guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
53
  with gr.Row():
54
  seed_number = gr.Number(value=6378, label="Seed")
55
  seed_checkbox = gr.Checkbox(value=True, label="Random")
@@ -61,8 +63,12 @@ with gr.Blocks(theme=theme) as demo:
61
 
62
  generate_button.click(
63
  fn=generate_output,
64
- inputs=[s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, guidance_scale_slider, seed_number, seed_checkbox],
 
 
 
 
65
  outputs=[s2l1c_output, s2l2a_output, s1rtc_output, dem_output],
66
  )
67
 
68
- demo.queue().launch()
 
1
  import gradio as gr
2
  import spaces
3
  from src.utils import *
4
+ from src.backend import *
5
 
6
  theme = gr.themes.Soft(primary_hue="cyan", secondary_hue="zinc", font=[gr.themes.GoogleFont("Source Sans 3", weights=(400, 600)),'arial'])
7
 
 
10
  gr.Markdown("# 🗾 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
11
  gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
12
  gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
13
+
14
+ gr.Markdown('⚠️ NOTE: This is a protoype Beta model of COP-GEN. It is based on image thumbnails of Major TOM and does not yet support raw source data. The hillshade visualisation is used for elevation. The full model COP-GEN is coming soon.')
15
+
16
  with gr.Column(elem_classes="abstract"):
17
 
18
  with gr.Accordion("Abstract", open=False) as abstract:
 
51
  dem_output = gr.Image(label="DEM (Elevation)", interactive=False)
52
 
53
  with gr.Accordion("Advanced Options", open=False) as advanced_options:
54
+ num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=10, label="Inference Steps")
 
55
  with gr.Row():
56
  seed_number = gr.Number(value=6378, label="Seed")
57
  seed_checkbox = gr.Checkbox(value=True, label="Random")
 
63
 
64
  generate_button.click(
65
  fn=generate_output,
66
+ inputs=[s2l1c_input, s2l1c_active,
67
+ s2l2a_input, s2l2a_active,
68
+ s1rtc_input, s1rtc_active,
69
+ dem_input, dem_active,
70
+ num_inference_steps_slider, seed_number, seed_checkbox],
71
  outputs=[s2l1c_output, s2l2a_output, s1rtc_output, dem_output],
72
  )
73
 
74
+ demo.queue().launch(share=True)
requirements.txt CHANGED
@@ -8,3 +8,4 @@ scikit-learn
8
  huggingface_hub
9
  transformers==4.51.1
10
  accelerate==1.5.2
 
 
8
  huggingface_hub
9
  transformers==4.51.1
10
  accelerate==1.5.2
11
+ ml_collections
src/backend.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import ml_collections
6
+ from torchvision.utils import save_image, make_grid
7
+ import torch.nn.functional as F
8
+ import einops
9
+ import random
10
+ import torchvision.transforms as standard_transforms
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ hf_hub_download(repo_id="thu-ml/unidiffuser-v1", filename="autoencoder_kl.pth", local_dir='./models')
14
+ hf_hub_download(repo_id="mespinosami/COP-GEN-Beta", filename="nnet_ema_114000.pth", local_dir='./models')
15
+
16
+ import sys
17
+ sys.path.append('./src/COP-GEN-Beta')
18
+
19
+ import libs
20
+ from dpm_solver_pp import DPM_Solver, NoiseScheduleVP
21
+ from sample_n_triffuser import set_seed, stable_diffusion_beta_schedule, unpreprocess
22
+ import utils
23
+
24
+ from diffusers import AutoencoderKL
25
+ from .Triffuser import *
26
+
27
+ # Function to load model
28
+ def load_model(device='cuda'):
29
+ nnet = Triffuser(num_modalities=4)
30
+ checkpoint = torch.load('models/nnet_ema_114000.pth', map_location='cuda')
31
+ nnet.load_state_dict(checkpoint)
32
+ nnet.to(device)
33
+ nnet.eval()
34
+
35
+ autoencoder = libs.autoencoder.get_model(pretrained_path = "models/autoencoder_kl.pth")
36
+ autoencoder.to(device)
37
+ autoencoder.eval()
38
+
39
+ return nnet, autoencoder
40
+
41
+ print('Loading COP-GEN-Beta model...')
42
+ nnet, autoencoder = load_model()
43
+ to_PIL = standard_transforms.ToPILImage()
44
+ print('[DONE]')
45
+
46
+ def get_config(generate_modalities, condition_modalities, seed, num_inference_steps=50):
47
+ config = ml_collections.ConfigDict()
48
+ config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
+ config.seed = seed
50
+ config.n_samples = 1
51
+ config.z_shape = (4, 32, 32) # Shape of the latent vectors
52
+ config.sample = {
53
+ 'sample_steps': num_inference_steps,
54
+ 'algorithm': "dpm_solver",
55
+ }
56
+ # Model config
57
+ config.num_modalities = 4 # 4 modalities: DEM, S1RTC, S2L1C, S2L2A
58
+ config.modalities = ['dem', 's1_rtc', 's2_l1c', 's2_l2a']
59
+ # Network config
60
+ config.nnet = {
61
+ 'name': 'triffuser_multi_post_ln',
62
+ 'img_size': 32,
63
+ 'in_chans': 4,
64
+ 'patch_size': 2,
65
+ 'embed_dim': 1024,
66
+ 'depth': 20,
67
+ 'num_heads': 16,
68
+ 'mlp_ratio': 4,
69
+ 'qkv_bias': False,
70
+ 'pos_drop_rate': 0.,
71
+ 'drop_rate': 0.,
72
+ 'attn_drop_rate': 0.,
73
+ 'mlp_time_embed': False,
74
+ 'num_modalities': 4,
75
+ 'use_checkpoint': True,
76
+ }
77
+
78
+ # Parse generate and condition modalities
79
+ config.generate_modalities = generate_modalities
80
+ config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x))
81
+ config.condition_modalities = condition_modalities if condition_modalities else []
82
+ config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x))
83
+ config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities]
84
+ config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities]
85
+ # Validate modalities
86
+ valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'}
87
+ for mod in config.generate_modalities + config.condition_modalities:
88
+ if mod not in valid_modalities:
89
+ raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}")
90
+ # Check that generate and condition modalities don't overlap
91
+ if set(config.generate_modalities) & set(config.condition_modalities):
92
+ raise ValueError("Generate and condition modalities must be different")
93
+ # Default data paths
94
+ config.nnet_path = 'models/nnet_ema_114000.pth'
95
+ #config.autoencoder = {"pretrained_path": "assets/stable-diffusion/autoencoder_kl_ema.pth"}
96
+
97
+ return config
98
+
99
+ # Function to prepare image for inference
100
+ def prepare_images(images):
101
+ transforms = standard_transforms.Compose([
102
+ standard_transforms.ToTensor(),
103
+ standard_transforms.Normalize(mean=(0.5,), std=(0.5,))
104
+ ])
105
+ img_tensors = []
106
+ for img in images:
107
+ img_tensors.append(transforms(img)) # Add batch dimension
108
+ return img_tensors
109
+
110
+
111
+ def run_inference(config, nnet, autoencoder, img_tensors):
112
+ set_seed(config.seed)
113
+ img_tensors = [tensor.to(config.device) for tensor in img_tensors]
114
+ # Create a context tensor for all modalities
115
+ img_contexts = torch.randn(config.num_modalities, 1, 2 * config.z_shape[0],
116
+ config.z_shape[1], config.z_shape[2], device=config.device)
117
+ with torch.no_grad():
118
+ # Encode the input images with autoencoder
119
+ z_conds = [autoencoder.encode_moments(tensor.unsqueeze(0)) for tensor in img_tensors]
120
+ # Create mapping of conditional modalities indices to the encoded inputs
121
+ cond_indices = [i for i, is_cond in enumerate(config.condition_modalities_mask) if is_cond]
122
+ # Check if we have the right number of inputs
123
+ if len(cond_indices) != len(z_conds):
124
+ raise ValueError(f"Number of conditioning modalities ({len(cond_indices)}) must match number of input images ({len(z_conds)})")
125
+ # Assign each encoded input to the corresponding modality
126
+ for i, z_cond in zip(cond_indices, z_conds):
127
+ img_contexts[i] = z_cond
128
+ # Sample values from the distribution (mean and variance)
129
+ z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts])
130
+ # Generate initial noise for the modalities being generated
131
+ _z_init = torch.randn(len(config.generate_modalities), 1, *z_imgs[0].shape[1:], device=config.device)
132
+
133
+ def combine_joint(z_list):
134
+ """Combine individual modality tensors into a single concatenated tensor"""
135
+ return torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z_list], dim=-1)
136
+
137
+ def split_joint(x, z_imgs, config):
138
+ """
139
+ Split the combined tensor back into individual modality tensors
140
+ and arrange them according to the full set of modalities
141
+ """
142
+ C, H, W = config.z_shape
143
+ z_dim = C * H * W
144
+ z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1)
145
+ z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W)
146
+ for z_i, modality in zip(z_generated, config.generate_modalities)}
147
+ z = []
148
+ for i, modality in enumerate(config.modalities):
149
+ if modality in config.generate_modalities: # Modalities that are being denoised
150
+ z.append(z_generated[modality])
151
+ elif modality in config.condition_modalities: # Modalities that are being conditioned on
152
+ z.append(z_imgs[i])
153
+ else: # Modalities that are ignored
154
+ z.append(torch.randn(x.shape[0], C, H, W, device=config.device))
155
+
156
+ return z
157
+
158
+ _x_init = combine_joint(_z_init) # Initial tensor for the modalities being generated
159
+ _betas = stable_diffusion_beta_schedule()
160
+ N = len(_betas)
161
+
162
+ def model_fn(x, t_continuous):
163
+ t = t_continuous * N
164
+
165
+ # Create timesteps for each modality based on the generate mask
166
+ timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask]
167
+ # Split the input into a list of tensors for all modalities
168
+ z = split_joint(x, z_imgs, config)
169
+ # Call the network with the right format
170
+ z_out = nnet(z, t_imgs=timesteps)
171
+ # Select only the generated modalities for the denoising process
172
+ z_out_generated = [z_out[i]
173
+ for i, modality in enumerate(config.modalities)
174
+ if modality in config.generate_modalities]
175
+ # Combine the outputs back into a single tensor
176
+ return combine_joint(z_out_generated)
177
+
178
+ # Sample using the DPM-Solver with exact parameters from sample_n_triffuser.py
179
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=config.device).float())
180
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
181
+
182
+ # Generate samples
183
+ with torch.no_grad():
184
+ with torch.autocast(device_type=config.device):
185
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
186
+
187
+ # Split the result back into individual modality tensors
188
+ _zs = split_joint(x, z_imgs, config)
189
+
190
+ # Replace conditional modalities with the original images
191
+ for i, mask in enumerate(config.condition_modalities_mask):
192
+ if mask:
193
+ _zs[i] = z_imgs[i]
194
+
195
+ # Decode and unprocess the generated samples
196
+ generated_samples = []
197
+ for i, modality in enumerate(config.modalities):
198
+ if modality in config.generate_modalities:
199
+ sample = autoencoder.decode(_zs[i]) # Decode the latent representation
200
+ sample = unpreprocess(sample) # Unpreprocess to [0, 1] range
201
+ generated_samples.append((modality, sample))
202
+
203
+ return generated_samples
204
+
205
+ def custom_inference(images, generate_modalities, condition_modalities, num_inference_steps, seed=None):
206
+ """
207
+ Run custom inference with user-specified parameters
208
+
209
+ Args:
210
+ generate_modalities: List of modalities to generate
211
+ condition_modalities: List of modalities to condition on
212
+ image_paths: Path to conditioning image or list of paths (ordered to match condition_modalities)
213
+
214
+ Returns:
215
+ Dict mapping modality names to generated tensors
216
+ """
217
+ if seed is None:
218
+ seed = random.randint(0, int(1e8))
219
+
220
+ img_tensors = prepare_images(images)
221
+
222
+ config = get_config(generate_modalities, condition_modalities, seed=seed)
223
+ config.sample.sample_steps = num_inference_steps
224
+ generated_samples = run_inference(config, nnet, autoencoder, img_tensors)
225
+ results = {modality: tensor for modality, tensor in generated_samples}
226
+
227
+ return results
228
+
229
+ def generate_output(s2l1c_input, s2l1c_active, s2l2a_input, s2l2a_active, s1rtc_input, s1rtc_active, dem_input, dem_active,num_inference_steps_slider, seed_number, ignore_seed):
230
+
231
+ seed = seed_number if not ignore_seed else None
232
+
233
+ images=[]
234
+ condition_modalities=[]
235
+ if s2l2a_active:
236
+ images.append(s2l2a_input)
237
+ condition_modalities.append('s2_l2a')
238
+ if s2l1c_active:
239
+ images.append(s2l1c_input)
240
+ condition_modalities.append('s2_l1c')
241
+ if s1rtc_active:
242
+ images.append(s1rtc_input)
243
+ condition_modalities.append('s1_rtc')
244
+ if dem_active:
245
+ images.append(dem_input)
246
+ condition_modalities.append('dem')
247
+
248
+ imgs_out = custom_inference(
249
+ images=images,
250
+ generate_modalities=[el for el in ['s2_l2a', 's2_l1c', 's1_rtc', 'dem'] if el not in condition_modalities],
251
+ condition_modalities=condition_modalities,
252
+ num_inference_steps=num_inference_steps_slider,
253
+ seed=seed
254
+ )
255
+
256
+ output = []
257
+
258
+ # Collect outputs
259
+ if s2l1c_active:
260
+ output.append(s2l1c_input)
261
+ else:
262
+ output.append(to_PIL(imgs_out['s2_l1c'][0]))
263
+ if s2l2a_active:
264
+ output.append(s2l2a_input)
265
+ else:
266
+ output.append(to_PIL(imgs_out['s2_l2a'][0]))
267
+ if s1rtc_active:
268
+ output.append(s1rtc_input)
269
+ else:
270
+ output.append(to_PIL(imgs_out['s1_rtc'][0]))
271
+ if dem_active:
272
+ output.append(dem_input)
273
+ else:
274
+ output.append(to_PIL(imgs_out['dem'][0]))
275
+
276
+ return output
277
+
278
+
src/utils.py CHANGED
@@ -100,8 +100,8 @@ def get_rows(grid_cell):
100
  l2a_df, l1c_df, rtc_df, and dem_df. It assumes these DataFrames are defined in the scope.
101
  Each element of the tuple is a Pandas Series representing a row.
102
  """
103
- return l2a_df[l2a_df.grid_cell == grid_cell].iloc[0], \
104
- l1c_df[l1c_df.grid_cell == grid_cell].iloc[0], \
105
  rtc_df[rtc_df.grid_cell == grid_cell].iloc[0], \
106
  dem_df[dem_df.grid_cell == grid_cell].iloc[0]
107
 
 
100
  l2a_df, l1c_df, rtc_df, and dem_df. It assumes these DataFrames are defined in the scope.
101
  Each element of the tuple is a Pandas Series representing a row.
102
  """
103
+ return l1c_df[l1c_df.grid_cell == grid_cell].iloc[0], \
104
+ l2a_df[l2a_df.grid_cell == grid_cell].iloc[0], \
105
  rtc_df[rtc_df.grid_cell == grid_cell].iloc[0], \
106
  dem_df[dem_df.grid_cell == grid_cell].iloc[0]
107