File size: 970 Bytes
6735ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from dataclasses import dataclass
@dataclass
class Word:
tokens: list[int]
text: str
logprob: float
first_token_index: int
def split_into_words(tokens, log_probs) -> list[Word]:
words = []
current_word = []
current_log_probs = []
current_word_first_token_index = 0
for i, (token, logprob) in enumerate(zip(tokens, log_probs)):
if not token.startswith(chr(9601)) and token.isalpha():
current_word.append(token)
current_log_probs.append(logprob)
else:
if current_word:
words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
current_word = [token]
current_log_probs = [logprob]
current_word_first_token_index = i
if current_word:
words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
return words
|