Update generate.py
Browse files- generate.py +1 -1
generate.py
CHANGED
@@ -8,7 +8,7 @@ def generate(model, input_ids, generation_config, **kwargs):
|
|
8 |
while cur_length < max_length:
|
9 |
logits = model(input_ids).logits
|
10 |
next_token_logits = logits[:, -1, :]
|
11 |
-
next_tokens = torch.argmax(next_token_logits)
|
12 |
model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
|
13 |
cur_length += 1
|
14 |
|
|
|
8 |
while cur_length < max_length:
|
9 |
logits = model(input_ids).logits
|
10 |
next_token_logits = logits[:, -1, :]
|
11 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
12 |
model_inputs["input_ids"] = torch.cat((input_ids, next_tokens), dim=-1)
|
13 |
cur_length += 1
|
14 |
|