joaogante HF Staff commited on
Commit
e108c35
·
verified ·
1 Parent(s): ca51180

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -1
generate.py CHANGED
@@ -8,7 +8,7 @@ def generate(model, input_ids, generation_config, **kwargs):
8
  while cur_length < max_length:
9
  logits = model(input_ids).logits
10
  next_token_logits = logits[:, -1, :]
11
- next_tokens = torch.argmax(next_token_logits)
12
  model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
13
  cur_length += 1
14
 
 
8
  while cur_length < max_length:
9
  logits = model(input_ids).logits
10
  next_token_logits = logits[:, -1, :]
11
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
12
  model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
13
  cur_length += 1
14