TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame contribute delete
2.76 kB
import os
from typing import List
from mmengine.dataset import Compose
from torch.utils.data import Dataset
from opencompass.registry import DATASETS
@DATASETS.register_module()
class MMEDataset(Dataset):
"""Dataset to load MME dataset.
Args:
data_dir (str): The path of the dataset.
pipeline (List[dict]): The data augmentation.
"""
tasks = [
'artwork', 'celebrity', 'code_reasoning', 'color',
'commonsense_reasoning', 'count', 'existence', 'landmark',
'numerical_calculation', 'OCR', 'position', 'posters', 'scene',
'text_translation'
]
sub_dir_name = ('images', 'questions_answers_YN')
def __init__(self, data_dir: str, pipeline: List[dict]) -> None:
self.pipeline = Compose(pipeline)
self.load_data(data_dir)
def load_data(self, data_dir: str):
self.data_list = []
image_dir, question_dir = self.sub_dir_name
for task in self.tasks:
if os.path.exists(os.path.join(data_dir, task, question_dir)):
q_list = os.listdir(os.path.join(data_dir, task, question_dir))
i_list = os.listdir(os.path.join(data_dir, task, image_dir))
q_prefix = os.path.join(data_dir, task, question_dir)
i_prefix = os.path.join(data_dir, task, image_dir)
else:
fn_list = os.listdir(os.path.join(data_dir, task))
q_list = [fn for fn in fn_list if '.txt' in fn]
i_list = [fn for fn in fn_list if fn not in q_list]
q_prefix = i_prefix = os.path.join(data_dir, task)
q_list.sort()
i_list.sort()
assert len(q_list) == len(i_list)
for q_fn, i_fn in zip(q_list, i_list):
assert q_fn.split('.')[0] == i_fn.split('.')[0]
q_path = os.path.join(q_prefix, q_fn)
image_path = os.path.join(i_prefix, i_fn)
with open(q_path, 'r') as f:
q1, a1 = f.readline().strip().split('\t')
q2, a2 = f.readline().strip().split('\t')
self.data_list.append({
'img_path': image_path,
'question': q1,
'answer': a1,
'task': task
})
self.data_list.append({
'img_path': image_path,
'question': q2,
'answer': a2,
'task': task
})
def __len__(self) -> None:
return len(self.data_list)
def __getitem__(self, idx: int) -> dict:
data_sample = self.data_list[idx]
data_sample = self.pipeline(data_sample)
return data_sample