Lukas Helff
commited on
Commit
·
58596cd
1
Parent(s):
03d8e50
extract rule from NL
Browse files- VerifiableRewardsForScalableLogicalReasoning.py +50 -12
- app.py +1 -1
VerifiableRewardsForScalableLogicalReasoning.py
CHANGED
@@ -100,7 +100,7 @@ Returns:
|
|
100 |
"""
|
101 |
|
102 |
|
103 |
-
def _evaluate_with_prolog(
|
104 |
"""
|
105 |
Evaluates a predicted rule against the validation program using Prolog.
|
106 |
"""
|
@@ -108,6 +108,7 @@ def _evaluate_with_prolog(rule_to_evaluate, validation_program, eval_config, tim
|
|
108 |
positive_pred = eval_config.get("positive_predicate", "eastbound")
|
109 |
negative_pred = eval_config.get("negative_predicate", "westbound")
|
110 |
# extract predicate from rule_to_evaluate
|
|
|
111 |
if positive_pred not in rule_to_evaluate:
|
112 |
logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
|
113 |
return {
|
@@ -137,8 +138,11 @@ check({vars}) :- neg({vars}), \\+ {positive_pred}({vars}). % negative rejected
|
|
137 |
|
138 |
% Count successful checks
|
139 |
check_count(Count) :-
|
140 |
-
setof(({vars}), ((pos({vars}); neg({vars})), check({vars})), CorrectExamples)
|
141 |
-
|
|
|
|
|
|
|
142 |
|
143 |
check_all :- forall((pos({vars});neg({vars})), check({vars})).
|
144 |
"""
|
@@ -165,12 +169,13 @@ check_all :- forall((pos({vars});neg({vars})), check({vars})).
|
|
165 |
timeout=timeout,
|
166 |
text=True
|
167 |
)
|
|
|
168 |
# Extract partial score from output
|
169 |
-
partial_score =
|
170 |
|
171 |
is_correct = True if partial_score == 1.0 else False
|
172 |
|
173 |
-
error = result.stderr if result.stderr else None
|
174 |
t1 = time.time()
|
175 |
|
176 |
return {
|
@@ -186,13 +191,50 @@ check_all :- forall((pos({vars});neg({vars})), check({vars})).
|
|
186 |
return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False,
|
187 |
"error": f"Evaluation timed out after {timeout} seconds"}
|
188 |
except Exception as e:
|
189 |
-
logger.warning(f"Error evaluating rule '{rule_to_evaluate}' returns: '{result.stdout.strip() if result else 'No error message'}'")
|
190 |
-
|
191 |
-
|
192 |
finally:
|
193 |
if os.path.exists(temp_file):
|
194 |
os.remove(temp_file)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
|
198 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
@@ -270,10 +312,6 @@ class VerifiableRewardsForScalableLogicalReasoning(evaluate.Metric):
|
|
270 |
if not validation_program:
|
271 |
raise ValueError(f"Example {i} does not contain validation program field")
|
272 |
|
273 |
-
# Make sure the prediction is a proper rule format
|
274 |
-
if not prediction.strip().endswith('.'):
|
275 |
-
prediction = prediction.strip() + '.'
|
276 |
-
|
277 |
eval_inputs.append((prediction, validation_program, eval_config))
|
278 |
|
279 |
# Process evaluations in parallel
|
|
|
100 |
"""
|
101 |
|
102 |
|
103 |
+
def _evaluate_with_prolog(prediction, validation_program, eval_config, timeout=5):
|
104 |
"""
|
105 |
Evaluates a predicted rule against the validation program using Prolog.
|
106 |
"""
|
|
|
108 |
positive_pred = eval_config.get("positive_predicate", "eastbound")
|
109 |
negative_pred = eval_config.get("negative_predicate", "westbound")
|
110 |
# extract predicate from rule_to_evaluate
|
111 |
+
rule_to_evaluate = extract_ilp_from_text_v2(prediction)
|
112 |
if positive_pred not in rule_to_evaluate:
|
113 |
logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
|
114 |
return {
|
|
|
138 |
|
139 |
% Count successful checks
|
140 |
check_count(Count) :-
|
141 |
+
(setof(({vars}), ((pos({vars}); neg({vars})), check({vars})), CorrectExamples) ->
|
142 |
+
length(CorrectExamples, Count)
|
143 |
+
;
|
144 |
+
Count = 0
|
145 |
+
).
|
146 |
|
147 |
check_all :- forall((pos({vars});neg({vars})), check({vars})).
|
148 |
"""
|
|
|
169 |
timeout=timeout,
|
170 |
text=True
|
171 |
)
|
172 |
+
partial_score = 0.0 if result.stdout.strip() == '' else int(result.stdout.strip())
|
173 |
# Extract partial score from output
|
174 |
+
partial_score = partial_score / pos_negs if pos_negs > 0 else 0.0
|
175 |
|
176 |
is_correct = True if partial_score == 1.0 else False
|
177 |
|
178 |
+
error = f'Rule invalid "{rule_to_evaluate}" with' + result.stderr if result.stderr else None
|
179 |
t1 = time.time()
|
180 |
|
181 |
return {
|
|
|
191 |
return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False,
|
192 |
"error": f"Evaluation timed out after {timeout} seconds"}
|
193 |
except Exception as e:
|
194 |
+
logger.warning(f"Error evaluating rule '{rule_to_evaluate}' returns: '{result.stdout.strip() if result else 'No error message'}' with error: {e}")
|
195 |
+
return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False,
|
196 |
+
"error": f"Syntactically invalid rule '{rule_to_evaluate}'"}
|
197 |
finally:
|
198 |
if os.path.exists(temp_file):
|
199 |
os.remove(temp_file)
|
200 |
|
201 |
+
def extract_ilp_from_text(text):
|
202 |
+
rule_patterns = [
|
203 |
+
# Pattern with body (full rule with implication)
|
204 |
+
r'([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\)\s*:-[^.]*\.)',
|
205 |
+
# Pattern for facts (no body)
|
206 |
+
# r'([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\)\s*\.)'
|
207 |
+
]
|
208 |
+
p_code = ''
|
209 |
+
for pattern in rule_patterns:
|
210 |
+
matches = re.findall(pattern, text)
|
211 |
+
for match in matches:
|
212 |
+
# Ensure the rule ends with a period
|
213 |
+
statement = match.strip()
|
214 |
+
if not statement.endswith('.'):
|
215 |
+
statement += '.'
|
216 |
+
p_code += statement + '\n'
|
217 |
+
return p_code
|
218 |
+
|
219 |
+
|
220 |
+
def extract_ilp_from_text_v2(text, target_predicates=None):
|
221 |
+
# Pre-process: collapse code blocks to single lines
|
222 |
+
text = re.sub(r'\n\s*', ' ', text) # crude: flatten all to one line
|
223 |
+
# Optionally restrict to only some predicates
|
224 |
+
preds = '|'.join([re.escape(p) for p in (target_predicates or [])])
|
225 |
+
head_pat = rf"(?:{preds})" if preds else r"[a-zA-Z_][a-zA-Z0-9_]*"
|
226 |
+
# Rule pattern, across newlines
|
227 |
+
rule_pattern = re.compile(rf'({head_pat}\([^()]*\)\s*:-.*?\.)')
|
228 |
+
rules = set(rule_pattern.findall(text))
|
229 |
+
# Remove rules that are also captured as facts
|
230 |
+
p_code = ''
|
231 |
+
for rule in rules:
|
232 |
+
# Ensure the rule ends with a period
|
233 |
+
statement = rule.strip()
|
234 |
+
if not statement.endswith('.'):
|
235 |
+
statement += '.'
|
236 |
+
p_code += statement + '\n'
|
237 |
+
return p_code.strip() # Ensure no trailing whitespace
|
238 |
|
239 |
|
240 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
|
|
312 |
if not validation_program:
|
313 |
raise ValueError(f"Example {i} does not contain validation program field")
|
314 |
|
|
|
|
|
|
|
|
|
315 |
eval_inputs.append((prediction, validation_program, eval_config))
|
316 |
|
317 |
# Process evaluations in parallel
|
app.py
CHANGED
@@ -269,5 +269,5 @@ Evaluations performed by the symbolic judge are fully verifiable and grounded in
|
|
269 |
return demo
|
270 |
|
271 |
# Use a local path instead of a module name
|
272 |
-
module = evaluate.load("
|
273 |
create_interface(module).launch()
|
|
|
269 |
return demo
|
270 |
|
271 |
# Use a local path instead of a module name
|
272 |
+
module = evaluate.load("./VerifiableRewardsForScalableLogicalReasoning.py")
|
273 |
create_interface(module).launch()
|