Bradarr's picture
Update app.py
971430d verified
raw
history blame
7.79 kB
# Importing required libraries
import warnings
warnings.filterwarnings("ignore")
import os
import sys
from llama_cpp import Llama
import gradio as gr
from huggingface_hub import hf_hub_download
from typing import List, Tuple, Generator
from logger import logging # Assuming you have a logger.py
from exception import CustomExceptionHandling # Assuming you have exception.py
import spaces
# Download gguf model files (Simplified for the specified models)
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set
def download_model(repo_id, filename):
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir="./models",
token=huggingface_token, # Use token directly
)
logging.info(f"Successfully downloaded {filename} from {repo_id}")
except Exception as e:
logging.error(f"Error downloading {filename} from {repo_id}: {e}")
raise # Re-raise to halt execution if download fails
# Only download if the files don't already exist. This is crucial.
if not os.path.exists("./models/google.gemma-3-1b-pt.Q4_K_M.gguf"):
download_model("DevQuasar/google.gemma-3-1b-pt-GGUF", "google.gemma-3-1b-pt.Q4_K_M.gguf")
if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"):
download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf")
if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf")
if not os.path.exists("./models/google.gemma-3-27b-pt.Q4_K_M.gguf"): # Example from original, in case needed.
download_model("DevQuasar/google.gemma-3-27b-pt-GGUF", "google.gemma-3-27b-pt.Q4_K_M.gguf")
# Set the title and description
title = "Gemma Text Generation"
description = """Gemma models for text generation and notebook continuation. This interface is designed for generating text continuations, not for interactive chat."""
llm = None
llm_model = None
@spaces.GPU
def generate_text(
prompt: str,
model: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
) -> Generator[str, None, None]:
"""
Generates text based on a prompt, using the specified Gemma model.
Args:
prompt (str): The initial text to continue.
model (str): The model file to use (without path).
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Controls randomness.
top_p (float): Nucleus sampling parameter.
top_k (int): Top-k sampling parameter.
repeat_penalty (float): Penalty for repeating tokens.
Yields:
str: Generated text chunks, streamed as they become available.
"""
try:
global llm
global llm_model
model_path = os.path.join("models", model)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load the model (only if it's a new model)
if llm is None or llm_model != model:
logging.info(f"Loading model: {model}")
llm = Llama(
model_path=model_path,
flash_attn=True,
n_gpu_layers=999, # Adjust based on your GPU availability
n_ctx=4096, # Context window size. Can increase.
verbose=False #Reduce unnecessary verbosity
)
llm_model = model
# llama_cpp handles streaming natively.
for token in llm(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stream=True, # Ensure streaming is on
stop=["<|im_end|>","<|endoftext|>","<|file_separator|>"], # Add appropriate stop tokens
):
text_chunk = token["choices"][0]["text"]
yield text_chunk
except Exception as e:
raise CustomExceptionHandling(e, sys) from e
def clear_history():
"""Clears the text input."""
return ""
with gr.Blocks(theme="Ocean", title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=4):
input_textbox = gr.Textbox(
label="Input Prompt",
placeholder="Enter text to continue...",
lines=10,
)
clear_button = gr.Button("Clear Input")
output_textbox = gr.Textbox( # Changed to Textbox for streaming
label="Generated Text",
lines=10, # Added lines for better display of longer outputs
interactive=False # Output shouldn't be editable
)
with gr.Column(scale=1):
submit_button = gr.Button("Generate", variant="primary")
model_dropdown = gr.Dropdown(
choices=[
"google.gemma-3-1b-pt.Q4_K_M.gguf",
"google.gemma-3-4b-pt.Q4_K_M.gguf",
"google.gemma-3-12b-pt.Q4_K_M.gguf",
"google.gemma-3-27b-pt.Q4_K_M.gguf",
# Add other models as needed and downloaded
],
value="google.gemma-3-1b-pt.Q4_K_M.gguf", # Default model
label="Model",
info="Select the AI model",
)
max_tokens_slider = gr.Slider(
minimum=32,
maximum=8192,
value=512,
step=1,
label="Max Tokens",
info="Maximum length of generated text",
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.05, # finer control
label="Temperature",
info="Controls randomness (higher = more creative)",
)
top_p_slider = gr.Slider(
minimum=0.05, # Allow lower top_p
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling threshold",
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top-k",
info="Limit vocabulary choices to top K tokens",
)
repeat_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.05, # Finer control
label="Repetition Penalty",
info="Penalize repeated words (higher = less repetition)",
)
def streaming_output(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
"""Wraps the generator for Gradio."""
generated_text = ""
for text_chunk in generate_text(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty):
generated_text += text_chunk
yield generated_text
submit_button.click(
streaming_output,
[
input_textbox,
model_dropdown,
max_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
repeat_penalty_slider,
],
output_textbox,
)
clear_button.click(clear_history, [], input_textbox)
if __name__ == "__main__":
demo.launch(debug=False, share=False) # Added share=False for clearer local-only run.