oucgc1996 commited on
Commit
b80b30e
·
verified ·
1 Parent(s): b60c479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -55
app.py CHANGED
@@ -1,55 +1,74 @@
1
- import torch
2
- import gradio as gr
3
- from utils import create_vocab, setup_seed
4
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
5
- setup_seed(4)
6
- device = torch.device("cpu")
7
- vocab_mlm = create_vocab()
8
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
9
- save_path = 'mlm-model-27.pt'
10
- model = torch.load(save_path)
11
- model = model.to(device)
12
-
13
- def CTXGen(X1, X2, X3, top_k):
14
- predicted_token_probability_all = []
15
- model.eval()
16
- topk = []
17
- with torch.no_grad():
18
- new_seq = None
19
- seq = [f"{X1}|{X2}|{X3}|||"]
20
- vocab_mlm.token_to_idx["X"] = 4
21
- padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
22
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
23
- mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
24
- if not mask_positions:
25
- raise ValueError("Nothing found in the sequence to predict.")
26
-
27
- for mask_position in mask_positions:
28
- padded_seq[mask_position] = "[MASK]"
29
- input_ids = vocab_mlm.__getitem__(padded_seq)
30
- input_ids = torch.tensor([input_ids]).to(device)
31
- logits = model(input_ids, idx_msa)
32
- mask_logits = logits[0, mask_position, :]
33
- predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=top_k)
34
- topk.append(predicted_token_id)
35
- predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
36
- predicted_token_probability_all.append(predicted_token_probability[0].item())
37
- padded_seq[mask_position] = predicted_token
38
-
39
- cls_pos = vocab_mlm.to_tokens(list(topk[0]))
40
- Topk = cls_pos
41
- if X1 != "X":
42
- Subtype = X1
43
- Potency = padded_seq[2],predicted_token_probability_all[0]
44
- elif X2 != "X":
45
- Subtype = padded_seq[1],predicted_token_probability_all[0]
46
- Potency = X2
47
- else:
48
- Subtype = padded_seq[1],predicted_token_probability_all[0]
49
- Potency = padded_seq[2],predicted_token_probability_all[1]
50
- return Subtype, Potency, Topk
51
-
52
- iface = gr.Interface(fn=CTXGen,
53
- inputs=["text", "text", "text", "text"],
54
- outputs= ["text", "text", "text"])
55
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from utils import create_vocab, setup_seed
4
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
5
+ setup_seed(4)
6
+
7
+ def CTXGen(X1,X2,X3,model_name):
8
+ device = torch.device("cpu")
9
+ vocab_mlm = create_vocab()
10
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
11
+ save_path = model_name
12
+ model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
13
+ model = model.to(device)
14
+
15
+ predicted_token_probability_all = []
16
+ model.eval()
17
+ topk = []
18
+ with torch.no_grad():
19
+ new_seq = None
20
+ seq = [f"{X1}|{X2}|{X3}|||"]
21
+ vocab_mlm.token_to_idx["X"] = 4
22
+ padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
23
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
24
+ mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
25
+ if not mask_positions:
26
+ raise ValueError("Nothing found in the sequence to predict.")
27
+
28
+ for mask_position in mask_positions:
29
+ padded_seq[mask_position] = "[MASK]"
30
+ input_ids = vocab_mlm.__getitem__(padded_seq)
31
+ input_ids = torch.tensor([input_ids]).to(device)
32
+ logits = model(input_ids, idx_msa)
33
+ mask_logits = logits[0, mask_position, :]
34
+ predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
35
+ topk.append(predicted_token_id)
36
+ predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
37
+ predicted_token_probability_all.append(predicted_token_probability[0].item())
38
+ padded_seq[mask_position] = predicted_token
39
+
40
+ cls_pos = vocab_mlm.to_tokens(list(topk[0]))
41
+ if X1 != "X":
42
+ Topk = X1
43
+ Subtype = X1
44
+ Potency = padded_seq[2],predicted_token_probability_all[0]
45
+ elif X2 != "X":
46
+ Topk = cls_pos
47
+ Subtype = padded_seq[1],predicted_token_probability_all[0]
48
+ Potency = X2
49
+ else:
50
+ Topk = cls_pos
51
+ Subtype = padded_seq[1],predicted_token_probability_all[0]
52
+ Potency = padded_seq[2],predicted_token_probability_all[1]
53
+ return Subtype, Potency, Topk
54
+
55
+ iface = gr.Interface(
56
+ fn=CTXGen,
57
+ inputs=[
58
+ gr.Dropdown(choices=['X','<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
59
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
60
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
61
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
62
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
63
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype"),
64
+ gr.Dropdown(choices=['X','<high>','low'], label="Potency"),
65
+ gr.Textbox(label="Conotoxin"),
66
+ gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
67
+ ],
68
+ outputs=[
69
+ gr.Textbox(label="Subtype"),
70
+ gr.Textbox(label="Potency"),
71
+ gr.Textbox(label="Top5")
72
+ ]
73
+ )
74
+ iface.launch()