|
import os |
|
import platform |
|
from typing import Union, Sequence, Sized |
|
|
|
import torch |
|
from hfendpoints.openai import Context, run |
|
from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, EmbeddingRequest, EmbeddingResponse, Usage |
|
from hfendpoints import EndpointConfig, Handler, __version__ |
|
from loguru import logger |
|
from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16} |
|
|
|
|
|
def has_bf16_support() -> bool: |
|
""" |
|
Helper to detect if the hardware supports bfloat16 |
|
|
|
Note: |
|
Intel libraries, such as oneDNN, provide emulation for bfloat16 even if the underlying hardware does not support it. |
|
This means CPU ISA with AVX512 will work, even if not with the same performances as one could expect from CPU ISA with AVX512_BF16. |
|
Also, AMX_BF16 is implicitly assumed true when AVX512_BF16 is true (that's the case on Intel Sapphire Rapids). |
|
|
|
:return: True if the hardware supports (or can emulate) bfloat16, False otherwise |
|
""" |
|
return torch.cpu._is_avx512_bf16_supported() or torch.cpu._is_avx512_supported() |
|
|
|
|
|
def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage: |
|
""" |
|
Compute the number of processed tokens and return as Usage object matching OpenAI |
|
:param tokens: List or nested List of tokens |
|
:param is_batched: Flag indicating if the original request contained batched inputs |
|
:return: Usage object matching OpenAI specifications |
|
""" |
|
if is_batched: |
|
num_tokens = sum(len(document) for document in tokens) |
|
else: |
|
num_tokens = len(tokens) |
|
|
|
return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens) |
|
|
|
class SentenceTransformerHandler(Handler): |
|
__slots__ = ("_config", "_dtype", "_model", "_model_name", "_use_amp") |
|
|
|
def __init__(self, config: EndpointConfig): |
|
self._config = config |
|
self._dtype = torch.float32 |
|
self._model_name = config.model_id |
|
|
|
self._allocate_model() |
|
|
|
def _allocate_model(self): |
|
dtype = torch.bfloat16 if has_bf16_support() else torch.float32 |
|
model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype}) |
|
|
|
if platform.machine() == "x86_64": |
|
import intel_extension_for_pytorch as ipex |
|
logger.info(f"x64 platform detected: {platform.processor()}") |
|
|
|
with torch.inference_mode(): |
|
model = model.eval() |
|
model = model.to(memory_format=torch.channels_last) |
|
model = ipex.optimize(model, dtype=dtype, weights_prepack=False, graph_mode=True, concat_linear=True) |
|
model = torch.compile(model, dynamic=True, backend="ipex") |
|
else: |
|
model = torch.compile(model) |
|
|
|
self._model = model |
|
self._dtype = dtype |
|
self._use_amp = dtype in SUPPORTED_AMP_DTYPES |
|
|
|
async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse: |
|
with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF): |
|
with torch.inference_mode(), torch.amp.autocast("cpu", dtype=self._dtype, enabled=self._use_amp): |
|
tokens = self._model.tokenize(request.input) |
|
vectors = self._model.encode(request.input) |
|
|
|
embeddings = [None] * len(request) |
|
if not request.is_batched: |
|
embeddings[0] = Embedding(index=0, embedding=vectors.tolist()) |
|
else: |
|
for (index, embedding) in enumerate(vectors.tolist()): |
|
embedding = Embedding(index=index, embedding=embedding) |
|
embeddings[index] = embedding |
|
|
|
usage = get_usage(tokens, request.is_batched) |
|
return EmbeddingResponse(model=self._model_name, embeddings=embeddings, usage=usage) |
|
|
|
|
|
def entrypoint(): |
|
config = EndpointConfig.from_env() |
|
|
|
logger.info(f"[Hugging Face Endpoint v{__version__}] Serving: {config.model_id}") |
|
|
|
endpoint = EmbeddingEndpoint(SentenceTransformerHandler(config)) |
|
run(endpoint, config.interface, config.port) |
|
|
|
|
|
if __name__ == "__main__": |
|
entrypoint() |
|
|