test
Browse files
mar.py
CHANGED
@@ -117,7 +117,7 @@ class MARBert(nn.Module):
|
|
117 |
print("test")
|
118 |
print("test")
|
119 |
print("test")
|
120 |
-
|
121 |
def initialize_weights(self):
|
122 |
# parameters
|
123 |
torch.nn.init.normal_(self.class_emb.weight, std=.02)
|
@@ -599,7 +599,7 @@ class MAR(nn.Module):
|
|
599 |
for step in indices:
|
600 |
cur_tokens = tokens.clone()
|
601 |
print(cur_tokens.shape)
|
602 |
-
|
603 |
# class embedding and CFG
|
604 |
if labels is not None:
|
605 |
class_embedding = self.class_emb(labels)
|
@@ -635,7 +635,10 @@ class MAR(nn.Module):
|
|
635 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
636 |
|
637 |
# sample token latents for this step
|
|
|
638 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
|
|
|
|
639 |
# cfg schedule follow Muse
|
640 |
if cfg_schedule == "linear":
|
641 |
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
|
|
|
117 |
print("test")
|
118 |
print("test")
|
119 |
print("test")
|
120 |
+
|
121 |
def initialize_weights(self):
|
122 |
# parameters
|
123 |
torch.nn.init.normal_(self.class_emb.weight, std=.02)
|
|
|
599 |
for step in indices:
|
600 |
cur_tokens = tokens.clone()
|
601 |
print(cur_tokens.shape)
|
602 |
+
print("+++++++++++++++++++++++")
|
603 |
# class embedding and CFG
|
604 |
if labels is not None:
|
605 |
class_embedding = self.class_emb(labels)
|
|
|
635 |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
|
636 |
|
637 |
# sample token latents for this step
|
638 |
+
print(z.shape)
|
639 |
z = z[mask_to_pred.nonzero(as_tuple=True)]
|
640 |
+
print(z.shape)
|
641 |
+
print("==============================")
|
642 |
# cfg schedule follow Muse
|
643 |
if cfg_schedule == "linear":
|
644 |
cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
|