|
from abc import abstractmethod |
|
from ..smp import * |
|
|
|
|
|
class TextBaseDataset: |
|
MODALITY = 'TEXT' |
|
DATASET_URL = {} |
|
DATASET_MD5 = {} |
|
|
|
def __init__(self, dataset='MMBench', **kwargs): |
|
self.dataset_name = dataset |
|
|
|
data = self.load_data(dataset) |
|
|
|
data['index'] = [str(x) for x in data['index']] |
|
|
|
if np.all([istype(x, int) for x in data['index']]): |
|
data['index'] = [int(x) for x in data['index']] |
|
|
|
self.data = data |
|
self.post_build(dataset) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
return dict(self.data.iloc[idx]) |
|
|
|
def prepare_tsv(self, url, file_md5=None): |
|
data_root = LMUDataRoot() |
|
os.makedirs(data_root, exist_ok=True) |
|
update_flag = False |
|
file_name = url.split('/')[-1] |
|
data_path = osp.join(data_root, file_name) |
|
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5): |
|
pass |
|
else: |
|
warnings.warn('The dataset tsv is not downloaded') |
|
download_file(url, data_path) |
|
update_flag = True |
|
|
|
if file_size(data_path, 'GB') > 1: |
|
local_path = data_path.replace('.tsv', '_local.tsv') |
|
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag: |
|
from ..tools import LOCALIZE |
|
LOCALIZE(data_path, local_path) |
|
data_path = local_path |
|
return load(data_path) |
|
|
|
def dump_image(self, line): |
|
return [] |
|
|
|
def display(self, line): |
|
if isinstance(line, int): |
|
line = self.data.iloc[line] |
|
assert isinstance(line, pd.Series) or isinstance(line, dict) |
|
mmqa_display(line) |
|
|
|
|
|
@classmethod |
|
def supported_datasets(cls): |
|
return list(cls.DATASET_URL) |
|
|
|
|
|
def load_data(self, dataset): |
|
url = self.DATASET_URL[dataset] |
|
file_md5 = self.DATASET_MD5[dataset] |
|
return self.prepare_tsv(url, file_md5) |
|
|
|
|
|
def post_build(self, dataset): |
|
pass |
|
|
|
|
|
def build_prompt(self, line): |
|
if isinstance(line, int): |
|
line = self.data.iloc[line] |
|
|
|
question = line['question'] |
|
|
|
msgs = [] |
|
msgs.append(dict(type='text', value=question)) |
|
return msgs |
|
|
|
|
|
@abstractmethod |
|
def evaluate(self, eval_file, **judge_kwargs): |
|
pass |
|
|