k999fff's picture
update readme.md, add context
f9979bb
metadata
license: apache-2.0
datasets:
  - WorkInTheDark/FairytaleQA
language:
  - en
metrics:
  - f1
  - accuracy
  - recall
base_model:
  - google-bert/bert-base-uncased
pipeline_tag: text-classification
library_name: transformers

BertForStorySkillClassification

Model Overview

BertForStorySkillClassification is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:

  1. Character
  2. Setting
  3. Feeling
  4. Action
  5. Causal Relationship
  6. Outcome Resolution
  7. Prediction

This model is suitable for applications in education, literary analysis, and story comprehension.


Model Architecture

  • Base Model: bert-base-uncased
  • Classification Layer: A fully connected layer on top of BERT for 7-class classification.
  • Input: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor <context> alice is ... ")
  • Output: Predicted label and confidence score.

Quick Start

Install Dependencies

Ensure you have the transformers library installed:

pip install transformers

Load Model and Tokenizer

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")

Use the predict Method for Inference

# Single text prediction
result = model.predict(
    texts="Where does this story take place?",
    tokenizer=tokenizer,
    return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]

# Batch prediction
results = model.predict(
    texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
    tokenizer=tokenizer,
    batch_size=16,
    device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
 {'text': 'How does the story end?', 'label': 'action'},
 {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
  'label': 'causal relationship'}]
"""

Training Details

Dataset

Source: FairytaleQAData

Training Parameters

Learning Rate: 2e-5 Batch Size: 32 Epochs: 3 Optimizer: AdamW

Performance Metrics

Accuracy: 97.3%

Recall: 96.59%

F1 Score: 96.96%

Notes

  1. Input Length: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
  2. Device Support: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
  3. Tokenizer: Always use the matching tokenizer (AutoTokenizer) for the model.

Citation

If you use this model, please cite the following:

@misc{BertForStorySkillClassification,
  author = {curious},
  title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}

License

This model is open-sourced under the Apache 2.0 License. For more details, see the LICENSE file.