wondervictor commited on
Commit
624dd8d
·
verified ·
1 Parent(s): b31a0bf

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +4 -4
autoregressive/models/generate.py CHANGED
@@ -68,13 +68,13 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
68
  # mask = (probs == values).float()
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
71
- # print(f'inf:{torch.any(torch.isinf(probs))}')
72
- # print(f'nan: {torch.any(torch.isnan(probs))}')
73
  # add to fix 'nan' and 'inf'
74
- probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
75
  probs = torch.clamp(probs, min=0, max=None)
76
  probs = probs / probs.sum()
77
-
 
78
 
79
  idx = torch.multinomial(probs, num_samples=1)
80
  else:
 
68
  # mask = (probs == values).float()
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
71
+
 
72
  # add to fix 'nan' and 'inf'
73
+ probs = torch.where(torch.isinf(probs), torch.tensor(0.0), probs)
74
  probs = torch.clamp(probs, min=0, max=None)
75
  probs = probs / probs.sum()
76
+ print(f'inf:{torch.any(torch.isinf(probs))}')
77
+ print(f'nan: {torch.any(torch.isnan(probs))}')
78
 
79
  idx = torch.multinomial(probs, num_samples=1)
80
  else: