|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import itertools |
|
import os |
|
import re |
|
import tempfile |
|
import unittest |
|
|
|
import pytest |
|
import torch |
|
from parameterized import parameterized |
|
from torch import nn |
|
from transformers import AutoModelForCausalLM |
|
|
|
from peft import ( |
|
AdaLoraConfig, |
|
LoHaConfig, |
|
LoKrConfig, |
|
LoraConfig, |
|
OFTConfig, |
|
PeftMixedModel, |
|
PrefixTuningConfig, |
|
get_peft_model, |
|
) |
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
from peft.utils import infer_device |
|
|
|
|
|
class SimpleNet(nn.Module): |
|
def __init__(self, bias=True): |
|
super().__init__() |
|
|
|
self.lin0 = nn.Linear(10, 20, bias=bias) |
|
self.relu = nn.ReLU() |
|
self.lin1 = nn.Linear(20, 16, bias=bias) |
|
|
|
def forward(self, X): |
|
X = X.float() |
|
X = self.lin0(X) |
|
X = self.relu(X) |
|
X = self.lin1(X) |
|
return X |
|
|
|
|
|
def _param_name_func(testcase_func, param_num, params): |
|
|
|
config0, config1 = params[0] |
|
name0 = config0.__class__.__name__[: -len("Config")] |
|
name1 = config1.__class__.__name__[: -len("Config")] |
|
if name0 != name1: |
|
return f"{testcase_func.__name__}_{param_num}_{name0}_{name1}" |
|
return f"{testcase_func.__name__}_{param_num}_{name0}_x2" |
|
|
|
|
|
class TestMixedAdapterTypes(unittest.TestCase): |
|
torch_device = infer_device() |
|
|
|
def _get_model(self, model_cls, peft_config=None, adapter_name=None, seed=0, mixed=True): |
|
torch.manual_seed(0) |
|
base_model = model_cls().eval().to(self.torch_device) |
|
if peft_config is None: |
|
return base_model |
|
|
|
torch.manual_seed(seed) |
|
assert adapter_name is not None |
|
peft_model = get_peft_model(base_model, peft_config, adapter_name=adapter_name, mixed=mixed) |
|
return peft_model.eval().to(self.torch_device) |
|
|
|
def _check_mixed_outputs(self, model_cls, config0, config1, input, *, is_commutative): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
seed0 = 0 |
|
seed1 = 1 |
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
output_base = base_model(input) |
|
assert torch.isfinite(output_base).all() |
|
|
|
|
|
peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
output_config0 = peft_model_0(input) |
|
|
|
assert torch.isfinite(output_config0).all() |
|
assert not torch.allclose(output_base, output_config0, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
output_config1 = peft_model_1(input) |
|
|
|
assert torch.isfinite(output_config1).all() |
|
assert not torch.allclose(output_base, output_config1, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_config0, output_config1, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
torch.manual_seed(seed1) |
|
peft_model_01.add_adapter("adapter1", config1) |
|
peft_model_01.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed_01 = peft_model_01(input) |
|
|
|
|
|
tuner_layers = [mod for mod in peft_model_01.modules() if isinstance(mod, BaseTunerLayer)] |
|
tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} |
|
if type(config0) == type(config1): |
|
assert len(tuner_types) == 1 |
|
else: |
|
assert len(tuner_types) == 2 |
|
|
|
assert peft_model_01.active_adapters == ["adapter0", "adapter1"] |
|
assert torch.isfinite(output_mixed_01).all() |
|
assert not torch.allclose(output_config0, output_mixed_01, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_config1, output_mixed_01, atol=atol, rtol=rtol) |
|
if is_commutative: |
|
delta0 = output_config0 - output_base |
|
delta1 = output_config1 - output_base |
|
delta_mixed_01 = output_mixed_01 - output_base |
|
assert torch.allclose((delta0 + delta1), delta_mixed_01, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
torch.manual_seed(seed0) |
|
peft_model_10.add_adapter("adapter0", config0) |
|
peft_model_10.set_adapter(["adapter1", "adapter0"]) |
|
output_mixed_10 = peft_model_10(input) |
|
|
|
|
|
tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] |
|
tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} |
|
if type(config0) == type(config1): |
|
assert len(tuner_types) == 1 |
|
else: |
|
assert len(tuner_types) == 2 |
|
|
|
assert peft_model_10.active_adapters == ["adapter1", "adapter0"] |
|
assert torch.isfinite(output_mixed_10).all() |
|
assert not torch.allclose(output_config0, output_mixed_10, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_config1, output_mixed_10, atol=atol, rtol=rtol) |
|
if is_commutative: |
|
assert torch.allclose(output_mixed_01, output_mixed_10, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_10.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed_reversed = peft_model_10(input) |
|
|
|
|
|
tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] |
|
tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} |
|
if type(config0) == type(config1): |
|
assert len(tuner_types) == 1 |
|
else: |
|
assert len(tuner_types) == 2 |
|
|
|
assert peft_model_10.active_adapters == ["adapter0", "adapter1"] |
|
assert torch.isfinite(output_mixed_reversed).all() |
|
assert not torch.allclose(output_mixed_reversed, output_config0, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_mixed_reversed, output_config1, atol=atol, rtol=rtol) |
|
if is_commutative: |
|
assert torch.allclose(output_mixed_reversed, output_mixed_01, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_mixed_reversed, output_mixed_10, atol=atol, rtol=rtol) |
|
|
|
def _check_merging(self, model_cls, config0, config1, input): |
|
|
|
|
|
atol = 1e-4 |
|
rtol = 1e-4 |
|
seed0 = 0 |
|
seed1 = 1 |
|
|
|
|
|
peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
torch.manual_seed(seed1) |
|
peft_model_01.add_adapter("adapter1", config1) |
|
peft_model_01.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed_01 = peft_model_01(input) |
|
|
|
model_merged_01 = peft_model_01.merge_and_unload() |
|
output_merged_01 = model_merged_01(input) |
|
assert torch.allclose(output_mixed_01, output_merged_01, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
torch.manual_seed(seed0) |
|
peft_model_10.add_adapter("adapter0", config0) |
|
peft_model_10.set_adapter(["adapter1", "adapter0"]) |
|
output_mixed_10 = peft_model_10(input) |
|
|
|
model_merged_10 = peft_model_10.merge_and_unload() |
|
output_merged_10 = model_merged_10(input) |
|
assert torch.allclose(output_mixed_10, output_merged_10, atol=atol, rtol=rtol) |
|
|
|
def _check_unload(self, model_cls, config0, config1, input): |
|
|
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
seed0 = 0 |
|
seed1 = 1 |
|
|
|
base_model = self._get_model(model_cls) |
|
output_base = base_model(input) |
|
|
|
|
|
peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
torch.manual_seed(seed1) |
|
peft_model_01.add_adapter("adapter1", config1) |
|
peft_model_01.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed = peft_model_01(input) |
|
|
|
|
|
model_unloaded = peft_model_01.unload() |
|
output_unloaded = model_unloaded(input) |
|
|
|
assert not torch.allclose(output_mixed, output_unloaded, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol) |
|
|
|
def _check_disable(self, model_cls, config0, config1, input): |
|
|
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
seed0 = 0 |
|
seed1 = 1 |
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
output_base = base_model(input) |
|
|
|
|
|
peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
output_config0 = peft_model_0(input) |
|
with peft_model_0.disable_adapter(): |
|
output_disabled0 = peft_model_0(input) |
|
|
|
assert not torch.allclose(output_base, output_config0, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_base, output_disabled0, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
output_config1 = peft_model_1(input) |
|
with peft_model_1.disable_adapter(): |
|
output_disabled1 = peft_model_1(input) |
|
|
|
assert not torch.allclose(output_base, output_config1, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_base, output_disabled1, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
torch.manual_seed(seed1) |
|
peft_model_01.add_adapter("adapter1", config1) |
|
peft_model_01.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed_01 = peft_model_01(input) |
|
with peft_model_01.disable_adapter(): |
|
output_disabled01 = peft_model_01(input) |
|
|
|
assert not torch.allclose(output_base, output_mixed_01, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_base, output_disabled01, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
torch.manual_seed(seed0) |
|
peft_model_10.add_adapter("adapter0", config0) |
|
peft_model_10.set_adapter(["adapter1", "adapter0"]) |
|
output_mixed_10 = peft_model_10(input) |
|
with peft_model_10.disable_adapter(): |
|
output_disabled10 = peft_model_10(input) |
|
|
|
assert not torch.allclose(output_base, output_mixed_10, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_base, output_disabled10, atol=atol, rtol=rtol) |
|
|
|
def _check_loading(self, model_cls, config0, config1, input, *, is_commutative): |
|
|
|
|
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
seed0 = 0 |
|
seed1 = 1 |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dirname: |
|
|
|
|
|
peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0, mixed=False) |
|
output_config0 = peft_model_0(input) |
|
peft_model_0.save_pretrained(os.path.join(tmp_dirname, "adapter0")) |
|
|
|
|
|
peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1, mixed=False) |
|
output_config1 = peft_model_1(input) |
|
peft_model_1.save_pretrained(os.path.join(tmp_dirname, "adapter1")) |
|
|
|
|
|
peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) |
|
torch.manual_seed(seed1) |
|
peft_model_01.add_adapter("adapter1", config1) |
|
peft_model_01.set_adapter(["adapter0", "adapter1"]) |
|
output_mixed_01 = peft_model_01(input) |
|
|
|
|
|
peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) |
|
torch.manual_seed(seed0) |
|
peft_model_10.add_adapter("adapter0", config0) |
|
peft_model_10.set_adapter(["adapter1", "adapter0"]) |
|
output_mixed_10 = peft_model_10(input) |
|
|
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
|
|
|
|
|
|
|
|
torch.manual_seed(123456) |
|
peft_model_loaded0 = PeftMixedModel.from_pretrained( |
|
base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" |
|
) |
|
output_loaded0 = peft_model_loaded0(input) |
|
assert torch.allclose(output_config0, output_loaded0, atol=atol, rtol=rtol) |
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
torch.manual_seed(654321) |
|
peft_model_loaded1 = PeftMixedModel.from_pretrained( |
|
base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" |
|
) |
|
output_loaded1 = peft_model_loaded1(input) |
|
assert torch.allclose(output_config1, output_loaded1, atol=atol, rtol=rtol) |
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
torch.manual_seed(97531) |
|
peft_model_loaded_01 = PeftMixedModel.from_pretrained( |
|
base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" |
|
) |
|
peft_model_loaded_01.load_adapter(os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1") |
|
|
|
assert peft_model_loaded_01.active_adapters == ["adapter0"] |
|
output_loaded01_0 = peft_model_loaded_01(input) |
|
assert torch.allclose(output_config0, output_loaded01_0, atol=atol, rtol=rtol) |
|
|
|
peft_model_loaded_01.set_adapter(["adapter1"]) |
|
assert peft_model_loaded_01.active_adapters == ["adapter1"] |
|
output_loaded01_1 = peft_model_loaded_01(input) |
|
assert torch.allclose(output_config1, output_loaded01_1, atol=atol, rtol=rtol) |
|
|
|
peft_model_loaded_01.set_adapter(["adapter0", "adapter1"]) |
|
output_loaded01 = peft_model_loaded_01(input) |
|
assert torch.allclose(output_mixed_01, output_loaded01, atol=atol, rtol=rtol) |
|
|
|
|
|
base_model = self._get_model(model_cls) |
|
torch.manual_seed(445566) |
|
peft_model_loaded_10 = PeftMixedModel.from_pretrained( |
|
base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" |
|
) |
|
peft_model_loaded_10.load_adapter(os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0") |
|
|
|
assert peft_model_loaded_10.active_adapters == ["adapter1"] |
|
output_loaded10_1 = peft_model_loaded_10(input) |
|
assert torch.allclose(output_config1, output_loaded10_1, atol=atol, rtol=rtol) |
|
|
|
peft_model_loaded_10.set_adapter(["adapter0"]) |
|
assert peft_model_loaded_10.active_adapters == ["adapter0"] |
|
output_loaded10_0 = peft_model_loaded_10(input) |
|
assert torch.allclose(output_config0, output_loaded10_0, atol=atol, rtol=rtol) |
|
|
|
peft_model_loaded_10.set_adapter(["adapter1", "adapter0"]) |
|
output_loaded10 = peft_model_loaded_10(input) |
|
assert torch.allclose(output_mixed_10, output_loaded10, atol=atol, rtol=rtol) |
|
|
|
if is_commutative: |
|
assert torch.allclose(output_loaded01, output_loaded10, atol=atol, rtol=rtol) |
|
assert torch.allclose(output_loaded10, output_mixed_01, atol=atol, rtol=rtol) |
|
|
|
@parameterized.expand( |
|
itertools.combinations( |
|
[ |
|
LoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
LoHaConfig(target_modules=["lin0"], init_weights=False), |
|
LoKrConfig(target_modules=["lin0"], init_weights=False), |
|
AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
OFTConfig(target_modules=["lin0"], init_weights=False), |
|
], |
|
r=2, |
|
), |
|
name_func=_param_name_func, |
|
) |
|
def test_target_first_layer(self, config0, config1): |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) |
|
self._check_merging(SimpleNet, config0, config1, input) |
|
self._check_unload(SimpleNet, config0, config1, input) |
|
self._check_disable(SimpleNet, config1, config0, input) |
|
self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) |
|
|
|
@parameterized.expand( |
|
itertools.combinations( |
|
[ |
|
LoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
LoHaConfig(target_modules=["lin1"], init_weights=False), |
|
LoKrConfig(target_modules=["lin1"], init_weights=False), |
|
AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
OFTConfig(target_modules=["lin1"], init_weights=False), |
|
], |
|
r=2, |
|
), |
|
name_func=_param_name_func, |
|
) |
|
def test_target_last_layer(self, config0, config1): |
|
|
|
|
|
|
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
|
|
is_commutative = not any(isinstance(config, OFTConfig) for config in [config0, config1]) |
|
|
|
self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=is_commutative) |
|
self._check_merging(SimpleNet, config0, config1, input) |
|
self._check_unload(SimpleNet, config0, config1, input) |
|
self._check_disable(SimpleNet, config1, config0, input) |
|
self._check_loading(SimpleNet, config0, config1, input, is_commutative=is_commutative) |
|
|
|
@parameterized.expand( |
|
itertools.combinations( |
|
[ |
|
LoraConfig(init_lora_weights=False), |
|
LoHaConfig(init_weights=False), |
|
LoKrConfig(init_weights=False), |
|
AdaLoraConfig(init_lora_weights=False), |
|
OFTConfig(init_weights=False), |
|
], |
|
r=2, |
|
), |
|
name_func=_param_name_func, |
|
) |
|
def test_target_different_layers(self, config0, config1): |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
|
|
config0.target_modules = ["lin0"] |
|
config1.target_modules = ["lin1"] |
|
self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) |
|
self._check_merging(SimpleNet, config0, config1, input) |
|
self._check_unload(SimpleNet, config0, config1, input) |
|
self._check_disable(SimpleNet, config0, config1, input) |
|
self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) |
|
|
|
|
|
config0.target_modules = ["lin1"] |
|
config1.target_modules = ["lin0"] |
|
self._check_mixed_outputs(SimpleNet, config1, config0, input, is_commutative=False) |
|
self._check_merging(SimpleNet, config1, config0, input) |
|
self._check_unload(SimpleNet, config1, config0, input) |
|
self._check_disable(SimpleNet, config1, config0, input) |
|
self._check_loading(SimpleNet, config1, config0, input, is_commutative=False) |
|
|
|
@parameterized.expand( |
|
[ |
|
( |
|
LoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
LoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
), |
|
( |
|
LoHaConfig(target_modules=["lin1"], init_weights=False), |
|
LoHaConfig(target_modules=["lin1"], init_weights=False), |
|
), |
|
( |
|
LoKrConfig(target_modules=["lin1"], init_weights=False), |
|
LoKrConfig(target_modules=["lin1"], init_weights=False), |
|
), |
|
( |
|
AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), |
|
), |
|
( |
|
OFTConfig(target_modules=["lin1"], init_weights=False), |
|
OFTConfig(target_modules=["lin1"], init_weights=False), |
|
), |
|
], |
|
name_func=_param_name_func, |
|
) |
|
def test_target_last_layer_same_type(self, config0, config1): |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
|
|
is_commutative = not any(isinstance(config, OFTConfig) for config in [config0, config1]) |
|
|
|
self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=is_commutative) |
|
self._check_merging(SimpleNet, config0, config1, input) |
|
self._check_unload(SimpleNet, config0, config1, input) |
|
self._check_disable(SimpleNet, config1, config0, input) |
|
|
|
@parameterized.expand( |
|
[ |
|
( |
|
LoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
LoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
), |
|
( |
|
LoHaConfig(target_modules=["lin0"], init_weights=False), |
|
LoHaConfig(target_modules=["lin0"], init_weights=False), |
|
), |
|
( |
|
LoKrConfig(target_modules=["lin0"], init_weights=False), |
|
LoKrConfig(target_modules=["lin0"], init_weights=False), |
|
), |
|
( |
|
AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), |
|
), |
|
( |
|
OFTConfig(target_modules=["lin0"], init_weights=False), |
|
OFTConfig(target_modules=["lin0"], init_weights=False), |
|
), |
|
], |
|
name_func=_param_name_func, |
|
) |
|
def test_target_first_layer_same_type(self, config0, config1): |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) |
|
self._check_merging(SimpleNet, config0, config1, input) |
|
self._check_unload(SimpleNet, config0, config1, input) |
|
self._check_disable(SimpleNet, config1, config0, input) |
|
self._check_loading(SimpleNet, config0, config1, input, is_commutative=False) |
|
|
|
def test_deeply_nested(self): |
|
|
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
torch.manual_seed(0) |
|
|
|
model = SimpleNet().eval().to(self.torch_device) |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
output_base = model(input) |
|
|
|
config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
|
|
config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) |
|
peft_model.add_adapter("adapter1", config1) |
|
|
|
config2 = AdaLoraConfig(r=4, lora_alpha=4, target_modules=["lin1"], init_lora_weights=False) |
|
peft_model.add_adapter("adapter2", config2) |
|
|
|
config3 = LoKrConfig(r=4, alpha=4, target_modules=["lin0", "lin1"], init_weights=False) |
|
peft_model.add_adapter("adapter3", config3) |
|
|
|
config4 = OFTConfig(r=8, target_modules=["lin0", "lin1"], init_weights=False) |
|
peft_model.add_adapter("adapter4", config4) |
|
|
|
peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) |
|
output_mixed = peft_model(input) |
|
assert torch.isfinite(output_base).all() |
|
assert not torch.allclose(output_base, output_mixed, atol=atol, rtol=rtol) |
|
|
|
|
|
with peft_model.disable_adapter(): |
|
output_disabled = peft_model(input) |
|
assert torch.isfinite(output_disabled).all() |
|
assert torch.allclose(output_base, output_disabled, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_mixed, output_disabled, atol=atol, rtol=rtol) |
|
|
|
|
|
model_copy = copy.deepcopy(peft_model) |
|
model = model_copy.merge_and_unload() |
|
output_merged = model(input) |
|
assert torch.isfinite(output_merged).all() |
|
assert torch.allclose(output_mixed, output_merged, atol=atol, rtol=rtol) |
|
|
|
|
|
model_copy = copy.deepcopy(peft_model) |
|
model_copy.set_adapter(["adapter1", "adapter3"]) |
|
output_13 = model_copy(input) |
|
assert torch.isfinite(output_13).all() |
|
assert not torch.allclose(output_mixed, output_13, atol=atol, rtol=rtol) |
|
|
|
model_copy.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) |
|
model_merged_unloaded = model_copy.merge_and_unload(adapter_names=["adapter1", "adapter3"]) |
|
output_merged_13 = model_merged_unloaded(input) |
|
assert torch.isfinite(output_merged_13).all() |
|
assert torch.allclose(output_13, output_merged_13, atol=atol, rtol=rtol) |
|
|
|
|
|
model_copy = copy.deepcopy(peft_model) |
|
model_unloaded = model_copy.unload() |
|
output_unloaded = model_unloaded(input) |
|
assert torch.isfinite(output_unloaded).all() |
|
assert torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol) |
|
|
|
def test_delete_adapter(self): |
|
atol = 1e-5 |
|
rtol = 1e-5 |
|
torch.manual_seed(0) |
|
|
|
model = SimpleNet().eval().to(self.torch_device) |
|
input = torch.arange(90).reshape(9, 10).to(self.torch_device) |
|
output_base = model(input) |
|
|
|
|
|
torch.manual_seed(0) |
|
config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
output_0 = peft_model(input) |
|
assert not torch.allclose(output_base, output_0, atol=atol, rtol=rtol) |
|
|
|
|
|
torch.manual_seed(1) |
|
config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) |
|
peft_model.add_adapter("adapter1", config1) |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
output_01 = peft_model(input) |
|
assert not torch.allclose(output_base, output_01, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_0, output_01, atol=atol, rtol=rtol) |
|
|
|
|
|
peft_model.delete_adapter("adapter1") |
|
assert peft_model.active_adapters == ["adapter0"] |
|
output_deleted_1 = peft_model(input) |
|
assert torch.allclose(output_0, output_deleted_1, atol=atol, rtol=rtol) |
|
|
|
msg = re.escape("Adapter(s) ['adapter1'] not found, available adapters: ['adapter0']") |
|
with pytest.raises(ValueError, match=msg): |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
|
|
|
|
torch.manual_seed(1) |
|
peft_model.add_adapter("adapter1", config1) |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
output_01_readded = peft_model(input) |
|
assert not torch.allclose(output_base, output_01_readded, atol=atol, rtol=rtol) |
|
|
|
|
|
torch.manual_seed(0) |
|
model = SimpleNet().eval().to(self.torch_device) |
|
torch.manual_seed(0) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
torch.manual_seed(1) |
|
peft_model.add_adapter("adapter1", config1) |
|
peft_model.delete_adapter("adapter0") |
|
assert peft_model.active_adapters == ["adapter1"] |
|
output_deleted_0 = peft_model(input) |
|
assert not torch.allclose(output_deleted_0, output_base, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output_deleted_0, output_01, atol=atol, rtol=rtol) |
|
|
|
msg = re.escape("Adapter(s) ['adapter0'] not found, available adapters: ['adapter1']") |
|
with pytest.raises(ValueError, match=msg): |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
|
|
peft_model.delete_adapter("adapter1") |
|
assert peft_model.active_adapters == [] |
|
output_deleted_01 = peft_model(input) |
|
assert torch.allclose(output_deleted_01, output_base, atol=atol, rtol=rtol) |
|
|
|
def test_modules_to_save(self): |
|
model = SimpleNet().eval().to(self.torch_device) |
|
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
|
|
|
|
|
|
config1 = LoHaConfig(target_modules=["lin0"], modules_to_save=["lin1"]) |
|
peft_model.add_adapter("adapter1", config1) |
|
with pytest.raises(ValueError, match="Only one adapter can be set at a time for modules_to_save"): |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
|
|
def test_get_nb_trainable_parameters(self): |
|
model = SimpleNet().eval().to(self.torch_device) |
|
params_base = sum(p.numel() for p in model.parameters()) |
|
|
|
config0 = LoraConfig(target_modules=["lin0"]) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
trainable_params0, all_param0 = peft_model.get_nb_trainable_parameters() |
|
|
|
params_lora = sum(p.numel() for n, p in model.named_parameters() if "adapter0" in n) |
|
assert trainable_params0 == params_lora |
|
assert all_param0 == (params_base + params_lora) |
|
|
|
config1 = LoHaConfig(target_modules=["lin1"]) |
|
peft_model.add_adapter("adapter1", config1) |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
params_loha = sum(p.numel() for n, p in model.named_parameters() if "adapter1" in n) |
|
trainable_params1, all_param1 = peft_model.get_nb_trainable_parameters() |
|
assert trainable_params1 == (params_lora + params_loha) |
|
assert all_param1 == ((params_base + params_lora) + params_loha) |
|
|
|
config2 = AdaLoraConfig(target_modules=["lin0", "lin1"]) |
|
peft_model.add_adapter("adapter2", config2) |
|
peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) |
|
params_adalora = sum(p.numel() for n, p in model.named_parameters() if "adapter2" in n) |
|
trainable_params2, all_param2 = peft_model.get_nb_trainable_parameters() |
|
|
|
assert trainable_params2 == (((params_lora + params_loha) + params_adalora) - 2) |
|
assert all_param2 == (((params_base + params_lora) + params_loha) + params_adalora) |
|
|
|
def test_incompatible_config_raises(self): |
|
model = SimpleNet().eval().to(self.torch_device) |
|
config0 = LoraConfig(target_modules=["lin0"]) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
|
|
config1 = PrefixTuningConfig() |
|
msg = "The provided `peft_type` 'PREFIX_TUNING' is not compatible with the `PeftMixedModel`." |
|
with pytest.raises(ValueError, match=msg): |
|
peft_model.add_adapter("adapter1", config1) |
|
|
|
def test_decoder_model(self): |
|
|
|
torch.manual_seed(0) |
|
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" |
|
model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) |
|
input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) |
|
attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) |
|
input_dict = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
} |
|
output_base = model.generate(**input_dict) |
|
|
|
torch.manual_seed(0) |
|
config0 = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) |
|
peft_model = get_peft_model(model, config0, "adapter0", mixed=True) |
|
output0 = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output0).all() |
|
assert not torch.allclose(output_base, output0) |
|
|
|
torch.manual_seed(1) |
|
config1 = LoHaConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) |
|
peft_model.add_adapter("adapter1", config1) |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
output1 = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output1).all() |
|
assert not torch.allclose(output0, output1) |
|
|
|
torch.manual_seed(2) |
|
config2 = AdaLoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) |
|
peft_model.add_adapter("adapter2", config2) |
|
peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) |
|
output2 = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output2).all() |
|
assert not torch.allclose(output1, output2) |
|
|
|
torch.manual_seed(3) |
|
config3 = LoKrConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) |
|
peft_model.add_adapter("adapter3", config3) |
|
peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3"]) |
|
output3 = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output3).all() |
|
assert not torch.allclose(output2, output3) |
|
|
|
torch.manual_seed(4) |
|
config4 = OFTConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) |
|
peft_model.add_adapter("adapter4", config4) |
|
peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) |
|
output4 = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output4).all() |
|
assert not torch.allclose(output3, output4) |
|
|
|
with peft_model.disable_adapter(): |
|
output_disabled = peft_model.generate(**input_dict) |
|
assert torch.isfinite(output_disabled).all() |
|
assert torch.allclose(output_base, output_disabled) |
|
|
|
model_unloaded = peft_model.merge_and_unload() |
|
output_unloaded = model_unloaded.generate(**input_dict) |
|
assert torch.isfinite(output_unloaded).all() |
|
assert torch.allclose(output4, output_unloaded) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
|
torch.manual_seed(0) |
|
model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) |
|
torch.manual_seed(0) |
|
peft_model = get_peft_model(model, config0, "adapter0") |
|
output0_save = peft_model(**input_dict).logits |
|
assert torch.isfinite(output0_save).all() |
|
peft_model.save_pretrained(tmp_dir) |
|
|
|
|
|
torch.manual_seed(0) |
|
model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) |
|
torch.manual_seed(1) |
|
peft_model = get_peft_model(model, config1, "adapter1") |
|
output1_save = peft_model(**input_dict).logits |
|
assert torch.isfinite(output1_save).all() |
|
peft_model.save_pretrained(tmp_dir) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) |
|
peft_model = PeftMixedModel.from_pretrained(model, os.path.join(tmp_dir, "adapter0"), "adapter0") |
|
peft_model.load_adapter(os.path.join(tmp_dir, "adapter1"), "adapter1") |
|
peft_model.set_adapter(["adapter0", "adapter1"]) |
|
output01_loaded = peft_model(**input_dict).logits |
|
|
|
atol, rtol = 1e-3, 1e-3 |
|
assert torch.isfinite(output01_loaded).all() |
|
assert not torch.allclose(output0_save, output01_loaded, atol=atol, rtol=rtol) |
|
assert not torch.allclose(output1_save, output01_loaded, atol=atol, rtol=rtol) |
|
|