mebubo commited on
Commit
91515a1
·
1 Parent(s): ada166c

refactor: Improve efficiency of generate_replacements by reducing model calls

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -30,19 +30,19 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
30
  def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
31
  input_context = tokenizer(prefix, return_tensors="pt").to(device)
32
  input_ids = input_context["input_ids"]
 
 
 
 
 
 
 
 
 
 
33
  new_words = []
34
- for _ in range(num_samples):
35
- with torch.no_grad():
36
- outputs = model.generate(
37
- input_ids=input_ids,
38
- max_length=input_ids.shape[-1] + 5,
39
- num_return_sequences=1,
40
- temperature=1.0,
41
- top_k=50,
42
- top_p=0.95,
43
- do_sample=True
44
- )
45
- generated_ids = outputs[0][input_ids.shape[-1]:]
46
  new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
47
  new_words.append(new_word)
48
  return new_words
 
30
  def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
31
  input_context = tokenizer(prefix, return_tensors="pt").to(device)
32
  input_ids = input_context["input_ids"]
33
+ with torch.no_grad():
34
+ outputs = model.generate(
35
+ input_ids=input_ids,
36
+ max_length=input_ids.shape[-1] + 5,
37
+ num_return_sequences=num_samples,
38
+ temperature=1.0,
39
+ top_k=50,
40
+ top_p=0.95,
41
+ do_sample=True
42
+ )
43
  new_words = []
44
+ for i in range(num_samples):
45
+ generated_ids = outputs[i][input_ids.shape[-1]:]
 
 
 
 
 
 
 
 
 
 
46
  new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
47
  new_words.append(new_word)
48
  return new_words