LimiTrix commited on
Commit
af9ffc2
·
1 Parent(s): dade80f
Files changed (1) hide show
  1. mar.py +5 -2
mar.py CHANGED
@@ -117,7 +117,7 @@ class MARBert(nn.Module):
117
  print("test")
118
  print("test")
119
  print("test")
120
-
121
  def initialize_weights(self):
122
  # parameters
123
  torch.nn.init.normal_(self.class_emb.weight, std=.02)
@@ -599,7 +599,7 @@ class MAR(nn.Module):
599
  for step in indices:
600
  cur_tokens = tokens.clone()
601
  print(cur_tokens.shape)
602
-
603
  # class embedding and CFG
604
  if labels is not None:
605
  class_embedding = self.class_emb(labels)
@@ -635,7 +635,10 @@ class MAR(nn.Module):
635
  mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
636
 
637
  # sample token latents for this step
 
638
  z = z[mask_to_pred.nonzero(as_tuple=True)]
 
 
639
  # cfg schedule follow Muse
640
  if cfg_schedule == "linear":
641
  cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
 
117
  print("test")
118
  print("test")
119
  print("test")
120
+
121
  def initialize_weights(self):
122
  # parameters
123
  torch.nn.init.normal_(self.class_emb.weight, std=.02)
 
599
  for step in indices:
600
  cur_tokens = tokens.clone()
601
  print(cur_tokens.shape)
602
+ print("+++++++++++++++++++++++")
603
  # class embedding and CFG
604
  if labels is not None:
605
  class_embedding = self.class_emb(labels)
 
635
  mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
636
 
637
  # sample token latents for this step
638
+ print(z.shape)
639
  z = z[mask_to_pred.nonzero(as_tuple=True)]
640
+ print(z.shape)
641
+ print("==============================")
642
  # cfg schedule follow Muse
643
  if cfg_schedule == "linear":
644
  cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len