oucgc1996 commited on
Commit
e13ac09
·
verified ·
1 Parent(s): 2d573ef

Upload 5 files

Browse files
Files changed (5) hide show
  1. bertmodel.py +171 -0
  2. dataset_mlm.py +151 -0
  3. model.py +171 -0
  4. utils.py +132 -0
  5. vocab.py +193 -0
bertmodel.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ class Bert(nn.Module):
8
+
9
+ def __init__(self, encoder, src_embed):
10
+ super(Bert, self).__init__()
11
+
12
+ self.encoder = encoder
13
+ self.src_embed = src_embed
14
+
15
+ def forward(self, src, src_mask):
16
+
17
+ return self.encoder(self.src_embed(src), src_mask)
18
+
19
+
20
+ class Encoder(nn.Module):
21
+ def __init__(self, layer, N):
22
+ super(Encoder, self).__init__()
23
+ self.layers = clones(layer, N)
24
+ self.norm = LayerNorm(layer.size)
25
+
26
+ def forward(self, x, mask):
27
+ for layer in self.layers:
28
+ x = layer(x, mask)
29
+ return self.norm(x)
30
+
31
+ class LayerNorm(nn.Module):
32
+ def __init__(self, features, eps=1e-6):
33
+ super(LayerNorm, self).__init__()
34
+ self.a_2 = nn.Parameter(torch.ones(features))
35
+ self.b_2 = nn.Parameter(torch.zeros(features))
36
+ self.eps = eps
37
+
38
+ def forward(self, x):
39
+ mean = x.mean(-1, keepdim=True)
40
+ std = x.std(-1, keepdim=True)
41
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
42
+
43
+ class SublayerConnection(nn.Module):
44
+ def __init__(self, size, dropout):
45
+ super(SublayerConnection, self).__init__()
46
+ self.norm = LayerNorm(size)
47
+ self.dropout = nn.Dropout(dropout)
48
+
49
+ def forward(self, x, sublayer):
50
+ return x + self.dropout(sublayer(self.norm(x)))
51
+
52
+ class EncoderLayer(nn.Module):
53
+ def __init__(self, size, self_attn, feed_forward, dropout):
54
+ super(EncoderLayer, self).__init__()
55
+ self.self_attn = self_attn
56
+ self.feed_forward = feed_forward
57
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
58
+ self.size = size
59
+
60
+ def forward(self, x, mask):
61
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
62
+ return self.sublayer[1](x, self.feed_forward)
63
+
64
+ class PositionwiseFeedForward(nn.Module):
65
+ def __init__(self, d_model, d_ff, dropout=0.1):
66
+ super(PositionwiseFeedForward, self).__init__()
67
+ self.w_1 = nn.Linear(d_model, d_ff)
68
+ self.w_2 = nn.Linear(d_ff, d_model)
69
+ self.dropout = nn.Dropout(dropout)
70
+
71
+ def forward(self, x):
72
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
73
+
74
+ def make_bert(src_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
75
+ c = copy.deepcopy
76
+ attn = MultiHeadedAttention(h, d_model)
77
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
78
+ position = PositionalEncoding(d_model, dropout)
79
+ model = Bert(
80
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
81
+
82
+ nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
83
+ )
84
+
85
+ for p in model.parameters():
86
+ if p.dim() > 1:
87
+ nn.init.xavier_uniform_(p)
88
+ return model
89
+
90
+ def make_bert_without_emb(d_model=128, N=2, d_ff=512, h=8, dropout=0.1):
91
+ c = copy.deepcopy
92
+ attn = MultiHeadedAttention(h, d_model)
93
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
94
+ trainable_encoder = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
95
+
96
+ return trainable_encoder
97
+
98
+
99
+
100
+ def clones(module, N):
101
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
102
+
103
+ def subsequent_mask(size):
104
+ attn_shape = (1, size, size)
105
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
106
+ return torch.from_numpy(subsequent_mask) == 0
107
+
108
+ def attention(query, key, value, mask=None, dropout=None):
109
+ d_k = query.size(-1)
110
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
111
+ if mask is not None:
112
+ mask = mask.unsqueeze(-2)
113
+ scores = scores.masked_fill(mask == 0, -1e9)
114
+ p_attn = F.softmax(scores, dim = -1)
115
+ if dropout is not None:
116
+ p_attn = dropout(p_attn)
117
+ return torch.matmul(p_attn, value), p_attn
118
+
119
+ class MultiHeadedAttention(nn.Module):
120
+ def __init__(self, h, d_model, dropout=0.1):
121
+ super(MultiHeadedAttention, self).__init__()
122
+ assert d_model % h == 0
123
+ self.d_k = d_model // h
124
+ self.h = h
125
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
126
+ self.attn = None
127
+ self.dropout = nn.Dropout(p=dropout)
128
+
129
+ def forward(self, query, key, value, mask=None):
130
+ if mask is not None:
131
+ mask = mask.unsqueeze(1)
132
+ nbatches = query.size(0)
133
+
134
+ query, key, value = \
135
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
136
+ for l, x in zip(self.linears, (query, key, value))]
137
+
138
+ x, self.attn = attention(query, key, value, mask=mask,
139
+ dropout=self.dropout)
140
+
141
+ x = x.transpose(1, 2).contiguous() \
142
+ .view(nbatches, -1, self.h * self.d_k)
143
+ return self.linears[-1](x)
144
+
145
+ class Embeddings(nn.Module):
146
+ def __init__(self, d_model, vocab):
147
+ super(Embeddings, self).__init__()
148
+ self.lut = nn.Embedding(vocab, d_model)
149
+ self.d_model = d_model
150
+
151
+ def forward(self, x):
152
+ return self.lut(x) * math.sqrt(self.d_model)
153
+
154
+ class PositionalEncoding(nn.Module):
155
+ def __init__(self, d_model, dropout, max_len=5000):
156
+ super(PositionalEncoding, self).__init__()
157
+ self.dropout = nn.Dropout(p=dropout)
158
+
159
+ pe = torch.zeros(max_len, d_model)
160
+ position = torch.arange(0, max_len).unsqueeze(1)
161
+ div_term = torch.exp(torch.arange(0, d_model, 2) *
162
+ -(math.log(10000.0) / d_model))
163
+ pe[:, 0::2] = torch.sin(position * div_term)
164
+ pe[:, 1::2] = torch.cos(position * div_term)
165
+ pe = pe.unsqueeze(0)
166
+ self.register_buffer('pe', pe)
167
+
168
+ def forward(self, x):
169
+ x = x + self.pe[:, :x.size(1)].clone().detach()
170
+ return self.dropout(x)
171
+
dataset_mlm.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ from torch.utils.data import TensorDataset, DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+
8
+ from vocab import PepVocab
9
+ from utils import mask, create_vocab
10
+
11
+ addtition_tokens = ['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
12
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
13
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
14
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
15
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
16
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>']
17
+
18
+ def add_tokens_to_vocab(vocab_mlm: PepVocab):
19
+ vocab_mlm.add_special_token(addtition_tokens)
20
+ return vocab_mlm
21
+
22
+ def split_seq(seq, vocab, get_seq=False):
23
+ '''
24
+ note: the function is suitable for the sequences with the format of "label|label|sequence|msa1|msa2|msa3"
25
+ '''
26
+ start = '[CLS]'
27
+ end = '[SEP]'
28
+ pad = '[PAD]'
29
+ cls_label = seq.split('|')[0]
30
+ act_label = seq.split('|')[1]
31
+
32
+ if get_seq == True:
33
+ add = lambda x: [start] + [cls_label] + [act_label] + x + [end]
34
+ pep_seq = seq.split('|')[2]
35
+ # return [start] + [cls_label] + [act_label] + vocab.split_seq(pep_seq) + [end]
36
+ return add(vocab.split_seq(pep_seq))
37
+
38
+ else:
39
+ add = lambda x: [start] + [pad] + [pad] + x + [end]
40
+ msa1_seq = seq.split('|')[3]
41
+ msa2_seq = seq.split('|')[4]
42
+ msa3_seq = seq.split('|')[5]
43
+
44
+ # return [vocab.split_seq(msa1_seq)] + [vocab.split_seq(msa2_seq)] + [vocab.split_seq(msa3_seq)]
45
+ return [add(vocab.split_seq(msa1_seq))] + [add(vocab.split_seq(msa2_seq))] + [add(vocab.split_seq(msa3_seq))]
46
+
47
+ def get_paded_token_idx(vocab_mlm):
48
+ cono_path = 'conoData_C5.csv'
49
+ seq = pd.read_csv(cono_path)['Sequences']
50
+
51
+ splited_seq = list(seq.apply(split_seq, args=(vocab_mlm,True, )))
52
+ splited_msa = list(seq.apply(split_seq, args=(vocab_mlm, False, )))
53
+
54
+ vocab_mlm.set_get_attn(is_get=True)
55
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
56
+ attn_idx = vocab_mlm.get_attention_mask_mat()
57
+
58
+ vocab_mlm.set_get_attn(is_get=False)
59
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
60
+
61
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
62
+
63
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
64
+
65
+ return padded_seq, idx_seq, idx_msa, attn_idx
66
+
67
+ def get_paded_token_idx_gen(vocab_mlm, seq):
68
+
69
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
70
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
71
+
72
+ vocab_mlm.set_get_attn(is_get=True)
73
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
74
+ attn_idx = vocab_mlm.get_attention_mask_mat()
75
+
76
+ vocab_mlm.set_get_attn(is_get=False)
77
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
78
+
79
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
80
+
81
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
82
+
83
+ return padded_seq, idx_seq, idx_msa, attn_idx
84
+
85
+
86
+ def get_paded_token_idx_gen(vocab_mlm, seq, new_seq):
87
+ if new_seq == None:
88
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
89
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
90
+
91
+ vocab_mlm.set_get_attn(is_get=True)
92
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
93
+ attn_idx = vocab_mlm.get_attention_mask_mat()
94
+ vocab_mlm.set_get_attn(is_get=False)
95
+
96
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
97
+
98
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
99
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
100
+ else:
101
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
102
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
103
+ vocab_mlm.set_get_attn(is_get=True)
104
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
105
+ attn_idx = vocab_mlm.get_attention_mask_mat()
106
+ vocab_mlm.set_get_attn(is_get=False)
107
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
108
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
109
+
110
+ idx_seq = vocab_mlm.__getitem__(new_seq)
111
+ return padded_seq, idx_seq, idx_msa, attn_idx
112
+
113
+
114
+
115
+ def make_mask(seq_ser, start, end, time, vocab_mlm, labels, idx_msa, attn_idx):
116
+ seq_ser = pd.Series(seq_ser)
117
+ masked_seq = seq_ser.apply(mask, args=(start, end, time))
118
+ masked_idx = vocab_mlm.__getitem__(list(masked_seq))
119
+ masked_idx = torch.tensor(masked_idx)
120
+ device = torch.device('cuda:1')
121
+ data_arrays = (masked_idx.to(device), labels.to(device), idx_msa.to(device), attn_idx.to(device))
122
+ dataset = TensorDataset(*data_arrays)
123
+ train_dataset, test_dataset = train_test_split(dataset, test_size=0.1, random_state=42, shuffle=True)
124
+ train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
125
+ test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
126
+
127
+ return train_loader, test_loader
128
+
129
+ if __name__ == '__main__':
130
+ # from add_args import parse_args
131
+ import numpy as np
132
+ # args = parse_args()
133
+
134
+ vocab_mlm = create_vocab()
135
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
136
+ padded_seq, idx_seq, idx_msa, attn_idx = get_paded_token_idx(vocab_mlm)
137
+ labels = torch.tensor(idx_seq)
138
+ idx_msa = torch.tensor(idx_msa)
139
+ attn_idx = torch.tensor(attn_idx)
140
+
141
+ # time_step = args.mask_time_step
142
+ for t in np.arange(1, 50):
143
+ padded_seq_copy = deepcopy(padded_seq)
144
+ train_loader, test_loader = make_mask(padded_seq_copy, start=0, end=49, time=t,
145
+ vocab_mlm=vocab_mlm, labels=labels, idx_msa=idx_msa, attn_idx=attn_idx)
146
+ for i, (masked_idx, label, msa, attn) in enumerate(train_loader):
147
+ print(f"the {i}th batch is that masked_idx is {masked_idx.shape}, labels is {label.shape}, idx_msa is {msa.shape}")
148
+ print(f"the {t}th time step is done")
149
+
150
+
151
+
model.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModelForMaskedLM, AutoConfig
7
+
8
+ from bertmodel import make_bert, make_bert_without_emb
9
+ from utils import ContraLoss
10
+
11
+ def load_pretrained_model():
12
+ model_checkpoint = "Rostlab/prot_bert"
13
+ config = AutoConfig.from_pretrained(model_checkpoint)
14
+ model = AutoModelForMaskedLM.from_config(config)
15
+
16
+ return model
17
+
18
+ class ConoEncoder(nn.Module):
19
+ def __init__(self, encoder):
20
+ super(ConoEncoder, self).__init__()
21
+
22
+ self.encoder = encoder
23
+ self.trainable_encoder = make_bert_without_emb()
24
+
25
+
26
+ for param in self.encoder.parameters():
27
+ param.requires_grad = False
28
+
29
+
30
+ def forward(self, x, mask): # x:(128,54) mask:(128,54)
31
+ feat = self.encoder(x, attention_mask=mask) # (128,54,128)
32
+ feat = list(feat.values())[0] # (128,54,128)
33
+
34
+ feat = self.trainable_encoder(feat, mask) # (128,54,128)
35
+
36
+ return feat
37
+
38
+ class MSABlock(nn.Module):
39
+ def __init__(self, in_dim, out_dim, vocab_size):
40
+ super(MSABlock, self).__init__()
41
+ self.embedding = nn.Embedding(vocab_size, in_dim)
42
+ self.mlp = nn.Sequential(
43
+ nn.Linear(in_dim, out_dim),
44
+ nn.LeakyReLU(),
45
+ nn.Linear(out_dim, out_dim)
46
+ )
47
+ self.init()
48
+
49
+ def init(self):
50
+ for layer in self.mlp.children():
51
+ if isinstance(layer, nn.Linear):
52
+ nn.init.xavier_uniform_(layer.weight)
53
+ # nn.init.xavier_uniform_(self.embedding.weight)
54
+
55
+ def forward(self, x): # x: (128,3,54)
56
+ x = self.embedding(x) # x: (128,3,54,128)
57
+ x = self.mlp(x) # x: (128,3,54,128)
58
+ return x
59
+
60
+ class ConoModel(nn.Module):
61
+ def __init__(self, encoder, msa_block, decoder):
62
+ super(ConoModel, self).__init__()
63
+ self.encoder = encoder
64
+ self.msa_block = msa_block
65
+ self.feature_combine = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=1)
66
+ self.decoder = decoder
67
+
68
+ def forward(self, input_ids, msa, attn_idx=None):
69
+ encoder_output = self.encoder.forward(input_ids, attn_idx) # (128,54,128)
70
+ msa_output = self.msa_block(msa) # (128,3,54,128)
71
+ # msa_output = torch.mean(msa_output, dim=1)
72
+ encoder_output = encoder_output.view(input_ids.shape[0], 54, -1).unsqueeze(1) # (128,1,54,128)
73
+
74
+ output = torch.cat([encoder_output*5, msa_output], dim=1) # (128,4,54,128)
75
+ output = self.feature_combine(output) # (128,1,54,128)
76
+ output = output.squeeze(1) # (128,54,128)
77
+ logits = self.decoder(output) # (128,54,85)
78
+
79
+ return logits
80
+
81
+ class ContraModel(nn.Module):
82
+ def __init__(self, cono_encoder):
83
+ super(ContraModel, self).__init__()
84
+
85
+ self.contra_loss = ContraLoss()
86
+
87
+ self.encoder1 = cono_encoder
88
+ self.encoder2 = make_bert(404, 6, 128)
89
+
90
+ # contrastive decoder
91
+ self.lstm = nn.LSTM(16, 16, batch_first=True)
92
+ self.contra_decoder = nn.Sequential(
93
+ nn.Linear(128, 64),
94
+ nn.LeakyReLU(),
95
+ nn.Linear(64, 32),
96
+ nn.LeakyReLU(),
97
+ nn.Linear(32, 16),
98
+ nn.LeakyReLU(),
99
+ nn.Dropout(0.1),
100
+ )
101
+
102
+ # classifier
103
+ self.pre_classifer = nn.LSTM(128, 64, batch_first=True)
104
+ self.classifer = nn.Sequential(
105
+ nn.Linear(128, 32),
106
+ nn.LeakyReLU(),
107
+ nn.Linear(32, 6),
108
+ nn.Softmax(dim=-1)
109
+ )
110
+
111
+ self.init()
112
+
113
+ def init(self):
114
+
115
+ for layer in self.contra_decoder.children():
116
+ if isinstance(layer, nn.Linear):
117
+ nn.init.xavier_uniform_(layer.weight)
118
+ for layer in self.classifer.children():
119
+ if isinstance(layer, nn.Linear):
120
+ nn.init.xavier_uniform_(layer.weight)
121
+ for layer in self.pre_classifer.children():
122
+ if isinstance(layer, nn.Linear):
123
+ nn.init.xavier_uniform_(layer.weight)
124
+ for layer in self.lstm.children():
125
+ if isinstance(layer, nn.Linear):
126
+ nn.init.xavier_uniform_(layer.weight)
127
+
128
+ def compute_class_loss(self, feat1, feat2, labels):
129
+ _, cls_feat1= self.pre_classifer(feat1)
130
+ _, cls_feat2 = self.pre_classifer(feat2)
131
+ cls_feat1 = torch.cat([cls_feat1[0], cls_feat1[1]], dim=-1).squeeze(0)
132
+ cls_feat2 = torch.cat([cls_feat2[0], cls_feat2[1]], dim=-1).squeeze(0)
133
+
134
+ cls1_dis = self.classifer(cls_feat1)
135
+ cls2_dis = self.classifer(cls_feat2)
136
+ cls1_loss = F.cross_entropy(cls1_dis, labels.to('cuda:0'))
137
+ cls2_loss = F.cross_entropy(cls2_dis, labels.to('cuda:0'))
138
+
139
+ return cls1_loss, cls2_loss
140
+
141
+ def compute_contrastive_loss(self, feat1, feat2):
142
+
143
+ contra_feat1 = self.contra_decoder(feat1)
144
+ contra_feat2 = self.contra_decoder(feat2)
145
+
146
+ _, feat1 = self.lstm(contra_feat1)
147
+ _, feat2 = self.lstm(contra_feat2)
148
+ feat1 = torch.cat([feat1[0], feat1[1]], dim=-1).squeeze(0)
149
+ feat2 = torch.cat([feat2[0], feat2[1]], dim=-1).squeeze(0)
150
+
151
+ ctr_loss = self.contra_loss(feat1, feat2)
152
+
153
+ return ctr_loss
154
+
155
+ def forward(self, x1, x2, labels=None):
156
+ loss = dict()
157
+
158
+ idx1, attn1 = x1
159
+ idx2, attn2 = x2
160
+ feat1 = self.encoder1(idx1.to('cuda:0'), attn1.to('cuda:0'))
161
+ feat2 = self.encoder2(idx2.to('cuda:0'), attn2.to('cuda:0'))
162
+
163
+ cls1_loss, cls2_loss = self.compute_class_loss(feat1, feat2, labels)
164
+
165
+ ctr_loss = self.compute_contrastive_loss(feat1, feat2)
166
+
167
+ loss['cls1_loss'] = cls1_loss
168
+ loss['cls2_loss'] = cls2_loss
169
+ loss['ctr_loss'] = ctr_loss
170
+
171
+ return loss
utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from vocab import PepVocab
8
+
9
+ def create_vocab():
10
+ vocab_mlm = PepVocab()
11
+ vocab_mlm.vocab_from_txt('vocab.txt')
12
+ # vocab_mlm.token_to_idx['-'] = 23
13
+ return vocab_mlm
14
+
15
+ def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
16
+
17
+ mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
18
+
19
+ if show_all:
20
+ print('All parameters:')
21
+ print(mlp_pa)
22
+
23
+ if show_trainable:
24
+ print('Trainable parameters:')
25
+ print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
26
+
27
+ class ContraLoss(nn.Module):
28
+ def __init__(self, *args, **kwargs) -> None:
29
+ super(ContraLoss, self).__init__(*args, **kwargs)
30
+
31
+ self.temp = 0.07
32
+
33
+ def contrastive_loss(self, proj1, proj2):
34
+ proj1 = F.normalize(proj1, dim=1)
35
+ proj2 = F.normalize(proj2, dim=1)
36
+ dot = torch.matmul(proj1, proj2.T) / self.temp
37
+ dot_max, _ = torch.max(dot, dim=1, keepdim=True)
38
+ dot = dot - dot_max.detach()
39
+
40
+ exp_dot = torch.exp(dot)
41
+ log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
42
+ cont_loss = -log_prob.mean()
43
+ return cont_loss
44
+
45
+ def forward(self, x, y, label=None):
46
+ return self.contrastive_loss(x, y)
47
+
48
+
49
+ import numpy as np
50
+ from tqdm import tqdm
51
+ import torch
52
+ import torch.nn as nn
53
+ import random
54
+ from transformers import set_seed
55
+
56
+ def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
57
+
58
+ mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
59
+
60
+ if show_all:
61
+ print('All parameters:')
62
+ print(mlp_pa)
63
+
64
+ if show_trainable:
65
+ print('Trainable parameters:')
66
+ print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
67
+
68
+ def extract_args(text):
69
+ str_list = []
70
+ substr = ""
71
+ for s in text:
72
+ if s in ('(', ')', '=', ',', ' ', '\n', "'"):
73
+ if substr != '':
74
+ str_list.append(substr)
75
+ substr = ''
76
+ else:
77
+ substr += s
78
+
79
+ def eval_one_epoch(loader, cono_encoder):
80
+ cono_encoder.eval()
81
+ batch_loss = []
82
+ for i, data in enumerate(tqdm(loader)):
83
+
84
+ loss = cono_encoder.contra_forward(data)
85
+ batch_loss.append(loss.item())
86
+ print(f'[INFO] Test batch {i} loss: {loss.item()}')
87
+
88
+ total_loss = np.mean(batch_loss)
89
+ print(f'[INFO] Total loss: {total_loss}')
90
+ return total_loss
91
+
92
+ def setup_seed(seed):
93
+ torch.manual_seed(seed)
94
+ torch.cuda.manual_seed_all(seed)
95
+ np.random.seed(seed)
96
+ random.seed(seed)
97
+ torch.backends.cudnn.deterministic = True
98
+ set_seed(seed)
99
+
100
+ class CrossEntropyLossWithMask(torch.nn.Module):
101
+ def __init__(self, weight=None):
102
+ super(CrossEntropyLossWithMask, self).__init__()
103
+ self.criterion = nn.CrossEntropyLoss(reduction='none')
104
+
105
+ def forward(self, y_pred, y_true, mask):
106
+ (pos_mask, label_mask, seq_mask) = mask
107
+ loss = self.criterion(y_pred, y_true) # (6912)
108
+
109
+ pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
110
+ label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
111
+ seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
112
+
113
+ loss = pos_loss + label_loss/2 + seq_loss/3
114
+
115
+ return loss
116
+
117
+
118
+ def mask(x, start, end, time):
119
+ ske_pos = np.where(np.array(x)=='C')[0] - start
120
+ lables_pos = np.array([1, 2]) - start
121
+ ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
122
+ lables_pos = list(filter(lambda x: x >= 0, lables_pos))
123
+ weight = np.ones(end - start+1)
124
+ rand = np.random.rand()
125
+ if rand < 0.5:
126
+ weight[lables_pos] = 100000
127
+ else:
128
+ weight[lables_pos] = 1
129
+ mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
130
+ for idx in mask_pos:
131
+ x[idx] = '[MASK]'
132
+ return x
vocab.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+
4
+ class PepVocab:
5
+ def __init__(self):
6
+ self.token_to_idx = {
7
+ '<MASK>': -1, '<PAD>': 0, 'A': 1, 'C': 2, 'E': 3, 'D': 4, 'F': 5, 'I': 6, 'H': 7,
8
+ 'K': 8, 'M': 9, 'L': 10, 'N': 11, 'Q': 12, 'P': 13, 'S': 14,
9
+ 'R': 15, 'T': 16, 'W': 17, 'V': 18, 'Y': 19, 'G': 20, 'O': 21, 'U': 22, 'Z': 23, 'X': 24}
10
+ self.idx_to_token = {
11
+ -1: '<MASK>', 0: '<PAD>', 1: 'A', 2: 'C', 3: 'E', 4: 'D', 5: 'F', 6: 'I', 7: 'H',
12
+ 8: 'K', 9: 'M', 10: 'L', 11: 'N', 12: 'Q', 13: 'P', 14: 'S',
13
+ 15: 'R', 16: 'T', 17: 'W', 18: 'V', 19: 'Y', 20: 'G', 21: 'O', 22: 'U', 23: 'Z', 24: 'X'}
14
+
15
+ self.get_attention_mask = False
16
+ self.attention_mask = []
17
+
18
+ def set_get_attn(self, is_get: bool):
19
+ self.get_attention_mask = is_get
20
+
21
+ def __len__(self):
22
+ return len(self.idx_to_token)
23
+
24
+ def __getitem__(self, tokens):
25
+ '''
26
+ note: input should a splited sequence
27
+
28
+ Args:
29
+ tokens: a token or token list of splited
30
+ '''
31
+ if not isinstance(tokens, (list, tuple)):
32
+ # return self.token_to_idx.get(tokens)
33
+ return self.token_to_idx[tokens]
34
+ return [self.__getitem__(token) for token in tokens]
35
+
36
+ def vocab_from_txt(self, path):
37
+ '''
38
+ note: this function use for constructing vocab mapping
39
+ but it is only suitable for special txt format
40
+ it support one column txt file, which column name is 0
41
+ '''
42
+ token_to_idx = {}
43
+ idx_to_token = {}
44
+ chr_idx = pd.read_csv(path, header=None, sep='\t')
45
+ if chr_idx.shape[1] == 1:
46
+ for idx, token in enumerate(chr_idx[0]):
47
+ token_to_idx[token] = idx
48
+ idx_to_token[idx] = token
49
+ self.token_to_idx = token_to_idx
50
+ self.idx_to_token = idx_to_token
51
+
52
+ def to_tokens(self, indices):
53
+ '''
54
+ note: input should a integer list
55
+ '''
56
+ if hasattr(indices, '__len__') and len(indices) > 1:
57
+ return [self.idx_to_token[int(index)] for index in indices]
58
+ return self.idx_to_token[indices]
59
+
60
+ def add_special_token(self, token: str|list|tuple) -> None:
61
+ if not isinstance(token, (list, tuple)):
62
+ if token in self.token_to_idx:
63
+ raise ValueError(f"token {token} already in the vocab")
64
+ self.idx_to_token[len(self.idx_to_token)] = token
65
+ self.token_to_idx[token] = len(self.token_to_idx)
66
+ else:
67
+ [self.add_special_token(t) for t in token]
68
+
69
+ def split_seq(self, seq: str|list|tuple) -> list:
70
+ if not isinstance(seq, (list, tuple)):
71
+ return re.findall(r"<[a-zA-Z0-9]+>|[a-zA-Z-]", seq)
72
+ return [self.split_seq(s) for s in seq] # a list of list
73
+
74
+ def truncate_pad(self, line, num_steps, padding_token='<PAD>') -> list:
75
+
76
+ if not isinstance(line[0], list):
77
+ if len(line) > num_steps:
78
+ if self.get_attention_mask:
79
+ self.attention_mask.append([1]*num_steps)
80
+ return line[:num_steps]
81
+ if self.get_attention_mask:
82
+ self.attention_mask.append([1] * len(line) + [0] * (num_steps - len(line)))
83
+ return line + [padding_token] * (num_steps - len(line))
84
+ else:
85
+ return [self.truncate_pad(l, num_steps, padding_token) for l in line] # a list of list
86
+
87
+ def get_attention_mask_mat(self):
88
+ attention_mask = self.attention_mask
89
+ self.attention_mask = []
90
+ return attention_mask
91
+
92
+ def seq_to_idx(self, seq: str|list|tuple, num_steps: int, padding_token='<PAD>') -> list:
93
+ '''
94
+ note: ensure to execut this function after add_special_token
95
+ '''
96
+
97
+ splited_seq = self.split_seq(seq)
98
+ # **********************
99
+ # after split, we need to mask sequence
100
+ # note:
101
+ # 1. mask tokens by probability
102
+ # 2. return a list or list of list
103
+ # **********************
104
+ padded_seq = self.truncate_pad(splited_seq, num_steps, padding_token)
105
+
106
+ return self.__getitem__(padded_seq)
107
+
108
+
109
+
110
+ class MutilVocab:
111
+ def __init__(self, data, AA_tok_len=2):
112
+ """
113
+ Args:
114
+ data (_type_):
115
+ AA_tok_len (int, optional): Defaults to 1.
116
+ start_token (bool, optional): True is required for encoder-based model.
117
+ """
118
+ ## Load train dataset
119
+ self.x_data = data
120
+ self.tok_AA_len = AA_tok_len
121
+ self.default_AA = list("RHKDESTNQCGPAVILMFYW")
122
+ # AAs which are not included in default_AA
123
+ self.tokens = self._token_gen(self.tok_AA_len)
124
+
125
+ self.token_to_idx = {k: i + 4 for i, k in enumerate(self.tokens)}
126
+ self.token_to_idx["[PAD]"] = 0 ## idx as 0 is PAD
127
+ self.token_to_idx["[CLS]"] = 1 ## idx as 1 is CLS
128
+ self.token_to_idx["[SEP]"] = 2 ## idx as 2 is SEP
129
+ self.token_to_idx["[MASK]"] = 3 ## idx as 3 is MASK
130
+
131
+ def split_seq(self):
132
+ self.X = [self._seq_to_tok(seq) for seq in self.x_data]
133
+ return self.X
134
+
135
+ def tok_idx(self, seqs):
136
+ '''
137
+ note: ensure to execut this function before truancate_pad
138
+ '''
139
+
140
+ seqs_idx = []
141
+ for seq in seqs:
142
+ seq_idx = []
143
+ for s in seq:
144
+ seq_idx.append(self.token_to_idx[s])
145
+ seqs_idx.append(seq_idx)
146
+
147
+ return seqs_idx
148
+
149
+
150
+
151
+ def _token_gen(self, tok_AA_len: int, st: str = "", curr_depth: int = 0):
152
+ """Generate tokens based on default amino acid residues
153
+ and also includes "X" as arbitrary residues.
154
+ Length of AAs in each token should be provided by "tok_AA_len"
155
+
156
+ Args:
157
+ tok_AA_len (int): Length of token
158
+ st (str, optional): Defaults to ''.
159
+ curr_depth (int, optional): Defaults to 0.
160
+
161
+ Returns:
162
+ List: List of tokens
163
+ """
164
+ curr_depth += 1
165
+ if curr_depth <= tok_AA_len:
166
+ l = [
167
+ st + t
168
+ for s in self.default_AA
169
+ for t in self._token_gen(tok_AA_len, s, curr_depth)
170
+ ]
171
+ return l
172
+ else:
173
+ return [st]
174
+
175
+ def _seq_to_tok(self, seq: str):
176
+ """Convert each token to index
177
+
178
+ Args:
179
+ seq (str): AA sequence
180
+
181
+ Returns:
182
+ list: A list of indexes
183
+ """
184
+
185
+ seq_idx = []
186
+
187
+ seq_idx += ["[CLS]"]
188
+
189
+ for i in range(len(seq) - self.tok_AA_len + 1):
190
+ curr_token = seq[i : i + self.tok_AA_len]
191
+ seq_idx.append(curr_token)
192
+ seq_idx += ['[SEP]']
193
+ return seq_idx