Update app.py
Browse files
app.py
CHANGED
@@ -6,12 +6,9 @@ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
|
6 |
import gradio as gr
|
7 |
from gradio_rangeslider import RangeSlider
|
8 |
import time
|
9 |
-
import numba
|
10 |
-
from numba import objmode
|
11 |
|
12 |
is_stopped = False
|
13 |
|
14 |
-
@numba.jit(nopython=True)
|
15 |
def temperature_sampling(logits, temperature):
|
16 |
logits = logits / temperature
|
17 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -23,29 +20,6 @@ def stop_generation():
|
|
23 |
is_stopped = True
|
24 |
return "Generation stopped."
|
25 |
|
26 |
-
@numba.jit(nopython=False)
|
27 |
-
def generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text):
|
28 |
-
for i in range(length):
|
29 |
-
if is_stopped:
|
30 |
-
return "output.csv", pd.DataFrame()
|
31 |
-
|
32 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
33 |
-
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
34 |
-
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
35 |
-
attn_idx = torch.tensor(attn_idx).to(device)
|
36 |
-
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
37 |
-
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
38 |
-
|
39 |
-
logits = model(idx_seq, idx_msa, attn_idx)
|
40 |
-
mask_logits = logits[0, mask_position.item(), :]
|
41 |
-
|
42 |
-
predicted_token_id = temperature_sampling(mask_logits, τ)
|
43 |
-
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
44 |
-
input_text[mask_position.item()] = predicted_token
|
45 |
-
padded_seq[mask_position.item()] = predicted_token.strip()
|
46 |
-
new_seq = padded_seq
|
47 |
-
return input_text
|
48 |
-
|
49 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
50 |
if seed =='random':
|
51 |
seed = random.randint(0,100000)
|
@@ -131,11 +105,31 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
|
131 |
|
132 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
133 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
|
|
134 |
gen_length = len(input_text)
|
135 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
generated_seq[1] = "[MASK]"
|
141 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
|
6 |
import gradio as gr
|
7 |
from gradio_rangeslider import RangeSlider
|
8 |
import time
|
|
|
|
|
9 |
|
10 |
is_stopped = False
|
11 |
|
|
|
12 |
def temperature_sampling(logits, temperature):
|
13 |
logits = logits / temperature
|
14 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
20 |
is_stopped = True
|
21 |
return "Generation stopped."
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
24 |
if seed =='random':
|
25 |
seed = random.randint(0,100000)
|
|
|
105 |
|
106 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
107 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
108 |
+
|
109 |
gen_length = len(input_text)
|
110 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
111 |
+
for i in range(length):
|
112 |
+
if is_stopped:
|
113 |
+
return "output.csv", pd.DataFrame()
|
114 |
+
|
115 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
116 |
+
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
117 |
+
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
118 |
+
attn_idx = torch.tensor(attn_idx).to(device)
|
119 |
+
|
120 |
+
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
121 |
+
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
122 |
+
|
123 |
+
logits = model(idx_seq,idx_msa, attn_idx)
|
124 |
+
mask_logits = logits[0, mask_position.item(), :]
|
125 |
+
|
126 |
+
predicted_token_id = temperature_sampling(mask_logits, τ)
|
127 |
+
|
128 |
+
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
129 |
+
input_text[mask_position.item()] = predicted_token
|
130 |
+
padded_seq[mask_position.item()] = predicted_token.strip()
|
131 |
+
new_seq = padded_seq
|
132 |
+
generated_seq = input_text
|
133 |
|
134 |
generated_seq[1] = "[MASK]"
|
135 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|