Spaces:
Running
Running
| """An abstraction layer for prompting different models.""" | |
| from __future__ import annotations | |
| import enum | |
| from fastchat.model.model_adapter import get_conversation_template | |
| class Task(enum.Enum): | |
| """Different system prompt styles.""" | |
| CHAT = "chat" | |
| CHAT_CONCISE = "chat-concise" | |
| INSTRUCT = "instruct" | |
| INSTRUCT_CONCISE = "instruct-concise" | |
| SYSTEM_PROMPTS = { | |
| Task.CHAT: ( | |
| "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " | |
| "The assistant gives helpful, detailed, and polite answers to the user's questions. " | |
| ), | |
| Task.CHAT_CONCISE: ( | |
| "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " | |
| "The assistant gives helpful, detailed, and polite answers to the user's questions. " | |
| "The assistant's answers are very concise. " | |
| ), | |
| Task.INSTRUCT: ( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request. " | |
| ), | |
| Task.INSTRUCT_CONCISE: ( | |
| "Below is an instruction that describes a task. " | |
| "Write a response that appropriately completes the request. " | |
| "The response should be very concise. " | |
| ), | |
| } | |
| def get_system_prompt(task: Task | str) -> str: | |
| """Get the system prompt for a given task.""" | |
| if isinstance(task, str): | |
| task = Task(task) | |
| return SYSTEM_PROMPTS[task] | |
| def apply_model_characteristics( | |
| prompt: str, | |
| model_name: str, | |
| system_prompt: str | None = None, | |
| ) -> tuple[str, str | None, list[int]]: | |
| """Apply and return model-specific differences.""" | |
| conv = get_conversation_template(model_name) | |
| if system_prompt is not None: | |
| conv.system_message = system_prompt | |
| conv.messages = [] | |
| conv.offset = 0 | |
| conv.append_message(conv.roles[0], prompt) | |
| conv.append_message(conv.roles[1], "") | |
| stop_str = None if conv.stop_str is None or not conv.stop_str else conv.stop_str | |
| return conv.get_prompt(), stop_str, (conv.stop_token_ids or []) | |