File size: 1,394 Bytes
c88ac20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from completions import calculate_log_probabilities, load_model, tokenize, split_into_words

model, tokenizer, device = load_model()

def test_text_to_words():
    text = """Hello
world!"""
    token_probs = calculate_log_probabilities(model, tokenizer, tokenize(text, tokenizer, device))
    words = split_into_words(token_probs, tokenizer)
    expected_words = ["Hello", "\n", "world", "!"]
    assert [w.text for w in words] == expected_words

def test_multiline():
    text = """// Context: C code from an image manipulation library.
for (int y = 0; y < HEIGHT; y++) {
    for (int x = 0; x < WIDTH; x++) {
        buf[y * HEIGHT + x] = 0;
    }
}"""
    tokenized = tokenize(text, tokenizer, device)
    print(tokenized)
    token_probs = calculate_log_probabilities(model, tokenizer, tokenized)
    words = split_into_words(token_probs, tokenizer)
    print("---", [w.text for w in words])
    expected_words = ["//", " Context", ":", " C", " code", " from", " an", " image", " manipulation", " library", ".\n",
                      "for", "(", "int", "y", "=", "0", ";", "y", "<", "HEIGHT", ";", "y", "+", "+", ")", "{", "\n", "    ", "for", "(", "int", "x", "=", "0", ";", "x", "<", "WIDTH", ";", "x", "+", "+", ")", "{", "\n", "        ", "buf", "[", "y", "*", "HEIGHT", "+", "x", "]", "=", "0", ";", "\n", "    ", "}", "\n", "}"]
    assert [w.text for w in words] == expected_words