ssbars commited on
Commit
12faaae
·
1 Parent(s): 2989d17
Files changed (6) hide show
  1. ML2_2025_nlp_ops1.ipynb +167 -0
  2. README.md +146 -1
  3. app.py +111 -44
  4. model.py +338 -35
  5. requirements.lock +64 -0
  6. requirements.txt +8 -1
ML2_2025_nlp_ops1.ipynb ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# __Девопсная домашка по трансформерам__\n",
8
+ "\n",
9
+ "## __Описание__\n",
10
+ "\n",
11
+ "![img](https://d35w6hwqhdq0in.cloudfront.net/521712556725591dcacec5bbdb32e047.png)\n",
12
+ "\n",
13
+ "Ваш главный квест на эту домашку - сделать свой простой сервис на трансформерах. Вот прям целый сервис: начиная с данных и заканчивая графическим интерфейсом где-то в интернете. Ваш сервис может решать либо одну из предложенных ниже задач, либо любую другую (что-то более дорогое лично вам).\n",
14
+ "\n",
15
+ "__Стандартная задача: классификатор статей.__ Нужно построить сервис который принимает название статьи и её abstract, и выдаёт наиболее вероятную тематику статьи: скажем, физика, биология или computer science. В интерфейсе должно быть можно ввести отдельно abstract, отдельно название -- и увидеть топ-95%* тематик, отсортированных по убыванию вероятности. Если abstract не ввели, нужно классифицировать статью только по названию. Ниже вас ждут инструкции и данные именно для этой задачи.\n",
16
+ "\n",
17
+ "<details><summary><u> Что значит Топ-95%?</u></summary>\n",
18
+ " Нужно выдавать темы по убыванию вероятности, пока их суммарная вероятность не превысит 95%. В зависимости от предсказанной вероятности, это может быть одна или более тем. Например, если модель предсказала вероятности [4%, 20%, 60%, 2%, 14%], нужно вывести 3 топ-3 класса. Если один из классов имеет вероятность 96%, достаточно вывести один этот класс.\n",
19
+ "</details>\n",
20
+ "\n",
21
+ "Альтернативно, вы можете отважиться сделать что-то своё, на данных из интернета или своих собственных. В вашей задаче обязательно должно быть _оправданное_ использование трансформеров. Использовать ML чтобы переводить часовые пояса - плохой план.\n",
22
+ "\n",
23
+ "Achtung: трансформеры круты, но не всемогущи. Далеко не любую задачу можно решить ощутимо лучше рандома. Для калибровки, вот несколько примеров решаемых задач (всё кликабельно):\n",
24
+ "\n",
25
+ "\n",
26
+ "<details><summary> - <b>[medium]</b> <u>Сгенерировать youtube-комментарии по _ссылке_ на видео</u></summary>\n",
27
+ " Всё просто, юзер постит ссылку на видео - вы его комментируете. Можно заранее обусловиться что видео только на английском или на русском. Нужно сочинить _несколько_ комментариев. Kudos если вместе с основным комментарием вы порождаете юзернеймы и-или ответы на него.\n",
28
+ " \n",
29
+ " Датасет для файнтюна можно [взять с kaggle](https://www.kaggle.com/tanmay111/youtube-comments-sentiment-analysis/data?select=UScomments.csv) или [собрать самостоятельно](https://towardsdatascience.com/how-to-build-your-own-dataset-of-youtube-comments-39a1e57aade).\n",
30
+ " \n",
31
+ " В качестве основной модели можно использовать [GPT-2 large](https://huggingface.co/gpt2-large). Вот как её файнтюнить: https://tinyurl.com/gpt2-finetune-colab . Если хотите больше - можно взять что-то из творчества https://huggingface.co/EleutherAI . Например, вот [тут](https://tinyurl.com/gpt-j-8bit) есть пример как файнтюнить GPT-J-6B (в 8 раз больше gpt2-large). Однако, этим стоит заниматься уже после того, как у вас заработал базовый сценарий с GPT2-large или даже base.\n",
32
+ " \n",
33
+ " В итоговом сервисе ��ожно дать пользователю вариировать параметры генерации: температура или top-p, если сэмплинг; beam size и length penalty, если beam search; сколько комментариев сгенерировать, etc. Отдельный респект если ваш код будет выводить комментарий по одному слову, прямо в процессе генерёжки - чтобы пользователь не ждал пока вы настругаете абзац целиком.\n",
34
+ "</details>\n",
35
+ "\n",
36
+ "<details><summary> - <b>[medium]</b> <u>Предсказать зарплату по профилю (симулятор Дудя).</u></summary>\n",
37
+ " Note: <details> <summary>Причём тут Дудь?</summary> <img src=https://www.meme-arsenal.com/memes/6dd85f126bbab4f9774ced71ffadbcb3.jpg> </details>\n",
38
+ " \n",
39
+ " Главная сложность задачи - достать хорошие данные. Если хороших данных не случилось - можно и трешовые :) Задание всё-таки про технологии а не про продукт. Для начала можно взять подмножество фичей [отсюда](https://www.kaggle.com/c/job-salary-prediction/data), которые вы можете восстановить из профиля linkedin - название профессии и компании. Название компании лучше заменить на фичи из открытых источников: сфера деятельности, размер, етц.\n",
40
+ " \n",
41
+ " А дальше файнтюним на этом BERT / T5 и радуемся. Ну или хотя бы смеёмся.\n",
42
+ "</details>\n",
43
+ "\n",
44
+ "\n",
45
+ "<details><summary> - <b>[hard]</b> <u>Мнения с географической окраской.</u></summary>\n",
46
+ " \n",
47
+ " Сервис который принимает на вход тему (хэштег или ключевую фразу) и рисует карту мира, где в каждом регионе показано, с какой эмоциональной окраской о ней высказываются в социальных сетях. В качестве социальной сети можно взять VK/twitter, в случая VK ожидается детализация не по странам, а по городам стран бывшего СССР.\n",
48
+ " \n",
49
+ " В минимальном варианте достаточно определять тональность твита в режиме \"позитивно-негативно\", зафайнтюнив условный BERT/T5 на одном из десятков {vk/twitter} sentiment classification датасетах. Географическую привязку можно получить из профиля пользователя. А дальше осталось собрать данные по странам и регионам.\n",
50
+ "\n",
51
+ "</details>\n",
52
+ "\n",
53
+ "\n",
54
+ "<details><summary> - <b>[very hard]</b> <u>Найти статью википедии по фото предмета статьи</u></summary>\n",
55
+ "\n",
56
+ " Чтобы можно было сфотать какую-нибудь неведомую чешуйню на телефон и получить сумму человеческих знаний о ней в форме вики-статьи.\n",
57
+ " \n",
58
+ " В качестве функции потерь можно использовать contrastive loss. Этот лосс неплохо описан в статье [CLIP](https://arxiv.org/abs/2103.00020). Вместо обучения с нуля предлагается взять, собственно, CLIP (text transformer + image transformer) отсюда: https://huggingface.co/docs/transformers/model_doc/clip. Модель будет сопоставлять каждой статьи и \n",
59
+ " \n",
60
+ " Данные для этого квеста можно собрать через API википедии: вики-статьи о предметах обычно содержит фото этого объекта и, собственно, текст статьи. Советуем собрать как минимум 10^4 пар картинка-статья. Картинки советуем дополнительно аугментировать как минимум стандартными картиночными аугами, как максимум - поиском похожих картинок в интернете / imagenet-е по тому же CLIP image encoder-у, но с исходными весами.\n",
61
+ " \n",
62
+ " На время отладки интерфейса рекомендуем ограничить��я небольшим списком статьей: условно, кошечки, собачки, птички, гаечные ключи, машины. Как станет понятно что оно работает \"на кошках\", можно расширить этот список до \"всех статей таких-то категорий\". Эмбединги статей лучше предпосчитать в файл. Если долго их перебирать - можно (но необязательно) воспользоваться быстрым поиском соседей, e.g. [faiss](https://github.com/facebookresearch/faiss) HNSW.\n",
63
+ "</details>\n",
64
+ "\n",
65
+ "\n",
66
+ "## __Как научить классификатор статей?__\n",
67
+ "\n",
68
+ "Данные для классификации статей можно скачать, например, [отсюда](https://www.kaggle.com/neelshah18/arxivdataset/). В этих данных есть заголовок и abstract статьи, а ещё поле __\"tag\"__: тематика статьи [по таксономии arxiv.org](https://arxiv.org/category_taxonomy). Вы можете расширить выборку, добавив в неё статьи за 2019-н.в. годы. Для этого можно [использовать arxiv API](https://github.com/lukasschwab/arxiv.py), самостоятельно распарсить arxiv с помощью [beautifulsoup](https://pypi.org/project/beautifulsoup4/), или поискать другие датасеты на kaggle, huggingface, etc.\n",
69
+ "\n",
70
+ "Когда данные собраны (и аккуратно нарезаны на train/test), можно что-нибудь и обучить. Мы советуем использовать для этого библиотеку `transformers`. Советуем, но не заставляем: если хочется, можно взять [fairseq roberta](https://github.com/pytorch/fairseq/blob/main/examples/roberta), [google t5](https://github.com/google-research/text-to-text-transfer-transformer) или даже написать всё с нуля.\n",
71
+ "\n",
72
+ "Мы разбирали transformers на [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/), за любой дополнительной информацией - смотрите [документации HF](https://huggingface.co/docs).\n",
73
+ "\n",
74
+ "Начать лучше с простой модели, такой как [`distilbert-base-cased`](https://huggingface.co/distilbert-base-cased). Когда вы будете понимать, какие значения accuracy ожидать от базовой модели, можно поискать что-то получше. Два очевидных направления улучшения: (1) сильнее модель T5 или deberta v3, или (2) близкие данные, например взять модель которую предобучили на том же arxiv. И то и другое удобно [искать здесь](https://huggingface.co/models).\n",
75
+ "\n",
76
+ "## __Научили, и что теперь?__\n",
77
+ "\n",
78
+ "А теперь нужно сделать так, чтобы ваша обученная модель отвечала на запросы в интернете. Как и на прошлом этапе, вы можете сделать это кучей разных способов: от простого [streamlit](https://streamlit.io/) / [gradio](https://gradio.app/), минуя [TorchServe](https://pytorch.org/serve/) с [Triton/TensorRT](https://developer.nvidia.com/nvidia-triton-inference-server), и заканчивая экспортом модели в javascript с помощью [TensorFlow.js](https://www.tensorflow.org/js/tutorials) / [ONNX.js](https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js).\n",
79
+ "\n",
80
+ "На [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/) мы разбирали основные вещи про то как работает streamlit и как сделать простое приложение с его помощью.\n",
81
+ "\n",
82
+ "Общая идея streamlit: вы [описываете](https://docs.streamlit.io/library/get-started/create-an-app) внешний вид приложения на питоне с помощью примитивов (кнопки, поля, любой html) -- а потом этот код выполняется на сервере и обслуживает каждого пользователя в отдельном процессе.\n",
83
+ "\n",
84
+ "__Для отладки__ можно запустить приложение локально, открыв консоль рядом с app.py:\n",
85
+ "* `pip install streamlit`\n",
86
+ "* `streamlit run app.py --server.port 8080`\n",
87
+ "* открыть в браузере localhost:8080, если он не открылся автоматически\n",
88
+ "\n",
89
+ "\n",
90
+ "## __Deployment time!__\n",
91
+ "\n",
92
+ "В этот раз вам нужно не просто написать код, __но и поднять ваше приложение с доступом из интернета__. И да, вы угадали, это можно сделать несколькими способами: [HuggingFace spaces](https://huggingface.co/spaces) (данный способ разбирали на [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/)), [Streamlit Cloud](https://streamlit.io/cloud), а ещё вы можете купить или арендовать свой собственный сервер и захоститься там.\n",
93
+ "\n",
94
+ "Проще всего захостить на HF spaces, для этого вам нужно [зарегистрироваться](https://huggingface.co/join) и найти [меню создания нового приложения](https://huggingface.co/new-space). Название и лицензию можно выбрать на своё усмотрение, главное чтобы Space SDK был Streamlit, а доступ - public.\n",
95
+ "\n",
96
+ "Как создали - можно редактировать ваше приложение прямо на сайте, для этого откройте приложение и перейдите в Files and versions, и там в правом углу добавьте нужные файлы.\n",
97
+ "\n",
98
+ "На минималках вам потребуется 2 файла:\n",
99
+ "- `app.py`, о котором мы говорили выше\n",
100
+ "- `requirements.txt`, где вы укажете нужные вам библиотеки\n",
101
+ "\n",
102
+ "Вы можете разместить там же веса вашей обученной модели, любые необходимые данные, дополнительные файлы, ...\n",
103
+ "\n",
104
+ "После каждого изменения файлов, ваше приложение соберётся (обычно 1-5 минут) и будет доступно уже во вкладке App. Ну или не соберётся и покажет вам, где оно сломалось. И вуаля, теперь у вас есть ссылка, которую можно показать ~друзьям~ ассистентам курса и кому угодно в интернете.\n",
105
+ "\n",
106
+ "__Удобная работа с кодом.__ Пока у вас 2 файла, их легко редактивровать прямо в интерфейсе HF spaces. Если же у вас дюжина файлов, вам может быть удобнее редактировать их в любимом vscode/pycharm/.../emacs. Чтобы это не вызывало мучений, можно пользоваться HF spaces как git репозиторием ([подробности тут](https://huggingface.co/docs/hub/spaces#manage-app-with-github-actions)).\n",
107
+ "\n",
108
+ "## __Что нужно сдать__\n",
109
+ "\n",
110
+ "Вы сдаёте проект, который будет проверяться вручную, то что ожидается от каждого проекта:\n",
111
+ "- Текстовое сопровождение вашего конкретного проекта в любом удобно читаемом формате (pdf, html, текст в lk, ...) - что за задачу вы решали, где/как брали данные, какие использовали модели, какие проводили эксперименты, ...\n",
112
+ "- Ссылка на веб интерфейс, где можно протестировать демо вашего проекта - обязательно проверяйте что работает не только у вас (с другого устройства и из под incognito режима)\n",
113
+ "- Код обучения вашей модели (желательно ipynb с заполненными ячейками и не стёртыми выходами, переведённый в pdf / html), но если вы обучали не в ноутбуке, то сдавайте код в виде файла / архива файлов / git ссылки с readme.md описанием того как именно проходило обучение с помощью этого кода.\n",
114
+ "\n",
115
+ "## __Оценка__\n",
116
+ "\n",
117
+ "Мы будем оценивать проект целиком, включая идею и реализацию. Максимум за проект можно получить 10 баллов, но мы оставляем ещё до 5 баллов, котор��е можем выдать как бонусные за особенно интересные и качественно реализованные проекты.\n",
118
+ "\n",
119
+ "### __Тонкие места, за которые могут быть снижения баллов:__\n",
120
+ "\n",
121
+ "__1. Скорость работы.__\n",
122
+ "\n",
123
+ "По умолчанию, streamlit будет выполняет весь ваш код на каждое действие пользователя. То есть всякий раз, когда пользователь меняет что-то в тексте, оно будет заново загружать модель. Чтобы исправить это безобразие, вы можете закэшировать подготовленную модель в `@st.cache`. Подробности в [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/), а также [читайте тут](https://docs.streamlit.io/library/advanced-features/caching).\n",
124
+ "\n",
125
+ "__Как будет оцениваться:__\n",
126
+ "\n",
127
+ "Вы не обязаны пользоваться кэшированием, но ваше приложение не должно неоправдано тормозить дольше, чем на 3 секунды. \"Оправданые\" тормоза это те, которые вы явно оправдали текстом в ЛМС :)\n",
128
+ "\n",
129
+ "-----\n",
130
+ "\n",
131
+ "__2. Понятный фронтенд.__\n",
132
+ "\n",
133
+ "Наколеночный графический интерфейс с семинара - пример того, как скорее не надо делать интерфейс приложения. Как надо - сложный вопрос, причём настолько сложный, что есть даже [Школа Разработки Интерфейсов](https://academy.yandex.ru/schools/frontend). Но для начала:\n",
134
+ "\n",
135
+ "- Выводить нужно человекочитаемый текст, а не просто JSON с индексами и метаданными.\n",
136
+ "- Пользователю должно быть понятно, куда и какие данные вводить. Пустые текстовые поля в вакууме - плохой тон.\n",
137
+ "- Сервис не должен падать с не_отловленными ошибками. Даже если пользователь введёт неправильные/пустые данные, нужно это обработать и написать, где произошла ошибка.\n",
138
+ "\n",
139
+ "__Как будет оцениваться:__\n",
140
+ "\n",
141
+ "Для полного балла достаточно соблюсти эти три правила и специально не стрелять себе в ногу.\n",
142
+ "\n",
143
+ "-----\n",
144
+ "\n",
145
+ "__3. Код обучения и инференса.__\n",
146
+ "\n",
147
+ "Сдавая проект мы будем также получать от вас код проекта (как обучения ваших моделей, так и код веб интерфейса).\n",
148
+ "\n",
149
+ "__Как будет оцениваться:__\n",
150
+ "\n",
151
+ "Код не будет отдельно проверяться как часть задания, поэтому пишите как хотите, однако - в спорных ситуациях мы оставляем за собой право проверить ваш код, за чем могут последовать потенциальные снижения баллов при любых нарушениях.\n"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {},
157
+ "source": []
158
+ }
159
+ ],
160
+ "metadata": {
161
+ "language_info": {
162
+ "name": "python"
163
+ }
164
+ },
165
+ "nbformat": 4,
166
+ "nbformat_minor": 2
167
+ }
README.md CHANGED
@@ -11,6 +11,8 @@ pinned: false
11
 
