tdoehmen's picture
added duckdb-nsql benchmark setup
b247dc4
raw
history blame
4.04 kB
"""Azure client."""
import logging
import os
from typing import Any, Dict, Optional
from manifest.clients.openai_chat import OPENAICHAT_ENGINES, OpenAIChatClient
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
# Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
# underscores ("_") may be used, except as ending characters. We create this mapping to
# handle difference between Azure and OpenAI
AZURE_DEPLOYMENT_NAME_MAPPING = {
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
}
OPENAI_DEPLOYMENT_NAME_MAPPING = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
}
class AzureChatClient(OpenAIChatClient):
"""Azure chat client."""
# User param -> (client param, default value)
PARAMS = OpenAIChatClient.PARAMS
REQUEST_CLS = LMRequest
NAME = "azureopenaichat"
IS_CHAT = True
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the AzureOpenAI server.
connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
"(e.g. https://openai-azure-service.openai.azure.com/)."
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAICHAT_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
engine = getattr(self, "engine")
deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
return (
self.host
+ "/openai/deployments/"
+ deployment_name
+ "/chat/completions?api-version=2023-05-15"
)
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"api-key": f"{self.api_key}"}
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
# IMPORTANT!!!
# Azure models are the same as openai models. So we want to unify their
# cached. Make sure we retrun the OpenAI name here.
return {"model_name": OpenAIChatClient.NAME, "engine": getattr(self, "engine")}