oucgc1996 commited on
Commit
db4c65b
·
verified ·
1 Parent(s): e789509

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -29
app.py CHANGED
@@ -6,12 +6,9 @@ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
- import numba
10
- from numba import objmode
11
 
12
  is_stopped = False
13
 
14
- @numba.jit(nopython=True)
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
@@ -23,29 +20,6 @@ def stop_generation():
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
- @numba.jit(nopython=False)
27
- def generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text):
28
- for i in range(length):
29
- if is_stopped:
30
- return "output.csv", pd.DataFrame()
31
-
32
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
33
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
34
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
35
- attn_idx = torch.tensor(attn_idx).to(device)
36
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
37
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
38
-
39
- logits = model(idx_seq, idx_msa, attn_idx)
40
- mask_logits = logits[0, mask_position.item(), :]
41
-
42
- predicted_token_id = temperature_sampling(mask_logits, τ)
43
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
44
- input_text[mask_position.item()] = predicted_token
45
- padded_seq[mask_position.item()] = predicted_token.strip()
46
- new_seq = padded_seq
47
- return input_text
48
-
49
  def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
50
  if seed =='random':
51
  seed = random.randint(0,100000)
@@ -131,11 +105,31 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
131
 
132
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
133
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
 
134
  gen_length = len(input_text)
135
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
136
-
137
- #函数
138
- generated_seq = generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  generated_seq[1] = "[MASK]"
141
  input_ids = vocab_mlm.__getitem__(generated_seq)
 
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
 
 
9
 
10
  is_stopped = False
11
 
 
12
  def temperature_sampling(logits, temperature):
13
  logits = logits / temperature
14
  probabilities = torch.softmax(logits, dim=-1)
 
20
  is_stopped = True
21
  return "Generation stopped."
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
24
  if seed =='random':
25
  seed = random.randint(0,100000)
 
105
 
106
  padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
107
  input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
108
+
109
  gen_length = len(input_text)
110
  length = gen_length - sum(1 for x in input_text if x != '[MASK]')
111
+ for i in range(length):
112
+ if is_stopped:
113
+ return "output.csv", pd.DataFrame()
114
+
115
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
116
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
117
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
118
+ attn_idx = torch.tensor(attn_idx).to(device)
119
+
120
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
121
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
122
+
123
+ logits = model(idx_seq,idx_msa, attn_idx)
124
+ mask_logits = logits[0, mask_position.item(), :]
125
+
126
+ predicted_token_id = temperature_sampling(mask_logits, τ)
127
+
128
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
129
+ input_text[mask_position.item()] = predicted_token
130
+ padded_seq[mask_position.item()] = predicted_token.strip()
131
+ new_seq = padded_seq
132
+ generated_seq = input_text
133
 
134
  generated_seq[1] = "[MASK]"
135
  input_ids = vocab_mlm.__getitem__(generated_seq)