Spaces:
Runtime error
Runtime error
import torch | |
from transformers import ( | |
EsmTokenizer, EsmModel, | |
BertTokenizer, BertModel, | |
T5Tokenizer, T5EncoderModel, | |
AutoTokenizer, PreTrainedModel, | |
AutoModelForMaskedLM, AutoModel | |
) | |
from peft import prepare_model_for_kbit_training | |
from .adapter_model import AdapterModel | |
from .lora_model import LoraModel | |
def create_models(args): | |
"""Create and initialize models and tokenizer.""" | |
# Create tokenizer and PLM | |
tokenizer, plm_model = create_plm_and_tokenizer(args) | |
# Update hidden size based on PLM | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
# Handle structure sequence vocabulary | |
if args.training_method == 'ses-adapter': | |
args.vocab_size = get_vocab_size(plm_model, args.structure_seq) | |
# Create adapter model | |
model = AdapterModel(args) | |
# Handle PLM parameters based on training method | |
if args.training_method != 'full': | |
freeze_plm_parameters(plm_model) | |
# if args.training_method == 'ses-adapter': | |
# plm_model=create_models(plm_model, args) | |
if args.training_method == 'plm-lora': | |
plm_model=setup_lora_plm(plm_model, args) | |
elif args.training_method == 'plm-qlora': | |
plm_model=create_qlora_model(plm_model, args) | |
elif args.training_method == 'plm-adalora': | |
plm_model=create_adalora_model(plm_model, args) | |
elif args.training_method == "plm-dora": | |
plm_model=create_dora_model(plm_model, args) | |
elif args.training_method == "plm-ia3": | |
plm_model=create_ia3_model(plm_model, args) | |
# Move models to device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
plm_model = plm_model.to(device) | |
return model, plm_model, tokenizer | |
def create_lora_model(args): | |
tokenizer, plm_model = create_plm_and_tokenizer(args) | |
# Update hidden size based on PLM | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
model = LoraModel(args=args) | |
# Enable gradient checkpointing | |
plm_model.gradient_checkpointing_enable() | |
plm_model=setup_lora_plm(plm_model, args) | |
return model, plm_model, tokenizer | |
def create_qlora_model(args): | |
qlora_config = setup_quantization_config() | |
tokenizer, plm_model = create_plm_and_tokenizer(args, qlora_config=qlora_config) | |
# Update hidden size based on PLM | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
model = LoraModel(args=args) | |
# Enable gradient checkpointing | |
plm_model.gradient_checkpointing_enable() | |
plm_model = prepare_model_for_kbit_training(plm_model) | |
plm_model=setup_lora_plm(plm_model, args) | |
return model, plm_model, tokenizer | |
def create_dora_model(args): | |
tokenizer, plm_model = create_plm_and_tokenizer(args) | |
# Update hidden size based on PLM | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
model = LoraModel(args=args) | |
# Enable gradient checkpointing | |
plm_model.gradient_checkpointing_enable() | |
plm_model=setup_dora_plm(plm_model, args) | |
return model, plm_model, tokenizer | |
def create_adalora_model(args): | |
tokenizer, plm_model = create_plm_and_tokenizer(args) | |
# Update hidden size based on PLM | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
model = LoraModel(args=args) | |
# Enable gradient checkpointing | |
plm_model.gradient_checkpointing_enable() | |
plm_model=setup_adalora_plm(plm_model, args) | |
print(" Using plm adalora ") | |
return model, plm_model, tokenizer | |
def create_ia3_model(args): | |
tokenizer, plm_model = create_plm_and_tokenizer(args) | |
args.hidden_size = get_hidden_size(plm_model, args.plm_model) | |
model = LoraModel(args=args) | |
plm_model.gradient_checkpointing_enable() | |
plm_model = prepare_model_for_kbit_training(plm_model) | |
plm_model=setup_ia3_plm(plm_model, args) | |
print(" Using plm IA3 ") | |
return model, plm_model, tokenizer | |
def lora_factory(args): | |
if args.training_method in "plm-lora": | |
model, plm_model, tokenizer = create_lora_model(args) | |
elif args.training_method == "plm-qlora": | |
model, plm_model, tokenizer = create_qlora_model(args) | |
elif args.training_method == "plm-dora": | |
model, plm_model, tokenizer = create_dora_model(args) | |
elif args.training_method == "plm-adalora": | |
model, plm_model, tokenizer = create_adalora_model(args) | |
elif args.training_method == "plm-ia3": | |
model, plm_model, tokenizer = create_ia3_model(args) | |
else: | |
raise ValueError(f"Unsupported lora training method: {args.training_method}") | |
# Move models to device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
plm_model = plm_model.to(device) | |
return model, plm_model, tokenizer | |
def freeze_plm_parameters(plm_model): | |
"""Freeze all parameters in the pre-trained language model.""" | |
for param in plm_model.parameters(): | |
param.requires_grad = False | |
plm_model.eval() # Set to evaluation mode | |
def setup_quantization_config(): | |
"""Setup quantization configuration.""" | |
from transformers import BitsAndBytesConfig | |
# https://huggingface.co/docs/peft/v0.14.0/en/developer_guides/quantization#quantize-a-model | |
qlora_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
return qlora_config | |
def setup_lora_plm(plm_model, args): | |
"""Setup LoRA for pre-trained language model.""" | |
# Import LoRA configurations | |
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType | |
if not isinstance(plm_model, PreTrainedModel): | |
raise TypeError("based_model must be a PreTrainedModel instance") | |
# validate lora_target_modules exist in model | |
available_modules = [name for name, _ in plm_model.named_modules()] | |
for module in args.lora_target_modules: | |
if not any(module in name for name in available_modules): | |
raise ValueError(f"Target module {module} not found in model") | |
# Configure LoRA | |
peft_config = LoraConfig( | |
task_type=TaskType.FEATURE_EXTRACTION, | |
inference_mode=False, | |
r=args.lora_r, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
target_modules=args.lora_target_modules, | |
) | |
# Apply LoRA to model | |
plm_model = get_peft_model(plm_model, peft_config) | |
plm_model.print_trainable_parameters() | |
return plm_model | |
def setup_dora_plm(plm_model, args): | |
"""Setup DoRA for pre-trained language model.""" | |
# Import DoRA configurations | |
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType | |
if not isinstance(plm_model, PreTrainedModel): | |
raise TypeError("based_model must be a PreTrainedModel instance") | |
# validate Dora_target_modules exist in model | |
available_modules = [name for name, _ in plm_model.named_modules()] | |
for module in args.lora_target_modules: | |
if not any(module in name for name in available_modules): | |
raise ValueError(f"Target module {module} not found in model") | |
# Configure DoRA | |
peft_config = LoraConfig( | |
task_type=TaskType.FEATURE_EXTRACTION, | |
inference_mode=False, | |
r=args.lora_r, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
target_modules=args.lora_target_modules, | |
use_dora=True | |
) | |
# Apply DoRA to model | |
plm_model = get_peft_model(plm_model, peft_config) | |
plm_model.print_trainable_parameters() | |
return plm_model | |
def setup_adalora_plm(plm_model, args): | |
"""Setup AdaLoRA for pre-trained language model.""" | |
# Import AdaLoRA configurations | |
from peft import get_peft_config, get_peft_model, AdaLoraConfig, TaskType | |
if not isinstance(plm_model, PreTrainedModel): | |
raise TypeError("based_model must be a PreTrainedModel instance") | |
# validate lora_target_modules exist in model | |
available_modules = [name for name, _ in plm_model.named_modules()] | |
for module in args.lora_target_modules: | |
if not any(module in name for name in available_modules): | |
raise ValueError(f"Target module {module} not found in model") | |
# Configure AdaLoRA | |
peft_config = AdaLoraConfig( | |
task_type=TaskType.FEATURE_EXTRACTION, | |
peft_type="ADALORA", | |
init_r=12, | |
r=args.lora_r, | |
lora_alpha=args.lora_alpha, | |
lora_dropout=args.lora_dropout, | |
target_modules=args.lora_target_modules | |
) | |
# Apply AdaLoRA to model | |
plm_model = get_peft_model(plm_model, peft_config) | |
plm_model.print_trainable_parameters() | |
return plm_model | |
def setup_ia3_plm(plm_model, args): | |
"""Setup IA3 for pre-trained language model.""" | |
# Import LoRA configurations | |
from peft import IA3Model, IA3Config, get_peft_model, TaskType | |
if not isinstance(plm_model, PreTrainedModel): | |
raise TypeError("based_model must be a PreTrainedModel instance") | |
# validate lora_target_modules exist in model | |
available_modules = [name for name, _ in plm_model.named_modules()] | |
print(available_modules) | |
for module in args.lora_target_modules: | |
if not any(module in name for name in available_modules): | |
raise ValueError(f"Target module {module} not found in model") | |
# Configure LoRA | |
peft_config = IA3Config( | |
task_type=TaskType.FEATURE_EXTRACTION, | |
peft_type="IA3", | |
target_modules=args.lora_target_modules, | |
feedforward_modules=args.feedforward_modules | |
) | |
# Apply LoRA to model | |
plm_model = get_peft_model(plm_model, peft_config) | |
plm_model.print_trainable_parameters() | |
return plm_model | |
def create_plm_and_tokenizer(args, qlora_config=None): | |
"""Create pre-trained language model and tokenizer based on model type.""" | |
if "esm" in args.plm_model: | |
tokenizer = EsmTokenizer.from_pretrained(args.plm_model) | |
if qlora_config: | |
plm_model = EsmModel.from_pretrained(args.plm_model, quantization_config=qlora_config) | |
else: | |
plm_model = EsmModel.from_pretrained(args.plm_model) | |
elif "bert" in args.plm_model: | |
tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
if qlora_config: | |
plm_model = BertModel.from_pretrained(args.plm_model, quantization_config=qlora_config) | |
else: | |
plm_model = BertModel.from_pretrained(args.plm_model) | |
elif "prot_t5" in args.plm_model: | |
tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
if qlora_config: | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model, quantization_config=qlora_config) | |
else: | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
elif "ankh" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
if qlora_config: | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model, quantization_config=qlora_config) | |
else: | |
plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
elif "ProSST" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
if qlora_config: | |
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config) | |
else: | |
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model, trust_remote_code=True) | |
elif "Prime" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, trust_remote_code=True, do_lower_case=False) | |
if qlora_config: | |
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config) | |
else: | |
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True) | |
elif "deep" in args.plm_model: | |
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
if qlora_config: | |
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config) | |
else: | |
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True) | |
else: | |
raise ValueError(f"Unsupported model type: {args.plm_model}") | |
return tokenizer, plm_model | |
def get_hidden_size(plm_model, model_type): | |
"""Get hidden size based on model type.""" | |
if "esm" in model_type: | |
return plm_model.config.hidden_size | |
elif "bert" in model_type: | |
return plm_model.config.hidden_size | |
elif "prot_t5" in model_type or "ankh" in model_type: | |
return plm_model.config.d_model | |
elif "ProSST" in model_type: | |
return plm_model.config.hidden_size | |
elif "Prime" in model_type: | |
return plm_model.config.hidden_size | |
elif "deep" in model_type: | |
return plm_model.config.hidden_size | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
def get_vocab_size(plm_model, structure_seq): | |
"""Get vocabulary size for structure sequences.""" | |
if 'esm3_structure_seq' in structure_seq: | |
return max(plm_model.config.vocab_size, 4100) | |
return plm_model.config.vocab_size | |