tuandunghcmut commited on
Commit
c85c9e5
·
verified ·
1 Parent(s): a98cf39

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. VLMEvalKit/vlmeval/api/__init__.py +26 -0
  2. VLMEvalKit/vlmeval/api/bailingmm.py +90 -0
  3. VLMEvalKit/vlmeval/api/base.py +289 -0
  4. VLMEvalKit/vlmeval/api/bluelm_v_api.py +120 -0
  5. VLMEvalKit/vlmeval/api/claude.py +130 -0
  6. VLMEvalKit/vlmeval/api/cloudwalk.py +107 -0
  7. VLMEvalKit/vlmeval/api/gemini.py +116 -0
  8. VLMEvalKit/vlmeval/api/glm_vision.py +95 -0
  9. VLMEvalKit/vlmeval/api/gpt.py +267 -0
  10. VLMEvalKit/vlmeval/api/hf_chat_model.py +246 -0
  11. VLMEvalKit/vlmeval/api/hunyuan.py +147 -0
  12. VLMEvalKit/vlmeval/api/jt_vl_chat.py +239 -0
  13. VLMEvalKit/vlmeval/api/qwen_api.py +75 -0
  14. VLMEvalKit/vlmeval/api/qwen_vl_api.py +219 -0
  15. VLMEvalKit/vlmeval/api/reka.py +60 -0
  16. VLMEvalKit/vlmeval/api/sensechat_vision.py +261 -0
  17. VLMEvalKit/vlmeval/api/siliconflow.py +269 -0
  18. VLMEvalKit/vlmeval/api/stepai.py +87 -0
  19. VLMEvalKit/vlmeval/api/taiyi.py +192 -0
  20. VLMEvalKit/vlmeval/dataset/__init__.py +230 -0
  21. VLMEvalKit/vlmeval/dataset/cmmmu.py +354 -0
  22. VLMEvalKit/vlmeval/dataset/dude.py +211 -0
  23. VLMEvalKit/vlmeval/dataset/dynamath.py +240 -0
  24. VLMEvalKit/vlmeval/dataset/image_base.py +172 -0
  25. VLMEvalKit/vlmeval/dataset/image_caption.py +75 -0
  26. VLMEvalKit/vlmeval/dataset/image_mcq.py +899 -0
  27. VLMEvalKit/vlmeval/dataset/image_mt.py +128 -0
  28. VLMEvalKit/vlmeval/dataset/image_vqa.py +1330 -0
  29. VLMEvalKit/vlmeval/dataset/image_yorn.py +95 -0
  30. VLMEvalKit/vlmeval/dataset/longvideobench.py +328 -0
  31. VLMEvalKit/vlmeval/dataset/miabench.py +167 -0
  32. VLMEvalKit/vlmeval/dataset/mlvu.py +455 -0
  33. VLMEvalKit/vlmeval/dataset/mmbench_video.py +256 -0
  34. VLMEvalKit/vlmeval/dataset/mmgenbench.py +69 -0
  35. VLMEvalKit/vlmeval/dataset/mmlongbench.py +584 -0
  36. VLMEvalKit/vlmeval/dataset/mmmath.py +446 -0
  37. VLMEvalKit/vlmeval/dataset/mvbench.py +668 -0
  38. VLMEvalKit/vlmeval/dataset/slidevqa.py +189 -0
  39. VLMEvalKit/vlmeval/dataset/tempcompass.py +639 -0
  40. VLMEvalKit/vlmeval/dataset/text_base.py +88 -0
  41. VLMEvalKit/vlmeval/dataset/text_mcq.py +123 -0
  42. VLMEvalKit/vlmeval/dataset/vcr.py +335 -0
  43. VLMEvalKit/vlmeval/dataset/video_base.py +126 -0
  44. VLMEvalKit/vlmeval/dataset/video_concat_dataset.py +83 -0
  45. VLMEvalKit/vlmeval/dataset/videomme.py +287 -0
  46. VLMEvalKit/vlmeval/dataset/wildvision.py +218 -0
  47. VLMEvalKit/vlmeval/smp/__init__.py +4 -0
  48. VLMEvalKit/vlmeval/smp/log.py +47 -0
  49. VLMEvalKit/vlmeval/smp/misc.py +280 -0
  50. VLMEvalKit/vlmeval/vlm/aria.py +206 -0
VLMEvalKit/vlmeval/api/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .gpt import OpenAIWrapper, GPT4V
2
+ from .hf_chat_model import HFChatModel
3
+ from .gemini import GeminiWrapper, GeminiProVision
4
+ from .qwen_vl_api import QwenVLWrapper, QwenVLAPI, Qwen2VLAPI
5
+ from .qwen_api import QwenAPI
6
+ from .claude import Claude_Wrapper, Claude3V
7
+ from .reka import Reka
8
+ from .glm_vision import GLMVisionAPI
9
+ from .cloudwalk import CWWrapper
10
+ from .sensechat_vision import SenseChatVisionAPI
11
+ from .siliconflow import SiliconFlowAPI, TeleMMAPI
12
+ from .hunyuan import HunyuanVision
13
+ from .bailingmm import bailingMMAPI
14
+ from .bluelm_v_api import BlueLMWrapper, BlueLM_V_API
15
+ from .jt_vl_chat import JTVLChatAPI
16
+ from .taiyi import TaiyiAPI
17
+
18
+
19
+ __all__ = [
20
+ 'OpenAIWrapper', 'HFChatModel', 'GeminiWrapper', 'GPT4V',
21
+ 'GeminiProVision', 'QwenVLWrapper', 'QwenVLAPI', 'QwenAPI',
22
+ 'Claude3V', 'Claude_Wrapper', 'Reka', 'GLMVisionAPI',
23
+ 'CWWrapper', 'SenseChatVisionAPI', 'HunyuanVision', 'Qwen2VLAPI',
24
+ 'BlueLMWrapper', 'BlueLM_V_API', 'JTVLChatAPI', 'bailingMMAPI',
25
+ 'TaiyiAPI', 'TeleMMAPI', 'SiliconFlowAPI'
26
+ ]
VLMEvalKit/vlmeval/api/bailingmm.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from vlmeval.smp import *
3
+ from vlmeval.api.base import BaseAPI
4
+ from vlmeval.dataset import DATASET_TYPE
5
+ from vlmeval.smp.vlm import encode_image_file_to_base64
6
+ import time
7
+
8
+
9
+ class bailingMMWrapper(BaseAPI):
10
+
11
+ is_api: bool = True
12
+
13
+ def __init__(self,
14
+ model: str,
15
+ retry: int = 5,
16
+ wait: int = 5,
17
+ key: str = None,
18
+ verbose: bool = True,
19
+ system_prompt: str = None,
20
+ max_tokens: int = 1024,
21
+ proxy: str = None,
22
+ **kwargs):
23
+
24
+ self.model = model
25
+ self.fail_msg = 'Failed to obtain answer via bailingMM API.'
26
+ if key is None:
27
+ key = os.environ.get('BAILINGMM_API_KEY', None)
28
+ assert key is not None, ('Please set the API Key for bailingMM.')
29
+ self.key = key
30
+ self.headers = {"Content-Type": "application/json"}
31
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
32
+
33
+ def image_to_base64(self, image_path):
34
+ with open(image_path, 'rb') as image_file:
35
+ encoded_string = str(base64.b64encode(image_file.read()), 'utf-8')
36
+ return encoded_string
37
+
38
+ def prepare_inputs(self, inputs):
39
+ msgs = cp.deepcopy(inputs)
40
+ content = []
41
+ for i, msg in enumerate(msgs):
42
+ if msg['type'] == 'text':
43
+ pass
44
+ else:
45
+ try:
46
+ image_data = self.image_to_base64(msg['value'])
47
+ except Exception as e:
48
+ if self.verbose:
49
+ self.logger.error(e)
50
+ image_data = ''
51
+ msg['value'] = image_data
52
+ content.append(msg)
53
+ return content
54
+
55
+ def generate_inner(self, inputs, **kwargs) -> str:
56
+ assert isinstance(inputs, str) or isinstance(inputs, list)
57
+ start = time.time()
58
+ inputs = [inputs] if isinstance(inputs, str) else inputs
59
+
60
+ messages = self.prepare_inputs(inputs)
61
+
62
+ service_url = "https://bailingchat.alipay.com/api/proxy/eval/antgmm/completions"
63
+
64
+ payload = {
65
+ "structInput": messages,
66
+ "sk": self.key,
67
+ "timeout": 180000
68
+ }
69
+ response = requests.post(service_url, headers=self.headers, json=payload)
70
+ if self.verbose:
71
+ self.logger.info('Time for requesting is:')
72
+ self.logger.info(time.time() - start)
73
+ try:
74
+ assert response.status_code == 200
75
+ output = json.loads(response.text)
76
+ answer = output['preds']['pred']
77
+ if self.verbose:
78
+ self.logger.info(f'inputs: {inputs}\nanswer: {answer}')
79
+ return 0, answer, 'Succeeded! '
80
+ except Exception as e:
81
+ if self.verbose:
82
+ self.logger.error(e)
83
+ self.logger.error(f'The input messages are {inputs}.')
84
+ return -1, self.fail_msg, ''
85
+
86
+
87
+ class bailingMMAPI(bailingMMWrapper):
88
+
89
+ def generate(self, message, dataset=None):
90
+ return super(bailingMMAPI, self).generate(message, dataset=dataset)
VLMEvalKit/vlmeval/api/base.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random as rd
3
+ from abc import abstractmethod
4
+ import os.path as osp
5
+ import copy as cp
6
+ from ..smp import get_logger, parse_file, concat_images_vlmeval, LMUDataRoot, md5, decode_base64_to_image_file
7
+
8
+
9
+ class BaseAPI:
10
+
11
+ allowed_types = ['text', 'image']
12
+ INTERLEAVE = True
13
+ INSTALL_REQ = False
14
+
15
+ def __init__(self,
16
+ retry=10,
17
+ wait=3,
18
+ system_prompt=None,
19
+ verbose=True,
20
+ fail_msg='Failed to obtain answer via API.',
21
+ **kwargs):
22
+ """Base Class for all APIs.
23
+
24
+ Args:
25
+ retry (int, optional): The retry times for `generate_inner`. Defaults to 10.
26
+ wait (int, optional): The wait time after each failed retry of `generate_inner`. Defaults to 3.
27
+ system_prompt (str, optional): Defaults to None.
28
+ verbose (bool, optional): Defaults to True.
29
+ fail_msg (str, optional): The message to return when failed to obtain answer.
30
+ Defaults to 'Failed to obtain answer via API.'.
31
+ **kwargs: Other kwargs for `generate_inner`.
32
+ """
33
+
34
+ self.wait = wait
35
+ self.retry = retry
36
+ self.system_prompt = system_prompt
37
+ self.verbose = verbose
38
+ self.fail_msg = fail_msg
39
+ self.logger = get_logger('ChatAPI')
40
+
41
+ if len(kwargs):
42
+ self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
43
+ self.logger.info('Will try to use them as kwargs for `generate`. ')
44
+ self.default_kwargs = kwargs
45
+
46
+ @abstractmethod
47
+ def generate_inner(self, inputs, **kwargs):
48
+ """The inner function to generate the answer.
49
+
50
+ Returns:
51
+ tuple(int, str, str): ret_code, response, log
52
+ """
53
+ self.logger.warning('For APIBase, generate_inner is an abstract method. ')
54
+ assert 0, 'generate_inner not defined'
55
+ ret_code, answer, log = None, None, None
56
+ # if ret_code is 0, means succeed
57
+ return ret_code, answer, log
58
+
59
+ def working(self):
60
+ """If the API model is working, return True, else return False.
61
+
62
+ Returns:
63
+ bool: If the API model is working, return True, else return False.
64
+ """
65
+ self.old_timeout = None
66
+ if hasattr(self, 'timeout'):
67
+ self.old_timeout = self.timeout
68
+ self.timeout = 120
69
+
70
+ retry = 5
71
+ while retry > 0:
72
+ ret = self.generate('hello')
73
+ if ret is not None and ret != '' and self.fail_msg not in ret:
74
+ if self.old_timeout is not None:
75
+ self.timeout = self.old_timeout
76
+ return True
77
+ retry -= 1
78
+
79
+ if self.old_timeout is not None:
80
+ self.timeout = self.old_timeout
81
+ return False
82
+
83
+ def check_content(self, msgs):
84
+ """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
85
+
86
+ Args:
87
+ msgs: Raw input messages.
88
+
89
+ Returns:
90
+ str: The message type.
91
+ """
92
+ if isinstance(msgs, str):
93
+ return 'str'
94
+ if isinstance(msgs, dict):
95
+ return 'dict'
96
+ if isinstance(msgs, list):
97
+ types = [self.check_content(m) for m in msgs]
98
+ if all(t == 'str' for t in types):
99
+ return 'liststr'
100
+ if all(t == 'dict' for t in types):
101
+ return 'listdict'
102
+ return 'unknown'
103
+
104
+ def preproc_content(self, inputs):
105
+ """Convert the raw input messages to a list of dicts.
106
+
107
+ Args:
108
+ inputs: raw input messages.
109
+
110
+ Returns:
111
+ list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
112
+ """
113
+ if self.check_content(inputs) == 'str':
114
+ return [dict(type='text', value=inputs)]
115
+ elif self.check_content(inputs) == 'dict':
116
+ assert 'type' in inputs and 'value' in inputs
117
+ return [inputs]
118
+ elif self.check_content(inputs) == 'liststr':
119
+ res = []
120
+ for s in inputs:
121
+ mime, pth = parse_file(s)
122
+ if mime is None or mime == 'unknown':
123
+ res.append(dict(type='text', value=s))
124
+ else:
125
+ res.append(dict(type=mime.split('/')[0], value=pth))
126
+ return res
127
+ elif self.check_content(inputs) == 'listdict':
128
+ for item in inputs:
129
+ assert 'type' in item and 'value' in item
130
+ mime, s = parse_file(item['value'])
131
+ if mime is None:
132
+ assert item['type'] == 'text', item['value']
133
+ else:
134
+ assert mime.split('/')[0] == item['type']
135
+ item['value'] = s
136
+ return inputs
137
+ else:
138
+ return None
139
+
140
+ # May exceed the context windows size, so try with different turn numbers.
141
+ def chat_inner(self, inputs, **kwargs):
142
+ _ = kwargs.pop('dataset', None)
143
+ while len(inputs):
144
+ try:
145
+ return self.generate_inner(inputs, **kwargs)
146
+ except Exception as e:
147
+ if self.verbose:
148
+ self.logger.info(f'{type(e)}: {e}')
149
+ inputs = inputs[1:]
150
+ while len(inputs) and inputs[0]['role'] != 'user':
151
+ inputs = inputs[1:]
152
+ continue
153
+ return -1, self.fail_msg + ': ' + 'Failed with all possible conversation turns.', None
154
+
155
+ def chat(self, messages, **kwargs1):
156
+ """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
157
+ assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
158
+ for msg in messages:
159
+ assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
160
+ assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
161
+ msg['content'] = self.preproc_content(msg['content'])
162
+ # merge kwargs
163
+ kwargs = cp.deepcopy(self.default_kwargs)
164
+ kwargs.update(kwargs1)
165
+
166
+ answer = None
167
+ # a very small random delay [0s - 0.5s]
168
+ T = rd.random() * 0.5
169
+ time.sleep(T)
170
+
171
+ assert messages[-1]['role'] == 'user'
172
+
173
+ for i in range(self.retry):
174
+ try:
175
+ ret_code, answer, log = self.chat_inner(messages, **kwargs)
176
+ if ret_code == 0 and self.fail_msg not in answer and answer != '':
177
+ if self.verbose:
178
+ print(answer)
179
+ return answer
180
+ elif self.verbose:
181
+ if not isinstance(log, str):
182
+ try:
183
+ log = log.text
184
+ except Exception as e:
185
+ self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
186
+ self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
187
+ except Exception as err:
188
+ if self.verbose:
189
+ self.logger.error(f'An error occured during try {i}: ')
190
+ self.logger.error(f'{type(err)}: {err}')
191
+ # delay before each retry
192
+ T = rd.random() * self.wait * 2
193
+ time.sleep(T)
194
+
195
+ return self.fail_msg if answer in ['', None] else answer
196
+
197
+ def preprocess_message_with_role(self, message):
198
+ system_prompt = ''
199
+ new_message = []
200
+
201
+ for data in message:
202
+ assert isinstance(data, dict)
203
+ role = data.pop('role', 'user')
204
+ if role == 'system':
205
+ system_prompt += data['value'] + '\n'
206
+ else:
207
+ new_message.append(data)
208
+
209
+ if system_prompt != '':
210
+ if self.system_prompt is None:
211
+ self.system_prompt = system_prompt
212
+ else:
213
+ self.system_prompt += '\n' + system_prompt
214
+ return new_message
215
+
216
+ def generate(self, message, **kwargs1):
217
+ """The main function to generate the answer. Will call `generate_inner` with the preprocessed input messages.
218
+
219
+ Args:
220
+ message: raw input messages.
221
+
222
+ Returns:
223
+ str: The generated answer of the Failed Message if failed to obtain answer.
224
+ """
225
+ if self.check_content(message) == 'listdict':
226
+ message = self.preprocess_message_with_role(message)
227
+
228
+ assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
229
+ message = self.preproc_content(message)
230
+ assert message is not None and self.check_content(message) == 'listdict'
231
+ for item in message:
232
+ assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
233
+
234
+ # merge kwargs
235
+ kwargs = cp.deepcopy(self.default_kwargs)
236
+ kwargs.update(kwargs1)
237
+
238
+ answer = None
239
+ # a very small random delay [0s - 0.5s]
240
+ T = rd.random() * 0.5
241
+ time.sleep(T)
242
+
243
+ for i in range(self.retry):
244
+ try:
245
+ ret_code, answer, log = self.generate_inner(message, **kwargs)
246
+ if ret_code == 0 and self.fail_msg not in answer and answer != '':
247
+ if self.verbose:
248
+ print(answer)
249
+ return answer
250
+ elif self.verbose:
251
+ if not isinstance(log, str):
252
+ try:
253
+ log = log.text
254
+ except Exception as e:
255
+ self.logger.warning(f'Failed to parse {log} as an http response: {str(e)}. ')
256
+ self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
257
+ except Exception as err:
258
+ if self.verbose:
259
+ self.logger.error(f'An error occured during try {i}: ')
260
+ self.logger.error(f'{type(err)}: {err}')
261
+ # delay before each retry
262
+ T = rd.random() * self.wait * 2
263
+ time.sleep(T)
264
+
265
+ return self.fail_msg if answer in ['', None] else answer
266
+
267
+ def message_to_promptimg(self, message, dataset=None):
268
+ assert not self.INTERLEAVE
269
+ model_name = self.__class__.__name__
270
+ import warnings
271
+ warnings.warn(
272
+ f'Model {model_name} does not support interleaved input. '
273
+ 'Will use the first image and aggregated texts as prompt. ')
274
+ num_images = len([x for x in message if x['type'] == 'image'])
275
+ if num_images == 0:
276
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
277
+ image = None
278
+ elif num_images == 1:
279
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
280
+ image = [x['value'] for x in message if x['type'] == 'image'][0]
281
+ else:
282
+ prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
283
+ if dataset == 'BLINK':
284
+ image = concat_images_vlmeval(
285
+ [x['value'] for x in message if x['type'] == 'image'],
286
+ target_size=512)
287
+ else:
288
+ image = [x['value'] for x in message if x['type'] == 'image'][0]
289
+ return prompt, image
VLMEvalKit/vlmeval/api/bluelm_v_api.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+ import os
4
+ import json
5
+
6
+
7
+ def multimodal(images, text, url, key, temperature=0, max_tokens=1024, history=[]):
8
+ if images:
9
+ pics = []
10
+ for image in images:
11
+ with open(image, 'rb') as f:
12
+ pic = base64.b64encode(f.read()).decode('utf-8')
13
+ pics.append(pic)
14
+ data = {'images': pics, 'text': text, 'key': key, 'temperature': temperature, 'max_new_tokens': max_tokens}
15
+ else:
16
+ data = {'text': text, 'key': key, 'temperature': temperature, 'max_new_tokens': max_tokens}
17
+ response = requests.post(url, json=data, headers={'Content-Type': 'application/json'})
18
+ response = json.loads(response.text)
19
+ return response
20
+
21
+
22
+ class BlueLMWrapper(BaseAPI):
23
+ is_api: bool = True
24
+
25
+ def __init__(self,
26
+ model: str = 'BlueLM-V-v3.0',
27
+ retry: int = 5,
28
+ wait: int = 5,
29
+ verbose: bool = True,
30
+ temperature: float = 0.0,
31
+ system_prompt: str = None,
32
+ max_tokens: int = 1024,
33
+ key: str = None,
34
+ url: str = 'http://api-ai.vivo.com.cn/multimodal',
35
+ **kwargs):
36
+
37
+ self.model = model
38
+ self.fail_msg = 'Failed to obtain answer BlueLM-V API. '
39
+ self.max_tokens = max_tokens
40
+ self.temperature = temperature
41
+ self.url = url
42
+ self.key = key
43
+
44
+ if self.key is None:
45
+ self.key = os.environ.get('BLUELM_V_API_KEY', None)
46
+ assert self.key is not None, (
47
+ 'Please set the API Key (obtain it here: '
48
+ 'contact by email : [email protected]'
49
+ )
50
+
51
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
52
+
53
+ def message_to_promptimg(self, message, dataset=None):
54
+
55
+ num_images = len([x for x in message if x['type'] == 'image'])
56
+ if num_images == 0:
57
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
58
+ image = None
59
+ elif num_images == 1:
60
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
61
+ image = [x['value'] for x in message if x['type'] == 'image']
62
+ else:
63
+ prompt = '\n'.join([x['value'] if x['type'] == 'text' else '<image>' for x in message])
64
+ if dataset == 'BLINK':
65
+ image = concat_images_vlmeval(
66
+ [x['value'] for x in message if x['type'] == 'image'],
67
+ target_size=512)
68
+ else:
69
+ image = [x['value'] for x in message if x['type'] == 'image']
70
+
71
+ if dataset in ['MMBench_DEV_EN_V11', 'MMBench_DEV_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11',
72
+ 'AI2D_TEST', 'AI2D_TEST_TO_MASK', 'MMMU_DEV_VAL']:
73
+ prompt = prompt.replace('Please select the correct answer from the options above.',
74
+ 'Answer with the option’s letter from the given choices directly.')
75
+ elif dataset in ['ChartQA_TEST']:
76
+ prompt = prompt.replace('Answer the question using a single word or phrase.',
77
+ 'Answer the question using a single number or phrase.')
78
+ elif dataset in ['DocVQA_VAL', 'DocVQA_TEST', ]:
79
+ prompt = prompt.replace('Answer the question using a single word or phrase.',
80
+ 'Give the short answer directly.')
81
+ elif dataset in ['TextVQA_VAL']:
82
+ prompt = prompt.replace('Answer the question using a single word or phrase.',
83
+ 'When the provided information is insufficient, respond with ’Unanswerable’.'
84
+ 'Answer the question using a single word or phrase.')
85
+ elif dataset in ['MTVQA_TEST']:
86
+ prompt = prompt.replace('\nAnswer the question using a word or phrase in the language of the question.', '')
87
+ elif dataset in ['MathVista_MINI']:
88
+ if 'Choices:' in prompt:
89
+ prompt = prompt.replace('Choices:', 'Options:').replace('Hint:', 'Context:')
90
+ for i in range(1, 7): # replace A ~ F
91
+ prompt = prompt.replace(f'({chr(64 + i)})', f'{chr(64 + i)}.')
92
+ prompt += '\nAnswer with the option’s letter from the given choices directly.'
93
+ else:
94
+ prompt += '\nAnswer the question using a single word or phrase.'
95
+
96
+ return prompt, image
97
+
98
+ def generate_inner(self, inputs, **kwargs) -> str:
99
+
100
+ assert isinstance(inputs, str) or isinstance(inputs, list)
101
+ pure_text = np.all([x['type'] == 'text' for x in inputs])
102
+ assert not pure_text
103
+
104
+ prompt, image_path = self.message_to_promptimg(inputs, kwargs['dataset'])
105
+
106
+ try:
107
+ response = multimodal(image_path, prompt, self.url, self.key, self.temperature, self.max_tokens)
108
+ answer = response['result']
109
+ return 0, answer, 'Succeeded! '
110
+ except Exception as err:
111
+ if self.verbose:
112
+ self.logger.error(f'{type(err)}: {err}')
113
+ self.logger.error(f'The input messages are {inputs}.')
114
+ return -1, '', ''
115
+
116
+
117
+ class BlueLM_V_API(BlueLMWrapper):
118
+
119
+ def generate(self, message, dataset=None):
120
+ return super(BlueLM_V_API, self).generate(message, dataset=dataset)
VLMEvalKit/vlmeval/api/claude.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+ from time import sleep
4
+ import base64
5
+ import mimetypes
6
+ from PIL import Image
7
+
8
+ alles_url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
9
+ alles_headers = {
10
+ 'alles-apin-token': '',
11
+ 'Content-Type': 'application/json'
12
+ }
13
+ official_url = 'https://api.anthropic.com/v1/messages'
14
+ official_headers = {
15
+ 'x-api-key': '',
16
+ 'anthropic-version': '2023-06-01',
17
+ 'content-type': 'application/json'
18
+ }
19
+
20
+
21
+ class Claude_Wrapper(BaseAPI):
22
+
23
+ is_api: bool = True
24
+
25
+ def __init__(self,
26
+ backend: str = 'alles',
27
+ model: str = 'claude-3-opus-20240229',
28
+ key: str = None,
29
+ retry: int = 10,
30
+ wait: int = 3,
31
+ system_prompt: str = None,
32
+ verbose: bool = True,
33
+ temperature: float = 0,
34
+ max_tokens: int = 1024,
35
+ **kwargs):
36
+
37
+ if os.environ.get('ANTHROPIC_BACKEND', '') == 'official':
38
+ backend = 'official'
39
+
40
+ assert backend in ['alles', 'official'], f'Invalid backend: {backend}'
41
+ self.backend = backend
42
+ self.url = alles_url if backend == 'alles' else official_url
43
+ self.model = model
44
+ self.temperature = temperature
45
+ self.max_tokens = max_tokens
46
+ self.headers = alles_headers if backend == 'alles' else official_headers
47
+
48
+ if key is not None:
49
+ self.key = key
50
+ else:
51
+ self.key = os.environ.get('ALLES', '') if self.backend == 'alles' else os.environ.get('ANTHROPIC_API_KEY', '') # noqa: E501
52
+
53
+ if self.backend == 'alles':
54
+ self.headers['alles-apin-token'] = self.key
55
+ else:
56
+ self.headers['x-api-key'] = self.key
57
+
58
+ super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
59
+
60
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
61
+ # content can be a string or a list of image & text
62
+ def prepare_itlist(self, inputs):
63
+ assert np.all([isinstance(x, dict) for x in inputs])
64
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
65
+ if has_images:
66
+ content_list = []
67
+ for msg in inputs:
68
+ if msg['type'] == 'text' and msg['value'] != '':
69
+ content_list.append(dict(type='text', text=msg['value']))
70
+ elif msg['type'] == 'image':
71
+ pth = msg['value']
72
+ suffix = osp.splitext(pth)[-1].lower()
73
+ media_type = mimetypes.types_map.get(suffix, None)
74
+ assert media_type is not None
75
+
76
+ content_list.append(dict(
77
+ type='image',
78
+ source={
79
+ 'type': 'base64',
80
+ 'media_type': media_type,
81
+ 'data': encode_image_file_to_base64(pth, target_size=4096)
82
+ }))
83
+ else:
84
+ assert all([x['type'] == 'text' for x in inputs])
85
+ text = '\n'.join([x['value'] for x in inputs])
86
+ content_list = [dict(type='text', text=text)]
87
+ return content_list
88
+
89
+ def prepare_inputs(self, inputs):
90
+ input_msgs = []
91
+ assert isinstance(inputs, list) and isinstance(inputs[0], dict)
92
+ assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
93
+ if 'role' in inputs[0]:
94
+ assert inputs[-1]['role'] == 'user', inputs[-1]
95
+ for item in inputs:
96
+ input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
97
+ else:
98
+ input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
99
+ return input_msgs
100
+
101
+ def generate_inner(self, inputs, **kwargs) -> str:
102
+ payload = {
103
+ 'model': self.model,
104
+ 'max_tokens': self.max_tokens,
105
+ 'messages': self.prepare_inputs(inputs),
106
+ **kwargs
107
+ }
108
+ if self.system_prompt is not None:
109
+ payload['system'] = self.system_prompt
110
+
111
+ response = requests.request('POST', self.url, headers=self.headers, data=json.dumps(payload))
112
+ ret_code = response.status_code
113
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
114
+ answer = self.fail_msg
115
+
116
+ try:
117
+ resp_struct = json.loads(response.text)
118
+ answer = resp_struct['data']['content'][0]['text'].strip()
119
+ except Exception as err:
120
+ if self.verbose:
121
+ self.logger.error(f'{type(err)}: {err}')
122
+ self.logger.error(response.text if hasattr(response, 'text') else response)
123
+
124
+ return ret_code, answer, response
125
+
126
+
127
+ class Claude3V(Claude_Wrapper):
128
+
129
+ def generate(self, message, dataset=None):
130
+ return super(Claude_Wrapper, self).generate(message)
VLMEvalKit/vlmeval/api/cloudwalk.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ import os
3
+ from .base import BaseAPI
4
+
5
+
6
+ class CWWrapper(BaseAPI):
7
+
8
+ is_api: bool = True
9
+
10
+ def __init__(self,
11
+ model: str = 'cw-congrong-v1.5',
12
+ retry: int = 10,
13
+ wait: int = 5,
14
+ key: str = None,
15
+ verbose: bool = True,
16
+ system_prompt: str = None,
17
+ temperature: float = 0,
18
+ timeout: int = 600,
19
+ api_base: str = 'http://cwapi-vlm01.cw_rb.azurebot.tk/v1/chat/completions',
20
+ max_tokens: int = 1024,
21
+ img_size: int = 512,
22
+ img_detail: str = 'low',
23
+ **kwargs):
24
+
25
+ self.model = model
26
+ self.cur_idx = 0
27
+ self.fail_msg = 'Failed to obtain answer via API. '
28
+ self.max_tokens = max_tokens
29
+ self.temperature = temperature
30
+
31
+ base = os.environ.get('CW_API_BASE', None)
32
+ self.api_base = base if base is not None else api_base
33
+
34
+ env_key = os.environ.get('CW_API_KEY', None)
35
+ self.key = env_key if env_key is not None else key
36
+ assert self.key is not None, 'API key not provided. Please set CW_API_KEY environment variable or \
37
+ pass it to the constructor.'
38
+
39
+ assert img_size > 0 or img_size == -1
40
+ self.img_size = -1 # allways send full size image
41
+ assert img_detail in ['high', 'low']
42
+ self.img_detail = img_detail
43
+
44
+ self.vision = True
45
+ self.timeout = timeout
46
+
47
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
48
+
49
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
50
+ # content can be a string or a list of image & text
51
+ def prepare_inputs(self, inputs):
52
+ input_msgs = []
53
+ if self.system_prompt is not None:
54
+ input_msgs.append(dict(role='system', content=self.system_prompt))
55
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
56
+ if has_images:
57
+ content_list = []
58
+ for msg in inputs:
59
+ if msg['type'] == 'text':
60
+ content_list.append(dict(type='text', text=msg['value']))
61
+ elif msg['type'] == 'image':
62
+ from PIL import Image
63
+ img = Image.open(msg['value'])
64
+ b64 = encode_image_to_base64(img, target_size=self.img_size)
65
+ img_struct = dict(url=f"data:image/jpeg;base64,{b64}", detail=self.img_detail)
66
+ content_list.append(dict(type='image_url', image_url=img_struct))
67
+ input_msgs.append(dict(role='user', content=content_list))
68
+ else:
69
+ assert all([x['type'] == 'text' for x in inputs])
70
+ text = '\n'.join([x['value'] for x in inputs])
71
+ input_msgs.append(dict(role='user', content=text))
72
+ return input_msgs
73
+
74
+ def generate_inner(self, inputs, **kwargs) -> str:
75
+ input_msgs = self.prepare_inputs(inputs)
76
+ temperature = kwargs.pop('temperature', self.temperature)
77
+ max_tokens = kwargs.pop('max_tokens', self.max_tokens)
78
+
79
+ if 0 < max_tokens <= 100:
80
+ self.logger.warning(
81
+ 'Less than 100 tokens left, '
82
+ 'may exceed the context window with some additional meta symbols. '
83
+ )
84
+ if max_tokens <= 0:
85
+ return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
86
+
87
+ headers = {'Content-Type': 'application/json', 'Authorization': f'{self.key}'}
88
+ payload = dict(
89
+ model=self.model,
90
+ messages=input_msgs,
91
+ max_tokens=max_tokens,
92
+ n=1,
93
+ temperature=temperature,
94
+ **kwargs)
95
+ response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
96
+ ret_code = response.status_code
97
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
98
+ answer = self.fail_msg
99
+ try:
100
+ resp_struct = json.loads(response.text)
101
+ answer = resp_struct['choices'][0]['message']['content'].strip()
102
+ except Exception as err:
103
+ if self.verbose:
104
+ self.logger.error(f'{type(err)}: {err}')
105
+ self.logger.error(response.text if hasattr(response, 'text') else response)
106
+
107
+ return ret_code, answer, response
VLMEvalKit/vlmeval/api/gemini.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+
4
+ headers = 'Content-Type: application/json'
5
+
6
+
7
+ class GeminiWrapper(BaseAPI):
8
+
9
+ is_api: bool = True
10
+
11
+ def __init__(self,
12
+ model: str = 'gemini-1.0-pro',
13
+ retry: int = 5,
14
+ wait: int = 5,
15
+ key: str = None,
16
+ verbose: bool = True,
17
+ temperature: float = 0.0,
18
+ system_prompt: str = None,
19
+ max_tokens: int = 1024,
20
+ proxy: str = None,
21
+ backend='genai',
22
+ project_id='vlmeval',
23
+ **kwargs):
24
+
25
+ self.model = model
26
+ self.fail_msg = 'Failed to obtain answer via API. '
27
+ self.max_tokens = max_tokens
28
+ self.temperature = temperature
29
+ if key is None:
30
+ key = os.environ.get('GOOGLE_API_KEY', None)
31
+ # Try to load backend from environment variable
32
+ be = os.environ.get('GOOGLE_API_BACKEND', None)
33
+ if be is not None and be in ['genai', 'vertex']:
34
+ backend = be
35
+
36
+ assert backend in ['genai', 'vertex']
37
+ if backend == 'genai':
38
+ # We have not evaluated Gemini-1.5 w. GenAI backend
39
+ assert key is not None # Vertex does not require API Key
40
+
41
+ self.backend = backend
42
+ self.project_id = project_id
43
+ self.api_key = key
44
+
45
+ if proxy is not None:
46
+ proxy_set(proxy)
47
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
48
+
49
+ def build_msgs_genai(self, inputs):
50
+ messages = [] if self.system_prompt is None else [self.system_prompt]
51
+ for inp in inputs:
52
+ if inp['type'] == 'text':
53
+ messages.append(inp['value'])
54
+ elif inp['type'] == 'image':
55
+ messages.append(Image.open(inp['value']))
56
+ return messages
57
+
58
+ def build_msgs_vertex(self, inputs):
59
+ from vertexai.generative_models import Part, Image
60
+ messages = [] if self.system_prompt is None else [self.system_prompt]
61
+ for inp in inputs:
62
+ if inp['type'] == 'text':
63
+ messages.append(inp['value'])
64
+ elif inp['type'] == 'image':
65
+ messages.append(Part.from_image(Image.load_from_file(inp['value'])))
66
+ return messages
67
+
68
+ def generate_inner(self, inputs, **kwargs) -> str:
69
+ if self.backend == 'genai':
70
+ import google.generativeai as genai
71
+ assert isinstance(inputs, list)
72
+ pure_text = np.all([x['type'] == 'text' for x in inputs])
73
+ genai.configure(api_key=self.api_key)
74
+
75
+ if pure_text and self.model == 'gemini-1.0-pro':
76
+ model = genai.GenerativeModel('gemini-1.0-pro')
77
+ else:
78
+ model = genai.GenerativeModel(self.model)
79
+
80
+ messages = self.build_msgs_genai(inputs)
81
+ gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
82
+ gen_config.update(kwargs)
83
+ try:
84
+ answer = model.generate_content(
85
+ messages,
86
+ generation_config=genai.types.GenerationConfig(**gen_config)).text
87
+ return 0, answer, 'Succeeded! '
88
+ except Exception as err:
89
+ if self.verbose:
90
+ self.logger.error(f'{type(err)}: {err}')
91
+ self.logger.error(f'The input messages are {inputs}.')
92
+
93
+ return -1, '', ''
94
+ elif self.backend == 'vertex':
95
+ import vertexai
96
+ from vertexai.generative_models import GenerativeModel
97
+ vertexai.init(project=self.project_id, location='us-central1')
98
+ model_name = 'gemini-1.0-pro-vision' if self.model == 'gemini-1.0-pro' else self.model
99
+ model = GenerativeModel(model_name=model_name)
100
+ messages = self.build_msgs_vertex(inputs)
101
+ try:
102
+ resp = model.generate_content(messages)
103
+ answer = resp.text
104
+ return 0, answer, 'Succeeded! '
105
+ except Exception as err:
106
+ if self.verbose:
107
+ self.logger.error(f'{type(err)}: {err}')
108
+ self.logger.error(f'The input messages are {inputs}.')
109
+
110
+ return -1, '', ''
111
+
112
+
113
+ class GeminiProVision(GeminiWrapper):
114
+
115
+ def generate(self, message, dataset=None):
116
+ return super(GeminiProVision, self).generate(message)
VLMEvalKit/vlmeval/api/glm_vision.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ requests.packages.urllib3.disable_warnings()
3
+
4
+ from vlmeval.smp import *
5
+ from vlmeval.api.base import BaseAPI
6
+ from vlmeval.dataset import DATASET_TYPE
7
+ from vlmeval.smp.vlm import encode_image_file_to_base64
8
+
9
+
10
+ class GLMVisionWrapper(BaseAPI):
11
+
12
+ is_api: bool = True
13
+
14
+ def __init__(self,
15
+ model: str,
16
+ retry: int = 5,
17
+ wait: int = 5,
18
+ key: str = None,
19
+ verbose: bool = True,
20
+ system_prompt: str = None,
21
+ max_tokens: int = 4096,
22
+ proxy: str = None,
23
+ **kwargs):
24
+
25
+ self.model = model
26
+ self.fail_msg = 'Failed to obtain answer via API. '
27
+ self.default_params = {
28
+ 'top_k': 1,
29
+ 'best_of': 1,
30
+ 'do_sample': False,
31
+ 'stream': False,
32
+ 'max_tokens': max_tokens,
33
+ "skip_moderation": True
34
+ }
35
+ if key is None:
36
+ key = os.environ.get('GLMV_API_KEY', None)
37
+ assert key is not None, (
38
+ 'Please set the API Key (obtain it here: '
39
+ 'https://open.bigmodel.cn/dev/howuse/introduction)'
40
+ )
41
+ self.key = key
42
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
43
+
44
+ def build_msgs(self, msgs_raw, system_prompt=None, dataset=None):
45
+ msgs = cp.deepcopy(msgs_raw)
46
+ content = []
47
+ for i, msg in enumerate(msgs):
48
+ if msg['type'] == 'text':
49
+ content.append(dict(type='text', text=msg['value']))
50
+ elif msg['type'] == 'image':
51
+ content.append(dict(type='image_url', image_url=dict(url=encode_image_file_to_base64(msg['value']))))
52
+ if dataset in {'HallusionBench', 'POPE'}:
53
+ content.append(dict(type="text", text="Please answer yes or no."))
54
+ ret = [dict(role='user', content=content)]
55
+ return ret
56
+
57
+ def generate_inner(self, inputs, **kwargs) -> str:
58
+ assert isinstance(inputs, str) or isinstance(inputs, list)
59
+ inputs = [inputs] if isinstance(inputs, str) else inputs
60
+
61
+ messages = self.build_msgs(msgs_raw=inputs, dataset=kwargs.get('dataset', None))
62
+
63
+ url = 'https://api.chatglm.cn/v1/chat/completions'
64
+ headers = {
65
+ 'Content-Type': 'application/json',
66
+ 'Request-Id': 'remote-test',
67
+ 'Authorization': f'Bearer {self.key}'
68
+ }
69
+ payload = {
70
+ 'model': self.model,
71
+ 'messages': messages,
72
+ **self.default_params
73
+ }
74
+ response = requests.post(url, headers=headers, data=json.dumps(payload), verify=False)
75
+ output = []
76
+ try:
77
+ assert response.status_code == 200
78
+ for line in response.iter_lines():
79
+ data = json.loads(line.decode('utf-8').lstrip('data: '))
80
+ output.append(data['choices'][0]['message']['content'])
81
+ answer = ''.join(output).replace('</s>', '')
82
+ if self.verbose:
83
+ self.logger.info(f'inputs: {inputs}\nanswer: {answer}')
84
+ return 0, answer, 'Succeeded! '
85
+ except Exception as err:
86
+ if self.verbose:
87
+ self.logger.error(f'{type(err)}: {err}')
88
+ self.logger.error(f'The input messages are {inputs}.')
89
+ return -1, self.fail_msg, ''
90
+
91
+
92
+ class GLMVisionAPI(GLMVisionWrapper):
93
+
94
+ def generate(self, message, dataset=None):
95
+ return super(GLMVisionAPI, self).generate(message, dataset=dataset)
VLMEvalKit/vlmeval/api/gpt.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ import os
3
+ import sys
4
+ from .base import BaseAPI
5
+
6
+ APIBASES = {
7
+ 'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
8
+ }
9
+
10
+
11
+ def GPT_context_window(model):
12
+ length_map = {
13
+ 'gpt-4': 8192,
14
+ 'gpt-4-0613': 8192,
15
+ 'gpt-4-turbo-preview': 128000,
16
+ 'gpt-4-1106-preview': 128000,
17
+ 'gpt-4-0125-preview': 128000,
18
+ 'gpt-4-vision-preview': 128000,
19
+ 'gpt-4-turbo': 128000,
20
+ 'gpt-4-turbo-2024-04-09': 128000,
21
+ 'gpt-3.5-turbo': 16385,
22
+ 'gpt-3.5-turbo-0125': 16385,
23
+ 'gpt-3.5-turbo-1106': 16385,
24
+ 'gpt-3.5-turbo-instruct': 4096,
25
+ }
26
+ if model in length_map:
27
+ return length_map[model]
28
+ else:
29
+ return 128000
30
+
31
+
32
+ class OpenAIWrapper(BaseAPI):
33
+
34
+ is_api: bool = True
35
+
36
+ def __init__(self,
37
+ model: str = 'gpt-3.5-turbo-0613',
38
+ retry: int = 5,
39
+ wait: int = 5,
40
+ key: str = None,
41
+ verbose: bool = False,
42
+ system_prompt: str = None,
43
+ temperature: float = 0,
44
+ timeout: int = 60,
45
+ api_base: str = None,
46
+ max_tokens: int = 1024,
47
+ img_size: int = 512,
48
+ img_detail: str = 'low',
49
+ use_azure: bool = False,
50
+ **kwargs):
51
+
52
+ self.model = model
53
+ self.cur_idx = 0
54
+ self.fail_msg = 'Failed to obtain answer via API. '
55
+ self.max_tokens = max_tokens
56
+ self.temperature = temperature
57
+ self.use_azure = use_azure
58
+
59
+ if 'step' in model:
60
+ env_key = os.environ.get('STEPAI_API_KEY', '')
61
+ if key is None:
62
+ key = env_key
63
+ elif 'yi-vision' in model:
64
+ env_key = os.environ.get('YI_API_KEY', '')
65
+ if key is None:
66
+ key = env_key
67
+ elif 'internvl2-pro' in model:
68
+ env_key = os.environ.get('InternVL2_PRO_KEY', '')
69
+ if key is None:
70
+ key = env_key
71
+ elif 'abab' in model:
72
+ env_key = os.environ.get('MiniMax_API_KEY', '')
73
+ if key is None:
74
+ key = env_key
75
+ else:
76
+ if use_azure:
77
+ env_key = os.environ.get('AZURE_OPENAI_API_KEY', None)
78
+ assert env_key is not None, 'Please set the environment variable AZURE_OPENAI_API_KEY. '
79
+
80
+ if key is None:
81
+ key = env_key
82
+ assert isinstance(key, str), (
83
+ 'Please set the environment variable AZURE_OPENAI_API_KEY to your openai key. '
84
+ )
85
+ else:
86
+ env_key = os.environ.get('OPENAI_API_KEY', '')
87
+ if key is None:
88
+ key = env_key
89
+ assert isinstance(key, str) and key.startswith('sk-'), (
90
+ f'Illegal openai_key {key}. '
91
+ 'Please set the environment variable OPENAI_API_KEY to your openai key. '
92
+ )
93
+
94
+ self.key = key
95
+ assert img_size > 0 or img_size == -1
96
+ self.img_size = img_size
97
+ assert img_detail in ['high', 'low']
98
+ self.img_detail = img_detail
99
+ self.timeout = timeout
100
+
101
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
102
+
103
+ if use_azure:
104
+ api_base_template = (
105
+ '{endpoint}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}'
106
+ )
107
+ endpoint = os.getenv('AZURE_OPENAI_ENDPOINT', None)
108
+ assert endpoint is not None, 'Please set the environment variable AZURE_OPENAI_ENDPOINT. '
109
+ deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME', None)
110
+ assert deployment_name is not None, 'Please set the environment variable AZURE_OPENAI_DEPLOYMENT_NAME. '
111
+ api_version = os.getenv('OPENAI_API_VERSION', None)
112
+ assert api_version is not None, 'Please set the environment variable OPENAI_API_VERSION. '
113
+
114
+ self.api_base = api_base_template.format(
115
+ endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
116
+ deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'),
117
+ api_version=os.getenv('OPENAI_API_VERSION')
118
+ )
119
+ else:
120
+ if api_base is None:
121
+ if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
122
+ self.logger.info('Environment variable OPENAI_API_BASE is set. Will use it as api_base. ')
123
+ api_base = os.environ['OPENAI_API_BASE']
124
+ else:
125
+ api_base = 'OFFICIAL'
126
+
127
+ assert api_base is not None
128
+
129
+ if api_base in APIBASES:
130
+ self.api_base = APIBASES[api_base]
131
+ elif api_base.startswith('http'):
132
+ self.api_base = api_base
133
+ else:
134
+ self.logger.error('Unknown API Base. ')
135
+ raise NotImplementedError
136
+
137
+ self.logger.info(f'Using API Base: {self.api_base}; API Key: {self.key}')
138
+
139
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
140
+ # content can be a string or a list of image & text
141
+ def prepare_itlist(self, inputs):
142
+ assert np.all([isinstance(x, dict) for x in inputs])
143
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
144
+ if has_images:
145
+ content_list = []
146
+ for msg in inputs:
147
+ if msg['type'] == 'text':
148
+ content_list.append(dict(type='text', text=msg['value']))
149
+ elif msg['type'] == 'image':
150
+ from PIL import Image
151
+ img = Image.open(msg['value'])
152
+ b64 = encode_image_to_base64(img, target_size=self.img_size)
153
+ img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
154
+ content_list.append(dict(type='image_url', image_url=img_struct))
155
+ else:
156
+ assert all([x['type'] == 'text' for x in inputs])
157
+ text = '\n'.join([x['value'] for x in inputs])
158
+ content_list = [dict(type='text', text=text)]
159
+ return content_list
160
+
161
+ def prepare_inputs(self, inputs):
162
+ input_msgs = []
163
+ if self.system_prompt is not None:
164
+ input_msgs.append(dict(role='system', content=self.system_prompt))
165
+ assert isinstance(inputs, list) and isinstance(inputs[0], dict)
166
+ assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
167
+ if 'role' in inputs[0]:
168
+ assert inputs[-1]['role'] == 'user', inputs[-1]
169
+ for item in inputs:
170
+ input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
171
+ else:
172
+ input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
173
+ return input_msgs
174
+
175
+ def generate_inner(self, inputs, **kwargs) -> str:
176
+ input_msgs = self.prepare_inputs(inputs)
177
+ temperature = kwargs.pop('temperature', self.temperature)
178
+ max_tokens = kwargs.pop('max_tokens', self.max_tokens)
179
+
180
+ context_window = GPT_context_window(self.model)
181
+ new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
182
+ if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
183
+ self.logger.warning(
184
+ 'Less than 100 tokens left, '
185
+ 'may exceed the context window with some additional meta symbols. '
186
+ )
187
+ if new_max_tokens <= 0:
188
+ return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
189
+ max_tokens = new_max_tokens
190
+
191
+ # Will send request if use Azure, dk how to use openai client for it
192
+ if self.use_azure:
193
+ headers = {'Content-Type': 'application/json', 'api-key': self.key}
194
+ elif 'internvl2-pro' in self.model:
195
+ headers = {'Content-Type': 'application/json', 'Authorization': self.key}
196
+ else:
197
+ headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
198
+ payload = dict(
199
+ model=self.model,
200
+ messages=input_msgs,
201
+ max_tokens=max_tokens,
202
+ n=1,
203
+ temperature=temperature,
204
+ **kwargs)
205
+ response = requests.post(
206
+ self.api_base,
207
+ headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
208
+ ret_code = response.status_code
209
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
210
+ answer = self.fail_msg
211
+ try:
212
+ resp_struct = json.loads(response.text)
213
+ answer = resp_struct['choices'][0]['message']['content'].strip()
214
+ except Exception as err:
215
+ if self.verbose:
216
+ self.logger.error(f'{type(err)}: {err}')
217
+ self.logger.error(response.text if hasattr(response, 'text') else response)
218
+
219
+ return ret_code, answer, response
220
+
221
+ def get_image_token_len(self, img_path, detail='low'):
222
+ import math
223
+ if detail == 'low':
224
+ return 85
225
+
226
+ im = Image.open(img_path)
227
+ height, width = im.size
228
+ if width > 1024 or height > 1024:
229
+ if width > height:
230
+ height = int(height * 1024 / width)
231
+ width = 1024
232
+ else:
233
+ width = int(width * 1024 / height)
234
+ height = 1024
235
+
236
+ h = math.ceil(height / 512)
237
+ w = math.ceil(width / 512)
238
+ total = 85 + 170 * h * w
239
+ return total
240
+
241
+ def get_token_len(self, inputs) -> int:
242
+ import tiktoken
243
+ try:
244
+ enc = tiktoken.encoding_for_model(self.model)
245
+ except Exception as err:
246
+ if 'gpt' in self.model.lower():
247
+ if self.verbose:
248
+ self.logger.warning(f'{type(err)}: {err}')
249
+ enc = tiktoken.encoding_for_model('gpt-4')
250
+ else:
251
+ return 0
252
+ assert isinstance(inputs, list)
253
+ tot = 0
254
+ for item in inputs:
255
+ if 'role' in item:
256
+ tot += self.get_token_len(item['content'])
257
+ elif item['type'] == 'text':
258
+ tot += len(enc.encode(item['value']))
259
+ elif item['type'] == 'image':
260
+ tot += self.get_image_token_len(item['value'], detail=self.img_detail)
261
+ return tot
262
+
263
+
264
+ class GPT4V(OpenAIWrapper):
265
+
266
+ def generate(self, message, dataset=None):
267
+ return super(GPT4V, self).generate(message)
VLMEvalKit/vlmeval/api/hf_chat_model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import os.path as osp
4
+ import torch
5
+ from ..smp import *
6
+
7
+
8
+ def get_gpu_num(model_name):
9
+ model_name = model_name.lower()
10
+ kws = {
11
+ 8: ['65b', '70b'],
12
+ 4: ['30b', '33b', '35b', '40b'],
13
+ 2: ['13b', '14b', '20b'],
14
+ 1: ['6b', '7b', 'moss'],
15
+ }
16
+ for k in [8, 4, 2, 1]:
17
+ for keyword in kws[k]:
18
+ if keyword in model_name:
19
+ return k
20
+ return 8
21
+
22
+
23
+ validated_llms = [
24
+ 'internlm/internlm-chat-7b', 'internlm/internlm-chat-7b-8k', 'internlm/internlm-chat-20b',
25
+ 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat',
26
+ 'THUDM/chatglm2-6b', 'THUDM/chatglm2-6b-32k', 'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b-32k',
27
+ 'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat',
28
+ 'lmsys/vicuna-7b-v1.5', 'lmsys/vicuna-13b-v1.5',
29
+ 'meta-llama/Llama-2-7b-chat-hf'
30
+ ]
31
+ Auto_model = ['chatglm']
32
+
33
+
34
+ class HFChatModel:
35
+
36
+ def _get_context_length(self, model, model_path):
37
+ # By default, we use model.config.seq_length
38
+ model_path = model_path.lower()
39
+ if 'baichuan' in model_path:
40
+ context_window = model.config.model_max_length
41
+ elif 'internlm' in model_path or 'llama' in model_path:
42
+ context_window = model.config.max_position_embeddings
43
+ elif 'vicuna' in model_path:
44
+ context_window = model.generation_config.max_length
45
+ else:
46
+ # chatglm & qwen
47
+ context_window = model.config.seq_length
48
+ return context_window
49
+
50
+ def _get_context_length_robust(self, model, model_path):
51
+ try:
52
+ context_window = self._get_context_length(model, model_path)
53
+ return context_window
54
+ except Exception as err:
55
+ self.logger.critical(f'{type(err)}: {err}')
56
+ self.logger.critical(
57
+ 'Failed to extract context_window information from config / generation_config. '
58
+ 'Please read the above code and check if the logic works for you model path'
59
+ )
60
+ raise NotImplementedError
61
+
62
+ def __init__(self,
63
+ model_path,
64
+ system_prompt: str = None,
65
+ **kwargs):
66
+
67
+ self.logger = get_logger('HFChatModel')
68
+ if 'vicuna' in model_path.lower():
69
+ try:
70
+ from fastchat.model import get_conversation_template
71
+ except Exception as err:
72
+ self.logger.critical('Please install fastchat first to use vicuna. ')
73
+ raise err
74
+
75
+ self.explicit_device = kwargs.pop('device', None)
76
+
77
+ if self.explicit_device is None:
78
+ # If CUDA_VISIBLE_DEVICES is not properly set
79
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ or os.environ['CUDA_VISIBLE_DEVICES'] == '0,1,2,3,4,5,6,7':
80
+ num_gpu = get_gpu_num(model_path)
81
+ gpu_offset = kwargs.pop('gpu_offset', 0)
82
+ cuda_visible_devices = ','.join([str(i) for i in range(gpu_offset, gpu_offset + num_gpu)])
83
+ os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
84
+
85
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
86
+ from transformers.generation import GenerationConfig
87
+
88
+ if model_path not in validated_llms:
89
+ self.logger.warning(f'{model_path} not in validated LLMs, may have inference troubles. ')
90
+
91
+ self.model_path = model_path
92
+ if listinstr(Auto_model, model_path):
93
+ LoadModel = AutoModel
94
+ else:
95
+ LoadModel = AutoModelForCausalLM
96
+
97
+ assert osp.exists(model_path) or len(model_path.split('/')) == 2
98
+
99
+ device = self.explicit_device if self.explicit_device else 'auto'
100
+
101
+ precision = {}
102
+ if 'internlm-chat-7b' in model_path:
103
+ precision = {'torch_dtype': torch.float16}
104
+ elif 'internlm-chat-20b' in model_path:
105
+ precision = {'torch_dtype': torch.bfloat16}
106
+
107
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
108
+ model = LoadModel.from_pretrained(model_path, trust_remote_code=True, device_map='cpu', **precision)
109
+ model = model.eval()
110
+
111
+ if device != 'cpu':
112
+ model = model.to(f'cuda:{device}' if isinstance(device, int) else 'cuda')
113
+ try:
114
+ model.generation_config = GenerationConfig.from_pretrained(
115
+ model_path, trust_remote_code=True, device_map=device)
116
+ except Exception as err:
117
+ self.logger.warning(f'{type(err)}: {err}')
118
+
119
+ torch.cuda.empty_cache()
120
+ self.model = model
121
+ self.context_length = self._get_context_length_robust(model=model, model_path=model_path)
122
+ self.answer_buffer = 192
123
+ self.system_prompt = system_prompt
124
+ for k, v in kwargs.items():
125
+ self.logger.info(f'Following args will be used for generation (If not set specifically), {k}: {v}. ')
126
+ self.kwargs = kwargs
127
+
128
+ def generate_str(self, input, **kwargs):
129
+ if 'baichuan' in self.model_path.lower():
130
+ messages = []
131
+ messages.append({'role': 'user', 'content': input})
132
+ resp = self.model.chat(self.tokenizer, messages, **kwargs)
133
+ elif 'vicuna' in self.model_path.lower():
134
+ from fastchat.model import get_conversation_template
135
+ conv = get_conversation_template('vicuna')
136
+ conv.append_message(conv.roles[0], input)
137
+ conv.append_message(conv.roles[1], None)
138
+ prompt = conv.get_prompt()
139
+ inputs = self.tokenizer([prompt], return_tensors='pt')
140
+ if torch.cuda.is_available():
141
+ for k in inputs:
142
+ inputs[k] = inputs[k].cuda()
143
+
144
+ params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
145
+ params.update(self.kwargs)
146
+ params.update(kwargs)
147
+ outputs = self.model.generate(**inputs, **params)
148
+ resp = self.tokenizer.decode(
149
+ outputs[0][len(inputs['input_ids'][0]):],
150
+ skip_special_tokens=True,
151
+ spaces_between_special_tokens=False)
152
+
153
+ else:
154
+ params = self.kwargs
155
+ params.update(kwargs)
156
+ resp, _ = self.model.chat(self.tokenizer, input, history=[], **params)
157
+
158
+ return resp
159
+
160
+ def length_ok(self, inputs):
161
+ tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0
162
+ for s in inputs:
163
+ tot += len(self.tokenizer.encode(s))
164
+ return tot + self.answer_buffer < self.context_length
165
+
166
+ def generate_list(self, full_inputs, offset=0, **kwargs):
167
+ assert isinstance(full_inputs, list)
168
+
169
+ inputs = full_inputs[offset:]
170
+ if not self.length_ok(inputs):
171
+ return self.chat(full_inputs, offset + 1)
172
+
173
+ model_path = self.model_path.lower()
174
+
175
+ if sum([x in model_path for x in ['baichuan']]):
176
+ input_msgs = []
177
+ if self.system_prompt is not None:
178
+ input_msgs.append(dict(role='user', content=self.system_prompt))
179
+ if len(inputs):
180
+ assert isinstance(inputs, list) and isinstance(inputs[0], str)
181
+ roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
182
+ roles = roles * len(inputs)
183
+ for role, msg in zip(roles, inputs):
184
+ input_msgs.append(dict(role=role, content=msg))
185
+ response = self.model.chat(self.tokenizer, input_msgs)
186
+ elif sum([x in model_path for x in ['vicuna']]):
187
+ from fastchat.model import get_conversation_template
188
+ conv = get_conversation_template('vicuna')
189
+ assert isinstance(inputs, list) and isinstance(inputs[0], str)
190
+ if len(inputs) % 2 == 1:
191
+ if self.system_prompt is not None:
192
+ conv.append_message(conv.roles[0], self.system_prompt)
193
+ for i in range(len(inputs) // 2):
194
+ conv.append_message(conv.roles[0], inputs[2 * i])
195
+ conv.append_message(conv.roles[1], inputs[2 * i + 1])
196
+ else:
197
+ assert self.system_prompt is not None
198
+ conv.append_message(conv.roles[0], self.system_prompt)
199
+ conv.append_message(conv.roles[1], inputs[0])
200
+ for i in range(len(inputs) // 2 - 1):
201
+ conv.append_message(conv.roles[0], inputs[2 * i + 1])
202
+ conv.append_message(conv.roles[1], inputs[2 * i + 2])
203
+ conv.append_message(conv.roles[0], inputs[-1])
204
+ conv.append_message(conv.roles[1], None)
205
+ prompt = conv.get_prompt()
206
+ inputs = self.tokenizer([prompt], return_tensors='pt')
207
+ if torch.cuda.is_available():
208
+ for k in inputs:
209
+ inputs[k] = inputs[k].cuda()
210
+
211
+ params = dict(do_sample=True, temperature=0.7, repetition_penalty=1.0, max_new_tokens=512)
212
+ params.update(self.kwargs)
213
+ params.update(kwargs)
214
+
215
+ outputs = self.model.generate(**inputs, **params)
216
+ response = self.tokenizer.decode(
217
+ outputs[0][len(inputs['input_ids'][0]):],
218
+ skip_special_tokens=True,
219
+ spaces_between_special_tokens=False)
220
+ response = response.lstrip('\n')
221
+ else:
222
+ # The default option, support internlm, chatglm, qwen
223
+ history, msg = [], None
224
+ if len(inputs) % 2 == 1:
225
+ if self.system_prompt is not None:
226
+ history = [(self.system_prompt, '')]
227
+ for i in range(len(inputs) // 2):
228
+ history.append((inputs[2 * i], inputs[2 * i + 1]))
229
+ else:
230
+ assert self.system_prompt is not None
231
+ history = [(self.system_prompt, inputs[0])]
232
+ for i in range(len(inputs) // 2 - 1):
233
+ history.append((inputs[2 * i + 1], inputs[2 * i + 2]))
234
+ msg = inputs[-1]
235
+
236
+ params = self.kwargs
237
+ params.update(kwargs)
238
+ response, _ = self.model.chat(self.tokenizer, msg, history=history, **params)
239
+
240
+ return response, offset
241
+
242
+ def generate(self, inputs, **kwargs):
243
+ if isinstance(inputs, str):
244
+ return self.generate_str(inputs, **kwargs)
245
+ elif isinstance(inputs, list):
246
+ return self.generate_list(inputs, **kwargs)
VLMEvalKit/vlmeval/api/hunyuan.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ import os
3
+ import sys
4
+ from vlmeval.api.base import BaseAPI
5
+
6
+
7
+ class HunyuanWrapper(BaseAPI):
8
+
9
+ is_api: bool = True
10
+ _apiVersion = '2023-09-01'
11
+ _service = 'hunyuan'
12
+
13
+ def __init__(self,
14
+ model: str = 'hunyuan-vision',
15
+ retry: int = 5,
16
+ wait: int = 5,
17
+ secret_key: str = None,
18
+ secret_id: str = None,
19
+ verbose: bool = True,
20
+ system_prompt: str = None,
21
+ temperature: float = 0,
22
+ timeout: int = 60,
23
+ api_base: str = 'hunyuan.tencentcloudapi.com',
24
+ **kwargs):
25
+
26
+ self.model = model
27
+ self.cur_idx = 0
28
+ self.fail_msg = 'Failed to obtain answer via API. '
29
+ self.temperature = temperature
30
+
31
+ warnings.warn('You may need to set the env variable HUNYUAN_SECRET_ID & HUNYUAN_SECRET_KEY to use Hunyuan. ')
32
+
33
+ secret_key = os.environ.get('HUNYUAN_SECRET_KEY', secret_key)
34
+ assert secret_key is not None, 'Please set the environment variable HUNYUAN_SECRET_KEY. '
35
+ secret_id = os.environ.get('HUNYUAN_SECRET_ID', secret_id)
36
+ assert secret_id is not None, 'Please set the environment variable HUNYUAN_SECRET_ID. '
37
+
38
+ self.model = model
39
+ self.endpoint = api_base
40
+ self.secret_id = secret_id
41
+ self.secret_key = secret_key
42
+ self.timeout = timeout
43
+
44
+ try:
45
+ from tencentcloud.common import credential
46
+ from tencentcloud.common.profile.client_profile import ClientProfile
47
+ from tencentcloud.common.profile.http_profile import HttpProfile
48
+ from tencentcloud.hunyuan.v20230901 import hunyuan_client
49
+ except ImportError as err:
50
+ self.logger.critical('Please install tencentcloud-sdk-python to use Hunyuan API. ')
51
+ raise err
52
+
53
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
54
+
55
+ cred = credential.Credential(self.secret_id, self.secret_key)
56
+ httpProfile = HttpProfile()
57
+ httpProfile.endpoint = self.endpoint
58
+ clientProfile = ClientProfile()
59
+ clientProfile.httpProfile = httpProfile
60
+ self.client = hunyuan_client.HunyuanClient(cred, 'ap-beijing', clientProfile)
61
+ self.logger.info(
62
+ f'Using Endpoint: {self.endpoint}; API Secret ID: {self.secret_id}; API Secret Key: {self.secret_key}'
63
+ )
64
+
65
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
66
+ # content can be a string or a list of image & text
67
+ def prepare_itlist(self, inputs):
68
+ assert np.all([isinstance(x, dict) for x in inputs])
69
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
70
+ if has_images:
71
+ content_list = []
72
+ for msg in inputs:
73
+ if msg['type'] == 'text':
74
+ content_list.append(dict(Type='text', Text=msg['value']))
75
+ elif msg['type'] == 'image':
76
+ from PIL import Image
77
+ img = Image.open(msg['value'])
78
+ b64 = encode_image_to_base64(img)
79
+ img_struct = dict(Url=f'data:image/jpeg;base64,{b64}')
80
+ content_list.append(dict(Type='image_url', ImageUrl=img_struct))
81
+ else:
82
+ assert all([x['type'] == 'text' for x in inputs])
83
+ text = '\n'.join([x['value'] for x in inputs])
84
+ content_list = [dict(Type='text', Text=text)]
85
+ return content_list
86
+
87
+ def prepare_inputs(self, inputs):
88
+ input_msgs = []
89
+ if self.system_prompt is not None:
90
+ input_msgs.append(dict(Role='system', Content=self.system_prompt))
91
+ assert isinstance(inputs, list) and isinstance(inputs[0], dict)
92
+ assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
93
+ if 'role' in inputs[0]:
94
+ assert inputs[-1]['role'] == 'user', inputs[-1]
95
+ for item in inputs:
96
+ input_msgs.append(dict(Role=item['role'], Contents=self.prepare_itlist(item['content'])))
97
+ else:
98
+ input_msgs.append(dict(Role='user', Contents=self.prepare_itlist(inputs)))
99
+ return input_msgs
100
+
101
+ def generate_inner(self, inputs, **kwargs) -> str:
102
+ from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
103
+ from tencentcloud.hunyuan.v20230901 import models
104
+
105
+ input_msgs = self.prepare_inputs(inputs)
106
+ temperature = kwargs.pop('temperature', self.temperature)
107
+
108
+ payload = dict(
109
+ Model=self.model,
110
+ Messages=input_msgs,
111
+ Temperature=temperature,
112
+ **kwargs)
113
+
114
+ retry_counter = 0
115
+ while retry_counter < 3:
116
+ try:
117
+ req = models.ChatCompletionsRequest()
118
+ req.from_json_string(json.dumps(payload))
119
+ resp = self.client.ChatCompletions(req)
120
+ resp = json.loads(resp.to_json_string())
121
+ answer = resp['Choices'][0]['Message']['Content']
122
+ return 0, answer, resp
123
+ except TencentCloudSDKException as e:
124
+ self.logger.error(f'Got error code: {e.get_code()}')
125
+ if e.get_code() == 'ClientNetworkError':
126
+ return -1, self.fail_msg + e.get_code(), None
127
+ elif e.get_code() in ['InternalError', 'ServerNetworkError']:
128
+ if retry_counter == 3:
129
+ return -1, self.fail_msg + e.get_code(), None
130
+ retry_counter += 1
131
+ continue
132
+ elif e.get_code() in ['LimitExceeded']:
133
+ time.sleep(5)
134
+ if retry_counter == 3:
135
+ return -1, self.fail_msg + e.get_code(), None
136
+ retry_counter += 1
137
+ continue
138
+ else:
139
+ return -1, self.fail_msg + str(e), None
140
+
141
+ return -1, self.fail_msg, None
142
+
143
+
144
+ class HunyuanVision(HunyuanWrapper):
145
+
146
+ def generate(self, message, dataset=None):
147
+ return super(HunyuanVision, self).generate(message)
VLMEvalKit/vlmeval/api/jt_vl_chat.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
+ import json
4
+ import os
5
+ import base64
6
+ from vlmeval.smp import *
7
+ from vlmeval.api.base import BaseAPI
8
+ from vlmeval.dataset import DATASET_TYPE
9
+ from vlmeval.dataset import img_root_map
10
+
11
+
12
+ API_ENDPOINT = 'https://jiutian.10086.cn/kunlun/ingress/api/h3t-eeceff/92390745235a40a484d850be19e1f8b4/ai-5d7ae47ec93f4280953273c4001aafee/service-7544ea5ee3e841ad9d01e7af44acef7c/v1/chat/completions' # noqa: E501
13
+ APP_CODE = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJzdWIiOiI5ZGQwNmQ2ZjU4YTU0ZGY0OGEzNjRhMjQyNGMwODEyNSIsImlzcyI6ImFwaS1hdXRoLWtleSIsImV4cCI6NDg4MjkwNDA3OX0.k5t_T-955xWMndzBbx4WQQNAgm5DpMos9mHm7vkFipQ3yebCFMfyufpSxORSfEVpBaDS3Nly0dd8ygQYGnDgIQcC72vQ1xtkjCP49LNcqlceoET4rGc1zwRi76XLPSGFES4GcwvEmr7Ilth7XtqZNxcDF_Z7HyHyf1-zF0JIQETYSoxenqLU-gNteNfqRUnlyCgaKh03DscAbYvtoMUxEaFa2ZqyRSwekdHI_SPKCq9aC9G19yDPHTjeiwl1ubtyC5uMy5pERn_ClRsZS3Wyb-GmD5QQsFofrWvCiU_fVJuUiez39pYZvEP8awH0R9B7SkpQ4XOzj3fdytTPYy3g6g' # noqa: E501
14
+
15
+
16
+ class JTVLChatWrapper(BaseAPI):
17
+ is_api: bool = True
18
+ INTERLEAVE = False
19
+
20
+ def __init__(self,
21
+ model: str = 'jt-vl-chat',
22
+ retry: int = 5,
23
+ wait: int = 5,
24
+ api_base: str = API_ENDPOINT,
25
+ key: str = APP_CODE,
26
+ verbose: bool = True,
27
+ system_prompt: str = None,
28
+ temperature: float = 0.7,
29
+ max_tokens: int = 256,
30
+ proxy: str = None,
31
+ **kwargs):
32
+ self.model = model
33
+
34
+ self.temperature = temperature
35
+ self.max_tokens = max_tokens
36
+ self.api_base = api_base
37
+
38
+ if key is None:
39
+ key = os.environ.get('JTVLChat_API_KEY', None)
40
+ assert key is not None, (
41
+ 'Please set the API Key (also called app_code, obtain it here: https://github.com/jiutiancv/JT-VL-Chat)'
42
+ )
43
+
44
+ self.key = key
45
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
46
+
47
+ def dump_image(self, line, dataset):
48
+ """Dump the image(s) of the input line to the corresponding dataset folder.
49
+
50
+ Args:
51
+ line (line of pd.DataFrame): The raw input line.
52
+ dataset (str): The name of the dataset.
53
+
54
+ Returns:
55
+ str | list[str]: The paths of the dumped images.
56
+ """
57
+ ROOT = LMUDataRoot()
58
+ assert isinstance(dataset, str)
59
+
60
+ img_root = os.path.join(ROOT, 'images', img_root_map(dataset) if dataset in img_root_map(dataset) else dataset)
61
+ os.makedirs(img_root, exist_ok=True)
62
+ if 'image' in line:
63
+ if isinstance(line['image'], list):
64
+ tgt_path = []
65
+ assert 'image_path' in line
66
+ for img, im_name in zip(line['image'], line['image_path']):
67
+ path = osp.join(img_root, im_name)
68
+ if not read_ok(path):
69
+ decode_base64_to_image_file(img, path)
70
+ tgt_path.append(path)
71
+ else:
72
+ tgt_path = osp.join(img_root, f"{line['index']}.jpg")
73
+ if not read_ok(tgt_path):
74
+ decode_base64_to_image_file(line['image'], tgt_path)
75
+ tgt_path = [tgt_path]
76
+ else:
77
+ assert 'image_path' in line
78
+ tgt_path = toliststr(line['image_path'])
79
+
80
+ return tgt_path
81
+
82
+ def use_custom_prompt(self, dataset):
83
+ assert dataset is not None
84
+ if listinstr(['MMMU_DEV_VAL','MMMU_TEST'], dataset):
85
+ return False
86
+ else:
87
+ return True
88
+
89
+ def build_multi_choice_prompt(self, line, dataset=None):
90
+ question = line['question']
91
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
92
+ if hint is not None:
93
+ question = hint + '\n' + question
94
+
95
+ options = {
96
+ cand: line[cand]
97
+ for cand in string.ascii_uppercase
98
+ if cand in line and not pd.isna(line[cand])
99
+ }
100
+ for key, item in options.items():
101
+ question += f'\n{key}. {item}'
102
+ prompt = question
103
+
104
+ if len(options):
105
+ prompt += '\n请直接回答选项字母。' if cn_string(
106
+ prompt) else "\nAnswer with the option's letter from the given choices directly."
107
+ else:
108
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
109
+
110
+ return prompt
111
+
112
+ def build_prompt(self, line, dataset=None):
113
+ assert self.use_custom_prompt(dataset)
114
+ assert dataset is None or isinstance(dataset, str)
115
+
116
+ tgt_path = self.dump_image(line, dataset)
117
+
118
+ if dataset is not None and listinstr(['MME'], dataset):
119
+ question = line['question']
120
+ prompt = question + ' Answer the question using a single word or phrase.'
121
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
122
+ question = line['question']
123
+ prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
124
+ elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
125
+ prompt = self.build_multi_choice_prompt(line, dataset)
126
+ elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
127
+ if listinstr(['MathVista', 'MathVision'], dataset):
128
+ prompt = line['question']
129
+ elif listinstr(['LLaVABench'], dataset):
130
+ question = line['question']
131
+ prompt = question + '\nAnswer this question in detail.'
132
+ elif listinstr(['MMVet'], dataset):
133
+ prompt = line['question']
134
+ else:
135
+ question = line['question']
136
+ prompt = question + '\nAnswer the question using a single word or phrase.'
137
+ else:
138
+ prompt = line['question']
139
+ message = [dict(type='text', value=prompt)]
140
+ message.extend([dict(type='image', value=s) for s in tgt_path])
141
+ return message
142
+
143
+ def message_to_promptimg(self, message, dataset=None):
144
+ assert not self.INTERLEAVE
145
+ model_name = self.__class__.__name__
146
+ import warnings
147
+ warnings.warn(
148
+ f'Model {model_name} does not support interleaved input. '
149
+ 'Will use the first image and aggregated texts as prompt. ')
150
+ num_images = len([x for x in message if x['type'] == 'image'])
151
+ if num_images == 0:
152
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
153
+ image = None
154
+ else:
155
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
156
+ if dataset == 'BLINK':
157
+ image = concat_images_vlmeval(
158
+ [x['value'] for x in message if x['type'] == 'image'],
159
+ target_size=512)
160
+ else:
161
+ image = [x['value'] for x in message if x['type'] == 'image'][0]
162
+ return prompt, image
163
+
164
+ def get_send_data(self,prompt, image_path, temperature, max_tokens):
165
+ image = ''
166
+ with open(image_path, 'rb') as f:
167
+ image = str(base64.b64encode(f.read()), 'utf-8')
168
+ send_data = {
169
+ "messages": [
170
+ {
171
+ "role": "user",
172
+ "content": prompt
173
+ }
174
+ ],
175
+ "image_base64": image,
176
+ "max_tokens": max_tokens,
177
+ "temperature": temperature
178
+ }
179
+ return send_data
180
+
181
+ def get_send_data_no_image(self,prompt, temperature, max_tokens):
182
+ send_data = {
183
+ "messages": [
184
+ {
185
+ "role": "user",
186
+ "content": prompt
187
+ }
188
+ ],
189
+ "max_tokens": max_tokens,
190
+ "temperature": temperature
191
+ }
192
+ return send_data
193
+
194
+ def generate_inner(self, inputs, **kwargs) -> str:
195
+ assert isinstance(inputs, str) or isinstance(inputs, list)
196
+ inputs = [inputs] if isinstance(inputs, str) else inputs
197
+ dataset = kwargs.get('dataset', None)
198
+ prompt, image_path = self.message_to_promptimg(message=inputs, dataset=dataset)
199
+ # print("prompt:",prompt)
200
+ if image_path:
201
+ send_data = self.get_send_data(
202
+ prompt=prompt,
203
+ image_path=image_path,
204
+ temperature=self.temperature,
205
+ max_tokens=self.max_tokens)
206
+ else:
207
+ send_data = self.get_send_data_no_image(
208
+ prompt=prompt,
209
+ temperature=self.temperature,
210
+ max_tokens=self.max_tokens)
211
+
212
+ json_data = json.dumps(send_data)
213
+
214
+ header_dict = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + self.key}
215
+
216
+ r = requests.post(self.api_base, headers=header_dict, data=json_data, timeout=3000)
217
+ try:
218
+ assert r.status_code == 200
219
+ r_json = r.json()
220
+ output = r_json['choices'][0]['message']['content']
221
+ if self.verbose:
222
+ self.logger.info(f'inputs: {inputs}\nanswer: {output}')
223
+
224
+ return 0,output,'Succeeded! '
225
+
226
+ except:
227
+ error_msg = f'Error! code {r.status_code} content: {r.content}'
228
+ error_con = r.content.decode('utf-8')
229
+ if self.verbose:
230
+ self.logger.error(error_msg)
231
+ self.logger.error(error_con)
232
+ self.logger.error(f'The input messages are {inputs}.')
233
+ return -1,error_msg,''
234
+
235
+
236
+ class JTVLChatAPI(JTVLChatWrapper):
237
+
238
+ def generate(self, message, dataset=None):
239
+ return super(JTVLChatAPI, self).generate(message, dataset=dataset)
VLMEvalKit/vlmeval/api/qwen_api.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from http import HTTPStatus
2
+ import os
3
+ from vlmeval.api.base import BaseAPI
4
+ from vlmeval.smp import *
5
+
6
+
7
+ # Note: This is a pure language model API.
8
+ class QwenAPI(BaseAPI):
9
+
10
+ is_api: bool = True
11
+
12
+ def __init__(self,
13
+ model: str = 'qwen-max-1201',
14
+ retry: int = 5,
15
+ wait: int = 5,
16
+ verbose: bool = True,
17
+ seed: int = 2680,
18
+ temperature: float = 0.0,
19
+ system_prompt: str = None,
20
+ key: str = None,
21
+ max_tokens: int = 1024,
22
+ proxy: str = None,
23
+ **kwargs):
24
+
25
+ assert model in ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-1201', 'qwen-max-longcontext']
26
+ self.model = model
27
+ import dashscope
28
+ self.fail_msg = 'Failed to obtain answer via API. '
29
+ self.max_tokens = max_tokens
30
+ self.temperature = temperature
31
+ self.seed = seed
32
+ if key is None:
33
+ key = os.environ.get('DASHSCOPE_API_KEY', None)
34
+ assert key is not None, (
35
+ 'Please set the API Key (obtain it here: '
36
+ 'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
37
+ )
38
+ dashscope.api_key = key
39
+ if proxy is not None:
40
+ proxy_set(proxy)
41
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
42
+
43
+ @staticmethod
44
+ def build_msgs(msgs_raw, system_prompt=None):
45
+ msgs = cp.deepcopy(msgs_raw)
46
+ ret = []
47
+ if system_prompt is not None:
48
+ ret.append(dict(role='system', content=system_prompt))
49
+ for i, msg in enumerate(msgs):
50
+ role = 'user' if i % 2 == 0 else 'assistant'
51
+ ret.append(dict(role=role, content=msg))
52
+ return ret
53
+
54
+ def generate_inner(self, inputs, **kwargs) -> str:
55
+ from dashscope import MultiModalConversation
56
+ assert isinstance(inputs, str) or isinstance(inputs, list)
57
+ inputs = [inputs] if isinstance(inputs, str) else inputs
58
+ messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
59
+
60
+ import dashscope
61
+ response = dashscope.Generation.call(
62
+ model=self.model,
63
+ messages=messages,
64
+ seed=self.seed,
65
+ temperature=self.temperature,
66
+ max_tokens=self.max_tokens,
67
+ result_format='message', # set the result to be "message" format.
68
+ )
69
+ if response.status_code != HTTPStatus.OK:
70
+ return -1, 'Error: Bad Response Statuse Code. ', f'The response status code is {response.status_code}. '
71
+
72
+ try:
73
+ return 0, response['output']['choices'][0]['message']['content'].strip(), 'Succeeded! '
74
+ except Exception as err:
75
+ return -1, f'Error: Failed to parse the response. {err}', response
VLMEvalKit/vlmeval/api/qwen_vl_api.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+
6
+ from vlmeval.smp import *
7
+ from vlmeval.api.base import BaseAPI
8
+ from vlmeval.vlm.qwen2_vl.prompt import Qwen2VLPromptMixin
9
+
10
+
11
+ def ensure_image_url(image: str) -> str:
12
+ prefixes = ['http://', 'https://', 'file://', 'data:image;']
13
+ if any(image.startswith(prefix) for prefix in prefixes):
14
+ return image
15
+ if os.path.exists(image):
16
+ return 'file://' + image
17
+ raise ValueError(f'Invalid image: {image}')
18
+
19
+
20
+ class Qwen2VLAPI(Qwen2VLPromptMixin, BaseAPI):
21
+ is_api: bool = True
22
+
23
+ def __init__(
24
+ self,
25
+ model: str = 'qwen-vl-max-0809',
26
+ key: str | None = None,
27
+ min_pixels: int | None = None,
28
+ max_pixels: int | None = None,
29
+ max_length=1024,
30
+ top_p=0.001,
31
+ top_k=1,
32
+ temperature=0.01,
33
+ repetition_penalty=1.0,
34
+ presence_penalty=0.0,
35
+ seed=3407,
36
+ use_custom_prompt: bool = True,
37
+ **kwargs,
38
+ ):
39
+ import dashscope
40
+
41
+ self.model = model
42
+ self.min_pixels = min_pixels
43
+ self.max_pixels = max_pixels
44
+ self.generate_kwargs = dict(
45
+ max_length=max_length,
46
+ top_p=top_p,
47
+ top_k=top_k,
48
+ temperature=temperature,
49
+ repetition_penalty=repetition_penalty,
50
+ presence_penalty=presence_penalty,
51
+ seed=seed,
52
+ )
53
+
54
+ key = os.environ.get('DASHSCOPE_API_KEY', None) if key is None else key
55
+ assert key is not None, (
56
+ 'Please set the API Key (obtain it here: '
57
+ 'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
58
+ )
59
+ dashscope.api_key = key
60
+ super().__init__(use_custom_prompt=use_custom_prompt, **kwargs)
61
+
62
+ def _prepare_content(self, inputs: list[dict[str, str]], dataset: str | None = None) -> list[dict[str, str]]:
63
+ """
64
+ inputs list[dict[str, str]], each dict has keys: ['type', 'value']
65
+ """
66
+ content = []
67
+ for s in inputs:
68
+ if s['type'] == 'image':
69
+ item = {'type': 'image', 'image': ensure_image_url(s['value'])}
70
+ if dataset == 'OCRBench':
71
+ item['min_pixels'] = 10 * 10 * 28 * 28
72
+ warnings.warn(f"OCRBench dataset uses custom min_pixels={item['min_pixels']}")
73
+ if self.max_pixels is not None:
74
+ item['max_pixels'] = self.max_pixels
75
+ else:
76
+ if self.min_pixels is not None:
77
+ item['min_pixels'] = self.min_pixels
78
+ if self.max_pixels is not None:
79
+ item['max_pixels'] = self.max_pixels
80
+ elif s['type'] == 'text':
81
+ item = {'type': 'text', 'text': s['value']}
82
+ else:
83
+ raise ValueError(f"Invalid message type: {s['type']}, {s}")
84
+ content.append(item)
85
+ return content
86
+
87
+ def generate_inner(self, inputs, **kwargs) -> str:
88
+ import dashscope
89
+
90
+ messages = []
91
+ if self.system_prompt is not None:
92
+ messages.append({'role': 'system', 'content': self.system_prompt})
93
+ messages.append(
94
+ {'role': 'user', 'content': self._prepare_content(inputs, dataset=kwargs.get('dataset', None))}
95
+ )
96
+ if self.verbose:
97
+ print(f'\033[31m{messages}\033[0m')
98
+
99
+ # generate
100
+ generation_kwargs = self.generate_kwargs.copy()
101
+ kwargs.pop('dataset', None)
102
+ generation_kwargs.update(kwargs)
103
+ try:
104
+ response = dashscope.MultiModalConversation.call(
105
+ model=self.model,
106
+ messages=messages,
107
+ **generation_kwargs,
108
+ )
109
+ if self.verbose:
110
+ print(response)
111
+ answer = response.output.choices[0]['message']['content'][0]['text']
112
+ return 0, answer, 'Succeeded! '
113
+ except Exception as err:
114
+ if self.verbose:
115
+ self.logger.error(f'{type(err)}: {err}')
116
+ self.logger.error(f'The input messages are {inputs}.')
117
+ return -1, '', ''
118
+
119
+
120
+ class QwenVLWrapper(BaseAPI):
121
+
122
+ is_api: bool = True
123
+
124
+ def __init__(self,
125
+ model: str = 'qwen-vl-plus',
126
+ retry: int = 5,
127
+ wait: int = 5,
128
+ key: str = None,
129
+ verbose: bool = True,
130
+ temperature: float = 0.0,
131
+ system_prompt: str = None,
132
+ max_tokens: int = 1024,
133
+ proxy: str = None,
134
+ **kwargs):
135
+
136
+ assert model in ['qwen-vl-plus', 'qwen-vl-max']
137
+ self.model = model
138
+ import dashscope
139
+ self.fail_msg = 'Failed to obtain answer via API. '
140
+ self.max_tokens = max_tokens
141
+ self.temperature = temperature
142
+ if key is None:
143
+ key = os.environ.get('DASHSCOPE_API_KEY', None)
144
+ assert key is not None, (
145
+ 'Please set the API Key (obtain it here: '
146
+ 'https://help.aliyun.com/zh/dashscope/developer-reference/vl-plus-quick-start)'
147
+ )
148
+ dashscope.api_key = key
149
+ if proxy is not None:
150
+ proxy_set(proxy)
151
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
152
+
153
+ # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
154
+ # content can be a string or a list of image & text
155
+ def prepare_itlist(self, inputs):
156
+ assert np.all([isinstance(x, dict) for x in inputs])
157
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
158
+ if has_images:
159
+ content_list = []
160
+ for msg in inputs:
161
+ if msg['type'] == 'text':
162
+ content_list.append(dict(text=msg['value']))
163
+ elif msg['type'] == 'image':
164
+ content_list.append(dict(image='file://' + msg['value']))
165
+ else:
166
+ assert all([x['type'] == 'text' for x in inputs])
167
+ text = '\n'.join([x['value'] for x in inputs])
168
+ content_list = [dict(text=text)]
169
+ return content_list
170
+
171
+ def prepare_inputs(self, inputs):
172
+ input_msgs = []
173
+ if self.system_prompt is not None:
174
+ input_msgs.append(dict(role='system', content=self.system_prompt))
175
+ assert isinstance(inputs, list) and isinstance(inputs[0], dict)
176
+ assert np.all(['type' in x for x in inputs]) or np.all(['role' in x for x in inputs]), inputs
177
+ if 'role' in inputs[0]:
178
+ assert inputs[-1]['role'] == 'user', inputs[-1]
179
+ for item in inputs:
180
+ input_msgs.append(dict(role=item['role'], content=self.prepare_itlist(item['content'])))
181
+ else:
182
+ input_msgs.append(dict(role='user', content=self.prepare_itlist(inputs)))
183
+ return input_msgs
184
+
185
+ def generate_inner(self, inputs, **kwargs) -> str:
186
+ from dashscope import MultiModalConversation
187
+ assert isinstance(inputs, str) or isinstance(inputs, list)
188
+
189
+ if 'type' in inputs[0]:
190
+ pure_text = np.all([x['type'] == 'text' for x in inputs])
191
+ else:
192
+ pure_text = True
193
+ for inp in inputs:
194
+ if not np.all([x['type'] == 'text' for x in inp['content']]):
195
+ pure_text = False
196
+ break
197
+
198
+ assert not pure_text
199
+ messages = self.prepare_inputs(inputs)
200
+ gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
201
+ gen_config.update(kwargs)
202
+ try:
203
+ response = MultiModalConversation.call(model=self.model, messages=messages)
204
+ if self.verbose:
205
+ print(response)
206
+ answer = response.output.choices[0]['message']['content'][0]['text']
207
+ return 0, answer, 'Succeeded! '
208
+ except Exception as err:
209
+ if self.verbose:
210
+ self.logger.error(f'{type(err)}: {err}')
211
+ self.logger.error(f'The input messages are {inputs}.')
212
+
213
+ return -1, '', ''
214
+
215
+
216
+ class QwenVLAPI(QwenVLWrapper):
217
+
218
+ def generate(self, message, dataset=None):
219
+ return super(QwenVLAPI, self).generate(message)
VLMEvalKit/vlmeval/api/reka.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+ from time import sleep
4
+ import mimetypes
5
+
6
+
7
+ class Reka_Wrapper(BaseAPI):
8
+
9
+ is_api: bool = True
10
+ INTERLEAVE: bool = False
11
+
12
+ def __init__(self,
13
+ model: str = 'reka-flash-20240226',
14
+ key: str = None,
15
+ retry: int = 10,
16
+ wait: int = 3,
17
+ system_prompt: str = None,
18
+ verbose: bool = True,
19
+ temperature: float = 0,
20
+ max_tokens: int = 1024,
21
+ **kwargs):
22
+
23
+ try:
24
+ import reka
25
+ except ImportError:
26
+ raise ImportError('Please install reka by running "pip install reka-api"')
27
+
28
+ self.model = model
29
+ default_kwargs = dict(temperature=temperature, request_output_len=max_tokens)
30
+ default_kwargs.update(kwargs)
31
+ self.kwargs = default_kwargs
32
+ if key is not None:
33
+ self.key = key
34
+ else:
35
+ self.key = os.environ.get('REKA_API_KEY', '')
36
+ super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
37
+
38
+ def generate_inner(self, inputs, **kwargs) -> str:
39
+ import reka
40
+ reka.API_KEY = self.key
41
+ dataset = kwargs.pop('dataset', None)
42
+ prompt, image_path = self.message_to_promptimg(inputs, dataset=dataset)
43
+ image_b64 = encode_image_file_to_base64(image_path)
44
+
45
+ response = reka.chat(
46
+ model_name=self.model,
47
+ human=prompt,
48
+ media_url=f'data:image/jpeg;base64,{image_b64}',
49
+ **self.kwargs)
50
+
51
+ try:
52
+ return 0, response['text'], response
53
+ except Exception as err:
54
+ return -1, self.fail_msg + str(err), response
55
+
56
+
57
+ class Reka(Reka_Wrapper):
58
+
59
+ def generate(self, message, dataset=None):
60
+ return super(Reka_Wrapper, self).generate(message)
VLMEvalKit/vlmeval/api/sensechat_vision.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+ from vlmeval.dataset import img_root_map
4
+ from vlmeval.dataset import DATASET_TYPE
5
+
6
+
7
+ class SenseChatVisionWrapper(BaseAPI):
8
+
9
+ is_api: bool = True
10
+
11
+ def __init__(self,
12
+ model: str = 'SenseChat-5-Vision',
13
+ retry: int = 5,
14
+ wait: int = 5,
15
+ ak: str = None,
16
+ sk: str = None,
17
+ verbose: bool = True,
18
+ system_prompt: str = None,
19
+ max_tokens: int = 1024,
20
+ proxy: str = None,
21
+ **kwargs):
22
+
23
+ self.model = model
24
+ self.fail_msg = 'Failed to obtain answer via API. '
25
+ self.ak = os.environ.get('SENSECHAT_AK', None) if ak is None else ak
26
+ self.sk = os.environ.get('SENSECHAT_SK', None) if sk is None else sk
27
+ assert self.ak is not None and self.sk is not None
28
+ self.max_new_tokens = max_tokens
29
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
30
+
31
+ def dump_image(self, line, dataset):
32
+ """Dump the image(s) of the input line to the corresponding dataset folder.
33
+
34
+ Args:
35
+ line (line of pd.DataFrame): The raw input line.
36
+ dataset (str): The name of the dataset.
37
+
38
+ Returns:
39
+ str | list[str]: The paths of the dumped images.
40
+ """
41
+ ROOT = LMUDataRoot()
42
+ assert isinstance(dataset, str)
43
+ img_root = osp.join(ROOT, 'images', img_root_map(dataset))
44
+ os.makedirs(img_root, exist_ok=True)
45
+ if 'image' in line:
46
+ if isinstance(line['image'], list):
47
+ tgt_path = []
48
+ assert 'image_path' in line
49
+ for img, im_name in zip(line['image'], line['image_path']):
50
+ path = osp.join(img_root, im_name)
51
+ if not read_ok(path):
52
+ decode_base64_to_image_file(img, path)
53
+ tgt_path.append(path)
54
+ else:
55
+ tgt_path = osp.join(img_root, f"{line['index']}.jpg")
56
+ if not read_ok(tgt_path):
57
+ decode_base64_to_image_file(line['image'], tgt_path)
58
+ tgt_path = [tgt_path]
59
+ else:
60
+ assert 'image_path' in line
61
+ tgt_path = toliststr(line['image_path'])
62
+
63
+ return tgt_path
64
+
65
+ def image_to_base64(self, image_path):
66
+ import base64
67
+ with open(image_path, 'rb') as image_file:
68
+ encoded_string = base64.b64encode(image_file.read())
69
+ return encoded_string.decode('utf-8')
70
+
71
+ def encode_jwt_token(self, ak, sk):
72
+ import jwt
73
+ headers = {'alg': 'HS256', 'typ': 'JWT'}
74
+ payload = {
75
+ 'iss': ak,
76
+ 'exp': int(time.time())
77
+ + 1800, # 填写您期望的有效时间,此处示例代表当前时间+30分钟
78
+ 'nbf': int(time.time()) - 5, # 填写您期望的生效时间,此处示例代表当前时间-5秒
79
+ }
80
+ token = jwt.encode(payload, sk, headers=headers)
81
+ return token
82
+
83
+ def use_custom_prompt(self, dataset):
84
+ return True
85
+
86
+ def build_multi_choice_prompt(self, line, dataset=None):
87
+ question = line['question']
88
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
89
+ if hint is not None:
90
+ question = hint + '\n' + question
91
+
92
+ options = {
93
+ cand: line[cand]
94
+ for cand in string.ascii_uppercase
95
+ if cand in line and not pd.isna(line[cand])
96
+ }
97
+ for key, item in options.items():
98
+ question += f'\n{key}. {item}'
99
+ prompt = question
100
+
101
+ if len(options):
102
+ prompt += '\n请直接回答选项字母。' if cn_string(
103
+ prompt) else "\nAnswer with the option's letter from the given choices directly."
104
+ else:
105
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
106
+
107
+ return prompt
108
+
109
+ def build_prompt(self, line, dataset=None):
110
+ assert self.use_custom_prompt(dataset)
111
+ assert dataset is None or isinstance(dataset, str)
112
+
113
+ tgt_path = self.dump_image(line, dataset)
114
+
115
+ if dataset is not None and listinstr(['MME'], dataset):
116
+ question = line['question']
117
+ prompt = question + ' Answer the question using a single word or phrase.'
118
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
119
+ question = line['question']
120
+ prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
121
+ elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ' and 'MMMU' not in dataset:
122
+ prompt = self.build_multi_choice_prompt(line, dataset)
123
+ elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
124
+ if 'MathVista' in dataset:
125
+ prompt = line['question']
126
+ elif listinstr(['LLaVABench'], dataset):
127
+ question = line['question']
128
+ prompt = question + '\nAnswer this question in detail.'
129
+ elif listinstr(['MMVet'], dataset):
130
+ prompt = line['question']
131
+ else:
132
+ question = line['question']
133
+ prompt = question + '\nAnswer the question using a single word or phrase.'
134
+ elif dataset is not None and 'MMMU' in dataset:
135
+ question = line['question']
136
+ options = {
137
+ cand: line[cand]
138
+ for cand in string.ascii_uppercase
139
+ if cand in line and not pd.isna(line[cand])
140
+ }
141
+ for key, item in options.items():
142
+ question += f'\n{key}. {item}'
143
+ prompt = {
144
+ 'multiple-choice': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n', # noqa
145
+ 'open': 'Answer with carefully thought step by step. Apply the thinking process recursively at both macro and micro levels. Verify consistency of reasoning and look for potential flaws or gaps during thinking. When realize mistakes, explain why the previous thinking was incorrect, fix it and then continue thinking.\n\n' # noqa
146
+ }
147
+ subject = '_'.join(line['id'].split('_')[1:-1])
148
+ prompt = prompt[line['question_type']].format(subject, subject) + '\n' + question
149
+ else:
150
+ prompt = line['question']
151
+
152
+ message = [dict(type='text', value=prompt)]
153
+ message.extend([dict(type='image', value=s) for s in tgt_path])
154
+
155
+ return message
156
+
157
+ def message_to_promptimg(self, message, dataset=None):
158
+ if dataset is None or listinstr(['MMMU', 'BLINK'], dataset):
159
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
160
+ image = [[x['value'] for x in message if x['type'] == 'image'][0]]
161
+ else:
162
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
163
+ image = [x['value'] for x in message if x['type'] == 'image']
164
+ return prompt, image
165
+
166
+ def generate_inner(self, inputs, **kwargs) -> str:
167
+ assert isinstance(inputs, str) or isinstance(inputs, list)
168
+ inputs = [inputs] if isinstance(inputs, str) else inputs
169
+ dataset = kwargs.get('dataset', None)
170
+
171
+ if dataset is not None and listinstr(['ChartQA_TEST','MathVista_MINI'], dataset):
172
+ self.max_num = 12
173
+ elif dataset is not None and listinstr(['DocVQA_VAL', 'DocVQA_TEST'], dataset):
174
+ self.max_num = 18
175
+ elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench'], dataset):
176
+ self.max_num = 24
177
+ else:
178
+ self.max_num = 6
179
+
180
+ if dataset is None:
181
+ pass
182
+ elif listinstr(['AI2D_TEST'], dataset):
183
+ self.max_new_tokens = 10
184
+ elif 'MMMU' in dataset:
185
+ self.max_new_tokens = 4096 # 1024
186
+ elif 'MMBench' in dataset:
187
+ self.max_new_tokens = 100
188
+ elif 'MathVista_MINI' in dataset:
189
+ self.max_new_tokens = 4096
190
+
191
+ prompt, image = self.message_to_promptimg(message=inputs, dataset=dataset)
192
+
193
+ url = 'https://api.sensenova.cn/v1/llm/chat-completions'
194
+ api_secret_key = self.encode_jwt_token(self.ak, self.sk)
195
+
196
+ content = [{
197
+ 'image_base64': self.image_to_base64(item),
198
+ 'image_file_id': '',
199
+ 'image_url': '',
200
+ 'text': '',
201
+ 'text': '',
202
+ 'type': 'image_base64'
203
+ } for item in image]
204
+
205
+ content.append({
206
+ 'image_base64': '',
207
+ 'image_file_id': '',
208
+ 'image_url': '',
209
+ 'text': prompt,
210
+ 'type': 'text'
211
+ })
212
+
213
+ message = [{'content': content, 'role': 'user'}]
214
+
215
+ data = {
216
+ 'messages': message,
217
+ 'max_new_tokens': self.max_new_tokens, # 1024
218
+ 'temperature': 0,
219
+ "top_k": 0,
220
+ "top_p": 0.99,
221
+ 'repetition_penalty': 1.05,
222
+ 'model': self.model,
223
+ 'stream': False,
224
+ }
225
+ headers = {
226
+ 'Content-type': 'application/json',
227
+ 'Authorization': 'Bearer ' + api_secret_key
228
+ }
229
+
230
+ response = requests.post(
231
+ url,
232
+ headers=headers,
233
+ json=data,
234
+ )
235
+ request_id = response.headers['x-request-id']
236
+
237
+ time.sleep(1)
238
+ try:
239
+ assert response.status_code == 200
240
+ response = response.json()['data']['choices'][0]['message'].strip()
241
+ if self.verbose:
242
+ self.logger.info(f'inputs: {inputs}\nanswer: {response}')
243
+ return 0, response, 'Succeeded! '
244
+ except Exception as err:
245
+ if self.verbose:
246
+ self.logger.error('---------------------------ERROR---------------------------')
247
+ self.logger.error(response.json())
248
+ self.logger.error(err)
249
+ self.logger.error('---------------------------request_id---------------------------' + request_id)
250
+ self.logger.error(
251
+ 'api error' + response.json()['error']['message']
252
+ + str([input['value'] if input['type'] == 'image' else None for input in inputs])
253
+ )
254
+ self.logger.error(f'The input messages are {inputs}.')
255
+ return -1, response.json()['error']['message'], ''
256
+
257
+
258
+ class SenseChatVisionAPI(SenseChatVisionWrapper):
259
+
260
+ def generate(self, message, dataset=None):
261
+ return super(SenseChatVisionAPI, self).generate(message, dataset=dataset)
VLMEvalKit/vlmeval/api/siliconflow.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from vlmeval.smp import *
3
+ from vlmeval.api.base import BaseAPI
4
+ from vlmeval.dataset import img_root_map
5
+
6
+ API_BASE = "https://api.siliconflow.cn/v1/chat/completions"
7
+
8
+
9
+ def resize_image(image: Image.Image, max_height: int, max_width: int) -> Image.Image:
10
+ width, height = image.size
11
+ if min(width, height) < 50:
12
+ scale = 50 / min(width, height)
13
+ image = image.resize((int(width * scale), int(height * scale)))
14
+ current_pixels = width * height
15
+
16
+ if current_pixels <= max_height * max_width:
17
+ return image
18
+
19
+ scale = math.sqrt(max_height * max_width / current_pixels)
20
+ new_width = int(width * scale)
21
+ new_height = int(height * scale)
22
+
23
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
24
+
25
+
26
+ def encode_image(path: str, max_height: int = 1024, max_width: int = 1024) -> str:
27
+ image = Image.open(path).convert("RGB")
28
+ image = resize_image(image, max_height, max_width)
29
+ height, width = image.size
30
+ if min(height, width) < 50:
31
+ scale = 50 / min(width, height)
32
+ image = image.resize((int(width * scale), int(height * scale)))
33
+ buffered = io.BytesIO()
34
+ image.save(buffered, format="PNG")
35
+ img_bytes = buffered.getvalue()
36
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
37
+ return img_base64
38
+
39
+
40
+ class SiliconFlowAPI(BaseAPI):
41
+
42
+ is_api: bool = True
43
+
44
+ def __init__(
45
+ self,
46
+ model: str = "deepseek-ai/DeepSeek-V2.5",
47
+ retry: int = 5,
48
+ wait: int = 5,
49
+ key: str = None,
50
+ api_base: str = API_BASE,
51
+ verbose: bool = True,
52
+ system_prompt: str = None,
53
+ timeout: int = 60,
54
+ **kwargs,
55
+ ):
56
+
57
+ self.model = model
58
+ self.api_base = api_base
59
+
60
+ default_kwargs = {
61
+ "stream": False,
62
+ "temperature": 0,
63
+ "n": 1,
64
+ "max_tokens": 1280,
65
+ }
66
+ for k, v in default_kwargs.items():
67
+ if k not in kwargs:
68
+ kwargs[k] = default_kwargs[k]
69
+ if key is not None:
70
+ self.key = key
71
+ else:
72
+ self.key = os.environ.get("SiliconFlow_API_KEY", "")
73
+ headers = {"Authorization": "Bearer {}", "Content-Type": "application/json"}
74
+ headers["Authorization"] = headers["Authorization"].format(self.key)
75
+ self.headers = headers
76
+ super().__init__(
77
+ wait=wait,
78
+ retry=retry,
79
+ system_prompt=system_prompt,
80
+ verbose=verbose,
81
+ **kwargs,
82
+ )
83
+
84
+ @staticmethod
85
+ def build_msgs(msgs_raw):
86
+ messages = []
87
+ message = {"role": "user", "content": []}
88
+ image_b64 = None
89
+ for msg in msgs_raw:
90
+ if msg["type"] == "image" and not image_b64:
91
+ image_b64 = encode_image(msg["value"])
92
+ message["content"].append(
93
+ {"image_url": {"url": image_b64}, "type": "image_url"}
94
+ )
95
+ elif msg["type"] == "text":
96
+ message["content"].append({"text": msg["value"], "type": "text"})
97
+
98
+ messages.append(message)
99
+ return messages
100
+
101
+ def generate_inner(self, inputs, **kwargs) -> str:
102
+ default_kwargs = self.default_kwargs
103
+ default_kwargs.update(kwargs)
104
+
105
+ payload = dict(
106
+ model=self.model,
107
+ messages=self.build_msgs(msgs_raw=inputs),
108
+ **default_kwargs,
109
+ )
110
+
111
+ response = requests.post(
112
+ self.api_base, headers=self.headers, data=json.dumps(payload)
113
+ )
114
+ ret_code = response.status_code
115
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
116
+
117
+ answer = self.fail_msg
118
+ try:
119
+ resp_struct = json.loads(response.text)
120
+ answer = resp_struct["choices"][0]["message"]["content"].strip()
121
+ except:
122
+ pass
123
+ return ret_code, answer, response
124
+
125
+
126
+ class TeleMMAPI(SiliconFlowAPI):
127
+
128
+ is_api: bool = True
129
+
130
+ def __init__(
131
+ self,
132
+ model: str = "TeleAI/TeleMM",
133
+ key: str = None,
134
+ max_height: int = 1280,
135
+ max_width: int = 784,
136
+ **kwargs,
137
+ ):
138
+ super().__init__(model=model, key=key, **kwargs)
139
+ self.max_height = max_height
140
+ self.max_width = max_width
141
+
142
+ def dump_image(self, line, dataset):
143
+ """Dump the image(s) of the input line to the corresponding dataset folder.
144
+
145
+ Args:
146
+ line (line of pd.DataFrame): The raw input line.
147
+ dataset (str): The name of the dataset.
148
+
149
+ Returns:
150
+ str | list[str]: The paths of the dumped images.
151
+ """
152
+ ROOT = LMUDataRoot()
153
+ assert isinstance(dataset, str)
154
+ # img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
155
+ img_root = osp.join(ROOT, "images", img_root_map(dataset))
156
+ os.makedirs(img_root, exist_ok=True)
157
+ if "image" in line:
158
+ if isinstance(line["image"], list):
159
+ tgt_path = []
160
+ assert "image_path" in line
161
+ for img, im_name in zip(line["image"], line["image_path"]):
162
+ path = osp.join(img_root, im_name)
163
+ if not read_ok(path):
164
+ decode_base64_to_image_file(img, path)
165
+ tgt_path.append(path)
166
+ else:
167
+ tgt_path = osp.join(img_root, f"{line['index']}.jpg")
168
+ if not read_ok(tgt_path):
169
+ decode_base64_to_image_file(line["image"], tgt_path)
170
+ tgt_path = [tgt_path]
171
+ else:
172
+ assert "image_path" in line
173
+ tgt_path = toliststr(line["image_path"])
174
+ return tgt_path
175
+
176
+ def _prepare_content(
177
+ self, inputs: list[dict[str, str]], dataset: str = None
178
+ ) -> list[dict[str, str]]:
179
+ """
180
+ inputs list[dict[str, str]], each dict has keys: ['type', 'value']
181
+ """
182
+ content = []
183
+ has_image = False
184
+ for s in inputs:
185
+ if s["type"] == "image":
186
+ if not has_image:
187
+ item = {
188
+ "type": "image_url",
189
+ "image_url": {
190
+ "url": encode_image(
191
+ s["value"],
192
+ max_height=self.max_height,
193
+ max_width=self.max_width,
194
+ )
195
+ },
196
+ }
197
+ has_image = True
198
+ else:
199
+ continue
200
+ elif s["type"] == "text":
201
+ prompt = s["value"]
202
+ if len(prompt) == 0:
203
+ continue
204
+ if dataset == "HallusionBench":
205
+ prompt += " Please answer yes or no directly, without any unnecessary explanation."
206
+ elif dataset == "OCRBench":
207
+ prompt = (
208
+ prompt + "\nExtract the text from the image intactly and "
209
+ + "answer the question concisely and clearly if possible."
210
+ )
211
+
212
+ elif (
213
+ dataset == "AI2D_TEST"
214
+ or dataset == "MMStar"
215
+ or dataset == "MMBench_TEST_EN_V11"
216
+ or dataset == "MMVet"
217
+ ):
218
+ prompt = prompt.replace(
219
+ "Please select the correct answer from the options above. \n",
220
+ "Please select the correct option from the above choices based on the "
221
+ + "input image and question. The final output should only be one option, such as 'A'",
222
+ )
223
+ elif dataset == "MMBench_TEST_CN_V11":
224
+ prompt = prompt.replace(
225
+ "Please select the correct answer from the options above. \n",
226
+ "请根据输入图像和问题从上述选项中选择正确选项,最终的输出只有一个选项,例如'A'",
227
+ )
228
+ item = {"type": "text", "text": prompt}
229
+ else:
230
+ raise ValueError(f"Invalid message type: {s['type']}, {s}")
231
+ content.append(item)
232
+
233
+ return content
234
+
235
+ def generate_inner(self, inputs, **kwargs) -> str:
236
+ default_kwargs = self.default_kwargs
237
+ default_kwargs.update(kwargs)
238
+
239
+ messages = []
240
+ messages.append(
241
+ {
242
+ "role": "user",
243
+ "content": self._prepare_content(
244
+ inputs, dataset=kwargs.get("dataset", None)
245
+ ),
246
+ }
247
+ )
248
+
249
+ payload = dict(model=self.model, messages=messages, **default_kwargs)
250
+
251
+ response = requests.post(
252
+ self.api_base, headers=self.headers, data=json.dumps(payload)
253
+ )
254
+ ret_code = response.status_code
255
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
256
+
257
+ answer = self.fail_msg
258
+ try:
259
+ resp_struct = json.loads(response.text)
260
+ answer = resp_struct["choices"][0]["message"]["content"].strip()
261
+ return ret_code, answer, response
262
+ except Exception as err:
263
+ import traceback
264
+
265
+ traceback.print_exc()
266
+ if self.verbose:
267
+ self.logger.error(f"{type(err)}: {err}")
268
+ self.logger.error(f"The input messages are {inputs}.")
269
+ return -1, "", ""
VLMEvalKit/vlmeval/api/stepai.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+
4
+ url = 'https://api.stepfun.com/v1/chat/completions'
5
+ headers = {
6
+ 'Content-Type': 'application/json',
7
+ 'Authorization': 'Bearer {}',
8
+ }
9
+
10
+
11
+ class StepAPI_INT(BaseAPI):
12
+
13
+ is_api: bool = True
14
+
15
+ def __init__(self,
16
+ model: str = 'step-1v-8k',
17
+ retry: int = 10,
18
+ wait: int = 3,
19
+ key: str = None,
20
+ temperature: float = 0,
21
+ max_tokens: int = 300,
22
+ verbose: bool = True,
23
+ system_prompt: str = None,
24
+ **kwargs):
25
+ self.model = model
26
+ self.fail_msg = 'Fail to obtain answer via API.'
27
+ self.headers = headers
28
+ self.temperature = temperature
29
+ self.max_tokens = max_tokens
30
+ self.system_prompt = system_prompt
31
+ if key is not None:
32
+ self.key = key
33
+ else:
34
+ self.key = os.environ.get('STEPAI_API_KEY', '')
35
+ headers['Authorization'] = headers['Authorization'].format(self.key)
36
+
37
+ super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)
38
+
39
+ @staticmethod
40
+ def build_msgs(msgs_raw):
41
+ messages = []
42
+ message = {'role': 'user', 'content': []}
43
+
44
+ for msg in msgs_raw:
45
+ if msg['type'] == 'image':
46
+ image_b64 = encode_image_file_to_base64(msg['value'])
47
+ message['content'].append({
48
+ 'image_url': {'url': 'data:image/webp;base64,%s' % (image_b64)},
49
+ 'type': 'image_url'
50
+ })
51
+ elif msg['type'] == 'text':
52
+ message['content'].append({
53
+ 'text': msg['value'],
54
+ 'type': 'text'
55
+ })
56
+
57
+ messages.append(message)
58
+ return messages
59
+
60
+ def generate_inner(self, inputs, **kwargs) -> str:
61
+ print(inputs, '\n')
62
+ payload = dict(
63
+ model=self.model,
64
+ max_tokens=self.max_tokens,
65
+ temperature=self.temperature,
66
+ messages=self.build_msgs(msgs_raw=inputs),
67
+ **kwargs)
68
+ response = requests.post(url, headers=headers, data=json.dumps(payload))
69
+ ret_code = response.status_code
70
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
71
+
72
+ answer = self.fail_msg
73
+ try:
74
+ resp_struct = json.loads(response.text)
75
+ answer = resp_struct['choices'][0]['message']['content'].strip()
76
+ except Exception as err:
77
+ if self.verbose:
78
+ self.logger.error(f'{type(err)}: {err}')
79
+ self.logger.error(response.text if hasattr(response, 'text') else response)
80
+
81
+ return ret_code, answer, response
82
+
83
+
84
+ class Step1V_INT(StepAPI_INT):
85
+
86
+ def generate(self, message, dataset=None):
87
+ return super(StepAPI_INT, self).generate(message)
VLMEvalKit/vlmeval/api/taiyi.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vlmeval.smp import *
2
+ from vlmeval.api.base import BaseAPI
3
+ from vlmeval.dataset import DATASET_TYPE, img_root_map
4
+
5
+
6
+ class TaiyiWrapper(BaseAPI):
7
+
8
+ is_api: bool = True
9
+
10
+ def __init__(self,
11
+ model: str = 'taiyi',
12
+ retry: int = 5,
13
+ wait: int = 5,
14
+ key: str = None,
15
+ verbose: bool = False,
16
+ system_prompt: str = None,
17
+ temperature: float = 0,
18
+ timeout: int = 60,
19
+ url: str = "https://taiyi.megvii.com/v1/chat/completions",
20
+ max_tokens: int = 1024,
21
+ **kwargs):
22
+
23
+ self.model = model
24
+ self.fail_msg = 'Failed to obtain answer via API. '
25
+ self.max_tokens = max_tokens
26
+ self.temperature = temperature
27
+
28
+ if key is None:
29
+ key = os.environ.get('TAIYI_API_KEY', None)
30
+ assert key is not None, ('Please set the API Key ')
31
+ self.key = key
32
+
33
+ self.timeout = timeout
34
+ super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
35
+ assert url is not None, ('Please set the url ')
36
+ self.url = url
37
+ self.logger.info(f'Using url: {self.url}; API Key: {self.key}')
38
+
39
+ def use_custom_prompt(self, dataset):
40
+ if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ' or DATASET_TYPE(dataset) == 'VQA':
41
+ return True
42
+ return False
43
+
44
+ def prepare_inputs(self, inputs):
45
+ input_msgs = []
46
+ if self.system_prompt is not None:
47
+ input_msgs.append(dict(role='system', content=self.system_prompt))
48
+ has_images = np.sum([x['type'] == 'image' for x in inputs])
49
+ if has_images:
50
+ content_list = []
51
+ for msg in inputs:
52
+ if msg['type'] == 'text':
53
+ content_list.append(dict(type='text', text=msg['value']))
54
+ elif msg['type'] == 'image':
55
+ imgbytes = open(msg['value'],'rb').read()
56
+ b64 = base64.b64encode(imgbytes).decode('ascii')
57
+ img_struct = dict(url=f'data:image/jpeg;base64,{b64}')
58
+ content_list.append(dict(type='image_url', image_url=img_struct))
59
+ input_msgs.append(dict(role='user', content=content_list))
60
+ else:
61
+ assert all([x['type'] == 'text' for x in inputs])
62
+ text = '\n'.join([x['value'] for x in inputs])
63
+ input_msgs.append(dict(role='user', content=text))
64
+ return input_msgs
65
+
66
+ def set_dump_image(self, dump_image_func):
67
+ self.dump_image_func = dump_image_func
68
+
69
+ def dump_image(self, line, dataset):
70
+ return self.dump_image_func(line)
71
+
72
+ def image_first(self, msgs):
73
+ nr_img = 0
74
+ for s in msgs:
75
+ if s['type'] == 'image':
76
+ nr_img += 1
77
+
78
+ if nr_img == 1:
79
+ new_msgs = []
80
+ img_msg = None
81
+ for s in msgs:
82
+ if s['type'] == 'text':
83
+ new_msgs.append(s)
84
+ else:
85
+ img_msg = s
86
+ new_msgs.insert(0, img_msg)
87
+ else:
88
+ new_msgs = msgs
89
+
90
+ return new_msgs
91
+
92
+ def build_multi_choice_prompt(self, line, dataset=None):
93
+ question = line['question']
94
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
95
+ if hint is not None:
96
+ question = hint + '\n' + question
97
+
98
+ options = {
99
+ cand: line[cand]
100
+ for cand in string.ascii_uppercase
101
+ if cand in line and not pd.isna(line[cand])
102
+ }
103
+ for key, item in options.items():
104
+ question += f'\n{key}. {item}'
105
+ prompt = question
106
+
107
+ if len(options):
108
+ prompt += '\n请直接回答选项字母。' if cn_string(
109
+ prompt) else "\nAnswer with the option's letter from the given choices directly."
110
+ else:
111
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
112
+
113
+ return prompt
114
+
115
+ def build_yorn_prompt(self, line, dataset=None):
116
+ if listinstr(['HallusionBench'], dataset):
117
+ pre_prompt = 'Read the following question carefully, think and solve it step by step.\n\n'
118
+ else:
119
+ pre_prompt = ''
120
+
121
+ prompt = pre_prompt + line['question'] + ' Please answer yes or no as the final answer.'
122
+
123
+ return prompt
124
+
125
+ def build_vqa_prompt(self, line, dataset=None):
126
+ if listinstr(['OCRBench'], dataset):
127
+ pre_prompt = 'Carefully identify the text in the image and answer the question.\n\n'
128
+ else:
129
+ pre_prompt = ''
130
+
131
+ if listinstr(['MMVet'], dataset):
132
+ post_prompt = '\nAnswer this question in detail.'
133
+ else:
134
+ post_prompt = ''
135
+
136
+ prompt = pre_prompt + line['question'] + post_prompt
137
+
138
+ return prompt
139
+
140
+ def build_prompt(self, line, dataset=None):
141
+ assert self.use_custom_prompt(dataset)
142
+ assert dataset is None or isinstance(dataset, str)
143
+ tgt_path = self.dump_image(line, dataset)
144
+
145
+ if DATASET_TYPE(dataset) == 'MCQ':
146
+ prompt = self.build_multi_choice_prompt(line, dataset)
147
+ elif DATASET_TYPE(dataset) == 'Y/N':
148
+ prompt = self.build_yorn_prompt(line, dataset)
149
+ elif DATASET_TYPE(dataset) == 'VQA':
150
+ prompt = self.build_vqa_prompt(line, dataset)
151
+ else:
152
+ raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
153
+ message = []
154
+ message.extend([dict(type='image', value=s) for s in tgt_path])
155
+ message.extend([dict(type='text', value=prompt)])
156
+
157
+ # interleave dataset
158
+ if dataset.startswith('MMMU_'):
159
+ from .. import MMMUDataset
160
+ message = MMMUDataset.split_MMMU(message)
161
+ message = self.image_first(message)
162
+
163
+ return message
164
+
165
+ def generate_inner(self, inputs, **kwargs) -> str:
166
+
167
+ input_msgs = self.prepare_inputs(inputs)
168
+ temperature = kwargs.pop('temperature', self.temperature)
169
+
170
+ headers = {'Authorization': f'Bearer {self.key}'}
171
+ payload = dict(
172
+ model=self.model,
173
+ messages=input_msgs,
174
+ n=1,
175
+ temperature=temperature,
176
+ **kwargs)
177
+ response = requests.post(self.url, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
178
+ ret_code = response.status_code
179
+ ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
180
+ answer = self.fail_msg
181
+ try:
182
+ resp_struct = json.loads(response.text)
183
+ answer = resp_struct['choices'][0]['message']['content'].strip()
184
+ except:
185
+ pass
186
+ return ret_code, answer, response
187
+
188
+
189
+ class TaiyiAPI(TaiyiWrapper):
190
+
191
+ def generate(self, message, dataset=None):
192
+ return super(TaiyiAPI, self).generate(message)
VLMEvalKit/vlmeval/dataset/__init__.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from .image_base import img_root_map, ImageBaseDataset
4
+ from .image_caption import ImageCaptionDataset
5
+ from .image_yorn import ImageYORNDataset
6
+ from .image_mcq import (
7
+ ImageMCQDataset, MMMUDataset, CustomMCQDataset, MUIRDataset, GMAIMMBenchDataset, MMERealWorld, HRBenchDataset,
8
+ NaturalBenchDataset
9
+ )
10
+ from .image_mt import MMDUDataset
11
+ from .image_vqa import (
12
+ ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
13
+ CustomVQADataset, CRPE, MathVerse, OlympiadBench, QSpatial, VizWiz, MMNIAH
14
+ )
15
+
16
+ from .text_mcq import CustomTextMCQDataset, TextMCQDataset
17
+
18
+ from .vcr import VCRDataset
19
+ from .mmlongbench import MMLongBench
20
+ from .dude import DUDE
21
+ from .slidevqa import SlideVQA
22
+
23
+ from .mmbench_video import MMBenchVideo
24
+ from .videomme import VideoMME
25
+ from .mvbench import MVBench, MVBench_MP4
26
+ from .mlvu import MLVU, MLVU_MCQ, MLVU_OpenEnded
27
+ from .tempcompass import TempCompass, TempCompass_Captioning, TempCompass_MCQ, TempCompass_YorN
28
+ from .longvideobench import LongVideoBench
29
+ from .video_concat_dataset import ConcatVideoDataset
30
+ from .mmgenbench import MMGenBench
31
+
32
+ from .miabench import MIABench
33
+ from .cmmmu import CMMMU
34
+ from .wildvision import WildVision
35
+ from .mmmath import MMMath
36
+ from .dynamath import Dynamath
37
+ from .utils import *
38
+ from ..smp import *
39
+
40
+
41
+ class ConcatDataset(ImageBaseDataset):
42
+ # This dataset takes multiple dataset names as input and aggregate them into a single dataset.
43
+ # Each single dataset should not have a field named `SUB_DATASET`
44
+
45
+ DATASET_SETS = {
46
+ 'MMMB': ['MMMB_ar', 'MMMB_cn', 'MMMB_en', 'MMMB_pt', 'MMMB_ru', 'MMMB_tr'],
47
+ 'MTL_MMBench_DEV': [
48
+ 'MMBench_dev_ar', 'MMBench_dev_cn', 'MMBench_dev_en',
49
+ 'MMBench_dev_pt', 'MMBench_dev_ru', 'MMBench_dev_tr'
50
+ ]
51
+ }
52
+
53
+ def __init__(self, dataset):
54
+ datasets = self.DATASET_SETS[dataset]
55
+ self.dataset_map = {}
56
+ # The name of the compliation
57
+ self.dataset_name = dataset
58
+ self.datasets = datasets
59
+ for dname in datasets:
60
+ dataset = build_dataset(dname)
61
+ assert dataset is not None, dataset
62
+ self.dataset_map[dname] = dataset
63
+ TYPES = [x.TYPE for x in self.dataset_map.values()]
64
+ MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
65
+ assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
66
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
67
+ self.TYPE = TYPES[0]
68
+ self.MODALITY = MODALITIES[0]
69
+ data_all = []
70
+ for dname in datasets:
71
+ data = self.dataset_map[dname].data
72
+ data['SUB_DATASET'] = [dname] * len(data)
73
+ data_new = localize_df(data, dname, nproc=16)
74
+ data_all.append(data_new)
75
+
76
+ data = pd.concat(data_all)
77
+ data['original_index'] = data.pop('index')
78
+ data['index'] = np.arange(len(data))
79
+ self.data = data
80
+
81
+ def build_prompt(self, line):
82
+ if isinstance(line, int):
83
+ line = self.data.iloc[line]
84
+ idx = line['original_index']
85
+ dname = line['SUB_DATASET']
86
+ org_data = self.dataset_map[dname].data
87
+ org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
88
+ return self.dataset_map[dname].build_prompt(org_line)
89
+
90
+ def dump_image(self, line):
91
+ # Assert all images are pre-dumped
92
+ assert 'image' not in line
93
+ assert 'image_path' in line
94
+ tgt_path = toliststr(line['image_path'])
95
+ return tgt_path
96
+
97
+ @classmethod
98
+ def supported_datasets(cls):
99
+ return list(cls.DATASET_SETS)
100
+
101
+ def evaluate(self, eval_file, **judge_kwargs):
102
+ suffix = eval_file.split('.')[-1]
103
+ # First, split the eval_file by dataset
104
+ data_all = load(eval_file)
105
+ for dname in self.datasets:
106
+ tgt = eval_file.replace(self.dataset_name, dname)
107
+ data_sub = data_all[data_all['SUB_DATASET'] == dname]
108
+ data_sub.pop('index')
109
+ data_sub['index'] = data_sub.pop('original_index')
110
+ data_sub.pop('SUB_DATASET')
111
+ dump(data_sub, tgt)
112
+ # Then, evaluate each dataset separately
113
+ results_all = []
114
+ for dname in self.datasets:
115
+ tgt = eval_file.replace(self.dataset_name, dname)
116
+ res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
117
+ assert isinstance(res, pd.DataFrame)
118
+ res['DATASET'] = [dname] * len(res)
119
+ results_all.append(res)
120
+ result = pd.concat(results_all)
121
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
122
+ dump(result, score_file)
123
+ return result
124
+
125
+
126
+ # Add new supported dataset class here
127
+ IMAGE_DATASET = [
128
+ ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, MathVision,
129
+ MMMUDataset, OCRBench, MathVista, LLaVABench, MMVet, MTVQADataset, TableVQABench,
130
+ MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset,
131
+ GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset,
132
+ MIABench, OlympiadBench, WildVision, MMMath, QSpatial, Dynamath, MMGenBench, VizWiz, MMNIAH,
133
+ CMMMU
134
+ ]
135
+
136
+ VIDEO_DATASET = [
137
+ MMBenchVideo, VideoMME, MVBench, MVBench_MP4, LongVideoBench,
138
+ MLVU, MLVU_MCQ, MLVU_OpenEnded,
139
+ TempCompass, TempCompass_MCQ, TempCompass_Captioning, TempCompass_YorN
140
+ ]
141
+
142
+ TEXT_DATASET = [
143
+ TextMCQDataset
144
+ ]
145
+
146
+ CUSTOM_DATASET = [
147
+ CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset
148
+ ]
149
+
150
+ DATASET_COLLECTION = [ConcatDataset, ConcatVideoDataset]
151
+
152
+ DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION
153
+ SUPPORTED_DATASETS = []
154
+ for DATASET_CLS in DATASET_CLASSES:
155
+ SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets())
156
+
157
+
158
+ def DATASET_TYPE(dataset, *, default: str = 'MCQ') -> str:
159
+ for cls in DATASET_CLASSES:
160
+ if dataset in cls.supported_datasets():
161
+ if hasattr(cls, 'TYPE'):
162
+ return cls.TYPE
163
+ # Have to add specific routine to handle ConcatDataset
164
+ if dataset in ConcatDataset.DATASET_SETS:
165
+ dataset_list = ConcatDataset.DATASET_SETS[dataset]
166
+ TYPES = [DATASET_TYPE(dname) for dname in dataset_list]
167
+ assert np.all([x == TYPES[0] for x in TYPES]), (dataset_list, TYPES)
168
+ return TYPES[0]
169
+
170
+ if 'openended' in dataset.lower():
171
+ return 'VQA'
172
+ warnings.warn(f'Dataset {dataset} is a custom one and not annotated as `openended`, will treat as {default}. ')
173
+ return default
174
+
175
+
176
+ def DATASET_MODALITY(dataset, *, default: str = 'IMAGE') -> str:
177
+ if dataset is None:
178
+ warnings.warn(f'Dataset is not specified, will treat modality as {default}. ')
179
+ return default
180
+ for cls in DATASET_CLASSES:
181
+ if dataset in cls.supported_datasets():
182
+ if hasattr(cls, 'MODALITY'):
183
+ return cls.MODALITY
184
+ # Have to add specific routine to handle ConcatDataset
185
+ if dataset in ConcatDataset.DATASET_SETS:
186
+ dataset_list = ConcatDataset.DATASET_SETS[dataset]
187
+ MODALITIES = [DATASET_MODALITY(dname) for dname in dataset_list]
188
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (dataset_list, MODALITIES)
189
+ return MODALITIES[0]
190
+
191
+ if 'VIDEO' in dataset.lower():
192
+ return 'VIDEO'
193
+ elif 'IMAGE' in dataset.lower():
194
+ return 'IMAGE'
195
+ warnings.warn(f'Dataset {dataset} is a custom one, will treat modality as {default}. ')
196
+ return default
197
+
198
+
199
+ def build_dataset(dataset_name, **kwargs):
200
+ for cls in DATASET_CLASSES:
201
+ if dataset_name in cls.supported_datasets():
202
+ return cls(dataset=dataset_name, **kwargs)
203
+
204
+ warnings.warn(f'Dataset {dataset_name} is not officially supported. ')
205
+
206
+ data_file = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
207
+ if not osp.exists(data_file):
208
+ warnings.warn(f'Data file {data_file} does not exist. Dataset building failed. ')
209
+ return None
210
+
211
+ data = load(data_file)
212
+ if 'question' not in [x.lower() for x in data.columns]:
213
+ warnings.warn(f'Data file {data_file} does not have a `question` column. Dataset building failed. ')
214
+ return None
215
+
216
+ if 'A' in data and 'B' in data:
217
+ if 'image' in data or 'image_path' in data:
218
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom MCQ dataset. ')
219
+ return CustomMCQDataset(dataset=dataset_name, **kwargs)
220
+ else:
221
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom Text MCQ dataset. ')
222
+ return CustomTextMCQDataset(dataset=dataset_name, **kwargs)
223
+ else:
224
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom VQA dataset. ')
225
+ return CustomVQADataset(dataset=dataset_name, **kwargs)
226
+
227
+
228
+ __all__ = [
229
+ 'build_dataset', 'img_root_map', 'build_judge', 'extract_answer_from_item', 'prefetch_answer', 'DEBUG_MESSAGE'
230
+ ] + [cls.__name__ for cls in DATASET_CLASSES]
VLMEvalKit/vlmeval/dataset/cmmmu.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ import random
3
+ from collections import Counter
4
+ import os
5
+ import re
6
+ import tempfile
7
+ from ..smp import *
8
+
9
+
10
+ def get_multi_choice_prediction(response, all_choices, index2ans):
11
+ for char in [',', '.', '!', '?', ';', ':', "'"]:
12
+ response = response.strip(char)
13
+ response = " " + response + " " # add space to avoid partial match
14
+
15
+ candidates = []
16
+
17
+ for choice in all_choices: # (A) (B) (C) (D)
18
+ # Add the choice to candidates each time it appears in the response
19
+ candidates.extend([choice for _ in range(response.count(f'({choice})'))])
20
+
21
+ if len(candidates) == 0:
22
+ for choice in all_choices: # A B C D
23
+ # Similarly, add the choice for each occurrence
24
+ candidates.extend([choice for _ in range(response.count(f'{choice}'))])
25
+
26
+ if len(candidates) == 0 and len(response.split()) >= 1:
27
+ for index, ans in index2ans.items():
28
+ # Add index for each occurrence of ans in response
29
+ candidates.extend([index for _ in range(response.count(ans))])
30
+
31
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
32
+ if len(candidates) == 0 and len(response.split()) >= 1:
33
+ for index, ans in index2ans.items():
34
+ if ans in response:
35
+ candidates.append(index)
36
+ # index_ans = False # it's content ans.
37
+
38
+ if len(candidates) == 0: # still not get answer, randomly choose one.
39
+ return random.choice(all_choices)
40
+ # return ''
41
+ else:
42
+ # Count the occurrence of each candidate
43
+ candidate_counts = Counter(candidates)
44
+
45
+ # Select the most frequent candidates
46
+ max_count = max(candidate_counts.values())
47
+ most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
48
+
49
+ # Combine the most frequent candidates in ABCD order
50
+ return ''.join(most_frequent_candidates)
51
+
52
+
53
+ def extract_numbers(string):
54
+ # Pattern for numbers with Chinese commas
55
+ pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
56
+ # Pattern for scientific notation
57
+ pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
58
+ # Pattern for simple numbers without Chinese commas
59
+ pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
60
+
61
+ # Extract numbers with Chinese commas
62
+ numbers_with_commas = re.findall(pattern_commas, string)
63
+ # Extract numbers in scientific notation
64
+ numbers_scientific = re.findall(pattern_scientific, string)
65
+ # Extract simple numbers without Chinese commas
66
+ numbers_simple = re.findall(pattern_simple, string)
67
+
68
+ # Combine all extracted numbers
69
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
70
+ return all_numbers
71
+
72
+
73
+ def check_is_number(string):
74
+ try:
75
+ float(string.replace(',', ''))
76
+ return True
77
+ except ValueError:
78
+ # check if there's comma inside
79
+ return False
80
+
81
+
82
+ def count_letters(string):
83
+ return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
84
+
85
+
86
+ def normalize_str(string, answer):
87
+ # check if characters in the string
88
+
89
+ # if number, numerize it.
90
+ if string is None:
91
+ return [string]
92
+ string = string.strip()
93
+
94
+ is_number = check_is_number(string)
95
+
96
+ if is_number:
97
+ string = string.replace(',', '')
98
+ string = float(string)
99
+ # leave 2 decimal
100
+ string = round(string, 2)
101
+ return [string]
102
+ else: # it's likely to be a string
103
+ if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
104
+ return []
105
+ return [string]
106
+
107
+
108
+ def get_fill_blank_prediction(response, answer):
109
+ """get the prediction from the generated response,
110
+ return a list of predicted strings or numbers"""
111
+
112
+ def get_key_subresponses(response):
113
+ response = response.strip("。").strip()
114
+ sub_responses = re.split(r'。|\n', response)
115
+ indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
116
+ '正确答案', '因此', '最后', '答案', '结果']
117
+ key_responses = []
118
+ for index, resp in enumerate(sub_responses):
119
+ # if last one, accept it's an equation (the entire response can be just one sentence with equation)
120
+ if index == len(sub_responses) - 1:
121
+ indicators_of_keys.extend(['='])
122
+ shortest_key_response = None
123
+ # the shortest response that may contain the answer (tail part of the response)
124
+ for indicator in indicators_of_keys:
125
+ if indicator in resp:
126
+ if not shortest_key_response:
127
+ shortest_key_response = resp.split(indicator)[-1].strip()
128
+ else:
129
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
130
+ shortest_key_response = resp.split(indicator)[-1].strip()
131
+
132
+ if shortest_key_response:
133
+ # and it's not trivial
134
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
135
+ key_responses.append(shortest_key_response)
136
+ if len(key_responses) == 0: # did not found any
137
+ return [response]
138
+ return key_responses
139
+
140
+ key_responses = get_key_subresponses(response)
141
+
142
+ pred_list = key_responses.copy() # keep the original string response
143
+ for resp in key_responses:
144
+ pred_list.extend(extract_numbers(resp))
145
+
146
+ tmp_pred_list = []
147
+ for i in range(len(pred_list)):
148
+ tmp_pred_list.extend(normalize_str(pred_list[i], answer))
149
+ pred_list = tmp_pred_list
150
+
151
+ # remove duplicates
152
+ pred_list = list(set(pred_list))
153
+
154
+ return pred_list
155
+
156
+
157
+ def get_TF_prediction(response):
158
+ """get the prediction from the generated response,
159
+ return a list of predicted strings or numbers"""
160
+
161
+ def get_key_subresponses(response):
162
+ response = response.strip("。").strip()
163
+ sub_responses = re.split(r'。|\n', response)
164
+ indicators_of_keys = ['是', '为', '所以', '判断',
165
+ '陈述', '说法', '表达', '答案', '结果']
166
+ key_responses = []
167
+ for index, resp in enumerate(sub_responses):
168
+ shortest_key_response = None
169
+ # the shortest response that may contain the answer (tail part of the response)
170
+ for indicator in indicators_of_keys:
171
+ if indicator in resp:
172
+ if not shortest_key_response:
173
+ shortest_key_response = resp.split(indicator)[-1].strip()
174
+ else:
175
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
176
+ shortest_key_response = resp.split(indicator)[-1].strip()
177
+
178
+ if shortest_key_response:
179
+ # and it's not trivial
180
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
181
+ key_responses.append(shortest_key_response)
182
+ if len(key_responses) == 0: # did not found any
183
+ return [response]
184
+ return key_responses
185
+
186
+ key_responses = get_key_subresponses(response)
187
+
188
+ pred_list = key_responses.copy() # keep the original string response
189
+ # remove duplicates
190
+ pred_list = list(set(pred_list))
191
+
192
+ return pred_list
193
+
194
+
195
+ class CMMMU(ImageBaseDataset):
196
+ TYPE = 'VQA'
197
+
198
+ DATASET_URL = {
199
+ 'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
200
+ }
201
+
202
+ DATASET_MD5 = {
203
+ 'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
204
+ }
205
+
206
+ def dump_image(self, line):
207
+ os.makedirs(self.img_root, exist_ok=True)
208
+
209
+ tgt_path_z = []
210
+ if isinstance(line['image'], list):
211
+ for i in range(len(line['image'])):
212
+ tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
213
+ if not read_ok(tgt_path):
214
+ decode_base64_to_image_file(line['image'][i], tgt_path)
215
+ tgt_path_z.append(tgt_path)
216
+ else:
217
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
218
+ if not read_ok(tgt_path):
219
+ decode_base64_to_image_file(line['image'], tgt_path)
220
+ tgt_path_z.append(tgt_path)
221
+ return tgt_path_z
222
+
223
+ @classmethod
224
+ def evaluate(self, eval_file, **judge_kwargs):
225
+
226
+ suffix = eval_file.split('.')[-1]
227
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
228
+
229
+ if not osp.exists(result_file):
230
+ data = load(eval_file)
231
+ assert 'answer' in data and 'prediction' in data
232
+ data['prediction'] = [str(x) for x in data['prediction']]
233
+ data['answer'] = [str(x) for x in data['answer']]
234
+
235
+ correct_count = 0
236
+ correct_category = {
237
+ '技术与工程': [0, 0],
238
+ '科学': [0, 0],
239
+ '健康与医学': [0, 0],
240
+ '商业': [0, 0],
241
+ '艺术与设计': [0, 0],
242
+ '人文社会科学': [0, 0],
243
+ }
244
+
245
+ for i in tqdm(data.iterrows()):
246
+ line = i[1]
247
+ correct_category[line['category']][0] += 1
248
+
249
+ # Options
250
+ if line['type'] == '选择':
251
+ index2ans = {
252
+ 'A': line['option1'],
253
+ 'B': line['option2'],
254
+ 'C': line['option3'],
255
+ 'D': line['option4']
256
+ }
257
+ fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
258
+ if fact_option == line['answer']:
259
+ correct_count += 1
260
+ correct_category[line['category']][1] += 1
261
+
262
+ # Binary
263
+ elif line['type'] == '判断':
264
+ positive_keywords = ['正确', '对', '准确', '肯定', '对的']
265
+ negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
266
+ ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
267
+
268
+ def judge_similarity(pred_list, positive_keywords, negative_keywords):
269
+ positive_count = 0
270
+ negative_count = 0
271
+
272
+ for pred in pred_list:
273
+ if any(pos_word in pred for pos_word in positive_keywords):
274
+ positive_count += 1
275
+ elif any(neg_word in pred for neg_word in negative_keywords):
276
+ negative_count += 1
277
+
278
+ if positive_count > negative_count:
279
+ return "对"
280
+ elif negative_count > positive_count:
281
+ return "错"
282
+ else:
283
+ return random.choice(['对', '错'])
284
+
285
+ answer = get_TF_prediction(line['prediction'])
286
+ answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
287
+ fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
288
+ if fact_answer == line['answer']:
289
+ correct_count += 1
290
+ correct_category[line['category']][1] += 1
291
+
292
+ # Fill the Blank
293
+ else:
294
+ norm_answers = normalize_str(line['answer'], line['answer'])
295
+ predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
296
+
297
+ for pred in predicted_answer:
298
+ # already normalized
299
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
300
+ for norm_ans in norm_answers:
301
+ # only see if the string answer in the string pred
302
+ # print(norm_ans, pred)
303
+ if isinstance(norm_ans, str) and norm_ans in pred:
304
+ correct_count += 1
305
+ correct_category[line['category']][1] += 1
306
+ else: # it's a number
307
+ if pred in norm_answers:
308
+ correct_count += 1
309
+ correct_category[line['category']][1] += 1
310
+
311
+ accuracyz = {}
312
+ accuracyz['总准确率'] = correct_count / len(data)
313
+ for i in correct_category.keys():
314
+ accuracyz[i] = correct_category[i][1] / correct_category[i][0]
315
+
316
+ accuracyz = d2df(accuracyz)
317
+ accuracyz.round(10)
318
+ dump(accuracyz, result_file)
319
+
320
+ result = pd.read_csv(result_file)
321
+ return result
322
+
323
+ def build_prompt(self, line):
324
+ if line['type'] == '选择':
325
+ tgt_path = self.dump_image(line)
326
+ question = line['question']
327
+ options_prompt = 'Options:\n'
328
+
329
+ for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
330
+ options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
331
+
332
+ prompt = (f'问题: {question}\n' + options_prompt
333
+ + '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
334
+
335
+ msgs = []
336
+ if isinstance(tgt_path, list):
337
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
338
+ else:
339
+ msgs = [dict(type='image', value=tgt_path)]
340
+ msgs.append(dict(type='text', value=prompt))
341
+
342
+ return msgs
343
+
344
+ elif line['type'] == '判断':
345
+ msgs = super().build_prompt(line)
346
+ assert msgs[-1]['type'] == 'text'
347
+ msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
348
+ return msgs
349
+
350
+ else:
351
+ msgs = super().build_prompt(line)
352
+ assert msgs[-1]['type'] == 'text'
353
+ msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
354
+ return msgs
VLMEvalKit/vlmeval/dataset/dude.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ from .utils.judge_util import build_judge
5
+ from .image_base import ImageBaseDataset
6
+ from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
7
+ from ..smp import *
8
+
9
+
10
+ FAIL_MSG = 'Failed to obtain answer via API.'
11
+
12
+
13
+ def DUDE_acc(result_file):
14
+ data = load(result_file)
15
+ overall_score = 0.0
16
+ score_list = list()
17
+ for i in range(len(data)):
18
+ item = data.iloc[i]
19
+ if isinstance(item['answer'], float) and math.isnan(item['answer']):
20
+ item['answer'] = 'Not answerable'
21
+
22
+ item['answer'] = item['answer'].lower()
23
+ item['pred'] = item['pred'].lower()
24
+ score = anls_compute(item['answer'], item['pred'])
25
+ score_list.append(score)
26
+ overall_score += score
27
+
28
+ data['score'] = score_list
29
+ dump(data, result_file)
30
+
31
+ res = dict()
32
+ res['category'], res['num'], res['avg_score'] = ['anls'], [len(data)], [overall_score / len(data)]
33
+ res = pd.DataFrame(res)
34
+ return res
35
+
36
+
37
+ class DUDE(ImageBaseDataset):
38
+
39
+ TYPE = 'VQA'
40
+
41
+ DATASET_URL = {
42
+ 'DUDE': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE.tsv',
43
+ 'DUDE_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/DUDE_MINI.tsv',
44
+ }
45
+ DATASET_MD5 = {
46
+ 'DUDE': '130d860d08206e1e407cd77150c10d88',
47
+ 'DUDE_MINI': 'e0c0d998114f0cca7516d12039d2b538',
48
+ }
49
+
50
+ SUPPORTED_MODELS = {
51
+ 'GPT4': (1, 1),
52
+ 'GPT4V': (1, 1),
53
+ 'GPT4V_HIGH': (1, 1),
54
+ 'GPT4o': (1, 1),
55
+ 'GPT4o_HIGH': (1, 1),
56
+ 'GPT4o_MINI': (1, 1),
57
+ 'XComposer2d5': (1, -1),
58
+ 'XComposer2_4KHD': (1, -1),
59
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
60
+ 'InternVL-Chat-V1-5': (5, 2),
61
+ }
62
+
63
+ def __init__(self, dataset, **kwargs):
64
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
65
+ model_name = kwargs['model']
66
+ if not listinstr(self.model_list, model_name):
67
+ raise AssertionError("{} doesn't support the evaluation on DUDE.".format(model_name))
68
+ super(DUDE, self).__init__(dataset)
69
+
70
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
71
+ self.max_pages = 120
72
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
73
+ self.concat_num = concat_num
74
+ self.column_num = column_num
75
+
76
+ def prepare_tsv(self, url, file_md5=None):
77
+ data_root = LMUDataRoot()
78
+ os.makedirs(data_root, exist_ok=True)
79
+ file_name = url.split('/')[-1]
80
+ data_path = osp.join(data_root, file_name)
81
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
82
+ pass
83
+ else:
84
+ warnings.warn('The dataset tsv is not downloaded')
85
+ download_file(url, data_path)
86
+ return load(data_path)
87
+
88
+ def dump_image(self, origin_line):
89
+ os.makedirs(self.img_root, exist_ok=True)
90
+ try:
91
+ import fitz
92
+ except Exception as e:
93
+ logging.critical(f'{type(e)}: {e}')
94
+ logging.critical('Please use `pip install pymupdf` to parse PDF files.')
95
+
96
+ line = origin_line.copy()
97
+ if not isinstance(line['image_path'], List):
98
+ line['image_path'] = [line['image_path']]
99
+ line['image_path'] = line['image_path'][:self.max_pages]
100
+ skip_pdf_parse = True
101
+ for im_name in line['image_path']:
102
+ path = osp.join(self.img_root, im_name)
103
+ if not read_ok(path):
104
+ skip_pdf_parse = False
105
+ break
106
+
107
+ # Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
108
+ if skip_pdf_parse:
109
+ line['image'] = line['image_path']
110
+ else:
111
+ pdf_data = base64.b64decode(line['image'])
112
+ pdf_file = io.BytesIO(pdf_data)
113
+ encoded_images = []
114
+ with fitz.open(stream=pdf_file, filetype='pdf') as doc:
115
+ doc = doc[:self.max_pages]
116
+ for page in doc:
117
+ image = page.get_pixmap(dpi=144)
118
+ image_file = io.BytesIO(image.tobytes(output='png'))
119
+ image = Image.open(image_file)
120
+ encoded_image = encode_image_to_base64(image)
121
+ encoded_images.append(encoded_image)
122
+ line['image'] = encoded_images
123
+ print('process {}'.format(line['doc_id']))
124
+
125
+ if 'image' in line:
126
+ if isinstance(line['image'], list):
127
+ tgt_path = []
128
+ assert 'image_path' in line
129
+ for img, im_name in zip(line['image'], line['image_path']):
130
+ path = osp.join(self.img_root, im_name)
131
+ if not read_ok(path):
132
+ decode_base64_to_image_file(img, path)
133
+ tgt_path.append(path)
134
+ else:
135
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
136
+ if not read_ok(tgt_path):
137
+ decode_base64_to_image_file(line['image'], tgt_path)
138
+ tgt_path = [tgt_path]
139
+ else:
140
+ assert 'image_path' in line
141
+ tgt_path = toliststr(line['image_path'])
142
+
143
+ if self.concat_num > 0 and not self.is_api:
144
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
145
+
146
+ old_tgt_path = tgt_path
147
+ assert isinstance(old_tgt_path, list)
148
+ if self.column_num != -1:
149
+ tgt_path = [
150
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
151
+ for i in range(len(concatenated_images))
152
+ ]
153
+ else:
154
+ tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
155
+
156
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
157
+ if not read_ok(path):
158
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
159
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
160
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
161
+ return tgt_path
162
+
163
+ @classmethod
164
+ def evaluate(self, eval_file, **judge_kwargs):
165
+ logger = get_logger('Evaluation')
166
+ model = judge_kwargs['model']
167
+
168
+ suffix = eval_file.split('.')[-1]
169
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
170
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
171
+
172
+ if osp.exists(storage):
173
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in DUDE_eval. ')
174
+ else:
175
+ data = load(eval_file)
176
+ model = build_judge(max_tokens=128, **judge_kwargs)
177
+ lt = len(data)
178
+ lines = [data.iloc[i] for i in range(lt)]
179
+ tups = [(model, line) for line in lines]
180
+ indices = [line['index'] for line in lines]
181
+
182
+ ans = {}
183
+ if osp.exists(tmp_file):
184
+ ans = load(tmp_file)
185
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
186
+ indices = [i for i in indices if i not in ans]
187
+
188
+ if len(indices):
189
+ new_results = list()
190
+ for model, line in tqdm(tups):
191
+ res = MMLongBench_auxeval(model, line)
192
+ new_results.append(res)
193
+
194
+ log_map, res_map, pred_map = {}, {}, {}
195
+ all_inds = [line['index'] for line in lines]
196
+ for k, v in zip(all_inds, new_results):
197
+ log_map[k] = v['log']
198
+ res_map[k] = v['res']
199
+ pred_map[k] = v['pred']
200
+ data['res'] = [res_map[idx] for idx in data['index']]
201
+ data['log'] = [log_map[idx] for idx in data['index']]
202
+ data['pred'] = [pred_map[idx] for idx in data['index']]
203
+ dump(data, storage)
204
+
205
+ score = DUDE_acc(storage)
206
+ score_pth = storage.replace('.xlsx', '_score.csv')
207
+
208
+ dump(score, score_pth)
209
+ logger.info(f'DUDE successfully finished evaluating {eval_file}, results saved in {score_pth}')
210
+ logger.info('Score: ')
211
+ logger.info(score)
VLMEvalKit/vlmeval/dataset/dynamath.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import sympy as sp
4
+ import numpy as np
5
+ import pandas as pd
6
+ from sympy import simplify, Eq, sympify, Pow, pi
7
+ from sympy.parsing.latex import parse_latex
8
+ import sys
9
+ import math
10
+ import os
11
+ import os.path as osp
12
+ import argparse
13
+
14
+ from .image_base import ImageBaseDataset
15
+ from .utils import build_judge
16
+ from ..utils import track_progress_rich
17
+ from ..smp import load, dump, d2df, toliststr
18
+
19
+
20
+ def preprocess(str1):
21
+ if 0 <= str1.find("{") < str1.rfind("}"):
22
+ str1 = str1[str1.find("{"): str1.rfind("}") + 1]
23
+ str2 = str1.replace("\\", "")
24
+ str2 = str2.replace("\\n", "\n")
25
+ return str2
26
+
27
+
28
+ def transfer(str1):
29
+ if "\u03c0" in str1:
30
+ strs = str1.split('\u03c0')
31
+ str1 = strs[0]
32
+ return float(str1) * np.pi
33
+ else:
34
+ return float(str1)
35
+
36
+
37
+ def parse_answer(answer, answer_type="multiple choice"):
38
+ if answer_type == "float":
39
+ if answer.isdigit():
40
+ return True, float(answer)
41
+ else:
42
+ parts = answer.split(' ')
43
+ answer = parts[0]
44
+ try:
45
+ answer = transfer(answer)
46
+ return True, answer
47
+ except:
48
+ return False, None
49
+ elif answer_type == "multiple choice":
50
+ if len(answer) == 1:
51
+ return True, answer.upper()
52
+ else:
53
+ in_flag = [ch in answer.upper() for ch in 'ABCDE']
54
+ if sum(in_flag) == 1:
55
+ for ch in 'ABCDE':
56
+ if ch in answer.upper():
57
+ return True, ch
58
+ return False, None
59
+ else:
60
+ return True, answer
61
+
62
+
63
+ def DynaMath_auxeval(model, line):
64
+ pred = line['prediction']
65
+ pred = preprocess(pred)
66
+
67
+ succeed, short_answer = None, None
68
+ try:
69
+ dj = json.loads(pred, strict=False)
70
+ short_answer = dj.get("short answer")
71
+ assert short_answer is not None
72
+ succeed, short_answer = parse_answer(short_answer, answer_type=line['anwser_type'])
73
+ assert succeed
74
+ except:
75
+ # Failed to parse the JSON, use an auxiliary LLM to get the short answer
76
+ if line['answer_type'] == 'multiple choice':
77
+ inst = "Output the corresponing choice option, such as 'A', 'B', 'C', 'D', in a single line."
78
+ elif line['answer_type'] == 'float':
79
+ inst = "Output a three-digit floating-point number in a single line."
80
+ else:
81
+ inst = (
82
+ "Output a short answer in a single line. Any float numbers in the answer "
83
+ "should be formatted as three-digit floating-point numbers."
84
+ )
85
+
86
+ prompt = f"Free-form answer: {pred}\nInstruction: {inst}"
87
+ response = pred
88
+ succeed, short_answer = parse_answer(response, line['answer_type'])
89
+ if not succeed:
90
+ response = model.generate(prompt)
91
+ succeed, short_answer = parse_answer(response, line['answer_type'])
92
+
93
+ if line['answer_type'] == 'float':
94
+ if succeed:
95
+ diff = float(short_answer) - float(line['answer'])
96
+ if abs(diff) <= 0.001:
97
+ return dict(parse=True, extracted=short_answer, correct=True)
98
+ else:
99
+ return dict(parse=True, extracted=short_answer, correct=False)
100
+ else:
101
+ return dict(parse=False, extracted=None, correct=False)
102
+ elif line['answer_type'] == 'multiple choice':
103
+ if succeed:
104
+ return dict(parse=True, extracted=short_answer, correct=(short_answer == line['answer']))
105
+ else:
106
+ if line['answer'] in pred[:3].upper():
107
+ return dict(parse=False, extracted=None, correct=True)
108
+ else:
109
+ return dict(parse=False, extracted=None, correct=False)
110
+ else:
111
+ if succeed:
112
+ return dict(parse=True, extracted=short_answer, correct=(short_answer.lower() in line['answer'].lower()))
113
+ else:
114
+ return dict(parse=False, extracted=None, correct=(short_answer.lower() in line['answer'].lower()))
115
+
116
+
117
+ class Dynamath(ImageBaseDataset):
118
+
119
+ TYPE = 'VQA'
120
+ DATASET_URL = {'DynaMath': 'https://opencompass.openxlab.space/utils/VLMEval/DynaMath.tsv'}
121
+ DATASET_MD5 = {'DynaMath': 'b8425ad9a7114571fc9366e013699494'}
122
+ GUIDE = """
123
+ ## Answer Instruction Please provide an answer to the question outlined above. Your response should adhere \
124
+ to the following JSON format, which includes two keys: 'solution' and 'short answer'. The 'solution' key can contain \
125
+ detailed steps needed to solve the question, and the 'short answer' key should provide a concise response. {INST}
126
+
127
+ Example of expected JSON response format:
128
+
129
+ """
130
+ EXAMPLE = {
131
+ "solution": "[Detailed step-by-step explanation]",
132
+ "short answer": "[Concise Answer]"
133
+ }
134
+ TEXT_EXAMPLE = json.dumps(EXAMPLE, indent=4)
135
+
136
+ # Given one data record, return the built prompt (a multi-modal message), can override
137
+ def build_prompt(self, line):
138
+ if isinstance(line, int):
139
+ line = self.data.iloc[line]
140
+
141
+ if self.meta_only:
142
+ tgt_path = toliststr(line['image_path'])
143
+ else:
144
+ tgt_path = self.dump_image(line)
145
+
146
+ prompt = f"## Question\n {line['question']}"
147
+ if line['answer_type'] == 'multiple choice':
148
+ inst = "Provide the corresponing choice option in the 'short answer' key, such as 'A', 'B', 'C', or 'D'."
149
+ elif line['answer_type'] == 'float':
150
+ inst = "Format the answer as a three-digit floating-point number and provide it in the 'short answer' key."
151
+ else:
152
+ inst = "Float numbers in the answer should be formatted as three-digit floating-point numbers."
153
+
154
+ prompt = prompt + self.GUIDE.format(INST=inst) + self.TEXT_EXAMPLE
155
+
156
+ msgs = []
157
+ if isinstance(tgt_path, list):
158
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
159
+ else:
160
+ msgs = [dict(type='image', value=tgt_path)]
161
+ msgs.append(dict(type='text', value=prompt))
162
+ return msgs
163
+
164
+ def evaluate(self, eval_file, **judge_kwargs):
165
+ judge_name = judge_kwargs.pop('model', 'gpt-4o-mini')
166
+
167
+ model = build_judge(model=judge_name, **judge_kwargs)
168
+ suffix = eval_file.split('.')[-1]
169
+
170
+ storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
171
+ score_file = eval_file.replace(f'.{suffix}', f'_{judge_name}_score.csv') # noqa: F841
172
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
173
+ nproc = judge_kwargs.pop('nproc', 6) # noqa: F841
174
+
175
+ res = load(tmp_file) if os.path.exists(tmp_file) else {}
176
+ res = {k: v for k, v in res.items() if v is not None}
177
+
178
+ model.system_prompt = """\
179
+ You are a helpful assistant that helps me to format free-form answers into a short answer according to the instruction.
180
+ """
181
+ if not osp.exists(storage):
182
+ data = load(eval_file)
183
+ lt = len(data)
184
+ payloads = [dict(model=model, line=data.iloc[i]) for i in range(lt) if data.iloc[i]['index'] not in res]
185
+ keys = [idx for idx in data['index'] if idx not in res]
186
+
187
+ if len(keys):
188
+ results = track_progress_rich(DynaMath_auxeval, payloads, nproc=nproc, save=tmp_file, keys=keys)
189
+ for k, r in zip(keys, results):
190
+ res[k] = r
191
+
192
+ data['parse'] = [res[idx]['parse'] for idx in data['index']]
193
+ data['extracted'] = [res[idx]['extracted'] for idx in data['index']]
194
+ data['correct'] = [res[idx]['correct'] for idx in data['index']]
195
+ dump(data, storage)
196
+
197
+ data = load(storage)
198
+ # Calculate Average Accuracy
199
+ score_avg = {}
200
+ score_avg['Overall'] = np.mean(data['correct'])
201
+
202
+ subs = set(data['subject'])
203
+ for sub in subs:
204
+ data_sub = data[data['subject'] == sub]
205
+ score_avg[f'Subject-{sub}'] = np.mean(data_sub['correct'])
206
+
207
+ lvls = set(data['knowledge_level'])
208
+ for lvl in lvls:
209
+ data_lvl = data[data['knowledge_level'] == lvl]
210
+ score_avg[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
211
+
212
+ # Calculate the Worst Case Accuracy
213
+ score_worst = {}
214
+ data_worst = data[data['varid'] == 1]
215
+ qid2corr = {idx: True for idx in data_worst['index']}
216
+ lt = len(data)
217
+ for i in range(lt):
218
+ item = data.iloc[i]
219
+ qid2corr[item['qid']] *= item['correct']
220
+ data_worst['correct'] = [qid2corr[idx] for idx in data_worst['qid']]
221
+ score_worst['Overall'] = np.mean(data_worst['correct'])
222
+
223
+ subs = set(data_worst['subject'])
224
+ for sub in subs:
225
+ data_sub = data_worst[data_worst['subject'] == sub]
226
+ score_worst[f'Subject-{sub}'] = np.mean(data_sub['correct'])
227
+
228
+ lvls = set(data_worst['knowledge_level'])
229
+ for lvl in lvls:
230
+ data_lvl = data_worst[data_worst['knowledge_level'] == lvl]
231
+ score_worst[f'Level-{lvl}'] = np.mean(data_lvl['correct'])
232
+
233
+ d1 = {'Setting': 'Average'}
234
+ d1.update(score_avg)
235
+ d2 = {'Setting': 'Worst Case'}
236
+ d2.update(score_worst)
237
+ score = pd.concat([d2df(d1), d2df(d2)], ignore_index=True)
238
+
239
+ dump(score, score_file)
240
+ return score
VLMEvalKit/vlmeval/dataset/image_base.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from abc import abstractmethod
3
+ from ..smp import *
4
+
5
+
6
+ def img_root_map(dataset):
7
+ if 'MM_NIAH' in dataset:
8
+ return 'MMNIAH'
9
+ if 'CRPE' in dataset:
10
+ return 'CRPE'
11
+ if 'OCRVQA' in dataset:
12
+ return 'OCRVQA'
13
+ if 'COCO_VAL' == dataset:
14
+ return 'COCO'
15
+ if 'MMMU' in dataset:
16
+ return 'MMMU'
17
+ if "QSpatial" in dataset:
18
+ return "QSpatial"
19
+
20
+ mmbench_root_map = {
21
+ 'MMBench_DEV_EN': 'MMBench', 'MMBench_TEST_EN': 'MMBench',
22
+ 'MMBench_DEV_CN': 'MMBench', 'MMBench_TEST_CN': 'MMBench',
23
+ 'MMBench': 'MMBench', 'MMBench_CN': 'MMBench',
24
+ 'MMBench_DEV_EN_V11': 'MMBench_V11', 'MMBench_TEST_EN_V11': 'MMBench_V11',
25
+ 'MMBench_DEV_CN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_V11',
26
+ 'MMBench_V11': 'MMBench', 'MMBench_CN_V11': 'MMBench',
27
+ }
28
+ if dataset in mmbench_root_map:
29
+ return mmbench_root_map[dataset]
30
+ return dataset
31
+
32
+
33
+ class ImageBaseDataset:
34
+
35
+ MODALITY = 'IMAGE'
36
+ DATASET_URL = {}
37
+ DATASET_MD5 = {}
38
+
39
+ def __init__(self, dataset='MMBench', skip_noimg=True):
40
+ ROOT = LMUDataRoot()
41
+ # You can override this variable to save image files to a different directory
42
+ self.dataset_name = dataset
43
+ self.img_root = osp.join(ROOT, 'images', img_root_map(dataset))
44
+
45
+ data = self.load_data(dataset)
46
+ self.skip_noimg = skip_noimg
47
+ if skip_noimg and 'image' in data:
48
+ data = data[~pd.isna(data['image'])]
49
+
50
+ data['index'] = [str(x) for x in data['index']]
51
+
52
+ self.meta_only = True
53
+
54
+ # The image field can store the base64 encoded image or another question index (for saving space)
55
+ if 'image' in data:
56
+ data['image'] = [str(x) for x in data['image']]
57
+ image_map = {x: y for x, y in zip(data['index'], data['image'])}
58
+ for k in image_map:
59
+ if len(image_map[k]) <= 64:
60
+ idx = image_map[k]
61
+ assert idx in image_map and len(image_map[idx]) > 64
62
+ image_map[k] = image_map[idx]
63
+
64
+ images = [toliststr(image_map[k]) for k in data['index']]
65
+ data['image'] = [x[0] if len(x) == 1 else x for x in images]
66
+ self.meta_only = False
67
+
68
+ if 'image_path' in data:
69
+ paths = [toliststr(x) for x in data['image_path']]
70
+ data['image_path'] = [x[0] if len(x) == 1 else x for x in paths]
71
+
72
+ if np.all([istype(x, int) for x in data['index']]):
73
+ data['index'] = [int(x) for x in data['index']]
74
+
75
+ self.data = data
76
+ self.post_build(dataset)
77
+
78
+ def __len__(self):
79
+ return len(self.data)
80
+
81
+ def __getitem__(self, idx):
82
+ return dict(self.data.iloc[idx])
83
+
84
+ def prepare_tsv(self, url, file_md5=None):
85
+ data_root = LMUDataRoot()
86
+ os.makedirs(data_root, exist_ok=True)
87
+ update_flag = False
88
+ file_name = url.split('/')[-1]
89
+ data_path = osp.join(data_root, file_name)
90
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
91
+ pass
92
+ else:
93
+ warnings.warn('The dataset tsv is not downloaded')
94
+ download_file(url, data_path)
95
+ update_flag = True
96
+
97
+ if file_size(data_path, 'GB') > 1:
98
+ local_path = data_path.replace('.tsv', '_local.tsv')
99
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
100
+ from ..tools import LOCALIZE
101
+ LOCALIZE(data_path, local_path)
102
+ data_path = local_path
103
+ return load(data_path)
104
+
105
+ def dump_image(self, line):
106
+ os.makedirs(self.img_root, exist_ok=True)
107
+
108
+ if 'image' in line:
109
+ if isinstance(line['image'], list):
110
+ tgt_path = []
111
+ assert 'image_path' in line
112
+ for img, im_name in zip(line['image'], line['image_path']):
113
+ path = osp.join(self.img_root, im_name)
114
+ if not read_ok(path):
115
+ decode_base64_to_image_file(img, path)
116
+ tgt_path.append(path)
117
+ else:
118
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
119
+ if not read_ok(tgt_path):
120
+ decode_base64_to_image_file(line['image'], tgt_path)
121
+ tgt_path = [tgt_path]
122
+ else:
123
+ assert 'image_path' in line
124
+ tgt_path = toliststr(line['image_path'])
125
+
126
+ return tgt_path
127
+
128
+ def display(self, line):
129
+ if isinstance(line, int):
130
+ line = self.data.iloc[line]
131
+ assert isinstance(line, pd.Series) or isinstance(line, dict)
132
+ mmqa_display(line)
133
+
134
+ # Return a list of dataset names that are supported by this class, can override
135
+ @classmethod
136
+ def supported_datasets(cls):
137
+ return list(cls.DATASET_URL)
138
+
139
+ # Given the dataset name, return the dataset as a pandas dataframe, can override
140
+ def load_data(self, dataset):
141
+ url = self.DATASET_URL[dataset]
142
+ file_md5 = self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
143
+ return self.prepare_tsv(url, file_md5)
144
+
145
+ # Post built hook, will be called after the dataset is built, can override
146
+ def post_build(self, dataset):
147
+ pass
148
+
149
+ # Given one data record, return the built prompt (a multi-modal message), can override
150
+ def build_prompt(self, line):
151
+ if isinstance(line, int):
152
+ line = self.data.iloc[line]
153
+
154
+ if self.meta_only:
155
+ tgt_path = toliststr(line['image_path'])
156
+ else:
157
+ tgt_path = self.dump_image(line)
158
+
159
+ question = line['question']
160
+
161
+ msgs = []
162
+ if isinstance(tgt_path, list):
163
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
164
+ else:
165
+ msgs = [dict(type='image', value=tgt_path)]
166
+ msgs.append(dict(type='text', value=question))
167
+ return msgs
168
+
169
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
170
+ @abstractmethod
171
+ def evaluate(self, eval_file, **judge_kwargs):
172
+ pass
VLMEvalKit/vlmeval/dataset/image_caption.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ from ..smp import *
3
+
4
+
5
+ class COCO_Caption_Scorer():
6
+ def __init__(self, ref, gt):
7
+ from pycocoevalcap.bleu.bleu import Bleu
8
+ from pycocoevalcap.rouge.rouge import Rouge
9
+ from pycocoevalcap.cider.cider import Cider
10
+
11
+ self.ref = ref
12
+ self.gt = gt
13
+ print('setting up scorers...')
14
+ self.scorers = [
15
+ (Bleu(4), ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4']),
16
+ (Rouge(), 'ROUGE_L'),
17
+ (Cider(), 'CIDEr'),
18
+ ]
19
+
20
+ def compute_scores(self):
21
+ total_scores = {}
22
+ for scorer, method in self.scorers:
23
+ print('computing %s score...' % (scorer.method()))
24
+ score, scores = scorer.compute_score(self.gt, self.ref)
25
+ if isinstance(method, list):
26
+ for sc, scs, m in zip(score, scores, method):
27
+ print('%s: %0.3f' % (m, sc * 100))
28
+ total_scores['Bleu'] = [x * 100 for x in score]
29
+ else:
30
+ print('%s: %0.3f' % (method, score * 100))
31
+ total_scores[method] = score * 100
32
+
33
+ print('*****DONE*****')
34
+ for key, value in total_scores.items():
35
+ print('{}:{}'.format(key, value))
36
+ return total_scores
37
+
38
+
39
+ class ImageCaptionDataset(ImageBaseDataset):
40
+
41
+ TYPE = 'Caption'
42
+
43
+ DATASET_URL = {
44
+ 'COCO_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/COCO_VAL.tsv',
45
+ }
46
+
47
+ DATASET_MD5 = {
48
+ 'COCO_VAL': '72a5079dead060269ac222c5aa5128af',
49
+ }
50
+
51
+ def load_data(self, dataset):
52
+ data = super().load_data(dataset)
53
+ if 'question' not in data:
54
+ data['question'] = [(
55
+ 'Please describe this image in general. Directly provide the description, '
56
+ 'do not include prefix like "This image depicts". '
57
+ )] * len(data)
58
+ return data
59
+
60
+ # It returns a dictionary of scores
61
+ @classmethod
62
+ def evaluate(self, eval_file, **kwargs):
63
+ data = load(eval_file)
64
+ lt = len(data)
65
+ lines = [data.iloc[i] for i in range(lt)]
66
+ ref, gt = {}, {}
67
+ for i, line in enumerate(lines):
68
+ ref[str(i)] = [str(line['prediction'])]
69
+ gt[str(i)] = eval(line['answer'])
70
+
71
+ scorer = COCO_Caption_Scorer(ref, gt)
72
+ coco_caption_score_dict = scorer.compute_scores()
73
+ score_pth = eval_file.replace('.xlsx', '_score.json')
74
+ dump(coco_caption_score_dict, score_pth)
75
+ return coco_caption_score_dict
VLMEvalKit/vlmeval/dataset/image_mcq.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from .image_base import ImageBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from ..smp import *
6
+ import pandas as pd
7
+
8
+ MMMB_URLS = {
9
+ 'MMMB_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ar.tsv',
10
+ 'MMMB_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_cn.tsv',
11
+ 'MMMB_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_en.tsv',
12
+ 'MMMB_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_pt.tsv',
13
+ 'MMMB_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_ru.tsv',
14
+ 'MMMB_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmmb/mmmb_tr.tsv',
15
+ }
16
+
17
+ MTL_MMBench_URLS = {
18
+ 'MMBench_dev_ar': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ar.tsv',
19
+ 'MMBench_dev_cn': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_cn.tsv',
20
+ 'MMBench_dev_en': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_en.tsv',
21
+ 'MMBench_dev_pt': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_pt.tsv',
22
+ 'MMBench_dev_tr': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_tr.tsv',
23
+ 'MMBench_dev_ru': 'https://huggingface.co/datasets/AIDC-AI/Parrot-dataset/resolve/main/mmbench/mmbench_dev_ru.tsv',
24
+ }
25
+
26
+ MMMB_MD5 = {
27
+ 'MMMB_ar': 'f3a18b6385f1d9701840aa42de27aead', 'MMMB_cn': '13ed82fa89730037292fcaa27f08f430',
28
+ 'MMMB_en': '1cd781a71ec5a2983c090b84105d6a01', 'MMMB_pt': '548ea2b3bb2da991790386f0015d30d1',
29
+ 'MMMB_ru': 'ce1cc8a0533425ab0d86b326ebfc2984', 'MMMB_tr': '0733739d43090327975294292bc5cd67'
30
+ }
31
+
32
+ MTL_MMBench_MD5 = {
33
+ 'MMBench_dev_ar': '4271b4a0d0200e1a86380a878e0d64a4', 'MMBench_dev_cn': '2ed5135326fed02c8e51ea50dda8222f',
34
+ 'MMBench_dev_en': 'd9ab776fc018b3d45785e9a5c23431c2', 'MMBench_dev_pt': '4ddfbcd27ef12444b908c03831cd0295',
35
+ 'MMBench_dev_tr': '4fab39d501389d3d6cc90264bb708f11', 'MMBench_dev_ru': '5ba1171ff2e68f80637bf78349e402a5'
36
+ }
37
+
38
+
39
+ class ImageMCQDataset(ImageBaseDataset):
40
+
41
+ TYPE = 'MCQ'
42
+
43
+ DATASET_URL = {
44
+ # MMBench v1.0
45
+ 'MMBench_DEV_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN.tsv',
46
+ 'MMBench_TEST_EN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN.tsv',
47
+ 'MMBench_DEV_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN.tsv',
48
+ 'MMBench_TEST_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN.tsv',
49
+ 'MMBench': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench.tsv', # Internal
50
+ 'MMBench_CN': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN.tsv', # Internal
51
+ # MMBench v1.1
52
+ 'MMBench_DEV_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_EN_V11.tsv',
53
+ 'MMBench_TEST_EN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_EN_V11.tsv',
54
+ 'MMBench_DEV_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_DEV_CN_V11.tsv',
55
+ 'MMBench_TEST_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_TEST_CN_V11.tsv',
56
+ 'MMBench_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_V11.tsv', # Internal
57
+ 'MMBench_CN_V11': 'https://opencompass.openxlab.space/utils/benchmarks/MMBench/MMBench_CN_V11.tsv', # Internal
58
+ # SEEDBench Series
59
+ 'SEEDBench_IMG': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench_IMG.tsv',
60
+ 'SEEDBench2': 'https://huggingface.co/datasets/VLMEval/SEEDBench2/resolve/main/SEEDBench2.tsv',
61
+ 'SEEDBench2_Plus': 'https://opencompass.openxlab.space/utils/benchmarks/SEEDBench/SEEDBench2_Plus.tsv',
62
+ # ScienceQA Series
63
+ 'ScienceQA_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_VAL.tsv',
64
+ 'ScienceQA_TEST': 'https://opencompass.openxlab.space/utils/benchmarks/ScienceQA/ScienceQA_TEST.tsv',
65
+ # MMT-Bench
66
+ 'MMT-Bench_ALL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL_MI.tsv',
67
+ 'MMT-Bench_ALL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_ALL.tsv',
68
+ 'MMT-Bench_VAL_MI': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL_MI.tsv',
69
+ 'MMT-Bench_VAL': 'https://opencompass.openxlab.space/utils/benchmarks/MMT-Bench/MMT-Bench_VAL.tsv',
70
+ # AesBench
71
+ 'AesBench_VAL': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_VAL.tsv',
72
+ 'AesBench_TEST': 'https://huggingface.co/datasets/VLMEval/AesBench/resolve/main/AesBench_TEST.tsv',
73
+ # Q-Bench1
74
+ 'Q-Bench1_VAL': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_VAL.tsv',
75
+ 'Q-Bench1_TEST': 'https://huggingface.co/datasets/zhangzicheng/qbench_tsv/resolve/main/Q-Bench1_TEST.tsv',
76
+ # A-Bench
77
+ 'A-Bench_VAL': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_VAL.tsv',
78
+ 'A-Bench_TEST': 'https://huggingface.co/datasets/zhangzicheng/abench_tsv/resolve/main/A-bench_TEST.tsv',
79
+ # R-Bench
80
+ 'R-Bench-Dis': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-dis.tsv',
81
+ 'R-Bench-Ref': 'https://huggingface.co/datasets/lcysyzxdxc/R-Bench/blob/main/R-bench-ref.tsv',
82
+ # Other Benchmarks
83
+ 'CCBench': 'https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv',
84
+ 'AI2D_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST.tsv',
85
+ 'AI2D_TEST_NO_MASK': 'https://opencompass.openxlab.space/utils/VLMEval/AI2D_TEST_NO_MASK.tsv',
86
+ 'MMStar': 'https://opencompass.openxlab.space/utils/VLMEval/MMStar.tsv',
87
+ 'RealWorldQA': 'https://opencompass.openxlab.space/utils/VLMEval/RealWorldQA.tsv',
88
+ 'MLLMGuard_DS': 'https://opencompass.openxlab.space/utils/VLMEval/MLLMGuard_DS.tsv',
89
+ 'BLINK': 'https://opencompass.openxlab.space/utils/VLMEval/BLINK.tsv',
90
+ 'TaskMeAnything_v1_imageqa_random': (
91
+ 'https://huggingface.co/datasets/weikaih/TaskMeAnything-v1-imageqa-random/'
92
+ 'resolve/main/TaskMeAnything-v1-imageqa-random.tsv'
93
+ ),
94
+ 'A-OKVQA': 'https://huggingface.co/datasets/Allen8/A-OKVQA/resolve/main/a-okvqa.tsv',
95
+ 'WorldMedQA-V': 'https://opencompass.openxlab.space/utils/VLMEval/WorldMedQA-V.tsv',
96
+ 'VisOnlyQA-VLMEvalKit': (
97
+ 'https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Eval_Real/'
98
+ 'resolve/main/visonlyqa_vlmevalkit.tsv'
99
+ ),
100
+ }
101
+
102
+ DATASET_MD5 = {
103
+ # MMBench v1.0
104
+ 'MMBench_DEV_EN': 'b6caf1133a01c6bb705cf753bb527ed8',
105
+ 'MMBench_TEST_EN': '6939fadb0ce626fefc0bdc9c64efc528',
106
+ 'MMBench_DEV_CN': '08b8fc3324a5ed74155350f57be69fbd',
107
+ 'MMBench_TEST_CN': '7e1239baf0ee4c8b513e19705a0f317e',
108
+ 'MMBench': '4115aea3383f3dd0083be6a633e0f820', # Internal Only
109
+ 'MMBench_CN': '2e053ffc90ea598b1feae13c36dc13ee', # Internal Only
110
+ # MMBench v1.1
111
+ 'MMBench_DEV_EN_V11': '30c05be8f2f347a50be25aa067248184',
112
+ 'MMBench_TEST_EN_V11': '26f0f15381a21720255091d3e0316ce6',
113
+ 'MMBench_DEV_CN_V11': '593f9b5f6bea453d870a798b34ae4f37',
114
+ 'MMBench_TEST_CN_V11': '74bbe4556dac745613c7cbe5ad787050',
115
+ 'MMBench_V11': 'b9276414f57af1308dcc4d0cd9b42e7c', # Internal Only
116
+ 'MMBench_CN_V11': '95f6980dd1b4de38e3cbffe0305a3f25', # Internal Only
117
+ # SEEDBench
118
+ 'SEEDBench_IMG': '68017231464752261a2526d6ca3a10c0',
119
+ 'SEEDBench2': '4ec15cf864c4f16274112284f531813e',
120
+ 'SEEDBench2_Plus': 'e32d3216dc4f452b0fe497a52015d1fd',
121
+ # ScienceQA
122
+ 'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
123
+ 'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
124
+ # MMT-Bench
125
+ 'MMT-Bench_ALL_MI': '5272157097e19cdd7cb41e412ab3b7c7',
126
+ 'MMT-Bench_ALL': 'b273a2f4c596fe4f2605de0494cd632f',
127
+ 'MMT-Bench_VAL_MI': 'c7d7b998eb5cd9aa36c7d4f721472462',
128
+ 'MMT-Bench_VAL': '8dd4b730f53dbf9c3aed90ca31c928e0',
129
+ # AesBench
130
+ 'AesBench_VAL': '3edb0c319e9187aa0b97fe7a11700a8c',
131
+ 'AesBench_TEST': '58b1f7ba2cc32e1d68896d6ee716bbf8',
132
+ # Q-Bench1
133
+ 'Q-Bench1_VAL': '837bdb6cd2da571713543462815187b7',
134
+ 'Q-Bench1_TEST': '15e759bfd58c9d5f30b23a317d347153',
135
+ # A-Bench
136
+ 'A-Bench_VAL': '218563ec50d34bb336c814143a5bb9c1',
137
+ 'A-Bench_TEST': '567013fb033a20cf23f51d8e865bd16c',
138
+ # R-Bench
139
+ 'R-Bench-Dis': 'd6e961dbfc43350688af2560226830b4',
140
+ 'R-Bench-Ref': '270c1cb555acb523f3fdb178ed57021d',
141
+ # Other Benchmarks
142
+ 'CCBench': 'f5dde47f24dc5a6fb6e595b409b466ac',
143
+ 'AI2D_TEST': '0f593e0d1c7df9a3d69bf1f947e71975',
144
+ 'AI2D_TEST_NO_MASK': 'fd8f463634d4fe9fbd23b876e8eea5be',
145
+ 'MMStar': 'e1ecd2140806c1b1bbf54b43372efb9e',
146
+ 'RealWorldQA': '92321028d2bc29040284b6674721e48f',
147
+ 'MLLMGuard_DS': '975fc0dd7119386e198c37d71e274b3f',
148
+ 'BLINK': '3b6649b6a662184ea046908e5506260e',
149
+ 'TaskMeAnything_v1_imageqa_random': '023fef69e2ca21827afb77c5ec3bc889',
150
+ 'WorldMedQA-V': '441e63875e30c87f5750528b57b41285',
151
+ "VisOnlyQA-VLMEvalKit": 'cf460a31d2acb8d3a7cecd0e69298bfa',
152
+ }
153
+
154
+ DATASET_URL.update(MMMB_URLS)
155
+ DATASET_URL.update(MTL_MMBench_URLS)
156
+ DATASET_MD5.update(MMMB_MD5)
157
+ DATASET_MD5.update(MTL_MMBench_MD5)
158
+
159
+ def build_prompt(self, line):
160
+
161
+ if isinstance(line, int):
162
+ line = self.data.iloc[line]
163
+
164
+ if self.meta_only:
165
+ tgt_path = toliststr(line['image_path'])
166
+ else:
167
+ tgt_path = self.dump_image(line)
168
+
169
+ question = line['question']
170
+ options = {
171
+ cand: line[cand]
172
+ for cand in string.ascii_uppercase
173
+ if cand in line and not pd.isna(line[cand])
174
+ }
175
+ options_prompt = 'Options:\n'
176
+ for key, item in options.items():
177
+ options_prompt += f'{key}. {item}\n'
178
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
179
+ prompt = ''
180
+ if hint is not None:
181
+ prompt += f'Hint: {hint}\n'
182
+ prompt += f'Question: {question}\n'
183
+ if len(options):
184
+ prompt += options_prompt
185
+ prompt += 'Please select the correct answer from the options above. \n'
186
+
187
+ msgs = []
188
+ if isinstance(tgt_path, list):
189
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
190
+ else:
191
+ msgs = [dict(type='image', value=tgt_path)]
192
+ msgs.append(dict(type='text', value=prompt))
193
+
194
+ return msgs
195
+
196
+ def evaluate(self, eval_file, **judge_kwargs):
197
+ from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
198
+ # assert dataset is not None
199
+ dataset_map = {
200
+ 'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
201
+ 'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
202
+ }
203
+ dataset = self.dataset_name
204
+ if dataset in dataset_map:
205
+ dataset = dataset_map[dataset]
206
+ nproc = judge_kwargs.pop('nproc', 4)
207
+
208
+ circular = False
209
+ if listinstr(['mmbench', 'ccbench'], dataset.lower()):
210
+ data = load(eval_file)
211
+ data['index'] = [int(x) for x in data['index']]
212
+ dump(data, eval_file)
213
+ circular = True
214
+
215
+ suffix = eval_file.split('.')[-1]
216
+ model = judge_kwargs.get('model', 'exact_matching')
217
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
218
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
219
+ name_str = name_str_map[model] if model in name_str_map else model
220
+
221
+ if model == 'exact_matching':
222
+ model = None
223
+ elif gpt_key_set():
224
+ model = build_judge(**judge_kwargs)
225
+ if not model.working():
226
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
227
+ warnings.warn(DEBUG_MESSAGE)
228
+ model = None
229
+ else:
230
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
231
+ model = None
232
+
233
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
234
+
235
+ data = load(eval_file)
236
+ data = data.sort_values(by='index')
237
+ data['prediction'] = [str(x) for x in data['prediction']]
238
+ # If not choice label, then use lower case
239
+ for k in data.keys():
240
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
241
+
242
+ meta = self.data
243
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
244
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
245
+ for k in data_map:
246
+ assert k in meta_q_map, (
247
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
248
+ )
249
+
250
+ if circular:
251
+ data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
252
+ else:
253
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
254
+
255
+ # load split
256
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
257
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
258
+
259
+ # May have different report acc functions for different datasets
260
+ if 'MMT' in dataset:
261
+ acc = report_acc_MMT(data)
262
+ else:
263
+ acc = report_acc(data)
264
+
265
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
266
+ dump(acc, score_file)
267
+
268
+ if dataset == 'AesBench_VAL':
269
+ warnings.warn('Note that AesBench VAL is just a toy version of AesBench TEST. For full results, \
270
+ please evaluate on AesBench TEST. The AesBench TEST dataset is more than 20 times \
271
+ larger than the VAL dataset and the leaderboard results are based on AesBench TEST.')
272
+ if dataset == 'VisOnlyQA-VLMEvalKit':
273
+ warnings.warn('Note that the results on VisOnlyQA-VLMEvalKit are different from the results on \
274
+ the original VisOnlyQA. VisOnlyQA-VLMEvalKit does not include the \
275
+ chemistry__shape_multi split and uses a different evaluation prompt. Please \
276
+ explicitly specify the version of the dataset when you report results.')
277
+
278
+ return acc
279
+
280
+
281
+ class MMMUDataset(ImageMCQDataset):
282
+
283
+ DATASET_URL = {
284
+ 'MMMU_DEV_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_DEV_VAL.tsv',
285
+ 'MMMU_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MMMU_TEST.tsv',
286
+ }
287
+
288
+ DATASET_MD5 = {
289
+ 'MMMU_DEV_VAL': '521afc0f3bf341e6654327792781644d',
290
+ 'MMMU_TEST': 'c19875d11a2d348d07e5eb4bdf33166d',
291
+ }
292
+
293
+ @staticmethod
294
+ def split_MMMU(msgs):
295
+ text, images = None, []
296
+ for s in msgs:
297
+ if s['type'] == 'image':
298
+ images.append(s['value'])
299
+ elif s['type'] == 'text':
300
+ assert text is None
301
+ text = s['value']
302
+ text_segs = text.split('<image ')
303
+ if len(text_segs) == 1:
304
+ return msgs
305
+
306
+ segs = [dict(type='text', value=text_segs[0])]
307
+ for i, seg in enumerate(text_segs):
308
+ if i == 0:
309
+ continue
310
+ assert istype(seg[0], int) and seg[1] == '>'
311
+ image_idx = int(seg[0]) - 1
312
+ segs.append(dict(type='image', value=images[image_idx]))
313
+ segs.append(dict(type='text', value=seg[2:]))
314
+ return segs
315
+
316
+ def build_prompt(self, line):
317
+ msgs = super().build_prompt(line)
318
+ msgs = self.split_MMMU(msgs)
319
+ return msgs
320
+
321
+
322
+ class MUIRDataset(ImageMCQDataset):
323
+
324
+ DATASET_URL = {
325
+ 'MUIRBench': 'http://opencompass.openxxlab.com/utils/VLMEval/MUIRBench.tsv'
326
+ }
327
+
328
+ DATASET_MD5 = {
329
+ 'MUIRBench': '2e5e6fd7699761b08a7cb3ab8c0c2ec8'
330
+ }
331
+
332
+ @staticmethod
333
+ def split_MUIR(msgs):
334
+ text, images = None, []
335
+
336
+ # Separate images and text from msgs
337
+ for s in msgs:
338
+ if s['type'] == 'image':
339
+ images.append(s['value'])
340
+ elif s['type'] == 'text':
341
+ assert text is None # Ensure only one text entry is expected
342
+ text = s['value']
343
+
344
+ # Split text by <image> tags
345
+ text_segs = text.split('<image>')
346
+
347
+ # Initialize the segments list
348
+ segs = []
349
+
350
+ # Iterate through the text segments and images
351
+ for i, seg in enumerate(text_segs):
352
+ # Append the image if this is not the first segment and there are still images left
353
+ if i > 0 and i - 1 < len(images):
354
+ segs.append(dict(type='image', value=images[i - 1]))
355
+ # Append the text segment (if it's non-empty)
356
+ if len(seg) > 0:
357
+ segs.append(dict(type='text', value=seg))
358
+
359
+ return segs
360
+
361
+ def build_prompt(self, line):
362
+
363
+ if isinstance(line, int):
364
+ line = self.data.iloc[line]
365
+
366
+ if self.meta_only:
367
+ tgt_path = toliststr(line['image_path'])
368
+ else:
369
+ tgt_path = self.dump_image(line)
370
+
371
+ question = line['question']
372
+ options = {
373
+ cand: line[cand]
374
+ for cand in string.ascii_uppercase
375
+ if cand in line and not pd.isna(line[cand])
376
+ }
377
+ # options_prompt = ''
378
+ options_prompt = '\n'.join([f'{key}. {item}' for key, item in options.items()])
379
+ # for key, item in options.items():
380
+ # options_prompt += f'{key}. {item}\n'
381
+
382
+ prompt = ''
383
+
384
+ prompt += f'{question}\n'
385
+ if len(options):
386
+ prompt += options_prompt
387
+ prompt += "\nAnswer with the option's letter from the given choices directly."
388
+
389
+ msgs = []
390
+ if isinstance(tgt_path, list):
391
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
392
+ else:
393
+ msgs = [dict(type='image', value=tgt_path)]
394
+ msgs.append(dict(type='text', value=prompt))
395
+
396
+ msgs = self.split_MUIR(msgs)
397
+ return msgs
398
+
399
+
400
+ class GMAIMMBenchDataset(ImageMCQDataset):
401
+
402
+ DATASET_URL = {
403
+ 'GMAI-MMBench_VAL': 'https://huggingface.co/datasets/VLMEval/GMAI-MMBench/resolve/main/GMAI-MMBench_VAL.tsv',
404
+ 'GMAI_mm_bench_TEST_part_1': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_1.tsv', # noqa: E501
405
+ 'GMAI_mm_bench_TEST_part_2': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_2.tsv', # noqa: E501
406
+ 'GMAI_mm_bench_TEST_part_3': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_3.tsv', # noqa: E501
407
+ 'GMAI_mm_bench_TEST_part_4': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_4.tsv', # noqa: E501
408
+ 'GMAI_mm_bench_TEST_part_5': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_5.tsv', # noqa: E501
409
+ 'GMAI_mm_bench_TEST_part_6': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_6.tsv', # noqa: E501
410
+ 'GMAI_mm_bench_TEST_part_7': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_7.tsv', # noqa: E501
411
+ 'GMAI_mm_bench_TEST_part_8': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_8.tsv', # noqa: E501
412
+ 'GMAI_mm_bench_TEST_part_9': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_9.tsv', # noqa: E501
413
+ 'GMAI_mm_bench_TEST_part_10': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_10.tsv', # noqa: E501
414
+ 'GMAI_mm_bench_TEST_part_11': 'https://huggingface.co/datasets/OpenGVLab/GMAI-MMBench/resolve/main/GMAI_mm_bench_TEST_part_11.tsv', # noqa: E501
415
+ }
416
+
417
+ DATASET_MD5 = {
418
+ 'GMAI-MMBench_VAL': '254bd581627866f1c499d3d6b4422324',
419
+ 'GMAI_mm_bench_TEST_part_1': '900d735231230a63f4ed45665c078ef4',
420
+ 'GMAI_mm_bench_TEST_part_2': '1b27ab621386945d7e4a765ad2d22b0e',
421
+ 'GMAI_mm_bench_TEST_part_3': '44bdc2b6267dd505d529b8cad06f0fb2',
422
+ 'GMAI_mm_bench_TEST_part_4': '5a04a04fcac9f1466709f242fdb80acb',
423
+ 'GMAI_mm_bench_TEST_part_5': 'c70baf8909eda9af0ddeab275c721336',
424
+ 'GMAI_mm_bench_TEST_part_6': '825abc39596b644dead9350d0cfa3b96',
425
+ 'GMAI_mm_bench_TEST_part_7': 'defb8aed2fb77365a76b6b9abd6a2701',
426
+ 'GMAI_mm_bench_TEST_part_8': 'ff490d60b85f2bb0abb67a435b298c65',
427
+ 'GMAI_mm_bench_TEST_part_9': 'ff67c86f40da93b09139ac1d1ba5dc6b',
428
+ 'GMAI_mm_bench_TEST_part_10': '3dae94627b9ac0fe00180d4780fbf6dc',
429
+ 'GMAI_mm_bench_TEST_part_11': 'd08dc813f0eb6bbab63cae2a9d113c4b',
430
+ }
431
+
432
+ @classmethod
433
+ def supported_datasets(cls):
434
+ return ['GMAI-MMBench_VAL', 'GMAI-MMBench_TEST']
435
+
436
+ def load_data(self, dataset):
437
+ if dataset == 'GMAI-MMBench_VAL':
438
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
439
+ if file_size(data_path, 'GB') > 1:
440
+ local_path = data_path.replace('.tsv', '_local.tsv')
441
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
442
+ from ..tools import LOCALIZE
443
+ LOCALIZE(data_path, local_path)
444
+ data_path = local_path
445
+ return load(data_path)
446
+ elif dataset == 'GMAI-MMBench_TEST':
447
+ dfs = []
448
+ for part_num in range(1, 12):
449
+ part_name = f'GMAI_mm_bench_TEST_part_{part_num}'
450
+ url = self.DATASET_URL[part_name]
451
+ file_md5 = self.DATASET_MD5.get(part_name)
452
+ tsv_path = osp.join(LMUDataRoot(), f'{part_name}.tsv')
453
+ if not osp.exists(tsv_path) or (file_md5 and md5(tsv_path) != file_md5):
454
+ download_file(url, filename=tsv_path)
455
+ local_path = tsv_path.replace('.tsv', '_local.tsv')
456
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL'):
457
+ from ..tools import LOCALIZE
458
+ LOCALIZE(tsv_path, local_path)
459
+ tsv_path = local_path
460
+ # 加载数据
461
+ df = load(tsv_path)
462
+ dfs.append(df)
463
+ # 合并所有数据
464
+ data = pd.concat(dfs, ignore_index=True)
465
+ return data
466
+ else:
467
+ raise ValueError(f"未知的数据集:{dataset}")
468
+
469
+ def report_acc_by_groups(self, df, group_column):
470
+ res = defaultdict(list)
471
+
472
+ # Check for the 'split' column
473
+ if 'split' in df:
474
+ splits = list(set(df['split']))
475
+ res['split'] = splits
476
+ else:
477
+ df['split'] = ['none'] * len(df)
478
+ res['split'] = ['none']
479
+
480
+ res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
481
+
482
+ if group_column not in df:
483
+ raise ValueError(f"Column '{group_column}' not found in dataframe.") # noqa: E713
484
+
485
+ abilities = list(set(df[group_column]))
486
+ abilities = ['None' if isinstance(ab, float) and pd.isna(ab) else ab for ab in abilities]
487
+ abilities.sort()
488
+
489
+ for ab in abilities:
490
+ ab_name = ab
491
+ sub_df = df[df[group_column] == ab]
492
+ res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
493
+
494
+ return pd.DataFrame(res)
495
+
496
+ def evaluate(self, eval_file, **judge_kwargs):
497
+ from .utils.multiple_choice import report_acc, mcq_vanilla_eval
498
+ nproc = judge_kwargs.pop('nproc', 4)
499
+
500
+ suffix = eval_file.split('.')[-1]
501
+ model = judge_kwargs.get('model', 'exact_matching')
502
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
503
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
504
+ name_str = name_str_map[model] if model in name_str_map else model
505
+
506
+ if model == 'exact_matching':
507
+ model = None
508
+ elif gpt_key_set():
509
+ model = build_judge(**judge_kwargs)
510
+ if not model.working():
511
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
512
+ warnings.warn(DEBUG_MESSAGE)
513
+ model = None
514
+ else:
515
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
516
+ model = None
517
+
518
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
519
+
520
+ data = load(eval_file)
521
+ data = data.sort_values(by='index')
522
+ data['prediction'] = [str(x) for x in data['prediction']]
523
+ # If not choice label, then use lower case
524
+ for k in data.keys():
525
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
526
+
527
+ meta = self.data
528
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
529
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
530
+ for k in data_map:
531
+ assert k in meta_q_map, (
532
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
533
+ )
534
+
535
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
536
+
537
+ # load split
538
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
539
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
540
+
541
+ acc = report_acc(data)
542
+
543
+ for group_col in ['clinical vqa task', 'department', 'perceptual granularity']:
544
+ acc_grouped = self.report_acc_by_groups(data, group_col)
545
+ score_file_grouped = eval_file.replace(f'.{suffix}', f'_{group_col}_acc.csv')
546
+ dump(acc_grouped, score_file_grouped)
547
+
548
+ return acc
549
+
550
+
551
+ class MMERealWorld(ImageMCQDataset):
552
+
553
+ TYPE = 'MMERealWorld'
554
+
555
+ DATASET_MD5 = {
556
+ 'MME-RealWorld': '271c33ec814c39533c467ec6fb8a6f36',
557
+ 'MME-RealWorld-Lite': '4c17057d7d3b6c4a0d4397c3dae0881c',
558
+ 'MME-RealWorld-CN': 'daaa763d52a760a38606d5dedb3fe444',
559
+ }
560
+ SYS = {
561
+ 'MME-RealWorld': (
562
+ 'Select the best answer to the above multiple-choice question based on the image. '
563
+ 'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
564
+ 'The best answer is:'
565
+ ),
566
+ 'MME-RealWorld-Lite': (
567
+ 'Select the best answer to the above multiple-choice question based on the image. '
568
+ 'Respond with only the letter (A, B, C, D, or E) of the correct option. \n'
569
+ 'The best answer is:'
570
+ ),
571
+ 'MME-RealWorld-CN': (
572
+ '根据图像选择上述多项选择题的最佳答案。只需回答正确选项的字母(A, B, C, D 或 E)。\n'
573
+ '最佳答案为:'
574
+ ),
575
+ }
576
+
577
+ @classmethod
578
+ def supported_datasets(cls):
579
+ return ['MME-RealWorld', 'MME-RealWorld-CN', 'MME-RealWorld-Lite',]
580
+
581
+ def load_data(
582
+ self, dataset="MME-RealWorld", repo_id="yifanzhang114/MME-RealWorld-Base64"
583
+ ):
584
+
585
+ def check_integrity(pth):
586
+ data_file = osp.join(pth, f"{dataset}.tsv")
587
+
588
+ if not os.path.exists(data_file):
589
+ return False
590
+
591
+ if md5(data_file) != self.DATASET_MD5[dataset]:
592
+ return False
593
+ return True
594
+
595
+ def generate_tsv(pth):
596
+ tsv_file = os.path.join(pth, f"{dataset}.tsv")
597
+
598
+ if os.path.exists(tsv_file):
599
+ print(f"{tsv_file} already exists.")
600
+ return
601
+
602
+ json_dir = os.path.join(pth, dataset)
603
+ json_files = [f for f in os.listdir(json_dir) if f.endswith(".json")]
604
+
605
+ data_list = []
606
+ for json_file in json_files:
607
+ with open(os.path.join(json_dir, json_file), "r") as f:
608
+ data = json.load(f)
609
+ for item in tqdm(data):
610
+ choice_prompt = (
611
+ "The choices are listed below:\n"
612
+ if dataset in ["MME-RealWorld", "MME-RealWorld-Lite"]
613
+ else "选项如下所示:\n"
614
+ )
615
+ data_list.append(
616
+ {
617
+ "index": item["index"],
618
+ "image": item["image"],
619
+ "question": item["question"],
620
+ "multi-choice options": choice_prompt
621
+ + "\n".join(item["multi-choice options"]),
622
+ "A": item["multi-choice options"][0][4:],
623
+ "B": item["multi-choice options"][1][4:],
624
+ "C": item["multi-choice options"][2][4:],
625
+ "D": item["multi-choice options"][3][4:],
626
+ "E": item["multi-choice options"][4][4:],
627
+ "answer": item["answer"],
628
+ "category": item["category"],
629
+ "l2-category": item["l2-category"],
630
+ }
631
+ )
632
+ df = pd.DataFrame(data_list)
633
+ df.to_csv(tsv_file, sep="\t", index=False)
634
+ print(f"TSV file saved to {tsv_file}")
635
+
636
+ # Check if dataset is cached and has integrity
637
+ if dataset == "MME-RealWorld-Lite":
638
+ url = 'https://huggingface.co/datasets/yifanzhang114/MME-RealWorld-Base64/resolve/main/mme_realworld_lite.tsv' # noqa: E501
639
+ file_md5 = (
640
+ self.DATASET_MD5[dataset] if dataset in self.DATASET_MD5 else None
641
+ )
642
+ datas = self.prepare_tsv(url, file_md5)
643
+ choice_prompt = "The choices are listed below:\n"
644
+ for index, item in datas.iterrows():
645
+ options = eval(item["multi-choice options"])
646
+ datas.loc[index, "multi-choice options"] = choice_prompt + "\n".join(
647
+ options
648
+ )
649
+ datas.loc[index, "A"] = options[0][4:]
650
+ datas.loc[index, "B"] = options[1][4:]
651
+ datas.loc[index, "C"] = options[2][4:]
652
+ datas.loc[index, "D"] = options[3][4:]
653
+ datas.loc[index, "E"] = options[4][4:]
654
+ return datas
655
+
656
+ update_flag = False
657
+ cache_path = get_cache_path(repo_id)
658
+ if cache_path is not None and check_integrity(cache_path):
659
+ dataset_path = cache_path
660
+ print(f"Using cached dataset from {cache_path}")
661
+ else:
662
+ from huggingface_hub import snapshot_download
663
+
664
+ # Download or find the dataset path
665
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
666
+ generate_tsv(dataset_path)
667
+ update_flag = True
668
+
669
+ data_path = os.path.join(dataset_path, f"{dataset}.tsv")
670
+ if file_size(data_path, "GB") > 1:
671
+ local_path = data_path.replace(".tsv", "_local.tsv")
672
+ if (
673
+ not osp.exists(local_path)
674
+ or os.environ.get("FORCE_LOCAL", None)
675
+ or update_flag
676
+ ):
677
+ from vlmeval.tools import LOCALIZE
678
+
679
+ LOCALIZE(data_path, local_path)
680
+ data_path = local_path
681
+ return load(data_path)
682
+
683
+ def post_build(self, dataset):
684
+ self.TYPE = 'MMERealWorld'
685
+
686
+ # Given one data record, return the built prompt (a multi-modal message), can override
687
+ def build_prompt(self, line):
688
+ if isinstance(line, int):
689
+ line = self.data.iloc[line]
690
+
691
+ if self.meta_only:
692
+ tgt_path = toliststr(line['image_path'])
693
+ else:
694
+ tgt_path = self.dump_image(line)
695
+
696
+ question = line['question']
697
+
698
+ choice_prompt = line['multi-choice options'] + '\n'
699
+ question += ' ' + choice_prompt + self.SYS[self.dataset_name]
700
+
701
+ msgs = []
702
+ if isinstance(tgt_path, list):
703
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
704
+ else:
705
+ msgs = [dict(type='image', value=tgt_path)]
706
+ msgs.append(dict(type='text', value=question))
707
+ return msgs
708
+
709
+ # It returns a dictionary
710
+ @classmethod
711
+ def evaluate(self, eval_file, **judge_kwargs):
712
+ from .utils.multiple_choice import extract_characters_regex, get_dimension_rating
713
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
714
+ FAIL_MSG = 'Failed to obtain answer via API.'
715
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
716
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
717
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
718
+
719
+ if not osp.exists(score_file):
720
+
721
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
722
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
723
+
724
+ data = load(eval_file)
725
+ cnt_rejected = 0
726
+ data_un = data[~pd.isna(data['prediction'])]
727
+
728
+ for idx in data['index']:
729
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
730
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
731
+
732
+ extract_pred = extract_characters_regex(pred)
733
+ if extract_pred == '':
734
+ cnt_rejected += 1
735
+ data.loc[data['index'] == idx, 'score'] = 0
736
+ else:
737
+ data.loc[data['index'] == idx, 'score'] = int(extract_pred == ans)
738
+
739
+ print(
740
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
741
+ f'failed to obtain the score for another {cnt_rejected} questions. '
742
+ f'Those questions will be counted as 0 score in ALL rating.'
743
+ )
744
+
745
+ dump(data, score_file)
746
+
747
+ rating = get_dimension_rating(score_file)
748
+ dump(rating, tgt_file)
749
+ return rating
750
+
751
+
752
+ class HRBenchDataset(ImageMCQDataset):
753
+
754
+ DATASET_URL = {
755
+ 'HRBench4K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_4k.tsv',
756
+ 'HRBench8K': 'https://huggingface.co/datasets/DreamMr/HR-Bench/resolve/main/hr_bench_8k.tsv',
757
+ }
758
+
759
+ DATASET_MD5 = {
760
+ 'HRBench4K': 'f6b041b03d49543494b8a56d2e35be65',
761
+ 'HRBench8K': '274c9c7f89329b804a4723178a00219c',
762
+ }
763
+
764
+ def evaluate(self, eval_file, **judge_kwargs):
765
+ assert os.path.exists(eval_file), '{} does not exist!'.format(eval_file)
766
+ from .utils.multiple_choice import mcq_vanilla_eval
767
+ from .utils.hrbench import report_acc_hrbench
768
+ nproc = judge_kwargs.pop('nproc', 4)
769
+
770
+ suffix = eval_file.split('.')[-1]
771
+ model = judge_kwargs.get('model', 'extract_matching')
772
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
773
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
774
+ name_str = name_str_map[model] if model in name_str_map else model
775
+
776
+ if model == 'exact_matching':
777
+ model = None
778
+ elif gpt_key_set():
779
+ model = build_judge(**judge_kwargs)
780
+ if not model.working():
781
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
782
+ warnings.warn(DEBUG_MESSAGE)
783
+ model = None
784
+ else:
785
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
786
+ model = None
787
+
788
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
789
+
790
+ data = load(eval_file)
791
+ data = data.sort_values(by='index')
792
+ data['prediction'] = [str(x) for x in data['prediction']]
793
+ # If not choice label, then use lower case
794
+ for k in data.keys():
795
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
796
+
797
+ meta = self.data
798
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
799
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
800
+ for k in data_map:
801
+ assert k in meta_q_map, (
802
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
803
+ )
804
+
805
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
806
+
807
+ if osp.exists(score_file):
808
+ acc = load(score_file)
809
+ return acc
810
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
811
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
812
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
813
+
814
+ acc = report_acc_hrbench(data)
815
+
816
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
817
+ dump(acc, score_file)
818
+
819
+ return acc
820
+
821
+
822
+ class CustomMCQDataset(ImageMCQDataset):
823
+
824
+ def load_data(self, dataset):
825
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
826
+
827
+ if file_size(data_path, 'GB') > 1:
828
+ local_path = data_path.replace('.tsv', '_local.tsv')
829
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
830
+ from ..tools import LOCALIZE
831
+ LOCALIZE(data_path, local_path)
832
+ data_path = local_path
833
+ return load(data_path)
834
+
835
+
836
+ class NaturalBenchDataset(ImageMCQDataset):
837
+
838
+ DATASET_URL = {
839
+ 'NaturalBenchDataset': (
840
+ 'https://huggingface.co/datasets/BaiqiL/'
841
+ 'NaturalBench/resolve/main/NaturalBenchDataset.tsv'
842
+ ),
843
+ }
844
+ DATASET_MD5 = {
845
+ 'NaturalBenchDataset':'dbe25b044bc35696426381e9ba4fe930',
846
+ }
847
+
848
+ def build_prompt(self, line):
849
+ SUFFIX_FOR_VQA = {
850
+ "yes_no": "Please answer Yes or No.",
851
+ "multiple_choice": "Please output the letter corresponding to the correct option."
852
+ }
853
+ if isinstance(line, int):
854
+ line = self.data.iloc[line]
855
+
856
+ if self.meta_only:
857
+ tgt_path = toliststr(line['image_path'])
858
+ else:
859
+ tgt_path = self.dump_image(line)
860
+
861
+ question = line['question']
862
+ prompt = f'{question} {SUFFIX_FOR_VQA[line["type"]]}'
863
+ msgs = []
864
+ if isinstance(tgt_path, list):
865
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
866
+ else:
867
+ msgs = [dict(type='image', value=tgt_path)]
868
+ msgs.append(dict(type='text', value=prompt))
869
+
870
+ return msgs
871
+
872
+ def evaluate(self, eval_file, **judge_kwargs):
873
+ from .utils.naturalbench import extract_answer, get_scores
874
+
875
+ data = load(eval_file)
876
+ data = data.sort_values(by='index')
877
+ predictions = [str(x) for x in data['prediction']]
878
+ answers = [str(x) for x in data['answer']]
879
+ indexs = [str(x) for x in data['index']]
880
+ meta = self.data
881
+ types = [str(x) for x in meta['type']]
882
+ results = {}
883
+ assert len(predictions) == len(answers) == len(indexs) == len(types) == (1900 * 4)
884
+ number_answered_samples = len(predictions) // 4
885
+ for i in range(number_answered_samples):
886
+ results[i] = {
887
+ "q0_i0": extract_answer(predictions[i * 4], types[i * 4]),
888
+ "q0_i1": extract_answer(predictions[i * 4 + 1], types[i * 4 + 1]),
889
+ "q1_i0": extract_answer(predictions[i * 4 + 2], types[i * 4 + 2]),
890
+ "q1_i1": extract_answer(predictions[i * 4 + 3], types[i * 4 + 3])
891
+ }
892
+
893
+ scores = get_scores(results)
894
+ print(scores)
895
+ score_file = 'NaturalBench_acc.csv'
896
+ df = pd.DataFrame(list(scores.items()), columns=['Metric', 'Score'])
897
+ dump(df, score_file)
898
+
899
+ return scores
VLMEvalKit/vlmeval/dataset/image_mt.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_base import ImageBaseDataset
2
+ from .utils.judge_util import build_judge
3
+ from ..smp import *
4
+ from ..utils import track_progress_rich
5
+
6
+
7
+ class ImageMTDataset(ImageBaseDataset):
8
+
9
+ TYPE = 'MT'
10
+
11
+ def build_prompt(self, line):
12
+ if isinstance(line, int):
13
+ line = self.data.iloc[line]
14
+
15
+ if self.meta_only:
16
+ tgt_path = toliststr(line['image_path'])
17
+ else:
18
+ tgt_path = self.dump_image(line)
19
+
20
+ questions = toliststr(line['question'])
21
+ if 'answer' in line:
22
+ answers = toliststr(line['answer'])
23
+ else:
24
+ answers = [''] * len(questions)
25
+ assert len(questions) == len(answers)
26
+
27
+ dlgs, pics_number = [], 0
28
+ for i in range(len(questions)):
29
+ q, a = questions[i], answers[i]
30
+ if '<ImageHere>' in q:
31
+ content = []
32
+ tag_number = q.count('<ImageHere>')
33
+ images = tgt_path[pics_number: pics_number + tag_number]
34
+ pics_number += tag_number
35
+ q_split = q.split('<ImageHere>')
36
+ for i in range(tag_number):
37
+ qsp, im = q_split[i], images[i]
38
+ if qsp != '':
39
+ content.append(dict(type='text', value=qsp))
40
+ content.append(dict(type='image', value=im))
41
+ if q_split[-1] != '':
42
+ content.append(dict(type='text', value=q_split[-1]))
43
+ else:
44
+ content = [dict(type='text', value=q)]
45
+ dlgs.append(dict(role='user', content=content))
46
+ assert '<ImageHere>' not in a, 'We currently do not support images in the answer. '
47
+ content = [dict(type='text', value=a)]
48
+ dlgs.append(dict(role='assistant', content=content))
49
+ return dlgs
50
+
51
+
52
+ class MMDUDataset(ImageMTDataset):
53
+
54
+ DATASET_URL = {'MMDU': 'https://opencompass.openxlab.space/utils/VLMEval/MMDU.tsv'}
55
+ DATASET_MD5 = {'MMDU': '848b635a88a078f49aebcc6e39792061'}
56
+ DIMS = [
57
+ 'Creativity', 'Richness', 'Visual Perception', 'Logical Coherence',
58
+ 'Answer Accuracy', 'Image Relationship Understanding', 'Overall Score'
59
+ ]
60
+
61
+ def calculat_metric(self, ans):
62
+ all = defaultdict(lambda: 0)
63
+ tot = defaultdict(lambda: 0)
64
+ valid = defaultdict(lambda: 0)
65
+ for k in ans:
66
+ res = ans[k]['res']
67
+ assert isinstance(res, pd.DataFrame)
68
+ lt = len(res)
69
+ for i in range(lt):
70
+ line = res.iloc[i]
71
+ for k in self.DIMS:
72
+ tot[k] += 1
73
+ if k in line and line[k] is not None:
74
+ try:
75
+ score = int(line[k])
76
+ score = np.clip(score, 0, 10)
77
+ all[k] += score
78
+ valid[k] += 1
79
+ except Exception as e:
80
+ print(f'Failed to parse the score: {str(e)}')
81
+ sp1 = {'set': 'all'}
82
+ sp1.update({k: all[k] / tot[k] * 10 for k in self.DIMS})
83
+ sp2 = {'set': 'valid'}
84
+ sp2.update({k: all[k] / valid[k] * 10 for k in self.DIMS})
85
+
86
+ return pd.DataFrame([sp1, sp2])
87
+
88
+ def evaluate(self, eval_file, **judge_kwargs):
89
+ suffix = eval_file.split('.')[-1]
90
+ model = judge_kwargs['model']
91
+
92
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
93
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
94
+ nproc = judge_kwargs.pop('nproc', 4)
95
+
96
+ data = load(eval_file)
97
+ model = judge_kwargs.pop('model', 'gpt-4o')
98
+ judge_model = build_judge(model=model, **judge_kwargs)
99
+
100
+ lt = len(data)
101
+ lines = [data.iloc[i] for i in range(lt)]
102
+ tups = [(judge_model, line) for line in lines]
103
+ indices = [line['index'] for line in lines]
104
+
105
+ ans = {}
106
+ if osp.exists(tmp_file):
107
+ ans = load(tmp_file)
108
+
109
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
110
+ indices = [i for i in indices if i not in ans]
111
+
112
+ from .utils.mmdu import mmdu_score
113
+
114
+ if len(indices):
115
+ new_results = track_progress_rich(
116
+ mmdu_score,
117
+ tups,
118
+ nproc=nproc,
119
+ chunksize=nproc,
120
+ keys=indices,
121
+ save=tmp_file,)
122
+ ans = load(tmp_file)
123
+ for k, v in zip(indices, new_results):
124
+ assert k in ans
125
+
126
+ metric = self.calculat_metric(ans)
127
+ dump(metric, score_file)
128
+ return metric
VLMEvalKit/vlmeval/dataset/image_vqa.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import tempfile
4
+ from functools import partial
5
+
6
+ import pandas as pd
7
+
8
+ from .image_base import ImageBaseDataset
9
+ from .utils import build_judge, DEBUG_MESSAGE
10
+ from ..smp import *
11
+ from ..utils import track_progress_rich
12
+
13
+
14
+ class ImageVQADataset(ImageBaseDataset):
15
+ TYPE = 'VQA'
16
+
17
+ DATASET_URL = {
18
+ 'OCRVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TEST.tsv',
19
+ 'OCRVQA_TESTCORE': 'https://opencompass.openxlab.space/utils/VLMEval/OCRVQA_TESTCORE.tsv',
20
+ 'TextVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/TextVQA_VAL.tsv',
21
+ 'DocVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_VAL.tsv',
22
+ 'DocVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/DocVQA_TEST.tsv',
23
+ 'InfoVQA_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_VAL.tsv',
24
+ 'InfoVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/InfoVQA_TEST.tsv',
25
+ 'ChartQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/ChartQA_TEST.tsv',
26
+ 'GQA_TestDev_Balanced': 'https://opencompass.openxlab.space/utils/VLMEval/GQA_TestDev_Balanced.tsv',
27
+ }
28
+
29
+ DATASET_MD5 = {
30
+ 'OCRVQA_TEST': 'ca46a6d74b403e9d6c0b670f6fc00db9',
31
+ 'OCRVQA_TESTCORE': 'c5239fe77db8bdc1f2ad8e55e0d1fe97',
32
+ 'TextVQA_VAL': 'b233b31f551bbf4056f2f955da3a92cd',
33
+ 'DocVQA_VAL': 'd5ee77e1926ff10690d469c56b73eabf',
34
+ 'DocVQA_TEST': '6a2f28cac26ef2d3447374e8c6f6c8e9',
35
+ 'InfoVQA_VAL': '2342e9c225222f0ef4dec545ebb126fe',
36
+ 'InfoVQA_TEST': 'df535bf51b88dc9718252c34131a6227',
37
+ 'ChartQA_TEST': 'c902e0aa9be5582a7aad6dcf52734b42',
38
+ 'GQA_TestDev_Balanced': 'fead7df22befc1ed3ca2b62ea26fa17b',
39
+ }
40
+
41
+ def build_prompt(self, line):
42
+ msgs = super().build_prompt(line)
43
+ assert msgs[-1]['type'] == 'text'
44
+ msgs[-1]['value'] += '\nAnswer the question using a single word or phrase.'
45
+ return msgs
46
+
47
+ # It returns a DataFrame
48
+ def evaluate(self, eval_file, **judge_kwargs):
49
+ from .utils.vqa_eval import hit_calculate, process_line
50
+
51
+ data = load(eval_file)
52
+ dataset = self.dataset_name
53
+ assert 'answer' in data and 'prediction' in data
54
+ data['prediction'] = [str(x) for x in data['prediction']]
55
+ data['answer'] = [str(x) for x in data['answer']]
56
+ lt = len(data)
57
+ pool = mp.Pool(16)
58
+ lines = [data.iloc[i] for i in range(lt)]
59
+ if listinstr(['TextVQA'], dataset):
60
+ res = pool.map(partial(process_line, method='vqa_score'), lines)
61
+ elif listinstr(['ChartQA'], dataset):
62
+ res = pool.map(partial(process_line, method='relaxed_accuracy'), lines)
63
+ elif listinstr(['OCRVQA', 'GQA'], dataset):
64
+ res = pool.map(partial(process_line, method='accuracy'), lines)
65
+ elif listinstr(['DocVQA', 'InfoVQA'], dataset):
66
+ res = pool.map(partial(process_line, method='anls'), lines)
67
+ else: # default using vqa_score to calculate score
68
+ res = pool.map(process_line, lines)
69
+ hit = hit_calculate(res, dataset)
70
+ ret = dict()
71
+ if 'split' in data:
72
+ splits = set(data['split'])
73
+ for sp in splits:
74
+ sub = [r for l, r in zip(lines, res) if l['split'] == sp]
75
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
76
+ hit = hit_calculate(sub, dataset)
77
+ ret[sp] = np.mean(hit) * 100
78
+ sub = [r for l, r in zip(lines, res)]
79
+ hit = hit_calculate(sub, dataset)
80
+ ret['Overall'] = np.mean(hit) * 100
81
+ else:
82
+ ret['Overall'] = np.mean(hit) * 100
83
+ if 'category' in data:
84
+ cates = list(set(data['category']))
85
+ cates.sort()
86
+ for c in cates:
87
+ sub = [r for l, r in zip(lines, res) if l['category'] == c]
88
+ # [np.mean(x['match']) >= full_score_weight for x in sub]
89
+ hit = hit_calculate(sub, dataset)
90
+ ret[c] = np.mean(hit) * 100
91
+ ret = d2df(ret)
92
+ ret.round(2)
93
+
94
+ suffix = eval_file.split('.')[-1]
95
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
96
+ dump(ret, result_file)
97
+ return ret
98
+
99
+
100
+ class VizWiz(ImageBaseDataset):
101
+ TYPE = 'VQA'
102
+ DATASET_URL = {
103
+ 'VizWiz': 'https://opencompass.openxlab.space/utils/VLMEval/VizWiz.tsv'
104
+ }
105
+ DATASET_MD5 = {
106
+ 'VizWiz': 'fa4ac4164467563ed2fac6eac6631bd0'
107
+ }
108
+
109
+ @classmethod
110
+ def evaluate(self, eval_file, **judge_kwargs):
111
+ from .utils.vqa_eval import hit_calculate, process_line
112
+
113
+ suffix = eval_file.split('.')[-1]
114
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
115
+
116
+ if not osp.exists(result_file):
117
+ data = load(eval_file)
118
+ assert 'answers' in data and 'prediction' in data
119
+ data['prediction'] = [str(x) for x in data['prediction']]
120
+ data['answer'] = [str(x) for x in data['answers']]
121
+
122
+ lt = len(data)
123
+ pool = mp.Pool(16)
124
+ lines = [data.iloc[i] for i in range(lt)]
125
+ res = pool.map(process_line, lines)
126
+
127
+ hit = hit_calculate(res, 'VizWiz')
128
+ ret = dict()
129
+
130
+ ret['Overall'] = np.mean(hit) * 100
131
+ ret = d2df(ret)
132
+ ret.round(2)
133
+
134
+ dump(ret, result_file)
135
+
136
+ retz = pd.read_csv(result_file)
137
+ return retz
138
+
139
+
140
+ class OCRBench(ImageBaseDataset):
141
+ TYPE = 'VQA'
142
+ DATASET_URL = {
143
+ 'OCRBench': 'https://opencompass.openxlab.space/utils/VLMEval/OCRBench.tsv'
144
+ }
145
+ DATASET_MD5 = {'OCRBench': 'e953d98a987cc6e26ef717b61260b778'}
146
+
147
+ # It returns a dictionary
148
+ @classmethod
149
+ def evaluate(self, eval_file, **judge_kwargs):
150
+ OCRBench_score = {
151
+ 'Regular Text Recognition': 0,
152
+ 'Irregular Text Recognition': 0,
153
+ 'Artistic Text Recognition': 0,
154
+ 'Handwriting Recognition': 0,
155
+ 'Digit String Recognition': 0,
156
+ 'Non-Semantic Text Recognition': 0,
157
+ 'Scene Text-centric VQA': 0,
158
+ 'Doc-oriented VQA': 0,
159
+ 'Key Information Extraction': 0,
160
+ 'Handwritten Mathematical Expression Recognition': 0,
161
+ }
162
+
163
+ data = load(eval_file)
164
+ lt = len(data)
165
+ lines = [data.iloc[i] for i in range(lt)]
166
+ for i in tqdm(range(len(lines))):
167
+ line = lines[i]
168
+ predict = str(line['prediction'])
169
+ answers = eval(line['answer'])
170
+ category = line['category']
171
+ if category == 'Handwritten Mathematical Expression Recognition':
172
+ for j in range(len(answers)):
173
+ answer = answers[j].strip().replace('\n', ' ').replace(' ', '')
174
+ predict = predict.strip().replace('\n', ' ').replace(' ', '')
175
+ if answer in predict:
176
+ OCRBench_score[category] += 1
177
+ break
178
+ else:
179
+ for j in range(len(answers)):
180
+ answer = answers[j].lower().strip().replace('\n', ' ')
181
+ predict = predict.lower().strip().replace('\n', ' ')
182
+ if answer in predict:
183
+ OCRBench_score[category] += 1
184
+ break
185
+
186
+ final_score_dict = {}
187
+ final_score_dict['Text Recognition'] = \
188
+ (OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition']
189
+ + OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition']
190
+ + OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition'])
191
+ final_score_dict['Scene Text-centric VQA'] = OCRBench_score['Scene Text-centric VQA']
192
+ final_score_dict['Doc-oriented VQA'] = OCRBench_score['Doc-oriented VQA']
193
+ final_score_dict['Key Information Extraction'] = OCRBench_score['Key Information Extraction']
194
+ final_score_dict['Handwritten Mathematical Expression Recognition'] = \
195
+ (OCRBench_score['Handwritten Mathematical Expression Recognition'])
196
+ final_score_dict['Final Score'] = \
197
+ (final_score_dict['Text Recognition'] + final_score_dict['Scene Text-centric VQA']
198
+ + final_score_dict['Doc-oriented VQA'] + final_score_dict['Key Information Extraction']
199
+ + final_score_dict['Handwritten Mathematical Expression Recognition'])
200
+ final_score_dict['Final Score Norm'] = (float(final_score_dict['Final Score']) / 10)
201
+ score_pth = eval_file.replace('.xlsx', '_score.json')
202
+ dump(final_score_dict, score_pth)
203
+ return final_score_dict
204
+
205
+
206
+ class MathVista(ImageBaseDataset):
207
+ TYPE = 'VQA'
208
+ DATASET_URL = {
209
+ 'MathVista_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVista_MINI.tsv'
210
+ }
211
+ DATASET_MD5 = {'MathVista_MINI': 'f199b98e178e5a2a20e7048f5dcb0464'}
212
+
213
+ # It returns a DataFrame
214
+ @classmethod
215
+ def evaluate(self, eval_file, **judge_kwargs):
216
+ from .utils.mathvista import MathVista_auxeval, MathVista_acc
217
+
218
+ model = judge_kwargs['model']
219
+ suffix = eval_file.split('.')[-1]
220
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
221
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
222
+ nproc = judge_kwargs.pop('nproc', 4)
223
+
224
+ if not osp.exists(storage):
225
+ data = load(eval_file)
226
+ model = build_judge(max_tokens=128, **judge_kwargs)
227
+ assert model.working(), ('MathVista evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
228
+ lt = len(data)
229
+ lines = [data.iloc[i] for i in range(lt)]
230
+ tups = [(model, line) for line in lines]
231
+ indices = [line['index'] for line in lines]
232
+
233
+ ans = {}
234
+ if osp.exists(tmp_file):
235
+ ans = load(tmp_file)
236
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
237
+ indices = [i for i in indices if i not in ans]
238
+
239
+ if len(indices):
240
+ new_results = track_progress_rich(
241
+ MathVista_auxeval,
242
+ tups,
243
+ nproc=nproc,
244
+ chunksize=nproc,
245
+ keys=indices,
246
+ save=tmp_file,
247
+ )
248
+ ans = load(tmp_file)
249
+ for k, v in zip(indices, new_results):
250
+ assert k in ans
251
+ assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
252
+
253
+ data['res'] = [ans[idx]['res'] for idx in data['index']]
254
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
255
+ dump(data, storage)
256
+
257
+ score = MathVista_acc(storage)
258
+ score_pth = storage.replace('.xlsx', '_score.csv')
259
+ dump(score, score_pth)
260
+ return score
261
+
262
+
263
+ class MathVerse(ImageBaseDataset):
264
+ TYPE = 'VQA'
265
+ DATASET_URL = {
266
+ 'MathVerse_MINI': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIV.tsv', # noqa
267
+ 'MathVerse_MINI_Vision_Only': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVOnly.tsv', # noqa
268
+ 'MathVerse_MINI_Vision_Dominant': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVDom.tsv', # noqa
269
+ 'MathVerse_MINI_Vision_Intensive': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINIVInt.tsv', # noqa
270
+ 'MathVerse_MINI_Text_Lite': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINITLite.tsv', # noqa
271
+ 'MathVerse_MINI_Text_Dominant': 'http://opencompass.openxlab.space/utils/benchmarks/MathVerse/MathVerse_MINITDom.tsv', # noqa
272
+ }
273
+ DATASET_MD5 = {
274
+ 'MathVerse_MINI': '5017caca32b7fa110c350a1bea861b65',
275
+ 'MathVerse_MINI_Vision_Only': '68a11d4680014ac881fa37adeadea3a4',
276
+ 'MathVerse_MINI_Vision_Dominant': 'b8fb63852d261ab2aaefba29cc2414d3',
277
+ 'MathVerse_MINI_Vision_Intensive': '01cbd35be202bb0c4873a4186a63bc19',
278
+ 'MathVerse_MINI_Text_Lite': '19e4b13bdd30b89a03b2e358bcfefa04',
279
+ 'MathVerse_MINI_Text_Dominant': '4f5cd2fa6630ea00bb11d6fde1f6fe6a',
280
+ }
281
+
282
+ # It returns a DataFrame
283
+ @classmethod
284
+ def evaluate(self, eval_file, **judge_kwargs):
285
+ from .utils.mathverse import MathVerse_auxeval_extract, MathVerse_auxeval_score, MathVerse_acc
286
+
287
+ model = judge_kwargs['model']
288
+ suffix = eval_file.split('.')[-1]
289
+ storage_extract = eval_file.replace(f'.{suffix}', f'_{model}_extract.xlsx')
290
+ tmp_file_extract = eval_file.replace(f'.{suffix}', f'_{model}_extract.pkl')
291
+ storage_score = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
292
+ tmp_file_score = eval_file.replace(f'.{suffix}', f'_{model}_score.pkl')
293
+ nproc = judge_kwargs.pop('nproc', 4)
294
+ # stage1: extract the answer
295
+ if not osp.exists(storage_extract):
296
+ data = load(eval_file)
297
+ model = build_judge(max_tokens=128, **judge_kwargs)
298
+ assert model.working(), ('MathVerse evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
299
+ lt = len(data)
300
+ lines = [data.iloc[i] for i in range(lt)]
301
+ tups = [(model, line) for line in lines]
302
+ indices = [line['index'] for line in lines]
303
+
304
+ ans = {}
305
+ if osp.exists(tmp_file_extract):
306
+ ans = load(tmp_file_extract)
307
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
308
+ indices = [i for i in indices if i not in ans]
309
+
310
+ if len(indices):
311
+ new_results = track_progress_rich(
312
+ MathVerse_auxeval_extract,
313
+ tups,
314
+ nproc=nproc,
315
+ chunksize=nproc,
316
+ keys=indices,
317
+ save=tmp_file_extract,
318
+ )
319
+ ans = load(tmp_file_extract)
320
+ for k, v in zip(indices, new_results):
321
+ assert k in ans
322
+ assert ans[k]['log_extract'] == v['log_extract'] and ans[k]['extract'] == v['extract']
323
+
324
+ data['extract'] = [ans[idx]['extract'] for idx in data['index']]
325
+ data['log_extract'] = [ans[idx]['log_extract'] for idx in data['index']]
326
+ dump(data, storage_extract)
327
+
328
+ # stage2: score the answer
329
+ if not osp.exists(storage_score):
330
+ data = load(storage_extract)
331
+ model = build_judge(max_tokens=128, **judge_kwargs)
332
+ assert model.working(), ('MathVerse evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
333
+ lt = len(data)
334
+ lines = [data.iloc[i] for i in range(lt)]
335
+ tups = [(model, line) for line in lines]
336
+ indices = [line['index'] for line in lines]
337
+
338
+ ans = {}
339
+ if osp.exists(tmp_file_score):
340
+ ans = load(tmp_file_score)
341
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
342
+ indices = [i for i in indices if i not in ans]
343
+
344
+ if len(indices):
345
+ new_results = track_progress_rich(
346
+ MathVerse_auxeval_score,
347
+ tups,
348
+ nproc=nproc,
349
+ chunksize=nproc,
350
+ keys=indices,
351
+ save=tmp_file_score,
352
+ )
353
+ ans = load(tmp_file_score)
354
+ for k, v in zip(indices, new_results):
355
+ assert k in ans
356
+ assert ans[k]['log_score'] == v['log_score'] and ans[k]['score'] == v['score']
357
+
358
+ data['score'] = [ans[idx]['score'] for idx in data['index']]
359
+ data['log_score'] = [ans[idx]['log_score'] for idx in data['index']]
360
+ dump(data, storage_score)
361
+
362
+ score = MathVerse_acc(storage_score)
363
+ score_pth = storage_score.replace('.xlsx', '.csv')
364
+ dump(score, score_pth)
365
+ return score
366
+
367
+
368
+ class MathVision(ImageBaseDataset):
369
+ TYPE = 'VQA'
370
+ DATASET_URL = {
371
+ 'MathVision': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv',
372
+ 'MathVision_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv'
373
+ }
374
+ DATASET_MD5 = {
375
+ 'MathVision': '93f6de14f7916e598aa1b7165589831e',
376
+ 'MathVision_MINI': '060fe4fa5d868987ce179307bd5f8a33'
377
+ }
378
+
379
+ # It returns a DataFrame
380
+ @classmethod
381
+ def evaluate(self, eval_file, **judge_kwargs):
382
+ from .utils.mathv import MATH_V_auxeval, MATH_V_acc
383
+
384
+ if 'model' in judge_kwargs:
385
+ model = judge_kwargs['model']
386
+ else:
387
+ model = os.path.basename(os.environ.get('LOCAL_LLM'))
388
+ suffix = eval_file.split('.')[-1]
389
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
390
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
391
+ nproc = judge_kwargs.pop('nproc', 4)
392
+
393
+ if not osp.exists(storage):
394
+ data = load(eval_file)
395
+ model = build_judge(max_tokens=128, **judge_kwargs)
396
+ assert model.working(), ('MATH-Vision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
397
+ lt = len(data)
398
+ lines = [data.iloc[i] for i in range(lt)]
399
+ tups = [(model, line) for line in lines]
400
+ indices = [line['index'] for line in lines]
401
+
402
+ ans = {}
403
+ if osp.exists(tmp_file):
404
+ ans = load(tmp_file)
405
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
406
+ indices = [i for i in indices if i not in ans]
407
+
408
+ if len(indices):
409
+ new_results = track_progress_rich(
410
+ MATH_V_auxeval,
411
+ tups,
412
+ nproc=nproc,
413
+ chunksize=nproc,
414
+ keys=indices,
415
+ save=tmp_file,
416
+ )
417
+ ans = load(tmp_file)
418
+ for k, v in zip(indices, new_results):
419
+ assert k in ans
420
+ assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
421
+
422
+ data['res'] = [ans[idx]['res'] for idx in data['index']]
423
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
424
+ dump(data, storage)
425
+
426
+ score = MATH_V_acc(storage)
427
+ score_pth = storage.replace('.xlsx', '_score.csv')
428
+ dump(score, score_pth)
429
+ return score
430
+
431
+
432
+ class OlympiadBench(ImageBaseDataset):
433
+ TYPE = 'VQA_ex_prompt'
434
+ DATASET_URL = {
435
+ 'OlympiadBench': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench.tsv',
436
+ 'OlympiadBench_EN': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench_EN.tsv',
437
+ 'OlympiadBench_CN': 'https://opencompass.openxlab.space/utils/VLMEval/OlympiadBench_CN.tsv'
438
+ }
439
+ DATASET_MD5 = {
440
+ 'OlympiadBench': '9735ae0f0299eae1e7d07f5a7feab914',
441
+ 'OlympiadBench_EN': '5c68e100d394351fc7049f29d4d4efed',
442
+ 'OlympiadBench_CN': 'ea01b16788955702c79650c701e5b623'
443
+ }
444
+
445
+ def dump_image(self, line):
446
+ os.makedirs(self.img_root, exist_ok=True)
447
+
448
+ tgt_path_z = []
449
+ if isinstance(line['image'], list):
450
+ for i in range(len(line['image'])):
451
+ tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
452
+ if not read_ok(tgt_path):
453
+ decode_base64_to_image_file(line['image'][i], tgt_path)
454
+ tgt_path_z.append(tgt_path)
455
+ else:
456
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
457
+ if not read_ok(tgt_path):
458
+ decode_base64_to_image_file(line['image'], tgt_path)
459
+ tgt_path_z.append(tgt_path)
460
+ return tgt_path_z
461
+
462
+ def build_prompt(self, line):
463
+
464
+ from .utils.olympiadbench import get_answer_type_text, make_input
465
+
466
+ self.is_chinese = 'zh' in line['source']
467
+ self.is_math = 'maths' in line['source']
468
+ self.is_theorem_proving = 'TP' in line['source']
469
+
470
+ if self.is_chinese:
471
+ subject_content = '数学' if self.is_math else '物理'
472
+ if self.is_theorem_proving:
473
+ prompt = (
474
+ f"以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,运用逻辑推理及常用定理证明题目中的命题。"
475
+ "证明过程中使用的变量和公式请使用LaTeX格式表示。"
476
+ )
477
+ else:
478
+ answer_type_text = get_answer_type_text(line['answer_type'], is_chinese=True,
479
+ multiple_answer=line['is_multiple_answer'])
480
+ if line['is_multiple_answer']:
481
+ multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}'
482
+ else:
483
+ multiple_answer_text = '\\boxed{答案}'
484
+ unit_text = ''
485
+ if line['unit']:
486
+ multiple_answer_text += '(单位)'
487
+ unit_text = ',注意答案的单位不要放在\\boxed{}中'
488
+ prompt = (
489
+ f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。请根据题目的要求和所提供的信息计算得出答案。'
490
+ f'解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以“所以最终答案是{multiple_answer_text}。”'
491
+ f'显式给出结果{unit_text}。'
492
+ )
493
+ else:
494
+ subject_content = 'Math' if self.is_math else 'Physics'
495
+ if self.is_theorem_proving:
496
+ prompt = (
497
+ f'The following is a theorem proving problem from an International {subject_content} competition. '
498
+ 'Please use logical reasoning and common theorems to prove the proposition in the problem '
499
+ 'according to the given requirements. '
500
+ 'Please use LaTeX format to represent the variables and formulas used in the proof.'
501
+ )
502
+ else:
503
+ if line['is_multiple_answer']:
504
+ multiple_answer_text = '\\boxed{multiple answers connected with commas}'
505
+ else:
506
+ multiple_answer_text = '\\boxed{answer}'
507
+ unit_text = ''
508
+ if line['unit']:
509
+ multiple_answer_text += '(unit)'
510
+ unit_text = ', note that the unit of the answer should not be included in \\boxed{}'
511
+ answer_type_text = get_answer_type_text(line['answer_type'], is_chinese=False,
512
+ multiple_answer=line['is_multiple_answer'])
513
+ prompt = (
514
+ f'The following is an open-ended problem from an International {subject_content} competition. '
515
+ f'{answer_type_text}Please calculate the answer according to the given requirements and '
516
+ 'the information provided. Please use LaTeX format to represent the variables and formulas '
517
+ 'used in the solution process and results. Please end your solution with "So the final answer '
518
+ f'is {multiple_answer_text}." and give the result explicitly{unit_text}.'
519
+ )
520
+
521
+ if self.is_math:
522
+ input = make_input(prompt, line['question'])
523
+ else:
524
+ if 'context' in line.keys() and str(line['context']) != 'nan': # cannot be null
525
+ input = make_input(prompt, line['context'] + '\n' + line['question'])
526
+ else:
527
+ input = make_input(prompt, line['question'])
528
+
529
+ ret = [dict(type='text', value=input)]
530
+ tgt_path = self.dump_image(line)
531
+
532
+ ret.extend([dict(type='image', value=s) for s in tgt_path])
533
+
534
+ return ret
535
+
536
+ @classmethod
537
+ def evaluate(self, eval_file, **judge_kwargs):
538
+ from .utils.olympiadbench import MathJudger, extract_answer
539
+ judger = MathJudger()
540
+
541
+ suffix = eval_file.split('.')[-1]
542
+ name_str1 = 'judge'
543
+ name_str2 = 'score'
544
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str1}_result.xlsx')
545
+ score_file = eval_file.replace(f'.{suffix}', f'_{name_str2}_result.csv')
546
+
547
+ if not osp.exists(result_file):
548
+ data = load(eval_file)
549
+ scorez = []
550
+
551
+ for i in tqdm(data.iterrows()):
552
+ line = i[1]
553
+ model_answer = line['prediction']
554
+ is_chinese = 'zh' in line['source']
555
+ model_answer = extract_answer(is_chinese, model_answer, is_deepseek=False)
556
+ answer_type = line['answer_type']
557
+
558
+ final_answer = line['final_answer'][2:-2]
559
+
560
+ if str(answer_type) != 'nan' and 'Tuple' in answer_type:
561
+ judge_result = judger.judge(model_answer, final_answer)
562
+ else:
563
+ if str(line['error']) != 'nan':
564
+ if ',' in line['error']:
565
+ precisions = line['error'].split(',')
566
+ precisions = [float(p) if p else 1e-8 for p in precisions]
567
+ judge_result = judger.judge(model_answer, final_answer, precisions)
568
+ else:
569
+ precision = float(line['error'])
570
+ judge_result = judger.judge(model_answer, final_answer, precision)
571
+ else:
572
+ judge_result = judger.judge(model_answer, final_answer)
573
+ scorez.append(judge_result)
574
+
575
+ data['score'] = scorez
576
+ dump(data, result_file)
577
+
578
+ judge_file = load(result_file)
579
+
580
+ if not osp.exists(score_file):
581
+ name_list = ['OE_MM_maths_en_COMP', 'OE_MM_maths_zh_CEE', 'OE_MM_maths_zh_COMP', 'OE_MM_physics_en_COMP',
582
+ 'OE_MM_physics_zh_CEE','OE_TO_maths_en_COMP', 'OE_TO_maths_zh_CEE', 'OE_TO_maths_zh_COMP',
583
+ 'OE_TO_physics_en_COMP', 'OE_TO_physics_zh_CEE']
584
+
585
+ sample_list = [[] for _ in range(len(name_list))]
586
+ for i in judge_file.iterrows():
587
+ line = i[1]
588
+ for j in range(len(name_list)):
589
+ if line['source'] == name_list[j]:
590
+ sample_list[j].append(line['score'])
591
+
592
+ acc_dict = {}
593
+ correct_list = []
594
+
595
+ # fine-grained
596
+ for i in range(len(name_list)):
597
+ correct_num = 0
598
+ for j in sample_list[i]:
599
+ if j:
600
+ correct_num += 1
601
+ correct_list.append(correct_num)
602
+ acc = 100 * correct_num / len(sample_list[i])
603
+ acc_dict[name_list[i]] = [acc]
604
+
605
+ # 4 grained
606
+ labela = ['zh', 'en']
607
+ labelb = ['maths', 'physics']
608
+
609
+ grain_list = [[x,y] for x in labela for y in labelb]
610
+ for j in grain_list:
611
+ dict_name = j[0] + "_" + j[1]
612
+ correct_num = 0
613
+ full_num = 0
614
+ for i in range(len(name_list)):
615
+ if all(k in name_list[i] for k in j):
616
+ correct_num += correct_list[i]
617
+ full_num += len(sample_list[i])
618
+ acc = 100 * correct_num / full_num
619
+ acc_dict[dict_name] = [acc]
620
+
621
+ # 2 grained
622
+ grain_list = ['maths', 'physics']
623
+ for j in grain_list:
624
+ dict_name = j
625
+ correct_num = 0
626
+ full_num = 0
627
+ for i in range(len(name_list)):
628
+ if j in name_list[i]:
629
+ correct_num += correct_list[i]
630
+ full_num += len(sample_list[i])
631
+ acc = 100 * correct_num / full_num
632
+ acc_dict[dict_name] = [acc]
633
+
634
+ # AVG
635
+ correct_num = sum(correct_list)
636
+ acc = 100 * correct_num / len(judge_file)
637
+ acc_dict['AVG'] = [acc]
638
+
639
+ acc_pd = pd.DataFrame(acc_dict)
640
+ acc_pd.to_csv(score_file, index=False, encoding='gbk')
641
+
642
+ accdz = pd.read_csv(score_file)
643
+ return accdz
644
+
645
+
646
+ class LLaVABench(ImageBaseDataset):
647
+ TYPE = 'VQA'
648
+ DATASET_URL = {'LLaVABench': 'https://opencompass.openxlab.space/utils/VLMEval/LLaVABench.tsv'}
649
+ DATASET_MD5 = {'LLaVABench': 'd382a093f749a697820d3dadd61c8428'}
650
+
651
+ # It returns a DataFrame
652
+ @classmethod
653
+ def evaluate(self, eval_file, **judge_kwargs):
654
+ from .utils.llavabench import (
655
+ build_prompt,
656
+ LLaVABench_atomeval,
657
+ LLaVABench_score,
658
+ )
659
+
660
+ suffix = '.' + eval_file.split('.')[-1]
661
+ record_file = eval_file.replace(suffix, '_openai_result' + suffix)
662
+ score_file = eval_file.replace(suffix, '_score.csv')
663
+ nproc = judge_kwargs.pop('nproc', 4)
664
+ system_prompt = 'You are a helpful and precise assistant for checking the quality of the answer.'
665
+
666
+ if not osp.exists(record_file):
667
+ data = load(eval_file)
668
+ lines = [data.iloc[i] for i in range(len(data))]
669
+ model = build_judge(temperature=0.2, system_prompt=system_prompt, **judge_kwargs)
670
+ assert model.working(), ('LLaVABench evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
671
+
672
+ prompts = [build_prompt(line) for line in lines]
673
+ tups = [(model, prompt) for prompt in prompts]
674
+ scores = track_progress_rich(LLaVABench_atomeval, tups, nproc=nproc, chunksize=nproc)
675
+ data['gpt4_score'] = [x[0] for x in scores]
676
+ data['score'] = [x[1] for x in scores]
677
+ dump(data, record_file)
678
+
679
+ data = load(record_file)
680
+ ret = LLaVABench_score(data).round(1)
681
+ dump(ret, score_file)
682
+ return ret
683
+
684
+
685
+ class MMVet(ImageBaseDataset):
686
+ TYPE = 'VQA'
687
+ DATASET_URL = {
688
+ 'MMVet': 'https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv'
689
+ }
690
+ DATASET_MD5 = {'MMVet': '748aa6d4aa9d4de798306a63718455e3'}
691
+
692
+ # It returns a DataFrame
693
+ @classmethod
694
+ def evaluate(self, eval_file, **judge_kwargs):
695
+ from .utils.mmvet import MMVet_auxeval, MMVet_acc
696
+
697
+ suffix = eval_file.split('.')[-1]
698
+ model = judge_kwargs['model']
699
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
700
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
701
+ nproc = judge_kwargs.pop('nproc', 4)
702
+ if not osp.exists(storage):
703
+ data = load(eval_file)
704
+ model = build_judge(max_tokens=3, **judge_kwargs)
705
+ assert model.working(), ('MMVet evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
706
+
707
+ lt = len(data)
708
+ lines = [data.iloc[i] for i in range(lt)]
709
+ tups = [(model, line) for line in lines]
710
+ indices = [line['index'] for line in lines]
711
+
712
+ ans = load(tmp_file) if osp.exists(tmp_file) else {}
713
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
714
+ indices = [i for i in indices if i not in ans]
715
+
716
+ if len(indices):
717
+ new_results = track_progress_rich(
718
+ MMVet_auxeval,
719
+ tups,
720
+ nproc=nproc,
721
+ chunksize=nproc,
722
+ keys=indices,
723
+ save=tmp_file,
724
+ )
725
+ ans = load(tmp_file)
726
+ for k, v in zip(indices, new_results):
727
+ assert k in ans
728
+ assert ans[k]['log'] == v['log'] and ans[k]['score'] == v['score']
729
+ data['score'] = [ans[idx]['score'] for idx in data['index']]
730
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
731
+ dump(data, storage)
732
+
733
+ score, score_fine = MMVet_acc(storage)
734
+ score_pth = storage.replace('.xlsx', '_score.csv')
735
+ score_fine_pth = storage.replace('.xlsx', '_score_fine.csv')
736
+ dump(score, score_pth)
737
+ dump(score_fine, score_fine_pth)
738
+ return score
739
+
740
+
741
+ class MTVQADataset(ImageBaseDataset):
742
+ TYPE = 'VQA'
743
+ DATASET_URL = {'MTVQA_TEST': 'https://opencompass.openxlab.space/utils/VLMEval/MTVQA_TEST.tsv'}
744
+ DATASET_MD5 = {'MTVQA_TEST': 'd87c17dbab934b7cd89c0a3c1c5657f4'}
745
+
746
+ @classmethod
747
+ def evaluate(self, eval_file, **judge_kwargs):
748
+ data = load(eval_file)
749
+ assert 'answer' in data and 'prediction' in data and 'category' in data
750
+ data['prediction'] = [str(x) for x in data['prediction']]
751
+ data['answer'] = [str(x) for x in data['answer']]
752
+ if 'split' in data:
753
+ assert np.all([x.lower() == 'test' for x in data['split']]), 'We only support MTVQA_TEST for now. '
754
+ lt = len(data)
755
+ category_scores = defaultdict(list)
756
+ for i in range(lt):
757
+ line = data.iloc[i]
758
+ ans = line['answer'].strip().lower().replace('.', '')
759
+ pred = line['prediction'].strip().lower().replace('.', '')
760
+ cate = line['category']
761
+ score = 1.0 if ans in pred else 0.0
762
+ category_scores[cate].append(score)
763
+ category_scores['Average'].append(score)
764
+ # Calculate the average score for each category, the score is normalized to [0, 100]
765
+ category_averages = {category: np.mean(scores) * 100 for category, scores in category_scores.items()}
766
+
767
+ suffix = eval_file.split('.')[-1]
768
+ result_file = eval_file.replace(f'.{suffix}', '_acc.json')
769
+ dump(category_averages, result_file)
770
+
771
+ return category_averages
772
+
773
+ # MT-VQA adopts a custom prompt
774
+ def build_prompt(self, line):
775
+ msgs = super().build_prompt(line)
776
+ assert sum([x['type'] == 'text' for x in msgs]) == 1
777
+ for item in msgs:
778
+ if item['type'] == 'text':
779
+ item['value'] += '\nAnswer the question using a word or phrase in the language of the question.'
780
+ return msgs
781
+
782
+
783
+ class TableVQABench(ImageBaseDataset):
784
+ TYPE = 'VQA'
785
+ DATASET_URL = {
786
+ 'TableVQABench': 'https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/mentor-vil/datasets/tablevqa-bench.tsv'
787
+ }
788
+ DATASET_MD5 = {'TableVQABench': '2550adc61bdc82d8e62f3b003de7c62d'}
789
+
790
+ from .utils.tablevqabench import FINTABNETQA_PROMPT, VTABFACT_PROMPT, VWTQ_PROMPT
791
+
792
+ # It returns a DataFrame
793
+ @classmethod
794
+ def evaluate(self, eval_file, **judge_kwargs):
795
+ import pandas as pd
796
+ from .utils.tablevqabench import evaluate_fintabnet, evaluate_tabfact, evaluate_wtq
797
+
798
+ data = load(eval_file)
799
+ assert 'answer' in data and 'prediction' in data
800
+
801
+ data['prediction'] = data['prediction'].str.replace('^Answer: ', '', regex=True)
802
+ data_group = dict(tuple(data.groupby('split')))
803
+ eval_result = {'split': [], 'average_scores': []}
804
+ for split in ['fintabnetqa', 'vtabfact', 'vwtq', 'vwtq_syn']:
805
+ data_split = data_group[split].to_dict(orient='records')
806
+ if split == 'fintabnetqa':
807
+ split_eval_meta = evaluate_fintabnet(data_split, ['accuracy'])
808
+ elif split == 'vtabfact':
809
+ split_eval_meta = evaluate_tabfact(data_split, ['accuracy'])
810
+ elif split == 'vwtq' or split == 'vwtq_syn':
811
+ split_eval_meta = evaluate_wtq(data_split, ['accuracy'])
812
+ eval_result['split'].append(split)
813
+ eval_result['average_scores'].append(split_eval_meta['average_scores'])
814
+
815
+ suffix = eval_file.split('.')[-1]
816
+ result_file = eval_file.replace(f'.{suffix}', '_acc.csv')
817
+ eval_result = pd.DataFrame(eval_result)
818
+ dump(eval_result, result_file)
819
+
820
+ return eval_result
821
+
822
+ # TableVQABench adopts a custom prompt
823
+ def build_prompt(self, line):
824
+ msgs = super().build_prompt(line)
825
+ assert sum([x['type'] == 'text' for x in msgs]) == 1
826
+ for item in msgs:
827
+ if item['type'] == 'text':
828
+ if line['split'] == 'fintabnetqa':
829
+ item['value'] = self.FINTABNETQA_PROMPT.format_map({'question': item['value']})
830
+ elif line['split'] == 'vtabfact':
831
+ item['value'] = self.VTABFACT_PROMPT.format_map({'question': item['value']})
832
+ elif line['split'] == 'vwtq_syn' or line['split'] == 'vwtq':
833
+ item['value'] = self.VWTQ_PROMPT.format_map({'question': item['value']})
834
+ return msgs
835
+
836
+
837
+ class CustomVQADataset(ImageBaseDataset):
838
+ TYPE = 'VQA'
839
+
840
+ def load_data(self, dataset):
841
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
842
+
843
+ if file_size(data_path, 'GB') > 1:
844
+ local_path = data_path.replace('.tsv', '_local.tsv')
845
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
846
+ from ..tools import LOCALIZE
847
+
848
+ LOCALIZE(data_path, local_path)
849
+ data_path = local_path
850
+ return load(data_path)
851
+
852
+ def evaluate(self, eval_file, **judge_kwargs):
853
+ raise NotImplementedError
854
+
855
+
856
+ class CRPE(ImageBaseDataset):
857
+ TYPE = 'VQA'
858
+ DATASET_URL = {
859
+ 'CRPE_EXIST': 'https://huggingface.co/datasets/petter12321/crpe_vlmevalkit/resolve/main/CRPE_EXIST.tsv',
860
+ 'CRPE_RELATION': 'https://huggingface.co/datasets/petter12321/crpe_vlmevalkit/resolve/main/CRPE_RELATION.tsv'
861
+ }
862
+ DATASET_MD5 = {
863
+ 'CRPE_EXIST': '315584e23ac1ff7f8719ed3b7ad90f08',
864
+ 'CRPE_RELATION': 'bad7094cde0b572288f4b119c2d0c656'}
865
+
866
+ @classmethod
867
+ def evaluate(self, eval_file, **judge_kwargs):
868
+ from .utils.crpe import is_correct
869
+ # find-image, count-text, find-text,
870
+ # infer-choose, count-image, visual-reasoning
871
+ score = {
872
+ 'exist': 0,
873
+ 'subject': 0,
874
+ 'predicate': 0,
875
+ 'object': 0,
876
+ 'total': 0,
877
+ }
878
+ num = {
879
+ 'exist': 0,
880
+ 'subject': 0,
881
+ 'predicate': 0,
882
+ 'object': 0,
883
+ 'total': 0,
884
+ }
885
+ final_score_dict = {
886
+ 'exist': 0,
887
+ 'subject': 0,
888
+ 'predicate': 0,
889
+ 'object': 0,
890
+ 'total': 0,
891
+ }
892
+ data = load(eval_file)
893
+ lt = len(data)
894
+ lines = [data.iloc[i] for i in range(lt)]
895
+ for i in tqdm(range(len(lines))):
896
+ line = lines[i]
897
+ predict = str(line['prediction'])
898
+ answers = str(line['answer'])
899
+ # print("predict =", predict)
900
+ # print("answers =", answers)
901
+ category = line['category']
902
+ if is_correct(answers, predict):
903
+ score[category] += 1
904
+ score['total'] += 1
905
+ num[category] += 1
906
+ num['total'] += 1
907
+
908
+ for category in ['exist', 'subject', 'predicate', 'object', 'total']:
909
+ if num[category] != 0:
910
+ final_score_dict[category] = score[category] / num[category]
911
+ else:
912
+ final_score_dict[category] = None
913
+
914
+ score_pth = eval_file.replace('.xlsx', '_score.json')
915
+ dump(final_score_dict, score_pth)
916
+ return final_score_dict
917
+
918
+ def build_prompt(self, line):
919
+ ROOT = LMUDataRoot()
920
+ msgs = super().build_prompt(line)
921
+ for msg in msgs:
922
+ if msg['type'] == 'image':
923
+ msg['value'] = osp.join(osp.join(ROOT, 'images', self.dataset_name), msg['value'])
924
+ return msgs
925
+
926
+
927
+ class QSpatial(ImageBaseDataset):
928
+ TYPE = 'VQA'
929
+ DATASET_URL = {
930
+ 'QSpatial_plus': '',
931
+ 'QSpatial_scannet': ''
932
+ }
933
+
934
+ # NOTE: To evaluate Q-Spatial-ScanNet, you need to get the permission from ScanNet website
935
+ # Once you get the permission, you can use the helper code here to download and extract necessary images:
936
+ # https://github.com/andrewliao11/Q-Spatial-Bench-code?tab=readme-ov-file#for-qspatial_scannet
937
+ qspatial_root = "TO_BE_REPLACED_WITH_THE_PATH_TO_QSPATIAL_DATASET"
938
+ url = "https://raw.githubusercontent.com/andrewliao11/Q-Spatial-Bench-code/refs/heads/main/prompt_templates/"
939
+
940
+ def post_build(self, dataset):
941
+ # Download the prompt templates from github
942
+
943
+ links = [
944
+ self.url + "system_prompt.txt",
945
+ self.url + "spatial_prompt_single.txt",
946
+ self.url + "spatial_prompt_steps.txt",
947
+ self.url + "standard_prompt.txt",
948
+ self.url + "zero_shot_prompt.txt"
949
+ ]
950
+ with tempfile.TemporaryDirectory() as temp_dir:
951
+ for link in links:
952
+ tgt_path = os.path.join(temp_dir, link.split("/")[-1])
953
+ os.system(f"wget {link} -O {tgt_path}")
954
+
955
+ self.system_prompt = open(os.path.join(temp_dir, "system_prompt.txt")).read()
956
+ self._prompt_templates = dict(
957
+ spatial_prompt_single=open(os.path.join(temp_dir, "spatial_prompt_single.txt")).read(),
958
+ spatial_prompt_steps=open(os.path.join(temp_dir, "spatial_prompt_steps.txt")).read(),
959
+ standard_prompt=open(os.path.join(temp_dir, "standard_prompt.txt")).read(),
960
+ zero_shot_prompt=open(os.path.join(temp_dir, "zero_shot_prompt.txt")).read(),
961
+ )
962
+
963
+ # Given one data record, return the built prompt (a multi-modal message), can override
964
+ def build_prompt(self, line):
965
+ from jinja2.sandbox import SandboxedEnvironment
966
+ text_prompt_template = self._prompt_templates["spatial_prompt_single"]
967
+ env = SandboxedEnvironment()
968
+ text_prompt = env.from_string(text_prompt_template).render(question=line["question"])
969
+ tgt_path = self.dump_image(line)
970
+
971
+ msgs = []
972
+ if isinstance(tgt_path, list):
973
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
974
+ else:
975
+ msgs = [dict(type='image', value=tgt_path)]
976
+
977
+ msgs.append(dict(type='text', value=f"{self.system_prompt}\n{text_prompt}"))
978
+ return msgs
979
+
980
+ # Given the dataset name, return the dataset as a pandas dataframe, can override
981
+ def load_data(self, dataset):
982
+ import io
983
+ import pandas as pd
984
+ from datasets import load_dataset
985
+
986
+ hf_dataset = load_dataset("andrewliao11/Q-Spatial-Bench", split=dataset)
987
+ df = hf_dataset.to_pandas()
988
+
989
+ df.reset_index(drop=True, inplace=True)
990
+ df['index'] = df.index
991
+ df['answer'] = list(zip(df['answer_value'], df['answer_unit']))
992
+ df = df[['index'] + [col for col in df.columns if col != 'index']]
993
+
994
+ if dataset == "QSpatial_scannet":
995
+ df = df.drop(columns=["image"])
996
+ df["image"] = [Image.open(os.path.join(self.qspatial_root, image_path)) for image_path in df["image_path"]]
997
+ else:
998
+ df["image"] = [Image.open(io.BytesIO(image_dict["bytes"])) for image_dict in df["image"]]
999
+
1000
+ df["image"] = [encode_image_to_base64(image) for image in df["image"]]
1001
+ return df
1002
+
1003
+ @classmethod
1004
+ def get_multiplier(self, unit):
1005
+
1006
+ unit = unit.lower()
1007
+ if unit in ["meters", "meter", "m", "metre", "metres"]:
1008
+ multiplier = 100
1009
+ elif unit in ["centimeters", "centimeter", "cm"]:
1010
+ multiplier = 1
1011
+ elif unit in ["feet", "foot", "ft"]:
1012
+ multiplier = 30.48
1013
+ elif unit in ["inch", "inches", "in"]:
1014
+ multiplier = 2.54
1015
+ elif unit in ["mm"]:
1016
+ multiplier = 0.1
1017
+ else:
1018
+ print(f"Unknown unit: {unit}")
1019
+ multiplier = 0.
1020
+
1021
+ return multiplier
1022
+
1023
+ @classmethod
1024
+ def parse_string(self, input_str):
1025
+ # Regular expression to match the pattern (number or range, text)
1026
+ match = re.match(r'\(([\d.-]+), (.+)\)', input_str)
1027
+ if match:
1028
+ number_part = match.group(1)
1029
+ text = match.group(2)
1030
+
1031
+ if '-' in number_part:
1032
+ start, end = map(float, number_part.split('-'))
1033
+ number = (start + end) / 2
1034
+ else:
1035
+ number = float(number_part)
1036
+
1037
+ return number * self.get_multiplier(text)
1038
+ else:
1039
+ print(f"Unable to parse the input string {input_str}")
1040
+ return 0
1041
+
1042
+ @classmethod
1043
+ def parse_prediction(self, vlm_response):
1044
+ # Value
1045
+ pattern = r'scalar{([^}]*)}'
1046
+ str_inside_scalar_boxes = re.findall(pattern, vlm_response)[-1]
1047
+ scalar_list = re.findall(r'\d+\.?\d*', str_inside_scalar_boxes)
1048
+ parsed_scalar = np.array(scalar_list).astype(float).mean()
1049
+
1050
+ # Unit
1051
+ pattern = r'distance_unit{([^}]*)}'
1052
+ str_inside_unit_boxes = re.findall(pattern, vlm_response)
1053
+ parsed_unit = str_inside_unit_boxes[-1]
1054
+
1055
+ pred_value_in_cms = parsed_scalar * self.get_multiplier(parsed_unit)
1056
+ return pred_value_in_cms
1057
+
1058
+ # It returns a dictionary
1059
+ @classmethod
1060
+ def evaluate(self, eval_file, **judge_kwargs):
1061
+
1062
+ data = load(eval_file)
1063
+ if "model" in judge_kwargs:
1064
+ from .utils.qspatial import QSpatial_auxeval
1065
+
1066
+ # extract using model
1067
+ model = judge_kwargs['model']
1068
+ suffix = eval_file.split('.')[-1]
1069
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
1070
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
1071
+ nproc = judge_kwargs.pop('nproc', 4)
1072
+
1073
+ if not osp.exists(storage):
1074
+ model = build_judge(max_tokens=128, **judge_kwargs)
1075
+
1076
+ assert model.working(), ('Evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
1077
+ lt = len(data)
1078
+ lines = [data.iloc[i] for i in range(lt)]
1079
+ tups = [(model, line) for line in lines]
1080
+ indices = [line['index'] for line in lines]
1081
+
1082
+ ans = {}
1083
+ if osp.exists(tmp_file):
1084
+ ans = load(tmp_file)
1085
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
1086
+ indices = [i for i in indices if i not in ans]
1087
+
1088
+ if len(indices):
1089
+ new_results = track_progress_rich(
1090
+ QSpatial_auxeval,
1091
+ tups,
1092
+ nproc=nproc,
1093
+ chunksize=nproc,
1094
+ keys=indices,
1095
+ save=tmp_file,
1096
+ )
1097
+ ans = load(tmp_file)
1098
+ for k, v in zip(indices, new_results):
1099
+ assert k in ans
1100
+ assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']
1101
+
1102
+ data['res'] = [ans[idx]['res'] for idx in data['index']]
1103
+ data['log'] = [ans[idx]['log'] for idx in data['index']]
1104
+ dump(data, storage)
1105
+
1106
+ data = load(storage)
1107
+
1108
+ pred_value_in_cms = []
1109
+ for res in data["res"]:
1110
+ try:
1111
+ pred_value_in_cms.append(self.parse_string(res))
1112
+ except ValueError:
1113
+ pred_value_in_cms.append(0.)
1114
+
1115
+ pred_value_in_cms = np.array(pred_value_in_cms) + 1e-8
1116
+ else:
1117
+ # regex parsing
1118
+ pred_value_in_cms = []
1119
+ n_errors_in_parsing = 0
1120
+ for pred in data["prediction"]:
1121
+ try:
1122
+ parsed_value = self.parse_prediction(pred)
1123
+ except IndexError:
1124
+ n_errors_in_parsing += 1
1125
+ parsed_value = 1e-8
1126
+
1127
+ pred_value_in_cms.append(parsed_value)
1128
+
1129
+ print(f"Encounter {n_errors_in_parsing} errors in parsing")
1130
+ pred_value_in_cms = np.array(pred_value_in_cms) + 1e-8
1131
+
1132
+ # Ground truth
1133
+ ground_truth_value_in_cms = []
1134
+ for answer in data["answer"]:
1135
+ value, unit = eval(answer)
1136
+ ground_truth_value_in_cms.append(value * self.get_multiplier(unit))
1137
+ ground_truth_value_in_cms = np.array(ground_truth_value_in_cms) + 1e-8
1138
+
1139
+ # Calculate the score
1140
+ pred_gt = pred_value_in_cms / ground_truth_value_in_cms
1141
+ gt_pred = ground_truth_value_in_cms / pred_value_in_cms
1142
+ delta_2 = np.stack([pred_gt, gt_pred]).max(0) < 2.
1143
+ delta_1_point_5 = np.stack([pred_gt, gt_pred]).max(0) < 1.5
1144
+
1145
+ data["eval_score_delta_2"] = delta_2
1146
+ data["eval_score_delta_1_point_5"] = delta_1_point_5
1147
+
1148
+ final_score_dict = {
1149
+ "delta_2": delta_2.mean(),
1150
+ "delta_1_point_5": delta_1_point_5.mean()
1151
+ }
1152
+ for question_type in set(data["question_type"]):
1153
+ filtered_data = data[data["question_type"] == question_type]
1154
+ delta_2_per_question_type = filtered_data["eval_score_delta_2"].mean()
1155
+ delta_1_point_5_per_question_type = filtered_data["eval_score_delta_1_point_5"].mean()
1156
+ final_score_dict.update({f"{question_type}_delta_2": delta_2_per_question_type})
1157
+ final_score_dict.update({f"{question_type}_delta_1_point_5": delta_1_point_5_per_question_type})
1158
+
1159
+ score_pth = eval_file.replace('.xlsx', '_score.json')
1160
+ dump(final_score_dict, score_pth)
1161
+ return final_score_dict
1162
+
1163
+
1164
+ class MMNIAH(ImageBaseDataset):
1165
+ TYPE = 'VQA'
1166
+ DATASET_URL = {
1167
+ 'MM_NIAH_VAL':
1168
+ 'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/MM_NIAH_VAL.tsv',
1169
+ 'MM_NIAH_TEST':
1170
+ ['https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-aa',
1171
+ 'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ab',
1172
+ 'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ac',
1173
+ 'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ad',
1174
+ 'https://huggingface.co/datasets/petter12321/MM-NIAH-VLMEvalKit/resolve/main/part-ae']}
1175
+ DATASET_MD5 = {'MM_NIAH_VAL': '27e5a8c3cef7746cb38f89cd86c474c5',
1176
+ 'MM_NIAH_TEST': 'f490eb2a43096307465fe9e7ef13497c'}
1177
+
1178
+ def prepare_tsv(self, url, file_md5=None):
1179
+ import os
1180
+ data_root = LMUDataRoot()
1181
+ os.makedirs(data_root, exist_ok=True)
1182
+ update_flag = False
1183
+ file_name = 'MM_NIAH_VAL.tsv' if 'MM_NIAH_VAL' in url else 'MM_NIAH_TEST.tsv'
1184
+ data_path = osp.join(data_root, file_name)
1185
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
1186
+ pass
1187
+ elif file_name == 'MM_NIAH_TEST.tsv':
1188
+ warnings.warn('The dataset tsv is not downloaded')
1189
+ for i in range(len(url)):
1190
+ if osp.exists(osp.join(data_root, 'part-a' + chr(ord('a') + i))):
1191
+ print('part_a' + chr(ord('a') + i) + ' is existed')
1192
+ continue
1193
+ download_file(url[i], data_path)
1194
+ file_prefix = 'part-'
1195
+ output_file = data_path
1196
+ split_files = sorted([f for f in os.listdir(data_root) if f.startswith(file_prefix)])
1197
+ with open(output_file, 'wb') as outfile:
1198
+ # 逐个读取每个拆分文件并写入到输出文件
1199
+ for filename in split_files:
1200
+ with open(osp.join(data_root, filename), 'rb') as infile:
1201
+ outfile.write(infile.read())
1202
+ update_flag = True
1203
+ else:
1204
+ warnings.warn('The dataset tsv is not downloaded')
1205
+ download_file(url, data_path)
1206
+ update_flag = True
1207
+
1208
+ if file_size(data_path, 'GB') > 1:
1209
+ local_path = data_path.replace('.tsv', '_local.tsv')
1210
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
1211
+ from ..tools import LOCALIZE
1212
+ LOCALIZE(data_path, local_path)
1213
+ data_path = local_path
1214
+ return load(data_path)
1215
+
1216
+ @classmethod
1217
+ def evaluate(self, eval_file, **judge_kwargs):
1218
+ from .utils.mmniah import is_correct
1219
+ # find-image, count-text, find-text,
1220
+ # infer-choose, count-image, visual-reasoning
1221
+ MMNIAH_score = {
1222
+ 'count-text': 0,
1223
+ 'find-image': 0,
1224
+ 'find-text': 0,
1225
+ 'infer-choose': 0,
1226
+ 'count-image': 0,
1227
+ 'visual-reasoning': 0,
1228
+ 'total': 0,
1229
+ }
1230
+ MMNIAH_num = {
1231
+ 'count-text': 0,
1232
+ 'find-image': 0,
1233
+ 'find-text': 0,
1234
+ 'infer-choose': 0,
1235
+ 'count-image': 0,
1236
+ 'visual-reasoning': 0,
1237
+ 'total': 0,
1238
+ }
1239
+ final_score_dict = {
1240
+ 'count-text': 0,
1241
+ 'find-image': 0,
1242
+ 'find-text': 0,
1243
+ 'infer-choose': 0,
1244
+ 'count-image': 0,
1245
+ 'visual-reasoning': 0,
1246
+ 'total': 0,
1247
+ }
1248
+ data = load(eval_file)
1249
+ lt = len(data)
1250
+ lines = [data.iloc[i] for i in range(lt)]
1251
+ for i in tqdm(range(len(lines))):
1252
+ line = lines[i]
1253
+ predict = line['prediction']
1254
+ answers = line['answer']
1255
+ category = line['category']
1256
+ if category in ['visual-reasoning', 'find-image']:
1257
+ answers = int(answers)
1258
+ if is_correct(answers, predict):
1259
+ MMNIAH_score[category] += 1
1260
+ MMNIAH_score['total'] += 1
1261
+ MMNIAH_num[category] += 1
1262
+ MMNIAH_num['total'] += 1
1263
+
1264
+ for category in ['find-image', 'count-text', 'find-text',
1265
+ 'infer-choose', 'count-image', 'visual-reasoning', 'total']:
1266
+ if MMNIAH_num[category] != 0:
1267
+ final_score_dict[category] = MMNIAH_score[category] / MMNIAH_num[category]
1268
+ else:
1269
+ final_score_dict[category] = None
1270
+
1271
+ score_pth = eval_file.replace('.xlsx', '_score.json')
1272
+ dump(final_score_dict, score_pth)
1273
+ return final_score_dict
1274
+
1275
+ def build_prompt(self, line):
1276
+ msgs = super().build_prompt(line)
1277
+ if isinstance(line, int):
1278
+ line = self.data.iloc[line]
1279
+ totalchoice = line['multi-choice options']
1280
+ totalchoice = eval(totalchoice)
1281
+ # find-image, count-text, find-text,
1282
+ # infer-choose, count-image, visual-reasoning
1283
+ context = msgs[-1]['value']
1284
+ context = eval(context)
1285
+ question = context[0] + '\n' + context[1]
1286
+ # tgt_path是所有图像地址列表
1287
+ tgt_path = []
1288
+ for i in range(len(msgs) - 1):
1289
+ tgt_path.append(msgs[i]['value'])
1290
+ choices = totalchoice[0]
1291
+ choices_image = totalchoice[1]
1292
+ if choices:
1293
+ for c_idx, c in enumerate(choices):
1294
+ question = f"{question}\n{chr(c_idx + ord('A'))}. {c}"
1295
+ question += "\nAnswer with the option's letter from the given choices directly."
1296
+ elif choices_image:
1297
+ for c_idx in range(len(choices_image)):
1298
+ question = f"{question}\n{chr(c_idx + ord('A'))}. <image>"
1299
+ question += "\nAnswer with the option's letter from the given choices directly."
1300
+ else:
1301
+ question += '\nAnswer the question using a single word or phrase.'
1302
+ question = '<start>' + question + '<end>'
1303
+ question = question.split('<image>')
1304
+ if choices_image:
1305
+ for i in range(len(question) - 5):
1306
+ question[i] = question[i] + '\n<image>'
1307
+ for i in range(len(question) - 5, len(question) - 1):
1308
+ question[i] = question[i] + '<image>'
1309
+ else:
1310
+ for i in range(len(question) - 1):
1311
+ question[i] = question[i] + '\n<image>'
1312
+ assert len(tgt_path) + 1 == len(question)
1313
+ context = []
1314
+ for i in range(len(tgt_path)):
1315
+ context.append(question[i])
1316
+ context.append(tgt_path[i])
1317
+ context.append(question[-1])
1318
+ context[0] = context[0][7:]
1319
+ context[-1] = context[-1][:-5]
1320
+ msgs = []
1321
+ for i in range(len(context)):
1322
+ if i % 2 == 0:
1323
+ msgs.append(dict(type='text', value=context[i]))
1324
+ else:
1325
+ ROOT = LMUDataRoot()
1326
+ msgs.append(dict(type='image', value=osp.join(osp.join(ROOT, 'images', self.dataset_name), context[i])))
1327
+ for element in msgs:
1328
+ if element['value'] == '':
1329
+ msgs.remove(element)
1330
+ return msgs
VLMEvalKit/vlmeval/dataset/image_yorn.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ from ..utils import *
3
+ from .image_base import ImageBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+
6
+
7
+ class ImageYORNDataset(ImageBaseDataset):
8
+
9
+ TYPE = 'Y/N'
10
+
11
+ DATASET_URL = {
12
+ 'MME': 'https://opencompass.openxlab.space/utils/VLMEval/MME.tsv',
13
+ 'HallusionBench': 'https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv',
14
+ 'POPE': 'https://opencompass.openxlab.space/utils/VLMEval/POPE.tsv',
15
+ 'AMBER': 'https://huggingface.co/datasets/yifanzhang114/AMBER_base64/resolve/main/AMBER.tsv',
16
+ }
17
+
18
+ DATASET_MD5 = {
19
+ 'MME': 'b36b43c3f09801f5d368627fb92187c3',
20
+ 'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
21
+ 'POPE': 'c12f5acb142f2ef1f85a26ba2fbe41d5',
22
+ 'AMBER': '970d94c0410916166e0a76ba75da7934',
23
+ }
24
+
25
+ # It returns a dataframe
26
+ def evaluate(self, eval_file, **judge_kwargs):
27
+ from .utils.yorn import YOrN_Extraction, YOrN_auxeval
28
+ from .utils.yorn import default_rating, MME_rating, Hallusion_rating, POPE_rating, AMBER_rating
29
+
30
+ dataset = self.dataset_name
31
+ data = load(eval_file)
32
+ data['prediction'] = [str(x) for x in data['prediction']]
33
+ storage = eval_file.replace('.xlsx', '_auxmatch.xlsx')
34
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
35
+ nproc = judge_kwargs.pop('nproc', 4)
36
+
37
+ if not osp.exists(storage):
38
+ ans_map = {k: YOrN_Extraction(v) for k, v in zip(data['index'], data['prediction'])}
39
+ if osp.exists(tmp_file):
40
+ tmp = load(tmp_file)
41
+ for k in tmp:
42
+ if ans_map[k] == 'Unknown' and tmp[k] != 'Unknown':
43
+ ans_map[k] = tmp[k]
44
+
45
+ data['extracted'] = [ans_map[x] for x in data['index']]
46
+ unknown = data[data['extracted'] == 'Unknown']
47
+
48
+ model = judge_kwargs.get('model', 'exact_matching')
49
+ if model == 'exact_matching':
50
+ model = None
51
+ elif gpt_key_set():
52
+ model = build_judge(**judge_kwargs)
53
+ if not model.working():
54
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
55
+ warnings.warn(DEBUG_MESSAGE)
56
+ model = None
57
+ else:
58
+ model = None
59
+ warnings.warn('OPENAI_API_KEY is not working properly, will use exact matching for evaluation')
60
+
61
+ if model is not None:
62
+ lt = len(unknown)
63
+ lines = [unknown.iloc[i] for i in range(lt)]
64
+ tups = [(model, line) for line in lines]
65
+ indices = list(unknown['index'])
66
+ if len(tups):
67
+ res = track_progress_rich(
68
+ YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
69
+ for k, v in zip(indices, res):
70
+ ans_map[k] = v
71
+
72
+ data['extracted'] = [ans_map[x] for x in data['index']]
73
+ dump(data, storage)
74
+
75
+ data = load(storage)
76
+ if listinstr(['AMBER'], dataset):
77
+ data['score'] = (data['answer'].str.lower() == data['extracted'].str.lower())
78
+ else:
79
+ data['score'] = (data['answer'] == data['extracted'])
80
+ dump(data, storage)
81
+
82
+ if dataset is not None and listinstr(['MME'], dataset):
83
+ score = MME_rating(storage)
84
+ elif dataset is not None and listinstr(['Hallusion'], dataset):
85
+ score = Hallusion_rating(storage)
86
+ elif dataset is not None and listinstr(['POPE'], dataset):
87
+ score = POPE_rating(storage)
88
+ elif dataset is not None and listinstr(['AMBER'], dataset):
89
+ score = AMBER_rating(storage)
90
+ else:
91
+ score = default_rating(storage)
92
+
93
+ score_tgt = eval_file.replace('.xlsx', '_score.csv')
94
+ dump(score, score_tgt)
95
+ return score
VLMEvalKit/vlmeval/dataset/longvideobench.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from glob import glob
6
+
7
+ FAIL_MSG = 'Failed to obtain answer via API.'
8
+
9
+
10
+ def timestamp_to_seconds(timestamp):
11
+ # Split the timestamp into hours, minutes, and seconds
12
+ h, m, s = timestamp.split(":")
13
+ # Convert hours, minutes, and total seconds (including fractions) to float and compute total seconds
14
+ total_seconds = int(h) * 3600 + int(m) * 60 + float(s)
15
+ return total_seconds
16
+
17
+
18
+ def uniformly_subsample(lst, K):
19
+ n = len(lst)
20
+ if K >= n:
21
+ return lst
22
+ step = n / K
23
+ return [lst[int(i * step)] for i in range(K)]
24
+
25
+
26
+ def insert_subtitles_into_frames(
27
+ frames,
28
+ frame_timestamps,
29
+ subtitles,
30
+ starting_timestamp_for_subtitles,
31
+ duration,
32
+ ):
33
+ interleaved_list = []
34
+ cur_i = 0
35
+
36
+ for subtitle in subtitles:
37
+ if "timestamp" in subtitle:
38
+ start, end = subtitle["timestamp"]
39
+
40
+ if not isinstance(end, float):
41
+ end = duration
42
+
43
+ start -= starting_timestamp_for_subtitles
44
+ end -= starting_timestamp_for_subtitles
45
+
46
+ subtitle_timestamp = (start + end) / 2
47
+ subtitle_text = subtitle["text"]
48
+ else:
49
+ start, end = subtitle["start"], subtitle["end"]
50
+ start = timestamp_to_seconds(start)
51
+ end = timestamp_to_seconds(end)
52
+ start -= starting_timestamp_for_subtitles
53
+ end -= starting_timestamp_for_subtitles
54
+
55
+ subtitle_timestamp = (start + end) / 2
56
+ subtitle_text = subtitle["line"]
57
+
58
+ for i, (frame, frame_timestamp) in enumerate(
59
+ zip(frames[cur_i:], frame_timestamps[cur_i:])
60
+ ):
61
+ if frame_timestamp <= subtitle_timestamp:
62
+ # print("frame:", frame_timestamp)
63
+ interleaved_list.append({"type": "image", "value": frame})
64
+ cur_i += 1
65
+ else:
66
+ break
67
+
68
+ if end - start < 1:
69
+ end = subtitle_timestamp + 0.5
70
+ start = subtitle_timestamp - 0.5
71
+
72
+ covering_frames = False
73
+ for frame, frame_timestamp in zip(frames, frame_timestamps):
74
+ if frame_timestamp < end and frame_timestamp > start:
75
+ covering_frames = True
76
+ break
77
+
78
+ if covering_frames:
79
+ interleaved_list.append({"type": "text", "value": subtitle_text + "\n"})
80
+ else:
81
+ pass
82
+
83
+ for i, (frame, frame_timestamp) in enumerate(
84
+ zip(frames[cur_i:], frame_timestamps[cur_i:])
85
+ ):
86
+ interleaved_list.append({"type": "image", "value": frame})
87
+ return interleaved_list
88
+
89
+
90
+ class LongVideoBench(VideoBaseDataset):
91
+
92
+ MD5 = '82905eae3a5ae7383c5a8ee9655e1ab9'
93
+ SYS = ''
94
+
95
+ TYPE = 'Video-MCQ'
96
+
97
+ def __init__(self, dataset='LongVideoBench', use_subtitle=False):
98
+ super().__init__(dataset=dataset)
99
+ self.use_subtitle = use_subtitle
100
+ self.dataset_name = dataset
101
+
102
+ @classmethod
103
+ def supported_datasets(cls):
104
+ return ['LongVideoBench']
105
+
106
+ def prepare_dataset(self, dataset_name='LongVideoBench', repo_id='longvideobench/LongVideoBench'):
107
+
108
+ def check_integrity(pth):
109
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
110
+
111
+ if not osp.exists(data_file):
112
+ return False
113
+
114
+ if md5(data_file) != self.MD5:
115
+ print("md5 mismatch", md5(data_file), self.MD5)
116
+ return False
117
+ data = load(data_file)
118
+ for video_pth in data['video_path']:
119
+ if not osp.exists(osp.join(pth, video_pth)):
120
+ print(video_pth, "is not found")
121
+ return False
122
+ return True
123
+
124
+ if modelscope_flag_set():
125
+ repo_id = "AI-ModelScope/LongVideoBench"
126
+
127
+ cache_path = get_cache_path(repo_id)
128
+ if cache_path is not None and check_integrity(cache_path):
129
+ dataset_path = cache_path
130
+ else:
131
+ def generate_tsv(pth):
132
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
133
+ if osp.exists(data_file) and md5(data_file) == self.MD5:
134
+ return
135
+
136
+ data_file = pd.read_json(osp.join(pth, 'lvb_val.json'))
137
+ data_file = data_file.assign(index=range(len(data_file)))
138
+ data_file['video'] = data_file['video_id']
139
+ data_file['video_path'] = data_file['video_path'].apply(lambda x: f'./videos/{x}')
140
+
141
+ data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
142
+
143
+ if modelscope_flag_set():
144
+ from modelscope import dataset_snapshot_download
145
+ dataset_snapshot_download(dataset_id=repo_id)
146
+ else:
147
+ snapshot_download(repo_id=repo_id, repo_type='dataset')
148
+ print("All videos are downloaded for LongVideoBench")
149
+
150
+ if not glob(osp.join(cache_path, "videos")):
151
+ tar_files = glob(osp.join(cache_path, "**/*.tar*"), recursive=True)
152
+
153
+ def untar_video_data(tar_file, cache_dir):
154
+ import tarfile
155
+ with tarfile.open(tar_file, "r") as tar_ref:
156
+ tar_ref.extractall(cache_dir)
157
+ print(f"Extracted all files from {tar_file} to {cache_dir}")
158
+
159
+ def concat_tar_parts(tar_parts, output_tar):
160
+ with open(output_tar, "wb") as out_tar:
161
+ from tqdm import tqdm
162
+ for part in tqdm(sorted(tar_parts)):
163
+ with open(part, "rb") as part_file:
164
+ out_tar.write(part_file.read())
165
+ print(f"Concatenated parts {tar_parts} into {output_tar}")
166
+
167
+ tar_parts_dict = {}
168
+
169
+ # Group tar parts together
170
+ for tar_file in tar_files:
171
+ base_name = tar_file.split(".tar")[0]
172
+ if base_name not in tar_parts_dict:
173
+ tar_parts_dict[base_name] = []
174
+ tar_parts_dict[base_name].append(tar_file)
175
+
176
+ # Concatenate and untar split parts
177
+ for base_name, parts in tar_parts_dict.items():
178
+ print(f"Extracting following tar files: {parts}")
179
+ output_tar = base_name + ".tar"
180
+ if not osp.exists(output_tar):
181
+ print('Start concatenating tar files')
182
+
183
+ concat_tar_parts(parts, output_tar)
184
+ print('Finish concatenating tar files')
185
+
186
+ if not osp.exists(osp.join(cache_path, osp.basename(base_name))):
187
+ untar_video_data(output_tar, cache_path)
188
+
189
+ print('All videos are extracted for LongVideoBench')
190
+
191
+ dataset_path = cache_path
192
+ generate_tsv(dataset_path)
193
+
194
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
195
+
196
+ return dict(data_file=data_file, root=dataset_path)
197
+
198
+ def save_video_frames(self, video_path, num_frames=8, fps=-1, video_llm=False):
199
+
200
+ vid_path = osp.join(self.data_root, video_path)
201
+ vid = decord.VideoReader(vid_path)
202
+ video_info = {
203
+ 'fps': vid.get_avg_fps(),
204
+ 'n_frames': len(vid),
205
+ }
206
+ if num_frames > 0 and fps < 0:
207
+ step_size = len(vid) / (num_frames + 1)
208
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
209
+ frame_paths = self.frame_paths(video_path[:-4], num_frames)
210
+ elif fps > 0:
211
+ # not constrained by num_frames, get frames by fps
212
+ total_duration = video_info['n_frames'] / video_info['fps']
213
+ required_frames = int(total_duration * fps)
214
+ step_size = video_info['fps'] / fps
215
+ indices = [int(i * step_size) for i in range(required_frames)]
216
+ frame_paths = self.frame_paths_fps(video_path[:-4], len(indices), fps)
217
+
218
+ flag = np.all([osp.exists(p) for p in frame_paths])
219
+
220
+ if not flag:
221
+ images = [vid[i].asnumpy() for i in indices]
222
+ images = [Image.fromarray(arr) for arr in images]
223
+ for im, pth in zip(images, frame_paths):
224
+ if not osp.exists(pth) and not video_llm:
225
+ im.save(pth)
226
+
227
+ return frame_paths, indices, video_info
228
+
229
+ def save_video_into_images(self, line, num_frames=8):
230
+ frame_paths, indices, video_info = self.save_video_frames(line['video_path'], num_frames)
231
+ return frame_paths
232
+
233
+ def build_prompt(self, line, num_frames, video_llm, fps):
234
+ if isinstance(line, int):
235
+ assert line < len(self)
236
+ line = self.data.iloc[line]
237
+
238
+ frames, indices, video_info = self.save_video_frames(line['video_path'], num_frames, fps, video_llm)
239
+ fps = video_info["fps"]
240
+
241
+ message = [dict(type='text', value=self.SYS)]
242
+ if video_llm:
243
+ message.append(dict(type='video', value=osp.join(self.data_root, line['video_path'])))
244
+ else:
245
+ if not self.use_subtitle:
246
+ with open(osp.join(self.data_root, "subtitles", line["subtitle_path"])) as f:
247
+ subtitles = json.load(f)
248
+
249
+ frame_message = insert_subtitles_into_frames(
250
+ frames,
251
+ [ind_ / fps for ind_ in indices],
252
+ subtitles,
253
+ line["starting_timestamp_for_subtitles"],
254
+ line["duration"]
255
+ )
256
+
257
+ message += frame_message
258
+ else:
259
+ for im in frames:
260
+ message.append(dict(type='image', value=im))
261
+
262
+ line['question'] += '\n' + '\n'.join(
263
+ ["{}. {}".format(chr(ord("A") + i), cand) for i, cand in enumerate(eval(line['candidates']))]
264
+ )
265
+ prompt = line["question"] + "\nAnswer with the option's letter from the given choices directly."
266
+ message.append(dict(type='text', value=prompt))
267
+ return message
268
+
269
+ # It returns a dictionary
270
+ @classmethod
271
+ def evaluate(self, eval_file, **judge_kwargs):
272
+ from .utils.longvideobench import get_dimension_rating, extract_characters_regex, extract_option
273
+
274
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
275
+
276
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
277
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
278
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
279
+
280
+ if not osp.exists(score_file):
281
+ model = judge_kwargs.get('model', 'exact_matching')
282
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
283
+
284
+ if model == 'exact_matching':
285
+ model = None
286
+ elif gpt_key_set():
287
+ model = build_judge(**judge_kwargs)
288
+ if not model.working():
289
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
290
+ warnings.warn(DEBUG_MESSAGE)
291
+ model = None
292
+ else:
293
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
294
+ model = None
295
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
296
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
297
+
298
+ data = load(eval_file)
299
+ data_un = data[~pd.isna(data['prediction'])]
300
+
301
+ for idx in data['index']:
302
+ ans = data.loc[data['index'] == idx, 'correct_choice'].values[0]
303
+ ans = chr(ord("A") + ans)
304
+ pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
305
+
306
+ if extract_characters_regex(pred) == '':
307
+ extract_pred = extract_option(
308
+ model,
309
+ data.loc[data['index'] == idx].to_dict(orient='records')[0],
310
+ 'LongVideoBench'
311
+ )
312
+ data.loc[idx, 'score'] = int(extract_pred == ans)
313
+ else:
314
+ data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
315
+
316
+ rejected = [x for x in data['score'] if x == -1]
317
+
318
+ print(
319
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
320
+ f'failed to obtain the score for another {len(rejected)} questions. '
321
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
322
+ )
323
+
324
+ dump(data, score_file)
325
+
326
+ rating = get_dimension_rating(score_file)
327
+ dump(rating, tgt_file)
328
+ return rating
VLMEvalKit/vlmeval/dataset/miabench.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import pandas as pd
5
+
6
+ from .image_base import ImageBaseDataset
7
+ from ..smp import *
8
+ from .utils import build_judge, DEBUG_MESSAGE
9
+ from ..utils import track_progress_rich
10
+
11
+
12
+ def generate_prompt(d):
13
+ question = d['question']
14
+ weights = eval(d['component_weight'])
15
+ components = eval(d['components'])
16
+ num_of_component = int(d['num_of_component'])
17
+ response = d['prediction']
18
+
19
+ if num_of_component == 1:
20
+ components = f"The first component is: '{components[0]}'. "
21
+ score = f"The first component is worth: {weights[0]} scores. "
22
+ elif num_of_component == 2:
23
+ components = f"The first component is: '{components[0]}', and the second component is '{components[1]}'. "
24
+ score = f"The first and second component is each worth {weights[0]} and {weights[1]} scores. "
25
+ elif num_of_component == 3:
26
+ components = (
27
+ f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
28
+ f"and the third component is '{components[2]}'. "
29
+ )
30
+ score = (
31
+ "The first, second, and third component is each worth "
32
+ f"{weights[0]}, {weights[1]}, and {weights[2]} scores."
33
+ )
34
+ elif num_of_component == 4:
35
+ components = (
36
+ f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
37
+ f"and the third component is '{components[2]}', and the fourth component is '{components[3]}'. "
38
+ )
39
+ score = (
40
+ "The first, second, third, and fourth component is each worth "
41
+ f"{weights[0]}, {weights[1]}, {weights[2]}, and {weights[3]} scores."
42
+ )
43
+ elif num_of_component == 5:
44
+ components = (
45
+ f"The first component is: '{components[0]}', and the second component is '{components[1]}', "
46
+ f"and the third component is '{components[2]}', and the fourth component is '{components[3]}', "
47
+ f"and the fifth component is '{components[4]}'. "
48
+ )
49
+ score = (
50
+ "The first, second, third, fourth, and fifth component is each worth "
51
+ f"{weights[0]}, {weights[1]}, {weights[2]}, {weights[3]}, and {weights[4]} scores."
52
+ )
53
+
54
+ return (
55
+ "Here is an instruction for a multimodal LLM: '"
56
+ f"{question}"
57
+ "'. You need to grade if the response from the model follows each component of the instruction. "
58
+ f"{components}"
59
+ "The response is: '"
60
+ f"{response}"
61
+ "'. You need to score the response and be strict. The total score ranges from 0 to 10, "
62
+ "depending on if the response follows the instruction. "
63
+ f"{score}"
64
+ "List scores of each component, and the total score in one sentence in this format: "
65
+ "score of component 1: x/2, score of component 2: y/8, total score: z/10. Then explain your reasons."
66
+ )
67
+
68
+
69
+ def process_rawscore(component_type, raw_score):
70
+ first_sentence = raw_score.split('.')[0].split(',')
71
+ score_dict = {}
72
+ for i in range(len(first_sentence) - 1):
73
+ score_ = first_sentence[i].split(':')[1][1:].split('/')
74
+ score = int(score_[0]) / int(score_[1])
75
+ score_dict[component_type[i]] = score
76
+ total_score_ = first_sentence[i + 1].split(':')[1][1:].split('/')
77
+ total_score = int(total_score_[0]) / int(total_score_[1])
78
+ score_dict['total_score'] = total_score
79
+ return score_dict
80
+
81
+
82
+ def get_score_dict(data, score_raw):
83
+ cat_score_dict = {}
84
+ for i in range(len(data)):
85
+ try:
86
+ cmp = data['component_type'][i][2:-2]
87
+ cmp_list = cmp.split('\', \'')
88
+ score_dict = process_rawscore(cmp_list, score_raw[i])
89
+ for key, val in score_dict.items():
90
+ if key not in cat_score_dict.keys():
91
+ cat_score_dict[key] = [val]
92
+ else:
93
+ cat_score_dict[key].append(val)
94
+ except:
95
+ pass
96
+ cat_score_dict_average = {}
97
+ for key, val in cat_score_dict.items():
98
+ cat_score_dict_average[key] = sum(val) / len(val)
99
+ return cat_score_dict_average
100
+
101
+
102
+ class MIABench(ImageBaseDataset):
103
+ TYPE = 'VQA'
104
+
105
+ DATASET_URL = {
106
+ 'MIA-Bench': 'https://opencompass.openxlab.space/utils/VLMEval/Mia-Bench.tsv',
107
+ }
108
+ DATASET_MD5 = {
109
+ 'MIA-Bench': '0b9de595f4dd40af18a69b94d89aba82',
110
+ }
111
+
112
+ @classmethod
113
+ def evaluate(self, eval_file, **judge_kwargs):
114
+ judge_name = judge_kwargs.pop('model', 'gpt-4o')
115
+
116
+ model = build_judge(model=judge_name, **judge_kwargs)
117
+ suffix = eval_file.split('.')[-1]
118
+
119
+ storage = eval_file.replace(f'.{suffix}', f'_{judge_name}.xlsx') # noqa: F841
120
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{judge_name}.pkl') # noqa: F841
121
+ nproc = judge_kwargs.pop('nproc', 4) # noqa: F841
122
+
123
+ if not osp.exists(storage):
124
+ data = load(eval_file)
125
+ num_samples = len(data)
126
+ lines = [data.loc[i] for i in range(num_samples)]
127
+ prompts = [generate_prompt(line) for line in lines]
128
+ org_data = MIABench('MIA-Bench').data
129
+ img_map = {x: y for x, y in zip(org_data['index'], org_data['image'])}
130
+ image_b64 = [img_map[idx] for idx in data['index']]
131
+ indices = list(data['index'])
132
+ mm_messages = [
133
+ dict(message=[
134
+ dict(type='text', value=prompt),
135
+ dict(type='image', value=f'data:image/jpeg;base64,{b64}')
136
+ ])
137
+ for prompt, b64 in zip(prompts, image_b64)
138
+ ]
139
+
140
+ res = {}
141
+ if osp.exists(tmp_file):
142
+ res = load(tmp_file)
143
+
144
+ jobs = {k: v for k, v in zip(indices, mm_messages) if k not in res}
145
+ job_keys = list(jobs.keys())
146
+ job_vals = [jobs[k] for k in job_keys]
147
+
148
+ resps = track_progress_rich(
149
+ model.generate,
150
+ job_vals,
151
+ nproc=nproc,
152
+ chunksize=nproc,
153
+ keys=job_keys,
154
+ save=tmp_file,
155
+ )
156
+ for k, resp in zip(job_keys, resps):
157
+ res[k] = resp
158
+ data['score_raw'] = [res[idx] for idx in indices]
159
+ dump(data, storage)
160
+
161
+ goresult = load(storage)
162
+ results = get_score_dict(goresult, goresult['score_raw'])
163
+ result_pth = storage.replace('.xlsx', '_score.csv')
164
+ results_pd = pd.DataFrame.from_dict(list(results.items()))
165
+ dump(results_pd, result_pth)
166
+
167
+ return results
VLMEvalKit/vlmeval/dataset/mlvu.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from huggingface_hub import snapshot_download
3
+ from ..smp import *
4
+ from .video_concat_dataset import ConcatVideoDataset
5
+ from .video_base import VideoBaseDataset
6
+ from .utils import build_judge, DEBUG_MESSAGE
7
+ from ..utils import track_progress_rich
8
+ import torchvision.transforms as T
9
+ from torchvision import transforms
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from decord import VideoReader, cpu
12
+ import pandas as pd
13
+ import imageio
14
+ import cv2
15
+ import zipfile
16
+ import os
17
+ import glob
18
+ from .utils.mlvu import *
19
+
20
+ FAIL_MSG = 'Failed to obtain answer via API.'
21
+
22
+
23
+ class MLVU(ConcatVideoDataset):
24
+ def __init__(self, dataset='MLVU'):
25
+ self.DATASET_SETS[dataset] = ['MLVU_MCQ', 'MLVU_OpenEnded']
26
+ self.type_data_dict = {
27
+ 'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning'],
28
+ 'G-Avg':['sub_scene', 'summary']
29
+ }
30
+ super().__init__(dataset=dataset)
31
+
32
+ @classmethod
33
+ def supported_datasets(cls):
34
+ return ['MLVU']
35
+
36
+ def evaluate(self, eval_file, **judge_kwargs):
37
+ result = super().evaluate(eval_file=eval_file, **judge_kwargs)
38
+ suffix = eval_file.split('.')[-1]
39
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
40
+ for key in self.type_data_dict:
41
+ result.loc[key] = 0.0
42
+ for name, item in result.iterrows():
43
+ if name in self.type_data_dict[key]:
44
+ result.loc[key, 'success'] += item['success']
45
+ result.loc[key, 'overall'] += item['overall']
46
+ if key == 'G-Avg':
47
+ result.loc[key, 'acc'] = round(
48
+ result.loc[key, 'success'] / result.loc[key, 'overall'], 2
49
+ )
50
+ else:
51
+ result.loc[key, 'acc'] = round(
52
+ result.loc[key, 'success'] / result.loc[key, 'overall'] * 100, 1
53
+ )
54
+ result = result.reset_index().rename(columns={'index': 'task'})
55
+ dump(result, score_file)
56
+ return result
57
+
58
+
59
+ class MLVU_MCQ(VideoBaseDataset):
60
+
61
+ MD5 = 'bb5c37e7cf8d43fc9a25c23d2b4633f5'
62
+ BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
63
+ SYS = BASE_SYS + 'Based on your observations, select the best option that accurately addresses the question.'
64
+ TYPE = 'Video-MCQ'
65
+
66
+ def __init__(self, dataset='MLVU_MCQ'):
67
+ self.type_data_list = {
68
+ 'plotQA': ('1_plotQA.json', './MLVU/video/1_plotQA', 'MCQ'),
69
+ 'needle': ('2_needle.json', './MLVU/video/2_needle', 'MCQ'),
70
+ 'ego': ('3_ego.json', './MLVU/video/3_ego', 'MCQ'),
71
+ 'count': ('4_count.json', './MLVU/video/4_count', 'MCQ'),
72
+ 'order': ('5_order.json', './MLVU/video/5_order', 'MCQ'),
73
+ 'anomaly_reco': ('6_anomaly_reco.json', './MLVU/video/6_anomaly_reco', 'MCQ'),
74
+ 'topic_reasoning': ('7_topic_reasoning.json', './MLVU/video/7_topic_reasoning', 'MCQ'),
75
+ }
76
+ super().__init__(dataset=dataset)
77
+
78
+ @classmethod
79
+ def supported_datasets(cls):
80
+ return ['MLVU_MCQ']
81
+
82
+ def prepare_dataset(self, dataset_name='MLVU_MCQ', repo_id='MLVU/MVLU'):
83
+ def check_integrity(pth):
84
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
85
+
86
+ if not os.path.exists(data_file):
87
+ return False
88
+
89
+ if md5(data_file) != self.MD5:
90
+ return False
91
+
92
+ data = load(data_file)
93
+ for idx, item in data.iterrows():
94
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
95
+ return False
96
+ return True
97
+
98
+ if modelscope_flag_set():
99
+ repo_id = "AI-ModelScope/MLVU"
100
+
101
+ cache_path = get_cache_path(repo_id)
102
+ if cache_path is not None and check_integrity(cache_path):
103
+ dataset_path = cache_path
104
+ else:
105
+ def generate_tsv(pth):
106
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
107
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
108
+ return
109
+ json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
110
+ self.data_list = []
111
+ for k, v in self.type_data_list.items():
112
+ with open(os.path.join(json_data_dir, v[0]), 'r') as f:
113
+ json_data = json.load(f)
114
+ for data in json_data:
115
+ self.data_list.append({
116
+ 'task_type': k,
117
+ 'prefix': v[1],
118
+ 'duration': data['duration'],
119
+ 'video': data['video'],
120
+ 'question': data['question'],
121
+ 'answer': data['answer'],
122
+ 'candidates': data['candidates'],
123
+ })
124
+
125
+ data_df = pd.DataFrame(self.data_list)
126
+ data_df = data_df.assign(index=range(len(data_df)))
127
+ data_df.to_csv(data_file, sep='\t', index=False)
128
+
129
+ if modelscope_flag_set():
130
+ from modelscope import dataset_snapshot_download
131
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
132
+ else:
133
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
134
+ huggingface_hub.login(hf_token)
135
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
136
+
137
+ generate_tsv(dataset_path)
138
+
139
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
140
+ return dict(root=dataset_path, data_file=data_file)
141
+
142
+ def qa_template(self, data):
143
+ question = f"Question: {data['question']}\n"
144
+ question += 'Options:\n'
145
+ answer = data['answer']
146
+ answer_idx = -1
147
+ for idx, c in enumerate(eval(data['candidates'])):
148
+ question += f"({chr(ord('A') + idx)}) {c}\n"
149
+ if c == answer:
150
+ answer_idx = idx
151
+ question = question.rstrip()
152
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
153
+ return question, answer
154
+
155
+ def save_video_frames(self, line, num_frames=8, fps=-1):
156
+ suffix = line['video'].split('.')[-1]
157
+ video = line['video'].replace(f'.{suffix}','')
158
+ vid_path = osp.join(self.data_root, line['prefix'], line['video'])
159
+ vid = decord.VideoReader(vid_path)
160
+ video_info = {
161
+ 'fps': vid.get_avg_fps(),
162
+ 'n_frames': len(vid),
163
+ }
164
+ if num_frames > 0 and fps < 0:
165
+ step_size = len(vid) / (num_frames + 1)
166
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
167
+ frame_paths = self.frame_paths(video, num_frames)
168
+ elif fps > 0:
169
+ # not constrained by num_frames, get frames by fps
170
+ total_duration = video_info['n_frames'] / video_info['fps']
171
+ required_frames = int(total_duration * fps)
172
+ step_size = video_info['fps'] / fps
173
+ indices = [int(i * step_size) for i in range(required_frames)]
174
+ frame_paths = self.frame_paths_fps(video, len(indices), fps)
175
+
176
+ flag = np.all([osp.exists(p) for p in frame_paths])
177
+
178
+ if not flag:
179
+ images = [vid[i].asnumpy() for i in indices]
180
+ images = [Image.fromarray(arr) for arr in images]
181
+ for im, pth in zip(images, frame_paths):
182
+ if not osp.exists(pth):
183
+ im.save(pth)
184
+
185
+ return frame_paths
186
+
187
+ def save_video_into_images(self, line, num_frames, fps):
188
+ frame_paths = self.save_video_frames(line, num_frames, fps)
189
+ return frame_paths
190
+
191
+ def build_prompt(self, line, num_frames, video_llm, fps=-1):
192
+ if isinstance(line, int):
193
+ assert line < len(self)
194
+ line = self.data.iloc[line]
195
+
196
+ question, answer = self.qa_template(line)
197
+ message = [dict(type='text', value=self.SYS, role='system')]
198
+ message.append(dict(type='text', value=question))
199
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
200
+ if video_llm:
201
+ message.append(dict(type='video', value=video_path))
202
+ else:
203
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
204
+ for im in img_frame_paths:
205
+ message.append(dict(type='image', value=im))
206
+ message.append(dict(type='text', value='\nOnly give the best option.'))
207
+ return message
208
+
209
+ @classmethod
210
+ def evaluate(self, eval_file, **judge_kwargs):
211
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
212
+
213
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
214
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
215
+
216
+ if not osp.exists(score_file):
217
+ model = judge_kwargs.setdefault('model', 'chatgpt-0125')
218
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
219
+
220
+ if model == 'exact_matching':
221
+ model = None
222
+ elif gpt_key_set():
223
+ model = build_judge(**judge_kwargs)
224
+ if not model.working():
225
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
226
+ warnings.warn(DEBUG_MESSAGE)
227
+ model = None
228
+ else:
229
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
230
+ model = None
231
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
232
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
233
+
234
+ data = load(eval_file)
235
+ data_un = data[~pd.isna(data['prediction'])]
236
+
237
+ for idx in data['index']:
238
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
239
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
240
+ options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
241
+ answer_idx = -1
242
+ for id, c in enumerate(options):
243
+ if c == ans:
244
+ answer_idx = id
245
+ ans = f"({chr(ord('A') + answer_idx)}) {ans}"
246
+ input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
247
+ for id, option_content in enumerate(eval(input_item['candidates'])):
248
+ input_item[chr(ord('A') + id)] = option_content
249
+ if option_content == input_item['answer']:
250
+ input_item['answer'] = chr(ord('A') + id)
251
+
252
+ if FAIL_MSG in pred:
253
+ data.loc[idx, 'score'] = -1
254
+ else:
255
+ data.loc[idx, 'score'] = int(check_ans_with_model(
256
+ pred, ans, model,
257
+ input_item,
258
+ 'MLVU_MCQ'
259
+ ))
260
+
261
+ rejected = [x for x in data['score'] if x == -1]
262
+
263
+ print(
264
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
265
+ f'failed to obtain the score for another {len(rejected)} questions. '
266
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
267
+ )
268
+
269
+ dump(data, score_file)
270
+
271
+ rating = get_dimension_rating(score_file)
272
+ return rating
273
+
274
+
275
+ class MLVU_OpenEnded(VideoBaseDataset):
276
+
277
+ MD5 = 'cee573a3627c6ac434ded704c60511ba'
278
+ BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
279
+ SYS = BASE_SYS + 'Based on your observations, answer the given questions.'
280
+ TYPE = 'Video-VQA'
281
+
282
+ def __init__(self, dataset='MLVU_OpenEnded'):
283
+ self.type_data_list = {
284
+ 'sub_scene': ('8_sub_scene.json', './MLVU/video/8_sub_scene', 'VQA'),
285
+ 'summary': ('9_summary.json', './MLVU/video/9_summary', 'VQA')
286
+ }
287
+ super().__init__(dataset=dataset)
288
+
289
+ @classmethod
290
+ def supported_datasets(cls):
291
+ return ['MLVU_OpenEnded']
292
+
293
+ def prepare_dataset(self, dataset_name='MLVU_OpenEnded', repo_id='MLVU/MVLU'):
294
+ def check_integrity(pth):
295
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
296
+
297
+ if not os.path.exists(data_file):
298
+ return False
299
+
300
+ if md5(data_file) != self.MD5:
301
+ return False
302
+
303
+ data = load(data_file)
304
+ for idx, item in data.iterrows():
305
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
306
+ return False
307
+ return True
308
+
309
+ if modelscope_flag_set():
310
+ repo_id = "AI-ModelScope/MLVU"
311
+
312
+ cache_path = get_cache_path(repo_id)
313
+ if cache_path is not None and check_integrity(cache_path):
314
+ dataset_path = cache_path
315
+ else:
316
+ def generate_tsv(pth):
317
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
318
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
319
+ return
320
+ json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
321
+ self.data_list = []
322
+ for k, v in self.type_data_list.items():
323
+ with open(os.path.join(json_data_dir, v[0]), 'r') as f:
324
+ json_data = json.load(f)
325
+ for data in json_data:
326
+ self.data_list.append({
327
+ 'task_type': k,
328
+ 'prefix': v[1],
329
+ 'duration': data['duration'],
330
+ 'video': data['video'],
331
+ 'question': data['question'],
332
+ 'answer': data['answer'],
333
+ 'scoring_points': data['scoring_points'] if 'scoring_points' in data else ''
334
+ })
335
+
336
+ data_df = pd.DataFrame(self.data_list)
337
+ data_df = data_df.assign(index=range(len(data_df)))
338
+ data_df.to_csv(data_file, sep='\t', index=False)
339
+
340
+ if modelscope_flag_set():
341
+ from modelscope import dataset_snapshot_download
342
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
343
+ else:
344
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
345
+ huggingface_hub.login(hf_token)
346
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
347
+
348
+ generate_tsv(dataset_path)
349
+
350
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
351
+ return dict(root=dataset_path, data_file=data_file)
352
+
353
+ def qa_template(self, data):
354
+ question = f"{data['question']}"
355
+ answer = data['answer']
356
+ return question, answer
357
+
358
+ def save_video_frames(self, line, num_frames=8, fps=-1):
359
+ suffix = line['video'].split('.')[-1]
360
+ video = line['video'].replace(f'.{suffix}','')
361
+ vid_path = osp.join(self.data_root, line['prefix'], line['video'])
362
+ vid = decord.VideoReader(vid_path)
363
+ video_info = {
364
+ 'fps': vid.get_avg_fps(),
365
+ 'n_frames': len(vid),
366
+ }
367
+ if num_frames > 0 and fps < 0:
368
+ step_size = len(vid) / (num_frames + 1)
369
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
370
+ frame_paths = self.frame_paths(video, num_frames)
371
+ elif fps > 0:
372
+ # not constrained by num_frames, get frames by fps
373
+ total_duration = video_info['n_frames'] / video_info['fps']
374
+ required_frames = int(total_duration * fps)
375
+ step_size = video_info['fps'] / fps
376
+ indices = [int(i * step_size) for i in range(required_frames)]
377
+ frame_paths = self.frame_paths_fps(video, len(indices), fps)
378
+
379
+ flag = np.all([osp.exists(p) for p in frame_paths])
380
+
381
+ if not flag:
382
+ images = [vid[i].asnumpy() for i in indices]
383
+ images = [Image.fromarray(arr) for arr in images]
384
+ for im, pth in zip(images, frame_paths):
385
+ if not osp.exists(pth):
386
+ im.save(pth)
387
+
388
+ return frame_paths
389
+
390
+ def save_video_into_images(self, line, num_frames, fps):
391
+ frame_paths = self.save_video_frames(line, num_frames, fps)
392
+ return frame_paths
393
+
394
+ def build_prompt(self, line, num_frames, video_llm, fps=-1):
395
+ if isinstance(line, int):
396
+ assert line < len(self)
397
+ line = self.data.iloc[line]
398
+
399
+ question, answer = self.qa_template(line)
400
+ message = [dict(type='text', value=self.SYS, role='system')]
401
+ message.append(dict(type='text', value=question))
402
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
403
+ if video_llm:
404
+ message.append(dict(type='video', value=video_path))
405
+ else:
406
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
407
+ for im in img_frame_paths:
408
+ message.append(dict(type='image', value=im))
409
+ return message
410
+
411
+ @classmethod
412
+ def evaluate(self, eval_file, **judge_kwargs):
413
+
414
+ model = judge_kwargs['model'] if 'model' in judge_kwargs else judge_kwargs.setdefault('model', 'gpt-4-0125')
415
+ if model != 'gpt-4-0125':
416
+ print('MLVU Open Ended default using gpt-4-0125! So judge model is changed to gpt-4-0125')
417
+ judge_kwargs['model'] = 'gpt-4-0125'
418
+
419
+ suffix = eval_file.split('.')[-1]
420
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
421
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
422
+ nproc = judge_kwargs.pop('nproc', 4)
423
+
424
+ if not osp.exists(score_file):
425
+ data = load(eval_file)
426
+ model_dict = {
427
+ 'sub_scene': build_judge(system_prompt=system_prompt_sub_scene, **judge_kwargs),
428
+ 'summary': build_judge(system_prompt=system_prompt_summary, **judge_kwargs)
429
+ }
430
+ lt = len(data)
431
+ lines = [data.iloc[i] for i in range(lt)]
432
+ tups = [(model_dict[line['task_type']], line) for line in lines]
433
+ indices = [line['index'] for line in lines]
434
+
435
+ ans = {}
436
+ if osp.exists(tmp_file):
437
+ ans = load(tmp_file)
438
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
439
+ indices = [i for i in indices if i not in ans]
440
+
441
+ if len(indices):
442
+ _ = track_progress_rich(
443
+ MLVU_OpenEnded_generate,
444
+ tups,
445
+ nproc=nproc,
446
+ chunksize=nproc,
447
+ keys=indices,
448
+ save=tmp_file,
449
+ )
450
+ ans = load(tmp_file)
451
+ data = MLVU_OpenEnded_extract(ans, data)
452
+ dump(data, score_file)
453
+
454
+ rating = get_dimension_rating(score_file)
455
+ return rating
VLMEvalKit/vlmeval/dataset/mmbench_video.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+ from ..utils import track_progress_rich
6
+
7
+
8
+ FAIL_MSG = 'Failed to obtain answer via API.'
9
+
10
+
11
+ def unwrap_hf_pkl(pth, suffix='.mp4'):
12
+ base_dir = os.path.join(pth, 'video_pkl/')
13
+ target_dir = os.path.join(pth, 'video/')
14
+ pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
15
+ pickle_files.sort()
16
+
17
+ if not os.path.exists(target_dir):
18
+ os.makedirs(target_dir, exist_ok=True)
19
+ for pickle_file in pickle_files:
20
+ with open(pickle_file, 'rb') as file:
21
+ video_data = pickle.load(file)
22
+ # For each video file in the pickle file, write its contents to a new mp4 file
23
+ for video_name, video_content in video_data.items():
24
+ output_path = os.path.join(target_dir, f'{video_name}{suffix}')
25
+ with open(output_path, 'wb') as output_file:
26
+ output_file.write(video_content)
27
+ print('The video file has been restored and stored from the pickle file.')
28
+ else:
29
+ print('The video file already exists.')
30
+
31
+
32
+ class MMBenchVideo(VideoBaseDataset):
33
+
34
+ MD5 = '98f7df3eb1007fc375ea6fe88a98e2ff'
35
+ SYS = 'You are an AI assistant responsible for answering questions about videos.'
36
+ FRAMES_TMPL_PACK = """
37
+ You will be provided with {} separate frames uniformly sampled from a video, \
38
+ the frames are provided in chronological order of the video.
39
+ Please analyze these images and provide the answer / answers to the \
40
+ following question / questions about the video content.
41
+ If multiple questions are provided (with indices I1, I2, I3, ...), \
42
+ you should organize your answers in the following json format:
43
+ {{
44
+ 'I1': 'Answer to Question I1',
45
+ 'I2': 'Answer to Question I2',
46
+ ...
47
+ }}
48
+ Otherwise, please directly reply with your response to the only question.
49
+ Even if the information in these separate frames is not enough to give an answer,
50
+ PLEASE GIVE A RESPONSE TO EACH OF THE QUESTIONS IN THE FORMAT DESCRIBED ABOVE.
51
+ """
52
+
53
+ FRAMES_TMPL_NOPACK = """
54
+ You will be provided with {} separate frames uniformly sampled from a video, \
55
+ the frames are provided in chronological order of the video.
56
+ Please analyze these images and provide the answer to the question about the video content.
57
+ Please directly reply with your response to the only question.
58
+ """
59
+
60
+ TYPE = 'Video-VQA'
61
+
62
+ def __init__(self, dataset='MMBench-Video', pack=False):
63
+ super().__init__(dataset=dataset, pack=pack)
64
+
65
+ @classmethod
66
+ def supported_datasets(cls):
67
+ return ['MMBench-Video']
68
+
69
+ def prepare_dataset(self, dataset_name='MMBench-Video', repo_id='opencompass/MMBench-Video'):
70
+ def check_integrity(pth):
71
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
72
+ if md5(data_file) != self.MD5:
73
+ return False
74
+ data = load(data_file)
75
+ for video_pth in data['video_path']:
76
+ if not osp.exists(osp.join(pth, video_pth)):
77
+ return False
78
+ return True
79
+
80
+ cache_path = get_cache_path(repo_id)
81
+ if cache_path is not None and check_integrity(cache_path):
82
+ dataset_path = cache_path
83
+ else:
84
+ if modelscope_flag_set():
85
+ from modelscope import dataset_snapshot_download
86
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
87
+ else:
88
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
89
+ unwrap_hf_pkl(dataset_path)
90
+ self.video_path = osp.join(dataset_path, 'video/')
91
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
92
+
93
+ return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))
94
+
95
+ def build_prompt_pack(self, line, num_frames, fps=-1):
96
+ if isinstance(line, int):
97
+ assert line < len(self)
98
+ video = self.videos[line]
99
+ elif isinstance(line, pd.Series):
100
+ video = line['video']
101
+ elif isinstance(line, str):
102
+ video = line
103
+
104
+ frames = self.save_video_frames(video, num_frames, fps)
105
+ sub = self.data[self.data['video'] == video]
106
+ sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames))
107
+ message = [dict(type='text', value=sys_prompt)]
108
+ for im in frames:
109
+ message.append(dict(type='image', value=im))
110
+ nq = len(sub)
111
+ prompt = 'Questions: \n{}\nAnswers: \n'
112
+ qs = {int(sub.iloc[i]['index']): sub.iloc[i]['question'] for i in range(nq)}
113
+ prompt = prompt.format(json.dumps(qs))
114
+ message.append(dict(type='text', value=prompt))
115
+ return message
116
+
117
+ def build_prompt_nopack(self, line, num_frames, video_llm, fps):
118
+ if isinstance(line, int):
119
+ assert line < len(self)
120
+ line = self.data.iloc[line]
121
+ if video_llm:
122
+ question = line['question']
123
+ prefix, video_idx_path = os.path.split(line['video_path'])
124
+ message = [dict(type='text', value=question)]
125
+ message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
126
+ return message
127
+ else:
128
+ frames = self.save_video_frames(line['video'], num_frames, fps)
129
+ sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames))
130
+ message = [dict(type='text', value=sys_prompt)]
131
+ for im in frames:
132
+ message.append(dict(type='image', value=im))
133
+ prompt = 'Question: {}\nAnswer: '.format(line['question'])
134
+ message.append(dict(type='text', value=prompt))
135
+ return message
136
+
137
+ def build_prompt(self, line, num_frames, video_llm, fps):
138
+ if self.pack and not video_llm:
139
+ return self.build_prompt_pack(line, num_frames, fps)
140
+ else:
141
+ return self.build_prompt_nopack(line, num_frames, video_llm, fps)
142
+
143
+ @staticmethod
144
+ def remove_side_quote(s, syms=[',', '"', "'"]):
145
+ if np.all([x in syms for x in s]):
146
+ return ''
147
+ while s[0] in syms:
148
+ s = s[1:]
149
+ while s[-1] in syms:
150
+ s = s[:-1]
151
+ return s
152
+
153
+ @staticmethod
154
+ def robust_json_load(s):
155
+ try:
156
+ jsons = list(extract_json_objects(s))
157
+ assert len(jsons) == 1
158
+ return jsons[0]
159
+ except:
160
+ if '{' in s and s.find('{') == s.rfind('{'):
161
+ sub_str = s[s.find('{') + 1:].strip()
162
+ lines = sub_str.split('\n')
163
+ res = {}
164
+ for l in lines:
165
+ l = l.strip()
166
+ if ': ' in l:
167
+ key = l.split(': ')[0].strip()
168
+ val = l.split(': ')[1].strip()
169
+ key = MMBenchVideo.remove_side_quote(key)
170
+ val = MMBenchVideo.remove_side_quote(val)
171
+ if len(key) and len(val):
172
+ res[key] = val
173
+ return res
174
+ return None
175
+
176
+ def load_pack_answers(self, data_raw):
177
+ vstats = defaultdict(lambda: 0)
178
+ data = defaultdict(lambda: {})
179
+
180
+ for k in data_raw:
181
+ ans = data_raw[k].strip()
182
+ if FAIL_MSG in ans:
183
+ vstats['GEN_FAIL'] += 1
184
+ continue
185
+ res = self.robust_json_load(ans)
186
+ if res is not None:
187
+ data[k] = res
188
+ vstats['PARSE_OK'] += 1
189
+ else:
190
+ vstats['PARSE_FAIL'] += 1
191
+
192
+ # return data
193
+ meta = cp.deepcopy(self.data)
194
+ lt = len(meta)
195
+ prediction = []
196
+ for i in range(lt):
197
+ line = meta.iloc[i]
198
+ vid = line['video']
199
+ idx = str(line['index'])
200
+ prediction.append(data[vid][idx] if idx in data[vid] else None)
201
+ meta['prediction'] = prediction
202
+ vstats['VALIDQ'] = len([x for x in prediction if x is not None])
203
+ vstats['INVALIDQ'] = len([x for x in prediction if x is None])
204
+ return meta, vstats
205
+
206
+ # It returns a dictionary
207
+ @classmethod
208
+ def evaluate(self, eval_file, **judge_kwargs):
209
+ from .utils.mmbench_video import get_dimension_rating, system_prompt, build_prompt
210
+
211
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
212
+ judge = judge_kwargs['model']
213
+ nproc = judge_kwargs.pop('nproc', 4)
214
+
215
+ tmp_file = eval_file.replace('.xlsx', f'_{judge}_tmp.pkl')
216
+ tgt_file = eval_file.replace('.xlsx', f'_{judge}_rating.json')
217
+ score_file = eval_file.replace('.xlsx', f'_{judge}_score.xlsx')
218
+
219
+ model = build_judge(system_prompt=system_prompt, **judge_kwargs)
220
+ assert model.working(), 'MMBench-Video evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE
221
+
222
+ if not osp.exists(score_file):
223
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
224
+ res = {k: v for k, v in res.items() if model.fail_msg not in v}
225
+
226
+ data = load(eval_file)
227
+ data_un = data[~data['index'].isin(res)]
228
+ data_un = data_un[~pd.isna(data_un['prediction'])]
229
+ lt = len(data_un)
230
+ prompts = [build_prompt(data_un.iloc[i]) for i in range(lt)]
231
+ indices = [data_un.iloc[i]['index'] for i in range(lt)]
232
+
233
+ if len(prompts):
234
+ _ = track_progress_rich(
235
+ model.generate,
236
+ prompts,
237
+ keys=indices,
238
+ save=tmp_file,
239
+ nproc=nproc,
240
+ chunksize=nproc
241
+ )
242
+ score_map = load(tmp_file)
243
+ data['score'] = [score_map[idx] if idx in score_map else -1 for idx in data['index']]
244
+ rejected = [x for x in score_map.values() if FAIL_MSG in x]
245
+ data['score'] = [int(x) if istype(x, int) else -1 for x in data['score']]
246
+ print(
247
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(score_map)} questions, '
248
+ f'failed to obtain the score for another {len(rejected)} questions. '
249
+ f'Those questions will be counted as 0 score in ALL rating, and will not be counted in VALID rating.'
250
+ )
251
+
252
+ dump(data, score_file)
253
+
254
+ rating = get_dimension_rating(score_file)
255
+ dump(rating, tgt_file)
256
+ return rating
VLMEvalKit/vlmeval/dataset/mmgenbench.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import pandas as pd
3
+ from abc import abstractmethod
4
+ from ..smp import *
5
+ from .image_base import ImageBaseDataset
6
+
7
+
8
+ class MMGenBench(ImageBaseDataset):
9
+
10
+ prompt_list = [
11
+ """
12
+ # Role
13
+ You are an expert in the field of image understanding, focusing on the \
14
+ understanding of images and generating the image caption-prompt.
15
+
16
+ # Definition Explanation
17
+ image caption-prompt: Refers to the caption or description of an image, \
18
+ used to provide to a Text-to-Image model to generate a new image.
19
+ Text-to-Image model: Can generate a new image based on the provided image \
20
+ caption-prompt, such as stable diffusion 3, flux, and other image generation models.
21
+
22
+ # Task Description
23
+ Generate an image caption-prompt based on the input image.
24
+
25
+ # Key Points and Requirements
26
+ 1. Accurately understand the input image and precisely generate an image caption-prompt.
27
+ 2. The generated image caption-prompt, when provided to the Text-to-Image model, requires the \
28
+ Text-to-Image model to generate a new image that is as consistent as possible with the input image.
29
+ 3. The generated image caption-prompt must conform to the preferences of the Text-to-Image model.
30
+ 4. The generated image caption-prompt should describe the input image in as much \
31
+ detail as possible, and it should be between 20 to 60 words.
32
+
33
+ # Output Format
34
+ A string, that is the image caption-prompt. No extra output needed.
35
+ """
36
+ ]
37
+ TYPE = 'GenerateImgPrompt'
38
+ DATASET_URL = {
39
+ 'MMGenBench-Test': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Test.tsv',
40
+ 'MMGenBench-Domain': 'https://huggingface.co/datasets/lerogo/MMGenBench/resolve/main/MMGenBench-Domain.tsv',
41
+ }
42
+ PROMPT_MAP = {
43
+ 'MMGenBench-Test': prompt_list[0],
44
+ 'MMGenBench-Domain': prompt_list[0],
45
+ }
46
+ DATASET_MD5 = {
47
+ 'MMGenBench-Test': "94f8dac6bbf7c20be403f99adeaa73da",
48
+ 'MMGenBench-Domain': "5c10daf6e2c5f08bdfb0701aa6db86bb",
49
+ }
50
+
51
+ def __init__(self, dataset='MMGenBench', **kwargs):
52
+ super().__init__(dataset, **kwargs)
53
+ warnings.warn('This dataset is for inference only and does not support direct output of evaluation results.\n')
54
+ warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
55
+
56
+ def load_data(self, dataset):
57
+ data = super().load_data(dataset)
58
+ if 'question' not in data:
59
+ data['question'] = [(
60
+ self.PROMPT_MAP[dataset]
61
+ )] * len(data)
62
+ return data
63
+
64
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
65
+ @abstractmethod
66
+ def evaluate(self, eval_file, **judge_kwargs):
67
+ warnings.warn('This evaluation method is not supported.\n')
68
+ warnings.warn('Please refer to "https://github.com/lerogo/MMGenBench" for more evaluation information.\n')
69
+ return None
VLMEvalKit/vlmeval/dataset/mmlongbench.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from urllib.request import urlopen
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import torchvision.transforms as transforms
6
+
7
+ from vlmeval.dataset.utils import build_judge, levenshtein_distance
8
+ from vlmeval.smp import *
9
+ from .image_base import ImageBaseDataset
10
+
11
+ FAIL_MSG = 'Failed to obtain answer via API.'
12
+
13
+
14
+ def get_gpt4_ICE():
15
+ example_1 = """
16
+ ---
17
+ Question: List the primary questions asked about the services in this report.
18
+ Analysis: The primary questions asked about the services in the report for The Limes Residential Home are:\n\n
19
+ 1. Is the service safe?\n
20
+ 2. Is the service effective?\n
21
+ 3. Is the service caring?\n
22
+ 4. Is the service responsive?\n
23
+ 5. Is the service well-led?
24
+ Extracted answer: [
25
+ 'Is the servife safe?',
26
+ 'Is the service effective',
27
+ 'Is the serve caring?',
28
+ 'Is the service responsive?',
29
+ 'Is the service well-led?'
30
+ ]
31
+ Answer format: List\n
32
+ """
33
+
34
+ example_2 = """
35
+ ---
36
+ Question: How many regulations of the HSCA 2008 are breached in all according to this report?
37
+ Analysis: According to the report, the provider breached 10 Health and Social Care Act 2008 (Regulated Activities)
38
+ Regulations in total. Here are the specifics:\n\n1. Regulation 13: Safeguarding service users from abuse and
39
+ improper treatment\n2. Regulation 12: Safe care and treatment\n3. Regulation 18: Staffing\n4. Regulation 11:
40
+ Need for consent\n5. Regulation 10: Dignity and respect\n6. Regulation 9: Person-centred care\n7. Regulation 17:
41
+ Good governance\n8. Regulation 18 (CQC Registration Regulations 2009): Notification of other incidents\n9.
42
+ Regulation 18: Failure to maintain an accurate and up-to-date care plan\n10. Regulation 11: Failure to implement
43
+ the Mental Capacity Act 2005 code of practice effectively\n\nThese breaches involve issues concerning staffing,
44
+ safeguarding, medicines management, dignity and respect, consent, care planning, governance, and failure to
45
+ notify the CQC of incidents.
46
+ Extracted answer: 10
47
+ Answer format: Integer\n
48
+ """
49
+
50
+ example_3 = """
51
+ ---
52
+ Question: According to the survey that is the percentage of Chinese who are paying more or
53
+ about the same attention to politics after Trump's election?
54
+ Analysis: The survey provided does not specify the percentage of Chinese individuals specifically who are paying
55
+ more or about the same attention to politics after Trump's election. The report focuses primarily on American
56
+ demographics and does not include specific details about the Chinese population in relation to this question. If
57
+ you need information about a different demographic or a summary of the findings from the American demographic,
58
+ I can certainly help with that!
59
+ Extracted answer: Not answerable
60
+ Answer format: String\n
61
+ """
62
+
63
+ example_4 = """
64
+ ---
65
+ Question: How many quotations from male respondent over 50 years old are included in this report?
66
+ Analysis: The image you've provided appears to be a screenshot of a document with multiple charts. However, the
67
+ text is too small and blurry to read accurately. If you can provide a clearer image or more context, I might be
68
+ able to help you with your question.
69
+ Extracted answer: Fail to answer
70
+ Answer format: String\n
71
+ """
72
+
73
+ return [example_1, example_2, example_3, example_4]
74
+
75
+
76
+ def build_mmlongbench_gpt4_prompt(line):
77
+ task_description = """
78
+ Given the question and analysis, you are tasked to extract answers with required formats from the free-form analysis.
79
+ - Your extracted answers should be one of the following formats: (1) Integer, (2) Float, (3) String and (4) List.
80
+ If you find the analysis the question can not be answered from the given documents, type "Not answerable".
81
+ Exception: If the analysis only tells you that it can not read/understand the images or documents,
82
+ type "Fail to answer".
83
+ - Please make your response as concise as possible. Also note that your response should be formatted as below:
84
+ ```
85
+ Extracted answer: [answer]
86
+ Answer format: [answer format]
87
+ ```
88
+ Please read the following example, then extract the answer from the model response
89
+ and type it at the end of the prompt.\n
90
+ """
91
+ question = line['question']
92
+ prediction = str(line['prediction'])
93
+ prompt = task_description
94
+ examples = get_gpt4_ICE()
95
+ for example in examples:
96
+ prompt += example
97
+ prompt += '---\nQuestion:' + question + '\n'
98
+ prompt += 'Analysis: ' + prediction
99
+ return prompt
100
+
101
+
102
+ def anls_compute(groundtruth, prediction, threshold=0.5):
103
+ dist = levenshtein_distance(groundtruth, prediction)
104
+ length = max(len(groundtruth.upper()), len(prediction.upper()))
105
+ value = 0.0 if length == 0 else float(dist) / float(length)
106
+ anls = 1.0 - value
107
+ if anls <= threshold:
108
+ anls = 0.0
109
+ return anls
110
+
111
+
112
+ def is_float_equal(reference, prediction, include_percentage: bool = False, is_close: float = False) -> bool:
113
+ def get_precision(gt_ans: float) -> int:
114
+ precision = 3
115
+ if '.' in str(gt_ans):
116
+ precision = len(str(gt_ans).split('.')[-1])
117
+ return precision
118
+
119
+ reference = float(str(reference).strip().rstrip('%').strip())
120
+ try:
121
+ prediction = float(str(prediction).strip().rstrip('%').strip())
122
+ except:
123
+ return False
124
+
125
+ if include_percentage:
126
+ gt_result = [reference / 100, reference, reference * 100]
127
+ else:
128
+ gt_result = [reference]
129
+ for item in gt_result:
130
+ try:
131
+ if is_close:
132
+ if math.isclose(item, prediction, rel_tol=0.01):
133
+ return True
134
+ precision = max(min(get_precision(prediction), get_precision(item)), 2)
135
+ if round(prediction, precision) == round(item, precision):
136
+ return True
137
+ except Exception:
138
+ continue
139
+ return False
140
+
141
+
142
+ def get_clean_string(s):
143
+ s = str(s).lower().strip()
144
+ if s.endswith('mile'):
145
+ s.rstrip('mile').strip()
146
+ if s.endswith('miles'):
147
+ s.rstrip('miles').strip()
148
+ if s.endswith('million'):
149
+ s.rstrip('million').strip()
150
+ # remove parenthesis
151
+ s = re.sub(r'\s*\([^)]*\)', '', s).strip()
152
+ # remove quotes
153
+ s = re.sub(r"^['\"]|['\"]$", '', s).strip()
154
+ s = s.strip().lstrip('$').strip()
155
+ s = s.strip().rstrip('%').strip()
156
+ return s
157
+
158
+
159
+ def is_exact_match(s):
160
+ flag = False
161
+ # Website
162
+ if 'https://' in s:
163
+ flag = True
164
+ # code file
165
+ if s.endswith('.py') or s.endswith('ipynb'):
166
+ flag = True
167
+ if s.startswith('page'):
168
+ flag = True
169
+ # telephone number
170
+ if re.fullmatch(r'\b\d+(-\d+|\s\d+)?\b', s):
171
+ flag = True
172
+ # time
173
+ if 'a.m.' in s or 'p.m.' in s:
174
+ flag = True
175
+ # YYYY-MM-DD
176
+ if re.fullmatch(r'\b\d{4}[-\s]\d{2}[-\s]\d{2}\b', s):
177
+ flag = True
178
+ # YYYY-MM
179
+ if re.fullmatch(r'\b\d{4}[-\s]\d{2}\b', s):
180
+ flag = True
181
+ # Email address
182
+ if re.fullmatch(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}', s):
183
+ flag = True
184
+ return flag
185
+
186
+
187
+ def isfloat(num):
188
+ try:
189
+ float(num)
190
+ return True
191
+ except ValueError:
192
+ return False
193
+
194
+
195
+ def get_font():
196
+ try:
197
+ truetype_url = "http://opencompass.openxlab.space/utils/Fonts/SimHei.ttf"
198
+ ff = urlopen(truetype_url)
199
+ font = ImageFont.truetype(ff, size=40)
200
+ except Exception as e:
201
+ logging.warning(f'{type(e)}: {e}')
202
+ logging.warning("Fail to download the font. Use the default one.")
203
+ font = ImageFont.load_default(size=40)
204
+ return font
205
+
206
+
207
+ def frame2img(img_path_list, font, save_path=None, idx_start=0):
208
+ imgs = [Image.open(img_path) for img_path in img_path_list]
209
+
210
+ new_imgs = []
211
+ for img in imgs:
212
+ w, h = img.size
213
+ scale = w / h
214
+ if w > h:
215
+ new_w = 560 * 2
216
+ new_h = int(560 * 2 / scale)
217
+ else:
218
+ new_w = int(560 * 2 * scale)
219
+ new_h = 560 * 2
220
+ img = transforms.functional.resize(img, [new_h, new_w],)
221
+ new_imgs.append(img)
222
+ imgs = new_imgs
223
+ new_w = 0
224
+ new_h = 0
225
+ pad = 40
226
+ if w > h:
227
+ for im in imgs:
228
+ w, h = im.size
229
+ new_w = max(new_w, w)
230
+ new_h += h + 10 + pad
231
+ new_img = Image.new("RGB", (new_w, new_h), "white")
232
+ draw = ImageDraw.Draw(new_img)
233
+ curr_h = 0
234
+ for idx, im in enumerate(imgs):
235
+ w, h = im.size
236
+ new_img.paste(im, (0, pad + curr_h))
237
+ draw.text((0, curr_h), f"<IMAGE {idx+idx_start}>", font=font, fill="black")
238
+ if idx + 1 < len(imgs):
239
+ draw.line([(0, pad + curr_h + h + 5), (new_w, pad + curr_h + h + 5)], fill='black', width=2)
240
+ curr_h += h + 10 + pad
241
+ else:
242
+ for im in imgs:
243
+ w, h = im.size
244
+ new_w += w + 10
245
+ new_h = max(new_h, h)
246
+ new_h += pad
247
+ new_img = Image.new('RGB', (new_w, new_h), 'white')
248
+ draw = ImageDraw.Draw(new_img)
249
+ curr_w = 0
250
+ for idx, im in enumerate(imgs):
251
+ w, h = im.size
252
+ new_img.paste(im, (curr_w, pad))
253
+ draw.text((curr_w, 0), f"<IMAGE {idx+idx_start}>", font=font, fill='black')
254
+ if idx + 1 < len(imgs):
255
+ draw.line([(curr_w + w + 5, 0), (curr_w + w + 5, new_h)], fill='black', width=2)
256
+ curr_w += w + 10
257
+
258
+ if save_path is not None:
259
+ new_img.save(save_path)
260
+
261
+ return new_img
262
+
263
+
264
+ def concat_images(image_list, max_concat=1, column_num=1):
265
+ concatenated_images = []
266
+ if column_num == -1:
267
+ MAX_COLUMN_NUM = 20
268
+ max_concat = 1
269
+ while len(image_list) / max_concat > MAX_COLUMN_NUM:
270
+ max_concat += 1
271
+ interval = max(math.ceil(len(image_list) / max_concat), 1)
272
+ for i in range(0, len(image_list), interval):
273
+ batch_images = image_list[i:i + interval]
274
+ concatenated_image = frame2img(batch_images, font=get_font(), idx_start=i)
275
+ concatenated_images.append(concatenated_image)
276
+ else:
277
+ interval = max(math.ceil(len(image_list) / max_concat), 1)
278
+ for i in range(0, len(image_list), interval):
279
+ batch_images = [Image.open(filename) for filename in image_list[i:i + interval]]
280
+ if column_num == 1:
281
+ total_height = batch_images[0].height * len(batch_images)
282
+ else:
283
+ total_height = batch_images[0].height * ((len(batch_images) - 1) // column_num + 1)
284
+ concatenated_image = Image.new('RGB', (batch_images[0].width * column_num, total_height), 'white')
285
+
286
+ x_offset, y_offset = 0, 0
287
+ for count, image in enumerate(batch_images):
288
+ concatenated_image.paste(image, (x_offset, y_offset))
289
+ x_offset += image.width
290
+ if (count + 1) % column_num == 0:
291
+ y_offset += image.height
292
+ x_offset = 0
293
+ concatenated_images.append(concatenated_image)
294
+ return concatenated_images
295
+
296
+
297
+ def eval_score(gt, pred, answer_type):
298
+ if answer_type == 'Int':
299
+ try:
300
+ gt, pred = int(gt), int(float(pred))
301
+ except:
302
+ pred = ''
303
+ score = (gt == pred)
304
+ elif answer_type == 'Float':
305
+ try:
306
+ gt = float(get_clean_string(str(gt)))
307
+ pred = float(get_clean_string(str(pred)))
308
+ except:
309
+ pred = ''
310
+ score = is_float_equal(gt, pred, include_percentage=True, is_close=True)
311
+ elif answer_type == 'Str':
312
+ gt = get_clean_string(gt)
313
+ pred = get_clean_string(pred)
314
+ if is_exact_match(gt):
315
+ score = (gt == pred)
316
+ else:
317
+ score = anls_compute(gt, pred)
318
+ else:
319
+ if isinstance(gt, str) and gt.startswith('['):
320
+ gt = eval(gt)
321
+ if not isinstance(gt, list):
322
+ gt = [gt]
323
+ if isinstance(pred, str) and pred.startswith('['):
324
+ pred = eval(pred)
325
+ if not isinstance(pred, list):
326
+ pred = [pred]
327
+ print(len(gt), len(pred))
328
+ if len(gt) != len(pred):
329
+ score = 0.0
330
+ else:
331
+ gt = sorted([get_clean_string(a) for a in gt])
332
+ pred = sorted([get_clean_string(a) for a in pred])
333
+ print(gt, pred)
334
+ if isfloat(gt[0]) or is_exact_match(gt[0]):
335
+ score = ('-'.join(gt) == '-'.join(pred))
336
+ else:
337
+ score = min([anls_compute(gt_v, pred_v) for gt_v, pred_v in zip(gt, pred)])
338
+
339
+ return float(score)
340
+
341
+
342
+ def MMLongBench_auxeval(model, line):
343
+ prompt = build_mmlongbench_gpt4_prompt(line)
344
+ log = ''
345
+ retry = 5
346
+
347
+ for i in range(retry):
348
+ prediction = line['prediction']
349
+ res = model.generate(prompt, temperature=i * 0.5)
350
+
351
+ if FAIL_MSG in res:
352
+ log += f'Try {i}: output is {prediction}, failed to parse.\n'
353
+ else:
354
+ log += 'Succeed'
355
+ try:
356
+ pred = res.split('Answer format:')[0].split('Extracted answer:')[1].strip()
357
+ except:
358
+ pred = ''
359
+ return dict(log=log, res=res, pred=pred)
360
+ log += 'All 5 retries failed.\n'
361
+ return dict(log=log, res='', pred='')
362
+
363
+
364
+ def get_f1(data):
365
+ gt_pos_data = data[data.apply(lambda k: k['answer'] != 'Not answerable', axis=1)]
366
+ pred_pos_data = data[data.apply(lambda k: k['pred'] != 'Not answerable', axis=1)]
367
+ recall = sum(gt_pos_data['score'].tolist()) / len(gt_pos_data)
368
+ precision = sum(pred_pos_data['score'].tolist()) / len(pred_pos_data)
369
+ return 2 * recall * precision / (recall + precision)
370
+
371
+
372
+ def MMLongBench_acc(result_file):
373
+ data = load(result_file)
374
+ overall_score = 0.0
375
+ score_list = list()
376
+ for i in range(len(data)):
377
+ item = data.iloc[i]
378
+ try:
379
+ score = eval_score(item['answer'], item['pred'], item['answer_format'])
380
+ except:
381
+ score = 0.0
382
+ score_list.append(score)
383
+ overall_score += score
384
+
385
+ data['score'] = score_list
386
+ dump(data, result_file)
387
+
388
+ data_chart = data[data.apply(lambda k: 'Chart' in eval(k['evidence_sources']), axis=1)]
389
+ data_table = data[data.apply(lambda k: 'Table' in eval(k['evidence_sources']), axis=1)]
390
+ data_image = data[data.apply(lambda k: 'Figure' in eval(k['evidence_sources']), axis=1)]
391
+ data_text = data[data.apply(lambda k: 'Pure-text (Plain-text)' in eval(k['evidence_sources']), axis=1)]
392
+ data_layout = data[data.apply(lambda k: 'Generalized-text (Layout)' in eval(k['evidence_sources']), axis=1)]
393
+
394
+ data_single = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 1, axis=1)]
395
+ data_multi = data[data.apply(lambda k: len(eval(k['evidence_pages'])) > 1, axis=1)]
396
+ data_unans = data[data.apply(lambda k: len(eval(k['evidence_pages'])) == 0, axis=1)]
397
+
398
+ res = dict()
399
+ res['category'] = [
400
+ 'overall_f1', 'overall_acc', 'text', 'layout', 'table', 'chart',
401
+ 'image', 'single-page', 'multi-page', 'unanswerable'
402
+ ]
403
+ res['num'] = [
404
+ len(data), len(data), len(data_text), len(data_layout), len(data_table),
405
+ len(data_chart), len(data_image), len(data_single), len(data_multi), len(data_unans)
406
+ ]
407
+ res['avg_score'] = [
408
+ get_f1(data),
409
+ overall_score / len(data),
410
+ sum(data_text['score'].tolist()) / len(data_text) if len(data_text) > 0 else 0.0,
411
+ sum(data_layout['score'].tolist()) / len(data_layout) if len(data_layout) > 0 else 0.0,
412
+ sum(data_table['score'].tolist()) / len(data_table) if len(data_table) > 0 else 0.0,
413
+ sum(data_chart['score'].tolist()) / len(data_chart) if len(data_chart) > 0 else 0.0,
414
+ sum(data_image['score'].tolist()) / len(data_image) if len(data_image) > 0 else 0.0,
415
+ sum(data_single['score'].tolist()) / len(data_single) if len(data_single) > 0 else 0.0,
416
+ sum(data_multi['score'].tolist()) / len(data_multi) if len(data_multi) > 0 else 0.0,
417
+ sum(data_unans['score'].tolist()) / len(data_unans) if len(data_unans) > 0 else 0.0,
418
+ ]
419
+ res = pd.DataFrame(res)
420
+ return res
421
+
422
+
423
+ class MMLongBench(ImageBaseDataset):
424
+
425
+ TYPE = 'VQA'
426
+
427
+ DATASET_URL = {
428
+ 'MMLongBench_DOC': 'https://opencompass.openxlab.space/utils/VLMEval/MMLongBench_DOC.tsv',
429
+ }
430
+ DATASET_MD5 = {
431
+ 'MMLongBench_DOC': '9b393e1f4c52718380d50586197eac9b',
432
+ }
433
+
434
+ SUPPORTED_MODELS = {
435
+ 'GPT4': (1, 1),
436
+ 'GPT4V': (1, 1),
437
+ 'GPT4V_HIGH': (1, 1),
438
+ 'GPT4o': (1, 1),
439
+ 'GPT4o_HIGH': (1, 1),
440
+ 'GPT4o_MINI': (1, 1),
441
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
442
+ 'InternVL-Chat-V1-5': (5, 2),
443
+ 'XComposer2_4KHD': (1, 5),
444
+ 'XComposer2d5': (1, -1),
445
+ }
446
+
447
+ def __init__(self, dataset, **kwargs):
448
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
449
+ model_name = kwargs['model']
450
+ if not listinstr(self.model_list, model_name):
451
+ raise AssertionError("{} doesn't support the evaluation on MMLongBench_DOC.".format(model_name))
452
+ super(MMLongBench, self).__init__(dataset)
453
+
454
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
455
+ self.max_pages = 120
456
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
457
+ self.concat_num = concat_num
458
+ self.column_num = column_num
459
+
460
+ def dump_image(self, origin_line):
461
+ os.makedirs(self.img_root, exist_ok=True)
462
+ try:
463
+ import fitz
464
+ except Exception as e:
465
+ logging.critical(f'{type(e)}: {e}')
466
+ logging.critical('Please use `pip install pymupdf` to parse PDF files.')
467
+
468
+ line = origin_line.copy()
469
+ line['image_path'] = line['image_path'][:self.max_pages]
470
+ skip_pdf_parse = True
471
+ for im_name in line['image_path']:
472
+ path = osp.join(self.img_root, im_name)
473
+ if not read_ok(path):
474
+ skip_pdf_parse = False
475
+ break
476
+
477
+ # Just for being compatible with the zooped loop: zip(line['image'], line['image_path'])
478
+ if skip_pdf_parse:
479
+ line['image'] = line['image_path']
480
+ else:
481
+ pdf_data = base64.b64decode(line['image'])
482
+ pdf_file = io.BytesIO(pdf_data)
483
+ encoded_images = []
484
+ with fitz.open(stream=pdf_file, filetype='pdf') as doc:
485
+ doc = doc[:self.max_pages]
486
+ for page in doc:
487
+ image = page.get_pixmap(dpi=144)
488
+ image_file = io.BytesIO(image.tobytes(output='png'))
489
+ image = Image.open(image_file)
490
+ encoded_image = encode_image_to_base64(image)
491
+ encoded_images.append(encoded_image)
492
+ line['image'] = encoded_images
493
+ print('process {}'.format(line['doc_id']))
494
+
495
+ if 'image' in line:
496
+ if isinstance(line['image'], list):
497
+ tgt_path = []
498
+ assert 'image_path' in line
499
+ for img, im_name in zip(line['image'], line['image_path']):
500
+ path = osp.join(self.img_root, im_name)
501
+ if not read_ok(path):
502
+ decode_base64_to_image_file(img, path)
503
+ tgt_path.append(path)
504
+ else:
505
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
506
+ if not read_ok(tgt_path):
507
+ decode_base64_to_image_file(line['image'], tgt_path)
508
+ tgt_path = [tgt_path]
509
+ else:
510
+ assert 'image_path' in line
511
+ tgt_path = toliststr(line['image_path'])
512
+
513
+ if self.concat_num > 0 and not self.is_api:
514
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
515
+
516
+ old_tgt_path = tgt_path
517
+ assert isinstance(old_tgt_path, list)
518
+ if self.column_num != -1:
519
+ tgt_path = [
520
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
521
+ for i in range(len(concatenated_images))
522
+ ]
523
+ else:
524
+ tgt_path = [
525
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all_{}.jpg'.format(i)
526
+ for i in range(len(concatenated_images))
527
+ ]
528
+
529
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
530
+ if not read_ok(path):
531
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
532
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
533
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
534
+ return tgt_path
535
+
536
+ @classmethod
537
+ def evaluate(self, eval_file, **judge_kwargs):
538
+ logger = get_logger('Evaluation')
539
+ model = judge_kwargs['model']
540
+
541
+ suffix = eval_file.split('.')[-1]
542
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
543
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
544
+
545
+ if osp.exists(storage):
546
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in MMLongBench_eval. ')
547
+ else:
548
+ data = load(eval_file)
549
+ model = build_judge(max_tokens=128, **judge_kwargs)
550
+ lt = len(data)
551
+ lines = [data.iloc[i] for i in range(lt)]
552
+ tups = [(model, line) for line in lines]
553
+ indices = [line['index'] for line in lines]
554
+
555
+ ans = {}
556
+ if osp.exists(tmp_file):
557
+ ans = load(tmp_file)
558
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
559
+ indices = [i for i in indices if i not in ans]
560
+
561
+ if len(indices):
562
+ new_results = list()
563
+ for model, line in tqdm(tups):
564
+ res = MMLongBench_auxeval(model, line)
565
+ new_results.append(res)
566
+
567
+ log_map, res_map, pred_map = {}, {}, {}
568
+ all_inds = [line['index'] for line in lines]
569
+ for k, v in zip(all_inds, new_results):
570
+ log_map[k] = v['log']
571
+ res_map[k] = v['res']
572
+ pred_map[k] = v['pred']
573
+ data['res'] = [res_map[idx] for idx in data['index']]
574
+ data['log'] = [log_map[idx] for idx in data['index']]
575
+ data['pred'] = [pred_map[idx] for idx in data['index']]
576
+ dump(data, storage)
577
+
578
+ score = MMLongBench_acc(storage)
579
+ score_pth = storage.replace('.xlsx', '_score.csv')
580
+
581
+ dump(score, score_pth)
582
+ logger.info(f'MMLongBench_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
583
+ logger.info('Score: ')
584
+ logger.info(score)
VLMEvalKit/vlmeval/dataset/mmmath.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import sympy as sp
4
+ import numpy as np
5
+ from sympy import simplify, Eq, sympify, Pow, pi
6
+ from sympy.parsing.latex import parse_latex
7
+ import sys
8
+ import math
9
+ import os
10
+ import argparse
11
+
12
+ from .image_base import ImageBaseDataset
13
+ from ..utils import track_progress_rich
14
+ from ..smp import load, dump
15
+
16
+
17
+ class AutoScoringJudge:
18
+ def __init__(self):
19
+ # Map of special symbols to their replacements
20
+ self.special_signal_map = {
21
+ "\\left": "",
22
+ "\\right": "",
23
+ "厘米":"",
24
+ # "∶": ":",
25
+ ",": ",",
26
+ "$": "",
27
+ "(":"(",
28
+ ")":")",
29
+ "\\infty":"oo",
30
+ "\\colon ":":",
31
+ # "\\approx": "=",
32
+ # "\\simeq": "=",
33
+ # "\\sim": "=",
34
+ # "^\\prime": "'",
35
+ # "^{\\prime}": "'",
36
+ "+":"+",
37
+ "\\, ": "",
38
+ "\\,":"",
39
+ "^\\circ": "",
40
+ "^{\\circ}": "",
41
+ # "%": "",
42
+ }
43
+ self.pi = parse_latex("\\pi")
44
+ # MM-Math default precision
45
+ self.precision = 1e-2
46
+
47
+ def trans_greater_sign_to_interval(self, expr:str):
48
+ expr_tmp = expr.split("<")
49
+ return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
50
+
51
+ def split_by_comma(self, expr: str):
52
+ # Splits expressions by commas outside of brackets
53
+ in_bracket_num = 0
54
+ splitted_expr = []
55
+ start_idx = 0
56
+ for i, char in enumerate(expr):
57
+ if char in ["(", "["]:
58
+ in_bracket_num += 1
59
+ elif char in [")", "]"]:
60
+ in_bracket_num -= 1
61
+ elif char == "," and in_bracket_num == 0:
62
+ splitted_expr.append(expr[start_idx:i].strip())
63
+ start_idx = i + 1
64
+
65
+ if start_idx < len(expr):
66
+ splitted_expr.append(expr[start_idx:].strip())
67
+
68
+ return splitted_expr
69
+
70
+ def trans_plus_minus_sign(self, expr_list: list):
71
+ # Translates plus-minus signs into separate expressions
72
+ new_expr_list = []
73
+ for expr in expr_list:
74
+ if "\\pm" in expr:
75
+ new_expr_list.append(expr.replace("\\pm", "+"))
76
+ new_expr_list.append(expr.replace("\\pm", "-"))
77
+ else:
78
+ new_expr_list.append(expr)
79
+
80
+ return new_expr_list
81
+
82
+ def judge(self, expression1, expression2, precision=1e-2):
83
+ # Judge if two expressions are equal (expression1 is considered as the Ground Truth)
84
+ # Default precision is a list for supporting multiple expressions
85
+ precision = precision if isinstance(precision, list) else [precision]
86
+
87
+ try:
88
+ expression1, expression2 = self.preprocess(expression1, expression2)
89
+ except:
90
+ return False
91
+ if expression1 == expression2:
92
+ # print("Exactly equal")
93
+ return True
94
+
95
+ # Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
96
+ expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
97
+ expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
98
+ # Check if two < or > in expression
99
+ if self.is_two_greater_sign(expression1):
100
+ expression1 = self.trans_greater_sign_to_interval(expression1)
101
+
102
+ if self.is_two_greater_sign(expression2):
103
+ expression2 = self.trans_greater_sign_to_interval(expression2)
104
+
105
+ expression1 = self.split_by_comma(expression1)
106
+ expression2 = self.split_by_comma(expression2)
107
+
108
+ temp_list1 = self.trans_plus_minus_sign(expression1)
109
+ temp_list2 = self.trans_plus_minus_sign(expression2)
110
+
111
+ # Set up a list for allowed errors
112
+ if len(precision) <= 1:
113
+ precision = precision * len(temp_list1)
114
+
115
+ if len(temp_list1) != len(temp_list2):
116
+ return False
117
+
118
+ # Check if elements in both lists can be paired and are equal
119
+ idx = -1
120
+ while len(temp_list1) != 0:
121
+ idx = (idx + 1) % len(temp_list1)
122
+
123
+ item1 = temp_list1[idx]
124
+ self.precision = precision[idx]
125
+
126
+ for item2 in temp_list2:
127
+ if self.is_equal(item1, item2):
128
+ temp_list1.remove(item1)
129
+ temp_list2.remove(item2)
130
+ precision.remove(self.precision)
131
+ break
132
+ else:
133
+ # If no match was found, return False
134
+ return False
135
+
136
+ # If all elements are matched, return True
137
+ return True
138
+
139
+ def is_interval(self, expr):
140
+ # Checks if an expression is an interval
141
+ return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
142
+
143
+ def is_two_greater_sign(self, expr):
144
+ match = re.findall(r'<', expr)
145
+ return len(match) == 2
146
+
147
+ def sympy_sub_pi(self, expression_sympy):
148
+ # Replaces the symbol for pi in sympy expressions with its numerical value
149
+ return expression_sympy.subs(self.pi, math.pi)
150
+
151
+ def is_equal(self, expression1, expression2):
152
+ # Default first expression is ground truth. Check if expressions are equal in different aspects
153
+ if expression1 == expression2 and expression1 != "" and expression2 != "":
154
+ # print("Equivalent natively")
155
+ return True
156
+
157
+ # First check if both are intervals
158
+ if self.is_interval(expression1) and self.is_interval(expression2):
159
+ try:
160
+ if self.interval_equal(expression1, expression2):
161
+ # print("Interval equivalent")
162
+ return True
163
+ except:
164
+ return False
165
+
166
+ # Then check for numerical equality
167
+ try:
168
+ if self.numerical_equal(expression1, expression2):
169
+ # print("Numerically equivalent")
170
+ return True
171
+ except:
172
+ pass
173
+ # Then check if expressions are mathematically equal
174
+ try:
175
+ if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
176
+ # print("Expression equivalent")
177
+ return True
178
+ except:
179
+ pass
180
+
181
+ # Lastly, check for equation equality
182
+ try:
183
+ if self.equation_equal(expression1, expression2):
184
+ # print("Equation equivalent")
185
+ return True
186
+ except:
187
+ pass
188
+
189
+ return False
190
+
191
+ def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
192
+ # Check if two numerical values are equal within an allowed error range
193
+ # Includes possible percentage cases
194
+ reference = float(expression1)
195
+ prediction = float(expression2)
196
+
197
+ if include_percentage:
198
+ gt_result = [reference / 100, reference, reference * 100]
199
+ else:
200
+ gt_result = [reference]
201
+
202
+ for item in gt_result:
203
+ if abs(item - prediction) <= self.precision * 1.01:
204
+ return True
205
+ return False
206
+
207
+ def expression_equal(self, exp1, exp2):
208
+ # Check if two expressions are mathematically equivalent
209
+ # Extract expression and use sympy for equivalence checking
210
+ def extract_expression(expression):
211
+ if "=" in expression:
212
+ expression = expression.split("=")[1]
213
+ return expression.strip()
214
+
215
+ exp1 = extract_expression(exp1)
216
+ exp2 = extract_expression(exp2)
217
+
218
+ exp_too_long = len(exp1) > 300 or len(exp2) > 300
219
+
220
+ expr1_sym = sympify(parse_latex(exp1))
221
+ expr2_sym = sympify(parse_latex(exp2))
222
+ if expr1_sym == expr2_sym:
223
+ return True
224
+ else:
225
+ expr1_sym = self.sympy_sub_pi(expr1_sym)
226
+ expr2_sym = self.sympy_sub_pi(expr2_sym)
227
+
228
+ if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
229
+ (not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
230
+ return False
231
+ elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
232
+ try:
233
+ if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
234
+ print("These two numbers cannot be calculated by the current computer for: "
235
+ f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
236
+ return False
237
+ if exp_too_long:
238
+ print(f'Expression {exp1} or {exp2} is too long to compute. ')
239
+ return False
240
+ if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
241
+ return True
242
+ else:
243
+ return False
244
+ except:
245
+ return False
246
+ elif exp_too_long:
247
+ print(f'Expression {exp1} or {exp2} is too long to compute. ')
248
+ return False
249
+ else:
250
+ try:
251
+ simplified_expr = simplify(expr1_sym - expr2_sym)
252
+ num_value = simplified_expr.evalf()
253
+ return abs(num_value) < 1e-3
254
+ except:
255
+ return False
256
+
257
+ def equation_equal(self, expression1, expression2):
258
+ # Check if two equations are mathematically equivalent
259
+ # Simplify equations and use sympy for equivalence checking
260
+ def simplify_equation(latex_eq):
261
+ lhs, rhs = latex_eq.split('=')
262
+
263
+ lhs_expr = parse_latex(lhs)
264
+ rhs_expr = parse_latex(rhs)
265
+
266
+ equation = Eq(lhs_expr, rhs_expr)
267
+
268
+ simplified_eq = simplify(equation.lhs - equation.rhs)
269
+
270
+ return simplified_eq
271
+
272
+ expr1_sym = simplify_equation(expression1)
273
+ expr2_sym = simplify_equation(expression2)
274
+
275
+ division_result_1 = simplify(expr1_sym / expr2_sym)
276
+ division_result_2 = simplify(expr2_sym / expr1_sym)
277
+
278
+ if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
279
+ (division_result_2.is_Integer and division_result_2 != 0)):
280
+ return True
281
+ else:
282
+ return False
283
+
284
+ def interval_equal(self, expression1, expression2):
285
+ # Check if two intervals are mathematically equivalent
286
+ def compare_two_interval(inter1, inter2):
287
+ if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
288
+ return False
289
+
290
+ inter1 = inter1.strip('[]()')
291
+ inter2 = inter2.strip('[]()')
292
+
293
+ items_1 = inter1.split(',')
294
+ items_2 = inter2.split(',')
295
+
296
+ for item_1, item_2 in zip(items_1, items_2):
297
+ if not self.expression_equal(item_1, item_2):
298
+ return False
299
+ return True
300
+
301
+ interval1 = expression1
302
+ interval2 = expression2
303
+
304
+ if interval1 == interval2:
305
+ return True
306
+ else:
307
+ inter_list1 = interval1.split("\\cup")
308
+ inter_list2 = interval2.split("\\cup")
309
+
310
+ if len(inter_list1) != len(inter_list2):
311
+ return False
312
+ else:
313
+ for inter1, inter2 in zip(inter_list1, inter_list2):
314
+ if not compare_two_interval(inter1, inter2):
315
+ return False
316
+ return True
317
+
318
+ def preprocess(self, expression1, expression2):
319
+ # Preprocess expressions to extract and replace special symbols
320
+ def extract_boxed_content(latex_str):
321
+ boxed_matches = re.finditer(r'\\boxed{', latex_str)
322
+ results = ""
323
+
324
+ for match in boxed_matches:
325
+ start_index = match.end()
326
+ end_index = start_index
327
+ stack = 1
328
+
329
+ while stack > 0 and end_index < len(latex_str):
330
+ if latex_str[end_index] == '{':
331
+ stack += 1
332
+ elif latex_str[end_index] == '}':
333
+ stack -= 1
334
+ end_index += 1
335
+
336
+ if stack == 0:
337
+ content = latex_str[start_index:end_index - 1]
338
+ results += content + ","
339
+ else:
340
+ raise ValueError("Mismatched braces in LaTeX string.")
341
+
342
+ if results == "":
343
+ last_line_ans = latex_str.strip().split("\n")[-1]
344
+ dollar_pattern = r"\$(.*?)\$"
345
+ answers = re.findall(dollar_pattern, last_line_ans)
346
+
347
+ if answers:
348
+ for ans in answers:
349
+ results += ans + ","
350
+ else:
351
+ results = latex_str
352
+
353
+ return results
354
+
355
+ def sepcial_symbol_replace(expression):
356
+
357
+ expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
358
+
359
+ expression = re.sub(r"(.+)m$", r"\1", expression)
360
+
361
+ if "\\in " in expression:
362
+ expression = expression.split("\\in ")[1]
363
+
364
+ for signal in self.special_signal_map:
365
+ expression = expression.replace(signal, self.special_signal_map[signal])
366
+
367
+ expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
368
+
369
+ expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
370
+
371
+ pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
372
+ expression = re.sub(pattern, r'\1', expression)
373
+
374
+ return expression
375
+
376
+ exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
377
+
378
+ exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
379
+
380
+ return exp1, exp2
381
+
382
+ def can_compute_power(self, expr):
383
+ # Checks if a power expression can be computed
384
+ if isinstance(expr, Pow):
385
+ base, exp = expr.as_base_exp()
386
+ if base.is_number and exp.is_number:
387
+ MAX_EXP = 1000 # Adjust based on computing environment
388
+ if abs(exp.evalf()) > MAX_EXP:
389
+ return False
390
+ else:
391
+ return True
392
+ else:
393
+ return False
394
+ else:
395
+ return True # Not a power expression, can compute
396
+
397
+
398
+ class MMMath(ImageBaseDataset):
399
+
400
+ TYPE = 'VQA'
401
+
402
+ DATASET_URL = {
403
+ 'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
404
+ }
405
+ DATASET_MD5 = {
406
+ 'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
407
+ }
408
+
409
+ @classmethod
410
+ def evaluate(self, eval_file, **kwargs):
411
+
412
+ data = load(eval_file)
413
+ judger = AutoScoringJudge()
414
+ func = judger.judge
415
+
416
+ tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
417
+
418
+ res = track_progress_rich(func, tups, nproc=16)
419
+ data['hit'] = res
420
+ dump(data, eval_file)
421
+
422
+ score_file = eval_file.replace('.xlsx', '_score.json')
423
+ score = {}
424
+ score['overall'] = np.mean(data['hit'])
425
+ # Results by Difficulty
426
+ difficulties = set(data['difficulty'])
427
+ for d in difficulties:
428
+ score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
429
+
430
+ # Results by Year
431
+ years = set(data['year'])
432
+ for y in years:
433
+ score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
434
+
435
+ # Results by Knowledge-L1
436
+ points = set(data['knowledge_l1'])
437
+ for p in points:
438
+ score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
439
+
440
+ # Results by Knowledge-L2
441
+ points = set(data['knowledge_l2'])
442
+ for p in points:
443
+ score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
444
+
445
+ dump(score, score_file)
446
+ return score
VLMEvalKit/vlmeval/dataset/mvbench.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from huggingface_hub import snapshot_download
3
+ from ..smp import *
4
+ from .video_base import VideoBaseDataset
5
+ from .utils import build_judge, DEBUG_MESSAGE
6
+ from ..utils import track_progress_rich
7
+ import torchvision.transforms as T
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import InterpolationMode
10
+ from decord import VideoReader, cpu
11
+ import imageio
12
+ import cv2
13
+ import zipfile
14
+ import os
15
+ import glob
16
+ from .utils.mvbench import *
17
+
18
+ FAIL_MSG = 'Failed to obtain answer via API.'
19
+
20
+
21
+ class MVBench(VideoBaseDataset):
22
+
23
+ MD5 = 'fd21d36522cdedd46d84dc46715ad832'
24
+ SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
25
+ the detail and movement of objects, and the action and pose of persons. \
26
+ Based on your observations, select the best option that accurately addresses the question.
27
+ """
28
+
29
+ TYPE = 'Video-MCQ'
30
+
31
+ def __init__(self, dataset='MVBench', pack=False):
32
+ self.type_data_list = {
33
+ 'Action Sequence': ('action_sequence.json',
34
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
35
+ 'Action Prediction': ('action_prediction.json',
36
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
37
+ 'Action Antonym': ('action_antonym.json',
38
+ 'your_data_path/ssv2_video/', 'video', False),
39
+ 'Fine-grained Action': ('fine_grained_action.json',
40
+ 'your_data_path/Moments_in_Time_Raw/videos/', 'video', False),
41
+ 'Unexpected Action': ('unexpected_action.json',
42
+ 'your_data_path/FunQA_test/test/', 'video', False),
43
+ 'Object Existence': ('object_existence.json',
44
+ 'your_data_path/clevrer/video_validation/', 'video', False),
45
+ 'Object Interaction': ('object_interaction.json',
46
+ 'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
47
+ 'Object Shuffle': ('object_shuffle.json',
48
+ 'your_data_path/perception/videos/', 'video', False),
49
+ 'Moving Direction': ('moving_direction.json',
50
+ 'your_data_path/clevrer/video_validation/', 'video', False),
51
+ 'Action Localization': ('action_localization.json',
52
+ 'your_data_path/sta/sta_video/', 'video', True), # has start & end
53
+ 'Scene Transition': ('scene_transition.json',
54
+ 'your_data_path/scene_qa/video/', 'video', False),
55
+ 'Action Count': ('action_count.json',
56
+ 'your_data_path/perception/videos/', 'video', False),
57
+ 'Moving Count': ('moving_count.json',
58
+ 'your_data_path/clevrer/video_validation/', 'video', False),
59
+ 'Moving Attribute': ('moving_attribute.json',
60
+ 'your_data_path/clevrer/video_validation/', 'video', False),
61
+ 'State Change': ('state_change.json',
62
+ 'your_data_path/perception/videos/', 'video', False),
63
+ 'Fine-grained Pose': ('fine_grained_pose.json',
64
+ 'your_data_path/nturgbd/', 'video', False),
65
+ 'Character Order': ('character_order.json',
66
+ 'your_data_path/perception/videos/', 'video', False),
67
+ 'Egocentric Navigation': ('egocentric_navigation.json',
68
+ 'your_data_path/vlnqa/', 'video', False),
69
+ 'Episodic Reasoning': ('episodic_reasoning.json',
70
+ 'your_data_path/tvqa/frames_fps3_hq/', 'frame', True), # has start & end, read frame
71
+ 'Counterfactual Inference': ('counterfactual_inference.json',
72
+ 'your_data_path/clevrer/video_validation/', 'video', False),
73
+ }
74
+ super().__init__(dataset=dataset, pack=pack)
75
+
76
+ @classmethod
77
+ def supported_datasets(cls):
78
+ return ['MVBench']
79
+
80
+ def prepare_dataset(self, dataset_name='MVBench', repo_id='OpenGVLab/MVBench'):
81
+ def check_integrity(pth):
82
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
83
+
84
+ if not os.path.exists(data_file):
85
+ return False
86
+
87
+ if md5(data_file) != self.MD5:
88
+ return False
89
+
90
+ data = load(data_file)
91
+ for idx, item in data.iterrows():
92
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
93
+ return False
94
+ return True
95
+
96
+ if modelscope_flag_set():
97
+ repo_id = 'modelscope/MVBench'
98
+
99
+ cache_path = get_cache_path(repo_id, branch='main')
100
+ if cache_path is not None and check_integrity(cache_path):
101
+ dataset_path = cache_path
102
+ else:
103
+ def unzip_hf_zip(pth):
104
+ pth = os.path.join(pth, 'video/')
105
+ for filename in os.listdir(pth):
106
+ if filename.endswith('.zip'):
107
+ # 构建完整的文件路径
108
+ zip_path = os.path.join(pth, filename)
109
+
110
+ # 解压 ZIP 文件
111
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
112
+ zip_ref.extractall(pth)
113
+
114
+ def generate_tsv(pth):
115
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
116
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
117
+ return
118
+ json_data_dir = os.path.join(pth, 'json')
119
+ self.data_list = []
120
+ for k, v in self.type_data_list.items():
121
+ with open(os.path.join(json_data_dir, v[0]), 'r') as f:
122
+ json_data = json.load(f)
123
+ for data in json_data:
124
+ if os.path.exists(os.path.join(pth, v[1].replace('your_data_path', 'video'), data['video'])):
125
+ self.data_list.append({
126
+ 'task_type': k,
127
+ 'prefix': v[1].replace('your_data_path', 'video'),
128
+ 'data_type': v[2],
129
+ 'bound': v[3],
130
+ 'start': data['start'] if 'start' in data.keys() else None,
131
+ 'end': data['end'] if 'end' in data.keys() else None,
132
+ 'video': data['video'],
133
+ 'question': data['question'],
134
+ 'answer': data['answer'],
135
+ 'candidates': data['candidates']
136
+ })
137
+ else:
138
+ print(
139
+ 'NTURGB-D zip file is removed according to MVBench, you can view it at '
140
+ 'https://huggingface.co/datasets/OpenGVLab/MVBench for detailed reason.'
141
+ )
142
+ raise Exception(
143
+ f"{os.path.join(v[1].replace('your_data_path', 'video'), data['video'])} does not exist"
144
+ )
145
+
146
+ data_df = pd.DataFrame(self.data_list)
147
+ data_df = data_df.assign(index=range(len(data_df)))
148
+ data_df.to_csv(data_file, sep='\t', index=False)
149
+
150
+ def move_files(pth):
151
+ src_folder = os.path.join(pth, 'video/data0613')
152
+ if not os.path.exists(src_folder):
153
+ return
154
+ for subdir in os.listdir(src_folder):
155
+ subdir_path = os.path.join(src_folder, subdir)
156
+ if os.path.isdir(subdir_path):
157
+ for subsubdir in os.listdir(subdir_path):
158
+ subsubdir_path = os.path.join(subdir_path, subsubdir)
159
+ if os.path.isdir(subsubdir_path):
160
+ for item in os.listdir(subsubdir_path):
161
+ item_path = os.path.join(subsubdir_path, item)
162
+ target_folder = os.path.join(pth, 'video', subdir, subsubdir)
163
+ if not os.path.exists(target_folder):
164
+ os.makedirs(target_folder)
165
+ target_path = os.path.join(target_folder, item)
166
+ try:
167
+ shutil.move(item_path, target_path)
168
+ except Exception as e:
169
+ print(f"Error moving {item_path} to {target_path}: {e}")
170
+
171
+ if modelscope_flag_set():
172
+ from modelscope import dataset_snapshot_download
173
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='master')
174
+ else:
175
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
176
+ huggingface_hub.login(hf_token)
177
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
178
+ unzip_hf_zip(dataset_path)
179
+ move_files(dataset_path)
180
+ generate_tsv(dataset_path)
181
+
182
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
183
+
184
+ self.decord_method = {
185
+ 'video': self.read_video,
186
+ 'gif': self.read_gif,
187
+ 'frame': self.read_frame,
188
+ }
189
+
190
+ self.nframe = 8
191
+ self.frame_fps = 3
192
+
193
+ # transform
194
+ self.transform = T.Compose([
195
+ Stack(),
196
+ ToTorchFormatTensor()
197
+ ])
198
+
199
+ return dict(root=dataset_path, data_file=data_file)
200
+
201
+ def get_index(self, bound, fps, max_frame, first_idx=0):
202
+ if bound:
203
+ start, end = bound[0], bound[1]
204
+ else:
205
+ start, end = -100000, 100000
206
+ start_idx = max(first_idx, round(start * fps))
207
+ end_idx = min(round(end * fps), max_frame)
208
+ seg_size = float(end_idx - start_idx) / self.num_segments
209
+ frame_indices = np.array([
210
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
211
+ for idx in range(self.num_segments)
212
+ ])
213
+ return frame_indices
214
+
215
+ def read_video(self, video_path, bound=None):
216
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
217
+ max_frame = len(vr) - 1
218
+ fps = float(vr.get_avg_fps())
219
+
220
+ images_group = list()
221
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
222
+ for frame_index in frame_indices:
223
+ img = Image.fromarray(vr[frame_index].asnumpy())
224
+ images_group.append(img)
225
+ torch_imgs = self.transform(images_group)
226
+ return torch_imgs
227
+
228
+ def read_gif(self, video_path, bound=None, fps=25):
229
+ gif = imageio.get_reader(video_path)
230
+ max_frame = len(gif) - 1
231
+
232
+ images_group = list()
233
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
234
+ for index, frame in enumerate(gif):
235
+ if index in frame_indices:
236
+ img = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
237
+ img = Image.fromarray(img)
238
+ images_group.append(img)
239
+ torch_imgs = self.transform(images_group)
240
+ return torch_imgs
241
+
242
+ def read_frame(self, video_path, bound=None, fps=3):
243
+ max_frame = len(os.listdir(video_path))
244
+ images_group = list()
245
+ frame_indices = self.get_index(bound, fps, max_frame, first_idx=1) # frame_idx starts from 1
246
+ for frame_index in frame_indices:
247
+ img = Image.open(os.path.join(video_path, f'{frame_index:05d}.jpg'))
248
+ images_group.append(img)
249
+ torch_imgs = self.transform(images_group)
250
+ return torch_imgs
251
+
252
+ def save_video_frames(self, imgs, video_name, frames):
253
+
254
+ frame_paths = self.frame_paths(video_name, frames)
255
+ flag = np.all([osp.exists(p) for p in frame_paths])
256
+
257
+ if not flag:
258
+ block_size = imgs.size(0) // frames
259
+ split_tensors = torch.split(imgs, block_size)
260
+ to_pil = transforms.ToPILImage()
261
+ images = [to_pil(arr) for arr in split_tensors]
262
+ for im, pth in zip(images, frame_paths):
263
+ if not osp.exists(pth):
264
+ im.save(pth)
265
+
266
+ return frame_paths
267
+
268
+ def qa_template(self, data):
269
+ question = f"Question: {data['question']}\n"
270
+ question += 'Options:\n'
271
+ answer = data['answer']
272
+ answer_idx = -1
273
+ for idx, c in enumerate(eval(data['candidates'])):
274
+ question += f"({chr(ord('A') + idx)}) {c}\n"
275
+ if c == answer:
276
+ answer_idx = idx
277
+ question = question.rstrip()
278
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
279
+ return question, answer
280
+
281
+ def load_into_video_and_process(self, line):
282
+ try:
283
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
284
+ except:
285
+ raise ImportError(
286
+ 'MoviePy is not installed, please install it by running "pip install moviepy==1.0.3"'
287
+ )
288
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
289
+
290
+ if line['data_type'] in ['gif'] or os.path.splitext(video_path)[1] in ['.webm']:
291
+ processed_video_path = video_path.replace(os.path.splitext(video_path)[1], '.mp4')
292
+ if not os.path.exists(processed_video_path):
293
+ # using MoviePy to transform GIF, webm into mp4 format
294
+ gif_clip = VideoFileClip(video_path)
295
+ gif_clip.write_videofile(processed_video_path, codec='libx264')
296
+ gif_clip.close()
297
+ elif line['data_type'] in ['frame']:
298
+ input_images = os.path.join(video_path, '*.jpg')
299
+ processed_video_path = f'{video_path}.mp4'
300
+ if not os.path.exists(processed_video_path):
301
+ # using MoviePy to transform images into mp4
302
+ image_files = sorted(glob.glob(input_images))
303
+ image_clip = ImageSequenceClip(image_files, fps=self.frame_fps)
304
+ image_clip.write_videofile(processed_video_path, codec='libx264')
305
+ image_clip.close()
306
+ else:
307
+ processed_video_path = video_path
308
+
309
+ if line['bound']:
310
+ base_name, suffix = os.path.splitext(processed_video_path)
311
+ output_video_path = f'{base_name}_processed{suffix}'
312
+ if not os.path.exists(output_video_path):
313
+ video_clip = VideoFileClip(processed_video_path)
314
+ clip = video_clip.subclip(line['start'], min(line['end'], video_clip.duration))
315
+ clip.write_videofile(output_video_path)
316
+ clip.close()
317
+ else:
318
+ output_video_path = processed_video_path
319
+
320
+ return output_video_path
321
+
322
+ def save_video_into_images(self, line, num_frames):
323
+ bound = None
324
+ if line['bound']:
325
+ bound = (
326
+ line['start'],
327
+ line['end'],
328
+ )
329
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
330
+ decord_method = self.decord_method[line['data_type']]
331
+ self.num_segments = num_frames if num_frames > 0 else self.nframe
332
+ torch_imgs = decord_method(video_path, bound)
333
+ img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments)
334
+ return img_frame_paths
335
+
336
+ def build_prompt(self, line, num_frames, video_llm, fps):
337
+ if fps > 0:
338
+ raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!')
339
+ if isinstance(line, int):
340
+ assert line < len(self)
341
+ line = self.data.iloc[line]
342
+
343
+ question, answer = self.qa_template(line)
344
+ message = [dict(type='text', value=self.SYS, role='system')]
345
+ message.append(dict(type='text', value=question))
346
+ if video_llm:
347
+ new_video_path = self.load_into_video_and_process(line)
348
+ message.append(dict(type='video', value=new_video_path))
349
+ else:
350
+ img_frame_paths = self.save_video_into_images(line, num_frames)
351
+ for im in img_frame_paths:
352
+ message.append(dict(type='image', value=im))
353
+ message.append(dict(type='text', value='\nOnly give the best option.'))
354
+ message.append(dict(type='text', value='Best option:(', role='assistant'))
355
+ return message
356
+
357
+ @classmethod
358
+ def evaluate(self, eval_file, **judge_kwargs):
359
+
360
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
361
+
362
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
363
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
364
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
365
+
366
+ if not osp.exists(score_file):
367
+ model = judge_kwargs.setdefault('model', 'chatgpt-0125')
368
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
369
+
370
+ if model == 'exact_matching':
371
+ model = None
372
+ elif gpt_key_set():
373
+ model = build_judge(**judge_kwargs)
374
+ if not model.working():
375
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
376
+ warnings.warn(DEBUG_MESSAGE)
377
+ model = None
378
+ else:
379
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
380
+ model = None
381
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
382
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
383
+
384
+ data = load(eval_file)
385
+ data_un = data[~pd.isna(data['prediction'])]
386
+
387
+ for idx in data_un['index']:
388
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
389
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
390
+ options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
391
+ answer_idx = -1
392
+ for id, c in enumerate(options):
393
+ if c == ans:
394
+ answer_idx = id
395
+ ans = f"({chr(ord('A') + answer_idx)}) {ans}"
396
+ input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
397
+ for id, option_content in enumerate(eval(input_item['candidates'])):
398
+ input_item[chr(ord('A') + id)] = option_content
399
+ if option_content == input_item['answer']:
400
+ input_item['answer'] = chr(ord('A') + id)
401
+
402
+ if FAIL_MSG in pred:
403
+ data.loc[idx, 'score'] = -1
404
+ else:
405
+ data.loc[idx, 'score'] = int(check_ans_with_model(
406
+ pred, ans, model,
407
+ input_item,
408
+ 'MVBench'
409
+ ))
410
+
411
+ rejected = [x for x in data['score'] if x == -1]
412
+
413
+ print(
414
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
415
+ f'failed to obtain the score for another {len(rejected)} questions. '
416
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
417
+ )
418
+
419
+ dump(data, score_file)
420
+
421
+ rating = get_dimension_rating(score_file)
422
+ dump(rating, tgt_file)
423
+ return rating
424
+
425
+
426
+ class MVBench_MP4(VideoBaseDataset):
427
+
428
+ MP4_MD5 = '5c8c6f8b7972c2de65a629590f7c42f5'
429
+ SYS = """Carefully watch the video and pay attention to the cause and sequence of events, \
430
+ the detail and movement of objects, and the action and pose of persons. \
431
+ Based on your observations, select the best option that accurately addresses the question.
432
+ """
433
+ TYPE = 'Video-MCQ'
434
+
435
+ def __init__(self, dataset='MVBench_MP4', pack=False):
436
+ super().__init__(dataset=dataset, pack=pack)
437
+
438
+ @classmethod
439
+ def supported_datasets(cls):
440
+ return ['MVBench_MP4']
441
+
442
+ def prepare_dataset(self, dataset_name='MVBench_MP4', repo_id='OpenGVLab/MVBench'):
443
+ def check_integrity(pth):
444
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
445
+
446
+ if not os.path.exists(data_file):
447
+ return False
448
+
449
+ if md5(data_file) != self.MP4_MD5:
450
+ return False
451
+
452
+ data = load(data_file)
453
+ for idx, item in data.iterrows():
454
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
455
+ return False
456
+ return True
457
+
458
+ if modelscope_flag_set():
459
+ repo_id = 'modelscope/MVBench'
460
+
461
+ cache_path = get_cache_path(repo_id, branch='video')
462
+ if cache_path is not None and check_integrity(cache_path):
463
+ dataset_path = cache_path
464
+ else:
465
+ def generate_tsv(pth):
466
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
467
+ if os.path.exists(data_file) and md5(data_file) == self.MP4_MD5:
468
+ return
469
+ json_data_path = os.path.join(dataset_path, 'test.json')
470
+ json_data = load(json_data_path)
471
+ root_data_dict = json_data['root']
472
+ self.data_list = []
473
+ for k, v in json_data['meta'].items():
474
+ for item in v:
475
+ self.data_list.append({
476
+ 'task_type': k,
477
+ 'prefix': root_data_dict[k],
478
+ 'video': item['video'],
479
+ 'question': item['question'],
480
+ 'answer': item['answer'],
481
+ 'candidates': item['candidates']
482
+ })
483
+ data_df = pd.DataFrame(self.data_list)
484
+ data_df = data_df.assign(index=range(len(data_df)))
485
+ data_df.to_csv(data_file, sep='\t', index=False)
486
+
487
+ if modelscope_flag_set():
488
+ from modelscope import dataset_snapshot_download
489
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id, revision='video')
490
+ else:
491
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
492
+ huggingface_hub.login(hf_token)
493
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset', revision='video')
494
+ generate_tsv(dataset_path)
495
+
496
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
497
+
498
+ self.nframe = 8
499
+
500
+ # transform
501
+ self.transform = T.Compose([
502
+ Stack(),
503
+ ToTorchFormatTensor()
504
+ ])
505
+
506
+ return dict(root=dataset_path, data_file=data_file)
507
+
508
+ def qa_template(self, data):
509
+ question = f"Question: {data['question']}\n"
510
+ question += 'Options:\n'
511
+ answer = data['answer']
512
+ answer_idx = -1
513
+ for idx, c in enumerate(eval(data['candidates'])):
514
+ question += f"({chr(ord('A') + idx)}) {c}\n"
515
+ if c == answer:
516
+ answer_idx = idx
517
+ question = question.rstrip()
518
+ answer = f"({chr(ord('A') + answer_idx)}) {answer}"
519
+ return question, answer
520
+
521
+ def get_index_by_frame(self, max_frame):
522
+ seg_size = float(max_frame) / self.num_segments
523
+ frame_indices = np.array([
524
+ int((seg_size / 2) + np.round(seg_size * idx))
525
+ for idx in range(self.num_segments)
526
+ ])
527
+ return frame_indices
528
+
529
+ def get_index_by_fps(self, vid, fps):
530
+ total_frames = len(vid)
531
+ video_fps = vid.get_avg_fps()
532
+ total_duration = total_frames / video_fps
533
+ required_frames = int(total_duration * fps)
534
+ step_size = video_fps / fps
535
+ frame_indices = np.array([int(i * step_size) for i in range(required_frames)])
536
+ self.num_segments = len(frame_indices)
537
+ return frame_indices
538
+
539
+ def read_video(self, video_path, fps=-1):
540
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
541
+ max_frame = len(vr) - 1
542
+
543
+ images_group = list()
544
+ if fps < 0:
545
+ frame_indices = self.get_index_by_frame(max_frame)
546
+ else:
547
+ frame_indices = self.get_index_by_fps(vr, fps)
548
+
549
+ for frame_index in frame_indices:
550
+ img = Image.fromarray(vr[frame_index].asnumpy())
551
+ images_group.append(img)
552
+ torch_imgs = self.transform(images_group)
553
+ return torch_imgs
554
+
555
+ def save_video_frames(self, imgs, video_name, frames, fps):
556
+ if fps > 0:
557
+ frame_paths = self.frame_paths_fps(video_name, frames, fps)
558
+ else:
559
+ frame_paths = self.frame_paths(video_name, frames)
560
+ flag = np.all([osp.exists(p) for p in frame_paths])
561
+
562
+ if not flag:
563
+ block_size = imgs.size(0) // frames
564
+ split_tensors = torch.split(imgs, block_size)
565
+ to_pil = transforms.ToPILImage()
566
+ images = [to_pil(arr) for arr in split_tensors]
567
+ for im, pth in zip(images, frame_paths):
568
+ if not osp.exists(pth):
569
+ im.save(pth)
570
+
571
+ return frame_paths
572
+
573
+ def save_video_into_images(self, line, num_frames, fps=-1):
574
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
575
+ if fps <= 0:
576
+ self.num_segments = num_frames if num_frames > 0 else self.nframe
577
+ else:
578
+ self.num_segments = 0
579
+ torch_imgs = self.read_video(video_path, fps)
580
+ img_frame_paths = self.save_video_frames(torch_imgs, line['video'], self.num_segments, fps)
581
+ return img_frame_paths
582
+
583
+ def build_prompt(self, line, num_frames, video_llm, fps):
584
+ if isinstance(line, int):
585
+ assert line < len(self)
586
+ line = self.data.iloc[line]
587
+
588
+ question, answer = self.qa_template(line)
589
+ message = [dict(type='text', value=self.SYS, role='system')]
590
+ message.append(dict(type='text', value=question))
591
+ video_path = os.path.join(self.data_root, line['prefix'], line['video'])
592
+ if video_llm:
593
+ message.append(dict(type='video', value=video_path))
594
+ else:
595
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
596
+ for im in img_frame_paths:
597
+ message.append(dict(type='image', value=im))
598
+ message.append(dict(type='text', value='\nOnly give the best option.'))
599
+ message.append(dict(type='text', value='Best option:(', role='assistant'))
600
+ return message
601
+
602
+ @classmethod
603
+ def evaluate(self, eval_file, **judge_kwargs):
604
+
605
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
606
+
607
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
608
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
609
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
610
+
611
+ if not osp.exists(score_file):
612
+ model = judge_kwargs.setdefault('model', 'chatgpt-0125')
613
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
614
+
615
+ if model == 'exact_matching':
616
+ model = None
617
+ elif gpt_key_set():
618
+ model = build_judge(**judge_kwargs)
619
+ if not model.working():
620
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
621
+ warnings.warn(DEBUG_MESSAGE)
622
+ model = None
623
+ else:
624
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
625
+ model = None
626
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
627
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
628
+
629
+ data = load(eval_file)
630
+ data_un = data[~pd.isna(data['prediction'])]
631
+
632
+ for idx in data_un['index']:
633
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
634
+ pred = data.loc[data['index'] == idx, 'prediction'].values[0]
635
+ options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
636
+ answer_idx = -1
637
+ for id, c in enumerate(options):
638
+ if c == ans:
639
+ answer_idx = id
640
+ ans = f"({chr(ord('A') + answer_idx)}) {ans}"
641
+ input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
642
+ for id, option_content in enumerate(eval(input_item['candidates'])):
643
+ input_item[chr(ord('A') + id)] = option_content
644
+ if option_content == input_item['answer']:
645
+ input_item['answer'] = chr(ord('A') + id)
646
+
647
+ if FAIL_MSG in pred:
648
+ data.loc[idx, 'score'] = -1
649
+ else:
650
+ data.loc[idx, 'score'] = int(check_ans_with_model(
651
+ pred, ans, model,
652
+ input_item,
653
+ 'MVBench_MP4'
654
+ ))
655
+
656
+ rejected = [x for x in data['score'] if x == -1]
657
+
658
+ print(
659
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
660
+ f'failed to obtain the score for another {len(rejected)} questions. '
661
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
662
+ )
663
+
664
+ dump(data, score_file)
665
+
666
+ rating = get_dimension_rating(score_file)
667
+ dump(rating, tgt_file)
668
+ return rating
VLMEvalKit/vlmeval/dataset/slidevqa.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ from typing import List
4
+
5
+ from vlmeval.dataset.utils.judge_util import build_judge
6
+ from vlmeval.smp import *
7
+ from .image_base import ImageBaseDataset
8
+ from .mmlongbench import concat_images, MMLongBench_auxeval, anls_compute
9
+
10
+
11
+ FAIL_MSG = 'Failed to obtain answer via API.'
12
+
13
+
14
+ def get_f1(gt, pred):
15
+ gt_bow, pred_bow = gt.strip().split(), pred.strip().split()
16
+ if not gt_bow or not pred_bow:
17
+ return 0.0
18
+
19
+ recall = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(gt_bow)
20
+ precision = len([pred_e for pred_e in pred_bow if pred_e in gt_bow]) / len(pred_bow)
21
+ f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 1e-4 else 0.0
22
+ return f1
23
+
24
+
25
+ def SlideVQA_acc(result_file):
26
+ data = load(result_file)
27
+ anls_list, em_list, f1_list = list(), list(), list()
28
+ for i in range(len(data)):
29
+ item = data.iloc[i]
30
+ if isinstance(item['answer'], float) and math.isnan(item['answer']):
31
+ item['answer'] = 'Not answerable'
32
+
33
+ item['answer'] = re.sub('\n', '', item['answer']).lower()
34
+ item['pred'] = str(item['pred']).lower()
35
+ anls_score = anls_compute(item['answer'], item['pred'])
36
+ em_score = (item['answer'].strip() == item['pred'].strip())
37
+ f1_score = get_f1(item['answer'], item['pred'])
38
+ anls_list.append(anls_score)
39
+ em_list.append(em_score)
40
+ f1_list.append(f1_score)
41
+ print('---------------------')
42
+ print(item['answer'], item['pred'], anls_score, em_score, f1_score)
43
+
44
+ data['anls'] = anls_list
45
+ data['em'] = em_list
46
+ data['f1'] = f1_list
47
+ dump(data, result_file)
48
+
49
+ res = dict()
50
+ res['category'], res['num'] = ['anls', 'EM', 'F1'], [len(data), len(data), len(data)]
51
+ res['avg'] = [sum(anls_list) / len(data), sum(em_list) / len(data), sum(f1_list) / len(data)]
52
+ res = pd.DataFrame(res)
53
+ return res
54
+
55
+
56
+ class SlideVQA(ImageBaseDataset):
57
+
58
+ TYPE = 'VQA'
59
+
60
+ DATASET_URL = {
61
+ 'SLIDEVQA_MINI': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA_MINI.tsv',
62
+ 'SLIDEVQA': 'https://opencompass.openxlab.space/utils/VLMEval/SLIDEVQA.tsv',
63
+ }
64
+ DATASET_MD5 = {
65
+ 'SLIDEVQA_MINI': '6d9a8d8814fa5b7669deb2af3a3208eb',
66
+ 'SLIDEVQA': '5e822c2f800e94c1e23badfd478326b6',
67
+ }
68
+
69
+ SUPPORTED_MODELS = {
70
+ 'GPT4': (1, 1),
71
+ 'GPT4V': (1, 1),
72
+ 'GPT4V_HIGH': (1, 1),
73
+ 'GPT4o': (1, 1),
74
+ 'GPT4o_HIGH': (1, 1),
75
+ 'GPT4o_MINI': (1, 1),
76
+ 'XComposer2d5': (1, -1),
77
+ 'XComposer2_4KHD': (1, -1),
78
+ 'MiniCPM-Llama3-V-2_5': (1, 5),
79
+ 'InternVL-Chat-V1-5': (5, 2),
80
+ }
81
+
82
+ def __init__(self, dataset, **kwargs):
83
+ self.model_list = list(self.SUPPORTED_MODELS.keys())
84
+ model_name = kwargs['model']
85
+ if not listinstr(self.model_list, model_name):
86
+ raise AssertionError("{} doesn't support the evaluation on SlideVQA.".format(model_name))
87
+ super(SlideVQA, self).__init__(dataset)
88
+
89
+ self.is_api = True if listinstr(['GPT4'], model_name) else False
90
+ self.max_pages = 120
91
+ concat_num, column_num = self.SUPPORTED_MODELS.get(model_name)
92
+ self.concat_num = concat_num
93
+ self.column_num = column_num
94
+
95
+ def dump_image(self, origin_line):
96
+ os.makedirs(self.img_root, exist_ok=True)
97
+
98
+ line = origin_line.copy()
99
+ if not isinstance(line['image_path'], List):
100
+ line['image_path'] = [line['image_path']]
101
+ line['image_path'] = line['image_path'][:self.max_pages]
102
+
103
+ if 'image' in line:
104
+ if isinstance(line['image'], list):
105
+ tgt_path = []
106
+ assert 'image_path' in line
107
+ for img, im_name in zip(line['image'], line['image_path']):
108
+ path = osp.join(self.img_root, im_name)
109
+ if not read_ok(path):
110
+ decode_base64_to_image_file(img, path)
111
+ tgt_path.append(path)
112
+ else:
113
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
114
+ if not read_ok(tgt_path):
115
+ decode_base64_to_image_file(line['image'], tgt_path)
116
+ tgt_path = [tgt_path]
117
+ else:
118
+ assert 'image_path' in line
119
+ tgt_path = toliststr(line['image_path'])
120
+
121
+ if self.concat_num > 0 and not self.is_api:
122
+ concatenated_images = concat_images(tgt_path, max_concat=self.concat_num, column_num=self.column_num)
123
+
124
+ old_tgt_path = tgt_path
125
+ assert isinstance(old_tgt_path, list)
126
+ if self.column_num != -1:
127
+ tgt_path = [
128
+ '_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat{}_{}.jpg'.format(self.concat_num, i)
129
+ for i in range(len(concatenated_images))
130
+ ]
131
+ else:
132
+ tgt_path = ['_'.join(old_tgt_path[0].split('_')[:-1]) + '_concat_all.jpg']
133
+
134
+ for path, concatenated_image in zip(tgt_path, concatenated_images):
135
+ if not read_ok(path):
136
+ decode_base64_to_image_file(encode_image_to_base64(concatenated_image), path)
137
+ num_images, image_size = len(old_tgt_path), concatenated_image.size
138
+ print('concat {} images to a new one with size {}. save at {}'.format(num_images, image_size, path))
139
+ return tgt_path
140
+
141
+ @classmethod
142
+ def evaluate(self, eval_file, **judge_kwargs):
143
+ logger = get_logger('Evaluation')
144
+ model = judge_kwargs['model']
145
+
146
+ suffix = eval_file.split('.')[-1]
147
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
148
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
149
+
150
+ if osp.exists(storage):
151
+ logger.warning(f'GPT scoring file {storage} already exists, will reuse it in SlideVQA_eval. ')
152
+ else:
153
+ data = load(eval_file)
154
+ model = build_judge(max_tokens=128, **judge_kwargs)
155
+ lt = len(data)
156
+ lines = [data.iloc[i] for i in range(lt)]
157
+ tups = [(model, line) for line in lines]
158
+ indices = [line['index'] for line in lines]
159
+
160
+ ans = {}
161
+ if osp.exists(tmp_file):
162
+ ans = load(tmp_file)
163
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
164
+ indices = [i for i in indices if i not in ans]
165
+
166
+ if len(indices):
167
+ new_results = list()
168
+ for model, line in tqdm(tups):
169
+ res = MMLongBench_auxeval(model, line)
170
+ new_results.append(res)
171
+
172
+ log_map, res_map, pred_map = {}, {}, {}
173
+ all_inds = [line['index'] for line in lines]
174
+ for k, v in zip(all_inds, new_results):
175
+ log_map[k] = v['log']
176
+ res_map[k] = v['res']
177
+ pred_map[k] = v['pred']
178
+ data['res'] = [res_map[idx] for idx in data['index']]
179
+ data['log'] = [log_map[idx] for idx in data['index']]
180
+ data['pred'] = [pred_map[idx] for idx in data['index']]
181
+ dump(data, storage)
182
+
183
+ score = SlideVQA_acc(storage)
184
+ score_pth = storage.replace('.xlsx', '_score.csv')
185
+
186
+ dump(score, score_pth)
187
+ logger.info(f'SlideVQA successfully finished evaluating {eval_file}, results saved in {score_pth}')
188
+ logger.info('Score: ')
189
+ logger.info(score)
VLMEvalKit/vlmeval/dataset/tempcompass.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ from huggingface_hub import snapshot_download
3
+ from ..smp import *
4
+ from .video_concat_dataset import ConcatVideoDataset
5
+ from .video_base import VideoBaseDataset
6
+ from .utils import build_judge, DEBUG_MESSAGE
7
+ from ..utils import track_progress_rich
8
+ import torchvision.transforms as T
9
+ from torchvision import transforms
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from decord import VideoReader, cpu
12
+ from .utils.tempcompass import *
13
+
14
+
15
+ FAIL_MSG = 'Failed to obtain answer via API.'
16
+
17
+
18
+ class TempCompass(ConcatVideoDataset):
19
+ def __init__(self, dataset='TempCompass'):
20
+ self.DATASET_SETS[dataset] = ['TempCompass_MCQ', 'TempCompass_Captioning', 'TempCompass_YorN']
21
+ super().__init__(dataset=dataset)
22
+
23
+ @classmethod
24
+ def supported_datasets(cls):
25
+ return ['TempCompass']
26
+
27
+ def evaluate(self, eval_file, **judge_kwargs):
28
+ result = super().evaluate(eval_file=eval_file, **judge_kwargs)
29
+ suffix = eval_file.split('.')[-1]
30
+ result = result.reset_index().rename(columns={'index': 'dim.task_type'})
31
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
32
+ avg_dict = {}
33
+ for idx, item in result.iterrows():
34
+ dim, task_type = item['dim.task_type'].split('. ')
35
+ if dim not in avg_dict:
36
+ avg_dict[dim] = {'success': 0.0, 'overall': 0.0}
37
+ if task_type not in avg_dict:
38
+ avg_dict[task_type] = {'success': 0.0, 'overall': 0.0}
39
+ if 'overall' not in avg_dict:
40
+ avg_dict['overall'] = {'success': 0.0, 'overall': 0.0}
41
+ avg_dict[dim]['success'] += item['success']
42
+ avg_dict[dim]['overall'] += item['overall']
43
+ avg_dict[task_type]['success'] += item['success']
44
+ avg_dict[task_type]['overall'] += item['overall']
45
+ avg_dict['overall']['success'] += item['success']
46
+ avg_dict['overall']['overall'] += item['overall']
47
+ result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 2)
48
+ for key, value in avg_dict.items():
49
+ # 使用 loc 方法添加新行
50
+ result.loc[len(result)] = {
51
+ 'dim.task_type': key,
52
+ 'success': value['success'],
53
+ 'overall': value['overall'],
54
+ 'acc': round(value['success'] / value['overall'] * 100, 2)
55
+ }
56
+ dump(result, score_file)
57
+ return result
58
+
59
+
60
+ class TempCompass_MCQ(VideoBaseDataset):
61
+
62
+ MD5 = '7efbb9e6d9dabacd22daf274852691dd'
63
+ TYPE = 'Video-MCQ'
64
+
65
+ def __init__(self, dataset='TempCompass_MCQ'):
66
+ self.type_data_list = {
67
+ 'multi-choice': ('multi-choice.json', './videos', '.mp4'),
68
+ 'caption_matching': ('caption_matching.json', './videos', '.mp4'),
69
+ }
70
+ super().__init__(dataset=dataset)
71
+
72
+ @classmethod
73
+ def supported_datasets(cls):
74
+ return ['TempCompass_MCQ']
75
+
76
+ def prepare_dataset(self, dataset_name='TempCompass_MCQ', repo_id='lmms-lab/TempCompass'):
77
+ def check_integrity(pth):
78
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
79
+
80
+ if not osp.exists(data_file):
81
+ return False
82
+
83
+ if md5(data_file) != self.MD5:
84
+ return False
85
+
86
+ data = load(data_file)
87
+ for idx, item in data.iterrows():
88
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
89
+ return False
90
+ return True
91
+
92
+ cache_path = get_cache_path(repo_id)
93
+ if cache_path is not None and check_integrity(cache_path):
94
+ dataset_path = cache_path
95
+ else:
96
+ def read_parquet(pth):
97
+ import pandas as pd
98
+ for task_name in self.type_data_list.keys():
99
+ if not osp.exists(osp.join(pth, f'{task_name}.json')):
100
+ data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
101
+ data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
102
+
103
+ def unzip_videos(pth):
104
+ import zipfile
105
+ if not osp.exists(osp.join(pth, 'videos')):
106
+ zip_file = osp.join(pth, 'tempcompass_videos.zip')
107
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
108
+ zip_ref.extractall(pth)
109
+
110
+ def generate_tsv(pth):
111
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
112
+ if osp.exists(data_file) and md5(data_file) == self.MD5:
113
+ return
114
+ self.data_list = []
115
+ for k, v in self.type_data_list.items():
116
+ with open(osp.join(pth, v[0]), 'r') as f:
117
+ json_data = json.load(f)
118
+ for data in json_data:
119
+ self.data_list.append({
120
+ 'task_type': k,
121
+ 'prefix': v[1],
122
+ 'suffix': v[2],
123
+ 'video': data['video_id'],
124
+ 'question': data['question'].split('\n')[0],
125
+ 'answer': data['answer'],
126
+ 'dim': data['dim'],
127
+ 'candidates': data['question'].split('\n')[1:],
128
+ })
129
+
130
+ data_df = pd.DataFrame(self.data_list)
131
+ data_df = data_df.assign(index=range(len(data_df)))
132
+ data_df.to_csv(data_file, sep='\t', index=False)
133
+
134
+ if modelscope_flag_set():
135
+ from modelscope import dataset_snapshot_download
136
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
137
+ else:
138
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
139
+ read_parquet(dataset_path)
140
+ unzip_videos(dataset_path)
141
+ generate_tsv(dataset_path)
142
+
143
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
144
+ return dict(root=dataset_path, data_file=data_file)
145
+
146
+ def qa_template(self, data):
147
+ question = data['question'] + '\n' + '\n'.join(eval(data['candidates']))
148
+ answer = data['answer']
149
+ return question, answer
150
+
151
+ def save_video_frames(self, line, num_frames=8, fps=-1):
152
+ vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
153
+ vid = decord.VideoReader(vid_path)
154
+ video_info = {
155
+ 'fps': vid.get_avg_fps(),
156
+ 'n_frames': len(vid),
157
+ }
158
+ if num_frames > 0 and fps < 0:
159
+ step_size = len(vid) / (num_frames + 1)
160
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
161
+ frame_paths = self.frame_paths(line['video'], num_frames)
162
+ elif fps > 0:
163
+ # not constrained by num_frames, get frames by fps
164
+ total_duration = video_info['n_frames'] / video_info['fps']
165
+ required_frames = int(total_duration * fps)
166
+ step_size = video_info['fps'] / fps
167
+ indices = [int(i * step_size) for i in range(required_frames)]
168
+ frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
169
+
170
+ flag = np.all([osp.exists(p) for p in frame_paths])
171
+
172
+ if not flag:
173
+ images = [vid[i].asnumpy() for i in indices]
174
+ images = [Image.fromarray(arr) for arr in images]
175
+ for im, pth in zip(images, frame_paths):
176
+ if not osp.exists(pth):
177
+ im.save(pth)
178
+
179
+ return frame_paths
180
+
181
+ def save_video_into_images(self, line, num_frames, fps):
182
+ frame_paths = self.save_video_frames(line, num_frames, fps)
183
+ return frame_paths
184
+
185
+ def build_prompt(self, line, num_frames, video_llm, fps=-1):
186
+ if isinstance(line, int):
187
+ assert line < len(self)
188
+ line = self.data.iloc[line]
189
+
190
+ question, answer = self.qa_template(line)
191
+ message = []
192
+ message.append(dict(type='text', value=question))
193
+ video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
194
+ if video_llm:
195
+ message.append(dict(type='video', value=video_path))
196
+ else:
197
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
198
+ for im in img_frame_paths:
199
+ message.append(dict(type='image', value=im))
200
+ message.append(dict(type='text', value='\nPlease directly give the best option:'))
201
+ return message
202
+
203
+ @classmethod
204
+ def evaluate(self, eval_file, **judge_kwargs):
205
+ model = judge_kwargs.get('model', 'exact_matching')
206
+ assert model in ['chatgpt-1106', 'exact_matching']
207
+ judge_kwargs.update({
208
+ "max_tokens": 128,
209
+ "temperature": 1.0,
210
+ "top_p": 1,
211
+ "presence_penalty": 1,
212
+ })
213
+
214
+ suffix = eval_file.split('.')[-1]
215
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
216
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
217
+ nproc = judge_kwargs.pop('nproc', 4)
218
+
219
+ if not osp.exists(score_file):
220
+ data = load(eval_file)
221
+ if model != 'exact_matching':
222
+ model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
223
+ else:
224
+ model = None
225
+
226
+ lt = len(data)
227
+ lines = [data.iloc[i] for i in range(lt)]
228
+ tups = [(model, line) for line in lines]
229
+ indices = [line['index'] for line in lines]
230
+
231
+ ans = {}
232
+ if osp.exists(tmp_file):
233
+ ans = load(tmp_file)
234
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
235
+ indices = [i for i in indices if i not in ans]
236
+
237
+ if len(indices):
238
+ _ = track_progress_rich(
239
+ evaluate_tempcompass_mcq,
240
+ tups,
241
+ nproc=nproc,
242
+ chunksize=nproc,
243
+ keys=indices,
244
+ save=tmp_file,
245
+ )
246
+ ans = load(tmp_file)
247
+ for idx, item in data.iterrows():
248
+ data.loc[idx, 'score'] = ans[idx]['rating']
249
+ dump(data, score_file)
250
+
251
+ rating = get_dimension_rating(score_file)
252
+ return rating
253
+
254
+
255
+ class TempCompass_Captioning(VideoBaseDataset):
256
+
257
+ MD5 = '35be9bf2581ea7767f02e9a8f37ae1ab'
258
+ TYPE = 'Video-VQA'
259
+
260
+ def __init__(self, dataset='TempCompass_Captioning'):
261
+ self.type_data_list = {
262
+ 'captioning': ('captioning.json', './videos', '.mp4'),
263
+ }
264
+ super().__init__(dataset=dataset)
265
+
266
+ @classmethod
267
+ def supported_datasets(cls):
268
+ return ['TempCompass_Captioning']
269
+
270
+ def prepare_dataset(self, dataset_name='TempCompass_Captioning', repo_id='lmms-lab/TempCompass'):
271
+ def check_integrity(pth):
272
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
273
+
274
+ if not osp.exists(data_file):
275
+ return False
276
+
277
+ if md5(data_file) != self.MD5:
278
+ return False
279
+
280
+ data = load(data_file)
281
+ for idx, item in data.iterrows():
282
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
283
+ return False
284
+ return True
285
+
286
+ cache_path = get_cache_path(repo_id)
287
+ if cache_path is not None and check_integrity(cache_path):
288
+ dataset_path = cache_path
289
+ else:
290
+ def read_parquet(pth):
291
+ import pandas as pd
292
+ for task_name in self.type_data_list.keys():
293
+ if not osp.exists(osp.join(pth, f'{task_name}.json')):
294
+ data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
295
+ data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
296
+
297
+ def unzip_videos(pth):
298
+ import zipfile
299
+ if not osp.exists(osp.join(pth, 'videos')):
300
+ zip_file = osp.join(pth, 'tempcompass_videos.zip')
301
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
302
+ zip_ref.extractall(pth)
303
+
304
+ def generate_tsv(pth):
305
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
306
+ if osp.exists(data_file) and md5(data_file) == self.MD5:
307
+ return
308
+ self.data_list = []
309
+ for k, v in self.type_data_list.items():
310
+ with open(osp.join(pth, v[0]), 'r') as f:
311
+ json_data = json.load(f)
312
+ for data in json_data:
313
+ self.data_list.append({
314
+ 'task_type': k,
315
+ 'prefix': v[1],
316
+ 'suffix': v[2],
317
+ 'video': data['video_id'],
318
+ 'question': data['question'],
319
+ 'answer': data['answer'],
320
+ 'dim': data['dim'],
321
+ 'mc_question': data['mc_question'],
322
+ 'mc_answer': data['mc_answer'],
323
+ })
324
+
325
+ data_df = pd.DataFrame(self.data_list)
326
+ data_df = data_df.assign(index=range(len(data_df)))
327
+ data_df.to_csv(data_file, sep='\t', index=False)
328
+
329
+ if modelscope_flag_set():
330
+ from modelscope import dataset_snapshot_download
331
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
332
+ else:
333
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
334
+ read_parquet(dataset_path)
335
+ unzip_videos(dataset_path)
336
+ generate_tsv(dataset_path)
337
+
338
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
339
+ return dict(root=dataset_path, data_file=data_file)
340
+
341
+ def qa_template(self, data):
342
+ question = data['question']
343
+ answer = data['answer']
344
+ return question, answer
345
+
346
+ def save_video_frames(self, line, num_frames=8, fps=-1):
347
+ vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
348
+ vid = decord.VideoReader(vid_path)
349
+ video_info = {
350
+ 'fps': vid.get_avg_fps(),
351
+ 'n_frames': len(vid),
352
+ }
353
+ if num_frames > 0 and fps < 0:
354
+ step_size = len(vid) / (num_frames + 1)
355
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
356
+ frame_paths = self.frame_paths(line['video'], num_frames)
357
+ elif fps > 0:
358
+ # not constrained by num_frames, get frames by fps
359
+ total_duration = video_info['n_frames'] / video_info['fps']
360
+ required_frames = int(total_duration * fps)
361
+ step_size = video_info['fps'] / fps
362
+ indices = [int(i * step_size) for i in range(required_frames)]
363
+ frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
364
+
365
+ flag = np.all([osp.exists(p) for p in frame_paths])
366
+
367
+ if not flag:
368
+ images = [vid[i].asnumpy() for i in indices]
369
+ images = [Image.fromarray(arr) for arr in images]
370
+ for im, pth in zip(images, frame_paths):
371
+ if not osp.exists(pth):
372
+ im.save(pth)
373
+
374
+ return frame_paths
375
+
376
+ def save_video_into_images(self, line, num_frames, fps):
377
+ frame_paths = self.save_video_frames(line, num_frames, fps)
378
+ return frame_paths
379
+
380
+ def build_prompt(self, line, num_frames, video_llm, fps=-1):
381
+ if isinstance(line, int):
382
+ assert line < len(self)
383
+ line = self.data.iloc[line]
384
+
385
+ question, answer = self.qa_template(line)
386
+ message = []
387
+ message.append(dict(type='text', value=question))
388
+ video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
389
+ if video_llm:
390
+ message.append(dict(type='video', value=video_path))
391
+ else:
392
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
393
+ for im in img_frame_paths:
394
+ message.append(dict(type='image', value=im))
395
+ return message
396
+
397
+ @classmethod
398
+ def evaluate(self, eval_file, **judge_kwargs):
399
+ model = judge_kwargs.get('model', 'exact_matching')
400
+ assert model in ['chatgpt-1106', 'exact_matching']
401
+ judge_kwargs.update({
402
+ "max_tokens": 128,
403
+ "temperature": 1.0,
404
+ "top_p": 1,
405
+ "presence_penalty": 1,
406
+ })
407
+
408
+ suffix = eval_file.split('.')[-1]
409
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
410
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
411
+ nproc = judge_kwargs.pop('nproc', 4)
412
+
413
+ if not osp.exists(score_file):
414
+ data = load(eval_file)
415
+ if model != 'exact_matching':
416
+ model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
417
+ else:
418
+ model = None
419
+
420
+ lt = len(data)
421
+ lines = [data.iloc[i] for i in range(lt)]
422
+ tups = [(model, line) for line in lines]
423
+ indices = [line['index'] for line in lines]
424
+
425
+ ans = {}
426
+ if osp.exists(tmp_file):
427
+ ans = load(tmp_file)
428
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
429
+ indices = [i for i in indices if i not in ans]
430
+
431
+ if len(indices):
432
+ _ = track_progress_rich(
433
+ evaluate_tempcompass_captioning,
434
+ tups,
435
+ nproc=nproc,
436
+ chunksize=nproc,
437
+ keys=indices,
438
+ save=tmp_file,
439
+ )
440
+ ans = load(tmp_file)
441
+ for idx, item in data.iterrows():
442
+ data.loc[idx, 'score'] = ans[idx]['rating']
443
+ dump(data, score_file)
444
+
445
+ rating = get_dimension_rating(score_file)
446
+ return rating
447
+
448
+
449
+ class TempCompass_YorN(VideoBaseDataset):
450
+
451
+ MD5 = 'c72c046d7fa0e82c8cd7462f2e844ea8'
452
+ TYPE = 'Video-Y/N'
453
+
454
+ def __init__(self, dataset='TempCompass_YorN'):
455
+ self.type_data_list = {
456
+ 'yes_no': ('yes_no.json', './videos', '.mp4'),
457
+ }
458
+ super().__init__(dataset=dataset)
459
+
460
+ @classmethod
461
+ def supported_datasets(cls):
462
+ return ['TempCompass_YorN']
463
+
464
+ def prepare_dataset(self, dataset_name='TempCompass_YorN', repo_id='lmms-lab/TempCompass'):
465
+ def check_integrity(pth):
466
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
467
+
468
+ if not osp.exists(data_file):
469
+ return False
470
+
471
+ if md5(data_file) != self.MD5:
472
+ return False
473
+
474
+ data = load(data_file)
475
+ for idx, item in data.iterrows():
476
+ if not osp.exists(osp.join(pth, item['prefix'], item['video'] + item['suffix'])):
477
+ return False
478
+ return True
479
+
480
+ cache_path = get_cache_path(repo_id)
481
+ if cache_path is not None and check_integrity(cache_path):
482
+ dataset_path = cache_path
483
+ else:
484
+ def read_parquet(pth):
485
+ import pandas as pd
486
+ for task_name in self.type_data_list.keys():
487
+ if not osp.exists(osp.join(pth, f'{task_name}.json')):
488
+ data = pd.read_parquet(osp.join(pth, task_name, 'test-00000-of-00001.parquet'))
489
+ data.to_json(osp.join(pth, f'{task_name}.json'), orient='records', lines=False)
490
+
491
+ def unzip_videos(pth):
492
+ import zipfile
493
+ if not osp.exists(osp.join(pth, 'videos')):
494
+ zip_file = osp.join(pth, 'tempcompass_videos.zip')
495
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
496
+ zip_ref.extractall(pth)
497
+
498
+ def generate_tsv(pth):
499
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
500
+ if osp.exists(data_file) and md5(data_file) == self.MD5:
501
+ return
502
+ self.data_list = []
503
+ for k, v in self.type_data_list.items():
504
+ with open(osp.join(pth, v[0]), 'r') as f:
505
+ json_data = json.load(f)
506
+ for data in json_data:
507
+ self.data_list.append({
508
+ 'task_type': k,
509
+ 'prefix': v[1],
510
+ 'suffix': v[2],
511
+ 'video': data['video_id'],
512
+ 'question': data['question'].split('\n')[0],
513
+ 'answer': data['answer'],
514
+ 'dim': data['dim']
515
+ })
516
+
517
+ data_df = pd.DataFrame(self.data_list)
518
+ data_df = data_df.assign(index=range(len(data_df)))
519
+ data_df.to_csv(data_file, sep='\t', index=False)
520
+
521
+ if modelscope_flag_set():
522
+ from modelscope import dataset_snapshot_download
523
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
524
+ else:
525
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
526
+ read_parquet(dataset_path)
527
+ unzip_videos(dataset_path)
528
+ generate_tsv(dataset_path)
529
+
530
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
531
+ return dict(root=dataset_path, data_file=data_file)
532
+
533
+ def qa_template(self, data):
534
+ question = data['question']
535
+ answer = data['answer']
536
+ return question, answer
537
+
538
+ def save_video_frames(self, line, num_frames=8, fps=-1):
539
+ vid_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
540
+ vid = decord.VideoReader(vid_path)
541
+ video_info = {
542
+ 'fps': vid.get_avg_fps(),
543
+ 'n_frames': len(vid),
544
+ }
545
+ if num_frames > 0 and fps < 0:
546
+ step_size = len(vid) / (num_frames + 1)
547
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
548
+ frame_paths = self.frame_paths(line['video'], num_frames)
549
+ elif fps > 0:
550
+ # not constrained by num_frames, get frames by fps
551
+ total_duration = video_info['n_frames'] / video_info['fps']
552
+ required_frames = int(total_duration * fps)
553
+ step_size = video_info['fps'] / fps
554
+ indices = [int(i * step_size) for i in range(required_frames)]
555
+ frame_paths = self.frame_paths_fps(line['video'], len(indices), fps)
556
+
557
+ flag = np.all([osp.exists(p) for p in frame_paths])
558
+
559
+ if not flag:
560
+ images = [vid[i].asnumpy() for i in indices]
561
+ images = [Image.fromarray(arr) for arr in images]
562
+ for im, pth in zip(images, frame_paths):
563
+ if not osp.exists(pth):
564
+ im.save(pth)
565
+
566
+ return frame_paths
567
+
568
+ def save_video_into_images(self, line, num_frames, fps):
569
+ frame_paths = self.save_video_frames(line, num_frames, fps)
570
+ return frame_paths
571
+
572
+ def build_prompt(self, line, num_frames, video_llm, fps=-1):
573
+ if isinstance(line, int):
574
+ assert line < len(self)
575
+ line = self.data.iloc[line]
576
+
577
+ question, answer = self.qa_template(line)
578
+ message = []
579
+ message.append(dict(type='text', value=question))
580
+ video_path = osp.join(self.data_root, line['prefix'], line['video'] + line['suffix'])
581
+ if video_llm:
582
+ message.append(dict(type='video', value=video_path))
583
+ else:
584
+ img_frame_paths = self.save_video_into_images(line, num_frames, fps)
585
+ for im in img_frame_paths:
586
+ message.append(dict(type='image', value=im))
587
+ message.append(dict(type='text', value='\nPlease answer yes or no:'))
588
+ return message
589
+
590
+ @classmethod
591
+ def evaluate(self, eval_file, **judge_kwargs):
592
+ model = judge_kwargs.get('model', 'exact_matching')
593
+ assert model in ['chatgpt-1106', 'exact_matching']
594
+ judge_kwargs.update({
595
+ "max_tokens": 128,
596
+ "temperature": 1.0,
597
+ "top_p": 1,
598
+ "presence_penalty": 1,
599
+ })
600
+
601
+ suffix = eval_file.split('.')[-1]
602
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
603
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
604
+ nproc = judge_kwargs.pop('nproc', 4)
605
+
606
+ if not osp.exists(score_file):
607
+ data = load(eval_file)
608
+ if model != 'exact_matching':
609
+ model = build_judge(system_prompt=sys_prompt, **judge_kwargs)
610
+ else:
611
+ model = None
612
+
613
+ lt = len(data)
614
+ lines = [data.iloc[i] for i in range(lt)]
615
+ tups = [(model, line) for line in lines]
616
+ indices = [line['index'] for line in lines]
617
+
618
+ ans = {}
619
+ if osp.exists(tmp_file):
620
+ ans = load(tmp_file)
621
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
622
+ indices = [i for i in indices if i not in ans]
623
+
624
+ if len(indices):
625
+ _ = track_progress_rich(
626
+ evaluate_tempcompass_YorN,
627
+ tups,
628
+ nproc=nproc,
629
+ chunksize=nproc,
630
+ keys=indices,
631
+ save=tmp_file,
632
+ )
633
+ ans = load(tmp_file)
634
+ for idx, item in data.iterrows():
635
+ data.loc[idx, 'score'] = ans[idx]['rating']
636
+ dump(data, score_file)
637
+
638
+ rating = get_dimension_rating(score_file)
639
+ return rating
VLMEvalKit/vlmeval/dataset/text_base.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from ..smp import *
3
+
4
+
5
+ class TextBaseDataset:
6
+ MODALITY = 'TEXT'
7
+ DATASET_URL = {}
8
+ DATASET_MD5 = {}
9
+
10
+ def __init__(self, dataset='MMBench', **kwargs):
11
+ self.dataset_name = dataset
12
+
13
+ data = self.load_data(dataset)
14
+
15
+ data['index'] = [str(x) for x in data['index']]
16
+
17
+ if np.all([istype(x, int) for x in data['index']]):
18
+ data['index'] = [int(x) for x in data['index']]
19
+
20
+ self.data = data
21
+ self.post_build(dataset)
22
+
23
+ def __len__(self):
24
+ return len(self.data)
25
+
26
+ def __getitem__(self, idx):
27
+ return dict(self.data.iloc[idx])
28
+
29
+ def prepare_tsv(self, url, file_md5=None):
30
+ data_root = LMUDataRoot()
31
+ os.makedirs(data_root, exist_ok=True)
32
+ update_flag = False
33
+ file_name = url.split('/')[-1]
34
+ data_path = osp.join(data_root, file_name)
35
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
36
+ pass
37
+ else:
38
+ warnings.warn('The dataset tsv is not downloaded')
39
+ download_file(url, data_path)
40
+ update_flag = True
41
+
42
+ if file_size(data_path, 'GB') > 1:
43
+ local_path = data_path.replace('.tsv', '_local.tsv')
44
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
45
+ from ..tools import LOCALIZE
46
+ LOCALIZE(data_path, local_path)
47
+ data_path = local_path
48
+ return load(data_path)
49
+
50
+ def dump_image(self, line):
51
+ return []
52
+
53
+ def display(self, line):
54
+ if isinstance(line, int):
55
+ line = self.data.iloc[line]
56
+ assert isinstance(line, pd.Series) or isinstance(line, dict)
57
+ mmqa_display(line)
58
+
59
+ # Return a list of dataset names that are supported by this class, can override
60
+ @classmethod
61
+ def supported_datasets(cls):
62
+ return list(cls.DATASET_URL)
63
+
64
+ # Given the dataset name, return the dataset as a pandas dataframe, can override
65
+ def load_data(self, dataset):
66
+ url = self.DATASET_URL[dataset]
67
+ file_md5 = self.DATASET_MD5[dataset]
68
+ return self.prepare_tsv(url, file_md5)
69
+
70
+ # Post built hook, will be called after the dataset is built, can override
71
+ def post_build(self, dataset):
72
+ pass
73
+
74
+ # Given one data record, return the built prompt (a multi-modal message), can override
75
+ def build_prompt(self, line):
76
+ if isinstance(line, int):
77
+ line = self.data.iloc[line]
78
+
79
+ question = line['question']
80
+
81
+ msgs = []
82
+ msgs.append(dict(type='text', value=question))
83
+ return msgs
84
+
85
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
86
+ @abstractmethod
87
+ def evaluate(self, eval_file, **judge_kwargs):
88
+ pass
VLMEvalKit/vlmeval/dataset/text_mcq.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .text_base import TextBaseDataset
2
+ from .utils import build_judge, DEBUG_MESSAGE
3
+ from ..smp import *
4
+
5
+
6
+ class TextMCQDataset(TextBaseDataset):
7
+ TYPE = 'MCQ'
8
+
9
+ DATASET_URL = {}
10
+
11
+ DATASET_MD5 = {}
12
+
13
+ def build_prompt(self, line):
14
+
15
+ if isinstance(line, int):
16
+ line = self.data.iloc[line]
17
+
18
+ question = line['question']
19
+ options = {
20
+ cand: line[cand]
21
+ for cand in string.ascii_uppercase
22
+ if cand in line and not pd.isna(line[cand])
23
+ }
24
+ options_prompt = 'Options:\n'
25
+ for key, item in options.items():
26
+ options_prompt += f'{key}. {item}\n'
27
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
28
+ prompt = ''
29
+ if hint is not None:
30
+ prompt += f'Hint: {hint}\n'
31
+ prompt += f'Question: {question}\n'
32
+ if len(options):
33
+ prompt += options_prompt
34
+ prompt += 'Please select the correct answer from the options above. \n'
35
+
36
+ msgs = []
37
+
38
+ msgs.append(dict(type='text', value=prompt))
39
+
40
+ return msgs
41
+
42
+ def evaluate(self, eval_file, **judge_kwargs):
43
+ from .utils.multiple_choice import report_acc, report_acc_MMT, mcq_circular_eval, mcq_vanilla_eval
44
+ # assert dataset is not None
45
+ dataset_map = {
46
+ 'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_EN_V11': 'MMBench_V11',
47
+ 'MMBench_TEST_CN': 'MMBench_CN', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11'
48
+ }
49
+ dataset = self.dataset_name
50
+ if dataset in dataset_map:
51
+ dataset = dataset_map[dataset]
52
+ nproc = judge_kwargs.pop('nproc', 4)
53
+
54
+ circular = False
55
+
56
+ suffix = eval_file.split('.')[-1]
57
+ model = judge_kwargs.get('model', 'exact_matching')
58
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
59
+ name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
60
+ name_str = name_str_map[model] if model in name_str_map else model
61
+
62
+ if model == 'exact_matching':
63
+ model = None
64
+ elif gpt_key_set():
65
+ model = build_judge(**judge_kwargs)
66
+ if not model.working():
67
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
68
+ warnings.warn(DEBUG_MESSAGE)
69
+ model = None
70
+ else:
71
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
72
+ model = None
73
+
74
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
75
+
76
+ data = load(eval_file)
77
+ data = data.sort_values(by='index')
78
+ data['prediction'] = [str(x) for x in data['prediction']]
79
+ # If not choice label, then use lower case
80
+ for k in data.keys():
81
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
82
+
83
+ meta = self.data
84
+ meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
85
+ data_map = {x: y for x, y in zip(data['index'], data['question'])}
86
+ for k in data_map:
87
+ assert k in meta_q_map, (
88
+ f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
89
+ )
90
+
91
+ if circular:
92
+ data = mcq_circular_eval(model, data, meta, nproc, result_file, self.dataset_name)
93
+ else:
94
+ data = mcq_vanilla_eval(model, data, meta, nproc, result_file, self.dataset_name)
95
+
96
+ # load split
97
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
98
+ data = load(eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}'))
99
+
100
+ # May have different report acc functions for different datasets
101
+ if 'MMT' in dataset:
102
+ acc = report_acc_MMT(data)
103
+ else:
104
+ acc = report_acc(data)
105
+
106
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
107
+ dump(acc, score_file)
108
+
109
+ return acc
110
+
111
+
112
+ class CustomTextMCQDataset(TextMCQDataset):
113
+
114
+ def load_data(self, dataset):
115
+ data_path = osp.join(LMUDataRoot(), f'{dataset}.tsv')
116
+
117
+ if file_size(data_path, 'GB') > 1:
118
+ local_path = data_path.replace('.tsv', '_local.tsv')
119
+ if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None):
120
+ from ..tools import LOCALIZE
121
+ LOCALIZE(data_path, local_path)
122
+ data_path = local_path
123
+ return load(data_path)
VLMEvalKit/vlmeval/dataset/vcr.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from functools import partial
3
+ from .image_base import ImageBaseDataset
4
+ from ..smp import *
5
+
6
+ rouge = None
7
+ nlp_en = None
8
+ nlp_zh = None
9
+ nlp = None
10
+
11
+
12
+ def initialize():
13
+ import evaluate
14
+ import spacy
15
+
16
+ global rouge, nlp_en, nlp_zh, nlp
17
+
18
+ try:
19
+ rouge = evaluate.load('rouge', experiment_id=str(uuid.uuid4()))
20
+ except Exception as e:
21
+ logging.critical(f'{type(e)}: {e}')
22
+ logging.critical('Please first `pip install rouge_score`.')
23
+
24
+ try:
25
+ nlp_en = spacy.load('en_core_web_sm')
26
+ except Exception as e:
27
+ logging.warning(f'{type(e)}: {e}')
28
+ logging.warning('Will automatically download en_core_web_sm via spacy.')
29
+ spacy.cli.download('en_core_web_sm')
30
+ nlp_en = spacy.load('en_core_web_sm')
31
+
32
+ try:
33
+ nlp_zh = spacy.load('zh_core_web_sm')
34
+ except Exception as e:
35
+ logging.warning(f'{type(e)}: {e}')
36
+ logging.warning('Will automatically download zh_core_web_sm via spacy.')
37
+ spacy.cli.download('zh_core_web_sm')
38
+ nlp_zh = spacy.load('zh_core_web_sm')
39
+
40
+ nlp = {'en': nlp_en, 'zh': nlp_zh}
41
+
42
+
43
+ def rough_filter(answer_text):
44
+ if "I can't" in answer_text:
45
+ return False
46
+ elif 'I cannot' in answer_text:
47
+ return False
48
+ elif 'sorry' in answer_text.lower():
49
+ return False
50
+ if '无法' in answer_text:
51
+ return False
52
+ elif '抱歉' in answer_text:
53
+ return False
54
+ else:
55
+ return True
56
+
57
+
58
+ def zero_template(crossed_text):
59
+ return {
60
+ 'crossed_text': crossed_text,
61
+ 'max_sim_val': 0,
62
+ 'max_sim_string': '',
63
+ 'precision': 0,
64
+ 'recall': 0,
65
+ 'f1': 0,
66
+ 'jaccard': 0,
67
+ 'rouge1': 0,
68
+ 'exact_match': 0,
69
+ }
70
+
71
+
72
+ def tokenize(text, language):
73
+ """
74
+ Tokenize the text and return the tokens.
75
+
76
+ Parameters:
77
+ text (str): The text to tokenize.
78
+ language (str): The language of the text.
79
+
80
+ Returns:
81
+ list: The list of tokens.
82
+ """
83
+ assert language in ['en', 'zh']
84
+ nlp_language = nlp[language]
85
+ processed_text = nlp_language(text)
86
+ return [token.text for token in processed_text]
87
+
88
+
89
+ def find_best_match(needle, hay, language, rouge):
90
+ """
91
+ Finds the best matching n-gram in the haystack for the given needle.
92
+
93
+ Parameters:
94
+ needle (str): The string to find.
95
+ hay (str): The text to search within.
96
+
97
+ Returns:
98
+ tuple: The highest similarity value and the best matching string.
99
+ """
100
+ assert language in ['en', 'zh']
101
+ from nltk.util import ngrams
102
+ from difflib import SequenceMatcher as SM
103
+
104
+ tokens_hay = tokenize(hay, language)
105
+ tokens_needle = tokenize(needle, language)
106
+
107
+ splitter = '' if language == 'zh' else ' '
108
+ ngrams_ = ngrams(tokens_hay, len(tokens_needle))
109
+ max_sim_val = 0
110
+ max_sim_string = ''
111
+ max_sim_ngram = []
112
+ tokens_needle_set = set(tokens_needle)
113
+ ngrams_hasjoint = [
114
+ ngram
115
+ for ngram in ngrams_
116
+ if not set(ngram).isdisjoint(tokens_needle_set)
117
+ ]
118
+
119
+ for ngram in ngrams_hasjoint:
120
+ hay_ngram = splitter.join(ngram)
121
+ similarity = SM(None, hay_ngram, needle).ratio()
122
+ if similarity > max_sim_val:
123
+ max_sim_val = similarity
124
+ max_sim_string = hay_ngram
125
+ max_sim_ngram = ngram
126
+
127
+ # Evaluate
128
+ if len(max_sim_ngram) == 0:
129
+ return {
130
+ 'crossed_text': needle,
131
+ 'max_sim_val': 0,
132
+ 'max_sim_string': '',
133
+ 'precision': 0,
134
+ 'recall': 0,
135
+ 'f1': 0,
136
+ 'jaccard': 0,
137
+ 'rouge1': 0,
138
+ 'exact_match': 0,
139
+ }
140
+ pred_set = set(max_sim_ngram)
141
+ ref_set = set(tokens_needle)
142
+ correct_tokens = pred_set.intersection(ref_set)
143
+ len_correct_tokens = len(correct_tokens)
144
+
145
+ precision = len_correct_tokens / len(pred_set)
146
+ recall = len_correct_tokens / len(ref_set)
147
+ if (precision + recall) == 0:
148
+ f1 = 0
149
+ else:
150
+ f1 = 2 * precision * recall / (precision + recall)
151
+ union = pred_set.union(ref_set)
152
+ jaccard = len_correct_tokens / len(union) if len(union) > 0 else 0
153
+ rouge_1 = rouge.compute(
154
+ predictions=[max_sim_string],
155
+ references=[needle],
156
+ tokenizer=partial(tokenize, language=language),
157
+ rouge_types=['rouge1'],
158
+ )['rouge1']
159
+ exact_match = float(list(max_sim_ngram) == list(tokens_needle))
160
+ out = {
161
+ 'crossed_text': needle,
162
+ 'max_sim_string': max_sim_string,
163
+ 'max_sim_val': max_sim_val,
164
+ 'precision': precision,
165
+ 'recall': recall,
166
+ 'f1': f1,
167
+ 'jaccard': jaccard,
168
+ 'rouge1': rouge_1,
169
+ 'exact_match': exact_match,
170
+ }
171
+ return out
172
+
173
+
174
+ def process_match_single_new(
175
+ image_id, prediction, answer, language, progress
176
+ ):
177
+ """
178
+ process the inference results for a single image and calculate the metrics
179
+
180
+ Parameters:
181
+ image_id (int): The image id (question id).
182
+ prediction (str): The prediction text.
183
+ answer (Union[str, List[str]]): The answer text, or a list of answer texts. The masked n-grams in the image.
184
+ language (str): The language of the text. Can be "en" or "zh".
185
+ rouge (rouge): The rouge metric object.
186
+ progress (multiprocessing.Queue): The progress queue.
187
+
188
+ Returns:
189
+ tuple: The image id (question_id, int) and the result per id (dict of dict of dict).
190
+ """
191
+ result_per_id = {image_id: {}}
192
+ if isinstance(answer, str):
193
+ answer = eval(answer)
194
+ assert isinstance(answer, list)
195
+ result = prediction.split('Assistant: ')[-1]
196
+ for i, crossed_text in enumerate(answer):
197
+ if rough_filter(result):
198
+ find_best_match_result = find_best_match(
199
+ crossed_text, result, language, rouge
200
+ )
201
+ if i == 0:
202
+ result_per_id[image_id] = {str(i): find_best_match_result}
203
+ else:
204
+ result_per_id[image_id][str(i)] = find_best_match_result
205
+ else:
206
+ if i == 0:
207
+ result_per_id[image_id] = {str(i): zero_template(crossed_text)}
208
+ else:
209
+ result_per_id[image_id][str(i)] = zero_template(crossed_text)
210
+ progress.put(1)
211
+ return image_id, result_per_id
212
+
213
+
214
+ class VCRDataset(ImageBaseDataset):
215
+ TYPE = 'VQA'
216
+
217
+ URL_PREFIX = 'https://huggingface.co/datasets/vcr-org'
218
+
219
+ DATASET_URL = {
220
+ 'VCR_EN_EASY_500': f'{URL_PREFIX}/VCR-wiki-en-easy-test-500/resolve/main/VCR-wiki-en-easy-test-500.tsv',
221
+ 'VCR_EN_EASY_100': f'{URL_PREFIX}/VCR-wiki-en-easy-test-100/resolve/main/VCR-wiki-en-easy-test-100.tsv',
222
+ 'VCR_EN_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-en-easy-test/resolve/main/VCR-wiki-en-easy-test.tsv',
223
+ 'VCR_EN_HARD_500': f'{URL_PREFIX}/VCR-wiki-en-hard-test-500/resolve/main/VCR-wiki-en-hard-test-500.tsv',
224
+ 'VCR_EN_HARD_100': f'{URL_PREFIX}/VCR-wiki-en-hard-test-100/resolve/main/VCR-wiki-en-hard-test-100.tsv',
225
+ 'VCR_EN_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-en-hard-test/resolve/main/VCR-wiki-en-hard-test.tsv',
226
+ 'VCR_ZH_EASY_500': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-500/resolve/main/VCR-wiki-zh-easy-test-500.tsv',
227
+ 'VCR_ZH_EASY_100': f'{URL_PREFIX}/VCR-wiki-zh-easy-test-100/resolve/main/VCR-wiki-zh-easy-test-100.tsv',
228
+ 'VCR_ZH_EASY_ALL': f'{URL_PREFIX}/VCR-wiki-zh-easy-test/resolve/main/VCR-wiki-zh-easy-test.tsv',
229
+ 'VCR_ZH_HARD_500': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-500/resolve/main/VCR-wiki-zh-hard-test-500.tsv',
230
+ 'VCR_ZH_HARD_100': f'{URL_PREFIX}/VCR-wiki-zh-hard-test-100/resolve/main/VCR-wiki-zh-hard-test-100.tsv',
231
+ 'VCR_ZH_HARD_ALL': f'{URL_PREFIX}/VCR-wiki-zh-hard-test/resolve/main/VCR-wiki-zh-hard-test.tsv',
232
+ }
233
+
234
+ DATASET_MD5 = {
235
+ 'VCR_EN_EASY_500': 'fd9258db52f8685dc710619a0ea0a261',
236
+ 'VCR_EN_EASY_100': '9df5d7266683458621ecbe122beb72f0',
237
+ 'VCR_EN_EASY_ALL': '8a9b96885f251d1c85f42f84073327f1',
238
+ 'VCR_EN_HARD_500': '0a22a85080b6a1f52b1f95e302d43df4',
239
+ 'VCR_EN_HARD_100': '1b20f5cbcbeae0b0bec77f7a36143958',
240
+ 'VCR_EN_HARD_ALL': '2d8b8b1ee0eba0e0b618fd3aa7d9710e',
241
+ 'VCR_ZH_EASY_500': 'beca5fd54176adf44cf94bd9b50cf048',
242
+ 'VCR_ZH_EASY_100': '4a86a5678a79844d6d22ab0629c51cd5',
243
+ 'VCR_ZH_EASY_ALL': '5050fe7f0027ad2068fd4c7f220edaea',
244
+ 'VCR_ZH_HARD_500': '617e3360f75c54455625cb0a8da5c1e7',
245
+ 'VCR_ZH_HARD_100': 'b0e38c85f5d5e63894a3b881c372a62b',
246
+ 'VCR_ZH_HARD_ALL': '54bbfef448206518b03127ef8b61404c',
247
+ }
248
+
249
+ def __init__(self, dataset='VCR_EN_EASY_500', skip_noimg=True):
250
+ super().__init__(dataset, skip_noimg)
251
+
252
+ initialize()
253
+ self.language = 'en' if 'EN' in dataset else 'zh'
254
+ self.difficulty = 'easy' if 'EASY' in dataset else 'hard'
255
+
256
+ # def build_prompt(self, line):
257
+ # msgs = super().build_prompt(line)
258
+ # assert msgs[-1]['type'] == 'text'
259
+ # if self.language == 'zh':
260
+ # msgs[-1]['value'] += '图像中被覆盖的文本是什么?请在不输出解释的情况下还原被覆盖的文本。'
261
+ # else:
262
+ # msgs[-1]['value'] += ('What is the covered texts in the image? '
263
+ # 'Please restore the covered texts without outputting the explanations.')
264
+ # return msgs
265
+
266
+ def evaluate(self, eval_file, **judge_kwargs):
267
+ import multiprocessing
268
+
269
+ vcr_score_list = {'Exact_Match': [], 'Jaccard': []}
270
+ vcr_score = {'Exact_Match': 0, 'Jaccard': 0}
271
+ logger = get_logger('Evaluation')
272
+ data = load(eval_file)
273
+
274
+ lt = len(data)
275
+ lines = [data.iloc[i] for i in range(lt)]
276
+
277
+ pool = multiprocessing.Pool()
278
+ manager = multiprocessing.Manager()
279
+ progress_queue = manager.Queue()
280
+ results = []
281
+
282
+ overall_results = {str(image_id): {} for image_id in range(len(lines))}
283
+
284
+ for instance_id, instance in enumerate(lines):
285
+ results.append(
286
+ pool.apply_async(
287
+ process_match_single_new,
288
+ args=(
289
+ str(instance_id),
290
+ instance['prediction'],
291
+ instance['answer'],
292
+ self.language,
293
+ progress_queue,
294
+ ),
295
+ )
296
+ )
297
+ pool.close()
298
+
299
+ # Display progress bar
300
+ for _ in tqdm(range(len(results))):
301
+ progress_queue.get()
302
+
303
+ pool.join()
304
+
305
+ # Merging results into overall_result
306
+ for result in results:
307
+ image_id, result_per_id = result.get()
308
+ overall_results[str(image_id)].update(result_per_id[image_id])
309
+ for blank_id_str in result_per_id[image_id].keys():
310
+ vcr_score_list['Exact_Match'].append(
311
+ result_per_id[image_id][blank_id_str]['exact_match']
312
+ )
313
+ vcr_score_list['Jaccard'].append(
314
+ result_per_id[image_id][blank_id_str]['jaccard']
315
+ )
316
+ vcr_score['Exact_Match'] = np.mean(vcr_score_list['Exact_Match'])
317
+ vcr_score['Jaccard'] = np.mean(vcr_score_list['Jaccard'])
318
+ results_out = {
319
+ k: v for i in range(len(results)) for k, v in results[i].get()[1].items()
320
+ }
321
+ results_with_metrics = {
322
+ 'Exact_Match': vcr_score['Exact_Match'],
323
+ 'Jaccard': vcr_score['Jaccard'],
324
+ 'Predictions': results_out,
325
+ }
326
+ score_pth = eval_file.replace(
327
+ '.xlsx', f'{self.language}_{self.difficulty}_score.json'
328
+ )
329
+ dump(results_with_metrics, score_pth)
330
+ logger.info(
331
+ f'VCR successfully finished evaluating {eval_file}, results saved in {score_pth}'
332
+ )
333
+ logger.info('Score: ')
334
+ for key, value in vcr_score.items():
335
+ logger.info('{}:{}'.format(key, value))
VLMEvalKit/vlmeval/dataset/video_base.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from ..smp import *
3
+
4
+
5
+ class VideoBaseDataset:
6
+
7
+ MODALITY = 'VIDEO'
8
+
9
+ def __init__(self,
10
+ dataset='MMBench-Video',
11
+ pack=False):
12
+ try:
13
+ import decord
14
+ except Exception as e:
15
+ logging.critical(f'{type(e)}: {e}')
16
+ logging.critical('Please install decord via `pip install decord`.')
17
+
18
+ self.dataset_name = dataset
19
+ ret = self.prepare_dataset(dataset)
20
+ assert ret is not None
21
+ lmu_root = LMUDataRoot()
22
+ self.frame_root = osp.join(lmu_root, 'images', dataset)
23
+ os.makedirs(self.frame_root, exist_ok=True)
24
+ self.frame_tmpl = 'frame-{}-of-{}.jpg'
25
+ self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg'
26
+
27
+ self.data_root = ret['root']
28
+ self.data_file = ret['data_file']
29
+ self.data = load(self.data_file)
30
+
31
+ assert 'question' in self.data and 'video' in self.data
32
+ videos = list(set(self.data['video']))
33
+ videos.sort()
34
+ self.videos = videos
35
+ self.pack = pack
36
+
37
+ def __len__(self):
38
+ return len(self.videos) if self.pack else len(self.data)
39
+
40
+ def __getitem__(self, idx):
41
+ if self.pack:
42
+ assert idx < len(self.videos)
43
+ sub_data = self.data[self.data['video'] == self.videos[idx]]
44
+ return sub_data
45
+ else:
46
+ assert idx < len(self.data)
47
+ return dict(self.data.iloc[idx])
48
+
49
+ def frame_paths(self, video, num_frames=8):
50
+ frame_root = osp.join(self.frame_root, video)
51
+ os.makedirs(frame_root, exist_ok=True)
52
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
53
+
54
+ def frame_paths_fps(self, video, num_frames=8, fps=-1):
55
+ frame_root = osp.join(self.frame_root, video)
56
+ os.makedirs(frame_root, exist_ok=True)
57
+ return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
58
+
59
+ def save_video_frames(self, video, num_frames=8, fps=-1):
60
+ if fps > 0:
61
+ vid_path = osp.join(self.data_root, video + '.mp4')
62
+ vid = decord.VideoReader(vid_path)
63
+
64
+ # 计算视频的总帧数和总时长
65
+ total_frames = len(vid)
66
+ video_fps = vid.get_avg_fps()
67
+ total_duration = total_frames / video_fps
68
+
69
+ # 计算需要提取的总帧数
70
+ required_frames = int(total_duration * fps)
71
+
72
+ # 计算提取帧的间隔
73
+ step_size = video_fps / fps
74
+
75
+ # 计算提取帧的索引
76
+ indices = [int(i * step_size) for i in range(required_frames)]
77
+
78
+ # 提取帧并保存
79
+ frame_paths = self.frame_paths_fps(video, len(indices), fps)
80
+ flag = np.all([osp.exists(p) for p in frame_paths])
81
+ if flag:
82
+ return frame_paths
83
+
84
+ images = [vid[i].asnumpy() for i in indices]
85
+ images = [Image.fromarray(arr) for arr in images]
86
+ for im, pth in zip(images, frame_paths):
87
+ if not osp.exists(pth):
88
+ im.save(pth)
89
+ return frame_paths
90
+
91
+ else:
92
+ frame_paths = self.frame_paths(video, num_frames)
93
+ flag = np.all([osp.exists(p) for p in frame_paths])
94
+ if flag:
95
+ return frame_paths
96
+ vid_path = osp.join(self.data_root, video + '.mp4')
97
+ vid = decord.VideoReader(vid_path)
98
+ step_size = len(vid) / (num_frames + 1)
99
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
100
+ images = [vid[i].asnumpy() for i in indices]
101
+ images = [Image.fromarray(arr) for arr in images]
102
+ for im, pth in zip(images, frame_paths):
103
+ if not osp.exists(pth):
104
+ im.save(pth)
105
+ return frame_paths
106
+
107
+ # Return a list of dataset names that are supported by this class, can override
108
+ @classmethod
109
+ def supported_datasets(cls):
110
+ return ['MMBench-Video', 'Video-MME', 'MVBench', 'MVBench_MP4', 'LongVideoBench']
111
+
112
+ # Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
113
+ @abstractmethod
114
+ def evaluate(self, eval_file, **judge_kwargs):
115
+ pass
116
+
117
+ @abstractmethod
118
+ def build_prompt(self, idx, num_frames=8):
119
+ pass
120
+
121
+ @abstractmethod
122
+ def prepare_dataset(self, dataset):
123
+ # The prepare_dataset function should return a dictionary containing:
124
+ # `root` (directory that containing video files)
125
+ # `data_file` (the TSV dataset file)
126
+ pass
VLMEvalKit/vlmeval/dataset/video_concat_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ from .video_base import VideoBaseDataset
3
+
4
+
5
+ class ConcatVideoDataset(VideoBaseDataset):
6
+ # This dataset takes multiple dataset names as input and aggregate them into a single dataset.
7
+ # Each single dataset should not have a field named `SUB_DATASET`
8
+
9
+ DATASET_SETS = {}
10
+
11
+ def __init__(self, dataset):
12
+ from . import build_dataset
13
+ datasets = self.DATASET_SETS[dataset]
14
+ self.dataset_map = {}
15
+ # The name of the compliation
16
+ self.dataset_name = dataset
17
+ self.datasets = datasets
18
+ for dname in datasets:
19
+ dataset = build_dataset(dname)
20
+ assert dataset is not None, dataset
21
+ self.dataset_map[dname] = dataset
22
+ TYPES = [x.TYPE for x in self.dataset_map.values()]
23
+ MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
24
+ # assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
25
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
26
+ self.TYPE = TYPES
27
+ self.MODALITY = MODALITIES[0]
28
+ data_all = []
29
+ for dname in datasets:
30
+ data = self.dataset_map[dname].data
31
+ data['SUB_DATASET'] = [dname] * len(data)
32
+ data_all.append(data)
33
+
34
+ data = pd.concat(data_all)
35
+ data['original_index'] = data.pop('index')
36
+ data['index'] = np.arange(len(data))
37
+ self.data = data
38
+
39
+ def build_prompt(self, line, num_frames, video_llm, fps):
40
+ if isinstance(line, int):
41
+ line = self.data.iloc[line]
42
+ idx = line['original_index']
43
+ dname = line['SUB_DATASET']
44
+ org_data = self.dataset_map[dname].data
45
+ org_line = cp.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
46
+ return self.dataset_map[dname].build_prompt(org_line, num_frames, video_llm, fps)
47
+
48
+ def dump_image(self, line):
49
+ # Assert all images are pre-dumped
50
+ assert 'image' not in line
51
+ assert 'image_path' in line
52
+ tgt_path = toliststr(line['image_path'])
53
+ return tgt_path
54
+
55
+ @classmethod
56
+ def supported_datasets(cls):
57
+ return [] # list(cls.DATASET_SETS)
58
+
59
+ def evaluate(self, eval_file, **judge_kwargs):
60
+ suffix = eval_file.split('.')[-1]
61
+ # First, split the eval_file by dataset
62
+ data_all = load(eval_file)
63
+ for dname in self.datasets:
64
+ tgt = eval_file.replace(self.dataset_name, dname)
65
+ data_sub = data_all[data_all['SUB_DATASET'] == dname]
66
+ data_sub.pop('index')
67
+ data_sub['index'] = data_sub.pop('original_index')
68
+ data_sub.pop('SUB_DATASET')
69
+ dump(data_sub, tgt)
70
+ # Then, evaluate each dataset separately
71
+ results_all = {}
72
+ for dname in self.datasets:
73
+ tgt = eval_file.replace(self.dataset_name, dname)
74
+ res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
75
+ results_all.update(res)
76
+
77
+ result = pd.DataFrame(results_all, index=['success', 'overall'])
78
+ result = result.T
79
+ for idx, item in result.iterrows():
80
+ result.loc[idx, 'acc'] = round(item['success'] / item['overall'] * 100, 1)
81
+ score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
82
+ dump(result, score_file)
83
+ return result
VLMEvalKit/vlmeval/dataset/videomme.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from ..smp import *
3
+ from .video_base import VideoBaseDataset
4
+ from .utils import build_judge, DEBUG_MESSAGE
5
+
6
+ FAIL_MSG = 'Failed to obtain answer via API.'
7
+
8
+
9
+ def unwrap_hf_pkl(pth, suffix='.mp4'):
10
+ base_dir = os.path.join(pth, 'video_pkl/')
11
+ target_dir = os.path.join(pth, 'video/')
12
+ pickle_files = [os.path.join(base_dir, file) for file in os.listdir(base_dir)]
13
+ pickle_files.sort()
14
+
15
+ if not os.path.exists(target_dir):
16
+ os.makedirs(target_dir, exist_ok=True)
17
+ for pickle_file in pickle_files:
18
+ with open(pickle_file, 'rb') as file:
19
+ video_data = pickle.load(file)
20
+ # For each video file in the pickle file, write its contents to a new mp4 file
21
+ for video_name, video_content in video_data.items():
22
+ output_path = os.path.join(target_dir, f'{video_name}{suffix}')
23
+ with open(output_path, 'wb') as output_file:
24
+ output_file.write(video_content)
25
+ print('The video file has been restored and stored from the pickle file.')
26
+ else:
27
+ print('The video file already exists.')
28
+
29
+
30
+ class VideoMME(VideoBaseDataset):
31
+
32
+ MD5 = '85bdd91f9b29a99354c23b97ab7c113c'
33
+ SYS = ''
34
+
35
+ FRAMES_TMPL_NOSUB = """
36
+ These are the frames of a video. \
37
+ Select the best answer to the following multiple-choice question based on the video. \
38
+ Respond with only the letter (A, B, C, or D) of the correct option.
39
+ """
40
+
41
+ FRAMES_TMPL_SUB = """
42
+ These are the frames of a video. \
43
+ This video's subtitles are listed below:
44
+ {}
45
+ Select the best answer to the following multiple-choice question based on the video. \
46
+ Respond with only the letter (A, B, C, or D) of the correct option.
47
+ """
48
+
49
+ TYPE = 'Video-MCQ'
50
+
51
+ def __init__(self, dataset='Video-MME', use_subtitle=False):
52
+ super().__init__(dataset=dataset)
53
+ self.use_subtitle = use_subtitle
54
+ self.dataset_name = dataset
55
+
56
+ @classmethod
57
+ def supported_datasets(cls):
58
+ return ['Video-MME']
59
+
60
+ def prepare_dataset(self, dataset_name='Video-MME', repo_id='lmms-lab/Video-MME'):
61
+
62
+ def check_integrity(pth):
63
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
64
+
65
+ if not os.path.exists(data_file):
66
+ return False
67
+
68
+ if md5(data_file) != self.MD5:
69
+ return False
70
+ data = load(data_file)
71
+ for video_pth in data['video_path']:
72
+ if not osp.exists(osp.join(pth, video_pth)):
73
+ return False
74
+ return True
75
+
76
+ cache_path = get_cache_path(repo_id)
77
+ if cache_path is not None and check_integrity(cache_path):
78
+ dataset_path = cache_path
79
+ else:
80
+
81
+ def unzip_hf_zip(pth):
82
+ import zipfile
83
+ base_dir = pth
84
+ target_dir = os.path.join(pth, 'video/')
85
+ zip_files = [
86
+ os.path.join(base_dir, file) for file in os.listdir(base_dir)
87
+ if file.endswith('.zip') and file.startswith('video')
88
+ ]
89
+ zip_files.sort()
90
+
91
+ if not os.path.exists(target_dir):
92
+ os.makedirs(target_dir, exist_ok=True)
93
+ for zip_file in zip_files:
94
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
95
+ for member in zip_ref.namelist():
96
+ # Check if the member is a file (not a directory)
97
+ if not member.endswith('/'):
98
+ # Extract the file to the specified directory
99
+ source = zip_ref.open(member)
100
+ target = open(os.path.join(target_dir, os.path.basename(member)), 'wb')
101
+ with source, target:
102
+ target.write(source.read())
103
+ print('The video file has been restored and stored from the zip file.')
104
+ else:
105
+ print('The video file already exists.')
106
+
107
+ subtitle_zip_file = os.path.join(base_dir, 'subtitle.zip')
108
+ subtitle_target_dir = os.path.join(base_dir, 'subtitle')
109
+
110
+ if not os.path.exists(subtitle_target_dir):
111
+ os.makedirs(subtitle_target_dir, exist_ok=True)
112
+ with zipfile.ZipFile(subtitle_zip_file, 'r') as zip_ref:
113
+ for member in zip_ref.namelist():
114
+ # Check if the member is a file (not a directory)
115
+ if not member.endswith('/'):
116
+ # Extract the file to the specified directory
117
+ source = zip_ref.open(member)
118
+ target = open(os.path.join(subtitle_target_dir, os.path.basename(member)), 'wb')
119
+ with source, target:
120
+ target.write(source.read())
121
+ print('The subtitle file has been restored and stored from the zip file.')
122
+ else:
123
+ print('The subtitle file already exists.')
124
+
125
+ def generate_tsv(pth):
126
+
127
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
128
+ if os.path.exists(data_file) and md5(data_file) == self.MD5:
129
+ return
130
+
131
+ data_file = pd.read_parquet(os.path.join(pth, 'videomme/test-00000-of-00001.parquet'))
132
+ data_file = data_file.assign(index=range(len(data_file)))
133
+ data_file['video'] = data_file['videoID']
134
+ data_file['video_path'] = data_file['videoID'].apply(lambda x: f'./video/{x}.mp4')
135
+ data_file['subtitle_path'] = data_file['videoID'].apply(lambda x: f'./subtitle/{x}.srt')
136
+ data_file['candidates'] = data_file['options'].apply(lambda x: x.tolist())
137
+
138
+ data_file = data_file[['index', 'video', 'video_path', 'duration', 'domain', 'candidates',
139
+ 'sub_category', 'task_type', 'subtitle_path', 'question', 'answer']]
140
+
141
+ data_file.to_csv(osp.join(pth, f'{dataset_name}.tsv'), sep='\t', index=False)
142
+
143
+ if modelscope_flag_set():
144
+ from modelscope import dataset_snapshot_download
145
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
146
+ else:
147
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
148
+ unzip_hf_zip(dataset_path)
149
+ generate_tsv(dataset_path)
150
+
151
+ data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
152
+
153
+ return dict(data_file=data_file, root=dataset_path)
154
+
155
+ def save_video_frames(self, video, num_frames=8, fps=-1, video_llm=False):
156
+
157
+ vid_path = osp.join(self.data_root, 'video', video + '.mp4')
158
+ vid = decord.VideoReader(vid_path)
159
+ video_info = {
160
+ 'fps': vid.get_avg_fps(),
161
+ 'n_frames': len(vid),
162
+ }
163
+ if num_frames > 0 and fps < 0:
164
+ step_size = len(vid) / (num_frames + 1)
165
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
166
+ frame_paths = self.frame_paths(video, num_frames)
167
+ elif fps > 0:
168
+ # not constrained by num_frames, get frames by fps
169
+ total_duration = video_info['n_frames'] / video_info['fps']
170
+ required_frames = int(total_duration * fps)
171
+ step_size = video_info['fps'] / fps
172
+ indices = [int(i * step_size) for i in range(required_frames)]
173
+ frame_paths = self.frame_paths_fps(video, len(indices), fps)
174
+
175
+ flag = np.all([osp.exists(p) for p in frame_paths])
176
+
177
+ if not flag:
178
+ images = [vid[i].asnumpy() for i in indices]
179
+ images = [Image.fromarray(arr) for arr in images]
180
+ for im, pth in zip(images, frame_paths):
181
+ if not osp.exists(pth) and not video_llm:
182
+ im.save(pth)
183
+
184
+ return frame_paths, indices, video_info
185
+
186
+ def save_video_into_images(self, line, num_frames=8):
187
+ frame_paths, indices, video_info = self.save_video_frames(line['video'], num_frames)
188
+ return frame_paths
189
+
190
+ def build_prompt(self, line, num_frames, video_llm, fps):
191
+ if isinstance(line, int):
192
+ assert line < len(self)
193
+ line = self.data.iloc[line]
194
+
195
+ frames, indices, video_info = self.save_video_frames(line['video'], num_frames, fps, video_llm)
196
+
197
+ if self.use_subtitle and os.path.exists(osp.join(self.data_root, line['subtitle_path'])):
198
+ import pysubs2
199
+ subs = pysubs2.load(osp.join(self.data_root, line['subtitle_path']), encoding='utf-8')
200
+ subtitles = []
201
+
202
+ for seleced_frame_id in indices:
203
+ sub_text = ''
204
+ cur_time = pysubs2.make_time(fps=video_info['fps'], frames=seleced_frame_id)
205
+ for sub in subs:
206
+ if sub.start < cur_time and sub.end > cur_time:
207
+ sub_text = sub.text.replace('\\N', ' ')
208
+ break
209
+ if sub_text.strip():
210
+ subtitles.append(sub_text)
211
+ subtitles = '\n'.join(subtitles)
212
+ else:
213
+ subtitles = ''
214
+
215
+ message = [dict(type='text', value=self.SYS)]
216
+ if video_llm:
217
+ message.append(dict(type='video', value=osp.join(self.data_root, 'video', line['video'] + '.mp4')))
218
+ else:
219
+ for im in frames:
220
+ message.append(dict(type='image', value=im))
221
+
222
+ text_prompt = self.FRAMES_TMPL_NOSUB if not self.use_subtitle else self.FRAMES_TMPL_SUB.format(subtitles)
223
+ message.append(dict(type='text', value=text_prompt))
224
+ line['question'] += '\n' + '\n'.join(eval(line['candidates']))
225
+ prompt = 'Question: {}\nAnswer: '.format(line['question'])
226
+ message.append(dict(type='text', value=prompt))
227
+ return message
228
+
229
+ # It returns a dictionary
230
+ @classmethod
231
+ def evaluate(self, eval_file, **judge_kwargs):
232
+ from .utils.videomme import get_dimension_rating, extract_characters_regex, extract_option
233
+
234
+ assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'
235
+
236
+ tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
237
+ tgt_file = eval_file.replace('.xlsx', '_rating.json')
238
+ score_file = eval_file.replace('.xlsx', '_score.xlsx')
239
+
240
+ if not osp.exists(score_file):
241
+ model = judge_kwargs.get('model', 'exact_matching')
242
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
243
+
244
+ if model == 'exact_matching':
245
+ model = None
246
+ elif gpt_key_set():
247
+ model = build_judge(**judge_kwargs)
248
+ if not model.working():
249
+ warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
250
+ warnings.warn(DEBUG_MESSAGE)
251
+ model = None
252
+ else:
253
+ warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
254
+ model = None
255
+ res = {} if not osp.exists(tmp_file) else load(tmp_file)
256
+ res = {k: v for k, v in res.items() if FAIL_MSG not in v}
257
+
258
+ data = load(eval_file)
259
+ data_un = data[~pd.isna(data['prediction'])]
260
+
261
+ for idx in data['index']:
262
+ ans = data.loc[data['index'] == idx, 'answer'].values[0]
263
+ pred = str(data.loc[data['index'] == idx, 'prediction'].values[0])
264
+
265
+ if extract_characters_regex(pred) == '':
266
+ extract_pred = extract_option(
267
+ model,
268
+ data.loc[data['index'] == idx].to_dict(orient='records')[0],
269
+ 'Video-MME'
270
+ )
271
+ data.loc[idx, 'score'] = int(extract_pred == ans)
272
+ else:
273
+ data.loc[idx, 'score'] = int(extract_characters_regex(pred) == ans)
274
+
275
+ rejected = [x for x in data['score'] if x == -1]
276
+
277
+ print(
278
+ f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
279
+ f'failed to obtain the score for another {len(rejected)} questions. '
280
+ f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
281
+ )
282
+
283
+ dump(data, score_file)
284
+
285
+ rating = get_dimension_rating(score_file)
286
+ dump(rating, tgt_file)
287
+ return rating
VLMEvalKit/vlmeval/dataset/wildvision.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from functools import partial
3
+
4
+ from .image_base import ImageBaseDataset
5
+ from .utils import build_judge, DEBUG_MESSAGE
6
+ from ..smp import *
7
+ from ..utils import track_progress_rich
8
+
9
+
10
+ SYSTEM_PROMPT = """\
11
+ Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user \
12
+ prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate \
13
+ which assistant's answer is better.
14
+
15
+ Begin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any \
16
+ answers.
17
+
18
+ When evaluating the assistants' answers, compare both assistants' answers with your answer. \
19
+ You must identify and correct any mistakes or inaccurate information.
20
+
21
+ Then consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly \
22
+ responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one \
23
+ interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than \
24
+ providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate \
25
+ to what is being asked. Concise means the response is clear and not verbose or excessive.
26
+
27
+ Then consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing \
28
+ important information in the assistants' answers that would be beneficial to include when responding to the user \
29
+ prompt.
30
+
31
+ After providing your explanation, you must output only one of the following choices as your final verdict with a label:
32
+
33
+ 1. Assistant A is significantly better: [[A>>B]]
34
+ 2. Assistant A is slightly better: [[A>B]]
35
+ 3. Tie, relatively the same: [[A=B]]
36
+ 4. Assistant B is slightly better: [[B>A]]
37
+ 5. Assistant B is significantly better: [[B>>A]]
38
+
39
+ Example output: "My final verdict is tie: [[A=B]]".\
40
+ """
41
+
42
+
43
+ PROMPT_TEMPLATE = """\
44
+ "<|User Prompt|>\n{question}
45
+
46
+ <|The Start of Assistant A's Answer|>\n{answer_1}\n<|The End of Assistant A's Answer|>
47
+
48
+ <|The Start of Assistant B's Answer|>\n{answer_2}\n<|The End of Assistant B's Answer|>
49
+ """
50
+
51
+
52
+ REGEX_PATTERN = re.compile("\[\[([AB<>=]+)\]\]") # noqa: W605
53
+
54
+
55
+ def get_score(judgement, pattern=REGEX_PATTERN):
56
+ matches = pattern.findall(judgement)
57
+ matches = [m for m in matches if m != ""]
58
+ if len(set(matches)) == 0:
59
+ return None, True
60
+ elif len(set(matches)) == 1:
61
+ return matches[0].strip("\n"), False
62
+ else:
63
+ return None, True
64
+
65
+
66
+ def WildVision_auxeval(model, line):
67
+ config = dict(question=line['question'], answer_1=line['A'], answer_2=line['B'])
68
+ prompt = PROMPT_TEMPLATE.format(**config)
69
+
70
+ prefix = 'data:image/jpeg;base64,'
71
+ img = prefix + line['image']
72
+
73
+ messages = [
74
+ dict(type='text', value=prompt),
75
+ dict(type='image', value=img)
76
+ ]
77
+
78
+ retry = 2
79
+ while retry:
80
+ resp = model.generate(messages)
81
+ score, try_again = get_score(resp)
82
+ if not try_again:
83
+ break
84
+ retry -= 1
85
+
86
+ if score is None:
87
+ return 'Unknown'
88
+ return score
89
+
90
+
91
+ class WildVision(ImageBaseDataset):
92
+ TYPE = 'VQA'
93
+ DATASET_URL = {
94
+ 'WildVision': 'https://opencompass.openxlab.space/utils/VLMEval/WildVision.tsv'
95
+ }
96
+ DATASET_MD5 = {'WildVision': 'b38f80156d49411c594772866b0d0b52'}
97
+
98
+ score_map = {
99
+ 'A>>B': -2,
100
+ 'A>B': -1,
101
+ 'A=B': 0,
102
+ 'B>A': 1,
103
+ 'B>>A': 2
104
+ }
105
+
106
+ # Given one data record, return the built prompt (a multi-modal message), can override
107
+ def build_prompt(self, line):
108
+ if isinstance(line, int):
109
+ line = self.data.iloc[line]
110
+
111
+ if self.meta_only:
112
+ tgt_path = toliststr(line['image_path'])
113
+ else:
114
+ tgt_path = self.dump_image(line)
115
+
116
+ question = line['question']
117
+
118
+ msgs = []
119
+ if isinstance(tgt_path, list):
120
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
121
+ else:
122
+ msgs = [dict(type='image', value=tgt_path)]
123
+ # WildVision adopts text first
124
+ msgs = [dict(type='text', value=question)] + msgs
125
+ return msgs
126
+
127
+ @classmethod
128
+ def gen_eval_base(self, eval_file, b64_map):
129
+ data = load(eval_file)
130
+ data['B'] = data.pop('prediction')
131
+ data['A'] = data.pop('claude3_sonnet')
132
+ data['image'] = [b64_map[x] for x in data['index']]
133
+ return data
134
+ # rev = cp.deepcopy(data)
135
+ # rev['A'] = data['B']
136
+ # rev['B'] = data['A']
137
+ # rev['index'] = [x + '_rev' for x in data['index']]
138
+ # return pd.concat([data, rev], ignore_index=True)
139
+
140
+ # It returns a DataFrame
141
+ @classmethod
142
+ def evaluate(self, eval_file, **judge_kwargs):
143
+ # We adopt pairwise evaluation (twice for a pair) for this dataset
144
+ suffix = eval_file.split('.')[-1]
145
+ model = judge_kwargs['model']
146
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
147
+ score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.csv')
148
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
149
+ nproc = judge_kwargs.pop('nproc', 4)
150
+
151
+ if not osp.exists(storage):
152
+ raw_data = WildVision('WildVision').data
153
+ b64_map = {x: y for x, y in zip(raw_data['index'], raw_data['image'])}
154
+ data = self.gen_eval_base(eval_file, b64_map)
155
+
156
+ judge_kwargs['system_prompt'] = SYSTEM_PROMPT
157
+ judge_kwargs['temperature'] = 0
158
+ judge_kwargs['img_detail'] = 'high'
159
+ judge_kwargs['timeout'] = 300
160
+ model = build_judge(max_tokens=4096, **judge_kwargs)
161
+
162
+ assert model.working(), ('WildVision evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
163
+
164
+ lt = len(data)
165
+ lines = [data.iloc[i] for i in range(lt)]
166
+ tups = [(model, line) for line in lines]
167
+ indices = [line['index'] for line in lines]
168
+
169
+ ans = load(tmp_file) if osp.exists(tmp_file) else {}
170
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
171
+ indices = [i for i in indices if i not in ans]
172
+
173
+ if len(indices):
174
+ new_results = track_progress_rich(
175
+ WildVision_auxeval,
176
+ tups,
177
+ nproc=nproc,
178
+ chunksize=nproc,
179
+ keys=indices,
180
+ save=tmp_file,
181
+ )
182
+ ans = load(tmp_file)
183
+ for k, v in zip(indices, new_results):
184
+ ans[k] = v
185
+
186
+ data['score'] = [ans[idx] for idx in data['index']]
187
+ data.pop('image')
188
+ dump(data, storage)
189
+
190
+ data = load(storage)
191
+ lt = len(data)
192
+
193
+ scores = defaultdict(lambda: 0)
194
+ for i in range(lt):
195
+ item = data.iloc[i]
196
+ if item['score'] not in self.score_map:
197
+ score = 0
198
+ else:
199
+ score = self.score_map[item['score']]
200
+ if '_rev' in item['index']:
201
+ score = -score
202
+ scores[score] += 1
203
+ name_map = {
204
+ 2: 'Much Better',
205
+ 1: 'Better',
206
+ 0: 'Tie',
207
+ -1: 'Worse',
208
+ -2: 'Much Worse'
209
+ }
210
+ scores = {name_map[k]: v for k, v in scores.items()}
211
+ scores['Reward'] = (
212
+ 100 * scores['Much Better'] + 50 * scores['Better'] - 50 * scores['Worse'] - 100 * scores['Much Worse']
213
+ ) / lt
214
+ scores['Win Rate'] = (scores['Better'] + scores['Much Better']) / lt
215
+ scores = {k: [v] for k, v in scores.items()}
216
+ scores = pd.DataFrame(scores)
217
+ dump(scores, score_file)
218
+ return scores
VLMEvalKit/vlmeval/smp/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .file import *
2
+ from .vlm import *
3
+ from .misc import *
4
+ from .log import *
VLMEvalKit/vlmeval/smp/log.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logging.basicConfig(
3
+ format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
4
+ datefmt='%Y-%m-%d %H:%M:%S')
5
+
6
+ logger_initialized = {}
7
+
8
+
9
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
10
+ logger = logging.getLogger(name)
11
+ if name in logger_initialized:
12
+ return logger
13
+
14
+ for logger_name in logger_initialized:
15
+ if name.startswith(logger_name):
16
+ return logger
17
+
18
+ stream_handler = logging.StreamHandler()
19
+ handlers = [stream_handler]
20
+
21
+ try:
22
+ import torch.distributed as dist
23
+ if dist.is_available() and dist.is_initialized():
24
+ rank = dist.get_rank()
25
+ else:
26
+ rank = 0
27
+ except ImportError:
28
+ rank = 0
29
+
30
+ if rank == 0 and log_file is not None:
31
+ file_handler = logging.FileHandler(log_file, file_mode)
32
+ handlers.append(file_handler)
33
+
34
+ formatter = logging.Formatter(
35
+ '[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s')
36
+ for handler in handlers:
37
+ handler.setFormatter(formatter)
38
+ handler.setLevel(log_level)
39
+ logger.addHandler(handler)
40
+
41
+ if rank == 0:
42
+ logger.setLevel(log_level)
43
+ else:
44
+ logger.setLevel(logging.ERROR)
45
+
46
+ logger_initialized[name] = True
47
+ return logger
VLMEvalKit/vlmeval/smp/misc.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401, F403
2
+ import abc
3
+ import argparse
4
+ import csv
5
+ import multiprocessing as mp
6
+ import os
7
+ import os.path as osp
8
+ from pathlib import Path
9
+ import copy as cp
10
+ import random as rd
11
+ import requests
12
+ import shutil
13
+ import subprocess
14
+ import warnings
15
+ import pandas as pd
16
+ from collections import OrderedDict, defaultdict
17
+ from multiprocessing import Pool, current_process
18
+ from tqdm import tqdm
19
+ import datetime
20
+ import matplotlib.pyplot as plt
21
+ from tabulate import tabulate
22
+ from json import JSONDecoder
23
+ from huggingface_hub import scan_cache_dir
24
+ from huggingface_hub.utils._cache_manager import _scan_cached_repo
25
+ from sty import fg, bg, ef, rs
26
+
27
+
28
+ def modelscope_flag_set():
29
+ return os.environ.get('VLMEVALKIT_USE_MODELSCOPE', None) in ['1', 'True']
30
+
31
+
32
+ def process_punctuation(inText):
33
+ import re
34
+ outText = inText
35
+ punct = [
36
+ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
37
+ '>', '<', '@', '`', ',', '?', '!'
38
+ ]
39
+ commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
40
+ periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
41
+ for p in punct:
42
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(
43
+ commaStrip, inText) is not None):
44
+ outText = outText.replace(p, '')
45
+ else:
46
+ outText = outText.replace(p, ' ')
47
+ outText = periodStrip.sub('', outText, re.UNICODE)
48
+ return outText
49
+
50
+ def h2r(value):
51
+ if value[0] == '#':
52
+ value = value[1:]
53
+ assert len(value) == 6
54
+ return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
55
+
56
+ def r2h(rgb):
57
+ return '#%02x%02x%02x' % rgb
58
+
59
+ def colored(s, color):
60
+ if isinstance(color, str):
61
+ if hasattr(fg, color):
62
+ return getattr(fg, color) + s + fg.rs
63
+ color = h2r(color)
64
+ return fg(*color) + s + fg.rs
65
+
66
+ def istype(s, type):
67
+ if isinstance(s, type):
68
+ return True
69
+ try:
70
+ return isinstance(eval(s), type)
71
+ except Exception as _:
72
+ return False
73
+
74
+ def bincount(lst):
75
+ bins = defaultdict(lambda: 0)
76
+ for item in lst:
77
+ bins[item] += 1
78
+ return bins
79
+
80
+ def get_cache_path(repo_id, branch='main', repo_type='datasets'):
81
+ try:
82
+ if modelscope_flag_set():
83
+ from modelscope.hub.file_download import create_temporary_directory_and_cache
84
+ if repo_type == 'datasets':
85
+ repo_type = 'dataset'
86
+ _, cache = create_temporary_directory_and_cache(model_id=repo_id, repo_type=repo_type)
87
+ cache_path = cache.get_root_location()
88
+ return cache_path
89
+ else:
90
+ from .file import HFCacheRoot
91
+ cache_path = HFCacheRoot()
92
+ org, repo_name = repo_id.split('/')
93
+ repo_path = Path(osp.join(cache_path, f'{repo_type}--{org}--{repo_name}/'))
94
+ hf_cache_info = _scan_cached_repo(repo_path=repo_path)
95
+ revs = {r.refs: r for r in hf_cache_info.revisions}
96
+ if branch is not None:
97
+ revs = {refs: r for refs, r in revs.items() if branch in refs}
98
+ rev2keep = max(revs.values(), key=lambda r: r.last_modified)
99
+ return str(rev2keep.snapshot_path)
100
+ except Exception as e:
101
+ import logging
102
+ logging.warning(f'{type(e)}: {e}')
103
+ return None
104
+
105
+ def proxy_set(s):
106
+ import os
107
+ for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
108
+ os.environ[key] = s
109
+
110
+ def get_rank_and_world_size():
111
+ rank = int(os.environ.get('RANK', 0))
112
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
113
+ return rank, world_size
114
+
115
+ def splitlen(s, sym='/'):
116
+ return len(s.split(sym))
117
+
118
+ def listinstr(lst, s):
119
+ assert isinstance(lst, list)
120
+ for item in lst:
121
+ if item in s:
122
+ return True
123
+ return False
124
+
125
+ def d2df(D):
126
+ return pd.DataFrame({x: [D[x]] for x in D})
127
+
128
+ def cn_string(s):
129
+ import re
130
+ if re.search(u'[\u4e00-\u9fff]', s):
131
+ return True
132
+ return False
133
+
134
+ try:
135
+ import decord
136
+ except ImportError:
137
+ pass
138
+
139
+ def timestr(granularity='second'):
140
+ s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
141
+ assert granularity in ['second', 'minute', 'hour', 'day']
142
+ if granularity == 'second':
143
+ return s
144
+ elif granularity == 'minute':
145
+ return s[:-2]
146
+ elif granularity == 'hour':
147
+ return s[:-4]
148
+ elif granularity == 'day':
149
+ return s[:-6]
150
+
151
+ def _minimal_ext_cmd(cmd, cwd=None):
152
+ env = {}
153
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
154
+ v = os.environ.get(k)
155
+ if v is not None:
156
+ env[k] = v
157
+ env['LANGUAGE'] = 'C'
158
+ env['LANG'] = 'C'
159
+ env['LC_ALL'] = 'C'
160
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=cwd).communicate()[0]
161
+ return out
162
+
163
+ def githash(fallback='unknown', digits=8):
164
+ if digits is not None and not isinstance(digits, int):
165
+ raise TypeError('digits must be None or an integer')
166
+ try:
167
+ import vlmeval
168
+ except ImportError as e:
169
+ import logging
170
+ logging.error(f'ImportError: {str(e)}')
171
+ return fallback
172
+ try:
173
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'], cwd=vlmeval.__path__[0])
174
+ sha = out.strip().decode('ascii')
175
+ if digits is not None:
176
+ sha = sha[:digits]
177
+ except OSError:
178
+ sha = fallback
179
+ return sha
180
+
181
+ def dict_merge(dct, merge_dct):
182
+ for k, _ in merge_dct.items():
183
+ if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
184
+ dict_merge(dct[k], merge_dct[k])
185
+ else:
186
+ dct[k] = merge_dct[k]
187
+
188
+ def youtube_dl(idx):
189
+ cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
190
+ os.system(cmd)
191
+
192
+ def run_command(cmd):
193
+ if isinstance(cmd, str):
194
+ cmd = cmd.split()
195
+ return subprocess.check_output(cmd).decode()
196
+
197
+ def load_env():
198
+ import logging
199
+ logging.basicConfig(
200
+ format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
201
+ datefmt='%Y-%m-%d %H:%M:%S')
202
+
203
+ try:
204
+ import vlmeval
205
+ except ImportError:
206
+ logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
207
+ return
208
+ pth = osp.realpath(vlmeval.__path__[0])
209
+ pth = osp.join(pth, '../.env')
210
+ pth = osp.realpath(pth)
211
+ if not osp.exists(pth):
212
+ logging.error(f'Did not detect the .env file at {pth}, failed to load. ')
213
+ return
214
+
215
+ from dotenv import dotenv_values
216
+ values = dotenv_values(pth)
217
+ for k, v in values.items():
218
+ if v is not None and len(v):
219
+ os.environ[k] = v
220
+ logging.info(f'API Keys successfully loaded from {pth}')
221
+
222
+ def pip_install_robust(package):
223
+ import sys
224
+ retry = 3
225
+ while retry > 0:
226
+ try:
227
+ package_base = package.split('=')[0]
228
+ module = __import__(package)
229
+ return True
230
+ except ImportError:
231
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
232
+ retry -= 1
233
+ return False
234
+
235
+
236
+ def version_cmp(v1, v2, op='eq'):
237
+ from packaging import version
238
+ import operator
239
+ op_func = getattr(operator, op)
240
+ return op_func(version.parse(v1), version.parse(v2))
241
+
242
+
243
+ def toliststr(s):
244
+ if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
245
+ return [str(x) for x in eval(s)]
246
+ elif isinstance(s, str):
247
+ return [s]
248
+ elif isinstance(s, list):
249
+ return [str(x) for x in s]
250
+ raise NotImplementedError
251
+
252
+
253
+ def extract_json_objects(text, decoder=JSONDecoder()):
254
+ pos = 0
255
+ while True:
256
+ match = text.find('{', pos)
257
+ if match == -1: break
258
+ try:
259
+ result, index = decoder.raw_decode(text[match:])
260
+ yield result
261
+ pos = match + index
262
+ except ValueError:
263
+ pos = match + 1
264
+
265
+
266
+ def get_gpu_memory():
267
+ import subprocess
268
+ try:
269
+ command = "nvidia-smi --query-gpu=memory.free --format=csv"
270
+ memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
271
+ memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
272
+ return memory_free_values
273
+ except Exception as e:
274
+ print(f'{type(e)}: {str(e)}')
275
+ return []
276
+
277
+
278
+ def auto_split_flag():
279
+ flag = os.environ.get('AUTO_SPLIT', '0')
280
+ return flag == '1'
VLMEvalKit/vlmeval/vlm/aria.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ import copy as cp
4
+ from PIL import Image
5
+ import pandas as pd
6
+ import string
7
+ import re
8
+ from .base import BaseModel
9
+ from ..smp import isimg, listinstr, cn_string
10
+ from ..dataset import DATASET_TYPE, DATASET_MODALITY
11
+
12
+
13
+ class Aria(BaseModel):
14
+
15
+ INSTALL_REQ = False
16
+ INTERLEAVE = True
17
+
18
+ def __init__(self, model_path='rhymes-ai/Aria', **kwargs):
19
+ from transformers import AutoModelForCausalLM, AutoProcessor
20
+ assert model_path is not None
21
+ self.model_path = model_path
22
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
23
+ tokenizer = processor.tokenizer
24
+ tokenizer.padding_side = 'left'
25
+ tokenizer.pad_token_id = tokenizer.unk_token_id
26
+ self.processor = processor
27
+ self.tokenizer = tokenizer
28
+ self.model = AutoModelForCausalLM.from_pretrained(
29
+ model_path,
30
+ device_map='cuda',
31
+ torch_dtype=torch.bfloat16,
32
+ trust_remote_code=True
33
+ ).eval()
34
+ default_kwargs = dict(
35
+ do_sample=False,
36
+ num_beams=1,
37
+ max_new_tokens=512,
38
+ min_new_tokens=1,
39
+ num_return_sequences=1,
40
+ use_cache=True,
41
+ output_hidden_states=True,
42
+ pad_token_id=tokenizer.unk_token_id,
43
+ stop_strings=["<|im_end|>"],
44
+ tokenizer=processor.tokenizer,
45
+ )
46
+ default_kwargs.update(kwargs)
47
+ self.kwargs = default_kwargs
48
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
49
+ torch.cuda.empty_cache()
50
+
51
+ def use_custom_prompt(self, dataset):
52
+ assert dataset is not None
53
+ if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
54
+ # For Multi-Turn we don't have custom prompt
55
+ return False
56
+ if DATASET_MODALITY(dataset) == 'VIDEO':
57
+ # For Video benchmarks we don't have custom prompt at here
58
+ return False
59
+ else:
60
+ return True
61
+
62
+ def build_prompt(self, line, dataset=None):
63
+ assert self.use_custom_prompt(dataset)
64
+ assert dataset is None or isinstance(dataset, str)
65
+ tgt_path = self.dump_image(line, dataset)
66
+
67
+ question = line['question']
68
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
69
+ if hint is not None:
70
+ question = hint + '\n' + question
71
+
72
+ options = {
73
+ cand: line[cand]
74
+ for cand in string.ascii_uppercase
75
+ if cand in line and not pd.isna(line[cand])
76
+ }
77
+ for key, item in options.items():
78
+ question += f'\n{key}. {item}'
79
+ prompt = question
80
+
81
+ if len(options):
82
+ prompt += (
83
+ "\nAnswer with the option's letter from the given choices directly."
84
+ )
85
+ else:
86
+ if listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse'], dataset):
87
+ prompt = prompt
88
+ elif listinstr(['LLaVABench', 'MMBench-Video'], dataset):
89
+ prompt += '\nAnswer this question in detail.'
90
+ elif listinstr(['DocVQA'], dataset):
91
+ prompt += '\nAnswer briefly and directly.'
92
+ else:
93
+ prompt += '\nAnswer the question using a single word or phrase.'
94
+
95
+ message = [dict(type='image', value=s) for s in tgt_path]
96
+ message.append(dict(type='text', value=prompt))
97
+ return message
98
+
99
+ def build_video_prompt(self, prompt, dataset=None):
100
+ if listinstr(['MMBench-Video'], dataset):
101
+ prompt = prompt.replace('\nAnswer:', '')
102
+ prompt = prompt.replace(
103
+ 'Question: ',
104
+ 'Please carefully check the video and then answer the following question with details:'
105
+ )
106
+ elif listinstr(['Video-MME'], dataset):
107
+ prompt = prompt.replace('\nAnswer:', '')
108
+ prompt += "\nAnswer with the option's letter from the given choices directly."
109
+ elif listinstr(['MVBench'], dataset):
110
+ prompt = prompt.replace('Best option:(', '')
111
+ system_prompt = 'Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n' # noqa: E501
112
+ prompt = prompt.replace(system_prompt, '')
113
+
114
+ return prompt
115
+
116
+ def adjust_kwargs(self, dataset):
117
+ kwargs = cp.deepcopy(self.kwargs)
118
+ kwargs["temperature"] = 0.0
119
+ kwargs["do_sample"] = False
120
+
121
+ if DATASET_MODALITY(dataset) == "VIDEO":
122
+ kwargs["max_image_size"] = 490
123
+ else:
124
+ kwargs["max_image_size"] = 980
125
+
126
+ kwargs["split_image"] = False
127
+
128
+ if listinstr(['MMMU', 'MMStar', 'Math'], dataset):
129
+ # These datasets may lead the model to work as a CoT-alike behaviour.
130
+ # Allow to output longer.
131
+ kwargs['max_new_tokens'] = 512
132
+ return kwargs
133
+ if DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
134
+ kwargs['max_new_tokens'] = 64
135
+ elif DATASET_TYPE(dataset) == 'Caption' and 'COCO' in dataset:
136
+ kwargs['max_new_tokens'] = 64
137
+ elif DATASET_TYPE(dataset) == 'VQA':
138
+ if listinstr(['OCRVQA', 'ChartQA', 'DocVQA'], dataset):
139
+ kwargs['max_new_tokens'] = 128
140
+ elif listinstr(['TextVQA'], dataset):
141
+ kwargs['max_new_tokens'] = 32
142
+
143
+ if listinstr(['OCR', 'ChartQA', 'DocVQA', 'InfoVQA', 'TextVQA'], dataset):
144
+ # OCR-related datasets that need to split image
145
+ kwargs["split_image"] = True
146
+
147
+ return kwargs
148
+
149
+ def generate_inner(self, message, dataset=None):
150
+ if dataset is not None:
151
+ kwargs = self.adjust_kwargs(dataset)
152
+ else:
153
+ kwargs = self.kwargs
154
+
155
+ max_image_size = kwargs.pop("max_image_size")
156
+ split_image = kwargs.pop("split_image")
157
+
158
+ prompt = '<|im_start|>user\n'
159
+ images = []
160
+ last_message_modality = "text"
161
+
162
+ if listinstr(['MLVU', 'TempCompass', 'MVBench'], dataset): # re-arrange the data
163
+ new_message = []
164
+ for s in message:
165
+ if s['type'] == 'image':
166
+ new_message.append(s)
167
+ for s in message:
168
+ if s['type'] == 'text':
169
+ new_message.append(s)
170
+ message = new_message
171
+
172
+ for s in message:
173
+ if s['type'] == 'image':
174
+ prompt += '<fim_prefix><|img|><fim_suffix>'
175
+ images.append(s['value'])
176
+ last_message_modality = "image"
177
+ elif s['type'] == 'text':
178
+ text = re.sub(r"<image \d+>", "", s["value"])
179
+ if last_message_modality == "image":
180
+ prompt += "\n"
181
+ last_message_modality = "text"
182
+ prompt += text
183
+
184
+ if DATASET_MODALITY(dataset) == 'VIDEO':
185
+ prompt = self.build_video_prompt(prompt, dataset)
186
+
187
+ prompt += '<|im_end|>\n<|im_start|>assistant\n'
188
+ if images:
189
+ images = [Image.open(s).convert('RGB') for s in images]
190
+ encoded = self.processor(
191
+ text=prompt,
192
+ images=images,
193
+ return_tensors='pt',
194
+ padding='longest',
195
+ max_image_size=max_image_size,
196
+ split_image=split_image,
197
+ )
198
+ else:
199
+ encoded = self.processor(text=prompt, return_tensors='pt', padding='longest')
200
+ encoded["pixel_values"] = encoded["pixel_values"].to(self.model.dtype)
201
+ encoded = {k: v.to(self.model.device) for k, v in encoded.items()}
202
+
203
+ pred = self.model.generate(**encoded, **kwargs)
204
+ answer = self.tokenizer.decode(pred[0][encoded['input_ids'].size(1):].cpu(), skip_special_tokens=True).strip()
205
+ answer = answer.replace('<|im_end|>', '')
206
+ return answer