My Gemma-like Model from Scratch
This model is a custom implementation of a Gemma-like architecture, trained from scratch.
Training Details
- Architecture: A 18-layer decoder-only transformer with Grouped-Query Attention.
- Data: Trained on the Wikitext-2 dataset.
- Training Script: The training script is available on GitHub at https://github.com/your_github_repo.
- Parameters: Total trainable parameters: 330.64 million.
Checkpointing
The training script includes a checkpointing mechanism. It automatically saves the model's progress every 50 steps and at the end of each epoch to a file named checkpoint.pt
. You can resume training by simply re-running the script. The final model is saved as pytorch_model.bin
.
Early Stopping
To prevent overfitting, the training process includes early stopping based on the validation loss. The script will monitor the loss on a dedicated validation set and stop training if it does not improve for 2 consecutive epochs.
Loading and Chatting with the Model
Since this model uses a custom architecture, it requires the model class definitions from the training script to be loaded.
Here's a step-by-step guide to get started:
Install Required Libraries:
pip install torch huggingface-hub tokenizers
Copy the Model Architecture: Copy the
GemmaForCausalLM
and all its required sub-classes (RMSNorm
,RotaryPositionalEmbedding
,MultiHeadAttention
,MLP
,TransformerBlock
) from this training script into your new Python file.Load the Model and Tokenizer:
import torch from huggingface_hub import hf_hub_download from tokenizers import Tokenizer # Define your model's hyperparameters config = { "vocab_size": 30000, "hidden_size": 1024, "num_attention_heads": 8, "num_key_value_heads": 1, "num_layers": 18, "intermediate_size": 4096, "max_position_embeddings": 32768, "attention_dropout": 0.0, "hidden_dropout": 0.0, "sliding_window": 512, "device": "cuda" if torch.cuda.is_available() else "cpu" } # Instantiate the custom model and load the weights model = GemmaForCausalLM(config) model_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="pytorch_model.bin") model.load_state_dict(torch.load(model_path, map_location=config["device"])) model.to(config["device"]).eval() # Load the tokenizer tokenizer_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="tokenizer.json") tokenizer = Tokenizer.from_file(tokenizer_path)
Generate Text:
def generate_text(model, tokenizer, prompt, max_length=50): input_ids = tokenizer.encode(prompt).ids input_tensor = torch.tensor(input_ids).unsqueeze(0).to(config["device"]) with torch.no_grad(): for _ in range(max_length): logits, _ = model(input_tensor) next_token_logits = logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) input_tensor = torch.cat([input_tensor, next_token], dim=-1) # Stop if we generate the end-of-sentence token if next_token.item() == tokenizer.token_to_id("</s>"): break return tokenizer.decode(input_tensor[0].tolist(), skip_special_tokens=True) # Example usage prompt = "The early bird catches the worm, but the second mouse gets the " generated_text = generate_text(model, tokenizer, prompt) print("Generated Text:") print(generated_text)
Note: This model is for demonstration purposes. Its custom architecture is not directly compatible with the Hugging Face
transformers
library out-of-the-box. To use the model, you must also include the full model class definitions in your script.