Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
def top_p(probs: torch.Tensor, p: float) -> torch.Tensor: | |
""" | |
Perform top-p (nucleus) sampling on a probability distribution. | |
Args: | |
probs (torch.Tensor): probability distribution tensor. | |
p (float): probability threshold for top-p sampling. | |
Returns: | |
torch.Tensor: sampled token indices. | |
Note: | |
Top-p sampling selects the smallest set of tokens whose cumulative | |
probability mass exceeds the threshold p. The distribution is | |
renormalized based on the selected tokens. | |
""" | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |