File size: 6,408 Bytes
47170a5 a683732 47170a5 031925d 47170a5 55406ba 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d 47170a5 031925d a683732 47170a5 a683732 47170a5 a683732 47170a5 260c1a3 47170a5 a683732 47170a5 a683732 47170a5 a683732 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import random
from collections import Counter, defaultdict
from langcodes import Language, standardize_tag
from rich import print
from .util import _get_dataset_config_names, _load_dataset
def print_counts(slug, subjects_dev, subjects_test):
print(
f"{slug:<25} {len(list(set(subjects_test))):>3} test categories, {len(subjects_test):>6} samples, {len(list(set(subjects_dev))):>3} dev categories, {len(subjects_dev):>6} dev samples"
)
def print_datasets_analysis():
print("Category counts and sample counts per dataset:")
slug1 = "masakhane/afrimmlu"
ds1 = _load_dataset(slug1, "eng")
print_counts(slug1, ds1["dev"]["subject"], ds1["test"]["subject"])
langs1 = _get_dataset_config_names(slug1)
langs1 = [standardize_tag(a, macro=True) for a in langs1]
slug2 = "openai/MMMLU" # does not have dev set! – but: these languages are all also present in Global-MMLU
ds2 = _load_dataset(slug2, "FR_FR")
print_counts(slug2, [], ds2["test"]["Subject"])
langs2 = _get_dataset_config_names(slug2)
langs2 = [a.split("_")[0].lower() for a in langs2]
langs2.remove("default")
slug3 = "CohereForAI/Global-MMLU"
ds3 = _load_dataset(slug3, "en")
print_counts(slug3, ds3["dev"]["subject"], ds3["test"]["subject"])
langs3 = _get_dataset_config_names(slug3)
langs3 = [standardize_tag(a, macro=True) for a in langs3]
slug4 = "lighteval/okapi_mmlu"
ds4 = _load_dataset(slug4, "ar", trust_remote_code=True)
print_counts(
slug4,
[a.split("/")[0] for a in ds4["dev"]["id"]],
[a.split("/")[0] for a in ds4["test"]["id"]],
)
langs4 = _get_dataset_config_names(slug4)
slug5 = "Eurolingua/mmlux"
subsets = _get_dataset_config_names(slug5)
subjects = set(a.rsplit("_", 1)[0] for a in subsets)
rows_test = [
_load_dataset(slug5, subset)["test"]["id"]
for subset in subsets
if "_DA" in subset
]
rows_test = [a.split("/")[0] for l in rows_test for a in l]
rows_dev = [
_load_dataset(slug5, subset)["dev"]["id"]
for subset in subsets
if "_DA" in subset
]
rows_dev = [a.split("/")[0] for l in rows_dev for a in l]
print_counts(slug5, rows_dev, rows_test)
langs5 = list(set(a.rsplit("_", 1)[1].split("-")[0].lower() for a in subsets))
langs = langs1 + langs2 + langs3 + langs4 + langs5
lang_datasets = defaultdict(list)
for slug, langs_list in [
(slug1, langs1),
(slug2, langs2),
(slug3, langs3),
(slug4, langs4),
(slug5, langs5),
]:
for lang in langs_list:
lname = Language.get(lang).display_name()
lang_datasets[lname].append(slug)
print("Datasets per language:")
print(sorted(lang_datasets.items()))
print(len(set(langs)))
print("Datasets per language for languages that are not in Global-MMLU:")
print(
sorted(
(lang, datasets)
for lang, datasets in lang_datasets.items()
if slug3 not in datasets
)
)
print(
Counter(
dataset
for ds_list in lang_datasets.values()
for dataset in ds_list
if slug3 not in ds_list
)
)
print(list(set(ds1["test"]["subject"])))
# based on this analysis:
# - we drop the OpenAI dataset, since it does not have a dev set, and since every language that it has is also present in Global-MMLU
# - we stick to the 5 categories of the AfriMMLU dataset, since this is the most restricted dataset, and these 5 categories are present in all datasets, so this is good for comparability
# AfriMMLU is human-translated, but has only 5 task categories
# Global-MMLU is mixed-translated, specifically those 15 languages are that are also present in Global-MMLU-Lite, which are mostly from MMMLU; otherwise translated using Google Translate
# Okapi-MMLU is translated using ChatGPT (version unclear)
# MMLUX is translated using DeepL
# Therefore, the priority is: AfriMMLU, Global-MMLU, MMLUX, Okapi-MMLU
# print_datasets_analysis()
def parse_choices(row):
if not isinstance(row["choices"], list):
row["choices"] = eval(row["choices"])
return row
def add_choices(row):
row["choices"] = [
row["option_a"],
row["option_b"],
row["option_c"],
row["option_d"],
]
return row
def load_mmlu(language_bcp_47, nr):
categories = sorted(
list(set(_load_dataset("masakhane/afrimmlu", "eng")["dev"]["subject"]))
)
category = categories[nr % len(categories)]
random.seed(nr)
i = random.randint(0, 100)
tags_afrimmlu = {
standardize_tag(a, macro=True): a
for a in _get_dataset_config_names("masakhane/afrimmlu")
}
tags_global_mmlu = {
standardize_tag(a, macro=True): a
for a in _get_dataset_config_names("CohereForAI/Global-MMLU")
}
tags_okapi = _get_dataset_config_names("lighteval/okapi_mmlu")
tags_mmlux = set(
a.rsplit("_", 1)[1].split("-")[0].lower()
for a in _get_dataset_config_names("Eurolingua/mmlux")
)
if language_bcp_47 in tags_afrimmlu:
ds = _load_dataset("masakhane/afrimmlu", tags_afrimmlu[language_bcp_47])
ds = ds.map(parse_choices)
examples = ds["dev"].filter(lambda x: x["subject"] == category)
task = ds["test"].filter(lambda x: x["subject"] == category)[i]
return "masakhane/afrimmlu", examples, task
elif language_bcp_47 in tags_global_mmlu:
ds = _load_dataset("CohereForAI/Global-MMLU", tags_global_mmlu[language_bcp_47])
ds = ds.map(add_choices)
examples = ds["dev"].filter(lambda x: x["subject"] == category)
task = ds["test"].filter(lambda x: x["subject"] == category)[i]
return "CohereForAI/Global-MMLU", examples, task
elif language_bcp_47 in tags_okapi:
return None, None, None # FIXME
ds = _load_dataset(
"lighteval/okapi_mmlu", language_bcp_47, trust_remote_code=True
)
examples = ds["dev"].filter(lambda x: x["subject"] == category)
task = ds["test"].filter(lambda x: x["id"] == f"{category}/test/{i}")[0]
return "lighteval/okapi_mmlu", examples, task
elif language_bcp_47 in tags_mmlux:
# loading this is more complicated, todo
return None, None, None
else:
return None, None, None
|