Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import contextlib | |
| from uuid import uuid4, UUID | |
| from typing import Generator, Literal | |
| import requests | |
| import gradio as gr | |
| from spitfight.colosseum.common import ( | |
| COLOSSEUM_MODELS_ROUTE, | |
| COLOSSEUM_PROMPT_ROUTE, | |
| COLOSSEUM_RESP_VOTE_ROUTE, | |
| COLOSSEUM_ENERGY_VOTE_ROUTE, | |
| ModelsResponse, | |
| PromptRequest, | |
| ResponseVoteRequest, | |
| ResponseVoteResponse, | |
| EnergyVoteRequest, | |
| EnergyVoteResponse, | |
| ) | |
| class ControllerClient: | |
| """Client for the Colosseum controller, to be used by Gradio.""" | |
| def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None: | |
| """Initialize the controller client.""" | |
| self.controller_addr = controller_addr | |
| self.timeout = timeout | |
| self.request_id = str(uuid4()) if request_id is None else str(request_id) | |
| def fork(self) -> ControllerClient: | |
| """Return a copy of the client with a new request ID.""" | |
| return ControllerClient( | |
| controller_addr=self.controller_addr, | |
| timeout=self.timeout, | |
| request_id=uuid4(), | |
| ) | |
| def get_available_models(self) -> list[str]: | |
| """Retrieve the list of available models.""" | |
| with _catch_requests_exceptions(): | |
| resp = requests.get( | |
| f"http://{self.controller_addr}{COLOSSEUM_MODELS_ROUTE}", | |
| timeout=self.timeout, | |
| ) | |
| _check_response(resp) | |
| return ModelsResponse(**resp.json()).available_models | |
| def prompt( | |
| self, | |
| prompt: str, | |
| index: Literal[0, 1], | |
| model_preference: str, | |
| ) -> Generator[str, None, None]: | |
| """Generate the response of the `index`th model with the prompt. | |
| `user_pref` is the user's preference for the model to use. It can be | |
| `"Random"` or one of the models in the list returned by `get_available_models`. | |
| """ | |
| prompt_request = PromptRequest( | |
| request_id=self.request_id, | |
| prompt=prompt, | |
| model_index=index, | |
| model_preference=model_preference, | |
| ) | |
| with _catch_requests_exceptions(): | |
| resp = requests.post( | |
| f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}", | |
| json=prompt_request.dict(), | |
| stream=True, | |
| timeout=self.timeout, | |
| ) | |
| _check_response(resp) | |
| # XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines? | |
| for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
| if chunk: | |
| yield json.loads(chunk.decode("utf-8")) | |
| def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse: | |
| """Notify the controller of the user's vote for the response.""" | |
| response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index) | |
| with _catch_requests_exceptions(): | |
| resp = requests.post( | |
| f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}", | |
| json=response_vote_request.dict(), | |
| ) | |
| _check_response(resp) | |
| return ResponseVoteResponse(**resp.json()) | |
| def energy_vote(self, is_worth: bool) -> EnergyVoteResponse: | |
| """Notify the controller of the user's vote for energy.""" | |
| energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth) | |
| with _catch_requests_exceptions(): | |
| resp = requests.post( | |
| f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}", | |
| json=energy_vote_request.dict(), | |
| ) | |
| _check_response(resp) | |
| return EnergyVoteResponse(**resp.json()) | |
| def _catch_requests_exceptions(): | |
| """Catch requests exceptions and raise gr.Error instead.""" | |
| try: | |
| yield | |
| except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): | |
| raise gr.Error("Failed to connect to our the backend server. Please try again later.") | |
| def _check_response(response: requests.Response) -> None: | |
| if 400 <= response.status_code < 500: | |
| raise gr.Error(response.json()["detail"]) | |
| elif response.status_code >= 500: | |
| raise gr.Error("Failed to talk to our backend server. Please try again later.") | |