File size: 1,813 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import json
import os.path as osp
from typing import Dict, Optional
import mmengine
from datasets import Dataset, DatasetDict
from opencompass.registry import TEXT_POSTPROCESSORS
from ..base import BaseDataset
class TEvalDataset(BaseDataset):
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs):
super().__init__(reader_cfg=reader_cfg, **kwargs)
def load(self, path: str, name: str):
dataset = DatasetDict()
data = mmengine.load(osp.join(path, f'{name}.json'))
raw_data = []
for i in data.keys():
origin_prompt = data[i]['origin_prompt']
if isinstance(origin_prompt, str):
origin_prompt = json.loads(origin_prompt)
# Aligning the default roles of opencompass
prompt = origin_prompt + [
dict(role='assistant',
content=str(data[i].get('ground_truth')))
]
raw_data.append({
'prompt': prompt,
'ground_truth': json.dumps(data[i])
})
dataset['test'] = Dataset.from_list(raw_data)
dataset['train'] = Dataset.from_list(raw_data)
return dataset
@TEXT_POSTPROCESSORS.register_module('teval')
def teval_postprocess(text: str) -> str:
if isinstance(text, str):
text = text.split('<eoa>')[0]
text = text.split('<TOKENS_UNUSED_1>')[0]
text = text.split('<|im_end|>')[0]
text = text.split('\nuser')[0]
text = text.split('\nUSER')[0]
text = text.split('[INST]')[0]
text = text.strip()
if text.startswith('```json'):
text = text[len('```json'):]
text = text.strip('`').strip()
if text[:2] == '{{' and text[-2:] == '}}':
text = text[1:-1]
return str(text)
|