refactor: Improve efficiency of generate_replacements by reducing model calls
Browse files
app.py
CHANGED
@@ -30,19 +30,19 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
|
|
30 |
def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
|
31 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
32 |
input_ids = input_context["input_ids"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
new_words = []
|
34 |
-
for
|
35 |
-
|
36 |
-
outputs = model.generate(
|
37 |
-
input_ids=input_ids,
|
38 |
-
max_length=input_ids.shape[-1] + 5,
|
39 |
-
num_return_sequences=1,
|
40 |
-
temperature=1.0,
|
41 |
-
top_k=50,
|
42 |
-
top_p=0.95,
|
43 |
-
do_sample=True
|
44 |
-
)
|
45 |
-
generated_ids = outputs[0][input_ids.shape[-1]:]
|
46 |
new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
|
47 |
new_words.append(new_word)
|
48 |
return new_words
|
|
|
30 |
def generate_replacements(model, tokenizer, prefix, device, num_samples=5):
|
31 |
input_context = tokenizer(prefix, return_tensors="pt").to(device)
|
32 |
input_ids = input_context["input_ids"]
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model.generate(
|
35 |
+
input_ids=input_ids,
|
36 |
+
max_length=input_ids.shape[-1] + 5,
|
37 |
+
num_return_sequences=num_samples,
|
38 |
+
temperature=1.0,
|
39 |
+
top_k=50,
|
40 |
+
top_p=0.95,
|
41 |
+
do_sample=True
|
42 |
+
)
|
43 |
new_words = []
|
44 |
+
for i in range(num_samples):
|
45 |
+
generated_ids = outputs[i][input_ids.shape[-1]:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
|
47 |
new_words.append(new_word)
|
48 |
return new_words
|