My Gemma-like Model from Scratch

This model is a custom implementation of a Gemma-like architecture, trained from scratch.

Training Details

  • Architecture: A 18-layer decoder-only transformer with Grouped-Query Attention.
  • Data: Trained on the Wikitext-2 dataset.
  • Training Script: The training script is available on GitHub at https://github.com/your_github_repo.
  • Parameters: Total trainable parameters: 330.64 million.

Checkpointing

The training script includes a checkpointing mechanism. It automatically saves the model's progress every 50 steps and at the end of each epoch to a file named checkpoint.pt. You can resume training by simply re-running the script. The final model is saved as pytorch_model.bin.

Early Stopping

To prevent overfitting, the training process includes early stopping based on the validation loss. The script will monitor the loss on a dedicated validation set and stop training if it does not improve for 2 consecutive epochs.

Loading and Chatting with the Model

Since this model uses a custom architecture, it requires the model class definitions from the training script to be loaded.

Here's a step-by-step guide to get started:

  1. Install Required Libraries:

    pip install torch huggingface-hub tokenizers
    
  2. Copy the Model Architecture: Copy the GemmaForCausalLM and all its required sub-classes (RMSNorm, RotaryPositionalEmbedding, MultiHeadAttention, MLP, TransformerBlock) from this training script into your new Python file.

  3. Load the Model and Tokenizer:

    import torch
    from huggingface_hub import hf_hub_download
    from tokenizers import Tokenizer
    
    # Define your model's hyperparameters
    config = {
        "vocab_size": 30000,
        "hidden_size": 1024,
        "num_attention_heads": 8,
        "num_key_value_heads": 1,
        "num_layers": 18,
        "intermediate_size": 4096,
        "max_position_embeddings": 32768,
        "attention_dropout": 0.0,
        "hidden_dropout": 0.0,
        "sliding_window": 512,
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    }
    
    # Instantiate the custom model and load the weights
    model = GemmaForCausalLM(config)
    model_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="pytorch_model.bin")
    model.load_state_dict(torch.load(model_path, map_location=config["device"]))
    model.to(config["device"]).eval()
    
    # Load the tokenizer
    tokenizer_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="tokenizer.json")
    tokenizer = Tokenizer.from_file(tokenizer_path)
    
  4. Generate Text:

    def generate_text(model, tokenizer, prompt, max_length=50):
        input_ids = tokenizer.encode(prompt).ids
        input_tensor = torch.tensor(input_ids).unsqueeze(0).to(config["device"])
        
        with torch.no_grad():
            for _ in range(max_length):
                logits, _ = model(input_tensor)
                next_token_logits = logits[:, -1, :]
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
                input_tensor = torch.cat([input_tensor, next_token], dim=-1)
                
                # Stop if we generate the end-of-sentence token
                if next_token.item() == tokenizer.token_to_id("</s>"):
                    break
        
        return tokenizer.decode(input_tensor[0].tolist(), skip_special_tokens=True)
    
    # Example usage
    prompt = "The early bird catches the worm, but the second mouse gets the "
    generated_text = generate_text(model, tokenizer, prompt)
    print("Generated Text:")
    print(generated_text)
    

Note: This model is for demonstration purposes. Its custom architecture is not directly compatible with the Hugging Face transformers library out-of-the-box. To use the model, you must also include the full model class definitions in your script.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support