Spaces:
Sleeping
Sleeping
import nltk # Here to have a nice missing dependency error message early on | |
from filelock import FileLock | |
from transformers.utils import is_offline_mode | |
try: | |
nltk.data.find("tokenizers/punkt") | |
except (LookupError, OSError): | |
if is_offline_mode(): | |
raise LookupError( | |
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" | |
) | |
with FileLock(".lock") as lock: | |
nltk.download("punkt", quiet=True) | |
def string_to_float(string, default=-1.0): | |
"""Converts string to float, using default when conversion not possible.""" | |
try: | |
return float(string) | |
except ValueError: | |
return default | |
def string_to_int(string, default=-1): | |
"""Converts string to int, using default when conversion not possible.""" | |
try: | |
return int(string) | |
except ValueError: | |
return default | |
def get_post_processor(task): | |
"""Returns post processor required to apply on the predictions/targets | |
before computing metrics for each task.""" | |
if task == "stsb": | |
return string_to_float | |
elif task in ["qqp", "cola", "mrpc"]: | |
return string_to_int | |
else: | |
return None | |
def postprocess_text_for_metric(metric, preds, labels=None, sources=None): | |
if metric == "sari": | |
assert sources is not None | |
preds = [pred.strip() for pred in preds] | |
labels = [label.strip() for label in labels] | |
sources = [source.strip() for source in sources] | |
return preds, labels, sources | |
elif metric == "rouge": | |
preds = [pred.strip() for pred in preds] | |
labels = [label.strip() for label in labels] | |
# rougeLSum expects newline after each sentence | |
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | |
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] | |
return preds, labels | |
elif metric == "bleu": | |
preds = [pred.strip() for pred in preds] | |
labels = [[label.strip()] for label in labels] | |
return preds, labels | |
elif metric in ["bertscore", "bertscore_them"]: | |
preds = [pred.strip() for pred in preds] | |
labels = [label.strip() for label in labels] | |
return preds, labels | |
elif metric in ["dist"]: | |
preds = [pred.strip() for pred in preds] | |
return preds | |
else: | |
raise NotImplementedError | |