File size: 1,417 Bytes
b834bad
 
33ad63f
ca51180
0373cfb
 
33ad63f
 
8baa630
f3a0f61
8baa630
33ad63f
411392e
8baa630
0bae8c6
8baa630
32fe250
a3f8f36
5ab5582
8baa630
 
b834bad
ca51180
b834bad
e108c35
3e8ad6e
b834bad
 
33ad63f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch

def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
    generation_config = generation_config or model.generation_config  # default to the model generation config
    cur_length = input_ids.shape[1]
    max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens

    # Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
    if left_padding is not None:
        if not isinstance(left_padding, int) or left_padding < 0:
            raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")

        pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
        if pad_token is None:
            raise ValueError("pad_token is not defined")
        batch_size = input_ids.shape[0]
        pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
        input_ids = torch.cat((pad_tensor, input_ids), dim=1)
        cur_length = input_ids.shape[1]

    # Simple greedy decoding loop
    while cur_length < max_length:
        logits = model(input_ids).logits
        next_token_logits = logits[:, -1, :]
        next_tokens = torch.argmax(next_token_logits, dim=-1)
        input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
        cur_length += 1

    return input_ids