oucgc1996 commited on
Commit
197fa95
·
verified ·
1 Parent(s): aadb062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -100,7 +100,7 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
100
  if is_stopped:
101
  return pd.DataFrame(), "output.csv"
102
 
103
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, None)
104
  idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
105
  idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
106
  attn_idx = torch.tensor(attn_idx).to(device)
@@ -115,7 +115,8 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
115
  predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
116
  input_text[mask_position.item()] = predicted_token
117
  padded_seq[mask_position.item()] = predicted_token.strip()
118
-
 
119
  generated_seq = input_text
120
 
121
  generated_seq[1] = "[MASK]"
 
100
  if is_stopped:
101
  return pd.DataFrame(), "output.csv"
102
 
103
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
104
  idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
105
  idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
106
  attn_idx = torch.tensor(attn_idx).to(device)
 
115
  predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
116
  input_text[mask_position.item()] = predicted_token
117
  padded_seq[mask_position.item()] = predicted_token.strip()
118
+ new_seq = padded_seq
119
+
120
  generated_seq = input_text
121
 
122
  generated_seq[1] = "[MASK]"