Spaces:
Running
on
Zero
Running
on
Zero
Sorting modalities in generate_output() backend.py for consistent generations
#2
by
mespinosami
- opened
- src/backend.py +20 -29
src/backend.py
CHANGED
@@ -242,25 +242,27 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
|
|
242 |
gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
|
243 |
return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
|
244 |
|
245 |
-
|
246 |
-
|
247 |
if s2l1c_active:
|
248 |
-
|
249 |
-
condition_modalities.append('s2_l1c')
|
250 |
if s2l2a_active:
|
251 |
-
|
252 |
-
condition_modalities.append('s2_l2a')
|
253 |
if s1rtc_active:
|
254 |
-
|
255 |
-
condition_modalities.append('s1_rtc')
|
256 |
if dem_active:
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
260 |
imgs_out = custom_inference(
|
261 |
-
images=
|
262 |
generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
|
263 |
-
condition_modalities=
|
264 |
num_inference_steps=num_inference_steps_slider,
|
265 |
seed=seed
|
266 |
)
|
@@ -268,22 +270,11 @@ def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_infere
|
|
268 |
output = []
|
269 |
|
270 |
# Collect outputs
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
output.append(s2l2a_input)
|
277 |
-
else:
|
278 |
-
output.append(to_PIL(imgs_out['s2_l2a'][0]))
|
279 |
-
if s1rtc_active:
|
280 |
-
output.append(s1rtc_input)
|
281 |
-
else:
|
282 |
-
output.append(to_PIL(imgs_out['s1_rtc'][0]))
|
283 |
-
if dem_active:
|
284 |
-
output.append(dem_input)
|
285 |
-
else:
|
286 |
-
output.append(to_PIL(imgs_out['dem'][0]))
|
287 |
|
288 |
return output
|
289 |
|
|
|
242 |
gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.")
|
243 |
return s2l1c_input, s2l2a_input, s1rtc_input, dem_input
|
244 |
|
245 |
+
# Instead of collecting in UI order, create ordered dictionaries
|
246 |
+
input_images = {}
|
247 |
if s2l1c_active:
|
248 |
+
input_images['s2_l1c'] = s2l1c_input
|
|
|
249 |
if s2l2a_active:
|
250 |
+
input_images['s2_l2a'] = s2l2a_input
|
|
|
251 |
if s1rtc_active:
|
252 |
+
input_images['s1_rtc'] = s1rtc_input
|
|
|
253 |
if dem_active:
|
254 |
+
input_images['dem'] = dem_input
|
255 |
+
|
256 |
+
condition_modalities = list(input_images.keys())
|
257 |
+
|
258 |
+
# Sort modalities and collect images in the same order
|
259 |
+
sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x))
|
260 |
+
sorted_images = [input_images[mod] for mod in sorted_modalities]
|
261 |
+
|
262 |
imgs_out = custom_inference(
|
263 |
+
images=sorted_images,
|
264 |
generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities],
|
265 |
+
condition_modalities=sorted_modalities,
|
266 |
num_inference_steps=num_inference_steps_slider,
|
267 |
seed=seed
|
268 |
)
|
|
|
270 |
output = []
|
271 |
|
272 |
# Collect outputs
|
273 |
+
for modality in sorted_modalities:
|
274 |
+
if modality in input_images:
|
275 |
+
output.append(input_images[modality])
|
276 |
+
else:
|
277 |
+
output.append(to_PIL(imgs_out[modality][0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
return output
|
280 |
|