|
import random
|
|
import torch
|
|
import gradio as gr
|
|
import pandas as pd
|
|
from utils import create_vocab, setup_seed
|
|
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
|
|
|
seed = random.randint(0,99999999)
|
|
|
|
setup_seed(seed)
|
|
device = torch.device("cpu")
|
|
vocab_mlm = create_vocab()
|
|
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
|
save_path = 'mlm-model-27.pt'
|
|
train_seqs = pd.read_csv('C0_seq.csv')
|
|
train_seq = train_seqs['Seq'].tolist()
|
|
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
|
|
model = model.to(device)
|
|
|
|
def temperature_sampling(logits, temperature):
|
|
logits = logits / temperature
|
|
probabilities = torch.softmax(logits, dim=-1)
|
|
sampled_token = torch.multinomial(probabilities, 1)
|
|
return sampled_token
|
|
|
|
def CTXGen(τ, g_num, start, end):
|
|
X1 = "X"
|
|
X2 = "X"
|
|
X4 = ""
|
|
X5 = ""
|
|
X6 = ""
|
|
model.eval()
|
|
with torch.no_grad():
|
|
new_seq = None
|
|
generated_seqs = []
|
|
generated_seqs_FINAL = []
|
|
cls_pos_all = []
|
|
cls_probability_all = []
|
|
act_pos_all = []
|
|
act_probability_all = []
|
|
|
|
count = 0
|
|
gen_num = int(g_num)
|
|
NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
|
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
|
|
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
|
|
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
|
|
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
|
|
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
|
|
|
|
while count < gen_num:
|
|
gen_len = random.randint(int(start), int(end))
|
|
X3 = "X" * gen_len
|
|
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
|
|
vocab_mlm.token_to_idx["X"] = 4
|
|
|
|
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
|
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
|
|
|
gen_length = len(input_text)
|
|
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
|
|
|
for i in range(length):
|
|
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
|
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
|
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
|
attn_idx = torch.tensor(attn_idx).to(device)
|
|
|
|
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
|
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
|
|
|
logits = model(idx_seq,idx_msa, attn_idx)
|
|
mask_logits = logits[0, mask_position.item(), :]
|
|
|
|
predicted_token_id = temperature_sampling(mask_logits, τ)
|
|
|
|
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
|
input_text[mask_position.item()] = predicted_token
|
|
padded_seq[mask_position.item()] = predicted_token.strip()
|
|
new_seq = padded_seq
|
|
|
|
generated_seq = input_text
|
|
|
|
generated_seq[1] = "[MASK]"
|
|
generated_seq[2] = "[MASK]"
|
|
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
|
|
|
cls_mask_logits = logits[0, 1, :]
|
|
act_mask_logits = logits[0, 2, :]
|
|
|
|
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
|
|
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)
|
|
|
|
cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
|
|
act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]
|
|
|
|
cls_probability = cls_probability[0].item()
|
|
act_probability = act_probability[0].item()
|
|
generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
|
|
if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
|
|
generated_seqs.append("".join(generated_seq))
|
|
if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
|
|
generated_seqs_FINAL.append("".join(generated_seq))
|
|
cls_pos_all.append(cls_pos)
|
|
cls_probability_all.append(cls_probability)
|
|
act_pos_all.append(act_pos)
|
|
act_probability_all.append(act_probability)
|
|
out = pd.DataFrame({'Generated_seq': generated_seqs_FINAL, 'Subtype': cls_pos_all, 'Subtype_probability': cls_probability_all, 'Potency': act_pos_all, 'Potency_probability': act_probability_all, 'random_seed': seed})
|
|
out.to_csv("output.csv", index=False)
|
|
count += 1
|
|
return 'output.csv'
|
|
|
|
iface = gr.Interface(
|
|
fn=CTXGen,
|
|
inputs=[
|
|
gr.Slider(minimum=1, maximum=2, step=0.01, label="τ"),
|
|
gr.Dropdown(choices=[1,10,100,1000], label="Number of generations"),
|
|
gr.Textbox(label="Min length"),
|
|
gr.Textbox(label="Max length")
|
|
],
|
|
outputs=["file"]
|
|
)
|
|
iface.launch() |