v2
Browse files- ML2_2025_nlp_ops1.ipynb +167 -0
- README.md +146 -1
- app.py +111 -44
- model.py +338 -35
- requirements.lock +64 -0
- requirements.txt +8 -1
ML2_2025_nlp_ops1.ipynb
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# __Девопсная домашка по трансформерам__\n",
|
8 |
+
"\n",
|
9 |
+
"## __Описание__\n",
|
10 |
+
"\n",
|
11 |
+
"\n",
|
12 |
+
"\n",
|
13 |
+
"Ваш главный квест на эту домашку - сделать свой простой сервис на трансформерах. Вот прям целый сервис: начиная с данных и заканчивая графическим интерфейсом где-то в интернете. Ваш сервис может решать либо одну из предложенных ниже задач, либо любую другую (что-то более дорогое лично вам).\n",
|
14 |
+
"\n",
|
15 |
+
"__Стандартная задача: классификатор статей.__ Нужно построить сервис который принимает название статьи и её abstract, и выдаёт наиболее вероятную тематику статьи: скажем, физика, биология или computer science. В интерфейсе должно быть можно ввести отдельно abstract, отдельно название -- и увидеть топ-95%* тематик, отсортированных по убыванию вероятности. Если abstract не ввели, нужно классифицировать статью только по названию. Ниже вас ждут инструкции и данные именно для этой задачи.\n",
|
16 |
+
"\n",
|
17 |
+
"<details><summary><u> Что значит Топ-95%?</u></summary>\n",
|
18 |
+
" Нужно выдавать темы по убыванию вероятности, пока их суммарная вероятность не превысит 95%. В зависимости от предсказанной вероятности, это может быть одна или более тем. Например, если модель предсказала вероятности [4%, 20%, 60%, 2%, 14%], нужно вывести 3 топ-3 класса. Если один из классов имеет вероятность 96%, достаточно вывести один этот класс.\n",
|
19 |
+
"</details>\n",
|
20 |
+
"\n",
|
21 |
+
"Альтернативно, вы можете отважиться сделать что-то своё, на данных из интернета или своих собственных. В вашей задаче обязательно должно быть _оправданное_ использование трансформеров. Использовать ML чтобы переводить часовые пояса - плохой план.\n",
|
22 |
+
"\n",
|
23 |
+
"Achtung: трансформеры круты, но не всемогущи. Далеко не любую задачу можно решить ощутимо лучше рандома. Для калибровки, вот несколько примеров решаемых задач (всё кликабельно):\n",
|
24 |
+
"\n",
|
25 |
+
"\n",
|
26 |
+
"<details><summary> - <b>[medium]</b> <u>Сгенерировать youtube-комментарии по _ссылке_ на видео</u></summary>\n",
|
27 |
+
" Всё просто, юзер постит ссылку на видео - вы его комментируете. Можно заранее обусловиться что видео только на английском или на русском. Нужно сочинить _несколько_ комментариев. Kudos если вместе с основным комментарием вы порождаете юзернеймы и-или ответы на него.\n",
|
28 |
+
" \n",
|
29 |
+
" Датасет для файнтюна можно [взять с kaggle](https://www.kaggle.com/tanmay111/youtube-comments-sentiment-analysis/data?select=UScomments.csv) или [собрать самостоятельно](https://towardsdatascience.com/how-to-build-your-own-dataset-of-youtube-comments-39a1e57aade).\n",
|
30 |
+
" \n",
|
31 |
+
" В качестве основной модели можно использовать [GPT-2 large](https://huggingface.co/gpt2-large). Вот как её файнтюнить: https://tinyurl.com/gpt2-finetune-colab . Если хотите больше - можно взять что-то из творчества https://huggingface.co/EleutherAI . Например, вот [тут](https://tinyurl.com/gpt-j-8bit) есть пример как файнтюнить GPT-J-6B (в 8 раз больше gpt2-large). Однако, этим стоит заниматься уже после того, как у вас заработал базовый сценарий с GPT2-large или даже base.\n",
|
32 |
+
" \n",
|
33 |
+
" В итоговом сервисе ��ожно дать пользователю вариировать параметры генерации: температура или top-p, если сэмплинг; beam size и length penalty, если beam search; сколько комментариев сгенерировать, etc. Отдельный респект если ваш код будет выводить комментарий по одному слову, прямо в процессе генерёжки - чтобы пользователь не ждал пока вы настругаете абзац целиком.\n",
|
34 |
+
"</details>\n",
|
35 |
+
"\n",
|
36 |
+
"<details><summary> - <b>[medium]</b> <u>Предсказать зарплату по профилю (симулятор Дудя).</u></summary>\n",
|
37 |
+
" Note: <details> <summary>Причём тут Дудь?</summary> <img src=https://www.meme-arsenal.com/memes/6dd85f126bbab4f9774ced71ffadbcb3.jpg> </details>\n",
|
38 |
+
" \n",
|
39 |
+
" Главная сложность задачи - достать хорошие данные. Если хороших данных не случилось - можно и трешовые :) Задание всё-таки про технологии а не про продукт. Для начала можно взять подмножество фичей [отсюда](https://www.kaggle.com/c/job-salary-prediction/data), которые вы можете восстановить из профиля linkedin - название профессии и компании. Название компании лучше заменить на фичи из открытых источников: сфера деятельности, размер, етц.\n",
|
40 |
+
" \n",
|
41 |
+
" А дальше файнтюним на этом BERT / T5 и радуемся. Ну или хотя бы смеёмся.\n",
|
42 |
+
"</details>\n",
|
43 |
+
"\n",
|
44 |
+
"\n",
|
45 |
+
"<details><summary> - <b>[hard]</b> <u>Мнения с географической окраской.</u></summary>\n",
|
46 |
+
" \n",
|
47 |
+
" Сервис который принимает на вход тему (хэштег или ключевую фразу) и рисует карту мира, где в каждом регионе показано, с какой эмоциональной окраской о ней высказываются в социальных сетях. В качестве социальной сети можно взять VK/twitter, в случая VK ожидается детализация не по странам, а по городам стран бывшего СССР.\n",
|
48 |
+
" \n",
|
49 |
+
" В минимальном варианте достаточно определять тональность твита в режиме \"позитивно-негативно\", зафайнтюнив условный BERT/T5 на одном из десятков {vk/twitter} sentiment classification датасетах. Географическую привязку можно получить из профиля пользователя. А дальше осталось собрать данные по странам и регионам.\n",
|
50 |
+
"\n",
|
51 |
+
"</details>\n",
|
52 |
+
"\n",
|
53 |
+
"\n",
|
54 |
+
"<details><summary> - <b>[very hard]</b> <u>Найти статью википедии по фото предмета статьи</u></summary>\n",
|
55 |
+
"\n",
|
56 |
+
" Чтобы можно было сфотать какую-нибудь неведомую чешуйню на телефон и получить сумму человеческих знаний о ней в форме вики-статьи.\n",
|
57 |
+
" \n",
|
58 |
+
" В качестве функции потерь можно использовать contrastive loss. Этот лосс неплохо описан в статье [CLIP](https://arxiv.org/abs/2103.00020). Вместо обучения с нуля предлагается взять, собственно, CLIP (text transformer + image transformer) отсюда: https://huggingface.co/docs/transformers/model_doc/clip. Модель будет сопоставлять каждой статьи и \n",
|
59 |
+
" \n",
|
60 |
+
" Данные для этого квеста можно собрать через API википедии: вики-статьи о предметах обычно содержит фото этого объекта и, собственно, текст статьи. Советуем собрать как минимум 10^4 пар картинка-статья. Картинки советуем дополнительно аугментировать как минимум стандартными картиночными аугами, как максимум - поиском похожих картинок в интернете / imagenet-е по тому же CLIP image encoder-у, но с исходными весами.\n",
|
61 |
+
" \n",
|
62 |
+
" На время отладки интерфейса рекомендуем ограничить��я небольшим списком статьей: условно, кошечки, собачки, птички, гаечные ключи, машины. Как станет понятно что оно работает \"на кошках\", можно расширить этот список до \"всех статей таких-то категорий\". Эмбединги статей лучше предпосчитать в файл. Если долго их перебирать - можно (но необязательно) воспользоваться быстрым поиском соседей, e.g. [faiss](https://github.com/facebookresearch/faiss) HNSW.\n",
|
63 |
+
"</details>\n",
|
64 |
+
"\n",
|
65 |
+
"\n",
|
66 |
+
"## __Как научить классификатор статей?__\n",
|
67 |
+
"\n",
|
68 |
+
"Данные для классификации статей можно скачать, например, [отсюда](https://www.kaggle.com/neelshah18/arxivdataset/). В этих данных есть заголовок и abstract статьи, а ещё поле __\"tag\"__: тематика статьи [по таксономии arxiv.org](https://arxiv.org/category_taxonomy). Вы можете расширить выборку, добавив в неё статьи за 2019-н.в. годы. Для этого можно [использовать arxiv API](https://github.com/lukasschwab/arxiv.py), самостоятельно распарсить arxiv с помощью [beautifulsoup](https://pypi.org/project/beautifulsoup4/), или поискать другие датасеты на kaggle, huggingface, etc.\n",
|
69 |
+
"\n",
|
70 |
+
"Когда данные собраны (и аккуратно нарезаны на train/test), можно что-нибудь и обучить. Мы советуем использовать для этого библиотеку `transformers`. Советуем, но не заставляем: если хочется, можно взять [fairseq roberta](https://github.com/pytorch/fairseq/blob/main/examples/roberta), [google t5](https://github.com/google-research/text-to-text-transfer-transformer) или даже написать всё с нуля.\n",
|
71 |
+
"\n",
|
72 |
+
"Мы разбирали transformers на [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/), за любой дополнительной информацией - смотрите [документации HF](https://huggingface.co/docs).\n",
|
73 |
+
"\n",
|
74 |
+
"Начать лучше с простой модели, такой как [`distilbert-base-cased`](https://huggingface.co/distilbert-base-cased). Когда вы будете понимать, какие значения accuracy ожидать от базовой модели, можно поискать что-то получше. Два очевидных направления улучшения: (1) сильнее модель T5 или deberta v3, или (2) близкие данные, например взять модель которую предобучили на том же arxiv. И то и другое удобно [искать здесь](https://huggingface.co/models).\n",
|
75 |
+
"\n",
|
76 |
+
"## __Научили, и что теперь?__\n",
|
77 |
+
"\n",
|
78 |
+
"А теперь нужно сделать так, чтобы ваша обученная модель отвечала на запросы в интернете. Как и на прошлом этапе, вы можете сделать это кучей разных способов: от простого [streamlit](https://streamlit.io/) / [gradio](https://gradio.app/), минуя [TorchServe](https://pytorch.org/serve/) с [Triton/TensorRT](https://developer.nvidia.com/nvidia-triton-inference-server), и заканчивая экспортом модели в javascript с помощью [TensorFlow.js](https://www.tensorflow.org/js/tutorials) / [ONNX.js](https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js).\n",
|
79 |
+
"\n",
|
80 |
+
"На [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/) мы разбирали основные вещи про то как работает streamlit и как сделать простое приложение с его помощью.\n",
|
81 |
+
"\n",
|
82 |
+
"Общая идея streamlit: вы [описываете](https://docs.streamlit.io/library/get-started/create-an-app) внешний вид приложения на питоне с помощью примитивов (кнопки, поля, любой html) -- а потом этот код выполняется на сервере и обслуживает каждого пользователя в отдельном процессе.\n",
|
83 |
+
"\n",
|
84 |
+
"__Для отладки__ можно запустить приложение локально, открыв консоль рядом с app.py:\n",
|
85 |
+
"* `pip install streamlit`\n",
|
86 |
+
"* `streamlit run app.py --server.port 8080`\n",
|
87 |
+
"* открыть в браузере localhost:8080, если он не открылся автоматически\n",
|
88 |
+
"\n",
|
89 |
+
"\n",
|
90 |
+
"## __Deployment time!__\n",
|
91 |
+
"\n",
|
92 |
+
"В этот раз вам нужно не просто написать код, __но и поднять ваше приложение с доступом из интернета__. И да, вы угадали, это можно сделать несколькими способами: [HuggingFace spaces](https://huggingface.co/spaces) (данный способ разбирали на [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/)), [Streamlit Cloud](https://streamlit.io/cloud), а ещё вы можете купить или арендовать свой собственный сервер и захоститься там.\n",
|
93 |
+
"\n",
|
94 |
+
"Проще всего захостить на HF spaces, для этого вам нужно [зарегистрироваться](https://huggingface.co/join) и найти [меню создания нового приложения](https://huggingface.co/new-space). Название и лицензию можно выбрать на своё усмотрение, главное чтобы Space SDK был Streamlit, а доступ - public.\n",
|
95 |
+
"\n",
|
96 |
+
"Как создали - можно редактировать ваше приложение прямо на сайте, для этого откройте приложение и перейдите в Files and versions, и там в правом углу добавьте нужные файлы.\n",
|
97 |
+
"\n",
|
98 |
+
"На минималках вам потребуется 2 файла:\n",
|
99 |
+
"- `app.py`, о котором мы говорили выше\n",
|
100 |
+
"- `requirements.txt`, где вы укажете нужные вам библиотеки\n",
|
101 |
+
"\n",
|
102 |
+
"Вы можете разместить там же веса вашей обученной модели, любые необходимые данные, дополнительные файлы, ...\n",
|
103 |
+
"\n",
|
104 |
+
"После каждого изменения файлов, ваше приложение соберётся (обычно 1-5 минут) и будет доступно уже во вкладке App. Ну или не соберётся и покажет вам, где оно сломалось. И вуаля, теперь у вас есть ссылка, которую можно показать ~друзьям~ ассистентам курса и кому угодно в интернете.\n",
|
105 |
+
"\n",
|
106 |
+
"__Удобная работа с кодом.__ Пока у вас 2 файла, их легко редактивровать прямо в интерфейсе HF spaces. Если же у вас дюжина файлов, вам может быть удобнее редактировать их в любимом vscode/pycharm/.../emacs. Чтобы это не вызывало мучений, можно пользоваться HF spaces как git репозиторием ([подробности тут](https://huggingface.co/docs/hub/spaces#manage-app-with-github-actions)).\n",
|
107 |
+
"\n",
|
108 |
+
"## __Что нужно сдать__\n",
|
109 |
+
"\n",
|
110 |
+
"Вы сдаёте проект, который будет проверяться вручную, то что ожидается от каждого проекта:\n",
|
111 |
+
"- Текстовое сопровождение вашего конкретного проекта в любом удобно читаемом формате (pdf, html, текст в lk, ...) - что за задачу вы решали, где/как брали данные, какие использовали модели, какие проводили эксперименты, ...\n",
|
112 |
+
"- Ссылка на веб интерфейс, где можно протестировать демо вашего проекта - обязательно проверяйте что работает не только у вас (с другого устройства и из под incognito режима)\n",
|
113 |
+
"- Код обучения вашей модели (желательно ipynb с заполненными ячейками и не стёртыми выходами, переведённый в pdf / html), но если вы обучали не в ноутбуке, то сдавайте код в виде файла / архива файлов / git ссылки с readme.md описанием того как именно проходило обучение с помощью этого кода.\n",
|
114 |
+
"\n",
|
115 |
+
"## __Оценка__\n",
|
116 |
+
"\n",
|
117 |
+
"Мы будем оценивать проект целиком, включая идею и реализацию. Максимум за проект можно получить 10 баллов, но мы оставляем ещё до 5 баллов, котор��е можем выдать как бонусные за особенно интересные и качественно реализованные проекты.\n",
|
118 |
+
"\n",
|
119 |
+
"### __Тонкие места, за которые могут быть снижения баллов:__\n",
|
120 |
+
"\n",
|
121 |
+
"__1. Скорость работы.__\n",
|
122 |
+
"\n",
|
123 |
+
"По умолчанию, streamlit будет выполняет весь ваш код на каждое действие пользователя. То есть всякий раз, когда пользователь меняет что-то в тексте, оно будет заново загружать модель. Чтобы исправить это безобразие, вы можете закэшировать подготовленную модель в `@st.cache`. Подробности в [семинаре](https://lk.yandexdataschool.ru/courses/2025-spring/7.1332-machine-learning-2/classes/13138/), а также [читайте тут](https://docs.streamlit.io/library/advanced-features/caching).\n",
|
124 |
+
"\n",
|
125 |
+
"__Как будет оцениваться:__\n",
|
126 |
+
"\n",
|
127 |
+
"Вы не обязаны пользоваться кэшированием, но ваше приложение не должно неоправдано тормозить дольше, чем на 3 секунды. \"Оправданые\" тормоза это те, которые вы явно оправдали текстом в ЛМС :)\n",
|
128 |
+
"\n",
|
129 |
+
"-----\n",
|
130 |
+
"\n",
|
131 |
+
"__2. Понятный фронтенд.__\n",
|
132 |
+
"\n",
|
133 |
+
"Наколеночный графический интерфейс с семинара - пример того, как скорее не надо делать интерфейс приложения. Как надо - сложный вопрос, причём настолько сложный, что есть даже [Школа Разработки Интерфейсов](https://academy.yandex.ru/schools/frontend). Но для начала:\n",
|
134 |
+
"\n",
|
135 |
+
"- Выводить нужно человекочитаемый текст, а не просто JSON с индексами и метаданными.\n",
|
136 |
+
"- Пользователю должно быть понятно, куда и какие данные вводить. Пустые текстовые поля в вакууме - плохой тон.\n",
|
137 |
+
"- Сервис не должен падать с не_отловленными ошибками. Даже если пользователь введёт неправильные/пустые данные, нужно это обработать и написать, где произошла ошибка.\n",
|
138 |
+
"\n",
|
139 |
+
"__Как будет оцениваться:__\n",
|
140 |
+
"\n",
|
141 |
+
"Для полного балла достаточно соблюсти эти три правила и специально не стрелять себе в ногу.\n",
|
142 |
+
"\n",
|
143 |
+
"-----\n",
|
144 |
+
"\n",
|
145 |
+
"__3. Код обучения и инференса.__\n",
|
146 |
+
"\n",
|
147 |
+
"Сдавая проект мы будем также получать от вас код проекта (как обучения ваших моделей, так и код веб интерфейса).\n",
|
148 |
+
"\n",
|
149 |
+
"__Как будет оцениваться:__\n",
|
150 |
+
"\n",
|
151 |
+
"Код не будет отдельно проверяться как часть задания, поэтому пишите как хотите, однако - в спорных ситуациях мы оставляем за собой право проверить ваш код, за чем могут последовать потенциальные снижения баллов при любых нарушениях.\n"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "markdown",
|
156 |
+
"metadata": {},
|
157 |
+
"source": []
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"metadata": {
|
161 |
+
"language_info": {
|
162 |
+
"name": "python"
|
163 |
+
}
|
164 |
+
},
|
165 |
+
"nbformat": 4,
|
166 |
+
"nbformat_minor": 2
|
167 |
+
}
|
README.md
CHANGED
@@ -11,6 +11,8 @@ pinned: false
|
|
11 |
|
12 |
# 📚 Academic Paper Classifier
|
13 |
|
|
|
|
|
14 |
This Streamlit application helps classify academic papers into different categories using a BERT-based model.
|
15 |
|
16 |
## Features
|
@@ -128,4 +130,147 @@ uv pip install -r requirements.lock
|
|
128 |
|
129 |
## Requirements
|
130 |
|
131 |
-
See `requirements.txt` for a complete list of dependencies.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# 📚 Academic Paper Classifier
|
13 |
|
14 |
+
[link](https://huggingface.co/spaces/ssbars/ysdaml4)
|
15 |
+
|
16 |
This Streamlit application helps classify academic papers into different categories using a BERT-based model.
|
17 |
|
18 |
## Features
|
|
|
130 |
|
131 |
## Requirements
|
132 |
|
133 |
+
See `requirements.txt` for a complete list of dependencies.
|
134 |
+
|
135 |
+
# ArXiv Paper Classifier
|
136 |
+
|
137 |
+
This project implements a machine learning system for classifying academic papers into ArXiv categories using state-of-the-art transformer models.
|
138 |
+
|
139 |
+
## Project Overview
|
140 |
+
|
141 |
+
The system uses pre-trained transformer models to classify academic papers into one of the main ArXiv categories:
|
142 |
+
- Computer Science (cs)
|
143 |
+
- Mathematics (math)
|
144 |
+
- Physics (physics)
|
145 |
+
- Quantitative Biology (q-bio)
|
146 |
+
- Quantitative Finance (q-fin)
|
147 |
+
- Statistics (stat)
|
148 |
+
- Electrical Engineering and Systems Science (eess)
|
149 |
+
- Economics (econ)
|
150 |
+
|
151 |
+
## Features
|
152 |
+
|
153 |
+
- Multiple model support:
|
154 |
+
- DistilBERT: Lightweight and fast model, good for testing
|
155 |
+
- DeBERTa-v3: Advanced model with better performance
|
156 |
+
- RoBERTa: Advanced model with strong performance
|
157 |
+
- SciBERT: Specialized for scientific text
|
158 |
+
- BERT: Classic model with good all-round performance
|
159 |
+
|
160 |
+
- Flexible input handling:
|
161 |
+
- Can process both title and abstract
|
162 |
+
- Handles text preprocessing and tokenization
|
163 |
+
- Supports different maximum sequence lengths
|
164 |
+
|
165 |
+
- Robust error handling:
|
166 |
+
- Multiple fallback mechanisms for tokenizer initialization
|
167 |
+
- Graceful degradation to simpler models if needed
|
168 |
+
- Detailed error messages and logging
|
169 |
+
|
170 |
+
## Installation
|
171 |
+
|
172 |
+
1. Clone the repository
|
173 |
+
2. Install dependencies:
|
174 |
+
```bash
|
175 |
+
pip install -r requirements.txt
|
176 |
+
```
|
177 |
+
|
178 |
+
## Usage
|
179 |
+
|
180 |
+
### Basic Usage
|
181 |
+
|
182 |
+
```python
|
183 |
+
from model import PaperClassifier
|
184 |
+
|
185 |
+
# Initialize classifier with default model (DistilBERT)
|
186 |
+
classifier = PaperClassifier()
|
187 |
+
|
188 |
+
# Classify a paper
|
189 |
+
result = classifier.classify_paper(
|
190 |
+
title="Your paper title",
|
191 |
+
abstract="Your paper abstract"
|
192 |
+
)
|
193 |
+
|
194 |
+
# Print results
|
195 |
+
print(result)
|
196 |
+
```
|
197 |
+
|
198 |
+
### Using Different Models
|
199 |
+
|
200 |
+
```python
|
201 |
+
# Initialize with DeBERTa-v3
|
202 |
+
classifier = PaperClassifier(model_type='deberta-v3')
|
203 |
+
|
204 |
+
# Initialize with RoBERTa
|
205 |
+
classifier = PaperClassifier(model_type='roberta')
|
206 |
+
|
207 |
+
# Initialize with SciBERT
|
208 |
+
classifier = PaperClassifier(model_type='scibert')
|
209 |
+
|
210 |
+
# Initialize with BERT
|
211 |
+
classifier = PaperClassifier(model_type='bert')
|
212 |
+
```
|
213 |
+
|
214 |
+
### Training on Custom Data
|
215 |
+
|
216 |
+
```python
|
217 |
+
# Prepare your training data
|
218 |
+
train_texts = ["paper1 title and abstract", "paper2 title and abstract", ...]
|
219 |
+
train_labels = ["cs", "math", ...]
|
220 |
+
|
221 |
+
# Train the model
|
222 |
+
classifier.train_on_arxiv(
|
223 |
+
train_texts=train_texts,
|
224 |
+
train_labels=train_labels,
|
225 |
+
epochs=3,
|
226 |
+
batch_size=16,
|
227 |
+
learning_rate=2e-5
|
228 |
+
)
|
229 |
+
```
|
230 |
+
|
231 |
+
## Model Details
|
232 |
+
|
233 |
+
### Available Models
|
234 |
+
|
235 |
+
1. **DistilBERT** (`distilbert`)
|
236 |
+
- Model: `distilbert-base-cased`
|
237 |
+
- Max length: 512 tokens
|
238 |
+
- Fast tokenizer
|
239 |
+
- Good for testing and quick results
|
240 |
+
|
241 |
+
2. **DeBERTa-v3** (`deberta-v3`)
|
242 |
+
- Model: `microsoft/deberta-v3-base`
|
243 |
+
- Max length: 512 tokens
|
244 |
+
- Uses DebertaV2TokenizerFast
|
245 |
+
- Advanced performance
|
246 |
+
|
247 |
+
3. **RoBERTa** (`roberta`)
|
248 |
+
- Model: `roberta-base`
|
249 |
+
- Max length: 512 tokens
|
250 |
+
- Strong performance on various tasks
|
251 |
+
|
252 |
+
4. **SciBERT** (`scibert`)
|
253 |
+
- Model: `allenai/scibert_scivocab_uncased`
|
254 |
+
- Max length: 512 tokens
|
255 |
+
- Specialized for scientific text
|
256 |
+
|
257 |
+
5. **BERT** (`bert`)
|
258 |
+
- Model: `bert-base-uncased`
|
259 |
+
- Max length: 512 tokens
|
260 |
+
- Classic model with good all-round performance
|
261 |
+
|
262 |
+
## Error Handling
|
263 |
+
|
264 |
+
The system includes robust error handling mechanisms:
|
265 |
+
- Multiple fallback levels for tokenizer initialization
|
266 |
+
- Graceful degradation to simpler models
|
267 |
+
- Detailed error messages and logging
|
268 |
+
- Automatic fallback to BERT tokenizer if needed
|
269 |
+
|
270 |
+
## Requirements
|
271 |
+
|
272 |
+
- Python 3.7+
|
273 |
+
- PyTorch
|
274 |
+
- Transformers library
|
275 |
+
- NumPy
|
276 |
+
- Sacremoses (for tokenization support)
|
app.py
CHANGED
@@ -11,19 +11,56 @@ import PyPDF2
|
|
11 |
import io
|
12 |
from model import PaperClassifier
|
13 |
|
14 |
-
# Initialize the classifier
|
15 |
@st.cache_resource
|
16 |
-
def load_classifier():
|
17 |
-
return PaperClassifier()
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Title and description
|
22 |
st.title("📚 Academic Paper Classification")
|
23 |
st.markdown("""
|
24 |
This service helps you classify academic papers into different categories.
|
25 |
You can either:
|
26 |
-
-
|
27 |
- Upload a PDF file
|
28 |
""")
|
29 |
|
@@ -31,28 +68,44 @@ You can either:
|
|
31 |
col1, col2 = st.columns(2)
|
32 |
|
33 |
with col1:
|
34 |
-
st.subheader("Option 1:
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
height=200,
|
38 |
-
placeholder="
|
39 |
)
|
40 |
|
41 |
-
if st.button("Classify
|
42 |
-
if
|
43 |
with st.spinner("Classifying..."):
|
44 |
-
result = classifier.classify_paper(
|
|
|
|
|
|
|
45 |
|
46 |
st.success("Classification Complete!")
|
47 |
-
st.write(f"**
|
48 |
-
st.write(f"**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
st.subheader("Category Probabilities")
|
52 |
-
for category, prob in result['all_probabilities'].items():
|
53 |
-
st.progress(prob, text=f"{category}: {prob:.2%}")
|
54 |
else:
|
55 |
-
st.warning("Please enter
|
56 |
|
57 |
with col2:
|
58 |
st.subheader("Option 2: PDF Upload")
|
@@ -62,38 +115,52 @@ with col2:
|
|
62 |
if st.button("Classify PDF"):
|
63 |
try:
|
64 |
with st.spinner("Processing PDF..."):
|
65 |
-
#
|
66 |
-
|
67 |
-
text_content = ""
|
68 |
-
for page in pdf_reader.pages:
|
69 |
-
text_content += page.extract_text()
|
70 |
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
st.success("Classification Complete!")
|
75 |
-
st.write(f"**
|
76 |
-
st.write(f"**
|
77 |
|
78 |
-
# Show
|
79 |
-
st.subheader("
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
except Exception as e:
|
83 |
st.error(f"Error processing PDF: {str(e)}")
|
84 |
|
85 |
-
# Add information about the
|
86 |
-
st.sidebar.
|
87 |
-
st.sidebar.
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
**
|
92 |
-
-
|
93 |
-
-
|
94 |
-
- Physics
|
95 |
-
- Biology
|
96 |
-
- Economics
|
97 |
""")
|
98 |
|
99 |
# Add footer
|
|
|
11 |
import io
|
12 |
from model import PaperClassifier
|
13 |
|
14 |
+
# Initialize the classifier with model selection
|
15 |
@st.cache_resource
|
16 |
+
def load_classifier(model_type):
|
17 |
+
return PaperClassifier(model_type)
|
18 |
|
19 |
+
# Cache the PDF text extraction
|
20 |
+
@st.cache_data
|
21 |
+
def extract_pdf_text(pdf_bytes):
|
22 |
+
"""Extract text from PDF and try to separate title and abstract"""
|
23 |
+
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
24 |
+
text = ""
|
25 |
+
for page in pdf_reader.pages:
|
26 |
+
text += page.extract_text() + "\n"
|
27 |
+
|
28 |
+
# Try to extract title and abstract
|
29 |
+
lines = text.split('\n')
|
30 |
+
title = lines[0] if lines else ""
|
31 |
+
abstract = "\n".join(lines[1:]) if len(lines) > 1 else ""
|
32 |
+
|
33 |
+
return title.strip(), abstract.strip()
|
34 |
+
|
35 |
+
# Get available models for selection
|
36 |
+
available_models = list(PaperClassifier.AVAILABLE_MODELS.keys())
|
37 |
+
|
38 |
+
# Add model selection to sidebar
|
39 |
+
st.sidebar.title("Model Settings")
|
40 |
+
selected_model = st.sidebar.selectbox(
|
41 |
+
"Select Model",
|
42 |
+
available_models,
|
43 |
+
index=0,
|
44 |
+
help="Choose the model to use for classification"
|
45 |
+
)
|
46 |
+
|
47 |
+
# Display model information
|
48 |
+
model_info = PaperClassifier.AVAILABLE_MODELS[selected_model]
|
49 |
+
st.sidebar.markdown(f"""
|
50 |
+
### Selected Model
|
51 |
+
**Name:** {model_info['name']}
|
52 |
+
**Description:** {model_info['description']}
|
53 |
+
""")
|
54 |
+
|
55 |
+
# Initialize the classifier with selected model
|
56 |
+
classifier = load_classifier(selected_model)
|
57 |
|
58 |
# Title and description
|
59 |
st.title("📚 Academic Paper Classification")
|
60 |
st.markdown("""
|
61 |
This service helps you classify academic papers into different categories.
|
62 |
You can either:
|
63 |
+
- Enter the paper's title and abstract separately
|
64 |
- Upload a PDF file
|
65 |
""")
|
66 |
|
|
|
68 |
col1, col2 = st.columns(2)
|
69 |
|
70 |
with col1:
|
71 |
+
st.subheader("Option 1: Manual Input")
|
72 |
+
|
73 |
+
# Title input
|
74 |
+
title_input = st.text_input(
|
75 |
+
"Paper Title:",
|
76 |
+
placeholder="Enter the paper title..."
|
77 |
+
)
|
78 |
+
|
79 |
+
# Abstract input
|
80 |
+
abstract_input = st.text_area(
|
81 |
+
"Paper Abstract (optional):",
|
82 |
height=200,
|
83 |
+
placeholder="Enter the paper abstract (optional)..."
|
84 |
)
|
85 |
|
86 |
+
if st.button("Classify Paper"):
|
87 |
+
if title_input.strip():
|
88 |
with st.spinner("Classifying..."):
|
89 |
+
result = classifier.classify_paper(
|
90 |
+
title=title_input,
|
91 |
+
abstract=abstract_input if abstract_input.strip() else None
|
92 |
+
)
|
93 |
|
94 |
st.success("Classification Complete!")
|
95 |
+
st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
|
96 |
+
st.write(f"**Model Used:** {result['model_used']}")
|
97 |
+
|
98 |
+
# Show top categories
|
99 |
+
st.subheader("Top Categories (95% Confidence)")
|
100 |
+
total_prob = 0
|
101 |
+
for cat_info in result['top_categories']:
|
102 |
+
prob = cat_info['probability']
|
103 |
+
total_prob += prob
|
104 |
+
st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
|
105 |
|
106 |
+
st.info(f"Total probability of shown categories: {total_prob:.1%}")
|
|
|
|
|
|
|
107 |
else:
|
108 |
+
st.warning("Please enter at least the paper title.")
|
109 |
|
110 |
with col2:
|
111 |
st.subheader("Option 2: PDF Upload")
|
|
|
115 |
if st.button("Classify PDF"):
|
116 |
try:
|
117 |
with st.spinner("Processing PDF..."):
|
118 |
+
# Extract title and abstract from PDF
|
119 |
+
title, abstract = extract_pdf_text(uploaded_file.read())
|
|
|
|
|
|
|
120 |
|
121 |
+
if not title:
|
122 |
+
st.error("Could not extract title from PDF.")
|
123 |
+
st.stop()
|
124 |
+
|
125 |
+
# Show extracted text
|
126 |
+
with st.expander("Show extracted text"):
|
127 |
+
st.write("**Extracted Title:**")
|
128 |
+
st.write(title)
|
129 |
+
if abstract:
|
130 |
+
st.write("**Extracted Abstract:**")
|
131 |
+
st.write(abstract)
|
132 |
+
|
133 |
+
# Classify the paper
|
134 |
+
result = classifier.classify_paper(
|
135 |
+
title=title,
|
136 |
+
abstract=abstract if abstract else None
|
137 |
+
)
|
138 |
|
139 |
st.success("Classification Complete!")
|
140 |
+
st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
|
141 |
+
st.write(f"**Model Used:** {result['model_used']}")
|
142 |
|
143 |
+
# Show top categories
|
144 |
+
st.subheader("Top Categories (95% Confidence)")
|
145 |
+
total_prob = 0
|
146 |
+
for cat_info in result['top_categories']:
|
147 |
+
prob = cat_info['probability']
|
148 |
+
total_prob += prob
|
149 |
+
st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
|
150 |
+
|
151 |
+
st.info(f"Total probability of shown categories: {total_prob:.1%}")
|
152 |
except Exception as e:
|
153 |
st.error(f"Error processing PDF: {str(e)}")
|
154 |
|
155 |
+
# Add information about the models
|
156 |
+
st.sidebar.markdown("---")
|
157 |
+
st.sidebar.title("Available Models")
|
158 |
+
st.sidebar.markdown("""
|
159 |
+
- **DistilBERT**: Fast and lightweight
|
160 |
+
- **DeBERTa v3**: Advanced performance
|
161 |
+
- **T5**: Versatile text-to-text
|
162 |
+
- **RoBERTa**: Strong performance
|
163 |
+
- **SciBERT**: Specialized for science
|
|
|
|
|
|
|
164 |
""")
|
165 |
|
166 |
# Add footer
|
model.py
CHANGED
@@ -1,53 +1,356 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
|
5 |
class PaperClassifier:
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
self.categories = [
|
14 |
-
"Computer Science
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"
|
|
|
|
|
|
|
19 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
def preprocess_text(self,
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
def classify_paper(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Preprocess the text
|
27 |
-
processed_text = self.preprocess_text(
|
28 |
|
29 |
# Tokenize
|
30 |
-
inputs = self.tokenizer(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
# Get model predictions
|
37 |
with torch.no_grad():
|
38 |
outputs = self.model(**inputs)
|
39 |
-
predictions = torch.softmax(outputs.logits, dim=1)
|
40 |
-
|
41 |
-
# Get
|
42 |
-
|
43 |
-
confidence = predictions[0][predicted_idx].item()
|
44 |
|
45 |
-
# Return
|
46 |
return {
|
47 |
-
'
|
48 |
-
'
|
49 |
-
'
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import logging
|
5 |
|
6 |
class PaperClassifier:
|
7 |
+
# Available models with their configurations
|
8 |
+
AVAILABLE_MODELS = {
|
9 |
+
'distilbert': {
|
10 |
+
'name': 'distilbert-base-cased',
|
11 |
+
'max_length': 512,
|
12 |
+
'description': 'Lightweight and fast model, good for testing',
|
13 |
+
'force_slow': False,
|
14 |
+
'tokenizer_class': None # Use default
|
15 |
+
},
|
16 |
+
'deberta-v3': {
|
17 |
+
'name': 'microsoft/deberta-v3-base',
|
18 |
+
'max_length': 512,
|
19 |
+
'description': 'Advanced model with better performance',
|
20 |
+
'force_slow': True, # Force slow tokenizer for DeBERTa
|
21 |
+
'tokenizer_class': 'DebertaV2TokenizerFast' # Specify tokenizer class
|
22 |
+
},
|
23 |
+
't5': {
|
24 |
+
'name': 'google/t5-v1_1-base',
|
25 |
+
'max_length': 512,
|
26 |
+
'description': 'Versatile text-to-text model',
|
27 |
+
'force_slow': False
|
28 |
+
},
|
29 |
+
'roberta': {
|
30 |
+
'name': 'roberta-base',
|
31 |
+
'max_length': 512,
|
32 |
+
'description': 'Advanced model with strong performance',
|
33 |
+
'force_slow': False,
|
34 |
+
'tokenizer_class': None # Use default
|
35 |
+
},
|
36 |
+
'scibert': {
|
37 |
+
'name': 'allenai/scibert_scivocab_uncased',
|
38 |
+
'max_length': 512,
|
39 |
+
'description': 'Specialized for scientific text',
|
40 |
+
'force_slow': False,
|
41 |
+
'tokenizer_class': None # Use default
|
42 |
+
},
|
43 |
+
'bert': {
|
44 |
+
'name': 'bert-base-uncased',
|
45 |
+
'max_length': 512,
|
46 |
+
'description': 'Classic BERT model, good all-round performance',
|
47 |
+
'force_slow': False,
|
48 |
+
'tokenizer_class': None # Use default
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
def __init__(self, model_type='distilbert'):
|
53 |
+
"""
|
54 |
+
Initialize the classifier with a specific model type
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model_type (str): One of 'distilbert', 'deberta-v3', 't5', 'roberta', 'scibert'
|
58 |
+
"""
|
59 |
+
if model_type not in self.AVAILABLE_MODELS:
|
60 |
+
raise ValueError(f"Model type must be one of {list(self.AVAILABLE_MODELS.keys())}")
|
61 |
+
|
62 |
+
self.model_type = model_type
|
63 |
+
self.model_config = self.AVAILABLE_MODELS[model_type]
|
64 |
+
self.model_name = self.model_config['name']
|
65 |
|
66 |
+
# ArXiv main categories with descriptions
|
67 |
self.categories = [
|
68 |
+
"cs", # Computer Science
|
69 |
+
"math", # Mathematics
|
70 |
+
"physics", # Physics
|
71 |
+
"q-bio", # Quantitative Biology
|
72 |
+
"q-fin", # Quantitative Finance
|
73 |
+
"stat", # Statistics
|
74 |
+
"eess", # Electrical Engineering and Systems Science
|
75 |
+
"econ" # Economics
|
76 |
]
|
77 |
+
|
78 |
+
# Human readable category names
|
79 |
+
self.category_names = {
|
80 |
+
"cs": "Computer Science",
|
81 |
+
"math": "Mathematics",
|
82 |
+
"physics": "Physics",
|
83 |
+
"q-bio": "Biology",
|
84 |
+
"q-fin": "Finance",
|
85 |
+
"stat": "Statistics",
|
86 |
+
"eess": "Electrical Engineering",
|
87 |
+
"econ": "Economics"
|
88 |
+
}
|
89 |
+
|
90 |
+
# Initialize tokenizer with proper error handling
|
91 |
+
self._initialize_tokenizer()
|
92 |
+
|
93 |
+
# Initialize model with proper error handling
|
94 |
+
self._initialize_model()
|
95 |
+
|
96 |
+
# Print model info
|
97 |
+
print(f"Initialized {model_type} model: {self.model_name}")
|
98 |
+
print(f"Description: {self.model_config['description']}")
|
99 |
+
print("Note: This model needs to be fine-tuned on ArXiv data for accurate predictions.")
|
100 |
+
|
101 |
+
def _initialize_tokenizer(self):
|
102 |
+
"""Initialize the tokenizer with proper error handling"""
|
103 |
+
try:
|
104 |
+
# First try loading the tokenizer configuration
|
105 |
+
config = AutoConfig.from_pretrained(self.model_name)
|
106 |
+
|
107 |
+
# Try loading the tokenizer with specific class if specified
|
108 |
+
if self.model_config['tokenizer_class']:
|
109 |
+
from transformers import DebertaV2TokenizerFast
|
110 |
+
self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
|
111 |
+
self.model_name,
|
112 |
+
model_max_length=self.model_config['max_length']
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
# Try loading with AutoTokenizer
|
116 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
117 |
+
self.model_name,
|
118 |
+
model_max_length=self.model_config['max_length'],
|
119 |
+
use_fast=not self.model_config['force_slow'],
|
120 |
+
trust_remote_code=True
|
121 |
+
)
|
122 |
+
|
123 |
+
print(f"Successfully initialized tokenizer for {self.model_type}")
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error initializing tokenizer: {str(e)}")
|
127 |
+
print("Falling back to basic tokenizer...")
|
128 |
+
|
129 |
+
# Try one more time with minimal settings
|
130 |
+
try:
|
131 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
132 |
+
self.model_name,
|
133 |
+
use_fast=False,
|
134 |
+
trust_remote_code=True
|
135 |
+
)
|
136 |
+
except Exception as e:
|
137 |
+
# If all else fails, try using BERT tokenizer as last resort
|
138 |
+
print("Falling back to BERT tokenizer...")
|
139 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
140 |
+
'bert-base-uncased',
|
141 |
+
model_max_length=self.model_config['max_length']
|
142 |
+
)
|
143 |
+
|
144 |
+
def _initialize_model(self):
|
145 |
+
"""Initialize the model with proper error handling"""
|
146 |
+
try:
|
147 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
148 |
+
self.model_name,
|
149 |
+
num_labels=len(self.categories),
|
150 |
+
id2label={i: label for i, label in enumerate(self.categories)},
|
151 |
+
label2id={label: i for i, label in enumerate(self.categories)},
|
152 |
+
trust_remote_code=True # Allow custom code from hub
|
153 |
+
)
|
154 |
+
except Exception as e:
|
155 |
+
raise RuntimeError(f"Failed to initialize model: {str(e)}")
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def list_available_models(cls):
|
159 |
+
"""List all available models with their descriptions"""
|
160 |
+
print("Available models:")
|
161 |
+
for model_type, config in cls.AVAILABLE_MODELS.items():
|
162 |
+
print(f"\n{model_type}:")
|
163 |
+
print(f" Model: {config['name']}")
|
164 |
+
print(f" Description: {config['description']}")
|
165 |
|
166 |
+
def preprocess_text(self, title, abstract=None):
|
167 |
+
"""
|
168 |
+
Preprocess title and abstract
|
169 |
+
|
170 |
+
Args:
|
171 |
+
title (str): Paper title
|
172 |
+
abstract (str, optional): Paper abstract
|
173 |
+
"""
|
174 |
+
if abstract:
|
175 |
+
text = f"Title: {title}\nAbstract: {abstract}"
|
176 |
+
else:
|
177 |
+
text = f"Title: {title}"
|
178 |
+
|
179 |
+
max_length = self.model_config['max_length']
|
180 |
+
|
181 |
+
if self.model_type == 't5':
|
182 |
+
text = "classify: " + text
|
183 |
+
|
184 |
+
return text[:max_length]
|
185 |
+
|
186 |
+
def get_top_categories(self, probabilities, threshold=0.95):
|
187 |
+
"""
|
188 |
+
Get top categories that sum up to the threshold
|
189 |
+
|
190 |
+
Args:
|
191 |
+
probabilities (torch.Tensor): Model predictions
|
192 |
+
threshold (float): Probability threshold (default: 0.95)
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
list: List of (category, probability) tuples
|
196 |
+
"""
|
197 |
+
# Convert to numpy for easier manipulation
|
198 |
+
probs = probabilities.numpy()
|
199 |
+
|
200 |
+
# Sort indices by probability
|
201 |
+
sorted_indices = np.argsort(probs)[::-1]
|
202 |
+
|
203 |
+
# Calculate cumulative sum
|
204 |
+
cumsum = np.cumsum(probs[sorted_indices])
|
205 |
+
|
206 |
+
# Find how many categories we need to reach the threshold
|
207 |
+
mask = cumsum <= threshold
|
208 |
+
if not any(mask): # If first probability is already > threshold
|
209 |
+
mask[0] = True
|
210 |
+
|
211 |
+
# Get the selected indices
|
212 |
+
selected_indices = sorted_indices[mask]
|
213 |
+
|
214 |
+
# Return categories and their probabilities
|
215 |
+
return [
|
216 |
+
{
|
217 |
+
'category': self.category_names.get(self.categories[idx], self.categories[idx]),
|
218 |
+
'arxiv_category': self.categories[idx],
|
219 |
+
'probability': float(probs[idx])
|
220 |
+
}
|
221 |
+
for idx in selected_indices
|
222 |
+
]
|
223 |
|
224 |
+
def classify_paper(self, title, abstract=None):
|
225 |
+
"""
|
226 |
+
Classify a paper based on its title and optional abstract
|
227 |
+
|
228 |
+
Args:
|
229 |
+
title (str): Paper title
|
230 |
+
abstract (str, optional): Paper abstract
|
231 |
+
"""
|
232 |
# Preprocess the text
|
233 |
+
processed_text = self.preprocess_text(title, abstract)
|
234 |
|
235 |
# Tokenize
|
236 |
+
inputs = self.tokenizer(
|
237 |
+
processed_text,
|
238 |
+
return_tensors="pt",
|
239 |
+
truncation=True,
|
240 |
+
max_length=self.model_config['max_length'],
|
241 |
+
padding=True
|
242 |
+
)
|
243 |
|
244 |
# Get model predictions
|
245 |
with torch.no_grad():
|
246 |
outputs = self.model(**inputs)
|
247 |
+
predictions = torch.softmax(outputs.logits, dim=1)[0]
|
248 |
+
|
249 |
+
# Get top categories that sum to 95% probability
|
250 |
+
top_categories = self.get_top_categories(predictions)
|
|
|
251 |
|
252 |
+
# Return predictions
|
253 |
return {
|
254 |
+
'top_categories': top_categories,
|
255 |
+
'model_used': self.model_type,
|
256 |
+
'input_type': 'title_and_abstract' if abstract else 'title_only'
|
257 |
+
}
|
258 |
+
|
259 |
+
def train_on_arxiv(self, train_texts, train_labels, validation_texts=None, validation_labels=None,
|
260 |
+
epochs=3, batch_size=16, learning_rate=2e-5):
|
261 |
+
"""
|
262 |
+
Function to fine-tune the model on ArXiv data
|
263 |
+
|
264 |
+
Args:
|
265 |
+
train_texts (list): List of paper texts (title + abstract)
|
266 |
+
train_labels (list): List of corresponding ArXiv categories
|
267 |
+
validation_texts (list, optional): Validation texts
|
268 |
+
validation_labels (list, optional): Validation labels
|
269 |
+
epochs (int): Number of training epochs
|
270 |
+
batch_size (int): Training batch size
|
271 |
+
learning_rate (float): Learning rate for training
|
272 |
+
"""
|
273 |
+
from transformers import TrainingArguments, Trainer
|
274 |
+
import datasets
|
275 |
+
|
276 |
+
# Prepare datasets
|
277 |
+
train_encodings = self.tokenizer(
|
278 |
+
train_texts,
|
279 |
+
truncation=True,
|
280 |
+
padding=True,
|
281 |
+
max_length=self.model_config['max_length']
|
282 |
+
)
|
283 |
+
|
284 |
+
# Convert labels to ids
|
285 |
+
train_label_ids = [self.categories.index(label) for label in train_labels]
|
286 |
+
|
287 |
+
# Create training dataset
|
288 |
+
train_dataset = datasets.Dataset.from_dict({
|
289 |
+
'input_ids': train_encodings['input_ids'],
|
290 |
+
'attention_mask': train_encodings['attention_mask'],
|
291 |
+
'labels': train_label_ids
|
292 |
+
})
|
293 |
+
|
294 |
+
# Create validation dataset if provided
|
295 |
+
if validation_texts and validation_labels:
|
296 |
+
val_encodings = self.tokenizer(
|
297 |
+
validation_texts,
|
298 |
+
truncation=True,
|
299 |
+
padding=True,
|
300 |
+
max_length=self.model_config['max_length']
|
301 |
+
)
|
302 |
+
val_label_ids = [self.categories.index(label) for label in validation_labels]
|
303 |
+
validation_dataset = datasets.Dataset.from_dict({
|
304 |
+
'input_ids': val_encodings['input_ids'],
|
305 |
+
'attention_mask': val_encodings['attention_mask'],
|
306 |
+
'labels': val_label_ids
|
307 |
+
})
|
308 |
+
else:
|
309 |
+
validation_dataset = None
|
310 |
+
|
311 |
+
# Training arguments
|
312 |
+
training_args = TrainingArguments(
|
313 |
+
output_dir=f"./results_{self.model_type}",
|
314 |
+
num_train_epochs=epochs,
|
315 |
+
per_device_train_batch_size=batch_size,
|
316 |
+
per_device_eval_batch_size=batch_size,
|
317 |
+
warmup_steps=500,
|
318 |
+
weight_decay=0.01,
|
319 |
+
logging_dir=f"./logs_{self.model_type}",
|
320 |
+
logging_steps=10,
|
321 |
+
learning_rate=learning_rate,
|
322 |
+
evaluation_strategy="epoch" if validation_dataset else "no",
|
323 |
+
save_strategy="epoch",
|
324 |
+
load_best_model_at_end=True if validation_dataset else False,
|
325 |
+
)
|
326 |
+
|
327 |
+
# Initialize trainer
|
328 |
+
trainer = Trainer(
|
329 |
+
model=self.model,
|
330 |
+
args=training_args,
|
331 |
+
train_dataset=train_dataset,
|
332 |
+
eval_dataset=validation_dataset,
|
333 |
+
)
|
334 |
+
|
335 |
+
# Train the model
|
336 |
+
trainer.train()
|
337 |
+
|
338 |
+
# Save the fine-tuned model
|
339 |
+
save_dir = f"./fine_tuned_{self.model_type}"
|
340 |
+
self.model.save_pretrained(save_dir)
|
341 |
+
self.tokenizer.save_pretrained(save_dir)
|
342 |
+
print(f"Model saved to {save_dir}")
|
343 |
+
|
344 |
+
@classmethod
|
345 |
+
def load_fine_tuned(cls, model_type, model_path):
|
346 |
+
"""
|
347 |
+
Load a fine-tuned model from disk
|
348 |
+
|
349 |
+
Args:
|
350 |
+
model_type (str): The type of model that was fine-tuned
|
351 |
+
model_path (str): Path to the saved model
|
352 |
+
"""
|
353 |
+
classifier = cls(model_type)
|
354 |
+
classifier.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
355 |
+
classifier.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
356 |
+
return classifier
|
requirements.lock
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.32.0
|
2 |
+
altair==5.2.0
|
3 |
+
attrs==23.2.0
|
4 |
+
blinker==1.7.0
|
5 |
+
cachetools==5.3.3
|
6 |
+
certifi==2024.2.2
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
click==8.1.7
|
9 |
+
gitdb==4.0.11
|
10 |
+
gitpython==3.1.42
|
11 |
+
idna==3.6
|
12 |
+
importlib-metadata==7.0.2
|
13 |
+
jinja2==3.1.3
|
14 |
+
jsonschema==4.21.1
|
15 |
+
markdown-it-py==3.0.0
|
16 |
+
markupsafe==2.1.5
|
17 |
+
mdurl==0.1.2
|
18 |
+
numpy==1.26.4
|
19 |
+
packaging==23.2
|
20 |
+
pandas==2.2.0
|
21 |
+
pillow==10.2.0
|
22 |
+
protobuf==4.25.3
|
23 |
+
pyarrow==15.0.1
|
24 |
+
pydeck==0.8.1b0
|
25 |
+
pygments==2.17.2
|
26 |
+
python-dateutil==2.9.0
|
27 |
+
pytz==2024.1
|
28 |
+
requests==2.31.0
|
29 |
+
rich==13.7.1
|
30 |
+
six==1.16.0
|
31 |
+
smmap==5.0.1
|
32 |
+
tenacity==8.2.3
|
33 |
+
toml==0.10.2
|
34 |
+
toolz==0.12.1
|
35 |
+
tornado==6.4
|
36 |
+
typing-extensions==4.10.0
|
37 |
+
tzdata==2024.1
|
38 |
+
tzlocal==5.2
|
39 |
+
urllib3==2.2.1
|
40 |
+
validators==0.22.0
|
41 |
+
watchdog==4.0.0
|
42 |
+
zipp==3.17.0
|
43 |
+
torch==2.2.0
|
44 |
+
filelock==3.13.1
|
45 |
+
fsspec==2024.2.0
|
46 |
+
jinja2==3.1.3
|
47 |
+
networkx==3.2.1
|
48 |
+
sympy==1.12
|
49 |
+
typing-extensions==4.10.0
|
50 |
+
transformers==4.37.2
|
51 |
+
huggingface-hub==0.21.4
|
52 |
+
packaging==23.2
|
53 |
+
pyyaml==6.0.1
|
54 |
+
regex==2023.12.25
|
55 |
+
requests==2.31.0
|
56 |
+
tokenizers==0.15.2
|
57 |
+
tqdm==4.66.2
|
58 |
+
scikit-learn==1.4.0
|
59 |
+
joblib==1.3.2
|
60 |
+
numpy==1.26.4
|
61 |
+
scipy==1.12.0
|
62 |
+
threadpoolctl==3.3.0
|
63 |
+
PyPDF2==3.0.1
|
64 |
+
typing-extensions==4.10.0
|
requirements.txt
CHANGED
@@ -2,4 +2,11 @@ streamlit==1.32.0
|
|
2 |
torch==2.2.0
|
3 |
transformers==4.37.2
|
4 |
scikit-learn==1.4.0
|
5 |
-
PyPDF2==3.0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
torch==2.2.0
|
3 |
transformers==4.37.2
|
4 |
scikit-learn==1.4.0
|
5 |
+
PyPDF2==3.0.1
|
6 |
+
datasets==2.18.0
|
7 |
+
arxiv==2.1.0
|
8 |
+
beautifulsoup4==4.12.3
|
9 |
+
sentencepiece==0.2.0
|
10 |
+
tokenizers==0.15.2
|
11 |
+
protobuf==4.25.3
|
12 |
+
sacremoses==0.1.1
|