File size: 3,940 Bytes
e4fcf38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch.nn as nn
import copy, math
import torch
import numpy as np
import torch.nn.functional as F
from vocab import PepVocab
def create_vocab():
vocab_mlm = PepVocab()
vocab_mlm.vocab_from_txt('vocab.txt')
# vocab_mlm.token_to_idx['-'] = 23
return vocab_mlm
def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
if show_all:
print('All parameters:')
print(mlp_pa)
if show_trainable:
print('Trainable parameters:')
print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
class ContraLoss(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(ContraLoss, self).__init__(*args, **kwargs)
self.temp = 0.07
def contrastive_loss(self, proj1, proj2):
proj1 = F.normalize(proj1, dim=1)
proj2 = F.normalize(proj2, dim=1)
dot = torch.matmul(proj1, proj2.T) / self.temp
dot_max, _ = torch.max(dot, dim=1, keepdim=True)
dot = dot - dot_max.detach()
exp_dot = torch.exp(dot)
log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
cont_loss = -log_prob.mean()
return cont_loss
def forward(self, x, y, label=None):
return self.contrastive_loss(x, y)
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import random
from transformers import set_seed
def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
if show_all:
print('All parameters:')
print(mlp_pa)
if show_trainable:
print('Trainable parameters:')
print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
def extract_args(text):
str_list = []
substr = ""
for s in text:
if s in ('(', ')', '=', ',', ' ', '\n', "'"):
if substr != '':
str_list.append(substr)
substr = ''
else:
substr += s
def eval_one_epoch(loader, cono_encoder):
cono_encoder.eval()
batch_loss = []
for i, data in enumerate(tqdm(loader)):
loss = cono_encoder.contra_forward(data)
batch_loss.append(loss.item())
print(f'[INFO] Test batch {i} loss: {loss.item()}')
total_loss = np.mean(batch_loss)
print(f'[INFO] Total loss: {total_loss}')
return total_loss
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
set_seed(seed)
class CrossEntropyLossWithMask(torch.nn.Module):
def __init__(self, weight=None):
super(CrossEntropyLossWithMask, self).__init__()
self.criterion = nn.CrossEntropyLoss(reduction='none')
def forward(self, y_pred, y_true, mask):
(pos_mask, label_mask, seq_mask) = mask
loss = self.criterion(y_pred, y_true) # (6912)
pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
loss = pos_loss + label_loss/2 + seq_loss/3
return loss
def mask(x, start, end, time):
ske_pos = np.where(np.array(x)=='C')[0] - start
lables_pos = np.array([1, 2]) - start
ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
lables_pos = list(filter(lambda x: x >= 0, lables_pos))
weight = np.ones(end - start+1)
rand = np.random.rand()
if rand < 0.5:
weight[lables_pos] = 100000
else:
weight[lables_pos] = 1
mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
for idx in mask_pos:
x[idx] = '[MASK]'
return x |