Spaces:
Running
Running
| # Importing required libraries | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import os | |
| import sys | |
| from llama_cpp import Llama | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from typing import List, Tuple, Generator | |
| from logger import logging # Assuming you have a logger.py | |
| from exception import CustomExceptionHandling # Assuming you have exception.py | |
| import spaces | |
| # Download gguf model files (Simplified for the specified models) | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Ensure token is set | |
| def download_model(repo_id, filename): | |
| try: | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_dir="./models", | |
| token=huggingface_token, # Use token directly | |
| ) | |
| logging.info(f"Successfully downloaded {filename} from {repo_id}") | |
| except Exception as e: | |
| logging.error(f"Error downloading {filename} from {repo_id}: {e}") | |
| raise # Re-raise to halt execution if download fails | |
| # Only download if the files don't already exist. This is crucial. | |
| if not os.path.exists("./models/google.gemma-3-1b-pt.Q4_K_M.gguf"): | |
| download_model("DevQuasar/google.gemma-3-1b-pt-GGUF", "google.gemma-3-1b-pt.Q4_K_M.gguf") | |
| if not os.path.exists("./models/google.gemma-3-12b-pt.Q4_K_M.gguf"): | |
| download_model("DevQuasar/google.gemma-3-12b-pt-GGUF", "google.gemma-3-12b-pt.Q4_K_M.gguf") | |
| if not os.path.exists("./models/google.gemma-3-4b-pt.Q4_K_M.gguf"): # Example from original, in case needed. | |
| download_model("DevQuasar/google.gemma-3-4b-pt-GGUF", "google.gemma-3-4b-pt.Q4_K_M.gguf") | |
| if not os.path.exists("./models/google.gemma-3-27b-pt.Q4_K_M.gguf"): # Example from original, in case needed. | |
| download_model("DevQuasar/google.gemma-3-27b-pt-GGUF", "google.gemma-3-27b-pt.Q4_K_M.gguf") | |
| # Set the title and description | |
| title = "Gemma Text Generation" | |
| description = """Gemma models for text generation and notebook continuation. This interface is designed for generating text continuations, not for interactive chat.""" | |
| llm = None | |
| llm_model = None | |
| def generate_text( | |
| prompt: str, | |
| model: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repeat_penalty: float, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Generates text based on a prompt, using the specified Gemma model. | |
| Args: | |
| prompt (str): The initial text to continue. | |
| model (str): The model file to use (without path). | |
| max_tokens (int): Maximum number of tokens to generate. | |
| temperature (float): Controls randomness. | |
| top_p (float): Nucleus sampling parameter. | |
| top_k (int): Top-k sampling parameter. | |
| repeat_penalty (float): Penalty for repeating tokens. | |
| Yields: | |
| str: Generated text chunks, streamed as they become available. | |
| """ | |
| try: | |
| global llm | |
| global llm_model | |
| model_path = os.path.join("models", model) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found: {model_path}") | |
| # Load the model (only if it's a new model) | |
| if llm is None or llm_model != model: | |
| logging.info(f"Loading model: {model}") | |
| llm = Llama( | |
| model_path=model_path, | |
| flash_attn=True, | |
| n_gpu_layers=999, # Adjust based on your GPU availability | |
| n_ctx=4096, # Context window size. Can increase. | |
| verbose=False #Reduce unnecessary verbosity | |
| ) | |
| llm_model = model | |
| # llama_cpp handles streaming natively. | |
| for token in llm( | |
| prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repeat_penalty=repeat_penalty, | |
| stream=True, # Ensure streaming is on | |
| stop=["<|im_end|>","<|endoftext|>","<|file_separator|>"], # Add appropriate stop tokens | |
| ): | |
| text_chunk = token["choices"][0]["text"] | |
| yield text_chunk | |
| except Exception as e: | |
| raise CustomExceptionHandling(e, sys) from e | |
| def clear_history(): | |
| """Clears the text input.""" | |
| return "" | |
| with gr.Blocks(theme="Ocean", title=title) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| input_textbox = gr.Textbox( | |
| label="Input Prompt", | |
| placeholder="Enter text to continue...", | |
| lines=10, | |
| ) | |
| clear_button = gr.Button("Clear Input") | |
| output_textbox = gr.Textbox( # Changed to Textbox for streaming | |
| label="Generated Text", | |
| lines=10, # Added lines for better display of longer outputs | |
| interactive=False # Output shouldn't be editable | |
| ) | |
| with gr.Column(scale=1): | |
| submit_button = gr.Button("Generate", variant="primary") | |
| model_dropdown = gr.Dropdown( | |
| choices=[ | |
| "google.gemma-3-1b-pt.Q4_K_M.gguf", | |
| "google.gemma-3-4b-pt.Q4_K_M.gguf", | |
| "google.gemma-3-12b-pt.Q4_K_M.gguf", | |
| "google.gemma-3-27b-pt.Q4_K_M.gguf", | |
| # Add other models as needed and downloaded | |
| ], | |
| value="google.gemma-3-1b-pt.Q4_K_M.gguf", # Default model | |
| label="Model", | |
| info="Select the AI model", | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=32, | |
| maximum=8192, | |
| value=512, | |
| step=1, | |
| label="Max Tokens", | |
| info="Maximum length of generated text", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.05, # finer control | |
| label="Temperature", | |
| info="Controls randomness (higher = more creative)", | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.05, # Allow lower top_p | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling threshold", | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=40, | |
| step=1, | |
| label="Top-k", | |
| info="Limit vocabulary choices to top K tokens", | |
| ) | |
| repeat_penalty_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.05, # Finer control | |
| label="Repetition Penalty", | |
| info="Penalize repeated words (higher = less repetition)", | |
| ) | |
| def streaming_output(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty): | |
| """Wraps the generator for Gradio.""" | |
| generated_text = "" | |
| for text_chunk in generate_text(prompt, model, max_tokens, temperature, top_p, top_k, repeat_penalty): | |
| generated_text += text_chunk | |
| yield generated_text | |
| submit_button.click( | |
| streaming_output, | |
| [ | |
| input_textbox, | |
| model_dropdown, | |
| max_tokens_slider, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| repeat_penalty_slider, | |
| ], | |
| output_textbox, | |
| ) | |
| clear_button.click(clear_history, [], input_textbox) | |
| if __name__ == "__main__": | |
| demo.launch(debug=False, share=False) # Added share=False for clearer local-only run. |