|
--- |
|
license: gpl-3.0 |
|
language: |
|
- en |
|
base_model: |
|
- openai-community/gpt2 |
|
datasets: |
|
- gofilipa/heritage_gender |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
Fine-tuned on gpt-2, using commentaries from The Heritage Foundation on the topic of "[gender](https://www.heritage.org/gender?f%5B0%5D=content_type%3Acommentary)." |
|
|
|
Training setup: |
|
|
|
```python |
|
import torch |
|
from transformers import ( |
|
pipeline, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
) |
|
from datasets import load_dataset |
|
from trl import SFTTrainer, SFTConfig |
|
|
|
# Add this to monitor MPS memory usage |
|
def print_mps_memory(): |
|
if torch.backends.mps.is_available(): |
|
print(f"MPS allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB") |
|
print(f"MPS cached: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB") |
|
|
|
# Call this periodically during training |
|
print_mps_memory() |
|
|
|
# Check if MPS is available |
|
if torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
print("MPS device found.") |
|
else: |
|
device = torch.device("cpu") |
|
print("MPS device not found, using CPU.") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") |
|
model = model.to(device) # Move model to MPS |
|
|
|
ds = load_dataset("gofilipa/heritage_foundation-gender") |
|
|
|
# Limit dataset size for testing |
|
train_dataset = ds['train'].select(range(min(2000, len(ds['train'])))) # Use only first 2000 samples |
|
|
|
# Clear memory first |
|
if torch.backends.mps.is_available(): |
|
torch.mps.empty_cache() |
|
|
|
# Reduce training parameters for lower memory usage |
|
training_params = SFTConfig( |
|
output_dir="../checkpoints", |
|
per_device_train_batch_size=1, # Keep at 1 |
|
per_device_eval_batch_size=1, |
|
gradient_accumulation_steps=2, # Reduce from 4 to 2 |
|
num_train_epochs=3, # slowly increased as memory allows, from 1-3 |
|
learning_rate=2e-4, |
|
weight_decay=0.001, |
|
dataset_text_field="text", # Fixed: removed [:400] |
|
report_to="none", |
|
bf16=False, |
|
fp16=False, |
|
dataloader_pin_memory=False, |
|
remove_unused_columns=False, |
|
max_seq_length=512, # Add this to limit sequence length |
|
gradient_checkpointing=True, # Add this to save memory |
|
) |
|
|
|
trainer = SFTTrainer( |
|
model = model, |
|
train_dataset = train_dataset, |
|
processing_class = tokenizer, |
|
args = training_params |
|
) |
|
|
|
trainer.train() |
|
|
|
``` |