Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
ede9e6a
1
Parent(s):
74ecebe
tests: proper `get_config` dependency override
Browse files
src/faster_whisper_server/dependencies.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from functools import lru_cache
|
|
|
|
| 2 |
from typing import Annotated
|
| 3 |
|
| 4 |
from fastapi import Depends, HTTPException, status
|
|
@@ -11,7 +12,13 @@ from openai.resources.chat.completions import AsyncCompletions
|
|
| 11 |
from faster_whisper_server.config import Config
|
| 12 |
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
|
| 13 |
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
@lru_cache
|
| 16 |
def get_config() -> Config:
|
| 17 |
return Config()
|
|
@@ -22,7 +29,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
|
|
| 22 |
|
| 23 |
@lru_cache
|
| 24 |
def get_model_manager() -> WhisperModelManager:
|
| 25 |
-
config = get_config()
|
| 26 |
return WhisperModelManager(config.whisper)
|
| 27 |
|
| 28 |
|
|
@@ -31,8 +38,8 @@ ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manage
|
|
| 31 |
|
| 32 |
@lru_cache
|
| 33 |
def get_piper_model_manager() -> PiperModelManager:
|
| 34 |
-
config = get_config()
|
| 35 |
-
return PiperModelManager(config.whisper.ttl) # HACK
|
| 36 |
|
| 37 |
|
| 38 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
|
|
@@ -53,7 +60,7 @@ ApiKeyDependency = Depends(verify_api_key)
|
|
| 53 |
|
| 54 |
@lru_cache
|
| 55 |
def get_completion_client() -> AsyncCompletions:
|
| 56 |
-
config = get_config()
|
| 57 |
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
|
| 58 |
return oai_client.chat.completions
|
| 59 |
|
|
@@ -63,9 +70,9 @@ CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_
|
|
| 63 |
|
| 64 |
@lru_cache
|
| 65 |
def get_speech_client() -> AsyncSpeech:
|
| 66 |
-
config = get_config()
|
| 67 |
if config.speech_base_url is None:
|
| 68 |
-
# this might not work as expected if
|
| 69 |
from faster_whisper_server.routers.speech import (
|
| 70 |
router as speech_router,
|
| 71 |
)
|
|
@@ -86,7 +93,7 @@ SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]
|
|
| 86 |
def get_transcription_client() -> AsyncTranscriptions:
|
| 87 |
config = get_config()
|
| 88 |
if config.transcription_base_url is None:
|
| 89 |
-
# this might not work as expected if
|
| 90 |
from faster_whisper_server.routers.stt import (
|
| 91 |
router as stt_router,
|
| 92 |
)
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
+
import logging
|
| 3 |
from typing import Annotated
|
| 4 |
|
| 5 |
from fastapi import Depends, HTTPException, status
|
|
|
|
| 12 |
from faster_whisper_server.config import Config
|
| 13 |
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
|
| 14 |
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
+
# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache
|
| 21 |
+
# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501
|
| 22 |
@lru_cache
|
| 23 |
def get_config() -> Config:
|
| 24 |
return Config()
|
|
|
|
| 29 |
|
| 30 |
@lru_cache
|
| 31 |
def get_model_manager() -> WhisperModelManager:
|
| 32 |
+
config = get_config()
|
| 33 |
return WhisperModelManager(config.whisper)
|
| 34 |
|
| 35 |
|
|
|
|
| 38 |
|
| 39 |
@lru_cache
|
| 40 |
def get_piper_model_manager() -> PiperModelManager:
|
| 41 |
+
config = get_config()
|
| 42 |
+
return PiperModelManager(config.whisper.ttl) # HACK: should have its own config
|
| 43 |
|
| 44 |
|
| 45 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
|
|
|
|
| 60 |
|
| 61 |
@lru_cache
|
| 62 |
def get_completion_client() -> AsyncCompletions:
|
| 63 |
+
config = get_config()
|
| 64 |
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
|
| 65 |
return oai_client.chat.completions
|
| 66 |
|
|
|
|
| 70 |
|
| 71 |
@lru_cache
|
| 72 |
def get_speech_client() -> AsyncSpeech:
|
| 73 |
+
config = get_config()
|
| 74 |
if config.speech_base_url is None:
|
| 75 |
+
# this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
|
| 76 |
from faster_whisper_server.routers.speech import (
|
| 77 |
router as speech_router,
|
| 78 |
)
|
|
|
|
| 93 |
def get_transcription_client() -> AsyncTranscriptions:
|
| 94 |
config = get_config()
|
| 95 |
if config.transcription_base_url is None:
|
| 96 |
+
# this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501
|
| 97 |
from faster_whisper_server.routers.stt import (
|
| 98 |
router as stt_router,
|
| 99 |
)
|
src/faster_whisper_server/logger.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
import logging
|
| 2 |
|
| 3 |
-
from faster_whisper_server.dependencies import get_config
|
| 4 |
|
| 5 |
-
|
| 6 |
-
def setup_logger() -> None:
|
| 7 |
-
config = get_config() # HACK
|
| 8 |
logging.getLogger().setLevel(logging.INFO)
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
-
logger.setLevel(
|
| 11 |
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
|
|
|
|
| 1 |
import logging
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
+
def setup_logger(log_level: str) -> None:
|
|
|
|
|
|
|
| 5 |
logging.getLogger().setLevel(logging.INFO)
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
+
logger.setLevel(log_level.upper())
|
| 8 |
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
|
src/faster_whisper_server/main.py
CHANGED
|
@@ -27,10 +27,12 @@ if TYPE_CHECKING:
|
|
| 27 |
|
| 28 |
|
| 29 |
def create_app() -> FastAPI:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
|
|
|
|
|
|
| 34 |
if platform.machine() == "x86_64":
|
| 35 |
from faster_whisper_server.routers.speech import (
|
| 36 |
router as speech_router,
|
|
@@ -39,9 +41,6 @@ def create_app() -> FastAPI:
|
|
| 39 |
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
|
| 40 |
speech_router = None
|
| 41 |
|
| 42 |
-
config = get_config() # HACK
|
| 43 |
-
logger.debug(f"Config: {config}")
|
| 44 |
-
|
| 45 |
model_manager = get_model_manager() # HACK
|
| 46 |
|
| 47 |
@asynccontextmanager
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def create_app() -> FastAPI:
|
| 30 |
+
config = get_config() # HACK
|
| 31 |
+
setup_logger(config.log_level)
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
+
logger.debug(f"Config: {config}")
|
| 35 |
+
|
| 36 |
if platform.machine() == "x86_64":
|
| 37 |
from faster_whisper_server.routers.speech import (
|
| 38 |
router as speech_router,
|
|
|
|
| 41 |
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
|
| 42 |
speech_router = None
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
model_manager = get_model_manager() # HACK
|
| 45 |
|
| 46 |
@asynccontextmanager
|
tests/conftest.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from collections.abc import AsyncGenerator, Generator
|
|
|
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
| 4 |
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
from httpx import ASGITransport, AsyncClient
|
|
@@ -8,19 +10,31 @@ from huggingface_hub import snapshot_download
|
|
| 8 |
from openai import AsyncOpenAI
|
| 9 |
import pytest
|
| 10 |
import pytest_asyncio
|
|
|
|
| 11 |
|
|
|
|
|
|
|
| 12 |
from faster_whisper_server.main import create_app
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def pytest_configure() -> None:
|
| 18 |
-
for logger_name in
|
| 19 |
logger = logging.getLogger(logger_name)
|
| 20 |
logger.disabled = True
|
| 21 |
|
| 22 |
|
| 23 |
-
# NOTE: not being used. Keeping just in case
|
| 24 |
@pytest.fixture
|
| 25 |
def client() -> Generator[TestClient, None, None]:
|
| 26 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
|
@@ -28,10 +42,37 @@ def client() -> Generator[TestClient, None, None]:
|
|
| 28 |
yield client
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
@pytest_asyncio.fixture()
|
| 32 |
-
async def
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
yield aclient
|
| 36 |
|
| 37 |
|
|
@@ -43,11 +84,13 @@ def openai_client(aclient: AsyncClient) -> AsyncOpenAI:
|
|
| 43 |
@pytest.fixture
|
| 44 |
def actual_openai_client() -> AsyncOpenAI:
|
| 45 |
return AsyncOpenAI(
|
| 46 |
-
base_url
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
# TODO: remove the download after running the tests
|
|
|
|
| 51 |
@pytest.fixture(scope="session", autouse=True)
|
| 52 |
def download_piper_voices() -> None:
|
| 53 |
# Only download `voices.json` and the default voice
|
|
|
|
| 1 |
from collections.abc import AsyncGenerator, Generator
|
| 2 |
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
from typing import Protocol
|
| 6 |
|
| 7 |
from fastapi.testclient import TestClient
|
| 8 |
from httpx import ASGITransport, AsyncClient
|
|
|
|
| 10 |
from openai import AsyncOpenAI
|
| 11 |
import pytest
|
| 12 |
import pytest_asyncio
|
| 13 |
+
from pytest_mock import MockerFixture
|
| 14 |
|
| 15 |
+
from faster_whisper_server.config import Config, WhisperConfig
|
| 16 |
+
from faster_whisper_server.dependencies import get_config
|
| 17 |
from faster_whisper_server.main import create_app
|
| 18 |
|
| 19 |
+
DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
|
| 20 |
+
OPENAI_BASE_URL = "https://api.openai.com/v1"
|
| 21 |
+
DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
|
| 22 |
+
# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501
|
| 23 |
+
DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
|
| 24 |
+
DEFAULT_CONFIG = Config(
|
| 25 |
+
whisper=DEFAULT_WHISPER_CONFIG,
|
| 26 |
+
# disable the UI as it slightly increases the app startup time due to the imports it's doing
|
| 27 |
+
enable_ui=False,
|
| 28 |
+
)
|
| 29 |
|
| 30 |
|
| 31 |
def pytest_configure() -> None:
|
| 32 |
+
for logger_name in DISABLE_LOGGERS:
|
| 33 |
logger = logging.getLogger(logger_name)
|
| 34 |
logger.disabled = True
|
| 35 |
|
| 36 |
|
| 37 |
+
# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
|
| 38 |
@pytest.fixture
|
| 39 |
def client() -> Generator[TestClient, None, None]:
|
| 40 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
|
|
|
| 42 |
yield client
|
| 43 |
|
| 44 |
|
| 45 |
+
# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional
|
| 46 |
+
class AclientFactory(Protocol):
|
| 47 |
+
def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ...
|
| 48 |
+
|
| 49 |
+
|
| 50 |
@pytest_asyncio.fixture()
|
| 51 |
+
async def aclient_factory(mocker: MockerFixture) -> AclientFactory:
|
| 52 |
+
"""Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration."""
|
| 53 |
+
|
| 54 |
+
@asynccontextmanager
|
| 55 |
+
async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]:
|
| 56 |
+
# NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501
|
| 57 |
+
mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config)
|
| 58 |
+
mocker.patch("faster_whisper_server.main.get_config", return_value=config)
|
| 59 |
+
# NOTE: I couldn't get the following to work but it shouldn't matter
|
| 60 |
+
# mocker.patch(
|
| 61 |
+
# "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config
|
| 62 |
+
# )
|
| 63 |
+
|
| 64 |
+
app = create_app()
|
| 65 |
+
# https://fastapi.tiangolo.com/advanced/testing-dependencies/
|
| 66 |
+
app.dependency_overrides[get_config] = lambda: config
|
| 67 |
+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
|
| 68 |
+
yield aclient
|
| 69 |
+
|
| 70 |
+
return inner
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@pytest_asyncio.fixture()
|
| 74 |
+
async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]:
|
| 75 |
+
async with aclient_factory() as aclient:
|
| 76 |
yield aclient
|
| 77 |
|
| 78 |
|
|
|
|
| 84 |
@pytest.fixture
|
| 85 |
def actual_openai_client() -> AsyncOpenAI:
|
| 86 |
return AsyncOpenAI(
|
| 87 |
+
# `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value
|
| 88 |
+
base_url=OPENAI_BASE_URL
|
| 89 |
+
)
|
| 90 |
|
| 91 |
|
| 92 |
# TODO: remove the download after running the tests
|
| 93 |
+
# TODO: do not download when not needed
|
| 94 |
@pytest.fixture(scope="session", autouse=True)
|
| 95 |
def download_piper_voices() -> None:
|
| 96 |
# Only download `voices.json` and the default voice
|
tests/model_manager_test.py
CHANGED
|
@@ -1,23 +1,22 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
import os
|
| 3 |
|
| 4 |
import anyio
|
| 5 |
-
from httpx import ASGITransport, AsyncClient
|
| 6 |
import pytest
|
| 7 |
|
| 8 |
-
from faster_whisper_server.
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
@pytest.mark.asyncio
|
| 12 |
-
async def test_model_unloaded_after_ttl() -> None:
|
| 13 |
ttl = 5
|
| 14 |
-
model =
|
| 15 |
-
|
| 16 |
-
os.environ["ENABLE_UI"] = "false"
|
| 17 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
| 18 |
res = (await aclient.get("/api/ps")).json()
|
| 19 |
assert len(res["models"]) == 0
|
| 20 |
-
await aclient.post(f"/api/ps/{
|
| 21 |
res = (await aclient.get("/api/ps")).json()
|
| 22 |
assert len(res["models"]) == 1
|
| 23 |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
|
|
@@ -26,13 +25,11 @@ async def test_model_unloaded_after_ttl() -> None:
|
|
| 26 |
|
| 27 |
|
| 28 |
@pytest.mark.asyncio
|
| 29 |
-
async def test_ttl_resets_after_usage() -> None:
|
| 30 |
ttl = 5
|
| 31 |
-
model =
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
| 35 |
-
await aclient.post(f"/api/ps/{model}")
|
| 36 |
res = (await aclient.get("/api/ps")).json()
|
| 37 |
assert len(res["models"]) == 1
|
| 38 |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
|
|
@@ -43,7 +40,9 @@ async def test_ttl_resets_after_usage() -> None:
|
|
| 43 |
data = await f.read()
|
| 44 |
res = (
|
| 45 |
await aclient.post(
|
| 46 |
-
"/v1/audio/transcriptions",
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
).json()
|
| 49 |
res = (await aclient.get("/api/ps")).json()
|
|
@@ -60,28 +59,28 @@ async def test_ttl_resets_after_usage() -> None:
|
|
| 60 |
# this just ensures the model can be loaded again after being unloaded
|
| 61 |
res = (
|
| 62 |
await aclient.post(
|
| 63 |
-
"/v1/audio/transcriptions",
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
).json()
|
| 66 |
|
| 67 |
|
| 68 |
@pytest.mark.asyncio
|
| 69 |
-
async def test_model_cant_be_unloaded_when_used() -> None:
|
| 70 |
ttl = 0
|
| 71 |
-
model =
|
| 72 |
-
|
| 73 |
-
os.environ["ENABLE_UI"] = "false"
|
| 74 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
| 75 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
| 76 |
data = await f.read()
|
| 77 |
|
| 78 |
task = asyncio.create_task(
|
| 79 |
aclient.post(
|
| 80 |
-
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model":
|
| 81 |
)
|
| 82 |
)
|
| 83 |
await asyncio.sleep(0.1) # wait for the server to start processing the request
|
| 84 |
-
res = await aclient.delete(f"/api/ps/{
|
| 85 |
assert res.status_code == 409
|
| 86 |
|
| 87 |
await task
|
|
@@ -90,27 +89,23 @@ async def test_model_cant_be_unloaded_when_used() -> None:
|
|
| 90 |
|
| 91 |
|
| 92 |
@pytest.mark.asyncio
|
| 93 |
-
async def test_model_cant_be_loaded_twice() -> None:
|
| 94 |
ttl = -1
|
| 95 |
-
model =
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
| 99 |
-
res = await aclient.post(f"/api/ps/{model}")
|
| 100 |
assert res.status_code == 201
|
| 101 |
-
res = await aclient.post(f"/api/ps/{
|
| 102 |
assert res.status_code == 409
|
| 103 |
res = (await aclient.get("/api/ps")).json()
|
| 104 |
assert len(res["models"]) == 1
|
| 105 |
|
| 106 |
|
| 107 |
@pytest.mark.asyncio
|
| 108 |
-
async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
|
| 109 |
ttl = 0
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
os.environ["ENABLE_UI"] = "false"
|
| 113 |
-
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
| 114 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
| 115 |
data = await f.read()
|
| 116 |
res = await aclient.post(
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
|
| 3 |
import anyio
|
|
|
|
| 4 |
import pytest
|
| 5 |
|
| 6 |
+
from faster_whisper_server.config import Config, WhisperConfig
|
| 7 |
+
from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
|
| 8 |
+
|
| 9 |
+
MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable
|
| 10 |
|
| 11 |
|
| 12 |
@pytest.mark.asyncio
|
| 13 |
+
async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
|
| 14 |
ttl = 5
|
| 15 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
| 16 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
|
|
| 17 |
res = (await aclient.get("/api/ps")).json()
|
| 18 |
assert len(res["models"]) == 0
|
| 19 |
+
await aclient.post(f"/api/ps/{MODEL}")
|
| 20 |
res = (await aclient.get("/api/ps")).json()
|
| 21 |
assert len(res["models"]) == 1
|
| 22 |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
@pytest.mark.asyncio
|
| 28 |
+
async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
|
| 29 |
ttl = 5
|
| 30 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
| 31 |
+
async with aclient_factory(config) as aclient:
|
| 32 |
+
await aclient.post(f"/api/ps/{MODEL}")
|
|
|
|
|
|
|
| 33 |
res = (await aclient.get("/api/ps")).json()
|
| 34 |
assert len(res["models"]) == 1
|
| 35 |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded
|
|
|
|
| 40 |
data = await f.read()
|
| 41 |
res = (
|
| 42 |
await aclient.post(
|
| 43 |
+
"/v1/audio/transcriptions",
|
| 44 |
+
files={"file": ("audio.wav", data, "audio/wav")},
|
| 45 |
+
data={"model": MODEL},
|
| 46 |
)
|
| 47 |
).json()
|
| 48 |
res = (await aclient.get("/api/ps")).json()
|
|
|
|
| 59 |
# this just ensures the model can be loaded again after being unloaded
|
| 60 |
res = (
|
| 61 |
await aclient.post(
|
| 62 |
+
"/v1/audio/transcriptions",
|
| 63 |
+
files={"file": ("audio.wav", data, "audio/wav")},
|
| 64 |
+
data={"model": MODEL},
|
| 65 |
)
|
| 66 |
).json()
|
| 67 |
|
| 68 |
|
| 69 |
@pytest.mark.asyncio
|
| 70 |
+
async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
|
| 71 |
ttl = 0
|
| 72 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
| 73 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
|
|
| 74 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
| 75 |
data = await f.read()
|
| 76 |
|
| 77 |
task = asyncio.create_task(
|
| 78 |
aclient.post(
|
| 79 |
+
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
|
| 80 |
)
|
| 81 |
)
|
| 82 |
await asyncio.sleep(0.1) # wait for the server to start processing the request
|
| 83 |
+
res = await aclient.delete(f"/api/ps/{MODEL}")
|
| 84 |
assert res.status_code == 409
|
| 85 |
|
| 86 |
await task
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
@pytest.mark.asyncio
|
| 92 |
+
async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
|
| 93 |
ttl = -1
|
| 94 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
| 95 |
+
async with aclient_factory(config) as aclient:
|
| 96 |
+
res = await aclient.post(f"/api/ps/{MODEL}")
|
|
|
|
|
|
|
| 97 |
assert res.status_code == 201
|
| 98 |
+
res = await aclient.post(f"/api/ps/{MODEL}")
|
| 99 |
assert res.status_code == 409
|
| 100 |
res = (await aclient.get("/api/ps")).json()
|
| 101 |
assert len(res["models"]) == 1
|
| 102 |
|
| 103 |
|
| 104 |
@pytest.mark.asyncio
|
| 105 |
+
async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
|
| 106 |
ttl = 0
|
| 107 |
+
config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
|
| 108 |
+
async with aclient_factory(config) as aclient:
|
|
|
|
|
|
|
| 109 |
async with await anyio.open_file("audio.wav", "rb") as f:
|
| 110 |
data = await f.read()
|
| 111 |
res = await aclient.post(
|