joaogante HF Staff commited on
Commit
b834bad
·
verified ·
1 Parent(s): 0116cc1

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +15 -0
generate.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def generate(model, model_inputs, generation_config, **kwargs)
4
+ cur_length = model_inputs["input_ids"].shape[1]
5
+ max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
6
+
7
+ while cur_length < max_length:
8
+ logits = model(model_inputs["input_ids"]).logits
9
+ next_token_logits = logits[:, -1, :]
10
+ next_tokens = torch.argmax(next_token_logits)
11
+ model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], next_tokens), dim=-1)
12
+ cur_length += 1
13
+
14
+ return model_inputs["input_ids"]
15
+