jiacheng-ye commited on
Commit
9ccfc13
·
verified ·
1 Parent(s): c667972

Upload model

Browse files
Files changed (1) hide show
  1. generation_utils.py +3 -1
generation_utils.py CHANGED
@@ -403,7 +403,9 @@ class DreamGenerationMixin:
403
  attention_mask = "full"
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
-
 
 
407
  for i in range(steps):
408
  mask_index = (x == mask_token_id)
409
  logits = self(x, attention_mask, tok_idx).logits
 
403
  attention_mask = "full"
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
+
407
+ # this allows user-defined token control of the intermediate steps
408
+ x = generation_tokens_hook_func(None, x, None)
409
  for i in range(steps):
410
  mask_index = (x == mask_token_id)
411
  logits = self(x, attention_mask, tok_idx).logits