Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,430 Bytes
0108542 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import random
import torch
class RetrievalProcessor:
def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate,
add_context, model_name, whitespace_token='Ġ'):
self.model = model
self.tokenizer = tokenizer
self.multi_token_kind = multi_token_kind
self.num_tokens_to_generate = num_tokens_to_generate
self.add_context = add_context
self.model_name = model_name
self.whitespace_token = whitespace_token
def get_next_word(self, tokens, i, max_length=1000, device='cuda'):
token_str = self.tokenizer.convert_ids_to_tokens(tokens[i].item())
j = i + 1
word_tokens = [tokens[i]]
if token_str.startswith(self.whitespace_token):
while j < len(tokens) and (
self.is_alpha_not_prefix(tokens[j])):
word_tokens.append(tokens[j])
j += 1
word = self.tokenizer.decode(word_tokens)
original_word = word
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
combined_text = context + word
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
max_length=max_length).to(device)
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
def get_next_full_word_typo(self, tokens, i, max_length=1000, device='cuda'):
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
word_tokens = [tokens[i]]
word = self.tokenizer.decode(word_tokens)
original_word = word
if self.is_full_word(tokens_str, i, word, word_tokens):
word = self.introduce_typo(word)
word_tokens = self.tokenizer(word, return_tensors='pt', truncation=True, max_length=max_length).input_ids[0][1:]
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
combined_text = context + word
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
max_length=max_length).to(device)
j = len(tokenized_combined_text.input_ids[0]) - 1 if self.add_context else len(tokenized_combined_text.input_ids[0]) - 1 + i
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
def get_next_full_word_separated(self, tokens, i, max_length=1000, device='cuda'):
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
word_tokens = [tokens[i]]
word = self.tokenizer.decode(word_tokens)
original_word = word
if self.is_full_word(tokens_str, i, word, word_tokens):
word = torch.tensor(self.separate_word(word)).unsqueeze(0)
else:
word = word_tokens[0].unsqueeze(0).unsqueeze(0)
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
tokenized_combined_text = self.tokenizer(context, return_tensors='pt', truncation=True,
max_length=max_length).to(device)
print(tokenized_combined_text.input_ids)
print(word)
tokenized_combined_text.input_ids = torch.cat((tokenized_combined_text.input_ids, word), dim=1)
word_tokens = word
j = i+1
return j, word_tokens, word, context, tokenized_combined_text, self.tokenizer.decode(tokenized_combined_text.input_ids[0]), original_word
def is_alpha_not_prefix(self, token):
return (not self.tokenizer.convert_ids_to_tokens(token.item()).startswith(self.whitespace_token)
and self.tokenizer.convert_ids_to_tokens(token.item()).isalpha())
def introduce_typo(self, word, typo_type=None):
letters = 'abcdefghijklmnopqrstuvwxyz'
if typo_type is None:
typo_type = random.choice(["substitution", "deletion", "insertion", "transposition"])
if typo_type == "substitution":
position = random.randint(1, len(word) - 1)
original_char = word[position]
typo_char = random.choice([c for c in letters if c != original_char])
return word[:position] + typo_char + word[position + 1:]
elif typo_type == "deletion":
position = random.randint(1, len(word) - 1)
return word[:position] + word[position + 1:]
elif typo_type == "insertion":
position = random.randint(1, len(word) - 1)
typo_char = random.choice(letters)
return word[:position] + typo_char + word[position:]
elif typo_type == "transposition":
position = random.randint(1, len(word) - 2)
return word[:position] + word[position + 1] + word[position] + word[position + 2:]
else:
return word
def separate_word(self, word):
character_tokens = [self.tokenizer.encode(f'\n{char}')[-1] for char in ''.join(word)]
character_tokens = character_tokens[3:]
return character_tokens
def is_full_word(self, token_str, i, token, word_tokens):
next_token = self.tokenizer.decode(word_tokens[i + 1]) if i + 1 < len(word_tokens) else ""
return (token[1:].isalpha() and
len(token) > 5 and
token_str[i].startswith(self.whitespace_token) and
not next_token.isalpha())
|