Update generate.py
Browse files- generate.py +1 -1
generate.py
CHANGED
@@ -9,7 +9,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
9 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
10 |
pad_token = kwargs.get("pad_token") or generation_config.pad_token_id or model.config.pad_token_id
|
11 |
if pad_token is None:
|
12 |
-
raise ValueError("pad_token is not defined)
|
13 |
batch_size = input_ids.shape[0]
|
14 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token)
|
15 |
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|
|
|
9 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
10 |
pad_token = kwargs.get("pad_token") or generation_config.pad_token_id or model.config.pad_token_id
|
11 |
if pad_token is None:
|
12 |
+
raise ValueError("pad_token is not defined")
|
13 |
batch_size = input_ids.shape[0]
|
14 |
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token)
|
15 |
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|