oucgc1996 commited on
Commit
965e421
·
verified ·
1 Parent(s): 9ca4aa5

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -174
model.py DELETED
@@ -1,174 +0,0 @@
1
- import torch.nn as nn
2
- import copy, math
3
- import torch
4
- import numpy as np
5
- import torch.nn.functional as F
6
- from transformers import AutoModelForMaskedLM, AutoConfig
7
-
8
- from bertmodel import make_bert, make_bert_without_emb
9
- from utils import ContraLoss
10
-
11
- def load_pretrained_model():
12
- # model_checkpoint = "/home/ubuntu/work/zq/conoMLM/prot_bert/prot_bert"
13
- model_checkpoint = "/home/ubuntu/work/gecheng/conoGen_final/FinalCono/MLM/prot_bert_finetuned_model_mlm_best"
14
- config = AutoConfig.from_pretrained(model_checkpoint)
15
- model = AutoModelForMaskedLM.from_config(config)
16
-
17
- return model
18
-
19
- class ConoEncoder(nn.Module):
20
- def __init__(self, encoder):
21
- super(ConoEncoder, self).__init__()
22
-
23
- self.encoder = encoder
24
- self.trainable_encoder = make_bert_without_emb()
25
-
26
-
27
- for param in self.encoder.parameters():
28
- param.requires_grad = False
29
-
30
-
31
- def forward(self, x, mask): # x:(128,54) mask:(128,54)
32
- feat = self.encoder(x, attention_mask=mask) # (128,54,128)
33
- feat = list(feat.values())[0] # (128,54,128)
34
-
35
- feat = self.trainable_encoder(feat, mask) # (128,54,128)
36
-
37
- return feat
38
-
39
- class MSABlock(nn.Module):
40
- def __init__(self, in_dim, out_dim, vocab_size):
41
- super(MSABlock, self).__init__()
42
- self.embedding = nn.Embedding(vocab_size, in_dim)
43
- self.mlp = nn.Sequential(
44
- nn.Linear(in_dim, out_dim),
45
- nn.LeakyReLU(),
46
- nn.Linear(out_dim, out_dim)
47
- )
48
- self.init()
49
-
50
- def init(self):
51
- for layer in self.mlp.children():
52
- if isinstance(layer, nn.Linear):
53
- nn.init.xavier_uniform_(layer.weight)
54
- # nn.init.xavier_uniform_(self.embedding.weight)
55
-
56
- def forward(self, x): # x: (128,3,54)
57
- x = self.embedding(x) # x: (128,3,54,128)
58
- x = self.mlp(x) # x: (128,3,54,128)
59
- return x
60
-
61
- class ConoModel(nn.Module):
62
- def __init__(self, encoder, msa_block, decoder):
63
- super(ConoModel, self).__init__()
64
- self.encoder = encoder
65
- self.msa_block = msa_block
66
- self.feature_combine = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=1)
67
- self.decoder = decoder
68
-
69
- def forward(self, input_ids, msa, attn_idx=None):
70
- # 仅使用 input_ids 作为输入,获取编码器输出
71
- encoder_output = self.encoder.forward(input_ids, attn_idx) # (128,54,128)
72
- msa_output = self.msa_block(msa) # (128,3,54,128)
73
- # msa_output = torch.mean(msa_output, dim=1)
74
- encoder_output = encoder_output.view(input_ids.shape[0], 54, -1).unsqueeze(1) # (128,1,54,128)
75
-
76
- output = torch.cat([encoder_output*5, msa_output], dim=1) # (128,4,54,128)
77
- output = self.feature_combine(output) # (128,1,54,128)
78
- output = output.squeeze(1) # (128,54,128)
79
- # 解码器对编码器的输出进行解码
80
- logits = self.decoder(output) # (128,54,85)
81
-
82
- return logits
83
-
84
- class ContraModel(nn.Module):
85
- def __init__(self, cono_encoder):
86
- super(ContraModel, self).__init__()
87
-
88
- self.contra_loss = ContraLoss()
89
-
90
- self.encoder1 = cono_encoder
91
- self.encoder2 = make_bert(404, 6, 128)
92
-
93
- # contrastive decoder
94
- self.lstm = nn.LSTM(16, 16, batch_first=True)
95
- self.contra_decoder = nn.Sequential(
96
- nn.Linear(128, 64),
97
- nn.LeakyReLU(),
98
- nn.Linear(64, 32),
99
- nn.LeakyReLU(),
100
- nn.Linear(32, 16),
101
- nn.LeakyReLU(),
102
- nn.Dropout(0.1),
103
- )
104
-
105
- # classifier
106
- self.pre_classifer = nn.LSTM(128, 64, batch_first=True)
107
- self.classifer = nn.Sequential(
108
- nn.Linear(128, 32),
109
- nn.LeakyReLU(),
110
- nn.Linear(32, 6),
111
- nn.Softmax(dim=-1)
112
- )
113
-
114
- self.init()
115
-
116
- def init(self):
117
-
118
- for layer in self.contra_decoder.children():
119
- if isinstance(layer, nn.Linear):
120
- nn.init.xavier_uniform_(layer.weight)
121
- for layer in self.classifer.children():
122
- if isinstance(layer, nn.Linear):
123
- nn.init.xavier_uniform_(layer.weight)
124
- for layer in self.pre_classifer.children():
125
- if isinstance(layer, nn.Linear):
126
- nn.init.xavier_uniform_(layer.weight)
127
- for layer in self.lstm.children():
128
- if isinstance(layer, nn.Linear):
129
- nn.init.xavier_uniform_(layer.weight)
130
-
131
- def compute_class_loss(self, feat1, feat2, labels):
132
- _, cls_feat1= self.pre_classifer(feat1)
133
- _, cls_feat2 = self.pre_classifer(feat2)
134
- cls_feat1 = torch.cat([cls_feat1[0], cls_feat1[1]], dim=-1).squeeze(0)
135
- cls_feat2 = torch.cat([cls_feat2[0], cls_feat2[1]], dim=-1).squeeze(0)
136
-
137
- cls1_dis = self.classifer(cls_feat1)
138
- cls2_dis = self.classifer(cls_feat2)
139
- cls1_loss = F.cross_entropy(cls1_dis, labels.to('cuda:0'))
140
- cls2_loss = F.cross_entropy(cls2_dis, labels.to('cuda:0'))
141
-
142
- return cls1_loss, cls2_loss
143
-
144
- def compute_contrastive_loss(self, feat1, feat2):
145
-
146
- contra_feat1 = self.contra_decoder(feat1)
147
- contra_feat2 = self.contra_decoder(feat2)
148
-
149
- _, feat1 = self.lstm(contra_feat1)
150
- _, feat2 = self.lstm(contra_feat2)
151
- feat1 = torch.cat([feat1[0], feat1[1]], dim=-1).squeeze(0)
152
- feat2 = torch.cat([feat2[0], feat2[1]], dim=-1).squeeze(0)
153
-
154
- ctr_loss = self.contra_loss(feat1, feat2)
155
-
156
- return ctr_loss
157
-
158
- def forward(self, x1, x2, labels=None):
159
- loss = dict()
160
-
161
- idx1, attn1 = x1
162
- idx2, attn2 = x2
163
- feat1 = self.encoder1(idx1.to('cuda:0'), attn1.to('cuda:0'))
164
- feat2 = self.encoder2(idx2.to('cuda:0'), attn2.to('cuda:0'))
165
-
166
- cls1_loss, cls2_loss = self.compute_class_loss(feat1, feat2, labels)
167
-
168
- ctr_loss = self.compute_contrastive_loss(feat1, feat2)
169
-
170
- loss['cls1_loss'] = cls1_loss
171
- loss['cls2_loss'] = cls2_loss
172
- loss['ctr_loss'] = ctr_loss
173
-
174
- return loss