oucgc1996 commited on
Commit
2452cd6
·
verified ·
1 Parent(s): e13ac09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -123
app.py CHANGED
@@ -1,124 +1,124 @@
1
- import random
2
- import torch
3
- import gradio as gr
4
- import pandas as pd
5
- from utils import create_vocab, setup_seed
6
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
7
-
8
- seed = random.randint(0,99999999)
9
-
10
- setup_seed(seed)
11
- device = torch.device("cpu")
12
- vocab_mlm = create_vocab()
13
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
14
- save_path = 'mlm-model-27.pt' #1
15
- train_seqs = pd.read_csv('C0_seq.csv') #2
16
- train_seq = train_seqs['Seq'].tolist()
17
- model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
18
- model = model.to(device)
19
-
20
- def temperature_sampling(logits, temperature):
21
- logits = logits / temperature
22
- probabilities = torch.softmax(logits, dim=-1)
23
- sampled_token = torch.multinomial(probabilities, 1)
24
- return sampled_token
25
-
26
- def CTXGen(τ, g_num, start, end):
27
- X1 = "X"
28
- X2 = "X"
29
- X4 = ""
30
- X5 = ""
31
- X6 = ""
32
- model.eval()
33
- with torch.no_grad():
34
- new_seq = None
35
- generated_seqs = []
36
- generated_seqs_FINAL = []
37
- cls_pos_all = []
38
- cls_probability_all = []
39
- act_pos_all = []
40
- act_probability_all = []
41
-
42
- count = 0
43
- gen_num = int(g_num)
44
- NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
45
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
46
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
47
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
48
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
49
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
50
-
51
- while count < gen_num:
52
- gen_len = random.randint(int(start), int(end))
53
- X3 = "X" * gen_len
54
- seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
55
- vocab_mlm.token_to_idx["X"] = 4
56
-
57
- padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
58
- input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
59
-
60
- gen_length = len(input_text)
61
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
62
-
63
- for i in range(length):
64
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
65
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
66
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
67
- attn_idx = torch.tensor(attn_idx).to(device)
68
-
69
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
70
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
71
-
72
- logits = model(idx_seq,idx_msa, attn_idx)
73
- mask_logits = logits[0, mask_position.item(), :]
74
-
75
- predicted_token_id = temperature_sampling(mask_logits, τ)
76
-
77
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
78
- input_text[mask_position.item()] = predicted_token
79
- padded_seq[mask_position.item()] = predicted_token.strip()
80
- new_seq = padded_seq
81
-
82
- generated_seq = input_text
83
-
84
- generated_seq[1] = "[MASK]"
85
- generated_seq[2] = "[MASK]"
86
- input_ids = vocab_mlm.__getitem__(generated_seq)
87
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
88
-
89
- cls_mask_logits = logits[0, 1, :]
90
- act_mask_logits = logits[0, 2, :]
91
-
92
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
93
- act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
94
-
95
- cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
96
- act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]
97
-
98
- cls_probability = cls_probability[0].item()
99
- act_probability = act_probability[0].item()
100
- generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
101
- if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
102
- generated_seqs.append("".join(generated_seq))
103
- if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
104
- generated_seqs_FINAL.append("".join(generated_seq))
105
- cls_pos_all.append(cls_pos)
106
- cls_probability_all.append(cls_probability)
107
- act_pos_all.append(act_pos)
108
- act_probability_all.append(act_probability)
109
- out = pd.DataFrame({'Generated_seq': generated_seqs_FINAL, 'Subtype': cls_pos_all, 'Subtype_probability': cls_probability_all, 'Potency': act_pos_all, 'Potency_probability': act_probability_all, 'random_seed': seed})
110
- out.to_csv("output.csv", index=False)
111
- count += 1
112
- return 'output.csv'
113
-
114
- iface = gr.Interface(
115
- fn=CTXGen,
116
- inputs=[
117
- gr.Slider(minimum=1, maximum=2, step=0.01, label="τ"),
118
- gr.Dropdown(choices=[1,10,100,1000], label="Number of generations"),
119
- gr.Textbox(label="Min length"),
120
- gr.Textbox(label="Max length")
121
- ],
122
- outputs=["file"]
123
- )
124
  iface.launch()
 
