tuandunghcmut's picture
Add files using upload-large-folder tool
c85c9e5 verified
raw
history blame
16.1 kB
import re
import json
import sympy as sp
import numpy as np
from sympy import simplify, Eq, sympify, Pow, pi
from sympy.parsing.latex import parse_latex
import sys
import math
import os
import argparse
from .image_base import ImageBaseDataset
from ..utils import track_progress_rich
from ..smp import load, dump
class AutoScoringJudge:
def __init__(self):
# Map of special symbols to their replacements
self.special_signal_map = {
"\\left": "",
"\\right": "",
"厘米":"",
# "∶": ":",
",": ",",
"$": "",
"(":"(",
")":")",
"\\infty":"oo",
"\\colon ":":",
# "\\approx": "=",
# "\\simeq": "=",
# "\\sim": "=",
# "^\\prime": "'",
# "^{\\prime}": "'",
"+":"+",
"\\, ": "",
"\\,":"",
"^\\circ": "",
"^{\\circ}": "",
# "%": "",
}
self.pi = parse_latex("\\pi")
# MM-Math default precision
self.precision = 1e-2
def trans_greater_sign_to_interval(self, expr:str):
expr_tmp = expr.split("<")
return "(" + expr_tmp[0] + ", " + expr_tmp[-1] + ")"
def split_by_comma(self, expr: str):
# Splits expressions by commas outside of brackets
in_bracket_num = 0
splitted_expr = []
start_idx = 0
for i, char in enumerate(expr):
if char in ["(", "["]:
in_bracket_num += 1
elif char in [")", "]"]:
in_bracket_num -= 1
elif char == "," and in_bracket_num == 0:
splitted_expr.append(expr[start_idx:i].strip())
start_idx = i + 1
if start_idx < len(expr):
splitted_expr.append(expr[start_idx:].strip())
return splitted_expr
def trans_plus_minus_sign(self, expr_list: list):
# Translates plus-minus signs into separate expressions
new_expr_list = []
for expr in expr_list:
if "\\pm" in expr:
new_expr_list.append(expr.replace("\\pm", "+"))
new_expr_list.append(expr.replace("\\pm", "-"))
else:
new_expr_list.append(expr)
return new_expr_list
def judge(self, expression1, expression2, precision=1e-2):
# Judge if two expressions are equal (expression1 is considered as the Ground Truth)
# Default precision is a list for supporting multiple expressions
precision = precision if isinstance(precision, list) else [precision]
try:
expression1, expression2 = self.preprocess(expression1, expression2)
except:
return False
if expression1 == expression2:
# print("Exactly equal")
return True
# Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered
expression1 = expression1 if re.fullmatch(r"[\u4e00-\u9fff]+", expression1) else re.sub(r'[\u4e00-\u9fff]+', '', expression1) # noqa: E501
expression2 = expression2 if re.fullmatch(r'[\u4e00-\u9fff]+', expression2) else re.sub(r'[\u4e00-\u9fff]+', '', expression2) # noqa: E501
# Check if two < or > in expression
if self.is_two_greater_sign(expression1):
expression1 = self.trans_greater_sign_to_interval(expression1)
if self.is_two_greater_sign(expression2):
expression2 = self.trans_greater_sign_to_interval(expression2)
expression1 = self.split_by_comma(expression1)
expression2 = self.split_by_comma(expression2)
temp_list1 = self.trans_plus_minus_sign(expression1)
temp_list2 = self.trans_plus_minus_sign(expression2)
# Set up a list for allowed errors
if len(precision) <= 1:
precision = precision * len(temp_list1)
if len(temp_list1) != len(temp_list2):
return False
# Check if elements in both lists can be paired and are equal
idx = -1
while len(temp_list1) != 0:
idx = (idx + 1) % len(temp_list1)
item1 = temp_list1[idx]
self.precision = precision[idx]
for item2 in temp_list2:
if self.is_equal(item1, item2):
temp_list1.remove(item1)
temp_list2.remove(item2)
precision.remove(self.precision)
break
else:
# If no match was found, return False
return False
# If all elements are matched, return True
return True
def is_interval(self, expr):
# Checks if an expression is an interval
return expr.startswith(("(", "[")) and expr.endswith((")", "]"))
def is_two_greater_sign(self, expr):
match = re.findall(r'<', expr)
return len(match) == 2
def sympy_sub_pi(self, expression_sympy):
# Replaces the symbol for pi in sympy expressions with its numerical value
return expression_sympy.subs(self.pi, math.pi)
def is_equal(self, expression1, expression2):
# Default first expression is ground truth. Check if expressions are equal in different aspects
if expression1 == expression2 and expression1 != "" and expression2 != "":
# print("Equivalent natively")
return True
# First check if both are intervals
if self.is_interval(expression1) and self.is_interval(expression2):
try:
if self.interval_equal(expression1, expression2):
# print("Interval equivalent")
return True
except:
return False
# Then check for numerical equality
try:
if self.numerical_equal(expression1, expression2):
# print("Numerically equivalent")
return True
except:
pass
# Then check if expressions are mathematically equal
try:
if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2):
# print("Expression equivalent")
return True
except:
pass
# Lastly, check for equation equality
try:
if self.equation_equal(expression1, expression2):
# print("Equation equivalent")
return True
except:
pass
return False
def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True):
# Check if two numerical values are equal within an allowed error range
# Includes possible percentage cases
reference = float(expression1)
prediction = float(expression2)
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
if abs(item - prediction) <= self.precision * 1.01:
return True
return False
def expression_equal(self, exp1, exp2):
# Check if two expressions are mathematically equivalent
# Extract expression and use sympy for equivalence checking
def extract_expression(expression):
if "=" in expression:
expression = expression.split("=")[1]
return expression.strip()
exp1 = extract_expression(exp1)
exp2 = extract_expression(exp2)
exp_too_long = len(exp1) > 300 or len(exp2) > 300
expr1_sym = sympify(parse_latex(exp1))
expr2_sym = sympify(parse_latex(exp2))
if expr1_sym == expr2_sym:
return True
else:
expr1_sym = self.sympy_sub_pi(expr1_sym)
expr2_sym = self.sympy_sub_pi(expr2_sym)
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or \
(not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
return False
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
try:
if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)):
print("These two numbers cannot be calculated by the current computer for: "
f"\"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"")
return False
if exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01:
return True
else:
return False
except:
return False
elif exp_too_long:
print(f'Expression {exp1} or {exp2} is too long to compute. ')
return False
else:
try:
simplified_expr = simplify(expr1_sym - expr2_sym)
num_value = simplified_expr.evalf()
return abs(num_value) < 1e-3
except:
return False
def equation_equal(self, expression1, expression2):
# Check if two equations are mathematically equivalent
# Simplify equations and use sympy for equivalence checking
def simplify_equation(latex_eq):
lhs, rhs = latex_eq.split('=')
lhs_expr = parse_latex(lhs)
rhs_expr = parse_latex(rhs)
equation = Eq(lhs_expr, rhs_expr)
simplified_eq = simplify(equation.lhs - equation.rhs)
return simplified_eq
expr1_sym = simplify_equation(expression1)
expr2_sym = simplify_equation(expression2)
division_result_1 = simplify(expr1_sym / expr2_sym)
division_result_2 = simplify(expr2_sym / expr1_sym)
if ((division_result_1.is_Integer and division_result_1 != 0) or # noqa: W504
(division_result_2.is_Integer and division_result_2 != 0)):
return True
else:
return False
def interval_equal(self, expression1, expression2):
# Check if two intervals are mathematically equivalent
def compare_two_interval(inter1, inter2):
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
return False
inter1 = inter1.strip('[]()')
inter2 = inter2.strip('[]()')
items_1 = inter1.split(',')
items_2 = inter2.split(',')
for item_1, item_2 in zip(items_1, items_2):
if not self.expression_equal(item_1, item_2):
return False
return True
interval1 = expression1
interval2 = expression2
if interval1 == interval2:
return True
else:
inter_list1 = interval1.split("\\cup")
inter_list2 = interval2.split("\\cup")
if len(inter_list1) != len(inter_list2):
return False
else:
for inter1, inter2 in zip(inter_list1, inter_list2):
if not compare_two_interval(inter1, inter2):
return False
return True
def preprocess(self, expression1, expression2):
# Preprocess expressions to extract and replace special symbols
def extract_boxed_content(latex_str):
boxed_matches = re.finditer(r'\\boxed{', latex_str)
results = ""
for match in boxed_matches:
start_index = match.end()
end_index = start_index
stack = 1
while stack > 0 and end_index < len(latex_str):
if latex_str[end_index] == '{':
stack += 1
elif latex_str[end_index] == '}':
stack -= 1
end_index += 1
if stack == 0:
content = latex_str[start_index:end_index - 1]
results += content + ","
else:
raise ValueError("Mismatched braces in LaTeX string.")
if results == "":
last_line_ans = latex_str.strip().split("\n")[-1]
dollar_pattern = r"\$(.*?)\$"
answers = re.findall(dollar_pattern, last_line_ans)
if answers:
for ans in answers:
results += ans + ","
else:
results = latex_str
return results
def sepcial_symbol_replace(expression):
expression = expression.replace("\\text{cm}^2", '').replace("\\text{cm}", "").replace("\\,cm", '').replace("\\text{ cm}", '').replace("cm", '').replace("\\text{分米}^2", '').replace("cm^{2}", '').replace("60 \\text{ cm}^2",'').replace("\\ \\text{m}", "").replace("\\text{米}","").strip() # noqa: E501
expression = re.sub(r"(.+)m$", r"\1", expression)
if "\\in " in expression:
expression = expression.split("\\in ")[1]
for signal in self.special_signal_map:
expression = expression.replace(signal, self.special_signal_map[signal])
expression = re.sub(r'(\\sin|\\cos|\\tan)(\d+)', r'\1((\2/180)\\pi)', expression)
expression = expression.strip("\n,.:;^_=+`!@#%^&*~,。")
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
expression = re.sub(pattern, r'\1', expression)
return expression
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2)
exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2)
return exp1, exp2
def can_compute_power(self, expr):
# Checks if a power expression can be computed
if isinstance(expr, Pow):
base, exp = expr.as_base_exp()
if base.is_number and exp.is_number:
MAX_EXP = 1000 # Adjust based on computing environment
if abs(exp.evalf()) > MAX_EXP:
return False
else:
return True
else:
return False
else:
return True # Not a power expression, can compute
class MMMath(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'MM-Math': 'https://opencompass.openxlab.space/utils/VLMEval/MM-Math.tsv',
}
DATASET_MD5 = {
'MM-Math': '1f064ed7c4e0e8926a3fa65849419ca5',
}
@classmethod
def evaluate(self, eval_file, **kwargs):
data = load(eval_file)
judger = AutoScoringJudge()
func = judger.judge
tups = [dict(expression1=x, expression2=y) for x, y in zip(data['answer'], data['prediction'])]
res = track_progress_rich(func, tups, nproc=16)
data['hit'] = res
dump(data, eval_file)
score_file = eval_file.replace('.xlsx', '_score.json')
score = {}
score['overall'] = np.mean(data['hit'])
# Results by Difficulty
difficulties = set(data['difficulty'])
for d in difficulties:
score[f'Difficulty-{d}'] = np.mean(data[data['difficulty'] == d]['hit'])
# Results by Year
years = set(data['year'])
for y in years:
score[f'Year-{y}'] = np.mean(data[data['year'] == y]['hit'])
# Results by Knowledge-L1
points = set(data['knowledge_l1'])
for p in points:
score[f'Knowledge-L1-{p}'] = np.mean(data[data['knowledge_l1'] == p]['hit'])
# Results by Knowledge-L2
points = set(data['knowledge_l2'])
for p in points:
score[f'Knowledge-L2-{p}'] = np.mean(data[data['knowledge_l2'] == p]['hit'])
dump(score, score_file)
return score