| """ | |
| Module: custom_agent | |
| This module provides a custom class, CustomHfAgent, for interacting with the Hugging Face model API. | |
| Dependencies: | |
| - time: Standard Python time module for time-related operations. | |
| - requests: HTTP library for making requests. | |
| - transformers: Hugging Face's transformers library for NLP tasks. | |
| - utils.logger: Custom logger module for logging responses. | |
| Classes: | |
| - CustomHfAgent: A custom class for interacting with the Hugging Face model API. | |
| """ | |
| import time | |
| import requests | |
| from transformers import Agent | |
| from utils.logger import log_response | |
| class CustomHfAgent(Agent): | |
| """A custom class for interacting with the Hugging Face model API.""" | |
| def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): | |
| """ | |
| Initialize the CustomHfAgent. | |
| Args: | |
| - url_endpoint (str): The URL endpoint for the Hugging Face model API. | |
| - token (str): The authentication token required to access the API. | |
| - chat_prompt_template (str): Template for chat prompts. | |
| - run_prompt_template (str): Template for run prompts. | |
| - additional_tools (list): Additional tools for the agent. | |
| - input_params (dict): Additional parameters for input. | |
| Returns: | |
| - None | |
| """ | |
| super().__init__( | |
| chat_prompt_template=chat_prompt_template, | |
| run_prompt_template=run_prompt_template, | |
| additional_tools=additional_tools, | |
| ) | |
| self.url_endpoint = url_endpoint | |
| self.token = token | |
| self.input_params = input_params | |
| def generate_one(self, prompt, stop): | |
| """ | |
| Generate one response from the Hugging Face model. | |
| Args: | |
| - prompt (str): The prompt to generate a response for. | |
| - stop (list): A list of strings indicating where to stop generating text. | |
| Returns: | |
| - str: The generated response. | |
| """ | |
| headers = {"Authorization": self.token} | |
| max_new_tokens = self.input_params.get("max_new_tokens", 192) | |
| parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} | |
| inputs = { | |
| "inputs": prompt, | |
| "parameters": parameters, | |
| } | |
| print(inputs) | |
| try: | |
| response = requests.post(self.url_endpoint, json=inputs, headers=headers, timeout=300) | |
| except requests.Timeout: | |
| pass | |
| except requests.ConnectionError: | |
| pass | |
| if response.status_code == 429: | |
| log_response("Getting rate-limited, waiting a tiny bit before trying again.") | |
| time.sleep(1) | |
| return self.generate_one(prompt, stop) | |
| elif response.status_code != 200: | |
| raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") | |
| log_response(response) | |
| result = response.json()[0]["generated_text"] | |
| for stop_seq in stop: | |
| if result.endswith(stop_seq): | |
| return result[: -len(stop_seq)] | |
| return result | |