joaogante HF Staff commited on
Commit
40081d1
·
1 Parent(s): 594965c
Files changed (2) hide show
  1. README.md +16 -0
  2. generate.py +29 -0
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## Description
6
+ Test repo to experiment with calling `generate` from the hub. It is a simplified implementation of greedy decoding.
7
+
8
+ ## Additional Arguments
9
+ `left_padding` (`int`, *optional*): number of padding tokens to add before the provided input
10
+
11
+ ## Output Type changes
12
+ (none)
13
+
14
+ ## Requirements
15
+ (none)
16
+
generate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
4
+ generation_config = generation_config or model.generation_config # default to the model generation config
5
+ cur_length = input_ids.shape[1]
6
+ max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
7
+
8
+ # Example of custom argument: add left padding
9
+ if left_padding is not None:
10
+ if not isinstance(left_padding, int) or left_padding < 0:
11
+ raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
12
+ pad_token = kwargs.get("pad_token") or generation_config.pad_token_id or model.config.pad_token_id
13
+ if pad_token is None:
14
+ raise ValueError("pad_token is not defined")
15
+ batch_size = input_ids.shape[0]
16
+ pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
17
+ input_ids = torch.cat((pad_tensor, input_ids), dim=1)
18
+ cur_length = input_ids.shape[1]
19
+
20
+ # Simple greedy decoding loop
21
+ while cur_length < max_length:
22
+ logits = model(input_ids).logits
23
+ next_token_logits = logits[:, -1, :]
24
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
25
+ input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
26
+ cur_length += 1
27
+
28
+ return input_ids
29
+