Update generate.py
Browse files- generate.py +6 -5
generate.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
def generate(model,
|
4 |
-
|
|
|
5 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
6 |
|
7 |
while cur_length < max_length:
|
8 |
-
logits = model(
|
9 |
next_token_logits = logits[:, -1, :]
|
10 |
next_tokens = torch.argmax(next_token_logits)
|
11 |
-
model_inputs["input_ids"] = torch.cat((
|
12 |
cur_length += 1
|
13 |
|
14 |
-
return
|
15 |
|
|
|
1 |
import torch
|
2 |
|
3 |
+
def generate(model, input_ids, generation_config, **kwargs):
|
4 |
+
generation_config = generation_config or model.generation_config # default to the model generation config
|
5 |
+
cur_length = input_ids.shape[1]
|
6 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
7 |
|
8 |
while cur_length < max_length:
|
9 |
+
logits = model(input_ids).logits
|
10 |
next_token_logits = logits[:, -1, :]
|
11 |
next_tokens = torch.argmax(next_token_logits)
|
12 |
+
model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
|
13 |
cur_length += 1
|
14 |
|
15 |
+
return input_ids
|
16 |
|