Rewrite combining into words
Browse files- completions.py +30 -32
- completions_test.py +5 -1
completions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
#%%
|
|
|
2 |
import time
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
|
@@ -8,6 +9,8 @@ from models import ApiWord, Word
|
|
8 |
|
9 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
10 |
|
|
|
|
|
11 |
def starts_with_space(token: str) -> bool:
|
12 |
return token.startswith(chr(9601)) or token.startswith(chr(288))
|
13 |
|
@@ -15,38 +18,33 @@ def is_newline(token: str) -> bool:
|
|
15 |
return len(token) == 1 and ord(token[0]) == 266
|
16 |
|
17 |
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
current_word = []
|
46 |
-
current_log_probs = []
|
47 |
-
current_word_first_token_index = i
|
48 |
-
|
49 |
-
append_word(current_word)
|
50 |
|
51 |
return words
|
52 |
|
|
|
1 |
#%%
|
2 |
+
from dataclasses import dataclass
|
3 |
import time
|
4 |
import torch
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
|
|
|
9 |
|
10 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
11 |
|
12 |
+
from combine import combine
|
13 |
+
|
14 |
def starts_with_space(token: str) -> bool:
|
15 |
return token.startswith(chr(9601)) or token.startswith(chr(288))
|
16 |
|
|
|
18 |
return len(token) == 1 and ord(token[0]) == 266
|
19 |
|
20 |
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Tok:
|
24 |
+
index: int
|
25 |
+
ids: list[int]
|
26 |
+
str: str
|
27 |
+
logprob: float
|
28 |
+
|
29 |
+
def is_beginning_of_word(s: str) -> bool:
|
30 |
+
return (s[0] == " " and s[1:].isalpha()) or s.isalpha()
|
31 |
+
|
32 |
+
def is_continuation_of_word(s: str) -> bool:
|
33 |
+
return s.isalpha()
|
34 |
+
|
35 |
+
def merge_tokens(a: Tok, b: Tok) -> Tok | None:
|
36 |
+
if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
|
37 |
+
return Tok(b.index, a.ids + b.ids, a.str + b.str, a.logprob * b.logprob)
|
38 |
+
return None
|
39 |
+
|
40 |
+
converted = [Tok(i, [token_id], tokenizer.decode([token_id]), logprob)
|
41 |
+
for i, (token_id, logprob) in enumerate(token_probs)]
|
42 |
+
|
43 |
+
combined = combine(converted, merge_tokens)
|
44 |
+
|
45 |
+
ts = [t[0] for t in token_probs]
|
46 |
+
|
47 |
+
words = [Word(tok.ids, tok.str, tok.logprob, ts[:tok.index]) for tok in combined]
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
return words
|
50 |
|
completions_test.py
CHANGED
@@ -23,5 +23,9 @@ for (int y = 0; y < HEIGHT; y++) {
|
|
23 |
words = split_into_words(token_probs, tokenizer)
|
24 |
print("---", [w.text for w in words])
|
25 |
expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
|
26 |
-
"for", "(", "int", "
|
|
|
|
|
|
|
|
|
27 |
assert [w.text for w in words] == expected_words
|
|
|
23 |
words = split_into_words(token_probs, tokenizer)
|
24 |
print("---", [w.text for w in words])
|
25 |
expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
|
26 |
+
"for", " (", "int", " y", " =", " ", "0", ";", " y", " <", " HEIGHT", ";", " y", "++)", " {\n",
|
27 |
+
" ", " for", " (", "int", " x", " =", " ", "0", ";", " x", " <", " WIDTH", ";", " x", "++)", " {\n",
|
28 |
+
" ", " buf", "[y", " *", " HEIGHT", " +", " x", "]", " =", " ", "0", ";\n",
|
29 |
+
" ", " }\n",
|
30 |
+
"}"]
|
31 |
assert [w.text for w in words] == expected_words
|