mebubo commited on
Commit
bb48904
·
1 Parent(s): 2fb63bf

Working expand_llm

Browse files
Files changed (5) hide show
  1. completions.py +13 -29
  2. expand.py +2 -5
  3. expand_llm.py +20 -0
  4. expand_test.py +12 -12
  5. run.py +49 -0
completions.py CHANGED
@@ -95,7 +95,7 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
95
  )
96
  return outputs
97
 
98
- def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer, min_p: float) -> list[list[tuple[int, str, float]]]:
99
  input_ids = inputs["input_ids"]
100
  attention_mask = inputs["attention_mask"]
101
  with torch.no_grad():
@@ -109,6 +109,18 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
109
  result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
110
  return result
111
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
113
  all_new_words = []
114
  for i in range(num_inputs):
@@ -161,31 +173,3 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
161
  else:
162
  result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
163
  return result
164
-
165
- # %%
166
- model, tokenizer, device = load_model()
167
-
168
- #%%
169
- input_text = "The quick brown fox jumpz over"
170
- inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
171
-
172
- #%%
173
- token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
174
-
175
- #%%
176
- words = split_into_words(token_probs, tokenizer)
177
- log_prob_threshold = -5.0
178
- low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
179
-
180
- #%%
181
- contexts = [word.context for _, word in low_prob_words]
182
- inputs = prepare_inputs(contexts, tokenizer, device)
183
- input_ids = inputs["input_ids"]
184
-
185
- #%%
186
- next_tokens = find_next_tokens(model, inputs, tokenizer, min_p=-5)
187
-
188
- #%%
189
- next_tokens
190
-
191
- # %%
 
95
  )
96
  return outputs
97
 
98
+ def find_next_tokens_0(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer, min_p: float) -> list[list[tuple[int, str, float]]]:
99
  input_ids = inputs["input_ids"]
100
  attention_mask = inputs["attention_mask"]
101
  with torch.no_grad():
 
109
  result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
110
  return result
111
 
112
+ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer) -> list[list[tuple[int, float]]]:
113
+ input_ids = inputs["input_ids"]
114
+ attention_mask = inputs["attention_mask"]
115
+ with torch.no_grad():
116
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
117
+ logits: torch.Tensor = outputs.logits[:, -1, :]
118
+ log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
119
+ result = []
120
+ for probs in log_probs:
121
+ result.append([(i, p) for i, p in enumerate(probs)])
122
+ return result
123
+
124
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
125
  all_new_words = []
126
  for i in range(num_inputs):
 
173
  else:
174
  result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
175
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
expand.py CHANGED
@@ -2,10 +2,6 @@ from collections import defaultdict
2
  from dataclasses import dataclass
3
  from typing import Protocol
4
 
5
- # import torch
6
- # from transformers import PreTrainedModel
7
- # from completions import find_next_tokens, Tokenizer
8
-
9
  @dataclass
10
  class Series:
11
  id: int
@@ -46,7 +42,7 @@ class ExpansionResultBatch:
46
  def compute_new_series(result: ExpansionOneResult) -> list[Series]:
47
  results = []
48
  for expansion in result.expansions:
49
- results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget - expansion.cost))
50
  return results
51
 
52
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
@@ -73,6 +69,7 @@ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
73
  completed_series: list[Series] = []
74
  current_batch = batch
75
  while len(current_batch.items) > 0:
 
76
  current_batch_items = []
77
  expanded = expander.expand(current_batch)
78
  for item in expanded.items:
 
2
  from dataclasses import dataclass
3
  from typing import Protocol
4
 
 
 
 
 
5
  @dataclass
6
  class Series:
7
  id: int
 
42
  def compute_new_series(result: ExpansionOneResult) -> list[Series]:
43
  results = []
44
  for expansion in result.expansions:
45
+ results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget + expansion.cost))
46
  return results
47
 
48
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
 
69
  completed_series: list[Series] = []
70
  current_batch = batch
71
  while len(current_batch.items) > 0:
72
+ print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
73
  current_batch_items = []
74
  expanded = expander.expand(current_batch)
75
  for item in expanded.items:
