|
from huggingface_hub import snapshot_download |
|
from ..smp import * |
|
from .video_base import VideoBaseDataset |
|
from .utils import build_judge, DEBUG_MESSAGE |
|
from ..utils import track_progress_rich |
|
|
|
|
|
FAIL_MSG = 'Failed to obtain answer via API.' |
|
|
|
|
|
def unwrap_hf_pkl(pth, suffix='.mp4'): |
|
base_dir = os.path.join(pth, 'video_pkl/') |
|
target_dir = os.path.join(pth, 'video/') |
|
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)] |
|
pickle_files.sort() |
|
|
|
if not os.path.exists(target_dir): |
|
os.makedirs(target_dir, exist_ok=True) |
|
for pickle_file in pickle_files: |
|
with open(pickle_file, 'rb') as file: |
|
video_data = pickle.load(file) |
|
|
|
for video_name, video_content in video_data.items(): |
|
output_path = os.path.join(target_dir, f'{video_name}{suffix}') |
|
with open(output_path, 'wb') as output_file: |
|
output_file.write(video_content) |
|
print('The video file has been restored and stored from the pickle file.') |
|
else: |
|
print('The video file already exists.') |
|
|
|
|
|
class MMBenchVideo(VideoBaseDataset): |
|
|
|
MD5 = '98f7df3eb1007fc375ea6fe88a98e2ff' |
|
SYS = 'You are an AI assistant responsible for answering questions about videos.' |
|
FRAMES_TMPL_PACK = """ |
|
You will be provided with {} separate frames uniformly sampled from a video, \ |
|
the frames are provided in chronological order of the video. |
|
Please analyze these images and provide the answer / answers to the \ |
|
following question / questions about the video content. |
|
If multiple questions are provided (with indices I1, I2, I3, ...), \ |
|
you should organize your answers in the following json format: |
|
{{ |
|
'I1': 'Answer to Question I1', |
|
'I2': 'Answer to Question I2', |
|
... |
|
}} |
|
Otherwise, please directly reply with your response to the only question. |
|
Even if the information in these separate frames is not enough to give an answer, |
|
PLEASE GIVE A RESPONSE TO EACH OF THE QUESTIONS IN THE FORMAT DESCRIBED ABOVE. |
|
""" |
|
|
|
FRAMES_TMPL_NOPACK = """ |
|
You will be provided with {} separate frames uniformly sampled from a video, \ |
|
the frames are provided in chronological order of the video. |
|
Please analyze these images and provide the answer to the question about the video content. |
|
Please directly reply with your response to the only question. |
|
""" |
|
|
|
TYPE = 'Video-VQA' |
|
|
|
def __init__(self, dataset='MMBench-Video', pack=False): |
|
super().__init__(dataset=dataset, pack=pack) |
|
|
|
@classmethod |
|
def supported_datasets(cls): |
|
return ['MMBench-Video'] |
|
|
|
def prepare_dataset(self, dataset_name='MMBench-Video', repo_id='opencompass/MMBench-Video'): |
|
def check_integrity(pth): |
|
data_file = osp.join(pth, f'{dataset_name}.tsv') |
|
if md5(data_file) != self.MD5: |
|
return False |
|
data = load(data_file) |
|
for video_pth in data['video_path']: |
|
if not osp.exists(osp.join(pth, video_pth)): |
|
return False |
|
return True |
|
|
|
cache_path = get_cache_path(repo_id) |
|
if cache_path is not None and check_integrity(cache_path): |
|
dataset_path = cache_path |
|
else: |
|
if modelscope_flag_set(): |
|
from modelscope import dataset_snapshot_download |
|
dataset_path = dataset_snapshot_download(dataset_id=repo_id) |
|
else: |
|
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset') |
|
unwrap_hf_pkl(dataset_path) |
|
self.video_path = osp.join(dataset_path, 'video/') |
|
data_file = osp.join(dataset_path, f'{dataset_name}.tsv') |
|
|
|
return dict(data_file=data_file, root=osp.join(dataset_path, 'video')) |
|
|
|
def build_prompt_pack(self, line, num_frames, fps=-1): |
|
if isinstance(line, int): |
|
assert line < len(self) |
|
video = self.videos[line] |
|
elif isinstance(line, pd.Series): |
|
video = line['video'] |
|
elif isinstance(line, str): |
|
video = line |
|
|
|
frames = self.save_video_frames(video, num_frames, fps) |
|
sub = self.data[self.data['video'] == video] |
|
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames)) |
|
message = [dict(type='text', value=sys_prompt)] |
|
for im in frames: |
|
message.append(dict(type='image', value=im)) |
|
nq = len(sub) |
|
prompt = 'Questions: \n{}\nAnswers: \n' |
|
qs = {int(sub.iloc[i]['index']): sub.iloc[i]['question'] for i in range(nq)} |
|
prompt = prompt.format(json.dumps(qs)) |
|
message.append(dict(type='text', value=prompt)) |
|
return message |
|
|
|
def build_prompt_nopack(self, line, num_frames, video_llm, fps): |
|
if isinstance(line, int): |
|
assert line < len(self) |
|
line = self.data.iloc[line] |
|
if video_llm: |
|
question = line['question'] |
|
prefix, video_idx_path = os.path.split(line['video_path']) |
|
message = [dict(type='text', value=question)] |
|
message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path))) |
|
return message |
|
else: |
|
frames = self.save_video_frames(line['video'], num_frames, fps) |
|
sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames)) |
|
message = [dict(type='text', value=sys_prompt)] |
|
for im in frames: |
|
message.append(dict(type='image', value=im)) |
|
prompt = 'Question: {}\nAnswer: '.format(line['question']) |
|
message.append(dict(type='text', value=prompt)) |
|
return message |
|
|
|
def build_prompt(self, line, num_frames, video_llm, fps): |
|
if self.pack and not video_llm: |
|
return self.build_prompt_pack(line, num_frames, fps) |
|
else: |
|
return self.build_prompt_nopack(line, num_frames, video_llm, fps) |
|
|
|
@staticmethod |
|
def remove_side_quote(s, syms=[',', '"', "'"]): |
|
if np.all([x in syms for x in s]): |
|
return '' |
|
while s[0] in syms: |
|
s = s[1:] |
|
while s[-1] in syms: |
|
s = s[:-1] |
|
return s |
|
|
|
@staticmethod |
|
def robust_json_load(s): |
|
try: |
|
jsons = list(extract_json_objects(s)) |
|
assert len(jsons) == 1 |
|
return jsons[0] |
|
except: |
|
if '{' in s and s.find('{') == s.rfind('{'): |
|
sub_str = s[s.find('{') + 1:].strip() |
|
lines = sub_str.split('\n') |
|
res = {} |
|
for l in lines: |
|
l = l.strip() |
|
if ': ' in l: |
|
key = l.split(': ')[0].strip() |
|
val = l.split(': ')[1].strip() |
|
key = MMBenchVideo.remove_side_quote(key) |
|
val = MMBenchVideo.remove_side_quote(val) |
|
if len(key) and len(val): |
|
res[key] = val |
|
return res |
|
return None |
|
|
|
def load_pack_answers(self, data_raw): |
|
vstats = defaultdict(lambda: 0) |
|
data = defaultdict(lambda: {}) |
|
|
|
for k in data_raw: |
|
ans = data_raw[k].strip() |
|
if FAIL_MSG in ans: |
|
vstats['GEN_FAIL'] += 1 |
|
continue |
|
res = self.robust_json_load(ans) |
|
if res is not None: |
|
data[k] = res |
|
vstats['PARSE_OK'] += 1 |
|
else: |
|
vstats['PARSE_FAIL'] += 1 |
|
|
|
|
|
meta = cp.deepcopy(self.data) |
|
lt = len(meta) |
|
prediction = [] |
|
for i in range(lt): |
|
line = meta.iloc[i] |
|
vid = line['video'] |
|
idx = str(line['index']) |
|
prediction.append(data[vid][idx] if idx in data[vid] else None) |
|
meta['prediction'] = prediction |
|
vstats['VALIDQ'] = len([x for x in prediction if x is not None]) |
|
vstats['INVALIDQ'] = len([x for x in prediction if x is None]) |
|
return meta, vstats |
|
|
|
|
|
@classmethod |
|
def evaluate(self, eval_file, **judge_kwargs): |
|
from .utils.mmbench_video import get_dimension_rating, system_prompt, build_prompt |
|
|
|
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file' |
|
judge = judge_kwargs['model'] |
|
nproc = judge_kwargs.pop('nproc', 4) |
|
|
|
tmp_file = eval_file.replace('.xlsx', f'_{judge}_tmp.pkl') |
|
tgt_file = eval_file.replace('.xlsx', f'_{judge}_rating.json') |
|
score_file = eval_file.replace('.xlsx', f'_{judge}_score.xlsx') |
|
|
|
model = build_judge(system_prompt=system_prompt, **judge_kwargs) |
|
assert model.working(), 'MMBench-Video evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE |
|
|
|
if not osp.exists(score_file): |
|
res = {} if not osp.exists(tmp_file) else load(tmp_file) |
|
res = {k: v for k, v in res.items() if model.fail_msg not in v} |
|
|
|
data = load(eval_file) |
|
data_un = data[~data['index'].isin(res)] |
|
data_un = data_un[~pd.isna(data_un['prediction'])] |
|
lt = len(data_un) |
|
prompts = [build_prompt(data_un.iloc[i]) for i in range(lt)] |
|
indices = [data_un.iloc[i]['index'] for i in range(lt)] |
|
|
|
if len(prompts): |
|
_ = track_progress_rich( |
|
model.generate, |
|
prompts, |
|
keys=indices, |
|
save=tmp_file, |
|
nproc=nproc, |
|
chunksize=nproc |
|
) |
|
score_map = load(tmp_file) |
|
data['score'] = [score_map[idx] if idx in score_map else -1 for idx in data['index']] |
|
rejected = [x for x in score_map.values() if FAIL_MSG in x] |
|
data['score'] = [int(x) if istype(x, int) else -1 for x in data['score']] |
|
print( |
|
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(score_map)} questions, ' |
|
f'failed to obtain the score for another {len(rejected)} questions. ' |
|
f'Those questions will be counted as 0 score in ALL rating, and will not be counted in VALID rating.' |
|
) |
|
|
|
dump(data, score_file) |
|
|
|
rating = get_dimension_rating(score_file) |
|
dump(rating, tgt_file) |
|
return rating |
|
|