File size: 3,940 Bytes
e4fcf38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch.nn as nn
import copy, math
import torch
import numpy as np
import torch.nn.functional as F

from vocab import PepVocab

def create_vocab():
    vocab_mlm = PepVocab()
    vocab_mlm.vocab_from_txt('vocab.txt')
    # vocab_mlm.token_to_idx['-'] = 23
    return vocab_mlm

def show_parameters(model: nn.Module, show_all=False, show_trainable=True):

    mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
    
    if show_all:
        print('All parameters:')
        print(mlp_pa)

    if show_trainable:
        print('Trainable parameters:')
        print(list(filter(lambda x: x[1], list(mlp_pa.items()))))

class ContraLoss(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(ContraLoss, self).__init__(*args, **kwargs)
        
        self.temp = 0.07

    def contrastive_loss(self, proj1, proj2):
        proj1 = F.normalize(proj1, dim=1)
        proj2 = F.normalize(proj2, dim=1)
        dot = torch.matmul(proj1, proj2.T) / self.temp
        dot_max, _ = torch.max(dot, dim=1, keepdim=True)
        dot = dot - dot_max.detach()

        exp_dot = torch.exp(dot)
        log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
        cont_loss = -log_prob.mean()
        return cont_loss
    
    def forward(self, x, y, label=None):
        return self.contrastive_loss(x, y)


import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import random
from transformers import set_seed

def show_parameters(model: nn.Module, show_all=False, show_trainable=True):

    mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
    
    if show_all:
        print('All parameters:')
        print(mlp_pa)

    if show_trainable:
        print('Trainable parameters:')
        print(list(filter(lambda x: x[1], list(mlp_pa.items()))))

def extract_args(text):
    str_list = []
    substr = ""
    for s in text:
        if s in ('(', ')', '=', ',', ' ', '\n', "'"):
            if substr != '':
                str_list.append(substr)
                substr = ''
        else:
            substr += s

def eval_one_epoch(loader, cono_encoder):
    cono_encoder.eval()
    batch_loss = []
    for i, data in enumerate(tqdm(loader)):
        
        loss = cono_encoder.contra_forward(data)
        batch_loss.append(loss.item())
        print(f'[INFO] Test batch {i} loss: {loss.item()}')

    total_loss = np.mean(batch_loss)    
    print(f'[INFO] Total loss: {total_loss}')    
    return total_loss

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    set_seed(seed)

class CrossEntropyLossWithMask(torch.nn.Module):
    def __init__(self, weight=None):
        super(CrossEntropyLossWithMask, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction='none')

    def forward(self, y_pred, y_true, mask):
        (pos_mask, label_mask, seq_mask) = mask
        loss = self.criterion(y_pred, y_true) # (6912)
        
        pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
        label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
        seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
        
        loss = pos_loss + label_loss/2 + seq_loss/3

        return loss


def mask(x, start, end, time):
    ske_pos = np.where(np.array(x)=='C')[0] - start
    lables_pos = np.array([1, 2]) - start
    ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
    lables_pos = list(filter(lambda x: x >= 0, lables_pos))
    weight = np.ones(end - start+1)
    rand = np.random.rand()
    if rand < 0.5:
        weight[lables_pos] = 100000
    else:
        weight[lables_pos] = 1
    mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
    for idx in mask_pos:
        x[idx]  = '[MASK]'
    return x