Mistral_7B_Chat / app.py
Shriharsh's picture
Update app.py
67cfd82 verified
import os
import logging
from huggingface_hub import InferenceClient
import gradio as gr
from requests.exceptions import ConnectionError
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the Hugging Face Inference Client
try:
client = InferenceClient(
model="mistralai/Mistral-7B-Instruct-v0.3",
token=os.getenv("HF_TOKEN"), # Ensure HF_TOKEN is set in your environment
timeout=30,
)
except Exception as e:
logger.error(f"Failed to initialize InferenceClient: {e}")
raise
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
try:
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
logger.info("Sending request to Hugging Face API")
stream = client.text_generation(
formatted_prompt,
**generate_kwargs,
stream=True,
details=True,
return_full_text=False,
)
output = ""
for response in stream:
output += response.token.text
yield output
return output
except ConnectionError as e:
logger.error(f"Network error: {e}")
yield "Error: Unable to connect to the Hugging Face API. Please check your internet connection and try again."
except Exception as e:
logger.error(f"Error during text generation: {e}")
yield f"Error: {str(e)}"
# Define additional inputs for Gradio interface
additional_inputs = [
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=512,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum number of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
),
]
# Create a Chatbot object
chatbot = gr.Chatbot(height=450, layout="bubble")
# Build the Gradio interface
with gr.Blocks() as demo:
gr.HTML("<h1><center>🤖 Mistral-7B-Chat 💬</center></h1>")
gr.ChatInterface(
fn=generate,
chatbot=chatbot,
additional_inputs=additional_inputs,
examples=[
["Give me the code for Binary Search in C++"],
["Explain the chapter of The Grand Inquisitor from The Brothers Karamazov."],
["Explain Newton's second law."],
],
)
if __name__ == "__main__":
logger.info("Starting Gradio application")
demo.launch()