File size: 495 Bytes
9d2cd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import json
import numpy as np
from typing import List, Union


def input_fn(input_data, content_type):
    data = json.loads(input_data)
    return data['inputs']


def predict_fn(data: Union[List[str], str], model):
    outputs = model(data, padding=False, truncation=True)
    embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
    return embeddings


def output_fn(prediction, accept):
    return json.dumps(
        obj={
            "outputs": prediction
        }
    )