Update generate.py
Browse files- generate.py +1 -0
generate.py
CHANGED
@@ -15,6 +15,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
15 |
batch_size = input_ids.shape[0]
|
16 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
17 |
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
|
|
18 |
|
19 |
# Simple greedy decoding loop
|
20 |
while cur_length < max_length:
|
|
|
15 |
batch_size = input_ids.shape[0]
|
16 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
17 |
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
18 |
+
cur_length = input_ids.shape[1]
|
19 |
|
20 |
# Simple greedy decoding loop
|
21 |
while cur_length < max_length:
|