"""
Test client.

We just test the dummy client.
"""
from manifest.clients.dummy import DummyClient


def test_init() -> None:
    """Test client initialization."""
    client = DummyClient(connection_str=None)
    assert client.n == 1  # type: ignore

    args = {"n": 3}
    client = DummyClient(connection_str=None, client_args=args)
    assert client.n == 3  # type: ignore


def test_get_params() -> None:
    """Test get param functions."""
    client = DummyClient(connection_str=None)
    assert client.get_model_params() == {
        "engine": "dummy",
        "model": "text-davinci-003",
    }
    assert client.get_model_inputs() == [
        "engine",
        "temperature",
        "max_tokens",
        "n",
        "top_p",
        "top_k",
        "batch_size",
    ]


def test_get_request() -> None:
    """Test client get request."""
    args = {"n": 3}
    client = DummyClient(connection_str=None, client_args=args)
    request_params = client.get_request("hello", {})
    response = client.run_request(request_params)
    assert client.get_cache_key(request_params) == {
        "prompt": "hello",
        "model": "text-davinci-003",
        "n": 3,
        "temperature": 0.0,
        "max_tokens": 10,
        "top_p": 1.0,
        "best_of": 1,
        "engine": "dummy",
        "request_cls": "LMRequest",
    }
    assert response.get_json_response() == {
        "choices": [
            {
                "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious",  # noqa: E501
                "token_logprobs": [
                    -0.2649905035732101,
                    -1.210794839387105,
                    -1.2173929801003434,
                    -0.7758233850171001,
                    -0.7165940659570416,
                    -1.7430328887209088,
                    -1.5379414228820203,
                    -1.7838011423472508,
                    -1.139095076944217,
                    -0.6321855879833425,
                ],
                "tokens": [
                    "70470",
                    "80723",
                    "52693",
                    "39743",
                    "38983",
                    "1303",
                    "56072",
                    "22306",
                    "17738",
                    "53176",
                ],
            }
        ]
        * 3
    }
    assert response.get_usage_obj().dict() == {
        "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
        * 3,
    }

    request_params = client.get_request("hello", {"n": 5})
    response = client.run_request(request_params)
    assert client.get_cache_key(request_params) == {
        "prompt": "hello",
        "model": "text-davinci-003",
        "n": 5,
        "temperature": 0.0,
        "max_tokens": 10,
        "top_p": 1.0,
        "best_of": 1,
        "engine": "dummy",
        "request_cls": "LMRequest",
    }
    assert response.get_json_response() == {
        "choices": [
            {
                "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious",  # noqa: E501
                "token_logprobs": [
                    -0.2649905035732101,
                    -1.210794839387105,
                    -1.2173929801003434,
                    -0.7758233850171001,
                    -0.7165940659570416,
                    -1.7430328887209088,
                    -1.5379414228820203,
                    -1.7838011423472508,
                    -1.139095076944217,
                    -0.6321855879833425,
                ],
                "tokens": [
                    "70470",
                    "80723",
                    "52693",
                    "39743",
                    "38983",
                    "1303",
                    "56072",
                    "22306",
                    "17738",
                    "53176",
                ],
            }
        ]
        * 5
    }
    assert response.get_usage_obj().dict() == {
        "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
        * 5,
    }

    request_params = client.get_request(["hello"] * 5, {"n": 1})
    response = client.run_request(request_params)
    assert client.get_cache_key(request_params) == {
        "prompt": ["hello"] * 5,
        "model": "text-davinci-003",
        "n": 1,
        "temperature": 0.0,
        "max_tokens": 10,
        "top_p": 1.0,
        "best_of": 1,
        "engine": "dummy",
        "request_cls": "LMRequest",
    }
    assert response.get_json_response() == {
        "choices": [
            {
                "text": " probsuib.FirstName>- commodityting segunda inserted signals Religious",  # noqa: E501
                "token_logprobs": [
                    -0.2649905035732101,
                    -1.210794839387105,
                    -1.2173929801003434,
                    -0.7758233850171001,
                    -0.7165940659570416,
                    -1.7430328887209088,
                    -1.5379414228820203,
                    -1.7838011423472508,
                    -1.139095076944217,
                    -0.6321855879833425,
                ],
                "tokens": [
                    "70470",
                    "80723",
                    "52693",
                    "39743",
                    "38983",
                    "1303",
                    "56072",
                    "22306",
                    "17738",
                    "53176",
                ],
            }
        ]
        * 5
    }
    assert response.get_usage_obj().dict() == {
        "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
        * 5,
    }