Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from utils import create_vocab, setup_seed | |
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab | |
setup_seed(4) | |
def CTXGen(X1,X2,X3,model_name): | |
device = torch.device("cpu") | |
vocab_mlm = create_vocab() | |
vocab_mlm = add_tokens_to_vocab(vocab_mlm) | |
save_path = model_name | |
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu')) | |
model = model.to(device) | |
predicted_token_probability_all = [] | |
model.eval() | |
topk = [] | |
with torch.no_grad(): | |
new_seq = None | |
seq = [f"{X1}|{X2}|{X3}|||"] | |
vocab_mlm.token_to_idx["X"] = 4 | |
padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) | |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) | |
mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"] | |
if not mask_positions: | |
raise ValueError("Nothing found in the sequence to predict.") | |
for mask_position in mask_positions: | |
padded_seq[mask_position] = "[MASK]" | |
input_ids = vocab_mlm.__getitem__(padded_seq) | |
input_ids = torch.tensor([input_ids]).to(device) | |
logits = model(input_ids, idx_msa) | |
mask_logits = logits[0, mask_position, :] | |
predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5) | |
topk.append(predicted_token_id) | |
predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()] | |
predicted_token_probability_all.append(predicted_token_probability[0].item()) | |
padded_seq[mask_position] = predicted_token | |
cls_pos = vocab_mlm.to_tokens(list(topk[0])) | |
if X1 != "X": | |
Topk = X1 | |
Subtype = X1 | |
Potency = padded_seq[2],predicted_token_probability_all[0] | |
elif X2 != "X": | |
Topk = cls_pos | |
Subtype = padded_seq[1],predicted_token_probability_all[0] | |
Potency = X2 | |
else: | |
Topk = cls_pos | |
Subtype = padded_seq[1],predicted_token_probability_all[0] | |
Potency = padded_seq[2],predicted_token_probability_all[1] | |
return Subtype, Potency, Topk | |
iface = gr.Interface( | |
fn=CTXGen, | |
inputs=[ | |
gr.Dropdown(choices=['X','<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>', | |
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', | |
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', | |
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>', | |
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', | |
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype"), | |
gr.Dropdown(choices=['X','<high>','low'], label="Potency"), | |
gr.Textbox(label="Conotoxin"), | |
gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") | |
], | |
outputs=[ | |
gr.Textbox(label="Subtype"), | |
gr.Textbox(label="Potency"), | |
gr.Textbox(label="Top5") | |
] | |
) | |
iface.launch() |