|
import torch |
|
import gradio as gr |
|
from utils import create_vocab, setup_seed |
|
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab |
|
setup_seed(4) |
|
|
|
def CTXGen(X1,X2,X3,model_name): |
|
device = torch.device("cpu") |
|
vocab_mlm = create_vocab() |
|
vocab_mlm = add_tokens_to_vocab(vocab_mlm) |
|
save_path = model_name |
|
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu')) |
|
model = model.to(device) |
|
|
|
predicted_token_probability_all = [] |
|
model.eval() |
|
topk = [] |
|
with torch.no_grad(): |
|
new_seq = None |
|
seq = [f"{X1}|{X2}|{X3}|||"] |
|
vocab_mlm.token_to_idx["X"] = 4 |
|
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) |
|
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) |
|
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"] |
|
if not mask_positions: |
|
raise ValueError("Nothing found in the sequence to predict.") |
|
|
|
for mask_position in mask_positions: |
|
padded_seq[mask_position] = "[MASK]" |
|
input_ids = vocab_mlm.__getitem__(padded_seq) |
|
input_ids = torch.tensor([input_ids]).to(device) |
|
logits = model(input_ids, idx_msa) |
|
mask_logits = logits[0, mask_position, :] |
|
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5) |
|
topk.append(predicted_token_id) |
|
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()] |
|
predicted_token_probability_all.append(predicted_token_probability[0].item()) |
|
padded_seq[mask_position] = predicted_token |
|
|
|
cls_pos = vocab_mlm.to_tokens(list(topk[0])) |
|
if X1 != "X": |
|
Topk = X1 |
|
Subtype = X1 |
|
Potency = padded_seq[2],predicted_token_probability_all[0] |
|
elif X2 != "X": |
|
Topk = cls_pos |
|
Subtype = padded_seq[1],predicted_token_probability_all[0] |
|
Potency = X2 |
|
else: |
|
Topk = cls_pos |
|
Subtype = padded_seq[1],predicted_token_probability_all[0] |
|
Potency = padded_seq[2],predicted_token_probability_all[1] |
|
return Subtype, Potency, Topk |
|
|
|
iface = gr.Interface( |
|
fn=CTXGen, |
|
inputs=[ |
|
gr.Dropdown(choices=['X', '<AChBP>', '<Ca12>', '<Ca13>', '<Ca22>', '<Ca23>', '<GABA>', '<GluN2A>', '<GluN2B>', '<GluN2C>', '<GluN2D>', '<GluN3A>', |
|
'<K11>', '<K12>', '<K13>', '<K16>', '<K17>', '<Kshaker>', |
|
'<Na11>', '<Na12>', '<Na13>', '<Na14>', '<Na15>', '<Na16>', '<Na17>', '<Na18>', '<NaTTXR>', '<NaTTXS>', '<NavBh>', '<NET>', |
|
'<α1AAR>', '<α1BAR>', '<α1β1γ>', '<α1β1γδ>', '<α1β1δ>', '<α1β1δε>', '<α1β1ε>', '<α2β2>', '<α2β4>', '<α3β2>', '<α3β4>', |
|
'<α4β2>', '<α4β4>', '<α6α3β2>', '<α6α3β2β3>', '<α6α3β4>', '<α6α3β4β3>', '<α6β3β4>', '<α6β4>', '<α7>', '<α7α6β2>', |
|
'<α75HT3>', '<α9>', '<α9α10>'], label="Subtype"), |
|
gr.Dropdown(choices=['X','<high>','<low>'], label="Potency"), |
|
gr.Textbox(label="Conotoxin"), |
|
gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Subtype"), |
|
gr.Textbox(label="Potency"), |
|
gr.Textbox(label="Top5") |
|
], |
|
title="Conotoxin Label Prediction", |
|
description=""" |
|
🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)** |
|
🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)** |
|
🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)** |
|
🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)** |
|
|
|
✅ **Subtype**: X if needs to be predicted. |
|
|
|
✅ **Potency**: X if needs to be predicted. |
|
|
|
✅ **Conotoxin**: conotoxin needs to be predicted. |
|
|
|
✅ **Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details. |
|
|
|
""" |
|
) |
|
iface.launch() |