LimiTrix commited on
Commit
ad05bf8
·
1 Parent(s): 9bf7098
Files changed (1) hide show
  1. mar.py +2 -0
mar.py CHANGED
@@ -309,6 +309,8 @@ class MARBert(nn.Module):
309
  mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
310
 
311
  # sample token latents for this step
 
 
312
  z = z[mask_to_pred.nonzero(as_tuple=True)]
313
  print(z.shape)
314
  # cfg schedule follow Muse
 
309
  mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
310
 
311
  # sample token latents for this step
312
+ print(z.shape)
313
+ print("-----------")
314
  z = z[mask_to_pred.nonzero(as_tuple=True)]
315
  print(z.shape)
316
  # cfg schedule follow Muse