oucgc1996 commited on
Commit
48636f3
·
verified ·
1 Parent(s): 541538c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -23,10 +23,15 @@ def stop_generation():
23
  return "Generation stopped."
24
 
25
  @numba.jit(nopython=True)
26
- def generate_sequence(input_text, model, vocab_mlm, idx_msa, τ):
27
- gen_length = len(input_text)
28
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
29
  for i in range(length):
 
 
 
 
 
 
 
30
  mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
31
  mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
32
 
@@ -125,8 +130,11 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
125
 
126
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
127
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
 
 
128
 
129
- generated_seq = generate_sequence(input_text, model, vocab_mlm, idx_msa, τ)
 
130
 
131
  generated_seq[1] = "[MASK]"
132
  input_ids = vocab_mlm.__getitem__(generated_seq)
 
23
  return "Generation stopped."
24
 
25
  @numba.jit(nopython=True)
26
+ def generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text):
 
 
27
  for i in range(length):
28
+ if is_stopped:
29
+ return "output.csv", pd.DataFrame()
30
+
31
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
32
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
33
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
34
+ attn_idx = torch.tensor(attn_idx).to(device)
35
  mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
36
  mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
37
 
 
130
 
131
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
132
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
133
+ gen_length = len(input_text)
134
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
135
 
136
+ #函数
137
+ generated_seq = generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text)
138
 
139
  generated_seq[1] = "[MASK]"
140
  input_ids = vocab_mlm.__getitem__(generated_seq)