RoBERTa for Multi-label Classification of Policy Instruments
This model fine-tunes roberta-large
for multilabel classification of policies, targets, and themes.
Model Details
- Base model: roberta-large
- Max length: 512
- Output: 67 multilabel classes (PI - Policy Instrument, TG - Target Group, TH - Theme). There are three main classes that have further sub-categories in them.
- Threshold: 0.25
Intended Use
Classify policy documents or government program descriptions into thematic categories.
How to Use
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import joblib
import requests
model_path = "toqeerehsan/multilabel-indicator-classification-roberta-l"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
mlb_url = "https://huggingface.co/toqeerehsan/multilabel-indicator-classification-roberta-l/resolve/main/mlb.pkl"
mlb_path = "mlb.pkl"
with open(mlb_path, "wb") as f:
f.write(requests.get(mlb_url).content)
mlb = joblib.load(mlb_path)
text = "This program supports clean technology and sustainable development in industries."
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
model.eval()
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().numpy()
# Threshold
binary_preds = (probs > 0.25).astype(int)
predicted_labels = [label for i, label in enumerate(mlb.classes_) if binary_preds[i] == 1]
print("Predicted Labels:", predicted_labels)
# Predicted Labels: ['PI008', 'TG20', 'TG21', 'TG22', 'TG25', 'TG29', 'TH31', 'TH92']
- Downloads last month
- 4
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support