Update app.py
Browse files
app.py
CHANGED
@@ -23,10 +23,15 @@ def stop_generation():
|
|
23 |
return "Generation stopped."
|
24 |
|
25 |
@numba.jit(nopython=True)
|
26 |
-
def generate_sequence(
|
27 |
-
gen_length = len(input_text)
|
28 |
-
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
29 |
for i in range(length):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
31 |
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
32 |
|
@@ -125,8 +130,11 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
|
125 |
|
126 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
127 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
|
|
|
|
128 |
|
129 |
-
|
|
|
130 |
|
131 |
generated_seq[1] = "[MASK]"
|
132 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
|
23 |
return "Generation stopped."
|
24 |
|
25 |
@numba.jit(nopython=True)
|
26 |
+
def generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text):
|
|
|
|
|
27 |
for i in range(length):
|
28 |
+
if is_stopped:
|
29 |
+
return "output.csv", pd.DataFrame()
|
30 |
+
|
31 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
32 |
+
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
33 |
+
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
34 |
+
attn_idx = torch.tensor(attn_idx).to(device)
|
35 |
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
36 |
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
37 |
|
|
|
130 |
|
131 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
132 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
133 |
+
gen_length = len(input_text)
|
134 |
+
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
135 |
|
136 |
+
#函数
|
137 |
+
generated_seq = generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text)
|
138 |
|
139 |
generated_seq[1] = "[MASK]"
|
140 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|