Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import torch | |
class RetrievalProcessor: | |
def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate, | |
add_context, model_name, whitespace_token='Ġ'): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.multi_token_kind = multi_token_kind | |
self.num_tokens_to_generate = num_tokens_to_generate | |
self.add_context = add_context | |
self.model_name = model_name | |
self.whitespace_token = whitespace_token | |
def get_next_word(self, tokens, i, max_length=1000, device='cuda'): | |
token_str = self.tokenizer.convert_ids_to_tokens(tokens[i].item()) | |
j = i + 1 | |
word_tokens = [tokens[i]] | |
if token_str.startswith(self.whitespace_token): | |
while j < len(tokens) and ( | |
self.is_alpha_not_prefix(tokens[j])): | |
word_tokens.append(tokens[j]) | |
j += 1 | |
word = self.tokenizer.decode(word_tokens) | |
original_word = word | |
context = self.tokenizer.decode(tokens[:i]) if self.add_context else "" | |
combined_text = context + word | |
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True, | |
max_length=max_length).to(device) | |
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word | |
def get_next_full_word_typo(self, tokens, i, max_length=1000, device='cuda'): | |
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens) | |
word_tokens = [tokens[i]] | |
word = self.tokenizer.decode(word_tokens) | |
original_word = word | |
if self.is_full_word(tokens_str, i, word, word_tokens): | |
word = self.introduce_typo(word) | |
word_tokens = self.tokenizer(word, return_tensors='pt', truncation=True, max_length=max_length).input_ids[0][1:] | |
context = self.tokenizer.decode(tokens[:i]) if self.add_context else "" | |
combined_text = context + word | |
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True, | |
max_length=max_length).to(device) | |
j = len(tokenized_combined_text.input_ids[0]) - 1 if self.add_context else len(tokenized_combined_text.input_ids[0]) - 1 + i | |
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word | |
def get_next_full_word_separated(self, tokens, i, max_length=1000, device='cuda'): | |
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens) | |
word_tokens = [tokens[i]] | |
word = self.tokenizer.decode(word_tokens) | |
original_word = word | |
if self.is_full_word(tokens_str, i, word, word_tokens): | |
word = torch.tensor(self.separate_word(word)).unsqueeze(0) | |
else: | |
word = word_tokens[0].unsqueeze(0).unsqueeze(0) | |
context = self.tokenizer.decode(tokens[:i]) if self.add_context else "" | |
tokenized_combined_text = self.tokenizer(context, return_tensors='pt', truncation=True, | |
max_length=max_length).to(device) | |
print(tokenized_combined_text.input_ids) | |
print(word) | |
tokenized_combined_text.input_ids = torch.cat((tokenized_combined_text.input_ids, word), dim=1) | |
word_tokens = word | |
j = i+1 | |
return j, word_tokens, word, context, tokenized_combined_text, self.tokenizer.decode(tokenized_combined_text.input_ids[0]), original_word | |
def is_alpha_not_prefix(self, token): | |
return (not self.tokenizer.convert_ids_to_tokens(token.item()).startswith(self.whitespace_token) | |
and self.tokenizer.convert_ids_to_tokens(token.item()).isalpha()) | |
def introduce_typo(self, word, typo_type=None): | |
letters = 'abcdefghijklmnopqrstuvwxyz' | |
if typo_type is None: | |
typo_type = random.choice(["substitution", "deletion", "insertion", "transposition"]) | |
if typo_type == "substitution": | |
position = random.randint(1, len(word) - 1) | |
original_char = word[position] | |
typo_char = random.choice([c for c in letters if c != original_char]) | |
return word[:position] + typo_char + word[position + 1:] | |
elif typo_type == "deletion": | |
position = random.randint(1, len(word) - 1) | |
return word[:position] + word[position + 1:] | |
elif typo_type == "insertion": | |
position = random.randint(1, len(word) - 1) | |
typo_char = random.choice(letters) | |
return word[:position] + typo_char + word[position:] | |
elif typo_type == "transposition": | |
position = random.randint(1, len(word) - 2) | |
return word[:position] + word[position + 1] + word[position] + word[position + 2:] | |
else: | |
return word | |
def separate_word(self, word): | |
character_tokens = [self.tokenizer.encode(f'\n{char}')[-1] for char in ''.join(word)] | |
character_tokens = character_tokens[3:] | |
return character_tokens | |
def is_full_word(self, token_str, i, token, word_tokens): | |
next_token = self.tokenizer.decode(word_tokens[i + 1]) if i + 1 < len(word_tokens) else "" | |
return (token[1:].isalpha() and | |
len(token) > 5 and | |
token_str[i].startswith(self.whitespace_token) and | |
not next_token.isalpha()) | |