k999fff's picture
update readme.md, add context
f9979bb
|
raw
history blame
3.63 kB
---
license: apache-2.0
datasets:
- WorkInTheDark/FairytaleQA
language:
- en
metrics:
- f1
- accuracy
- recall
base_model:
- google-bert/bert-base-uncased
pipeline_tag: text-classification
library_name: transformers
---
# BertForStorySkillClassification
## Model Overview
`BertForStorySkillClassification` is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
1. **Character**
2. **Setting**
3. **Feeling**
4. **Action**
5. **Causal Relationship**
6. **Outcome Resolution**
7. **Prediction**
This model is suitable for applications in education, literary analysis, and story comprehension.
---
## Model Architecture
- **Base Model**: `bert-base-uncased`
- **Classification Layer**: A fully connected layer on top of BERT for 7-class classification.
- **Input**: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? \<SEP> because her family was very poor \<context> alice is ... ")
- **Output**: Predicted label and confidence score.
---
## Quick Start
### Install Dependencies
Ensure you have the `transformers` library installed:
```bash
pip install transformers
```
### Load Model and Tokenizer
```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
```
### Use the predict Method for Inference
```python
# Single text prediction
result = model.predict(
texts="Where does this story take place?",
tokenizer=tokenizer,
return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]
# Batch prediction
results = model.predict(
texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
tokenizer=tokenizer,
batch_size=16,
device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
{'text': 'How does the story end?', 'label': 'action'},
{'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
'label': 'causal relationship'}]
"""
```
## Training Details
### Dataset
Source: [FairytaleQAData](https://github.com/uci-soe/FairytaleQAData)
### Training Parameters
Learning Rate: 2e-5
Batch Size: 32
Epochs: 3
Optimizer: AdamW
### Performance Metrics
Accuracy: 97.3%
Recall: 96.59%
F1 Score: 96.96%
## Notes
1. **Input Length**: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
2. **Device Suppor**t: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
3. **Tokenize**r: Always use the matching tokenizer (AutoTokenizer) for the model.
## Citation
If you use this model, please cite the following:
```
@misc{BertForStorySkillClassification,
author = {curious},
title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}
```
## License
This model is open-sourced under the Apache 2.0 License. For more details, see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) file.