Create pipeline.py (#2)
Browse files- Create pipeline.py (7ac259c19ecc70c2aad0575c7f2ff667f2e95fac)
Co-authored-by: yrlee <[email protected]>
- pipeline.py +55 -0
pipeline.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Pipeline
|
| 2 |
+
|
| 3 |
+
class MyPipeline(Pipeline):
|
| 4 |
+
def _sanitize_parameters(self, **kwargs):
|
| 5 |
+
preprocess_kwargs = {}
|
| 6 |
+
if "max_length" in kwargs:
|
| 7 |
+
preprocess_kwargs["max_length"] = kwargs["max_length"]
|
| 8 |
+
if "num_beams" in kwargs:
|
| 9 |
+
preprocess_kwargs["num_beams"] = kwargs["num_beams"]
|
| 10 |
+
|
| 11 |
+
return preprocess_kwargs, {}, {}
|
| 12 |
+
def preprocess(self, inputs, **kwargs):
|
| 13 |
+
inputs = re.sub(r'[^A-Za-z가-힣,<>0-9:&# ]', '', inputs)
|
| 14 |
+
inputs = "질문 생성: <unused0>"+inputs
|
| 15 |
+
|
| 16 |
+
input_ids = [tokenizer.bos_token_id] + tokenizer.encode(inputs) + [tokenizer.eos_token_id]
|
| 17 |
+
return {"inputs":torch.tensor([input_ids]),'max_length':kwargs['max_length'],'num_beams':kwargs['num_beams'] }
|
| 18 |
+
|
| 19 |
+
def _forward(self, model_inputs):
|
| 20 |
+
res_ids = model.generate(
|
| 21 |
+
model_inputs['inputs'],
|
| 22 |
+
max_length=model_inputs['max_length'],
|
| 23 |
+
num_beams=model_inputs['num_beams'],
|
| 24 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 25 |
+
bad_words_ids=[[tokenizer.unk_token_id]]
|
| 26 |
+
)
|
| 27 |
+
return {"logits": res_ids}
|
| 28 |
+
|
| 29 |
+
def postprocess(self, model_outputs):
|
| 30 |
+
a = tokenizer.batch_decode(model_outputs["logits"].tolist())[0]
|
| 31 |
+
out_question = a.replace('<s>', '').replace('</s>', '')
|
| 32 |
+
return out_question
|
| 33 |
+
|
| 34 |
+
def _inference(self,paragraph,**kwargs):
|
| 35 |
+
input_ids = self.preprocess(paragraph,**kwargs)
|
| 36 |
+
reds_ids = self._forward(input_ids)
|
| 37 |
+
out_question = self.postprocess(reds_ids)
|
| 38 |
+
return out_question
|
| 39 |
+
|
| 40 |
+
def make_question(self, text, **kwargs):
|
| 41 |
+
words = text.split(" ")
|
| 42 |
+
frame_size = kwargs['frame_size']
|
| 43 |
+
hop_length = kwargs['hop_length']
|
| 44 |
+
steps = round((len(words)-frame_size)/hop_length) + 1
|
| 45 |
+
outs = []
|
| 46 |
+
for step in range(steps):
|
| 47 |
+
try:
|
| 48 |
+
script = " ".join(words[step*hop_length:step*hop_length+frame_size])
|
| 49 |
+
except:
|
| 50 |
+
script = " ".join(words[(1+step)*hop_length:])
|
| 51 |
+
|
| 52 |
+
outs.append(self._inference(script,**kwargs))
|
| 53 |
+
#if step>4:
|
| 54 |
+
# break
|
| 55 |
+
return outs
|