Spaces:
Runtime error
Runtime error
File size: 3,365 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
"""
use LoRA finetuning model
"""
import torch
import gc
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .pooling import Attention1dPoolingHead, MeanPoolingHead, LightAttentionPoolingHead
from .pooling import MeanPooling, MeanPoolingProjection
class LoraModel(nn.Module):
"""
finetuning encoder
"""
def __init__(self, args) -> None:
super().__init__()
self.args = args
if args.pooling_method == "attention1d":
self.classifier = Attention1dPoolingHead(
args.hidden_size, args.num_labels, args.pooling_dropout
)
elif args.pooling_method == "mean":
if "PPI" in args.dataset:
self.pooling = MeanPooling()
self.projection = MeanPoolingProjection(
args.hidden_size, args.num_labels, args.pooling_dropout
)
else:
self.classifier = MeanPoolingHead(
args.hidden_size, args.num_labels, args.pooling_dropout
)
elif args.pooling_method == "light_attention":
self.classifier = LightAttentionPoolingHead(
args.hidden_size, args.num_labels, args.pooling_dropout
)
else:
raise ValueError(f"classifier method {args.pooling_method} not supported")
def plm_embedding(self, plm_model, aa_seq, attention_mask, stru_token=None):
if (
self.training
and hasattr(self, "args")
and self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']
):
if "ProSST" in self.args.plm_model:
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True)
elif "Prime" in self.args.plm_model:
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, output_hidden_states=True)
else:
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask)
else:
with torch.no_grad():
if "ProSST" in self.args.plm_model:
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True)
else:
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask)
seq_embeds = outputs.last_hidden_state
gc.collect()
torch.cuda.empty_cache()
return seq_embeds
def forward(self, plm_model, batch):
if "ProSST" in self.args.plm_model:
aa_seq, attention_mask, stru_token = (
batch["aa_seq_input_ids"],
batch["aa_seq_attention_mask"],
batch["aa_seq_stru_tokens"]
)
seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask, stru_token)
else:
aa_seq, attention_mask = (
batch["aa_seq_input_ids"],
batch["aa_seq_attention_mask"],
)
seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask)
logits = self.classifier(seq_embeds, attention_mask)
return logits
|