joaogante HF Staff commited on
Commit
ca51180
·
verified ·
1 Parent(s): 1539b58

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -5
generate.py CHANGED
@@ -1,15 +1,16 @@
1
  import torch
2
 
3
- def generate(model, model_inputs, generation_config, **kwargs):
4
- cur_length = model_inputs["input_ids"].shape[1]
 
5
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
6
 
7
  while cur_length < max_length:
8
- logits = model(model_inputs["input_ids"]).logits
9
  next_token_logits = logits[:, -1, :]
10
  next_tokens = torch.argmax(next_token_logits)
11
- model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], next_tokens), dim=-1)
12
  cur_length += 1
13
 
14
- return model_inputs["input_ids"]
15
 
 
1
  import torch
2
 
3
+ def generate(model, input_ids, generation_config, **kwargs):
4
+ generation_config = generation_config or model.generation_config # default to the model generation config
5
+ cur_length = input_ids.shape[1]
6
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
7
 
8
  while cur_length < max_length:
9
+ logits = model(input_ids).logits
10
  next_token_logits = logits[:, -1, :]
11
  next_tokens = torch.argmax(next_token_logits)
12
+ model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
13
  cur_length += 1
14
 
15
+ return input_ids
16