Spaces:
Sleeping
Sleeping
"""Computes the repetition metric. Adapted from: https://raw.githubusercontent.com/ari-holtzman/degen/master/metrics/repetition.py""" | |
def repetition(tokenized_texts, tokenizer): | |
""" | |
Args: | |
tokenized_texts: (List[List[int]]) generated input tokenized texts. | |
Computes the repetition metric https://arxiv.org/pdf/1904.09751.pdf showing how each | |
example is repeating itself, specifically the phrase the generation is repeating | |
and how many times it is repeated. | |
""" | |
SEP = tokenizer.encode(tokenizer.bos_token)[0] | |
repetition_stats = [] | |
max_n = 90 | |
num_examples = len(tokenized_texts) | |
n_repeated_examples = 0 | |
for tokenized_text in tokenized_texts: | |
if tokenized_text[-1] == SEP: | |
tokenized_text.pop(-1) | |
rev_gen = list(reversed(tokenized_text)) | |
last_n_repeats = [0] * max_n | |
for n in range(1, max_n + 1): | |
n_repeat = 1 | |
while ( | |
len(rev_gen[n * n_repeat : n * (n_repeat + 1)]) == n | |
and rev_gen[n * n_repeat : n * (n_repeat + 1)] == rev_gen[:n] | |
): | |
n_repeat += 1 | |
last_n_repeats[n - 1] = n_repeat | |
max_repeated_n = max(range(max_n), key=lambda x: last_n_repeats[x]) | |
if last_n_repeats[max_repeated_n] > 1 and ( | |
max_repeated_n + 1 >= 3 or last_n_repeats[max_repeated_n] > 50 | |
): | |
repetition_stats.append( | |
{ | |
"repeated_phrase": list(reversed(rev_gen[: max_repeated_n + 1])), | |
"repeated_times": last_n_repeats[max_repeated_n], | |
"repeated_phrase_length": max_repeated_n + 1, | |
} | |
) | |
n_repeated_examples += 1 | |
else: | |
repetition_stats.append({}) | |
return { | |
"repetition": n_repeated_examples * 1.0 / num_examples | |
} # , "repetition_stats": repetition_stats} | |