Spaces:
Runtime error
Runtime error
""" | |
use LoRA finetuning model | |
""" | |
import torch | |
import gc | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple | |
from .pooling import Attention1dPoolingHead, MeanPoolingHead, LightAttentionPoolingHead | |
from .pooling import MeanPooling, MeanPoolingProjection | |
class LoraModel(nn.Module): | |
""" | |
finetuning encoder | |
""" | |
def __init__(self, args) -> None: | |
super().__init__() | |
self.args = args | |
if args.pooling_method == "attention1d": | |
self.classifier = Attention1dPoolingHead( | |
args.hidden_size, args.num_labels, args.pooling_dropout | |
) | |
elif args.pooling_method == "mean": | |
if "PPI" in args.dataset: | |
self.pooling = MeanPooling() | |
self.projection = MeanPoolingProjection( | |
args.hidden_size, args.num_labels, args.pooling_dropout | |
) | |
else: | |
self.classifier = MeanPoolingHead( | |
args.hidden_size, args.num_labels, args.pooling_dropout | |
) | |
elif args.pooling_method == "light_attention": | |
self.classifier = LightAttentionPoolingHead( | |
args.hidden_size, args.num_labels, args.pooling_dropout | |
) | |
else: | |
raise ValueError(f"classifier method {args.pooling_method} not supported") | |
def plm_embedding(self, plm_model, aa_seq, attention_mask, stru_token=None): | |
if ( | |
self.training | |
and hasattr(self, "args") | |
and self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] | |
): | |
if "ProSST" in self.args.plm_model: | |
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True) | |
elif "Prime" in self.args.plm_model: | |
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, output_hidden_states=True) | |
else: | |
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask) | |
else: | |
with torch.no_grad(): | |
if "ProSST" in self.args.plm_model: | |
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask, ss_input_ids=stru_token, output_hidden_states=True) | |
else: | |
outputs = plm_model(input_ids=aa_seq, attention_mask=attention_mask) | |
seq_embeds = outputs.last_hidden_state | |
gc.collect() | |
torch.cuda.empty_cache() | |
return seq_embeds | |
def forward(self, plm_model, batch): | |
if "ProSST" in self.args.plm_model: | |
aa_seq, attention_mask, stru_token = ( | |
batch["aa_seq_input_ids"], | |
batch["aa_seq_attention_mask"], | |
batch["aa_seq_stru_tokens"] | |
) | |
seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask, stru_token) | |
else: | |
aa_seq, attention_mask = ( | |
batch["aa_seq_input_ids"], | |
batch["aa_seq_attention_mask"], | |
) | |
seq_embeds = self.plm_embedding(plm_model, aa_seq, attention_mask) | |
logits = self.classifier(seq_embeds, attention_mask) | |
return logits | |