| print("Loading Multi head pipeline") | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| class CustomTextClassificationPipeline(TextClassificationPipeline): | |
| def __init__(self, model, tokenizer=None, **kwargs): | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) | |
| super().__init__(model=model, tokenizer=tokenizer, **kwargs) | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| return preprocess_kwargs, {}, {} | |
| def preprocess(self, inputs): | |
| return self.tokenizer(inputs, return_tensors='pt', truncation=True, padding=True) | |
| def _forward(self, model_inputs): | |
| input_ids = model_inputs['input_ids'] | |
| attention_mask = (input_ids != 0).long() | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist() | |
| categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"] | |
| return dict(zip(categories, predictions)) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "multi-head-text-classification", | |
| pipeline_class=CustomTextClassificationPipeline, | |
| pt_model=AutoModelForSequenceClassification, | |
| ) |