Upload model
Browse files- generation_utils.py +3 -1
generation_utils.py
CHANGED
@@ -403,7 +403,9 @@ class DreamGenerationMixin:
|
|
403 |
attention_mask = "full"
|
404 |
|
405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
406 |
-
|
|
|
|
|
407 |
for i in range(steps):
|
408 |
mask_index = (x == mask_token_id)
|
409 |
logits = self(x, attention_mask, tok_idx).logits
|
|
|
403 |
attention_mask = "full"
|
404 |
|
405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
406 |
+
|
407 |
+
# this allows user-defined token control of the intermediate steps
|
408 |
+
x = generation_tokens_hook_func(None, x, None)
|
409 |
for i in range(steps):
|
410 |
mask_index = (x == mask_token_id)
|
411 |
logits = self(x, attention_mask, tok_idx).logits
|