expand_llm.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from expand import *
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
3
+ from dataclasses import dataclass
4
+ from completions import prepare_inputs, find_next_tokens
5
+
6
+ type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
7
+
8
+ @dataclass
9
+ class ExpanderOneBatchLLM:
10
+ model: PreTrainedModel
11
+ tokenizer: Tokenizer
12
+
13
+ def expand(self, batch: Batch) -> ExpansionOneResultBatch:
14
+ inputs = prepare_inputs([s.tokens for s in batch.items], self.tokenizer, self.model.device)
15
+ next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
+ results = []
17
+ for s, next_tokens in zip(batch.items, next_tokens):
18
+ expansions = [ExpansionOne(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.budget >= 0]
19
+ results.append(ExpansionOneResult(series=s, expansions=expansions))
20
+ return ExpansionOneResultBatch(items=results)
expand_test.py CHANGED
@@ -12,8 +12,8 @@ possible_sequences = [
12
  def expand_series(series: Series) -> list[ExpansionOne]:
13
  l = len(series.tokens)
14
  items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
15
- candidates = [ExpansionOne(token=l, cost=1.0) for l in dict.fromkeys(items)]
16
- return [c for c in candidates if c.cost <= series.budget]
17
 
18
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
19
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
@@ -38,8 +38,8 @@ def test_expander_budget_one():
38
  expanded = expander.expand(Batch(items=[s]))
39
  expected = ExpansionOneResultBatch(
40
  items=[ExpansionOneResult(series=s, expansions=[
41
- ExpansionOne(token=21, cost=1.0),
42
- ExpansionOne(token=22, cost=1.0),
43
  ])]
44
  )
45
  assert expected == expanded
@@ -49,8 +49,8 @@ def test_expander_budget_two():
49
  expanded = expander.expand(Batch(items=[s]))
50
  expected = ExpansionOneResultBatch(
51
  items=[ExpansionOneResult(series=s, expansions=[
52
- ExpansionOne(token=21, cost=1.0),
53
- ExpansionOne(token=22, cost=1.0),
54
  ])]
55
  )
56
  assert expected == expanded
@@ -68,8 +68,8 @@ def test_expander_budget_one_two_tokens():
68
  expanded = expander.expand(Batch(items=[s]))
69
  expected = ExpansionOneResultBatch(
70
  items=[ExpansionOneResult(series=s, expansions=[
71
- ExpansionOne(token=33, cost=1.0),
72
- ExpansionOne(token=34, cost=1.0),
73
  ])]
74
  )
75
  assert expected == expanded
@@ -81,12 +81,12 @@ def test_expander_budget_one_two_tokens_two_series():
81
  expected = ExpansionOneResultBatch(
82
  items=[
83
  ExpansionOneResult(series=s1, expansions=[
84
- ExpansionOne(token=41, cost=1.0),
85
- ExpansionOne(token=42, cost=1.0),
86
  ]),
87
  ExpansionOneResult(series=s2, expansions=[
88
- ExpansionOne(token=33, cost=1.0),
89
- ExpansionOne(token=34, cost=1.0),
90
  ])
91
  ]
92
  )
 
12
  def expand_series(series: Series) -> list[ExpansionOne]:
13
  l = len(series.tokens)
14
  items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
15
+ candidates = [ExpansionOne(token=l, cost=-1.0) for l in dict.fromkeys(items)]
16
+ return [c for c in candidates if c.cost + series.budget >= 0]
17
 
18
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
19
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
 
38
  expanded = expander.expand(Batch(items=[s]))
39
  expected = ExpansionOneResultBatch(
40
  items=[ExpansionOneResult(series=s, expansions=[
41
+ ExpansionOne(token=21, cost=-1.0),
42
+ ExpansionOne(token=22, cost=-1.0),
43
  ])]
44
  )
45
  assert expected == expanded
 
49
  expanded = expander.expand(Batch(items=[s]))
50
  expected = ExpansionOneResultBatch(
51
  items=[ExpansionOneResult(series=s, expansions=[
52
+ ExpansionOne(token=21, cost=-1.0),
53
+ ExpansionOne(token=22, cost=-1.0),
54
  ])]
55
  )
56
  assert expected == expanded
 
68
  expanded = expander.expand(Batch(items=[s]))
69
  expected = ExpansionOneResultBatch(
70
  items=[ExpansionOneResult(series=s, expansions=[
71
+ ExpansionOne(token=33, cost=-1.0),
72
+ ExpansionOne(token=34, cost=-1.0),
73
  ])]
74
  )
75
  assert expected == expanded
 
81
  expected = ExpansionOneResultBatch(
82
  items=[
83
  ExpansionOneResult(series=s1, expansions=[
84
+ ExpansionOne(token=41, cost=-1.0),
85
+ ExpansionOne(token=42, cost=-1.0),
86
  ]),
87
  ExpansionOneResult(series=s2, expansions=[
88
+ ExpansionOne(token=33, cost=-1.0),
89
+ ExpansionOne(token=34, cost=-1.0),
90
  ])
91
  ]
92
  )
run.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from completions import *
3
+ from expand_llm import *
4
+ from expand import *
5
+
6
+ # %%
7
+ model, tokenizer, device = load_model()
8
+
9
+ #%%
10
+ # input_text = "The quick brown fox jumpz over"
11
+ # input_text = "He asked me to prostate myself before the king"
12
+ input_text = "Здравствуйте, я хочу предвыполнить заказ"
13
+ inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
14
+
15
+ #%%
16
+ token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
17
+
18
+ #%%
19
+ words = split_into_words(token_probs, tokenizer)
20
+ log_prob_threshold = -5.0
21
+ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
22
+
23
+ #%%
24
+ contexts = [word.context for _, word in low_prob_words]
25
+
26
+ #%%
27
+ expander = ExpanderOneBatchLLM(model, tokenizer)
28
+
29
+ #%%
30
+ series = []
31
+ for i, x in enumerate(contexts):
32
+ series.append(Series(id=i, tokens=x, budget=5.0))
33
+
34
+ #%%
35
+ batch = Batch(items=series)
36
+
37
+ #%%
38
+ expanded = expand(batch, expander)
39
+
40
+ # %%
41
+ def print_expansions(expansions: ExpansionResultBatch):
42
+ for result in expansions.items:
43
+ for expansion in result.expansions:
44
+ # convert tokens to string
45
+ s = tokenizer.decode(expansion)
46
+ print(f"{result.series.id}: {expansion} {s}")
47
+
48
+ print_expansions(expanded)
49
+ # %%