joaogante HF Staff commited on
Commit
a3f8f36
·
verified ·
1 Parent(s): 32fe250

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -1
generate.py CHANGED
@@ -12,7 +12,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
12
  raise ValueError("pad_token is not defined")
13
  batch_size = input_ids.shape[0]
14
  pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
15
- input_ids = torch.cat((input_ids, pad_tensor), dim=1)
16
 
17
  # Simple greedy decoding loop
18
  cur_length = input_ids.shape[1]
 
12
  raise ValueError("pad_token is not defined")
13
  batch_size = input_ids.shape[0]
14
  pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
15
+ input_ids = torch.cat((pad_tensor, input_ids), dim=1)
16
 
17
  # Simple greedy decoding loop
18
  cur_length = input_ids.shape[1]