oucgc1996 commited on
Commit
2ef3fe3
·
verified ·
1 Parent(s): e6184b6

Upload 10 files

Browse files
Files changed (10) hide show
  1. C0_seq.csv +631 -0
  2. app.py +124 -0
  3. bertmodel.py +199 -0
  4. conoData_C5.csv +0 -0
  5. dataset_mlm.py +151 -0
  6. model.py +174 -0
  7. requirements.txt +4 -0
  8. utils.py +132 -0
  9. vocab.py +193 -0
  10. vocab.txt +30 -0
C0_seq.csv ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Seq
2
+ GCCSHPACLVDHPEIC
3
+ RDPCCYHPTCNMANPQIC
4
+ GCCSDPRCAYDHPEIC
5
+ PACCTHPACHVNHPELC
6
+ GEDEYAEGIREYQLIHGKI
7
+ RDCQEKWEYCIVPILGFVYCCPGLICGPFVCV
8
+ GCCSHPACDVDHPEIC
9
+ RCCTGKKGSCSGRACKNLKCCA
10
+ GCCSHPACAGNNQHIC
11
+ TARSSGRYARSPYDRRRRYSRRITDASV
12
+ GCCSHPVCHARHPALC
13
+ TSRSSGRYSRSPYDRRRRYARRITDAAV
14
+ DECCSRPPCRVNNPHVCRRR
15
+ IRDACCSNPACRVNNPHVC
16
+ GCCSDPACNVNNPHIC
17
+ GCCSHPVCRARHPALC
18
+ GCCLHPACSVNHPELC
19
+ GCCSHPACNVNNPHICG
20
+ GEEEVAKMAAELARENIAKGCKVNCYP
21
+ LPSCCSLALRLCPVPACKRNPCCT
22
+ GCCSHPACSVRHPELC
23
+ CCNCSSKWCRDHSRCC
24
+ QCCSNPPCAHEHC
25
+ FNWRCCLIPACRRNHKKFC
26
+ GCCSDPRCLYDHPEIC
27
+ IRDECCSNPACRVNNPAVC
28
+ GCCSYPPCFATNPDC
29
+ ECCNPACGRHYSC
30
+ VRCLEKSGAQPNKLFRPPCCQKGPSFARHSRCVYYTQSRE
31
+ CKLKGQSCRKTSYDCCSGSCGRSGKC
32
+ GEEEYAEFI
33
+ GCCSNPPCAHEHC
34
+ CAIPNQKCFQHLDDCCSRKCNRFNKCV
35
+ GAGCCSHPVCAAMSPIC
36
+ DPCCYHPTCNMSNPQIC
37
+ CKPPGSKCSPSMYDCCTTCISYTKRCRKYY
38
+ CCNCSSKQCRDHSRCC
39
+ GEEELAEKAAEFARELAN
40
+ GCCSDPRCRYRC
41
+ GCCSHPACSVNHPNLC
42
+ TARSSGRYARSPYDRRRRYCRRITDACV
43
+ GCCSHPVCNVRHPEIC
44
+ KPCCSIHDNSCCGA
45
+ YKLCHPC
46
+ GCCSHPACNVNNPHIC
47
+ GACCHPACGKNYSC
48
+ GCCSHPVCYAMSPIC
49
+ GCCSHPACSVNAPELC
50
+ GCCSHPACNVAAPHIC
51
+ GCCSDPRCNYNHPEIC
52
+ RCCHPACGKKFNC
53
+ CRIPNQKCYQHLDDCCSRKCNRFNKCV
54
+ LPPCCTPPKKHCPAPACKYKPCCKS
55
+ GCCSHPACYVNHPELC
56
+ GGGPRIPNQKCFQHLDDCCSRKCNRFNKCVLPET
57
+ KPCCSIADNSCCGL
58
+ LPSCCSLNLRLCPVPACRKNPCCT
59
+ TCRSSGRYCRSPYDRRRRYARRITDAAV
60
+ KPCCSIHDNSCCGL
61
+ ARFLHPFQYYTLYRYLTRFLHRYPIYYIRY
62
+ GCCSNPVCHLAHSNLC
63
+ GCCSHPVCAAMSPIC
64
+ GCCSHPACSVNHPEAC
65
+ DECCSNPACRLNNPHDCRRR
66
+ IRDECCSNPQCRVNNPHVC
67
+ IRDECCSNPACRANNPHVC
68
+ GCCSHPACSVNHPEIC
69
+ GCCSLPPCALSNPDYC
70
+ DECCSNPACRLNNPHVCRRR
71
+ TCRSSGRYCRSPYDRRRRYSRRITDASV
72
+ GCCSDPRCAWRC
73
+ GCCSDPRCRYRCR
74
+ QNCCNGGCSSKWCRDHARCC
75
+ IRDECCSNPACRVNNPHYC
76
+ VTDRCCKGKRECGRWCRDHSRCC
77
+ GCCSHPACSANHPELC
78
+ GCCSHAACSVNHPELC
79
+ GRCCHPACGKNYSC
80
+ GCCSNAVCHLEHSNLC
81
+ RDCCTPPKKCKDRRCKPLKCCA
82
+ LHCHEISDLTPWILCSPEPLCGGKGCCAQEVCDCSGPACTCPPCL
83
+ GCCSDPRCRYRCK
84
+ GCCKDPRCNYDHPEIC
85
+ RDPCCSNPVCTVHNPQIC
86
+ DDESECIINTRDSPWGRCCRTRMCGSMCCPRNGCTCVYHWRRGHGCSCPG
87
+ RDCCTPPKKAKDRQCKPQRCAA
88
+ GCCSNPVCHLEHSNLC
89
+ GCCSNPVCALEHSNLC
90
+ GCCSHPACIVDHPEIC
91
+ IRDECCSNAACRVNNPHVC
92
+ GCCSHPACSVNHAELC
93
+ QGVCCGSKLCHPC
94
+ GEPEVAKWAEGLREKAASN
95
+ GRCCDVPNACSGRWCRDHAQCC
96
+ RDACCYHPTCNMSNPQIC
97
+ GCCSTPPCALAC
98
+ DMCCHPACGKHFNC
99
+ GCCTHPACHGNHPELC
100
+ QRLCCGFPKSCRSRQCKPHRCC
101
+ GCCSHPACSRNHPELC
102
+ NGRCCHPACGKHFSC
103
+ TSRSSGRYSRSPYDRRRRYSRRITDASV
104
+ DEPEYAEAIREYQLKYGKI
105
+ ACCSDRRCRWRC
106
+ GCCANPVCALEHSNLC
107
+ RPECCTHPACHVSHPELC
108
+ GCCSDPPCRNKHPDLCM
109
+ GCCSRPPCIANNPDLC
110
+ ATSGPMGWLPVFYRF
111
+ TYGIYDAKPPFSCAGLRGGCVLPPNLRPKFKE
112
+ IRDECCSNPACRVNNPHVC
113
+ ASGADTCCSNPACQVQHSDLC
114
+ CGYKLCHPC
115
+ PCQSVRPGRVWGKCCLTRLCSTMCCARADCTCVYHTWRGHGCSCVM
116
+ GFRSPCAPFC
117
+ RGCCNGRGGCSSRWCRDHARCC
118
+ SGCCSNPACRVDNPNIC
119
+ RDACTPPKKCKDRQAKPQRCCA
120
+ CKSPGSSCSPTSYNCCRSCNPYTKRCY
121
+ PCCYHPTCNMSNPQIC
122
+ QGVCCGYQLCHPC
123
+ RDPCCYHPTCNMSNAQIC
124
+ GCCSHPVCRARHRALC
125
+ GCCSHPACNVDAPEIC
126
+ ECCSNPACRVNNPHVC
127
+ PECCTHPACASHPELC
128
+ RCCHPACMNHFNC
129
+ GCCSNPVCHLEHSALC
130
+ GCCSHPVCSVNHPELC
131
+ GCCSHAACNVDHPEIC
132
+ GCCSDPRCGYDHPEIC
133
+ RDCQEKWEYCIVPIAGFVYCCPGLICGPFVCV
134
+ GCCSAPACSVNHPELC
135
+ QGVCCGKKLCHPC
136
+ DDEEYSEAI
137
+ PECCTHPACHVSHPELC
138
+ GCCSHPACSDNHPELC
139
+ QSPGCCWNPACVKNRC
140
+ GCCSLPPCRANNPDYC
141
+ GCCSEPRCRYRCR
142
+ TARSSGRYCRSPYDRRRRYCRRITDAAV
143
+ GEEEYSEFI
144
+ GCCANPVCHLAHSNAC
145
+ QGVCCGQKLCHPC
146
+ GCCSDPPCIANNPDLC
147
+ TSRSSGRYCRSPYDRRRRYCRRITDASV
148
+ GCCSHPACSVNHSELC
149
+ GCCRWPCPSRCGMARCCSS
150
+ GEDDLQDNQDLIRDKSN
151
+ CCSNPACQVQHSDLC
152
+ CRIPNQKCFQHLDDCCARKCNRFNKCV
153
+ GRCCHPACGKNWSC
154
+ CKSPGTPCSRGMRDCCTSCLLYSNKCRRY
155
+ IRDECCSAPACRVNNPHVC
156
+ CKRKGASCRRTSYDCCTGSCRNGKC
157
+ LASCCSLNLRLCPVPACKRNPCCT
158
+ GRCCAPACGKNYSC
159
+ GDEEYSKFIELARENIAKGCKVNCYP
160
+ QKCCSGGSCPLYFRDRLICPCC
161
+ GEEELAELAPEFARELAN
162
+ PACCTHPACHVSHPELC
163
+ ICCNPACGPHYSC
164
+ GCCSTPPCSVLYC
165
+ KACCSIHDNSCCGL
166
+ KCNFDKCKGTGVYNCGESCSCEGLHSCRCTYNIGSMKSGCACICTYY
167
+ GWCGDPGATCGKLRLYCCSGFCDCYTKTCKDKSSA
168
+ RDPCCYHPTCNMSNPQAC
169
+ GCCSHPVCFAMSPIC
170
+ GCCSNPVCHLAHSNAC
171
+ GEEECSEAI
172
+ RHGCCKGPKGCSSRECRPQHCC
173
+ GCCSTPPCAVLYC
174
+ ADEEYLKFIEEQRKQGKLDPTKFP
175
+ GPPCCLYGSCRPFPGCYNALCCRK
176
+ CKRKGSSCARTSYDCCTGSCRNGKC
177
+ IRDECCSNPVCRVNNPHVC
178
+ GCCSHPACSVNHPELC
179
+ VCCGYKLCHPC
180
+ ICCNPACGPNYSC
181
+ GCCSLPPCAANNPDYC
182
+ QGVCCGYKLCHEC
183
+ RDPCCYHPTCNMSNPQIC
184
+ DMCCHPACMKHFNC
185
+ GCCSHPVCKAMSPIC
186
+ PECCTHPACHVSNPELC
187
+ SGCCSNPACMVNNPNIC
188
+ GCCSNPVCALEHSNAC
189
+ GCCSHPACANHPELC
190
+ TSRSSGRYSRSPYDRRRRYCRRITDACV
191
+ NGVCCGYKLC
192
+ GCCSHPACSVNHPDLC
193
+ GEEEYSEAI
194
+ GRCCHPACGKNYAC
195
+ CCSNPACRVNNPHVC
196
+ GCCSRPPCRLNNPRYC
197
+ PECCTHPACHVNHPELC
198
+ IRDECCSNPACRSNNPHVC
199
+ GGGCCSHPACAANNQDYC
200
+ GCCSHPACSVNHPQLC
201
+ QGVCCGRKLCHPC
202
+ GCCSHPVCNVRHPELC
203
+ QGCCNVPNGCSGRWCRDHAQCC
204
+ GRCCHPACGKHFSC
205
+ GCCSHPACNVDHPEAC
206
+ DMCCHPACGNHFNC
207
+ IRDECCSNPACRVNAPHVC
208
+ RDPCCYHPACNMSNPQIC
209
+ IRDECCSNPACRYNNPHVC
210
+ GAGGAAGGCCSHPVCAAMSPIC
211
+ GCCSHPACSVNHPRLC
212
+ VKPCRKEGQLCDPIFQNCCRGWNCVLFCV
213
+ IRDECCSNPACRVNHPHVC
214
+ GCCSDPRCNMNNPDYC
215
+ GCCSNPVCHLEHPNAC
216
+ GCCSDPRCAYRC
217
+ QGVCCGYKSCHPC
218
+ APCCSIHDNSCCGL
219
+ TRLCSTMCCARADCTCVYHTWRGHGCSCVM
220
+ EACYAPGTFCGIKPGLCCSEFCLPGVCFG
221
+ SGCCSNPACDVNNPNIC
222
+ GCCSYPPCFATNPDCAGG
223
+ GCCSRPPCALNNPDYC
224
+ TARSSGRYARSPYDRRRRYARRITDAAV
225
+ GCCSNPACSVNHPELC
226
+ GCCSHPACADHPEIC
227
+ TGVCCGYKLCHPC
228
+ GCCSDPRCRYNHPEIC
229
+ QGVCCGWKLCHPC
230
+ GCCSYPPCFATNPDCGGAGGAG
231
+ RDPCCAHPTCNMSNPQIC
232
+ LPSCCSLNLRLCPVPACKRNPCCT
233
+ GCCSHPACSVNNPDIC
234
+ GCCSAPACNVDHPEIC
235
+ GCCSNPACHLEHSNLC
236
+ RDCCTPPKKCKDRQCKPQRCCA
237
+ GCCSHPVCHARHPELC
238
+ TCRSSGRYARSPYDRRRRYARRITDACV
239
+ GCCANPVCHLEHSNLC
240
+ QGVCCGYKL
241
+ DDEEYAEFIEQQREAGLV
242
+ CRIPNQKCFQALDDCCSRKCNRFNKCV
243
+ GCCSDPPCRNKHPDLC
244
+ CKGKGASCRRTSYDCCTGSCRSGRC
245
+ GCCSHPACKVDHPEIC
246
+ PECCTHPACHGSHPELC
247
+ DMCCHPACMNHFNC
248
+ GCCSDPPCRNKHPDLCG
249
+ GCCSYPPCFATNPDCA
250
+ CCGVPNAACPPCVCKNTC
251
+ GRCCHPACGKNASC
252
+ FPSCCSLNLRLCPVPACKRNPCCT
253
+ FGVCCGYKLCHPC
254
+ GCCSNPACMVNNPQIC
255
+ KPCCSAHDNSCCGL
256
+ GCCIHPACSVNHPELC
257
+ IRDECCSNPACANNPHVC
258
+ HPPCCLYGKCRPFPGCSSASCCQR
259
+ GCCSHPVCHAMSPIC
260
+ GRCCHPACGKNHSC
261
+ CRIPNQKCFQHLDDCCSRACNRFNKCV
262
+ RDCCSNPPCAHNNPDLC
263
+ QGVCCGYKLCEPC
264
+ GGCCSHPVCYTKNPNCG
265
+ GCCSHPACNVDHAEIC
266
+ IRDECCSNPACRINNPHVC
267
+ ACCSDPRCRYRCR
268
+ SCCARNPACRHNHPCV
269
+ GCCSHPACAGNNPYFC
270
+ QRCCNGRRGCSSRWCRDHSRCC
271
+ GCCSNPVCHAEHSNAC
272
+ ILRGILRNGVCC
273
+ QGVCCGYKLCFPC
274
+ AECCSNPACRVNNPHVC
275
+ GCCSHPACSTNHPELC
276
+ GCCSHPVCRAMSPIC
277
+ RDPCCYHPTCAMSNPQIC
278
+ ADECCSNPACRVNNPHVC
279
+ GCCSHPACHLDHPELC
280
+ GCCSHPVCLAMSPIC
281
+ QCCSNPPCAHEHCR
282
+ GCCSHPVCDAMSPIC
283
+ IRNECCSNPACRVNNPHVC
284
+ RDPGCCSNPVCHLEHSNLC
285
+ GPPCCLYGSCRPFPGCSSASCCRK
286
+ ADCCSNPPCAHNNPDC
287
+ ECCNPACGRAYSC
288
+ TCRSSGRYSRSPYDRRRRYSRRITDACV
289
+ ICCNPACGKKYSC
290
+ GCCSHPACHLNHPEIC
291
+ ADPCCYHPTCNMSNPQIC
292
+ GCCSHPACTVNHPELC
293
+ GCCSDPPCRNAHPDLC
294
+ GEEELAEKAPEFARELAN
295
+ GCCGPYPNAACHPCGCKVGRPPYCDRPSGG
296
+ QGVCCGYLLCHPC
297
+ RDPCCYAPTCNMSNPQIC
298
+ CRIPNQACFQHLDDCCSRKCNRFNKCV
299
+ RDCQEKWAYCIVPILGFVYCCPGLICGPFVCV
300
+ RDAATPPKKCKDRQAKPQRACA
301
+ RIKKPIFIAFPRF
302
+ GCCSDPRCRYKCR
303
+ GCCSNPPCIANNPDLC
304
+ GCCSRPPCILNNPDLC
305
+ CKGKGAKCSRLMYDCCTGSCRSGKC
306
+ NGVCCGYK
307
+ GCCKDPRCAYDHPEIC
308
+ GCCSDPRCIYDHPEIC
309
+ GCCVHPACSVNHPELC
310
+ CRAPNQKCFQHLDDCCSRKCNRFNKCV
311
+ ACCSHPACNVDHPEIC
312
+ QGVCCGYKLCHPC
313
+ VGERCCKNGKRGCGRWCRDHSRCC
314
+ GCCSDPLCAWRC
315
+ RDCQKKWKYCIVPILGFVYCCPGLICGPFVCV
316
+ IRDECCSNPACRVANPHVC
317
+ IRAECCSNPACRVNNPHVC
318
+ GCCSHPACNVAHPEIC
319
+ GCCSDPKCRYRCR
320
+ GRCCHPACGKAYSC
321
+ GCCSNPPCAHEHCR
322
+ GCCSAPPCALYC
323
+ LPSCCSLNLALCPVPACKRNPCCT
324
+ QGVCCGYKLCHKC
325
+ GCCSYPPCFATNPDCGAGAAG
326
+ GCCSHPACSVNHQELC
327
+ GCCSHPACDVNHPELC
328
+ GCCRDPRCNYDHPEIC
329
+ GCCSHPACSLNHPELC
330
+ GCCSHPACNVDHPEIC
331
+ CKRKGSSCRRTSYDCCTGSCRNGKC
332
+ GCCSHPACSVKHPELC
333
+ CRIPNQKCFQHLDDCCSRKCNRFNKCV
334
+ GCCSNPVCHLRHSNLC
335
+ GCCSLPPCALNNPDYC
336
+ GCCGSYPNAACHPCSCKDRPSYCGQ
337
+ GCCSHPACHLNHPELC
338
+ RDPCCYHPTCNMSAPQIC
339
+ GCCSDVRCRYRCR
340
+ GGAAGGGCCSHPVCAAMSPIC
341
+ GCCSNPVCHAEHSNLC
342
+ GCCSHPACHARHPELC
343
+ GCCSHPACSVNHPEVC
344
+ GCCSHPACWVNNPHIC
345
+ CCNCSSKRCRDHSRCC
346
+ RDCQEAWEYCIVPILGFVYCCPGLICGPFVCV
347
+ CKGKGASCRRTSYDCCTGSCRLGRC
348
+ GCCSHPACHVNHPELC
349
+ CKGKGASCRKTSYDCCTGSCRLGRC
350
+ CKGKGAKCSRLAYDCCTGSCRSGKC
351
+ IRDQCCSNPACRVNNPHVC
352
+ ACRKKWEYCIVPIIGFIYCCPGLICGPFVCV
353
+ QGVCCGFKLCHPC
354
+ GCCSHPACSGNNPYAC
355
+ GEEELAEKAEFARELAN
356
+ GCCSRAACAGIHQELC
357
+ ACCNPACGRHYSC
358
+ CKGKGAKCSRIMYDCCTGSCRSGKC
359
+ GCCSHPACSVEHPELC
360
+ GCCSAPVCHLEHSNLC
361
+ GCCSDPRCNYEHPAICGGAAGG
362
+ GCCSHPACRVNHPELC
363
+ QGVCCGYELCHPC
364
+ IRDECCSNPACRVNNPHAC
365
+ ICCNPACGPKYSC
366
+ IRDECCSNPSCRVNNPHVC
367
+ ECCNPACGRHASC
368
+ DDEEYAEFI
369
+ MPSCCSLNLRLCPVPACKRNPCCT
370
+ QGVCCGYKLCQPC
371
+ GGAAGGCCSHPVCAAMSPIC
372
+ CCNCSSKECRDHSRCC
373
+ IRDECCSNPACRVNNPHQC
374
+ RACCSNPPCAHNNPDC
375
+ GYKLCHPC
376
+ NGRCCHPACGKHFNC
377
+ GCCAYPPCFATNPDC
378
+ GCCSNPRCAWRC
379
+ CCNCSSKWCRDHSACC
380
+ GCCSHPACSVNHPALC
381
+ GCCSHPACNADHPEIC
382
+ GCCSHPACNVDHPAIC
383
+ GDEEYSEFIERERELVSSKIPR
384
+ GCCSHPACSGANPYFC
385
+ ARDECCSNPACRVNNPHVC
386
+ DECCSNPACRVNNPHVCRRR
387
+ NGVCCGYKLCHPC
388
+ QGVCCGYKLC
389
+ GCCSTPPCAALYC
390
+ GGAGGCCSHPVCAAMSPIC
391
+ IRDECCSNPTCRVNNPHVC
392
+ GRCCHAACGKNYSC
393
+ PECCTHPACHGNHPELC
394
+ RTCCSRPTCRMEYPELCG
395
+ RDCQEKWEYCIVPALGFVYCCPGLICGPFVCV
396
+ RDCQEKWEYCIVPILGFVWCCPGLICGPFVCV
397
+ GCCSYPPCFATNPDCAGGG
398
+ AARCCTYHGSCLKEKCRRKYCC
399
+ GCCSHPACSVNHPERC
400
+ GCCSHPACSVAHPELC
401
+ QGVCCGYKLCPPC
402
+ GCCSDPRCRYQCR
403
+ QGVCCGYILCHPC
404
+ GCCSHPACNVNHPELC
405
+ GCCAHPACSVNHPELC
406
+ QGCCNGPKGCSSKWCRDHARCC
407
+ IADECCSNPACRVNNPHVC
408
+ CCSNPPCAHEHC
409
+ CCNCSSKWCADHSRCC
410
+ GCCSYPPCFATNPDCAG
411
+ GCCSHPACHLEHPELC
412
+ GCCSDPRCRYDHPEIC
413
+ GCCTHPACSVNHPELC
414
+ VGVCCGYKLCHPC
415
+ RIKKPIFAFPRF
416
+ GRCCHPACGKNMSC
417
+ GDEEYSKFIEREREAGRLDLSKFP
418
+ GFRSACPPFC
419
+ GEEELQENQELIREKSN
420
+ GCCSHPACSVNHPELCGRRRRGGCCSHPACSVNHPELC
421
+ CRIPNQKCFQHLADCCSRKCNRFNKCV
422
+ GEEELAENQEFARELAN
423
+ CRIPNQKCFQHLDDCCSRKCNRANKCV
424
+ IRDECCSNPACRTNNPHVC
425
+ SGCCSNPACRVQNPNIC
426
+ GCCSHPACFVNNPHIC
427
+ GCCAHPACNVDHPEIC
428
+ IRDECCSNPACRLNNPHVC
429
+ GCCSNPVCHLEHANLC
430
+ RDCATPPKKCKDRQCKPQRACA
431
+ ECCNPACARHYSC
432
+ GCCSNPVCHLEHSNAC
433
+ SGCCSNPACRVNNPNIC
434
+ RDPCCYHATCNMSNPQIC
435
+ IRDECCSNPACRVNNAHVC
436
+ CCYHPTCNMSNPQIC
437
+ IRDECCANPACRVNNPHVC
438
+ CKRKGSSCRRTAYDCCTGSCRNGKC
439
+ GCCSDPRCRWRCR
440
+ SGVCCGYKLCHPC
441
+ LPSCCALNLRLCPVPACKRNPCCT
442
+ LPSCCSLNLRLCPVPACARNPCCT
443
+ CKGKGASCHRTSYDCCTGSCNRGKC
444
+ LPSCCSLNLRLCPVPACKANPCCT
445
+ GCCSTPPCALYC
446
+ SGCCSNPACFVLNPNIC
447
+ GCCSDPRCAYRCR
448
+ QGVCCGLKLCHPC
449
+ GCCSDPRCWYDHPEIC
450
+ DECCSNPACRLNNPHACRRR
451
+ CKRKGSSCRRTSYDCCTGSCRSGKC
452
+ GCCSAPPCALYCG
453
+ GCCSHPACHARHPALC
454
+ GEEEYAEKAPEFARELAN
455
+ GRCCHPACGKYYSC
456
+ KPSCCSLNLRLCPVPACKRNPCCT
457
+ GEDDYQDAQDLIRDKSN
458
+ GFRSPCPPFC
459
+ GCCSDPRCKYRCR
460
+ IKDECCSNPACRVNNPHVC
461
+ GCCSHPACYVNNPHIC
462
+ EPSCCSLNLRLCPVPACKRNPCCT
463
+ CCSNPACRVNNPNIC
464
+ IRDECCSNPACRWNNPHVC
465
+ CKPPGSKCSPSMRDCCTTCISYTKRCRKYY
466
+ GCCSNPACMLKNPNLC
467
+ GCCSDPRCAWEC
468
+ TCRSSGRYCRSPYDRRRRYCRRITDACV
469
+ GCCSLPPCRLNNPDYC
470
+ GCCARAACAGIHQELC
471
+ CCSDPRCRYDHPEIC
472
+ AECCTHPACHVSHPELC
473
+ QGVCCGYKLCLPC
474
+ CCGVPNAACHPCVCNNTC
475
+ RDPCCYHPTCNMSNPAIC
476
+ ACCSDRRCRYRC
477
+ RDCATPPKKAKDRQCKPQRAAA
478
+ WNGVCCGYKLCHPC
479
+ RDPCCYHPTCNASNPQIC
480
+ CKPPGSPCRVSSYNCCSSCKSYNKKCG
481
+ CRIPNQKCFQHLDDCCSRKCNRFNACV
482
+ GCCSDPRCSVNHPELC
483
+ QGVCCGIKLCHPC
484
+ RIRKPIFIAFPRF
485
+ GCCSNPACRVNNPNIC
486
+ SGSTCTCFTSTNCQGSCECLSPPGCYCSNNGIRQRGCSCTCPGT
487
+ GCCGKYPNAACHPCGCTVGRPPYCDRPSGG
488
+ LPSCCSLNLRLCPVPACKRNPCCA
489
+ GCCSHPVCEAMSPIC
490
+ CRIPNQKCFAHLDDCCSRKCNRFNKCV
491
+ GCCSYPPCFATNPDCGGAAG
492
+ RDCQEKWEYCIVPILGAVYCCPGLICGPFVCV
493
+ IRNQCCSNPACRVNNPHVC
494
+ CRIPNQKCFQHLDDCCSRKCNAFNKCV
495
+ GRCCHPACGANYSC
496
+ ECCNAACGRHYSC
497
+ CKGQSCSSCSTKEFCLSKGSRLMYDCCTGSCCGVKTAGVT
498
+ HPPCCLYGKCRRYPGCSSASCCQR
499
+ DCCPAKLLCCNP
500
+ RDACTPPKKAKDRQAKPQRCAA
501
+ ACCSNPVCHLEHSNLC
502
+ GCCSHPACSVDHPELC
503
+ RDCCSNPPCAHNNPDC
504
+ QGVCCGYKLCHP
505
+ CCGVPNAACHPCVCTGKC
506
+ GCCSTPPCAAYC
507
+ LGVCCGYKLCHPC
508
+ GCCSHPACNVANPHIC
509
+ GCCSDPRCFYDHPEIC
510
+ RDCQAKWEYCIVPILGFVYCCPGLICGPFVCV
511
+ GCCSHPACSINHPELC
512
+ CQIPNQKCFQHLDDCCSRKCNRFNKCV
513
+ GCCSDPRCRYRCGRRRRGGCCSDPRCRYRC
514
+ LPSCCSLNLRLCAVPACKRNPCCT
515
+ CKGKGASCRKTSYDCCTGSCRSGRC
516
+ QGVCCGYKVCHPC
517
+ GCCSHPACNVNAPHIC
518
+ GCCSHPACSVNHRELC
519
+ GCCSHPVCHVRHPELC
520
+ GGCCSHPACAANNQDYC
521
+ GVCCGYKLCHPC
522
+ GCCSDPRCRYRCGGAAGAG
523
+ CCGVPNAACPPCVCNKTCG
524
+ CRSSGSPCGVTSICCGRCYRGKCT
525
+ ACSKKWEYCIVPILGFVYCCPGLICGPFVCV
526
+ CKRKGSSCRRLSYDCCTGSCRNGKC
527
+ GCCSNPVCHLEASNLC
528
+ QGVCCGPKLCHPC
529
+ LPSCCSLNLRLCPVPACKRNACCT
530
+ GCCSHPVCSAMSPIC
531
+ GDEEVSKFIEREREAGRLDLSKFP
532
+ RDCCSNPPCAANNPDC
533
+ GCCSHPACGVDHPEIC
534
+ QGVCCGYKLCHYC
535
+ CKIPNQKCFQHLDDCCSRKCNRFNKCV
536
+ GCCADPRCRYRCR
537
+ CRIPNQRCFQHLDDCCSRKCNRFNKCV
538
+ RIRKPIFAFPRF
539
+ GCCSHPACSGNNPYFC
540
+ GCCSLPPCAASNPDYC
541
+ GTYLYPFSYYRLWRYFTRFLHKQPYYYVHI
542
+ GDEEVAKFIEREREAGRLDLSKFP
543
+ CKSPGSSCSKTSYNCCRSCNPYTKRCY
544
+ QGVCCGYRLCHPC
545
+ CCGVPNAACHPCVCKNTC
546
+ LPACCSLNLRLCPVPACKRNPCCT
547
+ IRDECCSNPACRFNNPHVC
548
+ RDCCSNPPCAHNNPD
549
+ GCCCNPACGPNYGCGTSCS
550
+ RDPGCCSNPVCHLRHSNLC
551
+ GCCSDPRCNYDHPEICGRRRRGGCCSDPRCNYDHPEIC
552
+ ACCSNPACRVNNPHVC
553
+ GILRNGVCCGYKLCHPC
554
+ LRNGVCCGYKLCHPC
555
+ CKGKGASCRKTMYDCCRGSCRSGRC
556
+ QGVCCGYKLCHDC
557
+ GCCSHPACSGNNAYFC
558
+ CKRKGSSCRRASYDCCTGSCRNGKC
559
+ CKRKGSSCRRTMYDCCTGSCRNGKC
560
+ CKPPGSKCSPSMRDCCTTCISYTKRCRKYYN
561
+ YPSCCSLNLRLCPVPACKRNPCCT
562
+ ECCNPACGAHYSC
563
+ GCCATPPCALYC
564
+ GCCSHPACSGNNPAFC
565
+ AVKKTCIRSTPGSNWGRCCLTKMCHTLCCARSDCTCVYRSGKGHGCSCTS
566
+ SGCCSNPACRVLNPNIC
567
+ RDCQEKWEYCIVPILGFVYCCPGLICGPAVCV
568
+ LPSCCSANLRLCPVPACKRNPCCT
569
+ APSCCSLNLRLCPVPACKRNPCCT
570
+ GCCSDPRCRYSHPEIC
571
+ GPPCCLYGSCRRYPGCYNALCCRK
572
+ SGCCSHPACRVNNPNIC
573
+ QGVCCGYKLCWPC
574
+ CKAKGSSCRRTSYDCCTGSCRNGKC
575
+ GCCSRPACAGIHQELC
576
+ PECCTHPACHVSAPELC
577
+ RAPCCYHPTCNMSNPQIC
578
+ GHCSDPRFAWRC
579
+ RDCQEKAEYCIVPILGFVYCCPGLICGPFVCV
580
+ CKAAGKPCSRLMYDCCTGSCRSGKC
581
+ GWCGDPGATCGKLRLYCCSGFCDSYTKTCKDKSSA
582
+ CKRKGSACRRTSYDCCTGSCRNGKC
583
+ GCCSDPRCRYRHPEIC
584
+ LPSCCSLNLRLCPAPACKRNPCCT
585
+ GCCSHPACSGNAPYFC
586
+ GCCSDPRCRYRCGGAAGG
587
+ DGVCCGYKLCHPC
588
+ RDCQEKWEYCIVPILGFVFCCPGLICGPFVCV
589
+ IRDECCSNPACRVNNPHVCRRR
590
+ LPSCCSLNLRLCPVPACKRAPCCT
591
+ HPSCCSLNLRLCPVPACKRNPCCT
592
+ GCCSDPRCRYRCY
593
+ QGVCCGYKICHPC
594
+ ECCNPACGRHYAC
595
+ CKGKGAKCSRLMYNCCTGSCRSGKC
596
+ GGVCCGYKLCHPC
597
+ CKGTGKSCSRIAYNCCTGSCRSGKC
598
+ CKRKGSSCRRTSYDCCTGSCRAGKC
599
+ AGVCCGYKLCHPC
600
+ QGVCCGYKRCHPC
601
+ GCCSHPVCTAMSPIC
602
+ CLSPGSSCSPTSYNCCRSCNPYSRKC
603
+ CKRKGSSCSRTSYDCCTGSCRNGKC
604
+ IRDECCSNPACRVNNPHNC
605
+ GGAGCCSHPVCAAMSPIC
606
+ RDAATPPKKAKDRQAKPQRAAA
607
+ RPSCCSLNLRLCPVPACKRNPCCT
608
+ GCCSHPACAANNQDYC
609
+ GCCSYPPCFATNPDCAGAGA
610
+ RGVCCGYKLCHPC
611
+ QGVCCGEKLCHPC
612
+ CKGKGSSCRRTSYDCCTGSCRNGKC
613
+ QGVCCGYWLCHPC
614
+ IRDECCSNPACRVNKPHVC
615
+ QGVCCGYKLCRPC
616
+ LPSCCSLNARLCPVPACKRNPCCT
617
+ CRIPNQKCFQHLDDCCSAKCNRFNKCV
618
+ KGVCCGYKLCHPC
619
+ CKSKGAKCSRLMYDCCSGSCSGTVGRC
620
+ QGVCCGYKLCYPC
621
+ CRIPNAKCFQHLDDCCSRKCNRFNKCV
622
+ CKGKGAKCSKLMYDCCTGSCRSGKC
623
+ FPRPRICNLACRAGIGHKYPFCHCR
624
+ QGVCCGYKQCHPC
625
+ CKGKGAKCSRIAYNCCTGSCRSGKC
626
+ CRIPNQKCAQHLDDCCSRKCNRFNKCV
627
+ IGVCCGYKLCHPC
628
+ RDCQEKWEYCIVPILGFAYCCPGLICGPFVCV
629
+ IPSCCSLNLRLCPVPACKRNPCCT
630
+ ECCAPACGRHYSC
631
+ QGVCCGYKPCHPC
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from utils import create_vocab, setup_seed
6
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
7
+
8
+ seed = random.randint(0,99999999)
9
+
10
+ setup_seed(seed)
11
+ device = torch.device("cpu")
12
+ vocab_mlm = create_vocab()
13
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
14
+ save_path = 'mlm-model-27.pt' #1
15
+ train_seqs = pd.read_csv('C0_seq.csv') #2
16
+ train_seq = train_seqs['Seq'].tolist()
17
+ model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
18
+ model = model.to(device)
19
+
20
+ def temperature_sampling(logits, temperature):
21
+ logits = logits / temperature
22
+ probabilities = torch.softmax(logits, dim=-1)
23
+ sampled_token = torch.multinomial(probabilities, 1)
24
+ return sampled_token
25
+
26
+ def CTXGen(τ, g_num, start, end):
27
+ X1 = "X"
28
+ X2 = "X"
29
+ X4 = ""
30
+ X5 = ""
31
+ X6 = ""
32
+ model.eval()
33
+ with torch.no_grad():
34
+ new_seq = None
35
+ generated_seqs = []
36
+ generated_seqs_FINAL = []
37
+ cls_pos_all = []
38
+ cls_probability_all = []
39
+ act_pos_all = []
40
+ act_probability_all = []
41
+
42
+ count = 0
43
+ gen_num = int(g_num)
44
+ NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
45
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
46
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
47
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
48
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
49
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
50
+
51
+ while count < gen_num:
52
+ gen_len = random.randint(int(start), int(end))
53
+ X3 = "X" * gen_len
54
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
55
+ vocab_mlm.token_to_idx["X"] = 4
56
+
57
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
58
+ input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
59
+
60
+ gen_length = len(input_text)
61
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
62
+
63
+ for i in range(length):
64
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
65
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
66
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
67
+ attn_idx = torch.tensor(attn_idx).to(device)
68
+
69
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
70
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
71
+
72
+ logits = model(idx_seq,idx_msa, attn_idx)
73
+ mask_logits = logits[0, mask_position.item(), :]
74
+
75
+ predicted_token_id = temperature_sampling(mask_logits, τ)
76
+
77
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
78
+ input_text[mask_position.item()] = predicted_token
79
+ padded_seq[mask_position.item()] = predicted_token.strip()
80
+ new_seq = padded_seq
81
+
82
+ generated_seq = input_text
83
+
84
+ generated_seq[1] = "[MASK]"
85
+ generated_seq[2] = "[MASK]"
86
+ input_ids = vocab_mlm.__getitem__(generated_seq)
87
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
88
+
89
+ cls_mask_logits = logits[0, 1, :]
90
+ act_mask_logits = logits[0, 2, :]
91
+
92
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
93
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
94
+
95
+ cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
96
+ act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]
97
+
98
+ cls_probability = cls_probability[0].item()
99
+ act_probability = act_probability[0].item()
100
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
101
+ if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
102
+ generated_seqs.append("".join(generated_seq))
103
+ if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
104
+ generated_seqs_FINAL.append("".join(generated_seq))
105
+ cls_pos_all.append(cls_pos)
106
+ cls_probability_all.append(cls_probability)
107
+ act_pos_all.append(act_pos)
108
+ act_probability_all.append(act_probability)
109
+ out = pd.DataFrame({'Generated_seq': generated_seqs_FINAL, 'Subtype': cls_pos_all, 'Subtype_probability': cls_probability_all, 'Potency': act_pos_all, 'Potency_probability': act_probability_all, 'random_seed': seed})
110
+ out.to_csv("output.csv", index=False)
111
+ count += 1
112
+ return 'output.csv'
113
+
114
+ iface = gr.Interface(
115
+ fn=CTXGen,
116
+ inputs=[
117
+ gr.Slider(minimum=1, maximum=2, step=0.01, label="τ"),
118
+ gr.Dropdown(choices=[1,10,100,1000], label="Number of generations"),
119
+ gr.Textbox(label="Min length"),
120
+ gr.Textbox(label="Max length")
121
+ ],
122
+ outputs=["file"]
123
+ )
124
+ iface.launch()
bertmodel.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ class Bert(nn.Module):
8
+
9
+ def __init__(self, encoder, src_embed):
10
+ super(Bert, self).__init__()
11
+
12
+ self.encoder = encoder
13
+ self.src_embed = src_embed
14
+
15
+ def forward(self, src, src_mask):
16
+
17
+ return self.encoder(self.src_embed(src), src_mask)
18
+
19
+
20
+ class Encoder(nn.Module):
21
+ "Encoder是N个EncoderLayer的堆积而成"
22
+ def __init__(self, layer, N):
23
+ super(Encoder, self).__init__()
24
+ #layer是一个SubLayer,我们clone N个
25
+ self.layers = clones(layer, N)
26
+ #再加一个LayerNorm层
27
+ self.norm = LayerNorm(layer.size)
28
+
29
+ def forward(self, x, mask):
30
+ "把输入(x,mask)被逐层处理"
31
+ for layer in self.layers:
32
+ x = layer(x, mask)
33
+ return self.norm(x) #N个EncoderLayer处理完成之后还需要一个LayerNorm
34
+
35
+ class LayerNorm(nn.Module):
36
+ "构建一个layernorm模型"
37
+ def __init__(self, features, eps=1e-6):
38
+ super(LayerNorm, self).__init__()
39
+ self.a_2 = nn.Parameter(torch.ones(features))
40
+ self.b_2 = nn.Parameter(torch.zeros(features))
41
+ self.eps = eps
42
+
43
+ def forward(self, x):
44
+ mean = x.mean(-1, keepdim=True)
45
+ std = x.std(-1, keepdim=True)
46
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
47
+
48
+ class SublayerConnection(nn.Module):
49
+ """
50
+ LayerNorm + sublayer(Self-Attenion/Dense) + dropout + 残差连接
51
+ 为了简单,把LayerNorm放到了前面,这和原始论文稍有不同,原始论文LayerNorm在最后
52
+ """
53
+ def __init__(self, size, dropout):
54
+ super(SublayerConnection, self).__init__()
55
+ self.norm = LayerNorm(size)
56
+ self.dropout = nn.Dropout(dropout)
57
+
58
+ def forward(self, x, sublayer):
59
+ #将残差连接应用于具有相同大小的任何子层
60
+ return x + self.dropout(sublayer(self.norm(x)))
61
+
62
+ class EncoderLayer(nn.Module):
63
+ "Encoder由self-attn and feed forward构成"
64
+ def __init__(self, size, self_attn, feed_forward, dropout):
65
+ super(EncoderLayer, self).__init__()
66
+ self.self_attn = self_attn
67
+ self.feed_forward = feed_forward
68
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
69
+ self.size = size
70
+
71
+ def forward(self, x, mask):
72
+ "如上图所示"
73
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
74
+ return self.sublayer[1](x, self.feed_forward)
75
+
76
+ class PositionwiseFeedForward(nn.Module):
77
+ "Implements FFN equation."
78
+ def __init__(self, d_model, d_ff, dropout=0.1):
79
+ super(PositionwiseFeedForward, self).__init__()
80
+ self.w_1 = nn.Linear(d_model, d_ff)
81
+ self.w_2 = nn.Linear(d_ff, d_model)
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ def forward(self, x):
85
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
86
+
87
+ def make_bert(src_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
88
+ "构建模型"
89
+ c = copy.deepcopy
90
+ attn = MultiHeadedAttention(h, d_model)
91
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
92
+ position = PositionalEncoding(d_model, dropout)
93
+ model = Bert(
94
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
95
+
96
+ nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
97
+ )
98
+
99
+ # 随机初始化参数,这非常重要用Glorot/fan_avg.
100
+ for p in model.parameters():
101
+ if p.dim() > 1:
102
+ nn.init.xavier_uniform_(p)
103
+ return model
104
+
105
+ def make_bert_without_emb(d_model=128, N=2, d_ff=512, h=8, dropout=0.1):
106
+ c = copy.deepcopy
107
+ attn = MultiHeadedAttention(h, d_model)
108
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
109
+ trainable_encoder = Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N)
110
+
111
+ return trainable_encoder
112
+
113
+
114
+
115
+ def clones(module, N):
116
+ "克隆N个完全相同的SubLayer,使用了copy.deepcopy"
117
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
118
+
119
+ def subsequent_mask(size):
120
+ "Mask out subsequent positions."
121
+ attn_shape = (1, size, size)
122
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
123
+ return torch.from_numpy(subsequent_mask) == 0
124
+
125
+ def attention(query, key, value, mask=None, dropout=None):
126
+ "计算 'Scaled Dot Product Attention'"
127
+ d_k = query.size(-1)
128
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
129
+ if mask is not None:
130
+ mask = mask.unsqueeze(-2)
131
+ scores = scores.masked_fill(mask == 0, -1e9)
132
+ p_attn = F.softmax(scores, dim = -1)
133
+ if dropout is not None:
134
+ p_attn = dropout(p_attn)
135
+ return torch.matmul(p_attn, value), p_attn
136
+
137
+ class MultiHeadedAttention(nn.Module):
138
+ def __init__(self, h, d_model, dropout=0.1):
139
+ "传入head个数及model的维度."
140
+ super(MultiHeadedAttention, self).__init__()
141
+ assert d_model % h == 0
142
+ # 这里假设d_v=d_k
143
+ self.d_k = d_model // h
144
+ self.h = h
145
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
146
+ self.attn = None
147
+ self.dropout = nn.Dropout(p=dropout)
148
+
149
+ def forward(self, query, key, value, mask=None):
150
+ "Implements Figure 2"
151
+ if mask is not None:
152
+ # 相同的mask适应所有的head.
153
+ mask = mask.unsqueeze(1)
154
+ nbatches = query.size(0)
155
+
156
+ # 1) 首先使用线性变换,然后把d_model分配给h个Head,每个head为d_k=d_model/h
157
+ query, key, value = \
158
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
159
+ for l, x in zip(self.linears, (query, key, value))]
160
+
161
+ # 2) 使用attention函数计算scaled-Dot-product-attention
162
+ x, self.attn = attention(query, key, value, mask=mask,
163
+ dropout=self.dropout)
164
+
165
+ # 3) 实现Multi-head attention,用view函数把8个head的64维向量拼接成一个512的向量。
166
+ #然后再使用一个线性变换(512,521),shape不变.
167
+ x = x.transpose(1, 2).contiguous() \
168
+ .view(nbatches, -1, self.h * self.d_k)
169
+ return self.linears[-1](x)
170
+
171
+ class Embeddings(nn.Module):
172
+ def __init__(self, d_model, vocab):
173
+ super(Embeddings, self).__init__()
174
+ self.lut = nn.Embedding(vocab, d_model)
175
+ self.d_model = d_model
176
+
177
+ def forward(self, x):
178
+ return self.lut(x) * math.sqrt(self.d_model)
179
+
180
+ class PositionalEncoding(nn.Module):
181
+ "实现PE函数"
182
+ def __init__(self, d_model, dropout, max_len=5000):
183
+ super(PositionalEncoding, self).__init__()
184
+ self.dropout = nn.Dropout(p=dropout)
185
+
186
+ # Compute the positional encodings once in log space.
187
+ pe = torch.zeros(max_len, d_model)
188
+ position = torch.arange(0, max_len).unsqueeze(1)
189
+ div_term = torch.exp(torch.arange(0, d_model, 2) *
190
+ -(math.log(10000.0) / d_model))
191
+ pe[:, 0::2] = torch.sin(position * div_term)
192
+ pe[:, 1::2] = torch.cos(position * div_term)
193
+ pe = pe.unsqueeze(0)
194
+ self.register_buffer('pe', pe)
195
+
196
+ def forward(self, x):
197
+ x = x + self.pe[:, :x.size(1)].clone().detach()
198
+ return self.dropout(x)
199
+
conoData_C5.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset_mlm.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ from torch.utils.data import TensorDataset, DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+
8
+ from vocab import PepVocab
9
+ from utils import mask, create_vocab
10
+
11
+ addtition_tokens = ['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
12
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
13
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
14
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
15
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
16
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>']
17
+
18
+ def add_tokens_to_vocab(vocab_mlm: PepVocab):
19
+ vocab_mlm.add_special_token(addtition_tokens)
20
+ return vocab_mlm
21
+
22
+ def split_seq(seq, vocab, get_seq=False):
23
+ '''
24
+ note: the function is suitable for the sequences with the format of "label|label|sequence|msa1|msa2|msa3"
25
+ '''
26
+ start = '[CLS]'
27
+ end = '[SEP]'
28
+ pad = '[PAD]'
29
+ cls_label = seq.split('|')[0]
30
+ act_label = seq.split('|')[1]
31
+
32
+ if get_seq == True:
33
+ add = lambda x: [start] + [cls_label] + [act_label] + x + [end]
34
+ pep_seq = seq.split('|')[2]
35
+ # return [start] + [cls_label] + [act_label] + vocab.split_seq(pep_seq) + [end]
36
+ return add(vocab.split_seq(pep_seq))
37
+
38
+ else:
39
+ add = lambda x: [start] + [pad] + [pad] + x + [end]
40
+ msa1_seq = seq.split('|')[3]
41
+ msa2_seq = seq.split('|')[4]
42
+ msa3_seq = seq.split('|')[5]
43
+
44
+ # return [vocab.split_seq(msa1_seq)] + [vocab.split_seq(msa2_seq)] + [vocab.split_seq(msa3_seq)]
45
+ return [add(vocab.split_seq(msa1_seq))] + [add(vocab.split_seq(msa2_seq))] + [add(vocab.split_seq(msa3_seq))]
46
+
47
+ def get_paded_token_idx(vocab_mlm):
48
+ cono_path = '/home/ubuntu/work/gecheng/conoGen_final/FinalCono/new_cycle/conoData_C5.csv'
49
+ seq = pd.read_csv(cono_path)['Sequences']
50
+
51
+ splited_seq = list(seq.apply(split_seq, args=(vocab_mlm,True, )))
52
+ splited_msa = list(seq.apply(split_seq, args=(vocab_mlm, False, )))
53
+
54
+ vocab_mlm.set_get_attn(is_get=True)
55
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
56
+ attn_idx = vocab_mlm.get_attention_mask_mat()
57
+
58
+ vocab_mlm.set_get_attn(is_get=False)
59
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
60
+
61
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
62
+
63
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
64
+
65
+ return padded_seq, idx_seq, idx_msa, attn_idx
66
+
67
+ def get_paded_token_idx_gen(vocab_mlm, seq):
68
+
69
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
70
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
71
+
72
+ vocab_mlm.set_get_attn(is_get=True)
73
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
74
+ attn_idx = vocab_mlm.get_attention_mask_mat()
75
+
76
+ vocab_mlm.set_get_attn(is_get=False)
77
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
78
+
79
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
80
+
81
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
82
+
83
+ return padded_seq, idx_seq, idx_msa, attn_idx
84
+
85
+
86
+ def get_paded_token_idx_gen(vocab_mlm, seq, new_seq):
87
+ if new_seq == None:
88
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
89
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
90
+
91
+ vocab_mlm.set_get_attn(is_get=True)
92
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
93
+ attn_idx = vocab_mlm.get_attention_mask_mat()
94
+ vocab_mlm.set_get_attn(is_get=False)
95
+
96
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
97
+
98
+ idx_seq = vocab_mlm.__getitem__(padded_seq) # [b, 54] start, cls_label, act_label, sequence, end
99
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
100
+ else:
101
+ splited_seq = split_seq(seq[0], vocab_mlm, True)
102
+ splited_msa = split_seq(seq[0], vocab_mlm, False)
103
+ vocab_mlm.set_get_attn(is_get=True)
104
+ padded_seq = vocab_mlm.truncate_pad(splited_seq, num_steps=54, padding_token='[PAD]')
105
+ attn_idx = vocab_mlm.get_attention_mask_mat()
106
+ vocab_mlm.set_get_attn(is_get=False)
107
+ padded_msa = vocab_mlm.truncate_pad(splited_msa, num_steps=54, padding_token='[PAD]')
108
+ idx_msa = vocab_mlm.__getitem__(padded_msa) # [b, 3, 50]
109
+
110
+ idx_seq = vocab_mlm.__getitem__(new_seq)
111
+ return padded_seq, idx_seq, idx_msa, attn_idx
112
+
113
+
114
+
115
+ def make_mask(seq_ser, start, end, time, vocab_mlm, labels, idx_msa, attn_idx):
116
+ seq_ser = pd.Series(seq_ser)
117
+ masked_seq = seq_ser.apply(mask, args=(start, end, time))
118
+ masked_idx = vocab_mlm.__getitem__(list(masked_seq))
119
+ masked_idx = torch.tensor(masked_idx)
120
+ device = torch.device('cuda:1')
121
+ data_arrays = (masked_idx.to(device), labels.to(device), idx_msa.to(device), attn_idx.to(device))
122
+ dataset = TensorDataset(*data_arrays)
123
+ train_dataset, test_dataset = train_test_split(dataset, test_size=0.1, random_state=42, shuffle=True)
124
+ train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
125
+ test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)
126
+
127
+ return train_loader, test_loader
128
+
129
+ if __name__ == '__main__':
130
+ # from add_args import parse_args
131
+ import numpy as np
132
+ # args = parse_args()
133
+
134
+ vocab_mlm = create_vocab()
135
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
136
+ padded_seq, idx_seq, idx_msa, attn_idx = get_paded_token_idx(vocab_mlm)
137
+ labels = torch.tensor(idx_seq)
138
+ idx_msa = torch.tensor(idx_msa)
139
+ attn_idx = torch.tensor(attn_idx)
140
+
141
+ # time_step = args.mask_time_step
142
+ for t in np.arange(1, 50):
143
+ padded_seq_copy = deepcopy(padded_seq)
144
+ train_loader, test_loader = make_mask(padded_seq_copy, start=0, end=49, time=t,
145
+ vocab_mlm=vocab_mlm, labels=labels, idx_msa=idx_msa, attn_idx=attn_idx)
146
+ for i, (masked_idx, label, msa, attn) in enumerate(train_loader):
147
+ print(f"the {i}th batch is that masked_idx is {masked_idx.shape}, labels is {label.shape}, idx_msa is {msa.shape}")
148
+ print(f"the {t}th time step is done")
149
+
150
+
151
+
model.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModelForMaskedLM, AutoConfig
7
+
8
+ from bertmodel import make_bert, make_bert_without_emb
9
+ from utils import ContraLoss
10
+
11
+ def load_pretrained_model():
12
+ # model_checkpoint = "/home/ubuntu/work/zq/conoMLM/prot_bert/prot_bert"
13
+ model_checkpoint = "/home/ubuntu/work/gecheng/conoGen_final/FinalCono/MLM/prot_bert_finetuned_model_mlm_best"
14
+ config = AutoConfig.from_pretrained(model_checkpoint)
15
+ model = AutoModelForMaskedLM.from_config(config)
16
+
17
+ return model
18
+
19
+ class ConoEncoder(nn.Module):
20
+ def __init__(self, encoder):
21
+ super(ConoEncoder, self).__init__()
22
+
23
+ self.encoder = encoder
24
+ self.trainable_encoder = make_bert_without_emb()
25
+
26
+
27
+ for param in self.encoder.parameters():
28
+ param.requires_grad = False
29
+
30
+
31
+ def forward(self, x, mask): # x:(128,54) mask:(128,54)
32
+ feat = self.encoder(x, attention_mask=mask) # (128,54,128)
33
+ feat = list(feat.values())[0] # (128,54,128)
34
+
35
+ feat = self.trainable_encoder(feat, mask) # (128,54,128)
36
+
37
+ return feat
38
+
39
+ class MSABlock(nn.Module):
40
+ def __init__(self, in_dim, out_dim, vocab_size):
41
+ super(MSABlock, self).__init__()
42
+ self.embedding = nn.Embedding(vocab_size, in_dim)
43
+ self.mlp = nn.Sequential(
44
+ nn.Linear(in_dim, out_dim),
45
+ nn.LeakyReLU(),
46
+ nn.Linear(out_dim, out_dim)
47
+ )
48
+ self.init()
49
+
50
+ def init(self):
51
+ for layer in self.mlp.children():
52
+ if isinstance(layer, nn.Linear):
53
+ nn.init.xavier_uniform_(layer.weight)
54
+ # nn.init.xavier_uniform_(self.embedding.weight)
55
+
56
+ def forward(self, x): # x: (128,3,54)
57
+ x = self.embedding(x) # x: (128,3,54,128)
58
+ x = self.mlp(x) # x: (128,3,54,128)
59
+ return x
60
+
61
+ class ConoModel(nn.Module):
62
+ def __init__(self, encoder, msa_block, decoder):
63
+ super(ConoModel, self).__init__()
64
+ self.encoder = encoder
65
+ self.msa_block = msa_block
66
+ self.feature_combine = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=1)
67
+ self.decoder = decoder
68
+
69
+ def forward(self, input_ids, msa, attn_idx=None):
70
+ # 仅使用 input_ids 作为输入,获取编码器输出
71
+ encoder_output = self.encoder.forward(input_ids, attn_idx) # (128,54,128)
72
+ msa_output = self.msa_block(msa) # (128,3,54,128)
73
+ # msa_output = torch.mean(msa_output, dim=1)
74
+ encoder_output = encoder_output.view(input_ids.shape[0], 54, -1).unsqueeze(1) # (128,1,54,128)
75
+
76
+ output = torch.cat([encoder_output*5, msa_output], dim=1) # (128,4,54,128)
77
+ output = self.feature_combine(output) # (128,1,54,128)
78
+ output = output.squeeze(1) # (128,54,128)
79
+ # 解码器对编码器的输出进行解码
80
+ logits = self.decoder(output) # (128,54,85)
81
+
82
+ return logits
83
+
84
+ class ContraModel(nn.Module):
85
+ def __init__(self, cono_encoder):
86
+ super(ContraModel, self).__init__()
87
+
88
+ self.contra_loss = ContraLoss()
89
+
90
+ self.encoder1 = cono_encoder
91
+ self.encoder2 = make_bert(404, 6, 128)
92
+
93
+ # contrastive decoder
94
+ self.lstm = nn.LSTM(16, 16, batch_first=True)
95
+ self.contra_decoder = nn.Sequential(
96
+ nn.Linear(128, 64),
97
+ nn.LeakyReLU(),
98
+ nn.Linear(64, 32),
99
+ nn.LeakyReLU(),
100
+ nn.Linear(32, 16),
101
+ nn.LeakyReLU(),
102
+ nn.Dropout(0.1),
103
+ )
104
+
105
+ # classifier
106
+ self.pre_classifer = nn.LSTM(128, 64, batch_first=True)
107
+ self.classifer = nn.Sequential(
108
+ nn.Linear(128, 32),
109
+ nn.LeakyReLU(),
110
+ nn.Linear(32, 6),
111
+ nn.Softmax(dim=-1)
112
+ )
113
+
114
+ self.init()
115
+
116
+ def init(self):
117
+
118
+ for layer in self.contra_decoder.children():
119
+ if isinstance(layer, nn.Linear):
120
+ nn.init.xavier_uniform_(layer.weight)
121
+ for layer in self.classifer.children():
122
+ if isinstance(layer, nn.Linear):
123
+ nn.init.xavier_uniform_(layer.weight)
124
+ for layer in self.pre_classifer.children():
125
+ if isinstance(layer, nn.Linear):
126
+ nn.init.xavier_uniform_(layer.weight)
127
+ for layer in self.lstm.children():
128
+ if isinstance(layer, nn.Linear):
129
+ nn.init.xavier_uniform_(layer.weight)
130
+
131
+ def compute_class_loss(self, feat1, feat2, labels):
132
+ _, cls_feat1= self.pre_classifer(feat1)
133
+ _, cls_feat2 = self.pre_classifer(feat2)
134
+ cls_feat1 = torch.cat([cls_feat1[0], cls_feat1[1]], dim=-1).squeeze(0)
135
+ cls_feat2 = torch.cat([cls_feat2[0], cls_feat2[1]], dim=-1).squeeze(0)
136
+
137
+ cls1_dis = self.classifer(cls_feat1)
138
+ cls2_dis = self.classifer(cls_feat2)
139
+ cls1_loss = F.cross_entropy(cls1_dis, labels.to('cuda:0'))
140
+ cls2_loss = F.cross_entropy(cls2_dis, labels.to('cuda:0'))
141
+
142
+ return cls1_loss, cls2_loss
143
+
144
+ def compute_contrastive_loss(self, feat1, feat2):
145
+
146
+ contra_feat1 = self.contra_decoder(feat1)
147
+ contra_feat2 = self.contra_decoder(feat2)
148
+
149
+ _, feat1 = self.lstm(contra_feat1)
150
+ _, feat2 = self.lstm(contra_feat2)
151
+ feat1 = torch.cat([feat1[0], feat1[1]], dim=-1).squeeze(0)
152
+ feat2 = torch.cat([feat2[0], feat2[1]], dim=-1).squeeze(0)
153
+
154
+ ctr_loss = self.contra_loss(feat1, feat2)
155
+
156
+ return ctr_loss
157
+
158
+ def forward(self, x1, x2, labels=None):
159
+ loss = dict()
160
+
161
+ idx1, attn1 = x1
162
+ idx2, attn2 = x2
163
+ feat1 = self.encoder1(idx1.to('cuda:0'), attn1.to('cuda:0'))
164
+ feat2 = self.encoder2(idx2.to('cuda:0'), attn2.to('cuda:0'))
165
+
166
+ cls1_loss, cls2_loss = self.compute_class_loss(feat1, feat2, labels)
167
+
168
+ ctr_loss = self.compute_contrastive_loss(feat1, feat2)
169
+
170
+ loss['cls1_loss'] = cls1_loss
171
+ loss['cls2_loss'] = cls2_loss
172
+ loss['ctr_loss'] = ctr_loss
173
+
174
+ return loss
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ transformers
3
+ torch
4
+ pandas
utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import copy, math
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from vocab import PepVocab
8
+
9
+ def create_vocab():
10
+ vocab_mlm = PepVocab()
11
+ vocab_mlm.vocab_from_txt('/home/ubuntu/work/gecheng/conoGen_final/vocab.txt')
12
+ # vocab_mlm.token_to_idx['-'] = 23
13
+ return vocab_mlm
14
+
15
+ def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
16
+
17
+ mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
18
+
19
+ if show_all:
20
+ print('All parameters:')
21
+ print(mlp_pa)
22
+
23
+ if show_trainable:
24
+ print('Trainable parameters:')
25
+ print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
26
+
27
+ class ContraLoss(nn.Module):
28
+ def __init__(self, *args, **kwargs) -> None:
29
+ super(ContraLoss, self).__init__(*args, **kwargs)
30
+
31
+ self.temp = 0.07
32
+
33
+ def contrastive_loss(self, proj1, proj2):
34
+ proj1 = F.normalize(proj1, dim=1)
35
+ proj2 = F.normalize(proj2, dim=1)
36
+ dot = torch.matmul(proj1, proj2.T) / self.temp
37
+ dot_max, _ = torch.max(dot, dim=1, keepdim=True)
38
+ dot = dot - dot_max.detach()
39
+
40
+ exp_dot = torch.exp(dot)
41
+ log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
42
+ cont_loss = -log_prob.mean()
43
+ return cont_loss
44
+
45
+ def forward(self, x, y, label=None):
46
+ return self.contrastive_loss(x, y)
47
+
48
+
49
+ import numpy as np
50
+ from tqdm import tqdm
51
+ import torch
52
+ import torch.nn as nn
53
+ import random
54
+ from transformers import set_seed
55
+
56
+ def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
57
+
58
+ mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
59
+
60
+ if show_all:
61
+ print('All parameters:')
62
+ print(mlp_pa)
63
+
64
+ if show_trainable:
65
+ print('Trainable parameters:')
66
+ print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
67
+
68
+ def extract_args(text):
69
+ str_list = []
70
+ substr = ""
71
+ for s in text:
72
+ if s in ('(', ')', '=', ',', ' ', '\n', "'"):
73
+ if substr != '':
74
+ str_list.append(substr)
75
+ substr = ''
76
+ else:
77
+ substr += s
78
+
79
+ def eval_one_epoch(loader, cono_encoder):
80
+ cono_encoder.eval()
81
+ batch_loss = []
82
+ for i, data in enumerate(tqdm(loader)):
83
+
84
+ loss = cono_encoder.contra_forward(data)
85
+ batch_loss.append(loss.item())
86
+ print(f'[INFO] Test batch {i} loss: {loss.item()}')
87
+
88
+ total_loss = np.mean(batch_loss)
89
+ print(f'[INFO] Total loss: {total_loss}')
90
+ return total_loss
91
+
92
+ def setup_seed(seed):
93
+ torch.manual_seed(seed)
94
+ torch.cuda.manual_seed_all(seed)
95
+ np.random.seed(seed)
96
+ random.seed(seed)
97
+ torch.backends.cudnn.deterministic = True
98
+ set_seed(seed)
99
+
100
+ class CrossEntropyLossWithMask(torch.nn.Module):
101
+ def __init__(self, weight=None):
102
+ super(CrossEntropyLossWithMask, self).__init__()
103
+ self.criterion = nn.CrossEntropyLoss(reduction='none')
104
+
105
+ def forward(self, y_pred, y_true, mask):
106
+ (pos_mask, label_mask, seq_mask) = mask
107
+ loss = self.criterion(y_pred, y_true) # (6912)
108
+
109
+ pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
110
+ label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
111
+ seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
112
+
113
+ loss = pos_loss + label_loss/2 + seq_loss/3
114
+
115
+ return loss
116
+
117
+
118
+ def mask(x, start, end, time):
119
+ ske_pos = np.where(np.array(x)=='C')[0] - start
120
+ lables_pos = np.array([1, 2]) - start
121
+ ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
122
+ lables_pos = list(filter(lambda x: x >= 0, lables_pos))
123
+ weight = np.ones(end - start+1)
124
+ rand = np.random.rand()
125
+ if rand < 0.5:
126
+ weight[lables_pos] = 100000
127
+ else:
128
+ weight[lables_pos] = 1
129
+ mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
130
+ for idx in mask_pos:
131
+ x[idx] = '[MASK]'
132
+ return x
vocab.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+
4
+ class PepVocab:
5
+ def __init__(self):
6
+ self.token_to_idx = {
7
+ '<MASK>': -1, '<PAD>': 0, 'A': 1, 'C': 2, 'E': 3, 'D': 4, 'F': 5, 'I': 6, 'H': 7,
8
+ 'K': 8, 'M': 9, 'L': 10, 'N': 11, 'Q': 12, 'P': 13, 'S': 14,
9
+ 'R': 15, 'T': 16, 'W': 17, 'V': 18, 'Y': 19, 'G': 20, 'O': 21, 'U': 22, 'Z': 23, 'X': 24}
10
+ self.idx_to_token = {
11
+ -1: '<MASK>', 0: '<PAD>', 1: 'A', 2: 'C', 3: 'E', 4: 'D', 5: 'F', 6: 'I', 7: 'H',
12
+ 8: 'K', 9: 'M', 10: 'L', 11: 'N', 12: 'Q', 13: 'P', 14: 'S',
13
+ 15: 'R', 16: 'T', 17: 'W', 18: 'V', 19: 'Y', 20: 'G', 21: 'O', 22: 'U', 23: 'Z', 24: 'X'}
14
+
15
+ self.get_attention_mask = False
16
+ self.attention_mask = []
17
+
18
+ def set_get_attn(self, is_get: bool):
19
+ self.get_attention_mask = is_get
20
+
21
+ def __len__(self):
22
+ return len(self.idx_to_token)
23
+
24
+ def __getitem__(self, tokens):
25
+ '''
26
+ note: input should a splited sequence
27
+
28
+ Args:
29
+ tokens: a token or token list of splited
30
+ '''
31
+ if not isinstance(tokens, (list, tuple)):
32
+ # return self.token_to_idx.get(tokens)
33
+ return self.token_to_idx[tokens]
34
+ return [self.__getitem__(token) for token in tokens]
35
+
36
+ def vocab_from_txt(self, path):
37
+ '''
38
+ note: this function use for constructing vocab mapping
39
+ but it is only suitable for special txt format
40
+ it support one column txt file, which column name is 0
41
+ '''
42
+ token_to_idx = {}
43
+ idx_to_token = {}
44
+ chr_idx = pd.read_csv(path, header=None, sep='\t')
45
+ if chr_idx.shape[1] == 1:
46
+ for idx, token in enumerate(chr_idx[0]):
47
+ token_to_idx[token] = idx
48
+ idx_to_token[idx] = token
49
+ self.token_to_idx = token_to_idx
50
+ self.idx_to_token = idx_to_token
51
+
52
+ def to_tokens(self, indices):
53
+ '''
54
+ note: input should a integer list
55
+ '''
56
+ if hasattr(indices, '__len__') and len(indices) > 1:
57
+ return [self.idx_to_token[int(index)] for index in indices]
58
+ return self.idx_to_token[indices]
59
+
60
+ def add_special_token(self, token: str|list|tuple) -> None:
61
+ if not isinstance(token, (list, tuple)):
62
+ if token in self.token_to_idx:
63
+ raise ValueError(f"token {token} already in the vocab")
64
+ self.idx_to_token[len(self.idx_to_token)] = token
65
+ self.token_to_idx[token] = len(self.token_to_idx)
66
+ else:
67
+ [self.add_special_token(t) for t in token]
68
+
69
+ def split_seq(self, seq: str|list|tuple) -> list:
70
+ if not isinstance(seq, (list, tuple)):
71
+ return re.findall(r"<[a-zA-Z0-9]+>|[a-zA-Z-]", seq)
72
+ return [self.split_seq(s) for s in seq] # a list of list
73
+
74
+ def truncate_pad(self, line, num_steps, padding_token='<PAD>') -> list:
75
+
76
+ if not isinstance(line[0], list):
77
+ if len(line) > num_steps:
78
+ if self.get_attention_mask:
79
+ self.attention_mask.append([1]*num_steps)
80
+ return line[:num_steps]
81
+ if self.get_attention_mask:
82
+ self.attention_mask.append([1] * len(line) + [0] * (num_steps - len(line)))
83
+ return line + [padding_token] * (num_steps - len(line))
84
+ else:
85
+ return [self.truncate_pad(l, num_steps, padding_token) for l in line] # a list of list
86
+
87
+ def get_attention_mask_mat(self):
88
+ attention_mask = self.attention_mask
89
+ self.attention_mask = []
90
+ return attention_mask
91
+
92
+ def seq_to_idx(self, seq: str|list|tuple, num_steps: int, padding_token='<PAD>') -> list:
93
+ '''
94
+ note: ensure to execut this function after add_special_token
95
+ '''
96
+
97
+ splited_seq = self.split_seq(seq)
98
+ # **********************
99
+ # after split, we need to mask sequence
100
+ # note:
101
+ # 1. mask tokens by probability
102
+ # 2. return a list or list of list
103
+ # **********************
104
+ padded_seq = self.truncate_pad(splited_seq, num_steps, padding_token)
105
+
106
+ return self.__getitem__(padded_seq)
107
+
108
+
109
+
110
+ class MutilVocab:
111
+ def __init__(self, data, AA_tok_len=2):
112
+ """
113
+ Args:
114
+ data (_type_):
115
+ AA_tok_len (int, optional): Defaults to 1.
116
+ start_token (bool, optional): True is required for encoder-based model.
117
+ """
118
+ ## Load train dataset
119
+ self.x_data = data
120
+ self.tok_AA_len = AA_tok_len
121
+ self.default_AA = list("RHKDESTNQCGPAVILMFYW")
122
+ # AAs which are not included in default_AA
123
+ self.tokens = self._token_gen(self.tok_AA_len)
124
+
125
+ self.token_to_idx = {k: i + 4 for i, k in enumerate(self.tokens)}
126
+ self.token_to_idx["[PAD]"] = 0 ## idx as 0 is PAD
127
+ self.token_to_idx["[CLS]"] = 1 ## idx as 1 is CLS
128
+ self.token_to_idx["[SEP]"] = 2 ## idx as 2 is SEP
129
+ self.token_to_idx["[MASK]"] = 3 ## idx as 3 is MASK
130
+
131
+ def split_seq(self):
132
+ self.X = [self._seq_to_tok(seq) for seq in self.x_data]
133
+ return self.X
134
+
135
+ def tok_idx(self, seqs):
136
+ '''
137
+ note: ensure to execut this function before truancate_pad
138
+ '''
139
+
140
+ seqs_idx = []
141
+ for seq in seqs:
142
+ seq_idx = []
143
+ for s in seq:
144
+ seq_idx.append(self.token_to_idx[s])
145
+ seqs_idx.append(seq_idx)
146
+
147
+ return seqs_idx
148
+
149
+
150
+
151
+ def _token_gen(self, tok_AA_len: int, st: str = "", curr_depth: int = 0):
152
+ """Generate tokens based on default amino acid residues
153
+ and also includes "X" as arbitrary residues.
154
+ Length of AAs in each token should be provided by "tok_AA_len"
155
+
156
+ Args:
157
+ tok_AA_len (int): Length of token
158
+ st (str, optional): Defaults to ''.
159
+ curr_depth (int, optional): Defaults to 0.
160
+
161
+ Returns:
162
+ List: List of tokens
163
+ """
164
+ curr_depth += 1
165
+ if curr_depth <= tok_AA_len:
166
+ l = [
167
+ st + t
168
+ for s in self.default_AA
169
+ for t in self._token_gen(tok_AA_len, s, curr_depth)
170
+ ]
171
+ return l
172
+ else:
173
+ return [st]
174
+
175
+ def _seq_to_tok(self, seq: str):
176
+ """Convert each token to index
177
+
178
+ Args:
179
+ seq (str): AA sequence
180
+
181
+ Returns:
182
+ list: A list of indexes
183
+ """
184
+
185
+ seq_idx = []
186
+
187
+ seq_idx += ["[CLS]"]
188
+
189
+ for i in range(len(seq) - self.tok_AA_len + 1):
190
+ curr_token = seq[i : i + self.tok_AA_len]
191
+ seq_idx.append(curr_token)
192
+ seq_idx += ['[SEP]']
193
+ return seq_idx
vocab.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ L
7
+ A
8
+ G
9
+ V
10
+ E
11
+ S
12
+ I
13
+ K
14
+ R
15
+ D
16
+ T
17
+ P
18
+ N
19
+ Q
20
+ F
21
+ Y
22
+ M
23
+ H
24
+ C
25
+ W
26
+ X
27
+ U
28
+ B
29
+ Z
30
+ O