oucgc1996's picture
Upload 5 files
e13ac09 verified
import torch.nn as nn
import copy, math
import torch
import numpy as np
import torch.nn.functional as F
from vocab import PepVocab
def create_vocab():
vocab_mlm = PepVocab()
vocab_mlm.vocab_from_txt('vocab.txt')
# vocab_mlm.token_to_idx['-'] = 23
return vocab_mlm
def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
if show_all:
print('All parameters:')
print(mlp_pa)
if show_trainable:
print('Trainable parameters:')
print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
class ContraLoss(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(ContraLoss, self).__init__(*args, **kwargs)
self.temp = 0.07
def contrastive_loss(self, proj1, proj2):
proj1 = F.normalize(proj1, dim=1)
proj2 = F.normalize(proj2, dim=1)
dot = torch.matmul(proj1, proj2.T) / self.temp
dot_max, _ = torch.max(dot, dim=1, keepdim=True)
dot = dot - dot_max.detach()
exp_dot = torch.exp(dot)
log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
cont_loss = -log_prob.mean()
return cont_loss
def forward(self, x, y, label=None):
return self.contrastive_loss(x, y)
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import random
from transformers import set_seed
def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
if show_all:
print('All parameters:')
print(mlp_pa)
if show_trainable:
print('Trainable parameters:')
print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
def extract_args(text):
str_list = []
substr = ""
for s in text:
if s in ('(', ')', '=', ',', ' ', '\n', "'"):
if substr != '':
str_list.append(substr)
substr = ''
else:
substr += s
def eval_one_epoch(loader, cono_encoder):
cono_encoder.eval()
batch_loss = []
for i, data in enumerate(tqdm(loader)):
loss = cono_encoder.contra_forward(data)
batch_loss.append(loss.item())
print(f'[INFO] Test batch {i} loss: {loss.item()}')
total_loss = np.mean(batch_loss)
print(f'[INFO] Total loss: {total_loss}')
return total_loss
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
set_seed(seed)
class CrossEntropyLossWithMask(torch.nn.Module):
def __init__(self, weight=None):
super(CrossEntropyLossWithMask, self).__init__()
self.criterion = nn.CrossEntropyLoss(reduction='none')
def forward(self, y_pred, y_true, mask):
(pos_mask, label_mask, seq_mask) = mask
loss = self.criterion(y_pred, y_true) # (6912)
pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
loss = pos_loss + label_loss/2 + seq_loss/3
return loss
def mask(x, start, end, time):
ske_pos = np.where(np.array(x)=='C')[0] - start
lables_pos = np.array([1, 2]) - start
ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
lables_pos = list(filter(lambda x: x >= 0, lables_pos))
weight = np.ones(end - start+1)
rand = np.random.rand()
if rand < 0.5:
weight[lables_pos] = 100000
else:
weight[lables_pos] = 1
mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
for idx in mask_pos:
x[idx] = '[MASK]'
return x