zero-shot-prompt-classifier-bart-ft
This model is a fine-tuned version of facebook/bart-large-mnli on the reddgr/nli-chatbot-prompt-categorization dataset.
The purpose of the model is to help classify chatbot prompts into categories that are relevant in the context of working with LLM conversational tools: coding assistance, language assistance, role play, creative writing, general knowledge questions...
The model is fine-tuned and tested on the natural language inference (NLI) dataset reddgr/nli-chatbot-prompt-categorization
Below is a confusion matrix calculated on zero-shot inferences for the 10 most popular categories in the Test split of reddgr/nli-chatbot-prompt-categorization at the time of the first model upload. The classification with the base model on the same small test dataset is shown for comparison:
The current version of the fine-tuned model outperforms the base model facebook/bart-large-mnli by 34 percentage points (76% accuracy vs 42% accuracy) in a test set with 10 candidate zero-shot classes (the most frequent categories in the test split of reddgr/nli-chatbot-prompt-categorization).
The chart below compares the results for the 12 most popular candidate classes in the Test split, where the base model's zero-shot accuracy is outperformed by 32 percentage points:
We can also use the model to perform zero-shot inferences on combinations of categories formulated in natural language. The chart below compares the results for the 6 main category groups that classify conversations in Talking to Chatbots
The dataset and the model are continuously updated as they assist with content publishing on my website Talking to Chatbots
Model description
More information needed
Intended uses & limitations
More information needed
Training and evaluation data
More information needed
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- optimizer: {'name': 'Adam', 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'jit_compile': False, 'is_legacy_optimizer': False, 'learning_rate': 5e-06, 'beta_1': 0.9, 'beta_2': 0.999, 'epsilon': 1e-07, 'amsgrad': False}
- training_precision: float32
Training results
{'eval_loss': 0.8465692400932312, 'eval_runtime': 57.9011, 'eval_samples_per_second': 6.667, 'eval_steps_per_second': 0.846, 'epoch': 1.0, 'step': 19} {'eval_loss': 0.8361125588417053, 'eval_runtime': 60.2437, 'eval_samples_per_second': 6.407, 'eval_steps_per_second': 0.813, 'epoch': 2.0, 'step': 38} {'eval_loss': 0.6992325782775879, 'eval_runtime': 60.8204, 'eval_samples_per_second': 6.347, 'eval_steps_per_second': 0.806, 'epoch': 3.0, 'step': 57} {'eval_loss': 0.8125494718551636, 'eval_runtime': 59.2043, 'eval_samples_per_second': 6.52, 'eval_steps_per_second': 0.828, 'epoch': 4.0, 'step': 76} {'train_runtime': 1626.4598, 'train_samples_per_second': 1.424, 'train_steps_per_second': 0.047, 'total_flos': 624333153618216.0, 'train_loss': 0.7128369180779708, 'epoch': 4.0, 'step': 76} Train metrics: {'train_runtime': 1626.4598, 'train_samples_per_second': 1.424, 'train_steps_per_second': 0.047, 'total_flos': 624333153618216.0, 'train_loss': 0.7128369180779708, 'epoch': 4.0}
Framework versions
- Transformers 4.44.2
- TensorFlow 2.18.0-dev20240717
- Datasets 2.21.0
- Tokenizers 0.19.1
- Downloads last month
- 170
Model tree for reddgr/zero-shot-prompt-classifier-bart-ft
Base model
facebook/bart-large-mnli