File size: 495 Bytes
9d2cd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import json
import numpy as np
from typing import List, Union
def input_fn(input_data, content_type):
data = json.loads(input_data)
return data['inputs']
def predict_fn(data: Union[List[str], str], model):
outputs = model(data, padding=False, truncation=True)
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
return embeddings
def output_fn(prediction, accept):
return json.dumps(
obj={
"outputs": prediction
}
)
|