joaogante HF Staff commited on
Commit
32fe250
·
verified ·
1 Parent(s): f3a0f61

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -1
generate.py CHANGED
@@ -11,7 +11,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
11
  if pad_token is None:
12
  raise ValueError("pad_token is not defined")
13
  batch_size = input_ids.shape[0]
14
- pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token)
15
  input_ids = torch.cat((input_ids, pad_tensor), dim=1)
16
 
17
  # Simple greedy decoding loop
 
11
  if pad_token is None:
12
  raise ValueError("pad_token is not defined")
13
  batch_size = input_ids.shape[0]
14
+ pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
15
  input_ids = torch.cat((input_ids, pad_tensor), dim=1)
16
 
17
  # Simple greedy decoding loop