jiacheng-ye commited on
Commit
60bd28a
·
verified ·
1 Parent(s): 5831947

Upload model

Browse files
Files changed (1) hide show
  1. generation_utils.py +11 -6
generation_utils.py CHANGED
@@ -302,7 +302,7 @@ class DreamGenerationMixin:
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
- tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
306
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
307
 
308
  # 2. Define model inputs
@@ -355,6 +355,7 @@ class DreamGenerationMixin:
355
  input_ids,
356
  attention_mask=attention_mask,
357
  generation_config=generation_config,
 
358
  )
359
  return result
360
 
@@ -363,6 +364,7 @@ class DreamGenerationMixin:
363
  input_ids: torch.LongTensor,
364
  attention_mask: Optional[torch.LongTensor],
365
  generation_config: DreamGenerationConfig,
 
366
  ) -> Union[DreamModelOutput, torch.LongTensor]:
367
  # init values
368
  output_history = generation_config.output_history
@@ -402,7 +404,7 @@ class DreamGenerationMixin:
402
  mask_index = (x == mask_token_id)
403
  logits = self(x, attention_mask, tok_idx).logits
404
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
405
- logits = logits[mask_index]
406
  t = timesteps[i]
407
  s = timesteps[i + 1]
408
 
@@ -410,15 +412,15 @@ class DreamGenerationMixin:
410
  p_transfer = 1 - s / t if i < steps - 1 else 1
411
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
412
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
413
- _, x0[transfer_index_t_s]= sample_tokens(logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
414
  x[mask_index] = x0.clone()
415
  else:
416
  if alg == 'maskgit_plus':
417
- confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
418
  elif alg == 'topk_margin':
419
- confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
420
  elif alg == 'entropy':
421
- confidence, x0 = sample_tokens(logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
422
  else:
423
  raise RuntimeError(f"Unknown alg: {alg}")
424
  num_mask_token = mask_index.sum()
@@ -433,6 +435,9 @@ class DreamGenerationMixin:
433
  x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
434
  x0_[transfer_index] = x0[transfer_index].clone()
435
  x[mask_index] = x0_
 
 
 
436
 
437
  if histories is not None:
438
  histories.append(x.clone())
 
302
  **kwargs,
303
  ) -> Union[DreamModelOutput, torch.LongTensor]:
304
  # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
+ generation_hook_func = kwargs.pop("generation_hook_func", lambda x, logits: (x, logits))
306
  generation_config = self._prepare_generation_config(generation_config, **kwargs)
307
 
308
  # 2. Define model inputs
 
355
  input_ids,
356
  attention_mask=attention_mask,
357
  generation_config=generation_config,
358
+ generation_hook_func=generation_hook_func
359
  )
360
  return result
361
 
 
364
  input_ids: torch.LongTensor,
365
  attention_mask: Optional[torch.LongTensor],
366
  generation_config: DreamGenerationConfig,
367
+ generation_hook_func
368
  ) -> Union[DreamModelOutput, torch.LongTensor]:
369
  # init values
370
  output_history = generation_config.output_history
 
404
  mask_index = (x == mask_token_id)
405
  logits = self(x, attention_mask, tok_idx).logits
406
  logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
407
+ mask_logits = logits[mask_index]
408
  t = timesteps[i]
409
  s = timesteps[i + 1]
410
 
 
412
  p_transfer = 1 - s / t if i < steps - 1 else 1
413
  x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
414
  transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
415
+ _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
416
  x[mask_index] = x0.clone()
417
  else:
418
  if alg == 'maskgit_plus':
419
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
420
  elif alg == 'topk_margin':
421
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
422
  elif alg == 'entropy':
423
+ confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
424
  else:
425
  raise RuntimeError(f"Unknown alg: {alg}")
426
  num_mask_token = mask_index.sum()
 
435
  x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
436
  x0_[transfer_index] = x0[transfer_index].clone()
437
  x[mask_index] = x0_
438
+
439
+ # this allows user-defined control of the intermediate steps
440
+ x, logits = generation_hook_func(x, logits)
441
 
442
  if histories is not None:
443
  histories.append(x.clone())