File size: 3,365 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""

use LoRA finetuning model

"""
import torch
import gc
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .pooling import Attention1dPoolingHead, MeanPoolingHead, LightAttentionPoolingHead
from .pooling import MeanPooling, MeanPoolingProjection


class LoraModel(nn.Module):
    """

    finetuning encoder

    """

    def __init__(self, args) -> None:
        super().__init__()
        self.args = args
        if args.pooling_method == "attention1d":
            self.classifier = Attention1dPoolingHead(
                args.hidden_size, args.num_labels, args.pooling_dropout
            )
        elif args.pooling_method == "mean":
            if "PPI" in args.dataset:
                self.pooling = MeanPooling()
                self.projection = MeanPoolingProjection(
                    args.hidden_size, args.num_labels, args.pooling_dropout
                )
            else:
                self.classifier = MeanPoolingHead(
                    args.hidden_size, args.num_labels, args.pooling_dropout
                )
        elif args.pooling_method == "light_attention":
            self.classifier = LightAttentionPoolingHead(
                args.hidden_size, args.num_labels, args.pooling_dropout
            )
        else:
            raise ValueError(f"classifier method {args.pooling_method} not supported")

    def plm_embedding(self, plm_model, aa_seq, attention_mask, stru_token=None):
        if (
            self.training
            and hasattr(self, "args")
            and self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']
        ):
            if "ProSST" in self.args.plm_model:
                outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True)
            elif "Prime" in self.args.plm_model:
                outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, output_hidden_states=True)
            else:
                outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask)
        else:
            with torch.no_grad():
                if "ProSST" in self.args.plm_model:
                    outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True)
                else:
                    outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask)
        seq_embeds = outputs.last_hidden_state
        gc.collect()
        torch.cuda.empty_cache()
        return seq_embeds

    def forward(self, plm_model, batch):
        if "ProSST" in self.args.plm_model:
            aa_seq, attention_mask, stru_token = (
                batch["aa_seq_input_ids"],
                batch["aa_seq_attention_mask"],
                batch["aa_seq_stru_tokens"]
            )
            seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask, stru_token)
        else:
            aa_seq, attention_mask = (
                batch["aa_seq_input_ids"],
                batch["aa_seq_attention_mask"],
            )            
            seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask)
        logits = self.classifier(seq_embeds, attention_mask)
        return logits