SinclairSchneider commited on
Commit
37c517d
·
verified ·
1 Parent(s): 7a4663e

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +73 -0
train.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import DebertaV2ForSequenceClassification, DebertaV2Tokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
4
+ from tqdm import tqdm
5
+ from datasets import Dataset, load_dataset
6
+ import numpy as np
7
+ import wandb
8
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
9
+
10
+ output_dir = './german_politic_DeBERTa-v2-base'
11
+ model_name = "ikim-uk-essen/geberta-base"
12
+ max_length = 512
13
+ id2label = {0: 'other', 1: 'politic'}
14
+ label2id = {'other': 0, 'politic': 1}
15
+
16
+ wandb.init(project="german_politic_yes_no_classifier", entity="xxx", name="german_politic_DeBERTa")
17
+
18
+ model = DebertaV2ForSequenceClassification.from_pretrained(model_name, num_labels = 2, id2label=id2label, label2id=label2id, output_attentions = False, output_hidden_states = False)
19
+ tokenizer = DebertaV2Tokenizer.from_pretrained(model_name, do_lower_case=False, max_length = max_length, TOKENIZERS_PARALLELISM=True)
20
+
21
+ dataset = load_dataset("SinclairSchneider/trainset_political_text_yes_no_german")
22
+ dataset = dataset['train'].train_test_split(0.2)
23
+
24
+ def preprocess(sample):
25
+ return tokenizer(sample["text"], truncation=True)
26
+
27
+ dataset_tokenized = dataset.map(preprocess, batched = True)
28
+
29
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
30
+
31
+ def compute_metrics(pred):
32
+ labels = pred.label_ids
33
+ preds = pred.predictions.argmax(-1)
34
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
35
+ acc = accuracy_score(labels, preds)
36
+ return {
37
+ 'accuracy': acc,
38
+ 'f1': f1,
39
+ 'precision': precision,
40
+ 'recall': recall
41
+ }
42
+
43
+ training_args = TrainingArguments(
44
+ output_dir = output_dir,
45
+ learning_rate=2e-5,
46
+ per_device_train_batch_size=16,
47
+ per_device_eval_batch_size=16,
48
+ num_train_epochs=4,
49
+ weight_decay=0.01,
50
+ evaluation_strategy="epoch",
51
+ save_strategy="epoch",
52
+ load_best_model_at_end=True,
53
+ report_to="wandb",
54
+ fp16 = False,
55
+ logging_steps = 8,
56
+ disable_tqdm = False,
57
+ )
58
+
59
+ trainer = Trainer(
60
+ model=model,
61
+ args=training_args,
62
+ train_dataset=dataset_tokenized["train"],
63
+ eval_dataset=dataset_tokenized["test"],
64
+ tokenizer=tokenizer,
65
+ data_collator=data_collator,
66
+ compute_metrics=compute_metrics,
67
+ )
68
+
69
+ trainer.train()
70
+
71
+ model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
72
+ model_to_save.save_pretrained(output_dir)
73
+ tokenizer.save_pretrained(output_dir)