File size: 2,824 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import AutoTokenizer, EsmModel, T5Tokenizer, T5EncoderModel, BertModel, AutoModelForMaskedLM
from transformers import BertTokenizer, EsmTokenizer, T5Tokenizer
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from typing import List, Dict, Any, Tuple
from transformers import PreTrainedModel


def prepare_for_lora_model(

    based_model,

    lora_r: int = 8,

    lora_alpha: int = 32,

    lora_dropout: float = 0.1,

    target_modules: List[str,] = ["key", "query", "value"],

):
    if not isinstance(based_model, PreTrainedModel):
        raise TypeError("based_model must be a PreTrainedModel instance")

    # validate target_modules exist in model
    available_modules = [name for name, _ in based_model.named_modules()]
    for module in target_modules:
        if not any(module in name for name in available_modules):
            raise ValueError(f"Target module {module} not found in model")
    # get lora config
    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
    )
    # get lora model
    model = get_peft_model(based_model, lora_config)
    print("Lora model is ready! num of trainable_parameters: ")
    model.print_trainable_parameters()
    return model


def load_lora_model(base_model, lora_ckpt_path):
    model = PeftModel.from_pretrained(base_model, lora_ckpt_path)
    return model


def load_eval_base_model(plm_model):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if "esm" in plm_model:
        base_model = EsmModel.from_pretrained(plm_model).to(device)
    elif "bert" in plm_model:
        base_model = BertModel.from_pretrained(plm_model).to(device)
    elif "prot_t5" in plm_model:
        base_model = T5EncoderModel.from_pretrained(plm_model).to(device)
    elif "ankh" in plm_model:
        base_model = T5EncoderModel.from_pretrained(plm_model).to(device)
    elif "ProSST" in plm_model:
        base_model = AutoModelForMaskedLM.from_pretrained(plm_model).to(device)    

    return base_model


def check_lora_params(model):
    lora_params = [
        (name, param) for name, param in model.named_parameters() if "lora_" in name
    ]
    print(f"\n num of lora params: {len(lora_params)}")

    if len(lora_params) == 0:
        print("warning: no lora params found!")
    else:
        print("\n first lora param:")
        name, param = lora_params[0]
        print(f"name: {name}")
        print(f"param.shape: {param.shape}")
        print(f"param.dtype: {param.dtype}")
        print(f"param.device: {param.device}")
        # print(f"param_value:\n{param.data.cpu().numpy()}")
        print(f"requires_grad: {param.requires_grad}")