Update generate.py
Browse files- generate.py +1 -1
generate.py
CHANGED
@@ -11,7 +11,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
11 |
if pad_token is None:
|
12 |
raise ValueError("pad_token is not defined")
|
13 |
batch_size = input_ids.shape[0]
|
14 |
-
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token)
|
15 |
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|
16 |
|
17 |
# Simple greedy decoding loop
|
|
|
11 |
if pad_token is None:
|
12 |
raise ValueError("pad_token is not defined")
|
13 |
batch_size = input_ids.shape[0]
|
14 |
+
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
15 |
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|
16 |
|
17 |
# Simple greedy decoding loop
|