File size: 970 Bytes
6735ae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dataclasses import dataclass

@dataclass
class Word:
    tokens: list[int]
    text: str
    logprob: float
    first_token_index: int

def split_into_words(tokens, log_probs) -> list[Word]:
    words = []
    current_word = []
    current_log_probs = []
    current_word_first_token_index = 0

    for i, (token, logprob) in enumerate(zip(tokens, log_probs)):
        if not token.startswith(chr(9601)) and token.isalpha():
            current_word.append(token)
            current_log_probs.append(logprob)
        else:
            if current_word:
                words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
            current_word = [token]
            current_log_probs = [logprob]
            current_word_first_token_index = i

    if current_word:
        words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))

    return words