Upload model
Browse files- generation_utils.py +14 -6
generation_utils.py
CHANGED
@@ -302,8 +302,9 @@ class DreamGenerationMixin:
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
305 |
-
generation_hook_func = kwargs.pop("generation_hook_func", lambda x, logits: (x, logits))
|
306 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
|
|
|
|
307 |
|
308 |
# 2. Define model inputs
|
309 |
assert inputs is not None
|
@@ -355,7 +356,8 @@ class DreamGenerationMixin:
|
|
355 |
input_ids,
|
356 |
attention_mask=attention_mask,
|
357 |
generation_config=generation_config,
|
358 |
-
|
|
|
359 |
)
|
360 |
return result
|
361 |
|
@@ -364,7 +366,8 @@ class DreamGenerationMixin:
|
|
364 |
input_ids: torch.LongTensor,
|
365 |
attention_mask: Optional[torch.LongTensor],
|
366 |
generation_config: DreamGenerationConfig,
|
367 |
-
|
|
|
368 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
369 |
# init values
|
370 |
output_history = generation_config.output_history
|
@@ -400,10 +403,15 @@ class DreamGenerationMixin:
|
|
400 |
attention_mask = "full"
|
401 |
|
402 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
|
|
403 |
for i in range(steps):
|
404 |
mask_index = (x == mask_token_id)
|
405 |
logits = self(x, attention_mask, tok_idx).logits
|
406 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
|
|
|
|
|
|
|
|
407 |
mask_logits = logits[mask_index]
|
408 |
t = timesteps[i]
|
409 |
s = timesteps[i + 1]
|
@@ -435,9 +443,9 @@ class DreamGenerationMixin:
|
|
435 |
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
|
436 |
x0_[transfer_index] = x0[transfer_index].clone()
|
437 |
x[mask_index] = x0_
|
438 |
-
|
439 |
-
# this allows user-defined control of the intermediate steps
|
440 |
-
x
|
441 |
|
442 |
if histories is not None:
|
443 |
histories.append(x.clone())
|
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
|
|
305 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
306 |
+
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
307 |
+
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
308 |
|
309 |
# 2. Define model inputs
|
310 |
assert inputs is not None
|
|
|
356 |
input_ids,
|
357 |
attention_mask=attention_mask,
|
358 |
generation_config=generation_config,
|
359 |
+
generation_tokens_hook_func=generation_tokens_hook_func,
|
360 |
+
generation_logits_hook_func=generation_logits_hook_func
|
361 |
)
|
362 |
return result
|
363 |
|
|
|
366 |
input_ids: torch.LongTensor,
|
367 |
attention_mask: Optional[torch.LongTensor],
|
368 |
generation_config: DreamGenerationConfig,
|
369 |
+
generation_tokens_hook_func,
|
370 |
+
generation_logits_hook_func
|
371 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
372 |
# init values
|
373 |
output_history = generation_config.output_history
|
|
|
403 |
attention_mask = "full"
|
404 |
|
405 |
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
406 |
+
|
407 |
for i in range(steps):
|
408 |
mask_index = (x == mask_token_id)
|
409 |
logits = self(x, attention_mask, tok_idx).logits
|
410 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
411 |
+
|
412 |
+
# this allows user-defined logits control of the intermediate steps
|
413 |
+
logits = generation_logits_hook_func(i, x, logits)
|
414 |
+
|
415 |
mask_logits = logits[mask_index]
|
416 |
t = timesteps[i]
|
417 |
s = timesteps[i + 1]
|
|
|
443 |
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
|
444 |
x0_[transfer_index] = x0[transfer_index].clone()
|
445 |
x[mask_index] = x0_
|
446 |
+
|
447 |
+
# this allows user-defined token control of the intermediate steps
|
448 |
+
x = generation_tokens_hook_func(i, x, logits)
|
449 |
|
450 |
if histories is not None:
|
451 |
histories.append(x.clone())
|