Fill-Mask
Transformers
Safetensors
PyTorch
Kazakh
Russian
English
bert
KazBERT / script.py
Eraly-ml's picture
training pipline
2171bc2 verified
raw
history blame
5.85 kB
# %% [code]
import os
import math
import torch
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
Trainer,
TrainingArguments,
)
from datasets import load_dataset
# Отключаем параллелизм токенизатора, чтобы избежать ворнингов
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Если запускаем с DDP, инициализуем процессную группу NCCL
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
else:
local_rank = -1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================================
# 1. Загрузка токенизатора и тестирование
# ================================
# Здесь загружается токенизатор из указанного пути
tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/kaz-eng-rus/pytorch/default/1")
# Пробное токенизирование
test_text = "Қазақ тілі өте әдемі."
tokens = tokenizer.tokenize(test_text)
ids = tokenizer.encode(test_text)
print(f"Tokens: {tokens}")
print(f"IDs: {ids}")
# ================================
# 2. Загрузка датасета для предобучения
# ================================
# Загрузка JSON датасета, где каждая строка содержит поля 'original_sentence' и 'masked_sentence'
dataset = load_dataset("json", data_files="/kaggle/input/kaz-rus-eng-wiki/train_pretrain.json")
print("Первый пример из датасета:", dataset["train"][0])
# ================================
# 3. Загрузка модели
# ================================
# Загружаем базовую модель BERT для Masked LM
model = AutoModelForMaskedLM.from_pretrained("bert-base-multilingual-cased")
model.to(device)
# ================================
# 4. Подготовка данных: токенизация и создание меток (labels)
# ================================
def preprocess_dataset(examples):
# Токенизация замаскированного текста
inputs = tokenizer(
examples["masked_sentence"],
truncation=True,
padding="max_length",
max_length=128,
)
# Токенизация оригинального текста для формирования labels
originals = tokenizer(
examples["original_sentence"],
truncation=True,
padding="max_length",
max_length=128,
)["input_ids"]
# Получаем id специального токена [MASK]
mask_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
# Формируем метки: если токен не [MASK], то игнорируем (-100)
labels = [
[-100 if token_id != mask_token_id else orig_id
for token_id, orig_id in zip(input_ids, original_ids)]
for input_ids, original_ids in zip(inputs["input_ids"], originals)
]
inputs["labels"] = labels
return inputs
# Токенизируем датасет (batched для ускорения)
tokenized_datasets = dataset.map(
preprocess_dataset,
batched=True,
remove_columns=dataset["train"].column_names,
batch_size=1000
)
# ================================
# 5. Настройка обучения
# ================================
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=20, # Размер батча на один GPU
num_train_epochs=3,
weight_decay=0.01,
save_strategy="epoch",
fp16=True, # Используем mixed precision
dataloader_num_workers=4, # Количество воркеров для загрузчика данных
report_to="none", # Отключаем отчёты (wandb и т.п.)
)
# Создаем Trainer; если скрипт запущен через torchrun, Trainer автоматически использует DDP
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
)
# ================================
# 6. Обучение модели
# ================================
trainer.train()
# ================================
# 7. Сохранение модели и токенизатора
# ================================
output_dir = "./KazBERT"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Модель сохранена в {output_dir}")
# ================================
# 8. Вычисление Perplexity на валидационном датасете
# ================================
# Загружаем валидационный датасет как текстовый (формат "text")
valid_dataset = load_dataset("text", data_files="/kaggle/input/kaz-rus-eng-wiki/valid.txt", split="train[:1%]")
def compute_perplexity(model, tokenizer, text):
# Токенизируем текст и отправляем на нужное устройство
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
return math.exp(loss.item())
# Вычисляем perplexity для каждого примера и выводим среднее значение
ppl_scores = [compute_perplexity(model, tokenizer, sample["text"]) for sample in valid_dataset]
avg_ppl = sum(ppl_scores) / len(ppl_scores)
print(f"Perplexity модели: {avg_ppl:.2f}")