Spaces:
Runtime error
Runtime error
Update superposed/llama/superpose.py
Browse files
superposed/llama/superpose.py
CHANGED
@@ -198,7 +198,7 @@ class Superpose(nn.Module):
|
|
198 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
199 |
next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
|
200 |
# Set all other probs to 0
|
201 |
-
new_probs_map = torch.zeros(probs.shape).bool()
|
202 |
new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
|
203 |
new_probs = torch.where(new_probs_map, probs, 0)
|
204 |
# Renormalize
|
|
|
198 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
199 |
next_token = torch.gather(probs_idx, -1, torch.topk(probs_sort, k, dim=-1)[1])
|
200 |
# Set all other probs to 0
|
201 |
+
new_probs_map = torch.zeros(probs.shape, device="cuda").bool()
|
202 |
new_probs_map[torch.repeat_interleave(torch.arange(n_prompts), k), torch.flatten(next_token)] = True
|
203 |
new_probs = torch.where(new_probs_map, probs, 0)
|
204 |
# Renormalize
|