Update app.py
Browse files
app.py
CHANGED
@@ -100,7 +100,7 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
|
|
100 |
if is_stopped:
|
101 |
return pd.DataFrame(), "output.csv"
|
102 |
|
103 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq,
|
104 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
105 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
106 |
attn_idx = torch.tensor(attn_idx).to(device)
|
@@ -115,7 +115,8 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
|
|
115 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
116 |
input_text[mask_position.item()] = predicted_token
|
117 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
118 |
-
|
|
|
119 |
generated_seq = input_text
|
120 |
|
121 |
generated_seq[1] = "[MASK]"
|
|
|
100 |
if is_stopped:
|
101 |
return pd.DataFrame(), "output.csv"
|
102 |
|
103 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
104 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
105 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
106 |
attn_idx = torch.tensor(attn_idx).to(device)
|
|
|
115 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
116 |
input_text[mask_position.item()] = predicted_token
|
117 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
118 |
+
new_seq = padded_seq
|
119 |
+
|
120 |
generated_seq = input_text
|
121 |
|
122 |
generated_seq[1] = "[MASK]"
|