12
  # 📚 Academic Paper Classifier
13
 
 
 
14
  This Streamlit application helps classify academic papers into different categories using a BERT-based model.
15
 
16
  ## Features
@@ -128,4 +130,147 @@ uv pip install -r requirements.lock
128
 
129
  ## Requirements
130
 
131
- See `requirements.txt` for a complete list of dependencies.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # 📚 Academic Paper Classifier
13
 
14
+ [link](https://huggingface.co/spaces/ssbars/ysdaml4)
15
+
16
  This Streamlit application helps classify academic papers into different categories using a BERT-based model.
17
 
18
  ## Features
 
130
 
131
  ## Requirements
132
 
133
+ See `requirements.txt` for a complete list of dependencies.
134
+
135
+ # ArXiv Paper Classifier
136
+
137
+ This project implements a machine learning system for classifying academic papers into ArXiv categories using state-of-the-art transformer models.
138
+
139
+ ## Project Overview
140
+
141
+ The system uses pre-trained transformer models to classify academic papers into one of the main ArXiv categories:
142
+ - Computer Science (cs)
143
+ - Mathematics (math)
144
+ - Physics (physics)
145
+ - Quantitative Biology (q-bio)
146
+ - Quantitative Finance (q-fin)
147
+ - Statistics (stat)
148
+ - Electrical Engineering and Systems Science (eess)
149
+ - Economics (econ)
150
+
151
+ ## Features
152
+
153
+ - Multiple model support:
154
+ - DistilBERT: Lightweight and fast model, good for testing
155
+ - DeBERTa-v3: Advanced model with better performance
156
+ - RoBERTa: Advanced model with strong performance
157
+ - SciBERT: Specialized for scientific text
158
+ - BERT: Classic model with good all-round performance
159
+
160
+ - Flexible input handling:
161
+ - Can process both title and abstract
162
+ - Handles text preprocessing and tokenization
163
+ - Supports different maximum sequence lengths
164
+
165
+ - Robust error handling:
166
+ - Multiple fallback mechanisms for tokenizer initialization
167
+ - Graceful degradation to simpler models if needed
168
+ - Detailed error messages and logging
169
+
170
+ ## Installation
171
+
172
+ 1. Clone the repository
173
+ 2. Install dependencies:
174
+ ```bash
175
+ pip install -r requirements.txt
176
+ ```
177
+
178
+ ## Usage
179
+
180
+ ### Basic Usage
181
+
182
+ ```python
183
+ from model import PaperClassifier
184
+
185
+ # Initialize classifier with default model (DistilBERT)
186
+ classifier = PaperClassifier()
187
+
188
+ # Classify a paper
189
+ result = classifier.classify_paper(
190
+ title="Your paper title",
191
+ abstract="Your paper abstract"
192
+ )
193
+
194
+ # Print results
195
+ print(result)
196
+ ```
197
+
198
+ ### Using Different Models
199
+
200
+ ```python
201
+ # Initialize with DeBERTa-v3
202
+ classifier = PaperClassifier(model_type='deberta-v3')
203
+
204
+ # Initialize with RoBERTa
205
+ classifier = PaperClassifier(model_type='roberta')
206
+
207
+ # Initialize with SciBERT
208
+ classifier = PaperClassifier(model_type='scibert')
209
+
210
+ # Initialize with BERT
211
+ classifier = PaperClassifier(model_type='bert')
212
+ ```
213
+
214
+ ### Training on Custom Data
215
+
216
+ ```python
217
+ # Prepare your training data
218
+ train_texts = ["paper1 title and abstract", "paper2 title and abstract", ...]
219
+ train_labels = ["cs", "math", ...]
220
+
221
+ # Train the model
222
+ classifier.train_on_arxiv(
223
+ train_texts=train_texts,
224
+ train_labels=train_labels,
225
+ epochs=3,
226
+ batch_size=16,
227
+ learning_rate=2e-5
228
+ )
229
+ ```
230
+
231
+ ## Model Details
232
+
233
+ ### Available Models
234
+
235
+ 1. **DistilBERT** (`distilbert`)
236
+ - Model: `distilbert-base-cased`
237
+ - Max length: 512 tokens
238
+ - Fast tokenizer
239
+ - Good for testing and quick results
240
+
241
+ 2. **DeBERTa-v3** (`deberta-v3`)
242
+ - Model: `microsoft/deberta-v3-base`
243
+ - Max length: 512 tokens
244
+ - Uses DebertaV2TokenizerFast
245
+ - Advanced performance
246
+
247
+ 3. **RoBERTa** (`roberta`)
248
+ - Model: `roberta-base`
249
+ - Max length: 512 tokens
250
+ - Strong performance on various tasks
251
+
252
+ 4. **SciBERT** (`scibert`)
253
+ - Model: `allenai/scibert_scivocab_uncased`
254
+ - Max length: 512 tokens
255
+ - Specialized for scientific text
256
+
257
+ 5. **BERT** (`bert`)
258
+ - Model: `bert-base-uncased`
259
+ - Max length: 512 tokens
260
+ - Classic model with good all-round performance
261
+
262
+ ## Error Handling
263
+
264
+ The system includes robust error handling mechanisms:
265
+ - Multiple fallback levels for tokenizer initialization
266
+ - Graceful degradation to simpler models
267
+ - Detailed error messages and logging
268
+ - Automatic fallback to BERT tokenizer if needed
269
+
270
+ ## Requirements
271
+
272
+ - Python 3.7+
273
+ - PyTorch
274
+ - Transformers library
275
+ - NumPy
276
+ - Sacremoses (for tokenization support)
app.py CHANGED
@@ -11,19 +11,56 @@ import PyPDF2
11
  import io
12
  from model import PaperClassifier
13
 
14
- # Initialize the classifier
15
  @st.cache_resource
16
- def load_classifier():
17
- return PaperClassifier()
18
 
19
- classifier = load_classifier()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Title and description
22
  st.title("📚 Academic Paper Classification")
23
  st.markdown("""
24
  This service helps you classify academic papers into different categories.
25
  You can either:
26
- - Paste the paper's text directly
27
  - Upload a PDF file
28
  """)
29
 
@@ -31,28 +68,44 @@ You can either:
31
  col1, col2 = st.columns(2)
32
 
33
  with col1:
34
- st.subheader("Option 1: Text Input")
35
- text_input = st.text_area(
36
- "Paste your paper text here:",
 
 
 
 
 
 
 
 
37
  height=200,
38
- placeholder="Paste the paper's abstract or content here..."
39
  )
40
 
41
- if st.button("Classify Text"):
42
- if text_input.strip():
43
  with st.spinner("Classifying..."):
44
- result = classifier.classify_paper(text_input)
 
 
 
45
 
46
  st.success("Classification Complete!")
47
- st.write(f"**Predicted Category:** {result['category']}")
48
- st.write(f"**Confidence:** {result['confidence']:.2%}")
 
 
 
 
 
 
 
 
49
 
50
- # Show all probabilities
51
- st.subheader("Category Probabilities")
52
- for category, prob in result['all_probabilities'].items():
53
- st.progress(prob, text=f"{category}: {prob:.2%}")
54
  else:
55
- st.warning("Please enter some text to classify.")
56
 
57
  with col2:
58
  st.subheader("Option 2: PDF Upload")
@@ -62,38 +115,52 @@ with col2:
62
  if st.button("Classify PDF"):
63
  try:
64
  with st.spinner("Processing PDF..."):
65
- # Read PDF content
66
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(uploaded_file.read()))
67
- text_content = ""
68
- for page in pdf_reader.pages:
69
- text_content += page.extract_text()
70
 
