oucgc1996 commited on
Commit
2d573ef
·
verified ·
1 Parent(s): 5ab3f8b

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -132
utils.py DELETED
@@ -1,132 +0,0 @@
1
- import torch.nn as nn
2
- import copy, math
3
- import torch
4
- import numpy as np
5
- import torch.nn.functional as F
6
-
7
- from vocab import PepVocab
8
-
9
- def create_vocab():
10
- vocab_mlm = PepVocab()
11
- vocab_mlm.vocab_from_txt('/home/ubuntu/work/gecheng/conoGen_final/vocab.txt')
12
- # vocab_mlm.token_to_idx['-'] = 23
13
- return vocab_mlm
14
-
15
- def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
16
-
17
- mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
18
-
19
- if show_all:
20
- print('All parameters:')
21
- print(mlp_pa)
22
-
23
- if show_trainable:
24
- print('Trainable parameters:')
25
- print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
26
-
27
- class ContraLoss(nn.Module):
28
- def __init__(self, *args, **kwargs) -> None:
29
- super(ContraLoss, self).__init__(*args, **kwargs)
30
-
31
- self.temp = 0.07
32
-
33
- def contrastive_loss(self, proj1, proj2):
34
- proj1 = F.normalize(proj1, dim=1)
35
- proj2 = F.normalize(proj2, dim=1)
36
- dot = torch.matmul(proj1, proj2.T) / self.temp
37
- dot_max, _ = torch.max(dot, dim=1, keepdim=True)
38
- dot = dot - dot_max.detach()
39
-
40
- exp_dot = torch.exp(dot)
41
- log_prob = torch.diag(dot, 0) - torch.log(exp_dot.sum(1))
42
- cont_loss = -log_prob.mean()
43
- return cont_loss
44
-
45
- def forward(self, x, y, label=None):
46
- return self.contrastive_loss(x, y)
47
-
48
-
49
- import numpy as np
50
- from tqdm import tqdm
51
- import torch
52
- import torch.nn as nn
53
- import random
54
- from transformers import set_seed
55
-
56
- def show_parameters(model: nn.Module, show_all=False, show_trainable=True):
57
-
58
- mlp_pa = {name:param.requires_grad for name, param in model.named_parameters()}
59
-
60
- if show_all:
61
- print('All parameters:')
62
- print(mlp_pa)
63
-
64
- if show_trainable:
65
- print('Trainable parameters:')
66
- print(list(filter(lambda x: x[1], list(mlp_pa.items()))))
67
-
68
- def extract_args(text):
69
- str_list = []
70
- substr = ""
71
- for s in text:
72
- if s in ('(', ')', '=', ',', ' ', '\n', "'"):
73
- if substr != '':
74
- str_list.append(substr)
75
- substr = ''
76
- else:
77
- substr += s
78
-
79
- def eval_one_epoch(loader, cono_encoder):
80
- cono_encoder.eval()
81
- batch_loss = []
82
- for i, data in enumerate(tqdm(loader)):
83
-
84
- loss = cono_encoder.contra_forward(data)
85
- batch_loss.append(loss.item())
86
- print(f'[INFO] Test batch {i} loss: {loss.item()}')
87
-
88
- total_loss = np.mean(batch_loss)
89
- print(f'[INFO] Total loss: {total_loss}')
90
- return total_loss
91
-
92
- def setup_seed(seed):
93
- torch.manual_seed(seed)
94
- torch.cuda.manual_seed_all(seed)
95
- np.random.seed(seed)
96
- random.seed(seed)
97
- torch.backends.cudnn.deterministic = True
98
- set_seed(seed)
99
-
100
- class CrossEntropyLossWithMask(torch.nn.Module):
101
- def __init__(self, weight=None):
102
- super(CrossEntropyLossWithMask, self).__init__()
103
- self.criterion = nn.CrossEntropyLoss(reduction='none')
104
-
105
- def forward(self, y_pred, y_true, mask):
106
- (pos_mask, label_mask, seq_mask) = mask
107
- loss = self.criterion(y_pred, y_true) # (6912)
108
-
109
- pos_loss = (loss * pos_mask).sum() / torch.sum(pos_mask)
110
- label_loss = (loss * label_mask).sum() / torch.sum(label_mask)
111
- seq_loss = (loss * seq_mask).sum() / torch.sum(seq_mask)
112
-
113
- loss = pos_loss + label_loss/2 + seq_loss/3
114
-
115
- return loss
116
-
117
-
118
- def mask(x, start, end, time):
119
- ske_pos = np.where(np.array(x)=='C')[0] - start
120
- lables_pos = np.array([1, 2]) - start
121
- ske_pos = list(filter(lambda x: end-start >= x >= 0, ske_pos))
122
- lables_pos = list(filter(lambda x: x >= 0, lables_pos))
123
- weight = np.ones(end - start+1)
124
- rand = np.random.rand()
125
- if rand < 0.5:
126
- weight[lables_pos] = 100000
127
- else:
128
- weight[lables_pos] = 1
129
- mask_pos = np.random.choice(range(start, end+1), time, p=weight/np.sum(weight), replace=False)
130
- for idx in mask_pos:
131
- x[idx] = '[MASK]'
132
- return x