Upload ONNX weights
#1
by
Xenova
HF staff
- opened
import torch
from sentence_transformers import SentenceTransformer
class WrappedModel(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.embedding = m[0].embedding
def forward(self, input_ids, attention_mask):
indices = input_ids[attention_mask == 1]
offsets = torch.cat([torch.tensor([0]), attention_mask.sum(dim=1)[:-1].cumsum(dim=0)])
return self.embedding(indices, offsets)
shape = (3, 4)
input_ids = torch.tensor([1, 2, 3, 4, 5, 6, -1, -1, 1, 1, 1, 0]).view(shape)
attention_mask = torch.tensor([1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0]).view(shape)
model_id = "sentence-transformers/static-similarity-mrl-multilingual-v1"
model = SentenceTransformer(model_id)
wrapped = WrappedModel(model) # test forward pass
# Export the model
torch.onnx.export(wrapped,
(input_ids, attention_mask),
"model.onnx",
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names = ['input_ids', 'attention_mask'],
output_names = ['sentence_embedding'],
dynamic_axes={
'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
'sentence_embedding' : {0 : 'batch_size'},
})
tomaarsen
changed pull request status to
merged