Fill-Mask
Transformers
Safetensors
PyTorch
Kazakh
Russian
English
bert
Eraly-ml commited on
Commit
2171bc2
·
verified ·
1 Parent(s): 4fd097f

training pipline

Browse files
Files changed (1) hide show
  1. script.py +144 -0
script.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ import os
3
+ import math
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForMaskedLM,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ from datasets import load_dataset
12
+
13
+ # Отключаем параллелизм токенизатора, чтобы избежать ворнингов
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+ # Если запускаем с DDP, инициализуем процессную группу NCCL
17
+ if "LOCAL_RANK" in os.environ:
18
+ local_rank = int(os.environ["LOCAL_RANK"])
19
+ torch.distributed.init_process_group(backend="nccl")
20
+ device = torch.device("cuda", local_rank)
21
+ torch.cuda.set_device(device)
22
+ else:
23
+ local_rank = -1
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # ================================
27
+ # 1. Загрузка токенизатора и тестирование
28
+ # ================================
29
+ # Здесь загружается токенизатор из указанного пути
30
+ tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/kaz-eng-rus/pytorch/default/1")
31
+
32
+ # Пробное токенизирование
33
+ test_text = "Қазақ тілі өте әдемі."
34
+ tokens = tokenizer.tokenize(test_text)
35
+ ids = tokenizer.encode(test_text)
36
+ print(f"Tokens: {tokens}")
37
+ print(f"IDs: {ids}")
38
+
39
+ # ================================
40
+ # 2. Загрузка датасета для предобучения
41
+ # ================================
42
+ # Загрузка JSON датасета, где каждая строка содержит поля 'original_sentence' и 'masked_sentence'
43
+ dataset = load_dataset("json", data_files="/kaggle/input/kaz-rus-eng-wiki/train_pretrain.json")
44
+ print("Первый пример из датасета:", dataset["train"][0])
45
+
46
+ # ================================
47
+ # 3. Загрузка модели
48
+ # ================================
49
+ # Загружаем базовую модель BERT для Masked LM
50
+ model = AutoModelForMaskedLM.from_pretrained("bert-base-multilingual-cased")
51
+ model.to(device)
52
+
53
+ # ================================
54
+ # 4. Подготовка данных: токенизация и создание меток (labels)
55
+ # ================================
56
+ def preprocess_dataset(examples):
57
+ # Токенизация замаскированного текста
58
+ inputs = tokenizer(
59
+ examples["masked_sentence"],
60
+ truncation=True,
61
+ padding="max_length",
62
+ max_length=128,
63
+ )
64
+ # Токенизация оригинального текста для формирования labels
65
+ originals = tokenizer(
66
+ examples["original_sentence"],
67
+ truncation=True,
68
+ padding="max_length",
69
+ max_length=128,
70
+ )["input_ids"]
71
+
72
+ # Получаем id специального токена [MASK]
73
+ mask_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
74
+
75
+ # Формируем метки: если токен не [MASK], то игнорируем (-100)
76
+ labels = [
77
+ [-100 if token_id != mask_token_id else orig_id
78
+ for token_id, orig_id in zip(input_ids, original_ids)]
79
+ for input_ids, original_ids in zip(inputs["input_ids"], originals)
80
+ ]
81
+ inputs["labels"] = labels
82
+ return inputs
83
+
84
+ # Токенизируем датасет (batched для ускорения)
85
+ tokenized_datasets = dataset.map(
86
+ preprocess_dataset,
87
+ batched=True,
88
+ remove_columns=dataset["train"].column_names,
89
+ batch_size=1000
90
+ )
91
+
92
+ # ================================
93
+ # 5. Настройка обучения
94
+ # ================================
95
+ training_args = TrainingArguments(
96
+ output_dir="./results",
97
+ per_device_train_batch_size=20, # Размер батча на один GPU
98
+ num_train_epochs=3,
99
+ weight_decay=0.01,
100
+ save_strategy="epoch",
101
+ fp16=True, # Используем mixed precision
102
+ dataloader_num_workers=4, # Количество воркеров для загрузчика данных
103
+ report_to="none", # Отключаем отчёты (wandb и т.п.)
104
+ )
105
+
106
+ # Создаем Trainer; если скрипт запущен через torchrun, Trainer автоматически использует DDP
107
+ trainer = Trainer(
108
+ model=model,
109
+ args=training_args,
110
+ train_dataset=tokenized_datasets["train"],
111
+ )
112
+
113
+ # ================================
114
+ # 6. Обучение модели
115
+ # ================================
116
+ trainer.train()
117
+
118
+ # ================================
119
+ # 7. Сохранение модели и токенизатора
120
+ # ================================
121
+ output_dir = "./KazBERT"
122
+ model.save_pretrained(output_dir)
123
+ tokenizer.save_pretrained(output_dir)
124
+ print(f"Модель сохранена в {output_dir}")
125
+
126
+ # ================================
127
+ # 8. Вычисление Perplexity на валидационном датасете
128
+ # ================================
129
+ # Загружаем валидационный датасет как текстовый (формат "text")
130
+ valid_dataset = load_dataset("text", data_files="/kaggle/input/kaz-rus-eng-wiki/valid.txt", split="train[:1%]")
131
+
132
+ def compute_perplexity(model, tokenizer, text):
133
+ # Токенизируем текст и отправляем на нужное устройство
134
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
135
+ with torch.no_grad():
136
+ outputs = model(**inputs, labels=inputs["input_ids"])
137
+ loss = outputs.loss
138
+ return math.exp(loss.item())
139
+
140
+ # Вычисляем perplexity для каждого примера и выводим среднее значение
141
+ ppl_scores = [compute_perplexity(model, tokenizer, sample["text"]) for sample in valid_dataset]
142
+ avg_ppl = sum(ppl_scores) / len(ppl_scores)
143
+ print(f"Perplexity модели: {avg_ppl:.2f}")
144
+