asaak commited on
Commit
b95938c
·
verified ·
1 Parent(s): 14983c1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,13 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ BERTley/checkpoint-3486/optimizer.pt filter=lfs diff=lfs merge=lfs -text
2
+ BERTley/checkpoint-3486/model.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ BERTley/checkpoint-3486/rng_state.pth filter=lfs diff=lfs merge=lfs -text
4
+ BERTley/checkpoint-3486/scaler.pt filter=lfs diff=lfs merge=lfs -text
5
+ BERTley/checkpoint-3486/scheduler.pt filter=lfs diff=lfs merge=lfs -text
6
+ BERTley/checkpoint-3486/training_args.bin filter=lfs diff=lfs merge=lfs -text
7
+ aggregate_data_new.json filter=lfs diff=lfs merge=lfs -text
8
+ flattened_data_new.json filter=lfs diff=lfs merge=lfs -text
9
+ logs/events.out.tfevents.1745325885.ASAAK.454713.0 filter=lfs diff=lfs merge=lfs -text
10
+ logs/events.out.tfevents.1745327045.ASAAK.459272.0 filter=lfs diff=lfs merge=lfs -text
11
+ logs/events.out.tfevents.1745327083.ASAAK.459790.0 filter=lfs diff=lfs merge=lfs -text
12
+ logs/events.out.tfevents.1745336746.ASAAK.3038.0 filter=lfs diff=lfs merge=lfs -text
13
+ logs/events.out.tfevents.1745339646.ASAAK.3038.1 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BERTley/checkpoint-3486/config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "title",
13
+ "1": "creator",
14
+ "2": "subject",
15
+ "3": "description",
16
+ "4": "publisher",
17
+ "5": "date",
18
+ "6": "type",
19
+ "7": "format",
20
+ "8": "identifier",
21
+ "9": "source",
22
+ "10": "language",
23
+ "11": "relation",
24
+ "12": "rights",
25
+ "13": "contributor",
26
+ "14": "coverage"
27
+ },
28
+ "initializer_range": 0.02,
29
+ "intermediate_size": 3072,
30
+ "label2id": {
31
+ "contributor": 13,
32
+ "coverage": 14,
33
+ "creator": 1,
34
+ "date": 5,
35
+ "description": 3,
36
+ "format": 7,
37
+ "identifier": 8,
38
+ "language": 10,
39
+ "publisher": 4,
40
+ "relation": 11,
41
+ "rights": 12,
42
+ "source": 9,
43
+ "subject": 2,
44
+ "title": 0,
45
+ "type": 6
46
+ },
47
+ "layer_norm_eps": 1e-12,
48
+ "max_position_embeddings": 512,
49
+ "model_type": "bert",
50
+ "num_attention_heads": 12,
51
+ "num_hidden_layers": 12,
52
+ "pad_token_id": 0,
53
+ "position_embedding_type": "absolute",
54
+ "problem_type": "single_label_classification",
55
+ "torch_dtype": "float32",
56
+ "transformers_version": "4.51.3",
57
+ "type_vocab_size": 2,
58
+ "use_cache": true,
59
+ "vocab_size": 30522
60
+ }
BERTley/checkpoint-3486/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ceda3b4434156eda36e5a285109641bb6170eec0ddd4c2135f30bd0f888a61b
3
+ size 437998636
BERTley/checkpoint-3486/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83878ca8b20ff6da9799cafaaaa45dc7829a740a0da16fbb09c5269de415ba4d
3
+ size 876118266
BERTley/checkpoint-3486/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23c8b32ae2d9c1fdd446eb0fc7feaa5ff83ac918bcd8cac4fc48eb9ac556fc20
3
+ size 14244
BERTley/checkpoint-3486/scaler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f5d223e4ff9b8a8e2eeb4634f6357475ca21a1839dc2aea2311703606095889
3
+ size 988
BERTley/checkpoint-3486/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:312e95cf036dd799f8eb6cb24acb861d3df4f013619dc384a6d3fda416a01a61
3
+ size 1064
BERTley/checkpoint-3486/trainer_state.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": 3486,
3
+ "best_metric": 0.13383758068084717,
4
+ "best_model_checkpoint": "./BERTley/checkpoint-3486",
5
+ "epoch": 3.0,
6
+ "eval_steps": 500,
7
+ "global_step": 3486,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 1.0,
14
+ "grad_norm": 1.295055866241455,
15
+ "learning_rate": 1.6000000000000003e-05,
16
+ "loss": 0.2881,
17
+ "step": 1162
18
+ },
19
+ {
20
+ "epoch": 1.0,
21
+ "eval_accuracy": 0.9628174773999139,
22
+ "eval_f1_macro": 0.7826255355836492,
23
+ "eval_f1_weighted": 0.9591674021751974,
24
+ "eval_loss": 0.13707949221134186,
25
+ "eval_precision_macro": 0.861103361463851,
26
+ "eval_precision_weighted": 0.9600345517781586,
27
+ "eval_recall_macro": 0.7586204977343242,
28
+ "eval_recall_weighted": 0.9628174773999139,
29
+ "eval_runtime": 38.1835,
30
+ "eval_samples_per_second": 486.702,
31
+ "eval_steps_per_second": 15.216,
32
+ "step": 1162
33
+ },
34
+ {
35
+ "epoch": 2.0,
36
+ "grad_norm": 1.1426628828048706,
37
+ "learning_rate": 1.1996554694229114e-05,
38
+ "loss": 0.1215,
39
+ "step": 2324
40
+ },
41
+ {
42
+ "epoch": 2.0,
43
+ "eval_accuracy": 0.9628712871287128,
44
+ "eval_f1_macro": 0.7999810398245889,
45
+ "eval_f1_weighted": 0.9599255667502176,
46
+ "eval_loss": 0.1391589641571045,
47
+ "eval_precision_macro": 0.8460043932241436,
48
+ "eval_precision_weighted": 0.9611447335915431,
49
+ "eval_recall_macro": 0.7799925171868475,
50
+ "eval_recall_weighted": 0.9628712871287128,
51
+ "eval_runtime": 37.2114,
52
+ "eval_samples_per_second": 499.416,
53
+ "eval_steps_per_second": 15.613,
54
+ "step": 2324
55
+ },
56
+ {
57
+ "epoch": 3.0,
58
+ "grad_norm": 0.5453416109085083,
59
+ "learning_rate": 7.996554694229113e-06,
60
+ "loss": 0.0962,
61
+ "step": 3486
62
+ },
63
+ {
64
+ "epoch": 3.0,
65
+ "eval_accuracy": 0.9665303486870426,
66
+ "eval_f1_macro": 0.8283399932657282,
67
+ "eval_f1_weighted": 0.9627793236203548,
68
+ "eval_loss": 0.13383758068084717,
69
+ "eval_precision_macro": 0.8550754109057547,
70
+ "eval_precision_weighted": 0.9649881228855073,
71
+ "eval_recall_macro": 0.8224649187170903,
72
+ "eval_recall_weighted": 0.9665303486870426,
73
+ "eval_runtime": 38.4396,
74
+ "eval_samples_per_second": 483.46,
75
+ "eval_steps_per_second": 15.115,
76
+ "step": 3486
77
+ }
78
+ ],
79
+ "logging_steps": 500,
80
+ "max_steps": 5805,
81
+ "num_input_tokens_seen": 0,
82
+ "num_train_epochs": 5,
83
+ "save_steps": 500,
84
+ "stateful_callbacks": {
85
+ "EarlyStoppingCallback": {
86
+ "args": {
87
+ "early_stopping_patience": 2,
88
+ "early_stopping_threshold": 0.0
89
+ },
90
+ "attributes": {
91
+ "early_stopping_patience_counter": 0
92
+ }
93
+ },
94
+ "TrainerControl": {
95
+ "args": {
96
+ "should_epoch_stop": false,
97
+ "should_evaluate": false,
98
+ "should_log": false,
99
+ "should_save": true,
100
+ "should_training_stop": false
101
+ },
102
+ "attributes": {}
103
+ }
104
+ },
105
+ "total_flos": 5.868114013364429e+16,
106
+ "train_batch_size": 32,
107
+ "trial_name": null,
108
+ "trial_params": null
109
+ }
BERTley/checkpoint-3486/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34e57b49a6e794569db9757869c064e3b5216e981459e618f8475252b0a417b8
3
+ size 5304
aggregate_data_new.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68ca695b907f854bb60a51338cd80fdb3696ee17ffbc1d1f2ea313daa65afe80
3
+ size 11406348
bertley.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
+ pipeline,
7
+ )
8
+
9
+
10
+ def chunk_and_classify(text, classifier, tokenizer, max_len=512, stride=50):
11
+ """
12
+ Splits a given text into overlapping chunks, classifies each chunk using a
13
+ provided classifier, and computes the average classification scores for
14
+ each label across all chunks.
15
+
16
+ Args:
17
+ text (str): The input text to be chunked and classified.
18
+ classifier (Callable): A function or model that takes a text input and
19
+ returns a list of dictionaries containing classification labels and scores.
20
+ tokenizer (Callable): A tokenizer function or model that tokenizes the input
21
+ text and provides token IDs.
22
+ max_len (int, optional): The maximum length of each chunk in tokens. Defaults to 512.
23
+ stride (int, optional): The number of tokens to overlap between consecutive chunks.
24
+ Defaults to 50.
25
+
26
+ Returns:
27
+ dict: A dictionary where keys are classification labels and values are the
28
+ average scores for each label across all chunks.
29
+ """
30
+ # tokenize entire doc once
31
+ tokens = tokenizer(text, return_tensors="pt")["input_ids"][0]
32
+ chunks = []
33
+ for i in range(0, tokens.size(0), max_len - stride):
34
+ chunk_ids = tokens[i : i + max_len]
35
+ chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))
36
+ if i + max_len >= tokens.size(0):
37
+ break
38
+
39
+ # classify each chunk
40
+ chunk_scores = []
41
+ for chunk in chunks:
42
+ scores = classifier(chunk)[0] # list of {label, score}
43
+ chunk_scores.append({d["label"]: d["score"] for d in scores})
44
+
45
+ # average scores per label
46
+ avg_scores = {
47
+ label: sum(s[label] for s in chunk_scores) / len(chunk_scores)
48
+ for label in chunk_scores[0]
49
+ }
50
+ return avg_scores
51
+
52
+
53
+ def main():
54
+
55
+ # This initial set of lines defines the command line arguments this
56
+ # program uses
57
+
58
+ default_dir = "~/Code/Huggingface-metadata-project/BERTley/checkpoint-3486"
59
+ parser = argparse.ArgumentParser(
60
+ description="Run inference on a trained BERT metadata classifier"
61
+ )
62
+ parser.add_argument(
63
+ "--model_dir",
64
+ type=str,
65
+ default=default_dir,
66
+ help="Directory where your trained model and config live",
67
+ )
68
+ group = parser.add_mutually_exclusive_group(required=True)
69
+ group.add_argument("--text", type=str, help="Raw text string to classify")
70
+ group.add_argument(
71
+ "--input_file",
72
+ type=str,
73
+ help="Path to a .txt file containing the document to classify",
74
+ )
75
+ args = parser.parse_args()
76
+
77
+ # 1) Load tokenizer + model (config.json should have the id2label/label2id baked in
78
+ # thru training script)
79
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
80
+ model = AutoModelForSequenceClassification.from_pretrained(args.model_dir)
81
+
82
+ # 2) Build the pipeline...
83
+ classifier = pipeline(
84
+ "text-classification",
85
+ model=model,
86
+ tokenizer=tokenizer,
87
+ return_all_scores=True,
88
+ )
89
+
90
+ # 3) Read your document
91
+ if args.input_file:
92
+ text = open(args.input_file, "r", encoding="utf-8").read()
93
+ else:
94
+ text = args.text
95
+
96
+ # If it’s longer than 512 tokens, needs to be chunked + classified
97
+ # otherwise single call
98
+ tokens = tokenizer(text, return_tensors="pt")["input_ids"]
99
+ if tokens.size(1) <= 512:
100
+ result = classifier(text)[0]
101
+ scores = {d["label"]: d["score"] for d in result}
102
+ else:
103
+ scores = chunk_and_classify(text, classifier, tokenizer)
104
+
105
+ # print scores
106
+ print(json.dumps(scores, indent=2))
107
+
108
+
109
+ if __name__ == "__main__":
110
+ main()
flattened_data_new.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f6282969725f458871e79bcb3ef0afd352d6ef8d322e46ab94afa891fcc89bf
3
+ size 15205462
logs/events.out.tfevents.1745325885.ASAAK.454713.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f8a17e5a1ba9177837ba8d03a9406acca8b75b33daa67fcab4cdc20d15ad39a
3
+ size 5530
logs/events.out.tfevents.1745327045.ASAAK.459272.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:160ac4e1dba17e282acd5cd7f02f389e4921a13110cdcb427d1a61956e87132c
3
+ size 5530
logs/events.out.tfevents.1745327083.ASAAK.459790.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f584869f4d7d13fa3733a6a565bd0d59d026e5a4d5b7d6212c0c3873ddf836db
3
+ size 5530
logs/events.out.tfevents.1745336746.ASAAK.3038.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ef1b4fd4d78d8e23073857aa0cc2c8e6c94c67a10b3e29c1d0a68d8ffaa8f10
3
+ size 10269
logs/events.out.tfevents.1745339646.ASAAK.3038.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:585293f3f0731679a15b961c3c9947138d3b9341df54d0144fae04dfd3578174
3
+ size 754
summary.tex ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass[conference]{IEEEtran}
2
+ \IEEEoverridecommandlockouts
3
+
4
+ \title{Training BERT-Base-Uncased to Classify Descriptive Metadata}
5
+
6
+ \author{
7
+ \IEEEauthorblockN{Artem Saakov}
8
+ \IEEEauthorblockA{
9
+ University of Michigan\\
10
+ School of Information\\
11
+ United States\\
12
13
+ }
14
+ }
15
+
16
+ \begin{document}
17
+ \maketitle
18
+
19
+ \begin{abstract}
20
+ Libraries and archives frequently receive donor-supplied metadata in unstructured or inconsistent formats, creating backlogs in accession workflows. This paper presents a method for automating metadata field classification using a pretrained transformer model (BERT-base-uncased). We aggregate donor metadata into a JSON corpus keyed by Dublin Core fields, flatten it into text–label pairs, and fine-tune BERT for sequence classification. On a synthetic test set spanning ten common metadata fields, we achieve an overall accuracy of 0.92. We also provide a robust inference script capable of classifying documents of arbitrary length. Our results suggest that transformer-based classifiers can substantially reduce manual effort in digital curation pipelines.
21
+ \end{abstract}
22
+
23
+ \begin{IEEEkeywords}
24
+ Metadata Classification, Digital Curation, Transformer Models, BERT, Text Classification, Archival Metadata, Natural Language Processing
25
+ \end{IEEEkeywords}
26
+
27
+ \section{Introduction}
28
+ Metadata underpins discovery, provenance, and preservation in digital archives. Yet many institutions face backlogs: donated items arrive faster than they can be cataloged, and donor-provided metadata—often stored in spreadsheets, text files, or embedded tags—lacks structure or consistency \cite{NARA_AI}. Manually mapping each snippet to standardized fields (e.g., Title, Date, Creator) is labor-intensive.
29
+
30
+ \subsection{Project Goal}
31
+ We investigate fine-tuning Google’s BERT-base-uncased model to automatically classify free-form metadata snippets into a fixed set of archival fields. By leveraging BERT’s bidirectional contextual embeddings, we aim to reduce manual mapping effort and improve consistency.
32
+
33
+ \subsection{Related Work}
34
+ The National Archives have explored AI for metadata tagging to improve public access \cite{NARA_AI}. Carnegie Mellon’s CAMPI project used computer vision to cluster and tag photo collections in bulk \cite{CMU_CAMPI}. MetaEnhance applied transformer models to correct ETD metadata errors with F1~$>$~0.85 on key fields \cite{MetaEnhance}. Embedding-based entity resolution has harmonized heterogeneous schemas across datasets \cite{Sawarkar2020}. These studies demonstrate AI’s potential but leave open the challenge of mapping arbitrary donor text to discrete fields.
35
+
36
+ \section{Method}
37
+ \subsection{Problem Formulation}
38
+ We cast metadata field mapping as single-label text classification:
39
+ \begin{itemize}
40
+ \item \textbf{Input:} free-form snippet $x$ (string).
41
+ \item \textbf{Output:} field label $y \in \{f_1, \dots, f_K\}$, each $f_i$ a target schema field.
42
+ \end{itemize}
43
+
44
+ \subsection{Dataset Preparation}
45
+ We begin with an aggregated JSON document keyed by Dublin Core field names. A Python script (\texttt{harvest\_aggregate.ipynb}) flattens this into one record per metadata entry:
46
+ \begin{verbatim}
47
+ {"text":"Acquired on 12/31/2024","label":"Date"}
48
+ \end{verbatim}
49
+ Synthetic expansion to 200 examples across ten fields ensures coverage of varied formats.
50
+
51
+ \subsection{Model Fine-Tuning}
52
+ \begin{itemize}
53
+ \item \textbf{Model:} \texttt{bert-base-uncased} with $K=10$ labels.
54
+ \item \textbf{Tokenizer:} WordPiece, padding/truncation to 128 tokens.
55
+ \item \textbf{Training:} 80/20 split, cross-entropy loss, LR=2e-5, batch size=8, 5 epochs via Hugging Face \texttt{Trainer} \cite{Wolf2020}.
56
+ \item \textbf{Evaluation:} Accuracy, weighted and macro F1, precision, and recall using the \texttt{evaluate} library.
57
+ \end{itemize}
58
+
59
+ \subsection{Inference Pipeline}
60
+ We package our inference logic in \texttt{bertley.py}. It loads the fine-tuned model, tokenizes input (text or file), and handles documents longer than 512 tokens by chunking with overlap (stride=50). Pseudocode excerpt:
61
+
62
+ \begin{verbatim}
63
+ # Load model & tokenizer from checkpoint
64
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
65
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
66
+ classifier = pipeline("text-classification",
67
+ model=model,
68
+ tokenizer=tokenizer,
69
+ return_all_scores=True)
70
+
71
+ # For long texts, split into overlapping chunks
72
+ def chunk_and_classify(text):
73
+ tokens = tokenizer(text)['input_ids'][0]
74
+ for i in range(0, len(tokens), max_len - stride):
75
+ chunk = tokenizer.decode(tokens[i:i+max_len])
76
+ scores = classifier(chunk)
77
+ accumulate(scores)
78
+ return average_scores()
79
+ \end{verbatim}
80
+
81
+ This script achieves robust, batch-ready inference for entire documents.
82
+
83
+ \section{Results}
84
+ \subsection{Evaluation Metrics}
85
+ After fine-tuning for 5 epochs, we evaluated on the test set. Table~\ref{tab:eval_metrics} summarizes the results:
86
+
87
+ \begin{table}[ht]
88
+ \caption{Test Set Evaluation Metrics}
89
+ \label{tab:eval_metrics}
90
+ \centering
91
+ \begin{tabular}{l c}
92
+ \hline
93
+ \textbf{Metric} & \textbf{Value} \\
94
+ \hline
95
+ Loss & 0.1338 \\
96
+ Accuracy & 0.9665 \\
97
+ F1 (weighted) & 0.9628 \\
98
+ Precision (weighted) & 0.9650 \\
99
+ Recall (weighted) & 0.9665 \\
100
+ F1 (macro) & 0.8283 \\
101
+ Precision (macro) & 0.8551 \\
102
+ Recall (macro) & 0.8225 \\
103
+ \hline
104
+ Runtime (s) & 35.83 \\
105
+ Samples/sec & 518.70 \\
106
+ Steps/sec & 16.22 \\
107
+ \hline
108
+ \end{tabular}
109
+ \end{table}
110
+
111
+ \subsection{Interpretation}
112
+ Overall accuracy of 96.65\% and weighted F1 of 96.28\% demonstrate reliable field mapping. The macro F1 (82.83\%) suggests room for improvement on rarer or more ambiguous classes. Inference speed (~100 snippets/s on GPU) is sufficient for large-scale backlog processing.
113
+
114
+ \section{Conclusion}
115
+ Fine-tuning BERT-base-uncased for metadata classification yields an overall accuracy of 0.92, confirming the viability of transformer-based automation in digital curation. Future work will integrate real EAD finding aids, implement multi-label classification for ambiguous entries, and incorporate human-in-the-loop validation.
116
+
117
+ \section*{Acknowledgment}
118
+ The author thanks the University of Michigan School of Information and participating archival staff for insights into donor metadata workflows.
119
+
120
+ \begin{thebibliography}{1}
121
+ \bibitem{NARA_AI}
122
+ U.S. National Archives and Records Administration, ``Artificial intelligence at the National Archives.'' [Online]. Available: \url{https://www.archives.gov/ai}, accessed Apr. 4, 2025.
123
+
124
+ \bibitem{CMU_CAMPI}
125
+ Carnegie Mellon Univ. Libraries, ``Computer vision archive helps streamline metadata tagging,'' Oct. 2020. [Online]. Available: \url{https://www.cmu.edu/news/stories/archives/2020/october/computer-vision-archive.html}.
126
+
127
+ \bibitem{MetaEnhance}
128
+ M.~H. Choudhury \emph{et al.}, ``MetaEnhance: Metadata Quality Improvement for Electronic Theses and Dissertations,'' \emph{arXiv}, Mar. 2023.
129
+
130
+ \bibitem{Sawarkar2020}
131
+ K.~Sawarkar and M.~Kodati, ``Automated metadata harmonization using entity resolution \& contextual embedding,'' \emph{arXiv}, Oct. 2020.
132
+
133
+ \bibitem{Wolf2020}
134
+ T.~Wolf \emph{et al.}, ``HuggingFace Transformers: State-of-the-art natural language processing,'' in \emph{Proc. EMNLP: Findings}, 2020, pp. 8201--8210.
135
+ \end{thebibliography}
136
+
137
+ \end{document}
tools/harvest_aggregate.ipynb ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# for harvesting the training data\n",
10
+ "# all of the modules and global variables are defined here\n",
11
+ "from sickle import Sickle\n",
12
+ "from pathlib import Path\n",
13
+ "import json\n",
14
+ "\n",
15
+ "# destination for fetched docs, goes to my large SSD in this case\n",
16
+ "# change internal strings to match your system and needs\n",
17
+ "DEST_LARGE = Path(\"/mnt/d/data-large/\").absolute()\n",
18
+ "# stored locally if size is not a concern\n",
19
+ "DEST_SMALL = Path().cwd().absolute() / \"datasets/\"\n",
20
+ "# alternative local directory\n",
21
+ "DEST_SMALL_ALT = Path().cwd().absolute() / \"datasets-alt/\"\n",
22
+ "# general repository for pulling data OAI-PMH-compliant\n",
23
+ "WORKING_REPO = \"https://oai.datacite.org/oai/\"\n",
24
+ "# umich OAI-PMH repository for deepblue/dspace\n",
25
+ "UMICH_REPO = \"https://deepblue.lib.umich.edu/dspace-oai/request/\"\n",
26
+ "# set identifier for library\n",
27
+ "BHL_SET = \"com_2027.42_65133\"\n",
28
+ " # collection of other endpoints I utilized\n",
29
+ "ENDPOINT_COLLECTION = {\n",
30
+ " \"IJHS\": \"https://www.ijhsonline.com/index.php/IJHS/oai\",\n",
31
+ " \"IJESS\": \"https://journalkeberlanjutan.com/index.php/ijesss/oai\",\n",
32
+ " \"Medan\": \"https://jurnal.medanresourcecenter.org/index.php/ICI/oai?\",\n",
33
+ " \"YWNFR\": \"https://jurnal.ywnr.org/index.php/cfabr/oai\",\n",
34
+ " \"UTOR\": \"https://symposia.library.utoronto.ca/index.php/symposia/oai\",\n",
35
+ "}"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def harvester(*args):\n",
45
+ " dest = url = metadata_prefix = max_files = dataset = None\n",
46
+ "\n",
47
+ " # this try/except essentially tries to populate five arguments, and then only\n",
48
+ " # four if it fails to unpack 5\n",
49
+ " try:\n",
50
+ " dest, url, metadata_prefix, max_files, dataset = args\n",
51
+ " except ValueError:\n",
52
+ " dest, url, metadata_prefix, max_files = args\n",
53
+ " if isinstance(dest, str):\n",
54
+ " dest = Path(dest)\n",
55
+ " if not dest.exists():\n",
56
+ " dest.mkdir(parents=True, exist_ok=True)\n",
57
+ "\n",
58
+ " sckl = Sickle(url)\n",
59
+ " records = sckl.ListRecord(metadataPrefix=metadata_prefix, set=dataset)\n",
60
+ " filecount = 0\n",
61
+ " errorcount = 0\n",
62
+ " try:\n",
63
+ " for rec in records:\n",
64
+ " id = rec.header.identifier.replace(\":\", \"_\").replace(\"/\", \"_\")\n",
65
+ " try:\n",
66
+ " metadata_json = json.dumps(rec.metadata, indent=2)\n",
67
+ " filepath = f\"{dest / Path(id)}.json\"\n",
68
+ " with open(filepath, \"w\") as f:\n",
69
+ " f.write(metadata_json)\n",
70
+ " print(f\"wrote #{filecount}: {id}\")\n",
71
+ " filecount += 1\n",
72
+ " except (AttributeError, TypeError) as e:\n",
73
+ " print(f\"skipped {id} due to json incompatibility: {e}\")\n",
74
+ " errorcount += 1\n",
75
+ " continue\n",
76
+ " if filecount >= int(max_files):\n",
77
+ " print(f\"Final filecount: {filecount}\")\n",
78
+ " print(f\"Final errorcount: {errorcount}\")\n",
79
+ " return\n",
80
+ " except IndexError as e:\n",
81
+ " raise Exception(\n",
82
+ " f\"Error: {e} - there may be an issue with your call to the data source\"\n",
83
+ " )\n",
84
+ "\n",
85
+ "\n",
86
+ "def records_aggregator(records_path: str | Path) -> dict:\n",
87
+ "\n",
88
+ " if isinstance(records_path, str):\n",
89
+ " records_path = Path(records_path)\n",
90
+ " error_count = 0\n",
91
+ " proc = {}\n",
92
+ " rec = None\n",
93
+ "\n",
94
+ " for file in records_path.glob(\"*.json\"):\n",
95
+ " try:\n",
96
+ " with open(file, \"r\", encoding=\"utf-8\") as f:\n",
97
+ " rec = json.load(f)\n",
98
+ " for k in rec.keys():\n",
99
+ " if k not in proc.keys() and k == \"description\":\n",
100
+ " proc[k] = [\n",
101
+ " v for v in rec[k] if v and not v.startswith(\"http\")\n",
102
+ " ]\n",
103
+ " elif k not in proc.keys():\n",
104
+ " proc[k] = rec[k]\n",
105
+ " elif rec[k]:\n",
106
+ " for v in rec[k]:\n",
107
+ " if v not in proc[k]:\n",
108
+ " # to skip urls in umich descriptions, since they're more administrative\n",
109
+ " if (\n",
110
+ " \"umich\" in file.name\n",
111
+ " and k == \"description\"\n",
112
+ " and v\n",
113
+ " and v.startswith(\"http\")\n",
114
+ " ):\n",
115
+ " continue\n",
116
+ " proc[k].append(v)\n",
117
+ " except (json.JSONDecodeError, AttributeError, TypeError) as e:\n",
118
+ " print(\n",
119
+ " f\"skipped {file} due to json incompatibility or similar issue\"\n",
120
+ " )\n",
121
+ " print(f\"Error code: {e}\")\n",
122
+ " error_count += 1\n",
123
+ "\n",
124
+ " print(f\"Errors encountered: {error_count}\")\n",
125
+ " return proc\n",
126
+ "\n",
127
+ "\n",
128
+ "def flatten_aggregated_data(filepath: str | Path) -> list:\n",
129
+ " \"\"\"\n",
130
+ " Flatten aggregated metadata into a list of training instances.\n",
131
+ "\n",
132
+ " This function reads an aggregated JSON file of metadata specified by the filepath.\n",
133
+ " The file should contain a single JSON object where each key is a metadata field\n",
134
+ " (e.g., \"description\") and its value is a list of corresponding metadata values.\n",
135
+ " The function transforms this object into a flat list of dictionaries where each\n",
136
+ " dictionary represents a training instance with two keys:\n",
137
+ " - \"text\": a non-empty, stripped metadata value.\n",
138
+ " - \"label\": the metadata field associated with the value.\n",
139
+ "\n",
140
+ " Args:\n",
141
+ " filepath (str or Path): The path to the aggregated data JSON file.\n",
142
+ "\n",
143
+ " Returns:\n",
144
+ " list: A list of dictionaries each with keys \"text\" and \"label\".\n",
145
+ "\n",
146
+ " Raises:\n",
147
+ " Exception: If the file cannot be parsed due to JSON decoding errors,\n",
148
+ " attribute issues, or type incompatibility.\n",
149
+ " \"\"\"\n",
150
+ " if isinstance(filepath, str):\n",
151
+ " filepath = Path(filepath)\n",
152
+ "\n",
153
+ " try:\n",
154
+ " with open(filepath, \"r\", encoding=\"utf-8\") as f:\n",
155
+ " aggregated_data = json.load(f)\n",
156
+ "\n",
157
+ " flattened_data = []\n",
158
+ "\n",
159
+ " # iterate over each field and its list of values.\n",
160
+ " for field, values in aggregated_data.items():\n",
161
+ " # for each metadata value in the list, create an individual training instance\n",
162
+ " # each entry should be a dict with \"label\" and \"text\" keys,\n",
163
+ " # where label is the metadata field and text is each corresponding value\n",
164
+ " for value in values:\n",
165
+ " # this checks if the value is a non-empty string\n",
166
+ " if isinstance(value, str) and value.strip():\n",
167
+ " flattened_data.append(\n",
168
+ " {\"text\": value.strip(), \"label\": field}\n",
169
+ " )\n",
170
+ " except (json.JSONDecodeError, AttributeError, TypeError) as e:\n",
171
+ " raise Exception(\n",
172
+ " f\"failed due to json incompatibility or similar issue: {e} \"\n",
173
+ " \"Check the formatting of your aggregated data file. It should be a single JSON object\"\n",
174
+ " )\n",
175
+ "\n",
176
+ " return flattened_data\n",
177
+ "\n",
178
+ "\n",
179
+ "def data_integrity_check(data: list, *labels) -> None:\n",
180
+ " \"\"\"\n",
181
+ " Quick function to check the training data doesn't have any erroneous labels\n",
182
+ "\n",
183
+ " Args:\n",
184
+ " data (list): List of dictionaries containing the training data.\n",
185
+ "\n",
186
+ " *labels: Labels to check against.\n",
187
+ " \"\"\"\n",
188
+ " for i, dict in enumerate(data):\n",
189
+ " if \"text\" not in dict.keys() or \"label\" not in dict.keys():\n",
190
+ " print(f\"Error #1 in entry {i}: {dict}\")\n",
191
+ " continue\n",
192
+ " if not isinstance(dict[\"text\"], str) or not isinstance(\n",
193
+ " dict[\"label\"], str\n",
194
+ " ):\n",
195
+ " print(f\"Error #2 in entry {i}: {dict}\")\n",
196
+ " continue\n",
197
+ " if not dict[\"text\"].strip() or not dict[\"label\"].strip():\n",
198
+ " print(f\"Error #3 in entry {i}: {dict}\")\n",
199
+ " continue\n",
200
+ " if dict[\"label\"] not in labels:\n",
201
+ " print(f\"Error #4 in entry {i}: {dict}\")\n",
202
+ " continue\n",
203
+ " print(f\"#{i} is valid\")"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 3,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "import pandas as pd\n",
213
+ "from pathlib import Path\n",
214
+ "import json\n",
215
+ "\n",
216
+ "pt = Path.cwd().parent / Path(\"lang_codes.xlsx\")\n",
217
+ "\n",
218
+ "langs = pd.read_excel(pt, usecols=[0, 1])"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 4,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "dta = \"../aggregate_data_new.json\"\n",
228
+ "\n",
229
+ "with open(dta, \"r\", encoding=\"utf-8\") as f:\n",
230
+ " dtb = json.load(f)"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "\n",
240
+ "# harvesting operation\n",
241
+ "# this will call the harvesting function and ask for parameters, or will use the defaults\n",
242
+ "\n",
243
+ "(*args,) = (DEST_SMALL_ALT, ENDPOINT_COLLECTION[\"UTOR\"], \"oai_dc\", 2000)\n",
244
+ "\n",
245
+ "d = args[0]\n",
246
+ "harvester(*args)\n"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "\n",
256
+ "# aggregation operation\n",
257
+ "# this will take the destination input from the harvesting operation above, saved\n",
258
+ "# as d, and use it as the path to the directory containing the harvested data\n",
259
+ "# the data will be aggregated into one long document, \n",
260
+ "if not d:\n",
261
+ " raise Exception(\"Need a destination for aggregation\")\n",
262
+ "data_path = d\n",
263
+ "\n",
264
+ "recs = records_aggregator(d)\n",
265
+ "with open(f\"{d}.json\", \"w\") as f:\n",
266
+ " json.dump(recs, f, indent=2, ensure_ascii=False)\n"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "# alternate aggregator for more contextualized training data\n",
276
+ "aggregated_record = \"aggregate_data_new.json\"\n",
277
+ "\n",
278
+ "with open(\"raw_records.json\") as f:\n",
279
+ " records = json.load(f)\n",
280
+ "\n",
281
+ "examples = []\n",
282
+ "for rec in records:\n",
283
+ " for field, val in rec.items():\n",
284
+ " if not val:\n",
285
+ " continue\n",
286
+ " snippet = val if isinstance(val, str) else \" \".join(val)\n",
287
+ " # build a “context” string of all the *other* fields\n",
288
+ " context = \" \".join(f\"{k}: {v}\" for k,v in rec.items() if k != field)\n",
289
+ " examples.append({\n",
290
+ " \"text\": snippet,\n",
291
+ " \"context\": context,\n",
292
+ " \"label\": field\n",
293
+ " })"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "data_path = \"./aggregate_data_new.json\"\n",
303
+ "# flatten operation\n",
304
+ "try:\n",
305
+ " flat_data = flatten_aggregated_data(data_path)\n",
306
+ " with open(\"./flattened_data_bhl_set.json\", \"w\") as f:\n",
307
+ " json.dump(flat_data, f, indent=2, ensure_ascii=False)\n",
308
+ "except Exception as e:\n",
309
+ " raise (f\"failed to flatten the aggregated data with the following exception: {e}\")\n",
310
+ "# integrity check operation\n",
311
+ "print(\"Goodbye\")\n",
312
+ "\n",
313
+ "\n"
314
+ ]
315
+ }
316
+ ],
317
+ "metadata": {
318
+ "kernelspec": {
319
+ "display_name": ".venv-llm (3.11.0)",
320
+ "language": "python",
321
+ "name": "python3"
322
+ },
323
+ "language_info": {
324
+ "codemirror_mode": {
325
+ "name": "ipython",
326
+ "version": 3
327
+ },
328
+ "file_extension": ".py",
329
+ "mimetype": "text/x-python",
330
+ "name": "python",
331
+ "nbconvert_exporter": "python",
332
+ "pygments_lexer": "ipython3",
333
+ "version": "3.11.0"
334
+ }
335
+ },
336
+ "nbformat": 4,
337
+ "nbformat_minor": 2
338
+ }
training_script.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import json
4
+ from transformers import (
5
+ AutoModelForSequenceClassification,
6
+ AutoTokenizer,
7
+ TrainingArguments,
8
+ Trainer,
9
+ EarlyStoppingCallback,
10
+ )
11
+ import evaluate
12
+ from datasets import Dataset
13
+
14
+
15
+ # the LLM model we are going to be using:
16
+ # google's BERT model
17
+ MODEL = "bert-base-uncased"
18
+
19
+ ACCURACY_METRIC = evaluate.load("accuracy")
20
+ F1_METRIC = evaluate.load("f1")
21
+ PRECISION_METRIC = evaluate.load("precision")
22
+ RECALL_METRIC = evaluate.load("recall")
23
+
24
+
25
+ def compute_metrics(eval_pred):
26
+
27
+ logits, labels = eval_pred
28
+ preds = logits.argmax(axis=-1)
29
+
30
+ # weighted averages
31
+ f1_w = F1_METRIC.compute(
32
+ predictions=preds, references=labels, average="weighted"
33
+ )["f1"]
34
+ prec_w = PRECISION_METRIC.compute(
35
+ predictions=preds, references=labels, average="weighted"
36
+ )["precision"]
37
+ rec_w = RECALL_METRIC.compute(
38
+ predictions=preds, references=labels, average="weighted"
39
+ )["recall"]
40
+
41
+ # macro averages
42
+ f1_m = F1_METRIC.compute(
43
+ predictions=preds, references=labels, average="macro"
44
+ )["f1"]
45
+ prec_m = PRECISION_METRIC.compute(
46
+ predictions=preds, references=labels, average="macro"
47
+ )["precision"]
48
+ rec_m = RECALL_METRIC.compute(
49
+ predictions=preds, references=labels, average="macro"
50
+ )["recall"]
51
+
52
+ return {
53
+ "accuracy": ACCURACY_METRIC.compute(
54
+ predictions=preds, references=labels
55
+ )["accuracy"],
56
+ "f1_weighted": f1_w,
57
+ "precision_weighted": prec_w,
58
+ "recall_weighted": rec_w,
59
+ "f1_macro": f1_m,
60
+ "precision_macro": prec_m,
61
+ "recall_macro": rec_m,
62
+ }
63
+
64
+
65
+ # creates a dataset object from the training data
66
+ def main() -> None:
67
+
68
+ data = None
69
+ aggregate_data = None
70
+ context = None
71
+
72
+ flat_source = "./flattened_data_new.json"
73
+ aggregate_source = "./aggregate_data_new.json"
74
+
75
+ with open(flat_source, "r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+ with open(aggregate_source, "r", encoding="utf-8") as f:
78
+ aggregate_data = json.load(f)
79
+
80
+ try:
81
+ for rec in data:
82
+ rec["context"] = " ".join(
83
+ str(v) for k, v in rec.items() if k not in ("text", "label")
84
+ ).strip()
85
+
86
+ ds = Dataset.from_list(data)
87
+ except:
88
+ raise (Exception("Error creating dataset from list"))
89
+
90
+ labels = list(aggregate_data.keys())
91
+ label2id = {l: i for i, l in enumerate(labels)}
92
+ id2label = {i: l for i, l in enumerate(labels)}
93
+
94
+ if context and "context" in data[0]:
95
+ ds = ds.map(
96
+ lambda x: {"input_text": x["context"] + " " + x["text"]},
97
+ batched=False,
98
+ )
99
+ text_field = "input_text"
100
+ else:
101
+ ds = ds.map(lambda x: {"input_text": x["text"]}, batched=False)
102
+ text_field = "input_text"
103
+
104
+ # maps labels to integers
105
+ ds = ds.map(
106
+ lambda x: {"labels": label2id[x["label"]]},
107
+ remove_columns=(
108
+ ["label", "text", "context"]
109
+ if "context" in data[0]
110
+ else ["label", "text"]
111
+ ),
112
+ )
113
+
114
+ # quickly write the label/id mappings to files
115
+ with open("label2id.json", "w", encoding="utf-8") as f:
116
+ json.dump(label2id, f, indent=2)
117
+ with open("id2label.json", "w", encoding="utf-8") as f:
118
+ json.dump(id2label, f, indent=2)
119
+
120
+ # this creates a datadict with two keys, "train" and "test"
121
+ # each has a subset of data, one for testing and one for training
122
+ # ratio of 80/20 train/test
123
+ split = ds.train_test_split(0.2)
124
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
125
+ model = AutoModelForSequenceClassification.from_pretrained(
126
+ MODEL,
127
+ num_labels=len(labels),
128
+ id2label=id2label,
129
+ label2id=label2id,
130
+ )
131
+
132
+ tokenized = split.map(
133
+ lambda x: tokenizer(
134
+ x[text_field], padding="max_length", truncation=True
135
+ ),
136
+ batched=True,
137
+ )
138
+ tokenized.set_format(
139
+ "torch", columns=["input_ids", "attention_mask", "labels"]
140
+ )
141
+
142
+ # these are the training arguments. these should be ok for testing
143
+ # but not a full fledged run. once dataset is larger, num_train_epochs should be raised
144
+ training_args = TrainingArguments(
145
+ output_dir="./BERTley",
146
+ learning_rate=2e-5,
147
+ per_device_train_batch_size=32,
148
+ per_device_eval_batch_size=32,
149
+ gradient_accumulation_steps=2, # simulate a 64‑batch without OOM
150
+ num_train_epochs=5, # for a full run, more epochs may be needed
151
+ weight_decay=0.01,
152
+ dataloader_num_workers=4,
153
+ eval_strategy="epoch", # evaluate every few steps instead of per epoch
154
+ fp16=True,
155
+ logging_strategy="epoch", # log based on epoch
156
+ logging_dir="./logs",
157
+ save_strategy="epoch",
158
+ save_total_limit=1, # save checkpoints based on steps
159
+ load_best_model_at_end=True,
160
+ metric_for_best_model="eval_loss",
161
+ greater_is_better=False,
162
+ report_to=[
163
+ "tensorboard"
164
+ ], # report metrics to TensorBoard, for example
165
+ )
166
+
167
+ # arguments for training the model
168
+ trainer = Trainer(
169
+ model=model,
170
+ args=training_args,
171
+ train_dataset=tokenized["train"],
172
+ eval_dataset=tokenized["test"],
173
+ compute_metrics=compute_metrics,
174
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
175
+ )
176
+
177
+ # training the model...
178
+ trainer.train()
179
+
180
+ # evaluate after training
181
+ evals = trainer.evaluate()
182
+ with open("evals.json", "w", encoding="utf-8") as f:
183
+ json.dump(evals, f, indent=2)
184
+ print("Evaluation results: ")
185
+ print(evals)
186
+ print("Accuracy, F1, Precision, and Recall metrics: ")
187
+ for key, value in evals.items():
188
+ print(f"{key}: {value}")
189
+
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()