|
from huggingface_hub import snapshot_download |
|
from ..smp import * |
|
from .video_base import VideoBaseDataset |
|
from .utils import build_judge, DEBUG_MESSAGE |
|
|
|
FAIL_MSG = 'Failed to obtain answer via API.' |
|
|
|
|
|
def unwrap_hf_pkl(pth, suffix='.mp4'): |
|
base_dir = os.path.join(pth, 'video_pkl/') |
|
target_dir = os.path.join(pth, 'video/') |
|
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)] |
|
pickle_files.sort() |
|
|
|
if not os.path.exists(target_dir): |
|
os.makedirs(target_dir, exist_ok=True) |
|
for pickle_file in pickle_files: |
|
with open(pickle_file, 'rb') as file: |
|
video_data = pickle.load(file) |
|
|
|
for video_name, video_content in video_data.items(): |
|
output_path = os.path.join(target_dir, f'{video_name}{suffix}') |
|
with open(output_path, 'wb') as output_file: |
|
output_file.write(video_content) |
|
print('The video file has been restored and stored from the pickle file.') |
|
else: |
|
print('The video file already exists.') |
|
|
|
|
|
class VideoMME(VideoBaseDataset): |
|
|
|
MD5 = '85bdd91f9b29a99354c23b97ab7c113c' |
|
SYS = '' |
|
|
|
FRAMES_TMPL_NOSUB = """ |
|
These are the frames of a video. \ |
|
Select the best answer to the following multiple-choice question based on the video. \ |
|
Respond with only the letter (A, B, C, or D) of the correct option. |
|
""" |
|
|
|
FRAMES_TMPL_SUB = """ |
|
These are the frames of a video. \ |
|
This video's subtitles are listed below: |
|
{} |
|
Select the best answer to the following multiple-choice question based on the video. \ |
|
Respond with only the letter (A, B, C, or D) of the correct option. |
|
""" |
|
|
|
TYPE = 'Video-MCQ' |
|
|
|
def __init__(self, dataset='Video-MME', use_subtitle=False): |
|
super().__init__(dataset=dataset) |
|
self.use_subtitle = use_subtitle |
|
self.dataset_name = dataset |
|
|
|
@classmethod |
|
def supported_datasets(cls): |
|
return ['Video-MME'] |
|
|
|
def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'): |
|
|
|
def check_integrity(pth): |
|
data_file = osp.join(pth, f'{dataset_name}.tsv') |
|
|
|
if not os.path.exists(data_file): |
|
return False |
|
|
|
if md5(data_file) != self.MD5: |
|
return False |
|
data = load(data_file) |
|
for video_pth in data['video_path']: |
|
if not osp.exists(osp.join(pth, video_pth)): |
|
return False |
|
return True |
|
|
|
cache_path = get_cache_path(repo_id) |
|
if cache_path is not None and check_integrity(cache_path): |
|
dataset_path = cache_path |
|
else: |
|
|
|
def unzip_hf_zip(pth): |
|
import zipfile |
|
base_dir = pth |
|
target_dir = os.path.join(pth, 'video/') |
|
zip_files = [ |
|
os.path.join(base_dir, file) for file in os.listdir(base_dir) |
|
if file.endswith('.zip') and file.startswith('video') |
|
] |
|
zip_files.sort() |
|
|
|
if not os.path.exists(target_dir): |
|
os.makedirs(target_dir, exist_ok=True) |
|
for zip_file in zip_files: |
|
with zipfile.ZipFile(zip_file, 'r') as zip_ref: |
|
for member in zip_ref.namelist(): |
|
|
|
if not member.endswith('/'): |
|
|
|
source = zip_ref.open(member) |
|
target = open(os.path.join(target_dir, os.path.basename(member)), 'wb') |
|
with source, target: |
|
target.write(source.read()) |
|
print('The video file has been restored and stored from the zip file.') |
|
else: |
|
print('The video file already exists.') |
|
|
|
subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip') |
|
subtitle_target_dir = os.path.join(base_dir, 'subtitle') |
|
|
|
if not os.path.exists(subtitle_target_dir): |
|
os.makedirs(subtitle_target_dir, exist_ok=True) |
|
with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref: |
|
for member in zip_ref.namelist(): |
|
|
|
if not member.endswith('/'): |
|
|
|
source = zip_ref.open(member) |
|
target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb') |
|
with source, target: |
|
target.write(source.read()) |
|
print('The subtitle file has been restored and stored from the zip file.') |
|
else: |
|
print('The subtitle file already exists.') |
|
|
|
def generate_tsv(pth): |
|
|
|
data_file = osp.join(pth, f'{dataset_name}.tsv') |
|
if os.path.exists(data_file) and md5(data_file) == self.MD5: |
|
return |
|
|
|
data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet')) |
|
data_file = data_file.assign(index=range(len(data_file))) |
|
data_file['video'] = data_file['videoID'] |
|
data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4') |
|
data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt') |
|
data_file['candidates'] = data_file['options'].apply(lambda x: x.tolist()) |
|
|
|
data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain', 'candidates', |
|
'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']] |
|
|
|
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False) |
|
|
|
if modelscope_flag_set(): |
|
from modelscope import dataset_snapshot_download |
|
dataset_path = dataset_snapshot_download(dataset_id=repo_id) |
|
else: |
|
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset') |
|
unzip_hf_zip(dataset_path) |
|
generate_tsv(dataset_path) |
|
|
|
data_file = osp.join(dataset_path, f'{dataset_name}.tsv') |
|
|
|
return dict(data_file=data_file, root=dataset_path) |
|
|
|
def save_video_frames(self, video, num_frames=8, fps=-1, video_llm=False): |
|
|
|
vid_path = osp.join(self.data_root, 'video', video + '.mp4') |
|
vid = decord.VideoReader(vid_path) |
|
video_info = { |
|
'fps': vid.get_avg_fps(), |
|
'n_frames': len(vid), |
|
} |
|
if num_frames > 0 and fps < 0: |
|
step_size = len(vid) / (num_frames + 1) |
|
indices = [int(i * step_size) for i in range(1, num_frames + 1)] |
|
frame_paths = self.frame_paths(video, num_frames) |
|
elif fps > 0: |
|
|
|
total_duration = video_info['n_frames'] / video_info['fps'] |
|
required_frames = int(total_duration * fps) |
|
step_size = video_info['fps'] / fps |
|
indices = [int(i * step_size) for i in range(required_frames)] |
|
frame_paths = self.frame_paths_fps(video, len(indices), fps) |
|
|
|
flag = np.all([osp.exists(p) for p in frame_paths]) |
|
|
|
if not flag: |
|
images = [vid[i].asnumpy() for i in indices] |
|
images = [Image.fromarray(arr) for arr in images] |
|
for im, pth in zip(images, frame_paths): |
|
if not osp.exists(pth) and not video_llm: |
|
im.save(pth) |
|
|
|
return frame_paths, indices, video_info |
|
|
|
def save_video_into_images(self, line, num_frames=8): |
|
frame_paths, indices, video_info = self.save_video_frames(line['video'], num_frames) |
|
return frame_paths |
|
|
|
def build_prompt(self, line, num_frames, video_llm, fps): |
|
if isinstance(line, int): |
|
assert line < len(self) |
|
line = self.data.iloc[line] |
|
|
|
frames, indices, video_info = self.save_video_frames(line['video'], num_frames, fps, video_llm) |
|
|
|
if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])): |
|
import pysubs2 |
|
subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8') |
|
subtitles = [] |
|
|
|
for seleced_frame_id in indices: |
|
sub_text = '' |
|
cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id) |
|
for sub in subs: |
|
if sub.start < cur_time and sub.end > cur_time: |
|
sub_text = sub.text.replace('\\N', ' ') |
|
break |
|
if sub_text.strip(): |
|
subtitles.append(sub_text) |
|
subtitles = '\n'.join(subtitles) |
|
else: |
|
subtitles = '' |
|
|
|
message = [dict(type='text', value=self.SYS)] |
|
if video_llm: |
|
message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4'))) |
|
else: |
|
for im in frames: |
|
message.append(dict(type='image', value=im)) |
|
|
|
text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles) |
|
message.append(dict(type='text', value=text_prompt)) |
|
line['question'] += '\n' + '\n'.join(eval(line['candidates'])) |
|
prompt = 'Question: {}\nAnswer: '.format(line['question']) |
|
message.append(dict(type='text', value=prompt)) |
|
return message |
|
|
|
|
|
@classmethod |
|
def evaluate(self, eval_file, **judge_kwargs): |
|
from .utils.videomme import get_dimension_rating, extract_characters_regex, extract_option |
|
|
|
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file' |
|
|
|
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl') |
|
tgt_file = eval_file.replace('.xlsx', '_rating.json') |
|
score_file = eval_file.replace('.xlsx', '_score.xlsx') |
|
|
|
if not osp.exists(score_file): |
|
model = judge_kwargs.get('model', 'exact_matching') |
|
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125'] |
|
|
|
if model == 'exact_matching': |
|
model = None |
|
elif gpt_key_set(): |
|
model = build_judge(**judge_kwargs) |
|
if not model.working(): |
|
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation') |
|
warnings.warn(DEBUG_MESSAGE) |
|
model = None |
|
else: |
|
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation') |
|
model = None |
|
res = {} if not osp.exists(tmp_file) else load(tmp_file) |
|
res = {k: v for k, v in res.items() if FAIL_MSG not in v} |
|
|
|
data = load(eval_file) |
|
data_un = data[~pd.isna(data['prediction'])] |
|
|
|
for idx in data['index']: |
|
ans = data.loc[data['index'] == idx, 'answer'].values[0] |
|
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0]) |
|
|
|
if extract_characters_regex(pred) == '': |
|
extract_pred = extract_option( |
|
model, |
|
data.loc[data['index'] == idx].to_dict(orient='records')[0], |
|
'Video-MME' |
|
) |
|
data.loc[idx, 'score'] = int(extract_pred == ans) |
|
else: |
|
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans) |
|
|
|
rejected = [x for x in data['score'] if x == -1] |
|
|
|
print( |
|
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, ' |
|
f'failed to obtain the score for another {len(rejected)} questions. ' |
|
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.' |
|
) |
|
|
|
dump(data, score_file) |
|
|
|
rating = get_dimension_rating(score_file) |
|
dump(rating, tgt_file) |
|
return rating |
|
|