Spaces:
Runtime error
Runtime error
File size: 1,083 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
def top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): probability distribution tensor.
p (float): probability threshold for top-p sampling.
Returns:
torch.Tensor: sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative
probability mass exceeds the threshold p. The distribution is
renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
|