abdullahzahid77 commited on
Commit
6872957
·
verified ·
1 Parent(s): 7144119

Upload code.ipynb

Browse files
Files changed (1) hide show
  1. code.ipynb +1742 -0
code.ipynb ADDED
@@ -0,0 +1,1742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "source": [
6
+ "# !pip install torch\n",
7
+ "# !pip install pandas\n",
8
+ "!pip install datasets\n",
9
+ "# !pip install scikit-learn==1.3.0 # Install scikit-learn for metrics calculation\n"
10
+ ],
11
+ "metadata": {
12
+ "colab": {
13
+ "base_uri": "https://localhost:8080/"
14
+ },
15
+ "id": "S5L56kdMJSyW",
16
+ "outputId": "c0aad11f-5182-428a-9ad5-d482c9f552f7"
17
+ },
18
+ "execution_count": 6,
19
+ "outputs": [
20
+ {
21
+ "output_type": "stream",
22
+ "name": "stdout",
23
+ "text": [
24
+ "Collecting datasets\n",
25
+ " Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)\n",
26
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n",
27
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n",
28
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
29
+ "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
30
+ " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
31
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
32
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n",
33
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.67.1)\n",
34
+ "Collecting xxhash (from datasets)\n",
35
+ " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
36
+ "Collecting multiprocess<0.70.17 (from datasets)\n",
37
+ " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
38
+ "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)\n",
39
+ " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n",
40
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.11)\n",
41
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.27.1)\n",
42
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n",
43
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n",
44
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
45
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.2)\n",
46
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
47
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.3.0)\n",
48
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
49
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
50
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
51
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
52
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n",
53
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.1)\n",
54
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n",
55
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.3.0)\n",
56
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.12.14)\n",
57
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
58
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
59
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
60
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
61
+ "Downloading datasets-3.2.0-py3-none-any.whl (480 kB)\n",
62
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
63
+ "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
64
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
65
+ "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n",
66
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
67
+ "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
68
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
69
+ "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
70
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
71
+ "\u001b[?25hInstalling collected packages: xxhash, fsspec, dill, multiprocess, datasets\n",
72
+ " Attempting uninstall: fsspec\n",
73
+ " Found existing installation: fsspec 2024.10.0\n",
74
+ " Uninstalling fsspec-2024.10.0:\n",
75
+ " Successfully uninstalled fsspec-2024.10.0\n",
76
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
77
+ "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n",
78
+ "\u001b[0mSuccessfully installed datasets-3.2.0 dill-0.3.8 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0\n"
79
+ ]
80
+ }
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 7,
86
+ "metadata": {
87
+ "id": "8DBKKDMsIYmE"
88
+ },
89
+ "outputs": [],
90
+ "source": [
91
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
92
+ "import pandas as pd\n",
93
+ "from datasets import Dataset\n",
94
+ "from datasets import load_dataset\n",
95
+ "from torch.utils.data import DataLoader\n",
96
+ "from transformers import DataCollatorWithPadding\n",
97
+ "from transformers import AdamW\n",
98
+ "from transformers import Trainer, TrainingArguments\n",
99
+ "\n",
100
+ "\n",
101
+ "\n"
102
+ ]
103
+ },
104
+ {
105
+ "source": [
106
+ "import numpy as np\n",
107
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
108
+ "\n",
109
+ "def compute_metrics(pred):\n",
110
+ " \"\"\"\n",
111
+ " Computes and returns a dictionary of metrics (accuracy, precision, recall, F1-score).\n",
112
+ " \"\"\"\n",
113
+ " labels = pred.label_ids\n",
114
+ " preds = pred.predictions.argmax(axis=1)\n",
115
+ "\n",
116
+ " accuracy = accuracy_score(labels, preds)\n",
117
+ " precision = precision_score(labels, preds)\n",
118
+ " recall = recall_score(labels, preds)\n",
119
+ " f1 = f1_score(labels, preds)\n",
120
+ "\n",
121
+ " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}"
122
+ ],
123
+ "cell_type": "code",
124
+ "metadata": {
125
+ "id": "xAMF75YVSJFo"
126
+ },
127
+ "execution_count": null,
128
+ "outputs": []
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "source": [
133
+ "from google.colab import drive\n",
134
+ "drive.mount('/content/drive')"
135
+ ],
136
+ "metadata": {
137
+ "colab": {
138
+ "base_uri": "https://localhost:8080/"
139
+ },
140
+ "id": "becVBilbJLXc",
141
+ "outputId": "c720c029-c497-4792-98f4-88284fe41045"
142
+ },
143
+ "execution_count": 1,
144
+ "outputs": [
145
+ {
146
+ "output_type": "stream",
147
+ "name": "stdout",
148
+ "text": [
149
+ "Mounted at /content/drive\n"
150
+ ]
151
+ }
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "metadata": {
158
+ "id": "SBk5ob4wIYmJ"
159
+ },
160
+ "outputs": [],
161
+ "source": [
162
+ "# Load the dataset from the specified CSV file\n",
163
+ "raw_datasets = pd.read_csv(\"/content/drive/MyDrive/nlp/clickbait_data.csv\")\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {
170
+ "colab": {
171
+ "base_uri": "https://localhost:8080/"
172
+ },
173
+ "id": "PL7q5PKFIYmK",
174
+ "outputId": "d8953fc0-f80c-4ce1-aaea-adde68eef8e0"
175
+ },
176
+ "outputs": [
177
+ {
178
+ "output_type": "stream",
179
+ "name": "stdout",
180
+ "text": [
181
+ " headline clickbait\n",
182
+ "0 Should I Get Bings 1\n",
183
+ "1 Which TV Female Friend Group Do You Belong In 1\n",
184
+ "2 The New \"Star Wars: The Force Awakens\" Trailer... 1\n",
185
+ "3 This Vine Of New York On \"Celebrity Big Brothe... 1\n",
186
+ "4 A Couple Did A Stunning Photo Shoot With Their... 1\n",
187
+ "... ... ...\n",
188
+ "31995 To Make Female Hearts Flutter in Iraq, Throw a... 0\n",
189
+ "31996 British Liberal Democrat Patsy Calton, 56, die... 0\n",
190
+ "31997 Drone smartphone app to help heart attack vict... 0\n",
191
+ "31998 Netanyahu Urges Pope Benedict, in Israel, to D... 0\n",
192
+ "31999 Computer Makers Prepare to Stake Bigger Claim ... 0\n",
193
+ "\n",
194
+ "[32000 rows x 2 columns]\n"
195
+ ]
196
+ }
197
+ ],
198
+ "source": [
199
+ "print(raw_datasets)"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {
206
+ "id": "ucNw7B3lIYmK"
207
+ },
208
+ "outputs": [],
209
+ "source": [
210
+ "df = pd.DataFrame(raw_datasets, columns=[\"headline\", \"clickbait\"])\n",
211
+ "\n"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {
218
+ "colab": {
219
+ "base_uri": "https://localhost:8080/"
220
+ },
221
+ "id": "Lev9BlWZIYmK",
222
+ "outputId": "e16cc1cc-f9ee-4b5c-9ddf-42b59569e4a2"
223
+ },
224
+ "outputs": [
225
+ {
226
+ "output_type": "stream",
227
+ "name": "stdout",
228
+ "text": [
229
+ " headline clickbait\n",
230
+ "0 Filipino activist arrested for disrupting Mani... 0\n",
231
+ "1 International Board fixes soccer field size, h... 0\n",
232
+ "2 24 Rules For Women On A First Date With A Man 1\n",
233
+ "3 Political fallout from the sacking of Professo... 0\n",
234
+ "4 Which \"Clueless\" Character Are You Based On Yo... 1\n",
235
+ "... ... ...\n",
236
+ "31995 Rocket strike near hotel in Afghan capital inj... 0\n",
237
+ "31996 How Well Do You Remember The First Episode Of ... 1\n",
238
+ "31997 16 Photos From The Delhi Queer Pride Parade Th... 1\n",
239
+ "31998 33 Of The Most Canadian Sentences Ever 1\n",
240
+ "31999 Man killed after shop robbery in West Yorkshir... 0\n",
241
+ "\n",
242
+ "[32000 rows x 2 columns]\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "df = df.sample(frac=1, random_state=42).reset_index(drop=True)\n",
248
+ "\n",
249
+ "# Display the first few rows of the shuffled DataFrame\n",
250
+ "print(df)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {
257
+ "id": "JGLQrK2oIYmL"
258
+ },
259
+ "outputs": [],
260
+ "source": [
261
+ "# Step 1: Clean text (lowercase, remove special characters, normalize spaces)\n",
262
+ "df['headline'] = df['headline'].str.lower() # Convert to lowercase\n",
263
+ "df['headline'] = df['headline'].str.replace(r'[^a-z0-9\\s]', '', regex=True) # Remove special characters\n",
264
+ "df['headline'] = df['headline'].str.replace(r'\\s+', ' ', regex=True) # Normalize multiple spaces\n",
265
+ "\n",
266
+ "# Step 2: Split into DatasetDict format (to be used with Hugging Face's `datasets` library)\n",
267
+ "# Convert the cleaned dataframe into a Dataset object for easy tokenization with Hugging Face\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {
274
+ "colab": {
275
+ "base_uri": "https://localhost:8080/",
276
+ "height": 423
277
+ },
278
+ "id": "V0PLQ45hIYmL",
279
+ "outputId": "92045cbe-2f40-4a8a-9141-41b0f3c03a46"
280
+ },
281
+ "outputs": [
282
+ {
283
+ "output_type": "execute_result",
284
+ "data": {
285
+ "text/plain": [
286
+ " headline clickbait\n",
287
+ "0 filipino activist arrested for disrupting mani... 0\n",
288
+ "1 international board fixes soccer field size ha... 0\n",
289
+ "2 24 rules for women on a first date with a man 1\n",
290
+ "3 political fallout from the sacking of professo... 0\n",
291
+ "4 which clueless character are you based on your... 1\n",
292
+ "... ... ...\n",
293
+ "31995 rocket strike near hotel in afghan capital inj... 0\n",
294
+ "31996 how well do you remember the first episode of ... 1\n",
295
+ "31997 16 photos from the delhi queer pride parade th... 1\n",
296
+ "31998 33 of the most canadian sentences ever 1\n",
297
+ "31999 man killed after shop robbery in west yorkshir... 0\n",
298
+ "\n",
299
+ "[32000 rows x 2 columns]"
300
+ ],
301
+ "text/html": [
302
+ "\n",
303
+ " <div id=\"df-d5c6de19-cc3d-4e4a-b949-90707b35d029\" class=\"colab-df-container\">\n",
304
+ " <div>\n",
305
+ "<style scoped>\n",
306
+ " .dataframe tbody tr th:only-of-type {\n",
307
+ " vertical-align: middle;\n",
308
+ " }\n",
309
+ "\n",
310
+ " .dataframe tbody tr th {\n",
311
+ " vertical-align: top;\n",
312
+ " }\n",
313
+ "\n",
314
+ " .dataframe thead th {\n",
315
+ " text-align: right;\n",
316
+ " }\n",
317
+ "</style>\n",
318
+ "<table border=\"1\" class=\"dataframe\">\n",
319
+ " <thead>\n",
320
+ " <tr style=\"text-align: right;\">\n",
321
+ " <th></th>\n",
322
+ " <th>headline</th>\n",
323
+ " <th>clickbait</th>\n",
324
+ " </tr>\n",
325
+ " </thead>\n",
326
+ " <tbody>\n",
327
+ " <tr>\n",
328
+ " <th>0</th>\n",
329
+ " <td>filipino activist arrested for disrupting mani...</td>\n",
330
+ " <td>0</td>\n",
331
+ " </tr>\n",
332
+ " <tr>\n",
333
+ " <th>1</th>\n",
334
+ " <td>international board fixes soccer field size ha...</td>\n",
335
+ " <td>0</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <th>2</th>\n",
339
+ " <td>24 rules for women on a first date with a man</td>\n",
340
+ " <td>1</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <th>3</th>\n",
344
+ " <td>political fallout from the sacking of professo...</td>\n",
345
+ " <td>0</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <th>4</th>\n",
349
+ " <td>which clueless character are you based on your...</td>\n",
350
+ " <td>1</td>\n",
351
+ " </tr>\n",
352
+ " <tr>\n",
353
+ " <th>...</th>\n",
354
+ " <td>...</td>\n",
355
+ " <td>...</td>\n",
356
+ " </tr>\n",
357
+ " <tr>\n",
358
+ " <th>31995</th>\n",
359
+ " <td>rocket strike near hotel in afghan capital inj...</td>\n",
360
+ " <td>0</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <th>31996</th>\n",
364
+ " <td>how well do you remember the first episode of ...</td>\n",
365
+ " <td>1</td>\n",
366
+ " </tr>\n",
367
+ " <tr>\n",
368
+ " <th>31997</th>\n",
369
+ " <td>16 photos from the delhi queer pride parade th...</td>\n",
370
+ " <td>1</td>\n",
371
+ " </tr>\n",
372
+ " <tr>\n",
373
+ " <th>31998</th>\n",
374
+ " <td>33 of the most canadian sentences ever</td>\n",
375
+ " <td>1</td>\n",
376
+ " </tr>\n",
377
+ " <tr>\n",
378
+ " <th>31999</th>\n",
379
+ " <td>man killed after shop robbery in west yorkshir...</td>\n",
380
+ " <td>0</td>\n",
381
+ " </tr>\n",
382
+ " </tbody>\n",
383
+ "</table>\n",
384
+ "<p>32000 rows × 2 columns</p>\n",
385
+ "</div>\n",
386
+ " <div class=\"colab-df-buttons\">\n",
387
+ "\n",
388
+ " <div class=\"colab-df-container\">\n",
389
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d5c6de19-cc3d-4e4a-b949-90707b35d029')\"\n",
390
+ " title=\"Convert this dataframe to an interactive table.\"\n",
391
+ " style=\"display:none;\">\n",
392
+ "\n",
393
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
394
+ " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
395
+ " </svg>\n",
396
+ " </button>\n",
397
+ "\n",
398
+ " <style>\n",
399
+ " .colab-df-container {\n",
400
+ " display:flex;\n",
401
+ " gap: 12px;\n",
402
+ " }\n",
403
+ "\n",
404
+ " .colab-df-convert {\n",
405
+ " background-color: #E8F0FE;\n",
406
+ " border: none;\n",
407
+ " border-radius: 50%;\n",
408
+ " cursor: pointer;\n",
409
+ " display: none;\n",
410
+ " fill: #1967D2;\n",
411
+ " height: 32px;\n",
412
+ " padding: 0 0 0 0;\n",
413
+ " width: 32px;\n",
414
+ " }\n",
415
+ "\n",
416
+ " .colab-df-convert:hover {\n",
417
+ " background-color: #E2EBFA;\n",
418
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
419
+ " fill: #174EA6;\n",
420
+ " }\n",
421
+ "\n",
422
+ " .colab-df-buttons div {\n",
423
+ " margin-bottom: 4px;\n",
424
+ " }\n",
425
+ "\n",
426
+ " [theme=dark] .colab-df-convert {\n",
427
+ " background-color: #3B4455;\n",
428
+ " fill: #D2E3FC;\n",
429
+ " }\n",
430
+ "\n",
431
+ " [theme=dark] .colab-df-convert:hover {\n",
432
+ " background-color: #434B5C;\n",
433
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
434
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
435
+ " fill: #FFFFFF;\n",
436
+ " }\n",
437
+ " </style>\n",
438
+ "\n",
439
+ " <script>\n",
440
+ " const buttonEl =\n",
441
+ " document.querySelector('#df-d5c6de19-cc3d-4e4a-b949-90707b35d029 button.colab-df-convert');\n",
442
+ " buttonEl.style.display =\n",
443
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
444
+ "\n",
445
+ " async function convertToInteractive(key) {\n",
446
+ " const element = document.querySelector('#df-d5c6de19-cc3d-4e4a-b949-90707b35d029');\n",
447
+ " const dataTable =\n",
448
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
449
+ " [key], {});\n",
450
+ " if (!dataTable) return;\n",
451
+ "\n",
452
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
453
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
454
+ " + ' to learn more about interactive tables.';\n",
455
+ " element.innerHTML = '';\n",
456
+ " dataTable['output_type'] = 'display_data';\n",
457
+ " await google.colab.output.renderOutput(dataTable, element);\n",
458
+ " const docLink = document.createElement('div');\n",
459
+ " docLink.innerHTML = docLinkHtml;\n",
460
+ " element.appendChild(docLink);\n",
461
+ " }\n",
462
+ " </script>\n",
463
+ " </div>\n",
464
+ "\n",
465
+ "\n",
466
+ "<div id=\"df-ee00eaa1-2597-4b9f-bfbe-34ee45fe0b29\">\n",
467
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-ee00eaa1-2597-4b9f-bfbe-34ee45fe0b29')\"\n",
468
+ " title=\"Suggest charts\"\n",
469
+ " style=\"display:none;\">\n",
470
+ "\n",
471
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
472
+ " width=\"24px\">\n",
473
+ " <g>\n",
474
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
475
+ " </g>\n",
476
+ "</svg>\n",
477
+ " </button>\n",
478
+ "\n",
479
+ "<style>\n",
480
+ " .colab-df-quickchart {\n",
481
+ " --bg-color: #E8F0FE;\n",
482
+ " --fill-color: #1967D2;\n",
483
+ " --hover-bg-color: #E2EBFA;\n",
484
+ " --hover-fill-color: #174EA6;\n",
485
+ " --disabled-fill-color: #AAA;\n",
486
+ " --disabled-bg-color: #DDD;\n",
487
+ " }\n",
488
+ "\n",
489
+ " [theme=dark] .colab-df-quickchart {\n",
490
+ " --bg-color: #3B4455;\n",
491
+ " --fill-color: #D2E3FC;\n",
492
+ " --hover-bg-color: #434B5C;\n",
493
+ " --hover-fill-color: #FFFFFF;\n",
494
+ " --disabled-bg-color: #3B4455;\n",
495
+ " --disabled-fill-color: #666;\n",
496
+ " }\n",
497
+ "\n",
498
+ " .colab-df-quickchart {\n",
499
+ " background-color: var(--bg-color);\n",
500
+ " border: none;\n",
501
+ " border-radius: 50%;\n",
502
+ " cursor: pointer;\n",
503
+ " display: none;\n",
504
+ " fill: var(--fill-color);\n",
505
+ " height: 32px;\n",
506
+ " padding: 0;\n",
507
+ " width: 32px;\n",
508
+ " }\n",
509
+ "\n",
510
+ " .colab-df-quickchart:hover {\n",
511
+ " background-color: var(--hover-bg-color);\n",
512
+ " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
513
+ " fill: var(--button-hover-fill-color);\n",
514
+ " }\n",
515
+ "\n",
516
+ " .colab-df-quickchart-complete:disabled,\n",
517
+ " .colab-df-quickchart-complete:disabled:hover {\n",
518
+ " background-color: var(--disabled-bg-color);\n",
519
+ " fill: var(--disabled-fill-color);\n",
520
+ " box-shadow: none;\n",
521
+ " }\n",
522
+ "\n",
523
+ " .colab-df-spinner {\n",
524
+ " border: 2px solid var(--fill-color);\n",
525
+ " border-color: transparent;\n",
526
+ " border-bottom-color: var(--fill-color);\n",
527
+ " animation:\n",
528
+ " spin 1s steps(1) infinite;\n",
529
+ " }\n",
530
+ "\n",
531
+ " @keyframes spin {\n",
532
+ " 0% {\n",
533
+ " border-color: transparent;\n",
534
+ " border-bottom-color: var(--fill-color);\n",
535
+ " border-left-color: var(--fill-color);\n",
536
+ " }\n",
537
+ " 20% {\n",
538
+ " border-color: transparent;\n",
539
+ " border-left-color: var(--fill-color);\n",
540
+ " border-top-color: var(--fill-color);\n",
541
+ " }\n",
542
+ " 30% {\n",
543
+ " border-color: transparent;\n",
544
+ " border-left-color: var(--fill-color);\n",
545
+ " border-top-color: var(--fill-color);\n",
546
+ " border-right-color: var(--fill-color);\n",
547
+ " }\n",
548
+ " 40% {\n",
549
+ " border-color: transparent;\n",
550
+ " border-right-color: var(--fill-color);\n",
551
+ " border-top-color: var(--fill-color);\n",
552
+ " }\n",
553
+ " 60% {\n",
554
+ " border-color: transparent;\n",
555
+ " border-right-color: var(--fill-color);\n",
556
+ " }\n",
557
+ " 80% {\n",
558
+ " border-color: transparent;\n",
559
+ " border-right-color: var(--fill-color);\n",
560
+ " border-bottom-color: var(--fill-color);\n",
561
+ " }\n",
562
+ " 90% {\n",
563
+ " border-color: transparent;\n",
564
+ " border-bottom-color: var(--fill-color);\n",
565
+ " }\n",
566
+ " }\n",
567
+ "</style>\n",
568
+ "\n",
569
+ " <script>\n",
570
+ " async function quickchart(key) {\n",
571
+ " const quickchartButtonEl =\n",
572
+ " document.querySelector('#' + key + ' button');\n",
573
+ " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
574
+ " quickchartButtonEl.classList.add('colab-df-spinner');\n",
575
+ " try {\n",
576
+ " const charts = await google.colab.kernel.invokeFunction(\n",
577
+ " 'suggestCharts', [key], {});\n",
578
+ " } catch (error) {\n",
579
+ " console.error('Error during call to suggestCharts:', error);\n",
580
+ " }\n",
581
+ " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
582
+ " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
583
+ " }\n",
584
+ " (() => {\n",
585
+ " let quickchartButtonEl =\n",
586
+ " document.querySelector('#df-ee00eaa1-2597-4b9f-bfbe-34ee45fe0b29 button');\n",
587
+ " quickchartButtonEl.style.display =\n",
588
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
589
+ " })();\n",
590
+ " </script>\n",
591
+ "</div>\n",
592
+ "\n",
593
+ " <div id=\"id_dd256d94-7a78-4a73-8f42-8ee61fff3d06\">\n",
594
+ " <style>\n",
595
+ " .colab-df-generate {\n",
596
+ " background-color: #E8F0FE;\n",
597
+ " border: none;\n",
598
+ " border-radius: 50%;\n",
599
+ " cursor: pointer;\n",
600
+ " display: none;\n",
601
+ " fill: #1967D2;\n",
602
+ " height: 32px;\n",
603
+ " padding: 0 0 0 0;\n",
604
+ " width: 32px;\n",
605
+ " }\n",
606
+ "\n",
607
+ " .colab-df-generate:hover {\n",
608
+ " background-color: #E2EBFA;\n",
609
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
610
+ " fill: #174EA6;\n",
611
+ " }\n",
612
+ "\n",
613
+ " [theme=dark] .colab-df-generate {\n",
614
+ " background-color: #3B4455;\n",
615
+ " fill: #D2E3FC;\n",
616
+ " }\n",
617
+ "\n",
618
+ " [theme=dark] .colab-df-generate:hover {\n",
619
+ " background-color: #434B5C;\n",
620
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
621
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
622
+ " fill: #FFFFFF;\n",
623
+ " }\n",
624
+ " </style>\n",
625
+ " <button class=\"colab-df-generate\" onclick=\"generateWithVariable('df')\"\n",
626
+ " title=\"Generate code using this dataframe.\"\n",
627
+ " style=\"display:none;\">\n",
628
+ "\n",
629
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
630
+ " width=\"24px\">\n",
631
+ " <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
632
+ " </svg>\n",
633
+ " </button>\n",
634
+ " <script>\n",
635
+ " (() => {\n",
636
+ " const buttonEl =\n",
637
+ " document.querySelector('#id_dd256d94-7a78-4a73-8f42-8ee61fff3d06 button.colab-df-generate');\n",
638
+ " buttonEl.style.display =\n",
639
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
640
+ "\n",
641
+ " buttonEl.onclick = () => {\n",
642
+ " google.colab.notebook.generateWithVariable('df');\n",
643
+ " }\n",
644
+ " })();\n",
645
+ " </script>\n",
646
+ " </div>\n",
647
+ "\n",
648
+ " </div>\n",
649
+ " </div>\n"
650
+ ],
651
+ "application/vnd.google.colaboratory.intrinsic+json": {
652
+ "type": "dataframe",
653
+ "variable_name": "df",
654
+ "summary": "{\n \"name\": \"df\",\n \"rows\": 32000,\n \"fields\": [\n {\n \"column\": \"headline\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 31998,\n \"samples\": [\n \"prolific television producer aaron spelling dies at 83\",\n \"consumer prices remained steady in april\",\n \"jon hamm playing baseball will soothe the shit outta you\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"clickbait\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 1,\n \"num_unique_values\": 2,\n \"samples\": [\n 1,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
655
+ }
656
+ },
657
+ "metadata": {},
658
+ "execution_count": 51
659
+ }
660
+ ],
661
+ "source": [
662
+ "df"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": null,
668
+ "metadata": {
669
+ "colab": {
670
+ "base_uri": "https://localhost:8080/"
671
+ },
672
+ "id": "oB8E9zfkIYmL",
673
+ "outputId": "0250cb50-7efa-4f87-f7ed-a5fca2cecce4"
674
+ },
675
+ "outputs": [
676
+ {
677
+ "output_type": "execute_result",
678
+ "data": {
679
+ "text/plain": [
680
+ "Dataset({\n",
681
+ " features: ['headline', 'clickbait'],\n",
682
+ " num_rows: 32000\n",
683
+ "})"
684
+ ]
685
+ },
686
+ "metadata": {},
687
+ "execution_count": 52
688
+ }
689
+ ],
690
+ "source": [
691
+ "pre_processed_dataset = Dataset.from_pandas(df)\n",
692
+ "pre_processed_dataset"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "metadata": {
699
+ "colab": {
700
+ "base_uri": "https://localhost:8080/"
701
+ },
702
+ "id": "oxKb3Y2jIYmM",
703
+ "outputId": "83845fae-7dbe-4857-c713-4359212078db"
704
+ },
705
+ "outputs": [
706
+ {
707
+ "output_type": "stream",
708
+ "name": "stderr",
709
+ "text": [
710
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
711
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
712
+ ]
713
+ }
714
+ ],
715
+ "source": [
716
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') #Downloads the base version of BERT trained on lowercase English text (e.g., \"hello\" and \"Hello\" are treated the same).\n",
717
+ "model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) #Configures the model to handle a binary classification problem \"clickbait\" vs. \"non-clickbait\".\n"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "code",
722
+ "execution_count": null,
723
+ "metadata": {
724
+ "id": "HRzOKDl_IYmM"
725
+ },
726
+ "outputs": [],
727
+ "source": [
728
+ "def tokenize_function(batch):\n",
729
+ " # Tokenize the 'headline' column\n",
730
+ " return tokenizer(batch['headline'], truncation=True, padding=True, max_length=512)"
731
+ ]
732
+ },
733
+ {
734
+ "cell_type": "code",
735
+ "execution_count": null,
736
+ "metadata": {
737
+ "colab": {
738
+ "base_uri": "https://localhost:8080/",
739
+ "height": 49,
740
+ "referenced_widgets": [
741
+ "dad32a0edd9841c9ab1ce052f0efd994",
742
+ "41099ccad51745aeb06c52636ff646ba",
743
+ "00f15ecb12df4d36aed9ef3135a71f08",
744
+ "a1121eb967374766bc15ed290c2dc09d",
745
+ "3b3db9c5766a4fc1908f82d5d86ee86d",
746
+ "86c7d5da9ac94f4aa7019866cb48da30",
747
+ "12e0c1a1d2e545db8c89accffb1072fb",
748
+ "cb1f378c68e3464db6bbe1756ee68b46",
749
+ "272a34ba164e4b2a88aa1d92a42510b8",
750
+ "810299000553408ca96ed7b45a8049b8",
751
+ "93d6dfc7a47d4c76bdbeb868dc0a3369"
752
+ ]
753
+ },
754
+ "id": "2AuxGVJJIYmM",
755
+ "outputId": "4cf193fb-c5ab-4974-aa52-f05beffb4746"
756
+ },
757
+ "outputs": [
758
+ {
759
+ "output_type": "display_data",
760
+ "data": {
761
+ "text/plain": [
762
+ "Map: 0%| | 0/32000 [00:00<?, ? examples/s]"
763
+ ],
764
+ "application/vnd.jupyter.widget-view+json": {
765
+ "version_major": 2,
766
+ "version_minor": 0,
767
+ "model_id": "dad32a0edd9841c9ab1ce052f0efd994"
768
+ }
769
+ },
770
+ "metadata": {}
771
+ }
772
+ ],
773
+ "source": [
774
+ "tokenized_datasets = pre_processed_dataset.map(tokenize_function, batched=True)"
775
+ ]
776
+ },
777
+ {
778
+ "cell_type": "code",
779
+ "execution_count": null,
780
+ "metadata": {
781
+ "colab": {
782
+ "base_uri": "https://localhost:8080/"
783
+ },
784
+ "id": "_mfTpdzeIYmN",
785
+ "outputId": "f32b8858-a65c-451d-ecc4-b8d9bf2296f1"
786
+ },
787
+ "outputs": [
788
+ {
789
+ "output_type": "stream",
790
+ "name": "stdout",
791
+ "text": [
792
+ "Dataset({\n",
793
+ " features: ['headline', 'clickbait', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
794
+ " num_rows: 32000\n",
795
+ "})\n"
796
+ ]
797
+ }
798
+ ],
799
+ "source": [
800
+ "print(tokenized_datasets)"
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "source": [
806
+ "tokenized_datasets = tokenized_datasets.rename_column('clickbait', 'labels')"
807
+ ],
808
+ "metadata": {
809
+ "id": "YoXnzpMjNV9N"
810
+ },
811
+ "execution_count": null,
812
+ "outputs": []
813
+ },
814
+ {
815
+ "cell_type": "code",
816
+ "execution_count": null,
817
+ "metadata": {
818
+ "id": "RaVGS93BIYmN"
819
+ },
820
+ "outputs": [],
821
+ "source": [
822
+ "split_datasets = tokenized_datasets.train_test_split(test_size=0.2)\n",
823
+ "\n",
824
+ "# Further split the train data into train and validation (80% train, 20% validation)\n",
825
+ "train_val_split = split_datasets['train'].train_test_split(test_size=0.2)\n",
826
+ "\n",
827
+ "# Access the splits\n",
828
+ "train_dataset = train_val_split['train']\n",
829
+ "validation_dataset = train_val_split['test']\n",
830
+ "test_dataset = split_datasets['test']"
831
+ ]
832
+ },
833
+ {
834
+ "cell_type": "code",
835
+ "execution_count": null,
836
+ "metadata": {
837
+ "colab": {
838
+ "base_uri": "https://localhost:8080/"
839
+ },
840
+ "id": "8aGzBaiwIYmN",
841
+ "outputId": "bb8e0f0d-caef-4917-ceb0-c2119e387e63"
842
+ },
843
+ "outputs": [
844
+ {
845
+ "output_type": "execute_result",
846
+ "data": {
847
+ "text/plain": [
848
+ "DatasetDict({\n",
849
+ " train: Dataset({\n",
850
+ " features: ['headline', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
851
+ " num_rows: 25600\n",
852
+ " })\n",
853
+ " test: Dataset({\n",
854
+ " features: ['headline', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
855
+ " num_rows: 6400\n",
856
+ " })\n",
857
+ "})"
858
+ ]
859
+ },
860
+ "metadata": {},
861
+ "execution_count": 59
862
+ }
863
+ ],
864
+ "source": [
865
+ "split_datasets"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": null,
871
+ "metadata": {
872
+ "colab": {
873
+ "base_uri": "https://localhost:8080/"
874
+ },
875
+ "id": "EtWlbJzOIYmN",
876
+ "outputId": "512de166-e642-4099-f390-dc687be9a44a"
877
+ },
878
+ "outputs": [
879
+ {
880
+ "output_type": "execute_result",
881
+ "data": {
882
+ "text/plain": [
883
+ "DatasetDict({\n",
884
+ " train: Dataset({\n",
885
+ " features: ['headline', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
886
+ " num_rows: 20480\n",
887
+ " })\n",
888
+ " test: Dataset({\n",
889
+ " features: ['headline', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
890
+ " num_rows: 5120\n",
891
+ " })\n",
892
+ "})"
893
+ ]
894
+ },
895
+ "metadata": {},
896
+ "execution_count": 60
897
+ }
898
+ ],
899
+ "source": [
900
+ "train_val_split"
901
+ ]
902
+ },
903
+ {
904
+ "cell_type": "code",
905
+ "execution_count": null,
906
+ "metadata": {
907
+ "colab": {
908
+ "base_uri": "https://localhost:8080/"
909
+ },
910
+ "id": "eWrFzffIIYmO",
911
+ "outputId": "3bed5240-9259-4ee5-c063-2a33c1c96920"
912
+ },
913
+ "outputs": [
914
+ {
915
+ "output_type": "execute_result",
916
+ "data": {
917
+ "text/plain": [
918
+ "Dataset({\n",
919
+ " features: ['headline', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
920
+ " num_rows: 6400\n",
921
+ "})"
922
+ ]
923
+ },
924
+ "metadata": {},
925
+ "execution_count": 61
926
+ }
927
+ ],
928
+ "source": [
929
+ "test_dataset"
930
+ ]
931
+ },
932
+ {
933
+ "cell_type": "code",
934
+ "execution_count": null,
935
+ "metadata": {
936
+ "id": "POiHrmeeIYmO"
937
+ },
938
+ "outputs": [],
939
+ "source": [
940
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
941
+ "train_dataloader = DataLoader(\n",
942
+ " train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator\n",
943
+ ")"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": null,
949
+ "metadata": {
950
+ "colab": {
951
+ "base_uri": "https://localhost:8080/"
952
+ },
953
+ "id": "hOOVVNPwIYmO",
954
+ "outputId": "b6756add-09fe-4a2c-a0c2-751f8fc11194"
955
+ },
956
+ "outputs": [
957
+ {
958
+ "output_type": "stream",
959
+ "name": "stderr",
960
+ "text": [
961
+ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
962
+ " warnings.warn(\n"
963
+ ]
964
+ }
965
+ ],
966
+ "source": [
967
+ "#AdamW: A type of optimizer that updates the model’s weights during training to minimize the loss.\n",
968
+ "#lr=5e-5: Sets the learning rate to 0.00005, controlling how much the model adjusts weights during training.\n",
969
+ "\n",
970
+ "optimizer = AdamW(model.parameters(), lr=5e-5)"
971
+ ]
972
+ },
973
+ {
974
+ "cell_type": "code",
975
+ "source": [
976
+ "import numpy as np\n",
977
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
978
+ "\n",
979
+ "def compute_metrics(pred):\n",
980
+ " \"\"\"\n",
981
+ " Computes and returns a dictionary of metrics (accuracy, precision, recall, F1-score).\n",
982
+ " \"\"\"\n",
983
+ " labels = pred.label_ids\n",
984
+ " preds = pred.predictions.argmax(axis=1)\n",
985
+ "\n",
986
+ " accuracy = accuracy_score(labels, preds)\n",
987
+ " precision = precision_score(labels, preds)\n",
988
+ " recall = recall_score(labels, preds)\n",
989
+ " f1 = f1_score(labels, preds)\n",
990
+ "\n",
991
+ " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}"
992
+ ],
993
+ "metadata": {
994
+ "id": "F9QmrZ6dSRCk"
995
+ },
996
+ "execution_count": null,
997
+ "outputs": []
998
+ },
999
+ {
1000
+ "cell_type": "code",
1001
+ "execution_count": null,
1002
+ "metadata": {
1003
+ "colab": {
1004
+ "base_uri": "https://localhost:8080/"
1005
+ },
1006
+ "id": "Nf8E4KTSIYmP",
1007
+ "outputId": "22929b6e-85d7-4f39-8c6a-2971fbafbc6b"
1008
+ },
1009
+ "outputs": [
1010
+ {
1011
+ "output_type": "stream",
1012
+ "name": "stderr",
1013
+ "text": [
1014
+ "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
1015
+ " warnings.warn(\n"
1016
+ ]
1017
+ }
1018
+ ],
1019
+ "source": [
1020
+ "# output_dir: Saves the trained model and logs to a directory named \"results.\"\n",
1021
+ "# evaluation_strategy=\"epoch\": Evaluates the model after every epoch (one pass through the dataset).\n",
1022
+ "# learning_rate: Sets the learning rate to 0.00002.\n",
1023
+ "# num_train_epochs=3: Specifies 3 training iterations through the dataset.\n",
1024
+ "# weight_decay=0.01: Prevents overfitting by slightly penalizing large model weights.\n",
1025
+ "\n",
1026
+ "training_args = TrainingArguments(\n",
1027
+ " output_dir=\"/content/drive/MyDrive/nlp/results\",\n",
1028
+ " evaluation_strategy=\"epoch\",\n",
1029
+ " learning_rate=2e-5,\n",
1030
+ " per_device_train_batch_size=16,\n",
1031
+ " per_device_eval_batch_size=16,\n",
1032
+ " num_train_epochs=3,\n",
1033
+ " weight_decay=0.01,\n",
1034
+ ")\n"
1035
+ ]
1036
+ },
1037
+ {
1038
+ "cell_type": "code",
1039
+ "execution_count": null,
1040
+ "metadata": {
1041
+ "colab": {
1042
+ "base_uri": "https://localhost:8080/"
1043
+ },
1044
+ "id": "-UvYmvbjIYmP",
1045
+ "outputId": "1c2f2a41-2cdc-45b7-9a38-8c87975973cd"
1046
+ },
1047
+ "outputs": [
1048
+ {
1049
+ "output_type": "stream",
1050
+ "name": "stderr",
1051
+ "text": [
1052
+ "<ipython-input-66-8ac8eaf1a160>:5: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
1053
+ " trainer = Trainer(\n"
1054
+ ]
1055
+ }
1056
+ ],
1057
+ "source": [
1058
+ "# Creates a Trainer object that automates:\n",
1059
+ "# Training: Feeds the training dataset into the model.\n",
1060
+ "# Evaluation: Tests the model's performance on the validation dataset.\n",
1061
+ "\n",
1062
+ "trainer = Trainer(\n",
1063
+ " model=model,\n",
1064
+ " args=training_args,\n",
1065
+ " train_dataset=train_dataset,\n",
1066
+ " eval_dataset=validation_dataset,\n",
1067
+ " tokenizer=tokenizer,\n",
1068
+ " data_collator=data_collator,\n",
1069
+ " compute_metrics=compute_metrics\n",
1070
+ "\n",
1071
+ ")\n"
1072
+ ]
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "source": [
1077
+ "trainer.train()"
1078
+ ],
1079
+ "metadata": {
1080
+ "colab": {
1081
+ "base_uri": "https://localhost:8080/",
1082
+ "height": 239
1083
+ },
1084
+ "id": "qz65S7xdKKJi",
1085
+ "outputId": "5a2c06d6-89e3-47f3-cb0c-3f73132f564c"
1086
+ },
1087
+ "execution_count": null,
1088
+ "outputs": [
1089
+ {
1090
+ "output_type": "display_data",
1091
+ "data": {
1092
+ "text/plain": [
1093
+ "<IPython.core.display.HTML object>"
1094
+ ],
1095
+ "text/html": [
1096
+ "\n",
1097
+ " <div>\n",
1098
+ " \n",
1099
+ " <progress value='3840' max='3840' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1100
+ " [3840/3840 09:55, Epoch 3/3]\n",
1101
+ " </div>\n",
1102
+ " <table border=\"1\" class=\"dataframe\">\n",
1103
+ " <thead>\n",
1104
+ " <tr style=\"text-align: left;\">\n",
1105
+ " <th>Epoch</th>\n",
1106
+ " <th>Training Loss</th>\n",
1107
+ " <th>Validation Loss</th>\n",
1108
+ " <th>Accuracy</th>\n",
1109
+ " <th>Precision</th>\n",
1110
+ " <th>Recall</th>\n",
1111
+ " <th>F1</th>\n",
1112
+ " </tr>\n",
1113
+ " </thead>\n",
1114
+ " <tbody>\n",
1115
+ " <tr>\n",
1116
+ " <td>1</td>\n",
1117
+ " <td>0.062700</td>\n",
1118
+ " <td>0.051676</td>\n",
1119
+ " <td>0.986133</td>\n",
1120
+ " <td>0.989864</td>\n",
1121
+ " <td>0.982585</td>\n",
1122
+ " <td>0.986211</td>\n",
1123
+ " </tr>\n",
1124
+ " <tr>\n",
1125
+ " <td>2</td>\n",
1126
+ " <td>0.018100</td>\n",
1127
+ " <td>0.069427</td>\n",
1128
+ " <td>0.986328</td>\n",
1129
+ " <td>0.981979</td>\n",
1130
+ " <td>0.991099</td>\n",
1131
+ " <td>0.986518</td>\n",
1132
+ " </tr>\n",
1133
+ " <tr>\n",
1134
+ " <td>3</td>\n",
1135
+ " <td>0.005600</td>\n",
1136
+ " <td>0.069584</td>\n",
1137
+ " <td>0.989258</td>\n",
1138
+ " <td>0.990306</td>\n",
1139
+ " <td>0.988390</td>\n",
1140
+ " <td>0.989347</td>\n",
1141
+ " </tr>\n",
1142
+ " </tbody>\n",
1143
+ "</table><p>"
1144
+ ]
1145
+ },
1146
+ "metadata": {}
1147
+ },
1148
+ {
1149
+ "output_type": "execute_result",
1150
+ "data": {
1151
+ "text/plain": [
1152
+ "TrainOutput(global_step=3840, training_loss=0.03588072238489985, metrics={'train_runtime': 595.5357, 'train_samples_per_second': 103.168, 'train_steps_per_second': 6.448, 'total_flos': 1035958669377600.0, 'train_loss': 0.03588072238489985, 'epoch': 3.0})"
1153
+ ]
1154
+ },
1155
+ "metadata": {},
1156
+ "execution_count": 67
1157
+ }
1158
+ ]
1159
+ },
1160
+ {
1161
+ "cell_type": "code",
1162
+ "source": [
1163
+ "model.save_pretrained(\"/content/drive/MyDrive/nlp/fine_tuned_bert\")\n",
1164
+ "tokenizer.save_pretrained(\"/content/drive/MyDrive/nlp/fine_tuned_bert\")"
1165
+ ],
1166
+ "metadata": {
1167
+ "colab": {
1168
+ "base_uri": "https://localhost:8080/"
1169
+ },
1170
+ "id": "d91uceVJKMtJ",
1171
+ "outputId": "f47a45e0-ac5a-4be7-f6d3-5f04c431b9c7"
1172
+ },
1173
+ "execution_count": null,
1174
+ "outputs": [
1175
+ {
1176
+ "output_type": "execute_result",
1177
+ "data": {
1178
+ "text/plain": [
1179
+ "('/content/drive/MyDrive/nlp/fine_tuned_bert/tokenizer_config.json',\n",
1180
+ " '/content/drive/MyDrive/nlp/fine_tuned_bert/special_tokens_map.json',\n",
1181
+ " '/content/drive/MyDrive/nlp/fine_tuned_bert/vocab.txt',\n",
1182
+ " '/content/drive/MyDrive/nlp/fine_tuned_bert/added_tokens.json')"
1183
+ ]
1184
+ },
1185
+ "metadata": {},
1186
+ "execution_count": 68
1187
+ }
1188
+ ]
1189
+ },
1190
+ {
1191
+ "cell_type": "code",
1192
+ "source": [
1193
+ "trainer.save_model(\"/content/drive/MyDrive/nlp/api_saved_bert\")"
1194
+ ],
1195
+ "metadata": {
1196
+ "id": "zjeFrP65feAA"
1197
+ },
1198
+ "execution_count": null,
1199
+ "outputs": []
1200
+ },
1201
+ {
1202
+ "cell_type": "code",
1203
+ "source": [
1204
+ "results = trainer.evaluate(eval_dataset=test_dataset)\n",
1205
+ "print(results)"
1206
+ ],
1207
+ "metadata": {
1208
+ "colab": {
1209
+ "base_uri": "https://localhost:8080/",
1210
+ "height": 74
1211
+ },
1212
+ "id": "z-Z_oX4NQy64",
1213
+ "outputId": "9bebaf03-9b09-4e31-b819-48378162cc91"
1214
+ },
1215
+ "execution_count": null,
1216
+ "outputs": [
1217
+ {
1218
+ "output_type": "display_data",
1219
+ "data": {
1220
+ "text/plain": [
1221
+ "<IPython.core.display.HTML object>"
1222
+ ],
1223
+ "text/html": [
1224
+ "\n",
1225
+ " <div>\n",
1226
+ " \n",
1227
+ " <progress value='400' max='400' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1228
+ " [400/400 00:14]\n",
1229
+ " </div>\n",
1230
+ " "
1231
+ ]
1232
+ },
1233
+ "metadata": {}
1234
+ },
1235
+ {
1236
+ "output_type": "stream",
1237
+ "name": "stdout",
1238
+ "text": [
1239
+ "{'eval_loss': 0.06335476785898209, 'eval_accuracy': 0.99046875, 'eval_precision': 0.9921826141338337, 'eval_recall': 0.9887815518853226, 'eval_f1': 0.9904791634150147, 'eval_runtime': 14.213, 'eval_samples_per_second': 450.292, 'eval_steps_per_second': 28.143, 'epoch': 3.0}\n"
1240
+ ]
1241
+ }
1242
+ ]
1243
+ },
1244
+ {
1245
+ "cell_type": "code",
1246
+ "source": [
1247
+ "import torch\n",
1248
+ "\n",
1249
+ "# Assuming you want to use the GPU\n",
1250
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
1251
+ "\n",
1252
+ "# Move your model to the device\n",
1253
+ "model.to(device)\n",
1254
+ "\n",
1255
+ "# Move your input tensors to the device\n",
1256
+ "text = \"How to get 6 pack abs in 5 days?\"\n",
1257
+ "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n",
1258
+ "\n",
1259
+ "# Now, run the inference\n",
1260
+ "outputs = model(**inputs)\n",
1261
+ "predicted_class = outputs.logits.argmax(dim=1).item()\n",
1262
+ "\n",
1263
+ "# Print the prediction\n",
1264
+ "print(\"Predicted class:\", \"Clickbait\" if predicted_class == 1 else \"Non-Clickbait\")"
1265
+ ],
1266
+ "metadata": {
1267
+ "colab": {
1268
+ "base_uri": "https://localhost:8080/"
1269
+ },
1270
+ "id": "KqChnsb6RWKG",
1271
+ "outputId": "779ac87d-672a-481d-d4a5-a7eb88ac254e"
1272
+ },
1273
+ "execution_count": null,
1274
+ "outputs": [
1275
+ {
1276
+ "output_type": "stream",
1277
+ "name": "stdout",
1278
+ "text": [
1279
+ "Predicted class: Clickbait\n"
1280
+ ]
1281
+ }
1282
+ ]
1283
+ },
1284
+ {
1285
+ "cell_type": "markdown",
1286
+ "source": [
1287
+ "##Loading trained model and testing"
1288
+ ],
1289
+ "metadata": {
1290
+ "id": "J2ayxrk7SCx4"
1291
+ }
1292
+ },
1293
+ {
1294
+ "cell_type": "code",
1295
+ "source": [
1296
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
1297
+ "import pandas as pd\n",
1298
+ "from datasets import Dataset\n",
1299
+ "from datasets import load_dataset\n",
1300
+ "from torch.utils.data import DataLoader\n",
1301
+ "from transformers import DataCollatorWithPadding\n",
1302
+ "from transformers import AdamW\n",
1303
+ "from transformers import Trainer, TrainingArguments\n",
1304
+ "import torch"
1305
+ ],
1306
+ "metadata": {
1307
+ "id": "ufJ_kSdJTB29"
1308
+ },
1309
+ "execution_count": 8,
1310
+ "outputs": []
1311
+ },
1312
+ {
1313
+ "cell_type": "code",
1314
+ "source": [
1315
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
1316
+ "\n",
1317
+ "# Load the saved model and tokenizer\n",
1318
+ "model = BertForSequenceClassification.from_pretrained(\"/content/drive/MyDrive/nlp/fine_tuned_bert\")\n",
1319
+ "tokenizer = BertTokenizer.from_pretrained(\"/content/drive/MyDrive/nlp/fine_tuned_bert\")\n"
1320
+ ],
1321
+ "metadata": {
1322
+ "colab": {
1323
+ "base_uri": "https://localhost:8080/",
1324
+ "height": 349
1325
+ },
1326
+ "id": "Ve8bgTaoSCM_",
1327
+ "outputId": "8fcfda48-209e-4468-b2f3-883d084ae0c4"
1328
+ },
1329
+ "execution_count": 10,
1330
+ "outputs": [
1331
+ {
1332
+ "output_type": "error",
1333
+ "ename": "OSError",
1334
+ "evalue": "Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /content/drive/MyDrive/nlp/fine_tuned_bert.",
1335
+ "traceback": [
1336
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1337
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
1338
+ "\u001b[0;32m<ipython-input-10-4a37b2773130>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Load the saved model and tokenizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBertForSequenceClassification\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/drive/MyDrive/nlp/fine_tuned_bert\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mtokenizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBertTokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/drive/MyDrive/nlp/fine_tuned_bert\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1339
+ "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\u001b[0m in \u001b[0;36mfrom_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3777\u001b[0m )\n\u001b[1;32m 3778\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3779\u001b[0;31m raise EnvironmentError(\n\u001b[0m\u001b[1;32m 3780\u001b[0m \u001b[0;34mf\"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3781\u001b[0m \u001b[0;34mf\" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1340
+ "\u001b[0;31mOSError\u001b[0m: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /content/drive/MyDrive/nlp/fine_tuned_bert."
1341
+ ]
1342
+ }
1343
+ ]
1344
+ },
1345
+ {
1346
+ "cell_type": "markdown",
1347
+ "source": [],
1348
+ "metadata": {
1349
+ "id": "P6PZkbpSanjH"
1350
+ }
1351
+ },
1352
+ {
1353
+ "cell_type": "code",
1354
+ "source": [
1355
+ "\n",
1356
+ "text = \"Is this a clickbait headline?\"\n",
1357
+ "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512)\n",
1358
+ "outputs = model(**inputs)\n",
1359
+ "predicted_class = outputs.logits.argmax(dim=1).item()\n",
1360
+ "\n",
1361
+ "# Print the prediction\n",
1362
+ "print(\"Predicted class:\", \"Clickbait\" if predicted_class == 1 else \"Non-Clickbait\")\n"
1363
+ ],
1364
+ "metadata": {
1365
+ "id": "Cayo8SX1Sv2V"
1366
+ },
1367
+ "execution_count": null,
1368
+ "outputs": []
1369
+ }
1370
+ ],
1371
+ "metadata": {
1372
+ "kernelspec": {
1373
+ "display_name": "Python 3",
1374
+ "name": "python3"
1375
+ },
1376
+ "language_info": {
1377
+ "codemirror_mode": {
1378
+ "name": "ipython",
1379
+ "version": 3
1380
+ },
1381
+ "file_extension": ".py",
1382
+ "mimetype": "text/x-python",
1383
+ "name": "python",
1384
+ "nbconvert_exporter": "python",
1385
+ "pygments_lexer": "ipython3",
1386
+ "version": "3.12.4"
1387
+ },
1388
+ "colab": {
1389
+ "provenance": [],
1390
+ "gpuType": "T4"
1391
+ },
1392
+ "accelerator": "GPU",
1393
+ "widgets": {
1394
+ "application/vnd.jupyter.widget-state+json": {
1395
+ "dad32a0edd9841c9ab1ce052f0efd994": {
1396
+ "model_module": "@jupyter-widgets/controls",
1397
+ "model_name": "HBoxModel",
1398
+ "model_module_version": "1.5.0",
1399
+ "state": {
1400
+ "_dom_classes": [],
1401
+ "_model_module": "@jupyter-widgets/controls",
1402
+ "_model_module_version": "1.5.0",
1403
+ "_model_name": "HBoxModel",
1404
+ "_view_count": null,
1405
+ "_view_module": "@jupyter-widgets/controls",
1406
+ "_view_module_version": "1.5.0",
1407
+ "_view_name": "HBoxView",
1408
+ "box_style": "",
1409
+ "children": [
1410
+ "IPY_MODEL_41099ccad51745aeb06c52636ff646ba",
1411
+ "IPY_MODEL_00f15ecb12df4d36aed9ef3135a71f08",
1412
+ "IPY_MODEL_a1121eb967374766bc15ed290c2dc09d"
1413
+ ],
1414
+ "layout": "IPY_MODEL_3b3db9c5766a4fc1908f82d5d86ee86d"
1415
+ }
1416
+ },
1417
+ "41099ccad51745aeb06c52636ff646ba": {
1418
+ "model_module": "@jupyter-widgets/controls",
1419
+ "model_name": "HTMLModel",
1420
+ "model_module_version": "1.5.0",
1421
+ "state": {
1422
+ "_dom_classes": [],
1423
+ "_model_module": "@jupyter-widgets/controls",
1424
+ "_model_module_version": "1.5.0",
1425
+ "_model_name": "HTMLModel",
1426
+ "_view_count": null,
1427
+ "_view_module": "@jupyter-widgets/controls",
1428
+ "_view_module_version": "1.5.0",
1429
+ "_view_name": "HTMLView",
1430
+ "description": "",
1431
+ "description_tooltip": null,
1432
+ "layout": "IPY_MODEL_86c7d5da9ac94f4aa7019866cb48da30",
1433
+ "placeholder": "​",
1434
+ "style": "IPY_MODEL_12e0c1a1d2e545db8c89accffb1072fb",
1435
+ "value": "Map: 100%"
1436
+ }
1437
+ },
1438
+ "00f15ecb12df4d36aed9ef3135a71f08": {
1439
+ "model_module": "@jupyter-widgets/controls",
1440
+ "model_name": "FloatProgressModel",
1441
+ "model_module_version": "1.5.0",
1442
+ "state": {
1443
+ "_dom_classes": [],
1444
+ "_model_module": "@jupyter-widgets/controls",
1445
+ "_model_module_version": "1.5.0",
1446
+ "_model_name": "FloatProgressModel",
1447
+ "_view_count": null,
1448
+ "_view_module": "@jupyter-widgets/controls",
1449
+ "_view_module_version": "1.5.0",
1450
+ "_view_name": "ProgressView",
1451
+ "bar_style": "success",
1452
+ "description": "",
1453
+ "description_tooltip": null,
1454
+ "layout": "IPY_MODEL_cb1f378c68e3464db6bbe1756ee68b46",
1455
+ "max": 32000,
1456
+ "min": 0,
1457
+ "orientation": "horizontal",
1458
+ "style": "IPY_MODEL_272a34ba164e4b2a88aa1d92a42510b8",
1459
+ "value": 32000
1460
+ }
1461
+ },
1462
+ "a1121eb967374766bc15ed290c2dc09d": {
1463
+ "model_module": "@jupyter-widgets/controls",
1464
+ "model_name": "HTMLModel",
1465
+ "model_module_version": "1.5.0",
1466
+ "state": {
1467
+ "_dom_classes": [],
1468
+ "_model_module": "@jupyter-widgets/controls",
1469
+ "_model_module_version": "1.5.0",
1470
+ "_model_name": "HTMLModel",
1471
+ "_view_count": null,
1472
+ "_view_module": "@jupyter-widgets/controls",
1473
+ "_view_module_version": "1.5.0",
1474
+ "_view_name": "HTMLView",
1475
+ "description": "",
1476
+ "description_tooltip": null,
1477
+ "layout": "IPY_MODEL_810299000553408ca96ed7b45a8049b8",
1478
+ "placeholder": "​",
1479
+ "style": "IPY_MODEL_93d6dfc7a47d4c76bdbeb868dc0a3369",
1480
+ "value": " 32000/32000 [00:30&lt;00:00, 1034.18 examples/s]"
1481
+ }
1482
+ },
1483
+ "3b3db9c5766a4fc1908f82d5d86ee86d": {
1484
+ "model_module": "@jupyter-widgets/base",
1485
+ "model_name": "LayoutModel",
1486
+ "model_module_version": "1.2.0",
1487
+ "state": {
1488
+ "_model_module": "@jupyter-widgets/base",
1489
+ "_model_module_version": "1.2.0",
1490
+ "_model_name": "LayoutModel",
1491
+ "_view_count": null,
1492
+ "_view_module": "@jupyter-widgets/base",
1493
+ "_view_module_version": "1.2.0",
1494
+ "_view_name": "LayoutView",
1495
+ "align_content": null,
1496
+ "align_items": null,
1497
+ "align_self": null,
1498
+ "border": null,
1499
+ "bottom": null,
1500
+ "display": null,
1501
+ "flex": null,
1502
+ "flex_flow": null,
1503
+ "grid_area": null,
1504
+ "grid_auto_columns": null,
1505
+ "grid_auto_flow": null,
1506
+ "grid_auto_rows": null,
1507
+ "grid_column": null,
1508
+ "grid_gap": null,
1509
+ "grid_row": null,
1510
+ "grid_template_areas": null,
1511
+ "grid_template_columns": null,
1512
+ "grid_template_rows": null,
1513
+ "height": null,
1514
+ "justify_content": null,
1515
+ "justify_items": null,
1516
+ "left": null,
1517
+ "margin": null,
1518
+ "max_height": null,
1519
+ "max_width": null,
1520
+ "min_height": null,
1521
+ "min_width": null,
1522
+ "object_fit": null,
1523
+ "object_position": null,
1524
+ "order": null,
1525
+ "overflow": null,
1526
+ "overflow_x": null,
1527
+ "overflow_y": null,
1528
+ "padding": null,
1529
+ "right": null,
1530
+ "top": null,
1531
+ "visibility": null,
1532
+ "width": null
1533
+ }
1534
+ },
1535
+ "86c7d5da9ac94f4aa7019866cb48da30": {
1536
+ "model_module": "@jupyter-widgets/base",
1537
+ "model_name": "LayoutModel",
1538
+ "model_module_version": "1.2.0",
1539
+ "state": {
1540
+ "_model_module": "@jupyter-widgets/base",
1541
+ "_model_module_version": "1.2.0",
1542
+ "_model_name": "LayoutModel",
1543
+ "_view_count": null,
1544
+ "_view_module": "@jupyter-widgets/base",
1545
+ "_view_module_version": "1.2.0",
1546
+ "_view_name": "LayoutView",
1547
+ "align_content": null,
1548
+ "align_items": null,
1549
+ "align_self": null,
1550
+ "border": null,
1551
+ "bottom": null,
1552
+ "display": null,
1553
+ "flex": null,
1554
+ "flex_flow": null,
1555
+ "grid_area": null,
1556
+ "grid_auto_columns": null,
1557
+ "grid_auto_flow": null,
1558
+ "grid_auto_rows": null,
1559
+ "grid_column": null,
1560
+ "grid_gap": null,
1561
+ "grid_row": null,
1562
+ "grid_template_areas": null,
1563
+ "grid_template_columns": null,
1564
+ "grid_template_rows": null,
1565
+ "height": null,
1566
+ "justify_content": null,
1567
+ "justify_items": null,
1568
+ "left": null,
1569
+ "margin": null,
1570
+ "max_height": null,
1571
+ "max_width": null,
1572
+ "min_height": null,
1573
+ "min_width": null,
1574
+ "object_fit": null,
1575
+ "object_position": null,
1576
+ "order": null,
1577
+ "overflow": null,
1578
+ "overflow_x": null,
1579
+ "overflow_y": null,
1580
+ "padding": null,
1581
+ "right": null,
1582
+ "top": null,
1583
+ "visibility": null,
1584
+ "width": null
1585
+ }
1586
+ },
1587
+ "12e0c1a1d2e545db8c89accffb1072fb": {
1588
+ "model_module": "@jupyter-widgets/controls",
1589
+ "model_name": "DescriptionStyleModel",
1590
+ "model_module_version": "1.5.0",
1591
+ "state": {
1592
+ "_model_module": "@jupyter-widgets/controls",
1593
+ "_model_module_version": "1.5.0",
1594
+ "_model_name": "DescriptionStyleModel",
1595
+ "_view_count": null,
1596
+ "_view_module": "@jupyter-widgets/base",
1597
+ "_view_module_version": "1.2.0",
1598
+ "_view_name": "StyleView",
1599
+ "description_width": ""
1600
+ }
1601
+ },
1602
+ "cb1f378c68e3464db6bbe1756ee68b46": {
1603
+ "model_module": "@jupyter-widgets/base",
1604
+ "model_name": "LayoutModel",
1605
+ "model_module_version": "1.2.0",
1606
+ "state": {
1607
+ "_model_module": "@jupyter-widgets/base",
1608
+ "_model_module_version": "1.2.0",
1609
+ "_model_name": "LayoutModel",
1610
+ "_view_count": null,
1611
+ "_view_module": "@jupyter-widgets/base",
1612
+ "_view_module_version": "1.2.0",
1613
+ "_view_name": "LayoutView",
1614
+ "align_content": null,
1615
+ "align_items": null,
1616
+ "align_self": null,
1617
+ "border": null,
1618
+ "bottom": null,
1619
+ "display": null,
1620
+ "flex": null,
1621
+ "flex_flow": null,
1622
+ "grid_area": null,
1623
+ "grid_auto_columns": null,
1624
+ "grid_auto_flow": null,
1625
+ "grid_auto_rows": null,
1626
+ "grid_column": null,
1627
+ "grid_gap": null,
1628
+ "grid_row": null,
1629
+ "grid_template_areas": null,
1630
+ "grid_template_columns": null,
1631
+ "grid_template_rows": null,
1632
+ "height": null,
1633
+ "justify_content": null,
1634
+ "justify_items": null,
1635
+ "left": null,
1636
+ "margin": null,
1637
+ "max_height": null,
1638
+ "max_width": null,
1639
+ "min_height": null,
1640
+ "min_width": null,
1641
+ "object_fit": null,
1642
+ "object_position": null,
1643
+ "order": null,
1644
+ "overflow": null,
1645
+ "overflow_x": null,
1646
+ "overflow_y": null,
1647
+ "padding": null,
1648
+ "right": null,
1649
+ "top": null,
1650
+ "visibility": null,
1651
+ "width": null
1652
+ }
1653
+ },
1654
+ "272a34ba164e4b2a88aa1d92a42510b8": {
1655
+ "model_module": "@jupyter-widgets/controls",
1656
+ "model_name": "ProgressStyleModel",
1657
+ "model_module_version": "1.5.0",
1658
+ "state": {
1659
+ "_model_module": "@jupyter-widgets/controls",
1660
+ "_model_module_version": "1.5.0",
1661
+ "_model_name": "ProgressStyleModel",
1662
+ "_view_count": null,
1663
+ "_view_module": "@jupyter-widgets/base",
1664
+ "_view_module_version": "1.2.0",
1665
+ "_view_name": "StyleView",
1666
+ "bar_color": null,
1667
+ "description_width": ""
1668
+ }
1669
+ },
1670
+ "810299000553408ca96ed7b45a8049b8": {
1671
+ "model_module": "@jupyter-widgets/base",
1672
+ "model_name": "LayoutModel",
1673
+ "model_module_version": "1.2.0",
1674
+ "state": {
1675
+ "_model_module": "@jupyter-widgets/base",
1676
+ "_model_module_version": "1.2.0",
1677
+ "_model_name": "LayoutModel",
1678
+ "_view_count": null,
1679
+ "_view_module": "@jupyter-widgets/base",
1680
+ "_view_module_version": "1.2.0",
1681
+ "_view_name": "LayoutView",
1682
+ "align_content": null,
1683
+ "align_items": null,
1684
+ "align_self": null,
1685
+ "border": null,
1686
+ "bottom": null,
1687
+ "display": null,
1688
+ "flex": null,
1689
+ "flex_flow": null,
1690
+ "grid_area": null,
1691
+ "grid_auto_columns": null,
1692
+ "grid_auto_flow": null,
1693
+ "grid_auto_rows": null,
1694
+ "grid_column": null,
1695
+ "grid_gap": null,
1696
+ "grid_row": null,
1697
+ "grid_template_areas": null,
1698
+ "grid_template_columns": null,
1699
+ "grid_template_rows": null,
1700
+ "height": null,
1701
+ "justify_content": null,
1702
+ "justify_items": null,
1703
+ "left": null,
1704
+ "margin": null,
1705
+ "max_height": null,
1706
+ "max_width": null,
1707
+ "min_height": null,
1708
+ "min_width": null,
1709
+ "object_fit": null,
1710
+ "object_position": null,
1711
+ "order": null,
1712
+ "overflow": null,
1713
+ "overflow_x": null,
1714
+ "overflow_y": null,
1715
+ "padding": null,
1716
+ "right": null,
1717
+ "top": null,
1718
+ "visibility": null,
1719
+ "width": null
1720
+ }
1721
+ },
1722
+ "93d6dfc7a47d4c76bdbeb868dc0a3369": {
1723
+ "model_module": "@jupyter-widgets/controls",
1724
+ "model_name": "DescriptionStyleModel",
1725
+ "model_module_version": "1.5.0",
1726
+ "state": {
1727
+ "_model_module": "@jupyter-widgets/controls",
1728
+ "_model_module_version": "1.5.0",
1729
+ "_model_name": "DescriptionStyleModel",
1730
+ "_view_count": null,
1731
+ "_view_module": "@jupyter-widgets/base",
1732
+ "_view_module_version": "1.2.0",
1733
+ "_view_name": "StyleView",
1734
+ "description_width": ""
1735
+ }
1736
+ }
1737
+ }
1738
+ }
1739
+ },
1740
+ "nbformat": 4,
1741
+ "nbformat_minor": 0
1742
+ }