71
- # Classify the extracted text
72
- result = classifier.classify_paper(text_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  st.success("Classification Complete!")
75
- st.write(f"**Predicted Category:** {result['category']}")
76
- st.write(f"**Confidence:** {result['confidence']:.2%}")
77
 
78
- # Show all probabilities
79
- st.subheader("Category Probabilities")
80
- for category, prob in result['all_probabilities'].items():
81
- st.progress(prob, text=f"{category}: {prob:.2%}")
 
 
 
 
 
82
  except Exception as e:
83
  st.error(f"Error processing PDF: {str(e)}")
84
 
85
- # Add information about the model
86
- st.sidebar.title("About")
87
- st.sidebar.info("""
88
- This application uses a BERT-based model to classify academic papers into different categories.
89
- The model analyzes the content and predicts the most likely academic field.
90
-
91
- **Categories:**
92
- - Computer Science
93
- - Mathematics
94
- - Physics
95
- - Biology
96
- - Economics
97
  """)
98
 
99
  # Add footer
 
11
  import io
12
  from model import PaperClassifier
13
 
14
+ # Initialize the classifier with model selection
15
  @st.cache_resource
16
+ def load_classifier(model_type):
17
+ return PaperClassifier(model_type)
18
 
19
+ # Cache the PDF text extraction
20
+ @st.cache_data
21
+ def extract_pdf_text(pdf_bytes):
22
+ """Extract text from PDF and try to separate title and abstract"""
23
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
24
+ text = ""
25
+ for page in pdf_reader.pages:
26
+ text += page.extract_text() + "\n"
27
+
28
+ # Try to extract title and abstract
29
+ lines = text.split('\n')
30
+ title = lines[0] if lines else ""
31
+ abstract = "\n".join(lines[1:]) if len(lines) > 1 else ""
32
+
33
+ return title.strip(), abstract.strip()
34
+
35
+ # Get available models for selection
36
+ available_models = list(PaperClassifier.AVAILABLE_MODELS.keys())
37
+
38
+ # Add model selection to sidebar
39
+ st.sidebar.title("Model Settings")
40
+ selected_model = st.sidebar.selectbox(
41
+ "Select Model",
42
+ available_models,
43
+ index=0,
44
+ help="Choose the model to use for classification"
45
+ )
46
+
47
+ # Display model information
48
+ model_info = PaperClassifier.AVAILABLE_MODELS[selected_model]
49
+ st.sidebar.markdown(f"""
50
+ ### Selected Model
51
+ **Name:** {model_info['name']}
52
+ **Description:** {model_info['description']}
53
+ """)
54
+
55
+ # Initialize the classifier with selected model
56
+ classifier = load_classifier(selected_model)
57
 
58
  # Title and description
59
  st.title("📚 Academic Paper Classification")
60
  st.markdown("""
61
  This service helps you classify academic papers into different categories.
62
  You can either:
63
+ - Enter the paper's title and abstract separately
64
  - Upload a PDF file
65
  """)
66
 
 
68
  col1, col2 = st.columns(2)
69
 
70
  with col1:
71
+ st.subheader("Option 1: Manual Input")
72
+
73
+ # Title input
74
+ title_input = st.text_input(
75
+ "Paper Title:",
76
+ placeholder="Enter the paper title..."
77
+ )
78
+
79
+ # Abstract input
80
+ abstract_input = st.text_area(
81
+ "Paper Abstract (optional):",
82
  height=200,
83
+ placeholder="Enter the paper abstract (optional)..."
84
  )
85
 
86
+ if st.button("Classify Paper"):
87
+ if title_input.strip():
88
  with st.spinner("Classifying..."):
89
+ result = classifier.classify_paper(
90
+ title=title_input,
91
+ abstract=abstract_input if abstract_input.strip() else None
92
+ )
93
 
94
  st.success("Classification Complete!")
95
+ st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
96
+ st.write(f"**Model Used:** {result['model_used']}")
97
+
98
+ # Show top categories
99
+ st.subheader("Top Categories (95% Confidence)")
100
+ total_prob = 0
101
+ for cat_info in result['top_categories']:
102
+ prob = cat_info['probability']
103
+ total_prob += prob
104
+ st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
105
 
106
+ st.info(f"Total probability of shown categories: {total_prob:.1%}")
 
 
 
107
  else:
108
+ st.warning("Please enter at least the paper title.")
109
 
110
  with col2:
111
  st.subheader("Option 2: PDF Upload")
 
115
  if st.button("Classify PDF"):
116
  try:
117
  with st.spinner("Processing PDF..."):
118
+ # Extract title and abstract from PDF
119
+ title, abstract = extract_pdf_text(uploaded_file.read())
 
 
 
120
 
121
+ if not title:
122
+ st.error("Could not extract title from PDF.")
123
+ st.stop()
124
+
125
+ # Show extracted text
126
+ with st.expander("Show extracted text"):
127
+ st.write("**Extracted Title:**")
128
+ st.write(title)
129
+ if abstract:
130
+ st.write("**Extracted Abstract:**")
131
+ st.write(abstract)
132
+
133
+ # Classify the paper
134
+ result = classifier.classify_paper(
135
+ title=title,
136
+ abstract=abstract if abstract else None
137
+ )
138
 
139
  st.success("Classification Complete!")
140
+ st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
141
+ st.write(f"**Model Used:** {result['model_used']}")
142
 
143
+ # Show top categories
144
+ st.subheader("Top Categories (95% Confidence)")
145
+ total_prob = 0
146
+ for cat_info in result['top_categories']:
147
+ prob = cat_info['probability']
148
+ total_prob += prob
149
+ st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
150
+
151
+ st.info(f"Total probability of shown categories: {total_prob:.1%}")
152
  except Exception as e:
153
  st.error(f"Error processing PDF: {str(e)}")
154
 
155
+ # Add information about the models
156
+ st.sidebar.markdown("---")
157
+ st.sidebar.title("Available Models")
158
+ st.sidebar.markdown("""
159
+ - **DistilBERT**: Fast and lightweight
160
+ - **DeBERTa v3**: Advanced performance
161
+ - **T5**: Versatile text-to-text
162
+ - **RoBERTa**: Strong performance
163
+ - **SciBERT**: Specialized for science
 
 
 
164
  """)
165
 
166
  # Add footer
model.py CHANGED
@@ -1,53 +1,356 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import torch
3
  import numpy as np
 
4
 
5
  class PaperClassifier:
6
- def __init__(self):
7
- # Using BERT model fine-tuned on arXiv categories
8
- self.model_name = "bert-base-uncased" # This is a placeholder, you can replace with your fine-tuned model
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
- self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Define paper categories (example categories, can be modified based on needs)
13
  self.categories = [
14
- "Computer Science",
15
- "Mathematics",
16
- "Physics",
17
- "Biology",
18
- "Economics"
 
 
 
19
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def preprocess_text(self, text):
22
- # Truncate text to model's maximum length
23
- return text[:512]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def classify_paper(self, text):
 
 
 
 
 
 
 
26
  # Preprocess the text
27
- processed_text = self.preprocess_text(text)
28
 
29
  # Tokenize
30
- inputs = self.tokenizer(processed_text,
31
- return_tensors="pt",
32
- truncation=True,
33
- max_length=512,
34
- padding=True)
 
 
35
 
36
  # Get model predictions
37
  with torch.no_grad():
38
  outputs = self.model(**inputs)
39
- predictions = torch.softmax(outputs.logits, dim=1)
40
-
41
- # Get predicted category and confidence
42
- predicted_idx = torch.argmax(predictions).item()
43
- confidence = predictions[0][predicted_idx].item()
44
 
45
- # Return prediction and confidence
46
  return {
47
- 'category': self.categories[predicted_idx],
48
- 'confidence': confidence,
49
- 'all_probabilities': {
50
- cat: prob.item()
51
- for cat, prob in zip(self.categories, predictions[0])
52
- }
53
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
2
  import torch
3
  import numpy as np
4
+ import logging
5
 
6
  class PaperClassifier:
7
+ # Available models with their configurations
8
+ AVAILABLE_MODELS = {
9
+ 'distilbert': {
10
+ 'name': 'distilbert-base-cased',
11
+ 'max_length': 512,
12
+ 'description': 'Lightweight and fast model, good for testing',
13
+ 'force_slow': False,
14
+ 'tokenizer_class': None # Use default
15
+ },
16
+ 'deberta-v3': {
17
+ 'name': 'microsoft/deberta-v3-base',
18
+ 'max_length': 512,
19
+ 'description': 'Advanced model with better performance',
20
+ 'force_slow': True, # Force slow tokenizer for DeBERTa
21
+ 'tokenizer_class': 'DebertaV2TokenizerFast' # Specify tokenizer class
22
+ },
23
+ 't5': {
24
+ 'name': 'google/t5-v1_1-base',
25
+ 'max_length': 512,
26
+ 'description': 'Versatile text-to-text model',
27
+ 'force_slow': False
28
+ },
29
+ 'roberta': {
30
+ 'name': 'roberta-base',
31
+ 'max_length': 512,
32
+ 'description': 'Advanced model with strong performance',
33
+ 'force_slow': False,
34
+ 'tokenizer_class': None # Use default
35
+ },
36
+ 'scibert': {
37
+ 'name': 'allenai/scibert_scivocab_uncased',
38
+ 'max_length': 512,
39
+ 'description': 'Specialized for scientific text',
40
+ 'force_slow': False,
41
+ 'tokenizer_class': None # Use default
42
+ },
43
+ 'bert': {
44
+ 'name': 'bert-base-uncased',
45
+ 'max_length': 512,
46
+ 'description': 'Classic BERT model, good all-round performance',
47
+ 'force_slow': False,
48
+ 'tokenizer_class': None # Use default
49
+ }
50
+ }
51
+
52
+ def __init__(self, model_type='distilbert'):
53
+ """
54
+ Initialize the classifier with a specific model type
55
+
56
+ Args:
57
+ model_type (str): One of 'distilbert', 'deberta-v3', 't5', 'roberta', 'scibert'
58
+ """
59
+ if model_type not in self.AVAILABLE_MODELS:
60
+ raise ValueError(f"Model type must be one of {list(self.AVAILABLE_MODELS.keys())}")
61
+
62
+ self.model_type = model_type
63
+ self.model_config = self.AVAILABLE_MODELS[model_type]
64
+ self.model_name = self.model_config['name']
65
 
66
+ # ArXiv main categories with descriptions
67
  self.categories = [
68
+ "cs", # Computer Science
69
+ "math", # Mathematics
70
+ "physics", # Physics
71
+ "q-bio", # Quantitative Biology
72
+ "q-fin", # Quantitative Finance
73
+ "stat", # Statistics
74
+ "eess", # Electrical Engineering and Systems Science
75
+ "econ" # Economics
76
  ]
77
+
78
+ # Human readable category names
79
+ self.category_names = {
80
+ "cs": "Computer Science",
81
+ "math": "Mathematics",
82
+ "physics": "Physics",
83
+ "q-bio": "Biology",
84
+ "q-fin": "Finance",
85
+ "stat": "Statistics",
86
+ "eess": "Electrical Engineering",
87
+ "econ": "Economics"
88
+ }
89
+
90
+ # Initialize tokenizer with proper error handling
91
+ self._initialize_tokenizer()
92
+
93
+ # Initialize model with proper error handling
94
+ self._initialize_model()
95
+
96
+ # Print model info
97
+ print(f"Initialized {model_type} model: {self.model_name}")
98
+ print(f"Description: {self.model_config['description']}")
99
+ print("Note: This model needs to be fine-tuned on ArXiv data for accurate predictions.")
100
+
101
+ def _initialize_tokenizer(self):
102
+ """Initialize the tokenizer with proper error handling"""
103
+ try:
104
+ # First try loading the tokenizer configuration
105
+ config = AutoConfig.from_pretrained(self.model_name)
106
+
107
+ # Try loading the tokenizer with specific class if specified
108
+ if self.model_config['tokenizer_class']:
109
+ from transformers import DebertaV2TokenizerFast
110
+ self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
111
+ self.model_name,
112
+ model_max_length=self.model_config['max_length']
113
+ )
114
+ else:
115
+ # Try loading with AutoTokenizer
116
+ self.tokenizer = AutoTokenizer.from_pretrained(
117
+ self.model_name,
118
+ model_max_length=self.model_config['max_length'],
119
+ use_fast=not self.model_config['force_slow'],
120
+ trust_remote_code=True
121
+ )
122
+
123
+ print(f"Successfully initialized tokenizer for {self.model_type}")
124
+
125
+ except Exception as e:
126
+ print(f"Error initializing tokenizer: {str(e)}")
127
+ print("Falling back to basic tokenizer...")
128
+
129
+ # Try one more time with minimal settings
130
+ try:
131
+ self.tokenizer = AutoTokenizer.from_pretrained(
132
+ self.model_name,
133
+ use_fast=False,
134
+ trust_remote_code=True
135
+ )
136
+ except Exception as e:
137
+ # If all else fails, try using BERT tokenizer as last resort
138
+ print("Falling back to BERT tokenizer...")
139
+ self.tokenizer = AutoTokenizer.from_pretrained(
140
+ 'bert-base-uncased',
141
+ model_max_length=self.model_config['max_length']
142
+ )
143
+
144
+ def _initialize_model(self):
145
+ """Initialize the model with proper error handling"""
146
+ try:
147
+ self.model = AutoModelForSequenceClassification.from_pretrained(
148
+ self.model_name,
149
+ num_labels=len(self.categories),
150
+ id2label={i: label for i, label in enumerate(self.categories)},
151
+ label2id={label: i for i, label in enumerate(self.categories)},
152
+ trust_remote_code=True # Allow custom code from hub
153
+ )
154
+ except Exception as e:
155
+ raise RuntimeError(f"Failed to initialize model: {str(e)}")
156
+
157
+ @classmethod
158
+ def list_available_models(cls):
159
+ """List all available models with their descriptions"""
160
+ print("Available models:")
161
+ for model_type, config in cls.AVAILABLE_MODELS.items():
162
+ print(f"\n{model_type}:")
163
+ print(f" Model: {config['name']}")
164
+ print(f" Description: {config['description']}")
165
 
166
+ def preprocess_text(self, title, abstract=None):
167
+ """
168
+ Preprocess title and abstract
169
+
170
+ Args:
171
+ title (str): Paper title
172
+ abstract (str, optional): Paper abstract
173
+ """
174
+ if abstract:
175
+ text = f"Title: {title}\nAbstract: {abstract}"
176
+ else:
177
+ text = f"Title: {title}"
178
+
179
+ max_length = self.model_config['max_length']
180
+
181
+ if self.model_type == 't5':
182
+ text = "classify: " + text
183
+
184
+ return text[:max_length]
185
+
186
+ def get_top_categories(self, probabilities, threshold=0.95):
187
+ """
188
+ Get top categories that sum up to the threshold
189
+
190
+ Args:
191
+ probabilities (torch.Tensor): Model predictions
192
+ threshold (float): Probability threshold (default: 0.95)
193
+
194
+ Returns:
195
+ list: List of (category, probability) tuples
196
+ """
197
+ # Convert to numpy for easier manipulation
198
+ probs = probabilities.numpy()
199
+
200
+ # Sort indices by probability
201
+ sorted_indices = np.argsort(probs)[::-1]
202
+
203
+ # Calculate cumulative sum
204
+ cumsum = np.cumsum(probs[sorted_indices])
205
+
206
+ # Find how many categories we need to reach the threshold
207
+ mask = cumsum <= threshold
208
+ if not any(mask): # If first probability is already > threshold
209
+ mask[0] = True
210
+
211
+ # Get the selected indices
212
+ selected_indices = sorted_indices[mask]
213
+
214
+ # Return categories and their probabilities
215
+ return [
216
+ {
217
+ 'category': self.category_names.get(self.categories[idx], self.categories[idx]),
218
+ 'arxiv_category': self.categories[idx],
219
+ 'probability': float(probs[idx])
220
+ }
221
+ for idx in selected_indices
222
+ ]
223
 
224
+ def classify_paper(self, title, abstract=None):
225
+ """
226
+ Classify a paper based on its title and optional abstract
227
+
228
+ Args:
229
+ title (str): Paper title
230
+ abstract (str, optional): Paper abstract
231
+ """
232
  # Preprocess the text
233
+ processed_text = self.preprocess_text(title, abstract)
234
 
235
  # Tokenize
236
+ inputs = self.tokenizer(
237
+ processed_text,
238
+ return_tensors="pt",
239
+ truncation=True,
240
+ max_length=self.model_config['max_length'],
241
+ padding=True
242
+ )
243
 
244
  # Get model predictions
245
  with torch.no_grad():
246
  outputs = self.model(**inputs)
247
+ predictions = torch.softmax(outputs.logits, dim=1)[0]
248
+
249
+ # Get top categories that sum to 95% probability
250
+ top_categories = self.get_top_categories(predictions)
 
251
 
252
+ # Return predictions
253
  return {
254
+ 'top_categories': top_categories,
255
+ 'model_used': self.model_type,
256
+ 'input_type': 'title_and_abstract' if abstract else 'title_only'
257
+ }
258
+
259
+ def train_on_arxiv(self, train_texts, train_labels, validation_texts=None, validation_labels=None,
260
+ epochs=3, batch_size=16, learning_rate=2e-5):
261
+ """
262
+ Function to fine-tune the model on ArXiv data
263
+
264
+ Args:
265
+ train_texts (list): List of paper texts (title + abstract)
266
+ train_labels (list): List of corresponding ArXiv categories
267
+ validation_texts (list, optional): Validation texts
268
+ validation_labels (list, optional): Validation labels
269
+ epochs (int): Number of training epochs
270
+ batch_size (int): Training batch size
271
+ learning_rate (float): Learning rate for training
272
+ """
273
+ from transformers import TrainingArguments, Trainer
274
+ import datasets
275
+
276
+ # Prepare datasets
277
+ train_encodings = self.tokenizer(
278
+ train_texts,
279
+ truncation=True,
280
+ padding=True,
281
+ max_length=self.model_config['max_length']
282
+ )
283
+
284
+ # Convert labels to ids
285
+ train_label_ids = [self.categories.index(label) for label in train_labels]
286
+
287
+ # Create training dataset
288
+ train_dataset = datasets.Dataset.from_dict({
289
+ 'input_ids': train_encodings['input_ids'],
290
+ 'attention_mask': train_encodings['attention_mask'],
291
+ 'labels': train_label_ids
292
+ })
293
+
294
+ # Create validation dataset if provided
295
+ if validation_texts and validation_labels:
296
+ val_encodings = self.tokenizer(
297
+ validation_texts,
298
+ truncation=True,
299
+ padding=True,
300
+ max_length=self.model_config['max_length']
301
+ )
302
+ val_label_ids = [self.categories.index(label) for label in validation_labels]
303
+ validation_dataset = datasets.Dataset.from_dict({
304
+ 'input_ids': val_encodings['input_ids'],
305
+ 'attention_mask': val_encodings['attention_mask'],
306
+ 'labels': val_label_ids
307
+ })
308
+ else:
309
+ validation_dataset = None
310
+
311
+ # Training arguments
312
+ training_args = TrainingArguments(
313
+ output_dir=f"./results_{self.model_type}",
314
+ num_train_epochs=epochs,
315
+ per_device_train_batch_size=batch_size,
316
+ per_device_eval_batch_size=batch_size,
317
+ warmup_steps=500,
318
+ weight_decay=0.01,
319
+ logging_dir=f"./logs_{self.model_type}",
320
+ logging_steps=10,
321
+ learning_rate=learning_rate,
322
+ evaluation_strategy="epoch" if validation_dataset else "no",
323
+ save_strategy="epoch",
324
+ load_best_model_at_end=True if validation_dataset else False,
325
+ )
326
+
327
+ # Initialize trainer
328
+ trainer = Trainer(
329
+ model=self.model,
330
+ args=training_args,
331
+ train_dataset=train_dataset,
332
+ eval_dataset=validation_dataset,
333
+ )
334
+
335
+ # Train the model
336
+ trainer.train()
337
+
338
+ # Save the fine-tuned model
339
+ save_dir = f"./fine_tuned_{self.model_type}"
340
+ self.model.save_pretrained(save_dir)
341
+ self.tokenizer.save_pretrained(save_dir)
342
+ print(f"Model saved to {save_dir}")
343
+
344
+ @classmethod
345
+ def load_fine_tuned(cls, model_type, model_path):
346
+ """
347
+ Load a fine-tuned model from disk
348
+
349
+ Args:
350
+ model_type (str): The type of model that was fine-tuned
351
+ model_path (str): Path to the saved model
352
+ """
353
+ classifier = cls(model_type)
354
+ classifier.model = AutoModelForSequenceClassification.from_pretrained(model_path)
355
+ classifier.tokenizer = AutoTokenizer.from_pretrained(model_path)
356
+ return classifier
requirements.lock ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.32.0
2
+ altair==5.2.0
3
+ attrs==23.2.0
4
+ blinker==1.7.0
5
+ cachetools==5.3.3
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ gitdb==4.0.11
10
+ gitpython==3.1.42
11
+ idna==3.6
12
+ importlib-metadata==7.0.2
13
+ jinja2==3.1.3
14
+ jsonschema==4.21.1
15
+ markdown-it-py==3.0.0
16
+ markupsafe==2.1.5
17
+ mdurl==0.1.2
18
+ numpy==1.26.4
19
+ packaging==23.2
20
+ pandas==2.2.0
21
+ pillow==10.2.0
22
+ protobuf==4.25.3
23
+ pyarrow==15.0.1
24
+ pydeck==0.8.1b0
25
+ pygments==2.17.2
26
+ python-dateutil==2.9.0
27
+ pytz==2024.1
28
+ requests==2.31.0
29
+ rich==13.7.1
30
+ six==1.16.0
31
+ smmap==5.0.1
32
+ tenacity==8.2.3
33
+ toml==0.10.2
34
+ toolz==0.12.1
35
+ tornado==6.4
36
+ typing-extensions==4.10.0
37
+ tzdata==2024.1
38
+ tzlocal==5.2
39
+ urllib3==2.2.1
40
+ validators==0.22.0
41
+ watchdog==4.0.0
42
+ zipp==3.17.0
43
+ torch==2.2.0
44
+ filelock==3.13.1
45
+ fsspec==2024.2.0
46
+ jinja2==3.1.3
47
+ networkx==3.2.1
48
+ sympy==1.12
49
+ typing-extensions==4.10.0
50
+ transformers==4.37.2
51
+ huggingface-hub==0.21.4
52
+ packaging==23.2
53
+ pyyaml==6.0.1
54
+ regex==2023.12.25
55
+ requests==2.31.0
56
+ tokenizers==0.15.2
57
+ tqdm==4.66.2
58
+ scikit-learn==1.4.0
59
+ joblib==1.3.2
60
+ numpy==1.26.4
61
+ scipy==1.12.0
62
+ threadpoolctl==3.3.0
63
+ PyPDF2==3.0.1
64
+ typing-extensions==4.10.0
requirements.txt CHANGED
@@ -2,4 +2,11 @@ streamlit==1.32.0
2
  torch==2.2.0
3
  transformers==4.37.2
4
  scikit-learn==1.4.0
5
- PyPDF2==3.0.1
 
 
 
 
 
 
 
 
2
  torch==2.2.0
3
  transformers==4.37.2
4
  scikit-learn==1.4.0
5
+ PyPDF2==3.0.1
6
+ datasets==2.18.0
7
+ arxiv==2.1.0
8
+ beautifulsoup4==4.12.3
9
+ sentencepiece==0.2.0
10
+ tokenizers==0.15.2
11
+ protobuf==4.25.3
12
+ sacremoses==0.1.1