Create generate.py
Browse files- generate.py +15 -0
generate.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def generate(model, model_inputs, generation_config, **kwargs)
|
4 |
+
cur_length = model_inputs["input_ids"].shape[1]
|
5 |
+
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
6 |
+
|
7 |
+
while cur_length < max_length:
|
8 |
+
logits = model(model_inputs["input_ids"]).logits
|
9 |
+
next_token_logits = logits[:, -1, :]
|
10 |
+
next_tokens = torch.argmax(next_token_logits)
|
11 |
+
model_inputs["input_ids"] = torch.cat((model_inputs["input_ids"], next_tokens), dim=-1)
|
12 |
+
cur_length += 1
|
13 |
+
|
14 |
+
return model_inputs["input_ids"]
|
15 |
+
|