Kaballas's picture
initial
644bdfe
from typing import List
from openai import OpenAI
from mcp_server_mariadb_vector.embeddings.base import EmbeddingProvider
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""
OpenAI implementation of the embedding provider.
Args:
model: The name of the OpenAI model to use.
"""
def __init__(self, model: str, api_key: str):
self.model = model
self.client = OpenAI(api_key=api_key)
def length_of_embedding(self) -> int:
"""Get the length of the embedding for a given model."""
if self.model == "text-embedding-3-small":
return 1536
elif self.model == "text-embedding-3-large":
return 3072
else:
raise ValueError(f"Unknown embedding model: {self.model}")
def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of documents into vectors."""
embeddings = [
self.client.embeddings.create(
model=self.model,
input=document,
)
.data[0]
.embedding
for document in documents
]
return embeddings
def embed_query(self, query: str) -> List[float]:
"""Embed a query into a vector."""
embedding = self.client.embeddings.create(
model=self.model,
input=query,
)
return embedding.data[0].embedding