|
from typing import Dict, List, Union |
|
from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer |
|
import torch.nn as nn |
|
import torch |
|
class BertForStorySkillClassification(BertPreTrainedModel): |
|
def __init__(self,config): |
|
super(BertForStorySkillClassification,self).__init__(config) |
|
self.num_labels = config.num_labels |
|
self.bert = BertModel(config) |
|
self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
self.post_init() |
|
|
|
def forward(self,input_ids,attention_mask=None,labels=None,**kwargs): |
|
outputs = self.bert(input_ids,attention_mask=attention_mask) |
|
cls_hidden_state = outputs.last_hidden_state[:,0,:] |
|
logits = self.classifier(cls_hidden_state) |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1)) |
|
return loss |
|
return logits |
|
|
|
|
|
def predict( |
|
self, |
|
texts: Union[str, List[str]], |
|
tokenizer: PreTrainedTokenizer, |
|
batch_size: int = 32, |
|
return_probabilities: bool = False, |
|
device: Union[str, torch.device] = 'cpu', |
|
) -> List[Dict]: |
|
""" |
|
对输入文本进行分类预测。 |
|
|
|
Args: |
|
texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"] |
|
tokenizer: 分词器实例(需与模型兼容) |
|
batch_size: 批处理大小(提升推理速度) |
|
return_probabilities: 是否返回概率值(默认返回标签) |
|
device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备 |
|
|
|
Returns: |
|
预测结果列表,格式为: |
|
[{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...] |
|
""" |
|
|
|
if device is None: |
|
device = self.device |
|
|
|
|
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
|
|
predictions = [] |
|
|
|
|
|
with torch.no_grad(): |
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i : i + batch_size] |
|
|
|
|
|
inputs = tokenizer( |
|
batch_texts, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
max_length=512, |
|
).to(device) |
|
|
|
|
|
logits = self(**inputs) |
|
probs = torch.softmax(logits, dim=-1) |
|
scores, class_ids = torch.max(probs, dim=-1) |
|
|
|
|
|
for text, class_id, score in zip(batch_texts, class_ids, scores): |
|
label = self.config.id2label[class_id.item()] |
|
result = {"text": text, "label": label} |
|
if return_probabilities: |
|
result["score"] = score.item() |
|
predictions.append(result) |
|
|
|
return predictions |