Delete model.py
Browse files
model.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import copy, math
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from transformers import AutoModelForMaskedLM, AutoConfig
|
7 |
-
|
8 |
-
from bertmodel import make_bert, make_bert_without_emb
|
9 |
-
from utils import ContraLoss
|
10 |
-
|
11 |
-
def load_pretrained_model():
|
12 |
-
# model_checkpoint = "/home/ubuntu/work/zq/conoMLM/prot_bert/prot_bert"
|
13 |
-
model_checkpoint = "/home/ubuntu/work/gecheng/conoGen_final/FinalCono/MLM/prot_bert_finetuned_model_mlm_best"
|
14 |
-
config = AutoConfig.from_pretrained(model_checkpoint)
|
15 |
-
model = AutoModelForMaskedLM.from_config(config)
|
16 |
-
|
17 |
-
return model
|
18 |
-
|
19 |
-
class ConoEncoder(nn.Module):
|
20 |
-
def __init__(self, encoder):
|
21 |
-
super(ConoEncoder, self).__init__()
|
22 |
-
|
23 |
-
self.encoder = encoder
|
24 |
-
self.trainable_encoder = make_bert_without_emb()
|
25 |
-
|
26 |
-
|
27 |
-
for param in self.encoder.parameters():
|
28 |
-
param.requires_grad = False
|
29 |
-
|
30 |
-
|
31 |
-
def forward(self, x, mask): # x:(128,54) mask:(128,54)
|
32 |
-
feat = self.encoder(x, attention_mask=mask) # (128,54,128)
|
33 |
-
feat = list(feat.values())[0] # (128,54,128)
|
34 |
-
|
35 |
-
feat = self.trainable_encoder(feat, mask) # (128,54,128)
|
36 |
-
|
37 |
-
return feat
|
38 |
-
|
39 |
-
class MSABlock(nn.Module):
|
40 |
-
def __init__(self, in_dim, out_dim, vocab_size):
|
41 |
-
super(MSABlock, self).__init__()
|
42 |
-
self.embedding = nn.Embedding(vocab_size, in_dim)
|
43 |
-
self.mlp = nn.Sequential(
|
44 |
-
nn.Linear(in_dim, out_dim),
|
45 |
-
nn.LeakyReLU(),
|
46 |
-
nn.Linear(out_dim, out_dim)
|
47 |
-
)
|
48 |
-
self.init()
|
49 |
-
|
50 |
-
def init(self):
|
51 |
-
for layer in self.mlp.children():
|
52 |
-
if isinstance(layer, nn.Linear):
|
53 |
-
nn.init.xavier_uniform_(layer.weight)
|
54 |
-
# nn.init.xavier_uniform_(self.embedding.weight)
|
55 |
-
|
56 |
-
def forward(self, x): # x: (128,3,54)
|
57 |
-
x = self.embedding(x) # x: (128,3,54,128)
|
58 |
-
x = self.mlp(x) # x: (128,3,54,128)
|
59 |
-
return x
|
60 |
-
|
61 |
-
class ConoModel(nn.Module):
|
62 |
-
def __init__(self, encoder, msa_block, decoder):
|
63 |
-
super(ConoModel, self).__init__()
|
64 |
-
self.encoder = encoder
|
65 |
-
self.msa_block = msa_block
|
66 |
-
self.feature_combine = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=1)
|
67 |
-
self.decoder = decoder
|
68 |
-
|
69 |
-
def forward(self, input_ids, msa, attn_idx=None):
|
70 |
-
# 仅使用 input_ids 作为输入,获取编码器输出
|
71 |
-
encoder_output = self.encoder.forward(input_ids, attn_idx) # (128,54,128)
|
72 |
-
msa_output = self.msa_block(msa) # (128,3,54,128)
|
73 |
-
# msa_output = torch.mean(msa_output, dim=1)
|
74 |
-
encoder_output = encoder_output.view(input_ids.shape[0], 54, -1).unsqueeze(1) # (128,1,54,128)
|
75 |
-
|
76 |
-
output = torch.cat([encoder_output*5, msa_output], dim=1) # (128,4,54,128)
|
77 |
-
output = self.feature_combine(output) # (128,1,54,128)
|
78 |
-
output = output.squeeze(1) # (128,54,128)
|
79 |
-
# 解码器对编码器的输出进行解码
|
80 |
-
logits = self.decoder(output) # (128,54,85)
|
81 |
-
|
82 |
-
return logits
|
83 |
-
|
84 |
-
class ContraModel(nn.Module):
|
85 |
-
def __init__(self, cono_encoder):
|
86 |
-
super(ContraModel, self).__init__()
|
87 |
-
|
88 |
-
self.contra_loss = ContraLoss()
|
89 |
-
|
90 |
-
self.encoder1 = cono_encoder
|
91 |
-
self.encoder2 = make_bert(404, 6, 128)
|
92 |
-
|
93 |
-
# contrastive decoder
|
94 |
-
self.lstm = nn.LSTM(16, 16, batch_first=True)
|
95 |
-
self.contra_decoder = nn.Sequential(
|
96 |
-
nn.Linear(128, 64),
|
97 |
-
nn.LeakyReLU(),
|
98 |
-
nn.Linear(64, 32),
|
99 |
-
nn.LeakyReLU(),
|
100 |
-
nn.Linear(32, 16),
|
101 |
-
nn.LeakyReLU(),
|
102 |
-
nn.Dropout(0.1),
|
103 |
-
)
|
104 |
-
|
105 |
-
# classifier
|
106 |
-
self.pre_classifer = nn.LSTM(128, 64, batch_first=True)
|
107 |
-
self.classifer = nn.Sequential(
|
108 |
-
nn.Linear(128, 32),
|
109 |
-
nn.LeakyReLU(),
|
110 |
-
nn.Linear(32, 6),
|
111 |
-
nn.Softmax(dim=-1)
|
112 |
-
)
|
113 |
-
|
114 |
-
self.init()
|
115 |
-
|
116 |
-
def init(self):
|
117 |
-
|
118 |
-
for layer in self.contra_decoder.children():
|
119 |
-
if isinstance(layer, nn.Linear):
|
120 |
-
nn.init.xavier_uniform_(layer.weight)
|
121 |
-
for layer in self.classifer.children():
|
122 |
-
if isinstance(layer, nn.Linear):
|
123 |
-
nn.init.xavier_uniform_(layer.weight)
|
124 |
-
for layer in self.pre_classifer.children():
|
125 |
-
if isinstance(layer, nn.Linear):
|
126 |
-
nn.init.xavier_uniform_(layer.weight)
|
127 |
-
for layer in self.lstm.children():
|
128 |
-
if isinstance(layer, nn.Linear):
|
129 |
-
nn.init.xavier_uniform_(layer.weight)
|
130 |
-
|
131 |
-
def compute_class_loss(self, feat1, feat2, labels):
|
132 |
-
_, cls_feat1= self.pre_classifer(feat1)
|
133 |
-
_, cls_feat2 = self.pre_classifer(feat2)
|
134 |
-
cls_feat1 = torch.cat([cls_feat1[0], cls_feat1[1]], dim=-1).squeeze(0)
|
135 |
-
cls_feat2 = torch.cat([cls_feat2[0], cls_feat2[1]], dim=-1).squeeze(0)
|
136 |
-
|
137 |
-
cls1_dis = self.classifer(cls_feat1)
|
138 |
-
cls2_dis = self.classifer(cls_feat2)
|
139 |
-
cls1_loss = F.cross_entropy(cls1_dis, labels.to('cuda:0'))
|
140 |
-
cls2_loss = F.cross_entropy(cls2_dis, labels.to('cuda:0'))
|
141 |
-
|
142 |
-
return cls1_loss, cls2_loss
|
143 |
-
|
144 |
-
def compute_contrastive_loss(self, feat1, feat2):
|
145 |
-
|
146 |
-
contra_feat1 = self.contra_decoder(feat1)
|
147 |
-
contra_feat2 = self.contra_decoder(feat2)
|
148 |
-
|
149 |
-
_, feat1 = self.lstm(contra_feat1)
|
150 |
-
_, feat2 = self.lstm(contra_feat2)
|
151 |
-
feat1 = torch.cat([feat1[0], feat1[1]], dim=-1).squeeze(0)
|
152 |
-
feat2 = torch.cat([feat2[0], feat2[1]], dim=-1).squeeze(0)
|
153 |
-
|
154 |
-
ctr_loss = self.contra_loss(feat1, feat2)
|
155 |
-
|
156 |
-
return ctr_loss
|
157 |
-
|
158 |
-
def forward(self, x1, x2, labels=None):
|
159 |
-
loss = dict()
|
160 |
-
|
161 |
-
idx1, attn1 = x1
|
162 |
-
idx2, attn2 = x2
|
163 |
-
feat1 = self.encoder1(idx1.to('cuda:0'), attn1.to('cuda:0'))
|
164 |
-
feat2 = self.encoder2(idx2.to('cuda:0'), attn2.to('cuda:0'))
|
165 |
-
|
166 |
-
cls1_loss, cls2_loss = self.compute_class_loss(feat1, feat2, labels)
|
167 |
-
|
168 |
-
ctr_loss = self.compute_contrastive_loss(feat1, feat2)
|
169 |
-
|
170 |
-
loss['cls1_loss'] = cls1_loss
|
171 |
-
loss['cls2_loss'] = cls2_loss
|
172 |
-
loss['ctr_loss'] = ctr_loss
|
173 |
-
|
174 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|