1
+ import random
2
+ import torch
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from utils import create_vocab, setup_seed
6
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
7
+
8
+ seed = random.randint(0,99999999)
9
+
10
+ setup_seed(seed)
11
+ device = torch.device("cpu")
12
+ vocab_mlm = create_vocab()
13
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
14
+ save_path = 'mlm-model-27.pt' #1
15
+ train_seqs = pd.read_csv('C0_seq.csv') #2
16
+ train_seq = train_seqs['Seq'].tolist()
17
+ model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
18
+ model = model.to(device)
19
+
20
+ def temperature_sampling(logits, temperature):
21
+ logits = logits / temperature
22
+ probabilities = torch.softmax(logits, dim=-1)
23
+ sampled_token = torch.multinomial(probabilities, 1)
24
+ return sampled_token
25
+
26
+ def CTXGen(τ, g_num, start, end):
27
+ X1 = "X"
28
+ X2 = "X"
29
+ X4 = ""
30
+ X5 = ""
31
+ X6 = ""
32
+ model.eval()
33
+ with torch.no_grad():
34
+ new_seq = None
35
+ generated_seqs = []
36
+ generated_seqs_FINAL = []
37
+ cls_pos_all = []
38
+ cls_probability_all = []
39
+ act_pos_all = []
40
+ act_probability_all = []
41
+
42
+ count = 0
43
+ gen_num = int(g_num)
44
+ NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
45
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
46
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
47
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
48
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
49
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
50
+
51
+ while count < gen_num:
52
+ gen_len = random.randint(int(start), int(end))
53
+ X3 = "X" * gen_len
54
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
55
+ vocab_mlm.token_to_idx["X"] = 4
56
+
57
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
58
+ input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
59
+
60
+ gen_length = len(input_text)
61
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
62
+
63
+ for i in range(length):
64
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
65
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
66
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
67
+ attn_idx = torch.tensor(attn_idx).to(device)
68
+
69
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
70
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
71
+
72
+ logits = model(idx_seq,idx_msa, attn_idx)
73
+ mask_logits = logits[0, mask_position.item(), :]
74
+
75
+ predicted_token_id = temperature_sampling(mask_logits, τ)
76
+
77
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
78
+ input_text[mask_position.item()] = predicted_token
79
+ padded_seq[mask_position.item()] = predicted_token.strip()
80
+ new_seq = padded_seq
81
+
82
+ generated_seq = input_text
83
+
84
+ generated_seq[1] = "[MASK]"
85
+ generated_seq[2] = "[MASK]"
86
+ input_ids = vocab_mlm.__getitem__(generated_seq)
87
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
88
+
89
+ cls_mask_logits = logits[0, 1, :]
90
+ act_mask_logits = logits[0, 2, :]
91
+
92
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
93
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
94
+
95
+ cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
96
+ act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]
97
+
98
+ cls_probability = cls_probability[0].item()
99
+ act_probability = act_probability[0].item()
100
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
101
+ if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
102
+ generated_seqs.append("".join(generated_seq))
103
+ if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
104
+ generated_seqs_FINAL.append("".join(generated_seq))
105
+ cls_pos_all.append(cls_pos)
106
+ cls_probability_all.append(cls_probability)
107
+ act_pos_all.append(act_pos)
108
+ act_probability_all.append(act_probability)
109
+ out = pd.DataFrame({'Generated_seq': generated_seqs_FINAL, 'Subtype': cls_pos_all, 'Subtype_probability': cls_probability_all, 'Potency': act_pos_all, 'Potency_probability': act_probability_all, 'random_seed': seed})
110
+ out.to_csv("output.csv", index=False)
111
+ count += 1
112
+ return 'output.csv'
113
+
114
+ iface = gr.Interface(
115
+ fn=CTXGen,
116
+ inputs=[
117
+ gr.Slider(minimum=1, maximum=2, step=0.01, label="τ"),
118
+ gr.Dropdown(choices=[1,10,100], label="Number of generations"),
119
+ gr.Textbox(label="Min length"),
120
+ gr.Textbox(label="Max length")
121
+ ],
122
+ outputs=["file"]
123
+ )
124
  iface.launch()