File size: 4,335 Bytes
b80b30e eb91fcc 22b6e9f 5fcc1e9 22b6e9f f0f33dc b80b30e 97d4494 dca24e2 4bb1cf2 97d4494 3fefb08 97d4494 3fefb08 97d4494 3fefb08 97d4494 b80b30e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import gradio as gr
from utils import create_vocab, setup_seed
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
setup_seed(4)
def CTXGen(X1,X2,X3,model_name):
device = torch.device("cpu")
vocab_mlm = create_vocab()
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
save_path = model_name
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
model = model.to(device)
predicted_token_probability_all = []
model.eval()
topk = []
with torch.no_grad():
new_seq = None
seq = [f"{X1}|{X2}|{X3}|||"]
vocab_mlm.token_to_idx["X"] = 4
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"]
if not mask_positions:
raise ValueError("Nothing found in the sequence to predict.")
for mask_position in mask_positions:
padded_seq[mask_position] = "[MASK]"
input_ids = vocab_mlm.__getitem__(padded_seq)
input_ids = torch.tensor([input_ids]).to(device)
logits = model(input_ids, idx_msa)
mask_logits = logits[0, mask_position, :]
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5)
topk.append(predicted_token_id)
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()]
predicted_token_probability_all.append(predicted_token_probability[0].item())
padded_seq[mask_position] = predicted_token
cls_pos = vocab_mlm.to_tokens(list(topk[0]))
if X1 != "X":
Topk = X1
Subtype = X1
Potency = padded_seq[2],predicted_token_probability_all[0]
elif X2 != "X":
Topk = cls_pos
Subtype = padded_seq[1],predicted_token_probability_all[0]
Potency = X2
else:
Topk = cls_pos
Subtype = padded_seq[1],predicted_token_probability_all[0]
Potency = padded_seq[2],predicted_token_probability_all[1]
return Subtype, Potency, Topk
iface = gr.Interface(
fn=CTXGen,
inputs=[
gr.Dropdown(choices=['X', '<AChBP>', '<Ca12>', '<Ca13>', '<Ca22>', '<Ca23>', '<GABA>', '<GluN2A>', '<GluN2B>', '<GluN2C>', '<GluN2D>', '<GluN3A>',
'<K11>', '<K12>', '<K13>', '<K16>', '<K17>', '<Kshaker>',
'<Na11>', '<Na12>', '<Na13>', '<Na14>', '<Na15>', '<Na16>', '<Na17>', '<Na18>', '<NaTTXR>', '<NaTTXS>', '<NavBh>', '<NET>',
'<α1AAR>', '<α1BAR>', '<α1β1γ>', '<α1β1γδ>', '<α1β1δ>', '<α1β1δε>', '<α1β1ε>', '<α2β2>', '<α2β4>', '<α3β2>', '<α3β4>',
'<α4β2>', '<α4β4>', '<α6α3β2>', '<α6α3β2β3>', '<α6α3β4>', '<α6α3β4β3>', '<α6β3β4>', '<α6β4>', '<α7>', '<α7α6β2>',
'<α75HT3>', '<α9>', '<α9α10>'], label="Subtype"),
gr.Dropdown(choices=['X','<high>','<low>'], label="Potency"),
gr.Textbox(label="Conotoxin"),
gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
],
outputs=[
gr.Textbox(label="Subtype"),
gr.Textbox(label="Potency"),
gr.Textbox(label="Top5")
],
title="Conotoxin Label Prediction",
description="""
🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)**
🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)**
🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)**
🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)**
✅ **Subtype**: X if needs to be predicted.
✅ **Potency**: X if needs to be predicted.
✅ **Conotoxin**: conotoxin needs to be predicted.
✅ **Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.
"""
)
iface.launch() |