BertForStorySkillClassification / modeling_bert_classifier.py
k999fff's picture
Add new file
be52a7a
from typing import Dict, List, Union
from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer
import torch.nn as nn
import torch
class BertForStorySkillClassification(BertPreTrainedModel):
def __init__(self,config):
super(BertForStorySkillClassification,self).__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.post_init()
def forward(self,input_ids,attention_mask=None,labels=None,**kwargs):
outputs = self.bert(input_ids,attention_mask=attention_mask)
cls_hidden_state = outputs.last_hidden_state[:,0,:] ## [batch_size,seq_len,hidden_size]
logits = self.classifier(cls_hidden_state) ## [batch_size,num_labels]
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
return loss
return logits
def predict(
self,
texts: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
batch_size: int = 32,
return_probabilities: bool = False,
device: Union[str, torch.device] = 'cpu',
) -> List[Dict]:
"""
对输入文本进行分类预测。
Args:
texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"]
tokenizer: 分词器实例(需与模型兼容)
batch_size: 批处理大小(提升推理速度)
return_probabilities: 是否返回概率值(默认返回标签)
device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备
Returns:
预测结果列表,格式为:
[{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...]
"""
# 自动获取模型所在设备
if device is None:
device = self.device
# 统一输入格式为列表
if isinstance(texts, str):
texts = [texts]
# 结果存储
predictions = []
# 批处理预测
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
# 分词并转换为张量
inputs = tokenizer(
batch_texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512, # 与BERT最大长度一致
).to(device)
# 模型推理
logits = self(**inputs)
probs = torch.softmax(logits, dim=-1)
scores, class_ids = torch.max(probs, dim=-1)
# 转换为标签和分数
for text, class_id, score in zip(batch_texts, class_ids, scores):
label = self.config.id2label[class_id.item()]
result = {"text": text, "label": label}
if return_probabilities:
result["score"] = score.item()
predictions.append(result)
return predictions