VenusFactory / src /utils /lora_tools.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
from transformers import AutoTokenizer, EsmModel, T5Tokenizer, T5EncoderModel, BertModel, AutoModelForMaskedLM
from transformers import BertTokenizer, EsmTokenizer, T5Tokenizer
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from typing import List, Dict, Any, Tuple
from transformers import PreTrainedModel
def prepare_for_lora_model(
based_model,
lora_r: int = 8,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
target_modules: List[str,] = ["key", "query", "value"],
):
if not isinstance(based_model, PreTrainedModel):
raise TypeError("based_model must be a PreTrainedModel instance")
# validate target_modules exist in model
available_modules = [name for name, _ in based_model.named_modules()]
for module in target_modules:
if not any(module in name for name in available_modules):
raise ValueError(f"Target module {module} not found in model")
# get lora config
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
)
# get lora model
model = get_peft_model(based_model, lora_config)
print("Lora model is ready! num of trainable_parameters: ")
model.print_trainable_parameters()
return model
def load_lora_model(base_model, lora_ckpt_path):
model = PeftModel.from_pretrained(base_model, lora_ckpt_path)
return model
def load_eval_base_model(plm_model):
device = "cuda" if torch.cuda.is_available() else "cpu"
if "esm" in plm_model:
base_model = EsmModel.from_pretrained(plm_model).to(device)
elif "bert" in plm_model:
base_model = BertModel.from_pretrained(plm_model).to(device)
elif "prot_t5" in plm_model:
base_model = T5EncoderModel.from_pretrained(plm_model).to(device)
elif "ankh" in plm_model:
base_model = T5EncoderModel.from_pretrained(plm_model).to(device)
elif "ProSST" in plm_model:
base_model = AutoModelForMaskedLM.from_pretrained(plm_model).to(device)
return base_model
def check_lora_params(model):
lora_params = [
(name, param) for name, param in model.named_parameters() if "lora_" in name
]
print(f"\n num of lora params: {len(lora_params)}")
if len(lora_params) == 0:
print("warning: no lora params found!")
else:
print("\n first lora param:")
name, param = lora_params[0]
print(f"name: {name}")
print(f"param.shape: {param.shape}")
print(f"param.dtype: {param.dtype}")
print(f"param.device: {param.device}")
# print(f"param_value:\n{param.data.cpu().numpy()}")
print(f"requires_grad: {param.requires_grad}")