HReynaud commited on
Commit
308c0d9
·
1 Parent(s): 766c801
Files changed (1) hide show
  1. demo.py +5 -3
demo.py CHANGED
@@ -30,6 +30,8 @@ torch.set_grad_enabled(False)
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  dtype = torch.float32
32
 
 
 
33
  # 4f4 latent space
34
  B, T, C, H, W = 1, 64, 4, 28, 28
35
 
@@ -141,7 +143,7 @@ def get_vae_scaler(path):
141
  return scaler
142
 
143
 
144
- generator = torch.Generator(device=device).manual_seed(0)
145
 
146
  lifm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4")
147
  lifm = lifm.to(device, dtype=dtype)
@@ -271,7 +273,7 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
271
  (B, C, H, W),
272
  device=device,
273
  dtype=dtype,
274
- generator=generator,
275
  )
276
 
277
  lifm.forward_original = lifm.forward
@@ -439,7 +441,7 @@ def generate_animation(
439
  (B, C, T, H, W),
440
  device=device,
441
  dtype=dtype,
442
- generator=generator,
443
  )
444
 
445
  # print(
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  dtype = torch.float32
32
 
33
+ print(f"Using device: {device}")
34
+
35
  # 4f4 latent space
36
  B, T, C, H, W = 1, 64, 4, 28, 28
37
 
 
143
  return scaler
144
 
145
 
146
+ # generator = torch.Generator(device=device).manual_seed(0)
147
 
148
  lifm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4")
149
  lifm = lifm.to(device, dtype=dtype)
 
273
  (B, C, H, W),
274
  device=device,
275
  dtype=dtype,
276
+ # generator=generator,
277
  )
278
 
279
  lifm.forward_original = lifm.forward
 
441
  (B, C, T, H, W),
442
  device=device,
443
  dtype=dtype,
444
+ # generator=generator,
445
  )
446
 
447
  # print(