|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- WorkInTheDark/FairytaleQA |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- f1 |
|
|
- accuracy |
|
|
- recall |
|
|
base_model: |
|
|
- google-bert/bert-base-uncased |
|
|
pipeline_tag: text-classification |
|
|
library_name: transformers |
|
|
--- |
|
|
# BertForStorySkillClassification |
|
|
|
|
|
## Model Overview |
|
|
`BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes: |
|
|
1. **Character** |
|
|
2. **Setting** |
|
|
3. **Feeling** |
|
|
4. **Action** |
|
|
5. **Causal Relationship** |
|
|
6. **Outcome Resolution** |
|
|
7. **Prediction** |
|
|
|
|
|
This model is suitable for applications in education, literary analysis, and story comprehension. |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Architecture |
|
|
- **Base Model**: `bert-base-uncased` |
|
|
- **Classification Layer**: A fully connected layer on top of BERT for 7-class classification. |
|
|
- **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ") |
|
|
- **Output**: Predicted label and confidence score. |
|
|
|
|
|
--- |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Install Dependencies |
|
|
Ensure you have the `transformers` library installed: |
|
|
```bash |
|
|
pip install transformers |
|
|
``` |
|
|
|
|
|
### Load Model and Tokenizer |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification") |
|
|
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification") |
|
|
``` |
|
|
|
|
|
### Use the predict Method for Inference |
|
|
|
|
|
```python |
|
|
# Single text prediction |
|
|
result = model.predict( |
|
|
texts="Where does this story take place?", |
|
|
tokenizer=tokenizer, |
|
|
return_probabilities=True |
|
|
) |
|
|
print(result) |
|
|
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}] |
|
|
|
|
|
# Batch prediction |
|
|
results = model.predict( |
|
|
texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "], |
|
|
tokenizer=tokenizer, |
|
|
batch_size=16, |
|
|
device="cuda" |
|
|
) |
|
|
print(results) |
|
|
""" |
|
|
output: |
|
|
[{'text': 'Why is the character sad?', 'label': 'causal relationship'}, |
|
|
{'text': 'How does the story end?', 'label': 'action'}, |
|
|
{'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ", |
|
|
'label': 'causal relationship'}] |
|
|
""" |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
### Dataset |
|
|
Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData) |
|
|
|
|
|
### Training Parameters |
|
|
Learning Rate: 2e-5 |
|
|
Batch Size: 32 |
|
|
Epochs: 3 |
|
|
Optimizer: AdamW |
|
|
|
|
|
### Performance Metrics |
|
|
Accuracy: 97.3% |
|
|
|
|
|
Recall: 96.59% |
|
|
|
|
|
F1 Score: 96.96% |
|
|
|
|
|
## Notes |
|
|
1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated. |
|
|
2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance. |
|
|
3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite the following: |
|
|
|
|
|
``` |
|
|
@misc{BertForStorySkillClassification, |
|
|
author = {curious}, |
|
|
title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification}, |
|
|
year = {2025}, |
|
|
publisher = {Hugging Face}, |
|
|
howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file. |