copy
Browse files- README.md +16 -0
- generate.py +29 -0
README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
## Description
|
6 |
+
Test repo to experiment with calling `generate` from the hub. It is a simplified implementation of greedy decoding.
|
7 |
+
|
8 |
+
## Additional Arguments
|
9 |
+
`left_padding` (`int`, *optional*): number of padding tokens to add before the provided input
|
10 |
+
|
11 |
+
## Output Type changes
|
12 |
+
(none)
|
13 |
+
|
14 |
+
## Requirements
|
15 |
+
(none)
|
16 |
+
|
generate.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
4 |
+
generation_config = generation_config or model.generation_config # default to the model generation config
|
5 |
+
cur_length = input_ids.shape[1]
|
6 |
+
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
7 |
+
|
8 |
+
# Example of custom argument: add left padding
|
9 |
+
if left_padding is not None:
|
10 |
+
if not isinstance(left_padding, int) or left_padding < 0:
|
11 |
+
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
12 |
+
pad_token = kwargs.get("pad_token") or generation_config.pad_token_id or model.config.pad_token_id
|
13 |
+
if pad_token is None:
|
14 |
+
raise ValueError("pad_token is not defined")
|
15 |
+
batch_size = input_ids.shape[0]
|
16 |
+
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
17 |
+
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
18 |
+
cur_length = input_ids.shape[1]
|
19 |
+
|
20 |
+
# Simple greedy decoding loop
|
21 |
+
while cur_length < max_length:
|
22 |
+
logits = model(input_ids).logits
|
23 |
+
next_token_logits = logits[:, -1, :]
|
24 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
25 |
+
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
26 |
+
cur_length += 1
|
27 |
+
|
28 |
+
return input_ids
|
29 |
+
|