Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- VLMEvalKit/vlmeval/api/__init__.py +26 -0
- VLMEvalKit/vlmeval/api/bailingmm.py +90 -0
- VLMEvalKit/vlmeval/api/base.py +289 -0
- VLMEvalKit/vlmeval/api/bluelm_v_api.py +120 -0
- VLMEvalKit/vlmeval/api/claude.py +130 -0
- VLMEvalKit/vlmeval/api/cloudwalk.py +107 -0
- VLMEvalKit/vlmeval/api/gemini.py +116 -0
- VLMEvalKit/vlmeval/api/glm_vision.py +95 -0
- VLMEvalKit/vlmeval/api/gpt.py +267 -0
- VLMEvalKit/vlmeval/api/hf_chat_model.py +246 -0
- VLMEvalKit/vlmeval/api/hunyuan.py +147 -0
- VLMEvalKit/vlmeval/api/jt_vl_chat.py +239 -0
- VLMEvalKit/vlmeval/api/qwen_api.py +75 -0
- VLMEvalKit/vlmeval/api/qwen_vl_api.py +219 -0
- VLMEvalKit/vlmeval/api/reka.py +60 -0
- VLMEvalKit/vlmeval/api/sensechat_vision.py +261 -0
- VLMEvalKit/vlmeval/api/siliconflow.py +269 -0
- VLMEvalKit/vlmeval/api/stepai.py +87 -0
- VLMEvalKit/vlmeval/api/taiyi.py +192 -0
- VLMEvalKit/vlmeval/dataset/__init__.py +230 -0
- VLMEvalKit/vlmeval/dataset/cmmmu.py +354 -0
- VLMEvalKit/vlmeval/dataset/dude.py +211 -0
- VLMEvalKit/vlmeval/dataset/dynamath.py +240 -0
- VLMEvalKit/vlmeval/dataset/image_base.py +172 -0
- VLMEvalKit/vlmeval/dataset/image_caption.py +75 -0
- VLMEvalKit/vlmeval/dataset/image_mcq.py +899 -0
- VLMEvalKit/vlmeval/dataset/image_mt.py +128 -0
- VLMEvalKit/vlmeval/dataset/image_vqa.py +1330 -0
- VLMEvalKit/vlmeval/dataset/image_yorn.py +95 -0
- VLMEvalKit/vlmeval/dataset/longvideobench.py +328 -0
- VLMEvalKit/vlmeval/dataset/miabench.py +167 -0
- VLMEvalKit/vlmeval/dataset/mlvu.py +455 -0
- VLMEvalKit/vlmeval/dataset/mmbench_video.py +256 -0
- VLMEvalKit/vlmeval/dataset/mmgenbench.py +69 -0
- VLMEvalKit/vlmeval/dataset/mmlongbench.py +584 -0
- VLMEvalKit/vlmeval/dataset/mmmath.py +446 -0
- VLMEvalKit/vlmeval/dataset/mvbench.py +668 -0
- VLMEvalKit/vlmeval/dataset/slidevqa.py +189 -0
- VLMEvalKit/vlmeval/dataset/tempcompass.py +639 -0
- VLMEvalKit/vlmeval/dataset/text_base.py +88 -0
- VLMEvalKit/vlmeval/dataset/text_mcq.py +123 -0
- VLMEvalKit/vlmeval/dataset/vcr.py +335 -0
- VLMEvalKit/vlmeval/dataset/video_base.py +126 -0
- VLMEvalKit/vlmeval/dataset/video_concat_dataset.py +83 -0
- VLMEvalKit/vlmeval/dataset/videomme.py +287 -0
- VLMEvalKit/vlmeval/dataset/wildvision.py +218 -0
- VLMEvalKit/vlmeval/smp/__init__.py +4 -0
- VLMEvalKit/vlmeval/smp/log.py +47 -0
- VLMEvalKit/vlmeval/smp/misc.py +280 -0
- VLMEvalKit/vlmeval/vlm/aria.py +206 -0
VLMEvalKit/vlmeval/api/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .gpt import OpenAIWrapper, GPT4V
|
2 |
+
from .hf_chat_model import HFChatModel
|
3 |
+
from .gemini import GeminiWrapper, GeminiProVision
|
4 |
+
from .qwen_vl_api import QwenVLWrapper, QwenVLAPI, Qwen2VLAPI
|
5 |
+
from .qwen_api import QwenAPI
|
6 |
+
from .claude import Claude_Wrapper, Claude3V
|
7 |
+
from .reka import Reka
|
8 |
+
from .glm_vision import GLMVisionAPI
|
9 |
+
from .cloudwalk import CWWrapper
|
10 |
+
from .sensechat_vision import SenseChatVisionAPI
|
11 |
+
from .siliconflow import SiliconFlowAPI, TeleMMAPI
|
12 |
+
from .hunyuan import HunyuanVision
|
13 |
+
from .bailingmm import bailingMMAPI
|
14 |
+
from .bluelm_v_api import BlueLMWrapper, BlueLM_V_API
|
15 |
+
from .jt_vl_chat import JTVLChatAPI
|
16 |
+
from .taiyi import TaiyiAPI
|
17 |
+
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
'OpenAIWrapper', 'HFChatModel', 'GeminiWrapper', 'GPT4V',
|
21 |
+
'GeminiProVision', 'QwenVLWrapper', 'QwenVLAPI', 'QwenAPI',
|
22 |
+
'Claude3V', 'Claude_Wrapper', 'Reka', 'GLMVisionAPI',
|
23 |
+
'CWWrapper', 'SenseChatVisionAPI', 'HunyuanVision', 'Qwen2VLAPI',
|
24 |
+
'BlueLMWrapper', 'BlueLM_V_API', 'JTVLChatAPI', 'bailingMMAPI',
|
25 |
+
'TaiyiAPI', 'TeleMMAPI', 'SiliconFlowAPI'
|
26 |
+
]
|
VLMEvalKit/vlmeval/api/bailingmm.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from vlmeval.smp import *
|
3 |
+
from vlmeval.api.base import BaseAPI
|
4 |
+
from vlmeval.dataset import DATASET_TYPE
|
5 |
+
from vlmeval.smp.vlm import encode_image_file_to_base64
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
class bailingMMWrapper(BaseAPI):
|
10 |
+
|
11 |
+
is_api: bool = True
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
model: str,
|
15 |
+
retry: int = 5,
|
16 |
+
wait: int = 5,
|
17 |
+
key: str = None,
|
18 |
+
verbose: bool = True,
|
19 |
+
system_prompt: str = None,
|
20 |
+
max_tokens: int = 1024,
|
21 |
+
proxy: str = None,
|
22 |
+
**kwargs):
|
23 |
+
|
24 |
+
self.model = model
|
25 |
+
self.fail_msg = 'Failed to obtain answer via bailingMM API.'
|
26 |
+
if key is None:
|
27 |
+
key = os.environ.get('BAILINGMM_API_KEY', None)
|
28 |
+
assert key is not None, ('Please set the API Key for bailingMM.')
|
29 |
+
self.key = key
|
30 |
+
self.headers = {"Content-Type": "application/json"}
|
31 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
32 |
+
|
33 |
+
def image_to_base64(self, image_path):
|
34 |
+
with open(image_path, 'rb') as image_file:
|
35 |
+
encoded_string = str(base64.b64encode(image_file.read()), 'utf-8')
|
36 |
+
return encoded_string
|
37 |
+
|
38 |
+
def prepare_inputs(self, inputs):
|
39 |
+
msgs = cp.deepcopy(inputs)
|
40 |
+
content = []
|
41 |
+
for i, msg in enumerate(msgs):
|
42 |
+
if msg['type'] == 'text':
|
43 |
+
pass
|
44 |
+
else:
|
45 |
+
try:
|
46 |
+
image_data = self.image_to_base64(msg['value'])
|
47 |
+
except Exception as e:
|
48 |
+
if self.verbose:
|
49 |
+
self.logger.error(e)
|
50 |
+
image_data = ''
|
51 |
+
msg['value'] = image_data
|
52 |
+
content.append(msg)
|
53 |
+
return content
|
54 |
+
|
55 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
56 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
57 |
+
start = time.time()
|
58 |
+
inputs = [inputs] if isinstance(inputs, str) else inputs
|
59 |
+
|
60 |
+
messages = self.prepare_inputs(inputs)
|
61 |
+
|
62 |
+
service_url = "https://bailingchat.alipay.com/api/proxy/eval/antgmm/completions"
|
63 |
+
|
64 |
+
payload = {
|
65 |
+
"structInput": messages,
|
66 |
+
"sk": self.key,
|
67 |
+
"timeout": 180000
|
68 |
+
}
|
69 |
+
response = requests.post(service_url, headers=self.headers, json=payload)
|
70 |
+
if self.verbose:
|
71 |
+
self.logger.info('Time for requesting is:')
|
72 |
+
self.logger.info(time.time() - start)
|
73 |
+
try:
|
74 |
+
assert response.status_code == 200
|
75 |
+
output = json.loads(response.text)
|
76 |
+
answer = output['preds']['pred']
|
77 |
+
if self.verbose:
|
78 |
+
self.logger.info(f'inputs: {inputs}\nanswer: {answer}')
|
79 |
+
return 0, answer, 'Succeeded! '
|
80 |
+
except Exception as e:
|
81 |
+
if self.verbose:
|
82 |
+
self.logger.error(e)
|
83 |
+
self.logger.error(f'The input messages are {inputs}.')
|
84 |
+
return -1, self.fail_msg, ''
|
85 |
+
|
86 |
+
|
87 |
+
class bailingMMAPI(bailingMMWrapper):
|
88 |
+
|
89 |
+
def generate(self, message, dataset=None):
|
90 |
+
return super(bailingMMAPI, self).generate(message, dataset=dataset)
|
VLMEvalKit/vlmeval/api/base.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random as rd
|
3 |
+
from abc import abstractmethod
|
4 |
+
import os.path as osp
|
5 |
+
import copy as cp
|
6 |
+
from ..smp import get_logger, parse_file, concat_images_vlmeval, LMUDataRoot, md5, decode_base64_to_image_file
|
7 |
+
|
8 |
+
|
9 |
+
class BaseAPI:
|
10 |
+
|
11 |
+
allowed_types = ['text', 'image']
|
12 |
+
INTERLEAVE = True
|
13 |
+
INSTALL_REQ = False
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
retry=10,
|
17 |
+
wait=3,
|
18 |
+
system_prompt=None,
|
19 |
+
verbose=True,
|
20 |
+
fail_msg='Failed to obtain answer via API.',
|
21 |
+
**kwargs):
|
22 |
+
"""Base Class for all APIs.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
retry (int, optional): The retry times for `generate_inner`. Defaults to 10.
|
26 |
+
wait (int, optional): The wait time after each failed retry of `generate_inner`. Defaults to 3.
|
27 |
+
system_prompt (str, optional): Defaults to None.
|
28 |
+
verbose (bool, optional): Defaults to True.
|
29 |
+
fail_msg (str, optional): The message to return when failed to obtain answer.
|
30 |
+
Defaults to 'Failed to obtain answer via API.'.
|
31 |
+
**kwargs: Other kwargs for `generate_inner`.
|
32 |
+
"""
|
33 |
+
|
34 |
+
self.wait = wait
|
35 |
+
self.retry = retry
|
36 |
+
self.system_prompt = system_prompt
|
37 |
+
self.verbose = verbose
|
38 |
+
self.fail_msg = fail_msg
|
39 |
+
self.logger = get_logger('ChatAPI')
|
40 |
+
|
41 |
+
if len(kwargs):
|
42 |
+
self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
|
43 |
+
self.logger.info('Will try to use them as kwargs for `generate`. ')
|
44 |
+
self.default_kwargs = kwargs
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def generate_inner(self, inputs, **kwargs):
|
48 |
+
"""The inner function to generate the answer.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
tuple(int, str, str): ret_code, response, log
|
52 |
+
"""
|
53 |
+
self.logger.warning('For APIBase, generate_inner is an abstract method. ')
|
54 |
+
assert 0, 'generate_inner not defined'
|
55 |
+
ret_code, answer, log = None, None, None
|
56 |
+
# if ret_code is 0, means succeed
|
57 |
+
return ret_code, answer, log
|
58 |
+
|
59 |
+
def working(self):
|
60 |
+
"""If the API model is working, return True, else return False.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
bool: If the API model is working, return True, else return False.
|
64 |
+
"""
|
65 |
+
self.old_timeout = None
|
66 |
+
if hasattr(self, 'timeout'):
|
67 |
+
self.old_timeout = self.timeout
|
68 |
+
self.timeout = 120
|
69 |
+
|
70 |
+
retry = 5
|
71 |
+
while retry > 0:
|
72 |
+
ret = self.generate('hello')
|
73 |
+
if ret is not None and ret != '' and self.fail_msg not in ret:
|
74 |
+
if self.old_timeout is not None:
|
75 |
+
self.timeout = self.old_timeout
|
76 |
+
return True
|
77 |
+
retry -= 1
|
78 |
+
|
79 |
+
if self.old_timeout is not None:
|
80 |
+
self.timeout = self.old_timeout
|
81 |
+
return False
|
82 |
+
|
83 |
+
def check_content(self, msgs):
|
84 |
+
"""Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
msgs: Raw input messages.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
str: The message type.
|
91 |
+
"""
|
92 |
+
if isinstance(msgs, str):
|
93 |
+
return 'str'
|
94 |
+
if isinstance(msgs, dict):
|
95 |
+
return 'dict'
|
96 |
+
if isinstance(msgs, list):
|
97 |
+
types = [self.check_content(m) for m in msgs]
|
98 |
+
if all(t == 'str' for t in types):
|
99 |
+
return 'liststr'
|
100 |
+
if all(t == 'dict' for t in types):
|
101 |
+
return 'listdict'
|
102 |
+
return 'unknown'
|
103 |
+
|
104 |
+
def preproc_content(self, inputs):
|
105 |
+
"""Convert the raw input messages to a list of dicts.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
inputs: raw input messages.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
|
112 |
+
"""
|
113 |
+
if self.check_content(inputs) == 'str':
|
114 |
+
return [dict(type='text', value=inputs)]
|
115 |
+
elif self.check_content(inputs) == 'dict':
|
116 |
+
assert 'type' in inputs and 'value' in inputs
|
117 |
+
return [inputs]
|
118 |
+
elif self.check_content(inputs) == 'liststr':
|
119 |
+
res = []
|
120 |
+
for s in inputs:
|
121 |
+
mime, pth = parse_file(s)
|
122 |
+
if mime is None or mime == 'unknown':
|
123 |
+
res.append(dict(type='text', value=s))
|
124 |
+
else:
|
125 |
+
res.append(dict(type=mime.split('/')[0], value=pth))
|
126 |
+
return res
|
127 |
+
elif self.check_content(inputs) == 'listdict':
|
128 |
+
for item in inputs:
|
129 |
+
assert 'type' in item and 'value' in item
|
130 |
+
mime, s = parse_file(item['value'])
|
131 |
+
if mime is None:
|
132 |
+
assert item['type'] == 'text', item['value']
|
133 |
+
else:
|
134 |
+
assert mime.split('/')[0] == item['type']
|
135 |
+
item['value'] = s
|
136 |
+
return inputs
|
137 |
+
else:
|
138 |
+
return None
|
139 |
+
|
140 |
+
# May exceed the context windows size, so try with different turn numbers.
|
141 |
+
def chat_inner(self, inputs, **kwargs):
|
142 |
+
_ = kwargs.pop('dataset', None)
|
143 |
+
while len(inputs):
|
144 |
+
try:
|
145 |
+
return self.generate_inner(inputs, **kwargs)
|
146 |
+
except Exception as e:
|
147 |
+
if self.verbose:
|
148 |
+
self.logger.info(f'{type(e)}: {e}')
|
149 |
+
inputs = inputs[1:]
|
150 |
+
while len(inputs) and inputs[0]['role'] != 'user':
|
151 |
+
inputs = inputs[1:]
|
152 |
+
continue
|
153 |
+
return -1, self.fail_msg + ': ' + 'Failed with all possible conversation turns.', None
|
154 |
+
|
155 |
+
def chat(self, messages, **kwargs1):
|
156 |
+
"""The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
|
157 |
+
assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
|
158 |
+
for msg in messages:
|
159 |
+
assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
|
160 |
+
assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
|
161 |
+
msg['content'] = self.preproc_content(msg['content'])
|
162 |
+
# merge kwargs
|
163 |
+
kwargs = cp.deepcopy(self.default_kwargs)
|
164 |
+
kwargs.update(kwargs1)
|
165 |
+
|
166 |
+
answer = None
|
167 |
+
# a very small random delay [0s - 0.5s]
|
168 |
+
T = rd.random() * 0.5
|
169 |
+
time.sleep(T)
|
170 |
+
|
171 |
+
assert messages[-1]['role'] == 'user'
|
172 |
+
|
173 |
+
for i in range(self.retry):
|
174 |
+
try:
|
175 |
+
ret_code, answer, log = self.chat_inner(messages, **kwargs)
|
176 |
+
if ret_code == 0 and self.fail_msg not in answer and answer != '':
|
177 |
+
if self.verbose:
|
178 |
+
print(answer)
|
179 |
+
return answer
|
180 |
+
elif self.verbose:
|
181 |
+
if not isinstance(log, str):
|
182 |
+
try:
|
183 |
+
log = log.text
|
184 |
+
except Exception as e:
|
185 |
+
self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
|
186 |
+
self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
|
187 |
+
except Exception as err:
|
188 |
+
if self.verbose:
|
189 |
+
self.logger.error(f'An error occured during try {i}: ')
|
190 |
+
self.logger.error(f'{type(err)}: {err}')
|
191 |
+
# delay before each retry
|
192 |
+
T = rd.random() * self.wait * 2
|
193 |
+
time.sleep(T)
|
194 |
+
|
195 |
+
return self.fail_msg if answer in ['', None] else answer
|
196 |
+
|
197 |
+
def preprocess_message_with_role(self, message):
|
198 |
+
system_prompt = ''
|
199 |
+
new_message = []
|
200 |
+
|
201 |
+
for data in message:
|
202 |
+
assert isinstance(data, dict)
|
203 |
+
role = data.pop('role', 'user')
|
204 |
+
if role == 'system':
|
205 |
+
system_prompt += data['value'] + '\n'
|
206 |
+
else:
|
207 |
+
new_message.append(data)
|
208 |
+
|
209 |
+
if system_prompt != '':
|
210 |
+
if self.system_prompt is None:
|
211 |
+
self.system_prompt = system_prompt
|
212 |
+
else:
|
213 |
+
self.system_prompt += '\n' + system_prompt
|
214 |
+
return new_message
|
215 |
+
|
216 |
+
def generate(self, message, **kwargs1):
|
217 |
+
"""The main function to generate the answer. Will call `generate_inner` with the preprocessed input messages.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
message: raw input messages.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
str: The generated answer of the Failed Message if failed to obtain answer.
|
224 |
+
"""
|
225 |
+
if self.check_content(message) == 'listdict':
|
226 |
+
message = self.preprocess_message_with_role(message)
|
227 |
+
|
228 |
+
assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
|
229 |
+
message = self.preproc_content(message)
|
230 |
+
assert message is not None and self.check_content(message) == 'listdict'
|
231 |
+
for item in message:
|
232 |
+
assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
|
233 |
+
|
234 |
+
# merge kwargs
|
235 |
+
kwargs = cp.deepcopy(self.default_kwargs)
|
236 |
+
kwargs.update(kwargs1)
|
237 |
+
|
238 |
+
answer = None
|
239 |
+
# a very small random delay [0s - 0.5s]
|
240 |
+
T = rd.random() * 0.5
|
241 |
+
time.sleep(T)
|
242 |
+
|
243 |
+
for i in range(self.retry):
|
244 |
+
try:
|
245 |
+
ret_code, answer, log = self.generate_inner(message, **kwargs)
|
246 |
+
if ret_code == 0 and self.fail_msg not in answer and answer != '':
|
247 |
+
if self.verbose:
|
248 |
+
print(answer)
|
249 |
+
return answer
|
250 |
+
elif self.verbose:
|
251 |
+
if not isinstance(log, str):
|
252 |
+
try:
|
253 |
+
log = log.text
|
254 |
+
except Exception as e:
|
255 |
+
self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
|
256 |
+
self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
|
257 |
+
except Exception as err:
|
258 |
+
if self.verbose:
|
259 |
+
self.logger.error(f'An error occured during try {i}: ')
|
260 |
+
self.logger.error(f'{type(err)}: {err}')
|
261 |
+
# delay before each retry
|
262 |
+
T = rd.random() * self.wait * 2
|
263 |
+
time.sleep(T)
|
264 |
+
|
265 |
+
return self.fail_msg if answer in ['', None] else answer
|
266 |
+
|
267 |
+
def message_to_promptimg(self, message, dataset=None):
|
268 |
+
assert not self.INTERLEAVE
|
269 |
+
model_name = self.__class__.__name__
|
270 |
+
import warnings
|
271 |
+
warnings.warn(
|
272 |
+
f'Model {model_name} does not support interleaved input. '
|
273 |
+
'Will use the first image and aggregated texts as prompt. ')
|
274 |
+
num_images = len([x for x in message if x['type'] == 'image'])
|
275 |
+
if num_images == 0:
|
276 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
277 |
+
image = None
|
278 |
+
elif num_images == 1:
|
279 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
280 |
+
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
281 |
+
else:
|
282 |
+
prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
|
283 |
+
if dataset == 'BLINK':
|
284 |
+
image = concat_images_vlmeval(
|
285 |
+
[x['value'] for x in message if x['type'] == 'image'],
|
286 |
+
target_size=512)
|
287 |
+
else:
|
288 |
+
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
289 |
+
return prompt, image
|
VLMEvalKit/vlmeval/api/bluelm_v_api.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
def multimodal(images, text, url, key, temperature=0, max_tokens=1024, history=[]):
|
8 |
+
if images:
|
9 |
+
pics = []
|
10 |
+
for image in images:
|
11 |
+
with open(image, 'rb') as f:
|
12 |
+
pic = base64.b64encode(f.read()).decode('utf-8')
|
13 |
+
pics.append(pic)
|
14 |
+
data = {'images': pics, 'text': text, 'key': key, 'temperature': temperature, 'max_new_tokens': max_tokens}
|
15 |
+
else:
|
16 |
+
data = {'text': text, 'key': key, 'temperature': temperature, 'max_new_tokens': max_tokens}
|
17 |
+
response = requests.post(url, json=data, headers={'Content-Type': 'application/json'})
|
18 |
+
response = json.loads(response.text)
|
19 |
+
return response
|
20 |
+
|
21 |
+
|
22 |
+
class BlueLMWrapper(BaseAPI):
|
23 |
+
is_api: bool = True
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
model: str = 'BlueLM-V-v3.0',
|
27 |
+
retry: int = 5,
|
28 |
+
wait: int = 5,
|
29 |
+
verbose: bool = True,
|
30 |
+
temperature: float = 0.0,
|
31 |
+
system_prompt: str = None,
|
32 |
+
max_tokens: int = 1024,
|
33 |
+
key: str = None,
|
34 |
+
url: str = 'http://api-ai.vivo.com.cn/multimodal',
|
35 |
+
**kwargs):
|
36 |
+
|
37 |
+
self.model = model
|
38 |
+
self.fail_msg = 'Failed to obtain answer BlueLM-V API. '
|
39 |
+
self.max_tokens = max_tokens
|
40 |
+
self.temperature = temperature
|
41 |
+
self.url = url
|
42 |
+
self.key = key
|
43 |
+
|
44 |
+
if self.key is None:
|
45 |
+
self.key = os.environ.get('BLUELM_V_API_KEY', None)
|
46 |
+
assert self.key is not None, (
|
47 |
+
'Please set the API Key (obtain it here: '
|
48 |
+
'contact by email : [email protected]'
|
49 |
+
)
|
50 |
+
|
51 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
52 |
+
|
53 |
+
def message_to_promptimg(self, message, dataset=None):
|
54 |
+
|
55 |
+
num_images = len([x for x in message if x['type'] == 'image'])
|
56 |
+
if num_images == 0:
|
57 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
58 |
+
image = None
|
59 |
+
elif num_images == 1:
|
60 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
61 |
+
image = [x['value'] for x in message if x['type'] == 'image']
|
62 |
+
else:
|
63 |
+
prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
|
64 |
+
if dataset == 'BLINK':
|
65 |
+
image = concat_images_vlmeval(
|
66 |
+
[x['value'] for x in message if x['type'] == 'image'],
|
67 |
+
target_size=512)
|
68 |
+
else:
|
69 |
+
image = [x['value'] for x in message if x['type'] == 'image']
|
70 |
+
|
71 |
+
if dataset in ['MMBench_DEV_EN_V11', 'MMBench_DEV_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11',
|
72 |
+
'AI2D_TEST', 'AI2D_TEST_TO_MASK', 'MMMU_DEV_VAL']:
|
73 |
+
prompt = prompt.replace('Please select the correct answer from the options above.',
|
74 |
+
'Answer with the option’s letter from the given choices directly.')
|
75 |
+
elif dataset in ['ChartQA_TEST']:
|
76 |
+
prompt = prompt.replace('Answer the question using a single word or phrase.',
|
77 |
+
'Answer the question using a single number or phrase.')
|
78 |
+
elif dataset in ['DocVQA_VAL', 'DocVQA_TEST', ]:
|
79 |
+
prompt = prompt.replace('Answer the question using a single word or phrase.',
|
80 |
+
'Give the short answer directly.')
|
81 |
+
elif dataset in ['TextVQA_VAL']:
|
82 |
+
prompt = prompt.replace('Answer the question using a single word or phrase.',
|
83 |
+
'When the provided information is insufficient, respond with ’Unanswerable’.'
|
84 |
+
'Answer the question using a single word or phrase.')
|
85 |
+
elif dataset in ['MTVQA_TEST']:
|
86 |
+
prompt = prompt.replace('\nAnswer the question using a word or phrase in the language of the question.', '')
|
87 |
+
elif dataset in ['MathVista_MINI']:
|
88 |
+
if 'Choices:' in prompt:
|
89 |
+
prompt = prompt.replace('Choices:', 'Options:').replace('Hint:', 'Context:')
|
90 |
+
for i in range(1, 7): # replace A ~ F
|
91 |
+
prompt = prompt.replace(f'({chr(64 + i)})', f'{chr(64 + i)}.')
|
92 |
+
prompt += '\nAnswer with the option’s letter from the given choices directly.'
|
93 |
+
else:
|
94 |
+
prompt += '\nAnswer the question using a single word or phrase.'
|
95 |
+
|
96 |
+
return prompt, image
|
97 |
+
|
98 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
99 |
+
|
100 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
101 |
+
pure_text = np.all([x['type'] == 'text' for x in inputs])
|
102 |
+
assert not pure_text
|
103 |
+
|
104 |
+
prompt, image_path = self.message_to_promptimg(inputs, kwargs['dataset'])
|
105 |
+
|
106 |
+
try:
|
107 |
+
response = multimodal(image_path, prompt, self.url, self.key, self.temperature, self.max_tokens)
|
108 |
+
answer = response['result']
|
109 |
+
return 0, answer, 'Succeeded! '
|
110 |
+
except Exception as err:
|
111 |
+
if self.verbose:
|
112 |
+
self.logger.error(f'{type(err)}: {err}')
|
113 |
+
self.logger.error(f'The input messages are {inputs}.')
|
114 |
+
return -1, '', ''
|
115 |
+
|
116 |
+
|
117 |
+
class BlueLM_V_API(BlueLMWrapper):
|
118 |
+
|
119 |
+
def generate(self, message, dataset=None):
|
120 |
+
return super(BlueLM_V_API, self).generate(message, dataset=dataset)
|
VLMEvalKit/vlmeval/api/claude.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
from time import sleep
|
4 |
+
import base64
|
5 |
+
import mimetypes
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
alles_url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
|
9 |
+
alles_headers = {
|
10 |
+
'alles-apin-token': '',
|
11 |
+
'Content-Type': 'application/json'
|
12 |
+
}
|
13 |
+
official_url = 'https://api.anthropic.com/v1/messages'
|
14 |
+
official_headers = {
|
15 |
+
'x-api-key': '',
|
16 |
+
'anthropic-version': '2023-06-01',
|
17 |
+
'content-type': 'application/json'
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
class Claude_Wrapper(BaseAPI):
|
22 |
+
|
23 |
+
is_api: bool = True
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
backend: str = 'alles',
|
27 |
+
model: str = 'claude-3-opus-20240229',
|
28 |
+
key: str = None,
|
29 |
+
retry: int = 10,
|
30 |
+
wait: int = 3,
|
31 |
+
system_prompt: str = None,
|
32 |
+
verbose: bool = True,
|
33 |
+
temperature: float = 0,
|
34 |
+
max_tokens: int = 1024,
|
35 |
+
**kwargs):
|
36 |
+
|
37 |
+
if os.environ.get('ANTHROPIC_BACKEND', '') == 'official':
|
38 |
+
backend = 'official'
|
39 |
+
|
40 |
+
assert backend in ['alles', 'official'], f'Invalid backend: {backend}'
|
41 |
+
self.backend = backend
|
42 |
+
self.url = alles_url if backend == 'alles' else official_url
|
43 |
+
self.model = model
|
44 |
+
self.temperature = temperature
|
45 |
+
self.max_tokens = max_tokens
|
46 |
+
self.headers = alles_headers if backend == 'alles' else official_headers
|
47 |
+
|
48 |
+
if key is not None:
|
49 |
+
self.key = key
|
50 |
+
else:
|
51 |
+
self.key = os.environ.get('ALLES', '') if self.backend == 'alles' else os.environ.get('ANTHROPIC_API_KEY', '') # noqa: E501
|
52 |
+
|
53 |
+
if self.backend == 'alles':
|
54 |
+
self.headers['alles-apin-token'] = self.key
|
55 |
+
else:
|
56 |
+
self.headers['x-api-key'] = self.key
|
57 |
+
|
58 |
+
super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
|
59 |
+
|
60 |
+
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
61 |
+
# content can be a string or a list of image & text
|
62 |
+
def prepare_itlist(self, inputs):
|
63 |
+
assert np.all([isinstance(x, dict) for x in inputs])
|
64 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
65 |
+
if has_images:
|
66 |
+
content_list = []
|
67 |
+
for msg in inputs:
|
68 |
+
if msg['type'] == 'text' and msg['value'] != '':
|
69 |
+
content_list.append(dict(type='text', text=msg['value']))
|
70 |
+
elif msg['type'] == 'image':
|
71 |
+
pth = msg['value']
|
72 |
+
suffix = osp.splitext(pth)[-1].lower()
|
73 |
+
media_type = mimetypes.types_map.get(suffix, None)
|
74 |
+
assert media_type is not None
|
75 |
+
|
76 |
+
content_list.append(dict(
|
77 |
+
type='image',
|
78 |
+
source={
|
79 |
+
'type': 'base64',
|
80 |
+
'media_type': media_type,
|
81 |
+
'data': encode_image_file_to_base64(pth, target_size=4096)
|
82 |
+
}))
|
83 |
+
else:
|
84 |
+
assert all([x['type'] == 'text' for x in inputs])
|
85 |
+
text = '\n'.join([x['value'] for x in inputs])
|
86 |
+
content_list = [dict(type='text', text=text)]
|
87 |
+
return content_list
|
88 |
+
|
89 |
+
def prepare_inputs(self, inputs):
|
90 |
+
input_msgs = []
|
91 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], dict)
|
92 |
+
assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
|
93 |
+
if 'role' in inputs[0]:
|
94 |
+
assert inputs[-1]['role'] == 'user', inputs[-1]
|
95 |
+
for item in inputs:
|
96 |
+
input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
|
97 |
+
else:
|
98 |
+
input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
|
99 |
+
return input_msgs
|
100 |
+
|
101 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
102 |
+
payload = {
|
103 |
+
'model': self.model,
|
104 |
+
'max_tokens': self.max_tokens,
|
105 |
+
'messages': self.prepare_inputs(inputs),
|
106 |
+
**kwargs
|
107 |
+
}
|
108 |
+
if self.system_prompt is not None:
|
109 |
+
payload['system'] = self.system_prompt
|
110 |
+
|
111 |
+
response = requests.request('POST', self.url, headers=self.headers, data=json.dumps(payload))
|
112 |
+
ret_code = response.status_code
|
113 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
114 |
+
answer = self.fail_msg
|
115 |
+
|
116 |
+
try:
|
117 |
+
resp_struct = json.loads(response.text)
|
118 |
+
answer = resp_struct['data']['content'][0]['text'].strip()
|
119 |
+
except Exception as err:
|
120 |
+
if self.verbose:
|
121 |
+
self.logger.error(f'{type(err)}: {err}')
|
122 |
+
self.logger.error(response.text if hasattr(response, 'text') else response)
|
123 |
+
|
124 |
+
return ret_code, answer, response
|
125 |
+
|
126 |
+
|
127 |
+
class Claude3V(Claude_Wrapper):
|
128 |
+
|
129 |
+
def generate(self, message, dataset=None):
|
130 |
+
return super(Claude_Wrapper, self).generate(message)
|
VLMEvalKit/vlmeval/api/cloudwalk.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..smp import *
|
2 |
+
import os
|
3 |
+
from .base import BaseAPI
|
4 |
+
|
5 |
+
|
6 |
+
class CWWrapper(BaseAPI):
|
7 |
+
|
8 |
+
is_api: bool = True
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
model: str = 'cw-congrong-v1.5',
|
12 |
+
retry: int = 10,
|
13 |
+
wait: int = 5,
|
14 |
+
key: str = None,
|
15 |
+
verbose: bool = True,
|
16 |
+
system_prompt: str = None,
|
17 |
+
temperature: float = 0,
|
18 |
+
timeout: int = 600,
|
19 |
+
api_base: str = 'http://cwapi-vlm01.cw_rb.azurebot.tk/v1/chat/completions',
|
20 |
+
max_tokens: int = 1024,
|
21 |
+
img_size: int = 512,
|
22 |
+
img_detail: str = 'low',
|
23 |
+
**kwargs):
|
24 |
+
|
25 |
+
self.model = model
|
26 |
+
self.cur_idx = 0
|
27 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
28 |
+
self.max_tokens = max_tokens
|
29 |
+
self.temperature = temperature
|
30 |
+
|
31 |
+
base = os.environ.get('CW_API_BASE', None)
|
32 |
+
self.api_base = base if base is not None else api_base
|
33 |
+
|
34 |
+
env_key = os.environ.get('CW_API_KEY', None)
|
35 |
+
self.key = env_key if env_key is not None else key
|
36 |
+
assert self.key is not None, 'API key not provided. Please set CW_API_KEY environment variable or \
|
37 |
+
pass it to the constructor.'
|
38 |
+
|
39 |
+
assert img_size > 0 or img_size == -1
|
40 |
+
self.img_size = -1 # allways send full size image
|
41 |
+
assert img_detail in ['high', 'low']
|
42 |
+
self.img_detail = img_detail
|
43 |
+
|
44 |
+
self.vision = True
|
45 |
+
self.timeout = timeout
|
46 |
+
|
47 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
48 |
+
|
49 |
+
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
50 |
+
# content can be a string or a list of image & text
|
51 |
+
def prepare_inputs(self, inputs):
|
52 |
+
input_msgs = []
|
53 |
+
if self.system_prompt is not None:
|
54 |
+
input_msgs.append(dict(role='system', content=self.system_prompt))
|
55 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
56 |
+
if has_images:
|
57 |
+
content_list = []
|
58 |
+
for msg in inputs:
|
59 |
+
if msg['type'] == 'text':
|
60 |
+
content_list.append(dict(type='text', text=msg['value']))
|
61 |
+
elif msg['type'] == 'image':
|
62 |
+
from PIL import Image
|
63 |
+
img = Image.open(msg['value'])
|
64 |
+
b64 = encode_image_to_base64(img, target_size=self.img_size)
|
65 |
+
img_struct = dict(url=f"data:image/jpeg;base64,{b64}", detail=self.img_detail)
|
66 |
+
content_list.append(dict(type='image_url', image_url=img_struct))
|
67 |
+
input_msgs.append(dict(role='user', content=content_list))
|
68 |
+
else:
|
69 |
+
assert all([x['type'] == 'text' for x in inputs])
|
70 |
+
text = '\n'.join([x['value'] for x in inputs])
|
71 |
+
input_msgs.append(dict(role='user', content=text))
|
72 |
+
return input_msgs
|
73 |
+
|
74 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
75 |
+
input_msgs = self.prepare_inputs(inputs)
|
76 |
+
temperature = kwargs.pop('temperature', self.temperature)
|
77 |
+
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
|
78 |
+
|
79 |
+
if 0 < max_tokens <= 100:
|
80 |
+
self.logger.warning(
|
81 |
+
'Less than 100 tokens left, '
|
82 |
+
'may exceed the context window with some additional meta symbols. '
|
83 |
+
)
|
84 |
+
if max_tokens <= 0:
|
85 |
+
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
|
86 |
+
|
87 |
+
headers = {'Content-Type': 'application/json', 'Authorization': f'{self.key}'}
|
88 |
+
payload = dict(
|
89 |
+
model=self.model,
|
90 |
+
messages=input_msgs,
|
91 |
+
max_tokens=max_tokens,
|
92 |
+
n=1,
|
93 |
+
temperature=temperature,
|
94 |
+
**kwargs)
|
95 |
+
response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
96 |
+
ret_code = response.status_code
|
97 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
98 |
+
answer = self.fail_msg
|
99 |
+
try:
|
100 |
+
resp_struct = json.loads(response.text)
|
101 |
+
answer = resp_struct['choices'][0]['message']['content'].strip()
|
102 |
+
except Exception as err:
|
103 |
+
if self.verbose:
|
104 |
+
self.logger.error(f'{type(err)}: {err}')
|
105 |
+
self.logger.error(response.text if hasattr(response, 'text') else response)
|
106 |
+
|
107 |
+
return ret_code, answer, response
|
VLMEvalKit/vlmeval/api/gemini.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
|
4 |
+
headers = 'Content-Type: application/json'
|
5 |
+
|
6 |
+
|
7 |
+
class GeminiWrapper(BaseAPI):
|
8 |
+
|
9 |
+
is_api: bool = True
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
model: str = 'gemini-1.0-pro',
|
13 |
+
retry: int = 5,
|
14 |
+
wait: int = 5,
|
15 |
+
key: str = None,
|
16 |
+
verbose: bool = True,
|
17 |
+
temperature: float = 0.0,
|
18 |
+
system_prompt: str = None,
|
19 |
+
max_tokens: int = 1024,
|
20 |
+
proxy: str = None,
|
21 |
+
backend='genai',
|
22 |
+
project_id='vlmeval',
|
23 |
+
**kwargs):
|
24 |
+
|
25 |
+
self.model = model
|
26 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
27 |
+
self.max_tokens = max_tokens
|
28 |
+
self.temperature = temperature
|
29 |
+
if key is None:
|
30 |
+
key = os.environ.get('GOOGLE_API_KEY', None)
|
31 |
+
# Try to load backend from environment variable
|
32 |
+
be = os.environ.get('GOOGLE_API_BACKEND', None)
|
33 |
+
if be is not None and be in ['genai', 'vertex']:
|
34 |
+
backend = be
|
35 |
+
|
36 |
+
assert backend in ['genai', 'vertex']
|
37 |
+
if backend == 'genai':
|
38 |
+
# We have not evaluated Gemini-1.5 w. GenAI backend
|
39 |
+
assert key is not None # Vertex does not require API Key
|
40 |
+
|
41 |
+
self.backend = backend
|
42 |
+
self.project_id = project_id
|
43 |
+
self.api_key = key
|
44 |
+
|
45 |
+
if proxy is not None:
|
46 |
+
proxy_set(proxy)
|
47 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
48 |
+
|
49 |
+
def build_msgs_genai(self, inputs):
|
50 |
+
messages = [] if self.system_prompt is None else [self.system_prompt]
|
51 |
+
for inp in inputs:
|
52 |
+
if inp['type'] == 'text':
|
53 |
+
messages.append(inp['value'])
|
54 |
+
elif inp['type'] == 'image':
|
55 |
+
messages.append(Image.open(inp['value']))
|
56 |
+
return messages
|
57 |
+
|
58 |
+
def build_msgs_vertex(self, inputs):
|
59 |
+
from vertexai.generative_models import Part, Image
|
60 |
+
messages = [] if self.system_prompt is None else [self.system_prompt]
|
61 |
+
for inp in inputs:
|
62 |
+
if inp['type'] == 'text':
|
63 |
+
messages.append(inp['value'])
|
64 |
+
elif inp['type'] == 'image':
|
65 |
+
messages.append(Part.from_image(Image.load_from_file(inp['value'])))
|
66 |
+
return messages
|
67 |
+
|
68 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
69 |
+
if self.backend == 'genai':
|
70 |
+
import google.generativeai as genai
|
71 |
+
assert isinstance(inputs, list)
|
72 |
+
pure_text = np.all([x['type'] == 'text' for x in inputs])
|
73 |
+
genai.configure(api_key=self.api_key)
|
74 |
+
|
75 |
+
if pure_text and self.model == 'gemini-1.0-pro':
|
76 |
+
model = genai.GenerativeModel('gemini-1.0-pro')
|
77 |
+
else:
|
78 |
+
model = genai.GenerativeModel(self.model)
|
79 |
+
|
80 |
+
messages = self.build_msgs_genai(inputs)
|
81 |
+
gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
|
82 |
+
gen_config.update(kwargs)
|
83 |
+
try:
|
84 |
+
answer = model.generate_content(
|
85 |
+
messages,
|
86 |
+
generation_config=genai.types.GenerationConfig(**gen_config)).text
|
87 |
+
return 0, answer, 'Succeeded! '
|
88 |
+
except Exception as err:
|
89 |
+
if self.verbose:
|
90 |
+
self.logger.error(f'{type(err)}: {err}')
|
91 |
+
self.logger.error(f'The input messages are {inputs}.')
|
92 |
+
|
93 |
+
return -1, '', ''
|
94 |
+
elif self.backend == 'vertex':
|
95 |
+
import vertexai
|
96 |
+
from vertexai.generative_models import GenerativeModel
|
97 |
+
vertexai.init(project=self.project_id, location='us-central1')
|
98 |
+
model_name = 'gemini-1.0-pro-vision' if self.model == 'gemini-1.0-pro' else self.model
|
99 |
+
model = GenerativeModel(model_name=model_name)
|
100 |
+
messages = self.build_msgs_vertex(inputs)
|
101 |
+
try:
|
102 |
+
resp = model.generate_content(messages)
|
103 |
+
answer = resp.text
|
104 |
+
return 0, answer, 'Succeeded! '
|
105 |
+
except Exception as err:
|
106 |
+
if self.verbose:
|
107 |
+
self.logger.error(f'{type(err)}: {err}')
|
108 |
+
self.logger.error(f'The input messages are {inputs}.')
|
109 |
+
|
110 |
+
return -1, '', ''
|
111 |
+
|
112 |
+
|
113 |
+
class GeminiProVision(GeminiWrapper):
|
114 |
+
|
115 |
+
def generate(self, message, dataset=None):
|
116 |
+
return super(GeminiProVision, self).generate(message)
|
VLMEvalKit/vlmeval/api/glm_vision.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
requests.packages.urllib3.disable_warnings()
|
3 |
+
|
4 |
+
from vlmeval.smp import *
|
5 |
+
from vlmeval.api.base import BaseAPI
|
6 |
+
from vlmeval.dataset import DATASET_TYPE
|
7 |
+
from vlmeval.smp.vlm import encode_image_file_to_base64
|
8 |
+
|
9 |
+
|
10 |
+
class GLMVisionWrapper(BaseAPI):
|
11 |
+
|
12 |
+
is_api: bool = True
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
model: str,
|
16 |
+
retry: int = 5,
|
17 |
+
wait: int = 5,
|
18 |
+
key: str = None,
|
19 |
+
verbose: bool = True,
|
20 |
+
system_prompt: str = None,
|
21 |
+
max_tokens: int = 4096,
|
22 |
+
proxy: str = None,
|
23 |
+
**kwargs):
|
24 |
+
|
25 |
+
self.model = model
|
26 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
27 |
+
self.default_params = {
|
28 |
+
'top_k': 1,
|
29 |
+
'best_of': 1,
|
30 |
+
'do_sample': False,
|
31 |
+
'stream': False,
|
32 |
+
'max_tokens': max_tokens,
|
33 |
+
"skip_moderation": True
|
34 |
+
}
|
35 |
+
if key is None:
|
36 |
+
key = os.environ.get('GLMV_API_KEY', None)
|
37 |
+
assert key is not None, (
|
38 |
+
'Please set the API Key (obtain it here: '
|
39 |
+
'https://open.bigmodel.cn/dev/howuse/introduction)'
|
40 |
+
)
|
41 |
+
self.key = key
|
42 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
43 |
+
|
44 |
+
def build_msgs(self, msgs_raw, system_prompt=None, dataset=None):
|
45 |
+
msgs = cp.deepcopy(msgs_raw)
|
46 |
+
content = []
|
47 |
+
for i, msg in enumerate(msgs):
|
48 |
+
if msg['type'] == 'text':
|
49 |
+
content.append(dict(type='text', text=msg['value']))
|
50 |
+
elif msg['type'] == 'image':
|
51 |
+
content.append(dict(type='image_url', image_url=dict(url=encode_image_file_to_base64(msg['value']))))
|
52 |
+
if dataset in {'HallusionBench', 'POPE'}:
|
53 |
+
content.append(dict(type="text", text="Please answer yes or no."))
|
54 |
+
ret = [dict(role='user', content=content)]
|
55 |
+
return ret
|
56 |
+
|
57 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
58 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
59 |
+
inputs = [inputs] if isinstance(inputs, str) else inputs
|
60 |
+
|
61 |
+
messages = self.build_msgs(msgs_raw=inputs, dataset=kwargs.get('dataset', None))
|
62 |
+
|
63 |
+
url = 'https://api.chatglm.cn/v1/chat/completions'
|
64 |
+
headers = {
|
65 |
+
'Content-Type': 'application/json',
|
66 |
+
'Request-Id': 'remote-test',
|
67 |
+
'Authorization': f'Bearer {self.key}'
|
68 |
+
}
|
69 |
+
payload = {
|
70 |
+
'model': self.model,
|
71 |
+
'messages': messages,
|
72 |
+
**self.default_params
|
73 |
+
}
|
74 |
+
response = requests.post(url, headers=headers, data=json.dumps(payload), verify=False)
|
75 |
+
output = []
|
76 |
+
try:
|
77 |
+
assert response.status_code == 200
|
78 |
+
for line in response.iter_lines():
|
79 |
+
data = json.loads(line.decode('utf-8').lstrip('data: '))
|
80 |
+
output.append(data['choices'][0]['message']['content'])
|
81 |
+
answer = ''.join(output).replace('</s>', '')
|
82 |
+
if self.verbose:
|
83 |
+
self.logger.info(f'inputs: {inputs}\nanswer: {answer}')
|
84 |
+
return 0, answer, 'Succeeded! '
|
85 |
+
except Exception as err:
|
86 |
+
if self.verbose:
|
87 |
+
self.logger.error(f'{type(err)}: {err}')
|
88 |
+
self.logger.error(f'The input messages are {inputs}.')
|
89 |
+
return -1, self.fail_msg, ''
|
90 |
+
|
91 |
+
|
92 |
+
class GLMVisionAPI(GLMVisionWrapper):
|
93 |
+
|
94 |
+
def generate(self, message, dataset=None):
|
95 |
+
return super(GLMVisionAPI, self).generate(message, dataset=dataset)
|
VLMEvalKit/vlmeval/api/gpt.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..smp import *
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from .base import BaseAPI
|
5 |
+
|
6 |
+
APIBASES = {
|
7 |
+
'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
def GPT_context_window(model):
|
12 |
+
length_map = {
|
13 |
+
'gpt-4': 8192,
|
14 |
+
'gpt-4-0613': 8192,
|
15 |
+
'gpt-4-turbo-preview': 128000,
|
16 |
+
'gpt-4-1106-preview': 128000,
|
17 |
+
'gpt-4-0125-preview': 128000,
|
18 |
+
'gpt-4-vision-preview': 128000,
|
19 |
+
'gpt-4-turbo': 128000,
|
20 |
+
'gpt-4-turbo-2024-04-09': 128000,
|
21 |
+
'gpt-3.5-turbo': 16385,
|
22 |
+
'gpt-3.5-turbo-0125': 16385,
|
23 |
+
'gpt-3.5-turbo-1106': 16385,
|
24 |
+
'gpt-3.5-turbo-instruct': 4096,
|
25 |
+
}
|
26 |
+
if model in length_map:
|
27 |
+
return length_map[model]
|
28 |
+
else:
|
29 |
+
return 128000
|
30 |
+
|
31 |
+
|
32 |
+
class OpenAIWrapper(BaseAPI):
|
33 |
+
|
34 |
+
is_api: bool = True
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
model: str = 'gpt-3.5-turbo-0613',
|
38 |
+
retry: int = 5,
|
39 |
+
wait: int = 5,
|
40 |
+
key: str = None,
|
41 |
+
verbose: bool = False,
|
42 |
+
system_prompt: str = None,
|
43 |
+
temperature: float = 0,
|
44 |
+
timeout: int = 60,
|
45 |
+
api_base: str = None,
|
46 |
+
max_tokens: int = 1024,
|
47 |
+
img_size: int = 512,
|
48 |
+
img_detail: str = 'low',
|
49 |
+
use_azure: bool = False,
|
50 |
+
**kwargs):
|
51 |
+
|
52 |
+
self.model = model
|
53 |
+
self.cur_idx = 0
|
54 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
55 |
+
self.max_tokens = max_tokens
|
56 |
+
self.temperature = temperature
|
57 |
+
self.use_azure = use_azure
|
58 |
+
|
59 |
+
if 'step' in model:
|
60 |
+
env_key = os.environ.get('STEPAI_API_KEY', '')
|
61 |
+
if key is None:
|
62 |
+
key = env_key
|
63 |
+
elif 'yi-vision' in model:
|
64 |
+
env_key = os.environ.get('YI_API_KEY', '')
|
65 |
+
if key is None:
|
66 |
+
key = env_key
|
67 |
+
elif 'internvl2-pro' in model:
|
68 |
+
env_key = os.environ.get('InternVL2_PRO_KEY', '')
|
69 |
+
if key is None:
|
70 |
+
key = env_key
|
71 |
+
elif 'abab' in model:
|
72 |
+
env_key = os.environ.get('MiniMax_API_KEY', '')
|
73 |
+
if key is None:
|
74 |
+
key = env_key
|
75 |
+
else:
|
76 |
+
if use_azure:
|
77 |
+
env_key = os.environ.get('AZURE_OPENAI_API_KEY', None)
|
78 |
+
assert env_key is not None, 'Please set the environment variable AZURE_OPENAI_API_KEY. '
|
79 |
+
|
80 |
+
if key is None:
|
81 |
+
key = env_key
|
82 |
+
assert isinstance(key, str), (
|
83 |
+
'Please set the environment variable AZURE_OPENAI_API_KEY to your openai key. '
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
env_key = os.environ.get('OPENAI_API_KEY', '')
|
87 |
+
if key is None:
|
88 |
+
key = env_key
|
89 |
+
assert isinstance(key, str) and key.startswith('sk-'), (
|
90 |
+
f'Illegal openai_key {key}. '
|
91 |
+
'Please set the environment variable OPENAI_API_KEY to your openai key. '
|
92 |
+
)
|
93 |
+
|
94 |
+
self.key = key
|
95 |
+
assert img_size > 0 or img_size == -1
|
96 |
+
self.img_size = img_size
|
97 |
+
assert img_detail in ['high', 'low']
|
98 |
+
self.img_detail = img_detail
|
99 |
+
self.timeout = timeout
|
100 |
+
|
101 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
102 |
+
|
103 |
+
if use_azure:
|
104 |
+
api_base_template = (
|
105 |
+
'{endpoint}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}'
|
106 |
+
)
|
107 |
+
endpoint = os.getenv('AZURE_OPENAI_ENDPOINT', None)
|
108 |
+
assert endpoint is not None, 'Please set the environment variable AZURE_OPENAI_ENDPOINT. '
|
109 |
+
deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
110 |
+
assert deployment_name is not None, 'Please set the environment variable AZURE_OPENAI_DEPLOYMENT_NAME. '
|
111 |
+
api_version = os.getenv('OPENAI_API_VERSION', None)
|
112 |
+
assert api_version is not None, 'Please set the environment variable OPENAI_API_VERSION. '
|
113 |
+
|
114 |
+
self.api_base = api_base_template.format(
|
115 |
+
endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
|
116 |
+
deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'),
|
117 |
+
api_version=os.getenv('OPENAI_API_VERSION')
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
if api_base is None:
|
121 |
+
if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
|
122 |
+
self.logger.info('Environment variable OPENAI_API_BASE is set. Will use it as api_base. ')
|
123 |
+
api_base = os.environ['OPENAI_API_BASE']
|
124 |
+
else:
|
125 |
+
api_base = 'OFFICIAL'
|
126 |
+
|
127 |
+
assert api_base is not None
|
128 |
+
|
129 |
+
if api_base in APIBASES:
|
130 |
+
self.api_base = APIBASES[api_base]
|
131 |
+
elif api_base.startswith('http'):
|
132 |
+
self.api_base = api_base
|
133 |
+
else:
|
134 |
+
self.logger.error('Unknown API Base. ')
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
self.logger.info(f'Using API Base: {self.api_base}; API Key: {self.key}')
|
138 |
+
|
139 |
+
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
140 |
+
# content can be a string or a list of image & text
|
141 |
+
def prepare_itlist(self, inputs):
|
142 |
+
assert np.all([isinstance(x, dict) for x in inputs])
|
143 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
144 |
+
if has_images:
|
145 |
+
content_list = []
|
146 |
+
for msg in inputs:
|
147 |
+
if msg['type'] == 'text':
|
148 |
+
content_list.append(dict(type='text', text=msg['value']))
|
149 |
+
elif msg['type'] == 'image':
|
150 |
+
from PIL import Image
|
151 |
+
img = Image.open(msg['value'])
|
152 |
+
b64 = encode_image_to_base64(img, target_size=self.img_size)
|
153 |
+
img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
|
154 |
+
content_list.append(dict(type='image_url', image_url=img_struct))
|
155 |
+
else:
|
156 |
+
assert all([x['type'] == 'text' for x in inputs])
|
157 |
+
text = '\n'.join([x['value'] for x in inputs])
|
158 |
+
content_list = [dict(type='text', text=text)]
|
159 |
+
return content_list
|
160 |
+
|
161 |
+
def prepare_inputs(self, inputs):
|
162 |
+
input_msgs = []
|
163 |
+
if self.system_prompt is not None:
|
164 |
+
input_msgs.append(dict(role='system', content=self.system_prompt))
|
165 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], dict)
|
166 |
+
assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
|
167 |
+
if 'role' in inputs[0]:
|
168 |
+
assert inputs[-1]['role'] == 'user', inputs[-1]
|
169 |
+
for item in inputs:
|
170 |
+
input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
|
171 |
+
else:
|
172 |
+
input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
|
173 |
+
return input_msgs
|
174 |
+
|
175 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
176 |
+
input_msgs = self.prepare_inputs(inputs)
|
177 |
+
temperature = kwargs.pop('temperature', self.temperature)
|
178 |
+
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
|
179 |
+
|
180 |
+
context_window = GPT_context_window(self.model)
|
181 |
+
new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
|
182 |
+
if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
|
183 |
+
self.logger.warning(
|
184 |
+
'Less than 100 tokens left, '
|
185 |
+
'may exceed the context window with some additional meta symbols. '
|
186 |
+
)
|
187 |
+
if new_max_tokens <= 0:
|
188 |
+
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
|
189 |
+
max_tokens = new_max_tokens
|
190 |
+
|
191 |
+
# Will send request if use Azure, dk how to use openai client for it
|
192 |
+
if self.use_azure:
|
193 |
+
headers = {'Content-Type': 'application/json', 'api-key': self.key}
|
194 |
+
elif 'internvl2-pro' in self.model:
|
195 |
+
headers = {'Content-Type': 'application/json', 'Authorization': self.key}
|
196 |
+
else:
|
197 |
+
headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
|
198 |
+
payload = dict(
|
199 |
+
model=self.model,
|
200 |
+
messages=input_msgs,
|
201 |
+
max_tokens=max_tokens,
|
202 |
+
n=1,
|
203 |
+
temperature=temperature,
|
204 |
+
**kwargs)
|
205 |
+
response = requests.post(
|
206 |
+
self.api_base,
|
207 |
+
headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
208 |
+
ret_code = response.status_code
|
209 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
210 |
+
answer = self.fail_msg
|
211 |
+
try:
|
212 |
+
resp_struct = json.loads(response.text)
|
213 |
+
answer = resp_struct['choices'][0]['message']['content'].strip()
|
214 |
+
except Exception as err:
|
215 |
+
if self.verbose:
|
216 |
+
self.logger.error(f'{type(err)}: {err}')
|
217 |
+
self.logger.error(response.text if hasattr(response, 'text') else response)
|
218 |
+
|
219 |
+
return ret_code, answer, response
|
220 |
+
|
221 |
+
def get_image_token_len(self, img_path, detail='low'):
|
222 |
+
import math
|
223 |
+
if detail == 'low':
|
224 |
+
return 85
|
225 |
+
|
226 |
+
im = Image.open(img_path)
|
227 |
+
height, width = im.size
|
228 |
+
if width > 1024 or height > 1024:
|
229 |
+
if width > height:
|
230 |
+
height = int(height * 1024 / width)
|
231 |
+
width = 1024
|
232 |
+
else:
|
233 |
+
width = int(width * 1024 / height)
|
234 |
+
height = 1024
|
235 |
+
|
236 |
+
h = math.ceil(height / 512)
|
237 |
+
w = math.ceil(width / 512)
|
238 |
+
total = 85 + 170 * h * w
|
239 |
+
return total
|
240 |
+
|
241 |
+
def get_token_len(self, inputs) -> int:
|
242 |
+
import tiktoken
|
243 |
+
try:
|
244 |
+
enc = tiktoken.encoding_for_model(self.model)
|
245 |
+
except Exception as err:
|
246 |
+
if 'gpt' in self.model.lower():
|
247 |
+
if self.verbose:
|
248 |
+
self.logger.warning(f'{type(err)}: {err}')
|
249 |
+
enc = tiktoken.encoding_for_model('gpt-4')
|
250 |
+
else:
|
251 |
+
return 0
|
252 |
+
assert isinstance(inputs, list)
|
253 |
+
tot = 0
|
254 |
+
for item in inputs:
|
255 |
+
if 'role' in item:
|
256 |
+
tot += self.get_token_len(item['content'])
|
257 |
+
elif item['type'] == 'text':
|
258 |
+
tot += len(enc.encode(item['value']))
|
259 |
+
elif item['type'] == 'image':
|
260 |
+
tot += self.get_image_token_len(item['value'], detail=self.img_detail)
|
261 |
+
return tot
|
262 |
+
|
263 |
+
|
264 |
+
class GPT4V(OpenAIWrapper):
|
265 |
+
|
266 |
+
def generate(self, message, dataset=None):
|
267 |
+
return super(GPT4V, self).generate(message)
|
VLMEvalKit/vlmeval/api/hf_chat_model.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import os.path as osp
|
4 |
+
import torch
|
5 |
+
from ..smp import *
|
6 |
+
|
7 |
+
|
8 |
+
def get_gpu_num(model_name):
|
9 |
+
model_name = model_name.lower()
|
10 |
+
kws = {
|
11 |
+
8: ['65b', '70b'],
|
12 |
+
4: ['30b', '33b', '35b', '40b'],
|
13 |
+
2: ['13b', '14b', '20b'],
|
14 |
+
1: ['6b', '7b', 'moss'],
|
15 |
+
}
|
16 |
+
for k in [8, 4, 2, 1]:
|
17 |
+
for keyword in kws[k]:
|
18 |
+
if keyword in model_name:
|
19 |
+
return k
|
20 |
+
return 8
|
21 |
+
|
22 |
+
|
23 |
+
validated_llms = [
|
24 |
+
'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b',
|
25 |
+
'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
|
26 |
+
'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k',
|
27 |
+
'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat',
|
28 |
+
'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5',
|
29 |
+
'meta-llama/Llama-2-7b-chat-hf'
|
30 |
+
]
|
31 |
+
Auto_model = ['chatglm']
|
32 |
+
|
33 |
+
|
34 |
+
class HFChatModel:
|
35 |
+
|
36 |
+
def _get_context_length(self, model, model_path):
|
37 |
+
# By default, we use model.config.seq_length
|
38 |
+
model_path = model_path.lower()
|
39 |
+
if 'baichuan' in model_path:
|
40 |
+
context_window = model.config.model_max_length
|
41 |
+
elif 'internlm' in model_path or 'llama' in model_path:
|
42 |
+
context_window = model.config.max_position_embeddings
|
43 |
+
elif 'vicuna' in model_path:
|
44 |
+
context_window = model.generation_config.max_length
|
45 |
+
else:
|
46 |
+
# chatglm & qwen
|
47 |
+
context_window = model.config.seq_length
|
48 |
+
return context_window
|
49 |
+
|
50 |
+
def _get_context_length_robust(self, model, model_path):
|
51 |
+
try:
|
52 |
+
context_window = self._get_context_length(model, model_path)
|
53 |
+
return context_window
|
54 |
+
except Exception as err:
|
55 |
+
self.logger.critical(f'{type(err)}: {err}')
|
56 |
+
self.logger.critical(
|
57 |
+
'Failed to extract context_window information from config / generation_config. '
|
58 |
+
'Please read the above code and check if the logic works for you model path'
|
59 |
+
)
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def __init__(self,
|
63 |
+
model_path,
|
64 |
+
system_prompt: str = None,
|
65 |
+
**kwargs):
|
66 |
+
|
67 |
+
self.logger = get_logger('HFChatModel')
|
68 |
+
if 'vicuna' in model_path.lower():
|
69 |
+
try:
|
70 |
+
from fastchat.model import get_conversation_template
|
71 |
+
except Exception as err:
|
72 |
+
self.logger.critical('Please install fastchat first to use vicuna. ')
|
73 |
+
raise err
|
74 |
+
|
75 |
+
self.explicit_device = kwargs.pop('device', None)
|
76 |
+
|
77 |
+
if self.explicit_device is None:
|
78 |
+
# If CUDA_VISIBLE_DEVICES is not properly set
|
79 |
+
if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] == '0,1,2,3,4,5,6,7':
|
80 |
+
num_gpu = get_gpu_num(model_path)
|
81 |
+
gpu_offset = kwargs.pop('gpu_offset', 0)
|
82 |
+
cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset + num_gpu)])
|
83 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
|
84 |
+
|
85 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
86 |
+
from transformers.generation import GenerationConfig
|
87 |
+
|
88 |
+
if model_path not in validated_llms:
|
89 |
+
self.logger.warning(f'{model_path} not in validated LLMs, may have inference troubles. ')
|
90 |
+
|
91 |
+
self.model_path = model_path
|
92 |
+
if listinstr(Auto_model, model_path):
|
93 |
+
LoadModel = AutoModel
|
94 |
+
else:
|
95 |
+
LoadModel = AutoModelForCausalLM
|
96 |
+
|
97 |
+
assert osp.exists(model_path) or len(model_path.split('/')) == 2
|
98 |
+
|
99 |
+
device = self.explicit_device if self.explicit_device else 'auto'
|
100 |
+
|
101 |
+
precision = {}
|
102 |
+
if 'internlm-chat-7b' in model_path:
|
103 |
+
precision = {'torch_dtype': torch.float16}
|
104 |
+
elif 'internlm-chat-20b' in model_path:
|
105 |
+
precision = {'torch_dtype': torch.bfloat16}
|
106 |
+
|
107 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
108 |
+
model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision)
|
109 |
+
model = model.eval()
|
110 |
+
|
111 |
+
if device != 'cpu':
|
112 |
+
model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda')
|
113 |
+
try:
|
114 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
115 |
+
model_path, trust_remote_code=True, device_map=device)
|
116 |
+
except Exception as err:
|
117 |
+
self.logger.warning(f'{type(err)}: {err}')
|
118 |
+
|
119 |
+
torch.cuda.empty_cache()
|
120 |
+
self.model = model
|
121 |
+
self.context_length = self._get_context_length_robust(model=model, model_path=model_path)
|
122 |
+
self.answer_buffer = 192
|
123 |
+
self.system_prompt = system_prompt
|
124 |
+
for k, v in kwargs.items():
|
125 |
+
self.logger.info(f'Following args will be used for generation (If not set specifically), {k}: {v}. ')
|
126 |
+
self.kwargs = kwargs
|
127 |
+
|
128 |
+
def generate_str(self, input, **kwargs):
|
129 |
+
if 'baichuan' in self.model_path.lower():
|
130 |
+
messages = []
|
131 |
+
messages.append({'role': 'user', 'content': input})
|
132 |
+
resp = self.model.chat(self.tokenizer, messages, **kwargs)
|
133 |
+
elif 'vicuna' in self.model_path.lower():
|
134 |
+
from fastchat.model import get_conversation_template
|
135 |
+
conv = get_conversation_template('vicuna')
|
136 |
+
conv.append_message(conv.roles[0], input)
|
137 |
+
conv.append_message(conv.roles[1], None)
|
138 |
+
prompt = conv.get_prompt()
|
139 |
+
inputs = self.tokenizer([prompt], return_tensors='pt')
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
for k in inputs:
|
142 |
+
inputs[k] = inputs[k].cuda()
|
143 |
+
|
144 |
+
params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
|
145 |
+
params.update(self.kwargs)
|
146 |
+
params.update(kwargs)
|
147 |
+
outputs = self.model.generate(**inputs, **params)
|
148 |
+
resp = self.tokenizer.decode(
|
149 |
+
outputs[0][len(inputs['input_ids'][0]):],
|
150 |
+
skip_special_tokens=True,
|
151 |
+
spaces_between_special_tokens=False)
|
152 |
+
|
153 |
+
else:
|
154 |
+
params = self.kwargs
|
155 |
+
params.update(kwargs)
|
156 |
+
resp, _ = self.model.chat(self.tokenizer, input, history=[], **params)
|
157 |
+
|
158 |
+
return resp
|
159 |
+
|
160 |
+
def length_ok(self, inputs):
|
161 |
+
tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0
|
162 |
+
for s in inputs:
|
163 |
+
tot += len(self.tokenizer.encode(s))
|
164 |
+
return tot + self.answer_buffer < self.context_length
|
165 |
+
|
166 |
+
def generate_list(self, full_inputs, offset=0, **kwargs):
|
167 |
+
assert isinstance(full_inputs, list)
|
168 |
+
|
169 |
+
inputs = full_inputs[offset:]
|
170 |
+
if not self.length_ok(inputs):
|
171 |
+
return self.chat(full_inputs, offset + 1)
|
172 |
+
|
173 |
+
model_path = self.model_path.lower()
|
174 |
+
|
175 |
+
if sum([x in model_path for x in ['baichuan']]):
|
176 |
+
input_msgs = []
|
177 |
+
if self.system_prompt is not None:
|
178 |
+
input_msgs.append(dict(role='user', content=self.system_prompt))
|
179 |
+
if len(inputs):
|
180 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], str)
|
181 |
+
roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
|
182 |
+
roles = roles * len(inputs)
|
183 |
+
for role, msg in zip(roles, inputs):
|
184 |
+
input_msgs.append(dict(role=role, content=msg))
|
185 |
+
response = self.model.chat(self.tokenizer, input_msgs)
|
186 |
+
elif sum([x in model_path for x in ['vicuna']]):
|
187 |
+
from fastchat.model import get_conversation_template
|
188 |
+
conv = get_conversation_template('vicuna')
|
189 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], str)
|
190 |
+
if len(inputs) % 2 == 1:
|
191 |
+
if self.system_prompt is not None:
|
192 |
+
conv.append_message(conv.roles[0], self.system_prompt)
|
193 |
+
for i in range(len(inputs) // 2):
|
194 |
+
conv.append_message(conv.roles[0], inputs[2 * i])
|
195 |
+
conv.append_message(conv.roles[1], inputs[2 * i + 1])
|
196 |
+
else:
|
197 |
+
assert self.system_prompt is not None
|
198 |
+
conv.append_message(conv.roles[0], self.system_prompt)
|
199 |
+
conv.append_message(conv.roles[1], inputs[0])
|
200 |
+
for i in range(len(inputs) // 2 - 1):
|
201 |
+
conv.append_message(conv.roles[0], inputs[2 * i + 1])
|
202 |
+
conv.append_message(conv.roles[1], inputs[2 * i + 2])
|
203 |
+
conv.append_message(conv.roles[0], inputs[-1])
|
204 |
+
conv.append_message(conv.roles[1], None)
|
205 |
+
prompt = conv.get_prompt()
|
206 |
+
inputs = self.tokenizer([prompt], return_tensors='pt')
|
207 |
+
if torch.cuda.is_available():
|
208 |
+
for k in inputs:
|
209 |
+
inputs[k] = inputs[k].cuda()
|
210 |
+
|
211 |
+
params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
|
212 |
+
params.update(self.kwargs)
|
213 |
+
params.update(kwargs)
|
214 |
+
|
215 |
+
outputs = self.model.generate(**inputs, **params)
|
216 |
+
response = self.tokenizer.decode(
|
217 |
+
outputs[0][len(inputs['input_ids'][0]):],
|
218 |
+
skip_special_tokens=True,
|
219 |
+
spaces_between_special_tokens=False)
|
220 |
+
response = response.lstrip('\n')
|
221 |
+
else:
|
222 |
+
# The default option, support internlm, chatglm, qwen
|
223 |
+
history, msg = [], None
|
224 |
+
if len(inputs) % 2 == 1:
|
225 |
+
if self.system_prompt is not None:
|
226 |
+
history = [(self.system_prompt, '')]
|
227 |
+
for i in range(len(inputs) // 2):
|
228 |
+
history.append((inputs[2 * i], inputs[2 * i + 1]))
|
229 |
+
else:
|
230 |
+
assert self.system_prompt is not None
|
231 |
+
history = [(self.system_prompt, inputs[0])]
|
232 |
+
for i in range(len(inputs) // 2 - 1):
|
233 |
+
history.append((inputs[2 * i + 1], inputs[2 * i + 2]))
|
234 |
+
msg = inputs[-1]
|
235 |
+
|
236 |
+
params = self.kwargs
|
237 |
+
params.update(kwargs)
|
238 |
+
response, _ = self.model.chat(self.tokenizer, msg, history=history, **params)
|
239 |
+
|
240 |
+
return response, offset
|
241 |
+
|
242 |
+
def generate(self, inputs, **kwargs):
|
243 |
+
if isinstance(inputs, str):
|
244 |
+
return self.generate_str(inputs, **kwargs)
|
245 |
+
elif isinstance(inputs, list):
|
246 |
+
return self.generate_list(inputs, **kwargs)
|
VLMEvalKit/vlmeval/api/hunyuan.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from vlmeval.api.base import BaseAPI
|
5 |
+
|
6 |
+
|
7 |
+
class HunyuanWrapper(BaseAPI):
|
8 |
+
|
9 |
+
is_api: bool = True
|
10 |
+
_apiVersion = '2023-09-01'
|
11 |
+
_service = 'hunyuan'
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
model: str = 'hunyuan-vision',
|
15 |
+
retry: int = 5,
|
16 |
+
wait: int = 5,
|
17 |
+
secret_key: str = None,
|
18 |
+
secret_id: str = None,
|
19 |
+
verbose: bool = True,
|
20 |
+
system_prompt: str = None,
|
21 |
+
temperature: float = 0,
|
22 |
+
timeout: int = 60,
|
23 |
+
api_base: str = 'hunyuan.tencentcloudapi.com',
|
24 |
+
**kwargs):
|
25 |
+
|
26 |
+
self.model = model
|
27 |
+
self.cur_idx = 0
|
28 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
29 |
+
self.temperature = temperature
|
30 |
+
|
31 |
+
warnings.warn('You may need to set the env variable HUNYUAN_SECRET_ID & HUNYUAN_SECRET_KEY to use Hunyuan. ')
|
32 |
+
|
33 |
+
secret_key = os.environ.get('HUNYUAN_SECRET_KEY', secret_key)
|
34 |
+
assert secret_key is not None, 'Please set the environment variable HUNYUAN_SECRET_KEY. '
|
35 |
+
secret_id = os.environ.get('HUNYUAN_SECRET_ID', secret_id)
|
36 |
+
assert secret_id is not None, 'Please set the environment variable HUNYUAN_SECRET_ID. '
|
37 |
+
|
38 |
+
self.model = model
|
39 |
+
self.endpoint = api_base
|
40 |
+
self.secret_id = secret_id
|
41 |
+
self.secret_key = secret_key
|
42 |
+
self.timeout = timeout
|
43 |
+
|
44 |
+
try:
|
45 |
+
from tencentcloud.common import credential
|
46 |
+
from tencentcloud.common.profile.client_profile import ClientProfile
|
47 |
+
from tencentcloud.common.profile.http_profile import HttpProfile
|
48 |
+
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
49 |
+
except ImportError as err:
|
50 |
+
self.logger.critical('Please install tencentcloud-sdk-python to use Hunyuan API. ')
|
51 |
+
raise err
|
52 |
+
|
53 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
54 |
+
|
55 |
+
cred = credential.Credential(self.secret_id, self.secret_key)
|
56 |
+
httpProfile = HttpProfile()
|
57 |
+
httpProfile.endpoint = self.endpoint
|
58 |
+
clientProfile = ClientProfile()
|
59 |
+
clientProfile.httpProfile = httpProfile
|
60 |
+
self.client = hunyuan_client.HunyuanClient(cred, 'ap-beijing', clientProfile)
|
61 |
+
self.logger.info(
|
62 |
+
f'Using Endpoint: {self.endpoint}; API Secret ID: {self.secret_id}; API Secret Key: {self.secret_key}'
|
63 |
+
)
|
64 |
+
|
65 |
+
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
66 |
+
# content can be a string or a list of image & text
|
67 |
+
def prepare_itlist(self, inputs):
|
68 |
+
assert np.all([isinstance(x, dict) for x in inputs])
|
69 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
70 |
+
if has_images:
|
71 |
+
content_list = []
|
72 |
+
for msg in inputs:
|
73 |
+
if msg['type'] == 'text':
|
74 |
+
content_list.append(dict(Type='text', Text=msg['value']))
|
75 |
+
elif msg['type'] == 'image':
|
76 |
+
from PIL import Image
|
77 |
+
img = Image.open(msg['value'])
|
78 |
+
b64 = encode_image_to_base64(img)
|
79 |
+
img_struct = dict(Url=f'data:image/jpeg;base64,{b64}')
|
80 |
+
content_list.append(dict(Type='image_url', ImageUrl=img_struct))
|
81 |
+
else:
|
82 |
+
assert all([x['type'] == 'text' for x in inputs])
|
83 |
+
text = '\n'.join([x['value'] for x in inputs])
|
84 |
+
content_list = [dict(Type='text', Text=text)]
|
85 |
+
return content_list
|
86 |
+
|
87 |
+
def prepare_inputs(self, inputs):
|
88 |
+
input_msgs = []
|
89 |
+
if self.system_prompt is not None:
|
90 |
+
input_msgs.append(dict(Role='system', Content=self.system_prompt))
|
91 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], dict)
|
92 |
+
assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
|
93 |
+
if 'role' in inputs[0]:
|
94 |
+
assert inputs[-1]['role'] == 'user', inputs[-1]
|
95 |
+
for item in inputs:
|
96 |
+
input_msgs.append(dict(Role=item['role'], Contents=self.prepare_itlist(item['content'])))
|
97 |
+
else:
|
98 |
+
input_msgs.append(dict(Role='user', Contents=self.prepare_itlist(inputs)))
|
99 |
+
return input_msgs
|
100 |
+
|
101 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
102 |
+
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
103 |
+
from tencentcloud.hunyuan.v20230901 import models
|
104 |
+
|
105 |
+
input_msgs = self.prepare_inputs(inputs)
|
106 |
+
temperature = kwargs.pop('temperature', self.temperature)
|
107 |
+
|
108 |
+
payload = dict(
|
109 |
+
Model=self.model,
|
110 |
+
Messages=input_msgs,
|
111 |
+
Temperature=temperature,
|
112 |
+
**kwargs)
|
113 |
+
|
114 |
+
retry_counter = 0
|
115 |
+
while retry_counter < 3:
|
116 |
+
try:
|
117 |
+
req = models.ChatCompletionsRequest()
|
118 |
+
req.from_json_string(json.dumps(payload))
|
119 |
+
resp = self.client.ChatCompletions(req)
|
120 |
+
resp = json.loads(resp.to_json_string())
|
121 |
+
answer = resp['Choices'][0]['Message']['Content']
|
122 |
+
return 0, answer, resp
|
123 |
+
except TencentCloudSDKException as e:
|
124 |
+
self.logger.error(f'Got error code: {e.get_code()}')
|
125 |
+
if e.get_code() == 'ClientNetworkError':
|
126 |
+
return -1, self.fail_msg + e.get_code(), None
|
127 |
+
elif e.get_code() in ['InternalError', 'ServerNetworkError']:
|
128 |
+
if retry_counter == 3:
|
129 |
+
return -1, self.fail_msg + e.get_code(), None
|
130 |
+
retry_counter += 1
|
131 |
+
continue
|
132 |
+
elif e.get_code() in ['LimitExceeded']:
|
133 |
+
time.sleep(5)
|
134 |
+
if retry_counter == 3:
|
135 |
+
return -1, self.fail_msg + e.get_code(), None
|
136 |
+
retry_counter += 1
|
137 |
+
continue
|
138 |
+
else:
|
139 |
+
return -1, self.fail_msg + str(e), None
|
140 |
+
|
141 |
+
return -1, self.fail_msg, None
|
142 |
+
|
143 |
+
|
144 |
+
class HunyuanVision(HunyuanWrapper):
|
145 |
+
|
146 |
+
def generate(self, message, dataset=None):
|
147 |
+
return super(HunyuanVision, self).generate(message)
|
VLMEvalKit/vlmeval/api/jt_vl_chat.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import base64
|
6 |
+
from vlmeval.smp import *
|
7 |
+
from vlmeval.api.base import BaseAPI
|
8 |
+
from vlmeval.dataset import DATASET_TYPE
|
9 |
+
from vlmeval.dataset import img_root_map
|
10 |
+
|
11 |
+
|
12 |
+
API_ENDPOINT = 'https://jiutian.10086.cn/kunlun/ingress/api/h3t-eeceff/92390745235a40a484d850be19e1f8b4/ai-5d7ae47ec93f4280953273c4001aafee/service-7544ea5ee3e841ad9d01e7af44acef7c/v1/chat/completions' # noqa: E501
|
13 |
+
APP_CODE = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI5ZGQwNmQ2ZjU4YTU0ZGY0OGEzNjRhMjQyNGMwODEyNSIsImlzcyI6ImFwaS1hdXRoLWtleSIsImV4cCI6NDg4MjkwNDA3OX0.k5t_T-955xWMndzBbx4WQQNAgm5DpMos9mHm7vkFipQ3yebCFMfyufpSxORSfEVpBaDS3Nly0dd8ygQYGnDgIQcC72vQ1xtkjCP49LNcqlceoET4rGc1zwRi76XLPSGFES4GcwvEmr7Ilth7XtqZNxcDF_Z7HyHyf1-zF0JIQETYSoxenqLU-gNteNfqRUnlyCgaKh03DscAbYvtoMUxEaFa2ZqyRSwekdHI_SPKCq9aC9G19yDPHTjeiwl1ubtyC5uMy5pERn_ClRsZS3Wyb-GmD5QQsFofrWvCiU_fVJuUiez39pYZvEP8awH0R9B7SkpQ4XOzj3fdytTPYy3g6g' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class JTVLChatWrapper(BaseAPI):
|
17 |
+
is_api: bool = True
|
18 |
+
INTERLEAVE = False
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
model: str = 'jt-vl-chat',
|
22 |
+
retry: int = 5,
|
23 |
+
wait: int = 5,
|
24 |
+
api_base: str = API_ENDPOINT,
|
25 |
+
key: str = APP_CODE,
|
26 |
+
verbose: bool = True,
|
27 |
+
system_prompt: str = None,
|
28 |
+
temperature: float = 0.7,
|
29 |
+
max_tokens: int = 256,
|
30 |
+
proxy: str = None,
|
31 |
+
**kwargs):
|
32 |
+
self.model = model
|
33 |
+
|
34 |
+
self.temperature = temperature
|
35 |
+
self.max_tokens = max_tokens
|
36 |
+
self.api_base = api_base
|
37 |
+
|
38 |
+
if key is None:
|
39 |
+
key = os.environ.get('JTVLChat_API_KEY', None)
|
40 |
+
assert key is not None, (
|
41 |
+
'Please set the API Key (also called app_code, obtain it here: https://github.com/jiutiancv/JT-VL-Chat)'
|
42 |
+
)
|
43 |
+
|
44 |
+
self.key = key
|
45 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
46 |
+
|
47 |
+
def dump_image(self, line, dataset):
|
48 |
+
"""Dump the image(s) of the input line to the corresponding dataset folder.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
line (line of pd.DataFrame): The raw input line.
|
52 |
+
dataset (str): The name of the dataset.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
str | list[str]: The paths of the dumped images.
|
56 |
+
"""
|
57 |
+
ROOT = LMUDataRoot()
|
58 |
+
assert isinstance(dataset, str)
|
59 |
+
|
60 |
+
img_root = os.path.join(ROOT, 'images', img_root_map(dataset) if dataset in img_root_map(dataset) else dataset)
|
61 |
+
os.makedirs(img_root, exist_ok=True)
|
62 |
+
if 'image' in line:
|
63 |
+
if isinstance(line['image'], list):
|
64 |
+
tgt_path = []
|
65 |
+
assert 'image_path' in line
|
66 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
67 |
+
path = osp.join(img_root, im_name)
|
68 |
+
if not read_ok(path):
|
69 |
+
decode_base64_to_image_file(img, path)
|
70 |
+
tgt_path.append(path)
|
71 |
+
else:
|
72 |
+
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
|
73 |
+
if not read_ok(tgt_path):
|
74 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
75 |
+
tgt_path = [tgt_path]
|
76 |
+
else:
|
77 |
+
assert 'image_path' in line
|
78 |
+
tgt_path = toliststr(line['image_path'])
|
79 |
+
|
80 |
+
return tgt_path
|
81 |
+
|
82 |
+
def use_custom_prompt(self, dataset):
|
83 |
+
assert dataset is not None
|
84 |
+
if listinstr(['MMMU_DEV_VAL','MMMU_TEST'], dataset):
|
85 |
+
return False
|
86 |
+
else:
|
87 |
+
return True
|
88 |
+
|
89 |
+
def build_multi_choice_prompt(self, line, dataset=None):
|
90 |
+
question = line['question']
|
91 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
92 |
+
if hint is not None:
|
93 |
+
question = hint + '\n' + question
|
94 |
+
|
95 |
+
options = {
|
96 |
+
cand: line[cand]
|
97 |
+
for cand in string.ascii_uppercase
|
98 |
+
if cand in line and not pd.isna(line[cand])
|
99 |
+
}
|
100 |
+
for key, item in options.items():
|
101 |
+
question += f'\n{key}. {item}'
|
102 |
+
prompt = question
|
103 |
+
|
104 |
+
if len(options):
|
105 |
+
prompt += '\n请直接回答选项字母。' if cn_string(
|
106 |
+
prompt) else "\nAnswer with the option's letter from the given choices directly."
|
107 |
+
else:
|
108 |
+
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
|
109 |
+
|
110 |
+
return prompt
|
111 |
+
|
112 |
+
def build_prompt(self, line, dataset=None):
|
113 |
+
assert self.use_custom_prompt(dataset)
|
114 |
+
assert dataset is None or isinstance(dataset, str)
|
115 |
+
|
116 |
+
tgt_path = self.dump_image(line, dataset)
|
117 |
+
|
118 |
+
if dataset is not None and listinstr(['MME'], dataset):
|
119 |
+
question = line['question']
|
120 |
+
prompt = question + ' Answer the question using a single word or phrase.'
|
121 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
122 |
+
question = line['question']
|
123 |
+
prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
|
124 |
+
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
|
125 |
+
prompt = self.build_multi_choice_prompt(line, dataset)
|
126 |
+
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
|
127 |
+
if listinstr(['MathVista', 'MathVision'], dataset):
|
128 |
+
prompt = line['question']
|
129 |
+
elif listinstr(['LLaVABench'], dataset):
|
130 |
+
question = line['question']
|
131 |
+
prompt = question + '\nAnswer this question in detail.'
|
132 |
+
elif listinstr(['MMVet'], dataset):
|
133 |
+
prompt = line['question']
|
134 |
+
else:
|
135 |
+
question = line['question']
|
136 |
+
prompt = question + '\nAnswer the question using a single word or phrase.'
|
137 |
+
else:
|
138 |
+
prompt = line['question']
|
139 |
+
message = [dict(type='text', value=prompt)]
|
140 |
+
message.extend([dict(type='image', value=s) for s in tgt_path])
|
141 |
+
return message
|
142 |
+
|
143 |
+
def message_to_promptimg(self, message, dataset=None):
|
144 |
+
assert not self.INTERLEAVE
|
145 |
+
model_name = self.__class__.__name__
|
146 |
+
import warnings
|
147 |
+
warnings.warn(
|
148 |
+
f'Model {model_name} does not support interleaved input. '
|
149 |
+
'Will use the first image and aggregated texts as prompt. ')
|
150 |
+
num_images = len([x for x in message if x['type'] == 'image'])
|
151 |
+
if num_images == 0:
|
152 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
153 |
+
image = None
|
154 |
+
else:
|
155 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
156 |
+
if dataset == 'BLINK':
|
157 |
+
image = concat_images_vlmeval(
|
158 |
+
[x['value'] for x in message if x['type'] == 'image'],
|
159 |
+
target_size=512)
|
160 |
+
else:
|
161 |
+
image = [x['value'] for x in message if x['type'] == 'image'][0]
|
162 |
+
return prompt, image
|
163 |
+
|
164 |
+
def get_send_data(self,prompt, image_path, temperature, max_tokens):
|
165 |
+
image = ''
|
166 |
+
with open(image_path, 'rb') as f:
|
167 |
+
image = str(base64.b64encode(f.read()), 'utf-8')
|
168 |
+
send_data = {
|
169 |
+
"messages": [
|
170 |
+
{
|
171 |
+
"role": "user",
|
172 |
+
"content": prompt
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"image_base64": image,
|
176 |
+
"max_tokens": max_tokens,
|
177 |
+
"temperature": temperature
|
178 |
+
}
|
179 |
+
return send_data
|
180 |
+
|
181 |
+
def get_send_data_no_image(self,prompt, temperature, max_tokens):
|
182 |
+
send_data = {
|
183 |
+
"messages": [
|
184 |
+
{
|
185 |
+
"role": "user",
|
186 |
+
"content": prompt
|
187 |
+
}
|
188 |
+
],
|
189 |
+
"max_tokens": max_tokens,
|
190 |
+
"temperature": temperature
|
191 |
+
}
|
192 |
+
return send_data
|
193 |
+
|
194 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
195 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
196 |
+
inputs = [inputs] if isinstance(inputs, str) else inputs
|
197 |
+
dataset = kwargs.get('dataset', None)
|
198 |
+
prompt, image_path = self.message_to_promptimg(message=inputs, dataset=dataset)
|
199 |
+
# print("prompt:",prompt)
|
200 |
+
if image_path:
|
201 |
+
send_data = self.get_send_data(
|
202 |
+
prompt=prompt,
|
203 |
+
image_path=image_path,
|
204 |
+
temperature=self.temperature,
|
205 |
+
max_tokens=self.max_tokens)
|
206 |
+
else:
|
207 |
+
send_data = self.get_send_data_no_image(
|
208 |
+
prompt=prompt,
|
209 |
+
temperature=self.temperature,
|
210 |
+
max_tokens=self.max_tokens)
|
211 |
+
|
212 |
+
json_data = json.dumps(send_data)
|
213 |
+
|
214 |
+
header_dict = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + self.key}
|
215 |
+
|
216 |
+
r = requests.post(self.api_base, headers=header_dict, data=json_data, timeout=3000)
|
217 |
+
try:
|
218 |
+
assert r.status_code == 200
|
219 |
+
r_json = r.json()
|
220 |
+
output = r_json['choices'][0]['message']['content']
|
221 |
+
if self.verbose:
|
222 |
+
self.logger.info(f'inputs: {inputs}\nanswer: {output}')
|
223 |
+
|
224 |
+
return 0,output,'Succeeded! '
|
225 |
+
|
226 |
+
except:
|
227 |
+
error_msg = f'Error! code {r.status_code} content: {r.content}'
|
228 |
+
error_con = r.content.decode('utf-8')
|
229 |
+
if self.verbose:
|
230 |
+
self.logger.error(error_msg)
|
231 |
+
self.logger.error(error_con)
|
232 |
+
self.logger.error(f'The input messages are {inputs}.')
|
233 |
+
return -1,error_msg,''
|
234 |
+
|
235 |
+
|
236 |
+
class JTVLChatAPI(JTVLChatWrapper):
|
237 |
+
|
238 |
+
def generate(self, message, dataset=None):
|
239 |
+
return super(JTVLChatAPI, self).generate(message, dataset=dataset)
|
VLMEvalKit/vlmeval/api/qwen_api.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from http import HTTPStatus
|
2 |
+
import os
|
3 |
+
from vlmeval.api.base import BaseAPI
|
4 |
+
from vlmeval.smp import *
|
5 |
+
|
6 |
+
|
7 |
+
# Note: This is a pure language model API.
|
8 |
+
class QwenAPI(BaseAPI):
|
9 |
+
|
10 |
+
is_api: bool = True
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
model: str = 'qwen-max-1201',
|
14 |
+
retry: int = 5,
|
15 |
+
wait: int = 5,
|
16 |
+
verbose: bool = True,
|
17 |
+
seed: int = 2680,
|
18 |
+
temperature: float = 0.0,
|
19 |
+
system_prompt: str = None,
|
20 |
+
key: str = None,
|
21 |
+
max_tokens: int = 1024,
|
22 |
+
proxy: str = None,
|
23 |
+
**kwargs):
|
24 |
+
|
25 |
+
assert model in ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-1201', 'qwen-max-longcontext']
|
26 |
+
self.model = model
|
27 |
+
import dashscope
|
28 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
29 |
+
self.max_tokens = max_tokens
|
30 |
+
self.temperature = temperature
|
31 |
+
self.seed = seed
|
32 |
+
if key is None:
|
33 |
+
key = os.environ.get('DASHSCOPE_API_KEY', None)
|
34 |
+
assert key is not None, (
|
35 |
+
'Please set the API Key (obtain it here: '
|
36 |
+
'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
|
37 |
+
)
|
38 |
+
dashscope.api_key = key
|
39 |
+
if proxy is not None:
|
40 |
+
proxy_set(proxy)
|
41 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def build_msgs(msgs_raw, system_prompt=None):
|
45 |
+
msgs = cp.deepcopy(msgs_raw)
|
46 |
+
ret = []
|
47 |
+
if system_prompt is not None:
|
48 |
+
ret.append(dict(role='system', content=system_prompt))
|
49 |
+
for i, msg in enumerate(msgs):
|
50 |
+
role = 'user' if i % 2 == 0 else 'assistant'
|
51 |
+
ret.append(dict(role=role, content=msg))
|
52 |
+
return ret
|
53 |
+
|
54 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
55 |
+
from dashscope import MultiModalConversation
|
56 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
57 |
+
inputs = [inputs] if isinstance(inputs, str) else inputs
|
58 |
+
messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
|
59 |
+
|
60 |
+
import dashscope
|
61 |
+
response = dashscope.Generation.call(
|
62 |
+
model=self.model,
|
63 |
+
messages=messages,
|
64 |
+
seed=self.seed,
|
65 |
+
temperature=self.temperature,
|
66 |
+
max_tokens=self.max_tokens,
|
67 |
+
result_format='message', # set the result to be "message" format.
|
68 |
+
)
|
69 |
+
if response.status_code != HTTPStatus.OK:
|
70 |
+
return -1, 'Error: Bad Response Statuse Code. ', f'The response status code is {response.status_code}. '
|
71 |
+
|
72 |
+
try:
|
73 |
+
return 0, response['output']['choices'][0]['message']['content'].strip(), 'Succeeded! '
|
74 |
+
except Exception as err:
|
75 |
+
return -1, f'Error: Failed to parse the response. {err}', response
|
VLMEvalKit/vlmeval/api/qwen_vl_api.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from vlmeval.smp import *
|
7 |
+
from vlmeval.api.base import BaseAPI
|
8 |
+
from vlmeval.vlm.qwen2_vl.prompt import Qwen2VLPromptMixin
|
9 |
+
|
10 |
+
|
11 |
+
def ensure_image_url(image: str) -> str:
|
12 |
+
prefixes = ['http://', 'https://', 'file://', 'data:image;']
|
13 |
+
if any(image.startswith(prefix) for prefix in prefixes):
|
14 |
+
return image
|
15 |
+
if os.path.exists(image):
|
16 |
+
return 'file://' + image
|
17 |
+
raise ValueError(f'Invalid image: {image}')
|
18 |
+
|
19 |
+
|
20 |
+
class Qwen2VLAPI(Qwen2VLPromptMixin, BaseAPI):
|
21 |
+
is_api: bool = True
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
model: str = 'qwen-vl-max-0809',
|
26 |
+
key: str | None = None,
|
27 |
+
min_pixels: int | None = None,
|
28 |
+
max_pixels: int | None = None,
|
29 |
+
max_length=1024,
|
30 |
+
top_p=0.001,
|
31 |
+
top_k=1,
|
32 |
+
temperature=0.01,
|
33 |
+
repetition_penalty=1.0,
|
34 |
+
presence_penalty=0.0,
|
35 |
+
seed=3407,
|
36 |
+
use_custom_prompt: bool = True,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
import dashscope
|
40 |
+
|
41 |
+
self.model = model
|
42 |
+
self.min_pixels = min_pixels
|
43 |
+
self.max_pixels = max_pixels
|
44 |
+
self.generate_kwargs = dict(
|
45 |
+
max_length=max_length,
|
46 |
+
top_p=top_p,
|
47 |
+
top_k=top_k,
|
48 |
+
temperature=temperature,
|
49 |
+
repetition_penalty=repetition_penalty,
|
50 |
+
presence_penalty=presence_penalty,
|
51 |
+
seed=seed,
|
52 |
+
)
|
53 |
+
|
54 |
+
key = os.environ.get('DASHSCOPE_API_KEY', None) if key is None else key
|
55 |
+
assert key is not None, (
|
56 |
+
'Please set the API Key (obtain it here: '
|
57 |
+
'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
|
58 |
+
)
|
59 |
+
dashscope.api_key = key
|
60 |
+
super().__init__(use_custom_prompt=use_custom_prompt, **kwargs)
|
61 |
+
|
62 |
+
def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
|
63 |
+
"""
|
64 |
+
inputs list[dict[str, str]], each dict has keys: ['type', 'value']
|
65 |
+
"""
|
66 |
+
content = []
|
67 |
+
for s in inputs:
|
68 |
+
if s['type'] == 'image':
|
69 |
+
item = {'type': 'image', 'image': ensure_image_url(s['value'])}
|
70 |
+
if dataset == 'OCRBench':
|
71 |
+
item['min_pixels'] = 10 * 10 * 28 * 28
|
72 |
+
warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
|
73 |
+
if self.max_pixels is not None:
|
74 |
+
item['max_pixels'] = self.max_pixels
|
75 |
+
else:
|
76 |
+
if self.min_pixels is not None:
|
77 |
+
item['min_pixels'] = self.min_pixels
|
78 |
+
if self.max_pixels is not None:
|
79 |
+
item['max_pixels'] = self.max_pixels
|
80 |
+
elif s['type'] == 'text':
|
81 |
+
item = {'type': 'text', 'text': s['value']}
|
82 |
+
else:
|
83 |
+
raise ValueError(f"Invalid message type: {s['type']}, {s}")
|
84 |
+
content.append(item)
|
85 |
+
return content
|
86 |
+
|
87 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
88 |
+
import dashscope
|
89 |
+
|
90 |
+
messages = []
|
91 |
+
if self.system_prompt is not None:
|
92 |
+
messages.append({'role': 'system', 'content': self.system_prompt})
|
93 |
+
messages.append(
|
94 |
+
{'role': 'user', 'content': self._prepare_content(inputs, dataset=kwargs.get('dataset', None))}
|
95 |
+
)
|
96 |
+
if self.verbose:
|
97 |
+
print(f'\033[31m{messages}\033[0m')
|
98 |
+
|
99 |
+
# generate
|
100 |
+
generation_kwargs = self.generate_kwargs.copy()
|
101 |
+
kwargs.pop('dataset', None)
|
102 |
+
generation_kwargs.update(kwargs)
|
103 |
+
try:
|
104 |
+
response = dashscope.MultiModalConversation.call(
|
105 |
+
model=self.model,
|
106 |
+
messages=messages,
|
107 |
+
**generation_kwargs,
|
108 |
+
)
|
109 |
+
if self.verbose:
|
110 |
+
print(response)
|
111 |
+
answer = response.output.choices[0]['message']['content'][0]['text']
|
112 |
+
return 0, answer, 'Succeeded! '
|
113 |
+
except Exception as err:
|
114 |
+
if self.verbose:
|
115 |
+
self.logger.error(f'{type(err)}: {err}')
|
116 |
+
self.logger.error(f'The input messages are {inputs}.')
|
117 |
+
return -1, '', ''
|
118 |
+
|
119 |
+
|
120 |
+
class QwenVLWrapper(BaseAPI):
|
121 |
+
|
122 |
+
is_api: bool = True
|
123 |
+
|
124 |
+
def __init__(self,
|
125 |
+
model: str = 'qwen-vl-plus',
|
126 |
+
retry: int = 5,
|
127 |
+
wait: int = 5,
|
128 |
+
key: str = None,
|
129 |
+
verbose: bool = True,
|
130 |
+
temperature: float = 0.0,
|
131 |
+
system_prompt: str = None,
|
132 |
+
max_tokens: int = 1024,
|
133 |
+
proxy: str = None,
|
134 |
+
**kwargs):
|
135 |
+
|
136 |
+
assert model in ['qwen-vl-plus', 'qwen-vl-max']
|
137 |
+
self.model = model
|
138 |
+
import dashscope
|
139 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
140 |
+
self.max_tokens = max_tokens
|
141 |
+
self.temperature = temperature
|
142 |
+
if key is None:
|
143 |
+
key = os.environ.get('DASHSCOPE_API_KEY', None)
|
144 |
+
assert key is not None, (
|
145 |
+
'Please set the API Key (obtain it here: '
|
146 |
+
'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
|
147 |
+
)
|
148 |
+
dashscope.api_key = key
|
149 |
+
if proxy is not None:
|
150 |
+
proxy_set(proxy)
|
151 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
152 |
+
|
153 |
+
# inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
|
154 |
+
# content can be a string or a list of image & text
|
155 |
+
def prepare_itlist(self, inputs):
|
156 |
+
assert np.all([isinstance(x, dict) for x in inputs])
|
157 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
158 |
+
if has_images:
|
159 |
+
content_list = []
|
160 |
+
for msg in inputs:
|
161 |
+
if msg['type'] == 'text':
|
162 |
+
content_list.append(dict(text=msg['value']))
|
163 |
+
elif msg['type'] == 'image':
|
164 |
+
content_list.append(dict(image='file://' + msg['value']))
|
165 |
+
else:
|
166 |
+
assert all([x['type'] == 'text' for x in inputs])
|
167 |
+
text = '\n'.join([x['value'] for x in inputs])
|
168 |
+
content_list = [dict(text=text)]
|
169 |
+
return content_list
|
170 |
+
|
171 |
+
def prepare_inputs(self, inputs):
|
172 |
+
input_msgs = []
|
173 |
+
if self.system_prompt is not None:
|
174 |
+
input_msgs.append(dict(role='system', content=self.system_prompt))
|
175 |
+
assert isinstance(inputs, list) and isinstance(inputs[0], dict)
|
176 |
+
assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
|
177 |
+
if 'role' in inputs[0]:
|
178 |
+
assert inputs[-1]['role'] == 'user', inputs[-1]
|
179 |
+
for item in inputs:
|
180 |
+
input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
|
181 |
+
else:
|
182 |
+
input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
|
183 |
+
return input_msgs
|
184 |
+
|
185 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
186 |
+
from dashscope import MultiModalConversation
|
187 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
188 |
+
|
189 |
+
if 'type' in inputs[0]:
|
190 |
+
pure_text = np.all([x['type'] == 'text' for x in inputs])
|
191 |
+
else:
|
192 |
+
pure_text = True
|
193 |
+
for inp in inputs:
|
194 |
+
if not np.all([x['type'] == 'text' for x in inp['content']]):
|
195 |
+
pure_text = False
|
196 |
+
break
|
197 |
+
|
198 |
+
assert not pure_text
|
199 |
+
messages = self.prepare_inputs(inputs)
|
200 |
+
gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
|
201 |
+
gen_config.update(kwargs)
|
202 |
+
try:
|
203 |
+
response = MultiModalConversation.call(model=self.model, messages=messages)
|
204 |
+
if self.verbose:
|
205 |
+
print(response)
|
206 |
+
answer = response.output.choices[0]['message']['content'][0]['text']
|
207 |
+
return 0, answer, 'Succeeded! '
|
208 |
+
except Exception as err:
|
209 |
+
if self.verbose:
|
210 |
+
self.logger.error(f'{type(err)}: {err}')
|
211 |
+
self.logger.error(f'The input messages are {inputs}.')
|
212 |
+
|
213 |
+
return -1, '', ''
|
214 |
+
|
215 |
+
|
216 |
+
class QwenVLAPI(QwenVLWrapper):
|
217 |
+
|
218 |
+
def generate(self, message, dataset=None):
|
219 |
+
return super(QwenVLAPI, self).generate(message)
|
VLMEvalKit/vlmeval/api/reka.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
from time import sleep
|
4 |
+
import mimetypes
|
5 |
+
|
6 |
+
|
7 |
+
class Reka_Wrapper(BaseAPI):
|
8 |
+
|
9 |
+
is_api: bool = True
|
10 |
+
INTERLEAVE: bool = False
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
model: str = 'reka-flash-20240226',
|
14 |
+
key: str = None,
|
15 |
+
retry: int = 10,
|
16 |
+
wait: int = 3,
|
17 |
+
system_prompt: str = None,
|
18 |
+
verbose: bool = True,
|
19 |
+
temperature: float = 0,
|
20 |
+
max_tokens: int = 1024,
|
21 |
+
**kwargs):
|
22 |
+
|
23 |
+
try:
|
24 |
+
import reka
|
25 |
+
except ImportError:
|
26 |
+
raise ImportError('Please install reka by running "pip install reka-api"')
|
27 |
+
|
28 |
+
self.model = model
|
29 |
+
default_kwargs = dict(temperature=temperature, request_output_len=max_tokens)
|
30 |
+
default_kwargs.update(kwargs)
|
31 |
+
self.kwargs = default_kwargs
|
32 |
+
if key is not None:
|
33 |
+
self.key = key
|
34 |
+
else:
|
35 |
+
self.key = os.environ.get('REKA_API_KEY', '')
|
36 |
+
super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
|
37 |
+
|
38 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
39 |
+
import reka
|
40 |
+
reka.API_KEY = self.key
|
41 |
+
dataset = kwargs.pop('dataset', None)
|
42 |
+
prompt, image_path = self.message_to_promptimg(inputs, dataset=dataset)
|
43 |
+
image_b64 = encode_image_file_to_base64(image_path)
|
44 |
+
|
45 |
+
response = reka.chat(
|
46 |
+
model_name=self.model,
|
47 |
+
human=prompt,
|
48 |
+
media_url=f'data:image/jpeg;base64,{image_b64}',
|
49 |
+
**self.kwargs)
|
50 |
+
|
51 |
+
try:
|
52 |
+
return 0, response['text'], response
|
53 |
+
except Exception as err:
|
54 |
+
return -1, self.fail_msg + str(err), response
|
55 |
+
|
56 |
+
|
57 |
+
class Reka(Reka_Wrapper):
|
58 |
+
|
59 |
+
def generate(self, message, dataset=None):
|
60 |
+
return super(Reka_Wrapper, self).generate(message)
|
VLMEvalKit/vlmeval/api/sensechat_vision.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
from vlmeval.dataset import img_root_map
|
4 |
+
from vlmeval.dataset import DATASET_TYPE
|
5 |
+
|
6 |
+
|
7 |
+
class SenseChatVisionWrapper(BaseAPI):
|
8 |
+
|
9 |
+
is_api: bool = True
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
model: str = 'SenseChat-5-Vision',
|
13 |
+
retry: int = 5,
|
14 |
+
wait: int = 5,
|
15 |
+
ak: str = None,
|
16 |
+
sk: str = None,
|
17 |
+
verbose: bool = True,
|
18 |
+
system_prompt: str = None,
|
19 |
+
max_tokens: int = 1024,
|
20 |
+
proxy: str = None,
|
21 |
+
**kwargs):
|
22 |
+
|
23 |
+
self.model = model
|
24 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
25 |
+
self.ak = os.environ.get('SENSECHAT_AK', None) if ak is None else ak
|
26 |
+
self.sk = os.environ.get('SENSECHAT_SK', None) if sk is None else sk
|
27 |
+
assert self.ak is not None and self.sk is not None
|
28 |
+
self.max_new_tokens = max_tokens
|
29 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
30 |
+
|
31 |
+
def dump_image(self, line, dataset):
|
32 |
+
"""Dump the image(s) of the input line to the corresponding dataset folder.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
line (line of pd.DataFrame): The raw input line.
|
36 |
+
dataset (str): The name of the dataset.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str | list[str]: The paths of the dumped images.
|
40 |
+
"""
|
41 |
+
ROOT = LMUDataRoot()
|
42 |
+
assert isinstance(dataset, str)
|
43 |
+
img_root = osp.join(ROOT, 'images', img_root_map(dataset))
|
44 |
+
os.makedirs(img_root, exist_ok=True)
|
45 |
+
if 'image' in line:
|
46 |
+
if isinstance(line['image'], list):
|
47 |
+
tgt_path = []
|
48 |
+
assert 'image_path' in line
|
49 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
50 |
+
path = osp.join(img_root, im_name)
|
51 |
+
if not read_ok(path):
|
52 |
+
decode_base64_to_image_file(img, path)
|
53 |
+
tgt_path.append(path)
|
54 |
+
else:
|
55 |
+
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
|
56 |
+
if not read_ok(tgt_path):
|
57 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
58 |
+
tgt_path = [tgt_path]
|
59 |
+
else:
|
60 |
+
assert 'image_path' in line
|
61 |
+
tgt_path = toliststr(line['image_path'])
|
62 |
+
|
63 |
+
return tgt_path
|
64 |
+
|
65 |
+
def image_to_base64(self, image_path):
|
66 |
+
import base64
|
67 |
+
with open(image_path, 'rb') as image_file:
|
68 |
+
encoded_string = base64.b64encode(image_file.read())
|
69 |
+
return encoded_string.decode('utf-8')
|
70 |
+
|
71 |
+
def encode_jwt_token(self, ak, sk):
|
72 |
+
import jwt
|
73 |
+
headers = {'alg': 'HS256', 'typ': 'JWT'}
|
74 |
+
payload = {
|
75 |
+
'iss': ak,
|
76 |
+
'exp': int(time.time())
|
77 |
+
+ 1800, # 填写您期望的有效时间,此处示例代表当前时间+30分钟
|
78 |
+
'nbf': int(time.time()) - 5, # 填写您期望的生效时间,此处示例代表当前时间-5秒
|
79 |
+
}
|
80 |
+
token = jwt.encode(payload, sk, headers=headers)
|
81 |
+
return token
|
82 |
+
|
83 |
+
def use_custom_prompt(self, dataset):
|
84 |
+
return True
|
85 |
+
|
86 |
+
def build_multi_choice_prompt(self, line, dataset=None):
|
87 |
+
question = line['question']
|
88 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
89 |
+
if hint is not None:
|
90 |
+
question = hint + '\n' + question
|
91 |
+
|
92 |
+
options = {
|
93 |
+
cand: line[cand]
|
94 |
+
for cand in string.ascii_uppercase
|
95 |
+
if cand in line and not pd.isna(line[cand])
|
96 |
+
}
|
97 |
+
for key, item in options.items():
|
98 |
+
question += f'\n{key}. {item}'
|
99 |
+
prompt = question
|
100 |
+
|
101 |
+
if len(options):
|
102 |
+
prompt += '\n请直接回答选项字母。' if cn_string(
|
103 |
+
prompt) else "\nAnswer with the option's letter from the given choices directly."
|
104 |
+
else:
|
105 |
+
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
|
106 |
+
|
107 |
+
return prompt
|
108 |
+
|
109 |
+
def build_prompt(self, line, dataset=None):
|
110 |
+
assert self.use_custom_prompt(dataset)
|
111 |
+
assert dataset is None or isinstance(dataset, str)
|
112 |
+
|
113 |
+
tgt_path = self.dump_image(line, dataset)
|
114 |
+
|
115 |
+
if dataset is not None and listinstr(['MME'], dataset):
|
116 |
+
question = line['question']
|
117 |
+
prompt = question + ' Answer the question using a single word or phrase.'
|
118 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
119 |
+
question = line['question']
|
120 |
+
prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
|
121 |
+
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ' and 'MMMU' not in dataset:
|
122 |
+
prompt = self.build_multi_choice_prompt(line, dataset)
|
123 |
+
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
|
124 |
+
if 'MathVista' in dataset:
|
125 |
+
prompt = line['question']
|
126 |
+
elif listinstr(['LLaVABench'], dataset):
|
127 |
+
question = line['question']
|
128 |
+
prompt = question + '\nAnswer this question in detail.'
|
129 |
+
elif listinstr(['MMVet'], dataset):
|
130 |
+
prompt = line['question']
|
131 |
+
else:
|
132 |
+
question = line['question']
|
133 |
+
prompt = question + '\nAnswer the question using a single word or phrase.'
|
134 |
+
elif dataset is not None and 'MMMU' in dataset:
|
135 |
+
question = line['question']
|
136 |
+
options = {
|
137 |
+
cand: line[cand]
|
138 |
+
for cand in string.ascii_uppercase
|
139 |
+
if cand in line and not pd.isna(line[cand])
|
140 |
+
}
|
141 |
+
for key, item in options.items():
|
142 |
+
question += f'\n{key}. {item}'
|
143 |
+
prompt = {
|
144 |
+
'multiple-choice': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n', # noqa
|
145 |
+
'open': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n' # noqa
|
146 |
+
}
|
147 |
+
subject = '_'.join(line['id'].split('_')[1:-1])
|
148 |
+
prompt = prompt[line['question_type']].format(subject, subject) + '\n' + question
|
149 |
+
else:
|
150 |
+
prompt = line['question']
|
151 |
+
|
152 |
+
message = [dict(type='text', value=prompt)]
|
153 |
+
message.extend([dict(type='image', value=s) for s in tgt_path])
|
154 |
+
|
155 |
+
return message
|
156 |
+
|
157 |
+
def message_to_promptimg(self, message, dataset=None):
|
158 |
+
if dataset is None or listinstr(['MMMU', 'BLINK'], dataset):
|
159 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
160 |
+
image = [[x['value'] for x in message if x['type'] == 'image'][0]]
|
161 |
+
else:
|
162 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
163 |
+
image = [x['value'] for x in message if x['type'] == 'image']
|
164 |
+
return prompt, image
|
165 |
+
|
166 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
167 |
+
assert isinstance(inputs, str) or isinstance(inputs, list)
|
168 |
+
inputs = [inputs] if isinstance(inputs, str) else inputs
|
169 |
+
dataset = kwargs.get('dataset', None)
|
170 |
+
|
171 |
+
if dataset is not None and listinstr(['ChartQA_TEST','MathVista_MINI'], dataset):
|
172 |
+
self.max_num = 12
|
173 |
+
elif dataset is not None and listinstr(['DocVQA_VAL', 'DocVQA_TEST'], dataset):
|
174 |
+
self.max_num = 18
|
175 |
+
elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench'], dataset):
|
176 |
+
self.max_num = 24
|
177 |
+
else:
|
178 |
+
self.max_num = 6
|
179 |
+
|
180 |
+
if dataset is None:
|
181 |
+
pass
|
182 |
+
elif listinstr(['AI2D_TEST'], dataset):
|
183 |
+
self.max_new_tokens = 10
|
184 |
+
elif 'MMMU' in dataset:
|
185 |
+
self.max_new_tokens = 4096 # 1024
|
186 |
+
elif 'MMBench' in dataset:
|
187 |
+
self.max_new_tokens = 100
|
188 |
+
elif 'MathVista_MINI' in dataset:
|
189 |
+
self.max_new_tokens = 4096
|
190 |
+
|
191 |
+
prompt, image = self.message_to_promptimg(message=inputs, dataset=dataset)
|
192 |
+
|
193 |
+
url = 'https://api.sensenova.cn/v1/llm/chat-completions'
|
194 |
+
api_secret_key = self.encode_jwt_token(self.ak, self.sk)
|
195 |
+
|
196 |
+
content = [{
|
197 |
+
'image_base64': self.image_to_base64(item),
|
198 |
+
'image_file_id': '',
|
199 |
+
'image_url': '',
|
200 |
+
'text': '',
|
201 |
+
'text': '',
|
202 |
+
'type': 'image_base64'
|
203 |
+
} for item in image]
|
204 |
+
|
205 |
+
content.append({
|
206 |
+
'image_base64': '',
|
207 |
+
'image_file_id': '',
|
208 |
+
'image_url': '',
|
209 |
+
'text': prompt,
|
210 |
+
'type': 'text'
|
211 |
+
})
|
212 |
+
|
213 |
+
message = [{'content': content, 'role': 'user'}]
|
214 |
+
|
215 |
+
data = {
|
216 |
+
'messages': message,
|
217 |
+
'max_new_tokens': self.max_new_tokens, # 1024
|
218 |
+
'temperature': 0,
|
219 |
+
"top_k": 0,
|
220 |
+
"top_p": 0.99,
|
221 |
+
'repetition_penalty': 1.05,
|
222 |
+
'model': self.model,
|
223 |
+
'stream': False,
|
224 |
+
}
|
225 |
+
headers = {
|
226 |
+
'Content-type': 'application/json',
|
227 |
+
'Authorization': 'Bearer ' + api_secret_key
|
228 |
+
}
|
229 |
+
|
230 |
+
response = requests.post(
|
231 |
+
url,
|
232 |
+
headers=headers,
|
233 |
+
json=data,
|
234 |
+
)
|
235 |
+
request_id = response.headers['x-request-id']
|
236 |
+
|
237 |
+
time.sleep(1)
|
238 |
+
try:
|
239 |
+
assert response.status_code == 200
|
240 |
+
response = response.json()['data']['choices'][0]['message'].strip()
|
241 |
+
if self.verbose:
|
242 |
+
self.logger.info(f'inputs: {inputs}\nanswer: {response}')
|
243 |
+
return 0, response, 'Succeeded! '
|
244 |
+
except Exception as err:
|
245 |
+
if self.verbose:
|
246 |
+
self.logger.error('---------------------------ERROR---------------------------')
|
247 |
+
self.logger.error(response.json())
|
248 |
+
self.logger.error(err)
|
249 |
+
self.logger.error('---------------------------request_id---------------------------' + request_id)
|
250 |
+
self.logger.error(
|
251 |
+
'api error' + response.json()['error']['message']
|
252 |
+
+ str([input['value'] if input['type'] == 'image' else None for input in inputs])
|
253 |
+
)
|
254 |
+
self.logger.error(f'The input messages are {inputs}.')
|
255 |
+
return -1, response.json()['error']['message'], ''
|
256 |
+
|
257 |
+
|
258 |
+
class SenseChatVisionAPI(SenseChatVisionWrapper):
|
259 |
+
|
260 |
+
def generate(self, message, dataset=None):
|
261 |
+
return super(SenseChatVisionAPI, self).generate(message, dataset=dataset)
|
VLMEvalKit/vlmeval/api/siliconflow.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from vlmeval.smp import *
|
3 |
+
from vlmeval.api.base import BaseAPI
|
4 |
+
from vlmeval.dataset import img_root_map
|
5 |
+
|
6 |
+
API_BASE = "https://api.siliconflow.cn/v1/chat/completions"
|
7 |
+
|
8 |
+
|
9 |
+
def resize_image(image: Image.Image, max_height: int, max_width: int) -> Image.Image:
|
10 |
+
width, height = image.size
|
11 |
+
if min(width, height) < 50:
|
12 |
+
scale = 50 / min(width, height)
|
13 |
+
image = image.resize((int(width * scale), int(height * scale)))
|
14 |
+
current_pixels = width * height
|
15 |
+
|
16 |
+
if current_pixels <= max_height * max_width:
|
17 |
+
return image
|
18 |
+
|
19 |
+
scale = math.sqrt(max_height * max_width / current_pixels)
|
20 |
+
new_width = int(width * scale)
|
21 |
+
new_height = int(height * scale)
|
22 |
+
|
23 |
+
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
24 |
+
|
25 |
+
|
26 |
+
def encode_image(path: str, max_height: int = 1024, max_width: int = 1024) -> str:
|
27 |
+
image = Image.open(path).convert("RGB")
|
28 |
+
image = resize_image(image, max_height, max_width)
|
29 |
+
height, width = image.size
|
30 |
+
if min(height, width) < 50:
|
31 |
+
scale = 50 / min(width, height)
|
32 |
+
image = image.resize((int(width * scale), int(height * scale)))
|
33 |
+
buffered = io.BytesIO()
|
34 |
+
image.save(buffered, format="PNG")
|
35 |
+
img_bytes = buffered.getvalue()
|
36 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
37 |
+
return img_base64
|
38 |
+
|
39 |
+
|
40 |
+
class SiliconFlowAPI(BaseAPI):
|
41 |
+
|
42 |
+
is_api: bool = True
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
model: str = "deepseek-ai/DeepSeek-V2.5",
|
47 |
+
retry: int = 5,
|
48 |
+
wait: int = 5,
|
49 |
+
key: str = None,
|
50 |
+
api_base: str = API_BASE,
|
51 |
+
verbose: bool = True,
|
52 |
+
system_prompt: str = None,
|
53 |
+
timeout: int = 60,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
|
57 |
+
self.model = model
|
58 |
+
self.api_base = api_base
|
59 |
+
|
60 |
+
default_kwargs = {
|
61 |
+
"stream": False,
|
62 |
+
"temperature": 0,
|
63 |
+
"n": 1,
|
64 |
+
"max_tokens": 1280,
|
65 |
+
}
|
66 |
+
for k, v in default_kwargs.items():
|
67 |
+
if k not in kwargs:
|
68 |
+
kwargs[k] = default_kwargs[k]
|
69 |
+
if key is not None:
|
70 |
+
self.key = key
|
71 |
+
else:
|
72 |
+
self.key = os.environ.get("SiliconFlow_API_KEY", "")
|
73 |
+
headers = {"Authorization": "Bearer {}", "Content-Type": "application/json"}
|
74 |
+
headers["Authorization"] = headers["Authorization"].format(self.key)
|
75 |
+
self.headers = headers
|
76 |
+
super().__init__(
|
77 |
+
wait=wait,
|
78 |
+
retry=retry,
|
79 |
+
system_prompt=system_prompt,
|
80 |
+
verbose=verbose,
|
81 |
+
**kwargs,
|
82 |
+
)
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def build_msgs(msgs_raw):
|
86 |
+
messages = []
|
87 |
+
message = {"role": "user", "content": []}
|
88 |
+
image_b64 = None
|
89 |
+
for msg in msgs_raw:
|
90 |
+
if msg["type"] == "image" and not image_b64:
|
91 |
+
image_b64 = encode_image(msg["value"])
|
92 |
+
message["content"].append(
|
93 |
+
{"image_url": {"url": image_b64}, "type": "image_url"}
|
94 |
+
)
|
95 |
+
elif msg["type"] == "text":
|
96 |
+
message["content"].append({"text": msg["value"], "type": "text"})
|
97 |
+
|
98 |
+
messages.append(message)
|
99 |
+
return messages
|
100 |
+
|
101 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
102 |
+
default_kwargs = self.default_kwargs
|
103 |
+
default_kwargs.update(kwargs)
|
104 |
+
|
105 |
+
payload = dict(
|
106 |
+
model=self.model,
|
107 |
+
messages=self.build_msgs(msgs_raw=inputs),
|
108 |
+
**default_kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
response = requests.post(
|
112 |
+
self.api_base, headers=self.headers, data=json.dumps(payload)
|
113 |
+
)
|
114 |
+
ret_code = response.status_code
|
115 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
116 |
+
|
117 |
+
answer = self.fail_msg
|
118 |
+
try:
|
119 |
+
resp_struct = json.loads(response.text)
|
120 |
+
answer = resp_struct["choices"][0]["message"]["content"].strip()
|
121 |
+
except:
|
122 |
+
pass
|
123 |
+
return ret_code, answer, response
|
124 |
+
|
125 |
+
|
126 |
+
class TeleMMAPI(SiliconFlowAPI):
|
127 |
+
|
128 |
+
is_api: bool = True
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
model: str = "TeleAI/TeleMM",
|
133 |
+
key: str = None,
|
134 |
+
max_height: int = 1280,
|
135 |
+
max_width: int = 784,
|
136 |
+
**kwargs,
|
137 |
+
):
|
138 |
+
super().__init__(model=model, key=key, **kwargs)
|
139 |
+
self.max_height = max_height
|
140 |
+
self.max_width = max_width
|
141 |
+
|
142 |
+
def dump_image(self, line, dataset):
|
143 |
+
"""Dump the image(s) of the input line to the corresponding dataset folder.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
line (line of pd.DataFrame): The raw input line.
|
147 |
+
dataset (str): The name of the dataset.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
str | list[str]: The paths of the dumped images.
|
151 |
+
"""
|
152 |
+
ROOT = LMUDataRoot()
|
153 |
+
assert isinstance(dataset, str)
|
154 |
+
# img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
|
155 |
+
img_root = osp.join(ROOT, "images", img_root_map(dataset))
|
156 |
+
os.makedirs(img_root, exist_ok=True)
|
157 |
+
if "image" in line:
|
158 |
+
if isinstance(line["image"], list):
|
159 |
+
tgt_path = []
|
160 |
+
assert "image_path" in line
|
161 |
+
for img, im_name in zip(line["image"], line["image_path"]):
|
162 |
+
path = osp.join(img_root, im_name)
|
163 |
+
if not read_ok(path):
|
164 |
+
decode_base64_to_image_file(img, path)
|
165 |
+
tgt_path.append(path)
|
166 |
+
else:
|
167 |
+
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
|
168 |
+
if not read_ok(tgt_path):
|
169 |
+
decode_base64_to_image_file(line["image"], tgt_path)
|
170 |
+
tgt_path = [tgt_path]
|
171 |
+
else:
|
172 |
+
assert "image_path" in line
|
173 |
+
tgt_path = toliststr(line["image_path"])
|
174 |
+
return tgt_path
|
175 |
+
|
176 |
+
def _prepare_content(
|
177 |
+
self, inputs: list[dict[str, str]], dataset: str = None
|
178 |
+
) -> list[dict[str, str]]:
|
179 |
+
"""
|
180 |
+
inputs list[dict[str, str]], each dict has keys: ['type', 'value']
|
181 |
+
"""
|
182 |
+
content = []
|
183 |
+
has_image = False
|
184 |
+
for s in inputs:
|
185 |
+
if s["type"] == "image":
|
186 |
+
if not has_image:
|
187 |
+
item = {
|
188 |
+
"type": "image_url",
|
189 |
+
"image_url": {
|
190 |
+
"url": encode_image(
|
191 |
+
s["value"],
|
192 |
+
max_height=self.max_height,
|
193 |
+
max_width=self.max_width,
|
194 |
+
)
|
195 |
+
},
|
196 |
+
}
|
197 |
+
has_image = True
|
198 |
+
else:
|
199 |
+
continue
|
200 |
+
elif s["type"] == "text":
|
201 |
+
prompt = s["value"]
|
202 |
+
if len(prompt) == 0:
|
203 |
+
continue
|
204 |
+
if dataset == "HallusionBench":
|
205 |
+
prompt += " Please answer yes or no directly, without any unnecessary explanation."
|
206 |
+
elif dataset == "OCRBench":
|
207 |
+
prompt = (
|
208 |
+
prompt + "\nExtract the text from the image intactly and "
|
209 |
+
+ "answer the question concisely and clearly if possible."
|
210 |
+
)
|
211 |
+
|
212 |
+
elif (
|
213 |
+
dataset == "AI2D_TEST"
|
214 |
+
or dataset == "MMStar"
|
215 |
+
or dataset == "MMBench_TEST_EN_V11"
|
216 |
+
or dataset == "MMVet"
|
217 |
+
):
|
218 |
+
prompt = prompt.replace(
|
219 |
+
"Please select the correct answer from the options above. \n",
|
220 |
+
"Please select the correct option from the above choices based on the "
|
221 |
+
+ "input image and question. The final output should only be one option, such as 'A'",
|
222 |
+
)
|
223 |
+
elif dataset == "MMBench_TEST_CN_V11":
|
224 |
+
prompt = prompt.replace(
|
225 |
+
"Please select the correct answer from the options above. \n",
|
226 |
+
"请根据输入图像和问题从上述选项中选择正确选项,最终的输出只有一个选项,例如'A'",
|
227 |
+
)
|
228 |
+
item = {"type": "text", "text": prompt}
|
229 |
+
else:
|
230 |
+
raise ValueError(f"Invalid message type: {s['type']}, {s}")
|
231 |
+
content.append(item)
|
232 |
+
|
233 |
+
return content
|
234 |
+
|
235 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
236 |
+
default_kwargs = self.default_kwargs
|
237 |
+
default_kwargs.update(kwargs)
|
238 |
+
|
239 |
+
messages = []
|
240 |
+
messages.append(
|
241 |
+
{
|
242 |
+
"role": "user",
|
243 |
+
"content": self._prepare_content(
|
244 |
+
inputs, dataset=kwargs.get("dataset", None)
|
245 |
+
),
|
246 |
+
}
|
247 |
+
)
|
248 |
+
|
249 |
+
payload = dict(model=self.model, messages=messages, **default_kwargs)
|
250 |
+
|
251 |
+
response = requests.post(
|
252 |
+
self.api_base, headers=self.headers, data=json.dumps(payload)
|
253 |
+
)
|
254 |
+
ret_code = response.status_code
|
255 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
256 |
+
|
257 |
+
answer = self.fail_msg
|
258 |
+
try:
|
259 |
+
resp_struct = json.loads(response.text)
|
260 |
+
answer = resp_struct["choices"][0]["message"]["content"].strip()
|
261 |
+
return ret_code, answer, response
|
262 |
+
except Exception as err:
|
263 |
+
import traceback
|
264 |
+
|
265 |
+
traceback.print_exc()
|
266 |
+
if self.verbose:
|
267 |
+
self.logger.error(f"{type(err)}: {err}")
|
268 |
+
self.logger.error(f"The input messages are {inputs}.")
|
269 |
+
return -1, "", ""
|
VLMEvalKit/vlmeval/api/stepai.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
|
4 |
+
url = 'https://api.stepfun.com/v1/chat/completions'
|
5 |
+
headers = {
|
6 |
+
'Content-Type': 'application/json',
|
7 |
+
'Authorization': 'Bearer {}',
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
class StepAPI_INT(BaseAPI):
|
12 |
+
|
13 |
+
is_api: bool = True
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
model: str = 'step-1v-8k',
|
17 |
+
retry: int = 10,
|
18 |
+
wait: int = 3,
|
19 |
+
key: str = None,
|
20 |
+
temperature: float = 0,
|
21 |
+
max_tokens: int = 300,
|
22 |
+
verbose: bool = True,
|
23 |
+
system_prompt: str = None,
|
24 |
+
**kwargs):
|
25 |
+
self.model = model
|
26 |
+
self.fail_msg = 'Fail to obtain answer via API.'
|
27 |
+
self.headers = headers
|
28 |
+
self.temperature = temperature
|
29 |
+
self.max_tokens = max_tokens
|
30 |
+
self.system_prompt = system_prompt
|
31 |
+
if key is not None:
|
32 |
+
self.key = key
|
33 |
+
else:
|
34 |
+
self.key = os.environ.get('STEPAI_API_KEY', '')
|
35 |
+
headers['Authorization'] = headers['Authorization'].format(self.key)
|
36 |
+
|
37 |
+
super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def build_msgs(msgs_raw):
|
41 |
+
messages = []
|
42 |
+
message = {'role': 'user', 'content': []}
|
43 |
+
|
44 |
+
for msg in msgs_raw:
|
45 |
+
if msg['type'] == 'image':
|
46 |
+
image_b64 = encode_image_file_to_base64(msg['value'])
|
47 |
+
message['content'].append({
|
48 |
+
'image_url': {'url': 'data:image/webp;base64,%s' % (image_b64)},
|
49 |
+
'type': 'image_url'
|
50 |
+
})
|
51 |
+
elif msg['type'] == 'text':
|
52 |
+
message['content'].append({
|
53 |
+
'text': msg['value'],
|
54 |
+
'type': 'text'
|
55 |
+
})
|
56 |
+
|
57 |
+
messages.append(message)
|
58 |
+
return messages
|
59 |
+
|
60 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
61 |
+
print(inputs, '\n')
|
62 |
+
payload = dict(
|
63 |
+
model=self.model,
|
64 |
+
max_tokens=self.max_tokens,
|
65 |
+
temperature=self.temperature,
|
66 |
+
messages=self.build_msgs(msgs_raw=inputs),
|
67 |
+
**kwargs)
|
68 |
+
response = requests.post(url, headers=headers, data=json.dumps(payload))
|
69 |
+
ret_code = response.status_code
|
70 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
71 |
+
|
72 |
+
answer = self.fail_msg
|
73 |
+
try:
|
74 |
+
resp_struct = json.loads(response.text)
|
75 |
+
answer = resp_struct['choices'][0]['message']['content'].strip()
|
76 |
+
except Exception as err:
|
77 |
+
if self.verbose:
|
78 |
+
self.logger.error(f'{type(err)}: {err}')
|
79 |
+
self.logger.error(response.text if hasattr(response, 'text') else response)
|
80 |
+
|
81 |
+
return ret_code, answer, response
|
82 |
+
|
83 |
+
|
84 |
+
class Step1V_INT(StepAPI_INT):
|
85 |
+
|
86 |
+
def generate(self, message, dataset=None):
|
87 |
+
return super(StepAPI_INT, self).generate(message)
|
VLMEvalKit/vlmeval/api/taiyi.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vlmeval.smp import *
|
2 |
+
from vlmeval.api.base import BaseAPI
|
3 |
+
from vlmeval.dataset import DATASET_TYPE, img_root_map
|
4 |
+
|
5 |
+
|
6 |
+
class TaiyiWrapper(BaseAPI):
|
7 |
+
|
8 |
+
is_api: bool = True
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
model: str = 'taiyi',
|
12 |
+
retry: int = 5,
|
13 |
+
wait: int = 5,
|
14 |
+
key: str = None,
|
15 |
+
verbose: bool = False,
|
16 |
+
system_prompt: str = None,
|
17 |
+
temperature: float = 0,
|
18 |
+
timeout: int = 60,
|
19 |
+
url: str = "https://taiyi.megvii.com/v1/chat/completions",
|
20 |
+
max_tokens: int = 1024,
|
21 |
+
**kwargs):
|
22 |
+
|
23 |
+
self.model = model
|
24 |
+
self.fail_msg = 'Failed to obtain answer via API. '
|
25 |
+
self.max_tokens = max_tokens
|
26 |
+
self.temperature = temperature
|
27 |
+
|
28 |
+
if key is None:
|
29 |
+
key = os.environ.get('TAIYI_API_KEY', None)
|
30 |
+
assert key is not None, ('Please set the API Key ')
|
31 |
+
self.key = key
|
32 |
+
|
33 |
+
self.timeout = timeout
|
34 |
+
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
|
35 |
+
assert url is not None, ('Please set the url ')
|
36 |
+
self.url = url
|
37 |
+
self.logger.info(f'Using url: {self.url}; API Key: {self.key}')
|
38 |
+
|
39 |
+
def use_custom_prompt(self, dataset):
|
40 |
+
if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ' or DATASET_TYPE(dataset) == 'VQA':
|
41 |
+
return True
|
42 |
+
return False
|
43 |
+
|
44 |
+
def prepare_inputs(self, inputs):
|
45 |
+
input_msgs = []
|
46 |
+
if self.system_prompt is not None:
|
47 |
+
input_msgs.append(dict(role='system', content=self.system_prompt))
|
48 |
+
has_images = np.sum([x['type'] == 'image' for x in inputs])
|
49 |
+
if has_images:
|
50 |
+
content_list = []
|
51 |
+
for msg in inputs:
|
52 |
+
if msg['type'] == 'text':
|
53 |
+
content_list.append(dict(type='text', text=msg['value']))
|
54 |
+
elif msg['type'] == 'image':
|
55 |
+
imgbytes = open(msg['value'],'rb').read()
|
56 |
+
b64 = base64.b64encode(imgbytes).decode('ascii')
|
57 |
+
img_struct = dict(url=f'data:image/jpeg;base64,{b64}')
|
58 |
+
content_list.append(dict(type='image_url', image_url=img_struct))
|
59 |
+
input_msgs.append(dict(role='user', content=content_list))
|
60 |
+
else:
|
61 |
+
assert all([x['type'] == 'text' for x in inputs])
|
62 |
+
text = '\n'.join([x['value'] for x in inputs])
|
63 |
+
input_msgs.append(dict(role='user', content=text))
|
64 |
+
return input_msgs
|
65 |
+
|
66 |
+
def set_dump_image(self, dump_image_func):
|
67 |
+
self.dump_image_func = dump_image_func
|
68 |
+
|
69 |
+
def dump_image(self, line, dataset):
|
70 |
+
return self.dump_image_func(line)
|
71 |
+
|
72 |
+
def image_first(self, msgs):
|
73 |
+
nr_img = 0
|
74 |
+
for s in msgs:
|
75 |
+
if s['type'] == 'image':
|
76 |
+
nr_img += 1
|
77 |
+
|
78 |
+
if nr_img == 1:
|
79 |
+
new_msgs = []
|
80 |
+
img_msg = None
|
81 |
+
for s in msgs:
|
82 |
+
if s['type'] == 'text':
|
83 |
+
new_msgs.append(s)
|
84 |
+
else:
|
85 |
+
img_msg = s
|
86 |
+
new_msgs.insert(0, img_msg)
|
87 |
+
else:
|
88 |
+
new_msgs = msgs
|
89 |
+
|
90 |
+
return new_msgs
|
91 |
+
|
92 |
+
def build_multi_choice_prompt(self, line, dataset=None):
|
93 |
+
question = line['question']
|
94 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
95 |
+
if hint is not None:
|
96 |
+
question = hint + '\n' + question
|
97 |
+
|
98 |
+
options = {
|
99 |
+
cand: line[cand]
|
100 |
+
for cand in string.ascii_uppercase
|
101 |
+
if cand in line and not pd.isna(line[cand])
|
102 |
+
}
|
103 |
+
for key, item in options.items():
|
104 |
+
question += f'\n{key}. {item}'
|
105 |
+
prompt = question
|
106 |
+
|
107 |
+
if len(options):
|
108 |
+
prompt += '\n请直接回答选项字母。' if cn_string(
|
109 |
+
prompt) else "\nAnswer with the option's letter from the given choices directly."
|
110 |
+
else:
|
111 |
+
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
|
112 |
+
|
113 |
+
return prompt
|
114 |
+
|
115 |
+
def build_yorn_prompt(self, line, dataset=None):
|
116 |
+
if listinstr(['HallusionBench'], dataset):
|
117 |
+
pre_prompt = 'Read the following question carefully, think and solve it step by step.\n\n'
|
118 |
+
else:
|
119 |
+
pre_prompt = ''
|
120 |
+
|
121 |
+
prompt = pre_prompt + line['question'] + ' Please answer yes or no as the final answer.'
|
122 |
+
|
123 |
+
return prompt
|
124 |
+
|
125 |
+
def build_vqa_prompt(self, line, dataset=None):
|
126 |
+
if listinstr(['OCRBench'], dataset):
|
127 |
+
pre_prompt = 'Carefully identify the text in the image and answer the question.\n\n'
|
128 |
+
else:
|
129 |
+
pre_prompt = ''
|
130 |
+
|
131 |
+
if listinstr(['MMVet'], dataset):
|
132 |
+
post_prompt = '\nAnswer this question in detail.'
|
133 |
+
else:
|
134 |
+
post_prompt = ''
|
135 |
+
|
136 |
+
prompt = pre_prompt + line['question'] + post_prompt
|
137 |
+
|
138 |
+
return prompt
|
139 |
+
|
140 |
+
def build_prompt(self, line, dataset=None):
|
141 |
+
assert self.use_custom_prompt(dataset)
|
142 |
+
assert dataset is None or isinstance(dataset, str)
|
143 |
+
tgt_path = self.dump_image(line, dataset)
|
144 |
+
|
145 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
146 |
+
prompt = self.build_multi_choice_prompt(line, dataset)
|
147 |
+
elif DATASET_TYPE(dataset) == 'Y/N':
|
148 |
+
prompt = self.build_yorn_prompt(line, dataset)
|
149 |
+
elif DATASET_TYPE(dataset) == 'VQA':
|
150 |
+
prompt = self.build_vqa_prompt(line, dataset)
|
151 |
+
else:
|
152 |
+
raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
|
153 |
+
message = []
|
154 |
+
message.extend([dict(type='image', value=s) for s in tgt_path])
|
155 |
+
message.extend([dict(type='text', value=prompt)])
|
156 |
+
|
157 |
+
# interleave dataset
|
158 |
+
if dataset.startswith('MMMU_'):
|
159 |
+
from .. import MMMUDataset
|
160 |
+
message = MMMUDataset.split_MMMU(message)
|
161 |
+
message = self.image_first(message)
|
162 |
+
|
163 |
+
return message
|
164 |
+
|
165 |
+
def generate_inner(self, inputs, **kwargs) -> str:
|
166 |
+
|
167 |
+
input_msgs = self.prepare_inputs(inputs)
|
168 |
+
temperature = kwargs.pop('temperature', self.temperature)
|
169 |
+
|
170 |
+
headers = {'Authorization': f'Bearer {self.key}'}
|
171 |
+
payload = dict(
|
172 |
+
model=self.model,
|
173 |
+
messages=input_msgs,
|
174 |
+
n=1,
|
175 |
+
temperature=temperature,
|
176 |
+
**kwargs)
|
177 |
+
response = requests.post(self.url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
|
178 |
+
ret_code = response.status_code
|
179 |
+
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
|
180 |
+
answer = self.fail_msg
|
181 |
+
try:
|
182 |
+
resp_struct = json.loads(response.text)
|
183 |
+
answer = resp_struct['choices'][0]['message']['content'].strip()
|
184 |
+
except:
|
185 |
+
pass
|
186 |
+
return ret_code, answer, response
|
187 |
+
|
188 |
+
|
189 |
+
class TaiyiAPI(TaiyiWrapper):
|
190 |
+
|
191 |
+
def generate(self, message, dataset=None):
|
192 |
+
return super(TaiyiAPI, self).generate(message)
|
VLMEvalKit/vlmeval/dataset/__init__.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
from .image_base import img_root_map, ImageBaseDataset
|
4 |
+
from .image_caption import ImageCaptionDataset
|
5 |
+
from .image_yorn import ImageYORNDataset
|
6 |
+
from .image_mcq import (
|
7 |
+
ImageMCQDataset, MMMUDataset, CustomMCQDataset, MUIRDataset, GMAIMMBenchDataset, MMERealWorld, HRBenchDataset,
|
8 |
+
NaturalBenchDataset
|
9 |
+
)
|
10 |
+
from .image_mt import MMDUDataset
|
11 |
+
from .image_vqa import (
|
12 |
+
ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
|
13 |
+
CustomVQADataset, CRPE, MathVerse, OlympiadBench, QSpatial, VizWiz, MMNIAH
|
14 |
+
)
|
15 |
+
|
16 |
+
from .text_mcq import CustomTextMCQDataset, TextMCQDataset
|
17 |
+
|
18 |
+
from .vcr import VCRDataset
|
19 |
+
from .mmlongbench import MMLongBench
|
20 |
+
from .dude import DUDE
|
21 |
+
from .slidevqa import SlideVQA
|
22 |
+
|
23 |
+
from .mmbench_video import MMBenchVideo
|
24 |
+
from .videomme import VideoMME
|
25 |
+
from .mvbench import MVBench, MVBench_MP4
|
26 |
+
from .mlvu import MLVU, MLVU_MCQ, MLVU_OpenEnded
|
27 |
+
from .tempcompass import TempCompass, TempCompass_Captioning, TempCompass_MCQ, TempCompass_YorN
|
28 |
+
from .longvideobench import LongVideoBench
|
29 |
+
from .video_concat_dataset import ConcatVideoDataset
|
30 |
+
from .mmgenbench import MMGenBench
|
31 |
+
|
32 |
+
from .miabench import MIABench
|
33 |
+
from .cmmmu import CMMMU
|
34 |
+
from .wildvision import WildVision
|
35 |
+
from .mmmath import MMMath
|
36 |
+
from .dynamath import Dynamath
|
37 |
+
from .utils import *
|
38 |
+
from ..smp import *
|
39 |
+
|
40 |
+
|
41 |
+
class ConcatDataset(ImageBaseDataset):
|
42 |
+
# This dataset takes multiple dataset names as input and aggregate them into a single dataset.
|
43 |
+
# Each single dataset should not have a field named `SUB_DATASET`
|
44 |
+
|
45 |
+
DATASET_SETS = {
|
46 |
+
'MMMB': ['MMMB_ar', 'MMMB_cn', 'MMMB_en', 'MMMB_pt', 'MMMB_ru', 'MMMB_tr'],
|
47 |
+
'MTL_MMBench_DEV': [
|
48 |
+
'MMBench_dev_ar', 'MMBench_dev_cn', 'MMBench_dev_en',
|
49 |
+
'MMBench_dev_pt', 'MMBench_dev_ru', 'MMBench_dev_tr'
|
50 |
+
]
|
51 |
+
}
|
52 |
+
|
53 |
+
def __init__(self, dataset):
|
54 |
+
datasets = self.DATASET_SETS[dataset]
|
55 |
+
self.dataset_map = {}
|
56 |
+
# The name of the compliation
|
57 |
+
self.dataset_name = dataset
|
58 |
+
self.datasets = datasets
|
59 |
+
for dname in datasets:
|
60 |
+
dataset = build_dataset(dname)
|
61 |
+
assert dataset is not None, dataset
|
62 |
+
self.dataset_map[dname] = dataset
|
63 |
+
TYPES = [x.TYPE for x in self.dataset_map.values()]
|
64 |
+
MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
|
65 |
+
assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
|
66 |
+
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
|
67 |
+
self.TYPE = TYPES[0]
|
68 |
+
self.MODALITY = MODALITIES[0]
|
69 |
+
data_all = []
|
70 |
+
for dname in datasets:
|
71 |
+
data = self.dataset_map[dname].data
|
72 |
+
data['SUB_DATASET'] = [dname] * len(data)
|
73 |
+
data_new = localize_df(data, dname, nproc=16)
|
74 |
+
data_all.append(data_new)
|
75 |
+
|
76 |
+
data = pd.concat(data_all)
|
77 |
+
data['original_index'] = data.pop('index')
|
78 |
+
data['index'] = np.arange(len(data))
|
79 |
+
self.data = data
|
80 |
+
|
81 |
+
def build_prompt(self, line):
|
82 |
+
if isinstance(line, int):
|
83 |
+
line = self.data.iloc[line]
|
84 |
+
idx = line['original_index']
|
85 |
+
dname = line['SUB_DATASET']
|
86 |
+
org_data = self.dataset_map[dname].data
|
87 |
+
org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
|
88 |
+
return self.dataset_map[dname].build_prompt(org_line)
|
89 |
+
|
90 |
+
def dump_image(self, line):
|
91 |
+
# Assert all images are pre-dumped
|
92 |
+
assert 'image' not in line
|
93 |
+
assert 'image_path' in line
|
94 |
+
tgt_path = toliststr(line['image_path'])
|
95 |
+
return tgt_path
|
96 |
+
|
97 |
+
@classmethod
|
98 |
+
def supported_datasets(cls):
|
99 |
+
return list(cls.DATASET_SETS)
|
100 |
+
|
101 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
102 |
+
suffix = eval_file.split('.')[-1]
|
103 |
+
# First, split the eval_file by dataset
|
104 |
+
data_all = load(eval_file)
|
105 |
+
for dname in self.datasets:
|
106 |
+
tgt = eval_file.replace(self.dataset_name, dname)
|
107 |
+
data_sub = data_all[data_all['SUB_DATASET'] == dname]
|
108 |
+
data_sub.pop('index')
|
109 |
+
data_sub['index'] = data_sub.pop('original_index')
|
110 |
+
data_sub.pop('SUB_DATASET')
|
111 |
+
dump(data_sub, tgt)
|
112 |
+
# Then, evaluate each dataset separately
|
113 |
+
results_all = []
|
114 |
+
for dname in self.datasets:
|
115 |
+
tgt = eval_file.replace(self.dataset_name, dname)
|
116 |
+
res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
|
117 |
+
assert isinstance(res, pd.DataFrame)
|
118 |
+
res['DATASET'] = [dname] * len(res)
|
119 |
+
results_all.append(res)
|
120 |
+
result = pd.concat(results_all)
|
121 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
122 |
+
dump(result, score_file)
|
123 |
+
return result
|
124 |
+
|
125 |
+
|
126 |
+
# Add new supported dataset class here
|
127 |
+
IMAGE_DATASET = [
|
128 |
+
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
|
129 |
+
MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
|
130 |
+
MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset,
|
131 |
+
GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset,
|
132 |
+
MIABench, OlympiadBench, WildVision, MMMath, QSpatial, Dynamath, MMGenBench, VizWiz, MMNIAH,
|
133 |
+
CMMMU
|
134 |
+
]
|
135 |
+
|
136 |
+
VIDEO_DATASET = [
|
137 |
+
MMBenchVideo, VideoMME, MVBench, MVBench_MP4, LongVideoBench,
|
138 |
+
MLVU, MLVU_MCQ, MLVU_OpenEnded,
|
139 |
+
TempCompass, TempCompass_MCQ, TempCompass_Captioning, TempCompass_YorN
|
140 |
+
]
|
141 |
+
|
142 |
+
TEXT_DATASET = [
|
143 |
+
TextMCQDataset
|
144 |
+
]
|
145 |
+
|
146 |
+
CUSTOM_DATASET = [
|
147 |
+
CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset
|
148 |
+
]
|
149 |
+
|
150 |
+
DATASET_COLLECTION = [ConcatDataset, ConcatVideoDataset]
|
151 |
+
|
152 |
+
DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION
|
153 |
+
SUPPORTED_DATASETS = []
|
154 |
+
for DATASET_CLS in DATASET_CLASSES:
|
155 |
+
SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets())
|
156 |
+
|
157 |
+
|
158 |
+
def DATASET_TYPE(dataset, *, default: str = 'MCQ') -> str:
|
159 |
+
for cls in DATASET_CLASSES:
|
160 |
+
if dataset in cls.supported_datasets():
|
161 |
+
if hasattr(cls, 'TYPE'):
|
162 |
+
return cls.TYPE
|
163 |
+
# Have to add specific routine to handle ConcatDataset
|
164 |
+
if dataset in ConcatDataset.DATASET_SETS:
|
165 |
+
dataset_list = ConcatDataset.DATASET_SETS[dataset]
|
166 |
+
TYPES = [DATASET_TYPE(dname) for dname in dataset_list]
|
167 |
+
assert np.all([x == TYPES[0] for x in TYPES]), (dataset_list, TYPES)
|
168 |
+
return TYPES[0]
|
169 |
+
|
170 |
+
if 'openended' in dataset.lower():
|
171 |
+
return 'VQA'
|
172 |
+
warnings.warn(f'Dataset {dataset} is a custom one and not annotated as `openended`, will treat as {default}. ')
|
173 |
+
return default
|
174 |
+
|
175 |
+
|
176 |
+
def DATASET_MODALITY(dataset, *, default: str = 'IMAGE') -> str:
|
177 |
+
if dataset is None:
|
178 |
+
warnings.warn(f'Dataset is not specified, will treat modality as {default}. ')
|
179 |
+
return default
|
180 |
+
for cls in DATASET_CLASSES:
|
181 |
+
if dataset in cls.supported_datasets():
|
182 |
+
if hasattr(cls, 'MODALITY'):
|
183 |
+
return cls.MODALITY
|
184 |
+
# Have to add specific routine to handle ConcatDataset
|
185 |
+
if dataset in ConcatDataset.DATASET_SETS:
|
186 |
+
dataset_list = ConcatDataset.DATASET_SETS[dataset]
|
187 |
+
MODALITIES = [DATASET_MODALITY(dname) for dname in dataset_list]
|
188 |
+
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (dataset_list, MODALITIES)
|
189 |
+
return MODALITIES[0]
|
190 |
+
|
191 |
+
if 'VIDEO' in dataset.lower():
|
192 |
+
return 'VIDEO'
|
193 |
+
elif 'IMAGE' in dataset.lower():
|
194 |
+
return 'IMAGE'
|
195 |
+
warnings.warn(f'Dataset {dataset} is a custom one, will treat modality as {default}. ')
|
196 |
+
return default
|
197 |
+
|
198 |
+
|
199 |
+
def build_dataset(dataset_name, **kwargs):
|
200 |
+
for cls in DATASET_CLASSES:
|
201 |
+
if dataset_name in cls.supported_datasets():
|
202 |
+
return cls(dataset=dataset_name, **kwargs)
|
203 |
+
|
204 |
+
warnings.warn(f'Dataset {dataset_name} is not officially supported. ')
|
205 |
+
|
206 |
+
data_file = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
|
207 |
+
if not osp.exists(data_file):
|
208 |
+
warnings.warn(f'Data file {data_file} does not exist. Dataset building failed. ')
|
209 |
+
return None
|
210 |
+
|
211 |
+
data = load(data_file)
|
212 |
+
if 'question' not in [x.lower() for x in data.columns]:
|
213 |
+
warnings.warn(f'Data file {data_file} does not have a `question` column. Dataset building failed. ')
|
214 |
+
return None
|
215 |
+
|
216 |
+
if 'A' in data and 'B' in data:
|
217 |
+
if 'image' in data or 'image_path' in data:
|
218 |
+
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom MCQ dataset. ')
|
219 |
+
return CustomMCQDataset(dataset=dataset_name, **kwargs)
|
220 |
+
else:
|
221 |
+
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom Text MCQ dataset. ')
|
222 |
+
return CustomTextMCQDataset(dataset=dataset_name, **kwargs)
|
223 |
+
else:
|
224 |
+
warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom VQA dataset. ')
|
225 |
+
return CustomVQADataset(dataset=dataset_name, **kwargs)
|
226 |
+
|
227 |
+
|
228 |
+
__all__ = [
|
229 |
+
'build_dataset', 'img_root_map', 'build_judge', 'extract_answer_from_item', 'prefetch_answer', 'DEBUG_MESSAGE'
|
230 |
+
] + [cls.__name__ for cls in DATASET_CLASSES]
|
VLMEvalKit/vlmeval/dataset/cmmmu.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_base import ImageBaseDataset
|
2 |
+
import random
|
3 |
+
from collections import Counter
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
from ..smp import *
|
8 |
+
|
9 |
+
|
10 |
+
def get_multi_choice_prediction(response, all_choices, index2ans):
|
11 |
+
for char in [',', '.', '!', '?', ';', ':', "'"]:
|
12 |
+
response = response.strip(char)
|
13 |
+
response = " " + response + " " # add space to avoid partial match
|
14 |
+
|
15 |
+
candidates = []
|
16 |
+
|
17 |
+
for choice in all_choices: # (A) (B) (C) (D)
|
18 |
+
# Add the choice to candidates each time it appears in the response
|
19 |
+
candidates.extend([choice for _ in range(response.count(f'({choice})'))])
|
20 |
+
|
21 |
+
if len(candidates) == 0:
|
22 |
+
for choice in all_choices: # A B C D
|
23 |
+
# Similarly, add the choice for each occurrence
|
24 |
+
candidates.extend([choice for _ in range(response.count(f'{choice}'))])
|
25 |
+
|
26 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
27 |
+
for index, ans in index2ans.items():
|
28 |
+
# Add index for each occurrence of ans in response
|
29 |
+
candidates.extend([index for _ in range(response.count(ans))])
|
30 |
+
|
31 |
+
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
|
32 |
+
if len(candidates) == 0 and len(response.split()) >= 1:
|
33 |
+
for index, ans in index2ans.items():
|
34 |
+
if ans in response:
|
35 |
+
candidates.append(index)
|
36 |
+
# index_ans = False # it's content ans.
|
37 |
+
|
38 |
+
if len(candidates) == 0: # still not get answer, randomly choose one.
|
39 |
+
return random.choice(all_choices)
|
40 |
+
# return ''
|
41 |
+
else:
|
42 |
+
# Count the occurrence of each candidate
|
43 |
+
candidate_counts = Counter(candidates)
|
44 |
+
|
45 |
+
# Select the most frequent candidates
|
46 |
+
max_count = max(candidate_counts.values())
|
47 |
+
most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
|
48 |
+
|
49 |
+
# Combine the most frequent candidates in ABCD order
|
50 |
+
return ''.join(most_frequent_candidates)
|
51 |
+
|
52 |
+
|
53 |
+
def extract_numbers(string):
|
54 |
+
# Pattern for numbers with Chinese commas
|
55 |
+
pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
|
56 |
+
# Pattern for scientific notation
|
57 |
+
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
|
58 |
+
# Pattern for simple numbers without Chinese commas
|
59 |
+
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
|
60 |
+
|
61 |
+
# Extract numbers with Chinese commas
|
62 |
+
numbers_with_commas = re.findall(pattern_commas, string)
|
63 |
+
# Extract numbers in scientific notation
|
64 |
+
numbers_scientific = re.findall(pattern_scientific, string)
|
65 |
+
# Extract simple numbers without Chinese commas
|
66 |
+
numbers_simple = re.findall(pattern_simple, string)
|
67 |
+
|
68 |
+
# Combine all extracted numbers
|
69 |
+
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
|
70 |
+
return all_numbers
|
71 |
+
|
72 |
+
|
73 |
+
def check_is_number(string):
|
74 |
+
try:
|
75 |
+
float(string.replace(',', ''))
|
76 |
+
return True
|
77 |
+
except ValueError:
|
78 |
+
# check if there's comma inside
|
79 |
+
return False
|
80 |
+
|
81 |
+
|
82 |
+
def count_letters(string):
|
83 |
+
return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
|
84 |
+
|
85 |
+
|
86 |
+
def normalize_str(string, answer):
|
87 |
+
# check if characters in the string
|
88 |
+
|
89 |
+
# if number, numerize it.
|
90 |
+
if string is None:
|
91 |
+
return [string]
|
92 |
+
string = string.strip()
|
93 |
+
|
94 |
+
is_number = check_is_number(string)
|
95 |
+
|
96 |
+
if is_number:
|
97 |
+
string = string.replace(',', '')
|
98 |
+
string = float(string)
|
99 |
+
# leave 2 decimal
|
100 |
+
string = round(string, 2)
|
101 |
+
return [string]
|
102 |
+
else: # it's likely to be a string
|
103 |
+
if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
|
104 |
+
return []
|
105 |
+
return [string]
|
106 |
+
|
107 |
+
|
108 |
+
def get_fill_blank_prediction(response, answer):
|
109 |
+
"""get the prediction from the generated response,
|
110 |
+
return a list of predicted strings or numbers"""
|
111 |
+
|
112 |
+
def get_key_subresponses(response):
|
113 |
+
response = response.strip("。").strip()
|
114 |
+
sub_responses = re.split(r'。|\n', response)
|
115 |
+
indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
|
116 |
+
'正确答案', '因此', '最后', '答案', '结果']
|
117 |
+
key_responses = []
|
118 |
+
for index, resp in enumerate(sub_responses):
|
119 |
+
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
|
120 |
+
if index == len(sub_responses) - 1:
|
121 |
+
indicators_of_keys.extend(['='])
|
122 |
+
shortest_key_response = None
|
123 |
+
# the shortest response that may contain the answer (tail part of the response)
|
124 |
+
for indicator in indicators_of_keys:
|
125 |
+
if indicator in resp:
|
126 |
+
if not shortest_key_response:
|
127 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
128 |
+
else:
|
129 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
130 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
131 |
+
|
132 |
+
if shortest_key_response:
|
133 |
+
# and it's not trivial
|
134 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
135 |
+
key_responses.append(shortest_key_response)
|
136 |
+
if len(key_responses) == 0: # did not found any
|
137 |
+
return [response]
|
138 |
+
return key_responses
|
139 |
+
|
140 |
+
key_responses = get_key_subresponses(response)
|
141 |
+
|
142 |
+
pred_list = key_responses.copy() # keep the original string response
|
143 |
+
for resp in key_responses:
|
144 |
+
pred_list.extend(extract_numbers(resp))
|
145 |
+
|
146 |
+
tmp_pred_list = []
|
147 |
+
for i in range(len(pred_list)):
|
148 |
+
tmp_pred_list.extend(normalize_str(pred_list[i], answer))
|
149 |
+
pred_list = tmp_pred_list
|
150 |
+
|
151 |
+
# remove duplicates
|
152 |
+
pred_list = list(set(pred_list))
|
153 |
+
|
154 |
+
return pred_list
|
155 |
+
|
156 |
+
|
157 |
+
def get_TF_prediction(response):
|
158 |
+
"""get the prediction from the generated response,
|
159 |
+
return a list of predicted strings or numbers"""
|
160 |
+
|
161 |
+
def get_key_subresponses(response):
|
162 |
+
response = response.strip("。").strip()
|
163 |
+
sub_responses = re.split(r'。|\n', response)
|
164 |
+
indicators_of_keys = ['是', '为', '所以', '判断',
|
165 |
+
'陈述', '说法', '表达', '答案', '结果']
|
166 |
+
key_responses = []
|
167 |
+
for index, resp in enumerate(sub_responses):
|
168 |
+
shortest_key_response = None
|
169 |
+
# the shortest response that may contain the answer (tail part of the response)
|
170 |
+
for indicator in indicators_of_keys:
|
171 |
+
if indicator in resp:
|
172 |
+
if not shortest_key_response:
|
173 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
174 |
+
else:
|
175 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
176 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
177 |
+
|
178 |
+
if shortest_key_response:
|
179 |
+
# and it's not trivial
|
180 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
181 |
+
key_responses.append(shortest_key_response)
|
182 |
+
if len(key_responses) == 0: # did not found any
|
183 |
+
return [response]
|
184 |
+
return key_responses
|
185 |
+
|
186 |
+
key_responses = get_key_subresponses(response)
|
187 |
+
|
188 |
+
pred_list = key_responses.copy() # keep the original string response
|
189 |
+
# remove duplicates
|
190 |
+
pred_list = list(set(pred_list))
|
191 |
+
|
192 |
+
return pred_list
|
193 |
+
|
194 |
+
|
195 |
+
class CMMMU(ImageBaseDataset):
|
196 |
+
TYPE = 'VQA'
|
197 |
+
|
198 |
+
DATASET_URL = {
|
199 |
+
'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
|
200 |
+
}
|
201 |
+
|
202 |
+
DATASET_MD5 = {
|
203 |
+
'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
|
204 |
+
}
|
205 |
+
|
206 |
+
def dump_image(self, line):
|
207 |
+
os.makedirs(self.img_root, exist_ok=True)
|
208 |
+
|
209 |
+
tgt_path_z = []
|
210 |
+
if isinstance(line['image'], list):
|
211 |
+
for i in range(len(line['image'])):
|
212 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
|
213 |
+
if not read_ok(tgt_path):
|
214 |
+
decode_base64_to_image_file(line['image'][i], tgt_path)
|
215 |
+
tgt_path_z.append(tgt_path)
|
216 |
+
else:
|
217 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
218 |
+
if not read_ok(tgt_path):
|
219 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
220 |
+
tgt_path_z.append(tgt_path)
|
221 |
+
return tgt_path_z
|
222 |
+
|
223 |
+
@classmethod
|
224 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
225 |
+
|
226 |
+
suffix = eval_file.split('.')[-1]
|
227 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
228 |
+
|
229 |
+
if not osp.exists(result_file):
|
230 |
+
data = load(eval_file)
|
231 |
+
assert 'answer' in data and 'prediction' in data
|
232 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
233 |
+
data['answer'] = [str(x) for x in data['answer']]
|
234 |
+
|
235 |
+
correct_count = 0
|
236 |
+
correct_category = {
|
237 |
+
'技术与工程': [0, 0],
|
238 |
+
'科学': [0, 0],
|
239 |
+
'健康与医学': [0, 0],
|
240 |
+
'商业': [0, 0],
|
241 |
+
'艺术与设计': [0, 0],
|
242 |
+
'人文社会科学': [0, 0],
|
243 |
+
}
|
244 |
+
|
245 |
+
for i in tqdm(data.iterrows()):
|
246 |
+
line = i[1]
|
247 |
+
correct_category[line['category']][0] += 1
|
248 |
+
|
249 |
+
# Options
|
250 |
+
if line['type'] == '选择':
|
251 |
+
index2ans = {
|
252 |
+
'A': line['option1'],
|
253 |
+
'B': line['option2'],
|
254 |
+
'C': line['option3'],
|
255 |
+
'D': line['option4']
|
256 |
+
}
|
257 |
+
fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
|
258 |
+
if fact_option == line['answer']:
|
259 |
+
correct_count += 1
|
260 |
+
correct_category[line['category']][1] += 1
|
261 |
+
|
262 |
+
# Binary
|
263 |
+
elif line['type'] == '判断':
|
264 |
+
positive_keywords = ['正确', '对', '准确', '肯定', '对的']
|
265 |
+
negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
|
266 |
+
ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
|
267 |
+
|
268 |
+
def judge_similarity(pred_list, positive_keywords, negative_keywords):
|
269 |
+
positive_count = 0
|
270 |
+
negative_count = 0
|
271 |
+
|
272 |
+
for pred in pred_list:
|
273 |
+
if any(pos_word in pred for pos_word in positive_keywords):
|
274 |
+
positive_count += 1
|
275 |
+
elif any(neg_word in pred for neg_word in negative_keywords):
|
276 |
+
negative_count += 1
|
277 |
+
|
278 |
+
if positive_count > negative_count:
|
279 |
+
return "对"
|
280 |
+
elif negative_count > positive_count:
|
281 |
+
return "错"
|
282 |
+
else:
|
283 |
+
return random.choice(['对', '错'])
|
284 |
+
|
285 |
+
answer = get_TF_prediction(line['prediction'])
|
286 |
+
answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
|
287 |
+
fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
|
288 |
+
if fact_answer == line['answer']:
|
289 |
+
correct_count += 1
|
290 |
+
correct_category[line['category']][1] += 1
|
291 |
+
|
292 |
+
# Fill the Blank
|
293 |
+
else:
|
294 |
+
norm_answers = normalize_str(line['answer'], line['answer'])
|
295 |
+
predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
|
296 |
+
|
297 |
+
for pred in predicted_answer:
|
298 |
+
# already normalized
|
299 |
+
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
|
300 |
+
for norm_ans in norm_answers:
|
301 |
+
# only see if the string answer in the string pred
|
302 |
+
# print(norm_ans, pred)
|
303 |
+
if isinstance(norm_ans, str) and norm_ans in pred:
|
304 |
+
correct_count += 1
|
305 |
+
correct_category[line['category']][1] += 1
|
306 |
+
else: # it's a number
|
307 |
+
if pred in norm_answers:
|
308 |
+
correct_count += 1
|
309 |
+
correct_category[line['category']][1] += 1
|
310 |
+
|
311 |
+
accuracyz = {}
|
312 |
+
accuracyz['总准确率'] = correct_count / len(data)
|
313 |
+
for i in correct_category.keys():
|
314 |
+
accuracyz[i] = correct_category[i][1] / correct_category[i][0]
|
315 |
+
|
316 |
+
accuracyz = d2df(accuracyz)
|
317 |
+
accuracyz.round(10)
|
318 |
+
dump(accuracyz, result_file)
|
319 |
+
|
320 |
+
result = pd.read_csv(result_file)
|
321 |
+
return result
|
322 |
+
|
323 |
+
def build_prompt(self, line):
|
324 |
+
if line['type'] == '选择':
|
325 |
+
tgt_path = self.dump_image(line)
|
326 |
+
question = line['question']
|
327 |
+
options_prompt = 'Options:\n'
|
328 |
+
|
329 |
+
for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
|
330 |
+
options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
|
331 |
+
|
332 |
+
prompt = (f'问题: {question}\n' + options_prompt
|
333 |
+
+ '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
|
334 |
+
|
335 |
+
msgs = []
|
336 |
+
if isinstance(tgt_path, list):
|
337 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
338 |
+
else:
|
339 |
+
msgs = [dict(type='image', value=tgt_path)]
|
340 |
+
msgs.append(dict(type='text', value=prompt))
|
341 |
+
|
342 |
+
return msgs
|
343 |
+
|
344 |
+
elif line['type'] == '判断':
|
345 |
+
msgs = super().build_prompt(line)
|
346 |
+
assert msgs[-1]['type'] == 'text'
|
347 |
+
msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
|
348 |
+
return msgs
|
349 |
+
|
350 |
+
else:
|
351 |
+
msgs = super().build_prompt(line)
|
352 |
+
assert msgs[-1]['type'] == 'text'
|
353 |
+
msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
|
354 |
+
return msgs
|
VLMEvalKit/vlmeval/dataset/dude.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from .utils.judge_util import build_judge
|
5 |
+
from .image_base import ImageBaseDataset
|
6 |
+
from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
|
7 |
+
from ..smp import *
|
8 |
+
|
9 |
+
|
10 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
11 |
+
|
12 |
+
|
13 |
+
def DUDE_acc(result_file):
|
14 |
+
data = load(result_file)
|
15 |
+
overall_score = 0.0
|
16 |
+
score_list = list()
|
17 |
+
for i in range(len(data)):
|
18 |
+
item = data.iloc[i]
|
19 |
+
if isinstance(item['answer'], float) and math.isnan(item['answer']):
|
20 |
+
item['answer'] = 'Not answerable'
|
21 |
+
|
22 |
+
item['answer'] = item['answer'].lower()
|
23 |
+
item['pred'] = item['pred'].lower()
|
24 |
+
score = anls_compute(item['answer'], item['pred'])
|
25 |
+
score_list.append(score)
|
26 |
+
overall_score += score
|
27 |
+
|
28 |
+
data['score'] = score_list
|
29 |
+
dump(data, result_file)
|
30 |
+
|
31 |
+
res = dict()
|
32 |
+
res['category'], res['num'], res['avg_score'] = ['anls'], [len(data)], [overall_score / len(data)]
|
33 |
+
res = pd.DataFrame(res)
|
34 |
+
return res
|
35 |
+
|
36 |
+
|
37 |
+
class DUDE(ImageBaseDataset):
|
38 |
+
|
39 |
+
TYPE = 'VQA'
|
40 |
+
|
41 |
+
DATASET_URL = {
|
42 |
+
'DUDE': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE.tsv',
|
43 |
+
'DUDE_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE_MINI.tsv',
|
44 |
+
}
|
45 |
+
DATASET_MD5 = {
|
46 |
+
'DUDE': '130d860d08206e1e407cd77150c10d88',
|
47 |
+
'DUDE_MINI': 'e0c0d998114f0cca7516d12039d2b538',
|
48 |
+
}
|
49 |
+
|
50 |
+
SUPPORTED_MODELS = {
|
51 |
+
'GPT4': (1, 1),
|
52 |
+
'GPT4V': (1, 1),
|
53 |
+
'GPT4V_HIGH': (1, 1),
|
54 |
+
'GPT4o': (1, 1),
|
55 |
+
'GPT4o_HIGH': (1, 1),
|
56 |
+
'GPT4o_MINI': (1, 1),
|
57 |
+
'XComposer2d5': (1, -1),
|
58 |
+
'XComposer2_4KHD': (1, -1),
|
59 |
+
'MiniCPM-Llama3-V-2_5': (1, 5),
|
60 |
+
'InternVL-Chat-V1-5': (5, 2),
|
61 |
+
}
|
62 |
+
|
63 |
+
def __init__(self, dataset, **kwargs):
|
64 |
+
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
65 |
+
model_name = kwargs['model']
|
66 |
+
if not listinstr(self.model_list, model_name):
|
67 |
+
raise AssertionError("{} doesn't support the evaluation on DUDE.".format(model_name))
|
68 |
+
super(DUDE, self).__init__(dataset)
|
69 |
+
|
70 |
+
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
71 |
+
self.max_pages = 120
|
72 |
+
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
73 |
+
self.concat_num = concat_num
|
74 |
+
self.column_num = column_num
|
75 |
+
|
76 |
+
def prepare_tsv(self, url, file_md5=None):
|
77 |
+
data_root = LMUDataRoot()
|
78 |
+
os.makedirs(data_root, exist_ok=True)
|
79 |
+
file_name = url.split('/')[-1]
|
80 |
+
data_path = osp.join(data_root, file_name)
|
81 |
+
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
82 |
+
pass
|
83 |
+
else:
|
84 |
+
warnings.warn('The dataset tsv is not downloaded')
|
85 |
+
download_file(url, data_path)
|
86 |
+
return load(data_path)
|
87 |
+
|
88 |
+
def dump_image(self, origin_line):
|
89 |
+
os.makedirs(self.img_root, exist_ok=True)
|
90 |
+
try:
|
91 |
+
import fitz
|
92 |
+
except Exception as e:
|
93 |
+
logging.critical(f'{type(e)}: {e}')
|
94 |
+
logging.critical('Please use `pip install pymupdf` to parse PDF files.')
|
95 |
+
|
96 |
+
line = origin_line.copy()
|
97 |
+
if not isinstance(line['image_path'], List):
|
98 |
+
line['image_path'] = [line['image_path']]
|
99 |
+
line['image_path'] = line['image_path'][:self.max_pages]
|
100 |
+
skip_pdf_parse = True
|
101 |
+
for im_name in line['image_path']:
|
102 |
+
path = osp.join(self.img_root, im_name)
|
103 |
+
if not read_ok(path):
|
104 |
+
skip_pdf_parse = False
|
105 |
+
break
|
106 |
+
|
107 |
+
# Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
|
108 |
+
if skip_pdf_parse:
|
109 |
+
line['image'] = line['image_path']
|
110 |
+
else:
|
111 |
+
pdf_data = base64.b64decode(line['image'])
|
112 |
+
pdf_file = io.BytesIO(pdf_data)
|
113 |
+
encoded_images = []
|
114 |
+
with fitz.open(stream=pdf_file, filetype='pdf') as doc:
|
115 |
+
doc = doc[:self.max_pages]
|
116 |
+
for page in doc:
|
117 |
+
image = page.get_pixmap(dpi=144)
|
118 |
+
image_file = io.BytesIO(image.tobytes(output='png'))
|
119 |
+
image = Image.open(image_file)
|
120 |
+
encoded_image = encode_image_to_base64(image)
|
121 |
+
encoded_images.append(encoded_image)
|
122 |
+
line['image'] = encoded_images
|
123 |
+
print('process {}'.format(line['doc_id']))
|
124 |
+
|
125 |
+
if 'image' in line:
|
126 |
+
if isinstance(line['image'], list):
|
127 |
+
tgt_path = []
|
128 |
+
assert 'image_path' in line
|
129 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
130 |
+
path = osp.join(self.img_root, im_name)
|
131 |
+
if not read_ok(path):
|
132 |
+
decode_base64_to_image_file(img, path)
|
133 |
+
tgt_path.append(path)
|
134 |
+
else:
|
135 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
136 |
+
if not read_ok(tgt_path):
|
137 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
138 |
+
tgt_path = [tgt_path]
|
139 |
+
else:
|
140 |
+
assert 'image_path' in line
|
141 |
+
tgt_path = toliststr(line['image_path'])
|
142 |
+
|
143 |
+
if self.concat_num > 0 and not self.is_api:
|
144 |
+
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
145 |
+
|
146 |
+
old_tgt_path = tgt_path
|
147 |
+
assert isinstance(old_tgt_path, list)
|
148 |
+
if self.column_num != -1:
|
149 |
+
tgt_path = [
|
150 |
+
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
151 |
+
for i in range(len(concatenated_images))
|
152 |
+
]
|
153 |
+
else:
|
154 |
+
tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
|
155 |
+
|
156 |
+
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
157 |
+
if not read_ok(path):
|
158 |
+
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
159 |
+
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
160 |
+
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
161 |
+
return tgt_path
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
165 |
+
logger = get_logger('Evaluation')
|
166 |
+
model = judge_kwargs['model']
|
167 |
+
|
168 |
+
suffix = eval_file.split('.')[-1]
|
169 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
170 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
171 |
+
|
172 |
+
if osp.exists(storage):
|
173 |
+
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in DUDE_eval. ')
|
174 |
+
else:
|
175 |
+
data = load(eval_file)
|
176 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
177 |
+
lt = len(data)
|
178 |
+
lines = [data.iloc[i] for i in range(lt)]
|
179 |
+
tups = [(model, line) for line in lines]
|
180 |
+
indices = [line['index'] for line in lines]
|
181 |
+
|
182 |
+
ans = {}
|
183 |
+
if osp.exists(tmp_file):
|
184 |
+
ans = load(tmp_file)
|
185 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
186 |
+
indices = [i for i in indices if i not in ans]
|
187 |
+
|
188 |
+
if len(indices):
|
189 |
+
new_results = list()
|
190 |
+
for model, line in tqdm(tups):
|
191 |
+
res = MMLongBench_auxeval(model, line)
|
192 |
+
new_results.append(res)
|
193 |
+
|
194 |
+
log_map, res_map, pred_map = {}, {}, {}
|
195 |
+
all_inds = [line['index'] for line in lines]
|
196 |
+
for k, v in zip(all_inds, new_results):
|
197 |
+
log_map[k] = v['log']
|
198 |
+
res_map[k] = v['res']
|
199 |
+
pred_map[k] = v['pred']
|
200 |
+
data['res'] = [res_map[idx] for idx in data['index']]
|
201 |
+
data['log'] = [log_map[idx] for idx in data['index']]
|
202 |
+
data['pred'] = [pred_map[idx] for idx in data['index']]
|
203 |
+
dump(data, storage)
|
204 |
+
|
205 |
+
score = DUDE_acc(storage)
|
206 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
207 |
+
|
208 |
+
dump(score, score_pth)
|
209 |
+
logger.info(f'DUDE successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
210 |
+
logger.info('Score: ')
|
211 |
+
logger.info(score)
|
VLMEvalKit/vlmeval/dataset/dynamath.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import sympy as sp
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from sympy import simplify, Eq, sympify, Pow, pi
|
7 |
+
from sympy.parsing.latex import parse_latex
|
8 |
+
import sys
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import os.path as osp
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
from .image_base import ImageBaseDataset
|
15 |
+
from .utils import build_judge
|
16 |
+
from ..utils import track_progress_rich
|
17 |
+
from ..smp import load, dump, d2df, toliststr
|
18 |
+
|
19 |
+
|
20 |
+
def preprocess(str1):
|
21 |
+
if 0 <= str1.find("{") < str1.rfind("}"):
|
22 |
+
str1 = str1[str1.find("{"): str1.rfind("}") + 1]
|
23 |
+
str2 = str1.replace("\\", "")
|
24 |
+
str2 = str2.replace("\\n", "\n")
|
25 |
+
return str2
|
26 |
+
|
27 |
+
|
28 |
+
def transfer(str1):
|
29 |
+
if "\u03c0" in str1:
|
30 |
+
strs = str1.split('\u03c0')
|
31 |
+
str1 = strs[0]
|
32 |
+
return float(str1) * np.pi
|
33 |
+
else:
|
34 |
+
return float(str1)
|
35 |
+
|
36 |
+
|
37 |
+
def parse_answer(answer, answer_type="multiple choice"):
|
38 |
+
if answer_type == "float":
|
39 |
+
if answer.isdigit():
|
40 |
+
return True, float(answer)
|
41 |
+
else:
|
42 |
+
parts = answer.split(' ')
|
43 |
+
answer = parts[0]
|
44 |
+
try:
|
45 |
+
answer = transfer(answer)
|
46 |
+
return True, answer
|
47 |
+
except:
|
48 |
+
return False, None
|
49 |
+
elif answer_type == "multiple choice":
|
50 |
+
if len(answer) == 1:
|
51 |
+
return True, answer.upper()
|
52 |
+
else:
|
53 |
+
in_flag = [ch in answer.upper() for ch in 'ABCDE']
|
54 |
+
if sum(in_flag) == 1:
|
55 |
+
for ch in 'ABCDE':
|
56 |
+
if ch in answer.upper():
|
57 |
+
return True, ch
|
58 |
+
return False, None
|
59 |
+
else:
|
60 |
+
return True, answer
|
61 |
+
|
62 |
+
|
63 |
+
def DynaMath_auxeval(model, line):
|
64 |
+
pred = line['prediction']
|
65 |
+
pred = preprocess(pred)
|
66 |
+
|
67 |
+
succeed, short_answer = None, None
|
68 |
+
try:
|
69 |
+
dj = json.loads(pred, strict=False)
|
70 |
+
short_answer = dj.get("short answer")
|
71 |
+
assert short_answer is not None
|
72 |
+
succeed, short_answer = parse_answer(short_answer, answer_type=line['anwser_type'])
|
73 |
+
assert succeed
|
74 |
+
except:
|
75 |
+
# Failed to parse the JSON, use an auxiliary LLM to get the short answer
|
76 |
+
if line['answer_type'] == 'multiple choice':
|
77 |
+
inst = "Output the corresponing choice option, such as 'A', 'B', 'C', 'D', in a single line."
|
78 |
+
elif line['answer_type'] == 'float':
|
79 |
+
inst = "Output a three-digit floating-point number in a single line."
|
80 |
+
else:
|
81 |
+
inst = (
|
82 |
+
"Output a short answer in a single line. Any float numbers in the answer "
|
83 |
+
"should be formatted as three-digit floating-point numbers."
|
84 |
+
)
|
85 |
+
|
86 |
+
prompt = f"Free-form answer: {pred}\nInstruction: {inst}"
|
87 |
+
response = pred
|
88 |
+
succeed, short_answer = parse_answer(response, line['answer_type'])
|
89 |
+
if not succeed:
|
90 |
+
response = model.generate(prompt)
|
91 |
+
succeed, short_answer = parse_answer(response, line['answer_type'])
|
92 |
+
|
93 |
+
if line['answer_type'] == 'float':
|
94 |
+
if succeed:
|
95 |
+
diff = float(short_answer) - float(line['answer'])
|
96 |
+
if abs(diff) <= 0.001:
|
97 |
+
return dict(parse=True, extracted=short_answer, correct=True)
|
98 |
+
else:
|
99 |
+
return dict(parse=True, extracted=short_answer, correct=False)
|
100 |
+
else:
|
101 |
+
return dict(parse=False, extracted=None, correct=False)
|
102 |
+
elif line['answer_type'] == 'multiple choice':
|
103 |
+
if succeed:
|
104 |
+
return dict(parse=True, extracted=short_answer, correct=(short_answer == line['answer']))
|
105 |
+
else:
|
106 |
+
if line['answer'] in pred[:3].upper():
|
107 |
+
return dict(parse=False, extracted=None, correct=True)
|
108 |
+
else:
|
109 |
+
return dict(parse=False, extracted=None, correct=False)
|
110 |
+
else:
|
111 |
+
if succeed:
|
112 |
+
return dict(parse=True, extracted=short_answer, correct=(short_answer.lower() in line['answer'].lower()))
|
113 |
+
else:
|
114 |
+
return dict(parse=False, extracted=None, correct=(short_answer.lower() in line['answer'].lower()))
|
115 |
+
|
116 |
+
|
117 |
+
class Dynamath(ImageBaseDataset):
|
118 |
+
|
119 |
+
TYPE = 'VQA'
|
120 |
+
DATASET_URL = {'DynaMath': 'https://opencompass.openxlab.space/utils/VLMEval/DynaMath.tsv'}
|
121 |
+
DATASET_MD5 = {'DynaMath': 'b8425ad9a7114571fc9366e013699494'}
|
122 |
+
GUIDE = """
|
123 |
+
## Answer Instruction Please provide an answer to the question outlined above. Your response should adhere \
|
124 |
+
to the following JSON format, which includes two keys: 'solution' and 'short answer'. The 'solution' key can contain \
|
125 |
+
detailed steps needed to solve the question, and the 'short answer' key should provide a concise response. {INST}
|
126 |
+
|
127 |
+
Example of expected JSON response format:
|
128 |
+
|
129 |
+
"""
|
130 |
+
EXAMPLE = {
|
131 |
+
"solution": "[Detailed step-by-step explanation]",
|
132 |
+
"short answer": "[Concise Answer]"
|
133 |
+
}
|
134 |
+
TEXT_EXAMPLE = json.dumps(EXAMPLE, indent=4)
|
135 |
+
|
136 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
137 |
+
def build_prompt(self, line):
|
138 |
+
if isinstance(line, int):
|
139 |
+
line = self.data.iloc[line]
|
140 |
+
|
141 |
+
if self.meta_only:
|
142 |
+
tgt_path = toliststr(line['image_path'])
|
143 |
+
else:
|
144 |
+
tgt_path = self.dump_image(line)
|
145 |
+
|
146 |
+
prompt = f"## Question\n {line['question']}"
|
147 |
+
if line['answer_type'] == 'multiple choice':
|
148 |
+
inst = "Provide the corresponing choice option in the 'short answer' key, such as 'A', 'B', 'C', or 'D'."
|
149 |
+
elif line['answer_type'] == 'float':
|
150 |
+
inst = "Format the answer as a three-digit floating-point number and provide it in the 'short answer' key."
|
151 |
+
else:
|
152 |
+
inst = "Float numbers in the answer should be formatted as three-digit floating-point numbers."
|
153 |
+
|
154 |
+
prompt = prompt + self.GUIDE.format(INST=inst) + self.TEXT_EXAMPLE
|
155 |
+
|
156 |
+
msgs = []
|
157 |
+
if isinstance(tgt_path, list):
|
158 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
159 |
+
else:
|
160 |
+
msgs = [dict(type='image', value=tgt_path)]
|
161 |
+
msgs.append(dict(type='text', value=prompt))
|
162 |
+
return msgs
|
163 |
+
|
164 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
165 |
+
judge_name = judge_kwargs.pop('model', 'gpt-4o-mini')
|
166 |
+
|
167 |
+
model = build_judge(model=judge_name, **judge_kwargs)
|
168 |
+
suffix = eval_file.split('.')[-1]
|
169 |
+
|
170 |
+
storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
|
171 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{judge_name}_score.csv') # noqa: F841
|
172 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
|
173 |
+
nproc = judge_kwargs.pop('nproc', 6) # noqa: F841
|
174 |
+
|
175 |
+
res = load(tmp_file) if os.path.exists(tmp_file) else {}
|
176 |
+
res = {k: v for k, v in res.items() if v is not None}
|
177 |
+
|
178 |
+
model.system_prompt = """\
|
179 |
+
You are a helpful assistant that helps me to format free-form answers into a short answer according to the instruction.
|
180 |
+
"""
|
181 |
+
if not osp.exists(storage):
|
182 |
+
data = load(eval_file)
|
183 |
+
lt = len(data)
|
184 |
+
payloads = [dict(model=model, line=data.iloc[i]) for i in range(lt) if data.iloc[i]['index'] not in res]
|
185 |
+
keys = [idx for idx in data['index'] if idx not in res]
|
186 |
+
|
187 |
+
if len(keys):
|
188 |
+
results = track_progress_rich(DynaMath_auxeval, payloads, nproc=nproc, save=tmp_file, keys=keys)
|
189 |
+
for k, r in zip(keys, results):
|
190 |
+
res[k] = r
|
191 |
+
|
192 |
+
data['parse'] = [res[idx]['parse'] for idx in data['index']]
|
193 |
+
data['extracted'] = [res[idx]['extracted'] for idx in data['index']]
|
194 |
+
data['correct'] = [res[idx]['correct'] for idx in data['index']]
|
195 |
+
dump(data, storage)
|
196 |
+
|
197 |
+
data = load(storage)
|
198 |
+
# Calculate Average Accuracy
|
199 |
+
score_avg = {}
|
200 |
+
score_avg['Overall'] = np.mean(data['correct'])
|
201 |
+
|
202 |
+
subs = set(data['subject'])
|
203 |
+
for sub in subs:
|
204 |
+
data_sub = data[data['subject'] == sub]
|
205 |
+
score_avg[f'Subject-{sub}'] = np.mean(data_sub['correct'])
|
206 |
+
|
207 |
+
lvls = set(data['knowledge_level'])
|
208 |
+
for lvl in lvls:
|
209 |
+
data_lvl = data[data['knowledge_level'] == lvl]
|
210 |
+
score_avg[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
|
211 |
+
|
212 |
+
# Calculate the Worst Case Accuracy
|
213 |
+
score_worst = {}
|
214 |
+
data_worst = data[data['varid'] == 1]
|
215 |
+
qid2corr = {idx: True for idx in data_worst['index']}
|
216 |
+
lt = len(data)
|
217 |
+
for i in range(lt):
|
218 |
+
item = data.iloc[i]
|
219 |
+
qid2corr[item['qid']] *= item['correct']
|
220 |
+
data_worst['correct'] = [qid2corr[idx] for idx in data_worst['qid']]
|
221 |
+
score_worst['Overall'] = np.mean(data_worst['correct'])
|
222 |
+
|
223 |
+
subs = set(data_worst['subject'])
|
224 |
+
for sub in subs:
|
225 |
+
data_sub = data_worst[data_worst['subject'] == sub]
|
226 |
+
score_worst[f'Subject-{sub}'] = np.mean(data_sub['correct'])
|
227 |
+
|
228 |
+
lvls = set(data_worst['knowledge_level'])
|
229 |
+
for lvl in lvls:
|
230 |
+
data_lvl = data_worst[data_worst['knowledge_level'] == lvl]
|
231 |
+
score_worst[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
|
232 |
+
|
233 |
+
d1 = {'Setting': 'Average'}
|
234 |
+
d1.update(score_avg)
|
235 |
+
d2 = {'Setting': 'Worst Case'}
|
236 |
+
d2.update(score_worst)
|
237 |
+
score = pd.concat([d2df(d1), d2df(d2)], ignore_index=True)
|
238 |
+
|
239 |
+
dump(score, score_file)
|
240 |
+
return score
|
VLMEvalKit/vlmeval/dataset/image_base.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from abc import abstractmethod
|
3 |
+
from ..smp import *
|
4 |
+
|
5 |
+
|
6 |
+
def img_root_map(dataset):
|
7 |
+
if 'MM_NIAH' in dataset:
|
8 |
+
return 'MMNIAH'
|
9 |
+
if 'CRPE' in dataset:
|
10 |
+
return 'CRPE'
|
11 |
+
if 'OCRVQA' in dataset:
|
12 |
+
return 'OCRVQA'
|
13 |
+
if 'COCO_VAL' == dataset:
|
14 |
+
return 'COCO'
|
15 |
+
if 'MMMU' in dataset:
|
16 |
+
return 'MMMU'
|
17 |
+
if "QSpatial" in dataset:
|
18 |
+
return "QSpatial"
|
19 |
+
|
20 |
+
mmbench_root_map = {
|
21 |
+
'MMBench_DEV_EN': 'MMBench', 'MMBench_TEST_EN': 'MMBench',
|
22 |
+
'MMBench_DEV_CN': 'MMBench', 'MMBench_TEST_CN': 'MMBench',
|
23 |
+
'MMBench': 'MMBench', 'MMBench_CN': 'MMBench',
|
24 |
+
'MMBench_DEV_EN_V11': 'MMBench_V11', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
25 |
+
'MMBench_DEV_CN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_V11',
|
26 |
+
'MMBench_V11': 'MMBench', 'MMBench_CN_V11': 'MMBench',
|
27 |
+
}
|
28 |
+
if dataset in mmbench_root_map:
|
29 |
+
return mmbench_root_map[dataset]
|
30 |
+
return dataset
|
31 |
+
|
32 |
+
|
33 |
+
class ImageBaseDataset:
|
34 |
+
|
35 |
+
MODALITY = 'IMAGE'
|
36 |
+
DATASET_URL = {}
|
37 |
+
DATASET_MD5 = {}
|
38 |
+
|
39 |
+
def __init__(self, dataset='MMBench', skip_noimg=True):
|
40 |
+
ROOT = LMUDataRoot()
|
41 |
+
# You can override this variable to save image files to a different directory
|
42 |
+
self.dataset_name = dataset
|
43 |
+
self.img_root = osp.join(ROOT, 'images', img_root_map(dataset))
|
44 |
+
|
45 |
+
data = self.load_data(dataset)
|
46 |
+
self.skip_noimg = skip_noimg
|
47 |
+
if skip_noimg and 'image' in data:
|
48 |
+
data = data[~pd.isna(data['image'])]
|
49 |
+
|
50 |
+
data['index'] = [str(x) for x in data['index']]
|
51 |
+
|
52 |
+
self.meta_only = True
|
53 |
+
|
54 |
+
# The image field can store the base64 encoded image or another question index (for saving space)
|
55 |
+
if 'image' in data:
|
56 |
+
data['image'] = [str(x) for x in data['image']]
|
57 |
+
image_map = {x: y for x, y in zip(data['index'], data['image'])}
|
58 |
+
for k in image_map:
|
59 |
+
if len(image_map[k]) <= 64:
|
60 |
+
idx = image_map[k]
|
61 |
+
assert idx in image_map and len(image_map[idx]) > 64
|
62 |
+
image_map[k] = image_map[idx]
|
63 |
+
|
64 |
+
images = [toliststr(image_map[k]) for k in data['index']]
|
65 |
+
data['image'] = [x[0] if len(x) == 1 else x for x in images]
|
66 |
+
self.meta_only = False
|
67 |
+
|
68 |
+
if 'image_path' in data:
|
69 |
+
paths = [toliststr(x) for x in data['image_path']]
|
70 |
+
data['image_path'] = [x[0] if len(x) == 1 else x for x in paths]
|
71 |
+
|
72 |
+
if np.all([istype(x, int) for x in data['index']]):
|
73 |
+
data['index'] = [int(x) for x in data['index']]
|
74 |
+
|
75 |
+
self.data = data
|
76 |
+
self.post_build(dataset)
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return len(self.data)
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
return dict(self.data.iloc[idx])
|
83 |
+
|
84 |
+
def prepare_tsv(self, url, file_md5=None):
|
85 |
+
data_root = LMUDataRoot()
|
86 |
+
os.makedirs(data_root, exist_ok=True)
|
87 |
+
update_flag = False
|
88 |
+
file_name = url.split('/')[-1]
|
89 |
+
data_path = osp.join(data_root, file_name)
|
90 |
+
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
91 |
+
pass
|
92 |
+
else:
|
93 |
+
warnings.warn('The dataset tsv is not downloaded')
|
94 |
+
download_file(url, data_path)
|
95 |
+
update_flag = True
|
96 |
+
|
97 |
+
if file_size(data_path, 'GB') > 1:
|
98 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
99 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
|
100 |
+
from ..tools import LOCALIZE
|
101 |
+
LOCALIZE(data_path, local_path)
|
102 |
+
data_path = local_path
|
103 |
+
return load(data_path)
|
104 |
+
|
105 |
+
def dump_image(self, line):
|
106 |
+
os.makedirs(self.img_root, exist_ok=True)
|
107 |
+
|
108 |
+
if 'image' in line:
|
109 |
+
if isinstance(line['image'], list):
|
110 |
+
tgt_path = []
|
111 |
+
assert 'image_path' in line
|
112 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
113 |
+
path = osp.join(self.img_root, im_name)
|
114 |
+
if not read_ok(path):
|
115 |
+
decode_base64_to_image_file(img, path)
|
116 |
+
tgt_path.append(path)
|
117 |
+
else:
|
118 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
119 |
+
if not read_ok(tgt_path):
|
120 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
121 |
+
tgt_path = [tgt_path]
|
122 |
+
else:
|
123 |
+
assert 'image_path' in line
|
124 |
+
tgt_path = toliststr(line['image_path'])
|
125 |
+
|
126 |
+
return tgt_path
|
127 |
+
|
128 |
+
def display(self, line):
|
129 |
+
if isinstance(line, int):
|
130 |
+
line = self.data.iloc[line]
|
131 |
+
assert isinstance(line, pd.Series) or isinstance(line, dict)
|
132 |
+
mmqa_display(line)
|
133 |
+
|
134 |
+
# Return a list of dataset names that are supported by this class, can override
|
135 |
+
@classmethod
|
136 |
+
def supported_datasets(cls):
|
137 |
+
return list(cls.DATASET_URL)
|
138 |
+
|
139 |
+
# Given the dataset name, return the dataset as a pandas dataframe, can override
|
140 |
+
def load_data(self, dataset):
|
141 |
+
url = self.DATASET_URL[dataset]
|
142 |
+
file_md5 = self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
|
143 |
+
return self.prepare_tsv(url, file_md5)
|
144 |
+
|
145 |
+
# Post built hook, will be called after the dataset is built, can override
|
146 |
+
def post_build(self, dataset):
|
147 |
+
pass
|
148 |
+
|
149 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
150 |
+
def build_prompt(self, line):
|
151 |
+
if isinstance(line, int):
|
152 |
+
line = self.data.iloc[line]
|
153 |
+
|
154 |
+
if self.meta_only:
|
155 |
+
tgt_path = toliststr(line['image_path'])
|
156 |
+
else:
|
157 |
+
tgt_path = self.dump_image(line)
|
158 |
+
|
159 |
+
question = line['question']
|
160 |
+
|
161 |
+
msgs = []
|
162 |
+
if isinstance(tgt_path, list):
|
163 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
164 |
+
else:
|
165 |
+
msgs = [dict(type='image', value=tgt_path)]
|
166 |
+
msgs.append(dict(type='text', value=question))
|
167 |
+
return msgs
|
168 |
+
|
169 |
+
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
170 |
+
@abstractmethod
|
171 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
172 |
+
pass
|
VLMEvalKit/vlmeval/dataset/image_caption.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_base import ImageBaseDataset
|
2 |
+
from ..smp import *
|
3 |
+
|
4 |
+
|
5 |
+
class COCO_Caption_Scorer():
|
6 |
+
def __init__(self, ref, gt):
|
7 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
8 |
+
from pycocoevalcap.rouge.rouge import Rouge
|
9 |
+
from pycocoevalcap.cider.cider import Cider
|
10 |
+
|
11 |
+
self.ref = ref
|
12 |
+
self.gt = gt
|
13 |
+
print('setting up scorers...')
|
14 |
+
self.scorers = [
|
15 |
+
(Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
|
16 |
+
(Rouge(), 'ROUGE_L'),
|
17 |
+
(Cider(), 'CIDEr'),
|
18 |
+
]
|
19 |
+
|
20 |
+
def compute_scores(self):
|
21 |
+
total_scores = {}
|
22 |
+
for scorer, method in self.scorers:
|
23 |
+
print('computing %s score...' % (scorer.method()))
|
24 |
+
score, scores = scorer.compute_score(self.gt, self.ref)
|
25 |
+
if isinstance(method, list):
|
26 |
+
for sc, scs, m in zip(score, scores, method):
|
27 |
+
print('%s: %0.3f' % (m, sc * 100))
|
28 |
+
total_scores['Bleu'] = [x * 100 for x in score]
|
29 |
+
else:
|
30 |
+
print('%s: %0.3f' % (method, score * 100))
|
31 |
+
total_scores[method] = score * 100
|
32 |
+
|
33 |
+
print('*****DONE*****')
|
34 |
+
for key, value in total_scores.items():
|
35 |
+
print('{}:{}'.format(key, value))
|
36 |
+
return total_scores
|
37 |
+
|
38 |
+
|
39 |
+
class ImageCaptionDataset(ImageBaseDataset):
|
40 |
+
|
41 |
+
TYPE = 'Caption'
|
42 |
+
|
43 |
+
DATASET_URL = {
|
44 |
+
'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
|
45 |
+
}
|
46 |
+
|
47 |
+
DATASET_MD5 = {
|
48 |
+
'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
|
49 |
+
}
|
50 |
+
|
51 |
+
def load_data(self, dataset):
|
52 |
+
data = super().load_data(dataset)
|
53 |
+
if 'question' not in data:
|
54 |
+
data['question'] = [(
|
55 |
+
'Please describe this image in general. Directly provide the description, '
|
56 |
+
'do not include prefix like "This image depicts". '
|
57 |
+
)] * len(data)
|
58 |
+
return data
|
59 |
+
|
60 |
+
# It returns a dictionary of scores
|
61 |
+
@classmethod
|
62 |
+
def evaluate(self, eval_file, **kwargs):
|
63 |
+
data = load(eval_file)
|
64 |
+
lt = len(data)
|
65 |
+
lines = [data.iloc[i] for i in range(lt)]
|
66 |
+
ref, gt = {}, {}
|
67 |
+
for i, line in enumerate(lines):
|
68 |
+
ref[str(i)] = [str(line['prediction'])]
|
69 |
+
gt[str(i)] = eval(line['answer'])
|
70 |
+
|
71 |
+
scorer = COCO_Caption_Scorer(ref, gt)
|
72 |
+
coco_caption_score_dict = scorer.compute_scores()
|
73 |
+
score_pth = eval_file.replace('.xlsx', '_score.json')
|
74 |
+
dump(coco_caption_score_dict, score_pth)
|
75 |
+
return coco_caption_score_dict
|
VLMEvalKit/vlmeval/dataset/image_mcq.py
ADDED
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
from .image_base import ImageBaseDataset
|
4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
5 |
+
from ..smp import *
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
MMMB_URLS = {
|
9 |
+
'MMMB_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ar.tsv',
|
10 |
+
'MMMB_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_cn.tsv',
|
11 |
+
'MMMB_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_en.tsv',
|
12 |
+
'MMMB_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_pt.tsv',
|
13 |
+
'MMMB_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ru.tsv',
|
14 |
+
'MMMB_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_tr.tsv',
|
15 |
+
}
|
16 |
+
|
17 |
+
MTL_MMBench_URLS = {
|
18 |
+
'MMBench_dev_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ar.tsv',
|
19 |
+
'MMBench_dev_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_cn.tsv',
|
20 |
+
'MMBench_dev_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_en.tsv',
|
21 |
+
'MMBench_dev_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_pt.tsv',
|
22 |
+
'MMBench_dev_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_tr.tsv',
|
23 |
+
'MMBench_dev_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ru.tsv',
|
24 |
+
}
|
25 |
+
|
26 |
+
MMMB_MD5 = {
|
27 |
+
'MMMB_ar': 'f3a18b6385f1d9701840aa42de27aead', 'MMMB_cn': '13ed82fa89730037292fcaa27f08f430',
|
28 |
+
'MMMB_en': '1cd781a71ec5a2983c090b84105d6a01', 'MMMB_pt': '548ea2b3bb2da991790386f0015d30d1',
|
29 |
+
'MMMB_ru': 'ce1cc8a0533425ab0d86b326ebfc2984', 'MMMB_tr': '0733739d43090327975294292bc5cd67'
|
30 |
+
}
|
31 |
+
|
32 |
+
MTL_MMBench_MD5 = {
|
33 |
+
'MMBench_dev_ar': '4271b4a0d0200e1a86380a878e0d64a4', 'MMBench_dev_cn': '2ed5135326fed02c8e51ea50dda8222f',
|
34 |
+
'MMBench_dev_en': 'd9ab776fc018b3d45785e9a5c23431c2', 'MMBench_dev_pt': '4ddfbcd27ef12444b908c03831cd0295',
|
35 |
+
'MMBench_dev_tr': '4fab39d501389d3d6cc90264bb708f11', 'MMBench_dev_ru': '5ba1171ff2e68f80637bf78349e402a5'
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class ImageMCQDataset(ImageBaseDataset):
|
40 |
+
|
41 |
+
TYPE = 'MCQ'
|
42 |
+
|
43 |
+
DATASET_URL = {
|
44 |
+
# MMBench v1.0
|
45 |
+
'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN.tsv',
|
46 |
+
'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN.tsv',
|
47 |
+
'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN.tsv',
|
48 |
+
'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN.tsv',
|
49 |
+
'MMBench': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench.tsv', # Internal
|
50 |
+
'MMBench_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN.tsv', # Internal
|
51 |
+
# MMBench v1.1
|
52 |
+
'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN_V11.tsv',
|
53 |
+
'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN_V11.tsv',
|
54 |
+
'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN_V11.tsv',
|
55 |
+
'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN_V11.tsv',
|
56 |
+
'MMBench_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_V11.tsv', # Internal
|
57 |
+
'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN_V11.tsv', # Internal
|
58 |
+
# SEEDBench Series
|
59 |
+
'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench_IMG.tsv',
|
60 |
+
'SEEDBench2': 'https://huggingface.co/datasets/VLMEval/SEEDBench2/resolve/main/SEEDBench2.tsv',
|
61 |
+
'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench2_Plus.tsv',
|
62 |
+
# ScienceQA Series
|
63 |
+
'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_VAL.tsv',
|
64 |
+
'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_TEST.tsv',
|
65 |
+
# MMT-Bench
|
66 |
+
'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL_MI.tsv',
|
67 |
+
'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL.tsv',
|
68 |
+
'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL_MI.tsv',
|
69 |
+
'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL.tsv',
|
70 |
+
# AesBench
|
71 |
+
'AesBench_VAL': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_VAL.tsv',
|
72 |
+
'AesBench_TEST': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_TEST.tsv',
|
73 |
+
# Q-Bench1
|
74 |
+
'Q-Bench1_VAL': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_VAL.tsv',
|
75 |
+
'Q-Bench1_TEST': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_TEST.tsv',
|
76 |
+
# A-Bench
|
77 |
+
'A-Bench_VAL': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_VAL.tsv',
|
78 |
+
'A-Bench_TEST': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_TEST.tsv',
|
79 |
+
# R-Bench
|
80 |
+
'R-Bench-Dis': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-dis.tsv',
|
81 |
+
'R-Bench-Ref': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-ref.tsv',
|
82 |
+
# Other Benchmarks
|
83 |
+
'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
|
84 |
+
'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
|
85 |
+
'AI2D_TEST_NO_MASK': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST_NO_MASK.tsv',
|
86 |
+
'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
|
87 |
+
'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
|
88 |
+
'MLLMGuard_DS': 'https://opencompass.openxlab.space/utils/VLMEval/MLLMGuard_DS.tsv',
|
89 |
+
'BLINK': 'https://opencompass.openxlab.space/utils/VLMEval/BLINK.tsv',
|
90 |
+
'TaskMeAnything_v1_imageqa_random': (
|
91 |
+
'https://huggingface.co/datasets/weikaih/TaskMeAnything-v1-imageqa-random/'
|
92 |
+
'resolve/main/TaskMeAnything-v1-imageqa-random.tsv'
|
93 |
+
),
|
94 |
+
'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv',
|
95 |
+
'WorldMedQA-V': 'https://opencompass.openxlab.space/utils/VLMEval/WorldMedQA-V.tsv',
|
96 |
+
'VisOnlyQA-VLMEvalKit': (
|
97 |
+
'https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Eval_Real/'
|
98 |
+
'resolve/main/visonlyqa_vlmevalkit.tsv'
|
99 |
+
),
|
100 |
+
}
|
101 |
+
|
102 |
+
DATASET_MD5 = {
|
103 |
+
# MMBench v1.0
|
104 |
+
'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
|
105 |
+
'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
|
106 |
+
'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
|
107 |
+
'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
|
108 |
+
'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
|
109 |
+
'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
|
110 |
+
# MMBench v1.1
|
111 |
+
'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
|
112 |
+
'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
|
113 |
+
'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
|
114 |
+
'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
|
115 |
+
'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
|
116 |
+
'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
|
117 |
+
# SEEDBench
|
118 |
+
'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
|
119 |
+
'SEEDBench2': '4ec15cf864c4f16274112284f531813e',
|
120 |
+
'SEEDBench2_Plus': 'e32d3216dc4f452b0fe497a52015d1fd',
|
121 |
+
# ScienceQA
|
122 |
+
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
|
123 |
+
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
|
124 |
+
# MMT-Bench
|
125 |
+
'MMT-Bench_ALL_MI': '5272157097e19cdd7cb41e412ab3b7c7',
|
126 |
+
'MMT-Bench_ALL': 'b273a2f4c596fe4f2605de0494cd632f',
|
127 |
+
'MMT-Bench_VAL_MI': 'c7d7b998eb5cd9aa36c7d4f721472462',
|
128 |
+
'MMT-Bench_VAL': '8dd4b730f53dbf9c3aed90ca31c928e0',
|
129 |
+
# AesBench
|
130 |
+
'AesBench_VAL': '3edb0c319e9187aa0b97fe7a11700a8c',
|
131 |
+
'AesBench_TEST': '58b1f7ba2cc32e1d68896d6ee716bbf8',
|
132 |
+
# Q-Bench1
|
133 |
+
'Q-Bench1_VAL': '837bdb6cd2da571713543462815187b7',
|
134 |
+
'Q-Bench1_TEST': '15e759bfd58c9d5f30b23a317d347153',
|
135 |
+
# A-Bench
|
136 |
+
'A-Bench_VAL': '218563ec50d34bb336c814143a5bb9c1',
|
137 |
+
'A-Bench_TEST': '567013fb033a20cf23f51d8e865bd16c',
|
138 |
+
# R-Bench
|
139 |
+
'R-Bench-Dis': 'd6e961dbfc43350688af2560226830b4',
|
140 |
+
'R-Bench-Ref': '270c1cb555acb523f3fdb178ed57021d',
|
141 |
+
# Other Benchmarks
|
142 |
+
'CCBench': 'f5dde47f24dc5a6fb6e595b409b466ac',
|
143 |
+
'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
|
144 |
+
'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
|
145 |
+
'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
|
146 |
+
'RealWorldQA': '92321028d2bc29040284b6674721e48f',
|
147 |
+
'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
|
148 |
+
'BLINK': '3b6649b6a662184ea046908e5506260e',
|
149 |
+
'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889',
|
150 |
+
'WorldMedQA-V': '441e63875e30c87f5750528b57b41285',
|
151 |
+
"VisOnlyQA-VLMEvalKit": 'cf460a31d2acb8d3a7cecd0e69298bfa',
|
152 |
+
}
|
153 |
+
|
154 |
+
DATASET_URL.update(MMMB_URLS)
|
155 |
+
DATASET_URL.update(MTL_MMBench_URLS)
|
156 |
+
DATASET_MD5.update(MMMB_MD5)
|
157 |
+
DATASET_MD5.update(MTL_MMBench_MD5)
|
158 |
+
|
159 |
+
def build_prompt(self, line):
|
160 |
+
|
161 |
+
if isinstance(line, int):
|
162 |
+
line = self.data.iloc[line]
|
163 |
+
|
164 |
+
if self.meta_only:
|
165 |
+
tgt_path = toliststr(line['image_path'])
|
166 |
+
else:
|
167 |
+
tgt_path = self.dump_image(line)
|
168 |
+
|
169 |
+
question = line['question']
|
170 |
+
options = {
|
171 |
+
cand: line[cand]
|
172 |
+
for cand in string.ascii_uppercase
|
173 |
+
if cand in line and not pd.isna(line[cand])
|
174 |
+
}
|
175 |
+
options_prompt = 'Options:\n'
|
176 |
+
for key, item in options.items():
|
177 |
+
options_prompt += f'{key}. {item}\n'
|
178 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
179 |
+
prompt = ''
|
180 |
+
if hint is not None:
|
181 |
+
prompt += f'Hint: {hint}\n'
|
182 |
+
prompt += f'Question: {question}\n'
|
183 |
+
if len(options):
|
184 |
+
prompt += options_prompt
|
185 |
+
prompt += 'Please select the correct answer from the options above. \n'
|
186 |
+
|
187 |
+
msgs = []
|
188 |
+
if isinstance(tgt_path, list):
|
189 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
190 |
+
else:
|
191 |
+
msgs = [dict(type='image', value=tgt_path)]
|
192 |
+
msgs.append(dict(type='text', value=prompt))
|
193 |
+
|
194 |
+
return msgs
|
195 |
+
|
196 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
197 |
+
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
198 |
+
# assert dataset is not None
|
199 |
+
dataset_map = {
|
200 |
+
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
201 |
+
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
202 |
+
}
|
203 |
+
dataset = self.dataset_name
|
204 |
+
if dataset in dataset_map:
|
205 |
+
dataset = dataset_map[dataset]
|
206 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
207 |
+
|
208 |
+
circular = False
|
209 |
+
if listinstr(['mmbench', 'ccbench'], dataset.lower()):
|
210 |
+
data = load(eval_file)
|
211 |
+
data['index'] = [int(x) for x in data['index']]
|
212 |
+
dump(data, eval_file)
|
213 |
+
circular = True
|
214 |
+
|
215 |
+
suffix = eval_file.split('.')[-1]
|
216 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
217 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
218 |
+
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
219 |
+
name_str = name_str_map[model] if model in name_str_map else model
|
220 |
+
|
221 |
+
if model == 'exact_matching':
|
222 |
+
model = None
|
223 |
+
elif gpt_key_set():
|
224 |
+
model = build_judge(**judge_kwargs)
|
225 |
+
if not model.working():
|
226 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
227 |
+
warnings.warn(DEBUG_MESSAGE)
|
228 |
+
model = None
|
229 |
+
else:
|
230 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
231 |
+
model = None
|
232 |
+
|
233 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
234 |
+
|
235 |
+
data = load(eval_file)
|
236 |
+
data = data.sort_values(by='index')
|
237 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
238 |
+
# If not choice label, then use lower case
|
239 |
+
for k in data.keys():
|
240 |
+
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
241 |
+
|
242 |
+
meta = self.data
|
243 |
+
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
244 |
+
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
245 |
+
for k in data_map:
|
246 |
+
assert k in meta_q_map, (
|
247 |
+
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
248 |
+
)
|
249 |
+
|
250 |
+
if circular:
|
251 |
+
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
252 |
+
else:
|
253 |
+
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
254 |
+
|
255 |
+
# load split
|
256 |
+
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
257 |
+
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
258 |
+
|
259 |
+
# May have different report acc functions for different datasets
|
260 |
+
if 'MMT' in dataset:
|
261 |
+
acc = report_acc_MMT(data)
|
262 |
+
else:
|
263 |
+
acc = report_acc(data)
|
264 |
+
|
265 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
266 |
+
dump(acc, score_file)
|
267 |
+
|
268 |
+
if dataset == 'AesBench_VAL':
|
269 |
+
warnings.warn('Note that AesBench VAL is just a toy version of AesBench TEST. For full results, \
|
270 |
+
please evaluate on AesBench TEST. The AesBench TEST dataset is more than 20 times \
|
271 |
+
larger than the VAL dataset and the leaderboard results are based on AesBench TEST.')
|
272 |
+
if dataset == 'VisOnlyQA-VLMEvalKit':
|
273 |
+
warnings.warn('Note that the results on VisOnlyQA-VLMEvalKit are different from the results on \
|
274 |
+
the original VisOnlyQA. VisOnlyQA-VLMEvalKit does not include the \
|
275 |
+
chemistry__shape_multi split and uses a different evaluation prompt. Please \
|
276 |
+
explicitly specify the version of the dataset when you report results.')
|
277 |
+
|
278 |
+
return acc
|
279 |
+
|
280 |
+
|
281 |
+
class MMMUDataset(ImageMCQDataset):
|
282 |
+
|
283 |
+
DATASET_URL = {
|
284 |
+
'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
|
285 |
+
'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
|
286 |
+
}
|
287 |
+
|
288 |
+
DATASET_MD5 = {
|
289 |
+
'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
|
290 |
+
'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
|
291 |
+
}
|
292 |
+
|
293 |
+
@staticmethod
|
294 |
+
def split_MMMU(msgs):
|
295 |
+
text, images = None, []
|
296 |
+
for s in msgs:
|
297 |
+
if s['type'] == 'image':
|
298 |
+
images.append(s['value'])
|
299 |
+
elif s['type'] == 'text':
|
300 |
+
assert text is None
|
301 |
+
text = s['value']
|
302 |
+
text_segs = text.split('<image ')
|
303 |
+
if len(text_segs) == 1:
|
304 |
+
return msgs
|
305 |
+
|
306 |
+
segs = [dict(type='text', value=text_segs[0])]
|
307 |
+
for i, seg in enumerate(text_segs):
|
308 |
+
if i == 0:
|
309 |
+
continue
|
310 |
+
assert istype(seg[0], int) and seg[1] == '>'
|
311 |
+
image_idx = int(seg[0]) - 1
|
312 |
+
segs.append(dict(type='image', value=images[image_idx]))
|
313 |
+
segs.append(dict(type='text', value=seg[2:]))
|
314 |
+
return segs
|
315 |
+
|
316 |
+
def build_prompt(self, line):
|
317 |
+
msgs = super().build_prompt(line)
|
318 |
+
msgs = self.split_MMMU(msgs)
|
319 |
+
return msgs
|
320 |
+
|
321 |
+
|
322 |
+
class MUIRDataset(ImageMCQDataset):
|
323 |
+
|
324 |
+
DATASET_URL = {
|
325 |
+
'MUIRBench': 'http://opencompass.openxxlab.com/utils/VLMEval/MUIRBench.tsv'
|
326 |
+
}
|
327 |
+
|
328 |
+
DATASET_MD5 = {
|
329 |
+
'MUIRBench': '2e5e6fd7699761b08a7cb3ab8c0c2ec8'
|
330 |
+
}
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def split_MUIR(msgs):
|
334 |
+
text, images = None, []
|
335 |
+
|
336 |
+
# Separate images and text from msgs
|
337 |
+
for s in msgs:
|
338 |
+
if s['type'] == 'image':
|
339 |
+
images.append(s['value'])
|
340 |
+
elif s['type'] == 'text':
|
341 |
+
assert text is None # Ensure only one text entry is expected
|
342 |
+
text = s['value']
|
343 |
+
|
344 |
+
# Split text by <image> tags
|
345 |
+
text_segs = text.split('<image>')
|
346 |
+
|
347 |
+
# Initialize the segments list
|
348 |
+
segs = []
|
349 |
+
|
350 |
+
# Iterate through the text segments and images
|
351 |
+
for i, seg in enumerate(text_segs):
|
352 |
+
# Append the image if this is not the first segment and there are still images left
|
353 |
+
if i > 0 and i - 1 < len(images):
|
354 |
+
segs.append(dict(type='image', value=images[i - 1]))
|
355 |
+
# Append the text segment (if it's non-empty)
|
356 |
+
if len(seg) > 0:
|
357 |
+
segs.append(dict(type='text', value=seg))
|
358 |
+
|
359 |
+
return segs
|
360 |
+
|
361 |
+
def build_prompt(self, line):
|
362 |
+
|
363 |
+
if isinstance(line, int):
|
364 |
+
line = self.data.iloc[line]
|
365 |
+
|
366 |
+
if self.meta_only:
|
367 |
+
tgt_path = toliststr(line['image_path'])
|
368 |
+
else:
|
369 |
+
tgt_path = self.dump_image(line)
|
370 |
+
|
371 |
+
question = line['question']
|
372 |
+
options = {
|
373 |
+
cand: line[cand]
|
374 |
+
for cand in string.ascii_uppercase
|
375 |
+
if cand in line and not pd.isna(line[cand])
|
376 |
+
}
|
377 |
+
# options_prompt = ''
|
378 |
+
options_prompt = '\n'.join([f'{key}. {item}' for key, item in options.items()])
|
379 |
+
# for key, item in options.items():
|
380 |
+
# options_prompt += f'{key}. {item}\n'
|
381 |
+
|
382 |
+
prompt = ''
|
383 |
+
|
384 |
+
prompt += f'{question}\n'
|
385 |
+
if len(options):
|
386 |
+
prompt += options_prompt
|
387 |
+
prompt += "\nAnswer with the option's letter from the given choices directly."
|
388 |
+
|
389 |
+
msgs = []
|
390 |
+
if isinstance(tgt_path, list):
|
391 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
392 |
+
else:
|
393 |
+
msgs = [dict(type='image', value=tgt_path)]
|
394 |
+
msgs.append(dict(type='text', value=prompt))
|
395 |
+
|
396 |
+
msgs = self.split_MUIR(msgs)
|
397 |
+
return msgs
|
398 |
+
|
399 |
+
|
400 |
+
class GMAIMMBenchDataset(ImageMCQDataset):
|
401 |
+
|
402 |
+
DATASET_URL = {
|
403 |
+
'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv',
|
404 |
+
'GMAI_mm_bench_TEST_part_1': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_1.tsv', # noqa: E501
|
405 |
+
'GMAI_mm_bench_TEST_part_2': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_2.tsv', # noqa: E501
|
406 |
+
'GMAI_mm_bench_TEST_part_3': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_3.tsv', # noqa: E501
|
407 |
+
'GMAI_mm_bench_TEST_part_4': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_4.tsv', # noqa: E501
|
408 |
+
'GMAI_mm_bench_TEST_part_5': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_5.tsv', # noqa: E501
|
409 |
+
'GMAI_mm_bench_TEST_part_6': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_6.tsv', # noqa: E501
|
410 |
+
'GMAI_mm_bench_TEST_part_7': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_7.tsv', # noqa: E501
|
411 |
+
'GMAI_mm_bench_TEST_part_8': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_8.tsv', # noqa: E501
|
412 |
+
'GMAI_mm_bench_TEST_part_9': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_9.tsv', # noqa: E501
|
413 |
+
'GMAI_mm_bench_TEST_part_10': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_10.tsv', # noqa: E501
|
414 |
+
'GMAI_mm_bench_TEST_part_11': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_11.tsv', # noqa: E501
|
415 |
+
}
|
416 |
+
|
417 |
+
DATASET_MD5 = {
|
418 |
+
'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324',
|
419 |
+
'GMAI_mm_bench_TEST_part_1': '900d735231230a63f4ed45665c078ef4',
|
420 |
+
'GMAI_mm_bench_TEST_part_2': '1b27ab621386945d7e4a765ad2d22b0e',
|
421 |
+
'GMAI_mm_bench_TEST_part_3': '44bdc2b6267dd505d529b8cad06f0fb2',
|
422 |
+
'GMAI_mm_bench_TEST_part_4': '5a04a04fcac9f1466709f242fdb80acb',
|
423 |
+
'GMAI_mm_bench_TEST_part_5': 'c70baf8909eda9af0ddeab275c721336',
|
424 |
+
'GMAI_mm_bench_TEST_part_6': '825abc39596b644dead9350d0cfa3b96',
|
425 |
+
'GMAI_mm_bench_TEST_part_7': 'defb8aed2fb77365a76b6b9abd6a2701',
|
426 |
+
'GMAI_mm_bench_TEST_part_8': 'ff490d60b85f2bb0abb67a435b298c65',
|
427 |
+
'GMAI_mm_bench_TEST_part_9': 'ff67c86f40da93b09139ac1d1ba5dc6b',
|
428 |
+
'GMAI_mm_bench_TEST_part_10': '3dae94627b9ac0fe00180d4780fbf6dc',
|
429 |
+
'GMAI_mm_bench_TEST_part_11': 'd08dc813f0eb6bbab63cae2a9d113c4b',
|
430 |
+
}
|
431 |
+
|
432 |
+
@classmethod
|
433 |
+
def supported_datasets(cls):
|
434 |
+
return ['GMAI-MMBench_VAL', 'GMAI-MMBench_TEST']
|
435 |
+
|
436 |
+
def load_data(self, dataset):
|
437 |
+
if dataset == 'GMAI-MMBench_VAL':
|
438 |
+
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
439 |
+
if file_size(data_path, 'GB') > 1:
|
440 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
441 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
|
442 |
+
from ..tools import LOCALIZE
|
443 |
+
LOCALIZE(data_path, local_path)
|
444 |
+
data_path = local_path
|
445 |
+
return load(data_path)
|
446 |
+
elif dataset == 'GMAI-MMBench_TEST':
|
447 |
+
dfs = []
|
448 |
+
for part_num in range(1, 12):
|
449 |
+
part_name = f'GMAI_mm_bench_TEST_part_{part_num}'
|
450 |
+
url = self.DATASET_URL[part_name]
|
451 |
+
file_md5 = self.DATASET_MD5.get(part_name)
|
452 |
+
tsv_path = osp.join(LMUDataRoot(), f'{part_name}.tsv')
|
453 |
+
if not osp.exists(tsv_path) or (file_md5 and md5(tsv_path) != file_md5):
|
454 |
+
download_file(url, filename=tsv_path)
|
455 |
+
local_path = tsv_path.replace('.tsv', '_local.tsv')
|
456 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
|
457 |
+
from ..tools import LOCALIZE
|
458 |
+
LOCALIZE(tsv_path, local_path)
|
459 |
+
tsv_path = local_path
|
460 |
+
# 加载数据
|
461 |
+
df = load(tsv_path)
|
462 |
+
dfs.append(df)
|
463 |
+
# 合并所有数据
|
464 |
+
data = pd.concat(dfs, ignore_index=True)
|
465 |
+
return data
|
466 |
+
else:
|
467 |
+
raise ValueError(f"未知的数据集:{dataset}")
|
468 |
+
|
469 |
+
def report_acc_by_groups(self, df, group_column):
|
470 |
+
res = defaultdict(list)
|
471 |
+
|
472 |
+
# Check for the 'split' column
|
473 |
+
if 'split' in df:
|
474 |
+
splits = list(set(df['split']))
|
475 |
+
res['split'] = splits
|
476 |
+
else:
|
477 |
+
df['split'] = ['none'] * len(df)
|
478 |
+
res['split'] = ['none']
|
479 |
+
|
480 |
+
res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
|
481 |
+
|
482 |
+
if group_column not in df:
|
483 |
+
raise ValueError(f"Column '{group_column}' not found in dataframe.") # noqa: E713
|
484 |
+
|
485 |
+
abilities = list(set(df[group_column]))
|
486 |
+
abilities = ['None' if isinstance(ab, float) and pd.isna(ab) else ab for ab in abilities]
|
487 |
+
abilities.sort()
|
488 |
+
|
489 |
+
for ab in abilities:
|
490 |
+
ab_name = ab
|
491 |
+
sub_df = df[df[group_column] == ab]
|
492 |
+
res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
|
493 |
+
|
494 |
+
return pd.DataFrame(res)
|
495 |
+
|
496 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
497 |
+
from .utils.multiple_choice import report_acc, mcq_vanilla_eval
|
498 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
499 |
+
|
500 |
+
suffix = eval_file.split('.')[-1]
|
501 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
502 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
503 |
+
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
504 |
+
name_str = name_str_map[model] if model in name_str_map else model
|
505 |
+
|
506 |
+
if model == 'exact_matching':
|
507 |
+
model = None
|
508 |
+
elif gpt_key_set():
|
509 |
+
model = build_judge(**judge_kwargs)
|
510 |
+
if not model.working():
|
511 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
512 |
+
warnings.warn(DEBUG_MESSAGE)
|
513 |
+
model = None
|
514 |
+
else:
|
515 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
516 |
+
model = None
|
517 |
+
|
518 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
519 |
+
|
520 |
+
data = load(eval_file)
|
521 |
+
data = data.sort_values(by='index')
|
522 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
523 |
+
# If not choice label, then use lower case
|
524 |
+
for k in data.keys():
|
525 |
+
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
526 |
+
|
527 |
+
meta = self.data
|
528 |
+
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
529 |
+
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
530 |
+
for k in data_map:
|
531 |
+
assert k in meta_q_map, (
|
532 |
+
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
533 |
+
)
|
534 |
+
|
535 |
+
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
536 |
+
|
537 |
+
# load split
|
538 |
+
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
539 |
+
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
540 |
+
|
541 |
+
acc = report_acc(data)
|
542 |
+
|
543 |
+
for group_col in ['clinical vqa task', 'department', 'perceptual granularity']:
|
544 |
+
acc_grouped = self.report_acc_by_groups(data, group_col)
|
545 |
+
score_file_grouped = eval_file.replace(f'.{suffix}', f'_{group_col}_acc.csv')
|
546 |
+
dump(acc_grouped, score_file_grouped)
|
547 |
+
|
548 |
+
return acc
|
549 |
+
|
550 |
+
|
551 |
+
class MMERealWorld(ImageMCQDataset):
|
552 |
+
|
553 |
+
TYPE = 'MMERealWorld'
|
554 |
+
|
555 |
+
DATASET_MD5 = {
|
556 |
+
'MME-RealWorld': '271c33ec814c39533c467ec6fb8a6f36',
|
557 |
+
'MME-RealWorld-Lite': '4c17057d7d3b6c4a0d4397c3dae0881c',
|
558 |
+
'MME-RealWorld-CN': 'daaa763d52a760a38606d5dedb3fe444',
|
559 |
+
}
|
560 |
+
SYS = {
|
561 |
+
'MME-RealWorld': (
|
562 |
+
'Select the best answer to the above multiple-choice question based on the image. '
|
563 |
+
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
|
564 |
+
'The best answer is:'
|
565 |
+
),
|
566 |
+
'MME-RealWorld-Lite': (
|
567 |
+
'Select the best answer to the above multiple-choice question based on the image. '
|
568 |
+
'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
|
569 |
+
'The best answer is:'
|
570 |
+
),
|
571 |
+
'MME-RealWorld-CN': (
|
572 |
+
'根据图像选择上述多项选择题的最佳答案。只需回答正确选项的字母(A, B, C, D 或 E)。\n'
|
573 |
+
'最佳答案为:'
|
574 |
+
),
|
575 |
+
}
|
576 |
+
|
577 |
+
@classmethod
|
578 |
+
def supported_datasets(cls):
|
579 |
+
return ['MME-RealWorld', 'MME-RealWorld-CN', 'MME-RealWorld-Lite',]
|
580 |
+
|
581 |
+
def load_data(
|
582 |
+
self, dataset="MME-RealWorld", repo_id="yifanzhang114/MME-RealWorld-Base64"
|
583 |
+
):
|
584 |
+
|
585 |
+
def check_integrity(pth):
|
586 |
+
data_file = osp.join(pth, f"{dataset}.tsv")
|
587 |
+
|
588 |
+
if not os.path.exists(data_file):
|
589 |
+
return False
|
590 |
+
|
591 |
+
if md5(data_file) != self.DATASET_MD5[dataset]:
|
592 |
+
return False
|
593 |
+
return True
|
594 |
+
|
595 |
+
def generate_tsv(pth):
|
596 |
+
tsv_file = os.path.join(pth, f"{dataset}.tsv")
|
597 |
+
|
598 |
+
if os.path.exists(tsv_file):
|
599 |
+
print(f"{tsv_file} already exists.")
|
600 |
+
return
|
601 |
+
|
602 |
+
json_dir = os.path.join(pth, dataset)
|
603 |
+
json_files = [f for f in os.listdir(json_dir) if f.endswith(".json")]
|
604 |
+
|
605 |
+
data_list = []
|
606 |
+
for json_file in json_files:
|
607 |
+
with open(os.path.join(json_dir, json_file), "r") as f:
|
608 |
+
data = json.load(f)
|
609 |
+
for item in tqdm(data):
|
610 |
+
choice_prompt = (
|
611 |
+
"The choices are listed below:\n"
|
612 |
+
if dataset in ["MME-RealWorld", "MME-RealWorld-Lite"]
|
613 |
+
else "选项如下所示:\n"
|
614 |
+
)
|
615 |
+
data_list.append(
|
616 |
+
{
|
617 |
+
"index": item["index"],
|
618 |
+
"image": item["image"],
|
619 |
+
"question": item["question"],
|
620 |
+
"multi-choice options": choice_prompt
|
621 |
+
+ "\n".join(item["multi-choice options"]),
|
622 |
+
"A": item["multi-choice options"][0][4:],
|
623 |
+
"B": item["multi-choice options"][1][4:],
|
624 |
+
"C": item["multi-choice options"][2][4:],
|
625 |
+
"D": item["multi-choice options"][3][4:],
|
626 |
+
"E": item["multi-choice options"][4][4:],
|
627 |
+
"answer": item["answer"],
|
628 |
+
"category": item["category"],
|
629 |
+
"l2-category": item["l2-category"],
|
630 |
+
}
|
631 |
+
)
|
632 |
+
df = pd.DataFrame(data_list)
|
633 |
+
df.to_csv(tsv_file, sep="\t", index=False)
|
634 |
+
print(f"TSV file saved to {tsv_file}")
|
635 |
+
|
636 |
+
# Check if dataset is cached and has integrity
|
637 |
+
if dataset == "MME-RealWorld-Lite":
|
638 |
+
url = 'https://huggingface.co/datasets/yifanzhang114/MME-RealWorld-Base64/resolve/main/mme_realworld_lite.tsv' # noqa: E501
|
639 |
+
file_md5 = (
|
640 |
+
self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
|
641 |
+
)
|
642 |
+
datas = self.prepare_tsv(url, file_md5)
|
643 |
+
choice_prompt = "The choices are listed below:\n"
|
644 |
+
for index, item in datas.iterrows():
|
645 |
+
options = eval(item["multi-choice options"])
|
646 |
+
datas.loc[index, "multi-choice options"] = choice_prompt + "\n".join(
|
647 |
+
options
|
648 |
+
)
|
649 |
+
datas.loc[index, "A"] = options[0][4:]
|
650 |
+
datas.loc[index, "B"] = options[1][4:]
|
651 |
+
datas.loc[index, "C"] = options[2][4:]
|
652 |
+
datas.loc[index, "D"] = options[3][4:]
|
653 |
+
datas.loc[index, "E"] = options[4][4:]
|
654 |
+
return datas
|
655 |
+
|
656 |
+
update_flag = False
|
657 |
+
cache_path = get_cache_path(repo_id)
|
658 |
+
if cache_path is not None and check_integrity(cache_path):
|
659 |
+
dataset_path = cache_path
|
660 |
+
print(f"Using cached dataset from {cache_path}")
|
661 |
+
else:
|
662 |
+
from huggingface_hub import snapshot_download
|
663 |
+
|
664 |
+
# Download or find the dataset path
|
665 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
|
666 |
+
generate_tsv(dataset_path)
|
667 |
+
update_flag = True
|
668 |
+
|
669 |
+
data_path = os.path.join(dataset_path, f"{dataset}.tsv")
|
670 |
+
if file_size(data_path, "GB") > 1:
|
671 |
+
local_path = data_path.replace(".tsv", "_local.tsv")
|
672 |
+
if (
|
673 |
+
not osp.exists(local_path)
|
674 |
+
or os.environ.get("FORCE_LOCAL", None)
|
675 |
+
or update_flag
|
676 |
+
):
|
677 |
+
from vlmeval.tools import LOCALIZE
|
678 |
+
|
679 |
+
LOCALIZE(data_path, local_path)
|
680 |
+
data_path = local_path
|
681 |
+
return load(data_path)
|
682 |
+
|
683 |
+
def post_build(self, dataset):
|
684 |
+
self.TYPE = 'MMERealWorld'
|
685 |
+
|
686 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
687 |
+
def build_prompt(self, line):
|
688 |
+
if isinstance(line, int):
|
689 |
+
line = self.data.iloc[line]
|
690 |
+
|
691 |
+
if self.meta_only:
|
692 |
+
tgt_path = toliststr(line['image_path'])
|
693 |
+
else:
|
694 |
+
tgt_path = self.dump_image(line)
|
695 |
+
|
696 |
+
question = line['question']
|
697 |
+
|
698 |
+
choice_prompt = line['multi-choice options'] + '\n'
|
699 |
+
question += ' ' + choice_prompt + self.SYS[self.dataset_name]
|
700 |
+
|
701 |
+
msgs = []
|
702 |
+
if isinstance(tgt_path, list):
|
703 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
704 |
+
else:
|
705 |
+
msgs = [dict(type='image', value=tgt_path)]
|
706 |
+
msgs.append(dict(type='text', value=question))
|
707 |
+
return msgs
|
708 |
+
|
709 |
+
# It returns a dictionary
|
710 |
+
@classmethod
|
711 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
712 |
+
from .utils.multiple_choice import extract_characters_regex, get_dimension_rating
|
713 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
714 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
715 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
716 |
+
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
717 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
718 |
+
|
719 |
+
if not osp.exists(score_file):
|
720 |
+
|
721 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
722 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
723 |
+
|
724 |
+
data = load(eval_file)
|
725 |
+
cnt_rejected = 0
|
726 |
+
data_un = data[~pd.isna(data['prediction'])]
|
727 |
+
|
728 |
+
for idx in data['index']:
|
729 |
+
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
730 |
+
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
731 |
+
|
732 |
+
extract_pred = extract_characters_regex(pred)
|
733 |
+
if extract_pred == '':
|
734 |
+
cnt_rejected += 1
|
735 |
+
data.loc[data['index'] == idx, 'score'] = 0
|
736 |
+
else:
|
737 |
+
data.loc[data['index'] == idx, 'score'] = int(extract_pred == ans)
|
738 |
+
|
739 |
+
print(
|
740 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
741 |
+
f'failed to obtain the score for another {cnt_rejected} questions. '
|
742 |
+
f'Those questions will be counted as 0 score in ALL rating.'
|
743 |
+
)
|
744 |
+
|
745 |
+
dump(data, score_file)
|
746 |
+
|
747 |
+
rating = get_dimension_rating(score_file)
|
748 |
+
dump(rating, tgt_file)
|
749 |
+
return rating
|
750 |
+
|
751 |
+
|
752 |
+
class HRBenchDataset(ImageMCQDataset):
|
753 |
+
|
754 |
+
DATASET_URL = {
|
755 |
+
'HRBench4K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_4k.tsv',
|
756 |
+
'HRBench8K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_8k.tsv',
|
757 |
+
}
|
758 |
+
|
759 |
+
DATASET_MD5 = {
|
760 |
+
'HRBench4K': 'f6b041b03d49543494b8a56d2e35be65',
|
761 |
+
'HRBench8K': '274c9c7f89329b804a4723178a00219c',
|
762 |
+
}
|
763 |
+
|
764 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
765 |
+
assert os.path.exists(eval_file), '{} does not exist!'.format(eval_file)
|
766 |
+
from .utils.multiple_choice import mcq_vanilla_eval
|
767 |
+
from .utils.hrbench import report_acc_hrbench
|
768 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
769 |
+
|
770 |
+
suffix = eval_file.split('.')[-1]
|
771 |
+
model = judge_kwargs.get('model', 'extract_matching')
|
772 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
773 |
+
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
774 |
+
name_str = name_str_map[model] if model in name_str_map else model
|
775 |
+
|
776 |
+
if model == 'exact_matching':
|
777 |
+
model = None
|
778 |
+
elif gpt_key_set():
|
779 |
+
model = build_judge(**judge_kwargs)
|
780 |
+
if not model.working():
|
781 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
782 |
+
warnings.warn(DEBUG_MESSAGE)
|
783 |
+
model = None
|
784 |
+
else:
|
785 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
786 |
+
model = None
|
787 |
+
|
788 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
789 |
+
|
790 |
+
data = load(eval_file)
|
791 |
+
data = data.sort_values(by='index')
|
792 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
793 |
+
# If not choice label, then use lower case
|
794 |
+
for k in data.keys():
|
795 |
+
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
796 |
+
|
797 |
+
meta = self.data
|
798 |
+
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
799 |
+
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
800 |
+
for k in data_map:
|
801 |
+
assert k in meta_q_map, (
|
802 |
+
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
803 |
+
)
|
804 |
+
|
805 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
806 |
+
|
807 |
+
if osp.exists(score_file):
|
808 |
+
acc = load(score_file)
|
809 |
+
return acc
|
810 |
+
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
811 |
+
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
812 |
+
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
813 |
+
|
814 |
+
acc = report_acc_hrbench(data)
|
815 |
+
|
816 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
817 |
+
dump(acc, score_file)
|
818 |
+
|
819 |
+
return acc
|
820 |
+
|
821 |
+
|
822 |
+
class CustomMCQDataset(ImageMCQDataset):
|
823 |
+
|
824 |
+
def load_data(self, dataset):
|
825 |
+
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
826 |
+
|
827 |
+
if file_size(data_path, 'GB') > 1:
|
828 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
829 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
830 |
+
from ..tools import LOCALIZE
|
831 |
+
LOCALIZE(data_path, local_path)
|
832 |
+
data_path = local_path
|
833 |
+
return load(data_path)
|
834 |
+
|
835 |
+
|
836 |
+
class NaturalBenchDataset(ImageMCQDataset):
|
837 |
+
|
838 |
+
DATASET_URL = {
|
839 |
+
'NaturalBenchDataset': (
|
840 |
+
'https://huggingface.co/datasets/BaiqiL/'
|
841 |
+
'NaturalBench/resolve/main/NaturalBenchDataset.tsv'
|
842 |
+
),
|
843 |
+
}
|
844 |
+
DATASET_MD5 = {
|
845 |
+
'NaturalBenchDataset':'dbe25b044bc35696426381e9ba4fe930',
|
846 |
+
}
|
847 |
+
|
848 |
+
def build_prompt(self, line):
|
849 |
+
SUFFIX_FOR_VQA = {
|
850 |
+
"yes_no": "Please answer Yes or No.",
|
851 |
+
"multiple_choice": "Please output the letter corresponding to the correct option."
|
852 |
+
}
|
853 |
+
if isinstance(line, int):
|
854 |
+
line = self.data.iloc[line]
|
855 |
+
|
856 |
+
if self.meta_only:
|
857 |
+
tgt_path = toliststr(line['image_path'])
|
858 |
+
else:
|
859 |
+
tgt_path = self.dump_image(line)
|
860 |
+
|
861 |
+
question = line['question']
|
862 |
+
prompt = f'{question} {SUFFIX_FOR_VQA[line["type"]]}'
|
863 |
+
msgs = []
|
864 |
+
if isinstance(tgt_path, list):
|
865 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
866 |
+
else:
|
867 |
+
msgs = [dict(type='image', value=tgt_path)]
|
868 |
+
msgs.append(dict(type='text', value=prompt))
|
869 |
+
|
870 |
+
return msgs
|
871 |
+
|
872 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
873 |
+
from .utils.naturalbench import extract_answer, get_scores
|
874 |
+
|
875 |
+
data = load(eval_file)
|
876 |
+
data = data.sort_values(by='index')
|
877 |
+
predictions = [str(x) for x in data['prediction']]
|
878 |
+
answers = [str(x) for x in data['answer']]
|
879 |
+
indexs = [str(x) for x in data['index']]
|
880 |
+
meta = self.data
|
881 |
+
types = [str(x) for x in meta['type']]
|
882 |
+
results = {}
|
883 |
+
assert len(predictions) == len(answers) == len(indexs) == len(types) == (1900 * 4)
|
884 |
+
number_answered_samples = len(predictions) // 4
|
885 |
+
for i in range(number_answered_samples):
|
886 |
+
results[i] = {
|
887 |
+
"q0_i0": extract_answer(predictions[i * 4], types[i * 4]),
|
888 |
+
"q0_i1": extract_answer(predictions[i * 4 + 1], types[i * 4 + 1]),
|
889 |
+
"q1_i0": extract_answer(predictions[i * 4 + 2], types[i * 4 + 2]),
|
890 |
+
"q1_i1": extract_answer(predictions[i * 4 + 3], types[i * 4 + 3])
|
891 |
+
}
|
892 |
+
|
893 |
+
scores = get_scores(results)
|
894 |
+
print(scores)
|
895 |
+
score_file = 'NaturalBench_acc.csv'
|
896 |
+
df = pd.DataFrame(list(scores.items()), columns=['Metric', 'Score'])
|
897 |
+
dump(df, score_file)
|
898 |
+
|
899 |
+
return scores
|
VLMEvalKit/vlmeval/dataset/image_mt.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_base import ImageBaseDataset
|
2 |
+
from .utils.judge_util import build_judge
|
3 |
+
from ..smp import *
|
4 |
+
from ..utils import track_progress_rich
|
5 |
+
|
6 |
+
|
7 |
+
class ImageMTDataset(ImageBaseDataset):
|
8 |
+
|
9 |
+
TYPE = 'MT'
|
10 |
+
|
11 |
+
def build_prompt(self, line):
|
12 |
+
if isinstance(line, int):
|
13 |
+
line = self.data.iloc[line]
|
14 |
+
|
15 |
+
if self.meta_only:
|
16 |
+
tgt_path = toliststr(line['image_path'])
|
17 |
+
else:
|
18 |
+
tgt_path = self.dump_image(line)
|
19 |
+
|
20 |
+
questions = toliststr(line['question'])
|
21 |
+
if 'answer' in line:
|
22 |
+
answers = toliststr(line['answer'])
|
23 |
+
else:
|
24 |
+
answers = [''] * len(questions)
|
25 |
+
assert len(questions) == len(answers)
|
26 |
+
|
27 |
+
dlgs, pics_number = [], 0
|
28 |
+
for i in range(len(questions)):
|
29 |
+
q, a = questions[i], answers[i]
|
30 |
+
if '<ImageHere>' in q:
|
31 |
+
content = []
|
32 |
+
tag_number = q.count('<ImageHere>')
|
33 |
+
images = tgt_path[pics_number: pics_number + tag_number]
|
34 |
+
pics_number += tag_number
|
35 |
+
q_split = q.split('<ImageHere>')
|
36 |
+
for i in range(tag_number):
|
37 |
+
qsp, im = q_split[i], images[i]
|
38 |
+
if qsp != '':
|
39 |
+
content.append(dict(type='text', value=qsp))
|
40 |
+
content.append(dict(type='image', value=im))
|
41 |
+
if q_split[-1] != '':
|
42 |
+
content.append(dict(type='text', value=q_split[-1]))
|
43 |
+
else:
|
44 |
+
content = [dict(type='text', value=q)]
|
45 |
+
dlgs.append(dict(role='user', content=content))
|
46 |
+
assert '<ImageHere>' not in a, 'We currently do not support images in the answer. '
|
47 |
+
content = [dict(type='text', value=a)]
|
48 |
+
dlgs.append(dict(role='assistant', content=content))
|
49 |
+
return dlgs
|
50 |
+
|
51 |
+
|
52 |
+
class MMDUDataset(ImageMTDataset):
|
53 |
+
|
54 |
+
DATASET_URL = {'MMDU': 'https://opencompass.openxlab.space/utils/VLMEval/MMDU.tsv'}
|
55 |
+
DATASET_MD5 = {'MMDU': '848b635a88a078f49aebcc6e39792061'}
|
56 |
+
DIMS = [
|
57 |
+
'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
|
58 |
+
'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
|
59 |
+
]
|
60 |
+
|
61 |
+
def calculat_metric(self, ans):
|
62 |
+
all = defaultdict(lambda: 0)
|
63 |
+
tot = defaultdict(lambda: 0)
|
64 |
+
valid = defaultdict(lambda: 0)
|
65 |
+
for k in ans:
|
66 |
+
res = ans[k]['res']
|
67 |
+
assert isinstance(res, pd.DataFrame)
|
68 |
+
lt = len(res)
|
69 |
+
for i in range(lt):
|
70 |
+
line = res.iloc[i]
|
71 |
+
for k in self.DIMS:
|
72 |
+
tot[k] += 1
|
73 |
+
if k in line and line[k] is not None:
|
74 |
+
try:
|
75 |
+
score = int(line[k])
|
76 |
+
score = np.clip(score, 0, 10)
|
77 |
+
all[k] += score
|
78 |
+
valid[k] += 1
|
79 |
+
except Exception as e:
|
80 |
+
print(f'Failed to parse the score: {str(e)}')
|
81 |
+
sp1 = {'set': 'all'}
|
82 |
+
sp1.update({k: all[k] / tot[k] * 10 for k in self.DIMS})
|
83 |
+
sp2 = {'set': 'valid'}
|
84 |
+
sp2.update({k: all[k] / valid[k] * 10 for k in self.DIMS})
|
85 |
+
|
86 |
+
return pd.DataFrame([sp1, sp2])
|
87 |
+
|
88 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
89 |
+
suffix = eval_file.split('.')[-1]
|
90 |
+
model = judge_kwargs['model']
|
91 |
+
|
92 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
93 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
|
94 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
95 |
+
|
96 |
+
data = load(eval_file)
|
97 |
+
model = judge_kwargs.pop('model', 'gpt-4o')
|
98 |
+
judge_model = build_judge(model=model, **judge_kwargs)
|
99 |
+
|
100 |
+
lt = len(data)
|
101 |
+
lines = [data.iloc[i] for i in range(lt)]
|
102 |
+
tups = [(judge_model, line) for line in lines]
|
103 |
+
indices = [line['index'] for line in lines]
|
104 |
+
|
105 |
+
ans = {}
|
106 |
+
if osp.exists(tmp_file):
|
107 |
+
ans = load(tmp_file)
|
108 |
+
|
109 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
110 |
+
indices = [i for i in indices if i not in ans]
|
111 |
+
|
112 |
+
from .utils.mmdu import mmdu_score
|
113 |
+
|
114 |
+
if len(indices):
|
115 |
+
new_results = track_progress_rich(
|
116 |
+
mmdu_score,
|
117 |
+
tups,
|
118 |
+
nproc=nproc,
|
119 |
+
chunksize=nproc,
|
120 |
+
keys=indices,
|
121 |
+
save=tmp_file,)
|
122 |
+
ans = load(tmp_file)
|
123 |
+
for k, v in zip(indices, new_results):
|
124 |
+
assert k in ans
|
125 |
+
|
126 |
+
metric = self.calculat_metric(ans)
|
127 |
+
dump(metric, score_file)
|
128 |
+
return metric
|
VLMEvalKit/vlmeval/dataset/image_vqa.py
ADDED
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import tempfile
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from .image_base import ImageBaseDataset
|
9 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
10 |
+
from ..smp import *
|
11 |
+
from ..utils import track_progress_rich
|
12 |
+
|
13 |
+
|
14 |
+
class ImageVQADataset(ImageBaseDataset):
|
15 |
+
TYPE = 'VQA'
|
16 |
+
|
17 |
+
DATASET_URL = {
|
18 |
+
'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
|
19 |
+
'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
|
20 |
+
'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
|
21 |
+
'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
|
22 |
+
'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
|
23 |
+
'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
|
24 |
+
'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
|
25 |
+
'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
|
26 |
+
'GQA_TestDev_Balanced': 'https://opencompass.openxlab.space/utils/VLMEval/GQA_TestDev_Balanced.tsv',
|
27 |
+
}
|
28 |
+
|
29 |
+
DATASET_MD5 = {
|
30 |
+
'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
|
31 |
+
'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
|
32 |
+
'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
|
33 |
+
'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
|
34 |
+
'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
|
35 |
+
'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
|
36 |
+
'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
|
37 |
+
'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
|
38 |
+
'GQA_TestDev_Balanced': 'fead7df22befc1ed3ca2b62ea26fa17b',
|
39 |
+
}
|
40 |
+
|
41 |
+
def build_prompt(self, line):
|
42 |
+
msgs = super().build_prompt(line)
|
43 |
+
assert msgs[-1]['type'] == 'text'
|
44 |
+
msgs[-1]['value'] += '\nAnswer the question using a single word or phrase.'
|
45 |
+
return msgs
|
46 |
+
|
47 |
+
# It returns a DataFrame
|
48 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
49 |
+
from .utils.vqa_eval import hit_calculate, process_line
|
50 |
+
|
51 |
+
data = load(eval_file)
|
52 |
+
dataset = self.dataset_name
|
53 |
+
assert 'answer' in data and 'prediction' in data
|
54 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
55 |
+
data['answer'] = [str(x) for x in data['answer']]
|
56 |
+
lt = len(data)
|
57 |
+
pool = mp.Pool(16)
|
58 |
+
lines = [data.iloc[i] for i in range(lt)]
|
59 |
+
if listinstr(['TextVQA'], dataset):
|
60 |
+
res = pool.map(partial(process_line, method='vqa_score'), lines)
|
61 |
+
elif listinstr(['ChartQA'], dataset):
|
62 |
+
res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
|
63 |
+
elif listinstr(['OCRVQA', 'GQA'], dataset):
|
64 |
+
res = pool.map(partial(process_line, method='accuracy'), lines)
|
65 |
+
elif listinstr(['DocVQA', 'InfoVQA'], dataset):
|
66 |
+
res = pool.map(partial(process_line, method='anls'), lines)
|
67 |
+
else: # default using vqa_score to calculate score
|
68 |
+
res = pool.map(process_line, lines)
|
69 |
+
hit = hit_calculate(res, dataset)
|
70 |
+
ret = dict()
|
71 |
+
if 'split' in data:
|
72 |
+
splits = set(data['split'])
|
73 |
+
for sp in splits:
|
74 |
+
sub = [r for l, r in zip(lines, res) if l['split'] == sp]
|
75 |
+
# [np.mean(x['match']) >= full_score_weight for x in sub]
|
76 |
+
hit = hit_calculate(sub, dataset)
|
77 |
+
ret[sp] = np.mean(hit) * 100
|
78 |
+
sub = [r for l, r in zip(lines, res)]
|
79 |
+
hit = hit_calculate(sub, dataset)
|
80 |
+
ret['Overall'] = np.mean(hit) * 100
|
81 |
+
else:
|
82 |
+
ret['Overall'] = np.mean(hit) * 100
|
83 |
+
if 'category' in data:
|
84 |
+
cates = list(set(data['category']))
|
85 |
+
cates.sort()
|
86 |
+
for c in cates:
|
87 |
+
sub = [r for l, r in zip(lines, res) if l['category'] == c]
|
88 |
+
# [np.mean(x['match']) >= full_score_weight for x in sub]
|
89 |
+
hit = hit_calculate(sub, dataset)
|
90 |
+
ret[c] = np.mean(hit) * 100
|
91 |
+
ret = d2df(ret)
|
92 |
+
ret.round(2)
|
93 |
+
|
94 |
+
suffix = eval_file.split('.')[-1]
|
95 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
96 |
+
dump(ret, result_file)
|
97 |
+
return ret
|
98 |
+
|
99 |
+
|
100 |
+
class VizWiz(ImageBaseDataset):
|
101 |
+
TYPE = 'VQA'
|
102 |
+
DATASET_URL = {
|
103 |
+
'VizWiz': 'https://opencompass.openxlab.space/utils/VLMEval/VizWiz.tsv'
|
104 |
+
}
|
105 |
+
DATASET_MD5 = {
|
106 |
+
'VizWiz': 'fa4ac4164467563ed2fac6eac6631bd0'
|
107 |
+
}
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
111 |
+
from .utils.vqa_eval import hit_calculate, process_line
|
112 |
+
|
113 |
+
suffix = eval_file.split('.')[-1]
|
114 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
115 |
+
|
116 |
+
if not osp.exists(result_file):
|
117 |
+
data = load(eval_file)
|
118 |
+
assert 'answers' in data and 'prediction' in data
|
119 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
120 |
+
data['answer'] = [str(x) for x in data['answers']]
|
121 |
+
|
122 |
+
lt = len(data)
|
123 |
+
pool = mp.Pool(16)
|
124 |
+
lines = [data.iloc[i] for i in range(lt)]
|
125 |
+
res = pool.map(process_line, lines)
|
126 |
+
|
127 |
+
hit = hit_calculate(res, 'VizWiz')
|
128 |
+
ret = dict()
|
129 |
+
|
130 |
+
ret['Overall'] = np.mean(hit) * 100
|
131 |
+
ret = d2df(ret)
|
132 |
+
ret.round(2)
|
133 |
+
|
134 |
+
dump(ret, result_file)
|
135 |
+
|
136 |
+
retz = pd.read_csv(result_file)
|
137 |
+
return retz
|
138 |
+
|
139 |
+
|
140 |
+
class OCRBench(ImageBaseDataset):
|
141 |
+
TYPE = 'VQA'
|
142 |
+
DATASET_URL = {
|
143 |
+
'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv'
|
144 |
+
}
|
145 |
+
DATASET_MD5 = {'OCRBench': 'e953d98a987cc6e26ef717b61260b778'}
|
146 |
+
|
147 |
+
# It returns a dictionary
|
148 |
+
@classmethod
|
149 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
150 |
+
OCRBench_score = {
|
151 |
+
'Regular Text Recognition': 0,
|
152 |
+
'Irregular Text Recognition': 0,
|
153 |
+
'Artistic Text Recognition': 0,
|
154 |
+
'Handwriting Recognition': 0,
|
155 |
+
'Digit String Recognition': 0,
|
156 |
+
'Non-Semantic Text Recognition': 0,
|
157 |
+
'Scene Text-centric VQA': 0,
|
158 |
+
'Doc-oriented VQA': 0,
|
159 |
+
'Key Information Extraction': 0,
|
160 |
+
'Handwritten Mathematical Expression Recognition': 0,
|
161 |
+
}
|
162 |
+
|
163 |
+
data = load(eval_file)
|
164 |
+
lt = len(data)
|
165 |
+
lines = [data.iloc[i] for i in range(lt)]
|
166 |
+
for i in tqdm(range(len(lines))):
|
167 |
+
line = lines[i]
|
168 |
+
predict = str(line['prediction'])
|
169 |
+
answers = eval(line['answer'])
|
170 |
+
category = line['category']
|
171 |
+
if category == 'Handwritten Mathematical Expression Recognition':
|
172 |
+
for j in range(len(answers)):
|
173 |
+
answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
|
174 |
+
predict = predict.strip().replace('\n', ' ').replace(' ', '')
|
175 |
+
if answer in predict:
|
176 |
+
OCRBench_score[category] += 1
|
177 |
+
break
|
178 |
+
else:
|
179 |
+
for j in range(len(answers)):
|
180 |
+
answer = answers[j].lower().strip().replace('\n', ' ')
|
181 |
+
predict = predict.lower().strip().replace('\n', ' ')
|
182 |
+
if answer in predict:
|
183 |
+
OCRBench_score[category] += 1
|
184 |
+
break
|
185 |
+
|
186 |
+
final_score_dict = {}
|
187 |
+
final_score_dict['Text Recognition'] = \
|
188 |
+
(OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
|
189 |
+
+ OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
|
190 |
+
+ OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition'])
|
191 |
+
final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
|
192 |
+
final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
|
193 |
+
final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
|
194 |
+
final_score_dict['Handwritten Mathematical Expression Recognition'] = \
|
195 |
+
(OCRBench_score['Handwritten Mathematical Expression Recognition'])
|
196 |
+
final_score_dict['Final Score'] = \
|
197 |
+
(final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
|
198 |
+
+ final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
|
199 |
+
+ final_score_dict['Handwritten Mathematical Expression Recognition'])
|
200 |
+
final_score_dict['Final Score Norm'] = (float(final_score_dict['Final Score']) / 10)
|
201 |
+
score_pth = eval_file.replace('.xlsx', '_score.json')
|
202 |
+
dump(final_score_dict, score_pth)
|
203 |
+
return final_score_dict
|
204 |
+
|
205 |
+
|
206 |
+
class MathVista(ImageBaseDataset):
|
207 |
+
TYPE = 'VQA'
|
208 |
+
DATASET_URL = {
|
209 |
+
'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv'
|
210 |
+
}
|
211 |
+
DATASET_MD5 = {'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464'}
|
212 |
+
|
213 |
+
# It returns a DataFrame
|
214 |
+
@classmethod
|
215 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
216 |
+
from .utils.mathvista import MathVista_auxeval, MathVista_acc
|
217 |
+
|
218 |
+
model = judge_kwargs['model']
|
219 |
+
suffix = eval_file.split('.')[-1]
|
220 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
221 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
222 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
223 |
+
|
224 |
+
if not osp.exists(storage):
|
225 |
+
data = load(eval_file)
|
226 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
227 |
+
assert model.working(), ('MathVista evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
228 |
+
lt = len(data)
|
229 |
+
lines = [data.iloc[i] for i in range(lt)]
|
230 |
+
tups = [(model, line) for line in lines]
|
231 |
+
indices = [line['index'] for line in lines]
|
232 |
+
|
233 |
+
ans = {}
|
234 |
+
if osp.exists(tmp_file):
|
235 |
+
ans = load(tmp_file)
|
236 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
237 |
+
indices = [i for i in indices if i not in ans]
|
238 |
+
|
239 |
+
if len(indices):
|
240 |
+
new_results = track_progress_rich(
|
241 |
+
MathVista_auxeval,
|
242 |
+
tups,
|
243 |
+
nproc=nproc,
|
244 |
+
chunksize=nproc,
|
245 |
+
keys=indices,
|
246 |
+
save=tmp_file,
|
247 |
+
)
|
248 |
+
ans = load(tmp_file)
|
249 |
+
for k, v in zip(indices, new_results):
|
250 |
+
assert k in ans
|
251 |
+
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
|
252 |
+
|
253 |
+
data['res'] = [ans[idx]['res'] for idx in data['index']]
|
254 |
+
data['log'] = [ans[idx]['log'] for idx in data['index']]
|
255 |
+
dump(data, storage)
|
256 |
+
|
257 |
+
score = MathVista_acc(storage)
|
258 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
259 |
+
dump(score, score_pth)
|
260 |
+
return score
|
261 |
+
|
262 |
+
|
263 |
+
class MathVerse(ImageBaseDataset):
|
264 |
+
TYPE = 'VQA'
|
265 |
+
DATASET_URL = {
|
266 |
+
'MathVerse_MINI': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIV.tsv', # noqa
|
267 |
+
'MathVerse_MINI_Vision_Only': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVOnly.tsv', # noqa
|
268 |
+
'MathVerse_MINI_Vision_Dominant': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVDom.tsv', # noqa
|
269 |
+
'MathVerse_MINI_Vision_Intensive': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVInt.tsv', # noqa
|
270 |
+
'MathVerse_MINI_Text_Lite': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINITLite.tsv', # noqa
|
271 |
+
'MathVerse_MINI_Text_Dominant': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINITDom.tsv', # noqa
|
272 |
+
}
|
273 |
+
DATASET_MD5 = {
|
274 |
+
'MathVerse_MINI': '5017caca32b7fa110c350a1bea861b65',
|
275 |
+
'MathVerse_MINI_Vision_Only': '68a11d4680014ac881fa37adeadea3a4',
|
276 |
+
'MathVerse_MINI_Vision_Dominant': 'b8fb63852d261ab2aaefba29cc2414d3',
|
277 |
+
'MathVerse_MINI_Vision_Intensive': '01cbd35be202bb0c4873a4186a63bc19',
|
278 |
+
'MathVerse_MINI_Text_Lite': '19e4b13bdd30b89a03b2e358bcfefa04',
|
279 |
+
'MathVerse_MINI_Text_Dominant': '4f5cd2fa6630ea00bb11d6fde1f6fe6a',
|
280 |
+
}
|
281 |
+
|
282 |
+
# It returns a DataFrame
|
283 |
+
@classmethod
|
284 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
285 |
+
from .utils.mathverse import MathVerse_auxeval_extract, MathVerse_auxeval_score, MathVerse_acc
|
286 |
+
|
287 |
+
model = judge_kwargs['model']
|
288 |
+
suffix = eval_file.split('.')[-1]
|
289 |
+
storage_extract = eval_file.replace(f'.{suffix}', f'_{model}_extract.xlsx')
|
290 |
+
tmp_file_extract = eval_file.replace(f'.{suffix}', f'_{model}_extract.pkl')
|
291 |
+
storage_score = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
292 |
+
tmp_file_score = eval_file.replace(f'.{suffix}', f'_{model}_score.pkl')
|
293 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
294 |
+
# stage1: extract the answer
|
295 |
+
if not osp.exists(storage_extract):
|
296 |
+
data = load(eval_file)
|
297 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
298 |
+
assert model.working(), ('MathVerse evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
299 |
+
lt = len(data)
|
300 |
+
lines = [data.iloc[i] for i in range(lt)]
|
301 |
+
tups = [(model, line) for line in lines]
|
302 |
+
indices = [line['index'] for line in lines]
|
303 |
+
|
304 |
+
ans = {}
|
305 |
+
if osp.exists(tmp_file_extract):
|
306 |
+
ans = load(tmp_file_extract)
|
307 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
308 |
+
indices = [i for i in indices if i not in ans]
|
309 |
+
|
310 |
+
if len(indices):
|
311 |
+
new_results = track_progress_rich(
|
312 |
+
MathVerse_auxeval_extract,
|
313 |
+
tups,
|
314 |
+
nproc=nproc,
|
315 |
+
chunksize=nproc,
|
316 |
+
keys=indices,
|
317 |
+
save=tmp_file_extract,
|
318 |
+
)
|
319 |
+
ans = load(tmp_file_extract)
|
320 |
+
for k, v in zip(indices, new_results):
|
321 |
+
assert k in ans
|
322 |
+
assert ans[k]['log_extract'] == v['log_extract'] and ans[k]['extract'] == v['extract']
|
323 |
+
|
324 |
+
data['extract'] = [ans[idx]['extract'] for idx in data['index']]
|
325 |
+
data['log_extract'] = [ans[idx]['log_extract'] for idx in data['index']]
|
326 |
+
dump(data, storage_extract)
|
327 |
+
|
328 |
+
# stage2: score the answer
|
329 |
+
if not osp.exists(storage_score):
|
330 |
+
data = load(storage_extract)
|
331 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
332 |
+
assert model.working(), ('MathVerse evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
333 |
+
lt = len(data)
|
334 |
+
lines = [data.iloc[i] for i in range(lt)]
|
335 |
+
tups = [(model, line) for line in lines]
|
336 |
+
indices = [line['index'] for line in lines]
|
337 |
+
|
338 |
+
ans = {}
|
339 |
+
if osp.exists(tmp_file_score):
|
340 |
+
ans = load(tmp_file_score)
|
341 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
342 |
+
indices = [i for i in indices if i not in ans]
|
343 |
+
|
344 |
+
if len(indices):
|
345 |
+
new_results = track_progress_rich(
|
346 |
+
MathVerse_auxeval_score,
|
347 |
+
tups,
|
348 |
+
nproc=nproc,
|
349 |
+
chunksize=nproc,
|
350 |
+
keys=indices,
|
351 |
+
save=tmp_file_score,
|
352 |
+
)
|
353 |
+
ans = load(tmp_file_score)
|
354 |
+
for k, v in zip(indices, new_results):
|
355 |
+
assert k in ans
|
356 |
+
assert ans[k]['log_score'] == v['log_score'] and ans[k]['score'] == v['score']
|
357 |
+
|
358 |
+
data['score'] = [ans[idx]['score'] for idx in data['index']]
|
359 |
+
data['log_score'] = [ans[idx]['log_score'] for idx in data['index']]
|
360 |
+
dump(data, storage_score)
|
361 |
+
|
362 |
+
score = MathVerse_acc(storage_score)
|
363 |
+
score_pth = storage_score.replace('.xlsx', '.csv')
|
364 |
+
dump(score, score_pth)
|
365 |
+
return score
|
366 |
+
|
367 |
+
|
368 |
+
class MathVision(ImageBaseDataset):
|
369 |
+
TYPE = 'VQA'
|
370 |
+
DATASET_URL = {
|
371 |
+
'MathVision': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
|
372 |
+
'MathVision_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv'
|
373 |
+
}
|
374 |
+
DATASET_MD5 = {
|
375 |
+
'MathVision': '93f6de14f7916e598aa1b7165589831e',
|
376 |
+
'MathVision_MINI': '060fe4fa5d868987ce179307bd5f8a33'
|
377 |
+
}
|
378 |
+
|
379 |
+
# It returns a DataFrame
|
380 |
+
@classmethod
|
381 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
382 |
+
from .utils.mathv import MATH_V_auxeval, MATH_V_acc
|
383 |
+
|
384 |
+
if 'model' in judge_kwargs:
|
385 |
+
model = judge_kwargs['model']
|
386 |
+
else:
|
387 |
+
model = os.path.basename(os.environ.get('LOCAL_LLM'))
|
388 |
+
suffix = eval_file.split('.')[-1]
|
389 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
390 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
391 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
392 |
+
|
393 |
+
if not osp.exists(storage):
|
394 |
+
data = load(eval_file)
|
395 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
396 |
+
assert model.working(), ('MATH-Vision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
397 |
+
lt = len(data)
|
398 |
+
lines = [data.iloc[i] for i in range(lt)]
|
399 |
+
tups = [(model, line) for line in lines]
|
400 |
+
indices = [line['index'] for line in lines]
|
401 |
+
|
402 |
+
ans = {}
|
403 |
+
if osp.exists(tmp_file):
|
404 |
+
ans = load(tmp_file)
|
405 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
406 |
+
indices = [i for i in indices if i not in ans]
|
407 |
+
|
408 |
+
if len(indices):
|
409 |
+
new_results = track_progress_rich(
|
410 |
+
MATH_V_auxeval,
|
411 |
+
tups,
|
412 |
+
nproc=nproc,
|
413 |
+
chunksize=nproc,
|
414 |
+
keys=indices,
|
415 |
+
save=tmp_file,
|
416 |
+
)
|
417 |
+
ans = load(tmp_file)
|
418 |
+
for k, v in zip(indices, new_results):
|
419 |
+
assert k in ans
|
420 |
+
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
|
421 |
+
|
422 |
+
data['res'] = [ans[idx]['res'] for idx in data['index']]
|
423 |
+
data['log'] = [ans[idx]['log'] for idx in data['index']]
|
424 |
+
dump(data, storage)
|
425 |
+
|
426 |
+
score = MATH_V_acc(storage)
|
427 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
428 |
+
dump(score, score_pth)
|
429 |
+
return score
|
430 |
+
|
431 |
+
|
432 |
+
class OlympiadBench(ImageBaseDataset):
|
433 |
+
TYPE = 'VQA_ex_prompt'
|
434 |
+
DATASET_URL = {
|
435 |
+
'OlympiadBench': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench.tsv',
|
436 |
+
'OlympiadBench_EN': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench_EN.tsv',
|
437 |
+
'OlympiadBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench_CN.tsv'
|
438 |
+
}
|
439 |
+
DATASET_MD5 = {
|
440 |
+
'OlympiadBench': '9735ae0f0299eae1e7d07f5a7feab914',
|
441 |
+
'OlympiadBench_EN': '5c68e100d394351fc7049f29d4d4efed',
|
442 |
+
'OlympiadBench_CN': 'ea01b16788955702c79650c701e5b623'
|
443 |
+
}
|
444 |
+
|
445 |
+
def dump_image(self, line):
|
446 |
+
os.makedirs(self.img_root, exist_ok=True)
|
447 |
+
|
448 |
+
tgt_path_z = []
|
449 |
+
if isinstance(line['image'], list):
|
450 |
+
for i in range(len(line['image'])):
|
451 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
|
452 |
+
if not read_ok(tgt_path):
|
453 |
+
decode_base64_to_image_file(line['image'][i], tgt_path)
|
454 |
+
tgt_path_z.append(tgt_path)
|
455 |
+
else:
|
456 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
457 |
+
if not read_ok(tgt_path):
|
458 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
459 |
+
tgt_path_z.append(tgt_path)
|
460 |
+
return tgt_path_z
|
461 |
+
|
462 |
+
def build_prompt(self, line):
|
463 |
+
|
464 |
+
from .utils.olympiadbench import get_answer_type_text, make_input
|
465 |
+
|
466 |
+
self.is_chinese = 'zh' in line['source']
|
467 |
+
self.is_math = 'maths' in line['source']
|
468 |
+
self.is_theorem_proving = 'TP' in line['source']
|
469 |
+
|
470 |
+
if self.is_chinese:
|
471 |
+
subject_content = '数学' if self.is_math else '物理'
|
472 |
+
if self.is_theorem_proving:
|
473 |
+
prompt = (
|
474 |
+
f"以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,运用逻辑推理及常用定理证明题目中的命题。"
|
475 |
+
"证明过程中使用的变量和公式请使用LaTeX格式表示。"
|
476 |
+
)
|
477 |
+
else:
|
478 |
+
answer_type_text = get_answer_type_text(line['answer_type'], is_chinese=True,
|
479 |
+
multiple_answer=line['is_multiple_answer'])
|
480 |
+
if line['is_multiple_answer']:
|
481 |
+
multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}'
|
482 |
+
else:
|
483 |
+
multiple_answer_text = '\\boxed{答案}'
|
484 |
+
unit_text = ''
|
485 |
+
if line['unit']:
|
486 |
+
multiple_answer_text += '(单位)'
|
487 |
+
unit_text = ',注意答案的单位不要放在\\boxed{}中'
|
488 |
+
prompt = (
|
489 |
+
f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。请根据题目的要求和所提供的信息计算得出答案。'
|
490 |
+
f'解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以“所以最终答案是{multiple_answer_text}。”'
|
491 |
+
f'显式给出结果{unit_text}。'
|
492 |
+
)
|
493 |
+
else:
|
494 |
+
subject_content = 'Math' if self.is_math else 'Physics'
|
495 |
+
if self.is_theorem_proving:
|
496 |
+
prompt = (
|
497 |
+
f'The following is a theorem proving problem from an International {subject_content} competition. '
|
498 |
+
'Please use logical reasoning and common theorems to prove the proposition in the problem '
|
499 |
+
'according to the given requirements. '
|
500 |
+
'Please use LaTeX format to represent the variables and formulas used in the proof.'
|
501 |
+
)
|
502 |
+
else:
|
503 |
+
if line['is_multiple_answer']:
|
504 |
+
multiple_answer_text = '\\boxed{multiple answers connected with commas}'
|
505 |
+
else:
|
506 |
+
multiple_answer_text = '\\boxed{answer}'
|
507 |
+
unit_text = ''
|
508 |
+
if line['unit']:
|
509 |
+
multiple_answer_text += '(unit)'
|
510 |
+
unit_text = ', note that the unit of the answer should not be included in \\boxed{}'
|
511 |
+
answer_type_text = get_answer_type_text(line['answer_type'], is_chinese=False,
|
512 |
+
multiple_answer=line['is_multiple_answer'])
|
513 |
+
prompt = (
|
514 |
+
f'The following is an open-ended problem from an International {subject_content} competition. '
|
515 |
+
f'{answer_type_text}Please calculate the answer according to the given requirements and '
|
516 |
+
'the information provided. Please use LaTeX format to represent the variables and formulas '
|
517 |
+
'used in the solution process and results. Please end your solution with "So the final answer '
|
518 |
+
f'is {multiple_answer_text}." and give the result explicitly{unit_text}.'
|
519 |
+
)
|
520 |
+
|
521 |
+
if self.is_math:
|
522 |
+
input = make_input(prompt, line['question'])
|
523 |
+
else:
|
524 |
+
if 'context' in line.keys() and str(line['context']) != 'nan': # cannot be null
|
525 |
+
input = make_input(prompt, line['context'] + '\n' + line['question'])
|
526 |
+
else:
|
527 |
+
input = make_input(prompt, line['question'])
|
528 |
+
|
529 |
+
ret = [dict(type='text', value=input)]
|
530 |
+
tgt_path = self.dump_image(line)
|
531 |
+
|
532 |
+
ret.extend([dict(type='image', value=s) for s in tgt_path])
|
533 |
+
|
534 |
+
return ret
|
535 |
+
|
536 |
+
@classmethod
|
537 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
538 |
+
from .utils.olympiadbench import MathJudger, extract_answer
|
539 |
+
judger = MathJudger()
|
540 |
+
|
541 |
+
suffix = eval_file.split('.')[-1]
|
542 |
+
name_str1 = 'judge'
|
543 |
+
name_str2 = 'score'
|
544 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str1}_result.xlsx')
|
545 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{name_str2}_result.csv')
|
546 |
+
|
547 |
+
if not osp.exists(result_file):
|
548 |
+
data = load(eval_file)
|
549 |
+
scorez = []
|
550 |
+
|
551 |
+
for i in tqdm(data.iterrows()):
|
552 |
+
line = i[1]
|
553 |
+
model_answer = line['prediction']
|
554 |
+
is_chinese = 'zh' in line['source']
|
555 |
+
model_answer = extract_answer(is_chinese, model_answer, is_deepseek=False)
|
556 |
+
answer_type = line['answer_type']
|
557 |
+
|
558 |
+
final_answer = line['final_answer'][2:-2]
|
559 |
+
|
560 |
+
if str(answer_type) != 'nan' and 'Tuple' in answer_type:
|
561 |
+
judge_result = judger.judge(model_answer, final_answer)
|
562 |
+
else:
|
563 |
+
if str(line['error']) != 'nan':
|
564 |
+
if ',' in line['error']:
|
565 |
+
precisions = line['error'].split(',')
|
566 |
+
precisions = [float(p) if p else 1e-8 for p in precisions]
|
567 |
+
judge_result = judger.judge(model_answer, final_answer, precisions)
|
568 |
+
else:
|
569 |
+
precision = float(line['error'])
|
570 |
+
judge_result = judger.judge(model_answer, final_answer, precision)
|
571 |
+
else:
|
572 |
+
judge_result = judger.judge(model_answer, final_answer)
|
573 |
+
scorez.append(judge_result)
|
574 |
+
|
575 |
+
data['score'] = scorez
|
576 |
+
dump(data, result_file)
|
577 |
+
|
578 |
+
judge_file = load(result_file)
|
579 |
+
|
580 |
+
if not osp.exists(score_file):
|
581 |
+
name_list = ['OE_MM_maths_en_COMP', 'OE_MM_maths_zh_CEE', 'OE_MM_maths_zh_COMP', 'OE_MM_physics_en_COMP',
|
582 |
+
'OE_MM_physics_zh_CEE','OE_TO_maths_en_COMP', 'OE_TO_maths_zh_CEE', 'OE_TO_maths_zh_COMP',
|
583 |
+
'OE_TO_physics_en_COMP', 'OE_TO_physics_zh_CEE']
|
584 |
+
|
585 |
+
sample_list = [[] for _ in range(len(name_list))]
|
586 |
+
for i in judge_file.iterrows():
|
587 |
+
line = i[1]
|
588 |
+
for j in range(len(name_list)):
|
589 |
+
if line['source'] == name_list[j]:
|
590 |
+
sample_list[j].append(line['score'])
|
591 |
+
|
592 |
+
acc_dict = {}
|
593 |
+
correct_list = []
|
594 |
+
|
595 |
+
# fine-grained
|
596 |
+
for i in range(len(name_list)):
|
597 |
+
correct_num = 0
|
598 |
+
for j in sample_list[i]:
|
599 |
+
if j:
|
600 |
+
correct_num += 1
|
601 |
+
correct_list.append(correct_num)
|
602 |
+
acc = 100 * correct_num / len(sample_list[i])
|
603 |
+
acc_dict[name_list[i]] = [acc]
|
604 |
+
|
605 |
+
# 4 grained
|
606 |
+
labela = ['zh', 'en']
|
607 |
+
labelb = ['maths', 'physics']
|
608 |
+
|
609 |
+
grain_list = [[x,y] for x in labela for y in labelb]
|
610 |
+
for j in grain_list:
|
611 |
+
dict_name = j[0] + "_" + j[1]
|
612 |
+
correct_num = 0
|
613 |
+
full_num = 0
|
614 |
+
for i in range(len(name_list)):
|
615 |
+
if all(k in name_list[i] for k in j):
|
616 |
+
correct_num += correct_list[i]
|
617 |
+
full_num += len(sample_list[i])
|
618 |
+
acc = 100 * correct_num / full_num
|
619 |
+
acc_dict[dict_name] = [acc]
|
620 |
+
|
621 |
+
# 2 grained
|
622 |
+
grain_list = ['maths', 'physics']
|
623 |
+
for j in grain_list:
|
624 |
+
dict_name = j
|
625 |
+
correct_num = 0
|
626 |
+
full_num = 0
|
627 |
+
for i in range(len(name_list)):
|
628 |
+
if j in name_list[i]:
|
629 |
+
correct_num += correct_list[i]
|
630 |
+
full_num += len(sample_list[i])
|
631 |
+
acc = 100 * correct_num / full_num
|
632 |
+
acc_dict[dict_name] = [acc]
|
633 |
+
|
634 |
+
# AVG
|
635 |
+
correct_num = sum(correct_list)
|
636 |
+
acc = 100 * correct_num / len(judge_file)
|
637 |
+
acc_dict['AVG'] = [acc]
|
638 |
+
|
639 |
+
acc_pd = pd.DataFrame(acc_dict)
|
640 |
+
acc_pd.to_csv(score_file, index=False, encoding='gbk')
|
641 |
+
|
642 |
+
accdz = pd.read_csv(score_file)
|
643 |
+
return accdz
|
644 |
+
|
645 |
+
|
646 |
+
class LLaVABench(ImageBaseDataset):
|
647 |
+
TYPE = 'VQA'
|
648 |
+
DATASET_URL = {'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv'}
|
649 |
+
DATASET_MD5 = {'LLaVABench': 'd382a093f749a697820d3dadd61c8428'}
|
650 |
+
|
651 |
+
# It returns a DataFrame
|
652 |
+
@classmethod
|
653 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
654 |
+
from .utils.llavabench import (
|
655 |
+
build_prompt,
|
656 |
+
LLaVABench_atomeval,
|
657 |
+
LLaVABench_score,
|
658 |
+
)
|
659 |
+
|
660 |
+
suffix = '.' + eval_file.split('.')[-1]
|
661 |
+
record_file = eval_file.replace(suffix, '_openai_result' + suffix)
|
662 |
+
score_file = eval_file.replace(suffix, '_score.csv')
|
663 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
664 |
+
system_prompt = 'You are a helpful and precise assistant for checking the quality of the answer.'
|
665 |
+
|
666 |
+
if not osp.exists(record_file):
|
667 |
+
data = load(eval_file)
|
668 |
+
lines = [data.iloc[i] for i in range(len(data))]
|
669 |
+
model = build_judge(temperature=0.2, system_prompt=system_prompt, **judge_kwargs)
|
670 |
+
assert model.working(), ('LLaVABench evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
671 |
+
|
672 |
+
prompts = [build_prompt(line) for line in lines]
|
673 |
+
tups = [(model, prompt) for prompt in prompts]
|
674 |
+
scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc)
|
675 |
+
data['gpt4_score'] = [x[0] for x in scores]
|
676 |
+
data['score'] = [x[1] for x in scores]
|
677 |
+
dump(data, record_file)
|
678 |
+
|
679 |
+
data = load(record_file)
|
680 |
+
ret = LLaVABench_score(data).round(1)
|
681 |
+
dump(ret, score_file)
|
682 |
+
return ret
|
683 |
+
|
684 |
+
|
685 |
+
class MMVet(ImageBaseDataset):
|
686 |
+
TYPE = 'VQA'
|
687 |
+
DATASET_URL = {
|
688 |
+
'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv'
|
689 |
+
}
|
690 |
+
DATASET_MD5 = {'MMVet': '748aa6d4aa9d4de798306a63718455e3'}
|
691 |
+
|
692 |
+
# It returns a DataFrame
|
693 |
+
@classmethod
|
694 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
695 |
+
from .utils.mmvet import MMVet_auxeval, MMVet_acc
|
696 |
+
|
697 |
+
suffix = eval_file.split('.')[-1]
|
698 |
+
model = judge_kwargs['model']
|
699 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
700 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
701 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
702 |
+
if not osp.exists(storage):
|
703 |
+
data = load(eval_file)
|
704 |
+
model = build_judge(max_tokens=3, **judge_kwargs)
|
705 |
+
assert model.working(), ('MMVet evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
706 |
+
|
707 |
+
lt = len(data)
|
708 |
+
lines = [data.iloc[i] for i in range(lt)]
|
709 |
+
tups = [(model, line) for line in lines]
|
710 |
+
indices = [line['index'] for line in lines]
|
711 |
+
|
712 |
+
ans = load(tmp_file) if osp.exists(tmp_file) else {}
|
713 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
714 |
+
indices = [i for i in indices if i not in ans]
|
715 |
+
|
716 |
+
if len(indices):
|
717 |
+
new_results = track_progress_rich(
|
718 |
+
MMVet_auxeval,
|
719 |
+
tups,
|
720 |
+
nproc=nproc,
|
721 |
+
chunksize=nproc,
|
722 |
+
keys=indices,
|
723 |
+
save=tmp_file,
|
724 |
+
)
|
725 |
+
ans = load(tmp_file)
|
726 |
+
for k, v in zip(indices, new_results):
|
727 |
+
assert k in ans
|
728 |
+
assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score']
|
729 |
+
data['score'] = [ans[idx]['score'] for idx in data['index']]
|
730 |
+
data['log'] = [ans[idx]['log'] for idx in data['index']]
|
731 |
+
dump(data, storage)
|
732 |
+
|
733 |
+
score, score_fine = MMVet_acc(storage)
|
734 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
735 |
+
score_fine_pth = storage.replace('.xlsx', '_score_fine.csv')
|
736 |
+
dump(score, score_pth)
|
737 |
+
dump(score_fine, score_fine_pth)
|
738 |
+
return score
|
739 |
+
|
740 |
+
|
741 |
+
class MTVQADataset(ImageBaseDataset):
|
742 |
+
TYPE = 'VQA'
|
743 |
+
DATASET_URL = {'MTVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MTVQA_TEST.tsv'}
|
744 |
+
DATASET_MD5 = {'MTVQA_TEST': 'd87c17dbab934b7cd89c0a3c1c5657f4'}
|
745 |
+
|
746 |
+
@classmethod
|
747 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
748 |
+
data = load(eval_file)
|
749 |
+
assert 'answer' in data and 'prediction' in data and 'category' in data
|
750 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
751 |
+
data['answer'] = [str(x) for x in data['answer']]
|
752 |
+
if 'split' in data:
|
753 |
+
assert np.all([x.lower() == 'test' for x in data['split']]), 'We only support MTVQA_TEST for now. '
|
754 |
+
lt = len(data)
|
755 |
+
category_scores = defaultdict(list)
|
756 |
+
for i in range(lt):
|
757 |
+
line = data.iloc[i]
|
758 |
+
ans = line['answer'].strip().lower().replace('.', '')
|
759 |
+
pred = line['prediction'].strip().lower().replace('.', '')
|
760 |
+
cate = line['category']
|
761 |
+
score = 1.0 if ans in pred else 0.0
|
762 |
+
category_scores[cate].append(score)
|
763 |
+
category_scores['Average'].append(score)
|
764 |
+
# Calculate the average score for each category, the score is normalized to [0, 100]
|
765 |
+
category_averages = {category: np.mean(scores) * 100 for category, scores in category_scores.items()}
|
766 |
+
|
767 |
+
suffix = eval_file.split('.')[-1]
|
768 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.json')
|
769 |
+
dump(category_averages, result_file)
|
770 |
+
|
771 |
+
return category_averages
|
772 |
+
|
773 |
+
# MT-VQA adopts a custom prompt
|
774 |
+
def build_prompt(self, line):
|
775 |
+
msgs = super().build_prompt(line)
|
776 |
+
assert sum([x['type'] == 'text' for x in msgs]) == 1
|
777 |
+
for item in msgs:
|
778 |
+
if item['type'] == 'text':
|
779 |
+
item['value'] += '\nAnswer the question using a word or phrase in the language of the question.'
|
780 |
+
return msgs
|
781 |
+
|
782 |
+
|
783 |
+
class TableVQABench(ImageBaseDataset):
|
784 |
+
TYPE = 'VQA'
|
785 |
+
DATASET_URL = {
|
786 |
+
'TableVQABench': 'https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/mentor-vil/datasets/tablevqa-bench.tsv'
|
787 |
+
}
|
788 |
+
DATASET_MD5 = {'TableVQABench': '2550adc61bdc82d8e62f3b003de7c62d'}
|
789 |
+
|
790 |
+
from .utils.tablevqabench import FINTABNETQA_PROMPT, VTABFACT_PROMPT, VWTQ_PROMPT
|
791 |
+
|
792 |
+
# It returns a DataFrame
|
793 |
+
@classmethod
|
794 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
795 |
+
import pandas as pd
|
796 |
+
from .utils.tablevqabench import evaluate_fintabnet, evaluate_tabfact, evaluate_wtq
|
797 |
+
|
798 |
+
data = load(eval_file)
|
799 |
+
assert 'answer' in data and 'prediction' in data
|
800 |
+
|
801 |
+
data['prediction'] = data['prediction'].str.replace('^Answer: ', '', regex=True)
|
802 |
+
data_group = dict(tuple(data.groupby('split')))
|
803 |
+
eval_result = {'split': [], 'average_scores': []}
|
804 |
+
for split in ['fintabnetqa', 'vtabfact', 'vwtq', 'vwtq_syn']:
|
805 |
+
data_split = data_group[split].to_dict(orient='records')
|
806 |
+
if split == 'fintabnetqa':
|
807 |
+
split_eval_meta = evaluate_fintabnet(data_split, ['accuracy'])
|
808 |
+
elif split == 'vtabfact':
|
809 |
+
split_eval_meta = evaluate_tabfact(data_split, ['accuracy'])
|
810 |
+
elif split == 'vwtq' or split == 'vwtq_syn':
|
811 |
+
split_eval_meta = evaluate_wtq(data_split, ['accuracy'])
|
812 |
+
eval_result['split'].append(split)
|
813 |
+
eval_result['average_scores'].append(split_eval_meta['average_scores'])
|
814 |
+
|
815 |
+
suffix = eval_file.split('.')[-1]
|
816 |
+
result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
817 |
+
eval_result = pd.DataFrame(eval_result)
|
818 |
+
dump(eval_result, result_file)
|
819 |
+
|
820 |
+
return eval_result
|
821 |
+
|
822 |
+
# TableVQABench adopts a custom prompt
|
823 |
+
def build_prompt(self, line):
|
824 |
+
msgs = super().build_prompt(line)
|
825 |
+
assert sum([x['type'] == 'text' for x in msgs]) == 1
|
826 |
+
for item in msgs:
|
827 |
+
if item['type'] == 'text':
|
828 |
+
if line['split'] == 'fintabnetqa':
|
829 |
+
item['value'] = self.FINTABNETQA_PROMPT.format_map({'question': item['value']})
|
830 |
+
elif line['split'] == 'vtabfact':
|
831 |
+
item['value'] = self.VTABFACT_PROMPT.format_map({'question': item['value']})
|
832 |
+
elif line['split'] == 'vwtq_syn' or line['split'] == 'vwtq':
|
833 |
+
item['value'] = self.VWTQ_PROMPT.format_map({'question': item['value']})
|
834 |
+
return msgs
|
835 |
+
|
836 |
+
|
837 |
+
class CustomVQADataset(ImageBaseDataset):
|
838 |
+
TYPE = 'VQA'
|
839 |
+
|
840 |
+
def load_data(self, dataset):
|
841 |
+
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
842 |
+
|
843 |
+
if file_size(data_path, 'GB') > 1:
|
844 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
845 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
846 |
+
from ..tools import LOCALIZE
|
847 |
+
|
848 |
+
LOCALIZE(data_path, local_path)
|
849 |
+
data_path = local_path
|
850 |
+
return load(data_path)
|
851 |
+
|
852 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
853 |
+
raise NotImplementedError
|
854 |
+
|
855 |
+
|
856 |
+
class CRPE(ImageBaseDataset):
|
857 |
+
TYPE = 'VQA'
|
858 |
+
DATASET_URL = {
|
859 |
+
'CRPE_EXIST': 'https://huggingface.co/datasets/petter12321/crpe_vlmevalkit/resolve/main/CRPE_EXIST.tsv',
|
860 |
+
'CRPE_RELATION': 'https://huggingface.co/datasets/petter12321/crpe_vlmevalkit/resolve/main/CRPE_RELATION.tsv'
|
861 |
+
}
|
862 |
+
DATASET_MD5 = {
|
863 |
+
'CRPE_EXIST': '315584e23ac1ff7f8719ed3b7ad90f08',
|
864 |
+
'CRPE_RELATION': 'bad7094cde0b572288f4b119c2d0c656'}
|
865 |
+
|
866 |
+
@classmethod
|
867 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
868 |
+
from .utils.crpe import is_correct
|
869 |
+
# find-image, count-text, find-text,
|
870 |
+
# infer-choose, count-image, visual-reasoning
|
871 |
+
score = {
|
872 |
+
'exist': 0,
|
873 |
+
'subject': 0,
|
874 |
+
'predicate': 0,
|
875 |
+
'object': 0,
|
876 |
+
'total': 0,
|
877 |
+
}
|
878 |
+
num = {
|
879 |
+
'exist': 0,
|
880 |
+
'subject': 0,
|
881 |
+
'predicate': 0,
|
882 |
+
'object': 0,
|
883 |
+
'total': 0,
|
884 |
+
}
|
885 |
+
final_score_dict = {
|
886 |
+
'exist': 0,
|
887 |
+
'subject': 0,
|
888 |
+
'predicate': 0,
|
889 |
+
'object': 0,
|
890 |
+
'total': 0,
|
891 |
+
}
|
892 |
+
data = load(eval_file)
|
893 |
+
lt = len(data)
|
894 |
+
lines = [data.iloc[i] for i in range(lt)]
|
895 |
+
for i in tqdm(range(len(lines))):
|
896 |
+
line = lines[i]
|
897 |
+
predict = str(line['prediction'])
|
898 |
+
answers = str(line['answer'])
|
899 |
+
# print("predict =", predict)
|
900 |
+
# print("answers =", answers)
|
901 |
+
category = line['category']
|
902 |
+
if is_correct(answers, predict):
|
903 |
+
score[category] += 1
|
904 |
+
score['total'] += 1
|
905 |
+
num[category] += 1
|
906 |
+
num['total'] += 1
|
907 |
+
|
908 |
+
for category in ['exist', 'subject', 'predicate', 'object', 'total']:
|
909 |
+
if num[category] != 0:
|
910 |
+
final_score_dict[category] = score[category] / num[category]
|
911 |
+
else:
|
912 |
+
final_score_dict[category] = None
|
913 |
+
|
914 |
+
score_pth = eval_file.replace('.xlsx', '_score.json')
|
915 |
+
dump(final_score_dict, score_pth)
|
916 |
+
return final_score_dict
|
917 |
+
|
918 |
+
def build_prompt(self, line):
|
919 |
+
ROOT = LMUDataRoot()
|
920 |
+
msgs = super().build_prompt(line)
|
921 |
+
for msg in msgs:
|
922 |
+
if msg['type'] == 'image':
|
923 |
+
msg['value'] = osp.join(osp.join(ROOT, 'images', self.dataset_name), msg['value'])
|
924 |
+
return msgs
|
925 |
+
|
926 |
+
|
927 |
+
class QSpatial(ImageBaseDataset):
|
928 |
+
TYPE = 'VQA'
|
929 |
+
DATASET_URL = {
|
930 |
+
'QSpatial_plus': '',
|
931 |
+
'QSpatial_scannet': ''
|
932 |
+
}
|
933 |
+
|
934 |
+
# NOTE: To evaluate Q-Spatial-ScanNet, you need to get the permission from ScanNet website
|
935 |
+
# Once you get the permission, you can use the helper code here to download and extract necessary images:
|
936 |
+
# https://github.com/andrewliao11/Q-Spatial-Bench-code?tab=readme-ov-file#for-qspatial_scannet
|
937 |
+
qspatial_root = "TO_BE_REPLACED_WITH_THE_PATH_TO_QSPATIAL_DATASET"
|
938 |
+
url = "https://raw.githubusercontent.com/andrewliao11/Q-Spatial-Bench-code/refs/heads/main/prompt_templates/"
|
939 |
+
|
940 |
+
def post_build(self, dataset):
|
941 |
+
# Download the prompt templates from github
|
942 |
+
|
943 |
+
links = [
|
944 |
+
self.url + "system_prompt.txt",
|
945 |
+
self.url + "spatial_prompt_single.txt",
|
946 |
+
self.url + "spatial_prompt_steps.txt",
|
947 |
+
self.url + "standard_prompt.txt",
|
948 |
+
self.url + "zero_shot_prompt.txt"
|
949 |
+
]
|
950 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
951 |
+
for link in links:
|
952 |
+
tgt_path = os.path.join(temp_dir, link.split("/")[-1])
|
953 |
+
os.system(f"wget {link} -O {tgt_path}")
|
954 |
+
|
955 |
+
self.system_prompt = open(os.path.join(temp_dir, "system_prompt.txt")).read()
|
956 |
+
self._prompt_templates = dict(
|
957 |
+
spatial_prompt_single=open(os.path.join(temp_dir, "spatial_prompt_single.txt")).read(),
|
958 |
+
spatial_prompt_steps=open(os.path.join(temp_dir, "spatial_prompt_steps.txt")).read(),
|
959 |
+
standard_prompt=open(os.path.join(temp_dir, "standard_prompt.txt")).read(),
|
960 |
+
zero_shot_prompt=open(os.path.join(temp_dir, "zero_shot_prompt.txt")).read(),
|
961 |
+
)
|
962 |
+
|
963 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
964 |
+
def build_prompt(self, line):
|
965 |
+
from jinja2.sandbox import SandboxedEnvironment
|
966 |
+
text_prompt_template = self._prompt_templates["spatial_prompt_single"]
|
967 |
+
env = SandboxedEnvironment()
|
968 |
+
text_prompt = env.from_string(text_prompt_template).render(question=line["question"])
|
969 |
+
tgt_path = self.dump_image(line)
|
970 |
+
|
971 |
+
msgs = []
|
972 |
+
if isinstance(tgt_path, list):
|
973 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
974 |
+
else:
|
975 |
+
msgs = [dict(type='image', value=tgt_path)]
|
976 |
+
|
977 |
+
msgs.append(dict(type='text', value=f"{self.system_prompt}\n{text_prompt}"))
|
978 |
+
return msgs
|
979 |
+
|
980 |
+
# Given the dataset name, return the dataset as a pandas dataframe, can override
|
981 |
+
def load_data(self, dataset):
|
982 |
+
import io
|
983 |
+
import pandas as pd
|
984 |
+
from datasets import load_dataset
|
985 |
+
|
986 |
+
hf_dataset = load_dataset("andrewliao11/Q-Spatial-Bench", split=dataset)
|
987 |
+
df = hf_dataset.to_pandas()
|
988 |
+
|
989 |
+
df.reset_index(drop=True, inplace=True)
|
990 |
+
df['index'] = df.index
|
991 |
+
df['answer'] = list(zip(df['answer_value'], df['answer_unit']))
|
992 |
+
df = df[['index'] + [col for col in df.columns if col != 'index']]
|
993 |
+
|
994 |
+
if dataset == "QSpatial_scannet":
|
995 |
+
df = df.drop(columns=["image"])
|
996 |
+
df["image"] = [Image.open(os.path.join(self.qspatial_root, image_path)) for image_path in df["image_path"]]
|
997 |
+
else:
|
998 |
+
df["image"] = [Image.open(io.BytesIO(image_dict["bytes"])) for image_dict in df["image"]]
|
999 |
+
|
1000 |
+
df["image"] = [encode_image_to_base64(image) for image in df["image"]]
|
1001 |
+
return df
|
1002 |
+
|
1003 |
+
@classmethod
|
1004 |
+
def get_multiplier(self, unit):
|
1005 |
+
|
1006 |
+
unit = unit.lower()
|
1007 |
+
if unit in ["meters", "meter", "m", "metre", "metres"]:
|
1008 |
+
multiplier = 100
|
1009 |
+
elif unit in ["centimeters", "centimeter", "cm"]:
|
1010 |
+
multiplier = 1
|
1011 |
+
elif unit in ["feet", "foot", "ft"]:
|
1012 |
+
multiplier = 30.48
|
1013 |
+
elif unit in ["inch", "inches", "in"]:
|
1014 |
+
multiplier = 2.54
|
1015 |
+
elif unit in ["mm"]:
|
1016 |
+
multiplier = 0.1
|
1017 |
+
else:
|
1018 |
+
print(f"Unknown unit: {unit}")
|
1019 |
+
multiplier = 0.
|
1020 |
+
|
1021 |
+
return multiplier
|
1022 |
+
|
1023 |
+
@classmethod
|
1024 |
+
def parse_string(self, input_str):
|
1025 |
+
# Regular expression to match the pattern (number or range, text)
|
1026 |
+
match = re.match(r'\(([\d.-]+), (.+)\)', input_str)
|
1027 |
+
if match:
|
1028 |
+
number_part = match.group(1)
|
1029 |
+
text = match.group(2)
|
1030 |
+
|
1031 |
+
if '-' in number_part:
|
1032 |
+
start, end = map(float, number_part.split('-'))
|
1033 |
+
number = (start + end) / 2
|
1034 |
+
else:
|
1035 |
+
number = float(number_part)
|
1036 |
+
|
1037 |
+
return number * self.get_multiplier(text)
|
1038 |
+
else:
|
1039 |
+
print(f"Unable to parse the input string {input_str}")
|
1040 |
+
return 0
|
1041 |
+
|
1042 |
+
@classmethod
|
1043 |
+
def parse_prediction(self, vlm_response):
|
1044 |
+
# Value
|
1045 |
+
pattern = r'scalar{([^}]*)}'
|
1046 |
+
str_inside_scalar_boxes = re.findall(pattern, vlm_response)[-1]
|
1047 |
+
scalar_list = re.findall(r'\d+\.?\d*', str_inside_scalar_boxes)
|
1048 |
+
parsed_scalar = np.array(scalar_list).astype(float).mean()
|
1049 |
+
|
1050 |
+
# Unit
|
1051 |
+
pattern = r'distance_unit{([^}]*)}'
|
1052 |
+
str_inside_unit_boxes = re.findall(pattern, vlm_response)
|
1053 |
+
parsed_unit = str_inside_unit_boxes[-1]
|
1054 |
+
|
1055 |
+
pred_value_in_cms = parsed_scalar * self.get_multiplier(parsed_unit)
|
1056 |
+
return pred_value_in_cms
|
1057 |
+
|
1058 |
+
# It returns a dictionary
|
1059 |
+
@classmethod
|
1060 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
1061 |
+
|
1062 |
+
data = load(eval_file)
|
1063 |
+
if "model" in judge_kwargs:
|
1064 |
+
from .utils.qspatial import QSpatial_auxeval
|
1065 |
+
|
1066 |
+
# extract using model
|
1067 |
+
model = judge_kwargs['model']
|
1068 |
+
suffix = eval_file.split('.')[-1]
|
1069 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
1070 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
1071 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
1072 |
+
|
1073 |
+
if not osp.exists(storage):
|
1074 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
1075 |
+
|
1076 |
+
assert model.working(), ('Evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
1077 |
+
lt = len(data)
|
1078 |
+
lines = [data.iloc[i] for i in range(lt)]
|
1079 |
+
tups = [(model, line) for line in lines]
|
1080 |
+
indices = [line['index'] for line in lines]
|
1081 |
+
|
1082 |
+
ans = {}
|
1083 |
+
if osp.exists(tmp_file):
|
1084 |
+
ans = load(tmp_file)
|
1085 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
1086 |
+
indices = [i for i in indices if i not in ans]
|
1087 |
+
|
1088 |
+
if len(indices):
|
1089 |
+
new_results = track_progress_rich(
|
1090 |
+
QSpatial_auxeval,
|
1091 |
+
tups,
|
1092 |
+
nproc=nproc,
|
1093 |
+
chunksize=nproc,
|
1094 |
+
keys=indices,
|
1095 |
+
save=tmp_file,
|
1096 |
+
)
|
1097 |
+
ans = load(tmp_file)
|
1098 |
+
for k, v in zip(indices, new_results):
|
1099 |
+
assert k in ans
|
1100 |
+
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
|
1101 |
+
|
1102 |
+
data['res'] = [ans[idx]['res'] for idx in data['index']]
|
1103 |
+
data['log'] = [ans[idx]['log'] for idx in data['index']]
|
1104 |
+
dump(data, storage)
|
1105 |
+
|
1106 |
+
data = load(storage)
|
1107 |
+
|
1108 |
+
pred_value_in_cms = []
|
1109 |
+
for res in data["res"]:
|
1110 |
+
try:
|
1111 |
+
pred_value_in_cms.append(self.parse_string(res))
|
1112 |
+
except ValueError:
|
1113 |
+
pred_value_in_cms.append(0.)
|
1114 |
+
|
1115 |
+
pred_value_in_cms = np.array(pred_value_in_cms) + 1e-8
|
1116 |
+
else:
|
1117 |
+
# regex parsing
|
1118 |
+
pred_value_in_cms = []
|
1119 |
+
n_errors_in_parsing = 0
|
1120 |
+
for pred in data["prediction"]:
|
1121 |
+
try:
|
1122 |
+
parsed_value = self.parse_prediction(pred)
|
1123 |
+
except IndexError:
|
1124 |
+
n_errors_in_parsing += 1
|
1125 |
+
parsed_value = 1e-8
|
1126 |
+
|
1127 |
+
pred_value_in_cms.append(parsed_value)
|
1128 |
+
|
1129 |
+
print(f"Encounter {n_errors_in_parsing} errors in parsing")
|
1130 |
+
pred_value_in_cms = np.array(pred_value_in_cms) + 1e-8
|
1131 |
+
|
1132 |
+
# Ground truth
|
1133 |
+
ground_truth_value_in_cms = []
|
1134 |
+
for answer in data["answer"]:
|
1135 |
+
value, unit = eval(answer)
|
1136 |
+
ground_truth_value_in_cms.append(value * self.get_multiplier(unit))
|
1137 |
+
ground_truth_value_in_cms = np.array(ground_truth_value_in_cms) + 1e-8
|
1138 |
+
|
1139 |
+
# Calculate the score
|
1140 |
+
pred_gt = pred_value_in_cms / ground_truth_value_in_cms
|
1141 |
+
gt_pred = ground_truth_value_in_cms / pred_value_in_cms
|
1142 |
+
delta_2 = np.stack([pred_gt, gt_pred]).max(0) < 2.
|
1143 |
+
delta_1_point_5 = np.stack([pred_gt, gt_pred]).max(0) < 1.5
|
1144 |
+
|
1145 |
+
data["eval_score_delta_2"] = delta_2
|
1146 |
+
data["eval_score_delta_1_point_5"] = delta_1_point_5
|
1147 |
+
|
1148 |
+
final_score_dict = {
|
1149 |
+
"delta_2": delta_2.mean(),
|
1150 |
+
"delta_1_point_5": delta_1_point_5.mean()
|
1151 |
+
}
|
1152 |
+
for question_type in set(data["question_type"]):
|
1153 |
+
filtered_data = data[data["question_type"] == question_type]
|
1154 |
+
delta_2_per_question_type = filtered_data["eval_score_delta_2"].mean()
|
1155 |
+
delta_1_point_5_per_question_type = filtered_data["eval_score_delta_1_point_5"].mean()
|
1156 |
+
final_score_dict.update({f"{question_type}_delta_2": delta_2_per_question_type})
|
1157 |
+
final_score_dict.update({f"{question_type}_delta_1_point_5": delta_1_point_5_per_question_type})
|
1158 |
+
|
1159 |
+
score_pth = eval_file.replace('.xlsx', '_score.json')
|
1160 |
+
dump(final_score_dict, score_pth)
|
1161 |
+
return final_score_dict
|
1162 |
+
|
1163 |
+
|
1164 |
+
class MMNIAH(ImageBaseDataset):
|
1165 |
+
TYPE = 'VQA'
|
1166 |
+
DATASET_URL = {
|
1167 |
+
'MM_NIAH_VAL':
|
1168 |
+
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/MM_NIAH_VAL.tsv',
|
1169 |
+
'MM_NIAH_TEST':
|
1170 |
+
['https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-aa',
|
1171 |
+
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ab',
|
1172 |
+
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ac',
|
1173 |
+
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ad',
|
1174 |
+
'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ae']}
|
1175 |
+
DATASET_MD5 = {'MM_NIAH_VAL': '27e5a8c3cef7746cb38f89cd86c474c5',
|
1176 |
+
'MM_NIAH_TEST': 'f490eb2a43096307465fe9e7ef13497c'}
|
1177 |
+
|
1178 |
+
def prepare_tsv(self, url, file_md5=None):
|
1179 |
+
import os
|
1180 |
+
data_root = LMUDataRoot()
|
1181 |
+
os.makedirs(data_root, exist_ok=True)
|
1182 |
+
update_flag = False
|
1183 |
+
file_name = 'MM_NIAH_VAL.tsv' if 'MM_NIAH_VAL' in url else 'MM_NIAH_TEST.tsv'
|
1184 |
+
data_path = osp.join(data_root, file_name)
|
1185 |
+
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
1186 |
+
pass
|
1187 |
+
elif file_name == 'MM_NIAH_TEST.tsv':
|
1188 |
+
warnings.warn('The dataset tsv is not downloaded')
|
1189 |
+
for i in range(len(url)):
|
1190 |
+
if osp.exists(osp.join(data_root, 'part-a' + chr(ord('a') + i))):
|
1191 |
+
print('part_a' + chr(ord('a') + i) + ' is existed')
|
1192 |
+
continue
|
1193 |
+
download_file(url[i], data_path)
|
1194 |
+
file_prefix = 'part-'
|
1195 |
+
output_file = data_path
|
1196 |
+
split_files = sorted([f for f in os.listdir(data_root) if f.startswith(file_prefix)])
|
1197 |
+
with open(output_file, 'wb') as outfile:
|
1198 |
+
# 逐个读取每个拆分文件并写入到输出文件
|
1199 |
+
for filename in split_files:
|
1200 |
+
with open(osp.join(data_root, filename), 'rb') as infile:
|
1201 |
+
outfile.write(infile.read())
|
1202 |
+
update_flag = True
|
1203 |
+
else:
|
1204 |
+
warnings.warn('The dataset tsv is not downloaded')
|
1205 |
+
download_file(url, data_path)
|
1206 |
+
update_flag = True
|
1207 |
+
|
1208 |
+
if file_size(data_path, 'GB') > 1:
|
1209 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
1210 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
|
1211 |
+
from ..tools import LOCALIZE
|
1212 |
+
LOCALIZE(data_path, local_path)
|
1213 |
+
data_path = local_path
|
1214 |
+
return load(data_path)
|
1215 |
+
|
1216 |
+
@classmethod
|
1217 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
1218 |
+
from .utils.mmniah import is_correct
|
1219 |
+
# find-image, count-text, find-text,
|
1220 |
+
# infer-choose, count-image, visual-reasoning
|
1221 |
+
MMNIAH_score = {
|
1222 |
+
'count-text': 0,
|
1223 |
+
'find-image': 0,
|
1224 |
+
'find-text': 0,
|
1225 |
+
'infer-choose': 0,
|
1226 |
+
'count-image': 0,
|
1227 |
+
'visual-reasoning': 0,
|
1228 |
+
'total': 0,
|
1229 |
+
}
|
1230 |
+
MMNIAH_num = {
|
1231 |
+
'count-text': 0,
|
1232 |
+
'find-image': 0,
|
1233 |
+
'find-text': 0,
|
1234 |
+
'infer-choose': 0,
|
1235 |
+
'count-image': 0,
|
1236 |
+
'visual-reasoning': 0,
|
1237 |
+
'total': 0,
|
1238 |
+
}
|
1239 |
+
final_score_dict = {
|
1240 |
+
'count-text': 0,
|
1241 |
+
'find-image': 0,
|
1242 |
+
'find-text': 0,
|
1243 |
+
'infer-choose': 0,
|
1244 |
+
'count-image': 0,
|
1245 |
+
'visual-reasoning': 0,
|
1246 |
+
'total': 0,
|
1247 |
+
}
|
1248 |
+
data = load(eval_file)
|
1249 |
+
lt = len(data)
|
1250 |
+
lines = [data.iloc[i] for i in range(lt)]
|
1251 |
+
for i in tqdm(range(len(lines))):
|
1252 |
+
line = lines[i]
|
1253 |
+
predict = line['prediction']
|
1254 |
+
answers = line['answer']
|
1255 |
+
category = line['category']
|
1256 |
+
if category in ['visual-reasoning', 'find-image']:
|
1257 |
+
answers = int(answers)
|
1258 |
+
if is_correct(answers, predict):
|
1259 |
+
MMNIAH_score[category] += 1
|
1260 |
+
MMNIAH_score['total'] += 1
|
1261 |
+
MMNIAH_num[category] += 1
|
1262 |
+
MMNIAH_num['total'] += 1
|
1263 |
+
|
1264 |
+
for category in ['find-image', 'count-text', 'find-text',
|
1265 |
+
'infer-choose', 'count-image', 'visual-reasoning', 'total']:
|
1266 |
+
if MMNIAH_num[category] != 0:
|
1267 |
+
final_score_dict[category] = MMNIAH_score[category] / MMNIAH_num[category]
|
1268 |
+
else:
|
1269 |
+
final_score_dict[category] = None
|
1270 |
+
|
1271 |
+
score_pth = eval_file.replace('.xlsx', '_score.json')
|
1272 |
+
dump(final_score_dict, score_pth)
|
1273 |
+
return final_score_dict
|
1274 |
+
|
1275 |
+
def build_prompt(self, line):
|
1276 |
+
msgs = super().build_prompt(line)
|
1277 |
+
if isinstance(line, int):
|
1278 |
+
line = self.data.iloc[line]
|
1279 |
+
totalchoice = line['multi-choice options']
|
1280 |
+
totalchoice = eval(totalchoice)
|
1281 |
+
# find-image, count-text, find-text,
|
1282 |
+
# infer-choose, count-image, visual-reasoning
|
1283 |
+
context = msgs[-1]['value']
|
1284 |
+
context = eval(context)
|
1285 |
+
question = context[0] + '\n' + context[1]
|
1286 |
+
# tgt_path是所有图像地址列表
|
1287 |
+
tgt_path = []
|
1288 |
+
for i in range(len(msgs) - 1):
|
1289 |
+
tgt_path.append(msgs[i]['value'])
|
1290 |
+
choices = totalchoice[0]
|
1291 |
+
choices_image = totalchoice[1]
|
1292 |
+
if choices:
|
1293 |
+
for c_idx, c in enumerate(choices):
|
1294 |
+
question = f"{question}\n{chr(c_idx + ord('A'))}. {c}"
|
1295 |
+
question += "\nAnswer with the option's letter from the given choices directly."
|
1296 |
+
elif choices_image:
|
1297 |
+
for c_idx in range(len(choices_image)):
|
1298 |
+
question = f"{question}\n{chr(c_idx + ord('A'))}. <image>"
|
1299 |
+
question += "\nAnswer with the option's letter from the given choices directly."
|
1300 |
+
else:
|
1301 |
+
question += '\nAnswer the question using a single word or phrase.'
|
1302 |
+
question = '<start>' + question + '<end>'
|
1303 |
+
question = question.split('<image>')
|
1304 |
+
if choices_image:
|
1305 |
+
for i in range(len(question) - 5):
|
1306 |
+
question[i] = question[i] + '\n<image>'
|
1307 |
+
for i in range(len(question) - 5, len(question) - 1):
|
1308 |
+
question[i] = question[i] + '<image>'
|
1309 |
+
else:
|
1310 |
+
for i in range(len(question) - 1):
|
1311 |
+
question[i] = question[i] + '\n<image>'
|
1312 |
+
assert len(tgt_path) + 1 == len(question)
|
1313 |
+
context = []
|
1314 |
+
for i in range(len(tgt_path)):
|
1315 |
+
context.append(question[i])
|
1316 |
+
context.append(tgt_path[i])
|
1317 |
+
context.append(question[-1])
|
1318 |
+
context[0] = context[0][7:]
|
1319 |
+
context[-1] = context[-1][:-5]
|
1320 |
+
msgs = []
|
1321 |
+
for i in range(len(context)):
|
1322 |
+
if i % 2 == 0:
|
1323 |
+
msgs.append(dict(type='text', value=context[i]))
|
1324 |
+
else:
|
1325 |
+
ROOT = LMUDataRoot()
|
1326 |
+
msgs.append(dict(type='image', value=osp.join(osp.join(ROOT, 'images', self.dataset_name), context[i])))
|
1327 |
+
for element in msgs:
|
1328 |
+
if element['value'] == '':
|
1329 |
+
msgs.remove(element)
|
1330 |
+
return msgs
|
VLMEvalKit/vlmeval/dataset/image_yorn.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..smp import *
|
2 |
+
from ..utils import *
|
3 |
+
from .image_base import ImageBaseDataset
|
4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
5 |
+
|
6 |
+
|
7 |
+
class ImageYORNDataset(ImageBaseDataset):
|
8 |
+
|
9 |
+
TYPE = 'Y/N'
|
10 |
+
|
11 |
+
DATASET_URL = {
|
12 |
+
'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
|
13 |
+
'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
|
14 |
+
'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
|
15 |
+
'AMBER': 'https://huggingface.co/datasets/yifanzhang114/AMBER_base64/resolve/main/AMBER.tsv',
|
16 |
+
}
|
17 |
+
|
18 |
+
DATASET_MD5 = {
|
19 |
+
'MME': 'b36b43c3f09801f5d368627fb92187c3',
|
20 |
+
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
|
21 |
+
'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
|
22 |
+
'AMBER': '970d94c0410916166e0a76ba75da7934',
|
23 |
+
}
|
24 |
+
|
25 |
+
# It returns a dataframe
|
26 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
27 |
+
from .utils.yorn import YOrN_Extraction, YOrN_auxeval
|
28 |
+
from .utils.yorn import default_rating, MME_rating, Hallusion_rating, POPE_rating, AMBER_rating
|
29 |
+
|
30 |
+
dataset = self.dataset_name
|
31 |
+
data = load(eval_file)
|
32 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
33 |
+
storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')
|
34 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
35 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
36 |
+
|
37 |
+
if not osp.exists(storage):
|
38 |
+
ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])}
|
39 |
+
if osp.exists(tmp_file):
|
40 |
+
tmp = load(tmp_file)
|
41 |
+
for k in tmp:
|
42 |
+
if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown':
|
43 |
+
ans_map[k] = tmp[k]
|
44 |
+
|
45 |
+
data['extracted'] = [ans_map[x] for x in data['index']]
|
46 |
+
unknown = data[data['extracted'] == 'Unknown']
|
47 |
+
|
48 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
49 |
+
if model == 'exact_matching':
|
50 |
+
model = None
|
51 |
+
elif gpt_key_set():
|
52 |
+
model = build_judge(**judge_kwargs)
|
53 |
+
if not model.working():
|
54 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
55 |
+
warnings.warn(DEBUG_MESSAGE)
|
56 |
+
model = None
|
57 |
+
else:
|
58 |
+
model = None
|
59 |
+
warnings.warn('OPENAI_API_KEY is not working properly, will use exact matching for evaluation')
|
60 |
+
|
61 |
+
if model is not None:
|
62 |
+
lt = len(unknown)
|
63 |
+
lines = [unknown.iloc[i] for i in range(lt)]
|
64 |
+
tups = [(model, line) for line in lines]
|
65 |
+
indices = list(unknown['index'])
|
66 |
+
if len(tups):
|
67 |
+
res = track_progress_rich(
|
68 |
+
YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
|
69 |
+
for k, v in zip(indices, res):
|
70 |
+
ans_map[k] = v
|
71 |
+
|
72 |
+
data['extracted'] = [ans_map[x] for x in data['index']]
|
73 |
+
dump(data, storage)
|
74 |
+
|
75 |
+
data = load(storage)
|
76 |
+
if listinstr(['AMBER'], dataset):
|
77 |
+
data['score'] = (data['answer'].str.lower() == data['extracted'].str.lower())
|
78 |
+
else:
|
79 |
+
data['score'] = (data['answer'] == data['extracted'])
|
80 |
+
dump(data, storage)
|
81 |
+
|
82 |
+
if dataset is not None and listinstr(['MME'], dataset):
|
83 |
+
score = MME_rating(storage)
|
84 |
+
elif dataset is not None and listinstr(['Hallusion'], dataset):
|
85 |
+
score = Hallusion_rating(storage)
|
86 |
+
elif dataset is not None and listinstr(['POPE'], dataset):
|
87 |
+
score = POPE_rating(storage)
|
88 |
+
elif dataset is not None and listinstr(['AMBER'], dataset):
|
89 |
+
score = AMBER_rating(storage)
|
90 |
+
else:
|
91 |
+
score = default_rating(storage)
|
92 |
+
|
93 |
+
score_tgt = eval_file.replace('.xlsx', '_score.csv')
|
94 |
+
dump(score, score_tgt)
|
95 |
+
return score
|
VLMEvalKit/vlmeval/dataset/longvideobench.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
from ..smp import *
|
3 |
+
from .video_base import VideoBaseDataset
|
4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
8 |
+
|
9 |
+
|
10 |
+
def timestamp_to_seconds(timestamp):
|
11 |
+
# Split the timestamp into hours, minutes, and seconds
|
12 |
+
h, m, s = timestamp.split(":")
|
13 |
+
# Convert hours, minutes, and total seconds (including fractions) to float and compute total seconds
|
14 |
+
total_seconds = int(h) * 3600 + int(m) * 60 + float(s)
|
15 |
+
return total_seconds
|
16 |
+
|
17 |
+
|
18 |
+
def uniformly_subsample(lst, K):
|
19 |
+
n = len(lst)
|
20 |
+
if K >= n:
|
21 |
+
return lst
|
22 |
+
step = n / K
|
23 |
+
return [lst[int(i * step)] for i in range(K)]
|
24 |
+
|
25 |
+
|
26 |
+
def insert_subtitles_into_frames(
|
27 |
+
frames,
|
28 |
+
frame_timestamps,
|
29 |
+
subtitles,
|
30 |
+
starting_timestamp_for_subtitles,
|
31 |
+
duration,
|
32 |
+
):
|
33 |
+
interleaved_list = []
|
34 |
+
cur_i = 0
|
35 |
+
|
36 |
+
for subtitle in subtitles:
|
37 |
+
if "timestamp" in subtitle:
|
38 |
+
start, end = subtitle["timestamp"]
|
39 |
+
|
40 |
+
if not isinstance(end, float):
|
41 |
+
end = duration
|
42 |
+
|
43 |
+
start -= starting_timestamp_for_subtitles
|
44 |
+
end -= starting_timestamp_for_subtitles
|
45 |
+
|
46 |
+
subtitle_timestamp = (start + end) / 2
|
47 |
+
subtitle_text = subtitle["text"]
|
48 |
+
else:
|
49 |
+
start, end = subtitle["start"], subtitle["end"]
|
50 |
+
start = timestamp_to_seconds(start)
|
51 |
+
end = timestamp_to_seconds(end)
|
52 |
+
start -= starting_timestamp_for_subtitles
|
53 |
+
end -= starting_timestamp_for_subtitles
|
54 |
+
|
55 |
+
subtitle_timestamp = (start + end) / 2
|
56 |
+
subtitle_text = subtitle["line"]
|
57 |
+
|
58 |
+
for i, (frame, frame_timestamp) in enumerate(
|
59 |
+
zip(frames[cur_i:], frame_timestamps[cur_i:])
|
60 |
+
):
|
61 |
+
if frame_timestamp <= subtitle_timestamp:
|
62 |
+
# print("frame:", frame_timestamp)
|
63 |
+
interleaved_list.append({"type": "image", "value": frame})
|
64 |
+
cur_i += 1
|
65 |
+
else:
|
66 |
+
break
|
67 |
+
|
68 |
+
if end - start < 1:
|
69 |
+
end = subtitle_timestamp + 0.5
|
70 |
+
start = subtitle_timestamp - 0.5
|
71 |
+
|
72 |
+
covering_frames = False
|
73 |
+
for frame, frame_timestamp in zip(frames, frame_timestamps):
|
74 |
+
if frame_timestamp < end and frame_timestamp > start:
|
75 |
+
covering_frames = True
|
76 |
+
break
|
77 |
+
|
78 |
+
if covering_frames:
|
79 |
+
interleaved_list.append({"type": "text", "value": subtitle_text + "\n"})
|
80 |
+
else:
|
81 |
+
pass
|
82 |
+
|
83 |
+
for i, (frame, frame_timestamp) in enumerate(
|
84 |
+
zip(frames[cur_i:], frame_timestamps[cur_i:])
|
85 |
+
):
|
86 |
+
interleaved_list.append({"type": "image", "value": frame})
|
87 |
+
return interleaved_list
|
88 |
+
|
89 |
+
|
90 |
+
class LongVideoBench(VideoBaseDataset):
|
91 |
+
|
92 |
+
MD5 = '82905eae3a5ae7383c5a8ee9655e1ab9'
|
93 |
+
SYS = ''
|
94 |
+
|
95 |
+
TYPE = 'Video-MCQ'
|
96 |
+
|
97 |
+
def __init__(self, dataset='LongVideoBench', use_subtitle=False):
|
98 |
+
super().__init__(dataset=dataset)
|
99 |
+
self.use_subtitle = use_subtitle
|
100 |
+
self.dataset_name = dataset
|
101 |
+
|
102 |
+
@classmethod
|
103 |
+
def supported_datasets(cls):
|
104 |
+
return ['LongVideoBench']
|
105 |
+
|
106 |
+
def prepare_dataset(self, dataset_name='LongVideoBench', repo_id='longvideobench/LongVideoBench'):
|
107 |
+
|
108 |
+
def check_integrity(pth):
|
109 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
110 |
+
|
111 |
+
if not osp.exists(data_file):
|
112 |
+
return False
|
113 |
+
|
114 |
+
if md5(data_file) != self.MD5:
|
115 |
+
print("md5 mismatch", md5(data_file), self.MD5)
|
116 |
+
return False
|
117 |
+
data = load(data_file)
|
118 |
+
for video_pth in data['video_path']:
|
119 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
120 |
+
print(video_pth, "is not found")
|
121 |
+
return False
|
122 |
+
return True
|
123 |
+
|
124 |
+
if modelscope_flag_set():
|
125 |
+
repo_id = "AI-ModelScope/LongVideoBench"
|
126 |
+
|
127 |
+
cache_path = get_cache_path(repo_id)
|
128 |
+
if cache_path is not None and check_integrity(cache_path):
|
129 |
+
dataset_path = cache_path
|
130 |
+
else:
|
131 |
+
def generate_tsv(pth):
|
132 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
133 |
+
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
134 |
+
return
|
135 |
+
|
136 |
+
data_file = pd.read_json(osp.join(pth, 'lvb_val.json'))
|
137 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
138 |
+
data_file['video'] = data_file['video_id']
|
139 |
+
data_file['video_path'] = data_file['video_path'].apply(lambda x: f'./videos/{x}')
|
140 |
+
|
141 |
+
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
|
142 |
+
|
143 |
+
if modelscope_flag_set():
|
144 |
+
from modelscope import dataset_snapshot_download
|
145 |
+
dataset_snapshot_download(dataset_id=repo_id)
|
146 |
+
else:
|
147 |
+
snapshot_download(repo_id=repo_id, repo_type='dataset')
|
148 |
+
print("All videos are downloaded for LongVideoBench")
|
149 |
+
|
150 |
+
if not glob(osp.join(cache_path, "videos")):
|
151 |
+
tar_files = glob(osp.join(cache_path, "**/*.tar*"), recursive=True)
|
152 |
+
|
153 |
+
def untar_video_data(tar_file, cache_dir):
|
154 |
+
import tarfile
|
155 |
+
with tarfile.open(tar_file, "r") as tar_ref:
|
156 |
+
tar_ref.extractall(cache_dir)
|
157 |
+
print(f"Extracted all files from {tar_file} to {cache_dir}")
|
158 |
+
|
159 |
+
def concat_tar_parts(tar_parts, output_tar):
|
160 |
+
with open(output_tar, "wb") as out_tar:
|
161 |
+
from tqdm import tqdm
|
162 |
+
for part in tqdm(sorted(tar_parts)):
|
163 |
+
with open(part, "rb") as part_file:
|
164 |
+
out_tar.write(part_file.read())
|
165 |
+
print(f"Concatenated parts {tar_parts} into {output_tar}")
|
166 |
+
|
167 |
+
tar_parts_dict = {}
|
168 |
+
|
169 |
+
# Group tar parts together
|
170 |
+
for tar_file in tar_files:
|
171 |
+
base_name = tar_file.split(".tar")[0]
|
172 |
+
if base_name not in tar_parts_dict:
|
173 |
+
tar_parts_dict[base_name] = []
|
174 |
+
tar_parts_dict[base_name].append(tar_file)
|
175 |
+
|
176 |
+
# Concatenate and untar split parts
|
177 |
+
for base_name, parts in tar_parts_dict.items():
|
178 |
+
print(f"Extracting following tar files: {parts}")
|
179 |
+
output_tar = base_name + ".tar"
|
180 |
+
if not osp.exists(output_tar):
|
181 |
+
print('Start concatenating tar files')
|
182 |
+
|
183 |
+
concat_tar_parts(parts, output_tar)
|
184 |
+
print('Finish concatenating tar files')
|
185 |
+
|
186 |
+
if not osp.exists(osp.join(cache_path, osp.basename(base_name))):
|
187 |
+
untar_video_data(output_tar, cache_path)
|
188 |
+
|
189 |
+
print('All videos are extracted for LongVideoBench')
|
190 |
+
|
191 |
+
dataset_path = cache_path
|
192 |
+
generate_tsv(dataset_path)
|
193 |
+
|
194 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
195 |
+
|
196 |
+
return dict(data_file=data_file, root=dataset_path)
|
197 |
+
|
198 |
+
def save_video_frames(self, video_path, num_frames=8, fps=-1, video_llm=False):
|
199 |
+
|
200 |
+
vid_path = osp.join(self.data_root, video_path)
|
201 |
+
vid = decord.VideoReader(vid_path)
|
202 |
+
video_info = {
|
203 |
+
'fps': vid.get_avg_fps(),
|
204 |
+
'n_frames': len(vid),
|
205 |
+
}
|
206 |
+
if num_frames > 0 and fps < 0:
|
207 |
+
step_size = len(vid) / (num_frames + 1)
|
208 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
209 |
+
frame_paths = self.frame_paths(video_path[:-4], num_frames)
|
210 |
+
elif fps > 0:
|
211 |
+
# not constrained by num_frames, get frames by fps
|
212 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
213 |
+
required_frames = int(total_duration * fps)
|
214 |
+
step_size = video_info['fps'] / fps
|
215 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
216 |
+
frame_paths = self.frame_paths_fps(video_path[:-4], len(indices), fps)
|
217 |
+
|
218 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
219 |
+
|
220 |
+
if not flag:
|
221 |
+
images = [vid[i].asnumpy() for i in indices]
|
222 |
+
images = [Image.fromarray(arr) for arr in images]
|
223 |
+
for im, pth in zip(images, frame_paths):
|
224 |
+
if not osp.exists(pth) and not video_llm:
|
225 |
+
im.save(pth)
|
226 |
+
|
227 |
+
return frame_paths, indices, video_info
|
228 |
+
|
229 |
+
def save_video_into_images(self, line, num_frames=8):
|
230 |
+
frame_paths, indices, video_info = self.save_video_frames(line['video_path'], num_frames)
|
231 |
+
return frame_paths
|
232 |
+
|
233 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
234 |
+
if isinstance(line, int):
|
235 |
+
assert line < len(self)
|
236 |
+
line = self.data.iloc[line]
|
237 |
+
|
238 |
+
frames, indices, video_info = self.save_video_frames(line['video_path'], num_frames, fps, video_llm)
|
239 |
+
fps = video_info["fps"]
|
240 |
+
|
241 |
+
message = [dict(type='text', value=self.SYS)]
|
242 |
+
if video_llm:
|
243 |
+
message.append(dict(type='video', value=osp.join(self.data_root, line['video_path'])))
|
244 |
+
else:
|
245 |
+
if not self.use_subtitle:
|
246 |
+
with open(osp.join(self.data_root, "subtitles", line["subtitle_path"])) as f:
|
247 |
+
subtitles = json.load(f)
|
248 |
+
|
249 |
+
frame_message = insert_subtitles_into_frames(
|
250 |
+
frames,
|
251 |
+
[ind_ / fps for ind_ in indices],
|
252 |
+
subtitles,
|
253 |
+
line["starting_timestamp_for_subtitles"],
|
254 |
+
line["duration"]
|
255 |
+
)
|
256 |
+
|
257 |
+
message += frame_message
|
258 |
+
else:
|
259 |
+
for im in frames:
|
260 |
+
message.append(dict(type='image', value=im))
|
261 |
+
|
262 |
+
line['question'] += '\n' + '\n'.join(
|
263 |
+
["{}. {}".format(chr(ord("A") + i), cand) for i, cand in enumerate(eval(line['candidates']))]
|
264 |
+
)
|
265 |
+
prompt = line["question"] + "\nAnswer with the option's letter from the given choices directly."
|
266 |
+
message.append(dict(type='text', value=prompt))
|
267 |
+
return message
|
268 |
+
|
269 |
+
# It returns a dictionary
|
270 |
+
@classmethod
|
271 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
272 |
+
from .utils.longvideobench import get_dimension_rating, extract_characters_regex, extract_option
|
273 |
+
|
274 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
275 |
+
|
276 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
277 |
+
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
278 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
279 |
+
|
280 |
+
if not osp.exists(score_file):
|
281 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
282 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
283 |
+
|
284 |
+
if model == 'exact_matching':
|
285 |
+
model = None
|
286 |
+
elif gpt_key_set():
|
287 |
+
model = build_judge(**judge_kwargs)
|
288 |
+
if not model.working():
|
289 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
290 |
+
warnings.warn(DEBUG_MESSAGE)
|
291 |
+
model = None
|
292 |
+
else:
|
293 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
294 |
+
model = None
|
295 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
296 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
297 |
+
|
298 |
+
data = load(eval_file)
|
299 |
+
data_un = data[~pd.isna(data['prediction'])]
|
300 |
+
|
301 |
+
for idx in data['index']:
|
302 |
+
ans = data.loc[data['index'] == idx, 'correct_choice'].values[0]
|
303 |
+
ans = chr(ord("A") + ans)
|
304 |
+
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
|
305 |
+
|
306 |
+
if extract_characters_regex(pred) == '':
|
307 |
+
extract_pred = extract_option(
|
308 |
+
model,
|
309 |
+
data.loc[data['index'] == idx].to_dict(orient='records')[0],
|
310 |
+
'LongVideoBench'
|
311 |
+
)
|
312 |
+
data.loc[idx, 'score'] = int(extract_pred == ans)
|
313 |
+
else:
|
314 |
+
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
|
315 |
+
|
316 |
+
rejected = [x for x in data['score'] if x == -1]
|
317 |
+
|
318 |
+
print(
|
319 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
320 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
321 |
+
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
322 |
+
)
|
323 |
+
|
324 |
+
dump(data, score_file)
|
325 |
+
|
326 |
+
rating = get_dimension_rating(score_file)
|
327 |
+
dump(rating, tgt_file)
|
328 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/miabench.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from .image_base import ImageBaseDataset
|
7 |
+
from ..smp import *
|
8 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
9 |
+
from ..utils import track_progress_rich
|
10 |
+
|
11 |
+
|
12 |
+
def generate_prompt(d):
|
13 |
+
question = d['question']
|
14 |
+
weights = eval(d['component_weight'])
|
15 |
+
components = eval(d['components'])
|
16 |
+
num_of_component = int(d['num_of_component'])
|
17 |
+
response = d['prediction']
|
18 |
+
|
19 |
+
if num_of_component == 1:
|
20 |
+
components = f"The first component is: '{components[0]}'. "
|
21 |
+
score = f"The first component is worth: {weights[0]} scores. "
|
22 |
+
elif num_of_component == 2:
|
23 |
+
components = f"The first component is: '{components[0]}', and the second component is '{components[1]}'. "
|
24 |
+
score = f"The first and second component is each worth {weights[0]} and {weights[1]} scores. "
|
25 |
+
elif num_of_component == 3:
|
26 |
+
components = (
|
27 |
+
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
28 |
+
f"and the third component is '{components[2]}'. "
|
29 |
+
)
|
30 |
+
score = (
|
31 |
+
"The first, second, and third component is each worth "
|
32 |
+
f"{weights[0]}, {weights[1]}, and {weights[2]} scores."
|
33 |
+
)
|
34 |
+
elif num_of_component == 4:
|
35 |
+
components = (
|
36 |
+
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
37 |
+
f"and the third component is '{components[2]}', and the fourth component is '{components[3]}'. "
|
38 |
+
)
|
39 |
+
score = (
|
40 |
+
"The first, second, third, and fourth component is each worth "
|
41 |
+
f"{weights[0]}, {weights[1]}, {weights[2]}, and {weights[3]} scores."
|
42 |
+
)
|
43 |
+
elif num_of_component == 5:
|
44 |
+
components = (
|
45 |
+
f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
|
46 |
+
f"and the third component is '{components[2]}', and the fourth component is '{components[3]}', "
|
47 |
+
f"and the fifth component is '{components[4]}'. "
|
48 |
+
)
|
49 |
+
score = (
|
50 |
+
"The first, second, third, fourth, and fifth component is each worth "
|
51 |
+
f"{weights[0]}, {weights[1]}, {weights[2]}, {weights[3]}, and {weights[4]} scores."
|
52 |
+
)
|
53 |
+
|
54 |
+
return (
|
55 |
+
"Here is an instruction for a multimodal LLM: '"
|
56 |
+
f"{question}"
|
57 |
+
"'. You need to grade if the response from the model follows each component of the instruction. "
|
58 |
+
f"{components}"
|
59 |
+
"The response is: '"
|
60 |
+
f"{response}"
|
61 |
+
"'. You need to score the response and be strict. The total score ranges from 0 to 10, "
|
62 |
+
"depending on if the response follows the instruction. "
|
63 |
+
f"{score}"
|
64 |
+
"List scores of each component, and the total score in one sentence in this format: "
|
65 |
+
"score of component 1: x/2, score of component 2: y/8, total score: z/10. Then explain your reasons."
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def process_rawscore(component_type, raw_score):
|
70 |
+
first_sentence = raw_score.split('.')[0].split(',')
|
71 |
+
score_dict = {}
|
72 |
+
for i in range(len(first_sentence) - 1):
|
73 |
+
score_ = first_sentence[i].split(':')[1][1:].split('/')
|
74 |
+
score = int(score_[0]) / int(score_[1])
|
75 |
+
score_dict[component_type[i]] = score
|
76 |
+
total_score_ = first_sentence[i + 1].split(':')[1][1:].split('/')
|
77 |
+
total_score = int(total_score_[0]) / int(total_score_[1])
|
78 |
+
score_dict['total_score'] = total_score
|
79 |
+
return score_dict
|
80 |
+
|
81 |
+
|
82 |
+
def get_score_dict(data, score_raw):
|
83 |
+
cat_score_dict = {}
|
84 |
+
for i in range(len(data)):
|
85 |
+
try:
|
86 |
+
cmp = data['component_type'][i][2:-2]
|
87 |
+
cmp_list = cmp.split('\', \'')
|
88 |
+
score_dict = process_rawscore(cmp_list, score_raw[i])
|
89 |
+
for key, val in score_dict.items():
|
90 |
+
if key not in cat_score_dict.keys():
|
91 |
+
cat_score_dict[key] = [val]
|
92 |
+
else:
|
93 |
+
cat_score_dict[key].append(val)
|
94 |
+
except:
|
95 |
+
pass
|
96 |
+
cat_score_dict_average = {}
|
97 |
+
for key, val in cat_score_dict.items():
|
98 |
+
cat_score_dict_average[key] = sum(val) / len(val)
|
99 |
+
return cat_score_dict_average
|
100 |
+
|
101 |
+
|
102 |
+
class MIABench(ImageBaseDataset):
|
103 |
+
TYPE = 'VQA'
|
104 |
+
|
105 |
+
DATASET_URL = {
|
106 |
+
'MIA-Bench': 'https://opencompass.openxlab.space/utils/VLMEval/Mia-Bench.tsv',
|
107 |
+
}
|
108 |
+
DATASET_MD5 = {
|
109 |
+
'MIA-Bench': '0b9de595f4dd40af18a69b94d89aba82',
|
110 |
+
}
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
114 |
+
judge_name = judge_kwargs.pop('model', 'gpt-4o')
|
115 |
+
|
116 |
+
model = build_judge(model=judge_name, **judge_kwargs)
|
117 |
+
suffix = eval_file.split('.')[-1]
|
118 |
+
|
119 |
+
storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
|
120 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
|
121 |
+
nproc = judge_kwargs.pop('nproc', 4) # noqa: F841
|
122 |
+
|
123 |
+
if not osp.exists(storage):
|
124 |
+
data = load(eval_file)
|
125 |
+
num_samples = len(data)
|
126 |
+
lines = [data.loc[i] for i in range(num_samples)]
|
127 |
+
prompts = [generate_prompt(line) for line in lines]
|
128 |
+
org_data = MIABench('MIA-Bench').data
|
129 |
+
img_map = {x: y for x, y in zip(org_data['index'], org_data['image'])}
|
130 |
+
image_b64 = [img_map[idx] for idx in data['index']]
|
131 |
+
indices = list(data['index'])
|
132 |
+
mm_messages = [
|
133 |
+
dict(message=[
|
134 |
+
dict(type='text', value=prompt),
|
135 |
+
dict(type='image', value=f'data:image/jpeg;base64,{b64}')
|
136 |
+
])
|
137 |
+
for prompt, b64 in zip(prompts, image_b64)
|
138 |
+
]
|
139 |
+
|
140 |
+
res = {}
|
141 |
+
if osp.exists(tmp_file):
|
142 |
+
res = load(tmp_file)
|
143 |
+
|
144 |
+
jobs = {k: v for k, v in zip(indices, mm_messages) if k not in res}
|
145 |
+
job_keys = list(jobs.keys())
|
146 |
+
job_vals = [jobs[k] for k in job_keys]
|
147 |
+
|
148 |
+
resps = track_progress_rich(
|
149 |
+
model.generate,
|
150 |
+
job_vals,
|
151 |
+
nproc=nproc,
|
152 |
+
chunksize=nproc,
|
153 |
+
keys=job_keys,
|
154 |
+
save=tmp_file,
|
155 |
+
)
|
156 |
+
for k, resp in zip(job_keys, resps):
|
157 |
+
res[k] = resp
|
158 |
+
data['score_raw'] = [res[idx] for idx in indices]
|
159 |
+
dump(data, storage)
|
160 |
+
|
161 |
+
goresult = load(storage)
|
162 |
+
results = get_score_dict(goresult, goresult['score_raw'])
|
163 |
+
result_pth = storage.replace('.xlsx', '_score.csv')
|
164 |
+
results_pd = pd.DataFrame.from_dict(list(results.items()))
|
165 |
+
dump(results_pd, result_pth)
|
166 |
+
|
167 |
+
return results
|
VLMEvalKit/vlmeval/dataset/mlvu.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import huggingface_hub
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
from ..smp import *
|
4 |
+
from .video_concat_dataset import ConcatVideoDataset
|
5 |
+
from .video_base import VideoBaseDataset
|
6 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
7 |
+
from ..utils import track_progress_rich
|
8 |
+
import torchvision.transforms as T
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.transforms.functional import InterpolationMode
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
import pandas as pd
|
13 |
+
import imageio
|
14 |
+
import cv2
|
15 |
+
import zipfile
|
16 |
+
import os
|
17 |
+
import glob
|
18 |
+
from .utils.mlvu import *
|
19 |
+
|
20 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
21 |
+
|
22 |
+
|
23 |
+
class MLVU(ConcatVideoDataset):
|
24 |
+
def __init__(self, dataset='MLVU'):
|
25 |
+
self.DATASET_SETS[dataset] = ['MLVU_MCQ', 'MLVU_OpenEnded']
|
26 |
+
self.type_data_dict = {
|
27 |
+
'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning'],
|
28 |
+
'G-Avg':['sub_scene', 'summary']
|
29 |
+
}
|
30 |
+
super().__init__(dataset=dataset)
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def supported_datasets(cls):
|
34 |
+
return ['MLVU']
|
35 |
+
|
36 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
37 |
+
result = super().evaluate(eval_file=eval_file, **judge_kwargs)
|
38 |
+
suffix = eval_file.split('.')[-1]
|
39 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
40 |
+
for key in self.type_data_dict:
|
41 |
+
result.loc[key] = 0.0
|
42 |
+
for name, item in result.iterrows():
|
43 |
+
if name in self.type_data_dict[key]:
|
44 |
+
result.loc[key, 'success'] += item['success']
|
45 |
+
result.loc[key, 'overall'] += item['overall']
|
46 |
+
if key == 'G-Avg':
|
47 |
+
result.loc[key, 'acc'] = round(
|
48 |
+
result.loc[key, 'success'] / result.loc[key, 'overall'], 2
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
result.loc[key, 'acc'] = round(
|
52 |
+
result.loc[key, 'success'] / result.loc[key, 'overall'] * 100, 1
|
53 |
+
)
|
54 |
+
result = result.reset_index().rename(columns={'index': 'task'})
|
55 |
+
dump(result, score_file)
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
class MLVU_MCQ(VideoBaseDataset):
|
60 |
+
|
61 |
+
MD5 = 'bb5c37e7cf8d43fc9a25c23d2b4633f5'
|
62 |
+
BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
|
63 |
+
SYS = BASE_SYS + 'Based on your observations, select the best option that accurately addresses the question.'
|
64 |
+
TYPE = 'Video-MCQ'
|
65 |
+
|
66 |
+
def __init__(self, dataset='MLVU_MCQ'):
|
67 |
+
self.type_data_list = {
|
68 |
+
'plotQA': ('1_plotQA.json', './MLVU/video/1_plotQA', 'MCQ'),
|
69 |
+
'needle': ('2_needle.json', './MLVU/video/2_needle', 'MCQ'),
|
70 |
+
'ego': ('3_ego.json', './MLVU/video/3_ego', 'MCQ'),
|
71 |
+
'count': ('4_count.json', './MLVU/video/4_count', 'MCQ'),
|
72 |
+
'order': ('5_order.json', './MLVU/video/5_order', 'MCQ'),
|
73 |
+
'anomaly_reco': ('6_anomaly_reco.json', './MLVU/video/6_anomaly_reco', 'MCQ'),
|
74 |
+
'topic_reasoning': ('7_topic_reasoning.json', './MLVU/video/7_topic_reasoning', 'MCQ'),
|
75 |
+
}
|
76 |
+
super().__init__(dataset=dataset)
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def supported_datasets(cls):
|
80 |
+
return ['MLVU_MCQ']
|
81 |
+
|
82 |
+
def prepare_dataset(self, dataset_name='MLVU_MCQ', repo_id='MLVU/MVLU'):
|
83 |
+
def check_integrity(pth):
|
84 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
85 |
+
|
86 |
+
if not os.path.exists(data_file):
|
87 |
+
return False
|
88 |
+
|
89 |
+
if md5(data_file) != self.MD5:
|
90 |
+
return False
|
91 |
+
|
92 |
+
data = load(data_file)
|
93 |
+
for idx, item in data.iterrows():
|
94 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
95 |
+
return False
|
96 |
+
return True
|
97 |
+
|
98 |
+
if modelscope_flag_set():
|
99 |
+
repo_id = "AI-ModelScope/MLVU"
|
100 |
+
|
101 |
+
cache_path = get_cache_path(repo_id)
|
102 |
+
if cache_path is not None and check_integrity(cache_path):
|
103 |
+
dataset_path = cache_path
|
104 |
+
else:
|
105 |
+
def generate_tsv(pth):
|
106 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
107 |
+
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
108 |
+
return
|
109 |
+
json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
|
110 |
+
self.data_list = []
|
111 |
+
for k, v in self.type_data_list.items():
|
112 |
+
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
113 |
+
json_data = json.load(f)
|
114 |
+
for data in json_data:
|
115 |
+
self.data_list.append({
|
116 |
+
'task_type': k,
|
117 |
+
'prefix': v[1],
|
118 |
+
'duration': data['duration'],
|
119 |
+
'video': data['video'],
|
120 |
+
'question': data['question'],
|
121 |
+
'answer': data['answer'],
|
122 |
+
'candidates': data['candidates'],
|
123 |
+
})
|
124 |
+
|
125 |
+
data_df = pd.DataFrame(self.data_list)
|
126 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
127 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
128 |
+
|
129 |
+
if modelscope_flag_set():
|
130 |
+
from modelscope import dataset_snapshot_download
|
131 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
132 |
+
else:
|
133 |
+
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
134 |
+
huggingface_hub.login(hf_token)
|
135 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
136 |
+
|
137 |
+
generate_tsv(dataset_path)
|
138 |
+
|
139 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
140 |
+
return dict(root=dataset_path, data_file=data_file)
|
141 |
+
|
142 |
+
def qa_template(self, data):
|
143 |
+
question = f"Question: {data['question']}\n"
|
144 |
+
question += 'Options:\n'
|
145 |
+
answer = data['answer']
|
146 |
+
answer_idx = -1
|
147 |
+
for idx, c in enumerate(eval(data['candidates'])):
|
148 |
+
question += f"({chr(ord('A') + idx)}) {c}\n"
|
149 |
+
if c == answer:
|
150 |
+
answer_idx = idx
|
151 |
+
question = question.rstrip()
|
152 |
+
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
153 |
+
return question, answer
|
154 |
+
|
155 |
+
def save_video_frames(self, line, num_frames=8, fps=-1):
|
156 |
+
suffix = line['video'].split('.')[-1]
|
157 |
+
video = line['video'].replace(f'.{suffix}','')
|
158 |
+
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
|
159 |
+
vid = decord.VideoReader(vid_path)
|
160 |
+
video_info = {
|
161 |
+
'fps': vid.get_avg_fps(),
|
162 |
+
'n_frames': len(vid),
|
163 |
+
}
|
164 |
+
if num_frames > 0 and fps < 0:
|
165 |
+
step_size = len(vid) / (num_frames + 1)
|
166 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
167 |
+
frame_paths = self.frame_paths(video, num_frames)
|
168 |
+
elif fps > 0:
|
169 |
+
# not constrained by num_frames, get frames by fps
|
170 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
171 |
+
required_frames = int(total_duration * fps)
|
172 |
+
step_size = video_info['fps'] / fps
|
173 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
174 |
+
frame_paths = self.frame_paths_fps(video, len(indices), fps)
|
175 |
+
|
176 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
177 |
+
|
178 |
+
if not flag:
|
179 |
+
images = [vid[i].asnumpy() for i in indices]
|
180 |
+
images = [Image.fromarray(arr) for arr in images]
|
181 |
+
for im, pth in zip(images, frame_paths):
|
182 |
+
if not osp.exists(pth):
|
183 |
+
im.save(pth)
|
184 |
+
|
185 |
+
return frame_paths
|
186 |
+
|
187 |
+
def save_video_into_images(self, line, num_frames, fps):
|
188 |
+
frame_paths = self.save_video_frames(line, num_frames, fps)
|
189 |
+
return frame_paths
|
190 |
+
|
191 |
+
def build_prompt(self, line, num_frames, video_llm, fps=-1):
|
192 |
+
if isinstance(line, int):
|
193 |
+
assert line < len(self)
|
194 |
+
line = self.data.iloc[line]
|
195 |
+
|
196 |
+
question, answer = self.qa_template(line)
|
197 |
+
message = [dict(type='text', value=self.SYS, role='system')]
|
198 |
+
message.append(dict(type='text', value=question))
|
199 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
200 |
+
if video_llm:
|
201 |
+
message.append(dict(type='video', value=video_path))
|
202 |
+
else:
|
203 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
204 |
+
for im in img_frame_paths:
|
205 |
+
message.append(dict(type='image', value=im))
|
206 |
+
message.append(dict(type='text', value='\nOnly give the best option.'))
|
207 |
+
return message
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
211 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
212 |
+
|
213 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
214 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
215 |
+
|
216 |
+
if not osp.exists(score_file):
|
217 |
+
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
218 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
219 |
+
|
220 |
+
if model == 'exact_matching':
|
221 |
+
model = None
|
222 |
+
elif gpt_key_set():
|
223 |
+
model = build_judge(**judge_kwargs)
|
224 |
+
if not model.working():
|
225 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
226 |
+
warnings.warn(DEBUG_MESSAGE)
|
227 |
+
model = None
|
228 |
+
else:
|
229 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
230 |
+
model = None
|
231 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
232 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
233 |
+
|
234 |
+
data = load(eval_file)
|
235 |
+
data_un = data[~pd.isna(data['prediction'])]
|
236 |
+
|
237 |
+
for idx in data['index']:
|
238 |
+
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
239 |
+
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
240 |
+
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
241 |
+
answer_idx = -1
|
242 |
+
for id, c in enumerate(options):
|
243 |
+
if c == ans:
|
244 |
+
answer_idx = id
|
245 |
+
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
246 |
+
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
247 |
+
for id, option_content in enumerate(eval(input_item['candidates'])):
|
248 |
+
input_item[chr(ord('A') + id)] = option_content
|
249 |
+
if option_content == input_item['answer']:
|
250 |
+
input_item['answer'] = chr(ord('A') + id)
|
251 |
+
|
252 |
+
if FAIL_MSG in pred:
|
253 |
+
data.loc[idx, 'score'] = -1
|
254 |
+
else:
|
255 |
+
data.loc[idx, 'score'] = int(check_ans_with_model(
|
256 |
+
pred, ans, model,
|
257 |
+
input_item,
|
258 |
+
'MLVU_MCQ'
|
259 |
+
))
|
260 |
+
|
261 |
+
rejected = [x for x in data['score'] if x == -1]
|
262 |
+
|
263 |
+
print(
|
264 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
265 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
266 |
+
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
267 |
+
)
|
268 |
+
|
269 |
+
dump(data, score_file)
|
270 |
+
|
271 |
+
rating = get_dimension_rating(score_file)
|
272 |
+
return rating
|
273 |
+
|
274 |
+
|
275 |
+
class MLVU_OpenEnded(VideoBaseDataset):
|
276 |
+
|
277 |
+
MD5 = 'cee573a3627c6ac434ded704c60511ba'
|
278 |
+
BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
|
279 |
+
SYS = BASE_SYS + 'Based on your observations, answer the given questions.'
|
280 |
+
TYPE = 'Video-VQA'
|
281 |
+
|
282 |
+
def __init__(self, dataset='MLVU_OpenEnded'):
|
283 |
+
self.type_data_list = {
|
284 |
+
'sub_scene': ('8_sub_scene.json', './MLVU/video/8_sub_scene', 'VQA'),
|
285 |
+
'summary': ('9_summary.json', './MLVU/video/9_summary', 'VQA')
|
286 |
+
}
|
287 |
+
super().__init__(dataset=dataset)
|
288 |
+
|
289 |
+
@classmethod
|
290 |
+
def supported_datasets(cls):
|
291 |
+
return ['MLVU_OpenEnded']
|
292 |
+
|
293 |
+
def prepare_dataset(self, dataset_name='MLVU_OpenEnded', repo_id='MLVU/MVLU'):
|
294 |
+
def check_integrity(pth):
|
295 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
296 |
+
|
297 |
+
if not os.path.exists(data_file):
|
298 |
+
return False
|
299 |
+
|
300 |
+
if md5(data_file) != self.MD5:
|
301 |
+
return False
|
302 |
+
|
303 |
+
data = load(data_file)
|
304 |
+
for idx, item in data.iterrows():
|
305 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
306 |
+
return False
|
307 |
+
return True
|
308 |
+
|
309 |
+
if modelscope_flag_set():
|
310 |
+
repo_id = "AI-ModelScope/MLVU"
|
311 |
+
|
312 |
+
cache_path = get_cache_path(repo_id)
|
313 |
+
if cache_path is not None and check_integrity(cache_path):
|
314 |
+
dataset_path = cache_path
|
315 |
+
else:
|
316 |
+
def generate_tsv(pth):
|
317 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
318 |
+
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
319 |
+
return
|
320 |
+
json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
|
321 |
+
self.data_list = []
|
322 |
+
for k, v in self.type_data_list.items():
|
323 |
+
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
324 |
+
json_data = json.load(f)
|
325 |
+
for data in json_data:
|
326 |
+
self.data_list.append({
|
327 |
+
'task_type': k,
|
328 |
+
'prefix': v[1],
|
329 |
+
'duration': data['duration'],
|
330 |
+
'video': data['video'],
|
331 |
+
'question': data['question'],
|
332 |
+
'answer': data['answer'],
|
333 |
+
'scoring_points': data['scoring_points'] if 'scoring_points' in data else ''
|
334 |
+
})
|
335 |
+
|
336 |
+
data_df = pd.DataFrame(self.data_list)
|
337 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
338 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
339 |
+
|
340 |
+
if modelscope_flag_set():
|
341 |
+
from modelscope import dataset_snapshot_download
|
342 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
343 |
+
else:
|
344 |
+
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
345 |
+
huggingface_hub.login(hf_token)
|
346 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
347 |
+
|
348 |
+
generate_tsv(dataset_path)
|
349 |
+
|
350 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
351 |
+
return dict(root=dataset_path, data_file=data_file)
|
352 |
+
|
353 |
+
def qa_template(self, data):
|
354 |
+
question = f"{data['question']}"
|
355 |
+
answer = data['answer']
|
356 |
+
return question, answer
|
357 |
+
|
358 |
+
def save_video_frames(self, line, num_frames=8, fps=-1):
|
359 |
+
suffix = line['video'].split('.')[-1]
|
360 |
+
video = line['video'].replace(f'.{suffix}','')
|
361 |
+
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
|
362 |
+
vid = decord.VideoReader(vid_path)
|
363 |
+
video_info = {
|
364 |
+
'fps': vid.get_avg_fps(),
|
365 |
+
'n_frames': len(vid),
|
366 |
+
}
|
367 |
+
if num_frames > 0 and fps < 0:
|
368 |
+
step_size = len(vid) / (num_frames + 1)
|
369 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
370 |
+
frame_paths = self.frame_paths(video, num_frames)
|
371 |
+
elif fps > 0:
|
372 |
+
# not constrained by num_frames, get frames by fps
|
373 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
374 |
+
required_frames = int(total_duration * fps)
|
375 |
+
step_size = video_info['fps'] / fps
|
376 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
377 |
+
frame_paths = self.frame_paths_fps(video, len(indices), fps)
|
378 |
+
|
379 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
380 |
+
|
381 |
+
if not flag:
|
382 |
+
images = [vid[i].asnumpy() for i in indices]
|
383 |
+
images = [Image.fromarray(arr) for arr in images]
|
384 |
+
for im, pth in zip(images, frame_paths):
|
385 |
+
if not osp.exists(pth):
|
386 |
+
im.save(pth)
|
387 |
+
|
388 |
+
return frame_paths
|
389 |
+
|
390 |
+
def save_video_into_images(self, line, num_frames, fps):
|
391 |
+
frame_paths = self.save_video_frames(line, num_frames, fps)
|
392 |
+
return frame_paths
|
393 |
+
|
394 |
+
def build_prompt(self, line, num_frames, video_llm, fps=-1):
|
395 |
+
if isinstance(line, int):
|
396 |
+
assert line < len(self)
|
397 |
+
line = self.data.iloc[line]
|
398 |
+
|
399 |
+
question, answer = self.qa_template(line)
|
400 |
+
message = [dict(type='text', value=self.SYS, role='system')]
|
401 |
+
message.append(dict(type='text', value=question))
|
402 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
403 |
+
if video_llm:
|
404 |
+
message.append(dict(type='video', value=video_path))
|
405 |
+
else:
|
406 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
407 |
+
for im in img_frame_paths:
|
408 |
+
message.append(dict(type='image', value=im))
|
409 |
+
return message
|
410 |
+
|
411 |
+
@classmethod
|
412 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
413 |
+
|
414 |
+
model = judge_kwargs['model'] if 'model' in judge_kwargs else judge_kwargs.setdefault('model', 'gpt-4-0125')
|
415 |
+
if model != 'gpt-4-0125':
|
416 |
+
print('MLVU Open Ended default using gpt-4-0125! So judge model is changed to gpt-4-0125')
|
417 |
+
judge_kwargs['model'] = 'gpt-4-0125'
|
418 |
+
|
419 |
+
suffix = eval_file.split('.')[-1]
|
420 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
421 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
422 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
423 |
+
|
424 |
+
if not osp.exists(score_file):
|
425 |
+
data = load(eval_file)
|
426 |
+
model_dict = {
|
427 |
+
'sub_scene': build_judge(system_prompt=system_prompt_sub_scene, **judge_kwargs),
|
428 |
+
'summary': build_judge(system_prompt=system_prompt_summary, **judge_kwargs)
|
429 |
+
}
|
430 |
+
lt = len(data)
|
431 |
+
lines = [data.iloc[i] for i in range(lt)]
|
432 |
+
tups = [(model_dict[line['task_type']], line) for line in lines]
|
433 |
+
indices = [line['index'] for line in lines]
|
434 |
+
|
435 |
+
ans = {}
|
436 |
+
if osp.exists(tmp_file):
|
437 |
+
ans = load(tmp_file)
|
438 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
439 |
+
indices = [i for i in indices if i not in ans]
|
440 |
+
|
441 |
+
if len(indices):
|
442 |
+
_ = track_progress_rich(
|
443 |
+
MLVU_OpenEnded_generate,
|
444 |
+
tups,
|
445 |
+
nproc=nproc,
|
446 |
+
chunksize=nproc,
|
447 |
+
keys=indices,
|
448 |
+
save=tmp_file,
|
449 |
+
)
|
450 |
+
ans = load(tmp_file)
|
451 |
+
data = MLVU_OpenEnded_extract(ans, data)
|
452 |
+
dump(data, score_file)
|
453 |
+
|
454 |
+
rating = get_dimension_rating(score_file)
|
455 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/mmbench_video.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
from ..smp import *
|
3 |
+
from .video_base import VideoBaseDataset
|
4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
5 |
+
from ..utils import track_progress_rich
|
6 |
+
|
7 |
+
|
8 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
9 |
+
|
10 |
+
|
11 |
+
def unwrap_hf_pkl(pth, suffix='.mp4'):
|
12 |
+
base_dir = os.path.join(pth, 'video_pkl/')
|
13 |
+
target_dir = os.path.join(pth, 'video/')
|
14 |
+
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
|
15 |
+
pickle_files.sort()
|
16 |
+
|
17 |
+
if not os.path.exists(target_dir):
|
18 |
+
os.makedirs(target_dir, exist_ok=True)
|
19 |
+
for pickle_file in pickle_files:
|
20 |
+
with open(pickle_file, 'rb') as file:
|
21 |
+
video_data = pickle.load(file)
|
22 |
+
# For each video file in the pickle file, write its contents to a new mp4 file
|
23 |
+
for video_name, video_content in video_data.items():
|
24 |
+
output_path = os.path.join(target_dir, f'{video_name}{suffix}')
|
25 |
+
with open(output_path, 'wb') as output_file:
|
26 |
+
output_file.write(video_content)
|
27 |
+
print('The video file has been restored and stored from the pickle file.')
|
28 |
+
else:
|
29 |
+
print('The video file already exists.')
|
30 |
+
|
31 |
+
|
32 |
+
class MMBenchVideo(VideoBaseDataset):
|
33 |
+
|
34 |
+
MD5 = '98f7df3eb1007fc375ea6fe88a98e2ff'
|
35 |
+
SYS = 'You are an AI assistant responsible for answering questions about videos.'
|
36 |
+
FRAMES_TMPL_PACK = """
|
37 |
+
You will be provided with {} separate frames uniformly sampled from a video, \
|
38 |
+
the frames are provided in chronological order of the video.
|
39 |
+
Please analyze these images and provide the answer / answers to the \
|
40 |
+
following question / questions about the video content.
|
41 |
+
If multiple questions are provided (with indices I1, I2, I3, ...), \
|
42 |
+
you should organize your answers in the following json format:
|
43 |
+
{{
|
44 |
+
'I1': 'Answer to Question I1',
|
45 |
+
'I2': 'Answer to Question I2',
|
46 |
+
...
|
47 |
+
}}
|
48 |
+
Otherwise, please directly reply with your response to the only question.
|
49 |
+
Even if the information in these separate frames is not enough to give an answer,
|
50 |
+
PLEASE GIVE A RESPONSE TO EACH OF THE QUESTIONS IN THE FORMAT DESCRIBED ABOVE.
|
51 |
+
"""
|
52 |
+
|
53 |
+
FRAMES_TMPL_NOPACK = """
|
54 |
+
You will be provided with {} separate frames uniformly sampled from a video, \
|
55 |
+
the frames are provided in chronological order of the video.
|
56 |
+
Please analyze these images and provide the answer to the question about the video content.
|
57 |
+
Please directly reply with your response to the only question.
|
58 |
+
"""
|
59 |
+
|
60 |
+
TYPE = 'Video-VQA'
|
61 |
+
|
62 |
+
def __init__(self, dataset='MMBench-Video', pack=False):
|
63 |
+
super().__init__(dataset=dataset, pack=pack)
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def supported_datasets(cls):
|
67 |
+
return ['MMBench-Video']
|
68 |
+
|
69 |
+
def prepare_dataset(self, dataset_name='MMBench-Video', repo_id='opencompass/MMBench-Video'):
|
70 |
+
def check_integrity(pth):
|
71 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
72 |
+
if md5(data_file) != self.MD5:
|
73 |
+
return False
|
74 |
+
data = load(data_file)
|
75 |
+
for video_pth in data['video_path']:
|
76 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
77 |
+
return False
|
78 |
+
return True
|
79 |
+
|
80 |
+
cache_path = get_cache_path(repo_id)
|
81 |
+
if cache_path is not None and check_integrity(cache_path):
|
82 |
+
dataset_path = cache_path
|
83 |
+
else:
|
84 |
+
if modelscope_flag_set():
|
85 |
+
from modelscope import dataset_snapshot_download
|
86 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
87 |
+
else:
|
88 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
89 |
+
unwrap_hf_pkl(dataset_path)
|
90 |
+
self.video_path = osp.join(dataset_path, 'video/')
|
91 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
92 |
+
|
93 |
+
return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))
|
94 |
+
|
95 |
+
def build_prompt_pack(self, line, num_frames, fps=-1):
|
96 |
+
if isinstance(line, int):
|
97 |
+
assert line < len(self)
|
98 |
+
video = self.videos[line]
|
99 |
+
elif isinstance(line, pd.Series):
|
100 |
+
video = line['video']
|
101 |
+
elif isinstance(line, str):
|
102 |
+
video = line
|
103 |
+
|
104 |
+
frames = self.save_video_frames(video, num_frames, fps)
|
105 |
+
sub = self.data[self.data['video'] == video]
|
106 |
+
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames))
|
107 |
+
message = [dict(type='text', value=sys_prompt)]
|
108 |
+
for im in frames:
|
109 |
+
message.append(dict(type='image', value=im))
|
110 |
+
nq = len(sub)
|
111 |
+
prompt = 'Questions: \n{}\nAnswers: \n'
|
112 |
+
qs = {int(sub.iloc[i]['index']): sub.iloc[i]['question'] for i in range(nq)}
|
113 |
+
prompt = prompt.format(json.dumps(qs))
|
114 |
+
message.append(dict(type='text', value=prompt))
|
115 |
+
return message
|
116 |
+
|
117 |
+
def build_prompt_nopack(self, line, num_frames, video_llm, fps):
|
118 |
+
if isinstance(line, int):
|
119 |
+
assert line < len(self)
|
120 |
+
line = self.data.iloc[line]
|
121 |
+
if video_llm:
|
122 |
+
question = line['question']
|
123 |
+
prefix, video_idx_path = os.path.split(line['video_path'])
|
124 |
+
message = [dict(type='text', value=question)]
|
125 |
+
message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
|
126 |
+
return message
|
127 |
+
else:
|
128 |
+
frames = self.save_video_frames(line['video'], num_frames, fps)
|
129 |
+
sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames))
|
130 |
+
message = [dict(type='text', value=sys_prompt)]
|
131 |
+
for im in frames:
|
132 |
+
message.append(dict(type='image', value=im))
|
133 |
+
prompt = 'Question: {}\nAnswer: '.format(line['question'])
|
134 |
+
message.append(dict(type='text', value=prompt))
|
135 |
+
return message
|
136 |
+
|
137 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
138 |
+
if self.pack and not video_llm:
|
139 |
+
return self.build_prompt_pack(line, num_frames, fps)
|
140 |
+
else:
|
141 |
+
return self.build_prompt_nopack(line, num_frames, video_llm, fps)
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def remove_side_quote(s, syms=[',', '"', "'"]):
|
145 |
+
if np.all([x in syms for x in s]):
|
146 |
+
return ''
|
147 |
+
while s[0] in syms:
|
148 |
+
s = s[1:]
|
149 |
+
while s[-1] in syms:
|
150 |
+
s = s[:-1]
|
151 |
+
return s
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def robust_json_load(s):
|
155 |
+
try:
|
156 |
+
jsons = list(extract_json_objects(s))
|
157 |
+
assert len(jsons) == 1
|
158 |
+
return jsons[0]
|
159 |
+
except:
|
160 |
+
if '{' in s and s.find('{') == s.rfind('{'):
|
161 |
+
sub_str = s[s.find('{') + 1:].strip()
|
162 |
+
lines = sub_str.split('\n')
|
163 |
+
res = {}
|
164 |
+
for l in lines:
|
165 |
+
l = l.strip()
|
166 |
+
if ': ' in l:
|
167 |
+
key = l.split(': ')[0].strip()
|
168 |
+
val = l.split(': ')[1].strip()
|
169 |
+
key = MMBenchVideo.remove_side_quote(key)
|
170 |
+
val = MMBenchVideo.remove_side_quote(val)
|
171 |
+
if len(key) and len(val):
|
172 |
+
res[key] = val
|
173 |
+
return res
|
174 |
+
return None
|
175 |
+
|
176 |
+
def load_pack_answers(self, data_raw):
|
177 |
+
vstats = defaultdict(lambda: 0)
|
178 |
+
data = defaultdict(lambda: {})
|
179 |
+
|
180 |
+
for k in data_raw:
|
181 |
+
ans = data_raw[k].strip()
|
182 |
+
if FAIL_MSG in ans:
|
183 |
+
vstats['GEN_FAIL'] += 1
|
184 |
+
continue
|
185 |
+
res = self.robust_json_load(ans)
|
186 |
+
if res is not None:
|
187 |
+
data[k] = res
|
188 |
+
vstats['PARSE_OK'] += 1
|
189 |
+
else:
|
190 |
+
vstats['PARSE_FAIL'] += 1
|
191 |
+
|
192 |
+
# return data
|
193 |
+
meta = cp.deepcopy(self.data)
|
194 |
+
lt = len(meta)
|
195 |
+
prediction = []
|
196 |
+
for i in range(lt):
|
197 |
+
line = meta.iloc[i]
|
198 |
+
vid = line['video']
|
199 |
+
idx = str(line['index'])
|
200 |
+
prediction.append(data[vid][idx] if idx in data[vid] else None)
|
201 |
+
meta['prediction'] = prediction
|
202 |
+
vstats['VALIDQ'] = len([x for x in prediction if x is not None])
|
203 |
+
vstats['INVALIDQ'] = len([x for x in prediction if x is None])
|
204 |
+
return meta, vstats
|
205 |
+
|
206 |
+
# It returns a dictionary
|
207 |
+
@classmethod
|
208 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
209 |
+
from .utils.mmbench_video import get_dimension_rating, system_prompt, build_prompt
|
210 |
+
|
211 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
212 |
+
judge = judge_kwargs['model']
|
213 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
214 |
+
|
215 |
+
tmp_file = eval_file.replace('.xlsx', f'_{judge}_tmp.pkl')
|
216 |
+
tgt_file = eval_file.replace('.xlsx', f'_{judge}_rating.json')
|
217 |
+
score_file = eval_file.replace('.xlsx', f'_{judge}_score.xlsx')
|
218 |
+
|
219 |
+
model = build_judge(system_prompt=system_prompt, **judge_kwargs)
|
220 |
+
assert model.working(), 'MMBench-Video evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE
|
221 |
+
|
222 |
+
if not osp.exists(score_file):
|
223 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
224 |
+
res = {k: v for k, v in res.items() if model.fail_msg not in v}
|
225 |
+
|
226 |
+
data = load(eval_file)
|
227 |
+
data_un = data[~data['index'].isin(res)]
|
228 |
+
data_un = data_un[~pd.isna(data_un['prediction'])]
|
229 |
+
lt = len(data_un)
|
230 |
+
prompts = [build_prompt(data_un.iloc[i]) for i in range(lt)]
|
231 |
+
indices = [data_un.iloc[i]['index'] for i in range(lt)]
|
232 |
+
|
233 |
+
if len(prompts):
|
234 |
+
_ = track_progress_rich(
|
235 |
+
model.generate,
|
236 |
+
prompts,
|
237 |
+
keys=indices,
|
238 |
+
save=tmp_file,
|
239 |
+
nproc=nproc,
|
240 |
+
chunksize=nproc
|
241 |
+
)
|
242 |
+
score_map = load(tmp_file)
|
243 |
+
data['score'] = [score_map[idx] if idx in score_map else -1 for idx in data['index']]
|
244 |
+
rejected = [x for x in score_map.values() if FAIL_MSG in x]
|
245 |
+
data['score'] = [int(x) if istype(x, int) else -1 for x in data['score']]
|
246 |
+
print(
|
247 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(score_map)} questions, '
|
248 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
249 |
+
f'Those questions will be counted as 0 score in ALL rating, and will not be counted in VALID rating.'
|
250 |
+
)
|
251 |
+
|
252 |
+
dump(data, score_file)
|
253 |
+
|
254 |
+
rating = get_dimension_rating(score_file)
|
255 |
+
dump(rating, tgt_file)
|
256 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/mmgenbench.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import pandas as pd
|
3 |
+
from abc import abstractmethod
|
4 |
+
from ..smp import *
|
5 |
+
from .image_base import ImageBaseDataset
|
6 |
+
|
7 |
+
|
8 |
+
class MMGenBench(ImageBaseDataset):
|
9 |
+
|
10 |
+
prompt_list = [
|
11 |
+
"""
|
12 |
+
# Role
|
13 |
+
You are an expert in the field of image understanding, focusing on the \
|
14 |
+
understanding of images and generating the image caption-prompt.
|
15 |
+
|
16 |
+
# Definition Explanation
|
17 |
+
image caption-prompt: Refers to the caption or description of an image, \
|
18 |
+
used to provide to a Text-to-Image model to generate a new image.
|
19 |
+
Text-to-Image model: Can generate a new image based on the provided image \
|
20 |
+
caption-prompt, such as stable diffusion 3, flux, and other image generation models.
|
21 |
+
|
22 |
+
# Task Description
|
23 |
+
Generate an image caption-prompt based on the input image.
|
24 |
+
|
25 |
+
# Key Points and Requirements
|
26 |
+
1. Accurately understand the input image and precisely generate an image caption-prompt.
|
27 |
+
2. The generated image caption-prompt, when provided to the Text-to-Image model, requires the \
|
28 |
+
Text-to-Image model to generate a new image that is as consistent as possible with the input image.
|
29 |
+
3. The generated image caption-prompt must conform to the preferences of the Text-to-Image model.
|
30 |
+
4. The generated image caption-prompt should describe the input image in as much \
|
31 |
+
detail as possible, and it should be between 20 to 60 words.
|
32 |
+
|
33 |
+
# Output Format
|
34 |
+
A string, that is the image caption-prompt. No extra output needed.
|
35 |
+
"""
|
36 |
+
]
|
37 |
+
TYPE = 'GenerateImgPrompt'
|
38 |
+
DATASET_URL = {
|
39 |
+
'MMGenBench-Test': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Test.tsv',
|
40 |
+
'MMGenBench-Domain': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Domain.tsv',
|
41 |
+
}
|
42 |
+
PROMPT_MAP = {
|
43 |
+
'MMGenBench-Test': prompt_list[0],
|
44 |
+
'MMGenBench-Domain': prompt_list[0],
|
45 |
+
}
|
46 |
+
DATASET_MD5 = {
|
47 |
+
'MMGenBench-Test': "94f8dac6bbf7c20be403f99adeaa73da",
|
48 |
+
'MMGenBench-Domain': "5c10daf6e2c5f08bdfb0701aa6db86bb",
|
49 |
+
}
|
50 |
+
|
51 |
+
def __init__(self, dataset='MMGenBench', **kwargs):
|
52 |
+
super().__init__(dataset, **kwargs)
|
53 |
+
warnings.warn('This dataset is for inference only and does not support direct output of evaluation results.\n')
|
54 |
+
warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
|
55 |
+
|
56 |
+
def load_data(self, dataset):
|
57 |
+
data = super().load_data(dataset)
|
58 |
+
if 'question' not in data:
|
59 |
+
data['question'] = [(
|
60 |
+
self.PROMPT_MAP[dataset]
|
61 |
+
)] * len(data)
|
62 |
+
return data
|
63 |
+
|
64 |
+
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
65 |
+
@abstractmethod
|
66 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
67 |
+
warnings.warn('This evaluation method is not supported.\n')
|
68 |
+
warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
|
69 |
+
return None
|
VLMEvalKit/vlmeval/dataset/mmlongbench.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import math
|
3 |
+
from urllib.request import urlopen
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
|
7 |
+
from vlmeval.dataset.utils import build_judge, levenshtein_distance
|
8 |
+
from vlmeval.smp import *
|
9 |
+
from .image_base import ImageBaseDataset
|
10 |
+
|
11 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
12 |
+
|
13 |
+
|
14 |
+
def get_gpt4_ICE():
|
15 |
+
example_1 = """
|
16 |
+
---
|
17 |
+
Question: List the primary questions asked about the services in this report.
|
18 |
+
Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n
|
19 |
+
1. Is the service safe?\n
|
20 |
+
2. Is the service effective?\n
|
21 |
+
3. Is the service caring?\n
|
22 |
+
4. Is the service responsive?\n
|
23 |
+
5. Is the service well-led?
|
24 |
+
Extracted answer: [
|
25 |
+
'Is the servife safe?',
|
26 |
+
'Is the service effective',
|
27 |
+
'Is the serve caring?',
|
28 |
+
'Is the service responsive?',
|
29 |
+
'Is the service well-led?'
|
30 |
+
]
|
31 |
+
Answer format: List\n
|
32 |
+
"""
|
33 |
+
|
34 |
+
example_2 = """
|
35 |
+
---
|
36 |
+
Question: How many regulations of the HSCA 2008 are breached in all according to this report?
|
37 |
+
Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities)
|
38 |
+
Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and
|
39 |
+
improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11:
|
40 |
+
Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17:
|
41 |
+
Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9.
|
42 |
+
Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement
|
43 |
+
the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing,
|
44 |
+
safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to
|
45 |
+
notify the CQC of incidents.
|
46 |
+
Extracted answer: 10
|
47 |
+
Answer format: Integer\n
|
48 |
+
"""
|
49 |
+
|
50 |
+
example_3 = """
|
51 |
+
---
|
52 |
+
Question: According to the survey that is the percentage of Chinese who are paying more or
|
53 |
+
about the same attention to politics after Trump's election?
|
54 |
+
Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying
|
55 |
+
more or about the same attention to politics after Trump's election. The report focuses primarily on American
|
56 |
+
demographics and does not include specific details about the Chinese population in relation to this question. If
|
57 |
+
you need information about a different demographic or a summary of the findings from the American demographic,
|
58 |
+
I can certainly help with that!
|
59 |
+
Extracted answer: Not answerable
|
60 |
+
Answer format: String\n
|
61 |
+
"""
|
62 |
+
|
63 |
+
example_4 = """
|
64 |
+
---
|
65 |
+
Question: How many quotations from male respondent over 50 years old are included in this report?
|
66 |
+
Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the
|
67 |
+
text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be
|
68 |
+
able to help you with your question.
|
69 |
+
Extracted answer: Fail to answer
|
70 |
+
Answer format: String\n
|
71 |
+
"""
|
72 |
+
|
73 |
+
return [example_1, example_2, example_3, example_4]
|
74 |
+
|
75 |
+
|
76 |
+
def build_mmlongbench_gpt4_prompt(line):
|
77 |
+
task_description = """
|
78 |
+
Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
|
79 |
+
- Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List.
|
80 |
+
If you find the analysis the question can not be answered from the given documents, type "Not answerable".
|
81 |
+
Exception: If the analysis only tells you that it can not read/understand the images or documents,
|
82 |
+
type "Fail to answer".
|
83 |
+
- Please make your response as concise as possible. Also note that your response should be formatted as below:
|
84 |
+
```
|
85 |
+
Extracted answer: [answer]
|
86 |
+
Answer format: [answer format]
|
87 |
+
```
|
88 |
+
Please read the following example, then extract the answer from the model response
|
89 |
+
and type it at the end of the prompt.\n
|
90 |
+
"""
|
91 |
+
question = line['question']
|
92 |
+
prediction = str(line['prediction'])
|
93 |
+
prompt = task_description
|
94 |
+
examples = get_gpt4_ICE()
|
95 |
+
for example in examples:
|
96 |
+
prompt += example
|
97 |
+
prompt += '---\nQuestion:' + question + '\n'
|
98 |
+
prompt += 'Analysis: ' + prediction
|
99 |
+
return prompt
|
100 |
+
|
101 |
+
|
102 |
+
def anls_compute(groundtruth, prediction, threshold=0.5):
|
103 |
+
dist = levenshtein_distance(groundtruth, prediction)
|
104 |
+
length = max(len(groundtruth.upper()), len(prediction.upper()))
|
105 |
+
value = 0.0 if length == 0 else float(dist) / float(length)
|
106 |
+
anls = 1.0 - value
|
107 |
+
if anls <= threshold:
|
108 |
+
anls = 0.0
|
109 |
+
return anls
|
110 |
+
|
111 |
+
|
112 |
+
def is_float_equal(reference, prediction, include_percentage: bool = False, is_close: float = False) -> bool:
|
113 |
+
def get_precision(gt_ans: float) -> int:
|
114 |
+
precision = 3
|
115 |
+
if '.' in str(gt_ans):
|
116 |
+
precision = len(str(gt_ans).split('.')[-1])
|
117 |
+
return precision
|
118 |
+
|
119 |
+
reference = float(str(reference).strip().rstrip('%').strip())
|
120 |
+
try:
|
121 |
+
prediction = float(str(prediction).strip().rstrip('%').strip())
|
122 |
+
except:
|
123 |
+
return False
|
124 |
+
|
125 |
+
if include_percentage:
|
126 |
+
gt_result = [reference / 100, reference, reference * 100]
|
127 |
+
else:
|
128 |
+
gt_result = [reference]
|
129 |
+
for item in gt_result:
|
130 |
+
try:
|
131 |
+
if is_close:
|
132 |
+
if math.isclose(item, prediction, rel_tol=0.01):
|
133 |
+
return True
|
134 |
+
precision = max(min(get_precision(prediction), get_precision(item)), 2)
|
135 |
+
if round(prediction, precision) == round(item, precision):
|
136 |
+
return True
|
137 |
+
except Exception:
|
138 |
+
continue
|
139 |
+
return False
|
140 |
+
|
141 |
+
|
142 |
+
def get_clean_string(s):
|
143 |
+
s = str(s).lower().strip()
|
144 |
+
if s.endswith('mile'):
|
145 |
+
s.rstrip('mile').strip()
|
146 |
+
if s.endswith('miles'):
|
147 |
+
s.rstrip('miles').strip()
|
148 |
+
if s.endswith('million'):
|
149 |
+
s.rstrip('million').strip()
|
150 |
+
# remove parenthesis
|
151 |
+
s = re.sub(r'\s*\([^)]*\)', '', s).strip()
|
152 |
+
# remove quotes
|
153 |
+
s = re.sub(r"^['\"]|['\"]$", '', s).strip()
|
154 |
+
s = s.strip().lstrip('$').strip()
|
155 |
+
s = s.strip().rstrip('%').strip()
|
156 |
+
return s
|
157 |
+
|
158 |
+
|
159 |
+
def is_exact_match(s):
|
160 |
+
flag = False
|
161 |
+
# Website
|
162 |
+
if 'https://' in s:
|
163 |
+
flag = True
|
164 |
+
# code file
|
165 |
+
if s.endswith('.py') or s.endswith('ipynb'):
|
166 |
+
flag = True
|
167 |
+
if s.startswith('page'):
|
168 |
+
flag = True
|
169 |
+
# telephone number
|
170 |
+
if re.fullmatch(r'\b\d+(-\d+|\s\d+)?\b', s):
|
171 |
+
flag = True
|
172 |
+
# time
|
173 |
+
if 'a.m.' in s or 'p.m.' in s:
|
174 |
+
flag = True
|
175 |
+
# YYYY-MM-DD
|
176 |
+
if re.fullmatch(r'\b\d{4}[-\s]\d{2}[-\s]\d{2}\b', s):
|
177 |
+
flag = True
|
178 |
+
# YYYY-MM
|
179 |
+
if re.fullmatch(r'\b\d{4}[-\s]\d{2}\b', s):
|
180 |
+
flag = True
|
181 |
+
# Email address
|
182 |
+
if re.fullmatch(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', s):
|
183 |
+
flag = True
|
184 |
+
return flag
|
185 |
+
|
186 |
+
|
187 |
+
def isfloat(num):
|
188 |
+
try:
|
189 |
+
float(num)
|
190 |
+
return True
|
191 |
+
except ValueError:
|
192 |
+
return False
|
193 |
+
|
194 |
+
|
195 |
+
def get_font():
|
196 |
+
try:
|
197 |
+
truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
|
198 |
+
ff = urlopen(truetype_url)
|
199 |
+
font = ImageFont.truetype(ff, size=40)
|
200 |
+
except Exception as e:
|
201 |
+
logging.warning(f'{type(e)}: {e}')
|
202 |
+
logging.warning("Fail to download the font. Use the default one.")
|
203 |
+
font = ImageFont.load_default(size=40)
|
204 |
+
return font
|
205 |
+
|
206 |
+
|
207 |
+
def frame2img(img_path_list, font, save_path=None, idx_start=0):
|
208 |
+
imgs = [Image.open(img_path) for img_path in img_path_list]
|
209 |
+
|
210 |
+
new_imgs = []
|
211 |
+
for img in imgs:
|
212 |
+
w, h = img.size
|
213 |
+
scale = w / h
|
214 |
+
if w > h:
|
215 |
+
new_w = 560 * 2
|
216 |
+
new_h = int(560 * 2 / scale)
|
217 |
+
else:
|
218 |
+
new_w = int(560 * 2 * scale)
|
219 |
+
new_h = 560 * 2
|
220 |
+
img = transforms.functional.resize(img, [new_h, new_w],)
|
221 |
+
new_imgs.append(img)
|
222 |
+
imgs = new_imgs
|
223 |
+
new_w = 0
|
224 |
+
new_h = 0
|
225 |
+
pad = 40
|
226 |
+
if w > h:
|
227 |
+
for im in imgs:
|
228 |
+
w, h = im.size
|
229 |
+
new_w = max(new_w, w)
|
230 |
+
new_h += h + 10 + pad
|
231 |
+
new_img = Image.new("RGB", (new_w, new_h), "white")
|
232 |
+
draw = ImageDraw.Draw(new_img)
|
233 |
+
curr_h = 0
|
234 |
+
for idx, im in enumerate(imgs):
|
235 |
+
w, h = im.size
|
236 |
+
new_img.paste(im, (0, pad + curr_h))
|
237 |
+
draw.text((0, curr_h), f"<IMAGE {idx+idx_start}>", font=font, fill="black")
|
238 |
+
if idx + 1 < len(imgs):
|
239 |
+
draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
|
240 |
+
curr_h += h + 10 + pad
|
241 |
+
else:
|
242 |
+
for im in imgs:
|
243 |
+
w, h = im.size
|
244 |
+
new_w += w + 10
|
245 |
+
new_h = max(new_h, h)
|
246 |
+
new_h += pad
|
247 |
+
new_img = Image.new('RGB', (new_w, new_h), 'white')
|
248 |
+
draw = ImageDraw.Draw(new_img)
|
249 |
+
curr_w = 0
|
250 |
+
for idx, im in enumerate(imgs):
|
251 |
+
w, h = im.size
|
252 |
+
new_img.paste(im, (curr_w, pad))
|
253 |
+
draw.text((curr_w, 0), f"<IMAGE {idx+idx_start}>", font=font, fill='black')
|
254 |
+
if idx + 1 < len(imgs):
|
255 |
+
draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
|
256 |
+
curr_w += w + 10
|
257 |
+
|
258 |
+
if save_path is not None:
|
259 |
+
new_img.save(save_path)
|
260 |
+
|
261 |
+
return new_img
|
262 |
+
|
263 |
+
|
264 |
+
def concat_images(image_list, max_concat=1, column_num=1):
|
265 |
+
concatenated_images = []
|
266 |
+
if column_num == -1:
|
267 |
+
MAX_COLUMN_NUM = 20
|
268 |
+
max_concat = 1
|
269 |
+
while len(image_list) / max_concat > MAX_COLUMN_NUM:
|
270 |
+
max_concat += 1
|
271 |
+
interval = max(math.ceil(len(image_list) / max_concat), 1)
|
272 |
+
for i in range(0, len(image_list), interval):
|
273 |
+
batch_images = image_list[i:i + interval]
|
274 |
+
concatenated_image = frame2img(batch_images, font=get_font(), idx_start=i)
|
275 |
+
concatenated_images.append(concatenated_image)
|
276 |
+
else:
|
277 |
+
interval = max(math.ceil(len(image_list) / max_concat), 1)
|
278 |
+
for i in range(0, len(image_list), interval):
|
279 |
+
batch_images = [Image.open(filename) for filename in image_list[i:i + interval]]
|
280 |
+
if column_num == 1:
|
281 |
+
total_height = batch_images[0].height * len(batch_images)
|
282 |
+
else:
|
283 |
+
total_height = batch_images[0].height * ((len(batch_images) - 1) // column_num + 1)
|
284 |
+
concatenated_image = Image.new('RGB', (batch_images[0].width * column_num, total_height), 'white')
|
285 |
+
|
286 |
+
x_offset, y_offset = 0, 0
|
287 |
+
for count, image in enumerate(batch_images):
|
288 |
+
concatenated_image.paste(image, (x_offset, y_offset))
|
289 |
+
x_offset += image.width
|
290 |
+
if (count + 1) % column_num == 0:
|
291 |
+
y_offset += image.height
|
292 |
+
x_offset = 0
|
293 |
+
concatenated_images.append(concatenated_image)
|
294 |
+
return concatenated_images
|
295 |
+
|
296 |
+
|
297 |
+
def eval_score(gt, pred, answer_type):
|
298 |
+
if answer_type == 'Int':
|
299 |
+
try:
|
300 |
+
gt, pred = int(gt), int(float(pred))
|
301 |
+
except:
|
302 |
+
pred = ''
|
303 |
+
score = (gt == pred)
|
304 |
+
elif answer_type == 'Float':
|
305 |
+
try:
|
306 |
+
gt = float(get_clean_string(str(gt)))
|
307 |
+
pred = float(get_clean_string(str(pred)))
|
308 |
+
except:
|
309 |
+
pred = ''
|
310 |
+
score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
|
311 |
+
elif answer_type == 'Str':
|
312 |
+
gt = get_clean_string(gt)
|
313 |
+
pred = get_clean_string(pred)
|
314 |
+
if is_exact_match(gt):
|
315 |
+
score = (gt == pred)
|
316 |
+
else:
|
317 |
+
score = anls_compute(gt, pred)
|
318 |
+
else:
|
319 |
+
if isinstance(gt, str) and gt.startswith('['):
|
320 |
+
gt = eval(gt)
|
321 |
+
if not isinstance(gt, list):
|
322 |
+
gt = [gt]
|
323 |
+
if isinstance(pred, str) and pred.startswith('['):
|
324 |
+
pred = eval(pred)
|
325 |
+
if not isinstance(pred, list):
|
326 |
+
pred = [pred]
|
327 |
+
print(len(gt), len(pred))
|
328 |
+
if len(gt) != len(pred):
|
329 |
+
score = 0.0
|
330 |
+
else:
|
331 |
+
gt = sorted([get_clean_string(a) for a in gt])
|
332 |
+
pred = sorted([get_clean_string(a) for a in pred])
|
333 |
+
print(gt, pred)
|
334 |
+
if isfloat(gt[0]) or is_exact_match(gt[0]):
|
335 |
+
score = ('-'.join(gt) == '-'.join(pred))
|
336 |
+
else:
|
337 |
+
score = min([anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred)])
|
338 |
+
|
339 |
+
return float(score)
|
340 |
+
|
341 |
+
|
342 |
+
def MMLongBench_auxeval(model, line):
|
343 |
+
prompt = build_mmlongbench_gpt4_prompt(line)
|
344 |
+
log = ''
|
345 |
+
retry = 5
|
346 |
+
|
347 |
+
for i in range(retry):
|
348 |
+
prediction = line['prediction']
|
349 |
+
res = model.generate(prompt, temperature=i * 0.5)
|
350 |
+
|
351 |
+
if FAIL_MSG in res:
|
352 |
+
log += f'Try {i}: output is {prediction}, failed to parse.\n'
|
353 |
+
else:
|
354 |
+
log += 'Succeed'
|
355 |
+
try:
|
356 |
+
pred = res.split('Answer format:')[0].split('Extracted answer:')[1].strip()
|
357 |
+
except:
|
358 |
+
pred = ''
|
359 |
+
return dict(log=log, res=res, pred=pred)
|
360 |
+
log += 'All 5 retries failed.\n'
|
361 |
+
return dict(log=log, res='', pred='')
|
362 |
+
|
363 |
+
|
364 |
+
def get_f1(data):
|
365 |
+
gt_pos_data = data[data.apply(lambda k: k['answer'] != 'Not answerable', axis=1)]
|
366 |
+
pred_pos_data = data[data.apply(lambda k: k['pred'] != 'Not answerable', axis=1)]
|
367 |
+
recall = sum(gt_pos_data['score'].tolist()) / len(gt_pos_data)
|
368 |
+
precision = sum(pred_pos_data['score'].tolist()) / len(pred_pos_data)
|
369 |
+
return 2 * recall * precision / (recall + precision)
|
370 |
+
|
371 |
+
|
372 |
+
def MMLongBench_acc(result_file):
|
373 |
+
data = load(result_file)
|
374 |
+
overall_score = 0.0
|
375 |
+
score_list = list()
|
376 |
+
for i in range(len(data)):
|
377 |
+
item = data.iloc[i]
|
378 |
+
try:
|
379 |
+
score = eval_score(item['answer'], item['pred'], item['answer_format'])
|
380 |
+
except:
|
381 |
+
score = 0.0
|
382 |
+
score_list.append(score)
|
383 |
+
overall_score += score
|
384 |
+
|
385 |
+
data['score'] = score_list
|
386 |
+
dump(data, result_file)
|
387 |
+
|
388 |
+
data_chart = data[data.apply(lambda k: 'Chart' in eval(k['evidence_sources']), axis=1)]
|
389 |
+
data_table = data[data.apply(lambda k: 'Table' in eval(k['evidence_sources']), axis=1)]
|
390 |
+
data_image = data[data.apply(lambda k: 'Figure' in eval(k['evidence_sources']), axis=1)]
|
391 |
+
data_text = data[data.apply(lambda k: 'Pure-text (Plain-text)' in eval(k['evidence_sources']), axis=1)]
|
392 |
+
data_layout = data[data.apply(lambda k: 'Generalized-text (Layout)' in eval(k['evidence_sources']), axis=1)]
|
393 |
+
|
394 |
+
data_single = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 1, axis=1)]
|
395 |
+
data_multi = data[data.apply(lambda k: len(eval(k['evidence_pages'])) > 1, axis=1)]
|
396 |
+
data_unans = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 0, axis=1)]
|
397 |
+
|
398 |
+
res = dict()
|
399 |
+
res['category'] = [
|
400 |
+
'overall_f1', 'overall_acc', 'text', 'layout', 'table', 'chart',
|
401 |
+
'image', 'single-page', 'multi-page', 'unanswerable'
|
402 |
+
]
|
403 |
+
res['num'] = [
|
404 |
+
len(data), len(data), len(data_text), len(data_layout), len(data_table),
|
405 |
+
len(data_chart), len(data_image), len(data_single), len(data_multi), len(data_unans)
|
406 |
+
]
|
407 |
+
res['avg_score'] = [
|
408 |
+
get_f1(data),
|
409 |
+
overall_score / len(data),
|
410 |
+
sum(data_text['score'].tolist()) / len(data_text) if len(data_text) > 0 else 0.0,
|
411 |
+
sum(data_layout['score'].tolist()) / len(data_layout) if len(data_layout) > 0 else 0.0,
|
412 |
+
sum(data_table['score'].tolist()) / len(data_table) if len(data_table) > 0 else 0.0,
|
413 |
+
sum(data_chart['score'].tolist()) / len(data_chart) if len(data_chart) > 0 else 0.0,
|
414 |
+
sum(data_image['score'].tolist()) / len(data_image) if len(data_image) > 0 else 0.0,
|
415 |
+
sum(data_single['score'].tolist()) / len(data_single) if len(data_single) > 0 else 0.0,
|
416 |
+
sum(data_multi['score'].tolist()) / len(data_multi) if len(data_multi) > 0 else 0.0,
|
417 |
+
sum(data_unans['score'].tolist()) / len(data_unans) if len(data_unans) > 0 else 0.0,
|
418 |
+
]
|
419 |
+
res = pd.DataFrame(res)
|
420 |
+
return res
|
421 |
+
|
422 |
+
|
423 |
+
class MMLongBench(ImageBaseDataset):
|
424 |
+
|
425 |
+
TYPE = 'VQA'
|
426 |
+
|
427 |
+
DATASET_URL = {
|
428 |
+
'MMLongBench_DOC': 'https://opencompass.openxlab.space/utils/VLMEval/MMLongBench_DOC.tsv',
|
429 |
+
}
|
430 |
+
DATASET_MD5 = {
|
431 |
+
'MMLongBench_DOC': '9b393e1f4c52718380d50586197eac9b',
|
432 |
+
}
|
433 |
+
|
434 |
+
SUPPORTED_MODELS = {
|
435 |
+
'GPT4': (1, 1),
|
436 |
+
'GPT4V': (1, 1),
|
437 |
+
'GPT4V_HIGH': (1, 1),
|
438 |
+
'GPT4o': (1, 1),
|
439 |
+
'GPT4o_HIGH': (1, 1),
|
440 |
+
'GPT4o_MINI': (1, 1),
|
441 |
+
'MiniCPM-Llama3-V-2_5': (1, 5),
|
442 |
+
'InternVL-Chat-V1-5': (5, 2),
|
443 |
+
'XComposer2_4KHD': (1, 5),
|
444 |
+
'XComposer2d5': (1, -1),
|
445 |
+
}
|
446 |
+
|
447 |
+
def __init__(self, dataset, **kwargs):
|
448 |
+
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
449 |
+
model_name = kwargs['model']
|
450 |
+
if not listinstr(self.model_list, model_name):
|
451 |
+
raise AssertionError("{} doesn't support the evaluation on MMLongBench_DOC.".format(model_name))
|
452 |
+
super(MMLongBench, self).__init__(dataset)
|
453 |
+
|
454 |
+
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
455 |
+
self.max_pages = 120
|
456 |
+
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
457 |
+
self.concat_num = concat_num
|
458 |
+
self.column_num = column_num
|
459 |
+
|
460 |
+
def dump_image(self, origin_line):
|
461 |
+
os.makedirs(self.img_root, exist_ok=True)
|
462 |
+
try:
|
463 |
+
import fitz
|
464 |
+
except Exception as e:
|
465 |
+
logging.critical(f'{type(e)}: {e}')
|
466 |
+
logging.critical('Please use `pip install pymupdf` to parse PDF files.')
|
467 |
+
|
468 |
+
line = origin_line.copy()
|
469 |
+
line['image_path'] = line['image_path'][:self.max_pages]
|
470 |
+
skip_pdf_parse = True
|
471 |
+
for im_name in line['image_path']:
|
472 |
+
path = osp.join(self.img_root, im_name)
|
473 |
+
if not read_ok(path):
|
474 |
+
skip_pdf_parse = False
|
475 |
+
break
|
476 |
+
|
477 |
+
# Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
|
478 |
+
if skip_pdf_parse:
|
479 |
+
line['image'] = line['image_path']
|
480 |
+
else:
|
481 |
+
pdf_data = base64.b64decode(line['image'])
|
482 |
+
pdf_file = io.BytesIO(pdf_data)
|
483 |
+
encoded_images = []
|
484 |
+
with fitz.open(stream=pdf_file, filetype='pdf') as doc:
|
485 |
+
doc = doc[:self.max_pages]
|
486 |
+
for page in doc:
|
487 |
+
image = page.get_pixmap(dpi=144)
|
488 |
+
image_file = io.BytesIO(image.tobytes(output='png'))
|
489 |
+
image = Image.open(image_file)
|
490 |
+
encoded_image = encode_image_to_base64(image)
|
491 |
+
encoded_images.append(encoded_image)
|
492 |
+
line['image'] = encoded_images
|
493 |
+
print('process {}'.format(line['doc_id']))
|
494 |
+
|
495 |
+
if 'image' in line:
|
496 |
+
if isinstance(line['image'], list):
|
497 |
+
tgt_path = []
|
498 |
+
assert 'image_path' in line
|
499 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
500 |
+
path = osp.join(self.img_root, im_name)
|
501 |
+
if not read_ok(path):
|
502 |
+
decode_base64_to_image_file(img, path)
|
503 |
+
tgt_path.append(path)
|
504 |
+
else:
|
505 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
506 |
+
if not read_ok(tgt_path):
|
507 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
508 |
+
tgt_path = [tgt_path]
|
509 |
+
else:
|
510 |
+
assert 'image_path' in line
|
511 |
+
tgt_path = toliststr(line['image_path'])
|
512 |
+
|
513 |
+
if self.concat_num > 0 and not self.is_api:
|
514 |
+
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
515 |
+
|
516 |
+
old_tgt_path = tgt_path
|
517 |
+
assert isinstance(old_tgt_path, list)
|
518 |
+
if self.column_num != -1:
|
519 |
+
tgt_path = [
|
520 |
+
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
521 |
+
for i in range(len(concatenated_images))
|
522 |
+
]
|
523 |
+
else:
|
524 |
+
tgt_path = [
|
525 |
+
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all_{}.jpg'.format(i)
|
526 |
+
for i in range(len(concatenated_images))
|
527 |
+
]
|
528 |
+
|
529 |
+
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
530 |
+
if not read_ok(path):
|
531 |
+
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
532 |
+
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
533 |
+
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
534 |
+
return tgt_path
|
535 |
+
|
536 |
+
@classmethod
|
537 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
538 |
+
logger = get_logger('Evaluation')
|
539 |
+
model = judge_kwargs['model']
|
540 |
+
|
541 |
+
suffix = eval_file.split('.')[-1]
|
542 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
543 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
544 |
+
|
545 |
+
if osp.exists(storage):
|
546 |
+
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MMLongBench_eval. ')
|
547 |
+
else:
|
548 |
+
data = load(eval_file)
|
549 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
550 |
+
lt = len(data)
|
551 |
+
lines = [data.iloc[i] for i in range(lt)]
|
552 |
+
tups = [(model, line) for line in lines]
|
553 |
+
indices = [line['index'] for line in lines]
|
554 |
+
|
555 |
+
ans = {}
|
556 |
+
if osp.exists(tmp_file):
|
557 |
+
ans = load(tmp_file)
|
558 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
559 |
+
indices = [i for i in indices if i not in ans]
|
560 |
+
|
561 |
+
if len(indices):
|
562 |
+
new_results = list()
|
563 |
+
for model, line in tqdm(tups):
|
564 |
+
res = MMLongBench_auxeval(model, line)
|
565 |
+
new_results.append(res)
|
566 |
+
|
567 |
+
log_map, res_map, pred_map = {}, {}, {}
|
568 |
+
all_inds = [line['index'] for line in lines]
|
569 |
+
for k, v in zip(all_inds, new_results):
|
570 |
+
log_map[k] = v['log']
|
571 |
+
res_map[k] = v['res']
|
572 |
+
pred_map[k] = v['pred']
|
573 |
+
data['res'] = [res_map[idx] for idx in data['index']]
|
574 |
+
data['log'] = [log_map[idx] for idx in data['index']]
|
575 |
+
data['pred'] = [pred_map[idx] for idx in data['index']]
|
576 |
+
dump(data, storage)
|
577 |
+
|
578 |
+
score = MMLongBench_acc(storage)
|
579 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
580 |
+
|
581 |
+
dump(score, score_pth)
|
582 |
+
logger.info(f'MMLongBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
583 |
+
logger.info('Score: ')
|
584 |
+
logger.info(score)
|
VLMEvalKit/vlmeval/dataset/mmmath.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import sympy as sp
|
4 |
+
import numpy as np
|
5 |
+
from sympy import simplify, Eq, sympify, Pow, pi
|
6 |
+
from sympy.parsing.latex import parse_latex
|
7 |
+
import sys
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from .image_base import ImageBaseDataset
|
13 |
+
from ..utils import track_progress_rich
|
14 |
+
from ..smp import load, dump
|
15 |
+
|
16 |
+
|
17 |
+
class AutoScoringJudge:
|
18 |
+
def __init__(self):
|
19 |
+
# Map of special symbols to their replacements
|
20 |
+
self.special_signal_map = {
|
21 |
+
"\\left": "",
|
22 |
+
"\\right": "",
|
23 |
+
"厘米":"",
|
24 |
+
# "∶": ":",
|
25 |
+
",": ",",
|
26 |
+
"$": "",
|
27 |
+
"(":"(",
|
28 |
+
")":")",
|
29 |
+
"\\infty":"oo",
|
30 |
+
"\\colon ":":",
|
31 |
+
# "\\approx": "=",
|
32 |
+
# "\\simeq": "=",
|
33 |
+
# "\\sim": "=",
|
34 |
+
# "^\\prime": "'",
|
35 |
+
# "^{\\prime}": "'",
|
36 |
+
"+":"+",
|
37 |
+
"\\, ": "",
|
38 |
+
"\\,":"",
|
39 |
+
"^\\circ": "",
|
40 |
+
"^{\\circ}": "",
|
41 |
+
# "%": "",
|
42 |
+
}
|
43 |
+
self.pi = parse_latex("\\pi")
|
44 |
+
# MM-Math default precision
|
45 |
+
self.precision = 1e-2
|
46 |
+
|
47 |
+
def trans_greater_sign_to_interval(self, expr:str):
|
48 |
+
expr_tmp = expr.split("<")
|
49 |
+
return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
|
50 |
+
|
51 |
+
def split_by_comma(self, expr: str):
|
52 |
+
# Splits expressions by commas outside of brackets
|
53 |
+
in_bracket_num = 0
|
54 |
+
splitted_expr = []
|
55 |
+
start_idx = 0
|
56 |
+
for i, char in enumerate(expr):
|
57 |
+
if char in ["(", "["]:
|
58 |
+
in_bracket_num += 1
|
59 |
+
elif char in [")", "]"]:
|
60 |
+
in_bracket_num -= 1
|
61 |
+
elif char == "," and in_bracket_num == 0:
|
62 |
+
splitted_expr.append(expr[start_idx:i].strip())
|
63 |
+
start_idx = i + 1
|
64 |
+
|
65 |
+
if start_idx < len(expr):
|
66 |
+
splitted_expr.append(expr[start_idx:].strip())
|
67 |
+
|
68 |
+
return splitted_expr
|
69 |
+
|
70 |
+
def trans_plus_minus_sign(self, expr_list: list):
|
71 |
+
# Translates plus-minus signs into separate expressions
|
72 |
+
new_expr_list = []
|
73 |
+
for expr in expr_list:
|
74 |
+
if "\\pm" in expr:
|
75 |
+
new_expr_list.append(expr.replace("\\pm", "+"))
|
76 |
+
new_expr_list.append(expr.replace("\\pm", "-"))
|
77 |
+
else:
|
78 |
+
new_expr_list.append(expr)
|
79 |
+
|
80 |
+
return new_expr_list
|
81 |
+
|
82 |
+
def judge(self, expression1, expression2, precision=1e-2):
|
83 |
+
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
|
84 |
+
# Default precision is a list for supporting multiple expressions
|
85 |
+
precision = precision if isinstance(precision, list) else [precision]
|
86 |
+
|
87 |
+
try:
|
88 |
+
expression1, expression2 = self.preprocess(expression1, expression2)
|
89 |
+
except:
|
90 |
+
return False
|
91 |
+
if expression1 == expression2:
|
92 |
+
# print("Exactly equal")
|
93 |
+
return True
|
94 |
+
|
95 |
+
# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
|
96 |
+
expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
|
97 |
+
expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
|
98 |
+
# Check if two < or > in expression
|
99 |
+
if self.is_two_greater_sign(expression1):
|
100 |
+
expression1 = self.trans_greater_sign_to_interval(expression1)
|
101 |
+
|
102 |
+
if self.is_two_greater_sign(expression2):
|
103 |
+
expression2 = self.trans_greater_sign_to_interval(expression2)
|
104 |
+
|
105 |
+
expression1 = self.split_by_comma(expression1)
|
106 |
+
expression2 = self.split_by_comma(expression2)
|
107 |
+
|
108 |
+
temp_list1 = self.trans_plus_minus_sign(expression1)
|
109 |
+
temp_list2 = self.trans_plus_minus_sign(expression2)
|
110 |
+
|
111 |
+
# Set up a list for allowed errors
|
112 |
+
if len(precision) <= 1:
|
113 |
+
precision = precision * len(temp_list1)
|
114 |
+
|
115 |
+
if len(temp_list1) != len(temp_list2):
|
116 |
+
return False
|
117 |
+
|
118 |
+
# Check if elements in both lists can be paired and are equal
|
119 |
+
idx = -1
|
120 |
+
while len(temp_list1) != 0:
|
121 |
+
idx = (idx + 1) % len(temp_list1)
|
122 |
+
|
123 |
+
item1 = temp_list1[idx]
|
124 |
+
self.precision = precision[idx]
|
125 |
+
|
126 |
+
for item2 in temp_list2:
|
127 |
+
if self.is_equal(item1, item2):
|
128 |
+
temp_list1.remove(item1)
|
129 |
+
temp_list2.remove(item2)
|
130 |
+
precision.remove(self.precision)
|
131 |
+
break
|
132 |
+
else:
|
133 |
+
# If no match was found, return False
|
134 |
+
return False
|
135 |
+
|
136 |
+
# If all elements are matched, return True
|
137 |
+
return True
|
138 |
+
|
139 |
+
def is_interval(self, expr):
|
140 |
+
# Checks if an expression is an interval
|
141 |
+
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
|
142 |
+
|
143 |
+
def is_two_greater_sign(self, expr):
|
144 |
+
match = re.findall(r'<', expr)
|
145 |
+
return len(match) == 2
|
146 |
+
|
147 |
+
def sympy_sub_pi(self, expression_sympy):
|
148 |
+
# Replaces the symbol for pi in sympy expressions with its numerical value
|
149 |
+
return expression_sympy.subs(self.pi, math.pi)
|
150 |
+
|
151 |
+
def is_equal(self, expression1, expression2):
|
152 |
+
# Default first expression is ground truth. Check if expressions are equal in different aspects
|
153 |
+
if expression1 == expression2 and expression1 != "" and expression2 != "":
|
154 |
+
# print("Equivalent natively")
|
155 |
+
return True
|
156 |
+
|
157 |
+
# First check if both are intervals
|
158 |
+
if self.is_interval(expression1) and self.is_interval(expression2):
|
159 |
+
try:
|
160 |
+
if self.interval_equal(expression1, expression2):
|
161 |
+
# print("Interval equivalent")
|
162 |
+
return True
|
163 |
+
except:
|
164 |
+
return False
|
165 |
+
|
166 |
+
# Then check for numerical equality
|
167 |
+
try:
|
168 |
+
if self.numerical_equal(expression1, expression2):
|
169 |
+
# print("Numerically equivalent")
|
170 |
+
return True
|
171 |
+
except:
|
172 |
+
pass
|
173 |
+
# Then check if expressions are mathematically equal
|
174 |
+
try:
|
175 |
+
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
|
176 |
+
# print("Expression equivalent")
|
177 |
+
return True
|
178 |
+
except:
|
179 |
+
pass
|
180 |
+
|
181 |
+
# Lastly, check for equation equality
|
182 |
+
try:
|
183 |
+
if self.equation_equal(expression1, expression2):
|
184 |
+
# print("Equation equivalent")
|
185 |
+
return True
|
186 |
+
except:
|
187 |
+
pass
|
188 |
+
|
189 |
+
return False
|
190 |
+
|
191 |
+
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
|
192 |
+
# Check if two numerical values are equal within an allowed error range
|
193 |
+
# Includes possible percentage cases
|
194 |
+
reference = float(expression1)
|
195 |
+
prediction = float(expression2)
|
196 |
+
|
197 |
+
if include_percentage:
|
198 |
+
gt_result = [reference / 100, reference, reference * 100]
|
199 |
+
else:
|
200 |
+
gt_result = [reference]
|
201 |
+
|
202 |
+
for item in gt_result:
|
203 |
+
if abs(item - prediction) <= self.precision * 1.01:
|
204 |
+
return True
|
205 |
+
return False
|
206 |
+
|
207 |
+
def expression_equal(self, exp1, exp2):
|
208 |
+
# Check if two expressions are mathematically equivalent
|
209 |
+
# Extract expression and use sympy for equivalence checking
|
210 |
+
def extract_expression(expression):
|
211 |
+
if "=" in expression:
|
212 |
+
expression = expression.split("=")[1]
|
213 |
+
return expression.strip()
|
214 |
+
|
215 |
+
exp1 = extract_expression(exp1)
|
216 |
+
exp2 = extract_expression(exp2)
|
217 |
+
|
218 |
+
exp_too_long = len(exp1) > 300 or len(exp2) > 300
|
219 |
+
|
220 |
+
expr1_sym = sympify(parse_latex(exp1))
|
221 |
+
expr2_sym = sympify(parse_latex(exp2))
|
222 |
+
if expr1_sym == expr2_sym:
|
223 |
+
return True
|
224 |
+
else:
|
225 |
+
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
226 |
+
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
227 |
+
|
228 |
+
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
|
229 |
+
(not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
230 |
+
return False
|
231 |
+
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
232 |
+
try:
|
233 |
+
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
|
234 |
+
print("These two numbers cannot be calculated by the current computer for: "
|
235 |
+
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
|
236 |
+
return False
|
237 |
+
if exp_too_long:
|
238 |
+
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
239 |
+
return False
|
240 |
+
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
|
241 |
+
return True
|
242 |
+
else:
|
243 |
+
return False
|
244 |
+
except:
|
245 |
+
return False
|
246 |
+
elif exp_too_long:
|
247 |
+
print(f'Expression {exp1} or {exp2} is too long to compute. ')
|
248 |
+
return False
|
249 |
+
else:
|
250 |
+
try:
|
251 |
+
simplified_expr = simplify(expr1_sym - expr2_sym)
|
252 |
+
num_value = simplified_expr.evalf()
|
253 |
+
return abs(num_value) < 1e-3
|
254 |
+
except:
|
255 |
+
return False
|
256 |
+
|
257 |
+
def equation_equal(self, expression1, expression2):
|
258 |
+
# Check if two equations are mathematically equivalent
|
259 |
+
# Simplify equations and use sympy for equivalence checking
|
260 |
+
def simplify_equation(latex_eq):
|
261 |
+
lhs, rhs = latex_eq.split('=')
|
262 |
+
|
263 |
+
lhs_expr = parse_latex(lhs)
|
264 |
+
rhs_expr = parse_latex(rhs)
|
265 |
+
|
266 |
+
equation = Eq(lhs_expr, rhs_expr)
|
267 |
+
|
268 |
+
simplified_eq = simplify(equation.lhs - equation.rhs)
|
269 |
+
|
270 |
+
return simplified_eq
|
271 |
+
|
272 |
+
expr1_sym = simplify_equation(expression1)
|
273 |
+
expr2_sym = simplify_equation(expression2)
|
274 |
+
|
275 |
+
division_result_1 = simplify(expr1_sym / expr2_sym)
|
276 |
+
division_result_2 = simplify(expr2_sym / expr1_sym)
|
277 |
+
|
278 |
+
if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
|
279 |
+
(division_result_2.is_Integer and division_result_2 != 0)):
|
280 |
+
return True
|
281 |
+
else:
|
282 |
+
return False
|
283 |
+
|
284 |
+
def interval_equal(self, expression1, expression2):
|
285 |
+
# Check if two intervals are mathematically equivalent
|
286 |
+
def compare_two_interval(inter1, inter2):
|
287 |
+
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
288 |
+
return False
|
289 |
+
|
290 |
+
inter1 = inter1.strip('[]()')
|
291 |
+
inter2 = inter2.strip('[]()')
|
292 |
+
|
293 |
+
items_1 = inter1.split(',')
|
294 |
+
items_2 = inter2.split(',')
|
295 |
+
|
296 |
+
for item_1, item_2 in zip(items_1, items_2):
|
297 |
+
if not self.expression_equal(item_1, item_2):
|
298 |
+
return False
|
299 |
+
return True
|
300 |
+
|
301 |
+
interval1 = expression1
|
302 |
+
interval2 = expression2
|
303 |
+
|
304 |
+
if interval1 == interval2:
|
305 |
+
return True
|
306 |
+
else:
|
307 |
+
inter_list1 = interval1.split("\\cup")
|
308 |
+
inter_list2 = interval2.split("\\cup")
|
309 |
+
|
310 |
+
if len(inter_list1) != len(inter_list2):
|
311 |
+
return False
|
312 |
+
else:
|
313 |
+
for inter1, inter2 in zip(inter_list1, inter_list2):
|
314 |
+
if not compare_two_interval(inter1, inter2):
|
315 |
+
return False
|
316 |
+
return True
|
317 |
+
|
318 |
+
def preprocess(self, expression1, expression2):
|
319 |
+
# Preprocess expressions to extract and replace special symbols
|
320 |
+
def extract_boxed_content(latex_str):
|
321 |
+
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
322 |
+
results = ""
|
323 |
+
|
324 |
+
for match in boxed_matches:
|
325 |
+
start_index = match.end()
|
326 |
+
end_index = start_index
|
327 |
+
stack = 1
|
328 |
+
|
329 |
+
while stack > 0 and end_index < len(latex_str):
|
330 |
+
if latex_str[end_index] == '{':
|
331 |
+
stack += 1
|
332 |
+
elif latex_str[end_index] == '}':
|
333 |
+
stack -= 1
|
334 |
+
end_index += 1
|
335 |
+
|
336 |
+
if stack == 0:
|
337 |
+
content = latex_str[start_index:end_index - 1]
|
338 |
+
results += content + ","
|
339 |
+
else:
|
340 |
+
raise ValueError("Mismatched braces in LaTeX string.")
|
341 |
+
|
342 |
+
if results == "":
|
343 |
+
last_line_ans = latex_str.strip().split("\n")[-1]
|
344 |
+
dollar_pattern = r"\$(.*?)\$"
|
345 |
+
answers = re.findall(dollar_pattern, last_line_ans)
|
346 |
+
|
347 |
+
if answers:
|
348 |
+
for ans in answers:
|
349 |
+
results += ans + ","
|
350 |
+
else:
|
351 |
+
results = latex_str
|
352 |
+
|
353 |
+
return results
|
354 |
+
|
355 |
+
def sepcial_symbol_replace(expression):
|
356 |
+
|
357 |
+
expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
|
358 |
+
|
359 |
+
expression = re.sub(r"(.+)m$", r"\1", expression)
|
360 |
+
|
361 |
+
if "\\in " in expression:
|
362 |
+
expression = expression.split("\\in ")[1]
|
363 |
+
|
364 |
+
for signal in self.special_signal_map:
|
365 |
+
expression = expression.replace(signal, self.special_signal_map[signal])
|
366 |
+
|
367 |
+
expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
|
368 |
+
|
369 |
+
expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
|
370 |
+
|
371 |
+
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
372 |
+
expression = re.sub(pattern, r'\1', expression)
|
373 |
+
|
374 |
+
return expression
|
375 |
+
|
376 |
+
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
|
377 |
+
|
378 |
+
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
|
379 |
+
|
380 |
+
return exp1, exp2
|
381 |
+
|
382 |
+
def can_compute_power(self, expr):
|
383 |
+
# Checks if a power expression can be computed
|
384 |
+
if isinstance(expr, Pow):
|
385 |
+
base, exp = expr.as_base_exp()
|
386 |
+
if base.is_number and exp.is_number:
|
387 |
+
MAX_EXP = 1000 # Adjust based on computing environment
|
388 |
+
if abs(exp.evalf()) > MAX_EXP:
|
389 |
+
return False
|
390 |
+
else:
|
391 |
+
return True
|
392 |
+
else:
|
393 |
+
return False
|
394 |
+
else:
|
395 |
+
return True # Not a power expression, can compute
|
396 |
+
|
397 |
+
|
398 |
+
class MMMath(ImageBaseDataset):
|
399 |
+
|
400 |
+
TYPE = 'VQA'
|
401 |
+
|
402 |
+
DATASET_URL = {
|
403 |
+
'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
|
404 |
+
}
|
405 |
+
DATASET_MD5 = {
|
406 |
+
'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
|
407 |
+
}
|
408 |
+
|
409 |
+
@classmethod
|
410 |
+
def evaluate(self, eval_file, **kwargs):
|
411 |
+
|
412 |
+
data = load(eval_file)
|
413 |
+
judger = AutoScoringJudge()
|
414 |
+
func = judger.judge
|
415 |
+
|
416 |
+
tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
|
417 |
+
|
418 |
+
res = track_progress_rich(func, tups, nproc=16)
|
419 |
+
data['hit'] = res
|
420 |
+
dump(data, eval_file)
|
421 |
+
|
422 |
+
score_file = eval_file.replace('.xlsx', '_score.json')
|
423 |
+
score = {}
|
424 |
+
score['overall'] = np.mean(data['hit'])
|
425 |
+
# Results by Difficulty
|
426 |
+
difficulties = set(data['difficulty'])
|
427 |
+
for d in difficulties:
|
428 |
+
score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
|
429 |
+
|
430 |
+
# Results by Year
|
431 |
+
years = set(data['year'])
|
432 |
+
for y in years:
|
433 |
+
score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
|
434 |
+
|
435 |
+
# Results by Knowledge-L1
|
436 |
+
points = set(data['knowledge_l1'])
|
437 |
+
for p in points:
|
438 |
+
score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
|
439 |
+
|
440 |
+
# Results by Knowledge-L2
|
441 |
+
points = set(data['knowledge_l2'])
|
442 |
+
for p in points:
|
443 |
+
score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
|
444 |
+
|
445 |
+
dump(score, score_file)
|
446 |
+
return score
|
VLMEvalKit/vlmeval/dataset/mvbench.py
ADDED
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import huggingface_hub
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
from ..smp import *
|
4 |
+
from .video_base import VideoBaseDataset
|
5 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
6 |
+
from ..utils import track_progress_rich
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from torchvision import transforms
|
9 |
+
from torchvision.transforms.functional import InterpolationMode
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
import imageio
|
12 |
+
import cv2
|
13 |
+
import zipfile
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
from .utils.mvbench import *
|
17 |
+
|
18 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
19 |
+
|
20 |
+
|
21 |
+
class MVBench(VideoBaseDataset):
|
22 |
+
|
23 |
+
MD5 = 'fd21d36522cdedd46d84dc46715ad832'
|
24 |
+
SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
|
25 |
+
the detail and movement of objects, and the action and pose of persons. \
|
26 |
+
Based on your observations, select the best option that accurately addresses the question.
|
27 |
+
"""
|
28 |
+
|
29 |
+
TYPE = 'Video-MCQ'
|
30 |
+
|
31 |
+
def __init__(self, dataset='MVBench', pack=False):
|
32 |
+
self.type_data_list = {
|
33 |
+
'Action Sequence': ('action_sequence.json',
|
34 |
+
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
35 |
+
'Action Prediction': ('action_prediction.json',
|
36 |
+
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
37 |
+
'Action Antonym': ('action_antonym.json',
|
38 |
+
'your_data_path/ssv2_video/', 'video', False),
|
39 |
+
'Fine-grained Action': ('fine_grained_action.json',
|
40 |
+
'your_data_path/Moments_in_Time_Raw/videos/', 'video', False),
|
41 |
+
'Unexpected Action': ('unexpected_action.json',
|
42 |
+
'your_data_path/FunQA_test/test/', 'video', False),
|
43 |
+
'Object Existence': ('object_existence.json',
|
44 |
+
'your_data_path/clevrer/video_validation/', 'video', False),
|
45 |
+
'Object Interaction': ('object_interaction.json',
|
46 |
+
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
|
47 |
+
'Object Shuffle': ('object_shuffle.json',
|
48 |
+
'your_data_path/perception/videos/', 'video', False),
|
49 |
+
'Moving Direction': ('moving_direction.json',
|
50 |
+
'your_data_path/clevrer/video_validation/', 'video', False),
|
51 |
+
'Action Localization': ('action_localization.json',
|
52 |
+
'your_data_path/sta/sta_video/', 'video', True), # has start & end
|
53 |
+
'Scene Transition': ('scene_transition.json',
|
54 |
+
'your_data_path/scene_qa/video/', 'video', False),
|
55 |
+
'Action Count': ('action_count.json',
|
56 |
+
'your_data_path/perception/videos/', 'video', False),
|
57 |
+
'Moving Count': ('moving_count.json',
|
58 |
+
'your_data_path/clevrer/video_validation/', 'video', False),
|
59 |
+
'Moving Attribute': ('moving_attribute.json',
|
60 |
+
'your_data_path/clevrer/video_validation/', 'video', False),
|
61 |
+
'State Change': ('state_change.json',
|
62 |
+
'your_data_path/perception/videos/', 'video', False),
|
63 |
+
'Fine-grained Pose': ('fine_grained_pose.json',
|
64 |
+
'your_data_path/nturgbd/', 'video', False),
|
65 |
+
'Character Order': ('character_order.json',
|
66 |
+
'your_data_path/perception/videos/', 'video', False),
|
67 |
+
'Egocentric Navigation': ('egocentric_navigation.json',
|
68 |
+
'your_data_path/vlnqa/', 'video', False),
|
69 |
+
'Episodic Reasoning': ('episodic_reasoning.json',
|
70 |
+
'your_data_path/tvqa/frames_fps3_hq/', 'frame', True), # has start & end, read frame
|
71 |
+
'Counterfactual Inference': ('counterfactual_inference.json',
|
72 |
+
'your_data_path/clevrer/video_validation/', 'video', False),
|
73 |
+
}
|
74 |
+
super().__init__(dataset=dataset, pack=pack)
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def supported_datasets(cls):
|
78 |
+
return ['MVBench']
|
79 |
+
|
80 |
+
def prepare_dataset(self, dataset_name='MVBench', repo_id='OpenGVLab/MVBench'):
|
81 |
+
def check_integrity(pth):
|
82 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
83 |
+
|
84 |
+
if not os.path.exists(data_file):
|
85 |
+
return False
|
86 |
+
|
87 |
+
if md5(data_file) != self.MD5:
|
88 |
+
return False
|
89 |
+
|
90 |
+
data = load(data_file)
|
91 |
+
for idx, item in data.iterrows():
|
92 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
93 |
+
return False
|
94 |
+
return True
|
95 |
+
|
96 |
+
if modelscope_flag_set():
|
97 |
+
repo_id = 'modelscope/MVBench'
|
98 |
+
|
99 |
+
cache_path = get_cache_path(repo_id, branch='main')
|
100 |
+
if cache_path is not None and check_integrity(cache_path):
|
101 |
+
dataset_path = cache_path
|
102 |
+
else:
|
103 |
+
def unzip_hf_zip(pth):
|
104 |
+
pth = os.path.join(pth, 'video/')
|
105 |
+
for filename in os.listdir(pth):
|
106 |
+
if filename.endswith('.zip'):
|
107 |
+
# 构建完整的文件路径
|
108 |
+
zip_path = os.path.join(pth, filename)
|
109 |
+
|
110 |
+
# 解压 ZIP 文件
|
111 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
112 |
+
zip_ref.extractall(pth)
|
113 |
+
|
114 |
+
def generate_tsv(pth):
|
115 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
116 |
+
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
117 |
+
return
|
118 |
+
json_data_dir = os.path.join(pth, 'json')
|
119 |
+
self.data_list = []
|
120 |
+
for k, v in self.type_data_list.items():
|
121 |
+
with open(os.path.join(json_data_dir, v[0]), 'r') as f:
|
122 |
+
json_data = json.load(f)
|
123 |
+
for data in json_data:
|
124 |
+
if os.path.exists(os.path.join(pth, v[1].replace('your_data_path', 'video'), data['video'])):
|
125 |
+
self.data_list.append({
|
126 |
+
'task_type': k,
|
127 |
+
'prefix': v[1].replace('your_data_path', 'video'),
|
128 |
+
'data_type': v[2],
|
129 |
+
'bound': v[3],
|
130 |
+
'start': data['start'] if 'start' in data.keys() else None,
|
131 |
+
'end': data['end'] if 'end' in data.keys() else None,
|
132 |
+
'video': data['video'],
|
133 |
+
'question': data['question'],
|
134 |
+
'answer': data['answer'],
|
135 |
+
'candidates': data['candidates']
|
136 |
+
})
|
137 |
+
else:
|
138 |
+
print(
|
139 |
+
'NTURGB-D zip file is removed according to MVBench, you can view it at '
|
140 |
+
'https://huggingface.co/datasets/OpenGVLab/MVBench for detailed reason.'
|
141 |
+
)
|
142 |
+
raise Exception(
|
143 |
+
f"{os.path.join(v[1].replace('your_data_path', 'video'), data['video'])} does not exist"
|
144 |
+
)
|
145 |
+
|
146 |
+
data_df = pd.DataFrame(self.data_list)
|
147 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
148 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
149 |
+
|
150 |
+
def move_files(pth):
|
151 |
+
src_folder = os.path.join(pth, 'video/data0613')
|
152 |
+
if not os.path.exists(src_folder):
|
153 |
+
return
|
154 |
+
for subdir in os.listdir(src_folder):
|
155 |
+
subdir_path = os.path.join(src_folder, subdir)
|
156 |
+
if os.path.isdir(subdir_path):
|
157 |
+
for subsubdir in os.listdir(subdir_path):
|
158 |
+
subsubdir_path = os.path.join(subdir_path, subsubdir)
|
159 |
+
if os.path.isdir(subsubdir_path):
|
160 |
+
for item in os.listdir(subsubdir_path):
|
161 |
+
item_path = os.path.join(subsubdir_path, item)
|
162 |
+
target_folder = os.path.join(pth, 'video', subdir, subsubdir)
|
163 |
+
if not os.path.exists(target_folder):
|
164 |
+
os.makedirs(target_folder)
|
165 |
+
target_path = os.path.join(target_folder, item)
|
166 |
+
try:
|
167 |
+
shutil.move(item_path, target_path)
|
168 |
+
except Exception as e:
|
169 |
+
print(f"Error moving {item_path} to {target_path}: {e}")
|
170 |
+
|
171 |
+
if modelscope_flag_set():
|
172 |
+
from modelscope import dataset_snapshot_download
|
173 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='master')
|
174 |
+
else:
|
175 |
+
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
176 |
+
huggingface_hub.login(hf_token)
|
177 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
178 |
+
unzip_hf_zip(dataset_path)
|
179 |
+
move_files(dataset_path)
|
180 |
+
generate_tsv(dataset_path)
|
181 |
+
|
182 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
183 |
+
|
184 |
+
self.decord_method = {
|
185 |
+
'video': self.read_video,
|
186 |
+
'gif': self.read_gif,
|
187 |
+
'frame': self.read_frame,
|
188 |
+
}
|
189 |
+
|
190 |
+
self.nframe = 8
|
191 |
+
self.frame_fps = 3
|
192 |
+
|
193 |
+
# transform
|
194 |
+
self.transform = T.Compose([
|
195 |
+
Stack(),
|
196 |
+
ToTorchFormatTensor()
|
197 |
+
])
|
198 |
+
|
199 |
+
return dict(root=dataset_path, data_file=data_file)
|
200 |
+
|
201 |
+
def get_index(self, bound, fps, max_frame, first_idx=0):
|
202 |
+
if bound:
|
203 |
+
start, end = bound[0], bound[1]
|
204 |
+
else:
|
205 |
+
start, end = -100000, 100000
|
206 |
+
start_idx = max(first_idx, round(start * fps))
|
207 |
+
end_idx = min(round(end * fps), max_frame)
|
208 |
+
seg_size = float(end_idx - start_idx) / self.num_segments
|
209 |
+
frame_indices = np.array([
|
210 |
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
211 |
+
for idx in range(self.num_segments)
|
212 |
+
])
|
213 |
+
return frame_indices
|
214 |
+
|
215 |
+
def read_video(self, video_path, bound=None):
|
216 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
217 |
+
max_frame = len(vr) - 1
|
218 |
+
fps = float(vr.get_avg_fps())
|
219 |
+
|
220 |
+
images_group = list()
|
221 |
+
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
|
222 |
+
for frame_index in frame_indices:
|
223 |
+
img = Image.fromarray(vr[frame_index].asnumpy())
|
224 |
+
images_group.append(img)
|
225 |
+
torch_imgs = self.transform(images_group)
|
226 |
+
return torch_imgs
|
227 |
+
|
228 |
+
def read_gif(self, video_path, bound=None, fps=25):
|
229 |
+
gif = imageio.get_reader(video_path)
|
230 |
+
max_frame = len(gif) - 1
|
231 |
+
|
232 |
+
images_group = list()
|
233 |
+
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
|
234 |
+
for index, frame in enumerate(gif):
|
235 |
+
if index in frame_indices:
|
236 |
+
img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
237 |
+
img = Image.fromarray(img)
|
238 |
+
images_group.append(img)
|
239 |
+
torch_imgs = self.transform(images_group)
|
240 |
+
return torch_imgs
|
241 |
+
|
242 |
+
def read_frame(self, video_path, bound=None, fps=3):
|
243 |
+
max_frame = len(os.listdir(video_path))
|
244 |
+
images_group = list()
|
245 |
+
frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
|
246 |
+
for frame_index in frame_indices:
|
247 |
+
img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg'))
|
248 |
+
images_group.append(img)
|
249 |
+
torch_imgs = self.transform(images_group)
|
250 |
+
return torch_imgs
|
251 |
+
|
252 |
+
def save_video_frames(self, imgs, video_name, frames):
|
253 |
+
|
254 |
+
frame_paths = self.frame_paths(video_name, frames)
|
255 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
256 |
+
|
257 |
+
if not flag:
|
258 |
+
block_size = imgs.size(0) // frames
|
259 |
+
split_tensors = torch.split(imgs, block_size)
|
260 |
+
to_pil = transforms.ToPILImage()
|
261 |
+
images = [to_pil(arr) for arr in split_tensors]
|
262 |
+
for im, pth in zip(images, frame_paths):
|
263 |
+
if not osp.exists(pth):
|
264 |
+
im.save(pth)
|
265 |
+
|
266 |
+
return frame_paths
|
267 |
+
|
268 |
+
def qa_template(self, data):
|
269 |
+
question = f"Question: {data['question']}\n"
|
270 |
+
question += 'Options:\n'
|
271 |
+
answer = data['answer']
|
272 |
+
answer_idx = -1
|
273 |
+
for idx, c in enumerate(eval(data['candidates'])):
|
274 |
+
question += f"({chr(ord('A') + idx)}) {c}\n"
|
275 |
+
if c == answer:
|
276 |
+
answer_idx = idx
|
277 |
+
question = question.rstrip()
|
278 |
+
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
279 |
+
return question, answer
|
280 |
+
|
281 |
+
def load_into_video_and_process(self, line):
|
282 |
+
try:
|
283 |
+
from moviepy.editor import VideoFileClip, ImageSequenceClip
|
284 |
+
except:
|
285 |
+
raise ImportError(
|
286 |
+
'MoviePy is not installed, please install it by running "pip install moviepy==1.0.3"'
|
287 |
+
)
|
288 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
289 |
+
|
290 |
+
if line['data_type'] in ['gif'] or os.path.splitext(video_path)[1] in ['.webm']:
|
291 |
+
processed_video_path = video_path.replace(os.path.splitext(video_path)[1], '.mp4')
|
292 |
+
if not os.path.exists(processed_video_path):
|
293 |
+
# using MoviePy to transform GIF, webm into mp4 format
|
294 |
+
gif_clip = VideoFileClip(video_path)
|
295 |
+
gif_clip.write_videofile(processed_video_path, codec='libx264')
|
296 |
+
gif_clip.close()
|
297 |
+
elif line['data_type'] in ['frame']:
|
298 |
+
input_images = os.path.join(video_path, '*.jpg')
|
299 |
+
processed_video_path = f'{video_path}.mp4'
|
300 |
+
if not os.path.exists(processed_video_path):
|
301 |
+
# using MoviePy to transform images into mp4
|
302 |
+
image_files = sorted(glob.glob(input_images))
|
303 |
+
image_clip = ImageSequenceClip(image_files, fps=self.frame_fps)
|
304 |
+
image_clip.write_videofile(processed_video_path, codec='libx264')
|
305 |
+
image_clip.close()
|
306 |
+
else:
|
307 |
+
processed_video_path = video_path
|
308 |
+
|
309 |
+
if line['bound']:
|
310 |
+
base_name, suffix = os.path.splitext(processed_video_path)
|
311 |
+
output_video_path = f'{base_name}_processed{suffix}'
|
312 |
+
if not os.path.exists(output_video_path):
|
313 |
+
video_clip = VideoFileClip(processed_video_path)
|
314 |
+
clip = video_clip.subclip(line['start'], min(line['end'], video_clip.duration))
|
315 |
+
clip.write_videofile(output_video_path)
|
316 |
+
clip.close()
|
317 |
+
else:
|
318 |
+
output_video_path = processed_video_path
|
319 |
+
|
320 |
+
return output_video_path
|
321 |
+
|
322 |
+
def save_video_into_images(self, line, num_frames):
|
323 |
+
bound = None
|
324 |
+
if line['bound']:
|
325 |
+
bound = (
|
326 |
+
line['start'],
|
327 |
+
line['end'],
|
328 |
+
)
|
329 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
330 |
+
decord_method = self.decord_method[line['data_type']]
|
331 |
+
self.num_segments = num_frames if num_frames > 0 else self.nframe
|
332 |
+
torch_imgs = decord_method(video_path, bound)
|
333 |
+
img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
|
334 |
+
return img_frame_paths
|
335 |
+
|
336 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
337 |
+
if fps > 0:
|
338 |
+
raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!')
|
339 |
+
if isinstance(line, int):
|
340 |
+
assert line < len(self)
|
341 |
+
line = self.data.iloc[line]
|
342 |
+
|
343 |
+
question, answer = self.qa_template(line)
|
344 |
+
message = [dict(type='text', value=self.SYS, role='system')]
|
345 |
+
message.append(dict(type='text', value=question))
|
346 |
+
if video_llm:
|
347 |
+
new_video_path = self.load_into_video_and_process(line)
|
348 |
+
message.append(dict(type='video', value=new_video_path))
|
349 |
+
else:
|
350 |
+
img_frame_paths = self.save_video_into_images(line, num_frames)
|
351 |
+
for im in img_frame_paths:
|
352 |
+
message.append(dict(type='image', value=im))
|
353 |
+
message.append(dict(type='text', value='\nOnly give the best option.'))
|
354 |
+
message.append(dict(type='text', value='Best option:(', role='assistant'))
|
355 |
+
return message
|
356 |
+
|
357 |
+
@classmethod
|
358 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
359 |
+
|
360 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
361 |
+
|
362 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
363 |
+
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
364 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
365 |
+
|
366 |
+
if not osp.exists(score_file):
|
367 |
+
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
368 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
369 |
+
|
370 |
+
if model == 'exact_matching':
|
371 |
+
model = None
|
372 |
+
elif gpt_key_set():
|
373 |
+
model = build_judge(**judge_kwargs)
|
374 |
+
if not model.working():
|
375 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
376 |
+
warnings.warn(DEBUG_MESSAGE)
|
377 |
+
model = None
|
378 |
+
else:
|
379 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
380 |
+
model = None
|
381 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
382 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
383 |
+
|
384 |
+
data = load(eval_file)
|
385 |
+
data_un = data[~pd.isna(data['prediction'])]
|
386 |
+
|
387 |
+
for idx in data_un['index']:
|
388 |
+
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
389 |
+
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
390 |
+
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
391 |
+
answer_idx = -1
|
392 |
+
for id, c in enumerate(options):
|
393 |
+
if c == ans:
|
394 |
+
answer_idx = id
|
395 |
+
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
396 |
+
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
397 |
+
for id, option_content in enumerate(eval(input_item['candidates'])):
|
398 |
+
input_item[chr(ord('A') + id)] = option_content
|
399 |
+
if option_content == input_item['answer']:
|
400 |
+
input_item['answer'] = chr(ord('A') + id)
|
401 |
+
|
402 |
+
if FAIL_MSG in pred:
|
403 |
+
data.loc[idx, 'score'] = -1
|
404 |
+
else:
|
405 |
+
data.loc[idx, 'score'] = int(check_ans_with_model(
|
406 |
+
pred, ans, model,
|
407 |
+
input_item,
|
408 |
+
'MVBench'
|
409 |
+
))
|
410 |
+
|
411 |
+
rejected = [x for x in data['score'] if x == -1]
|
412 |
+
|
413 |
+
print(
|
414 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
415 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
416 |
+
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
417 |
+
)
|
418 |
+
|
419 |
+
dump(data, score_file)
|
420 |
+
|
421 |
+
rating = get_dimension_rating(score_file)
|
422 |
+
dump(rating, tgt_file)
|
423 |
+
return rating
|
424 |
+
|
425 |
+
|
426 |
+
class MVBench_MP4(VideoBaseDataset):
|
427 |
+
|
428 |
+
MP4_MD5 = '5c8c6f8b7972c2de65a629590f7c42f5'
|
429 |
+
SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
|
430 |
+
the detail and movement of objects, and the action and pose of persons. \
|
431 |
+
Based on your observations, select the best option that accurately addresses the question.
|
432 |
+
"""
|
433 |
+
TYPE = 'Video-MCQ'
|
434 |
+
|
435 |
+
def __init__(self, dataset='MVBench_MP4', pack=False):
|
436 |
+
super().__init__(dataset=dataset, pack=pack)
|
437 |
+
|
438 |
+
@classmethod
|
439 |
+
def supported_datasets(cls):
|
440 |
+
return ['MVBench_MP4']
|
441 |
+
|
442 |
+
def prepare_dataset(self, dataset_name='MVBench_MP4', repo_id='OpenGVLab/MVBench'):
|
443 |
+
def check_integrity(pth):
|
444 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
445 |
+
|
446 |
+
if not os.path.exists(data_file):
|
447 |
+
return False
|
448 |
+
|
449 |
+
if md5(data_file) != self.MP4_MD5:
|
450 |
+
return False
|
451 |
+
|
452 |
+
data = load(data_file)
|
453 |
+
for idx, item in data.iterrows():
|
454 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
|
455 |
+
return False
|
456 |
+
return True
|
457 |
+
|
458 |
+
if modelscope_flag_set():
|
459 |
+
repo_id = 'modelscope/MVBench'
|
460 |
+
|
461 |
+
cache_path = get_cache_path(repo_id, branch='video')
|
462 |
+
if cache_path is not None and check_integrity(cache_path):
|
463 |
+
dataset_path = cache_path
|
464 |
+
else:
|
465 |
+
def generate_tsv(pth):
|
466 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
467 |
+
if os.path.exists(data_file) and md5(data_file) == self.MP4_MD5:
|
468 |
+
return
|
469 |
+
json_data_path = os.path.join(dataset_path, 'test.json')
|
470 |
+
json_data = load(json_data_path)
|
471 |
+
root_data_dict = json_data['root']
|
472 |
+
self.data_list = []
|
473 |
+
for k, v in json_data['meta'].items():
|
474 |
+
for item in v:
|
475 |
+
self.data_list.append({
|
476 |
+
'task_type': k,
|
477 |
+
'prefix': root_data_dict[k],
|
478 |
+
'video': item['video'],
|
479 |
+
'question': item['question'],
|
480 |
+
'answer': item['answer'],
|
481 |
+
'candidates': item['candidates']
|
482 |
+
})
|
483 |
+
data_df = pd.DataFrame(self.data_list)
|
484 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
485 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
486 |
+
|
487 |
+
if modelscope_flag_set():
|
488 |
+
from modelscope import dataset_snapshot_download
|
489 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='video')
|
490 |
+
else:
|
491 |
+
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
|
492 |
+
huggingface_hub.login(hf_token)
|
493 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset', revision='video')
|
494 |
+
generate_tsv(dataset_path)
|
495 |
+
|
496 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
497 |
+
|
498 |
+
self.nframe = 8
|
499 |
+
|
500 |
+
# transform
|
501 |
+
self.transform = T.Compose([
|
502 |
+
Stack(),
|
503 |
+
ToTorchFormatTensor()
|
504 |
+
])
|
505 |
+
|
506 |
+
return dict(root=dataset_path, data_file=data_file)
|
507 |
+
|
508 |
+
def qa_template(self, data):
|
509 |
+
question = f"Question: {data['question']}\n"
|
510 |
+
question += 'Options:\n'
|
511 |
+
answer = data['answer']
|
512 |
+
answer_idx = -1
|
513 |
+
for idx, c in enumerate(eval(data['candidates'])):
|
514 |
+
question += f"({chr(ord('A') + idx)}) {c}\n"
|
515 |
+
if c == answer:
|
516 |
+
answer_idx = idx
|
517 |
+
question = question.rstrip()
|
518 |
+
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
519 |
+
return question, answer
|
520 |
+
|
521 |
+
def get_index_by_frame(self, max_frame):
|
522 |
+
seg_size = float(max_frame) / self.num_segments
|
523 |
+
frame_indices = np.array([
|
524 |
+
int((seg_size / 2) + np.round(seg_size * idx))
|
525 |
+
for idx in range(self.num_segments)
|
526 |
+
])
|
527 |
+
return frame_indices
|
528 |
+
|
529 |
+
def get_index_by_fps(self, vid, fps):
|
530 |
+
total_frames = len(vid)
|
531 |
+
video_fps = vid.get_avg_fps()
|
532 |
+
total_duration = total_frames / video_fps
|
533 |
+
required_frames = int(total_duration * fps)
|
534 |
+
step_size = video_fps / fps
|
535 |
+
frame_indices = np.array([int(i * step_size) for i in range(required_frames)])
|
536 |
+
self.num_segments = len(frame_indices)
|
537 |
+
return frame_indices
|
538 |
+
|
539 |
+
def read_video(self, video_path, fps=-1):
|
540 |
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
541 |
+
max_frame = len(vr) - 1
|
542 |
+
|
543 |
+
images_group = list()
|
544 |
+
if fps < 0:
|
545 |
+
frame_indices = self.get_index_by_frame(max_frame)
|
546 |
+
else:
|
547 |
+
frame_indices = self.get_index_by_fps(vr, fps)
|
548 |
+
|
549 |
+
for frame_index in frame_indices:
|
550 |
+
img = Image.fromarray(vr[frame_index].asnumpy())
|
551 |
+
images_group.append(img)
|
552 |
+
torch_imgs = self.transform(images_group)
|
553 |
+
return torch_imgs
|
554 |
+
|
555 |
+
def save_video_frames(self, imgs, video_name, frames, fps):
|
556 |
+
if fps > 0:
|
557 |
+
frame_paths = self.frame_paths_fps(video_name, frames, fps)
|
558 |
+
else:
|
559 |
+
frame_paths = self.frame_paths(video_name, frames)
|
560 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
561 |
+
|
562 |
+
if not flag:
|
563 |
+
block_size = imgs.size(0) // frames
|
564 |
+
split_tensors = torch.split(imgs, block_size)
|
565 |
+
to_pil = transforms.ToPILImage()
|
566 |
+
images = [to_pil(arr) for arr in split_tensors]
|
567 |
+
for im, pth in zip(images, frame_paths):
|
568 |
+
if not osp.exists(pth):
|
569 |
+
im.save(pth)
|
570 |
+
|
571 |
+
return frame_paths
|
572 |
+
|
573 |
+
def save_video_into_images(self, line, num_frames, fps=-1):
|
574 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
575 |
+
if fps <= 0:
|
576 |
+
self.num_segments = num_frames if num_frames > 0 else self.nframe
|
577 |
+
else:
|
578 |
+
self.num_segments = 0
|
579 |
+
torch_imgs = self.read_video(video_path, fps)
|
580 |
+
img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments, fps)
|
581 |
+
return img_frame_paths
|
582 |
+
|
583 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
584 |
+
if isinstance(line, int):
|
585 |
+
assert line < len(self)
|
586 |
+
line = self.data.iloc[line]
|
587 |
+
|
588 |
+
question, answer = self.qa_template(line)
|
589 |
+
message = [dict(type='text', value=self.SYS, role='system')]
|
590 |
+
message.append(dict(type='text', value=question))
|
591 |
+
video_path = os.path.join(self.data_root, line['prefix'], line['video'])
|
592 |
+
if video_llm:
|
593 |
+
message.append(dict(type='video', value=video_path))
|
594 |
+
else:
|
595 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
596 |
+
for im in img_frame_paths:
|
597 |
+
message.append(dict(type='image', value=im))
|
598 |
+
message.append(dict(type='text', value='\nOnly give the best option.'))
|
599 |
+
message.append(dict(type='text', value='Best option:(', role='assistant'))
|
600 |
+
return message
|
601 |
+
|
602 |
+
@classmethod
|
603 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
604 |
+
|
605 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
606 |
+
|
607 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
608 |
+
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
609 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
610 |
+
|
611 |
+
if not osp.exists(score_file):
|
612 |
+
model = judge_kwargs.setdefault('model', 'chatgpt-0125')
|
613 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
614 |
+
|
615 |
+
if model == 'exact_matching':
|
616 |
+
model = None
|
617 |
+
elif gpt_key_set():
|
618 |
+
model = build_judge(**judge_kwargs)
|
619 |
+
if not model.working():
|
620 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
621 |
+
warnings.warn(DEBUG_MESSAGE)
|
622 |
+
model = None
|
623 |
+
else:
|
624 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
625 |
+
model = None
|
626 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
627 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
628 |
+
|
629 |
+
data = load(eval_file)
|
630 |
+
data_un = data[~pd.isna(data['prediction'])]
|
631 |
+
|
632 |
+
for idx in data_un['index']:
|
633 |
+
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
634 |
+
pred = data.loc[data['index'] == idx, 'prediction'].values[0]
|
635 |
+
options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
|
636 |
+
answer_idx = -1
|
637 |
+
for id, c in enumerate(options):
|
638 |
+
if c == ans:
|
639 |
+
answer_idx = id
|
640 |
+
ans = f"({chr(ord('A') + answer_idx)}) {ans}"
|
641 |
+
input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
|
642 |
+
for id, option_content in enumerate(eval(input_item['candidates'])):
|
643 |
+
input_item[chr(ord('A') + id)] = option_content
|
644 |
+
if option_content == input_item['answer']:
|
645 |
+
input_item['answer'] = chr(ord('A') + id)
|
646 |
+
|
647 |
+
if FAIL_MSG in pred:
|
648 |
+
data.loc[idx, 'score'] = -1
|
649 |
+
else:
|
650 |
+
data.loc[idx, 'score'] = int(check_ans_with_model(
|
651 |
+
pred, ans, model,
|
652 |
+
input_item,
|
653 |
+
'MVBench_MP4'
|
654 |
+
))
|
655 |
+
|
656 |
+
rejected = [x for x in data['score'] if x == -1]
|
657 |
+
|
658 |
+
print(
|
659 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
660 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
661 |
+
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
662 |
+
)
|
663 |
+
|
664 |
+
dump(data, score_file)
|
665 |
+
|
666 |
+
rating = get_dimension_rating(score_file)
|
667 |
+
dump(rating, tgt_file)
|
668 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/slidevqa.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import math
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from vlmeval.dataset.utils.judge_util import build_judge
|
6 |
+
from vlmeval.smp import *
|
7 |
+
from .image_base import ImageBaseDataset
|
8 |
+
from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
|
9 |
+
|
10 |
+
|
11 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
12 |
+
|
13 |
+
|
14 |
+
def get_f1(gt, pred):
|
15 |
+
gt_bow, pred_bow = gt.strip().split(), pred.strip().split()
|
16 |
+
if not gt_bow or not pred_bow:
|
17 |
+
return 0.0
|
18 |
+
|
19 |
+
recall = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(gt_bow)
|
20 |
+
precision = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(pred_bow)
|
21 |
+
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 1e-4 else 0.0
|
22 |
+
return f1
|
23 |
+
|
24 |
+
|
25 |
+
def SlideVQA_acc(result_file):
|
26 |
+
data = load(result_file)
|
27 |
+
anls_list, em_list, f1_list = list(), list(), list()
|
28 |
+
for i in range(len(data)):
|
29 |
+
item = data.iloc[i]
|
30 |
+
if isinstance(item['answer'], float) and math.isnan(item['answer']):
|
31 |
+
item['answer'] = 'Not answerable'
|
32 |
+
|
33 |
+
item['answer'] = re.sub('\n', '', item['answer']).lower()
|
34 |
+
item['pred'] = str(item['pred']).lower()
|
35 |
+
anls_score = anls_compute(item['answer'], item['pred'])
|
36 |
+
em_score = (item['answer'].strip() == item['pred'].strip())
|
37 |
+
f1_score = get_f1(item['answer'], item['pred'])
|
38 |
+
anls_list.append(anls_score)
|
39 |
+
em_list.append(em_score)
|
40 |
+
f1_list.append(f1_score)
|
41 |
+
print('---------------------')
|
42 |
+
print(item['answer'], item['pred'], anls_score, em_score, f1_score)
|
43 |
+
|
44 |
+
data['anls'] = anls_list
|
45 |
+
data['em'] = em_list
|
46 |
+
data['f1'] = f1_list
|
47 |
+
dump(data, result_file)
|
48 |
+
|
49 |
+
res = dict()
|
50 |
+
res['category'], res['num'] = ['anls', 'EM', 'F1'], [len(data), len(data), len(data)]
|
51 |
+
res['avg'] = [sum(anls_list) / len(data), sum(em_list) / len(data), sum(f1_list) / len(data)]
|
52 |
+
res = pd.DataFrame(res)
|
53 |
+
return res
|
54 |
+
|
55 |
+
|
56 |
+
class SlideVQA(ImageBaseDataset):
|
57 |
+
|
58 |
+
TYPE = 'VQA'
|
59 |
+
|
60 |
+
DATASET_URL = {
|
61 |
+
'SLIDEVQA_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA_MINI.tsv',
|
62 |
+
'SLIDEVQA': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA.tsv',
|
63 |
+
}
|
64 |
+
DATASET_MD5 = {
|
65 |
+
'SLIDEVQA_MINI': '6d9a8d8814fa5b7669deb2af3a3208eb',
|
66 |
+
'SLIDEVQA': '5e822c2f800e94c1e23badfd478326b6',
|
67 |
+
}
|
68 |
+
|
69 |
+
SUPPORTED_MODELS = {
|
70 |
+
'GPT4': (1, 1),
|
71 |
+
'GPT4V': (1, 1),
|
72 |
+
'GPT4V_HIGH': (1, 1),
|
73 |
+
'GPT4o': (1, 1),
|
74 |
+
'GPT4o_HIGH': (1, 1),
|
75 |
+
'GPT4o_MINI': (1, 1),
|
76 |
+
'XComposer2d5': (1, -1),
|
77 |
+
'XComposer2_4KHD': (1, -1),
|
78 |
+
'MiniCPM-Llama3-V-2_5': (1, 5),
|
79 |
+
'InternVL-Chat-V1-5': (5, 2),
|
80 |
+
}
|
81 |
+
|
82 |
+
def __init__(self, dataset, **kwargs):
|
83 |
+
self.model_list = list(self.SUPPORTED_MODELS.keys())
|
84 |
+
model_name = kwargs['model']
|
85 |
+
if not listinstr(self.model_list, model_name):
|
86 |
+
raise AssertionError("{} doesn't support the evaluation on SlideVQA.".format(model_name))
|
87 |
+
super(SlideVQA, self).__init__(dataset)
|
88 |
+
|
89 |
+
self.is_api = True if listinstr(['GPT4'], model_name) else False
|
90 |
+
self.max_pages = 120
|
91 |
+
concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
|
92 |
+
self.concat_num = concat_num
|
93 |
+
self.column_num = column_num
|
94 |
+
|
95 |
+
def dump_image(self, origin_line):
|
96 |
+
os.makedirs(self.img_root, exist_ok=True)
|
97 |
+
|
98 |
+
line = origin_line.copy()
|
99 |
+
if not isinstance(line['image_path'], List):
|
100 |
+
line['image_path'] = [line['image_path']]
|
101 |
+
line['image_path'] = line['image_path'][:self.max_pages]
|
102 |
+
|
103 |
+
if 'image' in line:
|
104 |
+
if isinstance(line['image'], list):
|
105 |
+
tgt_path = []
|
106 |
+
assert 'image_path' in line
|
107 |
+
for img, im_name in zip(line['image'], line['image_path']):
|
108 |
+
path = osp.join(self.img_root, im_name)
|
109 |
+
if not read_ok(path):
|
110 |
+
decode_base64_to_image_file(img, path)
|
111 |
+
tgt_path.append(path)
|
112 |
+
else:
|
113 |
+
tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
|
114 |
+
if not read_ok(tgt_path):
|
115 |
+
decode_base64_to_image_file(line['image'], tgt_path)
|
116 |
+
tgt_path = [tgt_path]
|
117 |
+
else:
|
118 |
+
assert 'image_path' in line
|
119 |
+
tgt_path = toliststr(line['image_path'])
|
120 |
+
|
121 |
+
if self.concat_num > 0 and not self.is_api:
|
122 |
+
concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
|
123 |
+
|
124 |
+
old_tgt_path = tgt_path
|
125 |
+
assert isinstance(old_tgt_path, list)
|
126 |
+
if self.column_num != -1:
|
127 |
+
tgt_path = [
|
128 |
+
'_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
|
129 |
+
for i in range(len(concatenated_images))
|
130 |
+
]
|
131 |
+
else:
|
132 |
+
tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
|
133 |
+
|
134 |
+
for path, concatenated_image in zip(tgt_path, concatenated_images):
|
135 |
+
if not read_ok(path):
|
136 |
+
decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
|
137 |
+
num_images, image_size = len(old_tgt_path), concatenated_image.size
|
138 |
+
print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
|
139 |
+
return tgt_path
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
143 |
+
logger = get_logger('Evaluation')
|
144 |
+
model = judge_kwargs['model']
|
145 |
+
|
146 |
+
suffix = eval_file.split('.')[-1]
|
147 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
148 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
149 |
+
|
150 |
+
if osp.exists(storage):
|
151 |
+
logger.warning(f'GPT scoring file {storage} already exists, will reuse it in SlideVQA_eval. ')
|
152 |
+
else:
|
153 |
+
data = load(eval_file)
|
154 |
+
model = build_judge(max_tokens=128, **judge_kwargs)
|
155 |
+
lt = len(data)
|
156 |
+
lines = [data.iloc[i] for i in range(lt)]
|
157 |
+
tups = [(model, line) for line in lines]
|
158 |
+
indices = [line['index'] for line in lines]
|
159 |
+
|
160 |
+
ans = {}
|
161 |
+
if osp.exists(tmp_file):
|
162 |
+
ans = load(tmp_file)
|
163 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
164 |
+
indices = [i for i in indices if i not in ans]
|
165 |
+
|
166 |
+
if len(indices):
|
167 |
+
new_results = list()
|
168 |
+
for model, line in tqdm(tups):
|
169 |
+
res = MMLongBench_auxeval(model, line)
|
170 |
+
new_results.append(res)
|
171 |
+
|
172 |
+
log_map, res_map, pred_map = {}, {}, {}
|
173 |
+
all_inds = [line['index'] for line in lines]
|
174 |
+
for k, v in zip(all_inds, new_results):
|
175 |
+
log_map[k] = v['log']
|
176 |
+
res_map[k] = v['res']
|
177 |
+
pred_map[k] = v['pred']
|
178 |
+
data['res'] = [res_map[idx] for idx in data['index']]
|
179 |
+
data['log'] = [log_map[idx] for idx in data['index']]
|
180 |
+
data['pred'] = [pred_map[idx] for idx in data['index']]
|
181 |
+
dump(data, storage)
|
182 |
+
|
183 |
+
score = SlideVQA_acc(storage)
|
184 |
+
score_pth = storage.replace('.xlsx', '_score.csv')
|
185 |
+
|
186 |
+
dump(score, score_pth)
|
187 |
+
logger.info(f'SlideVQA successfully finished evaluating {eval_file}, results saved in {score_pth}')
|
188 |
+
logger.info('Score: ')
|
189 |
+
logger.info(score)
|
VLMEvalKit/vlmeval/dataset/tempcompass.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import huggingface_hub
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
from ..smp import *
|
4 |
+
from .video_concat_dataset import ConcatVideoDataset
|
5 |
+
from .video_base import VideoBaseDataset
|
6 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
7 |
+
from ..utils import track_progress_rich
|
8 |
+
import torchvision.transforms as T
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.transforms.functional import InterpolationMode
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
from .utils.tempcompass import *
|
13 |
+
|
14 |
+
|
15 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
16 |
+
|
17 |
+
|
18 |
+
class TempCompass(ConcatVideoDataset):
|
19 |
+
def __init__(self, dataset='TempCompass'):
|
20 |
+
self.DATASET_SETS[dataset] = ['TempCompass_MCQ', 'TempCompass_Captioning', 'TempCompass_YorN']
|
21 |
+
super().__init__(dataset=dataset)
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def supported_datasets(cls):
|
25 |
+
return ['TempCompass']
|
26 |
+
|
27 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
28 |
+
result = super().evaluate(eval_file=eval_file, **judge_kwargs)
|
29 |
+
suffix = eval_file.split('.')[-1]
|
30 |
+
result = result.reset_index().rename(columns={'index': 'dim.task_type'})
|
31 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
32 |
+
avg_dict = {}
|
33 |
+
for idx, item in result.iterrows():
|
34 |
+
dim, task_type = item['dim.task_type'].split('. ')
|
35 |
+
if dim not in avg_dict:
|
36 |
+
avg_dict[dim] = {'success': 0.0, 'overall': 0.0}
|
37 |
+
if task_type not in avg_dict:
|
38 |
+
avg_dict[task_type] = {'success': 0.0, 'overall': 0.0}
|
39 |
+
if 'overall' not in avg_dict:
|
40 |
+
avg_dict['overall'] = {'success': 0.0, 'overall': 0.0}
|
41 |
+
avg_dict[dim]['success'] += item['success']
|
42 |
+
avg_dict[dim]['overall'] += item['overall']
|
43 |
+
avg_dict[task_type]['success'] += item['success']
|
44 |
+
avg_dict[task_type]['overall'] += item['overall']
|
45 |
+
avg_dict['overall']['success'] += item['success']
|
46 |
+
avg_dict['overall']['overall'] += item['overall']
|
47 |
+
result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 2)
|
48 |
+
for key, value in avg_dict.items():
|
49 |
+
# 使用 loc 方法添加新行
|
50 |
+
result.loc[len(result)] = {
|
51 |
+
'dim.task_type': key,
|
52 |
+
'success': value['success'],
|
53 |
+
'overall': value['overall'],
|
54 |
+
'acc': round(value['success'] / value['overall'] * 100, 2)
|
55 |
+
}
|
56 |
+
dump(result, score_file)
|
57 |
+
return result
|
58 |
+
|
59 |
+
|
60 |
+
class TempCompass_MCQ(VideoBaseDataset):
|
61 |
+
|
62 |
+
MD5 = '7efbb9e6d9dabacd22daf274852691dd'
|
63 |
+
TYPE = 'Video-MCQ'
|
64 |
+
|
65 |
+
def __init__(self, dataset='TempCompass_MCQ'):
|
66 |
+
self.type_data_list = {
|
67 |
+
'multi-choice': ('multi-choice.json', './videos', '.mp4'),
|
68 |
+
'caption_matching': ('caption_matching.json', './videos', '.mp4'),
|
69 |
+
}
|
70 |
+
super().__init__(dataset=dataset)
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def supported_datasets(cls):
|
74 |
+
return ['TempCompass_MCQ']
|
75 |
+
|
76 |
+
def prepare_dataset(self, dataset_name='TempCompass_MCQ', repo_id='lmms-lab/TempCompass'):
|
77 |
+
def check_integrity(pth):
|
78 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
79 |
+
|
80 |
+
if not osp.exists(data_file):
|
81 |
+
return False
|
82 |
+
|
83 |
+
if md5(data_file) != self.MD5:
|
84 |
+
return False
|
85 |
+
|
86 |
+
data = load(data_file)
|
87 |
+
for idx, item in data.iterrows():
|
88 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
89 |
+
return False
|
90 |
+
return True
|
91 |
+
|
92 |
+
cache_path = get_cache_path(repo_id)
|
93 |
+
if cache_path is not None and check_integrity(cache_path):
|
94 |
+
dataset_path = cache_path
|
95 |
+
else:
|
96 |
+
def read_parquet(pth):
|
97 |
+
import pandas as pd
|
98 |
+
for task_name in self.type_data_list.keys():
|
99 |
+
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
100 |
+
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
101 |
+
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
102 |
+
|
103 |
+
def unzip_videos(pth):
|
104 |
+
import zipfile
|
105 |
+
if not osp.exists(osp.join(pth, 'videos')):
|
106 |
+
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
107 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
108 |
+
zip_ref.extractall(pth)
|
109 |
+
|
110 |
+
def generate_tsv(pth):
|
111 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
112 |
+
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
113 |
+
return
|
114 |
+
self.data_list = []
|
115 |
+
for k, v in self.type_data_list.items():
|
116 |
+
with open(osp.join(pth, v[0]), 'r') as f:
|
117 |
+
json_data = json.load(f)
|
118 |
+
for data in json_data:
|
119 |
+
self.data_list.append({
|
120 |
+
'task_type': k,
|
121 |
+
'prefix': v[1],
|
122 |
+
'suffix': v[2],
|
123 |
+
'video': data['video_id'],
|
124 |
+
'question': data['question'].split('\n')[0],
|
125 |
+
'answer': data['answer'],
|
126 |
+
'dim': data['dim'],
|
127 |
+
'candidates': data['question'].split('\n')[1:],
|
128 |
+
})
|
129 |
+
|
130 |
+
data_df = pd.DataFrame(self.data_list)
|
131 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
132 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
133 |
+
|
134 |
+
if modelscope_flag_set():
|
135 |
+
from modelscope import dataset_snapshot_download
|
136 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
137 |
+
else:
|
138 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
139 |
+
read_parquet(dataset_path)
|
140 |
+
unzip_videos(dataset_path)
|
141 |
+
generate_tsv(dataset_path)
|
142 |
+
|
143 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
144 |
+
return dict(root=dataset_path, data_file=data_file)
|
145 |
+
|
146 |
+
def qa_template(self, data):
|
147 |
+
question = data['question'] + '\n' + '\n'.join(eval(data['candidates']))
|
148 |
+
answer = data['answer']
|
149 |
+
return question, answer
|
150 |
+
|
151 |
+
def save_video_frames(self, line, num_frames=8, fps=-1):
|
152 |
+
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
153 |
+
vid = decord.VideoReader(vid_path)
|
154 |
+
video_info = {
|
155 |
+
'fps': vid.get_avg_fps(),
|
156 |
+
'n_frames': len(vid),
|
157 |
+
}
|
158 |
+
if num_frames > 0 and fps < 0:
|
159 |
+
step_size = len(vid) / (num_frames + 1)
|
160 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
161 |
+
frame_paths = self.frame_paths(line['video'], num_frames)
|
162 |
+
elif fps > 0:
|
163 |
+
# not constrained by num_frames, get frames by fps
|
164 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
165 |
+
required_frames = int(total_duration * fps)
|
166 |
+
step_size = video_info['fps'] / fps
|
167 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
168 |
+
frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
|
169 |
+
|
170 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
171 |
+
|
172 |
+
if not flag:
|
173 |
+
images = [vid[i].asnumpy() for i in indices]
|
174 |
+
images = [Image.fromarray(arr) for arr in images]
|
175 |
+
for im, pth in zip(images, frame_paths):
|
176 |
+
if not osp.exists(pth):
|
177 |
+
im.save(pth)
|
178 |
+
|
179 |
+
return frame_paths
|
180 |
+
|
181 |
+
def save_video_into_images(self, line, num_frames, fps):
|
182 |
+
frame_paths = self.save_video_frames(line, num_frames, fps)
|
183 |
+
return frame_paths
|
184 |
+
|
185 |
+
def build_prompt(self, line, num_frames, video_llm, fps=-1):
|
186 |
+
if isinstance(line, int):
|
187 |
+
assert line < len(self)
|
188 |
+
line = self.data.iloc[line]
|
189 |
+
|
190 |
+
question, answer = self.qa_template(line)
|
191 |
+
message = []
|
192 |
+
message.append(dict(type='text', value=question))
|
193 |
+
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
194 |
+
if video_llm:
|
195 |
+
message.append(dict(type='video', value=video_path))
|
196 |
+
else:
|
197 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
198 |
+
for im in img_frame_paths:
|
199 |
+
message.append(dict(type='image', value=im))
|
200 |
+
message.append(dict(type='text', value='\nPlease directly give the best option:'))
|
201 |
+
return message
|
202 |
+
|
203 |
+
@classmethod
|
204 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
205 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
206 |
+
assert model in ['chatgpt-1106', 'exact_matching']
|
207 |
+
judge_kwargs.update({
|
208 |
+
"max_tokens": 128,
|
209 |
+
"temperature": 1.0,
|
210 |
+
"top_p": 1,
|
211 |
+
"presence_penalty": 1,
|
212 |
+
})
|
213 |
+
|
214 |
+
suffix = eval_file.split('.')[-1]
|
215 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
216 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
217 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
218 |
+
|
219 |
+
if not osp.exists(score_file):
|
220 |
+
data = load(eval_file)
|
221 |
+
if model != 'exact_matching':
|
222 |
+
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
223 |
+
else:
|
224 |
+
model = None
|
225 |
+
|
226 |
+
lt = len(data)
|
227 |
+
lines = [data.iloc[i] for i in range(lt)]
|
228 |
+
tups = [(model, line) for line in lines]
|
229 |
+
indices = [line['index'] for line in lines]
|
230 |
+
|
231 |
+
ans = {}
|
232 |
+
if osp.exists(tmp_file):
|
233 |
+
ans = load(tmp_file)
|
234 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
235 |
+
indices = [i for i in indices if i not in ans]
|
236 |
+
|
237 |
+
if len(indices):
|
238 |
+
_ = track_progress_rich(
|
239 |
+
evaluate_tempcompass_mcq,
|
240 |
+
tups,
|
241 |
+
nproc=nproc,
|
242 |
+
chunksize=nproc,
|
243 |
+
keys=indices,
|
244 |
+
save=tmp_file,
|
245 |
+
)
|
246 |
+
ans = load(tmp_file)
|
247 |
+
for idx, item in data.iterrows():
|
248 |
+
data.loc[idx, 'score'] = ans[idx]['rating']
|
249 |
+
dump(data, score_file)
|
250 |
+
|
251 |
+
rating = get_dimension_rating(score_file)
|
252 |
+
return rating
|
253 |
+
|
254 |
+
|
255 |
+
class TempCompass_Captioning(VideoBaseDataset):
|
256 |
+
|
257 |
+
MD5 = '35be9bf2581ea7767f02e9a8f37ae1ab'
|
258 |
+
TYPE = 'Video-VQA'
|
259 |
+
|
260 |
+
def __init__(self, dataset='TempCompass_Captioning'):
|
261 |
+
self.type_data_list = {
|
262 |
+
'captioning': ('captioning.json', './videos', '.mp4'),
|
263 |
+
}
|
264 |
+
super().__init__(dataset=dataset)
|
265 |
+
|
266 |
+
@classmethod
|
267 |
+
def supported_datasets(cls):
|
268 |
+
return ['TempCompass_Captioning']
|
269 |
+
|
270 |
+
def prepare_dataset(self, dataset_name='TempCompass_Captioning', repo_id='lmms-lab/TempCompass'):
|
271 |
+
def check_integrity(pth):
|
272 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
273 |
+
|
274 |
+
if not osp.exists(data_file):
|
275 |
+
return False
|
276 |
+
|
277 |
+
if md5(data_file) != self.MD5:
|
278 |
+
return False
|
279 |
+
|
280 |
+
data = load(data_file)
|
281 |
+
for idx, item in data.iterrows():
|
282 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
283 |
+
return False
|
284 |
+
return True
|
285 |
+
|
286 |
+
cache_path = get_cache_path(repo_id)
|
287 |
+
if cache_path is not None and check_integrity(cache_path):
|
288 |
+
dataset_path = cache_path
|
289 |
+
else:
|
290 |
+
def read_parquet(pth):
|
291 |
+
import pandas as pd
|
292 |
+
for task_name in self.type_data_list.keys():
|
293 |
+
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
294 |
+
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
295 |
+
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
296 |
+
|
297 |
+
def unzip_videos(pth):
|
298 |
+
import zipfile
|
299 |
+
if not osp.exists(osp.join(pth, 'videos')):
|
300 |
+
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
301 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
302 |
+
zip_ref.extractall(pth)
|
303 |
+
|
304 |
+
def generate_tsv(pth):
|
305 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
306 |
+
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
307 |
+
return
|
308 |
+
self.data_list = []
|
309 |
+
for k, v in self.type_data_list.items():
|
310 |
+
with open(osp.join(pth, v[0]), 'r') as f:
|
311 |
+
json_data = json.load(f)
|
312 |
+
for data in json_data:
|
313 |
+
self.data_list.append({
|
314 |
+
'task_type': k,
|
315 |
+
'prefix': v[1],
|
316 |
+
'suffix': v[2],
|
317 |
+
'video': data['video_id'],
|
318 |
+
'question': data['question'],
|
319 |
+
'answer': data['answer'],
|
320 |
+
'dim': data['dim'],
|
321 |
+
'mc_question': data['mc_question'],
|
322 |
+
'mc_answer': data['mc_answer'],
|
323 |
+
})
|
324 |
+
|
325 |
+
data_df = pd.DataFrame(self.data_list)
|
326 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
327 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
328 |
+
|
329 |
+
if modelscope_flag_set():
|
330 |
+
from modelscope import dataset_snapshot_download
|
331 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
332 |
+
else:
|
333 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
334 |
+
read_parquet(dataset_path)
|
335 |
+
unzip_videos(dataset_path)
|
336 |
+
generate_tsv(dataset_path)
|
337 |
+
|
338 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
339 |
+
return dict(root=dataset_path, data_file=data_file)
|
340 |
+
|
341 |
+
def qa_template(self, data):
|
342 |
+
question = data['question']
|
343 |
+
answer = data['answer']
|
344 |
+
return question, answer
|
345 |
+
|
346 |
+
def save_video_frames(self, line, num_frames=8, fps=-1):
|
347 |
+
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
348 |
+
vid = decord.VideoReader(vid_path)
|
349 |
+
video_info = {
|
350 |
+
'fps': vid.get_avg_fps(),
|
351 |
+
'n_frames': len(vid),
|
352 |
+
}
|
353 |
+
if num_frames > 0 and fps < 0:
|
354 |
+
step_size = len(vid) / (num_frames + 1)
|
355 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
356 |
+
frame_paths = self.frame_paths(line['video'], num_frames)
|
357 |
+
elif fps > 0:
|
358 |
+
# not constrained by num_frames, get frames by fps
|
359 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
360 |
+
required_frames = int(total_duration * fps)
|
361 |
+
step_size = video_info['fps'] / fps
|
362 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
363 |
+
frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
|
364 |
+
|
365 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
366 |
+
|
367 |
+
if not flag:
|
368 |
+
images = [vid[i].asnumpy() for i in indices]
|
369 |
+
images = [Image.fromarray(arr) for arr in images]
|
370 |
+
for im, pth in zip(images, frame_paths):
|
371 |
+
if not osp.exists(pth):
|
372 |
+
im.save(pth)
|
373 |
+
|
374 |
+
return frame_paths
|
375 |
+
|
376 |
+
def save_video_into_images(self, line, num_frames, fps):
|
377 |
+
frame_paths = self.save_video_frames(line, num_frames, fps)
|
378 |
+
return frame_paths
|
379 |
+
|
380 |
+
def build_prompt(self, line, num_frames, video_llm, fps=-1):
|
381 |
+
if isinstance(line, int):
|
382 |
+
assert line < len(self)
|
383 |
+
line = self.data.iloc[line]
|
384 |
+
|
385 |
+
question, answer = self.qa_template(line)
|
386 |
+
message = []
|
387 |
+
message.append(dict(type='text', value=question))
|
388 |
+
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
389 |
+
if video_llm:
|
390 |
+
message.append(dict(type='video', value=video_path))
|
391 |
+
else:
|
392 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
393 |
+
for im in img_frame_paths:
|
394 |
+
message.append(dict(type='image', value=im))
|
395 |
+
return message
|
396 |
+
|
397 |
+
@classmethod
|
398 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
399 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
400 |
+
assert model in ['chatgpt-1106', 'exact_matching']
|
401 |
+
judge_kwargs.update({
|
402 |
+
"max_tokens": 128,
|
403 |
+
"temperature": 1.0,
|
404 |
+
"top_p": 1,
|
405 |
+
"presence_penalty": 1,
|
406 |
+
})
|
407 |
+
|
408 |
+
suffix = eval_file.split('.')[-1]
|
409 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
410 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
411 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
412 |
+
|
413 |
+
if not osp.exists(score_file):
|
414 |
+
data = load(eval_file)
|
415 |
+
if model != 'exact_matching':
|
416 |
+
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
417 |
+
else:
|
418 |
+
model = None
|
419 |
+
|
420 |
+
lt = len(data)
|
421 |
+
lines = [data.iloc[i] for i in range(lt)]
|
422 |
+
tups = [(model, line) for line in lines]
|
423 |
+
indices = [line['index'] for line in lines]
|
424 |
+
|
425 |
+
ans = {}
|
426 |
+
if osp.exists(tmp_file):
|
427 |
+
ans = load(tmp_file)
|
428 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
429 |
+
indices = [i for i in indices if i not in ans]
|
430 |
+
|
431 |
+
if len(indices):
|
432 |
+
_ = track_progress_rich(
|
433 |
+
evaluate_tempcompass_captioning,
|
434 |
+
tups,
|
435 |
+
nproc=nproc,
|
436 |
+
chunksize=nproc,
|
437 |
+
keys=indices,
|
438 |
+
save=tmp_file,
|
439 |
+
)
|
440 |
+
ans = load(tmp_file)
|
441 |
+
for idx, item in data.iterrows():
|
442 |
+
data.loc[idx, 'score'] = ans[idx]['rating']
|
443 |
+
dump(data, score_file)
|
444 |
+
|
445 |
+
rating = get_dimension_rating(score_file)
|
446 |
+
return rating
|
447 |
+
|
448 |
+
|
449 |
+
class TempCompass_YorN(VideoBaseDataset):
|
450 |
+
|
451 |
+
MD5 = 'c72c046d7fa0e82c8cd7462f2e844ea8'
|
452 |
+
TYPE = 'Video-Y/N'
|
453 |
+
|
454 |
+
def __init__(self, dataset='TempCompass_YorN'):
|
455 |
+
self.type_data_list = {
|
456 |
+
'yes_no': ('yes_no.json', './videos', '.mp4'),
|
457 |
+
}
|
458 |
+
super().__init__(dataset=dataset)
|
459 |
+
|
460 |
+
@classmethod
|
461 |
+
def supported_datasets(cls):
|
462 |
+
return ['TempCompass_YorN']
|
463 |
+
|
464 |
+
def prepare_dataset(self, dataset_name='TempCompass_YorN', repo_id='lmms-lab/TempCompass'):
|
465 |
+
def check_integrity(pth):
|
466 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
467 |
+
|
468 |
+
if not osp.exists(data_file):
|
469 |
+
return False
|
470 |
+
|
471 |
+
if md5(data_file) != self.MD5:
|
472 |
+
return False
|
473 |
+
|
474 |
+
data = load(data_file)
|
475 |
+
for idx, item in data.iterrows():
|
476 |
+
if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
|
477 |
+
return False
|
478 |
+
return True
|
479 |
+
|
480 |
+
cache_path = get_cache_path(repo_id)
|
481 |
+
if cache_path is not None and check_integrity(cache_path):
|
482 |
+
dataset_path = cache_path
|
483 |
+
else:
|
484 |
+
def read_parquet(pth):
|
485 |
+
import pandas as pd
|
486 |
+
for task_name in self.type_data_list.keys():
|
487 |
+
if not osp.exists(osp.join(pth, f'{task_name}.json')):
|
488 |
+
data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
|
489 |
+
data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
|
490 |
+
|
491 |
+
def unzip_videos(pth):
|
492 |
+
import zipfile
|
493 |
+
if not osp.exists(osp.join(pth, 'videos')):
|
494 |
+
zip_file = osp.join(pth, 'tempcompass_videos.zip')
|
495 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
496 |
+
zip_ref.extractall(pth)
|
497 |
+
|
498 |
+
def generate_tsv(pth):
|
499 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
500 |
+
if osp.exists(data_file) and md5(data_file) == self.MD5:
|
501 |
+
return
|
502 |
+
self.data_list = []
|
503 |
+
for k, v in self.type_data_list.items():
|
504 |
+
with open(osp.join(pth, v[0]), 'r') as f:
|
505 |
+
json_data = json.load(f)
|
506 |
+
for data in json_data:
|
507 |
+
self.data_list.append({
|
508 |
+
'task_type': k,
|
509 |
+
'prefix': v[1],
|
510 |
+
'suffix': v[2],
|
511 |
+
'video': data['video_id'],
|
512 |
+
'question': data['question'].split('\n')[0],
|
513 |
+
'answer': data['answer'],
|
514 |
+
'dim': data['dim']
|
515 |
+
})
|
516 |
+
|
517 |
+
data_df = pd.DataFrame(self.data_list)
|
518 |
+
data_df = data_df.assign(index=range(len(data_df)))
|
519 |
+
data_df.to_csv(data_file, sep='\t', index=False)
|
520 |
+
|
521 |
+
if modelscope_flag_set():
|
522 |
+
from modelscope import dataset_snapshot_download
|
523 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
524 |
+
else:
|
525 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
526 |
+
read_parquet(dataset_path)
|
527 |
+
unzip_videos(dataset_path)
|
528 |
+
generate_tsv(dataset_path)
|
529 |
+
|
530 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
531 |
+
return dict(root=dataset_path, data_file=data_file)
|
532 |
+
|
533 |
+
def qa_template(self, data):
|
534 |
+
question = data['question']
|
535 |
+
answer = data['answer']
|
536 |
+
return question, answer
|
537 |
+
|
538 |
+
def save_video_frames(self, line, num_frames=8, fps=-1):
|
539 |
+
vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
540 |
+
vid = decord.VideoReader(vid_path)
|
541 |
+
video_info = {
|
542 |
+
'fps': vid.get_avg_fps(),
|
543 |
+
'n_frames': len(vid),
|
544 |
+
}
|
545 |
+
if num_frames > 0 and fps < 0:
|
546 |
+
step_size = len(vid) / (num_frames + 1)
|
547 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
548 |
+
frame_paths = self.frame_paths(line['video'], num_frames)
|
549 |
+
elif fps > 0:
|
550 |
+
# not constrained by num_frames, get frames by fps
|
551 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
552 |
+
required_frames = int(total_duration * fps)
|
553 |
+
step_size = video_info['fps'] / fps
|
554 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
555 |
+
frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
|
556 |
+
|
557 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
558 |
+
|
559 |
+
if not flag:
|
560 |
+
images = [vid[i].asnumpy() for i in indices]
|
561 |
+
images = [Image.fromarray(arr) for arr in images]
|
562 |
+
for im, pth in zip(images, frame_paths):
|
563 |
+
if not osp.exists(pth):
|
564 |
+
im.save(pth)
|
565 |
+
|
566 |
+
return frame_paths
|
567 |
+
|
568 |
+
def save_video_into_images(self, line, num_frames, fps):
|
569 |
+
frame_paths = self.save_video_frames(line, num_frames, fps)
|
570 |
+
return frame_paths
|
571 |
+
|
572 |
+
def build_prompt(self, line, num_frames, video_llm, fps=-1):
|
573 |
+
if isinstance(line, int):
|
574 |
+
assert line < len(self)
|
575 |
+
line = self.data.iloc[line]
|
576 |
+
|
577 |
+
question, answer = self.qa_template(line)
|
578 |
+
message = []
|
579 |
+
message.append(dict(type='text', value=question))
|
580 |
+
video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
|
581 |
+
if video_llm:
|
582 |
+
message.append(dict(type='video', value=video_path))
|
583 |
+
else:
|
584 |
+
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
|
585 |
+
for im in img_frame_paths:
|
586 |
+
message.append(dict(type='image', value=im))
|
587 |
+
message.append(dict(type='text', value='\nPlease answer yes or no:'))
|
588 |
+
return message
|
589 |
+
|
590 |
+
@classmethod
|
591 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
592 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
593 |
+
assert model in ['chatgpt-1106', 'exact_matching']
|
594 |
+
judge_kwargs.update({
|
595 |
+
"max_tokens": 128,
|
596 |
+
"temperature": 1.0,
|
597 |
+
"top_p": 1,
|
598 |
+
"presence_penalty": 1,
|
599 |
+
})
|
600 |
+
|
601 |
+
suffix = eval_file.split('.')[-1]
|
602 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
|
603 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
604 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
605 |
+
|
606 |
+
if not osp.exists(score_file):
|
607 |
+
data = load(eval_file)
|
608 |
+
if model != 'exact_matching':
|
609 |
+
model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
|
610 |
+
else:
|
611 |
+
model = None
|
612 |
+
|
613 |
+
lt = len(data)
|
614 |
+
lines = [data.iloc[i] for i in range(lt)]
|
615 |
+
tups = [(model, line) for line in lines]
|
616 |
+
indices = [line['index'] for line in lines]
|
617 |
+
|
618 |
+
ans = {}
|
619 |
+
if osp.exists(tmp_file):
|
620 |
+
ans = load(tmp_file)
|
621 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
622 |
+
indices = [i for i in indices if i not in ans]
|
623 |
+
|
624 |
+
if len(indices):
|
625 |
+
_ = track_progress_rich(
|
626 |
+
evaluate_tempcompass_YorN,
|
627 |
+
tups,
|
628 |
+
nproc=nproc,
|
629 |
+
chunksize=nproc,
|
630 |
+
keys=indices,
|
631 |
+
save=tmp_file,
|
632 |
+
)
|
633 |
+
ans = load(tmp_file)
|
634 |
+
for idx, item in data.iterrows():
|
635 |
+
data.loc[idx, 'score'] = ans[idx]['rating']
|
636 |
+
dump(data, score_file)
|
637 |
+
|
638 |
+
rating = get_dimension_rating(score_file)
|
639 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/text_base.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from ..smp import *
|
3 |
+
|
4 |
+
|
5 |
+
class TextBaseDataset:
|
6 |
+
MODALITY = 'TEXT'
|
7 |
+
DATASET_URL = {}
|
8 |
+
DATASET_MD5 = {}
|
9 |
+
|
10 |
+
def __init__(self, dataset='MMBench', **kwargs):
|
11 |
+
self.dataset_name = dataset
|
12 |
+
|
13 |
+
data = self.load_data(dataset)
|
14 |
+
|
15 |
+
data['index'] = [str(x) for x in data['index']]
|
16 |
+
|
17 |
+
if np.all([istype(x, int) for x in data['index']]):
|
18 |
+
data['index'] = [int(x) for x in data['index']]
|
19 |
+
|
20 |
+
self.data = data
|
21 |
+
self.post_build(dataset)
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.data)
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
return dict(self.data.iloc[idx])
|
28 |
+
|
29 |
+
def prepare_tsv(self, url, file_md5=None):
|
30 |
+
data_root = LMUDataRoot()
|
31 |
+
os.makedirs(data_root, exist_ok=True)
|
32 |
+
update_flag = False
|
33 |
+
file_name = url.split('/')[-1]
|
34 |
+
data_path = osp.join(data_root, file_name)
|
35 |
+
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
|
36 |
+
pass
|
37 |
+
else:
|
38 |
+
warnings.warn('The dataset tsv is not downloaded')
|
39 |
+
download_file(url, data_path)
|
40 |
+
update_flag = True
|
41 |
+
|
42 |
+
if file_size(data_path, 'GB') > 1:
|
43 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
44 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
|
45 |
+
from ..tools import LOCALIZE
|
46 |
+
LOCALIZE(data_path, local_path)
|
47 |
+
data_path = local_path
|
48 |
+
return load(data_path)
|
49 |
+
|
50 |
+
def dump_image(self, line):
|
51 |
+
return []
|
52 |
+
|
53 |
+
def display(self, line):
|
54 |
+
if isinstance(line, int):
|
55 |
+
line = self.data.iloc[line]
|
56 |
+
assert isinstance(line, pd.Series) or isinstance(line, dict)
|
57 |
+
mmqa_display(line)
|
58 |
+
|
59 |
+
# Return a list of dataset names that are supported by this class, can override
|
60 |
+
@classmethod
|
61 |
+
def supported_datasets(cls):
|
62 |
+
return list(cls.DATASET_URL)
|
63 |
+
|
64 |
+
# Given the dataset name, return the dataset as a pandas dataframe, can override
|
65 |
+
def load_data(self, dataset):
|
66 |
+
url = self.DATASET_URL[dataset]
|
67 |
+
file_md5 = self.DATASET_MD5[dataset]
|
68 |
+
return self.prepare_tsv(url, file_md5)
|
69 |
+
|
70 |
+
# Post built hook, will be called after the dataset is built, can override
|
71 |
+
def post_build(self, dataset):
|
72 |
+
pass
|
73 |
+
|
74 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
75 |
+
def build_prompt(self, line):
|
76 |
+
if isinstance(line, int):
|
77 |
+
line = self.data.iloc[line]
|
78 |
+
|
79 |
+
question = line['question']
|
80 |
+
|
81 |
+
msgs = []
|
82 |
+
msgs.append(dict(type='text', value=question))
|
83 |
+
return msgs
|
84 |
+
|
85 |
+
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
86 |
+
@abstractmethod
|
87 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
88 |
+
pass
|
VLMEvalKit/vlmeval/dataset/text_mcq.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .text_base import TextBaseDataset
|
2 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
3 |
+
from ..smp import *
|
4 |
+
|
5 |
+
|
6 |
+
class TextMCQDataset(TextBaseDataset):
|
7 |
+
TYPE = 'MCQ'
|
8 |
+
|
9 |
+
DATASET_URL = {}
|
10 |
+
|
11 |
+
DATASET_MD5 = {}
|
12 |
+
|
13 |
+
def build_prompt(self, line):
|
14 |
+
|
15 |
+
if isinstance(line, int):
|
16 |
+
line = self.data.iloc[line]
|
17 |
+
|
18 |
+
question = line['question']
|
19 |
+
options = {
|
20 |
+
cand: line[cand]
|
21 |
+
for cand in string.ascii_uppercase
|
22 |
+
if cand in line and not pd.isna(line[cand])
|
23 |
+
}
|
24 |
+
options_prompt = 'Options:\n'
|
25 |
+
for key, item in options.items():
|
26 |
+
options_prompt += f'{key}. {item}\n'
|
27 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
28 |
+
prompt = ''
|
29 |
+
if hint is not None:
|
30 |
+
prompt += f'Hint: {hint}\n'
|
31 |
+
prompt += f'Question: {question}\n'
|
32 |
+
if len(options):
|
33 |
+
prompt += options_prompt
|
34 |
+
prompt += 'Please select the correct answer from the options above. \n'
|
35 |
+
|
36 |
+
msgs = []
|
37 |
+
|
38 |
+
msgs.append(dict(type='text', value=prompt))
|
39 |
+
|
40 |
+
return msgs
|
41 |
+
|
42 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
43 |
+
from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
|
44 |
+
# assert dataset is not None
|
45 |
+
dataset_map = {
|
46 |
+
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
|
47 |
+
'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
|
48 |
+
}
|
49 |
+
dataset = self.dataset_name
|
50 |
+
if dataset in dataset_map:
|
51 |
+
dataset = dataset_map[dataset]
|
52 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
53 |
+
|
54 |
+
circular = False
|
55 |
+
|
56 |
+
suffix = eval_file.split('.')[-1]
|
57 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
58 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
59 |
+
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
|
60 |
+
name_str = name_str_map[model] if model in name_str_map else model
|
61 |
+
|
62 |
+
if model == 'exact_matching':
|
63 |
+
model = None
|
64 |
+
elif gpt_key_set():
|
65 |
+
model = build_judge(**judge_kwargs)
|
66 |
+
if not model.working():
|
67 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
68 |
+
warnings.warn(DEBUG_MESSAGE)
|
69 |
+
model = None
|
70 |
+
else:
|
71 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
72 |
+
model = None
|
73 |
+
|
74 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
|
75 |
+
|
76 |
+
data = load(eval_file)
|
77 |
+
data = data.sort_values(by='index')
|
78 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
79 |
+
# If not choice label, then use lower case
|
80 |
+
for k in data.keys():
|
81 |
+
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
82 |
+
|
83 |
+
meta = self.data
|
84 |
+
meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
|
85 |
+
data_map = {x: y for x, y in zip(data['index'], data['question'])}
|
86 |
+
for k in data_map:
|
87 |
+
assert k in meta_q_map, (
|
88 |
+
f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
|
89 |
+
)
|
90 |
+
|
91 |
+
if circular:
|
92 |
+
data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
93 |
+
else:
|
94 |
+
data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
|
95 |
+
|
96 |
+
# load split
|
97 |
+
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
98 |
+
data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
|
99 |
+
|
100 |
+
# May have different report acc functions for different datasets
|
101 |
+
if 'MMT' in dataset:
|
102 |
+
acc = report_acc_MMT(data)
|
103 |
+
else:
|
104 |
+
acc = report_acc(data)
|
105 |
+
|
106 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
107 |
+
dump(acc, score_file)
|
108 |
+
|
109 |
+
return acc
|
110 |
+
|
111 |
+
|
112 |
+
class CustomTextMCQDataset(TextMCQDataset):
|
113 |
+
|
114 |
+
def load_data(self, dataset):
|
115 |
+
data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
|
116 |
+
|
117 |
+
if file_size(data_path, 'GB') > 1:
|
118 |
+
local_path = data_path.replace('.tsv', '_local.tsv')
|
119 |
+
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
|
120 |
+
from ..tools import LOCALIZE
|
121 |
+
LOCALIZE(data_path, local_path)
|
122 |
+
data_path = local_path
|
123 |
+
return load(data_path)
|
VLMEvalKit/vlmeval/dataset/vcr.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from functools import partial
|
3 |
+
from .image_base import ImageBaseDataset
|
4 |
+
from ..smp import *
|
5 |
+
|
6 |
+
rouge = None
|
7 |
+
nlp_en = None
|
8 |
+
nlp_zh = None
|
9 |
+
nlp = None
|
10 |
+
|
11 |
+
|
12 |
+
def initialize():
|
13 |
+
import evaluate
|
14 |
+
import spacy
|
15 |
+
|
16 |
+
global rouge, nlp_en, nlp_zh, nlp
|
17 |
+
|
18 |
+
try:
|
19 |
+
rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
|
20 |
+
except Exception as e:
|
21 |
+
logging.critical(f'{type(e)}: {e}')
|
22 |
+
logging.critical('Please first `pip install rouge_score`.')
|
23 |
+
|
24 |
+
try:
|
25 |
+
nlp_en = spacy.load('en_core_web_sm')
|
26 |
+
except Exception as e:
|
27 |
+
logging.warning(f'{type(e)}: {e}')
|
28 |
+
logging.warning('Will automatically download en_core_web_sm via spacy.')
|
29 |
+
spacy.cli.download('en_core_web_sm')
|
30 |
+
nlp_en = spacy.load('en_core_web_sm')
|
31 |
+
|
32 |
+
try:
|
33 |
+
nlp_zh = spacy.load('zh_core_web_sm')
|
34 |
+
except Exception as e:
|
35 |
+
logging.warning(f'{type(e)}: {e}')
|
36 |
+
logging.warning('Will automatically download zh_core_web_sm via spacy.')
|
37 |
+
spacy.cli.download('zh_core_web_sm')
|
38 |
+
nlp_zh = spacy.load('zh_core_web_sm')
|
39 |
+
|
40 |
+
nlp = {'en': nlp_en, 'zh': nlp_zh}
|
41 |
+
|
42 |
+
|
43 |
+
def rough_filter(answer_text):
|
44 |
+
if "I can't" in answer_text:
|
45 |
+
return False
|
46 |
+
elif 'I cannot' in answer_text:
|
47 |
+
return False
|
48 |
+
elif 'sorry' in answer_text.lower():
|
49 |
+
return False
|
50 |
+
if '无法' in answer_text:
|
51 |
+
return False
|
52 |
+
elif '抱歉' in answer_text:
|
53 |
+
return False
|
54 |
+
else:
|
55 |
+
return True
|
56 |
+
|
57 |
+
|
58 |
+
def zero_template(crossed_text):
|
59 |
+
return {
|
60 |
+
'crossed_text': crossed_text,
|
61 |
+
'max_sim_val': 0,
|
62 |
+
'max_sim_string': '',
|
63 |
+
'precision': 0,
|
64 |
+
'recall': 0,
|
65 |
+
'f1': 0,
|
66 |
+
'jaccard': 0,
|
67 |
+
'rouge1': 0,
|
68 |
+
'exact_match': 0,
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
def tokenize(text, language):
|
73 |
+
"""
|
74 |
+
Tokenize the text and return the tokens.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
text (str): The text to tokenize.
|
78 |
+
language (str): The language of the text.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
list: The list of tokens.
|
82 |
+
"""
|
83 |
+
assert language in ['en', 'zh']
|
84 |
+
nlp_language = nlp[language]
|
85 |
+
processed_text = nlp_language(text)
|
86 |
+
return [token.text for token in processed_text]
|
87 |
+
|
88 |
+
|
89 |
+
def find_best_match(needle, hay, language, rouge):
|
90 |
+
"""
|
91 |
+
Finds the best matching n-gram in the haystack for the given needle.
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
needle (str): The string to find.
|
95 |
+
hay (str): The text to search within.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
tuple: The highest similarity value and the best matching string.
|
99 |
+
"""
|
100 |
+
assert language in ['en', 'zh']
|
101 |
+
from nltk.util import ngrams
|
102 |
+
from difflib import SequenceMatcher as SM
|
103 |
+
|
104 |
+
tokens_hay = tokenize(hay, language)
|
105 |
+
tokens_needle = tokenize(needle, language)
|
106 |
+
|
107 |
+
splitter = '' if language == 'zh' else ' '
|
108 |
+
ngrams_ = ngrams(tokens_hay, len(tokens_needle))
|
109 |
+
max_sim_val = 0
|
110 |
+
max_sim_string = ''
|
111 |
+
max_sim_ngram = []
|
112 |
+
tokens_needle_set = set(tokens_needle)
|
113 |
+
ngrams_hasjoint = [
|
114 |
+
ngram
|
115 |
+
for ngram in ngrams_
|
116 |
+
if not set(ngram).isdisjoint(tokens_needle_set)
|
117 |
+
]
|
118 |
+
|
119 |
+
for ngram in ngrams_hasjoint:
|
120 |
+
hay_ngram = splitter.join(ngram)
|
121 |
+
similarity = SM(None, hay_ngram, needle).ratio()
|
122 |
+
if similarity > max_sim_val:
|
123 |
+
max_sim_val = similarity
|
124 |
+
max_sim_string = hay_ngram
|
125 |
+
max_sim_ngram = ngram
|
126 |
+
|
127 |
+
# Evaluate
|
128 |
+
if len(max_sim_ngram) == 0:
|
129 |
+
return {
|
130 |
+
'crossed_text': needle,
|
131 |
+
'max_sim_val': 0,
|
132 |
+
'max_sim_string': '',
|
133 |
+
'precision': 0,
|
134 |
+
'recall': 0,
|
135 |
+
'f1': 0,
|
136 |
+
'jaccard': 0,
|
137 |
+
'rouge1': 0,
|
138 |
+
'exact_match': 0,
|
139 |
+
}
|
140 |
+
pred_set = set(max_sim_ngram)
|
141 |
+
ref_set = set(tokens_needle)
|
142 |
+
correct_tokens = pred_set.intersection(ref_set)
|
143 |
+
len_correct_tokens = len(correct_tokens)
|
144 |
+
|
145 |
+
precision = len_correct_tokens / len(pred_set)
|
146 |
+
recall = len_correct_tokens / len(ref_set)
|
147 |
+
if (precision + recall) == 0:
|
148 |
+
f1 = 0
|
149 |
+
else:
|
150 |
+
f1 = 2 * precision * recall / (precision + recall)
|
151 |
+
union = pred_set.union(ref_set)
|
152 |
+
jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
|
153 |
+
rouge_1 = rouge.compute(
|
154 |
+
predictions=[max_sim_string],
|
155 |
+
references=[needle],
|
156 |
+
tokenizer=partial(tokenize, language=language),
|
157 |
+
rouge_types=['rouge1'],
|
158 |
+
)['rouge1']
|
159 |
+
exact_match = float(list(max_sim_ngram) == list(tokens_needle))
|
160 |
+
out = {
|
161 |
+
'crossed_text': needle,
|
162 |
+
'max_sim_string': max_sim_string,
|
163 |
+
'max_sim_val': max_sim_val,
|
164 |
+
'precision': precision,
|
165 |
+
'recall': recall,
|
166 |
+
'f1': f1,
|
167 |
+
'jaccard': jaccard,
|
168 |
+
'rouge1': rouge_1,
|
169 |
+
'exact_match': exact_match,
|
170 |
+
}
|
171 |
+
return out
|
172 |
+
|
173 |
+
|
174 |
+
def process_match_single_new(
|
175 |
+
image_id, prediction, answer, language, progress
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
process the inference results for a single image and calculate the metrics
|
179 |
+
|
180 |
+
Parameters:
|
181 |
+
image_id (int): The image id (question id).
|
182 |
+
prediction (str): The prediction text.
|
183 |
+
answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
|
184 |
+
language (str): The language of the text. Can be "en" or "zh".
|
185 |
+
rouge (rouge): The rouge metric object.
|
186 |
+
progress (multiprocessing.Queue): The progress queue.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
|
190 |
+
"""
|
191 |
+
result_per_id = {image_id: {}}
|
192 |
+
if isinstance(answer, str):
|
193 |
+
answer = eval(answer)
|
194 |
+
assert isinstance(answer, list)
|
195 |
+
result = prediction.split('Assistant: ')[-1]
|
196 |
+
for i, crossed_text in enumerate(answer):
|
197 |
+
if rough_filter(result):
|
198 |
+
find_best_match_result = find_best_match(
|
199 |
+
crossed_text, result, language, rouge
|
200 |
+
)
|
201 |
+
if i == 0:
|
202 |
+
result_per_id[image_id] = {str(i): find_best_match_result}
|
203 |
+
else:
|
204 |
+
result_per_id[image_id][str(i)] = find_best_match_result
|
205 |
+
else:
|
206 |
+
if i == 0:
|
207 |
+
result_per_id[image_id] = {str(i): zero_template(crossed_text)}
|
208 |
+
else:
|
209 |
+
result_per_id[image_id][str(i)] = zero_template(crossed_text)
|
210 |
+
progress.put(1)
|
211 |
+
return image_id, result_per_id
|
212 |
+
|
213 |
+
|
214 |
+
class VCRDataset(ImageBaseDataset):
|
215 |
+
TYPE = 'VQA'
|
216 |
+
|
217 |
+
URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
|
218 |
+
|
219 |
+
DATASET_URL = {
|
220 |
+
'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
|
221 |
+
'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
|
222 |
+
'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
|
223 |
+
'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
|
224 |
+
'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
|
225 |
+
'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
|
226 |
+
'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
|
227 |
+
'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
|
228 |
+
'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
|
229 |
+
'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
|
230 |
+
'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
|
231 |
+
'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
|
232 |
+
}
|
233 |
+
|
234 |
+
DATASET_MD5 = {
|
235 |
+
'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
|
236 |
+
'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
|
237 |
+
'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
|
238 |
+
'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
|
239 |
+
'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
|
240 |
+
'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
|
241 |
+
'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
|
242 |
+
'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
|
243 |
+
'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
|
244 |
+
'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
|
245 |
+
'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
|
246 |
+
'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
|
247 |
+
}
|
248 |
+
|
249 |
+
def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
|
250 |
+
super().__init__(dataset, skip_noimg)
|
251 |
+
|
252 |
+
initialize()
|
253 |
+
self.language = 'en' if 'EN' in dataset else 'zh'
|
254 |
+
self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
|
255 |
+
|
256 |
+
# def build_prompt(self, line):
|
257 |
+
# msgs = super().build_prompt(line)
|
258 |
+
# assert msgs[-1]['type'] == 'text'
|
259 |
+
# if self.language == 'zh':
|
260 |
+
# msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
|
261 |
+
# else:
|
262 |
+
# msgs[-1]['value'] += ('What is the covered texts in the image? '
|
263 |
+
# 'Please restore the covered texts without outputting the explanations.')
|
264 |
+
# return msgs
|
265 |
+
|
266 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
267 |
+
import multiprocessing
|
268 |
+
|
269 |
+
vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
|
270 |
+
vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
|
271 |
+
logger = get_logger('Evaluation')
|
272 |
+
data = load(eval_file)
|
273 |
+
|
274 |
+
lt = len(data)
|
275 |
+
lines = [data.iloc[i] for i in range(lt)]
|
276 |
+
|
277 |
+
pool = multiprocessing.Pool()
|
278 |
+
manager = multiprocessing.Manager()
|
279 |
+
progress_queue = manager.Queue()
|
280 |
+
results = []
|
281 |
+
|
282 |
+
overall_results = {str(image_id): {} for image_id in range(len(lines))}
|
283 |
+
|
284 |
+
for instance_id, instance in enumerate(lines):
|
285 |
+
results.append(
|
286 |
+
pool.apply_async(
|
287 |
+
process_match_single_new,
|
288 |
+
args=(
|
289 |
+
str(instance_id),
|
290 |
+
instance['prediction'],
|
291 |
+
instance['answer'],
|
292 |
+
self.language,
|
293 |
+
progress_queue,
|
294 |
+
),
|
295 |
+
)
|
296 |
+
)
|
297 |
+
pool.close()
|
298 |
+
|
299 |
+
# Display progress bar
|
300 |
+
for _ in tqdm(range(len(results))):
|
301 |
+
progress_queue.get()
|
302 |
+
|
303 |
+
pool.join()
|
304 |
+
|
305 |
+
# Merging results into overall_result
|
306 |
+
for result in results:
|
307 |
+
image_id, result_per_id = result.get()
|
308 |
+
overall_results[str(image_id)].update(result_per_id[image_id])
|
309 |
+
for blank_id_str in result_per_id[image_id].keys():
|
310 |
+
vcr_score_list['Exact_Match'].append(
|
311 |
+
result_per_id[image_id][blank_id_str]['exact_match']
|
312 |
+
)
|
313 |
+
vcr_score_list['Jaccard'].append(
|
314 |
+
result_per_id[image_id][blank_id_str]['jaccard']
|
315 |
+
)
|
316 |
+
vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
|
317 |
+
vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
|
318 |
+
results_out = {
|
319 |
+
k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
|
320 |
+
}
|
321 |
+
results_with_metrics = {
|
322 |
+
'Exact_Match': vcr_score['Exact_Match'],
|
323 |
+
'Jaccard': vcr_score['Jaccard'],
|
324 |
+
'Predictions': results_out,
|
325 |
+
}
|
326 |
+
score_pth = eval_file.replace(
|
327 |
+
'.xlsx', f'{self.language}_{self.difficulty}_score.json'
|
328 |
+
)
|
329 |
+
dump(results_with_metrics, score_pth)
|
330 |
+
logger.info(
|
331 |
+
f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
|
332 |
+
)
|
333 |
+
logger.info('Score: ')
|
334 |
+
for key, value in vcr_score.items():
|
335 |
+
logger.info('{}:{}'.format(key, value))
|
VLMEvalKit/vlmeval/dataset/video_base.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from ..smp import *
|
3 |
+
|
4 |
+
|
5 |
+
class VideoBaseDataset:
|
6 |
+
|
7 |
+
MODALITY = 'VIDEO'
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
dataset='MMBench-Video',
|
11 |
+
pack=False):
|
12 |
+
try:
|
13 |
+
import decord
|
14 |
+
except Exception as e:
|
15 |
+
logging.critical(f'{type(e)}: {e}')
|
16 |
+
logging.critical('Please install decord via `pip install decord`.')
|
17 |
+
|
18 |
+
self.dataset_name = dataset
|
19 |
+
ret = self.prepare_dataset(dataset)
|
20 |
+
assert ret is not None
|
21 |
+
lmu_root = LMUDataRoot()
|
22 |
+
self.frame_root = osp.join(lmu_root, 'images', dataset)
|
23 |
+
os.makedirs(self.frame_root, exist_ok=True)
|
24 |
+
self.frame_tmpl = 'frame-{}-of-{}.jpg'
|
25 |
+
self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg'
|
26 |
+
|
27 |
+
self.data_root = ret['root']
|
28 |
+
self.data_file = ret['data_file']
|
29 |
+
self.data = load(self.data_file)
|
30 |
+
|
31 |
+
assert 'question' in self.data and 'video' in self.data
|
32 |
+
videos = list(set(self.data['video']))
|
33 |
+
videos.sort()
|
34 |
+
self.videos = videos
|
35 |
+
self.pack = pack
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.videos) if self.pack else len(self.data)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
if self.pack:
|
42 |
+
assert idx < len(self.videos)
|
43 |
+
sub_data = self.data[self.data['video'] == self.videos[idx]]
|
44 |
+
return sub_data
|
45 |
+
else:
|
46 |
+
assert idx < len(self.data)
|
47 |
+
return dict(self.data.iloc[idx])
|
48 |
+
|
49 |
+
def frame_paths(self, video, num_frames=8):
|
50 |
+
frame_root = osp.join(self.frame_root, video)
|
51 |
+
os.makedirs(frame_root, exist_ok=True)
|
52 |
+
return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
|
53 |
+
|
54 |
+
def frame_paths_fps(self, video, num_frames=8, fps=-1):
|
55 |
+
frame_root = osp.join(self.frame_root, video)
|
56 |
+
os.makedirs(frame_root, exist_ok=True)
|
57 |
+
return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
|
58 |
+
|
59 |
+
def save_video_frames(self, video, num_frames=8, fps=-1):
|
60 |
+
if fps > 0:
|
61 |
+
vid_path = osp.join(self.data_root, video + '.mp4')
|
62 |
+
vid = decord.VideoReader(vid_path)
|
63 |
+
|
64 |
+
# 计算视频的总帧数和总时长
|
65 |
+
total_frames = len(vid)
|
66 |
+
video_fps = vid.get_avg_fps()
|
67 |
+
total_duration = total_frames / video_fps
|
68 |
+
|
69 |
+
# 计算需要提取的总帧数
|
70 |
+
required_frames = int(total_duration * fps)
|
71 |
+
|
72 |
+
# 计算提取帧的间隔
|
73 |
+
step_size = video_fps / fps
|
74 |
+
|
75 |
+
# 计算提取帧的索引
|
76 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
77 |
+
|
78 |
+
# 提取帧并保存
|
79 |
+
frame_paths = self.frame_paths_fps(video, len(indices), fps)
|
80 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
81 |
+
if flag:
|
82 |
+
return frame_paths
|
83 |
+
|
84 |
+
images = [vid[i].asnumpy() for i in indices]
|
85 |
+
images = [Image.fromarray(arr) for arr in images]
|
86 |
+
for im, pth in zip(images, frame_paths):
|
87 |
+
if not osp.exists(pth):
|
88 |
+
im.save(pth)
|
89 |
+
return frame_paths
|
90 |
+
|
91 |
+
else:
|
92 |
+
frame_paths = self.frame_paths(video, num_frames)
|
93 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
94 |
+
if flag:
|
95 |
+
return frame_paths
|
96 |
+
vid_path = osp.join(self.data_root, video + '.mp4')
|
97 |
+
vid = decord.VideoReader(vid_path)
|
98 |
+
step_size = len(vid) / (num_frames + 1)
|
99 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
100 |
+
images = [vid[i].asnumpy() for i in indices]
|
101 |
+
images = [Image.fromarray(arr) for arr in images]
|
102 |
+
for im, pth in zip(images, frame_paths):
|
103 |
+
if not osp.exists(pth):
|
104 |
+
im.save(pth)
|
105 |
+
return frame_paths
|
106 |
+
|
107 |
+
# Return a list of dataset names that are supported by this class, can override
|
108 |
+
@classmethod
|
109 |
+
def supported_datasets(cls):
|
110 |
+
return ['MMBench-Video', 'Video-MME', 'MVBench', 'MVBench_MP4', 'LongVideoBench']
|
111 |
+
|
112 |
+
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
|
113 |
+
@abstractmethod
|
114 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
115 |
+
pass
|
116 |
+
|
117 |
+
@abstractmethod
|
118 |
+
def build_prompt(self, idx, num_frames=8):
|
119 |
+
pass
|
120 |
+
|
121 |
+
@abstractmethod
|
122 |
+
def prepare_dataset(self, dataset):
|
123 |
+
# The prepare_dataset function should return a dictionary containing:
|
124 |
+
# `root` (directory that containing video files)
|
125 |
+
# `data_file` (the TSV dataset file)
|
126 |
+
pass
|
VLMEvalKit/vlmeval/dataset/video_concat_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..smp import *
|
2 |
+
from .video_base import VideoBaseDataset
|
3 |
+
|
4 |
+
|
5 |
+
class ConcatVideoDataset(VideoBaseDataset):
|
6 |
+
# This dataset takes multiple dataset names as input and aggregate them into a single dataset.
|
7 |
+
# Each single dataset should not have a field named `SUB_DATASET`
|
8 |
+
|
9 |
+
DATASET_SETS = {}
|
10 |
+
|
11 |
+
def __init__(self, dataset):
|
12 |
+
from . import build_dataset
|
13 |
+
datasets = self.DATASET_SETS[dataset]
|
14 |
+
self.dataset_map = {}
|
15 |
+
# The name of the compliation
|
16 |
+
self.dataset_name = dataset
|
17 |
+
self.datasets = datasets
|
18 |
+
for dname in datasets:
|
19 |
+
dataset = build_dataset(dname)
|
20 |
+
assert dataset is not None, dataset
|
21 |
+
self.dataset_map[dname] = dataset
|
22 |
+
TYPES = [x.TYPE for x in self.dataset_map.values()]
|
23 |
+
MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
|
24 |
+
# assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
|
25 |
+
assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
|
26 |
+
self.TYPE = TYPES
|
27 |
+
self.MODALITY = MODALITIES[0]
|
28 |
+
data_all = []
|
29 |
+
for dname in datasets:
|
30 |
+
data = self.dataset_map[dname].data
|
31 |
+
data['SUB_DATASET'] = [dname] * len(data)
|
32 |
+
data_all.append(data)
|
33 |
+
|
34 |
+
data = pd.concat(data_all)
|
35 |
+
data['original_index'] = data.pop('index')
|
36 |
+
data['index'] = np.arange(len(data))
|
37 |
+
self.data = data
|
38 |
+
|
39 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
40 |
+
if isinstance(line, int):
|
41 |
+
line = self.data.iloc[line]
|
42 |
+
idx = line['original_index']
|
43 |
+
dname = line['SUB_DATASET']
|
44 |
+
org_data = self.dataset_map[dname].data
|
45 |
+
org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
|
46 |
+
return self.dataset_map[dname].build_prompt(org_line, num_frames, video_llm, fps)
|
47 |
+
|
48 |
+
def dump_image(self, line):
|
49 |
+
# Assert all images are pre-dumped
|
50 |
+
assert 'image' not in line
|
51 |
+
assert 'image_path' in line
|
52 |
+
tgt_path = toliststr(line['image_path'])
|
53 |
+
return tgt_path
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def supported_datasets(cls):
|
57 |
+
return [] # list(cls.DATASET_SETS)
|
58 |
+
|
59 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
60 |
+
suffix = eval_file.split('.')[-1]
|
61 |
+
# First, split the eval_file by dataset
|
62 |
+
data_all = load(eval_file)
|
63 |
+
for dname in self.datasets:
|
64 |
+
tgt = eval_file.replace(self.dataset_name, dname)
|
65 |
+
data_sub = data_all[data_all['SUB_DATASET'] == dname]
|
66 |
+
data_sub.pop('index')
|
67 |
+
data_sub['index'] = data_sub.pop('original_index')
|
68 |
+
data_sub.pop('SUB_DATASET')
|
69 |
+
dump(data_sub, tgt)
|
70 |
+
# Then, evaluate each dataset separately
|
71 |
+
results_all = {}
|
72 |
+
for dname in self.datasets:
|
73 |
+
tgt = eval_file.replace(self.dataset_name, dname)
|
74 |
+
res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
|
75 |
+
results_all.update(res)
|
76 |
+
|
77 |
+
result = pd.DataFrame(results_all, index=['success', 'overall'])
|
78 |
+
result = result.T
|
79 |
+
for idx, item in result.iterrows():
|
80 |
+
result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 1)
|
81 |
+
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
|
82 |
+
dump(result, score_file)
|
83 |
+
return result
|
VLMEvalKit/vlmeval/dataset/videomme.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
from ..smp import *
|
3 |
+
from .video_base import VideoBaseDataset
|
4 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
5 |
+
|
6 |
+
FAIL_MSG = 'Failed to obtain answer via API.'
|
7 |
+
|
8 |
+
|
9 |
+
def unwrap_hf_pkl(pth, suffix='.mp4'):
|
10 |
+
base_dir = os.path.join(pth, 'video_pkl/')
|
11 |
+
target_dir = os.path.join(pth, 'video/')
|
12 |
+
pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
|
13 |
+
pickle_files.sort()
|
14 |
+
|
15 |
+
if not os.path.exists(target_dir):
|
16 |
+
os.makedirs(target_dir, exist_ok=True)
|
17 |
+
for pickle_file in pickle_files:
|
18 |
+
with open(pickle_file, 'rb') as file:
|
19 |
+
video_data = pickle.load(file)
|
20 |
+
# For each video file in the pickle file, write its contents to a new mp4 file
|
21 |
+
for video_name, video_content in video_data.items():
|
22 |
+
output_path = os.path.join(target_dir, f'{video_name}{suffix}')
|
23 |
+
with open(output_path, 'wb') as output_file:
|
24 |
+
output_file.write(video_content)
|
25 |
+
print('The video file has been restored and stored from the pickle file.')
|
26 |
+
else:
|
27 |
+
print('The video file already exists.')
|
28 |
+
|
29 |
+
|
30 |
+
class VideoMME(VideoBaseDataset):
|
31 |
+
|
32 |
+
MD5 = '85bdd91f9b29a99354c23b97ab7c113c'
|
33 |
+
SYS = ''
|
34 |
+
|
35 |
+
FRAMES_TMPL_NOSUB = """
|
36 |
+
These are the frames of a video. \
|
37 |
+
Select the best answer to the following multiple-choice question based on the video. \
|
38 |
+
Respond with only the letter (A, B, C, or D) of the correct option.
|
39 |
+
"""
|
40 |
+
|
41 |
+
FRAMES_TMPL_SUB = """
|
42 |
+
These are the frames of a video. \
|
43 |
+
This video's subtitles are listed below:
|
44 |
+
{}
|
45 |
+
Select the best answer to the following multiple-choice question based on the video. \
|
46 |
+
Respond with only the letter (A, B, C, or D) of the correct option.
|
47 |
+
"""
|
48 |
+
|
49 |
+
TYPE = 'Video-MCQ'
|
50 |
+
|
51 |
+
def __init__(self, dataset='Video-MME', use_subtitle=False):
|
52 |
+
super().__init__(dataset=dataset)
|
53 |
+
self.use_subtitle = use_subtitle
|
54 |
+
self.dataset_name = dataset
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def supported_datasets(cls):
|
58 |
+
return ['Video-MME']
|
59 |
+
|
60 |
+
def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'):
|
61 |
+
|
62 |
+
def check_integrity(pth):
|
63 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
64 |
+
|
65 |
+
if not os.path.exists(data_file):
|
66 |
+
return False
|
67 |
+
|
68 |
+
if md5(data_file) != self.MD5:
|
69 |
+
return False
|
70 |
+
data = load(data_file)
|
71 |
+
for video_pth in data['video_path']:
|
72 |
+
if not osp.exists(osp.join(pth, video_pth)):
|
73 |
+
return False
|
74 |
+
return True
|
75 |
+
|
76 |
+
cache_path = get_cache_path(repo_id)
|
77 |
+
if cache_path is not None and check_integrity(cache_path):
|
78 |
+
dataset_path = cache_path
|
79 |
+
else:
|
80 |
+
|
81 |
+
def unzip_hf_zip(pth):
|
82 |
+
import zipfile
|
83 |
+
base_dir = pth
|
84 |
+
target_dir = os.path.join(pth, 'video/')
|
85 |
+
zip_files = [
|
86 |
+
os.path.join(base_dir, file) for file in os.listdir(base_dir)
|
87 |
+
if file.endswith('.zip') and file.startswith('video')
|
88 |
+
]
|
89 |
+
zip_files.sort()
|
90 |
+
|
91 |
+
if not os.path.exists(target_dir):
|
92 |
+
os.makedirs(target_dir, exist_ok=True)
|
93 |
+
for zip_file in zip_files:
|
94 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
95 |
+
for member in zip_ref.namelist():
|
96 |
+
# Check if the member is a file (not a directory)
|
97 |
+
if not member.endswith('/'):
|
98 |
+
# Extract the file to the specified directory
|
99 |
+
source = zip_ref.open(member)
|
100 |
+
target = open(os.path.join(target_dir, os.path.basename(member)), 'wb')
|
101 |
+
with source, target:
|
102 |
+
target.write(source.read())
|
103 |
+
print('The video file has been restored and stored from the zip file.')
|
104 |
+
else:
|
105 |
+
print('The video file already exists.')
|
106 |
+
|
107 |
+
subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip')
|
108 |
+
subtitle_target_dir = os.path.join(base_dir, 'subtitle')
|
109 |
+
|
110 |
+
if not os.path.exists(subtitle_target_dir):
|
111 |
+
os.makedirs(subtitle_target_dir, exist_ok=True)
|
112 |
+
with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref:
|
113 |
+
for member in zip_ref.namelist():
|
114 |
+
# Check if the member is a file (not a directory)
|
115 |
+
if not member.endswith('/'):
|
116 |
+
# Extract the file to the specified directory
|
117 |
+
source = zip_ref.open(member)
|
118 |
+
target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb')
|
119 |
+
with source, target:
|
120 |
+
target.write(source.read())
|
121 |
+
print('The subtitle file has been restored and stored from the zip file.')
|
122 |
+
else:
|
123 |
+
print('The subtitle file already exists.')
|
124 |
+
|
125 |
+
def generate_tsv(pth):
|
126 |
+
|
127 |
+
data_file = osp.join(pth, f'{dataset_name}.tsv')
|
128 |
+
if os.path.exists(data_file) and md5(data_file) == self.MD5:
|
129 |
+
return
|
130 |
+
|
131 |
+
data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet'))
|
132 |
+
data_file = data_file.assign(index=range(len(data_file)))
|
133 |
+
data_file['video'] = data_file['videoID']
|
134 |
+
data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4')
|
135 |
+
data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt')
|
136 |
+
data_file['candidates'] = data_file['options'].apply(lambda x: x.tolist())
|
137 |
+
|
138 |
+
data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain', 'candidates',
|
139 |
+
'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']]
|
140 |
+
|
141 |
+
data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
|
142 |
+
|
143 |
+
if modelscope_flag_set():
|
144 |
+
from modelscope import dataset_snapshot_download
|
145 |
+
dataset_path = dataset_snapshot_download(dataset_id=repo_id)
|
146 |
+
else:
|
147 |
+
dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
|
148 |
+
unzip_hf_zip(dataset_path)
|
149 |
+
generate_tsv(dataset_path)
|
150 |
+
|
151 |
+
data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
|
152 |
+
|
153 |
+
return dict(data_file=data_file, root=dataset_path)
|
154 |
+
|
155 |
+
def save_video_frames(self, video, num_frames=8, fps=-1, video_llm=False):
|
156 |
+
|
157 |
+
vid_path = osp.join(self.data_root, 'video', video + '.mp4')
|
158 |
+
vid = decord.VideoReader(vid_path)
|
159 |
+
video_info = {
|
160 |
+
'fps': vid.get_avg_fps(),
|
161 |
+
'n_frames': len(vid),
|
162 |
+
}
|
163 |
+
if num_frames > 0 and fps < 0:
|
164 |
+
step_size = len(vid) / (num_frames + 1)
|
165 |
+
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
|
166 |
+
frame_paths = self.frame_paths(video, num_frames)
|
167 |
+
elif fps > 0:
|
168 |
+
# not constrained by num_frames, get frames by fps
|
169 |
+
total_duration = video_info['n_frames'] / video_info['fps']
|
170 |
+
required_frames = int(total_duration * fps)
|
171 |
+
step_size = video_info['fps'] / fps
|
172 |
+
indices = [int(i * step_size) for i in range(required_frames)]
|
173 |
+
frame_paths = self.frame_paths_fps(video, len(indices), fps)
|
174 |
+
|
175 |
+
flag = np.all([osp.exists(p) for p in frame_paths])
|
176 |
+
|
177 |
+
if not flag:
|
178 |
+
images = [vid[i].asnumpy() for i in indices]
|
179 |
+
images = [Image.fromarray(arr) for arr in images]
|
180 |
+
for im, pth in zip(images, frame_paths):
|
181 |
+
if not osp.exists(pth) and not video_llm:
|
182 |
+
im.save(pth)
|
183 |
+
|
184 |
+
return frame_paths, indices, video_info
|
185 |
+
|
186 |
+
def save_video_into_images(self, line, num_frames=8):
|
187 |
+
frame_paths, indices, video_info = self.save_video_frames(line['video'], num_frames)
|
188 |
+
return frame_paths
|
189 |
+
|
190 |
+
def build_prompt(self, line, num_frames, video_llm, fps):
|
191 |
+
if isinstance(line, int):
|
192 |
+
assert line < len(self)
|
193 |
+
line = self.data.iloc[line]
|
194 |
+
|
195 |
+
frames, indices, video_info = self.save_video_frames(line['video'], num_frames, fps, video_llm)
|
196 |
+
|
197 |
+
if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])):
|
198 |
+
import pysubs2
|
199 |
+
subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8')
|
200 |
+
subtitles = []
|
201 |
+
|
202 |
+
for seleced_frame_id in indices:
|
203 |
+
sub_text = ''
|
204 |
+
cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id)
|
205 |
+
for sub in subs:
|
206 |
+
if sub.start < cur_time and sub.end > cur_time:
|
207 |
+
sub_text = sub.text.replace('\\N', ' ')
|
208 |
+
break
|
209 |
+
if sub_text.strip():
|
210 |
+
subtitles.append(sub_text)
|
211 |
+
subtitles = '\n'.join(subtitles)
|
212 |
+
else:
|
213 |
+
subtitles = ''
|
214 |
+
|
215 |
+
message = [dict(type='text', value=self.SYS)]
|
216 |
+
if video_llm:
|
217 |
+
message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4')))
|
218 |
+
else:
|
219 |
+
for im in frames:
|
220 |
+
message.append(dict(type='image', value=im))
|
221 |
+
|
222 |
+
text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles)
|
223 |
+
message.append(dict(type='text', value=text_prompt))
|
224 |
+
line['question'] += '\n' + '\n'.join(eval(line['candidates']))
|
225 |
+
prompt = 'Question: {}\nAnswer: '.format(line['question'])
|
226 |
+
message.append(dict(type='text', value=prompt))
|
227 |
+
return message
|
228 |
+
|
229 |
+
# It returns a dictionary
|
230 |
+
@classmethod
|
231 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
232 |
+
from .utils.videomme import get_dimension_rating, extract_characters_regex, extract_option
|
233 |
+
|
234 |
+
assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
|
235 |
+
|
236 |
+
tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
|
237 |
+
tgt_file = eval_file.replace('.xlsx', '_rating.json')
|
238 |
+
score_file = eval_file.replace('.xlsx', '_score.xlsx')
|
239 |
+
|
240 |
+
if not osp.exists(score_file):
|
241 |
+
model = judge_kwargs.get('model', 'exact_matching')
|
242 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
243 |
+
|
244 |
+
if model == 'exact_matching':
|
245 |
+
model = None
|
246 |
+
elif gpt_key_set():
|
247 |
+
model = build_judge(**judge_kwargs)
|
248 |
+
if not model.working():
|
249 |
+
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
|
250 |
+
warnings.warn(DEBUG_MESSAGE)
|
251 |
+
model = None
|
252 |
+
else:
|
253 |
+
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
254 |
+
model = None
|
255 |
+
res = {} if not osp.exists(tmp_file) else load(tmp_file)
|
256 |
+
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
257 |
+
|
258 |
+
data = load(eval_file)
|
259 |
+
data_un = data[~pd.isna(data['prediction'])]
|
260 |
+
|
261 |
+
for idx in data['index']:
|
262 |
+
ans = data.loc[data['index'] == idx, 'answer'].values[0]
|
263 |
+
pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
|
264 |
+
|
265 |
+
if extract_characters_regex(pred) == '':
|
266 |
+
extract_pred = extract_option(
|
267 |
+
model,
|
268 |
+
data.loc[data['index'] == idx].to_dict(orient='records')[0],
|
269 |
+
'Video-MME'
|
270 |
+
)
|
271 |
+
data.loc[idx, 'score'] = int(extract_pred == ans)
|
272 |
+
else:
|
273 |
+
data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
|
274 |
+
|
275 |
+
rejected = [x for x in data['score'] if x == -1]
|
276 |
+
|
277 |
+
print(
|
278 |
+
f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
|
279 |
+
f'failed to obtain the score for another {len(rejected)} questions. '
|
280 |
+
f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
|
281 |
+
)
|
282 |
+
|
283 |
+
dump(data, score_file)
|
284 |
+
|
285 |
+
rating = get_dimension_rating(score_file)
|
286 |
+
dump(rating, tgt_file)
|
287 |
+
return rating
|
VLMEvalKit/vlmeval/dataset/wildvision.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from .image_base import ImageBaseDataset
|
5 |
+
from .utils import build_judge, DEBUG_MESSAGE
|
6 |
+
from ..smp import *
|
7 |
+
from ..utils import track_progress_rich
|
8 |
+
|
9 |
+
|
10 |
+
SYSTEM_PROMPT = """\
|
11 |
+
Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user \
|
12 |
+
prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate \
|
13 |
+
which assistant's answer is better.
|
14 |
+
|
15 |
+
Begin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any \
|
16 |
+
answers.
|
17 |
+
|
18 |
+
When evaluating the assistants' answers, compare both assistants' answers with your answer. \
|
19 |
+
You must identify and correct any mistakes or inaccurate information.
|
20 |
+
|
21 |
+
Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly \
|
22 |
+
responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one \
|
23 |
+
interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than \
|
24 |
+
providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate \
|
25 |
+
to what is being asked. Concise means the response is clear and not verbose or excessive.
|
26 |
+
|
27 |
+
Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing \
|
28 |
+
important information in the assistants' answers that would be beneficial to include when responding to the user \
|
29 |
+
prompt.
|
30 |
+
|
31 |
+
After providing your explanation, you must output only one of the following choices as your final verdict with a label:
|
32 |
+
|
33 |
+
1. Assistant A is significantly better: [[A>>B]]
|
34 |
+
2. Assistant A is slightly better: [[A>B]]
|
35 |
+
3. Tie, relatively the same: [[A=B]]
|
36 |
+
4. Assistant B is slightly better: [[B>A]]
|
37 |
+
5. Assistant B is significantly better: [[B>>A]]
|
38 |
+
|
39 |
+
Example output: "My final verdict is tie: [[A=B]]".\
|
40 |
+
"""
|
41 |
+
|
42 |
+
|
43 |
+
PROMPT_TEMPLATE = """\
|
44 |
+
"<|User Prompt|>\n{question}
|
45 |
+
|
46 |
+
<|The Start of Assistant A's Answer|>\n{answer_1}\n<|The End of Assistant A's Answer|>
|
47 |
+
|
48 |
+
<|The Start of Assistant B's Answer|>\n{answer_2}\n<|The End of Assistant B's Answer|>
|
49 |
+
"""
|
50 |
+
|
51 |
+
|
52 |
+
REGEX_PATTERN = re.compile("\[\[([AB<>=]+)\]\]") # noqa: W605
|
53 |
+
|
54 |
+
|
55 |
+
def get_score(judgement, pattern=REGEX_PATTERN):
|
56 |
+
matches = pattern.findall(judgement)
|
57 |
+
matches = [m for m in matches if m != ""]
|
58 |
+
if len(set(matches)) == 0:
|
59 |
+
return None, True
|
60 |
+
elif len(set(matches)) == 1:
|
61 |
+
return matches[0].strip("\n"), False
|
62 |
+
else:
|
63 |
+
return None, True
|
64 |
+
|
65 |
+
|
66 |
+
def WildVision_auxeval(model, line):
|
67 |
+
config = dict(question=line['question'], answer_1=line['A'], answer_2=line['B'])
|
68 |
+
prompt = PROMPT_TEMPLATE.format(**config)
|
69 |
+
|
70 |
+
prefix = 'data:image/jpeg;base64,'
|
71 |
+
img = prefix + line['image']
|
72 |
+
|
73 |
+
messages = [
|
74 |
+
dict(type='text', value=prompt),
|
75 |
+
dict(type='image', value=img)
|
76 |
+
]
|
77 |
+
|
78 |
+
retry = 2
|
79 |
+
while retry:
|
80 |
+
resp = model.generate(messages)
|
81 |
+
score, try_again = get_score(resp)
|
82 |
+
if not try_again:
|
83 |
+
break
|
84 |
+
retry -= 1
|
85 |
+
|
86 |
+
if score is None:
|
87 |
+
return 'Unknown'
|
88 |
+
return score
|
89 |
+
|
90 |
+
|
91 |
+
class WildVision(ImageBaseDataset):
|
92 |
+
TYPE = 'VQA'
|
93 |
+
DATASET_URL = {
|
94 |
+
'WildVision': 'https://opencompass.openxlab.space/utils/VLMEval/WildVision.tsv'
|
95 |
+
}
|
96 |
+
DATASET_MD5 = {'WildVision': 'b38f80156d49411c594772866b0d0b52'}
|
97 |
+
|
98 |
+
score_map = {
|
99 |
+
'A>>B': -2,
|
100 |
+
'A>B': -1,
|
101 |
+
'A=B': 0,
|
102 |
+
'B>A': 1,
|
103 |
+
'B>>A': 2
|
104 |
+
}
|
105 |
+
|
106 |
+
# Given one data record, return the built prompt (a multi-modal message), can override
|
107 |
+
def build_prompt(self, line):
|
108 |
+
if isinstance(line, int):
|
109 |
+
line = self.data.iloc[line]
|
110 |
+
|
111 |
+
if self.meta_only:
|
112 |
+
tgt_path = toliststr(line['image_path'])
|
113 |
+
else:
|
114 |
+
tgt_path = self.dump_image(line)
|
115 |
+
|
116 |
+
question = line['question']
|
117 |
+
|
118 |
+
msgs = []
|
119 |
+
if isinstance(tgt_path, list):
|
120 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
121 |
+
else:
|
122 |
+
msgs = [dict(type='image', value=tgt_path)]
|
123 |
+
# WildVision adopts text first
|
124 |
+
msgs = [dict(type='text', value=question)] + msgs
|
125 |
+
return msgs
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def gen_eval_base(self, eval_file, b64_map):
|
129 |
+
data = load(eval_file)
|
130 |
+
data['B'] = data.pop('prediction')
|
131 |
+
data['A'] = data.pop('claude3_sonnet')
|
132 |
+
data['image'] = [b64_map[x] for x in data['index']]
|
133 |
+
return data
|
134 |
+
# rev = cp.deepcopy(data)
|
135 |
+
# rev['A'] = data['B']
|
136 |
+
# rev['B'] = data['A']
|
137 |
+
# rev['index'] = [x + '_rev' for x in data['index']]
|
138 |
+
# return pd.concat([data, rev], ignore_index=True)
|
139 |
+
|
140 |
+
# It returns a DataFrame
|
141 |
+
@classmethod
|
142 |
+
def evaluate(self, eval_file, **judge_kwargs):
|
143 |
+
# We adopt pairwise evaluation (twice for a pair) for this dataset
|
144 |
+
suffix = eval_file.split('.')[-1]
|
145 |
+
model = judge_kwargs['model']
|
146 |
+
storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
|
147 |
+
score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
|
148 |
+
tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
|
149 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
150 |
+
|
151 |
+
if not osp.exists(storage):
|
152 |
+
raw_data = WildVision('WildVision').data
|
153 |
+
b64_map = {x: y for x, y in zip(raw_data['index'], raw_data['image'])}
|
154 |
+
data = self.gen_eval_base(eval_file, b64_map)
|
155 |
+
|
156 |
+
judge_kwargs['system_prompt'] = SYSTEM_PROMPT
|
157 |
+
judge_kwargs['temperature'] = 0
|
158 |
+
judge_kwargs['img_detail'] = 'high'
|
159 |
+
judge_kwargs['timeout'] = 300
|
160 |
+
model = build_judge(max_tokens=4096, **judge_kwargs)
|
161 |
+
|
162 |
+
assert model.working(), ('WildVision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
|
163 |
+
|
164 |
+
lt = len(data)
|
165 |
+
lines = [data.iloc[i] for i in range(lt)]
|
166 |
+
tups = [(model, line) for line in lines]
|
167 |
+
indices = [line['index'] for line in lines]
|
168 |
+
|
169 |
+
ans = load(tmp_file) if osp.exists(tmp_file) else {}
|
170 |
+
tups = [x for x, i in zip(tups, indices) if i not in ans]
|
171 |
+
indices = [i for i in indices if i not in ans]
|
172 |
+
|
173 |
+
if len(indices):
|
174 |
+
new_results = track_progress_rich(
|
175 |
+
WildVision_auxeval,
|
176 |
+
tups,
|
177 |
+
nproc=nproc,
|
178 |
+
chunksize=nproc,
|
179 |
+
keys=indices,
|
180 |
+
save=tmp_file,
|
181 |
+
)
|
182 |
+
ans = load(tmp_file)
|
183 |
+
for k, v in zip(indices, new_results):
|
184 |
+
ans[k] = v
|
185 |
+
|
186 |
+
data['score'] = [ans[idx] for idx in data['index']]
|
187 |
+
data.pop('image')
|
188 |
+
dump(data, storage)
|
189 |
+
|
190 |
+
data = load(storage)
|
191 |
+
lt = len(data)
|
192 |
+
|
193 |
+
scores = defaultdict(lambda: 0)
|
194 |
+
for i in range(lt):
|
195 |
+
item = data.iloc[i]
|
196 |
+
if item['score'] not in self.score_map:
|
197 |
+
score = 0
|
198 |
+
else:
|
199 |
+
score = self.score_map[item['score']]
|
200 |
+
if '_rev' in item['index']:
|
201 |
+
score = -score
|
202 |
+
scores[score] += 1
|
203 |
+
name_map = {
|
204 |
+
2: 'Much Better',
|
205 |
+
1: 'Better',
|
206 |
+
0: 'Tie',
|
207 |
+
-1: 'Worse',
|
208 |
+
-2: 'Much Worse'
|
209 |
+
}
|
210 |
+
scores = {name_map[k]: v for k, v in scores.items()}
|
211 |
+
scores['Reward'] = (
|
212 |
+
100 * scores['Much Better'] + 50 * scores['Better'] - 50 * scores['Worse'] - 100 * scores['Much Worse']
|
213 |
+
) / lt
|
214 |
+
scores['Win Rate'] = (scores['Better'] + scores['Much Better']) / lt
|
215 |
+
scores = {k: [v] for k, v in scores.items()}
|
216 |
+
scores = pd.DataFrame(scores)
|
217 |
+
dump(scores, score_file)
|
218 |
+
return scores
|
VLMEvalKit/vlmeval/smp/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .file import *
|
2 |
+
from .vlm import *
|
3 |
+
from .misc import *
|
4 |
+
from .log import *
|
VLMEvalKit/vlmeval/smp/log.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logging.basicConfig(
|
3 |
+
format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
|
4 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
5 |
+
|
6 |
+
logger_initialized = {}
|
7 |
+
|
8 |
+
|
9 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
10 |
+
logger = logging.getLogger(name)
|
11 |
+
if name in logger_initialized:
|
12 |
+
return logger
|
13 |
+
|
14 |
+
for logger_name in logger_initialized:
|
15 |
+
if name.startswith(logger_name):
|
16 |
+
return logger
|
17 |
+
|
18 |
+
stream_handler = logging.StreamHandler()
|
19 |
+
handlers = [stream_handler]
|
20 |
+
|
21 |
+
try:
|
22 |
+
import torch.distributed as dist
|
23 |
+
if dist.is_available() and dist.is_initialized():
|
24 |
+
rank = dist.get_rank()
|
25 |
+
else:
|
26 |
+
rank = 0
|
27 |
+
except ImportError:
|
28 |
+
rank = 0
|
29 |
+
|
30 |
+
if rank == 0 and log_file is not None:
|
31 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
32 |
+
handlers.append(file_handler)
|
33 |
+
|
34 |
+
formatter = logging.Formatter(
|
35 |
+
'[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s')
|
36 |
+
for handler in handlers:
|
37 |
+
handler.setFormatter(formatter)
|
38 |
+
handler.setLevel(log_level)
|
39 |
+
logger.addHandler(handler)
|
40 |
+
|
41 |
+
if rank == 0:
|
42 |
+
logger.setLevel(log_level)
|
43 |
+
else:
|
44 |
+
logger.setLevel(logging.ERROR)
|
45 |
+
|
46 |
+
logger_initialized[name] = True
|
47 |
+
return logger
|
VLMEvalKit/vlmeval/smp/misc.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: F401, F403
|
2 |
+
import abc
|
3 |
+
import argparse
|
4 |
+
import csv
|
5 |
+
import multiprocessing as mp
|
6 |
+
import os
|
7 |
+
import os.path as osp
|
8 |
+
from pathlib import Path
|
9 |
+
import copy as cp
|
10 |
+
import random as rd
|
11 |
+
import requests
|
12 |
+
import shutil
|
13 |
+
import subprocess
|
14 |
+
import warnings
|
15 |
+
import pandas as pd
|
16 |
+
from collections import OrderedDict, defaultdict
|
17 |
+
from multiprocessing import Pool, current_process
|
18 |
+
from tqdm import tqdm
|
19 |
+
import datetime
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
from tabulate import tabulate
|
22 |
+
from json import JSONDecoder
|
23 |
+
from huggingface_hub import scan_cache_dir
|
24 |
+
from huggingface_hub.utils._cache_manager import _scan_cached_repo
|
25 |
+
from sty import fg, bg, ef, rs
|
26 |
+
|
27 |
+
|
28 |
+
def modelscope_flag_set():
|
29 |
+
return os.environ.get('VLMEVALKIT_USE_MODELSCOPE', None) in ['1', 'True']
|
30 |
+
|
31 |
+
|
32 |
+
def process_punctuation(inText):
|
33 |
+
import re
|
34 |
+
outText = inText
|
35 |
+
punct = [
|
36 |
+
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
|
37 |
+
'>', '<', '@', '`', ',', '?', '!'
|
38 |
+
]
|
39 |
+
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
|
40 |
+
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
|
41 |
+
for p in punct:
|
42 |
+
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
|
43 |
+
commaStrip, inText) is not None):
|
44 |
+
outText = outText.replace(p, '')
|
45 |
+
else:
|
46 |
+
outText = outText.replace(p, ' ')
|
47 |
+
outText = periodStrip.sub('', outText, re.UNICODE)
|
48 |
+
return outText
|
49 |
+
|
50 |
+
def h2r(value):
|
51 |
+
if value[0] == '#':
|
52 |
+
value = value[1:]
|
53 |
+
assert len(value) == 6
|
54 |
+
return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
|
55 |
+
|
56 |
+
def r2h(rgb):
|
57 |
+
return '#%02x%02x%02x' % rgb
|
58 |
+
|
59 |
+
def colored(s, color):
|
60 |
+
if isinstance(color, str):
|
61 |
+
if hasattr(fg, color):
|
62 |
+
return getattr(fg, color) + s + fg.rs
|
63 |
+
color = h2r(color)
|
64 |
+
return fg(*color) + s + fg.rs
|
65 |
+
|
66 |
+
def istype(s, type):
|
67 |
+
if isinstance(s, type):
|
68 |
+
return True
|
69 |
+
try:
|
70 |
+
return isinstance(eval(s), type)
|
71 |
+
except Exception as _:
|
72 |
+
return False
|
73 |
+
|
74 |
+
def bincount(lst):
|
75 |
+
bins = defaultdict(lambda: 0)
|
76 |
+
for item in lst:
|
77 |
+
bins[item] += 1
|
78 |
+
return bins
|
79 |
+
|
80 |
+
def get_cache_path(repo_id, branch='main', repo_type='datasets'):
|
81 |
+
try:
|
82 |
+
if modelscope_flag_set():
|
83 |
+
from modelscope.hub.file_download import create_temporary_directory_and_cache
|
84 |
+
if repo_type == 'datasets':
|
85 |
+
repo_type = 'dataset'
|
86 |
+
_, cache = create_temporary_directory_and_cache(model_id=repo_id, repo_type=repo_type)
|
87 |
+
cache_path = cache.get_root_location()
|
88 |
+
return cache_path
|
89 |
+
else:
|
90 |
+
from .file import HFCacheRoot
|
91 |
+
cache_path = HFCacheRoot()
|
92 |
+
org, repo_name = repo_id.split('/')
|
93 |
+
repo_path = Path(osp.join(cache_path, f'{repo_type}--{org}--{repo_name}/'))
|
94 |
+
hf_cache_info = _scan_cached_repo(repo_path=repo_path)
|
95 |
+
revs = {r.refs: r for r in hf_cache_info.revisions}
|
96 |
+
if branch is not None:
|
97 |
+
revs = {refs: r for refs, r in revs.items() if branch in refs}
|
98 |
+
rev2keep = max(revs.values(), key=lambda r: r.last_modified)
|
99 |
+
return str(rev2keep.snapshot_path)
|
100 |
+
except Exception as e:
|
101 |
+
import logging
|
102 |
+
logging.warning(f'{type(e)}: {e}')
|
103 |
+
return None
|
104 |
+
|
105 |
+
def proxy_set(s):
|
106 |
+
import os
|
107 |
+
for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
|
108 |
+
os.environ[key] = s
|
109 |
+
|
110 |
+
def get_rank_and_world_size():
|
111 |
+
rank = int(os.environ.get('RANK', 0))
|
112 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
113 |
+
return rank, world_size
|
114 |
+
|
115 |
+
def splitlen(s, sym='/'):
|
116 |
+
return len(s.split(sym))
|
117 |
+
|
118 |
+
def listinstr(lst, s):
|
119 |
+
assert isinstance(lst, list)
|
120 |
+
for item in lst:
|
121 |
+
if item in s:
|
122 |
+
return True
|
123 |
+
return False
|
124 |
+
|
125 |
+
def d2df(D):
|
126 |
+
return pd.DataFrame({x: [D[x]] for x in D})
|
127 |
+
|
128 |
+
def cn_string(s):
|
129 |
+
import re
|
130 |
+
if re.search(u'[\u4e00-\u9fff]', s):
|
131 |
+
return True
|
132 |
+
return False
|
133 |
+
|
134 |
+
try:
|
135 |
+
import decord
|
136 |
+
except ImportError:
|
137 |
+
pass
|
138 |
+
|
139 |
+
def timestr(granularity='second'):
|
140 |
+
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
141 |
+
assert granularity in ['second', 'minute', 'hour', 'day']
|
142 |
+
if granularity == 'second':
|
143 |
+
return s
|
144 |
+
elif granularity == 'minute':
|
145 |
+
return s[:-2]
|
146 |
+
elif granularity == 'hour':
|
147 |
+
return s[:-4]
|
148 |
+
elif granularity == 'day':
|
149 |
+
return s[:-6]
|
150 |
+
|
151 |
+
def _minimal_ext_cmd(cmd, cwd=None):
|
152 |
+
env = {}
|
153 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
154 |
+
v = os.environ.get(k)
|
155 |
+
if v is not None:
|
156 |
+
env[k] = v
|
157 |
+
env['LANGUAGE'] = 'C'
|
158 |
+
env['LANG'] = 'C'
|
159 |
+
env['LC_ALL'] = 'C'
|
160 |
+
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=cwd).communicate()[0]
|
161 |
+
return out
|
162 |
+
|
163 |
+
def githash(fallback='unknown', digits=8):
|
164 |
+
if digits is not None and not isinstance(digits, int):
|
165 |
+
raise TypeError('digits must be None or an integer')
|
166 |
+
try:
|
167 |
+
import vlmeval
|
168 |
+
except ImportError as e:
|
169 |
+
import logging
|
170 |
+
logging.error(f'ImportError: {str(e)}')
|
171 |
+
return fallback
|
172 |
+
try:
|
173 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'], cwd=vlmeval.__path__[0])
|
174 |
+
sha = out.strip().decode('ascii')
|
175 |
+
if digits is not None:
|
176 |
+
sha = sha[:digits]
|
177 |
+
except OSError:
|
178 |
+
sha = fallback
|
179 |
+
return sha
|
180 |
+
|
181 |
+
def dict_merge(dct, merge_dct):
|
182 |
+
for k, _ in merge_dct.items():
|
183 |
+
if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
|
184 |
+
dict_merge(dct[k], merge_dct[k])
|
185 |
+
else:
|
186 |
+
dct[k] = merge_dct[k]
|
187 |
+
|
188 |
+
def youtube_dl(idx):
|
189 |
+
cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
|
190 |
+
os.system(cmd)
|
191 |
+
|
192 |
+
def run_command(cmd):
|
193 |
+
if isinstance(cmd, str):
|
194 |
+
cmd = cmd.split()
|
195 |
+
return subprocess.check_output(cmd).decode()
|
196 |
+
|
197 |
+
def load_env():
|
198 |
+
import logging
|
199 |
+
logging.basicConfig(
|
200 |
+
format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
|
201 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
202 |
+
|
203 |
+
try:
|
204 |
+
import vlmeval
|
205 |
+
except ImportError:
|
206 |
+
logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
|
207 |
+
return
|
208 |
+
pth = osp.realpath(vlmeval.__path__[0])
|
209 |
+
pth = osp.join(pth, '../.env')
|
210 |
+
pth = osp.realpath(pth)
|
211 |
+
if not osp.exists(pth):
|
212 |
+
logging.error(f'Did not detect the .env file at {pth}, failed to load. ')
|
213 |
+
return
|
214 |
+
|
215 |
+
from dotenv import dotenv_values
|
216 |
+
values = dotenv_values(pth)
|
217 |
+
for k, v in values.items():
|
218 |
+
if v is not None and len(v):
|
219 |
+
os.environ[k] = v
|
220 |
+
logging.info(f'API Keys successfully loaded from {pth}')
|
221 |
+
|
222 |
+
def pip_install_robust(package):
|
223 |
+
import sys
|
224 |
+
retry = 3
|
225 |
+
while retry > 0:
|
226 |
+
try:
|
227 |
+
package_base = package.split('=')[0]
|
228 |
+
module = __import__(package)
|
229 |
+
return True
|
230 |
+
except ImportError:
|
231 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
|
232 |
+
retry -= 1
|
233 |
+
return False
|
234 |
+
|
235 |
+
|
236 |
+
def version_cmp(v1, v2, op='eq'):
|
237 |
+
from packaging import version
|
238 |
+
import operator
|
239 |
+
op_func = getattr(operator, op)
|
240 |
+
return op_func(version.parse(v1), version.parse(v2))
|
241 |
+
|
242 |
+
|
243 |
+
def toliststr(s):
|
244 |
+
if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
|
245 |
+
return [str(x) for x in eval(s)]
|
246 |
+
elif isinstance(s, str):
|
247 |
+
return [s]
|
248 |
+
elif isinstance(s, list):
|
249 |
+
return [str(x) for x in s]
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
|
253 |
+
def extract_json_objects(text, decoder=JSONDecoder()):
|
254 |
+
pos = 0
|
255 |
+
while True:
|
256 |
+
match = text.find('{', pos)
|
257 |
+
if match == -1: break
|
258 |
+
try:
|
259 |
+
result, index = decoder.raw_decode(text[match:])
|
260 |
+
yield result
|
261 |
+
pos = match + index
|
262 |
+
except ValueError:
|
263 |
+
pos = match + 1
|
264 |
+
|
265 |
+
|
266 |
+
def get_gpu_memory():
|
267 |
+
import subprocess
|
268 |
+
try:
|
269 |
+
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
270 |
+
memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
271 |
+
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
272 |
+
return memory_free_values
|
273 |
+
except Exception as e:
|
274 |
+
print(f'{type(e)}: {str(e)}')
|
275 |
+
return []
|
276 |
+
|
277 |
+
|
278 |
+
def auto_split_flag():
|
279 |
+
flag = os.environ.get('AUTO_SPLIT', '0')
|
280 |
+
return flag == '1'
|
VLMEvalKit/vlmeval/vlm/aria.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
import copy as cp
|
4 |
+
from PIL import Image
|
5 |
+
import pandas as pd
|
6 |
+
import string
|
7 |
+
import re
|
8 |
+
from .base import BaseModel
|
9 |
+
from ..smp import isimg, listinstr, cn_string
|
10 |
+
from ..dataset import DATASET_TYPE, DATASET_MODALITY
|
11 |
+
|
12 |
+
|
13 |
+
class Aria(BaseModel):
|
14 |
+
|
15 |
+
INSTALL_REQ = False
|
16 |
+
INTERLEAVE = True
|
17 |
+
|
18 |
+
def __init__(self, model_path='rhymes-ai/Aria', **kwargs):
|
19 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
20 |
+
assert model_path is not None
|
21 |
+
self.model_path = model_path
|
22 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
23 |
+
tokenizer = processor.tokenizer
|
24 |
+
tokenizer.padding_side = 'left'
|
25 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
26 |
+
self.processor = processor
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
29 |
+
model_path,
|
30 |
+
device_map='cuda',
|
31 |
+
torch_dtype=torch.bfloat16,
|
32 |
+
trust_remote_code=True
|
33 |
+
).eval()
|
34 |
+
default_kwargs = dict(
|
35 |
+
do_sample=False,
|
36 |
+
num_beams=1,
|
37 |
+
max_new_tokens=512,
|
38 |
+
min_new_tokens=1,
|
39 |
+
num_return_sequences=1,
|
40 |
+
use_cache=True,
|
41 |
+
output_hidden_states=True,
|
42 |
+
pad_token_id=tokenizer.unk_token_id,
|
43 |
+
stop_strings=["<|im_end|>"],
|
44 |
+
tokenizer=processor.tokenizer,
|
45 |
+
)
|
46 |
+
default_kwargs.update(kwargs)
|
47 |
+
self.kwargs = default_kwargs
|
48 |
+
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
|
51 |
+
def use_custom_prompt(self, dataset):
|
52 |
+
assert dataset is not None
|
53 |
+
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
|
54 |
+
# For Multi-Turn we don't have custom prompt
|
55 |
+
return False
|
56 |
+
if DATASET_MODALITY(dataset) == 'VIDEO':
|
57 |
+
# For Video benchmarks we don't have custom prompt at here
|
58 |
+
return False
|
59 |
+
else:
|
60 |
+
return True
|
61 |
+
|
62 |
+
def build_prompt(self, line, dataset=None):
|
63 |
+
assert self.use_custom_prompt(dataset)
|
64 |
+
assert dataset is None or isinstance(dataset, str)
|
65 |
+
tgt_path = self.dump_image(line, dataset)
|
66 |
+
|
67 |
+
question = line['question']
|
68 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
69 |
+
if hint is not None:
|
70 |
+
question = hint + '\n' + question
|
71 |
+
|
72 |
+
options = {
|
73 |
+
cand: line[cand]
|
74 |
+
for cand in string.ascii_uppercase
|
75 |
+
if cand in line and not pd.isna(line[cand])
|
76 |
+
}
|
77 |
+
for key, item in options.items():
|
78 |
+
question += f'\n{key}. {item}'
|
79 |
+
prompt = question
|
80 |
+
|
81 |
+
if len(options):
|
82 |
+
prompt += (
|
83 |
+
"\nAnswer with the option's letter from the given choices directly."
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
if listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse'], dataset):
|
87 |
+
prompt = prompt
|
88 |
+
elif listinstr(['LLaVABench', 'MMBench-Video'], dataset):
|
89 |
+
prompt += '\nAnswer this question in detail.'
|
90 |
+
elif listinstr(['DocVQA'], dataset):
|
91 |
+
prompt += '\nAnswer briefly and directly.'
|
92 |
+
else:
|
93 |
+
prompt += '\nAnswer the question using a single word or phrase.'
|
94 |
+
|
95 |
+
message = [dict(type='image', value=s) for s in tgt_path]
|
96 |
+
message.append(dict(type='text', value=prompt))
|
97 |
+
return message
|
98 |
+
|
99 |
+
def build_video_prompt(self, prompt, dataset=None):
|
100 |
+
if listinstr(['MMBench-Video'], dataset):
|
101 |
+
prompt = prompt.replace('\nAnswer:', '')
|
102 |
+
prompt = prompt.replace(
|
103 |
+
'Question: ',
|
104 |
+
'Please carefully check the video and then answer the following question with details:'
|
105 |
+
)
|
106 |
+
elif listinstr(['Video-MME'], dataset):
|
107 |
+
prompt = prompt.replace('\nAnswer:', '')
|
108 |
+
prompt += "\nAnswer with the option's letter from the given choices directly."
|
109 |
+
elif listinstr(['MVBench'], dataset):
|
110 |
+
prompt = prompt.replace('Best option:(', '')
|
111 |
+
system_prompt = 'Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n' # noqa: E501
|
112 |
+
prompt = prompt.replace(system_prompt, '')
|
113 |
+
|
114 |
+
return prompt
|
115 |
+
|
116 |
+
def adjust_kwargs(self, dataset):
|
117 |
+
kwargs = cp.deepcopy(self.kwargs)
|
118 |
+
kwargs["temperature"] = 0.0
|
119 |
+
kwargs["do_sample"] = False
|
120 |
+
|
121 |
+
if DATASET_MODALITY(dataset) == "VIDEO":
|
122 |
+
kwargs["max_image_size"] = 490
|
123 |
+
else:
|
124 |
+
kwargs["max_image_size"] = 980
|
125 |
+
|
126 |
+
kwargs["split_image"] = False
|
127 |
+
|
128 |
+
if listinstr(['MMMU', 'MMStar', 'Math'], dataset):
|
129 |
+
# These datasets may lead the model to work as a CoT-alike behaviour.
|
130 |
+
# Allow to output longer.
|
131 |
+
kwargs['max_new_tokens'] = 512
|
132 |
+
return kwargs
|
133 |
+
if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
|
134 |
+
kwargs['max_new_tokens'] = 64
|
135 |
+
elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
|
136 |
+
kwargs['max_new_tokens'] = 64
|
137 |
+
elif DATASET_TYPE(dataset) == 'VQA':
|
138 |
+
if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
|
139 |
+
kwargs['max_new_tokens'] = 128
|
140 |
+
elif listinstr(['TextVQA'], dataset):
|
141 |
+
kwargs['max_new_tokens'] = 32
|
142 |
+
|
143 |
+
if listinstr(['OCR', 'ChartQA', 'DocVQA', 'InfoVQA', 'TextVQA'], dataset):
|
144 |
+
# OCR-related datasets that need to split image
|
145 |
+
kwargs["split_image"] = True
|
146 |
+
|
147 |
+
return kwargs
|
148 |
+
|
149 |
+
def generate_inner(self, message, dataset=None):
|
150 |
+
if dataset is not None:
|
151 |
+
kwargs = self.adjust_kwargs(dataset)
|
152 |
+
else:
|
153 |
+
kwargs = self.kwargs
|
154 |
+
|
155 |
+
max_image_size = kwargs.pop("max_image_size")
|
156 |
+
split_image = kwargs.pop("split_image")
|
157 |
+
|
158 |
+
prompt = '<|im_start|>user\n'
|
159 |
+
images = []
|
160 |
+
last_message_modality = "text"
|
161 |
+
|
162 |
+
if listinstr(['MLVU', 'TempCompass', 'MVBench'], dataset): # re-arrange the data
|
163 |
+
new_message = []
|
164 |
+
for s in message:
|
165 |
+
if s['type'] == 'image':
|
166 |
+
new_message.append(s)
|
167 |
+
for s in message:
|
168 |
+
if s['type'] == 'text':
|
169 |
+
new_message.append(s)
|
170 |
+
message = new_message
|
171 |
+
|
172 |
+
for s in message:
|
173 |
+
if s['type'] == 'image':
|
174 |
+
prompt += '<fim_prefix><|img|><fim_suffix>'
|
175 |
+
images.append(s['value'])
|
176 |
+
last_message_modality = "image"
|
177 |
+
elif s['type'] == 'text':
|
178 |
+
text = re.sub(r"<image \d+>", "", s["value"])
|
179 |
+
if last_message_modality == "image":
|
180 |
+
prompt += "\n"
|
181 |
+
last_message_modality = "text"
|
182 |
+
prompt += text
|
183 |
+
|
184 |
+
if DATASET_MODALITY(dataset) == 'VIDEO':
|
185 |
+
prompt = self.build_video_prompt(prompt, dataset)
|
186 |
+
|
187 |
+
prompt += '<|im_end|>\n<|im_start|>assistant\n'
|
188 |
+
if images:
|
189 |
+
images = [Image.open(s).convert('RGB') for s in images]
|
190 |
+
encoded = self.processor(
|
191 |
+
text=prompt,
|
192 |
+
images=images,
|
193 |
+
return_tensors='pt',
|
194 |
+
padding='longest',
|
195 |
+
max_image_size=max_image_size,
|
196 |
+
split_image=split_image,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
encoded = self.processor(text=prompt, return_tensors='pt', padding='longest')
|
200 |
+
encoded["pixel_values"] = encoded["pixel_values"].to(self.model.dtype)
|
201 |
+
encoded = {k: v.to(self.model.device) for k, v in encoded.items()}
|
202 |
+
|
203 |
+
pred = self.model.generate(**encoded, **kwargs)
|
204 |
+
answer = self.tokenizer.decode(pred[0][encoded['input_ids'].size(1):].cpu(), skip_special_tokens=True).strip()
|
205 |
+
answer = answer.replace('<|im_end|>', '')
|
206 |
+
return answer
|