File size: 2,442 Bytes
54c652e
 
 
 
 
 
 
 
d7b5443
54c652e
 
d7b5443
07a36a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c652e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
---
license: gpl-3.0
language:
- en
base_model:
- openai-community/gpt2
datasets:
- gofilipa/heritage_gender
pipeline_tag: text-generation
---

Fine-tuned on gpt-2, using commentaries from The Heritage Foundation on the topic of "[gender](https://www.heritage.org/gender?f%5B0%5D=content_type%3Acommentary)."

Training setup:

```python
import torch
from transformers import (
    pipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

# Add this to monitor MPS memory usage
def print_mps_memory():
    if torch.backends.mps.is_available():
        print(f"MPS allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB")
        print(f"MPS cached: {torch.mps.driver_allocated_memory() / 1024**3:.2f} GB")

# Call this periodically during training
print_mps_memory()

# Check if MPS is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS device found.")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU.")

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
model = model.to(device)  # Move model to MPS

ds = load_dataset("gofilipa/heritage_foundation-gender")

# Limit dataset size for testing
train_dataset = ds['train'].select(range(min(2000, len(ds['train']))))  # Use only first 2000 samples

# Clear memory first
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

# Reduce training parameters for lower memory usage
training_params = SFTConfig(
    output_dir="../checkpoints",
    per_device_train_batch_size=1,  # Keep at 1
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,  # Reduce from 4 to 2
    num_train_epochs=3,  # slowly increased as memory allows, from 1-3
    learning_rate=2e-4,
    weight_decay=0.001,
    dataset_text_field="text",  # Fixed: removed [:400]
    report_to="none",
    bf16=False,
    fp16=False,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    max_seq_length=512,  # Add this to limit sequence length
    gradient_checkpointing=True,  # Add this to save memory
)

trainer = SFTTrainer(
    model = model,
    train_dataset = train_dataset,
    processing_class = tokenizer,
    args = training_params
)

trainer.train()

```