|
from .text_base import TextBaseDataset |
|
from .utils import build_judge, DEBUG_MESSAGE |
|
from ..smp import * |
|
|
|
|
|
class TextMCQDataset(TextBaseDataset): |
|
TYPE = 'MCQ' |
|
|
|
DATASET_URL = {} |
|
|
|
DATASET_MD5 = {} |
|
|
|
def build_prompt(self, line): |
|
|
|
if isinstance(line, int): |
|
line = self.data.iloc[line] |
|
|
|
question = line['question'] |
|
options = { |
|
cand: line[cand] |
|
for cand in string.ascii_uppercase |
|
if cand in line and not pd.isna(line[cand]) |
|
} |
|
options_prompt = 'Options:\n' |
|
for key, item in options.items(): |
|
options_prompt += f'{key}. {item}\n' |
|
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None |
|
prompt = '' |
|
if hint is not None: |
|
prompt += f'Hint: {hint}\n' |
|
prompt += f'Question: {question}\n' |
|
if len(options): |
|
prompt += options_prompt |
|
prompt += 'Please select the correct answer from the options above. \n' |
|
|
|
msgs = [] |
|
|
|
msgs.append(dict(type='text', value=prompt)) |
|
|
|
return msgs |
|
|
|
def evaluate(self, eval_file, **judge_kwargs): |
|
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval |
|
|
|
dataset_map = { |
|
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11', |
|
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11' |
|
} |
|
dataset = self.dataset_name |
|
if dataset in dataset_map: |
|
dataset = dataset_map[dataset] |
|
nproc = judge_kwargs.pop('nproc', 4) |
|
|
|
circular = False |
|
|
|
suffix = eval_file.split('.')[-1] |
|
model = judge_kwargs.get('model', 'exact_matching') |
|
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125'] |
|
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'} |
|
name_str = name_str_map[model] if model in name_str_map else model |
|
|
|
if model == 'exact_matching': |
|
model = None |
|
elif gpt_key_set(): |
|
model = build_judge(**judge_kwargs) |
|
if not model.working(): |
|
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation') |
|
warnings.warn(DEBUG_MESSAGE) |
|
model = None |
|
else: |
|
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') |
|
model = None |
|
|
|
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl') |
|
|
|
data = load(eval_file) |
|
data = data.sort_values(by='index') |
|
data['prediction'] = [str(x) for x in data['prediction']] |
|
|
|
for k in data.keys(): |
|
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k) |
|
|
|
meta = self.data |
|
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])} |
|
data_map = {x: y for x, y in zip(data['index'], data['question'])} |
|
for k in data_map: |
|
assert k in meta_q_map, ( |
|
f'eval_file should be the same as or a subset of dataset {self.dataset_name}' |
|
) |
|
|
|
if circular: |
|
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name) |
|
else: |
|
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name) |
|
|
|
|
|
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) |
|
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')) |
|
|
|
|
|
if 'MMT' in dataset: |
|
acc = report_acc_MMT(data) |
|
else: |
|
acc = report_acc(data) |
|
|
|
score_file = eval_file.replace(f'.{suffix}', '_acc.csv') |
|
dump(acc, score_file) |
|
|
|
return acc |
|
|
|
|
|
class CustomTextMCQDataset(TextMCQDataset): |
|
|
|
def load_data(self, dataset): |
|
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv') |
|
|
|
if file_size(data_path, 'GB') > 1: |
|
local_path = data_path.replace('.tsv', '_local.tsv') |
|
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None): |
|
from ..tools import LOCALIZE |
|
LOCALIZE(data_path, local_path) |
|
data_path = local_path |
|
return load(data_path) |
|
|