mebubo commited on
Commit
9652314
·
1 Parent(s): c88ac20

Rewrite combining into words

Browse files
Files changed (2) hide show
  1. completions.py +30 -32
  2. completions_test.py +5 -1
completions.py CHANGED
@@ -1,4 +1,5 @@
1
  #%%
 
2
  import time
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
@@ -8,6 +9,8 @@ from models import ApiWord, Word
8
 
9
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
10
 
 
 
11
  def starts_with_space(token: str) -> bool:
12
  return token.startswith(chr(9601)) or token.startswith(chr(288))
13
 
@@ -15,38 +18,33 @@ def is_newline(token: str) -> bool:
15
  return len(token) == 1 and ord(token[0]) == 266
16
 
17
  def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
18
- words: list[Word] = []
19
- current_word: list[int] = []
20
- current_log_probs: list[float] = []
21
- current_word_first_token_index: int = 0
22
- all_tokens: list[int] = [token_id for token_id, _ in token_probs]
23
-
24
- def append_word(word):
25
- if word:
26
- words.append(Word(word,
27
- tokenizer.decode(word),
28
- sum(current_log_probs),
29
- all_tokens[:current_word_first_token_index]))
30
-
31
- for i, (token_id, logprob) in enumerate(token_probs):
32
- token: str = tokenizer.convert_ids_to_tokens([token_id])[0]
33
- token_str = tokenizer.decode([token_id])
34
- print(f"-- {token_id=} {token=} {token_str=} {token_str.isalpha()=} {token_str.isspace()=}")
35
- if (not starts_with_space(token) and token_str.isalpha()):
36
- current_word.append(token_id)
37
- current_log_probs.append(logprob)
38
- else:
39
- append_word(current_word)
40
- current_word = [token_id]
41
- current_log_probs = [logprob]
42
- current_word_first_token_index = i
43
- if is_newline(token):
44
- append_word(current_word)
45
- current_word = []
46
- current_log_probs = []
47
- current_word_first_token_index = i
48
-
49
- append_word(current_word)
50
 
51
  return words
52
 
 
1
  #%%
2
+ from dataclasses import dataclass
3
  import time
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
 
9
 
10
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
11
 
12
+ from combine import combine
13
+
14
  def starts_with_space(token: str) -> bool:
15
  return token.startswith(chr(9601)) or token.startswith(chr(288))
16
 
 
18
  return len(token) == 1 and ord(token[0]) == 266
19
 
20
  def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
21
+
22
+ @dataclass
23
+ class Tok:
24
+ index: int
25
+ ids: list[int]
26
+ str: str
27
+ logprob: float
28
+
29
+ def is_beginning_of_word(s: str) -> bool:
30
+ return (s[0] == " " and s[1:].isalpha()) or s.isalpha()
31
+
32
+ def is_continuation_of_word(s: str) -> bool:
33
+ return s.isalpha()
34
+
35
+ def merge_tokens(a: Tok, b: Tok) -> Tok | None:
36
+ if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
37
+ return Tok(b.index, a.ids + b.ids, a.str + b.str, a.logprob * b.logprob)
38
+ return None
39
+
40
+ converted = [Tok(i, [token_id], tokenizer.decode([token_id]), logprob)
41
+ for i, (token_id, logprob) in enumerate(token_probs)]
42
+
43
+ combined = combine(converted, merge_tokens)
44
+
45
+ ts = [t[0] for t in token_probs]
46
+
47
+ words = [Word(tok.ids, tok.str, tok.logprob, ts[:tok.index]) for tok in combined]
 
 
 
 
 
48
 
49
  return words
50
 
completions_test.py CHANGED
@@ -23,5 +23,9 @@ for (int y = 0; y < HEIGHT; y++) {
23
  words = split_into_words(token_probs, tokenizer)
24
  print("---", [w.text for w in words])
25
  expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
26
- "for", "(", "int", "y", "=", "0", ";", "y", "<", "HEIGHT", ";", "y", "+", "+", ")", "{", "\n", " ", "for", "(", "int", "x", "=", "0", ";", "x", "<", "WIDTH", ";", "x", "+", "+", ")", "{", "\n", " ", "buf", "[", "y", "*", "HEIGHT", "+", "x", "]", "=", "0", ";", "\n", " ", "}", "\n", "}"]
 
 
 
 
27
  assert [w.text for w in words] == expected_words
 
23
  words = split_into_words(token_probs, tokenizer)
24
  print("---", [w.text for w in words])
25
  expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
26
+ "for", " (", "int", " y", " =", " ", "0", ";", " y", " <", " HEIGHT", ";", " y", "++)", " {\n",
27
+ " ", " for", " (", "int", " x", " =", " ", "0", ";", " x", " <", " WIDTH", ";", " x", "++)", " {\n",
28
+ " ", " buf", "[y", " *", " HEIGHT", " +", " x", "]", " =", " ", "0", ";\n",
29
+ " ", " }\n",
30
+ "}"]
31
  assert [w.text for w in words] == expected_words