mebubo commited on
Commit
6735ae4
·
1 Parent(s): 91f2f92

refactor: separate text processing logic into a new module for better organization

Browse files
Files changed (2) hide show
  1. app.py +1 -29
  2. text_processing.py +30 -0
app.py CHANGED
@@ -1,5 +1,5 @@
1
  #%%
2
- from dataclasses import dataclass
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from pprint import pprint
@@ -73,34 +73,6 @@ for word, avg_logprob in words:
73
 
74
  # %%
75
 
76
- @dataclass
77
- class Word:
78
- tokens: list[int]
79
- text: str
80
- logprob: float
81
- first_token_index: int
82
-
83
- def split_into_words(tokens, log_probs) -> list[Word]:
84
- words = []
85
- current_word = []
86
- current_log_probs = []
87
- current_word_first_token_index = 0
88
-
89
- for i, (token, logprob) in enumerate(zip(tokens, log_probs)):
90
- if not token.startswith(chr(9601)) and token.isalpha():
91
- current_word.append(token)
92
- current_log_probs.append(logprob)
93
- else:
94
- if current_word:
95
- words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
96
- current_word = [token]
97
- current_log_probs = [logprob]
98
- current_word_first_token_index = i
99
-
100
- if current_word:
101
- words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
102
-
103
- return words
104
 
105
  words = split_into_words(tokens[1:], token_log_probs)
106
 
 
1
  #%%
2
+ from text_processing import split_into_words, Word
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from pprint import pprint
 
73
 
74
  # %%
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  words = split_into_words(tokens[1:], token_log_probs)
78
 
text_processing.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class Word:
5
+ tokens: list[int]
6
+ text: str
7
+ logprob: float
8
+ first_token_index: int
9
+
10
+ def split_into_words(tokens, log_probs) -> list[Word]:
11
+ words = []
12
+ current_word = []
13
+ current_log_probs = []
14
+ current_word_first_token_index = 0
15
+
16
+ for i, (token, logprob) in enumerate(zip(tokens, log_probs)):
17
+ if not token.startswith(chr(9601)) and token.isalpha():
18
+ current_word.append(token)
19
+ current_log_probs.append(logprob)
20
+ else:
21
+ if current_word:
22
+ words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
23
+ current_word = [token]
24
+ current_log_probs = [logprob]
25
+ current_word_first_token_index = i
26
+
27
+ if current_word:
28
+ words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
29
+
30
+ return words