joaogante HF Staff commited on
Commit
3e8ad6e
·
verified ·
1 Parent(s): f9c40f2

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -1
generate.py CHANGED
@@ -9,7 +9,7 @@ def generate(model, input_ids, generation_config, **kwargs):
9
  logits = model(input_ids).logits
10
  next_token_logits = logits[:, -1, :]
11
  next_tokens = torch.argmax(next_token_logits, dim=-1)
12
- model_inputs["input_ids"] = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
13
  cur_length += 1
14
 
15
  return input_ids
 
9
  logits = model(input_ids).logits
10
  next_token_logits = logits[:, -1, :]
11
  next_tokens = torch.argmax(next_token_logits, dim=-1)
12
+ input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
13
  cur_length += 1
14
 
15
  return input_ids