|
from typing import List |
|
|
|
from openai import OpenAI |
|
|
|
from mcp_server_mariadb_vector.embeddings.base import EmbeddingProvider |
|
|
|
|
|
class OpenAIEmbeddingProvider(EmbeddingProvider): |
|
""" |
|
OpenAI implementation of the embedding provider. |
|
|
|
Args: |
|
model: The name of the OpenAI model to use. |
|
""" |
|
|
|
def __init__(self, model: str, api_key: str): |
|
self.model = model |
|
self.client = OpenAI(api_key=api_key) |
|
|
|
def length_of_embedding(self) -> int: |
|
"""Get the length of the embedding for a given model.""" |
|
if self.model == "text-embedding-3-small": |
|
return 1536 |
|
elif self.model == "text-embedding-3-large": |
|
return 3072 |
|
else: |
|
raise ValueError(f"Unknown embedding model: {self.model}") |
|
|
|
def embed_documents(self, documents: List[str]) -> List[List[float]]: |
|
"""Embed a list of documents into vectors.""" |
|
embeddings = [ |
|
self.client.embeddings.create( |
|
model=self.model, |
|
input=document, |
|
) |
|
.data[0] |
|
.embedding |
|
for document in documents |
|
] |
|
return embeddings |
|
|
|
def embed_query(self, query: str) -> List[float]: |
|
"""Embed a query into a vector.""" |
|
embedding = self.client.embeddings.create( |
|
model=self.model, |
|
input=query, |
|
) |
|
return embedding.data[0].embedding |
|
|