jiacheng-ye commited on
Commit
c667972
·
verified ·
1 Parent(s): 60bd28a

Upload model

Browse files
Files changed (1) hide show
  1. generation_utils.py +14 -6
generation_utils.py CHANGED
@@ -302,8 +302,9 @@ class DreamGenerationMixin:
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
- generation_hook_func = kwargs.pop("generation_hook_func", lambda x, logits: (x, logits))
306
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
 
 
307
 
308
  # 2. Define model inputs
309
  assert inputs is not None
@@ -355,7 +356,8 @@ class DreamGenerationMixin:
355
  input_ids,
356
  attention_mask=attention_mask,
357
  generation_config=generation_config,
358
- generation_hook_func=generation_hook_func
 
359
  )
360
  return result
361
 
@@ -364,7 +366,8 @@ class DreamGenerationMixin:
364
  input_ids: torch.LongTensor,
365
  attention_mask: Optional[torch.LongTensor],
366
  generation_config: DreamGenerationConfig,
367
- generation_hook_func
 
368
  ) -> Union[DreamModelOutput, torch.LongTensor]:
369
  # init values
370
  output_history = generation_config.output_history
@@ -400,10 +403,15 @@ class DreamGenerationMixin:
400
  attention_mask = "full"
401
 
402
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
 
403
  for i in range(steps):
404
  mask_index = (x == mask_token_id)
405
  logits = self(x, attention_mask, tok_idx).logits
406
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
 
 
 
 
407
  mask_logits = logits[mask_index]
408
  t = timesteps[i]
409
  s = timesteps[i + 1]
@@ -435,9 +443,9 @@ class DreamGenerationMixin:
435
  x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
436
  x0_[transfer_index] = x0[transfer_index].clone()
437
  x[mask_index] = x0_
438
-
439
- # this allows user-defined control of the intermediate steps
440
- x, logits = generation_hook_func(x, logits)
441
 
442
  if histories is not None:
443
  histories.append(x.clone())
 
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 
305
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
306
+ generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
307
+ generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
308
 
309
  # 2. Define model inputs
310
  assert inputs is not None
 
356
  input_ids,
357
  attention_mask=attention_mask,
358
  generation_config=generation_config,
359
+ generation_tokens_hook_func=generation_tokens_hook_func,
360
+ generation_logits_hook_func=generation_logits_hook_func
361
  )
362
  return result
363
 
 
366
  input_ids: torch.LongTensor,
367
  attention_mask: Optional[torch.LongTensor],
368
  generation_config: DreamGenerationConfig,
369
+ generation_tokens_hook_func,
370
+ generation_logits_hook_func
371
  ) -> Union[DreamModelOutput, torch.LongTensor]:
372
  # init values
373
  output_history = generation_config.output_history
 
403
  attention_mask = "full"
404
 
405
  timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
+
407
  for i in range(steps):
408
  mask_index = (x == mask_token_id)
409
  logits = self(x, attention_mask, tok_idx).logits
410
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
411
+
412
+ # this allows user-defined logits control of the intermediate steps
413
+ logits = generation_logits_hook_func(i, x, logits)
414
+
415
  mask_logits = logits[mask_index]
416
  t = timesteps[i]
417
  s = timesteps[i + 1]
 
443
  x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
444
  x0_[transfer_index] = x0[transfer_index].clone()
445
  x[mask_index] = x0_
446
+
447
+ # this allows user-defined token control of the intermediate steps
448
+ x = generation_tokens_hook_func(i, x, logits)
449
 
450
  if histories is not None:
451
  histories.append(x.clone())