tlogandesigns commited on
Commit
0823edb
·
1 Parent(s): 5a92002

ML_POSITIVE_LABELS

Browse files
Files changed (1) hide show
  1. checker.py +45 -10
checker.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- checker.py — core logic for Image + Text Compliance Check
3
- """
4
  from __future__ import annotations
5
  from pathlib import Path
6
  from typing import List, Optional, Dict, Any, Iterable, Union, Tuple
@@ -40,6 +37,12 @@ USE_TINY_ML = os.getenv("USE_TINY_ML", "1") == "1"
40
  HF_REPO = os.getenv("HF_REPO", "tlogandesigns/fairhousing-bert-tiny")
41
  HF_THRESH = float(os.getenv("HF_THRESH", "0.75"))
42
 
 
 
 
 
 
 
43
  BASE_DIR = Path(__file__).parent
44
  PHRASES_PATH = Path(os.getenv("PHRASES_PATH", str(BASE_DIR / "phrases.yaml")))
45
 
@@ -93,14 +96,12 @@ def contains_disclaimer(text: str, disclaimer: str) -> bool:
93
 
94
  return squeeze(disclaimer) in squeeze(text)
95
 
96
-
97
  @dataclass
98
  class Rule:
99
  regex: re.Pattern
100
  category: str
101
  suggests: list[str]
102
 
103
-
104
  PHRASE_RULES: list[Rule] = []
105
  PHRASES_ERROR: Optional[str] = None
106
 
@@ -173,6 +174,33 @@ if USE_TINY_ML:
173
  )
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def fair_housing_flags(text: str) -> List[str]:
177
  flags: List[str] = []
178
  t = (text or "")[:1500]
@@ -185,11 +213,9 @@ def fair_housing_flags(text: str) -> List[str]:
185
  flags.append(rule.category)
186
  if hf_pipe:
187
  try:
188
- pred = hf_pipe(t)
189
- lbl = pred[0]["label"]
190
- score = float(pred[0]["score"])
191
- if (lbl in ("1", "LABEL_1", "violation", "POSITIVE")) and score >= HF_THRESH:
192
- flags.append(f"MLFlag: model={HF_REPO} label={lbl} score={score:.2f}")
193
  except Exception as e:
194
  flags.append(f"MLFlag: inference error: {e}")
195
  return flags
@@ -353,6 +379,13 @@ def run_check(
353
  )
354
  img_findings, img_spans = find_rule_matches(itxt)
355
  ptxt_findings, ptxt_spans = find_rule_matches(ptxt)
 
 
 
 
 
 
 
356
  results = {
357
  "Fair_Housing": fair_housing_block,
358
  "img": img_block,
@@ -371,6 +404,8 @@ def run_check(
371
  "OCR": pytesseract is not None,
372
  "Categories": sorted({r.category for r in PHRASE_RULES}),
373
  "DisclaimerRequiredOnNonSocial": REQUIRE_DISCLAIMER_ON_NON_SOCIAL,
 
 
374
  },
375
  }
376
  send_email_notification(results)
 
 
 
 
1
  from __future__ import annotations
2
  from pathlib import Path
3
  from typing import List, Optional, Dict, Any, Iterable, Union, Tuple
 
37
  HF_REPO = os.getenv("HF_REPO", "tlogandesigns/fairhousing-bert-tiny")
38
  HF_THRESH = float(os.getenv("HF_THRESH", "0.75"))
39
 
40
+ ML_POSITIVE_LABELS = {
41
+ s.strip().lower()
42
+ for s in re.split(r"\s*,\s*", os.getenv("ML_POSITIVE_LABELS", "Potential Violation,violation,positive,LABEL_1,1"))
43
+ if s.strip()
44
+ }
45
+
46
  BASE_DIR = Path(__file__).parent
47
  PHRASES_PATH = Path(os.getenv("PHRASES_PATH", str(BASE_DIR / "phrases.yaml")))
48
 
 
96
 
97
  return squeeze(disclaimer) in squeeze(text)
98
 
 
99
  @dataclass
100
  class Rule:
101
  regex: re.Pattern
102
  category: str
103
  suggests: list[str]
104
 
 
105
  PHRASE_RULES: list[Rule] = []
106
  PHRASES_ERROR: Optional[str] = None
107
 
 
174
  )
175
 
176
 
177
+ def _violation_score(pipe, text: str) -> float:
178
+ try:
179
+ preds = pipe(text, return_all_scores=True)
180
+ scores = {str(d["label"]).lower(): float(d["score"]) for d in preds[0]}
181
+ except TypeError:
182
+ preds = pipe(text)
183
+ if isinstance(preds, list) and preds:
184
+ p = preds[0]
185
+ label = str(p.get("label", "")).lower()
186
+ score = float(p.get("score", 0.0))
187
+ if label in ML_POSITIVE_LABELS:
188
+ return score
189
+ return score
190
+ return 0.0
191
+ except Exception:
192
+ return 0.0
193
+ for name in ML_POSITIVE_LABELS:
194
+ if name in scores:
195
+ return scores[name]
196
+ if "non-violation" in scores:
197
+ return 1.0 - scores["non-violation"]
198
+ candidates = {k: v for k, v in scores.items() if any(tok in k for tok in ("violat", "posit", "flag", "risk", "unsafe", "toxic"))}
199
+ if candidates:
200
+ return max(candidates.values())
201
+ return max(scores.values()) if scores else 0.0
202
+
203
+
204
  def fair_housing_flags(text: str) -> List[str]:
205
  flags: List[str] = []
206
  t = (text or "")[:1500]
 
213
  flags.append(rule.category)
214
  if hf_pipe:
215
  try:
216
+ score = _violation_score(hf_pipe, t)
217
+ if score >= HF_THRESH:
218
+ flags.append(f"MLFlag: model={HF_REPO} score={score:.2f}")
 
 
219
  except Exception as e:
220
  flags.append(f"MLFlag: inference error: {e}")
221
  return flags
 
379
  )
380
  img_findings, img_spans = find_rule_matches(itxt)
381
  ptxt_findings, ptxt_spans = find_rule_matches(ptxt)
382
+ model_labels = []
383
+ try:
384
+ if hf_pipe is not None and hasattr(hf_pipe, "model") and hasattr(hf_pipe.model, "config"):
385
+ labels_map = getattr(hf_pipe.model.config, "id2label", {}) or {}
386
+ model_labels = list(labels_map.values())
387
+ except Exception:
388
+ model_labels = []
389
  results = {
390
  "Fair_Housing": fair_housing_block,
391
  "img": img_block,
 
404
  "OCR": pytesseract is not None,
405
  "Categories": sorted({r.category for r in PHRASE_RULES}),
406
  "DisclaimerRequiredOnNonSocial": REQUIRE_DISCLAIMER_ON_NON_SOCIAL,
407
+ "ModelLabels": model_labels,
408
+ "MLPositiveLabels": sorted(list(ML_POSITIVE_LABELS)),
409
  },
410
  }
411
  send_email_notification(results)