|
from datasets import Dataset, load_dataset |
|
|
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS |
|
|
|
from ..base import BaseDataset |
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm100_dataset') |
|
def gsm100_dataset_postprocess(text: str) -> str: |
|
return text.replace(',', '') |
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm100') |
|
def gsm100_postprocess(text: str) -> str: |
|
|
|
segs = text.split('The answer is') |
|
if len(segs) < 2: |
|
return '' |
|
text = segs[1] |
|
text = text.split(' ') |
|
flag = False |
|
ret = '' |
|
for i in range(len(text)): |
|
s = text[i] |
|
for i in range(len(s)): |
|
if s[i].isdigit(): |
|
flag = True |
|
ret = s |
|
break |
|
if flag: |
|
break |
|
ret1 = '' |
|
for i in range(len(ret)): |
|
if ret[i].isdigit(): |
|
ret1 += ret[i] |
|
return ret1 |
|
|
|
|
|
@LOAD_DATASET.register_module() |
|
class LEvalGSM100Dataset(BaseDataset): |
|
|
|
@staticmethod |
|
def load(**kwargs): |
|
dataset = load_dataset(**kwargs) |
|
split = 'test' |
|
raw_data = [] |
|
for i in range(len(dataset[split])): |
|
instructions = dataset[split]['instructions'][i] |
|
outputs = dataset[split]['outputs'][i] |
|
context = dataset[split]['input'][i] |
|
for question, answer in zip(instructions, outputs): |
|
raw_data.append({ |
|
'question': question, |
|
'context': context, |
|
'answer': answer |
|
}) |
|
dataset[split] = Dataset.from_list(raw_data) |
|
return dataset |
|
|