Spaces:
Running
Running
Jae-Won Chung
commited on
Commit
·
e38f79f
1
Parent(s):
6393815
Do not standardize system prompts
Browse files
spitfight/colosseum/controller/controller.py
CHANGED
|
@@ -44,9 +44,10 @@ class RequestState(BaseModel):
|
|
| 44 |
This model is also serialized as is and logged.
|
| 45 |
"""
|
| 46 |
request_id: str
|
| 47 |
-
prompt: str
|
| 48 |
model_names: list[str]
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
energy_consumptions: list[float] = [0.0, 0.0]
|
| 51 |
response_victory_index: Optional[Literal[0, 1]] = None
|
| 52 |
extra_energy_was_worth: Optional[bool] = None
|
|
@@ -172,7 +173,7 @@ class Controller:
|
|
| 172 |
model_names = [worker.model_name for worker in workers]
|
| 173 |
self.request_states[request_id] = RequestState(
|
| 174 |
request_id=request_id,
|
| 175 |
-
|
| 176 |
model_names=model_names,
|
| 177 |
)
|
| 178 |
request_state = self.request_states[request_id]
|
|
@@ -185,11 +186,13 @@ class Controller:
|
|
| 185 |
except RuntimeError:
|
| 186 |
controller_logger.error("Worker %s is dead.", model_name)
|
| 187 |
raise
|
|
|
|
|
|
|
| 188 |
prompt, stop_str, stop_token_ids = apply_model_characteristics(
|
| 189 |
-
system_prompt=get_system_prompt("chat"),
|
| 190 |
prompt=prompt,
|
| 191 |
model_name=worker.model_id,
|
| 192 |
)
|
|
|
|
| 193 |
|
| 194 |
# Request the model worker to stream the response to the user's prompt.
|
| 195 |
response = ""
|
|
|
|
| 44 |
This model is also serialized as is and logged.
|
| 45 |
"""
|
| 46 |
request_id: str
|
|
|
|
| 47 |
model_names: list[str]
|
| 48 |
+
raw_prompt: str
|
| 49 |
+
responses: list[str] = ["UNSET", "UNSET"]
|
| 50 |
+
model_prompts: list[str] = ["UNSET", "UNSET"]
|
| 51 |
energy_consumptions: list[float] = [0.0, 0.0]
|
| 52 |
response_victory_index: Optional[Literal[0, 1]] = None
|
| 53 |
extra_energy_was_worth: Optional[bool] = None
|
|
|
|
| 173 |
model_names = [worker.model_name for worker in workers]
|
| 174 |
self.request_states[request_id] = RequestState(
|
| 175 |
request_id=request_id,
|
| 176 |
+
raw_prompt=prompt,
|
| 177 |
model_names=model_names,
|
| 178 |
)
|
| 179 |
request_state = self.request_states[request_id]
|
|
|
|
| 186 |
except RuntimeError:
|
| 187 |
controller_logger.error("Worker %s is dead.", model_name)
|
| 188 |
raise
|
| 189 |
+
|
| 190 |
+
# Models have different prompt formatting requirements and stopping criteria.
|
| 191 |
prompt, stop_str, stop_token_ids = apply_model_characteristics(
|
|
|
|
| 192 |
prompt=prompt,
|
| 193 |
model_name=worker.model_id,
|
| 194 |
)
|
| 195 |
+
request_state.model_prompts[model_index] = prompt
|
| 196 |
|
| 197 |
# Request the model worker to stream the response to the user's prompt.
|
| 198 |
response = ""
|
spitfight/prompt.py
CHANGED
|
@@ -45,14 +45,15 @@ def get_system_prompt(task: Task | str) -> str:
|
|
| 45 |
|
| 46 |
|
| 47 |
def apply_model_characteristics(
|
| 48 |
-
system_prompt: str,
|
| 49 |
prompt: str,
|
| 50 |
model_name: str,
|
|
|
|
| 51 |
) -> tuple[str, str | None, list[int]]:
|
| 52 |
"""Apply and return model-specific differences."""
|
| 53 |
conv = get_conversation_template(model_name)
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
conv.messages = []
|
| 57 |
conv.offset = 0
|
| 58 |
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def apply_model_characteristics(
|
|
|
|
| 48 |
prompt: str,
|
| 49 |
model_name: str,
|
| 50 |
+
system_prompt: str | None = None,
|
| 51 |
) -> tuple[str, str | None, list[int]]:
|
| 52 |
"""Apply and return model-specific differences."""
|
| 53 |
conv = get_conversation_template(model_name)
|
| 54 |
|
| 55 |
+
if system_prompt is not None:
|
| 56 |
+
conv.system_message = system_prompt
|
| 57 |
conv.messages = []
|
| 58 |
conv.offset = 0
|
| 59 |
|