|
|
from transformers import Pipeline |
|
|
|
|
|
|
|
|
class MultitaskTokenClassificationPipeline(Pipeline): |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
preprocess_kwargs = {} |
|
|
if "text" in kwargs: |
|
|
preprocess_kwargs["text"] = kwargs["text"] |
|
|
return preprocess_kwargs, {}, {} |
|
|
|
|
|
def preprocess(self, text, **kwargs): |
|
|
|
|
|
return text |
|
|
|
|
|
def _forward(self, text): |
|
|
print(f"Do we arrive here? {text}") |
|
|
print(f"Let's check the model: {self.model.get_floret_model()}") |
|
|
|
|
|
self.model(text) |
|
|
return text |
|
|
|
|
|
def postprocess(self, text, **kwargs): |
|
|
""" |
|
|
Postprocess the outputs of the model |
|
|
:param outputs: |
|
|
:param kwargs: |
|
|
:return: |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return text |
|
|
|