Upload model
Browse files- generation_utils.py +11 -6
generation_utils.py
CHANGED
@@ -302,7 +302,7 @@ class DreamGenerationMixin:
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
305 |
-
|
306 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
307 |
|
308 |
# 2. Define model inputs
|
@@ -355,6 +355,7 @@ class DreamGenerationMixin:
|
|
355 |
input_ids,
|
356 |
attention_mask=attention_mask,
|
357 |
generation_config=generation_config,
|
|
|
358 |
)
|
359 |
return result
|
360 |
|
@@ -363,6 +364,7 @@ class DreamGenerationMixin:
|
|
363 |
input_ids: torch.LongTensor,
|
364 |
attention_mask: Optional[torch.LongTensor],
|
365 |
generation_config: DreamGenerationConfig,
|
|
|
366 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
367 |
# init values
|
368 |
output_history = generation_config.output_history
|
@@ -402,7 +404,7 @@ class DreamGenerationMixin:
|
|
402 |
mask_index = (x == mask_token_id)
|
403 |
logits = self(x, attention_mask, tok_idx).logits
|
404 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
405 |
-
|
406 |
t = timesteps[i]
|
407 |
s = timesteps[i + 1]
|
408 |
|
@@ -410,15 +412,15 @@ class DreamGenerationMixin:
|
|
410 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
411 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
412 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
413 |
-
_, x0[transfer_index_t_s]= sample_tokens(
|
414 |
x[mask_index] = x0.clone()
|
415 |
else:
|
416 |
if alg == 'maskgit_plus':
|
417 |
-
confidence, x0 = sample_tokens(
|
418 |
elif alg == 'topk_margin':
|
419 |
-
confidence, x0 = sample_tokens(
|
420 |
elif alg == 'entropy':
|
421 |
-
confidence, x0 = sample_tokens(
|
422 |
else:
|
423 |
raise RuntimeError(f"Unknown alg: {alg}")
|
424 |
num_mask_token = mask_index.sum()
|
@@ -433,6 +435,9 @@ class DreamGenerationMixin:
|
|
433 |
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
|
434 |
x0_[transfer_index] = x0[transfer_index].clone()
|
435 |
x[mask_index] = x0_
|
|
|
|
|
|
|
436 |
|
437 |
if histories is not None:
|
438 |
histories.append(x.clone())
|
|
|
302 |
**kwargs,
|
303 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
304 |
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
305 |
+
generation_hook_func = kwargs.pop("generation_hook_func", lambda x, logits: (x, logits))
|
306 |
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
307 |
|
308 |
# 2. Define model inputs
|
|
|
355 |
input_ids,
|
356 |
attention_mask=attention_mask,
|
357 |
generation_config=generation_config,
|
358 |
+
generation_hook_func=generation_hook_func
|
359 |
)
|
360 |
return result
|
361 |
|
|
|
364 |
input_ids: torch.LongTensor,
|
365 |
attention_mask: Optional[torch.LongTensor],
|
366 |
generation_config: DreamGenerationConfig,
|
367 |
+
generation_hook_func
|
368 |
) -> Union[DreamModelOutput, torch.LongTensor]:
|
369 |
# init values
|
370 |
output_history = generation_config.output_history
|
|
|
404 |
mask_index = (x == mask_token_id)
|
405 |
logits = self(x, attention_mask, tok_idx).logits
|
406 |
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
407 |
+
mask_logits = logits[mask_index]
|
408 |
t = timesteps[i]
|
409 |
s = timesteps[i + 1]
|
410 |
|
|
|
412 |
p_transfer = 1 - s / t if i < steps - 1 else 1
|
413 |
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
414 |
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
415 |
+
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
416 |
x[mask_index] = x0.clone()
|
417 |
else:
|
418 |
if alg == 'maskgit_plus':
|
419 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
420 |
elif alg == 'topk_margin':
|
421 |
+
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
422 |
elif alg == 'entropy':
|
423 |
+
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
424 |
else:
|
425 |
raise RuntimeError(f"Unknown alg: {alg}")
|
426 |
num_mask_token = mask_index.sum()
|
|
|
435 |
x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
|
436 |
x0_[transfer_index] = x0[transfer_index].clone()
|
437 |
x[mask_index] = x0_
|
438 |
+
|
439 |
+
# this allows user-defined control of the intermediate steps
|
440 |
+
x, logits = generation_hook_func(x, logits)
|
441 |
|
442 |
if histories is not None:
|
443 |
histories.append(x.clone())
|