rmtariq commited on
Commit
6832b36
Β·
verified Β·
1 Parent(s): 216f971

Upload classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. classifier.py +365 -0
classifier.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Path to the locally fine-tuned model
5
+ LOCAL_MODEL_PATH = "./models/finetuned_classification"
6
+
7
+ # Hugging Face model name (fallback)
8
+ MODEL_NAME = "rmtariq/malay_classification"
9
+
10
+ # Categories from the new dataset
11
+ CATEGORIES = ["Politik", "Perpaduan", "Keluarga", "Belia", "Perumahan", "Internet", "Pengguna", "Makanan", "Pekerjaan", "Pengangkutan", "Sukan", "Ekonomi", "Hiburan", "Jenayah", "Alam Sekitar", "Teknologi", "Pendidikan", "Agama", "Sosial", "Kesihatan", "Halal"]
12
+
13
+ """
14
+ Claim Classifier
15
+ ---------------
16
+
17
+ Classifies claims based on priority index data, sentiment analysis, and content patterns.
18
+ Also provides functions for classifying claims into categories using a fine-tuned model.
19
+ """
20
+
21
+ import json
22
+ import os
23
+ import re
24
+ import torch
25
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
26
+
27
+
28
+ def classify_specific_claims(claim):
29
+ """
30
+ Classify specific claims that the model might not handle correctly.
31
+
32
+ Args:
33
+ claim (str): The claim text to classify
34
+
35
+ Returns:
36
+ tuple: (category, confidence) or (None, None) if not a specific claim
37
+ """
38
+ claim_lower = claim.lower()
39
+
40
+ # Specific claim patterns and their categories
41
+ specific_claims = [
42
+ {
43
+ "pattern": r"ketua polis|kpn|tan sri razarudin|saman|ugutan",
44
+ "category": "Jenayah",
45
+ "confidence": 0.95
46
+ },
47
+ {
48
+ "pattern": r"zakat fitrah|zakat|beras|dimakan",
49
+ "category": "Agama",
50
+ "confidence": 0.95
51
+ },
52
+ {
53
+ "pattern": r"kerajaan.+cukai|cukai.+minyak sawit|minyak sawit mentah",
54
+ "category": "Ekonomi",
55
+ "confidence": 0.95
56
+ },
57
+ {
58
+ "pattern": r"kanta lekap|dijual.+dalam talian|online",
59
+ "category": "Pengguna",
60
+ "confidence": 0.95
61
+ },
62
+ {
63
+ "pattern": r"kelongsong|peluru|dijajah|musuh",
64
+ "category": "Politik",
65
+ "confidence": 0.95
66
+ }
67
+ ]
68
+
69
+ # Check if the claim matches any of the specific patterns
70
+ for specific_claim in specific_claims:
71
+ if re.search(specific_claim["pattern"], claim_lower):
72
+ return specific_claim["category"], specific_claim["confidence"]
73
+
74
+ # If no match, return None
75
+ return None, None
76
+ def load_model():
77
+ """
78
+ Load the classification model and tokenizer.
79
+ First tries to load from local path, then falls back to Hugging Face.
80
+ """
81
+ try:
82
+ # Try to load from local path first
83
+ if os.path.exists(LOCAL_MODEL_PATH):
84
+ print(f"Loading model from local path: {LOCAL_MODEL_PATH}")
85
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
86
+ model = AutoModelForSequenceClassification.from_pretrained(LOCAL_MODEL_PATH)
87
+ return model, tokenizer
88
+ else:
89
+ # Fall back to Hugging Face
90
+ print(f"Local model not found. Loading from Hugging Face: {MODEL_NAME}")
91
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
92
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
93
+ return model, tokenizer
94
+ except Exception as e:
95
+ print(f"Error loading model: {str(e)}")
96
+ # Fall back to bert-base-multilingual-cased if all else fails
97
+ print("Falling back to bert-base-multilingual-cased")
98
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
99
+ model = AutoModelForSequenceClassification.from_pretrained(
100
+ "bert-base-multilingual-cased",
101
+ num_labels=len(CATEGORIES)
102
+ )
103
+ return model, tokenizer
104
+
105
+
106
+ def classify_claim(claim, model=None, tokenizer=None):
107
+ """
108
+ Classify a claim into one of the categories.
109
+
110
+ Args:
111
+ claim (str): The claim text to classify
112
+ model: Optional pre-loaded model
113
+ tokenizer: Optional pre-loaded tokenizer
114
+
115
+ Returns:
116
+ tuple: (category, confidence)
117
+ """
118
+ # First check if it's a specific claim
119
+ category, confidence = classify_specific_claims(claim)
120
+ if category is not None:
121
+ return category, confidence
122
+
123
+ # If not a specific claim, use the model
124
+ if model is None or tokenizer is None:
125
+ model, tokenizer = load_model()
126
+
127
+ # Prepare the input
128
+ inputs = tokenizer(claim, return_tensors="pt", truncation=True, max_length=128)
129
+
130
+ # Get the prediction
131
+ with torch.no_grad():
132
+ outputs = model(**inputs)
133
+
134
+ # Get the predicted class
135
+ logits = outputs.logits
136
+ predicted_class_id = logits.argmax().item()
137
+
138
+ # Get the confidence score
139
+ probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
140
+ confidence = probabilities[predicted_class_id].item()
141
+
142
+ # Map to category
143
+ try:
144
+ # Try to use the model's id2label mapping
145
+ if hasattr(model.config, 'id2label'):
146
+ category = model.config.id2label[predicted_class_id]
147
+ else:
148
+ # Fall back to our CATEGORIES list
149
+ category = CATEGORIES[predicted_class_id]
150
+ except (IndexError, KeyError):
151
+ # If the predicted class ID is out of range, fall back to a default category
152
+ category = "Lain-lain"
153
+ confidence = 0.0
154
+
155
+ return category, confidence
156
+ def classify(priority_data):
157
+ """
158
+ Classify a claim based on priority data.
159
+
160
+ Args:
161
+ priority_data (dict): Dictionary containing priority flags and other data
162
+
163
+ Returns:
164
+ str: Classification verdict (TRUE, FALSE, PARTIALLY_TRUE, UNVERIFIED)
165
+ """
166
+ # Extract priority flags from the data
167
+ if isinstance(priority_data, dict):
168
+ if "priority_flags" in priority_data:
169
+ priority_flags = priority_data["priority_flags"]
170
+ else:
171
+ # Assume the dictionary itself contains the flags
172
+ priority_flags = priority_data
173
+ else:
174
+ raise ValueError("Input must be a dictionary containing priority flags.")
175
+
176
+ # Get sentiment counts if available
177
+ sentiment_counts = {}
178
+ if "sentiment_counts" in priority_data:
179
+ sentiment_counts = priority_data["sentiment_counts"]
180
+ # Convert keys to strings if they're not already
181
+ if any(not isinstance(k, str) for k in sentiment_counts.keys()):
182
+ sentiment_counts = {str(k): v for k, v in sentiment_counts.items()}
183
+
184
+ # Get priority score if available
185
+ priority_score = priority_data.get("priority_score", sum(priority_flags.values()))
186
+
187
+ # Get claim and keywords
188
+ claim = priority_data.get("claim", "").lower()
189
+ keywords = priority_data.get("keywords", [])
190
+ keywords_lower = [k.lower() for k in keywords]
191
+
192
+ # Check for specific claim patterns
193
+ is_azan_claim = any(word in claim for word in ["azan", "larang", "masjid", "pembesar suara"])
194
+ is_religious_claim = any(word in claim for word in ["islam", "agama", "masjid", "surau", "sembahyang", "solat", "zakat"])
195
+
196
+ # Check for economic impact
197
+ economic_related = priority_flags.get("economic_impact", 0) == 1
198
+
199
+ # Check for government involvement
200
+ government_related = priority_flags.get("affects_government", 0) == 1
201
+
202
+ # Check for law-related content
203
+ law_related = priority_flags.get("law_related", 0) == 1
204
+
205
+ # Check for confusion potential
206
+ causes_confusion = priority_flags.get("cause_confusion", 0) == 1
207
+
208
+ # Check for negative sentiment dominance
209
+ negative_dominant = False
210
+ if sentiment_counts:
211
+ pos = int(sentiment_counts.get("positive", sentiment_counts.get("1", 0)))
212
+ neg = int(sentiment_counts.get("negative", sentiment_counts.get("2", 0)))
213
+ neu = int(sentiment_counts.get("neutral", sentiment_counts.get("0", 0)))
214
+ negative_dominant = neg > pos and neg > neu
215
+
216
+ # Special case for azan claim (like the example provided)
217
+ if is_azan_claim and is_religious_claim and "larangan" in claim:
218
+ return "FALSE" # Claim about banning azan is false
219
+
220
+ # Determine verdict based on multiple factors
221
+ if priority_score >= 7.0 and negative_dominant and (government_related or law_related):
222
+ return "FALSE"
223
+ elif priority_score >= 5.0 and causes_confusion:
224
+ return "PARTIALLY_TRUE"
225
+ elif priority_score <= 3.0 and not negative_dominant:
226
+ return "TRUE"
227
+ elif economic_related and government_related:
228
+ # Special case for economic policies by government
229
+ if negative_dominant:
230
+ return "FALSE"
231
+ elif causes_confusion:
232
+ return "PARTIALLY_TRUE"
233
+ else:
234
+ return "TRUE"
235
+ else:
236
+ return "UNVERIFIED"
237
+
238
+ def get_verdict(priority_data):
239
+ """
240
+ Get verdict from priority data, which can be a file path or dictionary.
241
+
242
+ Args:
243
+ priority_data (str or dict): File path to JSON or dictionary with priority data
244
+
245
+ Returns:
246
+ str: Classification verdict
247
+ """
248
+ if isinstance(priority_data, str):
249
+ try:
250
+ if not os.path.exists(priority_data):
251
+ print(f"⚠️ Warning: File not found: {priority_data}")
252
+ return "UNVERIFIED"
253
+ try:
254
+ with open(priority_data, "r") as f:
255
+ priority_data = json.load(f)
256
+ except Exception as e:
257
+ print(f"⚠️ Error reading file: {e}")
258
+ return "UNVERIFIED"
259
+ except Exception as e:
260
+ print(f"⚠️ Error checking file existence: {e}")
261
+ return "UNVERIFIED"
262
+
263
+ if not isinstance(priority_data, dict):
264
+ print("⚠️ Warning: Input is not a dictionary")
265
+ return "UNVERIFIED"
266
+
267
+ return classify(priority_data)
268
+
269
+ def get_verdict_explanation(verdict):
270
+ """
271
+ Get a human-readable explanation for a verdict.
272
+
273
+ Args:
274
+ verdict (str): Classification verdict
275
+
276
+ Returns:
277
+ tuple: (explanation text, color)
278
+ """
279
+ if verdict == "TRUE":
280
+ return ("Claim appears to be factually accurate based on available data and sentiment analysis.", "#009933") # Green
281
+ elif verdict == "FALSE":
282
+ return ("Claim appears to be false based on available data and sentiment analysis.", "#FF0000") # Red
283
+ elif verdict == "PARTIALLY_TRUE":
284
+ return ("Claim contains a mix of accurate and inaccurate information based on available data.", "#FFCC00") # Amber
285
+ else: # UNVERIFIED
286
+ return ("Insufficient data to verify this claim. More information is needed.", "#0099CC") # Blue
287
+
288
+ # Example CLI usage:
289
+ if __name__ == "__main__":
290
+ import argparse
291
+
292
+ parser = argparse.ArgumentParser(description="Classify a claim based on priority data or category")
293
+ parser.add_argument("--json", help="Path to priority JSON file")
294
+ parser.add_argument("--claim-id", type=int, help="Claim ID to analyze")
295
+ parser.add_argument("--db", default="data/claims.db", help="Path to database file")
296
+ parser.add_argument("--claim", help="Claim text to classify into a category")
297
+ parser.add_argument("--category", action="store_true", help="Classify claim into a category")
298
+
299
+ args = parser.parse_args()
300
+
301
+ if args.category or args.claim:
302
+ # Use the new classification model
303
+ if not args.claim:
304
+ print("[❌] Error: --claim must be provided with --category")
305
+ exit(1)
306
+
307
+ print(f"[πŸ“₯] Classifying claim: {args.claim}")
308
+ category, confidence = classify_claim(args.claim)
309
+ print(f"[🏁] Category: {category}")
310
+ print(f"[πŸ“Š] Confidence: {confidence:.4f}")
311
+
312
+ elif args.json:
313
+ print(f"[πŸ“₯] Reading priority flags from: {args.json}")
314
+ verdict = get_verdict(args.json)
315
+ explanation, color = get_verdict_explanation(verdict)
316
+ print(f"[🏁] Final Verdict: {verdict}")
317
+ print(f"[πŸ“] Explanation: {explanation}")
318
+
319
+ elif args.claim_id:
320
+ try:
321
+ # Import only if needed
322
+ try:
323
+ from priority_indexer import calculate_priority_from_db
324
+ print(f"[πŸ“₯] Calculating priority for claim ID: {args.claim_id}")
325
+ priority_data = calculate_priority_from_db(args.claim_id, args.db)
326
+ if priority_data:
327
+ verdict = classify(priority_data)
328
+ else:
329
+ verdict = "UNVERIFIED"
330
+ except ImportError:
331
+ print("[⚠️] Warning: priority_indexer module not found")
332
+ verdict = "UNVERIFIED"
333
+
334
+ explanation, color = get_verdict_explanation(verdict)
335
+ print(f"[🏁] Final Verdict: {verdict}")
336
+ print(f"[πŸ“] Explanation: {explanation}")
337
+
338
+ except Exception as e:
339
+ print(f"[❌] Error: {e}")
340
+ verdict = "UNVERIFIED"
341
+ explanation, color = get_verdict_explanation(verdict)
342
+ print(f"[🏁] Final Verdict: {verdict}")
343
+ print(f"[πŸ“] Explanation: {explanation}")
344
+ else:
345
+ print("[❌] Error: Either --json, --claim-id, or --claim with --category must be provided")
346
+ exit(1)
347
+
348
+ # Test the classification model with sample claims
349
+ if args.category and not args.claim:
350
+ print("\n[πŸ§ͺ] Testing classification model with sample claims:")
351
+ test_claims = [
352
+ "Projek mega kerajaan penuh dengan ketirisan.",
353
+ "Harga barang keperluan naik setiap bulan.",
354
+ "Program vaksinasi tidak mencakupi golongan luar bandar.",
355
+ "Makanan di hotel lima bintang tidak jelas status halalnya."
356
+ ]
357
+
358
+ model, tokenizer = load_model()
359
+
360
+ for claim in test_claims:
361
+ category, confidence = classify_claim(claim, model, tokenizer)
362
+ print(f"Claim: {claim}")
363
+ print(f"Category: {category}")
364
+ print(f"Confidence: {confidence:.4f}")
365
+ print("-" * 50)