File size: 2,440 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from tempfile import tempdir
from fairseq.data.encoders.gpt2_bpe import GPT2BPE, GPT2BPEConfig
from fairseq.data.encoders import register_bpe
import logging

logger = logging.getLogger(__name__)

INSERT_OR_REPLACE = 0 # 1 for replace and 0 for insert

@register_bpe("gpt2es", dataclass=GPT2BPEConfig) # as stands for attention space
class GPT2BPEEnhancedSpace(GPT2BPE):
    def __init__(self, cfg):
        logger.info('Using the GPT2BPEEnhancedSpace.')
        super().__init__(cfg)

    def encode(self, x: str) -> str:
        # only for sroie
        assert not x.startswith(' ')
        assert not x.endswith(' ')
        if INSERT_OR_REPLACE == 1:
            temp = []   
            word = ''                     
            for ch in x:
                if ch == ' ':
                    if word:
                        temp.append(word)
                        word = ''
                    temp.append('<s>')
                else:
                    word += ch
            if word:
                temp.append(word)

            for i in range(len(temp)):
                if temp[i] != '<s>':
                    temp[i] = ' '.join(map(str, self.bpe.encode(temp[i])))
                        
            return ' '.join(temp)
        elif INSERT_OR_REPLACE == 0:
            temp = []   
            word = ''                     
            for ch in x:
                if ch == ' ':
                    if word:
                        temp.append(word)
                        word = ' '
                    temp.append('<s>')
                else:
                    word += ch
            if word:
                temp.append(word)

            for i in range(len(temp)):
                if temp[i] != '<s>':
                    temp[i] = ' '.join(map(str, self.bpe.encode(temp[i])))
            
            return ' '.join(temp)           
                    
    def decode(self, x: str) -> str:
        if INSERT_OR_REPLACE == 1:            
            return self.bpe.decode(
                [int(tok) if tok not in {"<unk>", "<mask>", "<s>"} else tok for tok in x.split()]
            ).replace('<s>', ' ')
        elif INSERT_OR_REPLACE == 0:
            return self.bpe.decode(
                [int(tok) if tok not in {"<unk>", "<mask>", "<s>"} else tok for tok in x.split()]
            ).replace('<s>', '')

    def is_beginning_of_word(self, x: str) -> bool:
        return self.decode(x).startswith(" ")