{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset Background & Loading\n", "\n", "The training dataset was sourced from the publicly available BIS central bank speeches, downloaded using the `gingado` package:\n", "\n", "```python\n", "from gingado.datasets import load_CB_speeches\n", "all_speeches = load_CB_speeches()\n", "all_speeches.to_csv(\"central_bank_speeches.csv\", index=False)\n", "```\n", "\n", "A preprocessing script was applied to clean the text, lowercase it, split speeches into well-formed sentences, and filter out short/noisy segments. This generated over **2 million sentence-level samples**, saved as `speeches_data_preprocessed.csv`.\n", "\n", "For training on Kaggle, the preprocessed dataset was uploaded as an external file and loaded.\n", "\n", "This ensures clean and consistent input for masked language modeling (MLM)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2025-07-19T17:12:03.395329Z", "iopub.status.busy": "2025-07-19T17:12:03.395050Z", "iopub.status.idle": "2025-07-19T17:12:16.719665Z", "shell.execute_reply": "2025-07-19T17:12:16.719049Z", "shell.execute_reply.started": "2025-07-19T17:12:03.395302Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(19609, 8)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
urltitledescriptiondatetextauthorcountryprocessed_text
0https://www.bis.org/review/r970211c.pdfMr. Chen discusses monetary relations between ...Speech by the Deputy Governor of the People's ...1996-09-10 00:00:00Mr. Chen discusses monetary relations between ...Chen YuanChina[\"mr. chen discusses monetary relations betwee...
1https://www.bis.org/review/r970211b.pdfMr. Dai looks at the possibilities of strength...Speech by the Governor of the People's Bank of...1996-11-13 00:00:00Mr. Dai looks at the possibilities of strength...Dai XianglongChina[\"mr. dai looks at the possibilities of streng...
2https://www.bis.org/review/r970211a.pdfMr. Dai assesses the outlook for Hong Kong as ...Speech by the Governor of the People's Bank of...1996-09-30 00:00:00Mr. Dai assesses the outlook for Hong Kong as ...Dai XianglongChina[\"mr. dai assesses the outlook for hong kong a...
3https://www.bis.org/review/r970203b.pdfMr. Rangarajan examines the objectives of mone...Address by the Governor of the Reserve Bank of...1996-12-28 00:00:00Mr. Rangarajan examines the objectives of mone...Bimal JalanIndia[\"mr. rangarajan examines the objectives of mo...
4https://www.bis.org/review/r970115a.pdfM. Trichet presents the monetary policy guidel...BANK OF FRANCE, PRESS RELEASE, 17/12/96.1996-12-17 00:00:00M. Trichet presents the monetary policy guidel...Bank of FranceFrance['m. trichet presents the monetary policy guid...
\n", "
" ], "text/plain": [ " url \\\n", "0 https://www.bis.org/review/r970211c.pdf \n", "1 https://www.bis.org/review/r970211b.pdf \n", "2 https://www.bis.org/review/r970211a.pdf \n", "3 https://www.bis.org/review/r970203b.pdf \n", "4 https://www.bis.org/review/r970115a.pdf \n", "\n", " title \\\n", "0 Mr. Chen discusses monetary relations between ... \n", "1 Mr. Dai looks at the possibilities of strength... \n", "2 Mr. Dai assesses the outlook for Hong Kong as ... \n", "3 Mr. Rangarajan examines the objectives of mone... \n", "4 M. Trichet presents the monetary policy guidel... \n", "\n", " description date \\\n", "0 Speech by the Deputy Governor of the People's ... 1996-09-10 00:00:00 \n", "1 Speech by the Governor of the People's Bank of... 1996-11-13 00:00:00 \n", "2 Speech by the Governor of the People's Bank of... 1996-09-30 00:00:00 \n", "3 Address by the Governor of the Reserve Bank of... 1996-12-28 00:00:00 \n", "4 BANK OF FRANCE, PRESS RELEASE, 17/12/96. 1996-12-17 00:00:00 \n", "\n", " text author country \\\n", "0 Mr. Chen discusses monetary relations between ... Chen Yuan China \n", "1 Mr. Dai looks at the possibilities of strength... Dai Xianglong China \n", "2 Mr. Dai assesses the outlook for Hong Kong as ... Dai Xianglong China \n", "3 Mr. Rangarajan examines the objectives of mone... Bimal Jalan India \n", "4 M. Trichet presents the monetary policy guidel... Bank of France France \n", "\n", " processed_text \n", "0 [\"mr. chen discusses monetary relations betwee... \n", "1 [\"mr. dai looks at the possibilities of streng... \n", "2 [\"mr. dai assesses the outlook for hong kong a... \n", "3 [\"mr. rangarajan examines the objectives of mo... \n", "4 ['m. trichet presents the monetary policy guid... " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df = pd.read_csv('/kaggle/input/bis-speeches/speeches_data_preprocessed.csv')\n", "print(df.shape)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenize BIS Sentences for MLM Training\n", "\n", "This section prepares the preprocessed central bank speech sentences for masked language modeling (MLM) by:\n", "\n", "- Flattening over 2 million cleaned sentences into a single list.\n", "- Converting them into a Hugging Face `Dataset` object.\n", "- Tokenizing using the `bert-base-uncased` tokenizer with:\n", " - `max_length=128` (chosen based on sentence length distribution: ~99% of sentences fall within this limit),\n", " - truncation and padding enabled.\n", "- Applying tokenization in parallel using `num_proc=4` for efficiency.\n", "- Saving the tokenized dataset locally for later training use.\n", "\n", "The tokenized dataset is saved to:\n", "\n", "```\n", "/kaggle/working/tokenized_bis_dataset\n", "```\n", "\n", "This ensures the input is consistently preprocessed and optimally sized for efficient MLM training.\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2025-07-19T17:12:20.157315Z", "iopub.status.busy": "2025-07-19T17:12:20.157050Z", "iopub.status.idle": "2025-07-19T17:16:28.711161Z", "shell.execute_reply": "2025-07-19T17:16:28.710348Z", "shell.execute_reply.started": "2025-07-19T17:12:20.157296Z" }, "trusted": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a51d2c2858e644a585bd2c6e07b2d618", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/48.0 [00:00\n", " \n", " \n", " [65238/65238 8:18:46, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
2002.361400
4002.249100
6002.250100
8002.226300
10002.183100
12002.198800
14002.145000
16002.174700
18002.134100
20002.133000
22002.108300
24002.088800
26002.109000
28002.093700
30002.081200
32002.078700
34002.106100
36002.068000
38002.068400
40002.053400
42002.038000
44002.060000
46002.049400
48002.038100
50002.023700
52002.031100
54002.050300
56002.021300
58002.022600
60001.988000
62002.005500
64002.015700
66001.987600
68001.994200
70001.997800
72001.981400
74001.982900
76002.018900
78001.989700
80001.985400
82001.960200
84001.977300
86001.970100
88001.979800
90001.978400
92001.942400
94001.962300
96001.941700
98001.931800
100001.940800
102001.919900
104001.899700
106001.963000
108001.986000
110001.922700
112001.938600
114001.940800
116001.950300
118001.939500
120001.938500
122001.934400
124001.933600
126001.891400
128001.962800
130001.916100
132001.912500
134001.880500
136001.910800
138001.907600
140001.913600
142001.898900
144001.934200
146001.896200
148001.936400
150001.903800
152001.881700
154001.889800
156001.887100
158001.881200
160001.871600
162001.870300
164001.884200
166001.884400
168001.841900
170001.875600
172001.849700
174001.854700
176001.868200
178001.853200
180001.857000
182001.924400
184001.885000
186001.873600
188001.873100
190001.868300
192001.873800
194001.870100
196001.868400
198001.834900
200001.840700
202001.846400
204001.856500
206001.859000
208001.873700
210001.820800
212001.849200
214001.839000
216001.833900
218001.841400
220001.827000
222001.858900
224001.825700
226001.845400
228001.820800
230001.829700
232001.834400
234001.822500
236001.812600
238001.803200
240001.817000
242001.829300
244001.821900
246001.829200
248001.838800
250001.846000
252001.810500
254001.791800
256001.832400
258001.806300
260001.815600
262001.783200
264001.796400
266001.800400
268001.775000
270001.795000
272001.821800
274001.818200
276001.821500
278001.823700
280001.784400
282001.802300
284001.793400
286001.818000
288001.759300
290001.765300
292001.781000
294001.787900
296001.801300
298001.778400
300001.703900
302001.808600
304001.798500
306001.774700
308001.769300
310001.812800
312001.815200
314001.763000
316001.770900
318001.755600
320001.774800
322001.792300
324001.748700
326001.764200
328001.770000
330001.785100
332001.772400
334001.742800
336001.779800
338001.722400
340001.758500
342001.754000
344001.787000
346001.758700
348001.738800
350001.734000
352001.755200
354001.745000
356001.737300
358001.736600
360001.739600
362001.718000
364001.755300
366001.749200
368001.757300
370001.730600
372001.768200
374001.735300
376001.731500
378001.733600
380001.712900
382001.727000
384001.736600
386001.710300
388001.728400
390001.734800
392001.726600
394001.681900
396001.752200
398001.702100
400001.731000
402001.713800
404001.719200
406001.714400
408001.694700
410001.747300
412001.747600
414001.703500
416001.723200
418001.707700
420001.693900
422001.703700
424001.732700
426001.665700
428001.710400
430001.708900
432001.720000
434001.690400
436001.696600
438001.671700
440001.705700
442001.725100
444001.726000
446001.700000
448001.718800
450001.666500
452001.715900
454001.704800
456001.675300
458001.718500
460001.710300
462001.705200
464001.675400
466001.676400
468001.683600
470001.669400
472001.701700
474001.693300
476001.707200
478001.666400
480001.665500
482001.668200
484001.688100
486001.714800
488001.653800
490001.679800
492001.676300
494001.709800
496001.667600
498001.667900
500001.656900
502001.686600
504001.679800
506001.667100
508001.675700
510001.689400
512001.682400
514001.663600
516001.669500
518001.653500
520001.673900
522001.653600
524001.650300
526001.646600
528001.657700
530001.665000
532001.661700
534001.670700
536001.643200
538001.613200
540001.644600
542001.667900
544001.662500
546001.669900
548001.677700
550001.631500
552001.663500
554001.656300
556001.654600
558001.648000
560001.657400
562001.648000
564001.669800
566001.642000
568001.654600
570001.666300
572001.646200
574001.614200
576001.639200
578001.660000
580001.649900
582001.664000
584001.638000
586001.607500
588001.636300
590001.652900
592001.620800
594001.634200
596001.628300
598001.659300
600001.622400
602001.660000
604001.627900
606001.645900
608001.647800
610001.605900
612001.628400
614001.623200
616001.649100
618001.646600
620001.642000
622001.632000
624001.626200
626001.653700
628001.641200
630001.646200
632001.617100
634001.629300
636001.644500
638001.665600
640001.619800
642001.648800
644001.601500
646001.627900
648001.624000
650001.648100
652001.639600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "โœ… Training completed at: 2025-07-20 01:35:50.877293\n", "๐ŸŽ‰ Final model saved to /kaggle/working/bert-mlm-bis\n" ] } ], "source": [ "# 1. Install required packages\n", "# !pip install -U transformers datasets --quiet\n", "\n", "# 2. Imports\n", "from transformers import (\n", " BertTokenizerFast,\n", " BertForMaskedLM,\n", " Trainer,\n", " TrainingArguments,\n", " DataCollatorForLanguageModeling\n", ")\n", "from datasets import load_from_disk\n", "from datetime import datetime\n", "import torch\n", "import os\n", "\n", "# 3. Force use of single GPU (for P100)\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "# 4. Load tokenizer and dataset\n", "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n", "dataset = load_from_disk(\"/kaggle/working/tokenized_bis_dataset\")\n", "print(f\"โœ… Tokenized dataset loaded with {len(dataset)} samples.\")\n", "\n", "# 5. Load model\n", "model = BertForMaskedLM.from_pretrained(\"bert-base-uncased\")\n", "\n", "# 6. Data collator for MLM\n", "data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", ")\n", "\n", "# 7. Training arguments (gradient accumulation + smaller per-device batch)\n", "training_args = TrainingArguments(\n", " output_dir=\"/kaggle/working/bert-mlm-bis\",\n", " overwrite_output_dir=True,\n", " num_train_epochs=1, # โœ… Full dataset, 1 pass\n", " per_device_train_batch_size=16, # โœ… Lower memory per device\n", " gradient_accumulation_steps=2, # โœ… Effective batch size = 32\n", " eval_strategy=\"no\", # โœ… No eval during training\n", " save_strategy=\"epoch\", # โœ… Save once at end\n", " logging_dir=\"/kaggle/working/logs\",\n", " logging_steps=200,\n", " fp16=torch.cuda.is_available(), # โœ… Mixed precision\n", " dataloader_num_workers=4,\n", " save_total_limit=1,\n", " report_to=\"none\"\n", ")\n", "\n", "# 8. Initialize Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=dataset,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", ")\n", "\n", "# 9. Train\n", "print(\"โฑ๏ธ Training started at:\", datetime.now())\n", "trainer.train()\n", "print(\"โœ… Training completed at:\", datetime.now())\n", "\n", "# 10. Save final model\n", "trainer.save_model(\"/kaggle/working/bert-mlm-bis\")\n", "print(\"๐ŸŽ‰ Final model saved to /kaggle/working/bert-mlm-bis\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate Trained Model and Compute Perplexity\n", "\n", "To assess the quality of the pretrained CB-BERT-MLM model, evaluated it on a randomly sampled subset of 10,000 sentences from the tokenized dataset. This step computes:\n", "\n", "- **Evaluation loss** on masked language modeling (MLM)\n", "- **Perplexity**, a standard metric indicating how confidently the model predicts masked tokens (lower is better)\n", "\n", "```python\n", "from datasets import load_from_disk\n", "from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling\n", "import math\n", "\n", "# Load trained model and tokenizer\n", "model = AutoModelForMaskedLM.from_pretrained(...)\n", "tokenizer = AutoTokenizer.from_pretrained(...)\n", "\n", "# Select a subset of 10,000 sentences for quick evaluation\n", "eval_dataset = dataset.shuffle(seed=42).select(range(10000))\n", "\n", "# Evaluate\n", "metrics = trainer.evaluate()\n", "eval_loss = metrics[\"eval_loss\"]\n", "perplexity = math.exp(eval_loss)\n", "```\n", "\n", "> **Perplexity Score** is printed at the end of the cell. A lower perplexity indicates stronger masked token prediction performance and better fit to the domain-specific language.\n", "\n", "This provides a quantitative baseline for how well the model understands and reconstructs financial and monetary policy language.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-07-20T02:00:40.964023Z", "iopub.status.busy": "2025-07-20T02:00:40.963322Z", "iopub.status.idle": "2025-07-20T02:01:32.119053Z", "shell.execute_reply": "2025-07-20T02:01:32.118331Z", "shell.execute_reply.started": "2025-07-20T02:00:40.963997Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿš€ Using device: cuda\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_36/4227637877.py:39: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [625/625 00:50]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿ“‰ Evaluation Loss: 1.5392\n", "๐Ÿ“Š Perplexity Score (subset of 10000): 4.66\n" ] } ], "source": [ "# ๐Ÿ“ฆ Imports\n", "from transformers import (\n", " AutoModelForMaskedLM,\n", " AutoTokenizer,\n", " DataCollatorForLanguageModeling,\n", " Trainer,\n", " TrainingArguments\n", ")\n", "from datasets import load_from_disk\n", "import torch\n", "import math\n", "\n", "# ๐Ÿง  Ensure GPU is used\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"๐Ÿš€ Using device: {device}\")\n", "\n", "# ๐Ÿ”„ Load model and tokenizer from saved path\n", "model = AutoModelForMaskedLM.from_pretrained(\"/kaggle/working/bert-mlm-bis\").to(device)\n", "tokenizer = AutoTokenizer.from_pretrained(\"/kaggle/working/bert-mlm-bis\")\n", "\n", "# ๐Ÿ“‚ Load tokenized dataset and sample subset\n", "dataset = load_from_disk(\"/kaggle/working/tokenized_bis_dataset\")\n", "eval_dataset = dataset.shuffle(seed=42).select(range(10000)) # ๐Ÿ”ฝ reduce for speed\n", "\n", "# ๐Ÿ” Data collator for masked LM\n", "data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", ")\n", "\n", "# โš™๏ธ Trainer setup\n", "training_args = TrainingArguments(\n", " output_dir=\"/kaggle/working/tmp_eval\",\n", " per_device_eval_batch_size=16,\n", " report_to=\"none\"\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " eval_dataset=eval_dataset,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "# ๐Ÿ“Š Evaluate and compute perplexity\n", "metrics = trainer.evaluate()\n", "eval_loss = metrics[\"eval_loss\"]\n", "perplexity = math.exp(eval_loss)\n", "\n", "print(f\"๐Ÿ“‰ Evaluation Loss: {eval_loss:.4f}\")\n", "print(f\"๐Ÿ“Š Perplexity Score (subset of 10000): {perplexity:.2f}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Perplexity: BERT-Base vs CB-BERT-MLM\n", "\n", "This section evaluates and compares the perplexity of the original `bert-base-uncased` model and the domain-adapted `cb-bert-mlm` on a subset of 10,000 masked sentences from the BIS corpus.\n", "\n", "#### Evaluation Setup:\n", "- Both models use the same evaluation subset and masking strategy (MLM probability = 15%)\n", "- Performed on GPU (P100) with batch size 16\n", "- Perplexity is calculated from the evaluation loss: `perplexity = exp(loss)`\n", "\n", "#### Output:\n", "- Perplexity scores are printed for both models\n", "- Lower perplexity indicates better performance in masked token prediction on financial text\n", "\n", "This comparison highlights the impact of domain adaptation through MLM pretraining on central bank communication data." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2025-07-20T02:02:58.622861Z", "iopub.status.busy": "2025-07-20T02:02:58.622182Z", "iopub.status.idle": "2025-07-20T02:04:40.524560Z", "shell.execute_reply": "2025-07-20T02:04:40.523804Z", "shell.execute_reply.started": "2025-07-20T02:02:58.622839Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "๐Ÿ“Š Evaluating: BERT-Base\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "/tmp/ipykernel_36/810192027.py:37: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [625/625 00:50]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿ“‰ Eval Loss: 2.5698\n", "๐Ÿ“ Perplexity: 13.06\n", "\n", "๐Ÿ“Š Evaluating: BIS-BERT-MLM\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_36/810192027.py:37: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [625/625 00:50]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿ“‰ Eval Loss: 1.5392\n", "๐Ÿ“ Perplexity: 4.66\n", "\n", "๐Ÿงพ Summary:\n", "โžก๏ธ BERT-Base Perplexity : 13.06\n", "โžก๏ธ BIS-BERT-MLM Perplexity : 4.66\n" ] } ], "source": [ "from transformers import (\n", " AutoModelForMaskedLM,\n", " AutoTokenizer,\n", " DataCollatorForLanguageModeling,\n", " Trainer,\n", " TrainingArguments\n", ")\n", "from datasets import load_from_disk\n", "import math\n", "import torch\n", "\n", "# โœ… Load the tokenized dataset (use a subset for fast eval)\n", "dataset = load_from_disk(\"/kaggle/working/tokenized_bis_dataset\")\n", "eval_dataset = dataset.shuffle(seed=42).select(range(10000)) # adjust size if needed\n", "\n", "# โœ… Common data collator for both models\n", "def get_data_collator(tokenizer):\n", " return DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", " )\n", "\n", "# ๐Ÿ” Evaluation function\n", "def evaluate_perplexity(model_path, label):\n", " print(f\"\\n๐Ÿ“Š Evaluating: {label}\")\n", " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", " model = AutoModelForMaskedLM.from_pretrained(model_path).to(\"cuda\")\n", "\n", " collator = get_data_collator(tokenizer)\n", " args = TrainingArguments(\n", " output_dir=\"/kaggle/working/tmp_eval_\" + label.replace(\"-\", \"_\"),\n", " per_device_eval_batch_size=16,\n", " report_to=\"none\"\n", " )\n", "\n", " trainer = Trainer(\n", " model=model,\n", " args=args,\n", " eval_dataset=eval_dataset,\n", " data_collator=collator,\n", " tokenizer=tokenizer\n", " )\n", "\n", " metrics = trainer.evaluate()\n", " loss = metrics[\"eval_loss\"]\n", " perplexity = math.exp(loss)\n", "\n", " print(f\"๐Ÿ“‰ Eval Loss: {loss:.4f}\")\n", " print(f\"๐Ÿ“ Perplexity: {perplexity:.2f}\")\n", " return perplexity\n", "\n", "# โš–๏ธ Compare both models\n", "p1 = evaluate_perplexity(\"bert-base-uncased\", \"BERT-Base\")\n", "p2 = evaluate_perplexity(\"/kaggle/working/bert-mlm-bis\", \"BIS-BERT-MLM\")\n", "\n", "# ๐Ÿ“ˆ Summary\n", "print(\"\\n๐Ÿงพ Summary:\")\n", "print(f\"โžก๏ธ BERT-Base Perplexity : {p1:.2f}\")\n", "print(f\"โžก๏ธ BIS-BERT-MLM Perplexity : {p2:.2f}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Manual Masked Sentence Evaluation\n", "\n", "This section tests the `cb-bert-mlm` model on 20 manually constructed masked sentences based on real central banking and financial policy language.\n", "\n", "Each sentence contains a single `[MASK]` token, and is evaluated for whether the model correctly predicts the expected token.\n", "\n", "#### Evaluation Highlights:\n", "- Sentences represent realistic use cases in financial regulation, digital currency, and monetary policy\n", "- Most mismatches were plausible paraphrases (e.g., synonyms or domain-relevant alternates)\n", "\n", "The test demonstrates the model's strong contextual understanding of domain-specific language, particularly in predicting terminology used in central bank communication.\n", "Results are displayed in a tabular format showing the masked sentence, expected token, predicted token, and whether it matched exactly.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2025-07-20T02:15:43.730657Z", "iopub.status.busy": "2025-07-20T02:15:43.730330Z", "iopub.status.idle": "2025-07-20T02:15:45.482523Z", "shell.execute_reply": "2025-07-20T02:15:45.481827Z", "shell.execute_reply.started": "2025-07-20T02:15:43.730635Z" }, "trusted": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
SentenceExpectedPredictedMatch?
0Central banks are exploring the potential of d...currencies##isationโŒ
1The governor highlighted the importance of mon...policypolicyโœ…
2Inflation expectations remain [MASK] anchored ...wellwellโœ…
3Cross-border [MASK] are still slow and expensive.paymentspaymentsโœ…
4Financial [MASK] is a key objective for many c...inclusionstabilityโŒ
5Stablecoins pose new [MASK] for regulators and...challengeschallengesโœ…
6Monetary [MASK] must adapt to technological in...policypolicyโœ…
7The BIS supports the development of secure dig...paymentpaymentโœ…
8Central banks need to coordinate on [MASK] fra...regulatorytheseโŒ
9Emerging markets are experiencing strong capit...inflowsflowsโŒ
10The committee emphasized the need for macropru...oversightpoliciesโŒ
11Tokenization of [MASK] could transform financi...assetsriskโŒ
12Interoperability between payment [MASK] is cru...systemssystemsโœ…
13Cybersecurity [MASK] increase with digital fin...risksrisksโœ…
14Central banks must ensure [MASK] in digital in...resiliencetrustโŒ
15The future of [MASK] may involve public and pr...moneyfinanceโŒ
16Pilot [MASK] help central banks understand new...projectsexercisesโŒ
17Legal frameworks need to [MASK] for modern fin...evolveevolveโœ…
18Foreign exchange [MASK] have remained relative...marketsreservesโŒ
19The central bank raised its key interest [MASK...raterateโœ…
\n", "
" ], "text/plain": [ " Sentence Expected Predicted \\\n", "0 Central banks are exploring the potential of d... currencies ##isation \n", "1 The governor highlighted the importance of mon... policy policy \n", "2 Inflation expectations remain [MASK] anchored ... well well \n", "3 Cross-border [MASK] are still slow and expensive. payments payments \n", "4 Financial [MASK] is a key objective for many c... inclusion stability \n", "5 Stablecoins pose new [MASK] for regulators and... challenges challenges \n", "6 Monetary [MASK] must adapt to technological in... policy policy \n", "7 The BIS supports the development of secure dig... payment payment \n", "8 Central banks need to coordinate on [MASK] fra... regulatory these \n", "9 Emerging markets are experiencing strong capit... inflows flows \n", "10 The committee emphasized the need for macropru... oversight policies \n", "11 Tokenization of [MASK] could transform financi... assets risk \n", "12 Interoperability between payment [MASK] is cru... systems systems \n", "13 Cybersecurity [MASK] increase with digital fin... risks risks \n", "14 Central banks must ensure [MASK] in digital in... resilience trust \n", "15 The future of [MASK] may involve public and pr... money finance \n", "16 Pilot [MASK] help central banks understand new... projects exercises \n", "17 Legal frameworks need to [MASK] for modern fin... evolve evolve \n", "18 Foreign exchange [MASK] have remained relative... markets reserves \n", "19 The central bank raised its key interest [MASK... rate rate \n", "\n", " Match? \n", "0 โŒ \n", "1 โœ… \n", "2 โœ… \n", "3 โœ… \n", "4 โŒ \n", "5 โœ… \n", "6 โœ… \n", "7 โœ… \n", "8 โŒ \n", "9 โŒ \n", "10 โŒ \n", "11 โŒ \n", "12 โœ… \n", "13 โœ… \n", "14 โŒ \n", "15 โŒ \n", "16 โŒ \n", "17 โœ… \n", "18 โŒ \n", "19 โœ… " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import BertForMaskedLM, BertTokenizerFast\n", "import torch\n", "import pandas as pd\n", "from IPython.display import display\n", "\n", "# 1) Load trained MLM\n", "model_path = \"/kaggle/working/bert-mlm-bis\"\n", "tokenizer = BertTokenizerFast.from_pretrained(model_path)\n", "model = BertForMaskedLM.from_pretrained(model_path)\n", "model.eval()\n", "\n", "# 2) Manual maskedโ€‘sentence test set\n", "masked_data = [\n", " (\"Central banks are exploring the potential of digital [MASK].\", \"currencies\"),\n", " (\"The governor highlighted the importance of monetary [MASK] transparency.\", \"policy\"),\n", " (\"Inflation expectations remain [MASK] anchored across most economies.\", \"well\"),\n", " (\"Cross-border [MASK] are still slow and expensive.\", \"payments\"),\n", " (\"Financial [MASK] is a key objective for many central banks.\", \"inclusion\"),\n", " (\"Stablecoins pose new [MASK] for regulators and policymakers.\", \"challenges\"),\n", " (\"Monetary [MASK] must adapt to technological innovation.\", \"policy\"),\n", " (\"The BIS supports the development of secure digital [MASK] systems.\", \"payment\"),\n", " (\"Central banks need to coordinate on [MASK] frameworks.\", \"regulatory\"),\n", " (\"Emerging markets are experiencing strong capital [MASK].\", \"inflows\"),\n", " (\"The committee emphasized the need for macroprudential [MASK].\", \"oversight\"),\n", " (\"Tokenization of [MASK] could transform financial markets.\", \"assets\"),\n", " (\"Interoperability between payment [MASK] is crucial.\", \"systems\"),\n", " (\"Cybersecurity [MASK] increase with digital financial services.\", \"risks\"),\n", " (\"Central banks must ensure [MASK] in digital infrastructure.\", \"resilience\"),\n", " (\"The future of [MASK] may involve public and private sector collaboration.\", \"money\"),\n", " (\"Pilot [MASK] help central banks understand new financial instruments.\", \"projects\"),\n", " (\"Legal frameworks need to [MASK] for modern financial technology.\", \"evolve\"),\n", " (\"Foreign exchange [MASK] have remained relatively stable.\", \"markets\"),\n", " (\"The central bank raised its key interest [MASK] by 25 basis points.\", \"rate\"),\n", "]\n", "\n", "# 3) Run predictions\n", "results = []\n", "for sent, true_word in masked_data:\n", " # encode + mask\n", " inputs = tokenizer(sent, return_tensors=\"pt\")\n", " mask_index = torch.where(inputs.input_ids[0] == tokenizer.mask_token_id)[0]\n", "\n", " # forward pass\n", " with torch.no_grad():\n", " logits = model(**inputs).logits\n", "\n", " # pick top-1\n", " token_id = logits[0, mask_index, :].argmax(dim=-1).item()\n", " pred = tokenizer.decode([token_id]).strip()\n", "\n", " results.append({\n", " \"Sentence\": sent,\n", " \"Expected\": true_word,\n", " \"Predicted\": pred,\n", " \"Match?\": \"โœ…\" if pred.lower() == true_word.lower() else \"โŒ\"\n", " })\n", "\n", "# 4) Show as DataFrame\n", "df = pd.DataFrame(results)\n", "display(df)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Top-K Accuracy Evaluation on 100,000 Randomly Masked Sentences\n", "\n", "This section evaluates the `cb-bert-mlm` model's ability to recover randomly masked words in context across **100,000 test sentences**. The procedure involves:\n", "\n", "#### Procedure:\n", "\n", "1. **Sentence Sampling** \n", " 100,000 random sentences were sampled from the BIS preprocessed dataset.\n", "\n", "2. **Masking Strategy** \n", " One random eligible word (min sentence length = 5, alphabetic tokens only) was replaced with `[MASK]` in each sentence.\n", "\n", "3. **Prediction** \n", " The model generated **Top-K token predictions** for the masked position, with `k` ranging from 1 to 20.\n", "\n", "4. **Accuracy Computation** \n", " A prediction is considered correct if the original word appears in the top-K list. The accuracy is computed as: \n", " \\[\n", " \\text{Top-k Accuracy} = \\frac{\\text{\\# correct predictions}}{\\text{total samples}} \\times 100\n", " \\]\n", "\n", "\n", "#### Results:\n", "\n", "> *Exact values are printed at the end of the cell and visualized in the curve below.*\n", "\n", "\n", "#### Top-K Accuracy Curve\n", "\n", "A line plot visualizes model performance across increasing values of `k`, showing how quickly prediction confidence saturates.\n", "\n", "This benchmark confirms the model's strong ability to predict masked financial-domain tokens, with over **90% Top-20 accuracy**.\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2025-07-20T02:42:10.239306Z", "iopub.status.busy": "2025-07-20T02:42:10.239002Z", "iopub.status.idle": "2025-07-20T03:00:16.278533Z", "shell.execute_reply": "2025-07-20T03:00:16.277747Z", "shell.execute_reply.started": "2025-07-20T02:42:10.239285Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "โš™๏ธ Using device: cuda\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "๐Ÿ” Evaluating (Topโ€‘k): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 100000/100000 [17:43<00:00, 94.01it/s]\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Top- 1 Accuracy: 63.84%\n", "Top- 2 Accuracy: 74.24%\n", "Top- 3 Accuracy: 78.77%\n", "Top- 4 Accuracy: 81.41%\n", "Top- 5 Accuracy: 83.10%\n", "Top- 6 Accuracy: 84.45%\n", "Top- 7 Accuracy: 85.43%\n", "Top- 8 Accuracy: 86.25%\n", "Top- 9 Accuracy: 86.90%\n", "Top-10 Accuracy: 87.49%\n", "Top-11 Accuracy: 87.94%\n", "Top-12 Accuracy: 88.37%\n", "Top-13 Accuracy: 88.75%\n", "Top-14 Accuracy: 89.07%\n", "Top-15 Accuracy: 89.33%\n", "Top-16 Accuracy: 89.59%\n", "Top-17 Accuracy: 89.85%\n", "Top-18 Accuracy: 90.07%\n", "Top-19 Accuracy: 90.28%\n", "Top-20 Accuracy: 90.46%\n" ] } ], "source": [ "import pandas as pd\n", "import random\n", "import torch\n", "from transformers import BertTokenizerFast, BertForMaskedLM\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 1: Load raw BIS sentences\n", "# ===============================\n", "df = pd.read_csv(\"/kaggle/input/bis-speeches/speeches_data_preprocessed.csv\")\n", "df = df[df[\"processed_text\"].notna()]\n", "df[\"processed_text\"] = df[\"processed_text\"].apply(eval)\n", "sentences = [sentence for sublist in df[\"processed_text\"] for sentence in sublist]\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 2: Setup device, model & tokenizer\n", "# ===============================\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"โš™๏ธ Using device:\", device)\n", "\n", "model_path = \"/kaggle/working/bert-mlm-bis\"\n", "tokenizer = BertTokenizerFast.from_pretrained(model_path)\n", "model = BertForMaskedLM.from_pretrained(model_path).to(device)\n", "model.eval()\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 3: Function to mask one word in a sentence\n", "# ===============================\n", "def mask_random_word(sentence):\n", " words = sentence.strip().split()\n", " if len(words) < 5:\n", " return None\n", " # choose only alphabetic tokens\n", " candidates = [i for i, w in enumerate(words) if w.isalpha()]\n", " if not candidates:\n", " return None\n", " idx = random.choice(candidates)\n", " true_word = words[idx]\n", " words[idx] = \"[MASK]\"\n", " return \" \".join(words), true_word\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 4: Generate 100,000 masked test samples\n", "# ===============================\n", "masked_samples = []\n", "for sent in random.sample(sentences, len(sentences)):\n", " pair = mask_random_word(sent)\n", " if pair:\n", " masked_samples.append(pair)\n", " if len(masked_samples) >= 100000:\n", " break\n", "\n", "df_masked = pd.DataFrame(masked_samples, columns=[\"Sentence with [MASK]\", \"Masked Word\"])\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 5: Evaluate Topโ€‘k Accuracy\n", "# ===============================\n", "results = []\n", "max_k = 20\n", "\n", "for _, row in tqdm(df_masked.iterrows(), total=len(df_masked), desc=\"๐Ÿ” Evaluating (Topโ€‘k)\"):\n", " masked_sentence = row[\"Sentence with [MASK]\"]\n", " true_word = row[\"Masked Word\"].lower().strip()\n", "\n", " # Tokenize with truncation & padding\n", " inputs = tokenizer(\n", " masked_sentence,\n", " return_tensors=\"pt\",\n", " truncation=True,\n", " max_length=128,\n", " padding=\"max_length\"\n", " ).to(device)\n", "\n", " mask_indices = torch.where(inputs.input_ids[0] == tokenizer.mask_token_id)[0]\n", " if len(mask_indices) != 1:\n", " continue\n", " mask_idx = mask_indices.item()\n", "\n", " # Forward pass\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", "\n", " # Get topโ€‘k predictions\n", " mask_logits = logits[0, mask_idx]\n", " topk = torch.topk(mask_logits, k=max_k).indices.tolist()\n", " top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in topk]\n", "\n", " results.append({\n", " \"Masked Word\": true_word,\n", " \"Top-k Predictions\": top_tokens\n", " })\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 6: Compute Topโ€‘k Accuracy Curve\n", "# ===============================\n", "k_range = list(range(1, max_k+1))\n", "accuracies = []\n", "total = len(results)\n", "\n", "for k in k_range:\n", " correct = sum(true in preds[:k] for true, preds in \n", " [(r[\"Masked Word\"], r[\"Top-k Predictions\"]) for r in results])\n", " accuracies.append(correct/total*100)\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 7: Plot Topโ€‘k Curve\n", "# ===============================\n", "plt.figure(figsize=(10,6))\n", "plt.plot(k_range, accuracies, marker='o')\n", "plt.title(\"Topโ€‘k Accuracy Curve (BISโ€‘BERTโ€‘MLM)\", fontsize=14)\n", "plt.xlabel(\"k\", fontsize=12)\n", "plt.ylabel(\"Accuracy (%)\", fontsize=12)\n", "plt.xticks(k_range)\n", "plt.grid(True)\n", "plt.ylim(0, 100)\n", "plt.show()\n", "\n", "# ===============================\n", "# ๐Ÿ”น Step 8: Print Summary\n", "# ===============================\n", "for k, acc in zip(k_range, accuracies):\n", " print(f\"Top-{k:2d} Accuracy: {acc:5.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Corpus Statistics and Training Metadata Summary\n", "\n", "This section computes descriptive statistics for the corpus, tokenizer, and model, and documents training configurations used for pretraining `cb-bert-mlm`.\n", "\n", "These figures provide reproducibility and clarity for evaluating the scale and setup of the domain-adaptive masked language modeling process." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "โœ… Loaded tokenized dataset with 2087615 sentences.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a15e03981b9432da3ea1226c3269018", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/2087615 [00:00