Update generate.py
Browse files- generate.py +1 -1
generate.py
CHANGED
@@ -12,7 +12,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
12 |
raise ValueError("pad_token is not defined")
|
13 |
batch_size = input_ids.shape[0]
|
14 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
15 |
-
input_ids = torch.cat((
|
16 |
|
17 |
# Simple greedy decoding loop
|
18 |
cur_length = input_ids.shape[1]
|
|
|
12 |
raise ValueError("pad_token is not defined")
|
13 |
batch_size = input_ids.shape[0]
|
14 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
15 |
+
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
16 |
|
17 |
# Simple greedy decoding loop
|
18 |
cur_length = input_ids.shape[1]
|