File size: 922 Bytes
adef9f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .model_config import CustomConfig

class LogRegClassifier(nn.Module):
    def __init__(self, transformer_output_dim):
        super(LogRegClassifier, self).__init__()
        self.linear = nn.Linear(transformer_output_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

class CombinedModel(PreTrainedModel):
    config_class = CustomConfig

    def __init__(self, config):
        super().__init__(config)
        self.transformer = AutoModel.from_pretrained(config.transformer_type)
        self.classifier = LogRegClassifier(config.transformer_output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(pooled_output)