Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, EsmModel, T5Tokenizer, T5EncoderModel, BertModel, AutoModelForMaskedLM | |
from transformers import BertTokenizer, EsmTokenizer, T5Tokenizer | |
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig | |
from typing import List, Dict, Any, Tuple | |
from transformers import PreTrainedModel | |
def prepare_for_lora_model( | |
based_model, | |
lora_r: int = 8, | |
lora_alpha: int = 32, | |
lora_dropout: float = 0.1, | |
target_modules: List[str,] = ["key", "query", "value"], | |
): | |
if not isinstance(based_model, PreTrainedModel): | |
raise TypeError("based_model must be a PreTrainedModel instance") | |
# validate target_modules exist in model | |
available_modules = [name for name, _ in based_model.named_modules()] | |
for module in target_modules: | |
if not any(module in name for name in available_modules): | |
raise ValueError(f"Target module {module} not found in model") | |
# get lora config | |
lora_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
target_modules=target_modules, | |
) | |
# get lora model | |
model = get_peft_model(based_model, lora_config) | |
print("Lora model is ready! num of trainable_parameters: ") | |
model.print_trainable_parameters() | |
return model | |
def load_lora_model(base_model, lora_ckpt_path): | |
model = PeftModel.from_pretrained(base_model, lora_ckpt_path) | |
return model | |
def load_eval_base_model(plm_model): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if "esm" in plm_model: | |
base_model = EsmModel.from_pretrained(plm_model).to(device) | |
elif "bert" in plm_model: | |
base_model = BertModel.from_pretrained(plm_model).to(device) | |
elif "prot_t5" in plm_model: | |
base_model = T5EncoderModel.from_pretrained(plm_model).to(device) | |
elif "ankh" in plm_model: | |
base_model = T5EncoderModel.from_pretrained(plm_model).to(device) | |
elif "ProSST" in plm_model: | |
base_model = AutoModelForMaskedLM.from_pretrained(plm_model).to(device) | |
return base_model | |
def check_lora_params(model): | |
lora_params = [ | |
(name, param) for name, param in model.named_parameters() if "lora_" in name | |
] | |
print(f"\n num of lora params: {len(lora_params)}") | |
if len(lora_params) == 0: | |
print("warning: no lora params found!") | |
else: | |
print("\n first lora param:") | |
name, param = lora_params[0] | |
print(f"name: {name}") | |
print(f"param.shape: {param.shape}") | |
print(f"param.dtype: {param.dtype}") | |
print(f"param.device: {param.device}") | |
# print(f"param_value:\n{param.data.cpu().numpy()}") | |
print(f"requires_grad: {param.requires_grad}") | |