Sorting modalities in generate_output() backend.py for consistent generations

#2
Files changed (1) hide show
  1. src/backend.py +20 -29
src/backend.py CHANGED
@@ -242,25 +242,27 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
242
  gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
243
  return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
244
 
245
- images=[]
246
- condition_modalities=[]
247
  if s2l1c_active:
248
- images.append(s2l1c_input)
249
- condition_modalities.append('s2_l1c')
250
  if s2l2a_active:
251
- images.append(s2l2a_input)
252
- condition_modalities.append('s2_l2a')
253
  if s1rtc_active:
254
- images.append(s1rtc_input)
255
- condition_modalities.append('s1_rtc')
256
  if dem_active:
257
- images.append(dem_input)
258
- condition_modalities.append('dem')
259
-
 
 
 
 
 
260
  imgs_out = custom_inference(
261
- images=images,
262
  generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
263
- condition_modalities=condition_modalities,
264
  num_inference_steps=num_inference_steps_slider,
265
  seed=seed
266
  )
@@ -268,22 +270,11 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
268
  output = []
269
 
270
  # Collect outputs
271
- if s2l1c_active:
272
- output.append(s2l1c_input)
273
- else:
274
- output.append(to_PIL(imgs_out['s2_l1c'][0]))
275
- if s2l2a_active:
276
- output.append(s2l2a_input)
277
- else:
278
- output.append(to_PIL(imgs_out['s2_l2a'][0]))
279
- if s1rtc_active:
280
- output.append(s1rtc_input)
281
- else:
282
- output.append(to_PIL(imgs_out['s1_rtc'][0]))
283
- if dem_active:
284
- output.append(dem_input)
285
- else:
286
- output.append(to_PIL(imgs_out['dem'][0]))
287
 
288
  return output
289
 
 
242
  gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
243
  return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
244
 
245
+ # Instead of collecting in UI order, create ordered dictionaries
246
+ input_images = {}
247
  if s2l1c_active:
248
+ input_images['s2_l1c'] = s2l1c_input
 
249
  if s2l2a_active:
250
+ input_images['s2_l2a'] = s2l2a_input
 
251
  if s1rtc_active:
252
+ input_images['s1_rtc'] = s1rtc_input
 
253
  if dem_active:
254
+ input_images['dem'] = dem_input
255
+
256
+ condition_modalities = list(input_images.keys())
257
+
258
+ # Sort modalities and collect images in the same order
259
+ sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x))
260
+ sorted_images = [input_images[mod] for mod in sorted_modalities]
261
+
262
  imgs_out = custom_inference(
263
+ images=sorted_images,
264
  generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
265
+ condition_modalities=sorted_modalities,
266
  num_inference_steps=num_inference_steps_slider,
267
  seed=seed
268
  )
 
270
  output = []
271
 
272
  # Collect outputs
273
+ for modality in sorted_modalities:
274
+ if modality in input_images:
275
+ output.append(input_images[modality])
276
+ else:
277
+ output.append(to_PIL(imgs_out[modality][0]))
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  return output
280