File size: 5,430 Bytes
0108542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import random
import torch


class RetrievalProcessor:
    def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate,
                 add_context, model_name, whitespace_token='Ġ'):
        self.model = model
        self.tokenizer = tokenizer
        self.multi_token_kind = multi_token_kind
        self.num_tokens_to_generate = num_tokens_to_generate
        self.add_context = add_context
        self.model_name = model_name
        self.whitespace_token = whitespace_token

    def get_next_word(self, tokens, i, max_length=1000, device='cuda'):
        token_str = self.tokenizer.convert_ids_to_tokens(tokens[i].item())
        j = i + 1
        word_tokens = [tokens[i]]
        if token_str.startswith(self.whitespace_token):
            while j < len(tokens) and (
                    self.is_alpha_not_prefix(tokens[j])):
                word_tokens.append(tokens[j])
                j += 1
        word = self.tokenizer.decode(word_tokens)
        original_word = word
        context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
        combined_text = context + word

        tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
                                                 max_length=max_length).to(device)
        return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word

    def get_next_full_word_typo(self, tokens, i, max_length=1000, device='cuda'):
        tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
        word_tokens = [tokens[i]]
        word = self.tokenizer.decode(word_tokens)
        original_word = word
        if self.is_full_word(tokens_str, i, word, word_tokens):
            word = self.introduce_typo(word)
        word_tokens = self.tokenizer(word, return_tensors='pt', truncation=True, max_length=max_length).input_ids[0][1:]
        context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
        combined_text = context + word

        tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
                                                 max_length=max_length).to(device)
        j = len(tokenized_combined_text.input_ids[0]) - 1 if self.add_context else len(tokenized_combined_text.input_ids[0]) - 1 + i
        return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word

    def get_next_full_word_separated(self, tokens, i, max_length=1000, device='cuda'):
        tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
        word_tokens = [tokens[i]]
        word = self.tokenizer.decode(word_tokens)
        original_word = word
        if self.is_full_word(tokens_str, i, word, word_tokens):
            word = torch.tensor(self.separate_word(word)).unsqueeze(0)
        else:
            word = word_tokens[0].unsqueeze(0).unsqueeze(0)
        context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
        tokenized_combined_text = self.tokenizer(context, return_tensors='pt', truncation=True,
                                                 max_length=max_length).to(device)
        print(tokenized_combined_text.input_ids)
        print(word)
        tokenized_combined_text.input_ids = torch.cat((tokenized_combined_text.input_ids, word), dim=1)
        word_tokens = word
        j = i+1
        return j, word_tokens, word, context, tokenized_combined_text, self.tokenizer.decode(tokenized_combined_text.input_ids[0]), original_word

    def is_alpha_not_prefix(self, token):
        return (not self.tokenizer.convert_ids_to_tokens(token.item()).startswith(self.whitespace_token)
                and self.tokenizer.convert_ids_to_tokens(token.item()).isalpha())

    def introduce_typo(self, word, typo_type=None):
        letters = 'abcdefghijklmnopqrstuvwxyz'
        if typo_type is None:
            typo_type = random.choice(["substitution", "deletion", "insertion", "transposition"])

        if typo_type == "substitution":
            position = random.randint(1, len(word) - 1)
            original_char = word[position]
            typo_char = random.choice([c for c in letters if c != original_char])
            return word[:position] + typo_char + word[position + 1:]
        elif typo_type == "deletion":
            position = random.randint(1, len(word) - 1)
            return word[:position] + word[position + 1:]
        elif typo_type == "insertion":
            position = random.randint(1, len(word) - 1)
            typo_char = random.choice(letters)
            return word[:position] + typo_char + word[position:]
        elif typo_type == "transposition":
            position = random.randint(1, len(word) - 2)
            return word[:position] + word[position + 1] + word[position] + word[position + 2:]
        else:
            return word

    def separate_word(self, word):
        character_tokens = [self.tokenizer.encode(f'\n{char}')[-1] for char in ''.join(word)]
        character_tokens = character_tokens[3:]
        return character_tokens

    def is_full_word(self, token_str, i, token, word_tokens):
        next_token = self.tokenizer.decode(word_tokens[i + 1]) if i + 1 < len(word_tokens) else ""
        return (token[1:].isalpha() and
                len(token) > 5 and
                token_str[i].startswith(self.whitespace_token) and
                not next_token.isalpha())