File size: 3,490 Bytes
b80b30e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import gradio as gr
from utils import create_vocab, setup_seed
from dataset_mlm import  get_paded_token_idx_gen, add_tokens_to_vocab
setup_seed(4)

def CTXGen(X1,X2,X3,model_name):
    device = torch.device("cpu")
    vocab_mlm = create_vocab()
    vocab_mlm = add_tokens_to_vocab(vocab_mlm)
    save_path = model_name
    model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
    model = model.to(device)
    
    predicted_token_probability_all = []
    model.eval()
    topk = []
    with torch.no_grad():
        new_seq = None
        seq = [f"{X1}|{X2}|{X3}|||"]
        vocab_mlm.token_to_idx["X"] = 4
        padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
        idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
        mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
        if not mask_positions:
            raise ValueError("Nothing found in the sequence to predict.")

        for mask_position in mask_positions:
            padded_seq[mask_position] = "[MASK]"
            input_ids = vocab_mlm.__getitem__(padded_seq)
            input_ids = torch.tensor([input_ids]).to(device)
            logits = model(input_ids, idx_msa)
            mask_logits = logits[0, mask_position, :]
            predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
            topk.append(predicted_token_id)
            predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
            predicted_token_probability_all.append(predicted_token_probability[0].item())
            padded_seq[mask_position] = predicted_token

        cls_pos = vocab_mlm.to_tokens(list(topk[0]))
        if X1 != "X":
            Topk = X1
            Subtype = X1
            Potency = padded_seq[2],predicted_token_probability_all[0]
        elif X2 != "X":
            Topk = cls_pos
            Subtype = padded_seq[1],predicted_token_probability_all[0]
            Potency = X2
        else:
            Topk = cls_pos
            Subtype = padded_seq[1],predicted_token_probability_all[0]
            Potency = padded_seq[2],predicted_token_probability_all[1]
    return Subtype, Potency, Topk

iface = gr.Interface(
    fn=CTXGen,
    inputs=[
        gr.Dropdown(choices=['X','<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
                     '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', 
                     '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', 
                     '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
                     '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', 
                     '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype"),
        gr.Dropdown(choices=['X','<high>','low'], label="Potency"),
        gr.Textbox(label="Conotoxin"),
        gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
    ],
    outputs=[
        gr.Textbox(label="Subtype"),
        gr.Textbox(label="Potency"),
        gr.Textbox(label="Top5")
    ]
)
iface.launch()