File size: 5,072 Bytes
a9cc853 9652314 91f2f92 426b33e b174bd4 91f2f92 308bca9 9652314 308bca9 9652314 87af3eb 0a795e1 b174bd4 9652314 2fb63bf 9652314 b174bd4 426b33e b174bd4 426b33e 8e36e52 426b33e 8e36e52 d3ef10a 8e36e52 d3ef10a da342d0 8e36e52 da342d0 426b33e 230a441 c980c58 230a441 8e36e52 ada166c 426b33e 83ec4f2 426b33e 83ec4f2 8e36e52 83ec4f2 426b33e 83ec4f2 91f2f92 83ec4f2 c4d5641 b174bd4 c4d5641 da342d0 0235f77 308bca9 d3ef10a c4d5641 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#%%
from dataclasses import dataclass
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
from transformers.generation.utils import GenerateOutput
from models import ApiWord, Word, Replacement
from combine import combine
from expand import *
from expand_llm import *
def starts_with_space(token: str) -> bool:
return token.startswith(chr(9601)) or token.startswith(chr(288))
def is_newline(token: str) -> bool:
return len(token) == 1 and ord(token[0]) == 266
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
@dataclass
class Tok:
index: int
ids: list[int]
str: str
logprob: float
def is_beginning_of_word(s: str) -> bool:
return (s[0] == " " and s[1:].isalpha()) or s.isalpha()
def is_continuation_of_word(s: str) -> bool:
return s.isalpha()
def merge_tokens(a: Tok, b: Tok) -> Tok | None:
if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
return Tok(a.index, a.ids + b.ids, a.str + b.str, a.logprob + b.logprob)
return None
converted = [Tok(i, [token_id], tokenizer.decode([token_id]), logprob)
for i, (token_id, logprob) in enumerate(token_probs)]
combined = combine(converted, merge_tokens)
ts = [t[0] for t in token_probs]
words = [Word(tok.ids, tok.str, tok.logprob, ts[:tok.index]) for tok in combined]
return words
def load_model_and_tokenizer(model_name: str, device: torch.device) -> tuple[PreTrainedModel, Tokenizer]:
tokenizer: Tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
return model, tokenizer
def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> BatchEncoding:
return tokenizer(input_text, return_tensors="pt").to(device)
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, inputs: BatchEncoding) -> list[tuple[int, float]]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
# B x T x V
logits: torch.Tensor = outputs.logits[:, :-1, :]
# B x T x V
log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
# T - 1
tokens: torch.Tensor = input_ids[0][1:]
# T - 1
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), tokens]
return list(zip(tokens.tolist(), token_log_probs.tolist()))
#%%
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "unsloth/Llama-3.2-1B"
model, tokenizer = load_model_and_tokenizer(model_name, device)
return model, tokenizer, device
def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, device: torch.device) -> list[ApiWord]:
inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
words = split_into_words(token_probs, tokenizer)
log_prob_threshold = -5.0
low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
contexts = [word.context for _, word in low_prob_words]
expander = LLMBatchExpander(model, tokenizer, threshold=log_prob_threshold)
#%%
series = []
for i, x in enumerate(contexts):
series.append(Series(id=i, tokens=x, budget=5.0))
#%%
batch = Batch(items=series)
#%%
stopping_criterion = create_stopping_criterion_llm(tokenizer)
#%%
expanded = expand(batch, expander, stopping_criterion)
# group by series id
expanded_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
for result in expanded.items:
expanded_by_id[result.series.id].extend(result.expansions)
replacements: list[list[Replacement]] = []
for i, _ in enumerate(contexts):
r = []
expansions = expanded_by_id[i]
for exp in expansions:
tokens = [e.token for e in exp]
s = tokenizer.decode(tokens)
logprob = sum(e.cost for e in exp)
r.append(Replacement(text=s, logprob=logprob))
replacements.append(r)
low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
result = []
for i, word in enumerate(words):
if i in low_prob_words_with_replacements:
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=low_prob_words_with_replacements[i][1]))
else:
result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
return result
|