test
Browse files
mar.py
CHANGED
@@ -309,6 +309,8 @@ class MARBert(nn.Module):
|
|
309 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
310 |
|
311 |
# sample token latents for this step
|
|
|
|
|
312 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
313 |
print(z.shape)
|
314 |
# cfg schedule follow Muse
|
|
|
309 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
310 |
|
311 |
# sample token latents for this step
|
312 |
+
print(z.shape)
|
313 |
+
print("-----------")
|
314 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
315 |
print(z.shape)
|
316 |
# cfg schedule follow Muse
|