VenusFactory / src /models /model_factory.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
from transformers import (
EsmTokenizer, EsmModel,
BertTokenizer, BertModel,
T5Tokenizer, T5EncoderModel,
AutoTokenizer, PreTrainedModel,
AutoModelForMaskedLM, AutoModel
)
from peft import prepare_model_for_kbit_training
from .adapter_model import AdapterModel
from .lora_model import LoraModel
def create_models(args):
"""Create and initialize models and tokenizer."""
# Create tokenizer and PLM
tokenizer, plm_model = create_plm_and_tokenizer(args)
# Update hidden size based on PLM
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
# Handle structure sequence vocabulary
if args.training_method == 'ses-adapter':
args.vocab_size = get_vocab_size(plm_model, args.structure_seq)
# Create adapter model
model = AdapterModel(args)
# Handle PLM parameters based on training method
if args.training_method != 'full':
freeze_plm_parameters(plm_model)
# if args.training_method == 'ses-adapter':
# plm_model=create_models(plm_model, args)
if args.training_method == 'plm-lora':
plm_model=setup_lora_plm(plm_model, args)
elif args.training_method == 'plm-qlora':
plm_model=create_qlora_model(plm_model, args)
elif args.training_method == 'plm-adalora':
plm_model=create_adalora_model(plm_model, args)
elif args.training_method == "plm-dora":
plm_model=create_dora_model(plm_model, args)
elif args.training_method == "plm-ia3":
plm_model=create_ia3_model(plm_model, args)
# Move models to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
plm_model = plm_model.to(device)
return model, plm_model, tokenizer
def create_lora_model(args):
tokenizer, plm_model = create_plm_and_tokenizer(args)
# Update hidden size based on PLM
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
model = LoraModel(args=args)
# Enable gradient checkpointing
plm_model.gradient_checkpointing_enable()
plm_model=setup_lora_plm(plm_model, args)
return model, plm_model, tokenizer
def create_qlora_model(args):
qlora_config = setup_quantization_config()
tokenizer, plm_model = create_plm_and_tokenizer(args, qlora_config=qlora_config)
# Update hidden size based on PLM
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
model = LoraModel(args=args)
# Enable gradient checkpointing
plm_model.gradient_checkpointing_enable()
plm_model = prepare_model_for_kbit_training(plm_model)
plm_model=setup_lora_plm(plm_model, args)
return model, plm_model, tokenizer
def create_dora_model(args):
tokenizer, plm_model = create_plm_and_tokenizer(args)
# Update hidden size based on PLM
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
model = LoraModel(args=args)
# Enable gradient checkpointing
plm_model.gradient_checkpointing_enable()
plm_model=setup_dora_plm(plm_model, args)
return model, plm_model, tokenizer
def create_adalora_model(args):
tokenizer, plm_model = create_plm_and_tokenizer(args)
# Update hidden size based on PLM
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
model = LoraModel(args=args)
# Enable gradient checkpointing
plm_model.gradient_checkpointing_enable()
plm_model=setup_adalora_plm(plm_model, args)
print(" Using plm adalora ")
return model, plm_model, tokenizer
def create_ia3_model(args):
tokenizer, plm_model = create_plm_and_tokenizer(args)
args.hidden_size = get_hidden_size(plm_model, args.plm_model)
model = LoraModel(args=args)
plm_model.gradient_checkpointing_enable()
plm_model = prepare_model_for_kbit_training(plm_model)
plm_model=setup_ia3_plm(plm_model, args)
print(" Using plm IA3 ")
return model, plm_model, tokenizer
def lora_factory(args):
if args.training_method in "plm-lora":
model, plm_model, tokenizer = create_lora_model(args)
elif args.training_method == "plm-qlora":
model, plm_model, tokenizer = create_qlora_model(args)
elif args.training_method == "plm-dora":
model, plm_model, tokenizer = create_dora_model(args)
elif args.training_method == "plm-adalora":
model, plm_model, tokenizer = create_adalora_model(args)
elif args.training_method == "plm-ia3":
model, plm_model, tokenizer = create_ia3_model(args)
else:
raise ValueError(f"Unsupported lora training method: {args.training_method}")
# Move models to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
plm_model = plm_model.to(device)
return model, plm_model, tokenizer
def freeze_plm_parameters(plm_model):
"""Freeze all parameters in the pre-trained language model."""
for param in plm_model.parameters():
param.requires_grad = False
plm_model.eval() # Set to evaluation mode
def setup_quantization_config():
"""Setup quantization configuration."""
from transformers import BitsAndBytesConfig
# https://huggingface.co/docs/peft/v0.14.0/en/developer_guides/quantization#quantize-a-model
qlora_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
return qlora_config
def setup_lora_plm(plm_model, args):
"""Setup LoRA for pre-trained language model."""
# Import LoRA configurations
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
if not isinstance(plm_model, PreTrainedModel):
raise TypeError("based_model must be a PreTrainedModel instance")
# validate lora_target_modules exist in model
available_modules = [name for name, _ in plm_model.named_modules()]
for module in args.lora_target_modules:
if not any(module in name for name in available_modules):
raise ValueError(f"Target module {module} not found in model")
# Configure LoRA
peft_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules,
)
# Apply LoRA to model
plm_model = get_peft_model(plm_model, peft_config)
plm_model.print_trainable_parameters()
return plm_model
def setup_dora_plm(plm_model, args):
"""Setup DoRA for pre-trained language model."""
# Import DoRA configurations
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
if not isinstance(plm_model, PreTrainedModel):
raise TypeError("based_model must be a PreTrainedModel instance")
# validate Dora_target_modules exist in model
available_modules = [name for name, _ in plm_model.named_modules()]
for module in args.lora_target_modules:
if not any(module in name for name in available_modules):
raise ValueError(f"Target module {module} not found in model")
# Configure DoRA
peft_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules,
use_dora=True
)
# Apply DoRA to model
plm_model = get_peft_model(plm_model, peft_config)
plm_model.print_trainable_parameters()
return plm_model
def setup_adalora_plm(plm_model, args):
"""Setup AdaLoRA for pre-trained language model."""
# Import AdaLoRA configurations
from peft import get_peft_config, get_peft_model, AdaLoraConfig, TaskType
if not isinstance(plm_model, PreTrainedModel):
raise TypeError("based_model must be a PreTrainedModel instance")
# validate lora_target_modules exist in model
available_modules = [name for name, _ in plm_model.named_modules()]
for module in args.lora_target_modules:
if not any(module in name for name in available_modules):
raise ValueError(f"Target module {module} not found in model")
# Configure AdaLoRA
peft_config = AdaLoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
peft_type="ADALORA",
init_r=12,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules
)
# Apply AdaLoRA to model
plm_model = get_peft_model(plm_model, peft_config)
plm_model.print_trainable_parameters()
return plm_model
def setup_ia3_plm(plm_model, args):
"""Setup IA3 for pre-trained language model."""
# Import LoRA configurations
from peft import IA3Model, IA3Config, get_peft_model, TaskType
if not isinstance(plm_model, PreTrainedModel):
raise TypeError("based_model must be a PreTrainedModel instance")
# validate lora_target_modules exist in model
available_modules = [name for name, _ in plm_model.named_modules()]
print(available_modules)
for module in args.lora_target_modules:
if not any(module in name for name in available_modules):
raise ValueError(f"Target module {module} not found in model")
# Configure LoRA
peft_config = IA3Config(
task_type=TaskType.FEATURE_EXTRACTION,
peft_type="IA3",
target_modules=args.lora_target_modules,
feedforward_modules=args.feedforward_modules
)
# Apply LoRA to model
plm_model = get_peft_model(plm_model, peft_config)
plm_model.print_trainable_parameters()
return plm_model
def create_plm_and_tokenizer(args, qlora_config=None):
"""Create pre-trained language model and tokenizer based on model type."""
if "esm" in args.plm_model:
tokenizer = EsmTokenizer.from_pretrained(args.plm_model)
if qlora_config:
plm_model = EsmModel.from_pretrained(args.plm_model, quantization_config=qlora_config)
else:
plm_model = EsmModel.from_pretrained(args.plm_model)
elif "bert" in args.plm_model:
tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False)
if qlora_config:
plm_model = BertModel.from_pretrained(args.plm_model, quantization_config=qlora_config)
else:
plm_model = BertModel.from_pretrained(args.plm_model)
elif "prot_t5" in args.plm_model:
tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False)
if qlora_config:
plm_model = T5EncoderModel.from_pretrained(args.plm_model, quantization_config=qlora_config)
else:
plm_model = T5EncoderModel.from_pretrained(args.plm_model)
elif "ankh" in args.plm_model:
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False)
if qlora_config:
plm_model = T5EncoderModel.from_pretrained(args.plm_model, quantization_config=qlora_config)
else:
plm_model = T5EncoderModel.from_pretrained(args.plm_model)
elif "ProSST" in args.plm_model:
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False)
if qlora_config:
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config)
else:
plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model, trust_remote_code=True)
elif "Prime" in args.plm_model:
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, trust_remote_code=True, do_lower_case=False)
if qlora_config:
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config)
else:
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True)
elif "deep" in args.plm_model:
tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False)
if qlora_config:
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True, quantization_config=qlora_config)
else:
plm_model = AutoModel.from_pretrained(args.plm_model, trust_remote_code=True)
else:
raise ValueError(f"Unsupported model type: {args.plm_model}")
return tokenizer, plm_model
def get_hidden_size(plm_model, model_type):
"""Get hidden size based on model type."""
if "esm" in model_type:
return plm_model.config.hidden_size
elif "bert" in model_type:
return plm_model.config.hidden_size
elif "prot_t5" in model_type or "ankh" in model_type:
return plm_model.config.d_model
elif "ProSST" in model_type:
return plm_model.config.hidden_size
elif "Prime" in model_type:
return plm_model.config.hidden_size
elif "deep" in model_type:
return plm_model.config.hidden_size
else:
raise ValueError(f"Unsupported model type: {model_type}")
def get_vocab_size(plm_model, structure_seq):
"""Get vocabulary size for structure sequences."""
if 'esm3_structure_seq' in structure_seq:
return max(plm_model.config.vocab_size, 4100)
return plm_model.config.vocab_size