joaogante HF Staff commited on
Commit
5ab5582
·
verified ·
1 Parent(s): 0373cfb

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -0
generate.py CHANGED
@@ -15,6 +15,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
15
  batch_size = input_ids.shape[0]
16
  pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
17
  input_ids = torch.cat((pad_tensor, input_ids), dim=1)
 
18
 
19
  # Simple greedy decoding loop
20
  while cur_length < max_length:
 
15
  batch_size = input_ids.shape[0]
16
  pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
17
  input_ids = torch.cat((pad_tensor, input_ids), dim=1)
18
+ cur_length = input_ids.shape[1]
19
 
20
  # Simple greedy decoding loop
21
  while cur_length < max_length: